I can use text to get audio response and barge in
This commit is contained in:
509
core/duplex_pipeline.py
Normal file
509
core/duplex_pipeline.py
Normal file
@@ -0,0 +1,509 @@
|
||||
"""Full duplex audio pipeline for AI voice conversation.
|
||||
|
||||
This module implements the core duplex pipeline that orchestrates:
|
||||
- VAD (Voice Activity Detection)
|
||||
- EOU (End of Utterance) Detection
|
||||
- ASR (Automatic Speech Recognition) - optional
|
||||
- LLM (Language Model)
|
||||
- TTS (Text-to-Speech)
|
||||
|
||||
Inspired by pipecat's frame-based architecture and active-call's
|
||||
event-driven design.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Optional, Callable, Awaitable
|
||||
from loguru import logger
|
||||
|
||||
from core.transports import BaseTransport
|
||||
from core.conversation import ConversationManager, ConversationState
|
||||
from core.events import get_event_bus
|
||||
from processors.vad import VADProcessor, SileroVAD
|
||||
from processors.eou import EouDetector
|
||||
from services.base import BaseLLMService, BaseTTSService, BaseASRService
|
||||
from services.llm import OpenAILLMService, MockLLMService
|
||||
from services.tts import EdgeTTSService, MockTTSService
|
||||
from services.asr import BufferedASRService
|
||||
from services.siliconflow_tts import SiliconFlowTTSService
|
||||
from app.config import settings
|
||||
|
||||
|
||||
class DuplexPipeline:
|
||||
"""
|
||||
Full duplex audio pipeline for AI voice conversation.
|
||||
|
||||
Handles bidirectional audio flow with:
|
||||
- User speech detection and transcription
|
||||
- AI response generation
|
||||
- Text-to-speech synthesis
|
||||
- Barge-in (interruption) support
|
||||
|
||||
Architecture (inspired by pipecat):
|
||||
|
||||
User Audio → VAD → EOU → [ASR] → LLM → TTS → Audio Out
|
||||
↓
|
||||
Barge-in Detection → Interrupt
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transport: BaseTransport,
|
||||
session_id: str,
|
||||
llm_service: Optional[BaseLLMService] = None,
|
||||
tts_service: Optional[BaseTTSService] = None,
|
||||
asr_service: Optional[BaseASRService] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
greeting: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize duplex pipeline.
|
||||
|
||||
Args:
|
||||
transport: Transport for sending audio/events
|
||||
session_id: Session identifier
|
||||
llm_service: LLM service (defaults to OpenAI)
|
||||
tts_service: TTS service (defaults to EdgeTTS)
|
||||
asr_service: ASR service (optional)
|
||||
system_prompt: System prompt for LLM
|
||||
greeting: Optional greeting to speak on start
|
||||
"""
|
||||
self.transport = transport
|
||||
self.session_id = session_id
|
||||
self.event_bus = get_event_bus()
|
||||
|
||||
# Initialize VAD
|
||||
self.vad_model = SileroVAD(
|
||||
model_path=settings.vad_model_path,
|
||||
sample_rate=settings.sample_rate
|
||||
)
|
||||
self.vad_processor = VADProcessor(
|
||||
vad_model=self.vad_model,
|
||||
threshold=settings.vad_threshold
|
||||
)
|
||||
|
||||
# Initialize EOU detector
|
||||
self.eou_detector = EouDetector(
|
||||
silence_threshold_ms=600,
|
||||
min_speech_duration_ms=200
|
||||
)
|
||||
|
||||
# Initialize services
|
||||
self.llm_service = llm_service
|
||||
self.tts_service = tts_service
|
||||
self.asr_service = asr_service or BufferedASRService()
|
||||
|
||||
# Conversation manager
|
||||
self.conversation = ConversationManager(
|
||||
system_prompt=system_prompt,
|
||||
greeting=greeting
|
||||
)
|
||||
|
||||
# State
|
||||
self._running = True
|
||||
self._is_bot_speaking = False
|
||||
self._current_turn_task: Optional[asyncio.Task] = None
|
||||
self._audio_buffer: bytes = b""
|
||||
self._last_vad_status: str = "Silence"
|
||||
|
||||
# Interruption handling
|
||||
self._interrupt_event = asyncio.Event()
|
||||
|
||||
logger.info(f"DuplexPipeline initialized for session {session_id}")
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the pipeline and connect services."""
|
||||
try:
|
||||
# Connect LLM service
|
||||
if not self.llm_service:
|
||||
if settings.openai_api_key:
|
||||
self.llm_service = OpenAILLMService(
|
||||
api_key=settings.openai_api_key,
|
||||
base_url=settings.openai_api_url,
|
||||
model=settings.llm_model
|
||||
)
|
||||
else:
|
||||
logger.warning("No OpenAI API key - using mock LLM")
|
||||
self.llm_service = MockLLMService()
|
||||
|
||||
await self.llm_service.connect()
|
||||
|
||||
# Connect TTS service
|
||||
if not self.tts_service:
|
||||
if settings.tts_provider == "siliconflow" and settings.siliconflow_api_key:
|
||||
self.tts_service = SiliconFlowTTSService(
|
||||
api_key=settings.siliconflow_api_key,
|
||||
voice=settings.tts_voice,
|
||||
model=settings.siliconflow_tts_model,
|
||||
sample_rate=settings.sample_rate,
|
||||
speed=settings.tts_speed
|
||||
)
|
||||
logger.info("Using SiliconFlow TTS service")
|
||||
else:
|
||||
self.tts_service = EdgeTTSService(
|
||||
voice=settings.tts_voice,
|
||||
sample_rate=settings.sample_rate
|
||||
)
|
||||
logger.info("Using Edge TTS service")
|
||||
|
||||
await self.tts_service.connect()
|
||||
|
||||
# Connect ASR service
|
||||
await self.asr_service.connect()
|
||||
|
||||
logger.info("DuplexPipeline services connected")
|
||||
|
||||
# Speak greeting if configured
|
||||
if self.conversation.greeting:
|
||||
await self._speak(self.conversation.greeting)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start pipeline: {e}")
|
||||
raise
|
||||
|
||||
async def process_audio(self, pcm_bytes: bytes) -> None:
|
||||
"""
|
||||
Process incoming audio chunk.
|
||||
|
||||
This is the main entry point for audio from the user.
|
||||
|
||||
Args:
|
||||
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
|
||||
"""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
try:
|
||||
# 1. Process through VAD
|
||||
vad_result = self.vad_processor.process(pcm_bytes, settings.chunk_size_ms)
|
||||
|
||||
vad_status = "Silence"
|
||||
if vad_result:
|
||||
event_type, probability = vad_result
|
||||
vad_status = "Speech" if event_type == "speaking" else "Silence"
|
||||
|
||||
# Emit VAD event
|
||||
await self.event_bus.publish(event_type, {
|
||||
"trackId": self.session_id,
|
||||
"probability": probability
|
||||
})
|
||||
else:
|
||||
# No state change - keep previous status
|
||||
vad_status = self._last_vad_status
|
||||
|
||||
# Update state based on VAD
|
||||
if vad_status == "Speech" and self._last_vad_status != "Speech":
|
||||
await self._on_speech_start()
|
||||
|
||||
self._last_vad_status = vad_status
|
||||
|
||||
# 2. Check for barge-in (user speaking while bot speaking)
|
||||
if self._is_bot_speaking and vad_status == "Speech":
|
||||
await self._handle_barge_in()
|
||||
|
||||
# 3. Buffer audio for ASR
|
||||
if vad_status == "Speech" or self.conversation.state == ConversationState.LISTENING:
|
||||
self._audio_buffer += pcm_bytes
|
||||
await self.asr_service.send_audio(pcm_bytes)
|
||||
|
||||
# 4. Check for End of Utterance
|
||||
if self.eou_detector.process(vad_status):
|
||||
await self._on_end_of_utterance()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Pipeline audio processing error: {e}", exc_info=True)
|
||||
|
||||
async def process_text(self, text: str) -> None:
|
||||
"""
|
||||
Process text input (chat command).
|
||||
|
||||
Allows direct text input to bypass ASR.
|
||||
|
||||
Args:
|
||||
text: User text input
|
||||
"""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
logger.info(f"Processing text input: {text[:50]}...")
|
||||
|
||||
# Cancel any current speaking
|
||||
await self._stop_current_speech()
|
||||
|
||||
# Start new turn
|
||||
await self.conversation.end_user_turn(text)
|
||||
self._current_turn_task = asyncio.create_task(self._handle_turn(text))
|
||||
|
||||
async def interrupt(self) -> None:
|
||||
"""Interrupt current bot speech (manual interrupt command)."""
|
||||
await self._handle_barge_in()
|
||||
|
||||
async def _on_speech_start(self) -> None:
|
||||
"""Handle user starting to speak."""
|
||||
if self.conversation.state == ConversationState.IDLE:
|
||||
await self.conversation.start_user_turn()
|
||||
self._audio_buffer = b""
|
||||
self.eou_detector.reset()
|
||||
logger.debug("User speech started")
|
||||
|
||||
async def _on_end_of_utterance(self) -> None:
|
||||
"""Handle end of user utterance."""
|
||||
if self.conversation.state != ConversationState.LISTENING:
|
||||
return
|
||||
|
||||
# Get transcribed text (if using ASR that provides it)
|
||||
user_text = ""
|
||||
if hasattr(self.asr_service, 'get_and_clear_text'):
|
||||
user_text = self.asr_service.get_and_clear_text()
|
||||
|
||||
# If no ASR text, we could use the audio buffer for external ASR
|
||||
# For now, just use placeholder if no ASR text
|
||||
if not user_text:
|
||||
# In a real implementation, you'd send audio_buffer to ASR here
|
||||
# For demo purposes, use mock text
|
||||
user_text = "[User speech detected]"
|
||||
logger.warning("No ASR text available - using placeholder")
|
||||
|
||||
logger.info(f"EOU detected - user said: {user_text[:50]}...")
|
||||
|
||||
# Clear buffers
|
||||
self._audio_buffer = b""
|
||||
|
||||
# Process the turn
|
||||
await self.conversation.end_user_turn(user_text)
|
||||
self._current_turn_task = asyncio.create_task(self._handle_turn(user_text))
|
||||
|
||||
async def _handle_turn(self, user_text: str) -> None:
|
||||
"""
|
||||
Handle a complete conversation turn.
|
||||
|
||||
Uses sentence-by-sentence streaming TTS for lower latency.
|
||||
|
||||
Args:
|
||||
user_text: User's transcribed text
|
||||
"""
|
||||
try:
|
||||
# Get AI response (streaming)
|
||||
messages = self.conversation.get_messages()
|
||||
full_response = ""
|
||||
|
||||
await self.conversation.start_assistant_turn()
|
||||
self._is_bot_speaking = True
|
||||
self._interrupt_event.clear()
|
||||
|
||||
# Sentence buffer for streaming TTS
|
||||
sentence_buffer = ""
|
||||
sentence_ends = {'.', '!', '?', '。', '!', '?', ';', '\n'}
|
||||
first_audio_sent = False
|
||||
|
||||
# Stream LLM response and TTS sentence by sentence
|
||||
async for text_chunk in self.llm_service.generate_stream(messages):
|
||||
if self._interrupt_event.is_set():
|
||||
break
|
||||
|
||||
full_response += text_chunk
|
||||
sentence_buffer += text_chunk
|
||||
await self.conversation.update_assistant_text(text_chunk)
|
||||
|
||||
# Check for sentence completion - synthesize immediately for low latency
|
||||
while any(end in sentence_buffer for end in sentence_ends):
|
||||
# Find first sentence end
|
||||
min_idx = len(sentence_buffer)
|
||||
for end in sentence_ends:
|
||||
idx = sentence_buffer.find(end)
|
||||
if idx != -1 and idx < min_idx:
|
||||
min_idx = idx
|
||||
|
||||
if min_idx < len(sentence_buffer):
|
||||
sentence = sentence_buffer[:min_idx + 1].strip()
|
||||
sentence_buffer = sentence_buffer[min_idx + 1:]
|
||||
|
||||
if sentence and not self._interrupt_event.is_set():
|
||||
# Send track start on first audio
|
||||
if not first_audio_sent:
|
||||
await self.transport.send_event({
|
||||
"event": "trackStart",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
})
|
||||
first_audio_sent = True
|
||||
|
||||
# Synthesize and send this sentence immediately
|
||||
await self._speak_sentence(sentence)
|
||||
else:
|
||||
break
|
||||
|
||||
# Speak any remaining text
|
||||
if sentence_buffer.strip() and not self._interrupt_event.is_set():
|
||||
if not first_audio_sent:
|
||||
await self.transport.send_event({
|
||||
"event": "trackStart",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
})
|
||||
first_audio_sent = True
|
||||
await self._speak_sentence(sentence_buffer.strip())
|
||||
|
||||
# Send track end
|
||||
if first_audio_sent:
|
||||
await self.transport.send_event({
|
||||
"event": "trackEnd",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
})
|
||||
|
||||
# End assistant turn
|
||||
await self.conversation.end_assistant_turn(
|
||||
was_interrupted=self._interrupt_event.is_set()
|
||||
)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Turn handling cancelled")
|
||||
await self.conversation.end_assistant_turn(was_interrupted=True)
|
||||
except Exception as e:
|
||||
logger.error(f"Turn handling error: {e}", exc_info=True)
|
||||
await self.conversation.end_assistant_turn(was_interrupted=True)
|
||||
finally:
|
||||
self._is_bot_speaking = False
|
||||
|
||||
async def _speak_sentence(self, text: str) -> None:
|
||||
"""
|
||||
Synthesize and send a single sentence.
|
||||
|
||||
Args:
|
||||
text: Sentence to speak
|
||||
"""
|
||||
if not text.strip() or self._interrupt_event.is_set():
|
||||
return
|
||||
|
||||
try:
|
||||
async for chunk in self.tts_service.synthesize_stream(text):
|
||||
if self._interrupt_event.is_set():
|
||||
break
|
||||
await self.transport.send_audio(chunk.audio)
|
||||
await asyncio.sleep(0.005) # Small delay to prevent flooding
|
||||
except Exception as e:
|
||||
logger.error(f"TTS sentence error: {e}")
|
||||
|
||||
async def _speak(self, text: str) -> None:
|
||||
"""
|
||||
Synthesize and send speech.
|
||||
|
||||
Args:
|
||||
text: Text to speak
|
||||
"""
|
||||
if not text.strip():
|
||||
return
|
||||
|
||||
try:
|
||||
# Send track start event
|
||||
await self.transport.send_event({
|
||||
"event": "trackStart",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
})
|
||||
|
||||
self._is_bot_speaking = True
|
||||
|
||||
# Stream TTS audio
|
||||
async for chunk in self.tts_service.synthesize_stream(text):
|
||||
if self._interrupt_event.is_set():
|
||||
logger.info("TTS interrupted by barge-in")
|
||||
break
|
||||
|
||||
# Send audio to client
|
||||
await self.transport.send_audio(chunk.audio)
|
||||
|
||||
# Small delay to prevent flooding
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Send track end event
|
||||
await self.transport.send_event({
|
||||
"event": "trackEnd",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
})
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("TTS cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"TTS error: {e}")
|
||||
finally:
|
||||
self._is_bot_speaking = False
|
||||
|
||||
async def _handle_barge_in(self) -> None:
|
||||
"""Handle user barge-in (interruption)."""
|
||||
if not self._is_bot_speaking:
|
||||
return
|
||||
|
||||
logger.info("Barge-in detected - interrupting bot speech")
|
||||
|
||||
# Signal interruption
|
||||
self._interrupt_event.set()
|
||||
|
||||
# Cancel TTS
|
||||
if self.tts_service:
|
||||
await self.tts_service.cancel()
|
||||
|
||||
# Cancel LLM
|
||||
if self.llm_service and hasattr(self.llm_service, 'cancel'):
|
||||
self.llm_service.cancel()
|
||||
|
||||
# Interrupt conversation
|
||||
await self.conversation.interrupt()
|
||||
|
||||
# Send interrupt event to client
|
||||
await self.transport.send_event({
|
||||
"event": "interrupt",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
})
|
||||
|
||||
# Reset for new user turn
|
||||
self._is_bot_speaking = False
|
||||
await self.conversation.start_user_turn()
|
||||
self._audio_buffer = b""
|
||||
self.eou_detector.reset()
|
||||
|
||||
async def _stop_current_speech(self) -> None:
|
||||
"""Stop any current speech task."""
|
||||
if self._current_turn_task and not self._current_turn_task.done():
|
||||
self._interrupt_event.set()
|
||||
self._current_turn_task.cancel()
|
||||
try:
|
||||
await self._current_turn_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
self._is_bot_speaking = False
|
||||
self._interrupt_event.clear()
|
||||
|
||||
async def cleanup(self) -> None:
|
||||
"""Cleanup pipeline resources."""
|
||||
logger.info(f"Cleaning up DuplexPipeline for session {self.session_id}")
|
||||
|
||||
self._running = False
|
||||
await self._stop_current_speech()
|
||||
|
||||
# Disconnect services
|
||||
if self.llm_service:
|
||||
await self.llm_service.disconnect()
|
||||
if self.tts_service:
|
||||
await self.tts_service.disconnect()
|
||||
if self.asr_service:
|
||||
await self.asr_service.disconnect()
|
||||
|
||||
def _get_timestamp_ms(self) -> int:
|
||||
"""Get current timestamp in milliseconds."""
|
||||
import time
|
||||
return int(time.time() * 1000)
|
||||
|
||||
@property
|
||||
def is_speaking(self) -> bool:
|
||||
"""Check if bot is currently speaking."""
|
||||
return self._is_bot_speaking
|
||||
|
||||
@property
|
||||
def state(self) -> ConversationState:
|
||||
"""Get current conversation state."""
|
||||
return self.conversation.state
|
||||
Reference in New Issue
Block a user