Add backend api and engine
This commit is contained in:
20
engine/core/__init__.py
Normal file
20
engine/core/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Core Components Package"""
|
||||
|
||||
from core.events import EventBus, get_event_bus
|
||||
from core.transports import BaseTransport, SocketTransport, WebRtcTransport
|
||||
from core.session import Session
|
||||
from core.conversation import ConversationManager, ConversationState, ConversationTurn
|
||||
from core.duplex_pipeline import DuplexPipeline
|
||||
|
||||
__all__ = [
|
||||
"EventBus",
|
||||
"get_event_bus",
|
||||
"BaseTransport",
|
||||
"SocketTransport",
|
||||
"WebRtcTransport",
|
||||
"Session",
|
||||
"ConversationManager",
|
||||
"ConversationState",
|
||||
"ConversationTurn",
|
||||
"DuplexPipeline",
|
||||
]
|
||||
255
engine/core/conversation.py
Normal file
255
engine/core/conversation.py
Normal file
@@ -0,0 +1,255 @@
|
||||
"""Conversation management for voice AI.
|
||||
|
||||
Handles conversation context, turn-taking, and message history
|
||||
for multi-turn voice conversations.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import List, Optional, Dict, Any, Callable, Awaitable
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from loguru import logger
|
||||
|
||||
from services.base import LLMMessage
|
||||
|
||||
|
||||
class ConversationState(Enum):
|
||||
"""State of the conversation."""
|
||||
IDLE = "idle" # Waiting for user input
|
||||
LISTENING = "listening" # User is speaking
|
||||
PROCESSING = "processing" # Processing user input (LLM)
|
||||
SPEAKING = "speaking" # Bot is speaking
|
||||
INTERRUPTED = "interrupted" # Bot was interrupted
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConversationTurn:
|
||||
"""A single turn in the conversation."""
|
||||
role: str # "user" or "assistant"
|
||||
text: str
|
||||
audio_duration_ms: Optional[int] = None
|
||||
timestamp: float = field(default_factory=lambda: asyncio.get_event_loop().time())
|
||||
was_interrupted: bool = False
|
||||
|
||||
|
||||
class ConversationManager:
|
||||
"""
|
||||
Manages conversation state and history.
|
||||
|
||||
Provides:
|
||||
- Message history for LLM context
|
||||
- Turn management
|
||||
- State tracking
|
||||
- Event callbacks for state changes
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
system_prompt: Optional[str] = None,
|
||||
max_history: int = 20,
|
||||
greeting: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize conversation manager.
|
||||
|
||||
Args:
|
||||
system_prompt: System prompt for LLM
|
||||
max_history: Maximum number of turns to keep
|
||||
greeting: Optional greeting message when conversation starts
|
||||
"""
|
||||
self.system_prompt = system_prompt or (
|
||||
"You are a helpful, friendly voice assistant. "
|
||||
"Keep your responses concise and conversational. "
|
||||
"Respond naturally as if having a phone conversation. "
|
||||
"If you don't understand something, ask for clarification."
|
||||
)
|
||||
self.max_history = max_history
|
||||
self.greeting = greeting
|
||||
|
||||
# State
|
||||
self.state = ConversationState.IDLE
|
||||
self.turns: List[ConversationTurn] = []
|
||||
|
||||
# Callbacks
|
||||
self._state_callbacks: List[Callable[[ConversationState, ConversationState], Awaitable[None]]] = []
|
||||
self._turn_callbacks: List[Callable[[ConversationTurn], Awaitable[None]]] = []
|
||||
|
||||
# Current turn tracking
|
||||
self._current_user_text: str = ""
|
||||
self._current_assistant_text: str = ""
|
||||
|
||||
logger.info("ConversationManager initialized")
|
||||
|
||||
def on_state_change(
|
||||
self,
|
||||
callback: Callable[[ConversationState, ConversationState], Awaitable[None]]
|
||||
) -> None:
|
||||
"""Register callback for state changes."""
|
||||
self._state_callbacks.append(callback)
|
||||
|
||||
def on_turn_complete(
|
||||
self,
|
||||
callback: Callable[[ConversationTurn], Awaitable[None]]
|
||||
) -> None:
|
||||
"""Register callback for turn completion."""
|
||||
self._turn_callbacks.append(callback)
|
||||
|
||||
async def set_state(self, new_state: ConversationState) -> None:
|
||||
"""Set conversation state and notify listeners."""
|
||||
if new_state != self.state:
|
||||
old_state = self.state
|
||||
self.state = new_state
|
||||
logger.debug(f"Conversation state: {old_state.value} -> {new_state.value}")
|
||||
|
||||
for callback in self._state_callbacks:
|
||||
try:
|
||||
await callback(old_state, new_state)
|
||||
except Exception as e:
|
||||
logger.error(f"State callback error: {e}")
|
||||
|
||||
def get_messages(self) -> List[LLMMessage]:
|
||||
"""
|
||||
Get conversation history as LLM messages.
|
||||
|
||||
Returns:
|
||||
List of LLMMessage objects including system prompt
|
||||
"""
|
||||
messages = [LLMMessage(role="system", content=self.system_prompt)]
|
||||
|
||||
# Add conversation history
|
||||
for turn in self.turns[-self.max_history:]:
|
||||
messages.append(LLMMessage(role=turn.role, content=turn.text))
|
||||
|
||||
# Add current user text if any
|
||||
if self._current_user_text:
|
||||
messages.append(LLMMessage(role="user", content=self._current_user_text))
|
||||
|
||||
return messages
|
||||
|
||||
async def start_user_turn(self) -> None:
|
||||
"""Signal that user has started speaking."""
|
||||
await self.set_state(ConversationState.LISTENING)
|
||||
self._current_user_text = ""
|
||||
|
||||
async def update_user_text(self, text: str, is_final: bool = False) -> None:
|
||||
"""
|
||||
Update current user text (from ASR).
|
||||
|
||||
Args:
|
||||
text: Transcribed text
|
||||
is_final: Whether this is the final transcript
|
||||
"""
|
||||
self._current_user_text = text
|
||||
|
||||
async def end_user_turn(self, text: str) -> None:
|
||||
"""
|
||||
End user turn and add to history.
|
||||
|
||||
Args:
|
||||
text: Final user text
|
||||
"""
|
||||
if text.strip():
|
||||
turn = ConversationTurn(role="user", text=text.strip())
|
||||
self.turns.append(turn)
|
||||
|
||||
for callback in self._turn_callbacks:
|
||||
try:
|
||||
await callback(turn)
|
||||
except Exception as e:
|
||||
logger.error(f"Turn callback error: {e}")
|
||||
|
||||
logger.info(f"User: {text[:50]}...")
|
||||
|
||||
self._current_user_text = ""
|
||||
await self.set_state(ConversationState.PROCESSING)
|
||||
|
||||
async def start_assistant_turn(self) -> None:
|
||||
"""Signal that assistant has started speaking."""
|
||||
await self.set_state(ConversationState.SPEAKING)
|
||||
self._current_assistant_text = ""
|
||||
|
||||
async def update_assistant_text(self, text: str) -> None:
|
||||
"""
|
||||
Update current assistant text (streaming).
|
||||
|
||||
Args:
|
||||
text: Text chunk from LLM
|
||||
"""
|
||||
self._current_assistant_text += text
|
||||
|
||||
async def end_assistant_turn(self, was_interrupted: bool = False) -> None:
|
||||
"""
|
||||
End assistant turn and add to history.
|
||||
|
||||
Args:
|
||||
was_interrupted: Whether the turn was interrupted by user
|
||||
"""
|
||||
text = self._current_assistant_text.strip()
|
||||
if text:
|
||||
turn = ConversationTurn(
|
||||
role="assistant",
|
||||
text=text,
|
||||
was_interrupted=was_interrupted
|
||||
)
|
||||
self.turns.append(turn)
|
||||
|
||||
for callback in self._turn_callbacks:
|
||||
try:
|
||||
await callback(turn)
|
||||
except Exception as e:
|
||||
logger.error(f"Turn callback error: {e}")
|
||||
|
||||
status = " (interrupted)" if was_interrupted else ""
|
||||
logger.info(f"Assistant{status}: {text[:50]}...")
|
||||
|
||||
self._current_assistant_text = ""
|
||||
|
||||
if was_interrupted:
|
||||
await self.set_state(ConversationState.INTERRUPTED)
|
||||
else:
|
||||
await self.set_state(ConversationState.IDLE)
|
||||
|
||||
async def interrupt(self) -> None:
|
||||
"""Handle interruption (barge-in)."""
|
||||
if self.state == ConversationState.SPEAKING:
|
||||
await self.end_assistant_turn(was_interrupted=True)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset conversation history."""
|
||||
self.turns = []
|
||||
self._current_user_text = ""
|
||||
self._current_assistant_text = ""
|
||||
self.state = ConversationState.IDLE
|
||||
logger.info("Conversation reset")
|
||||
|
||||
@property
|
||||
def turn_count(self) -> int:
|
||||
"""Get number of turns in conversation."""
|
||||
return len(self.turns)
|
||||
|
||||
@property
|
||||
def last_user_text(self) -> Optional[str]:
|
||||
"""Get last user text."""
|
||||
for turn in reversed(self.turns):
|
||||
if turn.role == "user":
|
||||
return turn.text
|
||||
return None
|
||||
|
||||
@property
|
||||
def last_assistant_text(self) -> Optional[str]:
|
||||
"""Get last assistant text."""
|
||||
for turn in reversed(self.turns):
|
||||
if turn.role == "assistant":
|
||||
return turn.text
|
||||
return None
|
||||
|
||||
def get_context_summary(self) -> Dict[str, Any]:
|
||||
"""Get a summary of conversation context."""
|
||||
return {
|
||||
"state": self.state.value,
|
||||
"turn_count": self.turn_count,
|
||||
"last_user": self.last_user_text,
|
||||
"last_assistant": self.last_assistant_text,
|
||||
"current_user": self._current_user_text or None,
|
||||
"current_assistant": self._current_assistant_text or None
|
||||
}
|
||||
719
engine/core/duplex_pipeline.py
Normal file
719
engine/core/duplex_pipeline.py
Normal file
@@ -0,0 +1,719 @@
|
||||
"""Full duplex audio pipeline for AI voice conversation.
|
||||
|
||||
This module implements the core duplex pipeline that orchestrates:
|
||||
- VAD (Voice Activity Detection)
|
||||
- EOU (End of Utterance) Detection
|
||||
- ASR (Automatic Speech Recognition) - optional
|
||||
- LLM (Language Model)
|
||||
- TTS (Text-to-Speech)
|
||||
|
||||
Inspired by pipecat's frame-based architecture and active-call's
|
||||
event-driven design.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Optional, Callable, Awaitable
|
||||
from loguru import logger
|
||||
|
||||
from core.transports import BaseTransport
|
||||
from core.conversation import ConversationManager, ConversationState
|
||||
from core.events import get_event_bus
|
||||
from processors.vad import VADProcessor, SileroVAD
|
||||
from processors.eou import EouDetector
|
||||
from services.base import BaseLLMService, BaseTTSService, BaseASRService
|
||||
from services.llm import OpenAILLMService, MockLLMService
|
||||
from services.tts import EdgeTTSService, MockTTSService
|
||||
from services.asr import BufferedASRService
|
||||
from services.siliconflow_tts import SiliconFlowTTSService
|
||||
from services.siliconflow_asr import SiliconFlowASRService
|
||||
from app.config import settings
|
||||
|
||||
|
||||
class DuplexPipeline:
|
||||
"""
|
||||
Full duplex audio pipeline for AI voice conversation.
|
||||
|
||||
Handles bidirectional audio flow with:
|
||||
- User speech detection and transcription
|
||||
- AI response generation
|
||||
- Text-to-speech synthesis
|
||||
- Barge-in (interruption) support
|
||||
|
||||
Architecture (inspired by pipecat):
|
||||
|
||||
User Audio → VAD → EOU → [ASR] → LLM → TTS → Audio Out
|
||||
↓
|
||||
Barge-in Detection → Interrupt
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transport: BaseTransport,
|
||||
session_id: str,
|
||||
llm_service: Optional[BaseLLMService] = None,
|
||||
tts_service: Optional[BaseTTSService] = None,
|
||||
asr_service: Optional[BaseASRService] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
greeting: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize duplex pipeline.
|
||||
|
||||
Args:
|
||||
transport: Transport for sending audio/events
|
||||
session_id: Session identifier
|
||||
llm_service: LLM service (defaults to OpenAI)
|
||||
tts_service: TTS service (defaults to EdgeTTS)
|
||||
asr_service: ASR service (optional)
|
||||
system_prompt: System prompt for LLM
|
||||
greeting: Optional greeting to speak on start
|
||||
"""
|
||||
self.transport = transport
|
||||
self.session_id = session_id
|
||||
self.event_bus = get_event_bus()
|
||||
|
||||
# Initialize VAD
|
||||
self.vad_model = SileroVAD(
|
||||
model_path=settings.vad_model_path,
|
||||
sample_rate=settings.sample_rate
|
||||
)
|
||||
self.vad_processor = VADProcessor(
|
||||
vad_model=self.vad_model,
|
||||
threshold=settings.vad_threshold
|
||||
)
|
||||
|
||||
# Initialize EOU detector
|
||||
self.eou_detector = EouDetector(
|
||||
silence_threshold_ms=settings.vad_eou_threshold_ms,
|
||||
min_speech_duration_ms=settings.vad_min_speech_duration_ms
|
||||
)
|
||||
|
||||
# Initialize services
|
||||
self.llm_service = llm_service
|
||||
self.tts_service = tts_service
|
||||
self.asr_service = asr_service # Will be initialized in start()
|
||||
|
||||
# Track last sent transcript to avoid duplicates
|
||||
self._last_sent_transcript = ""
|
||||
|
||||
# Conversation manager
|
||||
self.conversation = ConversationManager(
|
||||
system_prompt=system_prompt,
|
||||
greeting=greeting
|
||||
)
|
||||
|
||||
# State
|
||||
self._running = True
|
||||
self._is_bot_speaking = False
|
||||
self._current_turn_task: Optional[asyncio.Task] = None
|
||||
self._audio_buffer: bytes = b""
|
||||
max_buffer_seconds = settings.max_audio_buffer_seconds if hasattr(settings, "max_audio_buffer_seconds") else 30
|
||||
self._max_audio_buffer_bytes = int(settings.sample_rate * 2 * max_buffer_seconds)
|
||||
self._last_vad_status: str = "Silence"
|
||||
self._process_lock = asyncio.Lock()
|
||||
|
||||
# Interruption handling
|
||||
self._interrupt_event = asyncio.Event()
|
||||
|
||||
# Latency tracking - TTFB (Time to First Byte)
|
||||
self._turn_start_time: Optional[float] = None
|
||||
self._first_audio_sent: bool = False
|
||||
|
||||
# Barge-in filtering - require minimum speech duration to interrupt
|
||||
self._barge_in_speech_start_time: Optional[float] = None
|
||||
self._barge_in_min_duration_ms: int = settings.barge_in_min_duration_ms if hasattr(settings, 'barge_in_min_duration_ms') else 50
|
||||
self._barge_in_speech_frames: int = 0 # Count speech frames
|
||||
self._barge_in_silence_frames: int = 0 # Count silence frames during potential barge-in
|
||||
self._barge_in_silence_tolerance: int = 3 # Allow up to 3 silence frames (60ms at 20ms chunks)
|
||||
|
||||
logger.info(f"DuplexPipeline initialized for session {session_id}")
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the pipeline and connect services."""
|
||||
try:
|
||||
# Connect LLM service
|
||||
if not self.llm_service:
|
||||
if settings.openai_api_key:
|
||||
self.llm_service = OpenAILLMService(
|
||||
api_key=settings.openai_api_key,
|
||||
base_url=settings.openai_api_url,
|
||||
model=settings.llm_model
|
||||
)
|
||||
else:
|
||||
logger.warning("No OpenAI API key - using mock LLM")
|
||||
self.llm_service = MockLLMService()
|
||||
|
||||
await self.llm_service.connect()
|
||||
|
||||
# Connect TTS service
|
||||
if not self.tts_service:
|
||||
if settings.tts_provider == "siliconflow" and settings.siliconflow_api_key:
|
||||
self.tts_service = SiliconFlowTTSService(
|
||||
api_key=settings.siliconflow_api_key,
|
||||
voice=settings.tts_voice,
|
||||
model=settings.siliconflow_tts_model,
|
||||
sample_rate=settings.sample_rate,
|
||||
speed=settings.tts_speed
|
||||
)
|
||||
logger.info("Using SiliconFlow TTS service")
|
||||
else:
|
||||
self.tts_service = EdgeTTSService(
|
||||
voice=settings.tts_voice,
|
||||
sample_rate=settings.sample_rate
|
||||
)
|
||||
logger.info("Using Edge TTS service")
|
||||
|
||||
await self.tts_service.connect()
|
||||
|
||||
# Connect ASR service
|
||||
if not self.asr_service:
|
||||
if settings.asr_provider == "siliconflow" and settings.siliconflow_api_key:
|
||||
self.asr_service = SiliconFlowASRService(
|
||||
api_key=settings.siliconflow_api_key,
|
||||
model=settings.siliconflow_asr_model,
|
||||
sample_rate=settings.sample_rate,
|
||||
interim_interval_ms=settings.asr_interim_interval_ms,
|
||||
min_audio_for_interim_ms=settings.asr_min_audio_ms,
|
||||
on_transcript=self._on_transcript_callback
|
||||
)
|
||||
logger.info("Using SiliconFlow ASR service")
|
||||
else:
|
||||
self.asr_service = BufferedASRService(
|
||||
sample_rate=settings.sample_rate
|
||||
)
|
||||
logger.info("Using Buffered ASR service (no real transcription)")
|
||||
|
||||
await self.asr_service.connect()
|
||||
|
||||
logger.info("DuplexPipeline services connected")
|
||||
|
||||
# Speak greeting if configured
|
||||
if self.conversation.greeting:
|
||||
await self._speak(self.conversation.greeting)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start pipeline: {e}")
|
||||
raise
|
||||
|
||||
async def process_audio(self, pcm_bytes: bytes) -> None:
|
||||
"""
|
||||
Process incoming audio chunk.
|
||||
|
||||
This is the main entry point for audio from the user.
|
||||
|
||||
Args:
|
||||
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
|
||||
"""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
try:
|
||||
async with self._process_lock:
|
||||
# 1. Process through VAD
|
||||
vad_result = self.vad_processor.process(pcm_bytes, settings.chunk_size_ms)
|
||||
|
||||
vad_status = "Silence"
|
||||
if vad_result:
|
||||
event_type, probability = vad_result
|
||||
vad_status = "Speech" if event_type == "speaking" else "Silence"
|
||||
|
||||
# Emit VAD event
|
||||
await self.event_bus.publish(event_type, {
|
||||
"trackId": self.session_id,
|
||||
"probability": probability
|
||||
})
|
||||
else:
|
||||
# No state change - keep previous status
|
||||
vad_status = self._last_vad_status
|
||||
|
||||
# Update state based on VAD
|
||||
if vad_status == "Speech" and self._last_vad_status != "Speech":
|
||||
await self._on_speech_start()
|
||||
|
||||
self._last_vad_status = vad_status
|
||||
|
||||
# 2. Check for barge-in (user speaking while bot speaking)
|
||||
# Filter false interruptions by requiring minimum speech duration
|
||||
if self._is_bot_speaking:
|
||||
if vad_status == "Speech":
|
||||
# User is speaking while bot is speaking
|
||||
self._barge_in_silence_frames = 0 # Reset silence counter
|
||||
|
||||
if self._barge_in_speech_start_time is None:
|
||||
# Start tracking speech duration
|
||||
self._barge_in_speech_start_time = time.time()
|
||||
self._barge_in_speech_frames = 1
|
||||
logger.debug("Potential barge-in detected, tracking duration...")
|
||||
else:
|
||||
self._barge_in_speech_frames += 1
|
||||
# Check if speech duration exceeds threshold
|
||||
speech_duration_ms = (time.time() - self._barge_in_speech_start_time) * 1000
|
||||
if speech_duration_ms >= self._barge_in_min_duration_ms:
|
||||
logger.info(f"Barge-in confirmed after {speech_duration_ms:.0f}ms of speech ({self._barge_in_speech_frames} frames)")
|
||||
await self._handle_barge_in()
|
||||
else:
|
||||
# Silence frame during potential barge-in
|
||||
if self._barge_in_speech_start_time is not None:
|
||||
self._barge_in_silence_frames += 1
|
||||
# Allow brief silence gaps (VAD flickering)
|
||||
if self._barge_in_silence_frames > self._barge_in_silence_tolerance:
|
||||
# Too much silence - reset barge-in tracking
|
||||
logger.debug(f"Barge-in cancelled after {self._barge_in_silence_frames} silence frames")
|
||||
self._barge_in_speech_start_time = None
|
||||
self._barge_in_speech_frames = 0
|
||||
self._barge_in_silence_frames = 0
|
||||
|
||||
# 3. Buffer audio for ASR
|
||||
if vad_status == "Speech" or self.conversation.state == ConversationState.LISTENING:
|
||||
self._audio_buffer += pcm_bytes
|
||||
if len(self._audio_buffer) > self._max_audio_buffer_bytes:
|
||||
# Keep only the most recent audio to cap memory usage
|
||||
self._audio_buffer = self._audio_buffer[-self._max_audio_buffer_bytes:]
|
||||
await self.asr_service.send_audio(pcm_bytes)
|
||||
|
||||
# For SiliconFlow ASR, trigger interim transcription periodically
|
||||
# The service handles timing internally via start_interim_transcription()
|
||||
|
||||
# 4. Check for End of Utterance - this triggers LLM response
|
||||
if self.eou_detector.process(vad_status):
|
||||
await self._on_end_of_utterance()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Pipeline audio processing error: {e}", exc_info=True)
|
||||
|
||||
async def process_text(self, text: str) -> None:
|
||||
"""
|
||||
Process text input (chat command).
|
||||
|
||||
Allows direct text input to bypass ASR.
|
||||
|
||||
Args:
|
||||
text: User text input
|
||||
"""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
logger.info(f"Processing text input: {text[:50]}...")
|
||||
|
||||
# Cancel any current speaking
|
||||
await self._stop_current_speech()
|
||||
|
||||
# Start new turn
|
||||
await self.conversation.end_user_turn(text)
|
||||
self._current_turn_task = asyncio.create_task(self._handle_turn(text))
|
||||
|
||||
async def interrupt(self) -> None:
|
||||
"""Interrupt current bot speech (manual interrupt command)."""
|
||||
await self._handle_barge_in()
|
||||
|
||||
async def _on_transcript_callback(self, text: str, is_final: bool) -> None:
|
||||
"""
|
||||
Callback for ASR transcription results.
|
||||
|
||||
Streams transcription to client for display.
|
||||
|
||||
Args:
|
||||
text: Transcribed text
|
||||
is_final: Whether this is the final transcription
|
||||
"""
|
||||
# Avoid sending duplicate transcripts
|
||||
if text == self._last_sent_transcript and not is_final:
|
||||
return
|
||||
|
||||
self._last_sent_transcript = text
|
||||
|
||||
# Send transcript event to client
|
||||
await self.transport.send_event({
|
||||
"event": "transcript",
|
||||
"trackId": self.session_id,
|
||||
"text": text,
|
||||
"isFinal": is_final,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
})
|
||||
|
||||
logger.debug(f"Sent transcript ({'final' if is_final else 'interim'}): {text[:50]}...")
|
||||
|
||||
async def _on_speech_start(self) -> None:
|
||||
"""Handle user starting to speak."""
|
||||
if self.conversation.state == ConversationState.IDLE:
|
||||
await self.conversation.start_user_turn()
|
||||
self._audio_buffer = b""
|
||||
self._last_sent_transcript = ""
|
||||
self.eou_detector.reset()
|
||||
|
||||
# Clear ASR buffer and start interim transcriptions
|
||||
if hasattr(self.asr_service, 'clear_buffer'):
|
||||
self.asr_service.clear_buffer()
|
||||
if hasattr(self.asr_service, 'start_interim_transcription'):
|
||||
await self.asr_service.start_interim_transcription()
|
||||
|
||||
logger.debug("User speech started")
|
||||
|
||||
async def _on_end_of_utterance(self) -> None:
|
||||
"""Handle end of user utterance."""
|
||||
if self.conversation.state != ConversationState.LISTENING:
|
||||
return
|
||||
|
||||
# Stop interim transcriptions
|
||||
if hasattr(self.asr_service, 'stop_interim_transcription'):
|
||||
await self.asr_service.stop_interim_transcription()
|
||||
|
||||
# Get final transcription from ASR service
|
||||
user_text = ""
|
||||
|
||||
if hasattr(self.asr_service, 'get_final_transcription'):
|
||||
# SiliconFlow ASR - get final transcription
|
||||
user_text = await self.asr_service.get_final_transcription()
|
||||
elif hasattr(self.asr_service, 'get_and_clear_text'):
|
||||
# Buffered ASR - get accumulated text
|
||||
user_text = self.asr_service.get_and_clear_text()
|
||||
|
||||
# Skip if no meaningful text
|
||||
if not user_text or not user_text.strip():
|
||||
logger.debug("EOU detected but no transcription - skipping")
|
||||
# Reset for next utterance
|
||||
self._audio_buffer = b""
|
||||
self._last_sent_transcript = ""
|
||||
# Return to idle; don't force LISTENING which causes buffering on silence
|
||||
await self.conversation.set_state(ConversationState.IDLE)
|
||||
return
|
||||
|
||||
logger.info(f"EOU detected - user said: {user_text[:100]}...")
|
||||
|
||||
# Send final transcription to client
|
||||
await self.transport.send_event({
|
||||
"event": "transcript",
|
||||
"trackId": self.session_id,
|
||||
"text": user_text,
|
||||
"isFinal": True,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
})
|
||||
|
||||
# Clear buffers
|
||||
self._audio_buffer = b""
|
||||
self._last_sent_transcript = ""
|
||||
|
||||
# Process the turn - trigger LLM response
|
||||
# Cancel any existing turn to avoid overlapping assistant responses
|
||||
await self._stop_current_speech()
|
||||
await self.conversation.end_user_turn(user_text)
|
||||
self._current_turn_task = asyncio.create_task(self._handle_turn(user_text))
|
||||
|
||||
async def _handle_turn(self, user_text: str) -> None:
|
||||
"""
|
||||
Handle a complete conversation turn.
|
||||
|
||||
Uses sentence-by-sentence streaming TTS for lower latency.
|
||||
|
||||
Args:
|
||||
user_text: User's transcribed text
|
||||
"""
|
||||
try:
|
||||
# Start latency tracking
|
||||
self._turn_start_time = time.time()
|
||||
self._first_audio_sent = False
|
||||
|
||||
# Get AI response (streaming)
|
||||
messages = self.conversation.get_messages()
|
||||
full_response = ""
|
||||
|
||||
await self.conversation.start_assistant_turn()
|
||||
self._is_bot_speaking = True
|
||||
self._interrupt_event.clear()
|
||||
|
||||
# Sentence buffer for streaming TTS
|
||||
sentence_buffer = ""
|
||||
sentence_ends = {',', '。', '!', '?', '\n'}
|
||||
first_audio_sent = False
|
||||
|
||||
# Stream LLM response and TTS sentence by sentence
|
||||
async for text_chunk in self.llm_service.generate_stream(messages):
|
||||
if self._interrupt_event.is_set():
|
||||
break
|
||||
|
||||
full_response += text_chunk
|
||||
sentence_buffer += text_chunk
|
||||
await self.conversation.update_assistant_text(text_chunk)
|
||||
|
||||
# Send LLM response streaming event to client
|
||||
await self.transport.send_event({
|
||||
"event": "llmResponse",
|
||||
"trackId": self.session_id,
|
||||
"text": text_chunk,
|
||||
"isFinal": False,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
})
|
||||
|
||||
# Check for sentence completion - synthesize immediately for low latency
|
||||
while any(end in sentence_buffer for end in sentence_ends):
|
||||
# Find first sentence end
|
||||
min_idx = len(sentence_buffer)
|
||||
for end in sentence_ends:
|
||||
idx = sentence_buffer.find(end)
|
||||
if idx != -1 and idx < min_idx:
|
||||
min_idx = idx
|
||||
|
||||
if min_idx < len(sentence_buffer):
|
||||
sentence = sentence_buffer[:min_idx + 1].strip()
|
||||
sentence_buffer = sentence_buffer[min_idx + 1:]
|
||||
|
||||
if sentence and not self._interrupt_event.is_set():
|
||||
# Send track start on first audio
|
||||
if not first_audio_sent:
|
||||
await self.transport.send_event({
|
||||
"event": "trackStart",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
})
|
||||
first_audio_sent = True
|
||||
|
||||
# Synthesize and send this sentence immediately
|
||||
await self._speak_sentence(sentence)
|
||||
else:
|
||||
break
|
||||
|
||||
# Send final LLM response event
|
||||
if full_response and not self._interrupt_event.is_set():
|
||||
await self.transport.send_event({
|
||||
"event": "llmResponse",
|
||||
"trackId": self.session_id,
|
||||
"text": full_response,
|
||||
"isFinal": True,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
})
|
||||
|
||||
# Speak any remaining text
|
||||
if sentence_buffer.strip() and not self._interrupt_event.is_set():
|
||||
if not first_audio_sent:
|
||||
await self.transport.send_event({
|
||||
"event": "trackStart",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
})
|
||||
first_audio_sent = True
|
||||
await self._speak_sentence(sentence_buffer.strip())
|
||||
|
||||
# Send track end
|
||||
if first_audio_sent:
|
||||
await self.transport.send_event({
|
||||
"event": "trackEnd",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
})
|
||||
|
||||
# End assistant turn
|
||||
await self.conversation.end_assistant_turn(
|
||||
was_interrupted=self._interrupt_event.is_set()
|
||||
)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Turn handling cancelled")
|
||||
await self.conversation.end_assistant_turn(was_interrupted=True)
|
||||
except Exception as e:
|
||||
logger.error(f"Turn handling error: {e}", exc_info=True)
|
||||
await self.conversation.end_assistant_turn(was_interrupted=True)
|
||||
finally:
|
||||
self._is_bot_speaking = False
|
||||
# Reset barge-in tracking when bot finishes speaking
|
||||
self._barge_in_speech_start_time = None
|
||||
self._barge_in_speech_frames = 0
|
||||
self._barge_in_silence_frames = 0
|
||||
|
||||
async def _speak_sentence(self, text: str) -> None:
|
||||
"""
|
||||
Synthesize and send a single sentence.
|
||||
|
||||
Args:
|
||||
text: Sentence to speak
|
||||
"""
|
||||
if not text.strip() or self._interrupt_event.is_set():
|
||||
return
|
||||
|
||||
try:
|
||||
async for chunk in self.tts_service.synthesize_stream(text):
|
||||
# Check interrupt at the start of each iteration
|
||||
if self._interrupt_event.is_set():
|
||||
logger.debug("TTS sentence interrupted")
|
||||
break
|
||||
|
||||
# Track and log first audio packet latency (TTFB)
|
||||
if not self._first_audio_sent and self._turn_start_time:
|
||||
ttfb_ms = (time.time() - self._turn_start_time) * 1000
|
||||
self._first_audio_sent = True
|
||||
logger.info(f"[TTFB] Server first audio packet latency: {ttfb_ms:.0f}ms (session {self.session_id})")
|
||||
|
||||
# Send TTFB event to client
|
||||
await self.transport.send_event({
|
||||
"event": "ttfb",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms(),
|
||||
"latencyMs": round(ttfb_ms)
|
||||
})
|
||||
|
||||
# Double-check interrupt right before sending audio
|
||||
if self._interrupt_event.is_set():
|
||||
break
|
||||
|
||||
await self.transport.send_audio(chunk.audio)
|
||||
await asyncio.sleep(0.005) # Small delay to prevent flooding
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("TTS sentence cancelled")
|
||||
except Exception as e:
|
||||
logger.error(f"TTS sentence error: {e}")
|
||||
|
||||
async def _speak(self, text: str) -> None:
|
||||
"""
|
||||
Synthesize and send speech.
|
||||
|
||||
Args:
|
||||
text: Text to speak
|
||||
"""
|
||||
if not text.strip():
|
||||
return
|
||||
|
||||
try:
|
||||
# Start latency tracking for greeting
|
||||
speak_start_time = time.time()
|
||||
first_audio_sent = False
|
||||
|
||||
# Send track start event
|
||||
await self.transport.send_event({
|
||||
"event": "trackStart",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
})
|
||||
|
||||
self._is_bot_speaking = True
|
||||
|
||||
# Stream TTS audio
|
||||
async for chunk in self.tts_service.synthesize_stream(text):
|
||||
if self._interrupt_event.is_set():
|
||||
logger.info("TTS interrupted by barge-in")
|
||||
break
|
||||
|
||||
# Track and log first audio packet latency (TTFB)
|
||||
if not first_audio_sent:
|
||||
ttfb_ms = (time.time() - speak_start_time) * 1000
|
||||
first_audio_sent = True
|
||||
logger.info(f"[TTFB] Greeting first audio packet latency: {ttfb_ms:.0f}ms (session {self.session_id})")
|
||||
|
||||
# Send TTFB event to client
|
||||
await self.transport.send_event({
|
||||
"event": "ttfb",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms(),
|
||||
"latencyMs": round(ttfb_ms)
|
||||
})
|
||||
|
||||
# Send audio to client
|
||||
await self.transport.send_audio(chunk.audio)
|
||||
|
||||
# Small delay to prevent flooding
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Send track end event
|
||||
await self.transport.send_event({
|
||||
"event": "trackEnd",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
})
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("TTS cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"TTS error: {e}")
|
||||
finally:
|
||||
self._is_bot_speaking = False
|
||||
|
||||
async def _handle_barge_in(self) -> None:
|
||||
"""Handle user barge-in (interruption)."""
|
||||
if not self._is_bot_speaking:
|
||||
return
|
||||
|
||||
logger.info("Barge-in detected - interrupting bot speech")
|
||||
|
||||
# Reset barge-in tracking
|
||||
self._barge_in_speech_start_time = None
|
||||
self._barge_in_speech_frames = 0
|
||||
self._barge_in_silence_frames = 0
|
||||
|
||||
# IMPORTANT: Signal interruption FIRST to stop audio sending
|
||||
self._interrupt_event.set()
|
||||
self._is_bot_speaking = False
|
||||
|
||||
# Send interrupt event to client IMMEDIATELY
|
||||
# This must happen BEFORE canceling services, so client knows to discard in-flight audio
|
||||
await self.transport.send_event({
|
||||
"event": "interrupt",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
})
|
||||
|
||||
# Cancel TTS
|
||||
if self.tts_service:
|
||||
await self.tts_service.cancel()
|
||||
|
||||
# Cancel LLM
|
||||
if self.llm_service and hasattr(self.llm_service, 'cancel'):
|
||||
self.llm_service.cancel()
|
||||
|
||||
# Interrupt conversation only if there is no active turn task.
|
||||
# When a turn task exists, it will handle end_assistant_turn() to avoid double callbacks.
|
||||
if not (self._current_turn_task and not self._current_turn_task.done()):
|
||||
await self.conversation.interrupt()
|
||||
|
||||
# Reset for new user turn
|
||||
await self.conversation.start_user_turn()
|
||||
self._audio_buffer = b""
|
||||
self.eou_detector.reset()
|
||||
|
||||
async def _stop_current_speech(self) -> None:
|
||||
"""Stop any current speech task."""
|
||||
if self._current_turn_task and not self._current_turn_task.done():
|
||||
self._interrupt_event.set()
|
||||
self._current_turn_task.cancel()
|
||||
try:
|
||||
await self._current_turn_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Ensure underlying services are cancelled to avoid leaking work/audio
|
||||
if self.tts_service:
|
||||
await self.tts_service.cancel()
|
||||
if self.llm_service and hasattr(self.llm_service, 'cancel'):
|
||||
self.llm_service.cancel()
|
||||
|
||||
self._is_bot_speaking = False
|
||||
self._interrupt_event.clear()
|
||||
|
||||
async def cleanup(self) -> None:
|
||||
"""Cleanup pipeline resources."""
|
||||
logger.info(f"Cleaning up DuplexPipeline for session {self.session_id}")
|
||||
|
||||
self._running = False
|
||||
await self._stop_current_speech()
|
||||
|
||||
# Disconnect services
|
||||
if self.llm_service:
|
||||
await self.llm_service.disconnect()
|
||||
if self.tts_service:
|
||||
await self.tts_service.disconnect()
|
||||
if self.asr_service:
|
||||
await self.asr_service.disconnect()
|
||||
|
||||
def _get_timestamp_ms(self) -> int:
|
||||
"""Get current timestamp in milliseconds."""
|
||||
import time
|
||||
return int(time.time() * 1000)
|
||||
|
||||
@property
|
||||
def is_speaking(self) -> bool:
|
||||
"""Check if bot is currently speaking."""
|
||||
return self._is_bot_speaking
|
||||
|
||||
@property
|
||||
def state(self) -> ConversationState:
|
||||
"""Get current conversation state."""
|
||||
return self.conversation.state
|
||||
134
engine/core/events.py
Normal file
134
engine/core/events.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""Event bus for pub/sub communication between components."""
|
||||
|
||||
import asyncio
|
||||
from typing import Callable, Dict, List, Any, Optional
|
||||
from collections import defaultdict
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class EventBus:
|
||||
"""
|
||||
Async event bus for pub/sub communication.
|
||||
|
||||
Similar to the original Rust implementation's broadcast channel.
|
||||
Components can subscribe to specific event types and receive events asynchronously.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the event bus."""
|
||||
self._subscribers: Dict[str, List[Callable]] = defaultdict(list)
|
||||
self._lock = asyncio.Lock()
|
||||
self._running = True
|
||||
|
||||
def subscribe(self, event_type: str, callback: Callable[[Dict[str, Any]], None]) -> None:
|
||||
"""
|
||||
Subscribe to an event type.
|
||||
|
||||
Args:
|
||||
event_type: Type of event to subscribe to (e.g., "speaking", "silence")
|
||||
callback: Async callback function that receives event data
|
||||
"""
|
||||
if not self._running:
|
||||
logger.warning(f"Event bus is shut down, ignoring subscription to {event_type}")
|
||||
return
|
||||
|
||||
self._subscribers[event_type].append(callback)
|
||||
logger.debug(f"Subscribed to event type: {event_type}")
|
||||
|
||||
def unsubscribe(self, event_type: str, callback: Callable[[Dict[str, Any]], None]) -> None:
|
||||
"""
|
||||
Unsubscribe from an event type.
|
||||
|
||||
Args:
|
||||
event_type: Type of event to unsubscribe from
|
||||
callback: Callback function to remove
|
||||
"""
|
||||
if callback in self._subscribers[event_type]:
|
||||
self._subscribers[event_type].remove(callback)
|
||||
logger.debug(f"Unsubscribed from event type: {event_type}")
|
||||
|
||||
async def publish(self, event_type: str, event_data: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Publish an event to all subscribers.
|
||||
|
||||
Args:
|
||||
event_type: Type of event to publish
|
||||
event_data: Event data to send to subscribers
|
||||
"""
|
||||
if not self._running:
|
||||
logger.warning(f"Event bus is shut down, ignoring event: {event_type}")
|
||||
return
|
||||
|
||||
# Get subscribers for this event type
|
||||
subscribers = self._subscribers.get(event_type, [])
|
||||
|
||||
if not subscribers:
|
||||
logger.debug(f"No subscribers for event type: {event_type}")
|
||||
return
|
||||
|
||||
# Notify all subscribers concurrently
|
||||
tasks = []
|
||||
for callback in subscribers:
|
||||
try:
|
||||
# Create task for each subscriber
|
||||
task = asyncio.create_task(self._call_subscriber(callback, event_data))
|
||||
tasks.append(task)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating task for subscriber: {e}")
|
||||
|
||||
# Wait for all subscribers to complete
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
logger.debug(f"Published event '{event_type}' to {len(tasks)} subscribers")
|
||||
|
||||
async def _call_subscriber(self, callback: Callable[[Dict[str, Any]], None], event_data: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Call a subscriber callback with error handling.
|
||||
|
||||
Args:
|
||||
callback: Subscriber callback function
|
||||
event_data: Event data to pass to callback
|
||||
"""
|
||||
try:
|
||||
# Check if callback is a coroutine function
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(event_data)
|
||||
else:
|
||||
callback(event_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in subscriber callback: {e}", exc_info=True)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the event bus and stop processing events."""
|
||||
self._running = False
|
||||
self._subscribers.clear()
|
||||
logger.info("Event bus closed")
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the event bus is running."""
|
||||
return self._running
|
||||
|
||||
|
||||
# Global event bus instance
|
||||
_event_bus: Optional[EventBus] = None
|
||||
|
||||
|
||||
def get_event_bus() -> EventBus:
|
||||
"""
|
||||
Get the global event bus instance.
|
||||
|
||||
Returns:
|
||||
EventBus instance
|
||||
"""
|
||||
global _event_bus
|
||||
if _event_bus is None:
|
||||
_event_bus = EventBus()
|
||||
return _event_bus
|
||||
|
||||
|
||||
def reset_event_bus() -> None:
|
||||
"""Reset the global event bus (mainly for testing)."""
|
||||
global _event_bus
|
||||
_event_bus = None
|
||||
285
engine/core/session.py
Normal file
285
engine/core/session.py
Normal file
@@ -0,0 +1,285 @@
|
||||
"""Session management for active calls."""
|
||||
|
||||
import uuid
|
||||
import json
|
||||
from typing import Optional, Dict, Any
|
||||
from loguru import logger
|
||||
|
||||
from core.transports import BaseTransport
|
||||
from core.duplex_pipeline import DuplexPipeline
|
||||
from models.commands import parse_command, TTSCommand, ChatCommand, InterruptCommand, HangupCommand
|
||||
from app.config import settings
|
||||
|
||||
|
||||
class Session:
|
||||
"""
|
||||
Manages a single call session.
|
||||
|
||||
Handles command routing, audio processing, and session lifecycle.
|
||||
Uses full duplex voice conversation pipeline.
|
||||
"""
|
||||
|
||||
def __init__(self, session_id: str, transport: BaseTransport, use_duplex: bool = None):
|
||||
"""
|
||||
Initialize session.
|
||||
|
||||
Args:
|
||||
session_id: Unique session identifier
|
||||
transport: Transport instance for communication
|
||||
use_duplex: Whether to use duplex pipeline (defaults to settings.duplex_enabled)
|
||||
"""
|
||||
self.id = session_id
|
||||
self.transport = transport
|
||||
self.use_duplex = use_duplex if use_duplex is not None else settings.duplex_enabled
|
||||
|
||||
self.pipeline = DuplexPipeline(
|
||||
transport=transport,
|
||||
session_id=session_id,
|
||||
system_prompt=settings.duplex_system_prompt,
|
||||
greeting=settings.duplex_greeting
|
||||
)
|
||||
|
||||
# Session state
|
||||
self.created_at = None
|
||||
self.state = "created" # created, invited, accepted, ringing, hungup
|
||||
self._pipeline_started = False
|
||||
|
||||
# Track IDs
|
||||
self.current_track_id: Optional[str] = str(uuid.uuid4())
|
||||
|
||||
logger.info(f"Session {self.id} created (duplex={self.use_duplex})")
|
||||
|
||||
async def handle_text(self, text_data: str) -> None:
|
||||
"""
|
||||
Handle incoming text data (JSON commands).
|
||||
|
||||
Args:
|
||||
text_data: JSON text data
|
||||
"""
|
||||
try:
|
||||
data = json.loads(text_data)
|
||||
command = parse_command(data)
|
||||
command_type = command.command
|
||||
|
||||
logger.info(f"Session {self.id} received command: {command_type}")
|
||||
|
||||
# Route command to appropriate handler
|
||||
if command_type == "invite":
|
||||
await self._handle_invite(data)
|
||||
|
||||
elif command_type == "accept":
|
||||
await self._handle_accept(data)
|
||||
|
||||
elif command_type == "reject":
|
||||
await self._handle_reject(data)
|
||||
|
||||
elif command_type == "ringing":
|
||||
await self._handle_ringing(data)
|
||||
|
||||
elif command_type == "tts":
|
||||
await self._handle_tts(command)
|
||||
|
||||
elif command_type == "play":
|
||||
await self._handle_play(data)
|
||||
|
||||
elif command_type == "interrupt":
|
||||
await self._handle_interrupt(command)
|
||||
|
||||
elif command_type == "pause":
|
||||
await self._handle_pause()
|
||||
|
||||
elif command_type == "resume":
|
||||
await self._handle_resume()
|
||||
|
||||
elif command_type == "hangup":
|
||||
await self._handle_hangup(command)
|
||||
|
||||
elif command_type == "history":
|
||||
await self._handle_history(data)
|
||||
|
||||
elif command_type == "chat":
|
||||
await self._handle_chat(command)
|
||||
|
||||
else:
|
||||
logger.warning(f"Session {self.id} unknown command: {command_type}")
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Session {self.id} JSON decode error: {e}")
|
||||
await self._send_error("client", f"Invalid JSON: {e}")
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Session {self.id} command parse error: {e}")
|
||||
await self._send_error("client", f"Invalid command: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Session {self.id} handle_text error: {e}", exc_info=True)
|
||||
await self._send_error("server", f"Internal error: {e}")
|
||||
|
||||
async def handle_audio(self, audio_bytes: bytes) -> None:
|
||||
"""
|
||||
Handle incoming audio data.
|
||||
|
||||
Args:
|
||||
audio_bytes: PCM audio data
|
||||
"""
|
||||
try:
|
||||
await self.pipeline.process_audio(audio_bytes)
|
||||
except Exception as e:
|
||||
logger.error(f"Session {self.id} handle_audio error: {e}", exc_info=True)
|
||||
|
||||
async def _handle_invite(self, data: Dict[str, Any]) -> None:
|
||||
"""Handle invite command."""
|
||||
self.state = "invited"
|
||||
option = data.get("option", {})
|
||||
|
||||
# Send answer event
|
||||
await self.transport.send_event({
|
||||
"event": "answer",
|
||||
"trackId": self.current_track_id,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
})
|
||||
|
||||
# Start duplex pipeline
|
||||
if not self._pipeline_started:
|
||||
try:
|
||||
await self.pipeline.start()
|
||||
self._pipeline_started = True
|
||||
logger.info(f"Session {self.id} duplex pipeline started")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start duplex pipeline: {e}")
|
||||
|
||||
logger.info(f"Session {self.id} invited with codec: {option.get('codec', 'pcm')}")
|
||||
|
||||
async def _handle_accept(self, data: Dict[str, Any]) -> None:
|
||||
"""Handle accept command."""
|
||||
self.state = "accepted"
|
||||
logger.info(f"Session {self.id} accepted")
|
||||
|
||||
async def _handle_reject(self, data: Dict[str, Any]) -> None:
|
||||
"""Handle reject command."""
|
||||
self.state = "rejected"
|
||||
reason = data.get("reason", "Rejected")
|
||||
logger.info(f"Session {self.id} rejected: {reason}")
|
||||
|
||||
async def _handle_ringing(self, data: Dict[str, Any]) -> None:
|
||||
"""Handle ringing command."""
|
||||
self.state = "ringing"
|
||||
logger.info(f"Session {self.id} ringing")
|
||||
|
||||
async def _handle_tts(self, command: TTSCommand) -> None:
|
||||
"""Handle TTS command."""
|
||||
logger.info(f"Session {self.id} TTS: {command.text[:50]}...")
|
||||
|
||||
# Send track start event
|
||||
await self.transport.send_event({
|
||||
"event": "trackStart",
|
||||
"trackId": self.current_track_id,
|
||||
"timestamp": self._get_timestamp_ms(),
|
||||
"playId": command.play_id
|
||||
})
|
||||
|
||||
# TODO: Implement actual TTS synthesis
|
||||
# For now, just send track end event
|
||||
await self.transport.send_event({
|
||||
"event": "trackEnd",
|
||||
"trackId": self.current_track_id,
|
||||
"timestamp": self._get_timestamp_ms(),
|
||||
"duration": 1000,
|
||||
"ssrc": 0,
|
||||
"playId": command.play_id
|
||||
})
|
||||
|
||||
async def _handle_play(self, data: Dict[str, Any]) -> None:
|
||||
"""Handle play command."""
|
||||
url = data.get("url", "")
|
||||
logger.info(f"Session {self.id} play: {url}")
|
||||
|
||||
# Send track start event
|
||||
await self.transport.send_event({
|
||||
"event": "trackStart",
|
||||
"trackId": self.current_track_id,
|
||||
"timestamp": self._get_timestamp_ms(),
|
||||
"playId": url
|
||||
})
|
||||
|
||||
# TODO: Implement actual audio playback
|
||||
# For now, just send track end event
|
||||
await self.transport.send_event({
|
||||
"event": "trackEnd",
|
||||
"trackId": self.current_track_id,
|
||||
"timestamp": self._get_timestamp_ms(),
|
||||
"duration": 1000,
|
||||
"ssrc": 0,
|
||||
"playId": url
|
||||
})
|
||||
|
||||
async def _handle_interrupt(self, command: InterruptCommand) -> None:
|
||||
"""Handle interrupt command."""
|
||||
if command.graceful:
|
||||
logger.info(f"Session {self.id} graceful interrupt")
|
||||
else:
|
||||
logger.info(f"Session {self.id} immediate interrupt")
|
||||
await self.pipeline.interrupt()
|
||||
|
||||
async def _handle_pause(self) -> None:
|
||||
"""Handle pause command."""
|
||||
logger.info(f"Session {self.id} paused")
|
||||
|
||||
async def _handle_resume(self) -> None:
|
||||
"""Handle resume command."""
|
||||
logger.info(f"Session {self.id} resumed")
|
||||
|
||||
async def _handle_hangup(self, command: HangupCommand) -> None:
|
||||
"""Handle hangup command."""
|
||||
self.state = "hungup"
|
||||
reason = command.reason or "User requested"
|
||||
logger.info(f"Session {self.id} hung up: {reason}")
|
||||
|
||||
# Send hangup event
|
||||
await self.transport.send_event({
|
||||
"event": "hangup",
|
||||
"timestamp": self._get_timestamp_ms(),
|
||||
"reason": reason,
|
||||
"initiator": command.initiator or "user"
|
||||
})
|
||||
|
||||
# Close transport
|
||||
await self.transport.close()
|
||||
|
||||
async def _handle_history(self, data: Dict[str, Any]) -> None:
|
||||
"""Handle history command."""
|
||||
speaker = data.get("speaker", "unknown")
|
||||
text = data.get("text", "")
|
||||
logger.info(f"Session {self.id} history [{speaker}]: {text[:50]}...")
|
||||
|
||||
async def _handle_chat(self, command: ChatCommand) -> None:
|
||||
"""Handle chat command."""
|
||||
logger.info(f"Session {self.id} chat: {command.text[:50]}...")
|
||||
await self.pipeline.process_text(command.text)
|
||||
|
||||
async def _send_error(self, sender: str, error_message: str) -> None:
|
||||
"""
|
||||
Send error event to client.
|
||||
|
||||
Args:
|
||||
sender: Component that generated the error
|
||||
error_message: Error message
|
||||
"""
|
||||
await self.transport.send_event({
|
||||
"event": "error",
|
||||
"trackId": self.current_track_id,
|
||||
"timestamp": self._get_timestamp_ms(),
|
||||
"sender": sender,
|
||||
"error": error_message
|
||||
})
|
||||
|
||||
def _get_timestamp_ms(self) -> int:
|
||||
"""Get current timestamp in milliseconds."""
|
||||
import time
|
||||
return int(time.time() * 1000)
|
||||
|
||||
async def cleanup(self) -> None:
|
||||
"""Cleanup session resources."""
|
||||
logger.info(f"Session {self.id} cleaning up")
|
||||
await self.pipeline.cleanup()
|
||||
await self.transport.close()
|
||||
207
engine/core/transports.py
Normal file
207
engine/core/transports.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""Transport layer for WebSocket and WebRTC communication."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from fastapi import WebSocket
|
||||
from loguru import logger
|
||||
|
||||
# Try to import aiortc (optional for WebRTC functionality)
|
||||
try:
|
||||
from aiortc import RTCPeerConnection
|
||||
AIORTC_AVAILABLE = True
|
||||
except ImportError:
|
||||
AIORTC_AVAILABLE = False
|
||||
RTCPeerConnection = None # Type hint placeholder
|
||||
|
||||
|
||||
class BaseTransport(ABC):
|
||||
"""
|
||||
Abstract base class for transports.
|
||||
|
||||
All transports must implement send_event and send_audio methods.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def send_event(self, event: dict) -> None:
|
||||
"""
|
||||
Send a JSON event to the client.
|
||||
|
||||
Args:
|
||||
event: Event data as dictionary
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def send_audio(self, pcm_bytes: bytes) -> None:
|
||||
"""
|
||||
Send audio data to the client.
|
||||
|
||||
Args:
|
||||
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None:
|
||||
"""Close the transport and cleanup resources."""
|
||||
pass
|
||||
|
||||
|
||||
class SocketTransport(BaseTransport):
|
||||
"""
|
||||
WebSocket transport for raw audio streaming.
|
||||
|
||||
Handles mixed text/binary frames over WebSocket connection.
|
||||
Uses asyncio.Lock to prevent frame interleaving.
|
||||
"""
|
||||
|
||||
def __init__(self, websocket: WebSocket):
|
||||
"""
|
||||
Initialize WebSocket transport.
|
||||
|
||||
Args:
|
||||
websocket: FastAPI WebSocket connection
|
||||
"""
|
||||
self.ws = websocket
|
||||
self.lock = asyncio.Lock() # Prevent frame interleaving
|
||||
self._closed = False
|
||||
|
||||
async def send_event(self, event: dict) -> None:
|
||||
"""
|
||||
Send a JSON event via WebSocket.
|
||||
|
||||
Args:
|
||||
event: Event data as dictionary
|
||||
"""
|
||||
if self._closed:
|
||||
logger.warning("Attempted to send event on closed transport")
|
||||
return
|
||||
|
||||
async with self.lock:
|
||||
try:
|
||||
await self.ws.send_text(json.dumps(event))
|
||||
logger.debug(f"Sent event: {event.get('event', 'unknown')}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending event: {e}")
|
||||
self._closed = True
|
||||
|
||||
async def send_audio(self, pcm_bytes: bytes) -> None:
|
||||
"""
|
||||
Send PCM audio data via WebSocket.
|
||||
|
||||
Args:
|
||||
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
|
||||
"""
|
||||
if self._closed:
|
||||
logger.warning("Attempted to send audio on closed transport")
|
||||
return
|
||||
|
||||
async with self.lock:
|
||||
try:
|
||||
await self.ws.send_bytes(pcm_bytes)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending audio: {e}")
|
||||
self._closed = True
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the WebSocket connection."""
|
||||
self._closed = True
|
||||
try:
|
||||
await self.ws.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing WebSocket: {e}")
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
"""Check if the transport is closed."""
|
||||
return self._closed
|
||||
|
||||
|
||||
class WebRtcTransport(BaseTransport):
|
||||
"""
|
||||
WebRTC transport for WebRTC audio streaming.
|
||||
|
||||
Uses WebSocket for signaling and RTCPeerConnection for media.
|
||||
"""
|
||||
|
||||
def __init__(self, websocket: WebSocket, pc):
|
||||
"""
|
||||
Initialize WebRTC transport.
|
||||
|
||||
Args:
|
||||
websocket: FastAPI WebSocket connection for signaling
|
||||
pc: RTCPeerConnection for media transport
|
||||
"""
|
||||
if not AIORTC_AVAILABLE:
|
||||
raise RuntimeError("aiortc is not available - WebRTC transport cannot be used")
|
||||
|
||||
self.ws = websocket
|
||||
self.pc = pc
|
||||
self.outbound_track = None # MediaStreamTrack for outbound audio
|
||||
self._closed = False
|
||||
|
||||
async def send_event(self, event: dict) -> None:
|
||||
"""
|
||||
Send a JSON event via WebSocket signaling.
|
||||
|
||||
Args:
|
||||
event: Event data as dictionary
|
||||
"""
|
||||
if self._closed:
|
||||
logger.warning("Attempted to send event on closed transport")
|
||||
return
|
||||
|
||||
try:
|
||||
await self.ws.send_text(json.dumps(event))
|
||||
logger.debug(f"Sent event: {event.get('event', 'unknown')}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending event: {e}")
|
||||
self._closed = True
|
||||
|
||||
async def send_audio(self, pcm_bytes: bytes) -> None:
|
||||
"""
|
||||
Send audio data via WebRTC track.
|
||||
|
||||
Note: In WebRTC, you don't send bytes directly. You push frames
|
||||
to a MediaStreamTrack that the peer connection is reading.
|
||||
|
||||
Args:
|
||||
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
|
||||
"""
|
||||
if self._closed:
|
||||
logger.warning("Attempted to send audio on closed transport")
|
||||
return
|
||||
|
||||
# This would require a custom MediaStreamTrack implementation
|
||||
# For now, we'll log this as a placeholder
|
||||
logger.debug(f"Audio bytes queued for WebRTC track: {len(pcm_bytes)} bytes")
|
||||
|
||||
# TODO: Implement outbound audio track if needed
|
||||
# if self.outbound_track:
|
||||
# await self.outbound_track.add_frame(pcm_bytes)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the WebRTC connection."""
|
||||
self._closed = True
|
||||
try:
|
||||
await self.pc.close()
|
||||
await self.ws.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing WebRTC transport: {e}")
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
"""Check if the transport is closed."""
|
||||
return self._closed
|
||||
|
||||
def set_outbound_track(self, track):
|
||||
"""
|
||||
Set the outbound audio track for sending audio to client.
|
||||
|
||||
Args:
|
||||
track: MediaStreamTrack for outbound audio
|
||||
"""
|
||||
self.outbound_track = track
|
||||
logger.debug("Set outbound track for WebRTC transport")
|
||||
Reference in New Issue
Block a user