I can use text to get audio response and barge in
This commit is contained in:
30
.env.example
Normal file
30
.env.example
Normal file
@@ -0,0 +1,30 @@
|
||||
# Server Configuration
|
||||
HOST=0.0.0.0
|
||||
PORT=8000
|
||||
|
||||
# Audio Configuration
|
||||
SAMPLE_RATE=16000
|
||||
CHUNK_SIZE_MS=20
|
||||
|
||||
# VAD Configuration
|
||||
VAD_THRESHOLD=0.5
|
||||
VAD_EOU_THRESHOLD_MS=400
|
||||
|
||||
# OpenAI / LLM Configuration (required for duplex voice)
|
||||
OPENAI_API_KEY=sk-your-openai-api-key-here
|
||||
# OPENAI_API_URL=https://api.openai.com/v1 # Optional: for Azure or compatible APIs
|
||||
LLM_MODEL=gpt-4o-mini
|
||||
LLM_TEMPERATURE=0.7
|
||||
|
||||
# TTS Configuration
|
||||
TTS_VOICE=en-US-JennyNeural
|
||||
TTS_SPEED=1.0
|
||||
|
||||
# Duplex Pipeline Configuration
|
||||
DUPLEX_ENABLED=true
|
||||
# DUPLEX_GREETING=Hello! How can I help you today?
|
||||
DUPLEX_SYSTEM_PROMPT=You are a helpful, friendly voice assistant. Keep your responses concise and conversational.
|
||||
|
||||
# Logging
|
||||
LOG_LEVEL=INFO
|
||||
LOG_FORMAT=text
|
||||
@@ -33,6 +33,29 @@ class Settings(BaseSettings):
|
||||
vad_min_speech_duration_ms: int = Field(default=250, description="Minimum speech duration in milliseconds")
|
||||
vad_eou_threshold_ms: int = Field(default=400, description="End of utterance (silence) threshold in milliseconds")
|
||||
|
||||
# OpenAI / LLM Configuration
|
||||
openai_api_key: Optional[str] = Field(default=None, description="OpenAI API key")
|
||||
openai_api_url: Optional[str] = Field(default=None, description="OpenAI API base URL (for Azure/compatible)")
|
||||
llm_model: str = Field(default="gpt-4o-mini", description="LLM model name")
|
||||
llm_temperature: float = Field(default=0.7, description="LLM temperature for response generation")
|
||||
|
||||
# TTS Configuration
|
||||
tts_provider: str = Field(default="siliconflow", description="TTS provider (edge, siliconflow)")
|
||||
tts_voice: str = Field(default="anna", description="TTS voice name")
|
||||
tts_speed: float = Field(default=1.0, description="TTS speech speed multiplier")
|
||||
|
||||
# SiliconFlow Configuration
|
||||
siliconflow_api_key: Optional[str] = Field(default=None, description="SiliconFlow API key")
|
||||
siliconflow_tts_model: str = Field(default="FunAudioLLM/CosyVoice2-0.5B", description="SiliconFlow TTS model")
|
||||
|
||||
# Duplex Pipeline Configuration
|
||||
duplex_enabled: bool = Field(default=True, description="Enable duplex voice pipeline")
|
||||
duplex_greeting: Optional[str] = Field(default=None, description="Optional greeting message")
|
||||
duplex_system_prompt: Optional[str] = Field(
|
||||
default="You are a helpful, friendly voice assistant. Keep your responses concise and conversational.",
|
||||
description="System prompt for LLM"
|
||||
)
|
||||
|
||||
# Logging
|
||||
log_level: str = Field(default="INFO", description="Logging level")
|
||||
log_format: str = Field(default="json", description="Log format (json or text)")
|
||||
|
||||
@@ -1 +1,22 @@
|
||||
"""Core Components Package"""
|
||||
|
||||
from core.events import EventBus, get_event_bus
|
||||
from core.transports import BaseTransport, SocketTransport, WebRtcTransport
|
||||
from core.pipeline import AudioPipeline
|
||||
from core.session import Session
|
||||
from core.conversation import ConversationManager, ConversationState, ConversationTurn
|
||||
from core.duplex_pipeline import DuplexPipeline
|
||||
|
||||
__all__ = [
|
||||
"EventBus",
|
||||
"get_event_bus",
|
||||
"BaseTransport",
|
||||
"SocketTransport",
|
||||
"WebRtcTransport",
|
||||
"AudioPipeline",
|
||||
"Session",
|
||||
"ConversationManager",
|
||||
"ConversationState",
|
||||
"ConversationTurn",
|
||||
"DuplexPipeline",
|
||||
]
|
||||
|
||||
255
core/conversation.py
Normal file
255
core/conversation.py
Normal file
@@ -0,0 +1,255 @@
|
||||
"""Conversation management for voice AI.
|
||||
|
||||
Handles conversation context, turn-taking, and message history
|
||||
for multi-turn voice conversations.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import List, Optional, Dict, Any, Callable, Awaitable
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from loguru import logger
|
||||
|
||||
from services.base import LLMMessage
|
||||
|
||||
|
||||
class ConversationState(Enum):
|
||||
"""State of the conversation."""
|
||||
IDLE = "idle" # Waiting for user input
|
||||
LISTENING = "listening" # User is speaking
|
||||
PROCESSING = "processing" # Processing user input (LLM)
|
||||
SPEAKING = "speaking" # Bot is speaking
|
||||
INTERRUPTED = "interrupted" # Bot was interrupted
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConversationTurn:
|
||||
"""A single turn in the conversation."""
|
||||
role: str # "user" or "assistant"
|
||||
text: str
|
||||
audio_duration_ms: Optional[int] = None
|
||||
timestamp: float = field(default_factory=lambda: asyncio.get_event_loop().time())
|
||||
was_interrupted: bool = False
|
||||
|
||||
|
||||
class ConversationManager:
|
||||
"""
|
||||
Manages conversation state and history.
|
||||
|
||||
Provides:
|
||||
- Message history for LLM context
|
||||
- Turn management
|
||||
- State tracking
|
||||
- Event callbacks for state changes
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
system_prompt: Optional[str] = None,
|
||||
max_history: int = 20,
|
||||
greeting: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize conversation manager.
|
||||
|
||||
Args:
|
||||
system_prompt: System prompt for LLM
|
||||
max_history: Maximum number of turns to keep
|
||||
greeting: Optional greeting message when conversation starts
|
||||
"""
|
||||
self.system_prompt = system_prompt or (
|
||||
"You are a helpful, friendly voice assistant. "
|
||||
"Keep your responses concise and conversational. "
|
||||
"Respond naturally as if having a phone conversation. "
|
||||
"If you don't understand something, ask for clarification."
|
||||
)
|
||||
self.max_history = max_history
|
||||
self.greeting = greeting
|
||||
|
||||
# State
|
||||
self.state = ConversationState.IDLE
|
||||
self.turns: List[ConversationTurn] = []
|
||||
|
||||
# Callbacks
|
||||
self._state_callbacks: List[Callable[[ConversationState, ConversationState], Awaitable[None]]] = []
|
||||
self._turn_callbacks: List[Callable[[ConversationTurn], Awaitable[None]]] = []
|
||||
|
||||
# Current turn tracking
|
||||
self._current_user_text: str = ""
|
||||
self._current_assistant_text: str = ""
|
||||
|
||||
logger.info("ConversationManager initialized")
|
||||
|
||||
def on_state_change(
|
||||
self,
|
||||
callback: Callable[[ConversationState, ConversationState], Awaitable[None]]
|
||||
) -> None:
|
||||
"""Register callback for state changes."""
|
||||
self._state_callbacks.append(callback)
|
||||
|
||||
def on_turn_complete(
|
||||
self,
|
||||
callback: Callable[[ConversationTurn], Awaitable[None]]
|
||||
) -> None:
|
||||
"""Register callback for turn completion."""
|
||||
self._turn_callbacks.append(callback)
|
||||
|
||||
async def set_state(self, new_state: ConversationState) -> None:
|
||||
"""Set conversation state and notify listeners."""
|
||||
if new_state != self.state:
|
||||
old_state = self.state
|
||||
self.state = new_state
|
||||
logger.debug(f"Conversation state: {old_state.value} -> {new_state.value}")
|
||||
|
||||
for callback in self._state_callbacks:
|
||||
try:
|
||||
await callback(old_state, new_state)
|
||||
except Exception as e:
|
||||
logger.error(f"State callback error: {e}")
|
||||
|
||||
def get_messages(self) -> List[LLMMessage]:
|
||||
"""
|
||||
Get conversation history as LLM messages.
|
||||
|
||||
Returns:
|
||||
List of LLMMessage objects including system prompt
|
||||
"""
|
||||
messages = [LLMMessage(role="system", content=self.system_prompt)]
|
||||
|
||||
# Add conversation history
|
||||
for turn in self.turns[-self.max_history:]:
|
||||
messages.append(LLMMessage(role=turn.role, content=turn.text))
|
||||
|
||||
# Add current user text if any
|
||||
if self._current_user_text:
|
||||
messages.append(LLMMessage(role="user", content=self._current_user_text))
|
||||
|
||||
return messages
|
||||
|
||||
async def start_user_turn(self) -> None:
|
||||
"""Signal that user has started speaking."""
|
||||
await self.set_state(ConversationState.LISTENING)
|
||||
self._current_user_text = ""
|
||||
|
||||
async def update_user_text(self, text: str, is_final: bool = False) -> None:
|
||||
"""
|
||||
Update current user text (from ASR).
|
||||
|
||||
Args:
|
||||
text: Transcribed text
|
||||
is_final: Whether this is the final transcript
|
||||
"""
|
||||
self._current_user_text = text
|
||||
|
||||
async def end_user_turn(self, text: str) -> None:
|
||||
"""
|
||||
End user turn and add to history.
|
||||
|
||||
Args:
|
||||
text: Final user text
|
||||
"""
|
||||
if text.strip():
|
||||
turn = ConversationTurn(role="user", text=text.strip())
|
||||
self.turns.append(turn)
|
||||
|
||||
for callback in self._turn_callbacks:
|
||||
try:
|
||||
await callback(turn)
|
||||
except Exception as e:
|
||||
logger.error(f"Turn callback error: {e}")
|
||||
|
||||
logger.info(f"User: {text[:50]}...")
|
||||
|
||||
self._current_user_text = ""
|
||||
await self.set_state(ConversationState.PROCESSING)
|
||||
|
||||
async def start_assistant_turn(self) -> None:
|
||||
"""Signal that assistant has started speaking."""
|
||||
await self.set_state(ConversationState.SPEAKING)
|
||||
self._current_assistant_text = ""
|
||||
|
||||
async def update_assistant_text(self, text: str) -> None:
|
||||
"""
|
||||
Update current assistant text (streaming).
|
||||
|
||||
Args:
|
||||
text: Text chunk from LLM
|
||||
"""
|
||||
self._current_assistant_text += text
|
||||
|
||||
async def end_assistant_turn(self, was_interrupted: bool = False) -> None:
|
||||
"""
|
||||
End assistant turn and add to history.
|
||||
|
||||
Args:
|
||||
was_interrupted: Whether the turn was interrupted by user
|
||||
"""
|
||||
text = self._current_assistant_text.strip()
|
||||
if text:
|
||||
turn = ConversationTurn(
|
||||
role="assistant",
|
||||
text=text,
|
||||
was_interrupted=was_interrupted
|
||||
)
|
||||
self.turns.append(turn)
|
||||
|
||||
for callback in self._turn_callbacks:
|
||||
try:
|
||||
await callback(turn)
|
||||
except Exception as e:
|
||||
logger.error(f"Turn callback error: {e}")
|
||||
|
||||
status = " (interrupted)" if was_interrupted else ""
|
||||
logger.info(f"Assistant{status}: {text[:50]}...")
|
||||
|
||||
self._current_assistant_text = ""
|
||||
|
||||
if was_interrupted:
|
||||
await self.set_state(ConversationState.INTERRUPTED)
|
||||
else:
|
||||
await self.set_state(ConversationState.IDLE)
|
||||
|
||||
async def interrupt(self) -> None:
|
||||
"""Handle interruption (barge-in)."""
|
||||
if self.state == ConversationState.SPEAKING:
|
||||
await self.end_assistant_turn(was_interrupted=True)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset conversation history."""
|
||||
self.turns = []
|
||||
self._current_user_text = ""
|
||||
self._current_assistant_text = ""
|
||||
self.state = ConversationState.IDLE
|
||||
logger.info("Conversation reset")
|
||||
|
||||
@property
|
||||
def turn_count(self) -> int:
|
||||
"""Get number of turns in conversation."""
|
||||
return len(self.turns)
|
||||
|
||||
@property
|
||||
def last_user_text(self) -> Optional[str]:
|
||||
"""Get last user text."""
|
||||
for turn in reversed(self.turns):
|
||||
if turn.role == "user":
|
||||
return turn.text
|
||||
return None
|
||||
|
||||
@property
|
||||
def last_assistant_text(self) -> Optional[str]:
|
||||
"""Get last assistant text."""
|
||||
for turn in reversed(self.turns):
|
||||
if turn.role == "assistant":
|
||||
return turn.text
|
||||
return None
|
||||
|
||||
def get_context_summary(self) -> Dict[str, Any]:
|
||||
"""Get a summary of conversation context."""
|
||||
return {
|
||||
"state": self.state.value,
|
||||
"turn_count": self.turn_count,
|
||||
"last_user": self.last_user_text,
|
||||
"last_assistant": self.last_assistant_text,
|
||||
"current_user": self._current_user_text or None,
|
||||
"current_assistant": self._current_assistant_text or None
|
||||
}
|
||||
509
core/duplex_pipeline.py
Normal file
509
core/duplex_pipeline.py
Normal file
@@ -0,0 +1,509 @@
|
||||
"""Full duplex audio pipeline for AI voice conversation.
|
||||
|
||||
This module implements the core duplex pipeline that orchestrates:
|
||||
- VAD (Voice Activity Detection)
|
||||
- EOU (End of Utterance) Detection
|
||||
- ASR (Automatic Speech Recognition) - optional
|
||||
- LLM (Language Model)
|
||||
- TTS (Text-to-Speech)
|
||||
|
||||
Inspired by pipecat's frame-based architecture and active-call's
|
||||
event-driven design.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Optional, Callable, Awaitable
|
||||
from loguru import logger
|
||||
|
||||
from core.transports import BaseTransport
|
||||
from core.conversation import ConversationManager, ConversationState
|
||||
from core.events import get_event_bus
|
||||
from processors.vad import VADProcessor, SileroVAD
|
||||
from processors.eou import EouDetector
|
||||
from services.base import BaseLLMService, BaseTTSService, BaseASRService
|
||||
from services.llm import OpenAILLMService, MockLLMService
|
||||
from services.tts import EdgeTTSService, MockTTSService
|
||||
from services.asr import BufferedASRService
|
||||
from services.siliconflow_tts import SiliconFlowTTSService
|
||||
from app.config import settings
|
||||
|
||||
|
||||
class DuplexPipeline:
|
||||
"""
|
||||
Full duplex audio pipeline for AI voice conversation.
|
||||
|
||||
Handles bidirectional audio flow with:
|
||||
- User speech detection and transcription
|
||||
- AI response generation
|
||||
- Text-to-speech synthesis
|
||||
- Barge-in (interruption) support
|
||||
|
||||
Architecture (inspired by pipecat):
|
||||
|
||||
User Audio → VAD → EOU → [ASR] → LLM → TTS → Audio Out
|
||||
↓
|
||||
Barge-in Detection → Interrupt
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transport: BaseTransport,
|
||||
session_id: str,
|
||||
llm_service: Optional[BaseLLMService] = None,
|
||||
tts_service: Optional[BaseTTSService] = None,
|
||||
asr_service: Optional[BaseASRService] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
greeting: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize duplex pipeline.
|
||||
|
||||
Args:
|
||||
transport: Transport for sending audio/events
|
||||
session_id: Session identifier
|
||||
llm_service: LLM service (defaults to OpenAI)
|
||||
tts_service: TTS service (defaults to EdgeTTS)
|
||||
asr_service: ASR service (optional)
|
||||
system_prompt: System prompt for LLM
|
||||
greeting: Optional greeting to speak on start
|
||||
"""
|
||||
self.transport = transport
|
||||
self.session_id = session_id
|
||||
self.event_bus = get_event_bus()
|
||||
|
||||
# Initialize VAD
|
||||
self.vad_model = SileroVAD(
|
||||
model_path=settings.vad_model_path,
|
||||
sample_rate=settings.sample_rate
|
||||
)
|
||||
self.vad_processor = VADProcessor(
|
||||
vad_model=self.vad_model,
|
||||
threshold=settings.vad_threshold
|
||||
)
|
||||
|
||||
# Initialize EOU detector
|
||||
self.eou_detector = EouDetector(
|
||||
silence_threshold_ms=600,
|
||||
min_speech_duration_ms=200
|
||||
)
|
||||
|
||||
# Initialize services
|
||||
self.llm_service = llm_service
|
||||
self.tts_service = tts_service
|
||||
self.asr_service = asr_service or BufferedASRService()
|
||||
|
||||
# Conversation manager
|
||||
self.conversation = ConversationManager(
|
||||
system_prompt=system_prompt,
|
||||
greeting=greeting
|
||||
)
|
||||
|
||||
# State
|
||||
self._running = True
|
||||
self._is_bot_speaking = False
|
||||
self._current_turn_task: Optional[asyncio.Task] = None
|
||||
self._audio_buffer: bytes = b""
|
||||
self._last_vad_status: str = "Silence"
|
||||
|
||||
# Interruption handling
|
||||
self._interrupt_event = asyncio.Event()
|
||||
|
||||
logger.info(f"DuplexPipeline initialized for session {session_id}")
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the pipeline and connect services."""
|
||||
try:
|
||||
# Connect LLM service
|
||||
if not self.llm_service:
|
||||
if settings.openai_api_key:
|
||||
self.llm_service = OpenAILLMService(
|
||||
api_key=settings.openai_api_key,
|
||||
base_url=settings.openai_api_url,
|
||||
model=settings.llm_model
|
||||
)
|
||||
else:
|
||||
logger.warning("No OpenAI API key - using mock LLM")
|
||||
self.llm_service = MockLLMService()
|
||||
|
||||
await self.llm_service.connect()
|
||||
|
||||
# Connect TTS service
|
||||
if not self.tts_service:
|
||||
if settings.tts_provider == "siliconflow" and settings.siliconflow_api_key:
|
||||
self.tts_service = SiliconFlowTTSService(
|
||||
api_key=settings.siliconflow_api_key,
|
||||
voice=settings.tts_voice,
|
||||
model=settings.siliconflow_tts_model,
|
||||
sample_rate=settings.sample_rate,
|
||||
speed=settings.tts_speed
|
||||
)
|
||||
logger.info("Using SiliconFlow TTS service")
|
||||
else:
|
||||
self.tts_service = EdgeTTSService(
|
||||
voice=settings.tts_voice,
|
||||
sample_rate=settings.sample_rate
|
||||
)
|
||||
logger.info("Using Edge TTS service")
|
||||
|
||||
await self.tts_service.connect()
|
||||
|
||||
# Connect ASR service
|
||||
await self.asr_service.connect()
|
||||
|
||||
logger.info("DuplexPipeline services connected")
|
||||
|
||||
# Speak greeting if configured
|
||||
if self.conversation.greeting:
|
||||
await self._speak(self.conversation.greeting)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start pipeline: {e}")
|
||||
raise
|
||||
|
||||
async def process_audio(self, pcm_bytes: bytes) -> None:
|
||||
"""
|
||||
Process incoming audio chunk.
|
||||
|
||||
This is the main entry point for audio from the user.
|
||||
|
||||
Args:
|
||||
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
|
||||
"""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
try:
|
||||
# 1. Process through VAD
|
||||
vad_result = self.vad_processor.process(pcm_bytes, settings.chunk_size_ms)
|
||||
|
||||
vad_status = "Silence"
|
||||
if vad_result:
|
||||
event_type, probability = vad_result
|
||||
vad_status = "Speech" if event_type == "speaking" else "Silence"
|
||||
|
||||
# Emit VAD event
|
||||
await self.event_bus.publish(event_type, {
|
||||
"trackId": self.session_id,
|
||||
"probability": probability
|
||||
})
|
||||
else:
|
||||
# No state change - keep previous status
|
||||
vad_status = self._last_vad_status
|
||||
|
||||
# Update state based on VAD
|
||||
if vad_status == "Speech" and self._last_vad_status != "Speech":
|
||||
await self._on_speech_start()
|
||||
|
||||
self._last_vad_status = vad_status
|
||||
|
||||
# 2. Check for barge-in (user speaking while bot speaking)
|
||||
if self._is_bot_speaking and vad_status == "Speech":
|
||||
await self._handle_barge_in()
|
||||
|
||||
# 3. Buffer audio for ASR
|
||||
if vad_status == "Speech" or self.conversation.state == ConversationState.LISTENING:
|
||||
self._audio_buffer += pcm_bytes
|
||||
await self.asr_service.send_audio(pcm_bytes)
|
||||
|
||||
# 4. Check for End of Utterance
|
||||
if self.eou_detector.process(vad_status):
|
||||
await self._on_end_of_utterance()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Pipeline audio processing error: {e}", exc_info=True)
|
||||
|
||||
async def process_text(self, text: str) -> None:
|
||||
"""
|
||||
Process text input (chat command).
|
||||
|
||||
Allows direct text input to bypass ASR.
|
||||
|
||||
Args:
|
||||
text: User text input
|
||||
"""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
logger.info(f"Processing text input: {text[:50]}...")
|
||||
|
||||
# Cancel any current speaking
|
||||
await self._stop_current_speech()
|
||||
|
||||
# Start new turn
|
||||
await self.conversation.end_user_turn(text)
|
||||
self._current_turn_task = asyncio.create_task(self._handle_turn(text))
|
||||
|
||||
async def interrupt(self) -> None:
|
||||
"""Interrupt current bot speech (manual interrupt command)."""
|
||||
await self._handle_barge_in()
|
||||
|
||||
async def _on_speech_start(self) -> None:
|
||||
"""Handle user starting to speak."""
|
||||
if self.conversation.state == ConversationState.IDLE:
|
||||
await self.conversation.start_user_turn()
|
||||
self._audio_buffer = b""
|
||||
self.eou_detector.reset()
|
||||
logger.debug("User speech started")
|
||||
|
||||
async def _on_end_of_utterance(self) -> None:
|
||||
"""Handle end of user utterance."""
|
||||
if self.conversation.state != ConversationState.LISTENING:
|
||||
return
|
||||
|
||||
# Get transcribed text (if using ASR that provides it)
|
||||
user_text = ""
|
||||
if hasattr(self.asr_service, 'get_and_clear_text'):
|
||||
user_text = self.asr_service.get_and_clear_text()
|
||||
|
||||
# If no ASR text, we could use the audio buffer for external ASR
|
||||
# For now, just use placeholder if no ASR text
|
||||
if not user_text:
|
||||
# In a real implementation, you'd send audio_buffer to ASR here
|
||||
# For demo purposes, use mock text
|
||||
user_text = "[User speech detected]"
|
||||
logger.warning("No ASR text available - using placeholder")
|
||||
|
||||
logger.info(f"EOU detected - user said: {user_text[:50]}...")
|
||||
|
||||
# Clear buffers
|
||||
self._audio_buffer = b""
|
||||
|
||||
# Process the turn
|
||||
await self.conversation.end_user_turn(user_text)
|
||||
self._current_turn_task = asyncio.create_task(self._handle_turn(user_text))
|
||||
|
||||
async def _handle_turn(self, user_text: str) -> None:
|
||||
"""
|
||||
Handle a complete conversation turn.
|
||||
|
||||
Uses sentence-by-sentence streaming TTS for lower latency.
|
||||
|
||||
Args:
|
||||
user_text: User's transcribed text
|
||||
"""
|
||||
try:
|
||||
# Get AI response (streaming)
|
||||
messages = self.conversation.get_messages()
|
||||
full_response = ""
|
||||
|
||||
await self.conversation.start_assistant_turn()
|
||||
self._is_bot_speaking = True
|
||||
self._interrupt_event.clear()
|
||||
|
||||
# Sentence buffer for streaming TTS
|
||||
sentence_buffer = ""
|
||||
sentence_ends = {'.', '!', '?', '。', '!', '?', ';', '\n'}
|
||||
first_audio_sent = False
|
||||
|
||||
# Stream LLM response and TTS sentence by sentence
|
||||
async for text_chunk in self.llm_service.generate_stream(messages):
|
||||
if self._interrupt_event.is_set():
|
||||
break
|
||||
|
||||
full_response += text_chunk
|
||||
sentence_buffer += text_chunk
|
||||
await self.conversation.update_assistant_text(text_chunk)
|
||||
|
||||
# Check for sentence completion - synthesize immediately for low latency
|
||||
while any(end in sentence_buffer for end in sentence_ends):
|
||||
# Find first sentence end
|
||||
min_idx = len(sentence_buffer)
|
||||
for end in sentence_ends:
|
||||
idx = sentence_buffer.find(end)
|
||||
if idx != -1 and idx < min_idx:
|
||||
min_idx = idx
|
||||
|
||||
if min_idx < len(sentence_buffer):
|
||||
sentence = sentence_buffer[:min_idx + 1].strip()
|
||||
sentence_buffer = sentence_buffer[min_idx + 1:]
|
||||
|
||||
if sentence and not self._interrupt_event.is_set():
|
||||
# Send track start on first audio
|
||||
if not first_audio_sent:
|
||||
await self.transport.send_event({
|
||||
"event": "trackStart",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
})
|
||||
first_audio_sent = True
|
||||
|
||||
# Synthesize and send this sentence immediately
|
||||
await self._speak_sentence(sentence)
|
||||
else:
|
||||
break
|
||||
|
||||
# Speak any remaining text
|
||||
if sentence_buffer.strip() and not self._interrupt_event.is_set():
|
||||
if not first_audio_sent:
|
||||
await self.transport.send_event({
|
||||
"event": "trackStart",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
})
|
||||
first_audio_sent = True
|
||||
await self._speak_sentence(sentence_buffer.strip())
|
||||
|
||||
# Send track end
|
||||
if first_audio_sent:
|
||||
await self.transport.send_event({
|
||||
"event": "trackEnd",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
})
|
||||
|
||||
# End assistant turn
|
||||
await self.conversation.end_assistant_turn(
|
||||
was_interrupted=self._interrupt_event.is_set()
|
||||
)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Turn handling cancelled")
|
||||
await self.conversation.end_assistant_turn(was_interrupted=True)
|
||||
except Exception as e:
|
||||
logger.error(f"Turn handling error: {e}", exc_info=True)
|
||||
await self.conversation.end_assistant_turn(was_interrupted=True)
|
||||
finally:
|
||||
self._is_bot_speaking = False
|
||||
|
||||
async def _speak_sentence(self, text: str) -> None:
|
||||
"""
|
||||
Synthesize and send a single sentence.
|
||||
|
||||
Args:
|
||||
text: Sentence to speak
|
||||
"""
|
||||
if not text.strip() or self._interrupt_event.is_set():
|
||||
return
|
||||
|
||||
try:
|
||||
async for chunk in self.tts_service.synthesize_stream(text):
|
||||
if self._interrupt_event.is_set():
|
||||
break
|
||||
await self.transport.send_audio(chunk.audio)
|
||||
await asyncio.sleep(0.005) # Small delay to prevent flooding
|
||||
except Exception as e:
|
||||
logger.error(f"TTS sentence error: {e}")
|
||||
|
||||
async def _speak(self, text: str) -> None:
|
||||
"""
|
||||
Synthesize and send speech.
|
||||
|
||||
Args:
|
||||
text: Text to speak
|
||||
"""
|
||||
if not text.strip():
|
||||
return
|
||||
|
||||
try:
|
||||
# Send track start event
|
||||
await self.transport.send_event({
|
||||
"event": "trackStart",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
})
|
||||
|
||||
self._is_bot_speaking = True
|
||||
|
||||
# Stream TTS audio
|
||||
async for chunk in self.tts_service.synthesize_stream(text):
|
||||
if self._interrupt_event.is_set():
|
||||
logger.info("TTS interrupted by barge-in")
|
||||
break
|
||||
|
||||
# Send audio to client
|
||||
await self.transport.send_audio(chunk.audio)
|
||||
|
||||
# Small delay to prevent flooding
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Send track end event
|
||||
await self.transport.send_event({
|
||||
"event": "trackEnd",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
})
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("TTS cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"TTS error: {e}")
|
||||
finally:
|
||||
self._is_bot_speaking = False
|
||||
|
||||
async def _handle_barge_in(self) -> None:
|
||||
"""Handle user barge-in (interruption)."""
|
||||
if not self._is_bot_speaking:
|
||||
return
|
||||
|
||||
logger.info("Barge-in detected - interrupting bot speech")
|
||||
|
||||
# Signal interruption
|
||||
self._interrupt_event.set()
|
||||
|
||||
# Cancel TTS
|
||||
if self.tts_service:
|
||||
await self.tts_service.cancel()
|
||||
|
||||
# Cancel LLM
|
||||
if self.llm_service and hasattr(self.llm_service, 'cancel'):
|
||||
self.llm_service.cancel()
|
||||
|
||||
# Interrupt conversation
|
||||
await self.conversation.interrupt()
|
||||
|
||||
# Send interrupt event to client
|
||||
await self.transport.send_event({
|
||||
"event": "interrupt",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
})
|
||||
|
||||
# Reset for new user turn
|
||||
self._is_bot_speaking = False
|
||||
await self.conversation.start_user_turn()
|
||||
self._audio_buffer = b""
|
||||
self.eou_detector.reset()
|
||||
|
||||
async def _stop_current_speech(self) -> None:
|
||||
"""Stop any current speech task."""
|
||||
if self._current_turn_task and not self._current_turn_task.done():
|
||||
self._interrupt_event.set()
|
||||
self._current_turn_task.cancel()
|
||||
try:
|
||||
await self._current_turn_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
self._is_bot_speaking = False
|
||||
self._interrupt_event.clear()
|
||||
|
||||
async def cleanup(self) -> None:
|
||||
"""Cleanup pipeline resources."""
|
||||
logger.info(f"Cleaning up DuplexPipeline for session {self.session_id}")
|
||||
|
||||
self._running = False
|
||||
await self._stop_current_speech()
|
||||
|
||||
# Disconnect services
|
||||
if self.llm_service:
|
||||
await self.llm_service.disconnect()
|
||||
if self.tts_service:
|
||||
await self.tts_service.disconnect()
|
||||
if self.asr_service:
|
||||
await self.asr_service.disconnect()
|
||||
|
||||
def _get_timestamp_ms(self) -> int:
|
||||
"""Get current timestamp in milliseconds."""
|
||||
import time
|
||||
return int(time.time() * 1000)
|
||||
|
||||
@property
|
||||
def is_speaking(self) -> bool:
|
||||
"""Check if bot is currently speaking."""
|
||||
return self._is_bot_speaking
|
||||
|
||||
@property
|
||||
def state(self) -> ConversationState:
|
||||
"""Get current conversation state."""
|
||||
return self.conversation.state
|
||||
@@ -8,6 +8,7 @@ from loguru import logger
|
||||
from core.transports import BaseTransport
|
||||
from core.pipeline import AudioPipeline
|
||||
from models.commands import parse_command, TTSCommand, ChatCommand, InterruptCommand, HangupCommand
|
||||
from app.config import settings
|
||||
|
||||
|
||||
class Session:
|
||||
@@ -15,28 +16,44 @@ class Session:
|
||||
Manages a single call session.
|
||||
|
||||
Handles command routing, audio processing, and session lifecycle.
|
||||
Supports both basic audio pipeline and full duplex voice conversation.
|
||||
"""
|
||||
|
||||
def __init__(self, session_id: str, transport: BaseTransport):
|
||||
def __init__(self, session_id: str, transport: BaseTransport, use_duplex: bool = None):
|
||||
"""
|
||||
Initialize session.
|
||||
|
||||
Args:
|
||||
session_id: Unique session identifier
|
||||
transport: Transport instance for communication
|
||||
use_duplex: Whether to use duplex pipeline (defaults to settings.duplex_enabled)
|
||||
"""
|
||||
self.id = session_id
|
||||
self.transport = transport
|
||||
self.pipeline = AudioPipeline(transport, session_id)
|
||||
|
||||
# Determine pipeline mode
|
||||
self.use_duplex = use_duplex if use_duplex is not None else settings.duplex_enabled
|
||||
|
||||
if self.use_duplex:
|
||||
from core.duplex_pipeline import DuplexPipeline
|
||||
self.pipeline = DuplexPipeline(
|
||||
transport=transport,
|
||||
session_id=session_id,
|
||||
system_prompt=settings.duplex_system_prompt,
|
||||
greeting=settings.duplex_greeting
|
||||
)
|
||||
else:
|
||||
self.pipeline = AudioPipeline(transport, session_id)
|
||||
|
||||
# Session state
|
||||
self.created_at = None
|
||||
self.state = "created" # created, invited, accepted, ringing, hungup
|
||||
self._pipeline_started = False
|
||||
|
||||
# Track IDs
|
||||
self.current_track_id: Optional[str] = str(uuid.uuid4())
|
||||
|
||||
logger.info(f"Session {self.id} created")
|
||||
logger.info(f"Session {self.id} created (duplex={self.use_duplex})")
|
||||
|
||||
async def handle_text(self, text_data: str) -> None:
|
||||
"""
|
||||
@@ -112,7 +129,10 @@ class Session:
|
||||
audio_bytes: PCM audio data
|
||||
"""
|
||||
try:
|
||||
await self.pipeline.process_input(audio_bytes)
|
||||
if self.use_duplex:
|
||||
await self.pipeline.process_audio(audio_bytes)
|
||||
else:
|
||||
await self.pipeline.process_input(audio_bytes)
|
||||
except Exception as e:
|
||||
logger.error(f"Session {self.id} handle_audio error: {e}", exc_info=True)
|
||||
|
||||
@@ -128,6 +148,15 @@ class Session:
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
})
|
||||
|
||||
# Start duplex pipeline if enabled
|
||||
if self.use_duplex and not self._pipeline_started:
|
||||
try:
|
||||
await self.pipeline.start()
|
||||
self._pipeline_started = True
|
||||
logger.info(f"Session {self.id} duplex pipeline started")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start duplex pipeline: {e}")
|
||||
|
||||
logger.info(f"Session {self.id} invited with codec: {option.get('codec', 'pcm')}")
|
||||
|
||||
async def _handle_accept(self, data: Dict[str, Any]) -> None:
|
||||
@@ -199,7 +228,10 @@ class Session:
|
||||
logger.info(f"Session {self.id} graceful interrupt")
|
||||
else:
|
||||
logger.info(f"Session {self.id} immediate interrupt")
|
||||
await self.pipeline.interrupt()
|
||||
if self.use_duplex:
|
||||
await self.pipeline.interrupt()
|
||||
else:
|
||||
await self.pipeline.interrupt()
|
||||
|
||||
async def _handle_pause(self) -> None:
|
||||
"""Handle pause command."""
|
||||
@@ -236,7 +268,10 @@ class Session:
|
||||
"""Handle chat command."""
|
||||
logger.info(f"Session {self.id} chat: {command.text[:50]}...")
|
||||
# Process text input through pipeline
|
||||
await self.pipeline.process_text_input(command.text)
|
||||
if self.use_duplex:
|
||||
await self.pipeline.process_text(command.text)
|
||||
else:
|
||||
await self.pipeline.process_text_input(command.text)
|
||||
|
||||
async def _send_error(self, sender: str, error_message: str) -> None:
|
||||
"""
|
||||
|
||||
@@ -1,137 +1,517 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Microphone WebSocket Client
|
||||
Microphone client for testing duplex voice conversation.
|
||||
|
||||
Connects to the backend WebSocket endpoint and streams audio from the microphone.
|
||||
Used to test VAD and EOU detection.
|
||||
This client captures audio from the microphone, sends it to the server,
|
||||
and plays back the AI's voice response through the speakers.
|
||||
|
||||
Dependencies:
|
||||
pip install pyaudio aiohttp
|
||||
Usage:
|
||||
python examples/mic_client.py --url ws://localhost:8000/ws
|
||||
python examples/mic_client.py --url ws://localhost:8000/ws --chat "Hello!"
|
||||
|
||||
Requirements:
|
||||
pip install sounddevice soundfile websockets numpy
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import pyaudio
|
||||
import json
|
||||
import sys
|
||||
from datetime import datetime
|
||||
import threading
|
||||
import queue
|
||||
from pathlib import Path
|
||||
|
||||
# Configuration
|
||||
SERVER_URL = "ws://localhost:8000/ws"
|
||||
SAMPLE_RATE = 16000
|
||||
CHANNELS = 1
|
||||
CHUNK_DURATION_MS = 20
|
||||
CHUNK_SIZE = int(SAMPLE_RATE * (CHUNK_DURATION_MS / 1000.0)) # 320 samples for 20ms
|
||||
FORMAT = pyaudio.paInt16
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
print("Please install numpy: pip install numpy")
|
||||
sys.exit(1)
|
||||
|
||||
async def send_audio_loop(ws, stream):
|
||||
"""Read from microphone and send to WebSocket."""
|
||||
print("🎙️ Microphone streaming started...")
|
||||
try:
|
||||
while True:
|
||||
# Read non-blocking? PyAudio read is blocking, so run in executor or use specialized async lib.
|
||||
# For simplicity in this script, we'll just read. It might block the event loop slightly
|
||||
# but for 20ms chunks it's usually acceptable for a test script.
|
||||
# To be proper async, we should run_in_executor.
|
||||
data = await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: stream.read(CHUNK_SIZE, exception_on_overflow=False)
|
||||
try:
|
||||
import sounddevice as sd
|
||||
except ImportError:
|
||||
print("Please install sounddevice: pip install sounddevice")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
import websockets
|
||||
except ImportError:
|
||||
print("Please install websockets: pip install websockets")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
class MicrophoneClient:
|
||||
"""
|
||||
Full-duplex microphone client for voice conversation.
|
||||
|
||||
Features:
|
||||
- Real-time microphone capture
|
||||
- Real-time speaker playback
|
||||
- WebSocket communication
|
||||
- Text chat support
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
sample_rate: int = 16000,
|
||||
chunk_duration_ms: int = 20,
|
||||
input_device: int = None,
|
||||
output_device: int = None
|
||||
):
|
||||
"""
|
||||
Initialize microphone client.
|
||||
|
||||
Args:
|
||||
url: WebSocket server URL
|
||||
sample_rate: Audio sample rate (Hz)
|
||||
chunk_duration_ms: Audio chunk duration (ms)
|
||||
input_device: Input device ID (None for default)
|
||||
output_device: Output device ID (None for default)
|
||||
"""
|
||||
self.url = url
|
||||
self.sample_rate = sample_rate
|
||||
self.chunk_duration_ms = chunk_duration_ms
|
||||
self.chunk_samples = int(sample_rate * chunk_duration_ms / 1000)
|
||||
self.input_device = input_device
|
||||
self.output_device = output_device
|
||||
|
||||
# WebSocket connection
|
||||
self.ws = None
|
||||
self.running = False
|
||||
|
||||
# Audio buffers
|
||||
self.audio_input_queue = queue.Queue()
|
||||
self.audio_output_buffer = b"" # Continuous buffer for smooth playback
|
||||
self.audio_output_lock = threading.Lock()
|
||||
|
||||
# Statistics
|
||||
self.bytes_sent = 0
|
||||
self.bytes_received = 0
|
||||
|
||||
# State
|
||||
self.is_recording = True
|
||||
self.is_playing = True
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Connect to WebSocket server."""
|
||||
print(f"Connecting to {self.url}...")
|
||||
self.ws = await websockets.connect(self.url)
|
||||
self.running = True
|
||||
print("Connected!")
|
||||
|
||||
# Send invite command
|
||||
await self.send_command({
|
||||
"command": "invite",
|
||||
"option": {
|
||||
"codec": "pcm",
|
||||
"sampleRate": self.sample_rate
|
||||
}
|
||||
})
|
||||
|
||||
async def send_command(self, cmd: dict) -> None:
|
||||
"""Send JSON command to server."""
|
||||
if self.ws:
|
||||
await self.ws.send(json.dumps(cmd))
|
||||
print(f"→ Command: {cmd.get('command', 'unknown')}")
|
||||
|
||||
async def send_chat(self, text: str) -> None:
|
||||
"""Send chat message (text input)."""
|
||||
await self.send_command({
|
||||
"command": "chat",
|
||||
"text": text
|
||||
})
|
||||
print(f"→ Chat: {text}")
|
||||
|
||||
async def send_interrupt(self) -> None:
|
||||
"""Send interrupt command."""
|
||||
await self.send_command({
|
||||
"command": "interrupt"
|
||||
})
|
||||
|
||||
async def send_hangup(self, reason: str = "User quit") -> None:
|
||||
"""Send hangup command."""
|
||||
await self.send_command({
|
||||
"command": "hangup",
|
||||
"reason": reason
|
||||
})
|
||||
|
||||
def _audio_input_callback(self, indata, frames, time, status):
|
||||
"""Callback for audio input (microphone)."""
|
||||
if status:
|
||||
print(f"Input status: {status}")
|
||||
|
||||
if self.is_recording and self.running:
|
||||
# Convert to 16-bit PCM
|
||||
audio_data = (indata[:, 0] * 32767).astype(np.int16).tobytes()
|
||||
self.audio_input_queue.put(audio_data)
|
||||
|
||||
def _add_audio_to_buffer(self, audio_data: bytes):
|
||||
"""Add audio data to playback buffer."""
|
||||
with self.audio_output_lock:
|
||||
self.audio_output_buffer += audio_data
|
||||
|
||||
async def _playback_task(self):
|
||||
"""Background task to play buffered audio smoothly using output stream."""
|
||||
# Use a continuous output stream for smooth playback
|
||||
chunk_samples = int(self.sample_rate * 0.05) # 50ms chunks
|
||||
chunk_bytes = chunk_samples * 2 # 16-bit = 2 bytes per sample
|
||||
|
||||
def output_callback(outdata, frames, time_info, status):
|
||||
"""Audio output callback."""
|
||||
if status:
|
||||
print(f"Output status: {status}")
|
||||
|
||||
bytes_needed = frames * 2
|
||||
with self.audio_output_lock:
|
||||
if len(self.audio_output_buffer) >= bytes_needed:
|
||||
audio_data = self.audio_output_buffer[:bytes_needed]
|
||||
self.audio_output_buffer = self.audio_output_buffer[bytes_needed:]
|
||||
samples = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32767.0
|
||||
outdata[:, 0] = samples
|
||||
else:
|
||||
outdata.fill(0)
|
||||
|
||||
# Create and start output stream
|
||||
try:
|
||||
output_stream = sd.OutputStream(
|
||||
samplerate=self.sample_rate,
|
||||
channels=1,
|
||||
dtype=np.float32,
|
||||
blocksize=chunk_samples,
|
||||
device=self.output_device,
|
||||
callback=output_callback,
|
||||
latency='low'
|
||||
)
|
||||
output_stream.start()
|
||||
print(f"Audio output stream started (device: {self.output_device or 'default'})")
|
||||
|
||||
# Keep stream running while client is active
|
||||
while self.running:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
output_stream.stop()
|
||||
output_stream.close()
|
||||
|
||||
except Exception as e:
|
||||
print(f"Playback error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
async def audio_sender(self) -> None:
|
||||
"""Send audio from microphone to server."""
|
||||
while self.running:
|
||||
try:
|
||||
# Get audio from queue with timeout
|
||||
try:
|
||||
audio_data = await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: self.audio_input_queue.get(timeout=0.1)
|
||||
)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
# Send to server
|
||||
if self.ws and self.is_recording:
|
||||
await self.ws.send(audio_data)
|
||||
self.bytes_sent += len(audio_data)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Audio sender error: {e}")
|
||||
break
|
||||
|
||||
async def receiver(self) -> None:
|
||||
"""Receive messages from server."""
|
||||
try:
|
||||
while self.running:
|
||||
try:
|
||||
message = await asyncio.wait_for(self.ws.recv(), timeout=0.1)
|
||||
|
||||
if isinstance(message, bytes):
|
||||
# Audio data received
|
||||
self.bytes_received += len(message)
|
||||
|
||||
if self.is_playing:
|
||||
self._add_audio_to_buffer(message)
|
||||
|
||||
# Show progress (less verbose)
|
||||
with self.audio_output_lock:
|
||||
buffer_ms = len(self.audio_output_buffer) / (self.sample_rate * 2) * 1000
|
||||
duration_ms = len(message) / (self.sample_rate * 2) * 1000
|
||||
print(f"← Audio: {duration_ms:.0f}ms (buffer: {buffer_ms:.0f}ms)")
|
||||
|
||||
else:
|
||||
# JSON event
|
||||
event = json.loads(message)
|
||||
await self._handle_event(event)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except websockets.ConnectionClosed:
|
||||
print("Connection closed")
|
||||
self.running = False
|
||||
break
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"Receiver error: {e}")
|
||||
self.running = False
|
||||
|
||||
async def _handle_event(self, event: dict) -> None:
|
||||
"""Handle incoming event."""
|
||||
event_type = event.get("event", "unknown")
|
||||
|
||||
if event_type == "answer":
|
||||
print("← Session ready!")
|
||||
elif event_type == "speaking":
|
||||
print("← User speech detected")
|
||||
elif event_type == "silence":
|
||||
print("← User silence detected")
|
||||
elif event_type == "trackStart":
|
||||
print("← Bot started speaking")
|
||||
# Clear any old audio in buffer
|
||||
with self.audio_output_lock:
|
||||
self.audio_output_buffer = b""
|
||||
elif event_type == "trackEnd":
|
||||
print("← Bot finished speaking")
|
||||
elif event_type == "interrupt":
|
||||
print("← Bot interrupted!")
|
||||
elif event_type == "error":
|
||||
print(f"← Error: {event.get('error')}")
|
||||
elif event_type == "hangup":
|
||||
print(f"← Hangup: {event.get('reason')}")
|
||||
self.running = False
|
||||
else:
|
||||
print(f"← Event: {event_type}")
|
||||
|
||||
async def interactive_mode(self) -> None:
|
||||
"""Run interactive mode for text chat."""
|
||||
print("\n" + "=" * 50)
|
||||
print("Voice Conversation Client")
|
||||
print("=" * 50)
|
||||
print("Speak into your microphone to talk to the AI.")
|
||||
print("Or type messages to send text.")
|
||||
print("")
|
||||
print("Commands:")
|
||||
print(" /quit - End conversation")
|
||||
print(" /mute - Mute microphone")
|
||||
print(" /unmute - Unmute microphone")
|
||||
print(" /interrupt - Interrupt AI speech")
|
||||
print(" /stats - Show statistics")
|
||||
print("=" * 50 + "\n")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
user_input = await asyncio.get_event_loop().run_in_executor(
|
||||
None, input, ""
|
||||
)
|
||||
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
# Handle commands
|
||||
if user_input.startswith("/"):
|
||||
cmd = user_input.lower().strip()
|
||||
|
||||
if cmd == "/quit":
|
||||
await self.send_hangup("User quit")
|
||||
break
|
||||
elif cmd == "/mute":
|
||||
self.is_recording = False
|
||||
print("Microphone muted")
|
||||
elif cmd == "/unmute":
|
||||
self.is_recording = True
|
||||
print("Microphone unmuted")
|
||||
elif cmd == "/interrupt":
|
||||
await self.send_interrupt()
|
||||
elif cmd == "/stats":
|
||||
print(f"Sent: {self.bytes_sent / 1024:.1f} KB")
|
||||
print(f"Received: {self.bytes_received / 1024:.1f} KB")
|
||||
else:
|
||||
print(f"Unknown command: {cmd}")
|
||||
else:
|
||||
# Send as chat message
|
||||
await self.send_chat(user_input)
|
||||
|
||||
except EOFError:
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Input error: {e}")
|
||||
|
||||
async def run(self, chat_message: str = None, interactive: bool = True) -> None:
|
||||
"""
|
||||
Run the client.
|
||||
|
||||
Args:
|
||||
chat_message: Optional single chat message to send
|
||||
interactive: Whether to run in interactive mode
|
||||
"""
|
||||
try:
|
||||
await self.connect()
|
||||
|
||||
# Wait for answer
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Start audio input stream
|
||||
print("Starting audio streams...")
|
||||
|
||||
input_stream = sd.InputStream(
|
||||
samplerate=self.sample_rate,
|
||||
channels=1,
|
||||
dtype=np.float32,
|
||||
blocksize=self.chunk_samples,
|
||||
device=self.input_device,
|
||||
callback=self._audio_input_callback
|
||||
)
|
||||
|
||||
await ws.send_bytes(data)
|
||||
# No sleep needed here as microphone dictates the timing
|
||||
input_stream.start()
|
||||
print("Audio streams started")
|
||||
|
||||
# Start background tasks
|
||||
sender_task = asyncio.create_task(self.audio_sender())
|
||||
receiver_task = asyncio.create_task(self.receiver())
|
||||
playback_task = asyncio.create_task(self._playback_task())
|
||||
|
||||
if chat_message:
|
||||
# Send single message and wait
|
||||
await self.send_chat(chat_message)
|
||||
await asyncio.sleep(15)
|
||||
elif interactive:
|
||||
# Run interactive mode
|
||||
await self.interactive_mode()
|
||||
else:
|
||||
# Just wait
|
||||
while self.running:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Cleanup
|
||||
self.running = False
|
||||
sender_task.cancel()
|
||||
receiver_task.cancel()
|
||||
playback_task.cancel()
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error in send loop: {e}")
|
||||
|
||||
async def receive_loop(ws):
|
||||
"""Listen for VAD/EOU events."""
|
||||
print("👂 Listening for server events...")
|
||||
async for msg in ws:
|
||||
timestamp = datetime.now().strftime("%H:%M:%S.%f")[:-3]
|
||||
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
try:
|
||||
data = json.loads(msg.data)
|
||||
event = data.get('event')
|
||||
|
||||
# Highlight VAD/EOU events
|
||||
if event == 'speaking':
|
||||
print(f"[{timestamp}] 🗣️ SPEAKING STARTED")
|
||||
elif event == 'silence':
|
||||
print(f"[{timestamp}] 🤫 SILENCE DETECTED")
|
||||
elif event == 'eou':
|
||||
print(f"[{timestamp}] ✅ END OF UTTERANCE (EOU)")
|
||||
elif event == 'error':
|
||||
print(f"[{timestamp}] ❌ ERROR: {data.get('error')}")
|
||||
else:
|
||||
print(f"[{timestamp}] 📩 {event}: {str(data)[:100]}")
|
||||
|
||||
except json.JSONDecodeError:
|
||||
print(f"[{timestamp}] 📄 Text: {msg.data}")
|
||||
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
||||
print("❌ Connection closed")
|
||||
break
|
||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||
print("❌ Connection error")
|
||||
break
|
||||
await sender_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
try:
|
||||
await receiver_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
try:
|
||||
await playback_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
input_stream.stop()
|
||||
|
||||
except ConnectionRefusedError:
|
||||
print(f"Error: Could not connect to {self.url}")
|
||||
print("Make sure the server is running.")
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
finally:
|
||||
await self.close()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the connection."""
|
||||
self.running = False
|
||||
if self.ws:
|
||||
await self.ws.close()
|
||||
|
||||
print(f"\nSession ended")
|
||||
print(f" Total sent: {self.bytes_sent / 1024:.1f} KB")
|
||||
print(f" Total received: {self.bytes_received / 1024:.1f} KB")
|
||||
|
||||
|
||||
def list_devices():
|
||||
"""List available audio devices."""
|
||||
print("\nAvailable audio devices:")
|
||||
print("-" * 60)
|
||||
devices = sd.query_devices()
|
||||
for i, device in enumerate(devices):
|
||||
direction = []
|
||||
if device['max_input_channels'] > 0:
|
||||
direction.append("IN")
|
||||
if device['max_output_channels'] > 0:
|
||||
direction.append("OUT")
|
||||
direction_str = "/".join(direction) if direction else "N/A"
|
||||
|
||||
default = ""
|
||||
if i == sd.default.device[0]:
|
||||
default += " [DEFAULT INPUT]"
|
||||
if i == sd.default.device[1]:
|
||||
default += " [DEFAULT OUTPUT]"
|
||||
|
||||
print(f" {i:2d}: {device['name'][:40]:40s} ({direction_str}){default}")
|
||||
print("-" * 60)
|
||||
|
||||
|
||||
async def main():
|
||||
p = pyaudio.PyAudio()
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Microphone client for duplex voice conversation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--url",
|
||||
default="ws://localhost:8000/ws",
|
||||
help="WebSocket server URL"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chat",
|
||||
help="Send a single chat message instead of using microphone"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="Audio sample rate (default: 16000)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-device",
|
||||
type=int,
|
||||
help="Input device ID"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-device",
|
||||
type=int,
|
||||
help="Output device ID"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--list-devices",
|
||||
action="store_true",
|
||||
help="List available audio devices and exit"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-interactive",
|
||||
action="store_true",
|
||||
help="Disable interactive mode"
|
||||
)
|
||||
|
||||
# Check for input devices
|
||||
info = p.get_host_api_info_by_index(0)
|
||||
numdevices = info.get('deviceCount')
|
||||
if numdevices == 0:
|
||||
print("❌ No audio input devices found")
|
||||
return
|
||||
|
||||
# Open microphone stream
|
||||
try:
|
||||
stream = p.open(format=FORMAT,
|
||||
channels=CHANNELS,
|
||||
rate=SAMPLE_RATE,
|
||||
input=True,
|
||||
frames_per_buffer=CHUNK_SIZE)
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to open microphone: {e}")
|
||||
return
|
||||
|
||||
session = aiohttp.ClientSession()
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
print(f"🔌 Connecting to {SERVER_URL}...")
|
||||
async with session.ws_connect(SERVER_URL) as ws:
|
||||
print("✅ Connected!")
|
||||
if args.list_devices:
|
||||
list_devices()
|
||||
return
|
||||
|
||||
client = MicrophoneClient(
|
||||
url=args.url,
|
||||
sample_rate=args.sample_rate,
|
||||
input_device=args.input_device,
|
||||
output_device=args.output_device
|
||||
)
|
||||
|
||||
await client.run(
|
||||
chat_message=args.chat,
|
||||
interactive=not args.no_interactive
|
||||
)
|
||||
|
||||
# 1. Send Invite
|
||||
invite_msg = {
|
||||
"command": "invite",
|
||||
"option": {
|
||||
"codec": "pcm",
|
||||
"samplerate": SAMPLE_RATE
|
||||
}
|
||||
}
|
||||
await ws.send_json(invite_msg)
|
||||
print("📤 Sent Invite")
|
||||
|
||||
# 2. Run loops
|
||||
await asyncio.gather(
|
||||
receive_loop(ws),
|
||||
send_audio_loop(ws, stream)
|
||||
)
|
||||
|
||||
except aiohttp.ClientConnectorError:
|
||||
print(f"❌ Failed to connect to {SERVER_URL}. Is the server running?")
|
||||
except KeyboardInterrupt:
|
||||
print("\n👋 Stopping...")
|
||||
finally:
|
||||
stream.stop_stream()
|
||||
stream.close()
|
||||
p.terminate()
|
||||
await session.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
print("\nInterrupted by user")
|
||||
|
||||
239
examples/simple_client.py
Normal file
239
examples/simple_client.py
Normal file
@@ -0,0 +1,239 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Simple WebSocket client for testing voice conversation.
|
||||
Uses PyAudio for more reliable audio playback on Windows.
|
||||
|
||||
Usage:
|
||||
python examples/simple_client.py
|
||||
python examples/simple_client.py --text "Hello"
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import wave
|
||||
import io
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
print("pip install numpy")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
import websockets
|
||||
except ImportError:
|
||||
print("pip install websockets")
|
||||
sys.exit(1)
|
||||
|
||||
# Try PyAudio first (more reliable on Windows)
|
||||
try:
|
||||
import pyaudio
|
||||
PYAUDIO_AVAILABLE = True
|
||||
except ImportError:
|
||||
PYAUDIO_AVAILABLE = False
|
||||
print("PyAudio not available, trying sounddevice...")
|
||||
|
||||
try:
|
||||
import sounddevice as sd
|
||||
SD_AVAILABLE = True
|
||||
except ImportError:
|
||||
SD_AVAILABLE = False
|
||||
|
||||
if not PYAUDIO_AVAILABLE and not SD_AVAILABLE:
|
||||
print("Please install pyaudio or sounddevice:")
|
||||
print(" pip install pyaudio")
|
||||
print(" or: pip install sounddevice")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
class SimpleVoiceClient:
|
||||
"""Simple voice client with reliable audio playback."""
|
||||
|
||||
def __init__(self, url: str, sample_rate: int = 16000):
|
||||
self.url = url
|
||||
self.sample_rate = sample_rate
|
||||
self.ws = None
|
||||
self.running = False
|
||||
|
||||
# Audio buffer
|
||||
self.audio_buffer = b""
|
||||
|
||||
# PyAudio setup
|
||||
if PYAUDIO_AVAILABLE:
|
||||
self.pa = pyaudio.PyAudio()
|
||||
self.stream = None
|
||||
|
||||
# Stats
|
||||
self.bytes_received = 0
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to server."""
|
||||
print(f"Connecting to {self.url}...")
|
||||
self.ws = await websockets.connect(self.url)
|
||||
self.running = True
|
||||
print("Connected!")
|
||||
|
||||
# Send invite
|
||||
await self.ws.send(json.dumps({
|
||||
"command": "invite",
|
||||
"option": {"codec": "pcm", "sampleRate": self.sample_rate}
|
||||
}))
|
||||
print("-> invite")
|
||||
|
||||
async def send_chat(self, text: str):
|
||||
"""Send chat message."""
|
||||
await self.ws.send(json.dumps({"command": "chat", "text": text}))
|
||||
print(f"-> chat: {text}")
|
||||
|
||||
def play_audio(self, audio_data: bytes):
|
||||
"""Play audio data immediately."""
|
||||
if len(audio_data) == 0:
|
||||
return
|
||||
|
||||
if PYAUDIO_AVAILABLE:
|
||||
# Use PyAudio - more reliable on Windows
|
||||
if self.stream is None:
|
||||
self.stream = self.pa.open(
|
||||
format=pyaudio.paInt16,
|
||||
channels=1,
|
||||
rate=self.sample_rate,
|
||||
output=True,
|
||||
frames_per_buffer=1024
|
||||
)
|
||||
self.stream.write(audio_data)
|
||||
elif SD_AVAILABLE:
|
||||
# Use sounddevice
|
||||
samples = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32767.0
|
||||
sd.play(samples, self.sample_rate, blocking=True)
|
||||
|
||||
async def receive_loop(self):
|
||||
"""Receive and play audio."""
|
||||
print("\nWaiting for response...")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
msg = await asyncio.wait_for(self.ws.recv(), timeout=0.1)
|
||||
|
||||
if isinstance(msg, bytes):
|
||||
# Audio data
|
||||
self.bytes_received += len(msg)
|
||||
duration_ms = len(msg) / (self.sample_rate * 2) * 1000
|
||||
print(f"<- audio: {len(msg)} bytes ({duration_ms:.0f}ms)")
|
||||
|
||||
# Play immediately in executor to not block
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, self.play_audio, msg)
|
||||
else:
|
||||
# JSON event
|
||||
event = json.loads(msg)
|
||||
etype = event.get("event", "?")
|
||||
print(f"<- {etype}")
|
||||
|
||||
if etype == "hangup":
|
||||
self.running = False
|
||||
break
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except websockets.ConnectionClosed:
|
||||
print("Connection closed")
|
||||
self.running = False
|
||||
break
|
||||
|
||||
async def run(self, text: str = None):
|
||||
"""Run the client."""
|
||||
try:
|
||||
await self.connect()
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Start receiver
|
||||
recv_task = asyncio.create_task(self.receive_loop())
|
||||
|
||||
if text:
|
||||
await self.send_chat(text)
|
||||
# Wait for response
|
||||
await asyncio.sleep(30)
|
||||
else:
|
||||
# Interactive mode
|
||||
print("\nType a message and press Enter (or 'quit' to exit):")
|
||||
while self.running:
|
||||
try:
|
||||
user_input = await asyncio.get_event_loop().run_in_executor(
|
||||
None, input, "> "
|
||||
)
|
||||
if user_input.lower() == 'quit':
|
||||
break
|
||||
if user_input.strip():
|
||||
await self.send_chat(user_input)
|
||||
except EOFError:
|
||||
break
|
||||
|
||||
self.running = False
|
||||
recv_task.cancel()
|
||||
try:
|
||||
await recv_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
finally:
|
||||
await self.close()
|
||||
|
||||
async def close(self):
|
||||
"""Close connections."""
|
||||
self.running = False
|
||||
|
||||
if PYAUDIO_AVAILABLE:
|
||||
if self.stream:
|
||||
self.stream.stop_stream()
|
||||
self.stream.close()
|
||||
self.pa.terminate()
|
||||
|
||||
if self.ws:
|
||||
await self.ws.close()
|
||||
|
||||
print(f"\nTotal audio received: {self.bytes_received / 1024:.1f} KB")
|
||||
|
||||
|
||||
def list_audio_devices():
|
||||
"""List available audio devices."""
|
||||
print("\n=== Audio Devices ===")
|
||||
|
||||
if PYAUDIO_AVAILABLE:
|
||||
pa = pyaudio.PyAudio()
|
||||
print("\nPyAudio devices:")
|
||||
for i in range(pa.get_device_count()):
|
||||
info = pa.get_device_info_by_index(i)
|
||||
if info['maxOutputChannels'] > 0:
|
||||
default = " [DEFAULT]" if i == pa.get_default_output_device_info()['index'] else ""
|
||||
print(f" {i}: {info['name']}{default}")
|
||||
pa.terminate()
|
||||
|
||||
if SD_AVAILABLE:
|
||||
print("\nSounddevice devices:")
|
||||
for i, d in enumerate(sd.query_devices()):
|
||||
if d['max_output_channels'] > 0:
|
||||
default = " [DEFAULT]" if i == sd.default.device[1] else ""
|
||||
print(f" {i}: {d['name']}{default}")
|
||||
|
||||
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser(description="Simple voice client")
|
||||
parser.add_argument("--url", default="ws://localhost:8000/ws")
|
||||
parser.add_argument("--text", help="Send text and play response")
|
||||
parser.add_argument("--list-devices", action="store_true")
|
||||
parser.add_argument("--sample-rate", type=int, default=16000)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.list_devices:
|
||||
list_audio_devices()
|
||||
return
|
||||
|
||||
client = SimpleVoiceClient(args.url, args.sample_rate)
|
||||
await client.run(args.text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,37 @@
|
||||
# Web Framework
|
||||
fastapi>=0.109.0
|
||||
uvicorn[standard]>=0.27.0
|
||||
websockets>=12.0
|
||||
python-multipart>=0.0.6
|
||||
|
||||
# WebRTC (optional - for WebRTC transport)
|
||||
aiortc>=1.6.0
|
||||
|
||||
# Audio Processing
|
||||
av>=12.1.0
|
||||
numpy>=1.26.3
|
||||
onnxruntime>=1.16.3
|
||||
|
||||
# Configuration
|
||||
pydantic>=2.5.3
|
||||
pydantic-settings>=2.1.0
|
||||
python-dotenv>=1.0.0
|
||||
toml>=0.10.2
|
||||
|
||||
# Logging
|
||||
loguru>=0.7.2
|
||||
|
||||
# HTTP Client
|
||||
aiohttp>=3.9.1
|
||||
|
||||
# AI Services - LLM
|
||||
openai>=1.0.0
|
||||
|
||||
# AI Services - TTS
|
||||
edge-tts>=6.1.0
|
||||
pydub>=0.25.0 # For audio format conversion
|
||||
|
||||
# Microphone client dependencies
|
||||
sounddevice>=0.4.6
|
||||
soundfile>=0.12.1
|
||||
pyaudio>=0.2.13 # More reliable audio on Windows
|
||||
|
||||
42
services/__init__.py
Normal file
42
services/__init__.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""AI Services package.
|
||||
|
||||
Provides ASR, LLM, TTS, and Realtime API services for voice conversation.
|
||||
"""
|
||||
|
||||
from services.base import (
|
||||
ServiceState,
|
||||
ASRResult,
|
||||
LLMMessage,
|
||||
TTSChunk,
|
||||
BaseASRService,
|
||||
BaseLLMService,
|
||||
BaseTTSService,
|
||||
)
|
||||
from services.llm import OpenAILLMService, MockLLMService
|
||||
from services.tts import EdgeTTSService, MockTTSService
|
||||
from services.asr import BufferedASRService, MockASRService
|
||||
from services.realtime import RealtimeService, RealtimeConfig, RealtimePipeline
|
||||
|
||||
__all__ = [
|
||||
# Base classes
|
||||
"ServiceState",
|
||||
"ASRResult",
|
||||
"LLMMessage",
|
||||
"TTSChunk",
|
||||
"BaseASRService",
|
||||
"BaseLLMService",
|
||||
"BaseTTSService",
|
||||
# LLM
|
||||
"OpenAILLMService",
|
||||
"MockLLMService",
|
||||
# TTS
|
||||
"EdgeTTSService",
|
||||
"MockTTSService",
|
||||
# ASR
|
||||
"BufferedASRService",
|
||||
"MockASRService",
|
||||
# Realtime
|
||||
"RealtimeService",
|
||||
"RealtimeConfig",
|
||||
"RealtimePipeline",
|
||||
]
|
||||
147
services/asr.py
Normal file
147
services/asr.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""ASR (Automatic Speech Recognition) Service implementations.
|
||||
|
||||
Provides speech-to-text capabilities with streaming support.
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
from typing import AsyncIterator, Optional
|
||||
from loguru import logger
|
||||
|
||||
from services.base import BaseASRService, ASRResult, ServiceState
|
||||
|
||||
# Try to import websockets for streaming ASR
|
||||
try:
|
||||
import websockets
|
||||
WEBSOCKETS_AVAILABLE = True
|
||||
except ImportError:
|
||||
WEBSOCKETS_AVAILABLE = False
|
||||
|
||||
|
||||
class BufferedASRService(BaseASRService):
|
||||
"""
|
||||
Buffered ASR service that accumulates audio and provides
|
||||
a simple text accumulator for use with EOU detection.
|
||||
|
||||
This is a lightweight implementation that works with the
|
||||
existing VAD + EOU pattern without requiring external ASR.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sample_rate: int = 16000,
|
||||
language: str = "en"
|
||||
):
|
||||
super().__init__(sample_rate=sample_rate, language=language)
|
||||
|
||||
self._audio_buffer: bytes = b""
|
||||
self._current_text: str = ""
|
||||
self._transcript_queue: asyncio.Queue[ASRResult] = asyncio.Queue()
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""No connection needed for buffered ASR."""
|
||||
self.state = ServiceState.CONNECTED
|
||||
logger.info("Buffered ASR service connected")
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Clear buffers on disconnect."""
|
||||
self._audio_buffer = b""
|
||||
self._current_text = ""
|
||||
self.state = ServiceState.DISCONNECTED
|
||||
logger.info("Buffered ASR service disconnected")
|
||||
|
||||
async def send_audio(self, audio: bytes) -> None:
|
||||
"""Buffer audio for later processing."""
|
||||
self._audio_buffer += audio
|
||||
|
||||
async def receive_transcripts(self) -> AsyncIterator[ASRResult]:
|
||||
"""Yield transcription results."""
|
||||
while True:
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
self._transcript_queue.get(),
|
||||
timeout=0.1
|
||||
)
|
||||
yield result
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
def set_text(self, text: str) -> None:
|
||||
"""
|
||||
Set the current transcript text directly.
|
||||
|
||||
This allows external integration (e.g., Whisper, other ASR)
|
||||
to provide transcripts.
|
||||
"""
|
||||
self._current_text = text
|
||||
result = ASRResult(text=text, is_final=False)
|
||||
asyncio.create_task(self._transcript_queue.put(result))
|
||||
|
||||
def get_and_clear_text(self) -> str:
|
||||
"""Get accumulated text and clear buffer."""
|
||||
text = self._current_text
|
||||
self._current_text = ""
|
||||
self._audio_buffer = b""
|
||||
return text
|
||||
|
||||
def get_audio_buffer(self) -> bytes:
|
||||
"""Get accumulated audio buffer."""
|
||||
return self._audio_buffer
|
||||
|
||||
def clear_audio_buffer(self) -> None:
|
||||
"""Clear audio buffer."""
|
||||
self._audio_buffer = b""
|
||||
|
||||
|
||||
class MockASRService(BaseASRService):
|
||||
"""
|
||||
Mock ASR service for testing without actual recognition.
|
||||
"""
|
||||
|
||||
def __init__(self, sample_rate: int = 16000, language: str = "en"):
|
||||
super().__init__(sample_rate=sample_rate, language=language)
|
||||
self._transcript_queue: asyncio.Queue[ASRResult] = asyncio.Queue()
|
||||
self._mock_texts = [
|
||||
"Hello, how are you?",
|
||||
"That's interesting.",
|
||||
"Tell me more about that.",
|
||||
"I understand.",
|
||||
]
|
||||
self._text_index = 0
|
||||
|
||||
async def connect(self) -> None:
|
||||
self.state = ServiceState.CONNECTED
|
||||
logger.info("Mock ASR service connected")
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
self.state = ServiceState.DISCONNECTED
|
||||
logger.info("Mock ASR service disconnected")
|
||||
|
||||
async def send_audio(self, audio: bytes) -> None:
|
||||
"""Mock audio processing - generates fake transcripts periodically."""
|
||||
pass
|
||||
|
||||
def trigger_transcript(self) -> None:
|
||||
"""Manually trigger a transcript (for testing)."""
|
||||
text = self._mock_texts[self._text_index % len(self._mock_texts)]
|
||||
self._text_index += 1
|
||||
|
||||
result = ASRResult(text=text, is_final=True, confidence=0.95)
|
||||
asyncio.create_task(self._transcript_queue.put(result))
|
||||
|
||||
async def receive_transcripts(self) -> AsyncIterator[ASRResult]:
|
||||
"""Yield transcription results."""
|
||||
while True:
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
self._transcript_queue.get(),
|
||||
timeout=0.1
|
||||
)
|
||||
yield result
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
244
services/base.py
Normal file
244
services/base.py
Normal file
@@ -0,0 +1,244 @@
|
||||
"""Base classes for AI services.
|
||||
|
||||
Defines abstract interfaces for ASR, LLM, and TTS services,
|
||||
inspired by pipecat's service architecture and active-call's
|
||||
StreamEngine pattern.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import AsyncIterator, Optional, List, Dict, Any
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ServiceState(Enum):
|
||||
"""Service connection state."""
|
||||
DISCONNECTED = "disconnected"
|
||||
CONNECTING = "connecting"
|
||||
CONNECTED = "connected"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ASRResult:
|
||||
"""ASR transcription result."""
|
||||
text: str
|
||||
is_final: bool = False
|
||||
confidence: float = 1.0
|
||||
language: Optional[str] = None
|
||||
start_time: Optional[float] = None
|
||||
end_time: Optional[float] = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
status = "FINAL" if self.is_final else "PARTIAL"
|
||||
return f"[{status}] {self.text}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMMessage:
|
||||
"""LLM conversation message."""
|
||||
role: str # "system", "user", "assistant", "function"
|
||||
content: str
|
||||
name: Optional[str] = None # For function calls
|
||||
function_call: Optional[Dict[str, Any]] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to API-compatible dict."""
|
||||
d = {"role": self.role, "content": self.content}
|
||||
if self.name:
|
||||
d["name"] = self.name
|
||||
if self.function_call:
|
||||
d["function_call"] = self.function_call
|
||||
return d
|
||||
|
||||
|
||||
@dataclass
|
||||
class TTSChunk:
|
||||
"""TTS audio chunk."""
|
||||
audio: bytes # PCM audio data
|
||||
sample_rate: int = 16000
|
||||
channels: int = 1
|
||||
bits_per_sample: int = 16
|
||||
is_final: bool = False
|
||||
text_offset: Optional[int] = None # Character offset in original text
|
||||
|
||||
|
||||
class BaseASRService(ABC):
|
||||
"""
|
||||
Abstract base class for ASR (Speech-to-Text) services.
|
||||
|
||||
Supports both streaming and non-streaming transcription.
|
||||
"""
|
||||
|
||||
def __init__(self, sample_rate: int = 16000, language: str = "en"):
|
||||
self.sample_rate = sample_rate
|
||||
self.language = language
|
||||
self.state = ServiceState.DISCONNECTED
|
||||
|
||||
@abstractmethod
|
||||
async def connect(self) -> None:
|
||||
"""Establish connection to ASR service."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def disconnect(self) -> None:
|
||||
"""Close connection to ASR service."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def send_audio(self, audio: bytes) -> None:
|
||||
"""
|
||||
Send audio chunk for transcription.
|
||||
|
||||
Args:
|
||||
audio: PCM audio data (16-bit, mono)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def receive_transcripts(self) -> AsyncIterator[ASRResult]:
|
||||
"""
|
||||
Receive transcription results.
|
||||
|
||||
Yields:
|
||||
ASRResult objects as they become available
|
||||
"""
|
||||
pass
|
||||
|
||||
async def transcribe(self, audio: bytes) -> ASRResult:
|
||||
"""
|
||||
Transcribe a complete audio buffer (non-streaming).
|
||||
|
||||
Args:
|
||||
audio: Complete PCM audio data
|
||||
|
||||
Returns:
|
||||
Final ASRResult
|
||||
"""
|
||||
# Default implementation using streaming
|
||||
await self.send_audio(audio)
|
||||
async for result in self.receive_transcripts():
|
||||
if result.is_final:
|
||||
return result
|
||||
return ASRResult(text="", is_final=True)
|
||||
|
||||
|
||||
class BaseLLMService(ABC):
|
||||
"""
|
||||
Abstract base class for LLM (Language Model) services.
|
||||
|
||||
Supports streaming responses for real-time conversation.
|
||||
"""
|
||||
|
||||
def __init__(self, model: str = "gpt-4"):
|
||||
self.model = model
|
||||
self.state = ServiceState.DISCONNECTED
|
||||
|
||||
@abstractmethod
|
||||
async def connect(self) -> None:
|
||||
"""Initialize LLM service connection."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def disconnect(self) -> None:
|
||||
"""Close LLM service connection."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def generate(
|
||||
self,
|
||||
messages: List[LLMMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None
|
||||
) -> str:
|
||||
"""
|
||||
Generate a complete response.
|
||||
|
||||
Args:
|
||||
messages: Conversation history
|
||||
temperature: Sampling temperature
|
||||
max_tokens: Maximum tokens to generate
|
||||
|
||||
Returns:
|
||||
Complete response text
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def generate_stream(
|
||||
self,
|
||||
messages: List[LLMMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None
|
||||
) -> AsyncIterator[str]:
|
||||
"""
|
||||
Generate response in streaming mode.
|
||||
|
||||
Args:
|
||||
messages: Conversation history
|
||||
temperature: Sampling temperature
|
||||
max_tokens: Maximum tokens to generate
|
||||
|
||||
Yields:
|
||||
Text chunks as they are generated
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class BaseTTSService(ABC):
|
||||
"""
|
||||
Abstract base class for TTS (Text-to-Speech) services.
|
||||
|
||||
Supports streaming audio synthesis for low-latency playback.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
voice: str = "default",
|
||||
sample_rate: int = 16000,
|
||||
speed: float = 1.0
|
||||
):
|
||||
self.voice = voice
|
||||
self.sample_rate = sample_rate
|
||||
self.speed = speed
|
||||
self.state = ServiceState.DISCONNECTED
|
||||
|
||||
@abstractmethod
|
||||
async def connect(self) -> None:
|
||||
"""Initialize TTS service connection."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def disconnect(self) -> None:
|
||||
"""Close TTS service connection."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def synthesize(self, text: str) -> bytes:
|
||||
"""
|
||||
Synthesize complete audio for text (non-streaming).
|
||||
|
||||
Args:
|
||||
text: Text to synthesize
|
||||
|
||||
Returns:
|
||||
Complete PCM audio data
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def synthesize_stream(self, text: str) -> AsyncIterator[TTSChunk]:
|
||||
"""
|
||||
Synthesize audio in streaming mode.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize
|
||||
|
||||
Yields:
|
||||
TTSChunk objects as audio is generated
|
||||
"""
|
||||
pass
|
||||
|
||||
async def cancel(self) -> None:
|
||||
"""Cancel ongoing synthesis (for barge-in support)."""
|
||||
pass
|
||||
239
services/llm.py
Normal file
239
services/llm.py
Normal file
@@ -0,0 +1,239 @@
|
||||
"""LLM (Large Language Model) Service implementations.
|
||||
|
||||
Provides OpenAI-compatible LLM integration with streaming support
|
||||
for real-time voice conversation.
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
from typing import AsyncIterator, Optional, List, Dict, Any
|
||||
from loguru import logger
|
||||
|
||||
from services.base import BaseLLMService, LLMMessage, ServiceState
|
||||
|
||||
# Try to import openai
|
||||
try:
|
||||
from openai import AsyncOpenAI
|
||||
OPENAI_AVAILABLE = True
|
||||
except ImportError:
|
||||
OPENAI_AVAILABLE = False
|
||||
logger.warning("openai package not available - LLM service will be disabled")
|
||||
|
||||
|
||||
class OpenAILLMService(BaseLLMService):
|
||||
"""
|
||||
OpenAI-compatible LLM service.
|
||||
|
||||
Supports streaming responses for low-latency voice conversation.
|
||||
Works with OpenAI API, Azure OpenAI, and compatible APIs.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "gpt-4o-mini",
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
system_prompt: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize OpenAI LLM service.
|
||||
|
||||
Args:
|
||||
model: Model name (e.g., "gpt-4o-mini", "gpt-4o")
|
||||
api_key: OpenAI API key (defaults to OPENAI_API_KEY env var)
|
||||
base_url: Custom API base URL (for Azure or compatible APIs)
|
||||
system_prompt: Default system prompt for conversations
|
||||
"""
|
||||
super().__init__(model=model)
|
||||
|
||||
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||
self.base_url = base_url or os.getenv("OPENAI_API_URL")
|
||||
self.system_prompt = system_prompt or (
|
||||
"You are a helpful, friendly voice assistant. "
|
||||
"Keep your responses concise and conversational. "
|
||||
"Respond naturally as if having a phone conversation."
|
||||
)
|
||||
|
||||
self.client: Optional[AsyncOpenAI] = None
|
||||
self._cancel_event = asyncio.Event()
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Initialize OpenAI client."""
|
||||
if not OPENAI_AVAILABLE:
|
||||
raise RuntimeError("openai package not installed")
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("OpenAI API key not provided")
|
||||
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url
|
||||
)
|
||||
self.state = ServiceState.CONNECTED
|
||||
logger.info(f"OpenAI LLM service connected: model={self.model}")
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Close OpenAI client."""
|
||||
if self.client:
|
||||
await self.client.close()
|
||||
self.client = None
|
||||
self.state = ServiceState.DISCONNECTED
|
||||
logger.info("OpenAI LLM service disconnected")
|
||||
|
||||
def _prepare_messages(self, messages: List[LLMMessage]) -> List[Dict[str, Any]]:
|
||||
"""Prepare messages list with system prompt."""
|
||||
result = []
|
||||
|
||||
# Add system prompt if not already present
|
||||
has_system = any(m.role == "system" for m in messages)
|
||||
if not has_system and self.system_prompt:
|
||||
result.append({"role": "system", "content": self.system_prompt})
|
||||
|
||||
# Add all messages
|
||||
for msg in messages:
|
||||
result.append(msg.to_dict())
|
||||
|
||||
return result
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
messages: List[LLMMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None
|
||||
) -> str:
|
||||
"""
|
||||
Generate a complete response.
|
||||
|
||||
Args:
|
||||
messages: Conversation history
|
||||
temperature: Sampling temperature
|
||||
max_tokens: Maximum tokens to generate
|
||||
|
||||
Returns:
|
||||
Complete response text
|
||||
"""
|
||||
if not self.client:
|
||||
raise RuntimeError("LLM service not connected")
|
||||
|
||||
prepared = self._prepare_messages(messages)
|
||||
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=prepared,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content or ""
|
||||
logger.debug(f"LLM response: {content[:100]}...")
|
||||
return content
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM generation error: {e}")
|
||||
raise
|
||||
|
||||
async def generate_stream(
|
||||
self,
|
||||
messages: List[LLMMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None
|
||||
) -> AsyncIterator[str]:
|
||||
"""
|
||||
Generate response in streaming mode.
|
||||
|
||||
Args:
|
||||
messages: Conversation history
|
||||
temperature: Sampling temperature
|
||||
max_tokens: Maximum tokens to generate
|
||||
|
||||
Yields:
|
||||
Text chunks as they are generated
|
||||
"""
|
||||
if not self.client:
|
||||
raise RuntimeError("LLM service not connected")
|
||||
|
||||
prepared = self._prepare_messages(messages)
|
||||
self._cancel_event.clear()
|
||||
|
||||
try:
|
||||
stream = await self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=prepared,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=True
|
||||
)
|
||||
|
||||
async for chunk in stream:
|
||||
# Check for cancellation
|
||||
if self._cancel_event.is_set():
|
||||
logger.info("LLM stream cancelled")
|
||||
break
|
||||
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
content = chunk.choices[0].delta.content
|
||||
yield content
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("LLM stream cancelled via asyncio")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"LLM streaming error: {e}")
|
||||
raise
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""Cancel ongoing generation."""
|
||||
self._cancel_event.set()
|
||||
|
||||
|
||||
class MockLLMService(BaseLLMService):
|
||||
"""
|
||||
Mock LLM service for testing without API calls.
|
||||
"""
|
||||
|
||||
def __init__(self, response_delay: float = 0.5):
|
||||
super().__init__(model="mock")
|
||||
self.response_delay = response_delay
|
||||
self.responses = [
|
||||
"Hello! How can I help you today?",
|
||||
"That's an interesting question. Let me think about it.",
|
||||
"I understand. Is there anything else you'd like to know?",
|
||||
"Great! I'm here if you need anything else.",
|
||||
]
|
||||
self._response_index = 0
|
||||
|
||||
async def connect(self) -> None:
|
||||
self.state = ServiceState.CONNECTED
|
||||
logger.info("Mock LLM service connected")
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
self.state = ServiceState.DISCONNECTED
|
||||
logger.info("Mock LLM service disconnected")
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
messages: List[LLMMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None
|
||||
) -> str:
|
||||
await asyncio.sleep(self.response_delay)
|
||||
response = self.responses[self._response_index % len(self.responses)]
|
||||
self._response_index += 1
|
||||
return response
|
||||
|
||||
async def generate_stream(
|
||||
self,
|
||||
messages: List[LLMMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None
|
||||
) -> AsyncIterator[str]:
|
||||
response = await self.generate(messages, temperature, max_tokens)
|
||||
|
||||
# Stream word by word
|
||||
words = response.split()
|
||||
for i, word in enumerate(words):
|
||||
if i > 0:
|
||||
yield " "
|
||||
yield word
|
||||
await asyncio.sleep(0.05) # Simulate streaming delay
|
||||
548
services/realtime.py
Normal file
548
services/realtime.py
Normal file
@@ -0,0 +1,548 @@
|
||||
"""OpenAI Realtime API Service.
|
||||
|
||||
Provides true duplex voice conversation using OpenAI's Realtime API,
|
||||
similar to active-call's RealtimeProcessor. This bypasses the need for
|
||||
separate ASR/LLM/TTS services by handling everything server-side.
|
||||
|
||||
The Realtime API provides:
|
||||
- Server-side VAD with turn detection
|
||||
- Streaming speech-to-text
|
||||
- Streaming LLM responses
|
||||
- Streaming text-to-speech
|
||||
- Function calling support
|
||||
- Barge-in/interruption handling
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
import base64
|
||||
from typing import Optional, Dict, Any, Callable, Awaitable, List
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
import websockets
|
||||
WEBSOCKETS_AVAILABLE = True
|
||||
except ImportError:
|
||||
WEBSOCKETS_AVAILABLE = False
|
||||
logger.warning("websockets not available - Realtime API will be disabled")
|
||||
|
||||
|
||||
class RealtimeState(Enum):
|
||||
"""Realtime API connection state."""
|
||||
DISCONNECTED = "disconnected"
|
||||
CONNECTING = "connecting"
|
||||
CONNECTED = "connected"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RealtimeConfig:
|
||||
"""Configuration for OpenAI Realtime API."""
|
||||
|
||||
# API Configuration
|
||||
api_key: Optional[str] = None
|
||||
model: str = "gpt-4o-realtime-preview"
|
||||
endpoint: Optional[str] = None # For Azure or custom endpoints
|
||||
|
||||
# Voice Configuration
|
||||
voice: str = "alloy" # alloy, echo, shimmer, etc.
|
||||
instructions: str = (
|
||||
"You are a helpful, friendly voice assistant. "
|
||||
"Keep your responses concise and conversational."
|
||||
)
|
||||
|
||||
# Turn Detection (Server-side VAD)
|
||||
turn_detection: Optional[Dict[str, Any]] = field(default_factory=lambda: {
|
||||
"type": "server_vad",
|
||||
"threshold": 0.5,
|
||||
"prefix_padding_ms": 300,
|
||||
"silence_duration_ms": 500
|
||||
})
|
||||
|
||||
# Audio Configuration
|
||||
input_audio_format: str = "pcm16"
|
||||
output_audio_format: str = "pcm16"
|
||||
|
||||
# Tools/Functions
|
||||
tools: List[Dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
class RealtimeService:
|
||||
"""
|
||||
OpenAI Realtime API service for true duplex voice conversation.
|
||||
|
||||
This service handles the entire voice conversation pipeline:
|
||||
1. Audio input → Server-side VAD → Speech-to-text
|
||||
2. Text → LLM processing → Response generation
|
||||
3. Response → Text-to-speech → Audio output
|
||||
|
||||
Events emitted:
|
||||
- on_audio: Audio output from the assistant
|
||||
- on_transcript: Text transcript (user or assistant)
|
||||
- on_speech_started: User started speaking
|
||||
- on_speech_stopped: User stopped speaking
|
||||
- on_response_started: Assistant started responding
|
||||
- on_response_done: Assistant finished responding
|
||||
- on_function_call: Function call requested
|
||||
- on_error: Error occurred
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[RealtimeConfig] = None):
|
||||
"""
|
||||
Initialize Realtime API service.
|
||||
|
||||
Args:
|
||||
config: Realtime configuration (uses defaults if not provided)
|
||||
"""
|
||||
self.config = config or RealtimeConfig()
|
||||
self.config.api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")
|
||||
|
||||
self.state = RealtimeState.DISCONNECTED
|
||||
self._ws = None
|
||||
self._receive_task: Optional[asyncio.Task] = None
|
||||
self._cancel_event = asyncio.Event()
|
||||
|
||||
# Event callbacks
|
||||
self._callbacks: Dict[str, List[Callable]] = {
|
||||
"on_audio": [],
|
||||
"on_transcript": [],
|
||||
"on_speech_started": [],
|
||||
"on_speech_stopped": [],
|
||||
"on_response_started": [],
|
||||
"on_response_done": [],
|
||||
"on_function_call": [],
|
||||
"on_error": [],
|
||||
"on_interrupted": [],
|
||||
}
|
||||
|
||||
logger.debug(f"RealtimeService initialized with model={self.config.model}")
|
||||
|
||||
def on(self, event: str, callback: Callable[..., Awaitable[None]]) -> None:
|
||||
"""
|
||||
Register event callback.
|
||||
|
||||
Args:
|
||||
event: Event name
|
||||
callback: Async callback function
|
||||
"""
|
||||
if event in self._callbacks:
|
||||
self._callbacks[event].append(callback)
|
||||
|
||||
async def _emit(self, event: str, *args, **kwargs) -> None:
|
||||
"""Emit event to all registered callbacks."""
|
||||
for callback in self._callbacks.get(event, []):
|
||||
try:
|
||||
await callback(*args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(f"Event callback error ({event}): {e}")
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Connect to OpenAI Realtime API."""
|
||||
if not WEBSOCKETS_AVAILABLE:
|
||||
raise RuntimeError("websockets package not installed")
|
||||
|
||||
if not self.config.api_key:
|
||||
raise ValueError("OpenAI API key not provided")
|
||||
|
||||
self.state = RealtimeState.CONNECTING
|
||||
|
||||
# Build URL
|
||||
if self.config.endpoint:
|
||||
# Azure or custom endpoint
|
||||
url = f"{self.config.endpoint}/openai/realtime?api-version=2024-10-01-preview&deployment={self.config.model}"
|
||||
else:
|
||||
# OpenAI endpoint
|
||||
url = f"wss://api.openai.com/v1/realtime?model={self.config.model}"
|
||||
|
||||
# Build headers
|
||||
headers = {}
|
||||
if self.config.endpoint:
|
||||
headers["api-key"] = self.config.api_key
|
||||
else:
|
||||
headers["Authorization"] = f"Bearer {self.config.api_key}"
|
||||
headers["OpenAI-Beta"] = "realtime=v1"
|
||||
|
||||
try:
|
||||
logger.info(f"Connecting to Realtime API: {url}")
|
||||
self._ws = await websockets.connect(url, extra_headers=headers)
|
||||
|
||||
# Send session configuration
|
||||
await self._configure_session()
|
||||
|
||||
# Start receive loop
|
||||
self._receive_task = asyncio.create_task(self._receive_loop())
|
||||
|
||||
self.state = RealtimeState.CONNECTED
|
||||
logger.info("Realtime API connected successfully")
|
||||
|
||||
except Exception as e:
|
||||
self.state = RealtimeState.ERROR
|
||||
logger.error(f"Realtime API connection failed: {e}")
|
||||
raise
|
||||
|
||||
async def _configure_session(self) -> None:
|
||||
"""Send session configuration to server."""
|
||||
session_config = {
|
||||
"type": "session.update",
|
||||
"session": {
|
||||
"modalities": ["text", "audio"],
|
||||
"instructions": self.config.instructions,
|
||||
"voice": self.config.voice,
|
||||
"input_audio_format": self.config.input_audio_format,
|
||||
"output_audio_format": self.config.output_audio_format,
|
||||
"turn_detection": self.config.turn_detection,
|
||||
}
|
||||
}
|
||||
|
||||
if self.config.tools:
|
||||
session_config["session"]["tools"] = self.config.tools
|
||||
|
||||
await self._send(session_config)
|
||||
logger.debug("Session configuration sent")
|
||||
|
||||
async def _send(self, data: Dict[str, Any]) -> None:
|
||||
"""Send JSON data to server."""
|
||||
if self._ws:
|
||||
await self._ws.send(json.dumps(data))
|
||||
|
||||
async def send_audio(self, audio_bytes: bytes) -> None:
|
||||
"""
|
||||
Send audio to the Realtime API.
|
||||
|
||||
Args:
|
||||
audio_bytes: PCM audio data (16-bit, mono, 24kHz by default)
|
||||
"""
|
||||
if self.state != RealtimeState.CONNECTED:
|
||||
return
|
||||
|
||||
# Encode audio as base64
|
||||
audio_b64 = base64.standard_b64encode(audio_bytes).decode()
|
||||
|
||||
await self._send({
|
||||
"type": "input_audio_buffer.append",
|
||||
"audio": audio_b64
|
||||
})
|
||||
|
||||
async def send_text(self, text: str) -> None:
|
||||
"""
|
||||
Send text input (bypassing audio).
|
||||
|
||||
Args:
|
||||
text: User text input
|
||||
"""
|
||||
if self.state != RealtimeState.CONNECTED:
|
||||
return
|
||||
|
||||
# Create a conversation item with user text
|
||||
await self._send({
|
||||
"type": "conversation.item.create",
|
||||
"item": {
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [{"type": "input_text", "text": text}]
|
||||
}
|
||||
})
|
||||
|
||||
# Trigger response
|
||||
await self._send({"type": "response.create"})
|
||||
|
||||
async def cancel_response(self) -> None:
|
||||
"""Cancel the current response (for barge-in)."""
|
||||
if self.state != RealtimeState.CONNECTED:
|
||||
return
|
||||
|
||||
await self._send({"type": "response.cancel"})
|
||||
logger.debug("Response cancelled")
|
||||
|
||||
async def commit_audio(self) -> None:
|
||||
"""Commit the audio buffer and trigger response."""
|
||||
if self.state != RealtimeState.CONNECTED:
|
||||
return
|
||||
|
||||
await self._send({"type": "input_audio_buffer.commit"})
|
||||
await self._send({"type": "response.create"})
|
||||
|
||||
async def clear_audio_buffer(self) -> None:
|
||||
"""Clear the input audio buffer."""
|
||||
if self.state != RealtimeState.CONNECTED:
|
||||
return
|
||||
|
||||
await self._send({"type": "input_audio_buffer.clear"})
|
||||
|
||||
async def submit_function_result(self, call_id: str, result: str) -> None:
|
||||
"""
|
||||
Submit function call result.
|
||||
|
||||
Args:
|
||||
call_id: The function call ID
|
||||
result: JSON string result
|
||||
"""
|
||||
if self.state != RealtimeState.CONNECTED:
|
||||
return
|
||||
|
||||
await self._send({
|
||||
"type": "conversation.item.create",
|
||||
"item": {
|
||||
"type": "function_call_output",
|
||||
"call_id": call_id,
|
||||
"output": result
|
||||
}
|
||||
})
|
||||
|
||||
# Trigger response with the function result
|
||||
await self._send({"type": "response.create"})
|
||||
|
||||
async def _receive_loop(self) -> None:
|
||||
"""Receive and process messages from the Realtime API."""
|
||||
if not self._ws:
|
||||
return
|
||||
|
||||
try:
|
||||
async for message in self._ws:
|
||||
try:
|
||||
data = json.loads(message)
|
||||
await self._handle_event(data)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Invalid JSON received: {message[:100]}")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("Receive loop cancelled")
|
||||
except websockets.ConnectionClosed as e:
|
||||
logger.info(f"WebSocket closed: {e}")
|
||||
self.state = RealtimeState.DISCONNECTED
|
||||
except Exception as e:
|
||||
logger.error(f"Receive loop error: {e}")
|
||||
self.state = RealtimeState.ERROR
|
||||
|
||||
async def _handle_event(self, data: Dict[str, Any]) -> None:
|
||||
"""Handle incoming event from Realtime API."""
|
||||
event_type = data.get("type", "unknown")
|
||||
|
||||
# Audio delta - streaming audio output
|
||||
if event_type == "response.audio.delta":
|
||||
if "delta" in data:
|
||||
audio_bytes = base64.standard_b64decode(data["delta"])
|
||||
await self._emit("on_audio", audio_bytes)
|
||||
|
||||
# Audio transcript delta - streaming text
|
||||
elif event_type == "response.audio_transcript.delta":
|
||||
if "delta" in data:
|
||||
await self._emit("on_transcript", data["delta"], "assistant", False)
|
||||
|
||||
# Audio transcript done
|
||||
elif event_type == "response.audio_transcript.done":
|
||||
if "transcript" in data:
|
||||
await self._emit("on_transcript", data["transcript"], "assistant", True)
|
||||
|
||||
# Input audio transcript (user speech)
|
||||
elif event_type == "conversation.item.input_audio_transcription.completed":
|
||||
if "transcript" in data:
|
||||
await self._emit("on_transcript", data["transcript"], "user", True)
|
||||
|
||||
# Speech started (server VAD detected speech)
|
||||
elif event_type == "input_audio_buffer.speech_started":
|
||||
await self._emit("on_speech_started", data.get("audio_start_ms", 0))
|
||||
|
||||
# Speech stopped
|
||||
elif event_type == "input_audio_buffer.speech_stopped":
|
||||
await self._emit("on_speech_stopped", data.get("audio_end_ms", 0))
|
||||
|
||||
# Response started
|
||||
elif event_type == "response.created":
|
||||
await self._emit("on_response_started", data.get("response", {}))
|
||||
|
||||
# Response done
|
||||
elif event_type == "response.done":
|
||||
await self._emit("on_response_done", data.get("response", {}))
|
||||
|
||||
# Function call
|
||||
elif event_type == "response.function_call_arguments.done":
|
||||
call_id = data.get("call_id")
|
||||
name = data.get("name")
|
||||
arguments = data.get("arguments", "{}")
|
||||
await self._emit("on_function_call", call_id, name, arguments)
|
||||
|
||||
# Error
|
||||
elif event_type == "error":
|
||||
error = data.get("error", {})
|
||||
logger.error(f"Realtime API error: {error}")
|
||||
await self._emit("on_error", error)
|
||||
|
||||
# Session events
|
||||
elif event_type == "session.created":
|
||||
logger.info("Session created")
|
||||
elif event_type == "session.updated":
|
||||
logger.debug("Session updated")
|
||||
|
||||
else:
|
||||
logger.debug(f"Unhandled event type: {event_type}")
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from Realtime API."""
|
||||
self._cancel_event.set()
|
||||
|
||||
if self._receive_task:
|
||||
self._receive_task.cancel()
|
||||
try:
|
||||
await self._receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
|
||||
self.state = RealtimeState.DISCONNECTED
|
||||
logger.info("Realtime API disconnected")
|
||||
|
||||
|
||||
class RealtimePipeline:
|
||||
"""
|
||||
Pipeline adapter for RealtimeService.
|
||||
|
||||
Provides a compatible interface with DuplexPipeline but uses
|
||||
OpenAI Realtime API for all processing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transport,
|
||||
session_id: str,
|
||||
config: Optional[RealtimeConfig] = None
|
||||
):
|
||||
"""
|
||||
Initialize Realtime pipeline.
|
||||
|
||||
Args:
|
||||
transport: Transport for sending audio/events
|
||||
session_id: Session identifier
|
||||
config: Realtime configuration
|
||||
"""
|
||||
self.transport = transport
|
||||
self.session_id = session_id
|
||||
|
||||
self.service = RealtimeService(config)
|
||||
|
||||
# Register callbacks
|
||||
self.service.on("on_audio", self._on_audio)
|
||||
self.service.on("on_transcript", self._on_transcript)
|
||||
self.service.on("on_speech_started", self._on_speech_started)
|
||||
self.service.on("on_speech_stopped", self._on_speech_stopped)
|
||||
self.service.on("on_response_started", self._on_response_started)
|
||||
self.service.on("on_response_done", self._on_response_done)
|
||||
self.service.on("on_error", self._on_error)
|
||||
|
||||
self._is_speaking = False
|
||||
self._running = True
|
||||
|
||||
logger.info(f"RealtimePipeline initialized for session {session_id}")
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the pipeline."""
|
||||
await self.service.connect()
|
||||
|
||||
async def process_audio(self, pcm_bytes: bytes) -> None:
|
||||
"""
|
||||
Process incoming audio.
|
||||
|
||||
Note: Realtime API expects 24kHz audio by default.
|
||||
You may need to resample from 16kHz.
|
||||
"""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
# TODO: Resample from 16kHz to 24kHz if needed
|
||||
await self.service.send_audio(pcm_bytes)
|
||||
|
||||
async def process_text(self, text: str) -> None:
|
||||
"""Process text input."""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
await self.service.send_text(text)
|
||||
|
||||
async def interrupt(self) -> None:
|
||||
"""Interrupt current response."""
|
||||
await self.service.cancel_response()
|
||||
await self.transport.send_event({
|
||||
"event": "interrupt",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
})
|
||||
|
||||
async def cleanup(self) -> None:
|
||||
"""Cleanup resources."""
|
||||
self._running = False
|
||||
await self.service.disconnect()
|
||||
|
||||
# Event handlers
|
||||
|
||||
async def _on_audio(self, audio_bytes: bytes) -> None:
|
||||
"""Handle audio output."""
|
||||
await self.transport.send_audio(audio_bytes)
|
||||
|
||||
async def _on_transcript(self, text: str, role: str, is_final: bool) -> None:
|
||||
"""Handle transcript."""
|
||||
logger.info(f"[{role.upper()}] {text[:50]}..." if len(text) > 50 else f"[{role.upper()}] {text}")
|
||||
|
||||
async def _on_speech_started(self, start_ms: int) -> None:
|
||||
"""Handle user speech start."""
|
||||
self._is_speaking = True
|
||||
await self.transport.send_event({
|
||||
"event": "speaking",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms(),
|
||||
"startTime": start_ms
|
||||
})
|
||||
|
||||
# Cancel any ongoing response (barge-in)
|
||||
await self.service.cancel_response()
|
||||
|
||||
async def _on_speech_stopped(self, end_ms: int) -> None:
|
||||
"""Handle user speech stop."""
|
||||
self._is_speaking = False
|
||||
await self.transport.send_event({
|
||||
"event": "silence",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms(),
|
||||
"duration": end_ms
|
||||
})
|
||||
|
||||
async def _on_response_started(self, response: Dict) -> None:
|
||||
"""Handle response start."""
|
||||
await self.transport.send_event({
|
||||
"event": "trackStart",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
})
|
||||
|
||||
async def _on_response_done(self, response: Dict) -> None:
|
||||
"""Handle response complete."""
|
||||
await self.transport.send_event({
|
||||
"event": "trackEnd",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
})
|
||||
|
||||
async def _on_error(self, error: Dict) -> None:
|
||||
"""Handle error."""
|
||||
await self.transport.send_event({
|
||||
"event": "error",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms(),
|
||||
"sender": "realtime",
|
||||
"error": str(error)
|
||||
})
|
||||
|
||||
def _get_timestamp_ms(self) -> int:
|
||||
"""Get current timestamp in milliseconds."""
|
||||
import time
|
||||
return int(time.time() * 1000)
|
||||
|
||||
@property
|
||||
def is_speaking(self) -> bool:
|
||||
"""Check if user is speaking."""
|
||||
return self._is_speaking
|
||||
255
services/siliconflow_tts.py
Normal file
255
services/siliconflow_tts.py
Normal file
@@ -0,0 +1,255 @@
|
||||
"""SiliconFlow TTS Service with streaming support.
|
||||
|
||||
Uses SiliconFlow's CosyVoice2 or MOSS-TTSD models for low-latency
|
||||
text-to-speech synthesis with streaming.
|
||||
|
||||
API Docs: https://docs.siliconflow.cn/cn/api-reference/audio/create-speech
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import aiohttp
|
||||
from typing import AsyncIterator, Optional
|
||||
from loguru import logger
|
||||
|
||||
from services.base import BaseTTSService, TTSChunk, ServiceState
|
||||
|
||||
|
||||
class SiliconFlowTTSService(BaseTTSService):
|
||||
"""
|
||||
SiliconFlow TTS service with streaming support.
|
||||
|
||||
Supports CosyVoice2-0.5B and MOSS-TTSD-v0.5 models.
|
||||
"""
|
||||
|
||||
# Available voices
|
||||
VOICES = {
|
||||
"alex": "FunAudioLLM/CosyVoice2-0.5B:alex",
|
||||
"anna": "FunAudioLLM/CosyVoice2-0.5B:anna",
|
||||
"bella": "FunAudioLLM/CosyVoice2-0.5B:bella",
|
||||
"benjamin": "FunAudioLLM/CosyVoice2-0.5B:benjamin",
|
||||
"charles": "FunAudioLLM/CosyVoice2-0.5B:charles",
|
||||
"claire": "FunAudioLLM/CosyVoice2-0.5B:claire",
|
||||
"david": "FunAudioLLM/CosyVoice2-0.5B:david",
|
||||
"diana": "FunAudioLLM/CosyVoice2-0.5B:diana",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
voice: str = "anna",
|
||||
model: str = "FunAudioLLM/CosyVoice2-0.5B",
|
||||
sample_rate: int = 16000,
|
||||
speed: float = 1.0
|
||||
):
|
||||
"""
|
||||
Initialize SiliconFlow TTS service.
|
||||
|
||||
Args:
|
||||
api_key: SiliconFlow API key (defaults to SILICONFLOW_API_KEY env var)
|
||||
voice: Voice name (alex, anna, bella, benjamin, charles, claire, david, diana)
|
||||
model: Model name
|
||||
sample_rate: Output sample rate (8000, 16000, 24000, 32000, 44100)
|
||||
speed: Speech speed (0.25 to 4.0)
|
||||
"""
|
||||
# Resolve voice name
|
||||
if voice in self.VOICES:
|
||||
full_voice = self.VOICES[voice]
|
||||
else:
|
||||
full_voice = voice
|
||||
|
||||
super().__init__(voice=full_voice, sample_rate=sample_rate, speed=speed)
|
||||
|
||||
self.api_key = api_key or os.getenv("SILICONFLOW_API_KEY")
|
||||
self.model = model
|
||||
self.api_url = "https://api.siliconflow.cn/v1/audio/speech"
|
||||
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
self._cancel_event = asyncio.Event()
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Initialize HTTP session."""
|
||||
if not self.api_key:
|
||||
raise ValueError("SiliconFlow API key not provided. Set SILICONFLOW_API_KEY env var.")
|
||||
|
||||
self._session = aiohttp.ClientSession(
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
)
|
||||
self.state = ServiceState.CONNECTED
|
||||
logger.info(f"SiliconFlow TTS service ready: voice={self.voice}, model={self.model}")
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Close HTTP session."""
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
self.state = ServiceState.DISCONNECTED
|
||||
logger.info("SiliconFlow TTS service disconnected")
|
||||
|
||||
async def synthesize(self, text: str) -> bytes:
|
||||
"""Synthesize complete audio for text."""
|
||||
audio_data = b""
|
||||
async for chunk in self.synthesize_stream(text):
|
||||
audio_data += chunk.audio
|
||||
return audio_data
|
||||
|
||||
async def synthesize_stream(self, text: str) -> AsyncIterator[TTSChunk]:
|
||||
"""
|
||||
Synthesize audio in streaming mode.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize
|
||||
|
||||
Yields:
|
||||
TTSChunk objects with PCM audio
|
||||
"""
|
||||
if not self._session:
|
||||
raise RuntimeError("TTS service not connected")
|
||||
|
||||
if not text.strip():
|
||||
return
|
||||
|
||||
self._cancel_event.clear()
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"input": text,
|
||||
"voice": self.voice,
|
||||
"response_format": "pcm",
|
||||
"sample_rate": self.sample_rate,
|
||||
"stream": True,
|
||||
"speed": self.speed
|
||||
}
|
||||
|
||||
try:
|
||||
async with self._session.post(self.api_url, json=payload) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"SiliconFlow TTS error: {response.status} - {error_text}")
|
||||
return
|
||||
|
||||
# Stream audio chunks
|
||||
chunk_size = self.sample_rate * 2 // 10 # 100ms chunks
|
||||
buffer = b""
|
||||
|
||||
async for chunk in response.content.iter_any():
|
||||
if self._cancel_event.is_set():
|
||||
logger.info("TTS synthesis cancelled")
|
||||
return
|
||||
|
||||
buffer += chunk
|
||||
|
||||
# Yield complete chunks
|
||||
while len(buffer) >= chunk_size:
|
||||
audio_chunk = buffer[:chunk_size]
|
||||
buffer = buffer[chunk_size:]
|
||||
|
||||
yield TTSChunk(
|
||||
audio=audio_chunk,
|
||||
sample_rate=self.sample_rate,
|
||||
is_final=False
|
||||
)
|
||||
|
||||
# Yield remaining buffer
|
||||
if buffer:
|
||||
yield TTSChunk(
|
||||
audio=buffer,
|
||||
sample_rate=self.sample_rate,
|
||||
is_final=True
|
||||
)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("TTS synthesis cancelled via asyncio")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"TTS synthesis error: {e}")
|
||||
raise
|
||||
|
||||
async def cancel(self) -> None:
|
||||
"""Cancel ongoing synthesis."""
|
||||
self._cancel_event.set()
|
||||
|
||||
|
||||
class StreamingTTSAdapter:
|
||||
"""
|
||||
Adapter for streaming LLM text to TTS with sentence-level chunking.
|
||||
|
||||
This reduces latency by starting TTS as soon as a complete sentence
|
||||
is received from the LLM, rather than waiting for the full response.
|
||||
"""
|
||||
|
||||
# Sentence delimiters
|
||||
SENTENCE_ENDS = {'.', '!', '?', '。', '!', '?', ';', '\n'}
|
||||
|
||||
def __init__(self, tts_service: BaseTTSService, transport, session_id: str):
|
||||
self.tts_service = tts_service
|
||||
self.transport = transport
|
||||
self.session_id = session_id
|
||||
self._buffer = ""
|
||||
self._cancel_event = asyncio.Event()
|
||||
self._is_speaking = False
|
||||
|
||||
async def process_text_chunk(self, text_chunk: str) -> None:
|
||||
"""
|
||||
Process a text chunk from LLM and trigger TTS when sentence is complete.
|
||||
|
||||
Args:
|
||||
text_chunk: Text chunk from LLM streaming
|
||||
"""
|
||||
if self._cancel_event.is_set():
|
||||
return
|
||||
|
||||
self._buffer += text_chunk
|
||||
|
||||
# Check for sentence completion
|
||||
for i, char in enumerate(self._buffer):
|
||||
if char in self.SENTENCE_ENDS:
|
||||
# Found sentence end, synthesize up to this point
|
||||
sentence = self._buffer[:i+1].strip()
|
||||
self._buffer = self._buffer[i+1:]
|
||||
|
||||
if sentence:
|
||||
await self._speak_sentence(sentence)
|
||||
break
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Flush remaining buffer."""
|
||||
if self._buffer.strip() and not self._cancel_event.is_set():
|
||||
await self._speak_sentence(self._buffer.strip())
|
||||
self._buffer = ""
|
||||
|
||||
async def _speak_sentence(self, text: str) -> None:
|
||||
"""Synthesize and send a sentence."""
|
||||
if not text or self._cancel_event.is_set():
|
||||
return
|
||||
|
||||
self._is_speaking = True
|
||||
|
||||
try:
|
||||
async for chunk in self.tts_service.synthesize_stream(text):
|
||||
if self._cancel_event.is_set():
|
||||
break
|
||||
await self.transport.send_audio(chunk.audio)
|
||||
await asyncio.sleep(0.01) # Prevent flooding
|
||||
except Exception as e:
|
||||
logger.error(f"TTS speak error: {e}")
|
||||
finally:
|
||||
self._is_speaking = False
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""Cancel ongoing speech."""
|
||||
self._cancel_event.set()
|
||||
self._buffer = ""
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset for new turn."""
|
||||
self._cancel_event.clear()
|
||||
self._buffer = ""
|
||||
self._is_speaking = False
|
||||
|
||||
@property
|
||||
def is_speaking(self) -> bool:
|
||||
return self._is_speaking
|
||||
271
services/tts.py
Normal file
271
services/tts.py
Normal file
@@ -0,0 +1,271 @@
|
||||
"""TTS (Text-to-Speech) Service implementations.
|
||||
|
||||
Provides multiple TTS backend options including edge-tts (free)
|
||||
and placeholder for cloud services.
|
||||
"""
|
||||
|
||||
import os
|
||||
import io
|
||||
import asyncio
|
||||
import struct
|
||||
from typing import AsyncIterator, Optional
|
||||
from loguru import logger
|
||||
|
||||
from services.base import BaseTTSService, TTSChunk, ServiceState
|
||||
|
||||
# Try to import edge-tts
|
||||
try:
|
||||
import edge_tts
|
||||
EDGE_TTS_AVAILABLE = True
|
||||
except ImportError:
|
||||
EDGE_TTS_AVAILABLE = False
|
||||
logger.warning("edge-tts not available - EdgeTTS service will be disabled")
|
||||
|
||||
|
||||
class EdgeTTSService(BaseTTSService):
|
||||
"""
|
||||
Microsoft Edge TTS service.
|
||||
|
||||
Uses edge-tts library for free, high-quality speech synthesis.
|
||||
Supports streaming for low-latency playback.
|
||||
"""
|
||||
|
||||
# Voice mapping for common languages
|
||||
VOICE_MAP = {
|
||||
"en": "en-US-JennyNeural",
|
||||
"en-US": "en-US-JennyNeural",
|
||||
"en-GB": "en-GB-SoniaNeural",
|
||||
"zh": "zh-CN-XiaoxiaoNeural",
|
||||
"zh-CN": "zh-CN-XiaoxiaoNeural",
|
||||
"zh-TW": "zh-TW-HsiaoChenNeural",
|
||||
"ja": "ja-JP-NanamiNeural",
|
||||
"ko": "ko-KR-SunHiNeural",
|
||||
"fr": "fr-FR-DeniseNeural",
|
||||
"de": "de-DE-KatjaNeural",
|
||||
"es": "es-ES-ElviraNeural",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
voice: str = "en-US-JennyNeural",
|
||||
sample_rate: int = 16000,
|
||||
speed: float = 1.0
|
||||
):
|
||||
"""
|
||||
Initialize Edge TTS service.
|
||||
|
||||
Args:
|
||||
voice: Voice name (e.g., "en-US-JennyNeural") or language code (e.g., "en")
|
||||
sample_rate: Target sample rate (will be resampled)
|
||||
speed: Speech speed multiplier
|
||||
"""
|
||||
# Resolve voice from language code if needed
|
||||
if voice in self.VOICE_MAP:
|
||||
voice = self.VOICE_MAP[voice]
|
||||
|
||||
super().__init__(voice=voice, sample_rate=sample_rate, speed=speed)
|
||||
self._cancel_event = asyncio.Event()
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Edge TTS doesn't require explicit connection."""
|
||||
if not EDGE_TTS_AVAILABLE:
|
||||
raise RuntimeError("edge-tts package not installed")
|
||||
self.state = ServiceState.CONNECTED
|
||||
logger.info(f"Edge TTS service ready: voice={self.voice}")
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Edge TTS doesn't require explicit disconnection."""
|
||||
self.state = ServiceState.DISCONNECTED
|
||||
logger.info("Edge TTS service disconnected")
|
||||
|
||||
def _get_rate_string(self) -> str:
|
||||
"""Convert speed to rate string for edge-tts."""
|
||||
# edge-tts uses percentage format: "+0%", "-10%", "+20%"
|
||||
percentage = int((self.speed - 1.0) * 100)
|
||||
if percentage >= 0:
|
||||
return f"+{percentage}%"
|
||||
return f"{percentage}%"
|
||||
|
||||
async def synthesize(self, text: str) -> bytes:
|
||||
"""
|
||||
Synthesize complete audio for text.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize
|
||||
|
||||
Returns:
|
||||
PCM audio data (16-bit, mono, 16kHz)
|
||||
"""
|
||||
if not EDGE_TTS_AVAILABLE:
|
||||
raise RuntimeError("edge-tts not available")
|
||||
|
||||
# Collect all chunks
|
||||
audio_data = b""
|
||||
async for chunk in self.synthesize_stream(text):
|
||||
audio_data += chunk.audio
|
||||
|
||||
return audio_data
|
||||
|
||||
async def synthesize_stream(self, text: str) -> AsyncIterator[TTSChunk]:
|
||||
"""
|
||||
Synthesize audio in streaming mode.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize
|
||||
|
||||
Yields:
|
||||
TTSChunk objects with PCM audio
|
||||
"""
|
||||
if not EDGE_TTS_AVAILABLE:
|
||||
raise RuntimeError("edge-tts not available")
|
||||
|
||||
self._cancel_event.clear()
|
||||
|
||||
try:
|
||||
communicate = edge_tts.Communicate(
|
||||
text,
|
||||
voice=self.voice,
|
||||
rate=self._get_rate_string()
|
||||
)
|
||||
|
||||
# edge-tts outputs MP3, we need to decode to PCM
|
||||
# For now, collect MP3 chunks and yield after conversion
|
||||
mp3_data = b""
|
||||
|
||||
async for chunk in communicate.stream():
|
||||
# Check for cancellation
|
||||
if self._cancel_event.is_set():
|
||||
logger.info("TTS synthesis cancelled")
|
||||
return
|
||||
|
||||
if chunk["type"] == "audio":
|
||||
mp3_data += chunk["data"]
|
||||
|
||||
# Convert MP3 to PCM
|
||||
if mp3_data:
|
||||
pcm_data = await self._convert_mp3_to_pcm(mp3_data)
|
||||
if pcm_data:
|
||||
# Yield in chunks for streaming playback
|
||||
chunk_size = self.sample_rate * 2 // 10 # 100ms chunks
|
||||
for i in range(0, len(pcm_data), chunk_size):
|
||||
if self._cancel_event.is_set():
|
||||
return
|
||||
|
||||
chunk_data = pcm_data[i:i + chunk_size]
|
||||
yield TTSChunk(
|
||||
audio=chunk_data,
|
||||
sample_rate=self.sample_rate,
|
||||
is_final=(i + chunk_size >= len(pcm_data))
|
||||
)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("TTS synthesis cancelled via asyncio")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"TTS synthesis error: {e}")
|
||||
raise
|
||||
|
||||
async def _convert_mp3_to_pcm(self, mp3_data: bytes) -> bytes:
|
||||
"""
|
||||
Convert MP3 audio to PCM.
|
||||
|
||||
Uses pydub or ffmpeg for conversion.
|
||||
"""
|
||||
try:
|
||||
# Try using pydub (requires ffmpeg)
|
||||
from pydub import AudioSegment
|
||||
|
||||
# Load MP3 from bytes
|
||||
audio = AudioSegment.from_mp3(io.BytesIO(mp3_data))
|
||||
|
||||
# Convert to target format
|
||||
audio = audio.set_frame_rate(self.sample_rate)
|
||||
audio = audio.set_channels(1)
|
||||
audio = audio.set_sample_width(2) # 16-bit
|
||||
|
||||
# Export as raw PCM
|
||||
return audio.raw_data
|
||||
|
||||
except ImportError:
|
||||
logger.warning("pydub not available, trying fallback")
|
||||
# Fallback: Use subprocess to call ffmpeg directly
|
||||
return await self._ffmpeg_convert(mp3_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Audio conversion error: {e}")
|
||||
return b""
|
||||
|
||||
async def _ffmpeg_convert(self, mp3_data: bytes) -> bytes:
|
||||
"""Convert MP3 to PCM using ffmpeg subprocess."""
|
||||
try:
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
"ffmpeg",
|
||||
"-i", "pipe:0",
|
||||
"-f", "s16le",
|
||||
"-acodec", "pcm_s16le",
|
||||
"-ar", str(self.sample_rate),
|
||||
"-ac", "1",
|
||||
"pipe:1",
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.DEVNULL
|
||||
)
|
||||
|
||||
stdout, _ = await process.communicate(input=mp3_data)
|
||||
return stdout
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ffmpeg conversion error: {e}")
|
||||
return b""
|
||||
|
||||
async def cancel(self) -> None:
|
||||
"""Cancel ongoing synthesis."""
|
||||
self._cancel_event.set()
|
||||
|
||||
|
||||
class MockTTSService(BaseTTSService):
|
||||
"""
|
||||
Mock TTS service for testing without actual synthesis.
|
||||
|
||||
Generates silence or simple tones.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
voice: str = "mock",
|
||||
sample_rate: int = 16000,
|
||||
speed: float = 1.0
|
||||
):
|
||||
super().__init__(voice=voice, sample_rate=sample_rate, speed=speed)
|
||||
|
||||
async def connect(self) -> None:
|
||||
self.state = ServiceState.CONNECTED
|
||||
logger.info("Mock TTS service connected")
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
self.state = ServiceState.DISCONNECTED
|
||||
logger.info("Mock TTS service disconnected")
|
||||
|
||||
async def synthesize(self, text: str) -> bytes:
|
||||
"""Generate silence based on text length."""
|
||||
# Approximate: 100ms per word
|
||||
word_count = len(text.split())
|
||||
duration_ms = word_count * 100
|
||||
samples = int(self.sample_rate * duration_ms / 1000)
|
||||
|
||||
# Generate silence (zeros)
|
||||
return bytes(samples * 2) # 16-bit = 2 bytes per sample
|
||||
|
||||
async def synthesize_stream(self, text: str) -> AsyncIterator[TTSChunk]:
|
||||
"""Generate silence chunks."""
|
||||
audio = await self.synthesize(text)
|
||||
|
||||
# Yield in 100ms chunks
|
||||
chunk_size = self.sample_rate * 2 // 10
|
||||
for i in range(0, len(audio), chunk_size):
|
||||
chunk_data = audio[i:i + chunk_size]
|
||||
yield TTSChunk(
|
||||
audio=chunk_data,
|
||||
sample_rate=self.sample_rate,
|
||||
is_final=(i + chunk_size >= len(audio))
|
||||
)
|
||||
await asyncio.sleep(0.05) # Simulate processing time
|
||||
Reference in New Issue
Block a user