Merge branch 'master' of https://gitea.xiaowang.eu.org/wx44wx/AI-VideoAssistant
This commit is contained in:
@@ -13,37 +13,38 @@ event-driven design.
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Optional, Callable, Awaitable, Dict, Any
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from core.transports import BaseTransport
|
||||
from app.config import settings
|
||||
from core.conversation import ConversationManager, ConversationState
|
||||
from core.events import get_event_bus
|
||||
from processors.vad import VADProcessor, SileroVAD
|
||||
from processors.eou import EouDetector
|
||||
from services.base import BaseLLMService, BaseTTSService, BaseASRService
|
||||
from services.llm import OpenAILLMService, MockLLMService
|
||||
from services.tts import EdgeTTSService, MockTTSService
|
||||
from services.asr import BufferedASRService
|
||||
from services.siliconflow_tts import SiliconFlowTTSService
|
||||
from services.siliconflow_asr import SiliconFlowASRService
|
||||
from app.config import settings
|
||||
from core.transports import BaseTransport
|
||||
from models.ws_v1 import ev
|
||||
from processors.eou import EouDetector
|
||||
from processors.vad import SileroVAD, VADProcessor
|
||||
from services.asr import BufferedASRService
|
||||
from services.base import BaseASRService, BaseLLMService, BaseTTSService
|
||||
from services.llm import MockLLMService, OpenAILLMService
|
||||
from services.siliconflow_asr import SiliconFlowASRService
|
||||
from services.siliconflow_tts import SiliconFlowTTSService
|
||||
from services.tts import EdgeTTSService, MockTTSService
|
||||
|
||||
|
||||
class DuplexPipeline:
|
||||
"""
|
||||
Full duplex audio pipeline for AI voice conversation.
|
||||
|
||||
|
||||
Handles bidirectional audio flow with:
|
||||
- User speech detection and transcription
|
||||
- AI response generation
|
||||
- Text-to-speech synthesis
|
||||
- Barge-in (interruption) support
|
||||
|
||||
|
||||
Architecture (inspired by pipecat):
|
||||
|
||||
|
||||
User Audio → VAD → EOU → [ASR] → LLM → TTS → Audio Out
|
||||
↓
|
||||
Barge-in Detection → Interrupt
|
||||
@@ -52,7 +53,7 @@ class DuplexPipeline:
|
||||
_SENTENCE_END_CHARS = frozenset({"。", "!", "?", ".", "!", "?", "\n"})
|
||||
_SENTENCE_TRAILING_CHARS = frozenset({"。", "!", "?", ".", "!", "?", "…", "~", "~", "\n"})
|
||||
_SENTENCE_CLOSERS = frozenset({'"', "'", "”", "’", ")", "]", "}", ")", "】", "」", "』", "》"})
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transport: BaseTransport,
|
||||
@@ -65,7 +66,7 @@ class DuplexPipeline:
|
||||
):
|
||||
"""
|
||||
Initialize duplex pipeline.
|
||||
|
||||
|
||||
Args:
|
||||
transport: Transport for sending audio/events
|
||||
session_id: Session identifier
|
||||
@@ -78,7 +79,7 @@ class DuplexPipeline:
|
||||
self.transport = transport
|
||||
self.session_id = session_id
|
||||
self.event_bus = get_event_bus()
|
||||
|
||||
|
||||
# Initialize VAD
|
||||
self.vad_model = SileroVAD(
|
||||
model_path=settings.vad_model_path,
|
||||
@@ -88,27 +89,27 @@ class DuplexPipeline:
|
||||
vad_model=self.vad_model,
|
||||
threshold=settings.vad_threshold
|
||||
)
|
||||
|
||||
|
||||
# Initialize EOU detector
|
||||
self.eou_detector = EouDetector(
|
||||
silence_threshold_ms=settings.vad_eou_threshold_ms,
|
||||
min_speech_duration_ms=settings.vad_min_speech_duration_ms
|
||||
)
|
||||
|
||||
|
||||
# Initialize services
|
||||
self.llm_service = llm_service
|
||||
self.tts_service = tts_service
|
||||
self.asr_service = asr_service # Will be initialized in start()
|
||||
|
||||
|
||||
# Track last sent transcript to avoid duplicates
|
||||
self._last_sent_transcript = ""
|
||||
|
||||
|
||||
# Conversation manager
|
||||
self.conversation = ConversationManager(
|
||||
system_prompt=system_prompt,
|
||||
greeting=greeting
|
||||
)
|
||||
|
||||
|
||||
# State
|
||||
self._running = True
|
||||
self._is_bot_speaking = False
|
||||
@@ -118,14 +119,14 @@ class DuplexPipeline:
|
||||
self._max_audio_buffer_bytes = int(settings.sample_rate * 2 * max_buffer_seconds)
|
||||
self._last_vad_status: str = "Silence"
|
||||
self._process_lock = asyncio.Lock()
|
||||
|
||||
|
||||
# Interruption handling
|
||||
self._interrupt_event = asyncio.Event()
|
||||
|
||||
|
||||
# Latency tracking - TTFB (Time to First Byte)
|
||||
self._turn_start_time: Optional[float] = None
|
||||
self._first_audio_sent: bool = False
|
||||
|
||||
|
||||
# Barge-in filtering - require minimum speech duration to interrupt
|
||||
self._barge_in_speech_start_time: Optional[float] = None
|
||||
self._barge_in_min_duration_ms: int = settings.barge_in_min_duration_ms if hasattr(settings, 'barge_in_min_duration_ms') else 50
|
||||
@@ -139,7 +140,7 @@ class DuplexPipeline:
|
||||
self._runtime_tts: Dict[str, Any] = {}
|
||||
self._runtime_system_prompt: Optional[str] = None
|
||||
self._runtime_greeting: Optional[str] = None
|
||||
|
||||
|
||||
logger.info(f"DuplexPipeline initialized for session {session_id}")
|
||||
|
||||
def apply_runtime_overrides(self, metadata: Optional[Dict[str, Any]]) -> None:
|
||||
@@ -176,7 +177,7 @@ class DuplexPipeline:
|
||||
self._runtime_asr = services["asr"]
|
||||
if isinstance(services.get("tts"), dict):
|
||||
self._runtime_tts = services["tts"]
|
||||
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the pipeline and connect services."""
|
||||
try:
|
||||
@@ -196,9 +197,9 @@ class DuplexPipeline:
|
||||
else:
|
||||
logger.warning("No OpenAI API key - using mock LLM")
|
||||
self.llm_service = MockLLMService()
|
||||
|
||||
|
||||
await self.llm_service.connect()
|
||||
|
||||
|
||||
# Connect TTS service
|
||||
if not self.tts_service:
|
||||
tts_provider = (self._runtime_tts.get("provider") or settings.tts_provider).lower()
|
||||
@@ -231,7 +232,7 @@ class DuplexPipeline:
|
||||
sample_rate=settings.sample_rate
|
||||
)
|
||||
await self.tts_service.connect()
|
||||
|
||||
|
||||
# Connect ASR service
|
||||
if not self.asr_service:
|
||||
asr_provider = (self._runtime_asr.get("provider") or settings.asr_provider).lower()
|
||||
@@ -255,41 +256,41 @@ class DuplexPipeline:
|
||||
sample_rate=settings.sample_rate
|
||||
)
|
||||
logger.info("Using Buffered ASR service (no real transcription)")
|
||||
|
||||
|
||||
await self.asr_service.connect()
|
||||
|
||||
|
||||
logger.info("DuplexPipeline services connected")
|
||||
|
||||
|
||||
# Speak greeting if configured
|
||||
if self.conversation.greeting:
|
||||
await self._speak(self.conversation.greeting)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start pipeline: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def process_audio(self, pcm_bytes: bytes) -> None:
|
||||
"""
|
||||
Process incoming audio chunk.
|
||||
|
||||
|
||||
This is the main entry point for audio from the user.
|
||||
|
||||
|
||||
Args:
|
||||
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
|
||||
"""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
async with self._process_lock:
|
||||
# 1. Process through VAD
|
||||
vad_result = self.vad_processor.process(pcm_bytes, settings.chunk_size_ms)
|
||||
|
||||
|
||||
vad_status = "Silence"
|
||||
if vad_result:
|
||||
event_type, probability = vad_result
|
||||
vad_status = "Speech" if event_type == "speaking" else "Silence"
|
||||
|
||||
|
||||
# Emit VAD event
|
||||
await self.event_bus.publish(event_type, {
|
||||
"trackId": self.session_id,
|
||||
@@ -305,20 +306,20 @@ class DuplexPipeline:
|
||||
else:
|
||||
# No state change - keep previous status
|
||||
vad_status = self._last_vad_status
|
||||
|
||||
|
||||
# Update state based on VAD
|
||||
if vad_status == "Speech" and self._last_vad_status != "Speech":
|
||||
await self._on_speech_start()
|
||||
|
||||
|
||||
self._last_vad_status = vad_status
|
||||
|
||||
|
||||
# 2. Check for barge-in (user speaking while bot speaking)
|
||||
# Filter false interruptions by requiring minimum speech duration
|
||||
if self._is_bot_speaking:
|
||||
if vad_status == "Speech":
|
||||
# User is speaking while bot is speaking
|
||||
self._barge_in_silence_frames = 0 # Reset silence counter
|
||||
|
||||
|
||||
if self._barge_in_speech_start_time is None:
|
||||
# Start tracking speech duration
|
||||
self._barge_in_speech_start_time = time.time()
|
||||
@@ -342,7 +343,7 @@ class DuplexPipeline:
|
||||
self._barge_in_speech_start_time = None
|
||||
self._barge_in_speech_frames = 0
|
||||
self._barge_in_silence_frames = 0
|
||||
|
||||
|
||||
# 3. Buffer audio for ASR
|
||||
if vad_status == "Speech" or self.conversation.state == ConversationState.LISTENING:
|
||||
self._audio_buffer += pcm_bytes
|
||||
@@ -350,48 +351,48 @@ class DuplexPipeline:
|
||||
# Keep only the most recent audio to cap memory usage
|
||||
self._audio_buffer = self._audio_buffer[-self._max_audio_buffer_bytes:]
|
||||
await self.asr_service.send_audio(pcm_bytes)
|
||||
|
||||
|
||||
# For SiliconFlow ASR, trigger interim transcription periodically
|
||||
# The service handles timing internally via start_interim_transcription()
|
||||
|
||||
|
||||
# 4. Check for End of Utterance - this triggers LLM response
|
||||
if self.eou_detector.process(vad_status):
|
||||
await self._on_end_of_utterance()
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Pipeline audio processing error: {e}", exc_info=True)
|
||||
|
||||
|
||||
async def process_text(self, text: str) -> None:
|
||||
"""
|
||||
Process text input (chat command).
|
||||
|
||||
|
||||
Allows direct text input to bypass ASR.
|
||||
|
||||
|
||||
Args:
|
||||
text: User text input
|
||||
"""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
|
||||
logger.info(f"Processing text input: {text[:50]}...")
|
||||
|
||||
|
||||
# Cancel any current speaking
|
||||
await self._stop_current_speech()
|
||||
|
||||
|
||||
# Start new turn
|
||||
await self.conversation.end_user_turn(text)
|
||||
self._current_turn_task = asyncio.create_task(self._handle_turn(text))
|
||||
|
||||
|
||||
async def interrupt(self) -> None:
|
||||
"""Interrupt current bot speech (manual interrupt command)."""
|
||||
await self._handle_barge_in()
|
||||
|
||||
|
||||
async def _on_transcript_callback(self, text: str, is_final: bool) -> None:
|
||||
"""
|
||||
Callback for ASR transcription results.
|
||||
|
||||
|
||||
Streams transcription to client for display.
|
||||
|
||||
|
||||
Args:
|
||||
text: Transcribed text
|
||||
is_final: Whether this is the final transcription
|
||||
@@ -399,9 +400,9 @@ class DuplexPipeline:
|
||||
# Avoid sending duplicate transcripts
|
||||
if text == self._last_sent_transcript and not is_final:
|
||||
return
|
||||
|
||||
|
||||
self._last_sent_transcript = text
|
||||
|
||||
|
||||
# Send transcript event to client
|
||||
await self.transport.send_event({
|
||||
**ev(
|
||||
@@ -410,9 +411,9 @@ class DuplexPipeline:
|
||||
text=text,
|
||||
)
|
||||
})
|
||||
|
||||
|
||||
logger.debug(f"Sent transcript ({'final' if is_final else 'interim'}): {text[:50]}...")
|
||||
|
||||
|
||||
async def _on_speech_start(self) -> None:
|
||||
"""Handle user starting to speak."""
|
||||
if self.conversation.state == ConversationState.IDLE:
|
||||
@@ -420,34 +421,34 @@ class DuplexPipeline:
|
||||
self._audio_buffer = b""
|
||||
self._last_sent_transcript = ""
|
||||
self.eou_detector.reset()
|
||||
|
||||
|
||||
# Clear ASR buffer and start interim transcriptions
|
||||
if hasattr(self.asr_service, 'clear_buffer'):
|
||||
self.asr_service.clear_buffer()
|
||||
if hasattr(self.asr_service, 'start_interim_transcription'):
|
||||
await self.asr_service.start_interim_transcription()
|
||||
|
||||
|
||||
logger.debug("User speech started")
|
||||
|
||||
|
||||
async def _on_end_of_utterance(self) -> None:
|
||||
"""Handle end of user utterance."""
|
||||
if self.conversation.state != ConversationState.LISTENING:
|
||||
return
|
||||
|
||||
|
||||
# Stop interim transcriptions
|
||||
if hasattr(self.asr_service, 'stop_interim_transcription'):
|
||||
await self.asr_service.stop_interim_transcription()
|
||||
|
||||
|
||||
# Get final transcription from ASR service
|
||||
user_text = ""
|
||||
|
||||
|
||||
if hasattr(self.asr_service, 'get_final_transcription'):
|
||||
# SiliconFlow ASR - get final transcription
|
||||
user_text = await self.asr_service.get_final_transcription()
|
||||
elif hasattr(self.asr_service, 'get_and_clear_text'):
|
||||
# Buffered ASR - get accumulated text
|
||||
user_text = self.asr_service.get_and_clear_text()
|
||||
|
||||
|
||||
# Skip if no meaningful text
|
||||
if not user_text or not user_text.strip():
|
||||
logger.debug("EOU detected but no transcription - skipping")
|
||||
@@ -457,9 +458,9 @@ class DuplexPipeline:
|
||||
# Return to idle; don't force LISTENING which causes buffering on silence
|
||||
await self.conversation.set_state(ConversationState.IDLE)
|
||||
return
|
||||
|
||||
|
||||
logger.info(f"EOU detected - user said: {user_text[:100]}...")
|
||||
|
||||
|
||||
# Send final transcription to client
|
||||
await self.transport.send_event({
|
||||
**ev(
|
||||
@@ -468,23 +469,23 @@ class DuplexPipeline:
|
||||
text=user_text,
|
||||
)
|
||||
})
|
||||
|
||||
|
||||
# Clear buffers
|
||||
self._audio_buffer = b""
|
||||
self._last_sent_transcript = ""
|
||||
|
||||
|
||||
# Process the turn - trigger LLM response
|
||||
# Cancel any existing turn to avoid overlapping assistant responses
|
||||
await self._stop_current_speech()
|
||||
await self.conversation.end_user_turn(user_text)
|
||||
self._current_turn_task = asyncio.create_task(self._handle_turn(user_text))
|
||||
|
||||
|
||||
async def _handle_turn(self, user_text: str) -> None:
|
||||
"""
|
||||
Handle a complete conversation turn.
|
||||
|
||||
|
||||
Uses sentence-by-sentence streaming TTS for lower latency.
|
||||
|
||||
|
||||
Args:
|
||||
user_text: User's transcribed text
|
||||
"""
|
||||
@@ -492,30 +493,30 @@ class DuplexPipeline:
|
||||
# Start latency tracking
|
||||
self._turn_start_time = time.time()
|
||||
self._first_audio_sent = False
|
||||
|
||||
|
||||
# Get AI response (streaming)
|
||||
messages = self.conversation.get_messages()
|
||||
full_response = ""
|
||||
|
||||
|
||||
await self.conversation.start_assistant_turn()
|
||||
self._is_bot_speaking = True
|
||||
self._interrupt_event.clear()
|
||||
|
||||
|
||||
# Sentence buffer for streaming TTS
|
||||
sentence_buffer = ""
|
||||
pending_punctuation = ""
|
||||
first_audio_sent = False
|
||||
spoken_sentence_count = 0
|
||||
|
||||
|
||||
# Stream LLM response and TTS sentence by sentence
|
||||
async for text_chunk in self.llm_service.generate_stream(messages):
|
||||
if self._interrupt_event.is_set():
|
||||
break
|
||||
|
||||
|
||||
full_response += text_chunk
|
||||
sentence_buffer += text_chunk
|
||||
await self.conversation.update_assistant_text(text_chunk)
|
||||
|
||||
|
||||
# Send LLM response streaming event to client
|
||||
await self.transport.send_event({
|
||||
**ev(
|
||||
@@ -524,7 +525,7 @@ class DuplexPipeline:
|
||||
text=text_chunk,
|
||||
)
|
||||
})
|
||||
|
||||
|
||||
# Check for sentence completion - synthesize immediately for low latency
|
||||
while True:
|
||||
split_result = self._extract_tts_sentence(sentence_buffer, force=False)
|
||||
@@ -561,7 +562,7 @@ class DuplexPipeline:
|
||||
fade_out_ms=8,
|
||||
)
|
||||
spoken_sentence_count += 1
|
||||
|
||||
|
||||
# Send final LLM response event
|
||||
if full_response and not self._interrupt_event.is_set():
|
||||
await self.transport.send_event({
|
||||
@@ -571,7 +572,7 @@ class DuplexPipeline:
|
||||
text=full_response,
|
||||
)
|
||||
})
|
||||
|
||||
|
||||
# Speak any remaining text
|
||||
remaining_text = f"{pending_punctuation}{sentence_buffer}".strip()
|
||||
if remaining_text and self._has_spoken_content(remaining_text) and not self._interrupt_event.is_set():
|
||||
@@ -588,7 +589,7 @@ class DuplexPipeline:
|
||||
fade_in_ms=0,
|
||||
fade_out_ms=8,
|
||||
)
|
||||
|
||||
|
||||
# Send track end
|
||||
if first_audio_sent:
|
||||
await self.transport.send_event({
|
||||
@@ -597,12 +598,12 @@ class DuplexPipeline:
|
||||
trackId=self.session_id,
|
||||
)
|
||||
})
|
||||
|
||||
|
||||
# End assistant turn
|
||||
await self.conversation.end_assistant_turn(
|
||||
was_interrupted=self._interrupt_event.is_set()
|
||||
)
|
||||
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Turn handling cancelled")
|
||||
await self.conversation.end_assistant_turn(was_interrupted=True)
|
||||
@@ -615,7 +616,7 @@ class DuplexPipeline:
|
||||
self._barge_in_speech_start_time = None
|
||||
self._barge_in_speech_frames = 0
|
||||
self._barge_in_silence_frames = 0
|
||||
|
||||
|
||||
def _extract_tts_sentence(self, text_buffer: str, force: bool = False) -> Optional[tuple[str, str]]:
|
||||
"""
|
||||
Extract one TTS sentence from the buffer.
|
||||
@@ -683,7 +684,7 @@ class DuplexPipeline:
|
||||
async def _speak_sentence(self, text: str, fade_in_ms: int = 0, fade_out_ms: int = 8) -> None:
|
||||
"""
|
||||
Synthesize and send a single sentence.
|
||||
|
||||
|
||||
Args:
|
||||
text: Sentence to speak
|
||||
fade_in_ms: Fade-in duration for sentence start chunks
|
||||
@@ -691,8 +692,9 @@ class DuplexPipeline:
|
||||
"""
|
||||
if not text.strip() or self._interrupt_event.is_set():
|
||||
return
|
||||
|
||||
|
||||
logger.info(f"[TTS] split sentence: {text!r}")
|
||||
|
||||
try:
|
||||
is_first_chunk = True
|
||||
async for chunk in self.tts_service.synthesize_stream(text):
|
||||
@@ -700,13 +702,13 @@ class DuplexPipeline:
|
||||
if self._interrupt_event.is_set():
|
||||
logger.debug("TTS sentence interrupted")
|
||||
break
|
||||
|
||||
|
||||
# Track and log first audio packet latency (TTFB)
|
||||
if not self._first_audio_sent and self._turn_start_time:
|
||||
ttfb_ms = (time.time() - self._turn_start_time) * 1000
|
||||
self._first_audio_sent = True
|
||||
logger.info(f"[TTFB] Server first audio packet latency: {ttfb_ms:.0f}ms (session {self.session_id})")
|
||||
|
||||
|
||||
# Send TTFB event to client
|
||||
await self.transport.send_event({
|
||||
**ev(
|
||||
@@ -715,7 +717,7 @@ class DuplexPipeline:
|
||||
latencyMs=round(ttfb_ms),
|
||||
)
|
||||
})
|
||||
|
||||
|
||||
# Double-check interrupt right before sending audio
|
||||
if self._interrupt_event.is_set():
|
||||
break
|
||||
@@ -767,22 +769,22 @@ class DuplexPipeline:
|
||||
except Exception:
|
||||
# Fallback: never block audio delivery on smoothing failure.
|
||||
return pcm_bytes
|
||||
|
||||
|
||||
async def _speak(self, text: str) -> None:
|
||||
"""
|
||||
Synthesize and send speech.
|
||||
|
||||
|
||||
Args:
|
||||
text: Text to speak
|
||||
"""
|
||||
if not text.strip():
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
# Start latency tracking for greeting
|
||||
speak_start_time = time.time()
|
||||
first_audio_sent = False
|
||||
|
||||
|
||||
# Send track start event
|
||||
await self.transport.send_event({
|
||||
**ev(
|
||||
@@ -790,21 +792,21 @@ class DuplexPipeline:
|
||||
trackId=self.session_id,
|
||||
)
|
||||
})
|
||||
|
||||
|
||||
self._is_bot_speaking = True
|
||||
|
||||
|
||||
# Stream TTS audio
|
||||
async for chunk in self.tts_service.synthesize_stream(text):
|
||||
if self._interrupt_event.is_set():
|
||||
logger.info("TTS interrupted by barge-in")
|
||||
break
|
||||
|
||||
|
||||
# Track and log first audio packet latency (TTFB)
|
||||
if not first_audio_sent:
|
||||
ttfb_ms = (time.time() - speak_start_time) * 1000
|
||||
first_audio_sent = True
|
||||
logger.info(f"[TTFB] Greeting first audio packet latency: {ttfb_ms:.0f}ms (session {self.session_id})")
|
||||
|
||||
|
||||
# Send TTFB event to client
|
||||
await self.transport.send_event({
|
||||
**ev(
|
||||
@@ -813,13 +815,13 @@ class DuplexPipeline:
|
||||
latencyMs=round(ttfb_ms),
|
||||
)
|
||||
})
|
||||
|
||||
|
||||
# Send audio to client
|
||||
await self.transport.send_audio(chunk.audio)
|
||||
|
||||
|
||||
# Small delay to prevent flooding
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
|
||||
# Send track end event
|
||||
await self.transport.send_event({
|
||||
**ev(
|
||||
@@ -827,7 +829,7 @@ class DuplexPipeline:
|
||||
trackId=self.session_id,
|
||||
)
|
||||
})
|
||||
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("TTS cancelled")
|
||||
raise
|
||||
@@ -835,23 +837,23 @@ class DuplexPipeline:
|
||||
logger.error(f"TTS error: {e}")
|
||||
finally:
|
||||
self._is_bot_speaking = False
|
||||
|
||||
|
||||
async def _handle_barge_in(self) -> None:
|
||||
"""Handle user barge-in (interruption)."""
|
||||
if not self._is_bot_speaking:
|
||||
return
|
||||
|
||||
|
||||
logger.info("Barge-in detected - interrupting bot speech")
|
||||
|
||||
|
||||
# Reset barge-in tracking
|
||||
self._barge_in_speech_start_time = None
|
||||
self._barge_in_speech_frames = 0
|
||||
self._barge_in_silence_frames = 0
|
||||
|
||||
|
||||
# IMPORTANT: Signal interruption FIRST to stop audio sending
|
||||
self._interrupt_event.set()
|
||||
self._is_bot_speaking = False
|
||||
|
||||
|
||||
# Send interrupt event to client IMMEDIATELY
|
||||
# This must happen BEFORE canceling services, so client knows to discard in-flight audio
|
||||
await self.transport.send_event({
|
||||
@@ -860,25 +862,25 @@ class DuplexPipeline:
|
||||
trackId=self.session_id,
|
||||
)
|
||||
})
|
||||
|
||||
|
||||
# Cancel TTS
|
||||
if self.tts_service:
|
||||
await self.tts_service.cancel()
|
||||
|
||||
|
||||
# Cancel LLM
|
||||
if self.llm_service and hasattr(self.llm_service, 'cancel'):
|
||||
self.llm_service.cancel()
|
||||
|
||||
|
||||
# Interrupt conversation only if there is no active turn task.
|
||||
# When a turn task exists, it will handle end_assistant_turn() to avoid double callbacks.
|
||||
if not (self._current_turn_task and not self._current_turn_task.done()):
|
||||
await self.conversation.interrupt()
|
||||
|
||||
|
||||
# Reset for new user turn
|
||||
await self.conversation.start_user_turn()
|
||||
self._audio_buffer = b""
|
||||
self.eou_detector.reset()
|
||||
|
||||
|
||||
async def _stop_current_speech(self) -> None:
|
||||
"""Stop any current speech task."""
|
||||
if self._current_turn_task and not self._current_turn_task.done():
|
||||
@@ -894,17 +896,17 @@ class DuplexPipeline:
|
||||
await self.tts_service.cancel()
|
||||
if self.llm_service and hasattr(self.llm_service, 'cancel'):
|
||||
self.llm_service.cancel()
|
||||
|
||||
|
||||
self._is_bot_speaking = False
|
||||
self._interrupt_event.clear()
|
||||
|
||||
|
||||
async def cleanup(self) -> None:
|
||||
"""Cleanup pipeline resources."""
|
||||
logger.info(f"Cleaning up DuplexPipeline for session {self.session_id}")
|
||||
|
||||
|
||||
self._running = False
|
||||
await self._stop_current_speech()
|
||||
|
||||
|
||||
# Disconnect services
|
||||
if self.llm_service:
|
||||
await self.llm_service.disconnect()
|
||||
@@ -912,17 +914,17 @@ class DuplexPipeline:
|
||||
await self.tts_service.disconnect()
|
||||
if self.asr_service:
|
||||
await self.asr_service.disconnect()
|
||||
|
||||
|
||||
def _get_timestamp_ms(self) -> int:
|
||||
"""Get current timestamp in milliseconds."""
|
||||
import time
|
||||
return int(time.time() * 1000)
|
||||
|
||||
|
||||
@property
|
||||
def is_speaking(self) -> bool:
|
||||
"""Check if bot is currently speaking."""
|
||||
return self._is_bot_speaking
|
||||
|
||||
|
||||
@property
|
||||
def state(self) -> ConversationState:
|
||||
"""Get current conversation state."""
|
||||
|
||||
Reference in New Issue
Block a user