Merge branch 'master' of https://gitea.xiaowang.eu.org/wx44wx/AI-VideoAssistant
This commit is contained in:
@@ -13,37 +13,38 @@ event-driven design.
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from typing import Optional, Callable, Awaitable, Dict, Any
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from core.transports import BaseTransport
|
from app.config import settings
|
||||||
from core.conversation import ConversationManager, ConversationState
|
from core.conversation import ConversationManager, ConversationState
|
||||||
from core.events import get_event_bus
|
from core.events import get_event_bus
|
||||||
from processors.vad import VADProcessor, SileroVAD
|
from core.transports import BaseTransport
|
||||||
from processors.eou import EouDetector
|
|
||||||
from services.base import BaseLLMService, BaseTTSService, BaseASRService
|
|
||||||
from services.llm import OpenAILLMService, MockLLMService
|
|
||||||
from services.tts import EdgeTTSService, MockTTSService
|
|
||||||
from services.asr import BufferedASRService
|
|
||||||
from services.siliconflow_tts import SiliconFlowTTSService
|
|
||||||
from services.siliconflow_asr import SiliconFlowASRService
|
|
||||||
from app.config import settings
|
|
||||||
from models.ws_v1 import ev
|
from models.ws_v1 import ev
|
||||||
|
from processors.eou import EouDetector
|
||||||
|
from processors.vad import SileroVAD, VADProcessor
|
||||||
|
from services.asr import BufferedASRService
|
||||||
|
from services.base import BaseASRService, BaseLLMService, BaseTTSService
|
||||||
|
from services.llm import MockLLMService, OpenAILLMService
|
||||||
|
from services.siliconflow_asr import SiliconFlowASRService
|
||||||
|
from services.siliconflow_tts import SiliconFlowTTSService
|
||||||
|
from services.tts import EdgeTTSService, MockTTSService
|
||||||
|
|
||||||
|
|
||||||
class DuplexPipeline:
|
class DuplexPipeline:
|
||||||
"""
|
"""
|
||||||
Full duplex audio pipeline for AI voice conversation.
|
Full duplex audio pipeline for AI voice conversation.
|
||||||
|
|
||||||
Handles bidirectional audio flow with:
|
Handles bidirectional audio flow with:
|
||||||
- User speech detection and transcription
|
- User speech detection and transcription
|
||||||
- AI response generation
|
- AI response generation
|
||||||
- Text-to-speech synthesis
|
- Text-to-speech synthesis
|
||||||
- Barge-in (interruption) support
|
- Barge-in (interruption) support
|
||||||
|
|
||||||
Architecture (inspired by pipecat):
|
Architecture (inspired by pipecat):
|
||||||
|
|
||||||
User Audio → VAD → EOU → [ASR] → LLM → TTS → Audio Out
|
User Audio → VAD → EOU → [ASR] → LLM → TTS → Audio Out
|
||||||
↓
|
↓
|
||||||
Barge-in Detection → Interrupt
|
Barge-in Detection → Interrupt
|
||||||
@@ -52,7 +53,7 @@ class DuplexPipeline:
|
|||||||
_SENTENCE_END_CHARS = frozenset({"。", "!", "?", ".", "!", "?", "\n"})
|
_SENTENCE_END_CHARS = frozenset({"。", "!", "?", ".", "!", "?", "\n"})
|
||||||
_SENTENCE_TRAILING_CHARS = frozenset({"。", "!", "?", ".", "!", "?", "…", "~", "~", "\n"})
|
_SENTENCE_TRAILING_CHARS = frozenset({"。", "!", "?", ".", "!", "?", "…", "~", "~", "\n"})
|
||||||
_SENTENCE_CLOSERS = frozenset({'"', "'", "”", "’", ")", "]", "}", ")", "】", "」", "』", "》"})
|
_SENTENCE_CLOSERS = frozenset({'"', "'", "”", "’", ")", "]", "}", ")", "】", "」", "』", "》"})
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
transport: BaseTransport,
|
transport: BaseTransport,
|
||||||
@@ -65,7 +66,7 @@ class DuplexPipeline:
|
|||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize duplex pipeline.
|
Initialize duplex pipeline.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
transport: Transport for sending audio/events
|
transport: Transport for sending audio/events
|
||||||
session_id: Session identifier
|
session_id: Session identifier
|
||||||
@@ -78,7 +79,7 @@ class DuplexPipeline:
|
|||||||
self.transport = transport
|
self.transport = transport
|
||||||
self.session_id = session_id
|
self.session_id = session_id
|
||||||
self.event_bus = get_event_bus()
|
self.event_bus = get_event_bus()
|
||||||
|
|
||||||
# Initialize VAD
|
# Initialize VAD
|
||||||
self.vad_model = SileroVAD(
|
self.vad_model = SileroVAD(
|
||||||
model_path=settings.vad_model_path,
|
model_path=settings.vad_model_path,
|
||||||
@@ -88,27 +89,27 @@ class DuplexPipeline:
|
|||||||
vad_model=self.vad_model,
|
vad_model=self.vad_model,
|
||||||
threshold=settings.vad_threshold
|
threshold=settings.vad_threshold
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize EOU detector
|
# Initialize EOU detector
|
||||||
self.eou_detector = EouDetector(
|
self.eou_detector = EouDetector(
|
||||||
silence_threshold_ms=settings.vad_eou_threshold_ms,
|
silence_threshold_ms=settings.vad_eou_threshold_ms,
|
||||||
min_speech_duration_ms=settings.vad_min_speech_duration_ms
|
min_speech_duration_ms=settings.vad_min_speech_duration_ms
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize services
|
# Initialize services
|
||||||
self.llm_service = llm_service
|
self.llm_service = llm_service
|
||||||
self.tts_service = tts_service
|
self.tts_service = tts_service
|
||||||
self.asr_service = asr_service # Will be initialized in start()
|
self.asr_service = asr_service # Will be initialized in start()
|
||||||
|
|
||||||
# Track last sent transcript to avoid duplicates
|
# Track last sent transcript to avoid duplicates
|
||||||
self._last_sent_transcript = ""
|
self._last_sent_transcript = ""
|
||||||
|
|
||||||
# Conversation manager
|
# Conversation manager
|
||||||
self.conversation = ConversationManager(
|
self.conversation = ConversationManager(
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
greeting=greeting
|
greeting=greeting
|
||||||
)
|
)
|
||||||
|
|
||||||
# State
|
# State
|
||||||
self._running = True
|
self._running = True
|
||||||
self._is_bot_speaking = False
|
self._is_bot_speaking = False
|
||||||
@@ -118,14 +119,14 @@ class DuplexPipeline:
|
|||||||
self._max_audio_buffer_bytes = int(settings.sample_rate * 2 * max_buffer_seconds)
|
self._max_audio_buffer_bytes = int(settings.sample_rate * 2 * max_buffer_seconds)
|
||||||
self._last_vad_status: str = "Silence"
|
self._last_vad_status: str = "Silence"
|
||||||
self._process_lock = asyncio.Lock()
|
self._process_lock = asyncio.Lock()
|
||||||
|
|
||||||
# Interruption handling
|
# Interruption handling
|
||||||
self._interrupt_event = asyncio.Event()
|
self._interrupt_event = asyncio.Event()
|
||||||
|
|
||||||
# Latency tracking - TTFB (Time to First Byte)
|
# Latency tracking - TTFB (Time to First Byte)
|
||||||
self._turn_start_time: Optional[float] = None
|
self._turn_start_time: Optional[float] = None
|
||||||
self._first_audio_sent: bool = False
|
self._first_audio_sent: bool = False
|
||||||
|
|
||||||
# Barge-in filtering - require minimum speech duration to interrupt
|
# Barge-in filtering - require minimum speech duration to interrupt
|
||||||
self._barge_in_speech_start_time: Optional[float] = None
|
self._barge_in_speech_start_time: Optional[float] = None
|
||||||
self._barge_in_min_duration_ms: int = settings.barge_in_min_duration_ms if hasattr(settings, 'barge_in_min_duration_ms') else 50
|
self._barge_in_min_duration_ms: int = settings.barge_in_min_duration_ms if hasattr(settings, 'barge_in_min_duration_ms') else 50
|
||||||
@@ -139,7 +140,7 @@ class DuplexPipeline:
|
|||||||
self._runtime_tts: Dict[str, Any] = {}
|
self._runtime_tts: Dict[str, Any] = {}
|
||||||
self._runtime_system_prompt: Optional[str] = None
|
self._runtime_system_prompt: Optional[str] = None
|
||||||
self._runtime_greeting: Optional[str] = None
|
self._runtime_greeting: Optional[str] = None
|
||||||
|
|
||||||
logger.info(f"DuplexPipeline initialized for session {session_id}")
|
logger.info(f"DuplexPipeline initialized for session {session_id}")
|
||||||
|
|
||||||
def apply_runtime_overrides(self, metadata: Optional[Dict[str, Any]]) -> None:
|
def apply_runtime_overrides(self, metadata: Optional[Dict[str, Any]]) -> None:
|
||||||
@@ -176,7 +177,7 @@ class DuplexPipeline:
|
|||||||
self._runtime_asr = services["asr"]
|
self._runtime_asr = services["asr"]
|
||||||
if isinstance(services.get("tts"), dict):
|
if isinstance(services.get("tts"), dict):
|
||||||
self._runtime_tts = services["tts"]
|
self._runtime_tts = services["tts"]
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Start the pipeline and connect services."""
|
"""Start the pipeline and connect services."""
|
||||||
try:
|
try:
|
||||||
@@ -196,9 +197,9 @@ class DuplexPipeline:
|
|||||||
else:
|
else:
|
||||||
logger.warning("No OpenAI API key - using mock LLM")
|
logger.warning("No OpenAI API key - using mock LLM")
|
||||||
self.llm_service = MockLLMService()
|
self.llm_service = MockLLMService()
|
||||||
|
|
||||||
await self.llm_service.connect()
|
await self.llm_service.connect()
|
||||||
|
|
||||||
# Connect TTS service
|
# Connect TTS service
|
||||||
if not self.tts_service:
|
if not self.tts_service:
|
||||||
tts_provider = (self._runtime_tts.get("provider") or settings.tts_provider).lower()
|
tts_provider = (self._runtime_tts.get("provider") or settings.tts_provider).lower()
|
||||||
@@ -231,7 +232,7 @@ class DuplexPipeline:
|
|||||||
sample_rate=settings.sample_rate
|
sample_rate=settings.sample_rate
|
||||||
)
|
)
|
||||||
await self.tts_service.connect()
|
await self.tts_service.connect()
|
||||||
|
|
||||||
# Connect ASR service
|
# Connect ASR service
|
||||||
if not self.asr_service:
|
if not self.asr_service:
|
||||||
asr_provider = (self._runtime_asr.get("provider") or settings.asr_provider).lower()
|
asr_provider = (self._runtime_asr.get("provider") or settings.asr_provider).lower()
|
||||||
@@ -255,41 +256,41 @@ class DuplexPipeline:
|
|||||||
sample_rate=settings.sample_rate
|
sample_rate=settings.sample_rate
|
||||||
)
|
)
|
||||||
logger.info("Using Buffered ASR service (no real transcription)")
|
logger.info("Using Buffered ASR service (no real transcription)")
|
||||||
|
|
||||||
await self.asr_service.connect()
|
await self.asr_service.connect()
|
||||||
|
|
||||||
logger.info("DuplexPipeline services connected")
|
logger.info("DuplexPipeline services connected")
|
||||||
|
|
||||||
# Speak greeting if configured
|
# Speak greeting if configured
|
||||||
if self.conversation.greeting:
|
if self.conversation.greeting:
|
||||||
await self._speak(self.conversation.greeting)
|
await self._speak(self.conversation.greeting)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to start pipeline: {e}")
|
logger.error(f"Failed to start pipeline: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def process_audio(self, pcm_bytes: bytes) -> None:
|
async def process_audio(self, pcm_bytes: bytes) -> None:
|
||||||
"""
|
"""
|
||||||
Process incoming audio chunk.
|
Process incoming audio chunk.
|
||||||
|
|
||||||
This is the main entry point for audio from the user.
|
This is the main entry point for audio from the user.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
|
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
|
||||||
"""
|
"""
|
||||||
if not self._running:
|
if not self._running:
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with self._process_lock:
|
async with self._process_lock:
|
||||||
# 1. Process through VAD
|
# 1. Process through VAD
|
||||||
vad_result = self.vad_processor.process(pcm_bytes, settings.chunk_size_ms)
|
vad_result = self.vad_processor.process(pcm_bytes, settings.chunk_size_ms)
|
||||||
|
|
||||||
vad_status = "Silence"
|
vad_status = "Silence"
|
||||||
if vad_result:
|
if vad_result:
|
||||||
event_type, probability = vad_result
|
event_type, probability = vad_result
|
||||||
vad_status = "Speech" if event_type == "speaking" else "Silence"
|
vad_status = "Speech" if event_type == "speaking" else "Silence"
|
||||||
|
|
||||||
# Emit VAD event
|
# Emit VAD event
|
||||||
await self.event_bus.publish(event_type, {
|
await self.event_bus.publish(event_type, {
|
||||||
"trackId": self.session_id,
|
"trackId": self.session_id,
|
||||||
@@ -305,20 +306,20 @@ class DuplexPipeline:
|
|||||||
else:
|
else:
|
||||||
# No state change - keep previous status
|
# No state change - keep previous status
|
||||||
vad_status = self._last_vad_status
|
vad_status = self._last_vad_status
|
||||||
|
|
||||||
# Update state based on VAD
|
# Update state based on VAD
|
||||||
if vad_status == "Speech" and self._last_vad_status != "Speech":
|
if vad_status == "Speech" and self._last_vad_status != "Speech":
|
||||||
await self._on_speech_start()
|
await self._on_speech_start()
|
||||||
|
|
||||||
self._last_vad_status = vad_status
|
self._last_vad_status = vad_status
|
||||||
|
|
||||||
# 2. Check for barge-in (user speaking while bot speaking)
|
# 2. Check for barge-in (user speaking while bot speaking)
|
||||||
# Filter false interruptions by requiring minimum speech duration
|
# Filter false interruptions by requiring minimum speech duration
|
||||||
if self._is_bot_speaking:
|
if self._is_bot_speaking:
|
||||||
if vad_status == "Speech":
|
if vad_status == "Speech":
|
||||||
# User is speaking while bot is speaking
|
# User is speaking while bot is speaking
|
||||||
self._barge_in_silence_frames = 0 # Reset silence counter
|
self._barge_in_silence_frames = 0 # Reset silence counter
|
||||||
|
|
||||||
if self._barge_in_speech_start_time is None:
|
if self._barge_in_speech_start_time is None:
|
||||||
# Start tracking speech duration
|
# Start tracking speech duration
|
||||||
self._barge_in_speech_start_time = time.time()
|
self._barge_in_speech_start_time = time.time()
|
||||||
@@ -342,7 +343,7 @@ class DuplexPipeline:
|
|||||||
self._barge_in_speech_start_time = None
|
self._barge_in_speech_start_time = None
|
||||||
self._barge_in_speech_frames = 0
|
self._barge_in_speech_frames = 0
|
||||||
self._barge_in_silence_frames = 0
|
self._barge_in_silence_frames = 0
|
||||||
|
|
||||||
# 3. Buffer audio for ASR
|
# 3. Buffer audio for ASR
|
||||||
if vad_status == "Speech" or self.conversation.state == ConversationState.LISTENING:
|
if vad_status == "Speech" or self.conversation.state == ConversationState.LISTENING:
|
||||||
self._audio_buffer += pcm_bytes
|
self._audio_buffer += pcm_bytes
|
||||||
@@ -350,48 +351,48 @@ class DuplexPipeline:
|
|||||||
# Keep only the most recent audio to cap memory usage
|
# Keep only the most recent audio to cap memory usage
|
||||||
self._audio_buffer = self._audio_buffer[-self._max_audio_buffer_bytes:]
|
self._audio_buffer = self._audio_buffer[-self._max_audio_buffer_bytes:]
|
||||||
await self.asr_service.send_audio(pcm_bytes)
|
await self.asr_service.send_audio(pcm_bytes)
|
||||||
|
|
||||||
# For SiliconFlow ASR, trigger interim transcription periodically
|
# For SiliconFlow ASR, trigger interim transcription periodically
|
||||||
# The service handles timing internally via start_interim_transcription()
|
# The service handles timing internally via start_interim_transcription()
|
||||||
|
|
||||||
# 4. Check for End of Utterance - this triggers LLM response
|
# 4. Check for End of Utterance - this triggers LLM response
|
||||||
if self.eou_detector.process(vad_status):
|
if self.eou_detector.process(vad_status):
|
||||||
await self._on_end_of_utterance()
|
await self._on_end_of_utterance()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Pipeline audio processing error: {e}", exc_info=True)
|
logger.error(f"Pipeline audio processing error: {e}", exc_info=True)
|
||||||
|
|
||||||
async def process_text(self, text: str) -> None:
|
async def process_text(self, text: str) -> None:
|
||||||
"""
|
"""
|
||||||
Process text input (chat command).
|
Process text input (chat command).
|
||||||
|
|
||||||
Allows direct text input to bypass ASR.
|
Allows direct text input to bypass ASR.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: User text input
|
text: User text input
|
||||||
"""
|
"""
|
||||||
if not self._running:
|
if not self._running:
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(f"Processing text input: {text[:50]}...")
|
logger.info(f"Processing text input: {text[:50]}...")
|
||||||
|
|
||||||
# Cancel any current speaking
|
# Cancel any current speaking
|
||||||
await self._stop_current_speech()
|
await self._stop_current_speech()
|
||||||
|
|
||||||
# Start new turn
|
# Start new turn
|
||||||
await self.conversation.end_user_turn(text)
|
await self.conversation.end_user_turn(text)
|
||||||
self._current_turn_task = asyncio.create_task(self._handle_turn(text))
|
self._current_turn_task = asyncio.create_task(self._handle_turn(text))
|
||||||
|
|
||||||
async def interrupt(self) -> None:
|
async def interrupt(self) -> None:
|
||||||
"""Interrupt current bot speech (manual interrupt command)."""
|
"""Interrupt current bot speech (manual interrupt command)."""
|
||||||
await self._handle_barge_in()
|
await self._handle_barge_in()
|
||||||
|
|
||||||
async def _on_transcript_callback(self, text: str, is_final: bool) -> None:
|
async def _on_transcript_callback(self, text: str, is_final: bool) -> None:
|
||||||
"""
|
"""
|
||||||
Callback for ASR transcription results.
|
Callback for ASR transcription results.
|
||||||
|
|
||||||
Streams transcription to client for display.
|
Streams transcription to client for display.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: Transcribed text
|
text: Transcribed text
|
||||||
is_final: Whether this is the final transcription
|
is_final: Whether this is the final transcription
|
||||||
@@ -399,9 +400,9 @@ class DuplexPipeline:
|
|||||||
# Avoid sending duplicate transcripts
|
# Avoid sending duplicate transcripts
|
||||||
if text == self._last_sent_transcript and not is_final:
|
if text == self._last_sent_transcript and not is_final:
|
||||||
return
|
return
|
||||||
|
|
||||||
self._last_sent_transcript = text
|
self._last_sent_transcript = text
|
||||||
|
|
||||||
# Send transcript event to client
|
# Send transcript event to client
|
||||||
await self.transport.send_event({
|
await self.transport.send_event({
|
||||||
**ev(
|
**ev(
|
||||||
@@ -410,9 +411,9 @@ class DuplexPipeline:
|
|||||||
text=text,
|
text=text,
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
logger.debug(f"Sent transcript ({'final' if is_final else 'interim'}): {text[:50]}...")
|
logger.debug(f"Sent transcript ({'final' if is_final else 'interim'}): {text[:50]}...")
|
||||||
|
|
||||||
async def _on_speech_start(self) -> None:
|
async def _on_speech_start(self) -> None:
|
||||||
"""Handle user starting to speak."""
|
"""Handle user starting to speak."""
|
||||||
if self.conversation.state == ConversationState.IDLE:
|
if self.conversation.state == ConversationState.IDLE:
|
||||||
@@ -420,34 +421,34 @@ class DuplexPipeline:
|
|||||||
self._audio_buffer = b""
|
self._audio_buffer = b""
|
||||||
self._last_sent_transcript = ""
|
self._last_sent_transcript = ""
|
||||||
self.eou_detector.reset()
|
self.eou_detector.reset()
|
||||||
|
|
||||||
# Clear ASR buffer and start interim transcriptions
|
# Clear ASR buffer and start interim transcriptions
|
||||||
if hasattr(self.asr_service, 'clear_buffer'):
|
if hasattr(self.asr_service, 'clear_buffer'):
|
||||||
self.asr_service.clear_buffer()
|
self.asr_service.clear_buffer()
|
||||||
if hasattr(self.asr_service, 'start_interim_transcription'):
|
if hasattr(self.asr_service, 'start_interim_transcription'):
|
||||||
await self.asr_service.start_interim_transcription()
|
await self.asr_service.start_interim_transcription()
|
||||||
|
|
||||||
logger.debug("User speech started")
|
logger.debug("User speech started")
|
||||||
|
|
||||||
async def _on_end_of_utterance(self) -> None:
|
async def _on_end_of_utterance(self) -> None:
|
||||||
"""Handle end of user utterance."""
|
"""Handle end of user utterance."""
|
||||||
if self.conversation.state != ConversationState.LISTENING:
|
if self.conversation.state != ConversationState.LISTENING:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Stop interim transcriptions
|
# Stop interim transcriptions
|
||||||
if hasattr(self.asr_service, 'stop_interim_transcription'):
|
if hasattr(self.asr_service, 'stop_interim_transcription'):
|
||||||
await self.asr_service.stop_interim_transcription()
|
await self.asr_service.stop_interim_transcription()
|
||||||
|
|
||||||
# Get final transcription from ASR service
|
# Get final transcription from ASR service
|
||||||
user_text = ""
|
user_text = ""
|
||||||
|
|
||||||
if hasattr(self.asr_service, 'get_final_transcription'):
|
if hasattr(self.asr_service, 'get_final_transcription'):
|
||||||
# SiliconFlow ASR - get final transcription
|
# SiliconFlow ASR - get final transcription
|
||||||
user_text = await self.asr_service.get_final_transcription()
|
user_text = await self.asr_service.get_final_transcription()
|
||||||
elif hasattr(self.asr_service, 'get_and_clear_text'):
|
elif hasattr(self.asr_service, 'get_and_clear_text'):
|
||||||
# Buffered ASR - get accumulated text
|
# Buffered ASR - get accumulated text
|
||||||
user_text = self.asr_service.get_and_clear_text()
|
user_text = self.asr_service.get_and_clear_text()
|
||||||
|
|
||||||
# Skip if no meaningful text
|
# Skip if no meaningful text
|
||||||
if not user_text or not user_text.strip():
|
if not user_text or not user_text.strip():
|
||||||
logger.debug("EOU detected but no transcription - skipping")
|
logger.debug("EOU detected but no transcription - skipping")
|
||||||
@@ -457,9 +458,9 @@ class DuplexPipeline:
|
|||||||
# Return to idle; don't force LISTENING which causes buffering on silence
|
# Return to idle; don't force LISTENING which causes buffering on silence
|
||||||
await self.conversation.set_state(ConversationState.IDLE)
|
await self.conversation.set_state(ConversationState.IDLE)
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(f"EOU detected - user said: {user_text[:100]}...")
|
logger.info(f"EOU detected - user said: {user_text[:100]}...")
|
||||||
|
|
||||||
# Send final transcription to client
|
# Send final transcription to client
|
||||||
await self.transport.send_event({
|
await self.transport.send_event({
|
||||||
**ev(
|
**ev(
|
||||||
@@ -468,23 +469,23 @@ class DuplexPipeline:
|
|||||||
text=user_text,
|
text=user_text,
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
# Clear buffers
|
# Clear buffers
|
||||||
self._audio_buffer = b""
|
self._audio_buffer = b""
|
||||||
self._last_sent_transcript = ""
|
self._last_sent_transcript = ""
|
||||||
|
|
||||||
# Process the turn - trigger LLM response
|
# Process the turn - trigger LLM response
|
||||||
# Cancel any existing turn to avoid overlapping assistant responses
|
# Cancel any existing turn to avoid overlapping assistant responses
|
||||||
await self._stop_current_speech()
|
await self._stop_current_speech()
|
||||||
await self.conversation.end_user_turn(user_text)
|
await self.conversation.end_user_turn(user_text)
|
||||||
self._current_turn_task = asyncio.create_task(self._handle_turn(user_text))
|
self._current_turn_task = asyncio.create_task(self._handle_turn(user_text))
|
||||||
|
|
||||||
async def _handle_turn(self, user_text: str) -> None:
|
async def _handle_turn(self, user_text: str) -> None:
|
||||||
"""
|
"""
|
||||||
Handle a complete conversation turn.
|
Handle a complete conversation turn.
|
||||||
|
|
||||||
Uses sentence-by-sentence streaming TTS for lower latency.
|
Uses sentence-by-sentence streaming TTS for lower latency.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_text: User's transcribed text
|
user_text: User's transcribed text
|
||||||
"""
|
"""
|
||||||
@@ -492,30 +493,30 @@ class DuplexPipeline:
|
|||||||
# Start latency tracking
|
# Start latency tracking
|
||||||
self._turn_start_time = time.time()
|
self._turn_start_time = time.time()
|
||||||
self._first_audio_sent = False
|
self._first_audio_sent = False
|
||||||
|
|
||||||
# Get AI response (streaming)
|
# Get AI response (streaming)
|
||||||
messages = self.conversation.get_messages()
|
messages = self.conversation.get_messages()
|
||||||
full_response = ""
|
full_response = ""
|
||||||
|
|
||||||
await self.conversation.start_assistant_turn()
|
await self.conversation.start_assistant_turn()
|
||||||
self._is_bot_speaking = True
|
self._is_bot_speaking = True
|
||||||
self._interrupt_event.clear()
|
self._interrupt_event.clear()
|
||||||
|
|
||||||
# Sentence buffer for streaming TTS
|
# Sentence buffer for streaming TTS
|
||||||
sentence_buffer = ""
|
sentence_buffer = ""
|
||||||
pending_punctuation = ""
|
pending_punctuation = ""
|
||||||
first_audio_sent = False
|
first_audio_sent = False
|
||||||
spoken_sentence_count = 0
|
spoken_sentence_count = 0
|
||||||
|
|
||||||
# Stream LLM response and TTS sentence by sentence
|
# Stream LLM response and TTS sentence by sentence
|
||||||
async for text_chunk in self.llm_service.generate_stream(messages):
|
async for text_chunk in self.llm_service.generate_stream(messages):
|
||||||
if self._interrupt_event.is_set():
|
if self._interrupt_event.is_set():
|
||||||
break
|
break
|
||||||
|
|
||||||
full_response += text_chunk
|
full_response += text_chunk
|
||||||
sentence_buffer += text_chunk
|
sentence_buffer += text_chunk
|
||||||
await self.conversation.update_assistant_text(text_chunk)
|
await self.conversation.update_assistant_text(text_chunk)
|
||||||
|
|
||||||
# Send LLM response streaming event to client
|
# Send LLM response streaming event to client
|
||||||
await self.transport.send_event({
|
await self.transport.send_event({
|
||||||
**ev(
|
**ev(
|
||||||
@@ -524,7 +525,7 @@ class DuplexPipeline:
|
|||||||
text=text_chunk,
|
text=text_chunk,
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
# Check for sentence completion - synthesize immediately for low latency
|
# Check for sentence completion - synthesize immediately for low latency
|
||||||
while True:
|
while True:
|
||||||
split_result = self._extract_tts_sentence(sentence_buffer, force=False)
|
split_result = self._extract_tts_sentence(sentence_buffer, force=False)
|
||||||
@@ -561,7 +562,7 @@ class DuplexPipeline:
|
|||||||
fade_out_ms=8,
|
fade_out_ms=8,
|
||||||
)
|
)
|
||||||
spoken_sentence_count += 1
|
spoken_sentence_count += 1
|
||||||
|
|
||||||
# Send final LLM response event
|
# Send final LLM response event
|
||||||
if full_response and not self._interrupt_event.is_set():
|
if full_response and not self._interrupt_event.is_set():
|
||||||
await self.transport.send_event({
|
await self.transport.send_event({
|
||||||
@@ -571,7 +572,7 @@ class DuplexPipeline:
|
|||||||
text=full_response,
|
text=full_response,
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
# Speak any remaining text
|
# Speak any remaining text
|
||||||
remaining_text = f"{pending_punctuation}{sentence_buffer}".strip()
|
remaining_text = f"{pending_punctuation}{sentence_buffer}".strip()
|
||||||
if remaining_text and self._has_spoken_content(remaining_text) and not self._interrupt_event.is_set():
|
if remaining_text and self._has_spoken_content(remaining_text) and not self._interrupt_event.is_set():
|
||||||
@@ -588,7 +589,7 @@ class DuplexPipeline:
|
|||||||
fade_in_ms=0,
|
fade_in_ms=0,
|
||||||
fade_out_ms=8,
|
fade_out_ms=8,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send track end
|
# Send track end
|
||||||
if first_audio_sent:
|
if first_audio_sent:
|
||||||
await self.transport.send_event({
|
await self.transport.send_event({
|
||||||
@@ -597,12 +598,12 @@ class DuplexPipeline:
|
|||||||
trackId=self.session_id,
|
trackId=self.session_id,
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
# End assistant turn
|
# End assistant turn
|
||||||
await self.conversation.end_assistant_turn(
|
await self.conversation.end_assistant_turn(
|
||||||
was_interrupted=self._interrupt_event.is_set()
|
was_interrupted=self._interrupt_event.is_set()
|
||||||
)
|
)
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.info("Turn handling cancelled")
|
logger.info("Turn handling cancelled")
|
||||||
await self.conversation.end_assistant_turn(was_interrupted=True)
|
await self.conversation.end_assistant_turn(was_interrupted=True)
|
||||||
@@ -615,7 +616,7 @@ class DuplexPipeline:
|
|||||||
self._barge_in_speech_start_time = None
|
self._barge_in_speech_start_time = None
|
||||||
self._barge_in_speech_frames = 0
|
self._barge_in_speech_frames = 0
|
||||||
self._barge_in_silence_frames = 0
|
self._barge_in_silence_frames = 0
|
||||||
|
|
||||||
def _extract_tts_sentence(self, text_buffer: str, force: bool = False) -> Optional[tuple[str, str]]:
|
def _extract_tts_sentence(self, text_buffer: str, force: bool = False) -> Optional[tuple[str, str]]:
|
||||||
"""
|
"""
|
||||||
Extract one TTS sentence from the buffer.
|
Extract one TTS sentence from the buffer.
|
||||||
@@ -683,7 +684,7 @@ class DuplexPipeline:
|
|||||||
async def _speak_sentence(self, text: str, fade_in_ms: int = 0, fade_out_ms: int = 8) -> None:
|
async def _speak_sentence(self, text: str, fade_in_ms: int = 0, fade_out_ms: int = 8) -> None:
|
||||||
"""
|
"""
|
||||||
Synthesize and send a single sentence.
|
Synthesize and send a single sentence.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: Sentence to speak
|
text: Sentence to speak
|
||||||
fade_in_ms: Fade-in duration for sentence start chunks
|
fade_in_ms: Fade-in duration for sentence start chunks
|
||||||
@@ -691,8 +692,9 @@ class DuplexPipeline:
|
|||||||
"""
|
"""
|
||||||
if not text.strip() or self._interrupt_event.is_set():
|
if not text.strip() or self._interrupt_event.is_set():
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(f"[TTS] split sentence: {text!r}")
|
logger.info(f"[TTS] split sentence: {text!r}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
is_first_chunk = True
|
is_first_chunk = True
|
||||||
async for chunk in self.tts_service.synthesize_stream(text):
|
async for chunk in self.tts_service.synthesize_stream(text):
|
||||||
@@ -700,13 +702,13 @@ class DuplexPipeline:
|
|||||||
if self._interrupt_event.is_set():
|
if self._interrupt_event.is_set():
|
||||||
logger.debug("TTS sentence interrupted")
|
logger.debug("TTS sentence interrupted")
|
||||||
break
|
break
|
||||||
|
|
||||||
# Track and log first audio packet latency (TTFB)
|
# Track and log first audio packet latency (TTFB)
|
||||||
if not self._first_audio_sent and self._turn_start_time:
|
if not self._first_audio_sent and self._turn_start_time:
|
||||||
ttfb_ms = (time.time() - self._turn_start_time) * 1000
|
ttfb_ms = (time.time() - self._turn_start_time) * 1000
|
||||||
self._first_audio_sent = True
|
self._first_audio_sent = True
|
||||||
logger.info(f"[TTFB] Server first audio packet latency: {ttfb_ms:.0f}ms (session {self.session_id})")
|
logger.info(f"[TTFB] Server first audio packet latency: {ttfb_ms:.0f}ms (session {self.session_id})")
|
||||||
|
|
||||||
# Send TTFB event to client
|
# Send TTFB event to client
|
||||||
await self.transport.send_event({
|
await self.transport.send_event({
|
||||||
**ev(
|
**ev(
|
||||||
@@ -715,7 +717,7 @@ class DuplexPipeline:
|
|||||||
latencyMs=round(ttfb_ms),
|
latencyMs=round(ttfb_ms),
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
# Double-check interrupt right before sending audio
|
# Double-check interrupt right before sending audio
|
||||||
if self._interrupt_event.is_set():
|
if self._interrupt_event.is_set():
|
||||||
break
|
break
|
||||||
@@ -767,22 +769,22 @@ class DuplexPipeline:
|
|||||||
except Exception:
|
except Exception:
|
||||||
# Fallback: never block audio delivery on smoothing failure.
|
# Fallback: never block audio delivery on smoothing failure.
|
||||||
return pcm_bytes
|
return pcm_bytes
|
||||||
|
|
||||||
async def _speak(self, text: str) -> None:
|
async def _speak(self, text: str) -> None:
|
||||||
"""
|
"""
|
||||||
Synthesize and send speech.
|
Synthesize and send speech.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: Text to speak
|
text: Text to speak
|
||||||
"""
|
"""
|
||||||
if not text.strip():
|
if not text.strip():
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Start latency tracking for greeting
|
# Start latency tracking for greeting
|
||||||
speak_start_time = time.time()
|
speak_start_time = time.time()
|
||||||
first_audio_sent = False
|
first_audio_sent = False
|
||||||
|
|
||||||
# Send track start event
|
# Send track start event
|
||||||
await self.transport.send_event({
|
await self.transport.send_event({
|
||||||
**ev(
|
**ev(
|
||||||
@@ -790,21 +792,21 @@ class DuplexPipeline:
|
|||||||
trackId=self.session_id,
|
trackId=self.session_id,
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
self._is_bot_speaking = True
|
self._is_bot_speaking = True
|
||||||
|
|
||||||
# Stream TTS audio
|
# Stream TTS audio
|
||||||
async for chunk in self.tts_service.synthesize_stream(text):
|
async for chunk in self.tts_service.synthesize_stream(text):
|
||||||
if self._interrupt_event.is_set():
|
if self._interrupt_event.is_set():
|
||||||
logger.info("TTS interrupted by barge-in")
|
logger.info("TTS interrupted by barge-in")
|
||||||
break
|
break
|
||||||
|
|
||||||
# Track and log first audio packet latency (TTFB)
|
# Track and log first audio packet latency (TTFB)
|
||||||
if not first_audio_sent:
|
if not first_audio_sent:
|
||||||
ttfb_ms = (time.time() - speak_start_time) * 1000
|
ttfb_ms = (time.time() - speak_start_time) * 1000
|
||||||
first_audio_sent = True
|
first_audio_sent = True
|
||||||
logger.info(f"[TTFB] Greeting first audio packet latency: {ttfb_ms:.0f}ms (session {self.session_id})")
|
logger.info(f"[TTFB] Greeting first audio packet latency: {ttfb_ms:.0f}ms (session {self.session_id})")
|
||||||
|
|
||||||
# Send TTFB event to client
|
# Send TTFB event to client
|
||||||
await self.transport.send_event({
|
await self.transport.send_event({
|
||||||
**ev(
|
**ev(
|
||||||
@@ -813,13 +815,13 @@ class DuplexPipeline:
|
|||||||
latencyMs=round(ttfb_ms),
|
latencyMs=round(ttfb_ms),
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
# Send audio to client
|
# Send audio to client
|
||||||
await self.transport.send_audio(chunk.audio)
|
await self.transport.send_audio(chunk.audio)
|
||||||
|
|
||||||
# Small delay to prevent flooding
|
# Small delay to prevent flooding
|
||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
|
|
||||||
# Send track end event
|
# Send track end event
|
||||||
await self.transport.send_event({
|
await self.transport.send_event({
|
||||||
**ev(
|
**ev(
|
||||||
@@ -827,7 +829,7 @@ class DuplexPipeline:
|
|||||||
trackId=self.session_id,
|
trackId=self.session_id,
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.info("TTS cancelled")
|
logger.info("TTS cancelled")
|
||||||
raise
|
raise
|
||||||
@@ -835,23 +837,23 @@ class DuplexPipeline:
|
|||||||
logger.error(f"TTS error: {e}")
|
logger.error(f"TTS error: {e}")
|
||||||
finally:
|
finally:
|
||||||
self._is_bot_speaking = False
|
self._is_bot_speaking = False
|
||||||
|
|
||||||
async def _handle_barge_in(self) -> None:
|
async def _handle_barge_in(self) -> None:
|
||||||
"""Handle user barge-in (interruption)."""
|
"""Handle user barge-in (interruption)."""
|
||||||
if not self._is_bot_speaking:
|
if not self._is_bot_speaking:
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info("Barge-in detected - interrupting bot speech")
|
logger.info("Barge-in detected - interrupting bot speech")
|
||||||
|
|
||||||
# Reset barge-in tracking
|
# Reset barge-in tracking
|
||||||
self._barge_in_speech_start_time = None
|
self._barge_in_speech_start_time = None
|
||||||
self._barge_in_speech_frames = 0
|
self._barge_in_speech_frames = 0
|
||||||
self._barge_in_silence_frames = 0
|
self._barge_in_silence_frames = 0
|
||||||
|
|
||||||
# IMPORTANT: Signal interruption FIRST to stop audio sending
|
# IMPORTANT: Signal interruption FIRST to stop audio sending
|
||||||
self._interrupt_event.set()
|
self._interrupt_event.set()
|
||||||
self._is_bot_speaking = False
|
self._is_bot_speaking = False
|
||||||
|
|
||||||
# Send interrupt event to client IMMEDIATELY
|
# Send interrupt event to client IMMEDIATELY
|
||||||
# This must happen BEFORE canceling services, so client knows to discard in-flight audio
|
# This must happen BEFORE canceling services, so client knows to discard in-flight audio
|
||||||
await self.transport.send_event({
|
await self.transport.send_event({
|
||||||
@@ -860,25 +862,25 @@ class DuplexPipeline:
|
|||||||
trackId=self.session_id,
|
trackId=self.session_id,
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
# Cancel TTS
|
# Cancel TTS
|
||||||
if self.tts_service:
|
if self.tts_service:
|
||||||
await self.tts_service.cancel()
|
await self.tts_service.cancel()
|
||||||
|
|
||||||
# Cancel LLM
|
# Cancel LLM
|
||||||
if self.llm_service and hasattr(self.llm_service, 'cancel'):
|
if self.llm_service and hasattr(self.llm_service, 'cancel'):
|
||||||
self.llm_service.cancel()
|
self.llm_service.cancel()
|
||||||
|
|
||||||
# Interrupt conversation only if there is no active turn task.
|
# Interrupt conversation only if there is no active turn task.
|
||||||
# When a turn task exists, it will handle end_assistant_turn() to avoid double callbacks.
|
# When a turn task exists, it will handle end_assistant_turn() to avoid double callbacks.
|
||||||
if not (self._current_turn_task and not self._current_turn_task.done()):
|
if not (self._current_turn_task and not self._current_turn_task.done()):
|
||||||
await self.conversation.interrupt()
|
await self.conversation.interrupt()
|
||||||
|
|
||||||
# Reset for new user turn
|
# Reset for new user turn
|
||||||
await self.conversation.start_user_turn()
|
await self.conversation.start_user_turn()
|
||||||
self._audio_buffer = b""
|
self._audio_buffer = b""
|
||||||
self.eou_detector.reset()
|
self.eou_detector.reset()
|
||||||
|
|
||||||
async def _stop_current_speech(self) -> None:
|
async def _stop_current_speech(self) -> None:
|
||||||
"""Stop any current speech task."""
|
"""Stop any current speech task."""
|
||||||
if self._current_turn_task and not self._current_turn_task.done():
|
if self._current_turn_task and not self._current_turn_task.done():
|
||||||
@@ -894,17 +896,17 @@ class DuplexPipeline:
|
|||||||
await self.tts_service.cancel()
|
await self.tts_service.cancel()
|
||||||
if self.llm_service and hasattr(self.llm_service, 'cancel'):
|
if self.llm_service and hasattr(self.llm_service, 'cancel'):
|
||||||
self.llm_service.cancel()
|
self.llm_service.cancel()
|
||||||
|
|
||||||
self._is_bot_speaking = False
|
self._is_bot_speaking = False
|
||||||
self._interrupt_event.clear()
|
self._interrupt_event.clear()
|
||||||
|
|
||||||
async def cleanup(self) -> None:
|
async def cleanup(self) -> None:
|
||||||
"""Cleanup pipeline resources."""
|
"""Cleanup pipeline resources."""
|
||||||
logger.info(f"Cleaning up DuplexPipeline for session {self.session_id}")
|
logger.info(f"Cleaning up DuplexPipeline for session {self.session_id}")
|
||||||
|
|
||||||
self._running = False
|
self._running = False
|
||||||
await self._stop_current_speech()
|
await self._stop_current_speech()
|
||||||
|
|
||||||
# Disconnect services
|
# Disconnect services
|
||||||
if self.llm_service:
|
if self.llm_service:
|
||||||
await self.llm_service.disconnect()
|
await self.llm_service.disconnect()
|
||||||
@@ -912,17 +914,17 @@ class DuplexPipeline:
|
|||||||
await self.tts_service.disconnect()
|
await self.tts_service.disconnect()
|
||||||
if self.asr_service:
|
if self.asr_service:
|
||||||
await self.asr_service.disconnect()
|
await self.asr_service.disconnect()
|
||||||
|
|
||||||
def _get_timestamp_ms(self) -> int:
|
def _get_timestamp_ms(self) -> int:
|
||||||
"""Get current timestamp in milliseconds."""
|
"""Get current timestamp in milliseconds."""
|
||||||
import time
|
import time
|
||||||
return int(time.time() * 1000)
|
return int(time.time() * 1000)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_speaking(self) -> bool:
|
def is_speaking(self) -> bool:
|
||||||
"""Check if bot is currently speaking."""
|
"""Check if bot is currently speaking."""
|
||||||
return self._is_bot_speaking
|
return self._is_bot_speaking
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def state(self) -> ConversationState:
|
def state(self) -> ConversationState:
|
||||||
"""Get current conversation state."""
|
"""Get current conversation state."""
|
||||||
|
|||||||
Reference in New Issue
Block a user