Cleanup duplex pipeline

This commit is contained in:
Xin Wang
2026-02-09 16:00:32 +08:00
parent cb5c08d84d
commit ed044bd8ad

View File

@@ -13,37 +13,38 @@ event-driven design.
import asyncio import asyncio
import time import time
from typing import Optional, Callable, Awaitable, Dict, Any from typing import Any, Dict, Optional
import numpy as np import numpy as np
from loguru import logger from loguru import logger
from core.transports import BaseTransport from app.config import settings
from core.conversation import ConversationManager, ConversationState from core.conversation import ConversationManager, ConversationState
from core.events import get_event_bus from core.events import get_event_bus
from processors.vad import VADProcessor, SileroVAD from core.transports import BaseTransport
from processors.eou import EouDetector
from services.base import BaseLLMService, BaseTTSService, BaseASRService
from services.llm import OpenAILLMService, MockLLMService
from services.tts import EdgeTTSService, MockTTSService
from services.asr import BufferedASRService
from services.siliconflow_tts import SiliconFlowTTSService
from services.siliconflow_asr import SiliconFlowASRService
from app.config import settings
from models.ws_v1 import ev from models.ws_v1 import ev
from processors.eou import EouDetector
from processors.vad import SileroVAD, VADProcessor
from services.asr import BufferedASRService
from services.base import BaseASRService, BaseLLMService, BaseTTSService
from services.llm import MockLLMService, OpenAILLMService
from services.siliconflow_asr import SiliconFlowASRService
from services.siliconflow_tts import SiliconFlowTTSService
from services.tts import EdgeTTSService, MockTTSService
class DuplexPipeline: class DuplexPipeline:
""" """
Full duplex audio pipeline for AI voice conversation. Full duplex audio pipeline for AI voice conversation.
Handles bidirectional audio flow with: Handles bidirectional audio flow with:
- User speech detection and transcription - User speech detection and transcription
- AI response generation - AI response generation
- Text-to-speech synthesis - Text-to-speech synthesis
- Barge-in (interruption) support - Barge-in (interruption) support
Architecture (inspired by pipecat): Architecture (inspired by pipecat):
User Audio → VAD → EOU → [ASR] → LLM → TTS → Audio Out User Audio → VAD → EOU → [ASR] → LLM → TTS → Audio Out
Barge-in Detection → Interrupt Barge-in Detection → Interrupt
@@ -52,7 +53,7 @@ class DuplexPipeline:
_SENTENCE_END_CHARS = frozenset({"", "", "", ".", "!", "?", "\n"}) _SENTENCE_END_CHARS = frozenset({"", "", "", ".", "!", "?", "\n"})
_SENTENCE_TRAILING_CHARS = frozenset({"", "", "", ".", "!", "?", "", "~", "", "\n"}) _SENTENCE_TRAILING_CHARS = frozenset({"", "", "", ".", "!", "?", "", "~", "", "\n"})
_SENTENCE_CLOSERS = frozenset({'"', "'", "", "", ")", "]", "}", "", "", "", "", ""}) _SENTENCE_CLOSERS = frozenset({'"', "'", "", "", ")", "]", "}", "", "", "", "", ""})
def __init__( def __init__(
self, self,
transport: BaseTransport, transport: BaseTransport,
@@ -65,7 +66,7 @@ class DuplexPipeline:
): ):
""" """
Initialize duplex pipeline. Initialize duplex pipeline.
Args: Args:
transport: Transport for sending audio/events transport: Transport for sending audio/events
session_id: Session identifier session_id: Session identifier
@@ -78,7 +79,7 @@ class DuplexPipeline:
self.transport = transport self.transport = transport
self.session_id = session_id self.session_id = session_id
self.event_bus = get_event_bus() self.event_bus = get_event_bus()
# Initialize VAD # Initialize VAD
self.vad_model = SileroVAD( self.vad_model = SileroVAD(
model_path=settings.vad_model_path, model_path=settings.vad_model_path,
@@ -88,27 +89,27 @@ class DuplexPipeline:
vad_model=self.vad_model, vad_model=self.vad_model,
threshold=settings.vad_threshold threshold=settings.vad_threshold
) )
# Initialize EOU detector # Initialize EOU detector
self.eou_detector = EouDetector( self.eou_detector = EouDetector(
silence_threshold_ms=settings.vad_eou_threshold_ms, silence_threshold_ms=settings.vad_eou_threshold_ms,
min_speech_duration_ms=settings.vad_min_speech_duration_ms min_speech_duration_ms=settings.vad_min_speech_duration_ms
) )
# Initialize services # Initialize services
self.llm_service = llm_service self.llm_service = llm_service
self.tts_service = tts_service self.tts_service = tts_service
self.asr_service = asr_service # Will be initialized in start() self.asr_service = asr_service # Will be initialized in start()
# Track last sent transcript to avoid duplicates # Track last sent transcript to avoid duplicates
self._last_sent_transcript = "" self._last_sent_transcript = ""
# Conversation manager # Conversation manager
self.conversation = ConversationManager( self.conversation = ConversationManager(
system_prompt=system_prompt, system_prompt=system_prompt,
greeting=greeting greeting=greeting
) )
# State # State
self._running = True self._running = True
self._is_bot_speaking = False self._is_bot_speaking = False
@@ -118,14 +119,14 @@ class DuplexPipeline:
self._max_audio_buffer_bytes = int(settings.sample_rate * 2 * max_buffer_seconds) self._max_audio_buffer_bytes = int(settings.sample_rate * 2 * max_buffer_seconds)
self._last_vad_status: str = "Silence" self._last_vad_status: str = "Silence"
self._process_lock = asyncio.Lock() self._process_lock = asyncio.Lock()
# Interruption handling # Interruption handling
self._interrupt_event = asyncio.Event() self._interrupt_event = asyncio.Event()
# Latency tracking - TTFB (Time to First Byte) # Latency tracking - TTFB (Time to First Byte)
self._turn_start_time: Optional[float] = None self._turn_start_time: Optional[float] = None
self._first_audio_sent: bool = False self._first_audio_sent: bool = False
# Barge-in filtering - require minimum speech duration to interrupt # Barge-in filtering - require minimum speech duration to interrupt
self._barge_in_speech_start_time: Optional[float] = None self._barge_in_speech_start_time: Optional[float] = None
self._barge_in_min_duration_ms: int = settings.barge_in_min_duration_ms if hasattr(settings, 'barge_in_min_duration_ms') else 50 self._barge_in_min_duration_ms: int = settings.barge_in_min_duration_ms if hasattr(settings, 'barge_in_min_duration_ms') else 50
@@ -139,7 +140,7 @@ class DuplexPipeline:
self._runtime_tts: Dict[str, Any] = {} self._runtime_tts: Dict[str, Any] = {}
self._runtime_system_prompt: Optional[str] = None self._runtime_system_prompt: Optional[str] = None
self._runtime_greeting: Optional[str] = None self._runtime_greeting: Optional[str] = None
logger.info(f"DuplexPipeline initialized for session {session_id}") logger.info(f"DuplexPipeline initialized for session {session_id}")
def apply_runtime_overrides(self, metadata: Optional[Dict[str, Any]]) -> None: def apply_runtime_overrides(self, metadata: Optional[Dict[str, Any]]) -> None:
@@ -176,7 +177,7 @@ class DuplexPipeline:
self._runtime_asr = services["asr"] self._runtime_asr = services["asr"]
if isinstance(services.get("tts"), dict): if isinstance(services.get("tts"), dict):
self._runtime_tts = services["tts"] self._runtime_tts = services["tts"]
async def start(self) -> None: async def start(self) -> None:
"""Start the pipeline and connect services.""" """Start the pipeline and connect services."""
try: try:
@@ -196,9 +197,9 @@ class DuplexPipeline:
else: else:
logger.warning("No OpenAI API key - using mock LLM") logger.warning("No OpenAI API key - using mock LLM")
self.llm_service = MockLLMService() self.llm_service = MockLLMService()
await self.llm_service.connect() await self.llm_service.connect()
# Connect TTS service # Connect TTS service
if not self.tts_service: if not self.tts_service:
tts_provider = (self._runtime_tts.get("provider") or settings.tts_provider).lower() tts_provider = (self._runtime_tts.get("provider") or settings.tts_provider).lower()
@@ -231,7 +232,7 @@ class DuplexPipeline:
sample_rate=settings.sample_rate sample_rate=settings.sample_rate
) )
await self.tts_service.connect() await self.tts_service.connect()
# Connect ASR service # Connect ASR service
if not self.asr_service: if not self.asr_service:
asr_provider = (self._runtime_asr.get("provider") or settings.asr_provider).lower() asr_provider = (self._runtime_asr.get("provider") or settings.asr_provider).lower()
@@ -255,41 +256,41 @@ class DuplexPipeline:
sample_rate=settings.sample_rate sample_rate=settings.sample_rate
) )
logger.info("Using Buffered ASR service (no real transcription)") logger.info("Using Buffered ASR service (no real transcription)")
await self.asr_service.connect() await self.asr_service.connect()
logger.info("DuplexPipeline services connected") logger.info("DuplexPipeline services connected")
# Speak greeting if configured # Speak greeting if configured
if self.conversation.greeting: if self.conversation.greeting:
await self._speak(self.conversation.greeting) await self._speak(self.conversation.greeting)
except Exception as e: except Exception as e:
logger.error(f"Failed to start pipeline: {e}") logger.error(f"Failed to start pipeline: {e}")
raise raise
async def process_audio(self, pcm_bytes: bytes) -> None: async def process_audio(self, pcm_bytes: bytes) -> None:
""" """
Process incoming audio chunk. Process incoming audio chunk.
This is the main entry point for audio from the user. This is the main entry point for audio from the user.
Args: Args:
pcm_bytes: PCM audio data (16-bit, mono, 16kHz) pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
""" """
if not self._running: if not self._running:
return return
try: try:
async with self._process_lock: async with self._process_lock:
# 1. Process through VAD # 1. Process through VAD
vad_result = self.vad_processor.process(pcm_bytes, settings.chunk_size_ms) vad_result = self.vad_processor.process(pcm_bytes, settings.chunk_size_ms)
vad_status = "Silence" vad_status = "Silence"
if vad_result: if vad_result:
event_type, probability = vad_result event_type, probability = vad_result
vad_status = "Speech" if event_type == "speaking" else "Silence" vad_status = "Speech" if event_type == "speaking" else "Silence"
# Emit VAD event # Emit VAD event
await self.event_bus.publish(event_type, { await self.event_bus.publish(event_type, {
"trackId": self.session_id, "trackId": self.session_id,
@@ -305,20 +306,20 @@ class DuplexPipeline:
else: else:
# No state change - keep previous status # No state change - keep previous status
vad_status = self._last_vad_status vad_status = self._last_vad_status
# Update state based on VAD # Update state based on VAD
if vad_status == "Speech" and self._last_vad_status != "Speech": if vad_status == "Speech" and self._last_vad_status != "Speech":
await self._on_speech_start() await self._on_speech_start()
self._last_vad_status = vad_status self._last_vad_status = vad_status
# 2. Check for barge-in (user speaking while bot speaking) # 2. Check for barge-in (user speaking while bot speaking)
# Filter false interruptions by requiring minimum speech duration # Filter false interruptions by requiring minimum speech duration
if self._is_bot_speaking: if self._is_bot_speaking:
if vad_status == "Speech": if vad_status == "Speech":
# User is speaking while bot is speaking # User is speaking while bot is speaking
self._barge_in_silence_frames = 0 # Reset silence counter self._barge_in_silence_frames = 0 # Reset silence counter
if self._barge_in_speech_start_time is None: if self._barge_in_speech_start_time is None:
# Start tracking speech duration # Start tracking speech duration
self._barge_in_speech_start_time = time.time() self._barge_in_speech_start_time = time.time()
@@ -342,7 +343,7 @@ class DuplexPipeline:
self._barge_in_speech_start_time = None self._barge_in_speech_start_time = None
self._barge_in_speech_frames = 0 self._barge_in_speech_frames = 0
self._barge_in_silence_frames = 0 self._barge_in_silence_frames = 0
# 3. Buffer audio for ASR # 3. Buffer audio for ASR
if vad_status == "Speech" or self.conversation.state == ConversationState.LISTENING: if vad_status == "Speech" or self.conversation.state == ConversationState.LISTENING:
self._audio_buffer += pcm_bytes self._audio_buffer += pcm_bytes
@@ -350,48 +351,48 @@ class DuplexPipeline:
# Keep only the most recent audio to cap memory usage # Keep only the most recent audio to cap memory usage
self._audio_buffer = self._audio_buffer[-self._max_audio_buffer_bytes:] self._audio_buffer = self._audio_buffer[-self._max_audio_buffer_bytes:]
await self.asr_service.send_audio(pcm_bytes) await self.asr_service.send_audio(pcm_bytes)
# For SiliconFlow ASR, trigger interim transcription periodically # For SiliconFlow ASR, trigger interim transcription periodically
# The service handles timing internally via start_interim_transcription() # The service handles timing internally via start_interim_transcription()
# 4. Check for End of Utterance - this triggers LLM response # 4. Check for End of Utterance - this triggers LLM response
if self.eou_detector.process(vad_status): if self.eou_detector.process(vad_status):
await self._on_end_of_utterance() await self._on_end_of_utterance()
except Exception as e: except Exception as e:
logger.error(f"Pipeline audio processing error: {e}", exc_info=True) logger.error(f"Pipeline audio processing error: {e}", exc_info=True)
async def process_text(self, text: str) -> None: async def process_text(self, text: str) -> None:
""" """
Process text input (chat command). Process text input (chat command).
Allows direct text input to bypass ASR. Allows direct text input to bypass ASR.
Args: Args:
text: User text input text: User text input
""" """
if not self._running: if not self._running:
return return
logger.info(f"Processing text input: {text[:50]}...") logger.info(f"Processing text input: {text[:50]}...")
# Cancel any current speaking # Cancel any current speaking
await self._stop_current_speech() await self._stop_current_speech()
# Start new turn # Start new turn
await self.conversation.end_user_turn(text) await self.conversation.end_user_turn(text)
self._current_turn_task = asyncio.create_task(self._handle_turn(text)) self._current_turn_task = asyncio.create_task(self._handle_turn(text))
async def interrupt(self) -> None: async def interrupt(self) -> None:
"""Interrupt current bot speech (manual interrupt command).""" """Interrupt current bot speech (manual interrupt command)."""
await self._handle_barge_in() await self._handle_barge_in()
async def _on_transcript_callback(self, text: str, is_final: bool) -> None: async def _on_transcript_callback(self, text: str, is_final: bool) -> None:
""" """
Callback for ASR transcription results. Callback for ASR transcription results.
Streams transcription to client for display. Streams transcription to client for display.
Args: Args:
text: Transcribed text text: Transcribed text
is_final: Whether this is the final transcription is_final: Whether this is the final transcription
@@ -399,9 +400,9 @@ class DuplexPipeline:
# Avoid sending duplicate transcripts # Avoid sending duplicate transcripts
if text == self._last_sent_transcript and not is_final: if text == self._last_sent_transcript and not is_final:
return return
self._last_sent_transcript = text self._last_sent_transcript = text
# Send transcript event to client # Send transcript event to client
await self.transport.send_event({ await self.transport.send_event({
**ev( **ev(
@@ -410,9 +411,9 @@ class DuplexPipeline:
text=text, text=text,
) )
}) })
logger.debug(f"Sent transcript ({'final' if is_final else 'interim'}): {text[:50]}...") logger.debug(f"Sent transcript ({'final' if is_final else 'interim'}): {text[:50]}...")
async def _on_speech_start(self) -> None: async def _on_speech_start(self) -> None:
"""Handle user starting to speak.""" """Handle user starting to speak."""
if self.conversation.state == ConversationState.IDLE: if self.conversation.state == ConversationState.IDLE:
@@ -420,34 +421,34 @@ class DuplexPipeline:
self._audio_buffer = b"" self._audio_buffer = b""
self._last_sent_transcript = "" self._last_sent_transcript = ""
self.eou_detector.reset() self.eou_detector.reset()
# Clear ASR buffer and start interim transcriptions # Clear ASR buffer and start interim transcriptions
if hasattr(self.asr_service, 'clear_buffer'): if hasattr(self.asr_service, 'clear_buffer'):
self.asr_service.clear_buffer() self.asr_service.clear_buffer()
if hasattr(self.asr_service, 'start_interim_transcription'): if hasattr(self.asr_service, 'start_interim_transcription'):
await self.asr_service.start_interim_transcription() await self.asr_service.start_interim_transcription()
logger.debug("User speech started") logger.debug("User speech started")
async def _on_end_of_utterance(self) -> None: async def _on_end_of_utterance(self) -> None:
"""Handle end of user utterance.""" """Handle end of user utterance."""
if self.conversation.state != ConversationState.LISTENING: if self.conversation.state != ConversationState.LISTENING:
return return
# Stop interim transcriptions # Stop interim transcriptions
if hasattr(self.asr_service, 'stop_interim_transcription'): if hasattr(self.asr_service, 'stop_interim_transcription'):
await self.asr_service.stop_interim_transcription() await self.asr_service.stop_interim_transcription()
# Get final transcription from ASR service # Get final transcription from ASR service
user_text = "" user_text = ""
if hasattr(self.asr_service, 'get_final_transcription'): if hasattr(self.asr_service, 'get_final_transcription'):
# SiliconFlow ASR - get final transcription # SiliconFlow ASR - get final transcription
user_text = await self.asr_service.get_final_transcription() user_text = await self.asr_service.get_final_transcription()
elif hasattr(self.asr_service, 'get_and_clear_text'): elif hasattr(self.asr_service, 'get_and_clear_text'):
# Buffered ASR - get accumulated text # Buffered ASR - get accumulated text
user_text = self.asr_service.get_and_clear_text() user_text = self.asr_service.get_and_clear_text()
# Skip if no meaningful text # Skip if no meaningful text
if not user_text or not user_text.strip(): if not user_text or not user_text.strip():
logger.debug("EOU detected but no transcription - skipping") logger.debug("EOU detected but no transcription - skipping")
@@ -457,9 +458,9 @@ class DuplexPipeline:
# Return to idle; don't force LISTENING which causes buffering on silence # Return to idle; don't force LISTENING which causes buffering on silence
await self.conversation.set_state(ConversationState.IDLE) await self.conversation.set_state(ConversationState.IDLE)
return return
logger.info(f"EOU detected - user said: {user_text[:100]}...") logger.info(f"EOU detected - user said: {user_text[:100]}...")
# Send final transcription to client # Send final transcription to client
await self.transport.send_event({ await self.transport.send_event({
**ev( **ev(
@@ -468,23 +469,23 @@ class DuplexPipeline:
text=user_text, text=user_text,
) )
}) })
# Clear buffers # Clear buffers
self._audio_buffer = b"" self._audio_buffer = b""
self._last_sent_transcript = "" self._last_sent_transcript = ""
# Process the turn - trigger LLM response # Process the turn - trigger LLM response
# Cancel any existing turn to avoid overlapping assistant responses # Cancel any existing turn to avoid overlapping assistant responses
await self._stop_current_speech() await self._stop_current_speech()
await self.conversation.end_user_turn(user_text) await self.conversation.end_user_turn(user_text)
self._current_turn_task = asyncio.create_task(self._handle_turn(user_text)) self._current_turn_task = asyncio.create_task(self._handle_turn(user_text))
async def _handle_turn(self, user_text: str) -> None: async def _handle_turn(self, user_text: str) -> None:
""" """
Handle a complete conversation turn. Handle a complete conversation turn.
Uses sentence-by-sentence streaming TTS for lower latency. Uses sentence-by-sentence streaming TTS for lower latency.
Args: Args:
user_text: User's transcribed text user_text: User's transcribed text
""" """
@@ -492,30 +493,30 @@ class DuplexPipeline:
# Start latency tracking # Start latency tracking
self._turn_start_time = time.time() self._turn_start_time = time.time()
self._first_audio_sent = False self._first_audio_sent = False
# Get AI response (streaming) # Get AI response (streaming)
messages = self.conversation.get_messages() messages = self.conversation.get_messages()
full_response = "" full_response = ""
await self.conversation.start_assistant_turn() await self.conversation.start_assistant_turn()
self._is_bot_speaking = True self._is_bot_speaking = True
self._interrupt_event.clear() self._interrupt_event.clear()
# Sentence buffer for streaming TTS # Sentence buffer for streaming TTS
sentence_buffer = "" sentence_buffer = ""
pending_punctuation = "" pending_punctuation = ""
first_audio_sent = False first_audio_sent = False
spoken_sentence_count = 0 spoken_sentence_count = 0
# Stream LLM response and TTS sentence by sentence # Stream LLM response and TTS sentence by sentence
async for text_chunk in self.llm_service.generate_stream(messages): async for text_chunk in self.llm_service.generate_stream(messages):
if self._interrupt_event.is_set(): if self._interrupt_event.is_set():
break break
full_response += text_chunk full_response += text_chunk
sentence_buffer += text_chunk sentence_buffer += text_chunk
await self.conversation.update_assistant_text(text_chunk) await self.conversation.update_assistant_text(text_chunk)
# Send LLM response streaming event to client # Send LLM response streaming event to client
await self.transport.send_event({ await self.transport.send_event({
**ev( **ev(
@@ -524,7 +525,7 @@ class DuplexPipeline:
text=text_chunk, text=text_chunk,
) )
}) })
# Check for sentence completion - synthesize immediately for low latency # Check for sentence completion - synthesize immediately for low latency
while True: while True:
split_result = self._extract_tts_sentence(sentence_buffer, force=False) split_result = self._extract_tts_sentence(sentence_buffer, force=False)
@@ -561,7 +562,7 @@ class DuplexPipeline:
fade_out_ms=8, fade_out_ms=8,
) )
spoken_sentence_count += 1 spoken_sentence_count += 1
# Send final LLM response event # Send final LLM response event
if full_response and not self._interrupt_event.is_set(): if full_response and not self._interrupt_event.is_set():
await self.transport.send_event({ await self.transport.send_event({
@@ -571,7 +572,7 @@ class DuplexPipeline:
text=full_response, text=full_response,
) )
}) })
# Speak any remaining text # Speak any remaining text
remaining_text = f"{pending_punctuation}{sentence_buffer}".strip() remaining_text = f"{pending_punctuation}{sentence_buffer}".strip()
if remaining_text and self._has_spoken_content(remaining_text) and not self._interrupt_event.is_set(): if remaining_text and self._has_spoken_content(remaining_text) and not self._interrupt_event.is_set():
@@ -588,7 +589,7 @@ class DuplexPipeline:
fade_in_ms=0, fade_in_ms=0,
fade_out_ms=8, fade_out_ms=8,
) )
# Send track end # Send track end
if first_audio_sent: if first_audio_sent:
await self.transport.send_event({ await self.transport.send_event({
@@ -597,12 +598,12 @@ class DuplexPipeline:
trackId=self.session_id, trackId=self.session_id,
) )
}) })
# End assistant turn # End assistant turn
await self.conversation.end_assistant_turn( await self.conversation.end_assistant_turn(
was_interrupted=self._interrupt_event.is_set() was_interrupted=self._interrupt_event.is_set()
) )
except asyncio.CancelledError: except asyncio.CancelledError:
logger.info("Turn handling cancelled") logger.info("Turn handling cancelled")
await self.conversation.end_assistant_turn(was_interrupted=True) await self.conversation.end_assistant_turn(was_interrupted=True)
@@ -615,7 +616,7 @@ class DuplexPipeline:
self._barge_in_speech_start_time = None self._barge_in_speech_start_time = None
self._barge_in_speech_frames = 0 self._barge_in_speech_frames = 0
self._barge_in_silence_frames = 0 self._barge_in_silence_frames = 0
def _extract_tts_sentence(self, text_buffer: str, force: bool = False) -> Optional[tuple[str, str]]: def _extract_tts_sentence(self, text_buffer: str, force: bool = False) -> Optional[tuple[str, str]]:
""" """
Extract one TTS sentence from the buffer. Extract one TTS sentence from the buffer.
@@ -683,7 +684,7 @@ class DuplexPipeline:
async def _speak_sentence(self, text: str, fade_in_ms: int = 0, fade_out_ms: int = 8) -> None: async def _speak_sentence(self, text: str, fade_in_ms: int = 0, fade_out_ms: int = 8) -> None:
""" """
Synthesize and send a single sentence. Synthesize and send a single sentence.
Args: Args:
text: Sentence to speak text: Sentence to speak
fade_in_ms: Fade-in duration for sentence start chunks fade_in_ms: Fade-in duration for sentence start chunks
@@ -691,7 +692,7 @@ class DuplexPipeline:
""" """
if not text.strip() or self._interrupt_event.is_set(): if not text.strip() or self._interrupt_event.is_set():
return return
try: try:
is_first_chunk = True is_first_chunk = True
async for chunk in self.tts_service.synthesize_stream(text): async for chunk in self.tts_service.synthesize_stream(text):
@@ -699,13 +700,13 @@ class DuplexPipeline:
if self._interrupt_event.is_set(): if self._interrupt_event.is_set():
logger.debug("TTS sentence interrupted") logger.debug("TTS sentence interrupted")
break break
# Track and log first audio packet latency (TTFB) # Track and log first audio packet latency (TTFB)
if not self._first_audio_sent and self._turn_start_time: if not self._first_audio_sent and self._turn_start_time:
ttfb_ms = (time.time() - self._turn_start_time) * 1000 ttfb_ms = (time.time() - self._turn_start_time) * 1000
self._first_audio_sent = True self._first_audio_sent = True
logger.info(f"[TTFB] Server first audio packet latency: {ttfb_ms:.0f}ms (session {self.session_id})") logger.info(f"[TTFB] Server first audio packet latency: {ttfb_ms:.0f}ms (session {self.session_id})")
# Send TTFB event to client # Send TTFB event to client
await self.transport.send_event({ await self.transport.send_event({
**ev( **ev(
@@ -714,7 +715,7 @@ class DuplexPipeline:
latencyMs=round(ttfb_ms), latencyMs=round(ttfb_ms),
) )
}) })
# Double-check interrupt right before sending audio # Double-check interrupt right before sending audio
if self._interrupt_event.is_set(): if self._interrupt_event.is_set():
break break
@@ -766,22 +767,22 @@ class DuplexPipeline:
except Exception: except Exception:
# Fallback: never block audio delivery on smoothing failure. # Fallback: never block audio delivery on smoothing failure.
return pcm_bytes return pcm_bytes
async def _speak(self, text: str) -> None: async def _speak(self, text: str) -> None:
""" """
Synthesize and send speech. Synthesize and send speech.
Args: Args:
text: Text to speak text: Text to speak
""" """
if not text.strip(): if not text.strip():
return return
try: try:
# Start latency tracking for greeting # Start latency tracking for greeting
speak_start_time = time.time() speak_start_time = time.time()
first_audio_sent = False first_audio_sent = False
# Send track start event # Send track start event
await self.transport.send_event({ await self.transport.send_event({
**ev( **ev(
@@ -789,21 +790,21 @@ class DuplexPipeline:
trackId=self.session_id, trackId=self.session_id,
) )
}) })
self._is_bot_speaking = True self._is_bot_speaking = True
# Stream TTS audio # Stream TTS audio
async for chunk in self.tts_service.synthesize_stream(text): async for chunk in self.tts_service.synthesize_stream(text):
if self._interrupt_event.is_set(): if self._interrupt_event.is_set():
logger.info("TTS interrupted by barge-in") logger.info("TTS interrupted by barge-in")
break break
# Track and log first audio packet latency (TTFB) # Track and log first audio packet latency (TTFB)
if not first_audio_sent: if not first_audio_sent:
ttfb_ms = (time.time() - speak_start_time) * 1000 ttfb_ms = (time.time() - speak_start_time) * 1000
first_audio_sent = True first_audio_sent = True
logger.info(f"[TTFB] Greeting first audio packet latency: {ttfb_ms:.0f}ms (session {self.session_id})") logger.info(f"[TTFB] Greeting first audio packet latency: {ttfb_ms:.0f}ms (session {self.session_id})")
# Send TTFB event to client # Send TTFB event to client
await self.transport.send_event({ await self.transport.send_event({
**ev( **ev(
@@ -812,13 +813,13 @@ class DuplexPipeline:
latencyMs=round(ttfb_ms), latencyMs=round(ttfb_ms),
) )
}) })
# Send audio to client # Send audio to client
await self.transport.send_audio(chunk.audio) await self.transport.send_audio(chunk.audio)
# Small delay to prevent flooding # Small delay to prevent flooding
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
# Send track end event # Send track end event
await self.transport.send_event({ await self.transport.send_event({
**ev( **ev(
@@ -826,7 +827,7 @@ class DuplexPipeline:
trackId=self.session_id, trackId=self.session_id,
) )
}) })
except asyncio.CancelledError: except asyncio.CancelledError:
logger.info("TTS cancelled") logger.info("TTS cancelled")
raise raise
@@ -834,23 +835,23 @@ class DuplexPipeline:
logger.error(f"TTS error: {e}") logger.error(f"TTS error: {e}")
finally: finally:
self._is_bot_speaking = False self._is_bot_speaking = False
async def _handle_barge_in(self) -> None: async def _handle_barge_in(self) -> None:
"""Handle user barge-in (interruption).""" """Handle user barge-in (interruption)."""
if not self._is_bot_speaking: if not self._is_bot_speaking:
return return
logger.info("Barge-in detected - interrupting bot speech") logger.info("Barge-in detected - interrupting bot speech")
# Reset barge-in tracking # Reset barge-in tracking
self._barge_in_speech_start_time = None self._barge_in_speech_start_time = None
self._barge_in_speech_frames = 0 self._barge_in_speech_frames = 0
self._barge_in_silence_frames = 0 self._barge_in_silence_frames = 0
# IMPORTANT: Signal interruption FIRST to stop audio sending # IMPORTANT: Signal interruption FIRST to stop audio sending
self._interrupt_event.set() self._interrupt_event.set()
self._is_bot_speaking = False self._is_bot_speaking = False
# Send interrupt event to client IMMEDIATELY # Send interrupt event to client IMMEDIATELY
# This must happen BEFORE canceling services, so client knows to discard in-flight audio # This must happen BEFORE canceling services, so client knows to discard in-flight audio
await self.transport.send_event({ await self.transport.send_event({
@@ -859,25 +860,25 @@ class DuplexPipeline:
trackId=self.session_id, trackId=self.session_id,
) )
}) })
# Cancel TTS # Cancel TTS
if self.tts_service: if self.tts_service:
await self.tts_service.cancel() await self.tts_service.cancel()
# Cancel LLM # Cancel LLM
if self.llm_service and hasattr(self.llm_service, 'cancel'): if self.llm_service and hasattr(self.llm_service, 'cancel'):
self.llm_service.cancel() self.llm_service.cancel()
# Interrupt conversation only if there is no active turn task. # Interrupt conversation only if there is no active turn task.
# When a turn task exists, it will handle end_assistant_turn() to avoid double callbacks. # When a turn task exists, it will handle end_assistant_turn() to avoid double callbacks.
if not (self._current_turn_task and not self._current_turn_task.done()): if not (self._current_turn_task and not self._current_turn_task.done()):
await self.conversation.interrupt() await self.conversation.interrupt()
# Reset for new user turn # Reset for new user turn
await self.conversation.start_user_turn() await self.conversation.start_user_turn()
self._audio_buffer = b"" self._audio_buffer = b""
self.eou_detector.reset() self.eou_detector.reset()
async def _stop_current_speech(self) -> None: async def _stop_current_speech(self) -> None:
"""Stop any current speech task.""" """Stop any current speech task."""
if self._current_turn_task and not self._current_turn_task.done(): if self._current_turn_task and not self._current_turn_task.done():
@@ -893,17 +894,17 @@ class DuplexPipeline:
await self.tts_service.cancel() await self.tts_service.cancel()
if self.llm_service and hasattr(self.llm_service, 'cancel'): if self.llm_service and hasattr(self.llm_service, 'cancel'):
self.llm_service.cancel() self.llm_service.cancel()
self._is_bot_speaking = False self._is_bot_speaking = False
self._interrupt_event.clear() self._interrupt_event.clear()
async def cleanup(self) -> None: async def cleanup(self) -> None:
"""Cleanup pipeline resources.""" """Cleanup pipeline resources."""
logger.info(f"Cleaning up DuplexPipeline for session {self.session_id}") logger.info(f"Cleaning up DuplexPipeline for session {self.session_id}")
self._running = False self._running = False
await self._stop_current_speech() await self._stop_current_speech()
# Disconnect services # Disconnect services
if self.llm_service: if self.llm_service:
await self.llm_service.disconnect() await self.llm_service.disconnect()
@@ -911,17 +912,17 @@ class DuplexPipeline:
await self.tts_service.disconnect() await self.tts_service.disconnect()
if self.asr_service: if self.asr_service:
await self.asr_service.disconnect() await self.asr_service.disconnect()
def _get_timestamp_ms(self) -> int: def _get_timestamp_ms(self) -> int:
"""Get current timestamp in milliseconds.""" """Get current timestamp in milliseconds."""
import time import time
return int(time.time() * 1000) return int(time.time() * 1000)
@property @property
def is_speaking(self) -> bool: def is_speaking(self) -> bool:
"""Check if bot is currently speaking.""" """Check if bot is currently speaking."""
return self._is_bot_speaking return self._is_bot_speaking
@property @property
def state(self) -> ConversationState: def state(self) -> ConversationState:
"""Get current conversation state.""" """Get current conversation state."""