882 lines
34 KiB
Python
882 lines
34 KiB
Python
"""Full duplex audio pipeline for AI voice conversation.
|
||
|
||
This module implements the core duplex pipeline that orchestrates:
|
||
- VAD (Voice Activity Detection)
|
||
- EOU (End of Utterance) Detection
|
||
- ASR (Automatic Speech Recognition) - optional
|
||
- LLM (Language Model)
|
||
- TTS (Text-to-Speech)
|
||
|
||
Inspired by pipecat's frame-based architecture and active-call's
|
||
event-driven design.
|
||
"""
|
||
|
||
import asyncio
|
||
import time
|
||
from typing import Any, Dict, Optional
|
||
|
||
import numpy as np
|
||
from loguru import logger
|
||
|
||
from app.config import settings
|
||
from core.conversation import ConversationManager, ConversationState
|
||
from core.events import get_event_bus
|
||
from core.transports import BaseTransport
|
||
from models.ws_v1 import ev
|
||
from processors.eou import EouDetector
|
||
from processors.vad import SileroVAD, VADProcessor
|
||
from services.asr import BufferedASRService
|
||
from services.base import BaseASRService, BaseLLMService, BaseTTSService
|
||
from services.llm import MockLLMService, OpenAILLMService
|
||
from services.siliconflow_asr import SiliconFlowASRService
|
||
from services.siliconflow_tts import SiliconFlowTTSService
|
||
from services.streaming_text import extract_tts_sentence, has_spoken_content
|
||
from services.tts import EdgeTTSService, MockTTSService
|
||
|
||
|
||
class DuplexPipeline:
|
||
"""
|
||
Full duplex audio pipeline for AI voice conversation.
|
||
|
||
Handles bidirectional audio flow with:
|
||
- User speech detection and transcription
|
||
- AI response generation
|
||
- Text-to-speech synthesis
|
||
- Barge-in (interruption) support
|
||
|
||
Architecture (inspired by pipecat):
|
||
|
||
User Audio → VAD → EOU → [ASR] → LLM → TTS → Audio Out
|
||
↓
|
||
Barge-in Detection → Interrupt
|
||
"""
|
||
|
||
_SENTENCE_END_CHARS = frozenset({"。", "!", "?", ".", "!", "?", "\n"})
|
||
_SENTENCE_TRAILING_CHARS = frozenset({"。", "!", "?", ".", "!", "?", "…", "~", "~", "\n"})
|
||
_SENTENCE_CLOSERS = frozenset({'"', "'", "”", "’", ")", "]", "}", ")", "】", "」", "』", "》"})
|
||
_MIN_SPLIT_SPOKEN_CHARS = 6
|
||
|
||
def __init__(
|
||
self,
|
||
transport: BaseTransport,
|
||
session_id: str,
|
||
llm_service: Optional[BaseLLMService] = None,
|
||
tts_service: Optional[BaseTTSService] = None,
|
||
asr_service: Optional[BaseASRService] = None,
|
||
system_prompt: Optional[str] = None,
|
||
greeting: Optional[str] = None
|
||
):
|
||
"""
|
||
Initialize duplex pipeline.
|
||
|
||
Args:
|
||
transport: Transport for sending audio/events
|
||
session_id: Session identifier
|
||
llm_service: LLM service (defaults to OpenAI)
|
||
tts_service: TTS service (defaults to EdgeTTS)
|
||
asr_service: ASR service (optional)
|
||
system_prompt: System prompt for LLM
|
||
greeting: Optional greeting to speak on start
|
||
"""
|
||
self.transport = transport
|
||
self.session_id = session_id
|
||
self.event_bus = get_event_bus()
|
||
|
||
# Initialize VAD
|
||
self.vad_model = SileroVAD(
|
||
model_path=settings.vad_model_path,
|
||
sample_rate=settings.sample_rate
|
||
)
|
||
self.vad_processor = VADProcessor(
|
||
vad_model=self.vad_model,
|
||
threshold=settings.vad_threshold
|
||
)
|
||
|
||
# Initialize EOU detector
|
||
self.eou_detector = EouDetector(
|
||
silence_threshold_ms=settings.vad_eou_threshold_ms,
|
||
min_speech_duration_ms=settings.vad_min_speech_duration_ms
|
||
)
|
||
|
||
# Initialize services
|
||
self.llm_service = llm_service
|
||
self.tts_service = tts_service
|
||
self.asr_service = asr_service # Will be initialized in start()
|
||
|
||
# Track last sent transcript to avoid duplicates
|
||
self._last_sent_transcript = ""
|
||
|
||
# Conversation manager
|
||
self.conversation = ConversationManager(
|
||
system_prompt=system_prompt,
|
||
greeting=greeting
|
||
)
|
||
|
||
# State
|
||
self._running = True
|
||
self._is_bot_speaking = False
|
||
self._current_turn_task: Optional[asyncio.Task] = None
|
||
self._audio_buffer: bytes = b""
|
||
max_buffer_seconds = settings.max_audio_buffer_seconds if hasattr(settings, "max_audio_buffer_seconds") else 30
|
||
self._max_audio_buffer_bytes = int(settings.sample_rate * 2 * max_buffer_seconds)
|
||
self._last_vad_status: str = "Silence"
|
||
self._process_lock = asyncio.Lock()
|
||
|
||
# Interruption handling
|
||
self._interrupt_event = asyncio.Event()
|
||
|
||
# Latency tracking - TTFB (Time to First Byte)
|
||
self._turn_start_time: Optional[float] = None
|
||
self._first_audio_sent: bool = False
|
||
|
||
# Barge-in filtering - require minimum speech duration to interrupt
|
||
self._barge_in_speech_start_time: Optional[float] = None
|
||
self._barge_in_min_duration_ms: int = settings.barge_in_min_duration_ms if hasattr(settings, 'barge_in_min_duration_ms') else 50
|
||
self._barge_in_speech_frames: int = 0 # Count speech frames
|
||
self._barge_in_silence_frames: int = 0 # Count silence frames during potential barge-in
|
||
self._barge_in_silence_tolerance: int = 3 # Allow up to 3 silence frames (60ms at 20ms chunks)
|
||
|
||
# Runtime overrides injected from session.start metadata
|
||
self._runtime_llm: Dict[str, Any] = {}
|
||
self._runtime_asr: Dict[str, Any] = {}
|
||
self._runtime_tts: Dict[str, Any] = {}
|
||
self._runtime_system_prompt: Optional[str] = None
|
||
self._runtime_greeting: Optional[str] = None
|
||
|
||
logger.info(f"DuplexPipeline initialized for session {session_id}")
|
||
|
||
def apply_runtime_overrides(self, metadata: Optional[Dict[str, Any]]) -> None:
|
||
"""
|
||
Apply runtime overrides from WS session.start metadata.
|
||
|
||
Expected metadata shape:
|
||
{
|
||
"systemPrompt": "...",
|
||
"greeting": "...",
|
||
"services": {
|
||
"llm": {...},
|
||
"asr": {...},
|
||
"tts": {...}
|
||
}
|
||
}
|
||
"""
|
||
if not metadata:
|
||
return
|
||
|
||
if "systemPrompt" in metadata:
|
||
self._runtime_system_prompt = str(metadata.get("systemPrompt") or "")
|
||
if self._runtime_system_prompt:
|
||
self.conversation.system_prompt = self._runtime_system_prompt
|
||
if "greeting" in metadata:
|
||
self._runtime_greeting = str(metadata.get("greeting") or "")
|
||
self.conversation.greeting = self._runtime_greeting or None
|
||
|
||
services = metadata.get("services") or {}
|
||
if isinstance(services, dict):
|
||
if isinstance(services.get("llm"), dict):
|
||
self._runtime_llm = services["llm"]
|
||
if isinstance(services.get("asr"), dict):
|
||
self._runtime_asr = services["asr"]
|
||
if isinstance(services.get("tts"), dict):
|
||
self._runtime_tts = services["tts"]
|
||
|
||
async def start(self) -> None:
|
||
"""Start the pipeline and connect services."""
|
||
try:
|
||
# Connect LLM service
|
||
if not self.llm_service:
|
||
llm_api_key = self._runtime_llm.get("apiKey") or settings.openai_api_key
|
||
llm_base_url = self._runtime_llm.get("baseUrl") or settings.openai_api_url
|
||
llm_model = self._runtime_llm.get("model") or settings.llm_model
|
||
llm_provider = (self._runtime_llm.get("provider") or "openai").lower()
|
||
|
||
if llm_provider == "openai" and llm_api_key:
|
||
self.llm_service = OpenAILLMService(
|
||
api_key=llm_api_key,
|
||
base_url=llm_base_url,
|
||
model=llm_model
|
||
)
|
||
else:
|
||
logger.warning("No OpenAI API key - using mock LLM")
|
||
self.llm_service = MockLLMService()
|
||
|
||
await self.llm_service.connect()
|
||
|
||
# Connect TTS service
|
||
if not self.tts_service:
|
||
tts_provider = (self._runtime_tts.get("provider") or settings.tts_provider).lower()
|
||
tts_api_key = self._runtime_tts.get("apiKey") or settings.siliconflow_api_key
|
||
tts_voice = self._runtime_tts.get("voice") or settings.tts_voice
|
||
tts_model = self._runtime_tts.get("model") or settings.siliconflow_tts_model
|
||
tts_speed = float(self._runtime_tts.get("speed") or settings.tts_speed)
|
||
|
||
if tts_provider == "siliconflow" and tts_api_key:
|
||
self.tts_service = SiliconFlowTTSService(
|
||
api_key=tts_api_key,
|
||
voice=tts_voice,
|
||
model=tts_model,
|
||
sample_rate=settings.sample_rate,
|
||
speed=tts_speed
|
||
)
|
||
logger.info("Using SiliconFlow TTS service")
|
||
else:
|
||
self.tts_service = EdgeTTSService(
|
||
voice=tts_voice,
|
||
sample_rate=settings.sample_rate
|
||
)
|
||
logger.info("Using Edge TTS service")
|
||
|
||
try:
|
||
await self.tts_service.connect()
|
||
except Exception as e:
|
||
logger.warning(f"TTS backend unavailable ({e}); falling back to MockTTS")
|
||
self.tts_service = MockTTSService(
|
||
sample_rate=settings.sample_rate
|
||
)
|
||
await self.tts_service.connect()
|
||
|
||
# Connect ASR service
|
||
if not self.asr_service:
|
||
asr_provider = (self._runtime_asr.get("provider") or settings.asr_provider).lower()
|
||
asr_api_key = self._runtime_asr.get("apiKey") or settings.siliconflow_api_key
|
||
asr_model = self._runtime_asr.get("model") or settings.siliconflow_asr_model
|
||
asr_interim_interval = int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms)
|
||
asr_min_audio_ms = int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms)
|
||
|
||
if asr_provider == "siliconflow" and asr_api_key:
|
||
self.asr_service = SiliconFlowASRService(
|
||
api_key=asr_api_key,
|
||
model=asr_model,
|
||
sample_rate=settings.sample_rate,
|
||
interim_interval_ms=asr_interim_interval,
|
||
min_audio_for_interim_ms=asr_min_audio_ms,
|
||
on_transcript=self._on_transcript_callback
|
||
)
|
||
logger.info("Using SiliconFlow ASR service")
|
||
else:
|
||
self.asr_service = BufferedASRService(
|
||
sample_rate=settings.sample_rate
|
||
)
|
||
logger.info("Using Buffered ASR service (no real transcription)")
|
||
|
||
await self.asr_service.connect()
|
||
|
||
logger.info("DuplexPipeline services connected")
|
||
|
||
# Speak greeting if configured
|
||
if self.conversation.greeting:
|
||
await self._speak(self.conversation.greeting)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to start pipeline: {e}")
|
||
raise
|
||
|
||
async def process_audio(self, pcm_bytes: bytes) -> None:
|
||
"""
|
||
Process incoming audio chunk.
|
||
|
||
This is the main entry point for audio from the user.
|
||
|
||
Args:
|
||
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
|
||
"""
|
||
if not self._running:
|
||
return
|
||
|
||
try:
|
||
async with self._process_lock:
|
||
# 1. Process through VAD
|
||
vad_result = self.vad_processor.process(pcm_bytes, settings.chunk_size_ms)
|
||
|
||
vad_status = "Silence"
|
||
if vad_result:
|
||
event_type, probability = vad_result
|
||
vad_status = "Speech" if event_type == "speaking" else "Silence"
|
||
|
||
# Emit VAD event
|
||
await self.event_bus.publish(event_type, {
|
||
"trackId": self.session_id,
|
||
"probability": probability
|
||
})
|
||
await self.transport.send_event(
|
||
ev(
|
||
"input.speech_started" if event_type == "speaking" else "input.speech_stopped",
|
||
trackId=self.session_id,
|
||
probability=probability,
|
||
)
|
||
)
|
||
else:
|
||
# No state change - keep previous status
|
||
vad_status = self._last_vad_status
|
||
|
||
# Update state based on VAD
|
||
if vad_status == "Speech" and self._last_vad_status != "Speech":
|
||
await self._on_speech_start()
|
||
|
||
self._last_vad_status = vad_status
|
||
|
||
# 2. Check for barge-in (user speaking while bot speaking)
|
||
# Filter false interruptions by requiring minimum speech duration
|
||
if self._is_bot_speaking:
|
||
if vad_status == "Speech":
|
||
# User is speaking while bot is speaking
|
||
self._barge_in_silence_frames = 0 # Reset silence counter
|
||
|
||
if self._barge_in_speech_start_time is None:
|
||
# Start tracking speech duration
|
||
self._barge_in_speech_start_time = time.time()
|
||
self._barge_in_speech_frames = 1
|
||
logger.debug("Potential barge-in detected, tracking duration...")
|
||
else:
|
||
self._barge_in_speech_frames += 1
|
||
# Check if speech duration exceeds threshold
|
||
speech_duration_ms = (time.time() - self._barge_in_speech_start_time) * 1000
|
||
if speech_duration_ms >= self._barge_in_min_duration_ms:
|
||
logger.info(f"Barge-in confirmed after {speech_duration_ms:.0f}ms of speech ({self._barge_in_speech_frames} frames)")
|
||
await self._handle_barge_in()
|
||
else:
|
||
# Silence frame during potential barge-in
|
||
if self._barge_in_speech_start_time is not None:
|
||
self._barge_in_silence_frames += 1
|
||
# Allow brief silence gaps (VAD flickering)
|
||
if self._barge_in_silence_frames > self._barge_in_silence_tolerance:
|
||
# Too much silence - reset barge-in tracking
|
||
logger.debug(f"Barge-in cancelled after {self._barge_in_silence_frames} silence frames")
|
||
self._barge_in_speech_start_time = None
|
||
self._barge_in_speech_frames = 0
|
||
self._barge_in_silence_frames = 0
|
||
|
||
# 3. Buffer audio for ASR
|
||
if vad_status == "Speech" or self.conversation.state == ConversationState.LISTENING:
|
||
self._audio_buffer += pcm_bytes
|
||
if len(self._audio_buffer) > self._max_audio_buffer_bytes:
|
||
# Keep only the most recent audio to cap memory usage
|
||
self._audio_buffer = self._audio_buffer[-self._max_audio_buffer_bytes:]
|
||
await self.asr_service.send_audio(pcm_bytes)
|
||
|
||
# For SiliconFlow ASR, trigger interim transcription periodically
|
||
# The service handles timing internally via start_interim_transcription()
|
||
|
||
# 4. Check for End of Utterance - this triggers LLM response
|
||
if self.eou_detector.process(vad_status):
|
||
await self._on_end_of_utterance()
|
||
|
||
except Exception as e:
|
||
logger.error(f"Pipeline audio processing error: {e}", exc_info=True)
|
||
|
||
async def process_text(self, text: str) -> None:
|
||
"""
|
||
Process text input (chat command).
|
||
|
||
Allows direct text input to bypass ASR.
|
||
|
||
Args:
|
||
text: User text input
|
||
"""
|
||
if not self._running:
|
||
return
|
||
|
||
logger.info(f"Processing text input: {text[:50]}...")
|
||
|
||
# Cancel any current speaking
|
||
await self._stop_current_speech()
|
||
|
||
# Start new turn
|
||
await self.conversation.end_user_turn(text)
|
||
self._current_turn_task = asyncio.create_task(self._handle_turn(text))
|
||
|
||
async def interrupt(self) -> None:
|
||
"""Interrupt current bot speech (manual interrupt command)."""
|
||
await self._handle_barge_in()
|
||
|
||
async def _on_transcript_callback(self, text: str, is_final: bool) -> None:
|
||
"""
|
||
Callback for ASR transcription results.
|
||
|
||
Streams transcription to client for display.
|
||
|
||
Args:
|
||
text: Transcribed text
|
||
is_final: Whether this is the final transcription
|
||
"""
|
||
# Avoid sending duplicate transcripts
|
||
if text == self._last_sent_transcript and not is_final:
|
||
return
|
||
|
||
self._last_sent_transcript = text
|
||
|
||
# Send transcript event to client
|
||
await self.transport.send_event({
|
||
**ev(
|
||
"transcript.final" if is_final else "transcript.delta",
|
||
trackId=self.session_id,
|
||
text=text,
|
||
)
|
||
})
|
||
|
||
if not is_final:
|
||
logger.info(f"ASR interim: {text[:100]}")
|
||
logger.debug(f"Sent transcript ({'final' if is_final else 'interim'}): {text[:50]}...")
|
||
|
||
async def _on_speech_start(self) -> None:
|
||
"""Handle user starting to speak."""
|
||
if self.conversation.state == ConversationState.IDLE:
|
||
await self.conversation.start_user_turn()
|
||
self._audio_buffer = b""
|
||
self._last_sent_transcript = ""
|
||
self.eou_detector.reset()
|
||
|
||
# Clear ASR buffer and start interim transcriptions
|
||
if hasattr(self.asr_service, 'clear_buffer'):
|
||
self.asr_service.clear_buffer()
|
||
if hasattr(self.asr_service, 'start_interim_transcription'):
|
||
await self.asr_service.start_interim_transcription()
|
||
|
||
logger.debug("User speech started")
|
||
|
||
async def _on_end_of_utterance(self) -> None:
|
||
"""Handle end of user utterance."""
|
||
if self.conversation.state != ConversationState.LISTENING:
|
||
return
|
||
|
||
# Stop interim transcriptions
|
||
if hasattr(self.asr_service, 'stop_interim_transcription'):
|
||
await self.asr_service.stop_interim_transcription()
|
||
|
||
# Get final transcription from ASR service
|
||
user_text = ""
|
||
|
||
if hasattr(self.asr_service, 'get_final_transcription'):
|
||
# SiliconFlow ASR - get final transcription
|
||
user_text = await self.asr_service.get_final_transcription()
|
||
elif hasattr(self.asr_service, 'get_and_clear_text'):
|
||
# Buffered ASR - get accumulated text
|
||
user_text = self.asr_service.get_and_clear_text()
|
||
|
||
# Skip if no meaningful text
|
||
if not user_text or not user_text.strip():
|
||
logger.debug("EOU detected but no transcription - skipping")
|
||
# Reset for next utterance
|
||
self._audio_buffer = b""
|
||
self._last_sent_transcript = ""
|
||
# Return to idle; don't force LISTENING which causes buffering on silence
|
||
await self.conversation.set_state(ConversationState.IDLE)
|
||
return
|
||
|
||
logger.info(f"EOU detected - user said: {user_text[:100]}...")
|
||
|
||
# For ASR backends that already emitted final via callback,
|
||
# avoid duplicating transcript.final on EOU.
|
||
if user_text != self._last_sent_transcript:
|
||
await self.transport.send_event({
|
||
**ev(
|
||
"transcript.final",
|
||
trackId=self.session_id,
|
||
text=user_text,
|
||
)
|
||
})
|
||
|
||
# Clear buffers
|
||
self._audio_buffer = b""
|
||
self._last_sent_transcript = ""
|
||
|
||
# Process the turn - trigger LLM response
|
||
# Cancel any existing turn to avoid overlapping assistant responses
|
||
await self._stop_current_speech()
|
||
await self.conversation.end_user_turn(user_text)
|
||
self._current_turn_task = asyncio.create_task(self._handle_turn(user_text))
|
||
|
||
async def _handle_turn(self, user_text: str) -> None:
|
||
"""
|
||
Handle a complete conversation turn.
|
||
|
||
Uses sentence-by-sentence streaming TTS for lower latency.
|
||
|
||
Args:
|
||
user_text: User's transcribed text
|
||
"""
|
||
try:
|
||
# Start latency tracking
|
||
self._turn_start_time = time.time()
|
||
self._first_audio_sent = False
|
||
|
||
# Get AI response (streaming)
|
||
messages = self.conversation.get_messages()
|
||
full_response = ""
|
||
|
||
await self.conversation.start_assistant_turn()
|
||
self._is_bot_speaking = True
|
||
self._interrupt_event.clear()
|
||
|
||
# Sentence buffer for streaming TTS
|
||
sentence_buffer = ""
|
||
pending_punctuation = ""
|
||
first_audio_sent = False
|
||
spoken_sentence_count = 0
|
||
|
||
# Stream LLM response and TTS sentence by sentence
|
||
async for text_chunk in self.llm_service.generate_stream(messages):
|
||
if self._interrupt_event.is_set():
|
||
break
|
||
|
||
full_response += text_chunk
|
||
sentence_buffer += text_chunk
|
||
await self.conversation.update_assistant_text(text_chunk)
|
||
|
||
# Send LLM response streaming event to client
|
||
await self.transport.send_event({
|
||
**ev(
|
||
"assistant.response.delta",
|
||
trackId=self.session_id,
|
||
text=text_chunk,
|
||
)
|
||
})
|
||
|
||
# Check for sentence completion - synthesize immediately for low latency
|
||
while True:
|
||
split_result = extract_tts_sentence(
|
||
sentence_buffer,
|
||
end_chars=self._SENTENCE_END_CHARS,
|
||
trailing_chars=self._SENTENCE_TRAILING_CHARS,
|
||
closers=self._SENTENCE_CLOSERS,
|
||
min_split_spoken_chars=self._MIN_SPLIT_SPOKEN_CHARS,
|
||
hold_trailing_at_buffer_end=True,
|
||
force=False,
|
||
)
|
||
if not split_result:
|
||
break
|
||
sentence, sentence_buffer = split_result
|
||
if not sentence:
|
||
continue
|
||
|
||
sentence = f"{pending_punctuation}{sentence}".strip()
|
||
pending_punctuation = ""
|
||
if not sentence:
|
||
continue
|
||
|
||
# Avoid synthesizing punctuation-only fragments (e.g. standalone "!")
|
||
if not has_spoken_content(sentence):
|
||
pending_punctuation = sentence
|
||
continue
|
||
|
||
if not self._interrupt_event.is_set():
|
||
# Send track start on first audio
|
||
if not first_audio_sent:
|
||
await self.transport.send_event({
|
||
**ev(
|
||
"output.audio.start",
|
||
trackId=self.session_id,
|
||
)
|
||
})
|
||
first_audio_sent = True
|
||
|
||
await self._speak_sentence(
|
||
sentence,
|
||
fade_in_ms=0,
|
||
fade_out_ms=8,
|
||
)
|
||
spoken_sentence_count += 1
|
||
|
||
# Send final LLM response event
|
||
if full_response and not self._interrupt_event.is_set():
|
||
await self.transport.send_event({
|
||
**ev(
|
||
"assistant.response.final",
|
||
trackId=self.session_id,
|
||
text=full_response,
|
||
)
|
||
})
|
||
|
||
# Speak any remaining text
|
||
remaining_text = f"{pending_punctuation}{sentence_buffer}".strip()
|
||
if remaining_text and has_spoken_content(remaining_text) and not self._interrupt_event.is_set():
|
||
if not first_audio_sent:
|
||
await self.transport.send_event({
|
||
**ev(
|
||
"output.audio.start",
|
||
trackId=self.session_id,
|
||
)
|
||
})
|
||
first_audio_sent = True
|
||
await self._speak_sentence(
|
||
remaining_text,
|
||
fade_in_ms=0,
|
||
fade_out_ms=8,
|
||
)
|
||
|
||
# Send track end
|
||
if first_audio_sent:
|
||
await self.transport.send_event({
|
||
**ev(
|
||
"output.audio.end",
|
||
trackId=self.session_id,
|
||
)
|
||
})
|
||
|
||
# End assistant turn
|
||
await self.conversation.end_assistant_turn(
|
||
was_interrupted=self._interrupt_event.is_set()
|
||
)
|
||
|
||
except asyncio.CancelledError:
|
||
logger.info("Turn handling cancelled")
|
||
await self.conversation.end_assistant_turn(was_interrupted=True)
|
||
except Exception as e:
|
||
logger.error(f"Turn handling error: {e}", exc_info=True)
|
||
await self.conversation.end_assistant_turn(was_interrupted=True)
|
||
finally:
|
||
self._is_bot_speaking = False
|
||
# Reset barge-in tracking when bot finishes speaking
|
||
self._barge_in_speech_start_time = None
|
||
self._barge_in_speech_frames = 0
|
||
self._barge_in_silence_frames = 0
|
||
|
||
async def _speak_sentence(self, text: str, fade_in_ms: int = 0, fade_out_ms: int = 8) -> None:
|
||
"""
|
||
Synthesize and send a single sentence.
|
||
|
||
Args:
|
||
text: Sentence to speak
|
||
fade_in_ms: Fade-in duration for sentence start chunks
|
||
fade_out_ms: Fade-out duration for sentence end chunks
|
||
"""
|
||
if not text.strip() or self._interrupt_event.is_set():
|
||
return
|
||
|
||
logger.info(f"[TTS] split sentence: {text!r}")
|
||
|
||
try:
|
||
is_first_chunk = True
|
||
async for chunk in self.tts_service.synthesize_stream(text):
|
||
# Check interrupt at the start of each iteration
|
||
if self._interrupt_event.is_set():
|
||
logger.debug("TTS sentence interrupted")
|
||
break
|
||
|
||
# Track and log first audio packet latency (TTFB)
|
||
if not self._first_audio_sent and self._turn_start_time:
|
||
ttfb_ms = (time.time() - self._turn_start_time) * 1000
|
||
self._first_audio_sent = True
|
||
logger.info(f"[TTFB] Server first audio packet latency: {ttfb_ms:.0f}ms (session {self.session_id})")
|
||
|
||
# Send TTFB event to client
|
||
await self.transport.send_event({
|
||
**ev(
|
||
"metrics.ttfb",
|
||
trackId=self.session_id,
|
||
latencyMs=round(ttfb_ms),
|
||
)
|
||
})
|
||
|
||
# Double-check interrupt right before sending audio
|
||
if self._interrupt_event.is_set():
|
||
break
|
||
|
||
smoothed_audio = self._apply_edge_fade(
|
||
pcm_bytes=chunk.audio,
|
||
sample_rate=chunk.sample_rate,
|
||
fade_in=is_first_chunk,
|
||
fade_out=bool(chunk.is_final),
|
||
fade_in_ms=fade_in_ms,
|
||
fade_out_ms=fade_out_ms,
|
||
)
|
||
is_first_chunk = False
|
||
|
||
await self.transport.send_audio(smoothed_audio)
|
||
except asyncio.CancelledError:
|
||
logger.debug("TTS sentence cancelled")
|
||
except Exception as e:
|
||
logger.error(f"TTS sentence error: {e}")
|
||
|
||
def _apply_edge_fade(
|
||
self,
|
||
pcm_bytes: bytes,
|
||
sample_rate: int,
|
||
fade_in: bool = False,
|
||
fade_out: bool = False,
|
||
fade_in_ms: int = 0,
|
||
fade_out_ms: int = 8,
|
||
) -> bytes:
|
||
"""Apply short edge fades to reduce click/pop at sentence boundaries."""
|
||
if not pcm_bytes or (not fade_in and not fade_out):
|
||
return pcm_bytes
|
||
|
||
try:
|
||
samples = np.frombuffer(pcm_bytes, dtype="<i2").astype(np.float32)
|
||
if samples.size == 0:
|
||
return pcm_bytes
|
||
|
||
if fade_in and fade_in_ms > 0:
|
||
fade_in_samples = int(sample_rate * (fade_in_ms / 1000.0))
|
||
fade_in_samples = max(1, min(fade_in_samples, samples.size))
|
||
samples[:fade_in_samples] *= np.linspace(0.0, 1.0, fade_in_samples, endpoint=True)
|
||
if fade_out:
|
||
fade_out_samples = int(sample_rate * (fade_out_ms / 1000.0))
|
||
fade_out_samples = max(1, min(fade_out_samples, samples.size))
|
||
samples[-fade_out_samples:] *= np.linspace(1.0, 0.0, fade_out_samples, endpoint=True)
|
||
|
||
return np.clip(samples, -32768, 32767).astype("<i2").tobytes()
|
||
except Exception:
|
||
# Fallback: never block audio delivery on smoothing failure.
|
||
return pcm_bytes
|
||
|
||
async def _speak(self, text: str) -> None:
|
||
"""
|
||
Synthesize and send speech.
|
||
|
||
Args:
|
||
text: Text to speak
|
||
"""
|
||
if not text.strip():
|
||
return
|
||
|
||
try:
|
||
# Start latency tracking for greeting
|
||
speak_start_time = time.time()
|
||
first_audio_sent = False
|
||
|
||
# Send track start event
|
||
await self.transport.send_event({
|
||
**ev(
|
||
"output.audio.start",
|
||
trackId=self.session_id,
|
||
)
|
||
})
|
||
|
||
self._is_bot_speaking = True
|
||
|
||
# Stream TTS audio
|
||
async for chunk in self.tts_service.synthesize_stream(text):
|
||
if self._interrupt_event.is_set():
|
||
logger.info("TTS interrupted by barge-in")
|
||
break
|
||
|
||
# Track and log first audio packet latency (TTFB)
|
||
if not first_audio_sent:
|
||
ttfb_ms = (time.time() - speak_start_time) * 1000
|
||
first_audio_sent = True
|
||
logger.info(f"[TTFB] Greeting first audio packet latency: {ttfb_ms:.0f}ms (session {self.session_id})")
|
||
|
||
# Send TTFB event to client
|
||
await self.transport.send_event({
|
||
**ev(
|
||
"metrics.ttfb",
|
||
trackId=self.session_id,
|
||
latencyMs=round(ttfb_ms),
|
||
)
|
||
})
|
||
|
||
# Send audio to client
|
||
await self.transport.send_audio(chunk.audio)
|
||
|
||
# Small delay to prevent flooding
|
||
await asyncio.sleep(0.01)
|
||
|
||
# Send track end event
|
||
await self.transport.send_event({
|
||
**ev(
|
||
"output.audio.end",
|
||
trackId=self.session_id,
|
||
)
|
||
})
|
||
|
||
except asyncio.CancelledError:
|
||
logger.info("TTS cancelled")
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"TTS error: {e}")
|
||
finally:
|
||
self._is_bot_speaking = False
|
||
|
||
async def _handle_barge_in(self) -> None:
|
||
"""Handle user barge-in (interruption)."""
|
||
if not self._is_bot_speaking:
|
||
return
|
||
|
||
logger.info("Barge-in detected - interrupting bot speech")
|
||
|
||
# Reset barge-in tracking
|
||
self._barge_in_speech_start_time = None
|
||
self._barge_in_speech_frames = 0
|
||
self._barge_in_silence_frames = 0
|
||
|
||
# IMPORTANT: Signal interruption FIRST to stop audio sending
|
||
self._interrupt_event.set()
|
||
self._is_bot_speaking = False
|
||
|
||
# Send interrupt event to client IMMEDIATELY
|
||
# This must happen BEFORE canceling services, so client knows to discard in-flight audio
|
||
await self.transport.send_event({
|
||
**ev(
|
||
"response.interrupted",
|
||
trackId=self.session_id,
|
||
)
|
||
})
|
||
|
||
# Cancel TTS
|
||
if self.tts_service:
|
||
await self.tts_service.cancel()
|
||
|
||
# Cancel LLM
|
||
if self.llm_service and hasattr(self.llm_service, 'cancel'):
|
||
self.llm_service.cancel()
|
||
|
||
# Interrupt conversation only if there is no active turn task.
|
||
# When a turn task exists, it will handle end_assistant_turn() to avoid double callbacks.
|
||
if not (self._current_turn_task and not self._current_turn_task.done()):
|
||
await self.conversation.interrupt()
|
||
|
||
# Reset for new user turn
|
||
await self.conversation.start_user_turn()
|
||
self._audio_buffer = b""
|
||
self.eou_detector.reset()
|
||
|
||
async def _stop_current_speech(self) -> None:
|
||
"""Stop any current speech task."""
|
||
if self._current_turn_task and not self._current_turn_task.done():
|
||
self._interrupt_event.set()
|
||
self._current_turn_task.cancel()
|
||
try:
|
||
await self._current_turn_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
|
||
# Ensure underlying services are cancelled to avoid leaking work/audio
|
||
if self.tts_service:
|
||
await self.tts_service.cancel()
|
||
if self.llm_service and hasattr(self.llm_service, 'cancel'):
|
||
self.llm_service.cancel()
|
||
|
||
self._is_bot_speaking = False
|
||
self._interrupt_event.clear()
|
||
|
||
async def cleanup(self) -> None:
|
||
"""Cleanup pipeline resources."""
|
||
logger.info(f"Cleaning up DuplexPipeline for session {self.session_id}")
|
||
|
||
self._running = False
|
||
await self._stop_current_speech()
|
||
|
||
# Disconnect services
|
||
if self.llm_service:
|
||
await self.llm_service.disconnect()
|
||
if self.tts_service:
|
||
await self.tts_service.disconnect()
|
||
if self.asr_service:
|
||
await self.asr_service.disconnect()
|
||
|
||
def _get_timestamp_ms(self) -> int:
|
||
"""Get current timestamp in milliseconds."""
|
||
import time
|
||
return int(time.time() * 1000)
|
||
|
||
@property
|
||
def is_speaking(self) -> bool:
|
||
"""Check if bot is currently speaking."""
|
||
return self._is_bot_speaking
|
||
|
||
@property
|
||
def state(self) -> ConversationState:
|
||
"""Get current conversation state."""
|
||
return self.conversation.state
|