Compare commits
5 Commits
4cb267a288
...
add-readme
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
71b7e32563 | ||
|
|
aa4316de6f | ||
|
|
d6d0ade33e | ||
|
|
ac0c76e6e8 | ||
|
|
cd90b4fb37 |
30
.env.example
Normal file
30
.env.example
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
# Server Configuration
|
||||||
|
HOST=0.0.0.0
|
||||||
|
PORT=8000
|
||||||
|
|
||||||
|
# Audio Configuration
|
||||||
|
SAMPLE_RATE=16000
|
||||||
|
CHUNK_SIZE_MS=20
|
||||||
|
|
||||||
|
# VAD Configuration
|
||||||
|
VAD_THRESHOLD=0.5
|
||||||
|
VAD_EOU_THRESHOLD_MS=400
|
||||||
|
|
||||||
|
# OpenAI / LLM Configuration (required for duplex voice)
|
||||||
|
OPENAI_API_KEY=sk-your-openai-api-key-here
|
||||||
|
# OPENAI_API_URL=https://api.openai.com/v1 # Optional: for Azure or compatible APIs
|
||||||
|
LLM_MODEL=gpt-4o-mini
|
||||||
|
LLM_TEMPERATURE=0.7
|
||||||
|
|
||||||
|
# TTS Configuration
|
||||||
|
TTS_VOICE=en-US-JennyNeural
|
||||||
|
TTS_SPEED=1.0
|
||||||
|
|
||||||
|
# Duplex Pipeline Configuration
|
||||||
|
DUPLEX_ENABLED=true
|
||||||
|
# DUPLEX_GREETING=Hello! How can I help you today?
|
||||||
|
DUPLEX_SYSTEM_PROMPT=You are a helpful, friendly voice assistant. Keep your responses concise and conversational.
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
LOG_LEVEL=INFO
|
||||||
|
LOG_FORMAT=text
|
||||||
7
README.md
Normal file
7
README.md
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
# py-active-call-cc
|
||||||
|
|
||||||
|
Python Active-Call: real-time audio streaming with WebSocket and WebRTC.
|
||||||
|
|
||||||
|
This repo contains a Python 3.11+ codebase for building low-latency voice
|
||||||
|
pipelines (capture, stream, and process audio) using WebRTC and WebSockets.
|
||||||
|
It is currently in an early, experimental stage.
|
||||||
1
app/__init__.py
Normal file
1
app/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Active-Call Application Package"""
|
||||||
116
app/config.py
Normal file
116
app/config.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
"""Configuration management using Pydantic settings."""
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
from pydantic import Field
|
||||||
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
"""Application settings loaded from environment variables."""
|
||||||
|
|
||||||
|
model_config = SettingsConfigDict(
|
||||||
|
env_file=".env",
|
||||||
|
env_file_encoding="utf-8",
|
||||||
|
case_sensitive=False,
|
||||||
|
extra="ignore"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Server Configuration
|
||||||
|
host: str = Field(default="0.0.0.0", description="Server host address")
|
||||||
|
port: int = Field(default=8000, description="Server port")
|
||||||
|
external_ip: Optional[str] = Field(default=None, description="External IP for NAT traversal")
|
||||||
|
|
||||||
|
# Audio Configuration
|
||||||
|
sample_rate: int = Field(default=16000, description="Audio sample rate in Hz")
|
||||||
|
chunk_size_ms: int = Field(default=20, description="Audio chunk duration in milliseconds")
|
||||||
|
default_codec: str = Field(default="pcm", description="Default audio codec")
|
||||||
|
|
||||||
|
# VAD Configuration
|
||||||
|
vad_type: str = Field(default="silero", description="VAD algorithm type")
|
||||||
|
vad_model_path: str = Field(default="data/vad/silero_vad.onnx", description="Path to VAD model")
|
||||||
|
vad_threshold: float = Field(default=0.5, description="VAD detection threshold")
|
||||||
|
vad_min_speech_duration_ms: int = Field(default=250, description="Minimum speech duration in milliseconds")
|
||||||
|
vad_eou_threshold_ms: int = Field(default=800, description="End of utterance (silence) threshold in milliseconds")
|
||||||
|
|
||||||
|
# OpenAI / LLM Configuration
|
||||||
|
openai_api_key: Optional[str] = Field(default=None, description="OpenAI API key")
|
||||||
|
openai_api_url: Optional[str] = Field(default=None, description="OpenAI API base URL (for Azure/compatible)")
|
||||||
|
llm_model: str = Field(default="gpt-4o-mini", description="LLM model name")
|
||||||
|
llm_temperature: float = Field(default=0.7, description="LLM temperature for response generation")
|
||||||
|
|
||||||
|
# TTS Configuration
|
||||||
|
tts_provider: str = Field(default="siliconflow", description="TTS provider (edge, siliconflow)")
|
||||||
|
tts_voice: str = Field(default="anna", description="TTS voice name")
|
||||||
|
tts_speed: float = Field(default=1.0, description="TTS speech speed multiplier")
|
||||||
|
|
||||||
|
# SiliconFlow Configuration
|
||||||
|
siliconflow_api_key: Optional[str] = Field(default=None, description="SiliconFlow API key")
|
||||||
|
siliconflow_tts_model: str = Field(default="FunAudioLLM/CosyVoice2-0.5B", description="SiliconFlow TTS model")
|
||||||
|
|
||||||
|
# ASR Configuration
|
||||||
|
asr_provider: str = Field(default="siliconflow", description="ASR provider (siliconflow, buffered)")
|
||||||
|
siliconflow_asr_model: str = Field(default="FunAudioLLM/SenseVoiceSmall", description="SiliconFlow ASR model")
|
||||||
|
asr_interim_interval_ms: int = Field(default=500, description="Interval for interim ASR results in ms")
|
||||||
|
asr_min_audio_ms: int = Field(default=300, description="Minimum audio duration before first ASR result")
|
||||||
|
|
||||||
|
# Duplex Pipeline Configuration
|
||||||
|
duplex_enabled: bool = Field(default=True, description="Enable duplex voice pipeline")
|
||||||
|
duplex_greeting: Optional[str] = Field(default=None, description="Optional greeting message")
|
||||||
|
duplex_system_prompt: Optional[str] = Field(
|
||||||
|
default="You are a helpful, friendly voice assistant. Keep your responses concise and conversational.",
|
||||||
|
description="System prompt for LLM"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Barge-in (interruption) Configuration
|
||||||
|
barge_in_min_duration_ms: int = Field(
|
||||||
|
default=50,
|
||||||
|
description="Minimum speech duration (ms) required to trigger barge-in. 50-100ms recommended."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
log_level: str = Field(default="INFO", description="Logging level")
|
||||||
|
log_format: str = Field(default="json", description="Log format (json or text)")
|
||||||
|
|
||||||
|
# CORS
|
||||||
|
cors_origins: str = Field(
|
||||||
|
default='["http://localhost:3000", "http://localhost:8080"]',
|
||||||
|
description="CORS allowed origins"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ICE Servers (WebRTC)
|
||||||
|
ice_servers: str = Field(
|
||||||
|
default='[{"urls": "stun:stun.l.google.com:19302"}]',
|
||||||
|
description="ICE servers configuration"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def chunk_size_bytes(self) -> int:
|
||||||
|
"""Calculate chunk size in bytes based on sample rate and duration."""
|
||||||
|
# 16-bit (2 bytes) per sample, mono channel
|
||||||
|
return int(self.sample_rate * 2 * (self.chunk_size_ms / 1000.0))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cors_origins_list(self) -> List[str]:
|
||||||
|
"""Parse CORS origins from JSON string."""
|
||||||
|
try:
|
||||||
|
return json.loads(self.cors_origins)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return ["http://localhost:3000", "http://localhost:8080"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ice_servers_list(self) -> List[dict]:
|
||||||
|
"""Parse ICE servers from JSON string."""
|
||||||
|
try:
|
||||||
|
return json.loads(self.ice_servers)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return [{"urls": "stun:stun.l.google.com:19302"}]
|
||||||
|
|
||||||
|
|
||||||
|
# Global settings instance
|
||||||
|
settings = Settings()
|
||||||
|
|
||||||
|
|
||||||
|
def get_settings() -> Settings:
|
||||||
|
"""Get application settings instance."""
|
||||||
|
return settings
|
||||||
290
app/main.py
Normal file
290
app/main.py
Normal file
@@ -0,0 +1,290 @@
|
|||||||
|
"""FastAPI application with WebSocket and WebRTC endpoints."""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
import json
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
# Try to import aiortc (optional for WebRTC functionality)
|
||||||
|
try:
|
||||||
|
from aiortc import RTCPeerConnection, RTCSessionDescription
|
||||||
|
AIORTC_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
AIORTC_AVAILABLE = False
|
||||||
|
logger.warning("aiortc not available - WebRTC endpoint will be disabled")
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
from core.transports import SocketTransport, WebRtcTransport
|
||||||
|
from core.session import Session
|
||||||
|
from processors.tracks import Resampled16kTrack
|
||||||
|
from core.events import get_event_bus, reset_event_bus
|
||||||
|
|
||||||
|
# Initialize FastAPI
|
||||||
|
app = FastAPI(title="Python Active-Call", version="0.1.0")
|
||||||
|
|
||||||
|
# Configure CORS
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=settings.cors_origins_list,
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Active sessions storage
|
||||||
|
active_sessions: Dict[str, Session] = {}
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logger.remove()
|
||||||
|
logger.add(
|
||||||
|
"../logs/active_call_{time}.log",
|
||||||
|
rotation="1 day",
|
||||||
|
retention="7 days",
|
||||||
|
level=settings.log_level,
|
||||||
|
format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}"
|
||||||
|
)
|
||||||
|
logger.add(
|
||||||
|
lambda msg: print(msg, end=""),
|
||||||
|
level=settings.log_level,
|
||||||
|
format="{time:HH:mm:ss} | {level: <8} | {message}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health_check():
|
||||||
|
"""Health check endpoint."""
|
||||||
|
return {"status": "healthy", "sessions": len(active_sessions)}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/iceservers")
|
||||||
|
async def get_ice_servers():
|
||||||
|
"""Get ICE servers configuration for WebRTC."""
|
||||||
|
return settings.ice_servers_list
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/call/lists")
|
||||||
|
async def list_calls():
|
||||||
|
"""List all active calls."""
|
||||||
|
return {
|
||||||
|
"calls": [
|
||||||
|
{
|
||||||
|
"id": session_id,
|
||||||
|
"state": session.state,
|
||||||
|
"created_at": session.created_at
|
||||||
|
}
|
||||||
|
for session_id, session in active_sessions.items()
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/call/kill/{session_id}")
|
||||||
|
async def kill_call(session_id: str):
|
||||||
|
"""Kill a specific active call."""
|
||||||
|
if session_id not in active_sessions:
|
||||||
|
raise HTTPException(status_code=404, detail="Session not found")
|
||||||
|
|
||||||
|
session = active_sessions[session_id]
|
||||||
|
await session.cleanup()
|
||||||
|
del active_sessions[session_id]
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@app.websocket("/ws")
|
||||||
|
async def websocket_endpoint(websocket: WebSocket):
|
||||||
|
"""
|
||||||
|
WebSocket endpoint for raw audio streaming.
|
||||||
|
|
||||||
|
Accepts mixed text/binary frames:
|
||||||
|
- Text frames: JSON commands
|
||||||
|
- Binary frames: PCM audio data (16kHz, 16-bit, mono)
|
||||||
|
"""
|
||||||
|
await websocket.accept()
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Create transport and session
|
||||||
|
transport = SocketTransport(websocket)
|
||||||
|
session = Session(session_id, transport)
|
||||||
|
active_sessions[session_id] = session
|
||||||
|
|
||||||
|
logger.info(f"WebSocket connection established: {session_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Receive loop
|
||||||
|
while True:
|
||||||
|
message = await websocket.receive()
|
||||||
|
|
||||||
|
# Handle binary audio data
|
||||||
|
if "bytes" in message:
|
||||||
|
await session.handle_audio(message["bytes"])
|
||||||
|
|
||||||
|
# Handle text commands
|
||||||
|
elif "text" in message:
|
||||||
|
await session.handle_text(message["text"])
|
||||||
|
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
logger.info(f"WebSocket disconnected: {session_id}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"WebSocket error: {e}", exc_info=True)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Cleanup session
|
||||||
|
if session_id in active_sessions:
|
||||||
|
await session.cleanup()
|
||||||
|
del active_sessions[session_id]
|
||||||
|
|
||||||
|
logger.info(f"Session {session_id} removed")
|
||||||
|
|
||||||
|
|
||||||
|
@app.websocket("/webrtc")
|
||||||
|
async def webrtc_endpoint(websocket: WebSocket):
|
||||||
|
"""
|
||||||
|
WebRTC endpoint for WebRTC audio streaming.
|
||||||
|
|
||||||
|
Uses WebSocket for signaling (SDP exchange) and WebRTC for media transport.
|
||||||
|
"""
|
||||||
|
# Check if aiortc is available
|
||||||
|
if not AIORTC_AVAILABLE:
|
||||||
|
await websocket.close(code=1011, reason="WebRTC not available - aiortc/av not installed")
|
||||||
|
logger.warning("WebRTC connection attempted but aiortc is not available")
|
||||||
|
return
|
||||||
|
await websocket.accept()
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Create WebRTC peer connection
|
||||||
|
pc = RTCPeerConnection()
|
||||||
|
|
||||||
|
# Create transport and session
|
||||||
|
transport = WebRtcTransport(websocket, pc)
|
||||||
|
session = Session(session_id, transport)
|
||||||
|
active_sessions[session_id] = session
|
||||||
|
|
||||||
|
logger.info(f"WebRTC connection established: {session_id}")
|
||||||
|
|
||||||
|
# Track handler for incoming audio
|
||||||
|
@pc.on("track")
|
||||||
|
def on_track(track):
|
||||||
|
logger.info(f"Track received: {track.kind}")
|
||||||
|
|
||||||
|
if track.kind == "audio":
|
||||||
|
# Wrap track with resampler
|
||||||
|
wrapped_track = Resampled16kTrack(track)
|
||||||
|
|
||||||
|
# Create task to pull audio from track
|
||||||
|
async def pull_audio():
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
frame = await wrapped_track.recv()
|
||||||
|
# Convert frame to bytes
|
||||||
|
pcm_bytes = frame.to_ndarray().tobytes()
|
||||||
|
# Feed to session
|
||||||
|
await session.handle_audio(pcm_bytes)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error pulling audio from track: {e}")
|
||||||
|
|
||||||
|
asyncio.create_task(pull_audio())
|
||||||
|
|
||||||
|
@pc.on("connectionstatechange")
|
||||||
|
async def on_connectionstatechange():
|
||||||
|
logger.info(f"Connection state: {pc.connectionState}")
|
||||||
|
if pc.connectionState == "failed" or pc.connectionState == "closed":
|
||||||
|
await session.cleanup()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Signaling loop
|
||||||
|
while True:
|
||||||
|
message = await websocket.receive()
|
||||||
|
|
||||||
|
if "text" not in message:
|
||||||
|
continue
|
||||||
|
|
||||||
|
data = json.loads(message["text"])
|
||||||
|
|
||||||
|
# Handle SDP offer/answer
|
||||||
|
if "sdp" in data and "type" in data:
|
||||||
|
logger.info(f"Received SDP {data['type']}")
|
||||||
|
|
||||||
|
# Set remote description
|
||||||
|
offer = RTCSessionDescription(sdp=data["sdp"], type=data["type"])
|
||||||
|
await pc.setRemoteDescription(offer)
|
||||||
|
|
||||||
|
# Create and set local description
|
||||||
|
if data["type"] == "offer":
|
||||||
|
answer = await pc.createAnswer()
|
||||||
|
await pc.setLocalDescription(answer)
|
||||||
|
|
||||||
|
# Send answer back
|
||||||
|
await websocket.send_text(json.dumps({
|
||||||
|
"event": "answer",
|
||||||
|
"trackId": session_id,
|
||||||
|
"timestamp": int(asyncio.get_event_loop().time() * 1000),
|
||||||
|
"sdp": pc.localDescription.sdp
|
||||||
|
}))
|
||||||
|
|
||||||
|
logger.info(f"Sent SDP answer")
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Handle other commands
|
||||||
|
await session.handle_text(message["text"])
|
||||||
|
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
logger.info(f"WebRTC WebSocket disconnected: {session_id}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"WebRTC error: {e}", exc_info=True)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Cleanup
|
||||||
|
await pc.close()
|
||||||
|
if session_id in active_sessions:
|
||||||
|
await session.cleanup()
|
||||||
|
del active_sessions[session_id]
|
||||||
|
|
||||||
|
logger.info(f"WebRTC session {session_id} removed")
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def startup_event():
|
||||||
|
"""Run on application startup."""
|
||||||
|
logger.info("Starting Python Active-Call server")
|
||||||
|
logger.info(f"Server: {settings.host}:{settings.port}")
|
||||||
|
logger.info(f"Sample rate: {settings.sample_rate} Hz")
|
||||||
|
logger.info(f"VAD model: {settings.vad_model_path}")
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("shutdown")
|
||||||
|
async def shutdown_event():
|
||||||
|
"""Run on application shutdown."""
|
||||||
|
logger.info("Shutting down Python Active-Call server")
|
||||||
|
|
||||||
|
# Cleanup all sessions
|
||||||
|
for session_id, session in active_sessions.items():
|
||||||
|
await session.cleanup()
|
||||||
|
|
||||||
|
# Close event bus
|
||||||
|
event_bus = get_event_bus()
|
||||||
|
await event_bus.close()
|
||||||
|
reset_event_bus()
|
||||||
|
|
||||||
|
logger.info("Server shutdown complete")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
# Create logs directory
|
||||||
|
import os
|
||||||
|
os.makedirs("logs", exist_ok=True)
|
||||||
|
|
||||||
|
# Run server
|
||||||
|
uvicorn.run(
|
||||||
|
"app.main:app",
|
||||||
|
host=settings.host,
|
||||||
|
port=settings.port,
|
||||||
|
reload=True,
|
||||||
|
log_level=settings.log_level.lower()
|
||||||
|
)
|
||||||
22
core/__init__.py
Normal file
22
core/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
"""Core Components Package"""
|
||||||
|
|
||||||
|
from core.events import EventBus, get_event_bus
|
||||||
|
from core.transports import BaseTransport, SocketTransport, WebRtcTransport
|
||||||
|
from core.pipeline import AudioPipeline
|
||||||
|
from core.session import Session
|
||||||
|
from core.conversation import ConversationManager, ConversationState, ConversationTurn
|
||||||
|
from core.duplex_pipeline import DuplexPipeline
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"EventBus",
|
||||||
|
"get_event_bus",
|
||||||
|
"BaseTransport",
|
||||||
|
"SocketTransport",
|
||||||
|
"WebRtcTransport",
|
||||||
|
"AudioPipeline",
|
||||||
|
"Session",
|
||||||
|
"ConversationManager",
|
||||||
|
"ConversationState",
|
||||||
|
"ConversationTurn",
|
||||||
|
"DuplexPipeline",
|
||||||
|
]
|
||||||
255
core/conversation.py
Normal file
255
core/conversation.py
Normal file
@@ -0,0 +1,255 @@
|
|||||||
|
"""Conversation management for voice AI.
|
||||||
|
|
||||||
|
Handles conversation context, turn-taking, and message history
|
||||||
|
for multi-turn voice conversations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import List, Optional, Dict, Any, Callable, Awaitable
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from services.base import LLMMessage
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationState(Enum):
|
||||||
|
"""State of the conversation."""
|
||||||
|
IDLE = "idle" # Waiting for user input
|
||||||
|
LISTENING = "listening" # User is speaking
|
||||||
|
PROCESSING = "processing" # Processing user input (LLM)
|
||||||
|
SPEAKING = "speaking" # Bot is speaking
|
||||||
|
INTERRUPTED = "interrupted" # Bot was interrupted
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ConversationTurn:
|
||||||
|
"""A single turn in the conversation."""
|
||||||
|
role: str # "user" or "assistant"
|
||||||
|
text: str
|
||||||
|
audio_duration_ms: Optional[int] = None
|
||||||
|
timestamp: float = field(default_factory=lambda: asyncio.get_event_loop().time())
|
||||||
|
was_interrupted: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationManager:
|
||||||
|
"""
|
||||||
|
Manages conversation state and history.
|
||||||
|
|
||||||
|
Provides:
|
||||||
|
- Message history for LLM context
|
||||||
|
- Turn management
|
||||||
|
- State tracking
|
||||||
|
- Event callbacks for state changes
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
max_history: int = 20,
|
||||||
|
greeting: Optional[str] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize conversation manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
system_prompt: System prompt for LLM
|
||||||
|
max_history: Maximum number of turns to keep
|
||||||
|
greeting: Optional greeting message when conversation starts
|
||||||
|
"""
|
||||||
|
self.system_prompt = system_prompt or (
|
||||||
|
"You are a helpful, friendly voice assistant. "
|
||||||
|
"Keep your responses concise and conversational. "
|
||||||
|
"Respond naturally as if having a phone conversation. "
|
||||||
|
"If you don't understand something, ask for clarification."
|
||||||
|
)
|
||||||
|
self.max_history = max_history
|
||||||
|
self.greeting = greeting
|
||||||
|
|
||||||
|
# State
|
||||||
|
self.state = ConversationState.IDLE
|
||||||
|
self.turns: List[ConversationTurn] = []
|
||||||
|
|
||||||
|
# Callbacks
|
||||||
|
self._state_callbacks: List[Callable[[ConversationState, ConversationState], Awaitable[None]]] = []
|
||||||
|
self._turn_callbacks: List[Callable[[ConversationTurn], Awaitable[None]]] = []
|
||||||
|
|
||||||
|
# Current turn tracking
|
||||||
|
self._current_user_text: str = ""
|
||||||
|
self._current_assistant_text: str = ""
|
||||||
|
|
||||||
|
logger.info("ConversationManager initialized")
|
||||||
|
|
||||||
|
def on_state_change(
|
||||||
|
self,
|
||||||
|
callback: Callable[[ConversationState, ConversationState], Awaitable[None]]
|
||||||
|
) -> None:
|
||||||
|
"""Register callback for state changes."""
|
||||||
|
self._state_callbacks.append(callback)
|
||||||
|
|
||||||
|
def on_turn_complete(
|
||||||
|
self,
|
||||||
|
callback: Callable[[ConversationTurn], Awaitable[None]]
|
||||||
|
) -> None:
|
||||||
|
"""Register callback for turn completion."""
|
||||||
|
self._turn_callbacks.append(callback)
|
||||||
|
|
||||||
|
async def set_state(self, new_state: ConversationState) -> None:
|
||||||
|
"""Set conversation state and notify listeners."""
|
||||||
|
if new_state != self.state:
|
||||||
|
old_state = self.state
|
||||||
|
self.state = new_state
|
||||||
|
logger.debug(f"Conversation state: {old_state.value} -> {new_state.value}")
|
||||||
|
|
||||||
|
for callback in self._state_callbacks:
|
||||||
|
try:
|
||||||
|
await callback(old_state, new_state)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"State callback error: {e}")
|
||||||
|
|
||||||
|
def get_messages(self) -> List[LLMMessage]:
|
||||||
|
"""
|
||||||
|
Get conversation history as LLM messages.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of LLMMessage objects including system prompt
|
||||||
|
"""
|
||||||
|
messages = [LLMMessage(role="system", content=self.system_prompt)]
|
||||||
|
|
||||||
|
# Add conversation history
|
||||||
|
for turn in self.turns[-self.max_history:]:
|
||||||
|
messages.append(LLMMessage(role=turn.role, content=turn.text))
|
||||||
|
|
||||||
|
# Add current user text if any
|
||||||
|
if self._current_user_text:
|
||||||
|
messages.append(LLMMessage(role="user", content=self._current_user_text))
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
async def start_user_turn(self) -> None:
|
||||||
|
"""Signal that user has started speaking."""
|
||||||
|
await self.set_state(ConversationState.LISTENING)
|
||||||
|
self._current_user_text = ""
|
||||||
|
|
||||||
|
async def update_user_text(self, text: str, is_final: bool = False) -> None:
|
||||||
|
"""
|
||||||
|
Update current user text (from ASR).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Transcribed text
|
||||||
|
is_final: Whether this is the final transcript
|
||||||
|
"""
|
||||||
|
self._current_user_text = text
|
||||||
|
|
||||||
|
async def end_user_turn(self, text: str) -> None:
|
||||||
|
"""
|
||||||
|
End user turn and add to history.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Final user text
|
||||||
|
"""
|
||||||
|
if text.strip():
|
||||||
|
turn = ConversationTurn(role="user", text=text.strip())
|
||||||
|
self.turns.append(turn)
|
||||||
|
|
||||||
|
for callback in self._turn_callbacks:
|
||||||
|
try:
|
||||||
|
await callback(turn)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Turn callback error: {e}")
|
||||||
|
|
||||||
|
logger.info(f"User: {text[:50]}...")
|
||||||
|
|
||||||
|
self._current_user_text = ""
|
||||||
|
await self.set_state(ConversationState.PROCESSING)
|
||||||
|
|
||||||
|
async def start_assistant_turn(self) -> None:
|
||||||
|
"""Signal that assistant has started speaking."""
|
||||||
|
await self.set_state(ConversationState.SPEAKING)
|
||||||
|
self._current_assistant_text = ""
|
||||||
|
|
||||||
|
async def update_assistant_text(self, text: str) -> None:
|
||||||
|
"""
|
||||||
|
Update current assistant text (streaming).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text chunk from LLM
|
||||||
|
"""
|
||||||
|
self._current_assistant_text += text
|
||||||
|
|
||||||
|
async def end_assistant_turn(self, was_interrupted: bool = False) -> None:
|
||||||
|
"""
|
||||||
|
End assistant turn and add to history.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
was_interrupted: Whether the turn was interrupted by user
|
||||||
|
"""
|
||||||
|
text = self._current_assistant_text.strip()
|
||||||
|
if text:
|
||||||
|
turn = ConversationTurn(
|
||||||
|
role="assistant",
|
||||||
|
text=text,
|
||||||
|
was_interrupted=was_interrupted
|
||||||
|
)
|
||||||
|
self.turns.append(turn)
|
||||||
|
|
||||||
|
for callback in self._turn_callbacks:
|
||||||
|
try:
|
||||||
|
await callback(turn)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Turn callback error: {e}")
|
||||||
|
|
||||||
|
status = " (interrupted)" if was_interrupted else ""
|
||||||
|
logger.info(f"Assistant{status}: {text[:50]}...")
|
||||||
|
|
||||||
|
self._current_assistant_text = ""
|
||||||
|
|
||||||
|
if was_interrupted:
|
||||||
|
await self.set_state(ConversationState.INTERRUPTED)
|
||||||
|
else:
|
||||||
|
await self.set_state(ConversationState.IDLE)
|
||||||
|
|
||||||
|
async def interrupt(self) -> None:
|
||||||
|
"""Handle interruption (barge-in)."""
|
||||||
|
if self.state == ConversationState.SPEAKING:
|
||||||
|
await self.end_assistant_turn(was_interrupted=True)
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Reset conversation history."""
|
||||||
|
self.turns = []
|
||||||
|
self._current_user_text = ""
|
||||||
|
self._current_assistant_text = ""
|
||||||
|
self.state = ConversationState.IDLE
|
||||||
|
logger.info("Conversation reset")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def turn_count(self) -> int:
|
||||||
|
"""Get number of turns in conversation."""
|
||||||
|
return len(self.turns)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def last_user_text(self) -> Optional[str]:
|
||||||
|
"""Get last user text."""
|
||||||
|
for turn in reversed(self.turns):
|
||||||
|
if turn.role == "user":
|
||||||
|
return turn.text
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def last_assistant_text(self) -> Optional[str]:
|
||||||
|
"""Get last assistant text."""
|
||||||
|
for turn in reversed(self.turns):
|
||||||
|
if turn.role == "assistant":
|
||||||
|
return turn.text
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_context_summary(self) -> Dict[str, Any]:
|
||||||
|
"""Get a summary of conversation context."""
|
||||||
|
return {
|
||||||
|
"state": self.state.value,
|
||||||
|
"turn_count": self.turn_count,
|
||||||
|
"last_user": self.last_user_text,
|
||||||
|
"last_assistant": self.last_assistant_text,
|
||||||
|
"current_user": self._current_user_text or None,
|
||||||
|
"current_assistant": self._current_assistant_text or None
|
||||||
|
}
|
||||||
632
core/duplex_pipeline.py
Normal file
632
core/duplex_pipeline.py
Normal file
@@ -0,0 +1,632 @@
|
|||||||
|
"""Full duplex audio pipeline for AI voice conversation.
|
||||||
|
|
||||||
|
This module implements the core duplex pipeline that orchestrates:
|
||||||
|
- VAD (Voice Activity Detection)
|
||||||
|
- EOU (End of Utterance) Detection
|
||||||
|
- ASR (Automatic Speech Recognition) - optional
|
||||||
|
- LLM (Language Model)
|
||||||
|
- TTS (Text-to-Speech)
|
||||||
|
|
||||||
|
Inspired by pipecat's frame-based architecture and active-call's
|
||||||
|
event-driven design.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from typing import Optional, Callable, Awaitable
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from core.transports import BaseTransport
|
||||||
|
from core.conversation import ConversationManager, ConversationState
|
||||||
|
from core.events import get_event_bus
|
||||||
|
from processors.vad import VADProcessor, SileroVAD
|
||||||
|
from processors.eou import EouDetector
|
||||||
|
from services.base import BaseLLMService, BaseTTSService, BaseASRService
|
||||||
|
from services.llm import OpenAILLMService, MockLLMService
|
||||||
|
from services.tts import EdgeTTSService, MockTTSService
|
||||||
|
from services.asr import BufferedASRService
|
||||||
|
from services.siliconflow_tts import SiliconFlowTTSService
|
||||||
|
from services.siliconflow_asr import SiliconFlowASRService
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
class DuplexPipeline:
|
||||||
|
"""
|
||||||
|
Full duplex audio pipeline for AI voice conversation.
|
||||||
|
|
||||||
|
Handles bidirectional audio flow with:
|
||||||
|
- User speech detection and transcription
|
||||||
|
- AI response generation
|
||||||
|
- Text-to-speech synthesis
|
||||||
|
- Barge-in (interruption) support
|
||||||
|
|
||||||
|
Architecture (inspired by pipecat):
|
||||||
|
|
||||||
|
User Audio → VAD → EOU → [ASR] → LLM → TTS → Audio Out
|
||||||
|
↓
|
||||||
|
Barge-in Detection → Interrupt
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
transport: BaseTransport,
|
||||||
|
session_id: str,
|
||||||
|
llm_service: Optional[BaseLLMService] = None,
|
||||||
|
tts_service: Optional[BaseTTSService] = None,
|
||||||
|
asr_service: Optional[BaseASRService] = None,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
greeting: Optional[str] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize duplex pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
transport: Transport for sending audio/events
|
||||||
|
session_id: Session identifier
|
||||||
|
llm_service: LLM service (defaults to OpenAI)
|
||||||
|
tts_service: TTS service (defaults to EdgeTTS)
|
||||||
|
asr_service: ASR service (optional)
|
||||||
|
system_prompt: System prompt for LLM
|
||||||
|
greeting: Optional greeting to speak on start
|
||||||
|
"""
|
||||||
|
self.transport = transport
|
||||||
|
self.session_id = session_id
|
||||||
|
self.event_bus = get_event_bus()
|
||||||
|
|
||||||
|
# Initialize VAD
|
||||||
|
self.vad_model = SileroVAD(
|
||||||
|
model_path=settings.vad_model_path,
|
||||||
|
sample_rate=settings.sample_rate
|
||||||
|
)
|
||||||
|
self.vad_processor = VADProcessor(
|
||||||
|
vad_model=self.vad_model,
|
||||||
|
threshold=settings.vad_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize EOU detector
|
||||||
|
self.eou_detector = EouDetector(
|
||||||
|
silence_threshold_ms=600,
|
||||||
|
min_speech_duration_ms=200
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize services
|
||||||
|
self.llm_service = llm_service
|
||||||
|
self.tts_service = tts_service
|
||||||
|
self.asr_service = asr_service # Will be initialized in start()
|
||||||
|
|
||||||
|
# Track last sent transcript to avoid duplicates
|
||||||
|
self._last_sent_transcript = ""
|
||||||
|
|
||||||
|
# Conversation manager
|
||||||
|
self.conversation = ConversationManager(
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
greeting=greeting
|
||||||
|
)
|
||||||
|
|
||||||
|
# State
|
||||||
|
self._running = True
|
||||||
|
self._is_bot_speaking = False
|
||||||
|
self._current_turn_task: Optional[asyncio.Task] = None
|
||||||
|
self._audio_buffer: bytes = b""
|
||||||
|
self._last_vad_status: str = "Silence"
|
||||||
|
|
||||||
|
# Interruption handling
|
||||||
|
self._interrupt_event = asyncio.Event()
|
||||||
|
|
||||||
|
# Barge-in filtering - require minimum speech duration to interrupt
|
||||||
|
self._barge_in_speech_start_time: Optional[float] = None
|
||||||
|
self._barge_in_min_duration_ms: int = settings.barge_in_min_duration_ms if hasattr(settings, 'barge_in_min_duration_ms') else 50
|
||||||
|
self._barge_in_speech_frames: int = 0 # Count speech frames
|
||||||
|
self._barge_in_silence_frames: int = 0 # Count silence frames during potential barge-in
|
||||||
|
self._barge_in_silence_tolerance: int = 3 # Allow up to 3 silence frames (60ms at 20ms chunks)
|
||||||
|
|
||||||
|
logger.info(f"DuplexPipeline initialized for session {session_id}")
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
"""Start the pipeline and connect services."""
|
||||||
|
try:
|
||||||
|
# Connect LLM service
|
||||||
|
if not self.llm_service:
|
||||||
|
if settings.openai_api_key:
|
||||||
|
self.llm_service = OpenAILLMService(
|
||||||
|
api_key=settings.openai_api_key,
|
||||||
|
base_url=settings.openai_api_url,
|
||||||
|
model=settings.llm_model
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning("No OpenAI API key - using mock LLM")
|
||||||
|
self.llm_service = MockLLMService()
|
||||||
|
|
||||||
|
await self.llm_service.connect()
|
||||||
|
|
||||||
|
# Connect TTS service
|
||||||
|
if not self.tts_service:
|
||||||
|
if settings.tts_provider == "siliconflow" and settings.siliconflow_api_key:
|
||||||
|
self.tts_service = SiliconFlowTTSService(
|
||||||
|
api_key=settings.siliconflow_api_key,
|
||||||
|
voice=settings.tts_voice,
|
||||||
|
model=settings.siliconflow_tts_model,
|
||||||
|
sample_rate=settings.sample_rate,
|
||||||
|
speed=settings.tts_speed
|
||||||
|
)
|
||||||
|
logger.info("Using SiliconFlow TTS service")
|
||||||
|
else:
|
||||||
|
self.tts_service = EdgeTTSService(
|
||||||
|
voice=settings.tts_voice,
|
||||||
|
sample_rate=settings.sample_rate
|
||||||
|
)
|
||||||
|
logger.info("Using Edge TTS service")
|
||||||
|
|
||||||
|
await self.tts_service.connect()
|
||||||
|
|
||||||
|
# Connect ASR service
|
||||||
|
if not self.asr_service:
|
||||||
|
if settings.asr_provider == "siliconflow" and settings.siliconflow_api_key:
|
||||||
|
self.asr_service = SiliconFlowASRService(
|
||||||
|
api_key=settings.siliconflow_api_key,
|
||||||
|
model=settings.siliconflow_asr_model,
|
||||||
|
sample_rate=settings.sample_rate,
|
||||||
|
interim_interval_ms=settings.asr_interim_interval_ms,
|
||||||
|
min_audio_for_interim_ms=settings.asr_min_audio_ms,
|
||||||
|
on_transcript=self._on_transcript_callback
|
||||||
|
)
|
||||||
|
logger.info("Using SiliconFlow ASR service")
|
||||||
|
else:
|
||||||
|
self.asr_service = BufferedASRService(
|
||||||
|
sample_rate=settings.sample_rate
|
||||||
|
)
|
||||||
|
logger.info("Using Buffered ASR service (no real transcription)")
|
||||||
|
|
||||||
|
await self.asr_service.connect()
|
||||||
|
|
||||||
|
logger.info("DuplexPipeline services connected")
|
||||||
|
|
||||||
|
# Speak greeting if configured
|
||||||
|
if self.conversation.greeting:
|
||||||
|
await self._speak(self.conversation.greeting)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to start pipeline: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def process_audio(self, pcm_bytes: bytes) -> None:
|
||||||
|
"""
|
||||||
|
Process incoming audio chunk.
|
||||||
|
|
||||||
|
This is the main entry point for audio from the user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
|
||||||
|
"""
|
||||||
|
if not self._running:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. Process through VAD
|
||||||
|
vad_result = self.vad_processor.process(pcm_bytes, settings.chunk_size_ms)
|
||||||
|
|
||||||
|
vad_status = "Silence"
|
||||||
|
if vad_result:
|
||||||
|
event_type, probability = vad_result
|
||||||
|
vad_status = "Speech" if event_type == "speaking" else "Silence"
|
||||||
|
|
||||||
|
# Emit VAD event
|
||||||
|
await self.event_bus.publish(event_type, {
|
||||||
|
"trackId": self.session_id,
|
||||||
|
"probability": probability
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
# No state change - keep previous status
|
||||||
|
vad_status = self._last_vad_status
|
||||||
|
|
||||||
|
# Update state based on VAD
|
||||||
|
if vad_status == "Speech" and self._last_vad_status != "Speech":
|
||||||
|
await self._on_speech_start()
|
||||||
|
|
||||||
|
self._last_vad_status = vad_status
|
||||||
|
|
||||||
|
# 2. Check for barge-in (user speaking while bot speaking)
|
||||||
|
# Filter false interruptions by requiring minimum speech duration
|
||||||
|
if self._is_bot_speaking:
|
||||||
|
if vad_status == "Speech":
|
||||||
|
# User is speaking while bot is speaking
|
||||||
|
self._barge_in_silence_frames = 0 # Reset silence counter
|
||||||
|
|
||||||
|
if self._barge_in_speech_start_time is None:
|
||||||
|
# Start tracking speech duration
|
||||||
|
self._barge_in_speech_start_time = time.time()
|
||||||
|
self._barge_in_speech_frames = 1
|
||||||
|
logger.debug("Potential barge-in detected, tracking duration...")
|
||||||
|
else:
|
||||||
|
self._barge_in_speech_frames += 1
|
||||||
|
# Check if speech duration exceeds threshold
|
||||||
|
speech_duration_ms = (time.time() - self._barge_in_speech_start_time) * 1000
|
||||||
|
if speech_duration_ms >= self._barge_in_min_duration_ms:
|
||||||
|
logger.info(f"Barge-in confirmed after {speech_duration_ms:.0f}ms of speech ({self._barge_in_speech_frames} frames)")
|
||||||
|
await self._handle_barge_in()
|
||||||
|
else:
|
||||||
|
# Silence frame during potential barge-in
|
||||||
|
if self._barge_in_speech_start_time is not None:
|
||||||
|
self._barge_in_silence_frames += 1
|
||||||
|
# Allow brief silence gaps (VAD flickering)
|
||||||
|
if self._barge_in_silence_frames > self._barge_in_silence_tolerance:
|
||||||
|
# Too much silence - reset barge-in tracking
|
||||||
|
logger.debug(f"Barge-in cancelled after {self._barge_in_silence_frames} silence frames")
|
||||||
|
self._barge_in_speech_start_time = None
|
||||||
|
self._barge_in_speech_frames = 0
|
||||||
|
self._barge_in_silence_frames = 0
|
||||||
|
|
||||||
|
# 3. Buffer audio for ASR
|
||||||
|
if vad_status == "Speech" or self.conversation.state == ConversationState.LISTENING:
|
||||||
|
self._audio_buffer += pcm_bytes
|
||||||
|
await self.asr_service.send_audio(pcm_bytes)
|
||||||
|
|
||||||
|
# For SiliconFlow ASR, trigger interim transcription periodically
|
||||||
|
# The service handles timing internally via start_interim_transcription()
|
||||||
|
|
||||||
|
# 4. Check for End of Utterance - this triggers LLM response
|
||||||
|
if self.eou_detector.process(vad_status):
|
||||||
|
await self._on_end_of_utterance()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Pipeline audio processing error: {e}", exc_info=True)
|
||||||
|
|
||||||
|
async def process_text(self, text: str) -> None:
|
||||||
|
"""
|
||||||
|
Process text input (chat command).
|
||||||
|
|
||||||
|
Allows direct text input to bypass ASR.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: User text input
|
||||||
|
"""
|
||||||
|
if not self._running:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Processing text input: {text[:50]}...")
|
||||||
|
|
||||||
|
# Cancel any current speaking
|
||||||
|
await self._stop_current_speech()
|
||||||
|
|
||||||
|
# Start new turn
|
||||||
|
await self.conversation.end_user_turn(text)
|
||||||
|
self._current_turn_task = asyncio.create_task(self._handle_turn(text))
|
||||||
|
|
||||||
|
async def interrupt(self) -> None:
|
||||||
|
"""Interrupt current bot speech (manual interrupt command)."""
|
||||||
|
await self._handle_barge_in()
|
||||||
|
|
||||||
|
async def _on_transcript_callback(self, text: str, is_final: bool) -> None:
|
||||||
|
"""
|
||||||
|
Callback for ASR transcription results.
|
||||||
|
|
||||||
|
Streams transcription to client for display.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Transcribed text
|
||||||
|
is_final: Whether this is the final transcription
|
||||||
|
"""
|
||||||
|
# Avoid sending duplicate transcripts
|
||||||
|
if text == self._last_sent_transcript and not is_final:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._last_sent_transcript = text
|
||||||
|
|
||||||
|
# Send transcript event to client
|
||||||
|
await self.transport.send_event({
|
||||||
|
"event": "transcript",
|
||||||
|
"trackId": self.session_id,
|
||||||
|
"text": text,
|
||||||
|
"isFinal": is_final,
|
||||||
|
"timestamp": self._get_timestamp_ms()
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.debug(f"Sent transcript ({'final' if is_final else 'interim'}): {text[:50]}...")
|
||||||
|
|
||||||
|
async def _on_speech_start(self) -> None:
|
||||||
|
"""Handle user starting to speak."""
|
||||||
|
if self.conversation.state == ConversationState.IDLE:
|
||||||
|
await self.conversation.start_user_turn()
|
||||||
|
self._audio_buffer = b""
|
||||||
|
self._last_sent_transcript = ""
|
||||||
|
self.eou_detector.reset()
|
||||||
|
|
||||||
|
# Clear ASR buffer and start interim transcriptions
|
||||||
|
if hasattr(self.asr_service, 'clear_buffer'):
|
||||||
|
self.asr_service.clear_buffer()
|
||||||
|
if hasattr(self.asr_service, 'start_interim_transcription'):
|
||||||
|
await self.asr_service.start_interim_transcription()
|
||||||
|
|
||||||
|
logger.debug("User speech started")
|
||||||
|
|
||||||
|
async def _on_end_of_utterance(self) -> None:
|
||||||
|
"""Handle end of user utterance."""
|
||||||
|
if self.conversation.state != ConversationState.LISTENING:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Stop interim transcriptions
|
||||||
|
if hasattr(self.asr_service, 'stop_interim_transcription'):
|
||||||
|
await self.asr_service.stop_interim_transcription()
|
||||||
|
|
||||||
|
# Get final transcription from ASR service
|
||||||
|
user_text = ""
|
||||||
|
|
||||||
|
if hasattr(self.asr_service, 'get_final_transcription'):
|
||||||
|
# SiliconFlow ASR - get final transcription
|
||||||
|
user_text = await self.asr_service.get_final_transcription()
|
||||||
|
elif hasattr(self.asr_service, 'get_and_clear_text'):
|
||||||
|
# Buffered ASR - get accumulated text
|
||||||
|
user_text = self.asr_service.get_and_clear_text()
|
||||||
|
|
||||||
|
# Skip if no meaningful text
|
||||||
|
if not user_text or not user_text.strip():
|
||||||
|
logger.debug("EOU detected but no transcription - skipping")
|
||||||
|
# Reset for next utterance
|
||||||
|
self._audio_buffer = b""
|
||||||
|
self._last_sent_transcript = ""
|
||||||
|
await self.conversation.start_user_turn()
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"EOU detected - user said: {user_text[:100]}...")
|
||||||
|
|
||||||
|
# Send final transcription to client
|
||||||
|
await self.transport.send_event({
|
||||||
|
"event": "transcript",
|
||||||
|
"trackId": self.session_id,
|
||||||
|
"text": user_text,
|
||||||
|
"isFinal": True,
|
||||||
|
"timestamp": self._get_timestamp_ms()
|
||||||
|
})
|
||||||
|
|
||||||
|
# Clear buffers
|
||||||
|
self._audio_buffer = b""
|
||||||
|
self._last_sent_transcript = ""
|
||||||
|
|
||||||
|
# Process the turn - trigger LLM response
|
||||||
|
await self.conversation.end_user_turn(user_text)
|
||||||
|
self._current_turn_task = asyncio.create_task(self._handle_turn(user_text))
|
||||||
|
|
||||||
|
async def _handle_turn(self, user_text: str) -> None:
|
||||||
|
"""
|
||||||
|
Handle a complete conversation turn.
|
||||||
|
|
||||||
|
Uses sentence-by-sentence streaming TTS for lower latency.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_text: User's transcribed text
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get AI response (streaming)
|
||||||
|
messages = self.conversation.get_messages()
|
||||||
|
full_response = ""
|
||||||
|
|
||||||
|
await self.conversation.start_assistant_turn()
|
||||||
|
self._is_bot_speaking = True
|
||||||
|
self._interrupt_event.clear()
|
||||||
|
|
||||||
|
# Sentence buffer for streaming TTS
|
||||||
|
sentence_buffer = ""
|
||||||
|
sentence_ends = {'.', '!', '?', '。', '!', '?', ';', '\n'}
|
||||||
|
first_audio_sent = False
|
||||||
|
|
||||||
|
# Stream LLM response and TTS sentence by sentence
|
||||||
|
async for text_chunk in self.llm_service.generate_stream(messages):
|
||||||
|
if self._interrupt_event.is_set():
|
||||||
|
break
|
||||||
|
|
||||||
|
full_response += text_chunk
|
||||||
|
sentence_buffer += text_chunk
|
||||||
|
await self.conversation.update_assistant_text(text_chunk)
|
||||||
|
|
||||||
|
# Check for sentence completion - synthesize immediately for low latency
|
||||||
|
while any(end in sentence_buffer for end in sentence_ends):
|
||||||
|
# Find first sentence end
|
||||||
|
min_idx = len(sentence_buffer)
|
||||||
|
for end in sentence_ends:
|
||||||
|
idx = sentence_buffer.find(end)
|
||||||
|
if idx != -1 and idx < min_idx:
|
||||||
|
min_idx = idx
|
||||||
|
|
||||||
|
if min_idx < len(sentence_buffer):
|
||||||
|
sentence = sentence_buffer[:min_idx + 1].strip()
|
||||||
|
sentence_buffer = sentence_buffer[min_idx + 1:]
|
||||||
|
|
||||||
|
if sentence and not self._interrupt_event.is_set():
|
||||||
|
# Send track start on first audio
|
||||||
|
if not first_audio_sent:
|
||||||
|
await self.transport.send_event({
|
||||||
|
"event": "trackStart",
|
||||||
|
"trackId": self.session_id,
|
||||||
|
"timestamp": self._get_timestamp_ms()
|
||||||
|
})
|
||||||
|
first_audio_sent = True
|
||||||
|
|
||||||
|
# Synthesize and send this sentence immediately
|
||||||
|
await self._speak_sentence(sentence)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Speak any remaining text
|
||||||
|
if sentence_buffer.strip() and not self._interrupt_event.is_set():
|
||||||
|
if not first_audio_sent:
|
||||||
|
await self.transport.send_event({
|
||||||
|
"event": "trackStart",
|
||||||
|
"trackId": self.session_id,
|
||||||
|
"timestamp": self._get_timestamp_ms()
|
||||||
|
})
|
||||||
|
first_audio_sent = True
|
||||||
|
await self._speak_sentence(sentence_buffer.strip())
|
||||||
|
|
||||||
|
# Send track end
|
||||||
|
if first_audio_sent:
|
||||||
|
await self.transport.send_event({
|
||||||
|
"event": "trackEnd",
|
||||||
|
"trackId": self.session_id,
|
||||||
|
"timestamp": self._get_timestamp_ms()
|
||||||
|
})
|
||||||
|
|
||||||
|
# End assistant turn
|
||||||
|
await self.conversation.end_assistant_turn(
|
||||||
|
was_interrupted=self._interrupt_event.is_set()
|
||||||
|
)
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("Turn handling cancelled")
|
||||||
|
await self.conversation.end_assistant_turn(was_interrupted=True)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Turn handling error: {e}", exc_info=True)
|
||||||
|
await self.conversation.end_assistant_turn(was_interrupted=True)
|
||||||
|
finally:
|
||||||
|
self._is_bot_speaking = False
|
||||||
|
# Reset barge-in tracking when bot finishes speaking
|
||||||
|
self._barge_in_speech_start_time = None
|
||||||
|
self._barge_in_speech_frames = 0
|
||||||
|
self._barge_in_silence_frames = 0
|
||||||
|
|
||||||
|
async def _speak_sentence(self, text: str) -> None:
|
||||||
|
"""
|
||||||
|
Synthesize and send a single sentence.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Sentence to speak
|
||||||
|
"""
|
||||||
|
if not text.strip() or self._interrupt_event.is_set():
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for chunk in self.tts_service.synthesize_stream(text):
|
||||||
|
if self._interrupt_event.is_set():
|
||||||
|
break
|
||||||
|
await self.transport.send_audio(chunk.audio)
|
||||||
|
await asyncio.sleep(0.005) # Small delay to prevent flooding
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"TTS sentence error: {e}")
|
||||||
|
|
||||||
|
async def _speak(self, text: str) -> None:
|
||||||
|
"""
|
||||||
|
Synthesize and send speech.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to speak
|
||||||
|
"""
|
||||||
|
if not text.strip():
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Send track start event
|
||||||
|
await self.transport.send_event({
|
||||||
|
"event": "trackStart",
|
||||||
|
"trackId": self.session_id,
|
||||||
|
"timestamp": self._get_timestamp_ms()
|
||||||
|
})
|
||||||
|
|
||||||
|
self._is_bot_speaking = True
|
||||||
|
|
||||||
|
# Stream TTS audio
|
||||||
|
async for chunk in self.tts_service.synthesize_stream(text):
|
||||||
|
if self._interrupt_event.is_set():
|
||||||
|
logger.info("TTS interrupted by barge-in")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Send audio to client
|
||||||
|
await self.transport.send_audio(chunk.audio)
|
||||||
|
|
||||||
|
# Small delay to prevent flooding
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
|
||||||
|
# Send track end event
|
||||||
|
await self.transport.send_event({
|
||||||
|
"event": "trackEnd",
|
||||||
|
"trackId": self.session_id,
|
||||||
|
"timestamp": self._get_timestamp_ms()
|
||||||
|
})
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("TTS cancelled")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"TTS error: {e}")
|
||||||
|
finally:
|
||||||
|
self._is_bot_speaking = False
|
||||||
|
|
||||||
|
async def _handle_barge_in(self) -> None:
|
||||||
|
"""Handle user barge-in (interruption)."""
|
||||||
|
if not self._is_bot_speaking:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Barge-in detected - interrupting bot speech")
|
||||||
|
|
||||||
|
# Reset barge-in tracking
|
||||||
|
self._barge_in_speech_start_time = None
|
||||||
|
self._barge_in_speech_frames = 0
|
||||||
|
self._barge_in_silence_frames = 0
|
||||||
|
|
||||||
|
# Signal interruption
|
||||||
|
self._interrupt_event.set()
|
||||||
|
|
||||||
|
# Cancel TTS
|
||||||
|
if self.tts_service:
|
||||||
|
await self.tts_service.cancel()
|
||||||
|
|
||||||
|
# Cancel LLM
|
||||||
|
if self.llm_service and hasattr(self.llm_service, 'cancel'):
|
||||||
|
self.llm_service.cancel()
|
||||||
|
|
||||||
|
# Interrupt conversation
|
||||||
|
await self.conversation.interrupt()
|
||||||
|
|
||||||
|
# Send interrupt event to client
|
||||||
|
await self.transport.send_event({
|
||||||
|
"event": "interrupt",
|
||||||
|
"trackId": self.session_id,
|
||||||
|
"timestamp": self._get_timestamp_ms()
|
||||||
|
})
|
||||||
|
|
||||||
|
# Reset for new user turn
|
||||||
|
self._is_bot_speaking = False
|
||||||
|
await self.conversation.start_user_turn()
|
||||||
|
self._audio_buffer = b""
|
||||||
|
self.eou_detector.reset()
|
||||||
|
|
||||||
|
async def _stop_current_speech(self) -> None:
|
||||||
|
"""Stop any current speech task."""
|
||||||
|
if self._current_turn_task and not self._current_turn_task.done():
|
||||||
|
self._interrupt_event.set()
|
||||||
|
self._current_turn_task.cancel()
|
||||||
|
try:
|
||||||
|
await self._current_turn_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self._is_bot_speaking = False
|
||||||
|
self._interrupt_event.clear()
|
||||||
|
|
||||||
|
async def cleanup(self) -> None:
|
||||||
|
"""Cleanup pipeline resources."""
|
||||||
|
logger.info(f"Cleaning up DuplexPipeline for session {self.session_id}")
|
||||||
|
|
||||||
|
self._running = False
|
||||||
|
await self._stop_current_speech()
|
||||||
|
|
||||||
|
# Disconnect services
|
||||||
|
if self.llm_service:
|
||||||
|
await self.llm_service.disconnect()
|
||||||
|
if self.tts_service:
|
||||||
|
await self.tts_service.disconnect()
|
||||||
|
if self.asr_service:
|
||||||
|
await self.asr_service.disconnect()
|
||||||
|
|
||||||
|
def _get_timestamp_ms(self) -> int:
|
||||||
|
"""Get current timestamp in milliseconds."""
|
||||||
|
import time
|
||||||
|
return int(time.time() * 1000)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_speaking(self) -> bool:
|
||||||
|
"""Check if bot is currently speaking."""
|
||||||
|
return self._is_bot_speaking
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state(self) -> ConversationState:
|
||||||
|
"""Get current conversation state."""
|
||||||
|
return self.conversation.state
|
||||||
134
core/events.py
Normal file
134
core/events.py
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
"""Event bus for pub/sub communication between components."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import Callable, Dict, List, Any, Optional
|
||||||
|
from collections import defaultdict
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
|
class EventBus:
|
||||||
|
"""
|
||||||
|
Async event bus for pub/sub communication.
|
||||||
|
|
||||||
|
Similar to the original Rust implementation's broadcast channel.
|
||||||
|
Components can subscribe to specific event types and receive events asynchronously.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize the event bus."""
|
||||||
|
self._subscribers: Dict[str, List[Callable]] = defaultdict(list)
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
self._running = True
|
||||||
|
|
||||||
|
def subscribe(self, event_type: str, callback: Callable[[Dict[str, Any]], None]) -> None:
|
||||||
|
"""
|
||||||
|
Subscribe to an event type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_type: Type of event to subscribe to (e.g., "speaking", "silence")
|
||||||
|
callback: Async callback function that receives event data
|
||||||
|
"""
|
||||||
|
if not self._running:
|
||||||
|
logger.warning(f"Event bus is shut down, ignoring subscription to {event_type}")
|
||||||
|
return
|
||||||
|
|
||||||
|
self._subscribers[event_type].append(callback)
|
||||||
|
logger.debug(f"Subscribed to event type: {event_type}")
|
||||||
|
|
||||||
|
def unsubscribe(self, event_type: str, callback: Callable[[Dict[str, Any]], None]) -> None:
|
||||||
|
"""
|
||||||
|
Unsubscribe from an event type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_type: Type of event to unsubscribe from
|
||||||
|
callback: Callback function to remove
|
||||||
|
"""
|
||||||
|
if callback in self._subscribers[event_type]:
|
||||||
|
self._subscribers[event_type].remove(callback)
|
||||||
|
logger.debug(f"Unsubscribed from event type: {event_type}")
|
||||||
|
|
||||||
|
async def publish(self, event_type: str, event_data: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
Publish an event to all subscribers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_type: Type of event to publish
|
||||||
|
event_data: Event data to send to subscribers
|
||||||
|
"""
|
||||||
|
if not self._running:
|
||||||
|
logger.warning(f"Event bus is shut down, ignoring event: {event_type}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get subscribers for this event type
|
||||||
|
subscribers = self._subscribers.get(event_type, [])
|
||||||
|
|
||||||
|
if not subscribers:
|
||||||
|
logger.debug(f"No subscribers for event type: {event_type}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Notify all subscribers concurrently
|
||||||
|
tasks = []
|
||||||
|
for callback in subscribers:
|
||||||
|
try:
|
||||||
|
# Create task for each subscriber
|
||||||
|
task = asyncio.create_task(self._call_subscriber(callback, event_data))
|
||||||
|
tasks.append(task)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating task for subscriber: {e}")
|
||||||
|
|
||||||
|
# Wait for all subscribers to complete
|
||||||
|
if tasks:
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
logger.debug(f"Published event '{event_type}' to {len(tasks)} subscribers")
|
||||||
|
|
||||||
|
async def _call_subscriber(self, callback: Callable[[Dict[str, Any]], None], event_data: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
Call a subscriber callback with error handling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
callback: Subscriber callback function
|
||||||
|
event_data: Event data to pass to callback
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Check if callback is a coroutine function
|
||||||
|
if asyncio.iscoroutinefunction(callback):
|
||||||
|
await callback(event_data)
|
||||||
|
else:
|
||||||
|
callback(event_data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in subscriber callback: {e}", exc_info=True)
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close the event bus and stop processing events."""
|
||||||
|
self._running = False
|
||||||
|
self._subscribers.clear()
|
||||||
|
logger.info("Event bus closed")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_running(self) -> bool:
|
||||||
|
"""Check if the event bus is running."""
|
||||||
|
return self._running
|
||||||
|
|
||||||
|
|
||||||
|
# Global event bus instance
|
||||||
|
_event_bus: Optional[EventBus] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_event_bus() -> EventBus:
|
||||||
|
"""
|
||||||
|
Get the global event bus instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EventBus instance
|
||||||
|
"""
|
||||||
|
global _event_bus
|
||||||
|
if _event_bus is None:
|
||||||
|
_event_bus = EventBus()
|
||||||
|
return _event_bus
|
||||||
|
|
||||||
|
|
||||||
|
def reset_event_bus() -> None:
|
||||||
|
"""Reset the global event bus (mainly for testing)."""
|
||||||
|
global _event_bus
|
||||||
|
_event_bus = None
|
||||||
131
core/pipeline.py
Normal file
131
core/pipeline.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
"""Audio processing pipeline."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import Optional
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from core.transports import BaseTransport
|
||||||
|
from core.events import EventBus, get_event_bus
|
||||||
|
from processors.vad import VADProcessor, SileroVAD
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
class AudioPipeline:
|
||||||
|
"""
|
||||||
|
Audio processing pipeline.
|
||||||
|
|
||||||
|
Processes incoming audio through VAD and emits events.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, transport: BaseTransport, session_id: str):
|
||||||
|
"""
|
||||||
|
Initialize audio pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
transport: Transport instance for sending events/audio
|
||||||
|
session_id: Session identifier for event tracking
|
||||||
|
"""
|
||||||
|
self.transport = transport
|
||||||
|
self.session_id = session_id
|
||||||
|
self.event_bus = get_event_bus()
|
||||||
|
|
||||||
|
# Initialize VAD
|
||||||
|
self.vad_model = SileroVAD(
|
||||||
|
model_path=settings.vad_model_path,
|
||||||
|
sample_rate=settings.sample_rate
|
||||||
|
)
|
||||||
|
self.vad_processor = VADProcessor(
|
||||||
|
vad_model=self.vad_model,
|
||||||
|
threshold=settings.vad_threshold,
|
||||||
|
silence_threshold_ms=settings.vad_eou_threshold_ms,
|
||||||
|
min_speech_duration_ms=settings.vad_min_speech_duration_ms
|
||||||
|
)
|
||||||
|
|
||||||
|
# State
|
||||||
|
self.is_bot_speaking = False
|
||||||
|
self.interrupt_signal = asyncio.Event()
|
||||||
|
self._running = True
|
||||||
|
|
||||||
|
logger.info(f"Audio pipeline initialized for session {session_id}")
|
||||||
|
|
||||||
|
async def process_input(self, pcm_bytes: bytes) -> None:
|
||||||
|
"""
|
||||||
|
Process incoming audio chunk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
|
||||||
|
"""
|
||||||
|
if not self._running:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Process through VAD
|
||||||
|
result = self.vad_processor.process(pcm_bytes, settings.chunk_size_ms)
|
||||||
|
|
||||||
|
if result:
|
||||||
|
event_type, probability = result
|
||||||
|
|
||||||
|
# Emit event through event bus
|
||||||
|
await self.event_bus.publish(event_type, {
|
||||||
|
"trackId": self.session_id,
|
||||||
|
"probability": probability
|
||||||
|
})
|
||||||
|
|
||||||
|
# Send event to client
|
||||||
|
if event_type == "speaking":
|
||||||
|
logger.info(f"User speaking started (session {self.session_id})")
|
||||||
|
await self.transport.send_event({
|
||||||
|
"event": "speaking",
|
||||||
|
"trackId": self.session_id,
|
||||||
|
"timestamp": self._get_timestamp_ms(),
|
||||||
|
"startTime": self._get_timestamp_ms()
|
||||||
|
})
|
||||||
|
|
||||||
|
elif event_type == "silence":
|
||||||
|
logger.info(f"User speaking stopped (session {self.session_id})")
|
||||||
|
await self.transport.send_event({
|
||||||
|
"event": "silence",
|
||||||
|
"trackId": self.session_id,
|
||||||
|
"timestamp": self._get_timestamp_ms(),
|
||||||
|
"startTime": self._get_timestamp_ms(),
|
||||||
|
"duration": 0 # TODO: Calculate actual duration
|
||||||
|
})
|
||||||
|
|
||||||
|
elif event_type == "eou":
|
||||||
|
logger.info(f"EOU detected (session {self.session_id})")
|
||||||
|
await self.transport.send_event({
|
||||||
|
"event": "eou",
|
||||||
|
"trackId": self.session_id,
|
||||||
|
"timestamp": self._get_timestamp_ms()
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Pipeline processing error: {e}", exc_info=True)
|
||||||
|
|
||||||
|
async def process_text_input(self, text: str) -> None:
|
||||||
|
"""
|
||||||
|
Process text input (chat command).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text input
|
||||||
|
"""
|
||||||
|
logger.info(f"Processing text input: {text[:50]}...")
|
||||||
|
# TODO: Implement text processing (LLM integration, etc.)
|
||||||
|
# For now, just log it
|
||||||
|
|
||||||
|
async def interrupt(self) -> None:
|
||||||
|
"""Interrupt current audio playback."""
|
||||||
|
if self.is_bot_speaking:
|
||||||
|
self.interrupt_signal.set()
|
||||||
|
logger.info(f"Pipeline interrupted for session {self.session_id}")
|
||||||
|
|
||||||
|
async def cleanup(self) -> None:
|
||||||
|
"""Cleanup pipeline resources."""
|
||||||
|
logger.info(f"Cleaning up pipeline for session {self.session_id}")
|
||||||
|
self._running = False
|
||||||
|
self.interrupt_signal.set()
|
||||||
|
|
||||||
|
def _get_timestamp_ms(self) -> int:
|
||||||
|
"""Get current timestamp in milliseconds."""
|
||||||
|
import time
|
||||||
|
return int(time.time() * 1000)
|
||||||
301
core/session.py
Normal file
301
core/session.py
Normal file
@@ -0,0 +1,301 @@
|
|||||||
|
"""Session management for active calls."""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
import json
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from core.transports import BaseTransport
|
||||||
|
from core.pipeline import AudioPipeline
|
||||||
|
from models.commands import parse_command, TTSCommand, ChatCommand, InterruptCommand, HangupCommand
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
class Session:
|
||||||
|
"""
|
||||||
|
Manages a single call session.
|
||||||
|
|
||||||
|
Handles command routing, audio processing, and session lifecycle.
|
||||||
|
Supports both basic audio pipeline and full duplex voice conversation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, session_id: str, transport: BaseTransport, use_duplex: bool = None):
|
||||||
|
"""
|
||||||
|
Initialize session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: Unique session identifier
|
||||||
|
transport: Transport instance for communication
|
||||||
|
use_duplex: Whether to use duplex pipeline (defaults to settings.duplex_enabled)
|
||||||
|
"""
|
||||||
|
self.id = session_id
|
||||||
|
self.transport = transport
|
||||||
|
|
||||||
|
# Determine pipeline mode
|
||||||
|
self.use_duplex = use_duplex if use_duplex is not None else settings.duplex_enabled
|
||||||
|
|
||||||
|
if self.use_duplex:
|
||||||
|
from core.duplex_pipeline import DuplexPipeline
|
||||||
|
self.pipeline = DuplexPipeline(
|
||||||
|
transport=transport,
|
||||||
|
session_id=session_id,
|
||||||
|
system_prompt=settings.duplex_system_prompt,
|
||||||
|
greeting=settings.duplex_greeting
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.pipeline = AudioPipeline(transport, session_id)
|
||||||
|
|
||||||
|
# Session state
|
||||||
|
self.created_at = None
|
||||||
|
self.state = "created" # created, invited, accepted, ringing, hungup
|
||||||
|
self._pipeline_started = False
|
||||||
|
|
||||||
|
# Track IDs
|
||||||
|
self.current_track_id: Optional[str] = str(uuid.uuid4())
|
||||||
|
|
||||||
|
logger.info(f"Session {self.id} created (duplex={self.use_duplex})")
|
||||||
|
|
||||||
|
async def handle_text(self, text_data: str) -> None:
|
||||||
|
"""
|
||||||
|
Handle incoming text data (JSON commands).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text_data: JSON text data
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
data = json.loads(text_data)
|
||||||
|
command = parse_command(data)
|
||||||
|
command_type = command.command
|
||||||
|
|
||||||
|
logger.info(f"Session {self.id} received command: {command_type}")
|
||||||
|
|
||||||
|
# Route command to appropriate handler
|
||||||
|
if command_type == "invite":
|
||||||
|
await self._handle_invite(data)
|
||||||
|
|
||||||
|
elif command_type == "accept":
|
||||||
|
await self._handle_accept(data)
|
||||||
|
|
||||||
|
elif command_type == "reject":
|
||||||
|
await self._handle_reject(data)
|
||||||
|
|
||||||
|
elif command_type == "ringing":
|
||||||
|
await self._handle_ringing(data)
|
||||||
|
|
||||||
|
elif command_type == "tts":
|
||||||
|
await self._handle_tts(command)
|
||||||
|
|
||||||
|
elif command_type == "play":
|
||||||
|
await self._handle_play(data)
|
||||||
|
|
||||||
|
elif command_type == "interrupt":
|
||||||
|
await self._handle_interrupt(command)
|
||||||
|
|
||||||
|
elif command_type == "pause":
|
||||||
|
await self._handle_pause()
|
||||||
|
|
||||||
|
elif command_type == "resume":
|
||||||
|
await self._handle_resume()
|
||||||
|
|
||||||
|
elif command_type == "hangup":
|
||||||
|
await self._handle_hangup(command)
|
||||||
|
|
||||||
|
elif command_type == "history":
|
||||||
|
await self._handle_history(data)
|
||||||
|
|
||||||
|
elif command_type == "chat":
|
||||||
|
await self._handle_chat(command)
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.warning(f"Session {self.id} unknown command: {command_type}")
|
||||||
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"Session {self.id} JSON decode error: {e}")
|
||||||
|
await self._send_error("client", f"Invalid JSON: {e}")
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(f"Session {self.id} command parse error: {e}")
|
||||||
|
await self._send_error("client", f"Invalid command: {e}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Session {self.id} handle_text error: {e}", exc_info=True)
|
||||||
|
await self._send_error("server", f"Internal error: {e}")
|
||||||
|
|
||||||
|
async def handle_audio(self, audio_bytes: bytes) -> None:
|
||||||
|
"""
|
||||||
|
Handle incoming audio data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_bytes: PCM audio data
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if self.use_duplex:
|
||||||
|
await self.pipeline.process_audio(audio_bytes)
|
||||||
|
else:
|
||||||
|
await self.pipeline.process_input(audio_bytes)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Session {self.id} handle_audio error: {e}", exc_info=True)
|
||||||
|
|
||||||
|
async def _handle_invite(self, data: Dict[str, Any]) -> None:
|
||||||
|
"""Handle invite command."""
|
||||||
|
self.state = "invited"
|
||||||
|
option = data.get("option", {})
|
||||||
|
|
||||||
|
# Send answer event
|
||||||
|
await self.transport.send_event({
|
||||||
|
"event": "answer",
|
||||||
|
"trackId": self.current_track_id,
|
||||||
|
"timestamp": self._get_timestamp_ms()
|
||||||
|
})
|
||||||
|
|
||||||
|
# Start duplex pipeline if enabled
|
||||||
|
if self.use_duplex and not self._pipeline_started:
|
||||||
|
try:
|
||||||
|
await self.pipeline.start()
|
||||||
|
self._pipeline_started = True
|
||||||
|
logger.info(f"Session {self.id} duplex pipeline started")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to start duplex pipeline: {e}")
|
||||||
|
|
||||||
|
logger.info(f"Session {self.id} invited with codec: {option.get('codec', 'pcm')}")
|
||||||
|
|
||||||
|
async def _handle_accept(self, data: Dict[str, Any]) -> None:
|
||||||
|
"""Handle accept command."""
|
||||||
|
self.state = "accepted"
|
||||||
|
logger.info(f"Session {self.id} accepted")
|
||||||
|
|
||||||
|
async def _handle_reject(self, data: Dict[str, Any]) -> None:
|
||||||
|
"""Handle reject command."""
|
||||||
|
self.state = "rejected"
|
||||||
|
reason = data.get("reason", "Rejected")
|
||||||
|
logger.info(f"Session {self.id} rejected: {reason}")
|
||||||
|
|
||||||
|
async def _handle_ringing(self, data: Dict[str, Any]) -> None:
|
||||||
|
"""Handle ringing command."""
|
||||||
|
self.state = "ringing"
|
||||||
|
logger.info(f"Session {self.id} ringing")
|
||||||
|
|
||||||
|
async def _handle_tts(self, command: TTSCommand) -> None:
|
||||||
|
"""Handle TTS command."""
|
||||||
|
logger.info(f"Session {self.id} TTS: {command.text[:50]}...")
|
||||||
|
|
||||||
|
# Send track start event
|
||||||
|
await self.transport.send_event({
|
||||||
|
"event": "trackStart",
|
||||||
|
"trackId": self.current_track_id,
|
||||||
|
"timestamp": self._get_timestamp_ms(),
|
||||||
|
"playId": command.play_id
|
||||||
|
})
|
||||||
|
|
||||||
|
# TODO: Implement actual TTS synthesis
|
||||||
|
# For now, just send track end event
|
||||||
|
await self.transport.send_event({
|
||||||
|
"event": "trackEnd",
|
||||||
|
"trackId": self.current_track_id,
|
||||||
|
"timestamp": self._get_timestamp_ms(),
|
||||||
|
"duration": 1000,
|
||||||
|
"ssrc": 0,
|
||||||
|
"playId": command.play_id
|
||||||
|
})
|
||||||
|
|
||||||
|
async def _handle_play(self, data: Dict[str, Any]) -> None:
|
||||||
|
"""Handle play command."""
|
||||||
|
url = data.get("url", "")
|
||||||
|
logger.info(f"Session {self.id} play: {url}")
|
||||||
|
|
||||||
|
# Send track start event
|
||||||
|
await self.transport.send_event({
|
||||||
|
"event": "trackStart",
|
||||||
|
"trackId": self.current_track_id,
|
||||||
|
"timestamp": self._get_timestamp_ms(),
|
||||||
|
"playId": url
|
||||||
|
})
|
||||||
|
|
||||||
|
# TODO: Implement actual audio playback
|
||||||
|
# For now, just send track end event
|
||||||
|
await self.transport.send_event({
|
||||||
|
"event": "trackEnd",
|
||||||
|
"trackId": self.current_track_id,
|
||||||
|
"timestamp": self._get_timestamp_ms(),
|
||||||
|
"duration": 1000,
|
||||||
|
"ssrc": 0,
|
||||||
|
"playId": url
|
||||||
|
})
|
||||||
|
|
||||||
|
async def _handle_interrupt(self, command: InterruptCommand) -> None:
|
||||||
|
"""Handle interrupt command."""
|
||||||
|
if command.graceful:
|
||||||
|
logger.info(f"Session {self.id} graceful interrupt")
|
||||||
|
else:
|
||||||
|
logger.info(f"Session {self.id} immediate interrupt")
|
||||||
|
if self.use_duplex:
|
||||||
|
await self.pipeline.interrupt()
|
||||||
|
else:
|
||||||
|
await self.pipeline.interrupt()
|
||||||
|
|
||||||
|
async def _handle_pause(self) -> None:
|
||||||
|
"""Handle pause command."""
|
||||||
|
logger.info(f"Session {self.id} paused")
|
||||||
|
|
||||||
|
async def _handle_resume(self) -> None:
|
||||||
|
"""Handle resume command."""
|
||||||
|
logger.info(f"Session {self.id} resumed")
|
||||||
|
|
||||||
|
async def _handle_hangup(self, command: HangupCommand) -> None:
|
||||||
|
"""Handle hangup command."""
|
||||||
|
self.state = "hungup"
|
||||||
|
reason = command.reason or "User requested"
|
||||||
|
logger.info(f"Session {self.id} hung up: {reason}")
|
||||||
|
|
||||||
|
# Send hangup event
|
||||||
|
await self.transport.send_event({
|
||||||
|
"event": "hangup",
|
||||||
|
"timestamp": self._get_timestamp_ms(),
|
||||||
|
"reason": reason,
|
||||||
|
"initiator": command.initiator or "user"
|
||||||
|
})
|
||||||
|
|
||||||
|
# Close transport
|
||||||
|
await self.transport.close()
|
||||||
|
|
||||||
|
async def _handle_history(self, data: Dict[str, Any]) -> None:
|
||||||
|
"""Handle history command."""
|
||||||
|
speaker = data.get("speaker", "unknown")
|
||||||
|
text = data.get("text", "")
|
||||||
|
logger.info(f"Session {self.id} history [{speaker}]: {text[:50]}...")
|
||||||
|
|
||||||
|
async def _handle_chat(self, command: ChatCommand) -> None:
|
||||||
|
"""Handle chat command."""
|
||||||
|
logger.info(f"Session {self.id} chat: {command.text[:50]}...")
|
||||||
|
# Process text input through pipeline
|
||||||
|
if self.use_duplex:
|
||||||
|
await self.pipeline.process_text(command.text)
|
||||||
|
else:
|
||||||
|
await self.pipeline.process_text_input(command.text)
|
||||||
|
|
||||||
|
async def _send_error(self, sender: str, error_message: str) -> None:
|
||||||
|
"""
|
||||||
|
Send error event to client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sender: Component that generated the error
|
||||||
|
error_message: Error message
|
||||||
|
"""
|
||||||
|
await self.transport.send_event({
|
||||||
|
"event": "error",
|
||||||
|
"trackId": self.current_track_id,
|
||||||
|
"timestamp": self._get_timestamp_ms(),
|
||||||
|
"sender": sender,
|
||||||
|
"error": error_message
|
||||||
|
})
|
||||||
|
|
||||||
|
def _get_timestamp_ms(self) -> int:
|
||||||
|
"""Get current timestamp in milliseconds."""
|
||||||
|
import time
|
||||||
|
return int(time.time() * 1000)
|
||||||
|
|
||||||
|
async def cleanup(self) -> None:
|
||||||
|
"""Cleanup session resources."""
|
||||||
|
logger.info(f"Session {self.id} cleaning up")
|
||||||
|
await self.pipeline.cleanup()
|
||||||
|
await self.transport.close()
|
||||||
207
core/transports.py
Normal file
207
core/transports.py
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
"""Transport layer for WebSocket and WebRTC communication."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
from fastapi import WebSocket
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
# Try to import aiortc (optional for WebRTC functionality)
|
||||||
|
try:
|
||||||
|
from aiortc import RTCPeerConnection
|
||||||
|
AIORTC_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
AIORTC_AVAILABLE = False
|
||||||
|
RTCPeerConnection = None # Type hint placeholder
|
||||||
|
|
||||||
|
|
||||||
|
class BaseTransport(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for transports.
|
||||||
|
|
||||||
|
All transports must implement send_event and send_audio methods.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def send_event(self, event: dict) -> None:
|
||||||
|
"""
|
||||||
|
Send a JSON event to the client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event: Event data as dictionary
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def send_audio(self, pcm_bytes: bytes) -> None:
|
||||||
|
"""
|
||||||
|
Send audio data to the client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close the transport and cleanup resources."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SocketTransport(BaseTransport):
|
||||||
|
"""
|
||||||
|
WebSocket transport for raw audio streaming.
|
||||||
|
|
||||||
|
Handles mixed text/binary frames over WebSocket connection.
|
||||||
|
Uses asyncio.Lock to prevent frame interleaving.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, websocket: WebSocket):
|
||||||
|
"""
|
||||||
|
Initialize WebSocket transport.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
websocket: FastAPI WebSocket connection
|
||||||
|
"""
|
||||||
|
self.ws = websocket
|
||||||
|
self.lock = asyncio.Lock() # Prevent frame interleaving
|
||||||
|
self._closed = False
|
||||||
|
|
||||||
|
async def send_event(self, event: dict) -> None:
|
||||||
|
"""
|
||||||
|
Send a JSON event via WebSocket.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event: Event data as dictionary
|
||||||
|
"""
|
||||||
|
if self._closed:
|
||||||
|
logger.warning("Attempted to send event on closed transport")
|
||||||
|
return
|
||||||
|
|
||||||
|
async with self.lock:
|
||||||
|
try:
|
||||||
|
await self.ws.send_text(json.dumps(event))
|
||||||
|
logger.debug(f"Sent event: {event.get('event', 'unknown')}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error sending event: {e}")
|
||||||
|
self._closed = True
|
||||||
|
|
||||||
|
async def send_audio(self, pcm_bytes: bytes) -> None:
|
||||||
|
"""
|
||||||
|
Send PCM audio data via WebSocket.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
|
||||||
|
"""
|
||||||
|
if self._closed:
|
||||||
|
logger.warning("Attempted to send audio on closed transport")
|
||||||
|
return
|
||||||
|
|
||||||
|
async with self.lock:
|
||||||
|
try:
|
||||||
|
await self.ws.send_bytes(pcm_bytes)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error sending audio: {e}")
|
||||||
|
self._closed = True
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close the WebSocket connection."""
|
||||||
|
self._closed = True
|
||||||
|
try:
|
||||||
|
await self.ws.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error closing WebSocket: {e}")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_closed(self) -> bool:
|
||||||
|
"""Check if the transport is closed."""
|
||||||
|
return self._closed
|
||||||
|
|
||||||
|
|
||||||
|
class WebRtcTransport(BaseTransport):
|
||||||
|
"""
|
||||||
|
WebRTC transport for WebRTC audio streaming.
|
||||||
|
|
||||||
|
Uses WebSocket for signaling and RTCPeerConnection for media.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, websocket: WebSocket, pc):
|
||||||
|
"""
|
||||||
|
Initialize WebRTC transport.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
websocket: FastAPI WebSocket connection for signaling
|
||||||
|
pc: RTCPeerConnection for media transport
|
||||||
|
"""
|
||||||
|
if not AIORTC_AVAILABLE:
|
||||||
|
raise RuntimeError("aiortc is not available - WebRTC transport cannot be used")
|
||||||
|
|
||||||
|
self.ws = websocket
|
||||||
|
self.pc = pc
|
||||||
|
self.outbound_track = None # MediaStreamTrack for outbound audio
|
||||||
|
self._closed = False
|
||||||
|
|
||||||
|
async def send_event(self, event: dict) -> None:
|
||||||
|
"""
|
||||||
|
Send a JSON event via WebSocket signaling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event: Event data as dictionary
|
||||||
|
"""
|
||||||
|
if self._closed:
|
||||||
|
logger.warning("Attempted to send event on closed transport")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.ws.send_text(json.dumps(event))
|
||||||
|
logger.debug(f"Sent event: {event.get('event', 'unknown')}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error sending event: {e}")
|
||||||
|
self._closed = True
|
||||||
|
|
||||||
|
async def send_audio(self, pcm_bytes: bytes) -> None:
|
||||||
|
"""
|
||||||
|
Send audio data via WebRTC track.
|
||||||
|
|
||||||
|
Note: In WebRTC, you don't send bytes directly. You push frames
|
||||||
|
to a MediaStreamTrack that the peer connection is reading.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
|
||||||
|
"""
|
||||||
|
if self._closed:
|
||||||
|
logger.warning("Attempted to send audio on closed transport")
|
||||||
|
return
|
||||||
|
|
||||||
|
# This would require a custom MediaStreamTrack implementation
|
||||||
|
# For now, we'll log this as a placeholder
|
||||||
|
logger.debug(f"Audio bytes queued for WebRTC track: {len(pcm_bytes)} bytes")
|
||||||
|
|
||||||
|
# TODO: Implement outbound audio track if needed
|
||||||
|
# if self.outbound_track:
|
||||||
|
# await self.outbound_track.add_frame(pcm_bytes)
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close the WebRTC connection."""
|
||||||
|
self._closed = True
|
||||||
|
try:
|
||||||
|
await self.pc.close()
|
||||||
|
await self.ws.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error closing WebRTC transport: {e}")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_closed(self) -> bool:
|
||||||
|
"""Check if the transport is closed."""
|
||||||
|
return self._closed
|
||||||
|
|
||||||
|
def set_outbound_track(self, track):
|
||||||
|
"""
|
||||||
|
Set the outbound audio track for sending audio to client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
track: MediaStreamTrack for outbound audio
|
||||||
|
"""
|
||||||
|
self.outbound_track = track
|
||||||
|
logger.debug("Set outbound track for WebRTC transport")
|
||||||
BIN
data/vad/silero_vad.onnx
Normal file
BIN
data/vad/silero_vad.onnx
Normal file
Binary file not shown.
538
examples/mic_client.py
Normal file
538
examples/mic_client.py
Normal file
@@ -0,0 +1,538 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Microphone client for testing duplex voice conversation.
|
||||||
|
|
||||||
|
This client captures audio from the microphone, sends it to the server,
|
||||||
|
and plays back the AI's voice response through the speakers.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python examples/mic_client.py --url ws://localhost:8000/ws
|
||||||
|
python examples/mic_client.py --url ws://localhost:8000/ws --chat "Hello!"
|
||||||
|
|
||||||
|
Requirements:
|
||||||
|
pip install sounddevice soundfile websockets numpy
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import queue
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
try:
|
||||||
|
import numpy as np
|
||||||
|
except ImportError:
|
||||||
|
print("Please install numpy: pip install numpy")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import sounddevice as sd
|
||||||
|
except ImportError:
|
||||||
|
print("Please install sounddevice: pip install sounddevice")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import websockets
|
||||||
|
except ImportError:
|
||||||
|
print("Please install websockets: pip install websockets")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
class MicrophoneClient:
|
||||||
|
"""
|
||||||
|
Full-duplex microphone client for voice conversation.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Real-time microphone capture
|
||||||
|
- Real-time speaker playback
|
||||||
|
- WebSocket communication
|
||||||
|
- Text chat support
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
sample_rate: int = 16000,
|
||||||
|
chunk_duration_ms: int = 20,
|
||||||
|
input_device: int = None,
|
||||||
|
output_device: int = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize microphone client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: WebSocket server URL
|
||||||
|
sample_rate: Audio sample rate (Hz)
|
||||||
|
chunk_duration_ms: Audio chunk duration (ms)
|
||||||
|
input_device: Input device ID (None for default)
|
||||||
|
output_device: Output device ID (None for default)
|
||||||
|
"""
|
||||||
|
self.url = url
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.chunk_duration_ms = chunk_duration_ms
|
||||||
|
self.chunk_samples = int(sample_rate * chunk_duration_ms / 1000)
|
||||||
|
self.input_device = input_device
|
||||||
|
self.output_device = output_device
|
||||||
|
|
||||||
|
# WebSocket connection
|
||||||
|
self.ws = None
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
# Audio buffers
|
||||||
|
self.audio_input_queue = queue.Queue()
|
||||||
|
self.audio_output_buffer = b"" # Continuous buffer for smooth playback
|
||||||
|
self.audio_output_lock = threading.Lock()
|
||||||
|
|
||||||
|
# Statistics
|
||||||
|
self.bytes_sent = 0
|
||||||
|
self.bytes_received = 0
|
||||||
|
|
||||||
|
# State
|
||||||
|
self.is_recording = True
|
||||||
|
self.is_playing = True
|
||||||
|
|
||||||
|
async def connect(self) -> None:
|
||||||
|
"""Connect to WebSocket server."""
|
||||||
|
print(f"Connecting to {self.url}...")
|
||||||
|
self.ws = await websockets.connect(self.url)
|
||||||
|
self.running = True
|
||||||
|
print("Connected!")
|
||||||
|
|
||||||
|
# Send invite command
|
||||||
|
await self.send_command({
|
||||||
|
"command": "invite",
|
||||||
|
"option": {
|
||||||
|
"codec": "pcm",
|
||||||
|
"sampleRate": self.sample_rate
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
async def send_command(self, cmd: dict) -> None:
|
||||||
|
"""Send JSON command to server."""
|
||||||
|
if self.ws:
|
||||||
|
await self.ws.send(json.dumps(cmd))
|
||||||
|
print(f"→ Command: {cmd.get('command', 'unknown')}")
|
||||||
|
|
||||||
|
async def send_chat(self, text: str) -> None:
|
||||||
|
"""Send chat message (text input)."""
|
||||||
|
await self.send_command({
|
||||||
|
"command": "chat",
|
||||||
|
"text": text
|
||||||
|
})
|
||||||
|
print(f"→ Chat: {text}")
|
||||||
|
|
||||||
|
async def send_interrupt(self) -> None:
|
||||||
|
"""Send interrupt command."""
|
||||||
|
await self.send_command({
|
||||||
|
"command": "interrupt"
|
||||||
|
})
|
||||||
|
|
||||||
|
async def send_hangup(self, reason: str = "User quit") -> None:
|
||||||
|
"""Send hangup command."""
|
||||||
|
await self.send_command({
|
||||||
|
"command": "hangup",
|
||||||
|
"reason": reason
|
||||||
|
})
|
||||||
|
|
||||||
|
def _audio_input_callback(self, indata, frames, time, status):
|
||||||
|
"""Callback for audio input (microphone)."""
|
||||||
|
if status:
|
||||||
|
print(f"Input status: {status}")
|
||||||
|
|
||||||
|
if self.is_recording and self.running:
|
||||||
|
# Convert to 16-bit PCM
|
||||||
|
audio_data = (indata[:, 0] * 32767).astype(np.int16).tobytes()
|
||||||
|
self.audio_input_queue.put(audio_data)
|
||||||
|
|
||||||
|
def _add_audio_to_buffer(self, audio_data: bytes):
|
||||||
|
"""Add audio data to playback buffer."""
|
||||||
|
with self.audio_output_lock:
|
||||||
|
self.audio_output_buffer += audio_data
|
||||||
|
|
||||||
|
def _playback_thread_func(self):
|
||||||
|
"""Thread function for continuous audio playback."""
|
||||||
|
import time
|
||||||
|
|
||||||
|
# Chunk size: 50ms of audio
|
||||||
|
chunk_samples = int(self.sample_rate * 0.05)
|
||||||
|
chunk_bytes = chunk_samples * 2
|
||||||
|
|
||||||
|
print(f"Audio playback thread started (device: {self.output_device or 'default'})")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create output stream with callback
|
||||||
|
with sd.OutputStream(
|
||||||
|
samplerate=self.sample_rate,
|
||||||
|
channels=1,
|
||||||
|
dtype='int16',
|
||||||
|
blocksize=chunk_samples,
|
||||||
|
device=self.output_device,
|
||||||
|
latency='low'
|
||||||
|
) as stream:
|
||||||
|
while self.running:
|
||||||
|
# Get audio from buffer
|
||||||
|
with self.audio_output_lock:
|
||||||
|
if len(self.audio_output_buffer) >= chunk_bytes:
|
||||||
|
audio_data = self.audio_output_buffer[:chunk_bytes]
|
||||||
|
self.audio_output_buffer = self.audio_output_buffer[chunk_bytes:]
|
||||||
|
else:
|
||||||
|
# Not enough audio - output silence
|
||||||
|
audio_data = b'\x00' * chunk_bytes
|
||||||
|
|
||||||
|
# Convert to numpy array and write to stream
|
||||||
|
samples = np.frombuffer(audio_data, dtype=np.int16).reshape(-1, 1)
|
||||||
|
stream.write(samples)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Playback thread error: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
async def _playback_task(self):
|
||||||
|
"""Start playback thread and monitor it."""
|
||||||
|
# Run playback in a dedicated thread for reliable timing
|
||||||
|
playback_thread = threading.Thread(target=self._playback_thread_func, daemon=True)
|
||||||
|
playback_thread.start()
|
||||||
|
|
||||||
|
# Wait for client to stop
|
||||||
|
while self.running and playback_thread.is_alive():
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
print("Audio playback stopped")
|
||||||
|
|
||||||
|
async def audio_sender(self) -> None:
|
||||||
|
"""Send audio from microphone to server."""
|
||||||
|
while self.running:
|
||||||
|
try:
|
||||||
|
# Get audio from queue with timeout
|
||||||
|
try:
|
||||||
|
audio_data = await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None, lambda: self.audio_input_queue.get(timeout=0.1)
|
||||||
|
)
|
||||||
|
except queue.Empty:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Send to server
|
||||||
|
if self.ws and self.is_recording:
|
||||||
|
await self.ws.send(audio_data)
|
||||||
|
self.bytes_sent += len(audio_data)
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Audio sender error: {e}")
|
||||||
|
break
|
||||||
|
|
||||||
|
async def receiver(self) -> None:
|
||||||
|
"""Receive messages from server."""
|
||||||
|
try:
|
||||||
|
while self.running:
|
||||||
|
try:
|
||||||
|
message = await asyncio.wait_for(self.ws.recv(), timeout=0.1)
|
||||||
|
|
||||||
|
if isinstance(message, bytes):
|
||||||
|
# Audio data received
|
||||||
|
self.bytes_received += len(message)
|
||||||
|
|
||||||
|
if self.is_playing:
|
||||||
|
self._add_audio_to_buffer(message)
|
||||||
|
|
||||||
|
# Show progress (less verbose)
|
||||||
|
with self.audio_output_lock:
|
||||||
|
buffer_ms = len(self.audio_output_buffer) / (self.sample_rate * 2) * 1000
|
||||||
|
duration_ms = len(message) / (self.sample_rate * 2) * 1000
|
||||||
|
print(f"← Audio: {duration_ms:.0f}ms (buffer: {buffer_ms:.0f}ms)")
|
||||||
|
|
||||||
|
else:
|
||||||
|
# JSON event
|
||||||
|
event = json.loads(message)
|
||||||
|
await self._handle_event(event)
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
continue
|
||||||
|
except websockets.ConnectionClosed:
|
||||||
|
print("Connection closed")
|
||||||
|
self.running = False
|
||||||
|
break
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Receiver error: {e}")
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
async def _handle_event(self, event: dict) -> None:
|
||||||
|
"""Handle incoming event."""
|
||||||
|
event_type = event.get("event", "unknown")
|
||||||
|
|
||||||
|
if event_type == "answer":
|
||||||
|
print("← Session ready!")
|
||||||
|
elif event_type == "speaking":
|
||||||
|
print("← User speech detected")
|
||||||
|
elif event_type == "silence":
|
||||||
|
print("← User silence detected")
|
||||||
|
elif event_type == "transcript":
|
||||||
|
# Display user speech transcription
|
||||||
|
text = event.get("text", "")
|
||||||
|
is_final = event.get("isFinal", False)
|
||||||
|
if is_final:
|
||||||
|
# Clear the interim line and print final
|
||||||
|
print(" " * 80, end="\r") # Clear previous interim text
|
||||||
|
print(f"→ You: {text}")
|
||||||
|
else:
|
||||||
|
# Interim result - show with indicator (overwrite same line)
|
||||||
|
display_text = text[:60] + "..." if len(text) > 60 else text
|
||||||
|
print(f" [listening] {display_text}".ljust(80), end="\r")
|
||||||
|
elif event_type == "trackStart":
|
||||||
|
print("← Bot started speaking")
|
||||||
|
# Clear any old audio in buffer
|
||||||
|
with self.audio_output_lock:
|
||||||
|
self.audio_output_buffer = b""
|
||||||
|
elif event_type == "trackEnd":
|
||||||
|
print("← Bot finished speaking")
|
||||||
|
elif event_type == "interrupt":
|
||||||
|
print("← Bot interrupted!")
|
||||||
|
# IMPORTANT: Clear audio buffer immediately on interrupt
|
||||||
|
with self.audio_output_lock:
|
||||||
|
buffer_ms = len(self.audio_output_buffer) / (self.sample_rate * 2) * 1000
|
||||||
|
self.audio_output_buffer = b""
|
||||||
|
print(f" (cleared {buffer_ms:.0f}ms of buffered audio)")
|
||||||
|
elif event_type == "error":
|
||||||
|
print(f"← Error: {event.get('error')}")
|
||||||
|
elif event_type == "hangup":
|
||||||
|
print(f"← Hangup: {event.get('reason')}")
|
||||||
|
self.running = False
|
||||||
|
else:
|
||||||
|
print(f"← Event: {event_type}")
|
||||||
|
|
||||||
|
async def interactive_mode(self) -> None:
|
||||||
|
"""Run interactive mode for text chat."""
|
||||||
|
print("\n" + "=" * 50)
|
||||||
|
print("Voice Conversation Client")
|
||||||
|
print("=" * 50)
|
||||||
|
print("Speak into your microphone to talk to the AI.")
|
||||||
|
print("Or type messages to send text.")
|
||||||
|
print("")
|
||||||
|
print("Commands:")
|
||||||
|
print(" /quit - End conversation")
|
||||||
|
print(" /mute - Mute microphone")
|
||||||
|
print(" /unmute - Unmute microphone")
|
||||||
|
print(" /interrupt - Interrupt AI speech")
|
||||||
|
print(" /stats - Show statistics")
|
||||||
|
print("=" * 50 + "\n")
|
||||||
|
|
||||||
|
while self.running:
|
||||||
|
try:
|
||||||
|
user_input = await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None, input, ""
|
||||||
|
)
|
||||||
|
|
||||||
|
if not user_input:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Handle commands
|
||||||
|
if user_input.startswith("/"):
|
||||||
|
cmd = user_input.lower().strip()
|
||||||
|
|
||||||
|
if cmd == "/quit":
|
||||||
|
await self.send_hangup("User quit")
|
||||||
|
break
|
||||||
|
elif cmd == "/mute":
|
||||||
|
self.is_recording = False
|
||||||
|
print("Microphone muted")
|
||||||
|
elif cmd == "/unmute":
|
||||||
|
self.is_recording = True
|
||||||
|
print("Microphone unmuted")
|
||||||
|
elif cmd == "/interrupt":
|
||||||
|
await self.send_interrupt()
|
||||||
|
elif cmd == "/stats":
|
||||||
|
print(f"Sent: {self.bytes_sent / 1024:.1f} KB")
|
||||||
|
print(f"Received: {self.bytes_received / 1024:.1f} KB")
|
||||||
|
else:
|
||||||
|
print(f"Unknown command: {cmd}")
|
||||||
|
else:
|
||||||
|
# Send as chat message
|
||||||
|
await self.send_chat(user_input)
|
||||||
|
|
||||||
|
except EOFError:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Input error: {e}")
|
||||||
|
|
||||||
|
async def run(self, chat_message: str = None, interactive: bool = True) -> None:
|
||||||
|
"""
|
||||||
|
Run the client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_message: Optional single chat message to send
|
||||||
|
interactive: Whether to run in interactive mode
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await self.connect()
|
||||||
|
|
||||||
|
# Wait for answer
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
# Start audio input stream
|
||||||
|
print("Starting audio streams...")
|
||||||
|
|
||||||
|
input_stream = sd.InputStream(
|
||||||
|
samplerate=self.sample_rate,
|
||||||
|
channels=1,
|
||||||
|
dtype=np.float32,
|
||||||
|
blocksize=self.chunk_samples,
|
||||||
|
device=self.input_device,
|
||||||
|
callback=self._audio_input_callback
|
||||||
|
)
|
||||||
|
|
||||||
|
input_stream.start()
|
||||||
|
print("Audio streams started")
|
||||||
|
|
||||||
|
# Start background tasks
|
||||||
|
sender_task = asyncio.create_task(self.audio_sender())
|
||||||
|
receiver_task = asyncio.create_task(self.receiver())
|
||||||
|
playback_task = asyncio.create_task(self._playback_task())
|
||||||
|
|
||||||
|
if chat_message:
|
||||||
|
# Send single message and wait
|
||||||
|
await self.send_chat(chat_message)
|
||||||
|
await asyncio.sleep(15)
|
||||||
|
elif interactive:
|
||||||
|
# Run interactive mode
|
||||||
|
await self.interactive_mode()
|
||||||
|
else:
|
||||||
|
# Just wait
|
||||||
|
while self.running:
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
self.running = False
|
||||||
|
sender_task.cancel()
|
||||||
|
receiver_task.cancel()
|
||||||
|
playback_task.cancel()
|
||||||
|
|
||||||
|
try:
|
||||||
|
await sender_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
await receiver_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
await playback_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
input_stream.stop()
|
||||||
|
|
||||||
|
except ConnectionRefusedError:
|
||||||
|
print(f"Error: Could not connect to {self.url}")
|
||||||
|
print("Make sure the server is running.")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
finally:
|
||||||
|
await self.close()
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close the connection."""
|
||||||
|
self.running = False
|
||||||
|
if self.ws:
|
||||||
|
await self.ws.close()
|
||||||
|
|
||||||
|
print(f"\nSession ended")
|
||||||
|
print(f" Total sent: {self.bytes_sent / 1024:.1f} KB")
|
||||||
|
print(f" Total received: {self.bytes_received / 1024:.1f} KB")
|
||||||
|
|
||||||
|
|
||||||
|
def list_devices():
|
||||||
|
"""List available audio devices."""
|
||||||
|
print("\nAvailable audio devices:")
|
||||||
|
print("-" * 60)
|
||||||
|
devices = sd.query_devices()
|
||||||
|
for i, device in enumerate(devices):
|
||||||
|
direction = []
|
||||||
|
if device['max_input_channels'] > 0:
|
||||||
|
direction.append("IN")
|
||||||
|
if device['max_output_channels'] > 0:
|
||||||
|
direction.append("OUT")
|
||||||
|
direction_str = "/".join(direction) if direction else "N/A"
|
||||||
|
|
||||||
|
default = ""
|
||||||
|
if i == sd.default.device[0]:
|
||||||
|
default += " [DEFAULT INPUT]"
|
||||||
|
if i == sd.default.device[1]:
|
||||||
|
default += " [DEFAULT OUTPUT]"
|
||||||
|
|
||||||
|
print(f" {i:2d}: {device['name'][:40]:40s} ({direction_str}){default}")
|
||||||
|
print("-" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Microphone client for duplex voice conversation"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--url",
|
||||||
|
default="ws://localhost:8000/ws",
|
||||||
|
help="WebSocket server URL"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--chat",
|
||||||
|
help="Send a single chat message instead of using microphone"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--sample-rate",
|
||||||
|
type=int,
|
||||||
|
default=16000,
|
||||||
|
help="Audio sample rate (default: 16000)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--input-device",
|
||||||
|
type=int,
|
||||||
|
help="Input device ID"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-device",
|
||||||
|
type=int,
|
||||||
|
help="Output device ID"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--list-devices",
|
||||||
|
action="store_true",
|
||||||
|
help="List available audio devices and exit"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-interactive",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable interactive mode"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.list_devices:
|
||||||
|
list_devices()
|
||||||
|
return
|
||||||
|
|
||||||
|
client = MicrophoneClient(
|
||||||
|
url=args.url,
|
||||||
|
sample_rate=args.sample_rate,
|
||||||
|
input_device=args.input_device,
|
||||||
|
output_device=args.output_device
|
||||||
|
)
|
||||||
|
|
||||||
|
await client.run(
|
||||||
|
chat_message=args.chat,
|
||||||
|
interactive=not args.no_interactive
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
asyncio.run(main())
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nInterrupted by user")
|
||||||
249
examples/simple_client.py
Normal file
249
examples/simple_client.py
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Simple WebSocket client for testing voice conversation.
|
||||||
|
Uses PyAudio for more reliable audio playback on Windows.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python examples/simple_client.py
|
||||||
|
python examples/simple_client.py --text "Hello"
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import wave
|
||||||
|
import io
|
||||||
|
|
||||||
|
try:
|
||||||
|
import numpy as np
|
||||||
|
except ImportError:
|
||||||
|
print("pip install numpy")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import websockets
|
||||||
|
except ImportError:
|
||||||
|
print("pip install websockets")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Try PyAudio first (more reliable on Windows)
|
||||||
|
try:
|
||||||
|
import pyaudio
|
||||||
|
PYAUDIO_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
PYAUDIO_AVAILABLE = False
|
||||||
|
print("PyAudio not available, trying sounddevice...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
import sounddevice as sd
|
||||||
|
SD_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
SD_AVAILABLE = False
|
||||||
|
|
||||||
|
if not PYAUDIO_AVAILABLE and not SD_AVAILABLE:
|
||||||
|
print("Please install pyaudio or sounddevice:")
|
||||||
|
print(" pip install pyaudio")
|
||||||
|
print(" or: pip install sounddevice")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleVoiceClient:
|
||||||
|
"""Simple voice client with reliable audio playback."""
|
||||||
|
|
||||||
|
def __init__(self, url: str, sample_rate: int = 16000):
|
||||||
|
self.url = url
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.ws = None
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
# Audio buffer
|
||||||
|
self.audio_buffer = b""
|
||||||
|
|
||||||
|
# PyAudio setup
|
||||||
|
if PYAUDIO_AVAILABLE:
|
||||||
|
self.pa = pyaudio.PyAudio()
|
||||||
|
self.stream = None
|
||||||
|
|
||||||
|
# Stats
|
||||||
|
self.bytes_received = 0
|
||||||
|
|
||||||
|
async def connect(self):
|
||||||
|
"""Connect to server."""
|
||||||
|
print(f"Connecting to {self.url}...")
|
||||||
|
self.ws = await websockets.connect(self.url)
|
||||||
|
self.running = True
|
||||||
|
print("Connected!")
|
||||||
|
|
||||||
|
# Send invite
|
||||||
|
await self.ws.send(json.dumps({
|
||||||
|
"command": "invite",
|
||||||
|
"option": {"codec": "pcm", "sampleRate": self.sample_rate}
|
||||||
|
}))
|
||||||
|
print("-> invite")
|
||||||
|
|
||||||
|
async def send_chat(self, text: str):
|
||||||
|
"""Send chat message."""
|
||||||
|
await self.ws.send(json.dumps({"command": "chat", "text": text}))
|
||||||
|
print(f"-> chat: {text}")
|
||||||
|
|
||||||
|
def play_audio(self, audio_data: bytes):
|
||||||
|
"""Play audio data immediately."""
|
||||||
|
if len(audio_data) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
if PYAUDIO_AVAILABLE:
|
||||||
|
# Use PyAudio - more reliable on Windows
|
||||||
|
if self.stream is None:
|
||||||
|
self.stream = self.pa.open(
|
||||||
|
format=pyaudio.paInt16,
|
||||||
|
channels=1,
|
||||||
|
rate=self.sample_rate,
|
||||||
|
output=True,
|
||||||
|
frames_per_buffer=1024
|
||||||
|
)
|
||||||
|
self.stream.write(audio_data)
|
||||||
|
elif SD_AVAILABLE:
|
||||||
|
# Use sounddevice
|
||||||
|
samples = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32767.0
|
||||||
|
sd.play(samples, self.sample_rate, blocking=True)
|
||||||
|
|
||||||
|
async def receive_loop(self):
|
||||||
|
"""Receive and play audio."""
|
||||||
|
print("\nWaiting for response...")
|
||||||
|
|
||||||
|
while self.running:
|
||||||
|
try:
|
||||||
|
msg = await asyncio.wait_for(self.ws.recv(), timeout=0.1)
|
||||||
|
|
||||||
|
if isinstance(msg, bytes):
|
||||||
|
# Audio data
|
||||||
|
self.bytes_received += len(msg)
|
||||||
|
duration_ms = len(msg) / (self.sample_rate * 2) * 1000
|
||||||
|
print(f"<- audio: {len(msg)} bytes ({duration_ms:.0f}ms)")
|
||||||
|
|
||||||
|
# Play immediately in executor to not block
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
await loop.run_in_executor(None, self.play_audio, msg)
|
||||||
|
else:
|
||||||
|
# JSON event
|
||||||
|
event = json.loads(msg)
|
||||||
|
etype = event.get("event", "?")
|
||||||
|
|
||||||
|
if etype == "transcript":
|
||||||
|
# User speech transcription
|
||||||
|
text = event.get("text", "")
|
||||||
|
is_final = event.get("isFinal", False)
|
||||||
|
if is_final:
|
||||||
|
print(f"<- You said: {text}")
|
||||||
|
else:
|
||||||
|
print(f"<- [listening] {text}", end="\r")
|
||||||
|
elif etype == "hangup":
|
||||||
|
print(f"<- {etype}")
|
||||||
|
self.running = False
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
print(f"<- {etype}")
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
continue
|
||||||
|
except websockets.ConnectionClosed:
|
||||||
|
print("Connection closed")
|
||||||
|
self.running = False
|
||||||
|
break
|
||||||
|
|
||||||
|
async def run(self, text: str = None):
|
||||||
|
"""Run the client."""
|
||||||
|
try:
|
||||||
|
await self.connect()
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
# Start receiver
|
||||||
|
recv_task = asyncio.create_task(self.receive_loop())
|
||||||
|
|
||||||
|
if text:
|
||||||
|
await self.send_chat(text)
|
||||||
|
# Wait for response
|
||||||
|
await asyncio.sleep(30)
|
||||||
|
else:
|
||||||
|
# Interactive mode
|
||||||
|
print("\nType a message and press Enter (or 'quit' to exit):")
|
||||||
|
while self.running:
|
||||||
|
try:
|
||||||
|
user_input = await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None, input, "> "
|
||||||
|
)
|
||||||
|
if user_input.lower() == 'quit':
|
||||||
|
break
|
||||||
|
if user_input.strip():
|
||||||
|
await self.send_chat(user_input)
|
||||||
|
except EOFError:
|
||||||
|
break
|
||||||
|
|
||||||
|
self.running = False
|
||||||
|
recv_task.cancel()
|
||||||
|
try:
|
||||||
|
await recv_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
finally:
|
||||||
|
await self.close()
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""Close connections."""
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
if PYAUDIO_AVAILABLE:
|
||||||
|
if self.stream:
|
||||||
|
self.stream.stop_stream()
|
||||||
|
self.stream.close()
|
||||||
|
self.pa.terminate()
|
||||||
|
|
||||||
|
if self.ws:
|
||||||
|
await self.ws.close()
|
||||||
|
|
||||||
|
print(f"\nTotal audio received: {self.bytes_received / 1024:.1f} KB")
|
||||||
|
|
||||||
|
|
||||||
|
def list_audio_devices():
|
||||||
|
"""List available audio devices."""
|
||||||
|
print("\n=== Audio Devices ===")
|
||||||
|
|
||||||
|
if PYAUDIO_AVAILABLE:
|
||||||
|
pa = pyaudio.PyAudio()
|
||||||
|
print("\nPyAudio devices:")
|
||||||
|
for i in range(pa.get_device_count()):
|
||||||
|
info = pa.get_device_info_by_index(i)
|
||||||
|
if info['maxOutputChannels'] > 0:
|
||||||
|
default = " [DEFAULT]" if i == pa.get_default_output_device_info()['index'] else ""
|
||||||
|
print(f" {i}: {info['name']}{default}")
|
||||||
|
pa.terminate()
|
||||||
|
|
||||||
|
if SD_AVAILABLE:
|
||||||
|
print("\nSounddevice devices:")
|
||||||
|
for i, d in enumerate(sd.query_devices()):
|
||||||
|
if d['max_output_channels'] > 0:
|
||||||
|
default = " [DEFAULT]" if i == sd.default.device[1] else ""
|
||||||
|
print(f" {i}: {d['name']}{default}")
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Simple voice client")
|
||||||
|
parser.add_argument("--url", default="ws://localhost:8000/ws")
|
||||||
|
parser.add_argument("--text", help="Send text and play response")
|
||||||
|
parser.add_argument("--list-devices", action="store_true")
|
||||||
|
parser.add_argument("--sample-rate", type=int, default=16000)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.list_devices:
|
||||||
|
list_audio_devices()
|
||||||
|
return
|
||||||
|
|
||||||
|
client = SimpleVoiceClient(args.url, args.sample_rate)
|
||||||
|
await client.run(args.text)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
137
main.py
137
main.py
@@ -1,137 +0,0 @@
|
|||||||
"""
|
|
||||||
Step 1: Minimal WebSocket Echo Server
|
|
||||||
|
|
||||||
This is the simplest possible WebSocket audio server.
|
|
||||||
It accepts connections and echoes back events.
|
|
||||||
|
|
||||||
What you'll learn:
|
|
||||||
- How to create a FastAPI WebSocket endpoint
|
|
||||||
- How to handle mixed text/binary frames
|
|
||||||
- Basic event sending
|
|
||||||
|
|
||||||
Test with:
|
|
||||||
python main.py
|
|
||||||
python test_client.py
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import uuid
|
|
||||||
from fastapi import FastAPI, WebSocket
|
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
# Configure logging
|
|
||||||
logger.remove()
|
|
||||||
logger.add(lambda msg: print(msg, end=""), level="INFO", format="<green>{time:HH:mm:ss}</green> | {level} | {message}")
|
|
||||||
|
|
||||||
# Create FastAPI app
|
|
||||||
app = FastAPI(title="Voice Gateway - Step 1")
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
|
||||||
async def health_check():
|
|
||||||
"""Health check endpoint."""
|
|
||||||
return {"status": "healthy", "step": "1_minimal_echo"}
|
|
||||||
|
|
||||||
|
|
||||||
@app.websocket("/ws")
|
|
||||||
async def websocket_endpoint(websocket: WebSocket):
|
|
||||||
"""
|
|
||||||
WebSocket endpoint for audio streaming.
|
|
||||||
|
|
||||||
This is a minimal echo server that:
|
|
||||||
1. Accepts WebSocket connections
|
|
||||||
2. Sends a welcome event
|
|
||||||
3. Receives text commands and binary audio
|
|
||||||
4. Echoes speaking events back
|
|
||||||
"""
|
|
||||||
await websocket.accept()
|
|
||||||
|
|
||||||
# Generate unique session ID
|
|
||||||
session_id = str(uuid.uuid4())
|
|
||||||
logger.info(f"[{session_id}] Client connected")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Send welcome event (answer)
|
|
||||||
await websocket.send_json({
|
|
||||||
"event": "answer",
|
|
||||||
"trackId": session_id,
|
|
||||||
"timestamp": _get_timestamp_ms()
|
|
||||||
})
|
|
||||||
logger.info(f"[{session_id}] Sent answer event")
|
|
||||||
|
|
||||||
# Message receive loop
|
|
||||||
while True:
|
|
||||||
message = await websocket.receive()
|
|
||||||
|
|
||||||
# Handle binary audio data
|
|
||||||
if "bytes" in message:
|
|
||||||
audio_bytes = message["bytes"]
|
|
||||||
logger.info(f"[{session_id}] Received audio: {len(audio_bytes)} bytes")
|
|
||||||
|
|
||||||
# Send speaking event (echo back)
|
|
||||||
await websocket.send_json({
|
|
||||||
"event": "speaking",
|
|
||||||
"trackId": session_id,
|
|
||||||
"timestamp": _get_timestamp_ms(),
|
|
||||||
"startTime": _get_timestamp_ms()
|
|
||||||
})
|
|
||||||
|
|
||||||
# Handle text commands
|
|
||||||
elif "text" in message:
|
|
||||||
text_data = message["text"]
|
|
||||||
logger.info(f"[{session_id}] Received text: {text_data[:100]}...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
data = json.loads(text_data)
|
|
||||||
command = data.get("command", "unknown")
|
|
||||||
logger.info(f"[{session_id}] Command: {command}")
|
|
||||||
|
|
||||||
# Handle basic commands
|
|
||||||
if command == "invite":
|
|
||||||
await websocket.send_json({
|
|
||||||
"event": "answer",
|
|
||||||
"trackId": session_id,
|
|
||||||
"timestamp": _get_timestamp_ms()
|
|
||||||
})
|
|
||||||
logger.info(f"[{session_id}] Responded to invite")
|
|
||||||
|
|
||||||
elif command == "hangup":
|
|
||||||
logger.info(f"[{session_id}] Hangup requested")
|
|
||||||
break
|
|
||||||
|
|
||||||
elif command == "ping":
|
|
||||||
await websocket.send_json({
|
|
||||||
"event": "pong",
|
|
||||||
"timestamp": _get_timestamp_ms()
|
|
||||||
})
|
|
||||||
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
logger.error(f"[{session_id}] Invalid JSON: {e}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[{session_id}] Error: {e}")
|
|
||||||
|
|
||||||
finally:
|
|
||||||
logger.info(f"[{session_id}] Connection closed")
|
|
||||||
|
|
||||||
|
|
||||||
def _get_timestamp_ms() -> int:
|
|
||||||
"""Get current timestamp in milliseconds."""
|
|
||||||
import time
|
|
||||||
return int(time.time() * 1000)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import uvicorn
|
|
||||||
|
|
||||||
logger.info("🚀 Starting Step 1: Minimal WebSocket Echo Server")
|
|
||||||
logger.info("📡 Server: ws://localhost:8000/ws")
|
|
||||||
logger.info("🩺 Health: http://localhost:8000/health")
|
|
||||||
|
|
||||||
uvicorn.run(
|
|
||||||
app,
|
|
||||||
host="0.0.0.0",
|
|
||||||
port=8000,
|
|
||||||
log_level="info"
|
|
||||||
)
|
|
||||||
1
models/__init__.py
Normal file
1
models/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Data Models Package"""
|
||||||
143
models/commands.py
Normal file
143
models/commands.py
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
"""Protocol command models matching the original active-call API."""
|
||||||
|
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class InviteCommand(BaseModel):
|
||||||
|
"""Invite command to initiate a call."""
|
||||||
|
|
||||||
|
command: str = Field(default="invite", description="Command type")
|
||||||
|
option: Optional[Dict[str, Any]] = Field(default=None, description="Call configuration options")
|
||||||
|
|
||||||
|
|
||||||
|
class AcceptCommand(BaseModel):
|
||||||
|
"""Accept command to accept an incoming call."""
|
||||||
|
|
||||||
|
command: str = Field(default="accept", description="Command type")
|
||||||
|
option: Optional[Dict[str, Any]] = Field(default=None, description="Call configuration options")
|
||||||
|
|
||||||
|
|
||||||
|
class RejectCommand(BaseModel):
|
||||||
|
"""Reject command to reject an incoming call."""
|
||||||
|
|
||||||
|
command: str = Field(default="reject", description="Command type")
|
||||||
|
reason: str = Field(default="", description="Reason for rejection")
|
||||||
|
code: Optional[int] = Field(default=None, description="SIP response code")
|
||||||
|
|
||||||
|
|
||||||
|
class RingingCommand(BaseModel):
|
||||||
|
"""Ringing command to send ringing response."""
|
||||||
|
|
||||||
|
command: str = Field(default="ringing", description="Command type")
|
||||||
|
recorder: Optional[Dict[str, Any]] = Field(default=None, description="Call recording configuration")
|
||||||
|
early_media: bool = Field(default=False, description="Enable early media")
|
||||||
|
ringtone: Optional[str] = Field(default=None, description="Custom ringtone URL")
|
||||||
|
|
||||||
|
|
||||||
|
class TTSCommand(BaseModel):
|
||||||
|
"""TTS command to convert text to speech."""
|
||||||
|
|
||||||
|
command: str = Field(default="tts", description="Command type")
|
||||||
|
text: str = Field(..., description="Text to synthesize")
|
||||||
|
speaker: Optional[str] = Field(default=None, description="Speaker voice name")
|
||||||
|
play_id: Optional[str] = Field(default=None, description="Unique identifier for this TTS session")
|
||||||
|
auto_hangup: bool = Field(default=False, description="Auto hangup after TTS completion")
|
||||||
|
streaming: bool = Field(default=False, description="Streaming text input")
|
||||||
|
end_of_stream: bool = Field(default=False, description="End of streaming input")
|
||||||
|
wait_input_timeout: Optional[int] = Field(default=None, description="Max time to wait for input (seconds)")
|
||||||
|
option: Optional[Dict[str, Any]] = Field(default=None, description="TTS provider specific options")
|
||||||
|
|
||||||
|
|
||||||
|
class PlayCommand(BaseModel):
|
||||||
|
"""Play command to play audio from URL."""
|
||||||
|
|
||||||
|
command: str = Field(default="play", description="Command type")
|
||||||
|
url: str = Field(..., description="URL of audio file to play")
|
||||||
|
auto_hangup: bool = Field(default=False, description="Auto hangup after playback")
|
||||||
|
wait_input_timeout: Optional[int] = Field(default=None, description="Max time to wait for input (seconds)")
|
||||||
|
|
||||||
|
|
||||||
|
class InterruptCommand(BaseModel):
|
||||||
|
"""Interrupt command to interrupt current playback."""
|
||||||
|
|
||||||
|
command: str = Field(default="interrupt", description="Command type")
|
||||||
|
graceful: bool = Field(default=False, description="Wait for current TTS to complete")
|
||||||
|
|
||||||
|
|
||||||
|
class PauseCommand(BaseModel):
|
||||||
|
"""Pause command to pause current playback."""
|
||||||
|
|
||||||
|
command: str = Field(default="pause", description="Command type")
|
||||||
|
|
||||||
|
|
||||||
|
class ResumeCommand(BaseModel):
|
||||||
|
"""Resume command to resume paused playback."""
|
||||||
|
|
||||||
|
command: str = Field(default="resume", description="Command type")
|
||||||
|
|
||||||
|
|
||||||
|
class HangupCommand(BaseModel):
|
||||||
|
"""Hangup command to end the call."""
|
||||||
|
|
||||||
|
command: str = Field(default="hangup", description="Command type")
|
||||||
|
reason: Optional[str] = Field(default=None, description="Reason for hangup")
|
||||||
|
initiator: Optional[str] = Field(default=None, description="Who initiated the hangup")
|
||||||
|
|
||||||
|
|
||||||
|
class HistoryCommand(BaseModel):
|
||||||
|
"""History command to add conversation history."""
|
||||||
|
|
||||||
|
command: str = Field(default="history", description="Command type")
|
||||||
|
speaker: str = Field(..., description="Speaker identifier")
|
||||||
|
text: str = Field(..., description="Conversation text")
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCommand(BaseModel):
|
||||||
|
"""Chat command for text-based conversation."""
|
||||||
|
|
||||||
|
command: str = Field(default="chat", description="Command type")
|
||||||
|
text: str = Field(..., description="Chat text message")
|
||||||
|
|
||||||
|
|
||||||
|
# Command type mapping
|
||||||
|
COMMAND_TYPES = {
|
||||||
|
"invite": InviteCommand,
|
||||||
|
"accept": AcceptCommand,
|
||||||
|
"reject": RejectCommand,
|
||||||
|
"ringing": RingingCommand,
|
||||||
|
"tts": TTSCommand,
|
||||||
|
"play": PlayCommand,
|
||||||
|
"interrupt": InterruptCommand,
|
||||||
|
"pause": PauseCommand,
|
||||||
|
"resume": ResumeCommand,
|
||||||
|
"hangup": HangupCommand,
|
||||||
|
"history": HistoryCommand,
|
||||||
|
"chat": ChatCommand,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def parse_command(data: Dict[str, Any]) -> BaseModel:
|
||||||
|
"""
|
||||||
|
Parse a command from JSON data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: JSON data as dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Parsed command model
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If command type is unknown
|
||||||
|
"""
|
||||||
|
command_type = data.get("command")
|
||||||
|
|
||||||
|
if not command_type:
|
||||||
|
raise ValueError("Missing 'command' field")
|
||||||
|
|
||||||
|
command_class = COMMAND_TYPES.get(command_type)
|
||||||
|
|
||||||
|
if not command_class:
|
||||||
|
raise ValueError(f"Unknown command type: {command_type}")
|
||||||
|
|
||||||
|
return command_class(**data)
|
||||||
126
models/config.py
Normal file
126
models/config.py
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
"""Configuration models for call options."""
|
||||||
|
|
||||||
|
from typing import Optional, Dict, Any, List
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class VADOption(BaseModel):
|
||||||
|
"""Voice Activity Detection configuration."""
|
||||||
|
|
||||||
|
type: str = Field(default="silero", description="VAD algorithm type (silero, webrtc)")
|
||||||
|
samplerate: int = Field(default=16000, description="Audio sample rate for VAD")
|
||||||
|
speech_padding: int = Field(default=250, description="Speech padding in milliseconds")
|
||||||
|
silence_padding: int = Field(default=100, description="Silence padding in milliseconds")
|
||||||
|
ratio: float = Field(default=0.5, description="Voice detection ratio threshold")
|
||||||
|
voice_threshold: float = Field(default=0.5, description="Voice energy threshold")
|
||||||
|
max_buffer_duration_secs: int = Field(default=50, description="Maximum buffer duration in seconds")
|
||||||
|
silence_timeout: Optional[int] = Field(default=None, description="Silence timeout in milliseconds")
|
||||||
|
endpoint: Optional[str] = Field(default=None, description="Custom VAD service endpoint")
|
||||||
|
secret_key: Optional[str] = Field(default=None, description="VAD service secret key")
|
||||||
|
secret_id: Optional[str] = Field(default=None, description="VAD service secret ID")
|
||||||
|
|
||||||
|
|
||||||
|
class ASROption(BaseModel):
|
||||||
|
"""Automatic Speech Recognition configuration."""
|
||||||
|
|
||||||
|
provider: str = Field(..., description="ASR provider (tencent, aliyun, openai, etc.)")
|
||||||
|
language: Optional[str] = Field(default=None, description="Language code (zh-CN, en-US)")
|
||||||
|
app_id: Optional[str] = Field(default=None, description="Application ID")
|
||||||
|
secret_id: Optional[str] = Field(default=None, description="Secret ID for authentication")
|
||||||
|
secret_key: Optional[str] = Field(default=None, description="Secret key for authentication")
|
||||||
|
model_type: Optional[str] = Field(default=None, description="ASR model type (16k_zh, 8k_en)")
|
||||||
|
buffer_size: Optional[int] = Field(default=None, description="Audio buffer size in bytes")
|
||||||
|
samplerate: Optional[int] = Field(default=None, description="Audio sample rate")
|
||||||
|
endpoint: Optional[str] = Field(default=None, description="Custom ASR service endpoint")
|
||||||
|
extra: Optional[Dict[str, Any]] = Field(default=None, description="Additional parameters")
|
||||||
|
start_when_answer: bool = Field(default=False, description="Start ASR when call is answered")
|
||||||
|
|
||||||
|
|
||||||
|
class TTSOption(BaseModel):
|
||||||
|
"""Text-to-Speech configuration."""
|
||||||
|
|
||||||
|
samplerate: Optional[int] = Field(default=None, description="TTS output sample rate")
|
||||||
|
provider: str = Field(default="msedge", description="TTS provider (tencent, aliyun, deepgram, msedge)")
|
||||||
|
speed: float = Field(default=1.0, description="Speech speed multiplier")
|
||||||
|
app_id: Optional[str] = Field(default=None, description="Application ID")
|
||||||
|
secret_id: Optional[str] = Field(default=None, description="Secret ID for authentication")
|
||||||
|
secret_key: Optional[str] = Field(default=None, description="Secret key for authentication")
|
||||||
|
volume: Optional[int] = Field(default=None, description="Speech volume level (1-10)")
|
||||||
|
speaker: Optional[str] = Field(default=None, description="Voice speaker name")
|
||||||
|
codec: Optional[str] = Field(default=None, description="Audio codec")
|
||||||
|
subtitle: bool = Field(default=False, description="Enable subtitle generation")
|
||||||
|
emotion: Optional[str] = Field(default=None, description="Speech emotion")
|
||||||
|
endpoint: Optional[str] = Field(default=None, description="Custom TTS service endpoint")
|
||||||
|
extra: Optional[Dict[str, Any]] = Field(default=None, description="Additional parameters")
|
||||||
|
max_concurrent_tasks: Optional[int] = Field(default=None, description="Max concurrent tasks")
|
||||||
|
|
||||||
|
|
||||||
|
class RecorderOption(BaseModel):
|
||||||
|
"""Call recording configuration."""
|
||||||
|
|
||||||
|
recorder_file: str = Field(..., description="Path to recording file")
|
||||||
|
samplerate: int = Field(default=16000, description="Recording sample rate")
|
||||||
|
ptime: int = Field(default=200, description="Packet time in milliseconds")
|
||||||
|
|
||||||
|
|
||||||
|
class MediaPassOption(BaseModel):
|
||||||
|
"""Media pass-through configuration for external audio processing."""
|
||||||
|
|
||||||
|
url: str = Field(..., description="WebSocket URL for media streaming")
|
||||||
|
input_sample_rate: int = Field(default=16000, description="Sample rate of audio received from WebSocket")
|
||||||
|
output_sample_rate: int = Field(default=16000, description="Sample rate of audio sent to WebSocket")
|
||||||
|
packet_size: int = Field(default=2560, description="Packet size in bytes")
|
||||||
|
ptime: Optional[int] = Field(default=None, description="Buffered playback period in milliseconds")
|
||||||
|
|
||||||
|
|
||||||
|
class SipOption(BaseModel):
|
||||||
|
"""SIP protocol configuration."""
|
||||||
|
|
||||||
|
username: Optional[str] = Field(default=None, description="SIP username")
|
||||||
|
password: Optional[str] = Field(default=None, description="SIP password")
|
||||||
|
realm: Optional[str] = Field(default=None, description="SIP realm/domain")
|
||||||
|
headers: Optional[Dict[str, str]] = Field(default=None, description="Additional SIP headers")
|
||||||
|
|
||||||
|
|
||||||
|
class HandlerRule(BaseModel):
|
||||||
|
"""Handler routing rule."""
|
||||||
|
|
||||||
|
caller: Optional[str] = Field(default=None, description="Caller pattern (regex)")
|
||||||
|
callee: Optional[str] = Field(default=None, description="Callee pattern (regex)")
|
||||||
|
playbook: Optional[str] = Field(default=None, description="Playbook file path")
|
||||||
|
webhook: Optional[str] = Field(default=None, description="Webhook URL")
|
||||||
|
|
||||||
|
|
||||||
|
class CallOption(BaseModel):
|
||||||
|
"""Comprehensive call configuration options."""
|
||||||
|
|
||||||
|
# Basic options
|
||||||
|
denoise: bool = Field(default=False, description="Enable noise reduction")
|
||||||
|
offer: Optional[str] = Field(default=None, description="SDP offer string")
|
||||||
|
callee: Optional[str] = Field(default=None, description="Callee SIP URI or phone number")
|
||||||
|
caller: Optional[str] = Field(default=None, description="Caller SIP URI or phone number")
|
||||||
|
|
||||||
|
# Audio codec
|
||||||
|
codec: str = Field(default="pcm", description="Audio codec (pcm, pcma, pcmu, g722)")
|
||||||
|
|
||||||
|
# Component configurations
|
||||||
|
recorder: Optional[RecorderOption] = Field(default=None, description="Call recording config")
|
||||||
|
asr: Optional[ASROption] = Field(default=None, description="ASR configuration")
|
||||||
|
vad: Optional[VADOption] = Field(default=None, description="VAD configuration")
|
||||||
|
tts: Optional[TTSOption] = Field(default=None, description="TTS configuration")
|
||||||
|
media_pass: Optional[MediaPassOption] = Field(default=None, description="Media pass-through config")
|
||||||
|
sip: Optional[SipOption] = Field(default=None, description="SIP configuration")
|
||||||
|
|
||||||
|
# Timeouts and networking
|
||||||
|
handshake_timeout: Optional[int] = Field(default=None, description="Handshake timeout in seconds")
|
||||||
|
enable_ipv6: bool = Field(default=False, description="Enable IPv6 support")
|
||||||
|
inactivity_timeout: Optional[int] = Field(default=None, description="Inactivity timeout in seconds")
|
||||||
|
|
||||||
|
# EOU configuration
|
||||||
|
eou: Optional[Dict[str, Any]] = Field(default=None, description="End of utterance detection config")
|
||||||
|
|
||||||
|
# Extra parameters
|
||||||
|
extra: Optional[Dict[str, Any]] = Field(default=None, description="Additional custom parameters")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
populate_by_name = True
|
||||||
223
models/events.py
Normal file
223
models/events.py
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
"""Protocol event models matching the original active-call API."""
|
||||||
|
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
def current_timestamp_ms() -> int:
|
||||||
|
"""Get current timestamp in milliseconds."""
|
||||||
|
return int(datetime.now().timestamp() * 1000)
|
||||||
|
|
||||||
|
|
||||||
|
# Base Event Model
|
||||||
|
class BaseEvent(BaseModel):
|
||||||
|
"""Base event model."""
|
||||||
|
|
||||||
|
event: str = Field(..., description="Event type")
|
||||||
|
track_id: str = Field(..., description="Unique track identifier")
|
||||||
|
timestamp: int = Field(default_factory=current_timestamp_ms, description="Event timestamp in milliseconds")
|
||||||
|
|
||||||
|
|
||||||
|
# Lifecycle Events
|
||||||
|
class IncomingEvent(BaseEvent):
|
||||||
|
"""Incoming call event (SIP only)."""
|
||||||
|
|
||||||
|
event: str = Field(default="incoming", description="Event type")
|
||||||
|
caller: Optional[str] = Field(default=None, description="Caller's SIP URI")
|
||||||
|
callee: Optional[str] = Field(default=None, description="Callee's SIP URI")
|
||||||
|
sdp: Optional[str] = Field(default=None, description="SDP offer from caller")
|
||||||
|
|
||||||
|
|
||||||
|
class AnswerEvent(BaseEvent):
|
||||||
|
"""Call answered event."""
|
||||||
|
|
||||||
|
event: str = Field(default="answer", description="Event type")
|
||||||
|
sdp: Optional[str] = Field(default=None, description="SDP answer from server")
|
||||||
|
|
||||||
|
|
||||||
|
class RejectEvent(BaseEvent):
|
||||||
|
"""Call rejected event."""
|
||||||
|
|
||||||
|
event: str = Field(default="reject", description="Event type")
|
||||||
|
reason: Optional[str] = Field(default=None, description="Rejection reason")
|
||||||
|
code: Optional[int] = Field(default=None, description="SIP response code")
|
||||||
|
|
||||||
|
|
||||||
|
class RingingEvent(BaseEvent):
|
||||||
|
"""Call ringing event."""
|
||||||
|
|
||||||
|
event: str = Field(default="ringing", description="Event type")
|
||||||
|
early_media: bool = Field(default=False, description="Early media available")
|
||||||
|
|
||||||
|
|
||||||
|
class HangupEvent(BaseModel):
|
||||||
|
"""Call hangup event."""
|
||||||
|
|
||||||
|
event: str = Field(default="hangup", description="Event type")
|
||||||
|
timestamp: int = Field(default_factory=current_timestamp_ms, description="Event timestamp")
|
||||||
|
reason: Optional[str] = Field(default=None, description="Hangup reason")
|
||||||
|
initiator: Optional[str] = Field(default=None, description="Who initiated hangup")
|
||||||
|
start_time: Optional[str] = Field(default=None, description="Call start time (ISO 8601)")
|
||||||
|
hangup_time: Optional[str] = Field(default=None, description="Hangup time (ISO 8601)")
|
||||||
|
answer_time: Optional[str] = Field(default=None, description="Answer time (ISO 8601)")
|
||||||
|
ringing_time: Optional[str] = Field(default=None, description="Ringing time (ISO 8601)")
|
||||||
|
from_: Optional[Dict[str, Any]] = Field(default=None, alias="from", description="Caller info")
|
||||||
|
to: Optional[Dict[str, Any]] = Field(default=None, description="Callee info")
|
||||||
|
extra: Optional[Dict[str, Any]] = Field(default=None, description="Additional metadata")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
populate_by_name = True
|
||||||
|
|
||||||
|
|
||||||
|
# VAD Events
|
||||||
|
class SpeakingEvent(BaseEvent):
|
||||||
|
"""Speech detected event."""
|
||||||
|
|
||||||
|
event: str = Field(default="speaking", description="Event type")
|
||||||
|
start_time: int = Field(default_factory=current_timestamp_ms, description="Speech start time")
|
||||||
|
|
||||||
|
|
||||||
|
class SilenceEvent(BaseEvent):
|
||||||
|
"""Silence detected event."""
|
||||||
|
|
||||||
|
event: str = Field(default="silence", description="Event type")
|
||||||
|
start_time: int = Field(default_factory=current_timestamp_ms, description="Silence start time")
|
||||||
|
duration: int = Field(default=0, description="Silence duration in milliseconds")
|
||||||
|
|
||||||
|
|
||||||
|
# AI/ASR Events
|
||||||
|
class AsrFinalEvent(BaseEvent):
|
||||||
|
"""ASR final transcription event."""
|
||||||
|
|
||||||
|
event: str = Field(default="asrFinal", description="Event type")
|
||||||
|
index: int = Field(..., description="ASR result sequence number")
|
||||||
|
start_time: Optional[int] = Field(default=None, description="Speech start time")
|
||||||
|
end_time: Optional[int] = Field(default=None, description="Speech end time")
|
||||||
|
text: str = Field(..., description="Transcribed text")
|
||||||
|
|
||||||
|
|
||||||
|
class AsrDeltaEvent(BaseEvent):
|
||||||
|
"""ASR partial transcription event (streaming)."""
|
||||||
|
|
||||||
|
event: str = Field(default="asrDelta", description="Event type")
|
||||||
|
index: int = Field(..., description="ASR result sequence number")
|
||||||
|
start_time: Optional[int] = Field(default=None, description="Speech start time")
|
||||||
|
end_time: Optional[int] = Field(default=None, description="Speech end time")
|
||||||
|
text: str = Field(..., description="Partial transcribed text")
|
||||||
|
|
||||||
|
|
||||||
|
class EouEvent(BaseEvent):
|
||||||
|
"""End of utterance detection event."""
|
||||||
|
|
||||||
|
event: str = Field(default="eou", description="Event type")
|
||||||
|
completed: bool = Field(default=True, description="Whether utterance was completed")
|
||||||
|
|
||||||
|
|
||||||
|
# Audio Track Events
|
||||||
|
class TrackStartEvent(BaseEvent):
|
||||||
|
"""Audio track start event."""
|
||||||
|
|
||||||
|
event: str = Field(default="trackStart", description="Event type")
|
||||||
|
play_id: Optional[str] = Field(default=None, description="Play ID from TTS/Play command")
|
||||||
|
|
||||||
|
|
||||||
|
class TrackEndEvent(BaseEvent):
|
||||||
|
"""Audio track end event."""
|
||||||
|
|
||||||
|
event: str = Field(default="trackEnd", description="Event type")
|
||||||
|
duration: int = Field(..., description="Track duration in milliseconds")
|
||||||
|
ssrc: int = Field(..., description="RTP SSRC identifier")
|
||||||
|
play_id: Optional[str] = Field(default=None, description="Play ID from TTS/Play command")
|
||||||
|
|
||||||
|
|
||||||
|
class InterruptionEvent(BaseEvent):
|
||||||
|
"""Playback interruption event."""
|
||||||
|
|
||||||
|
event: str = Field(default="interruption", description="Event type")
|
||||||
|
play_id: Optional[str] = Field(default=None, description="Play ID that was interrupted")
|
||||||
|
subtitle: Optional[str] = Field(default=None, description="TTS text being played")
|
||||||
|
position: Optional[int] = Field(default=None, description="Word index position")
|
||||||
|
total_duration: Optional[int] = Field(default=None, description="Total TTS duration")
|
||||||
|
current: Optional[int] = Field(default=None, description="Elapsed time when interrupted")
|
||||||
|
|
||||||
|
|
||||||
|
# System Events
|
||||||
|
class ErrorEvent(BaseEvent):
|
||||||
|
"""Error event."""
|
||||||
|
|
||||||
|
event: str = Field(default="error", description="Event type")
|
||||||
|
sender: str = Field(..., description="Component that generated the error")
|
||||||
|
error: str = Field(..., description="Error message")
|
||||||
|
code: Optional[int] = Field(default=None, description="Error code")
|
||||||
|
|
||||||
|
|
||||||
|
class MetricsEvent(BaseModel):
|
||||||
|
"""Performance metrics event."""
|
||||||
|
|
||||||
|
event: str = Field(default="metrics", description="Event type")
|
||||||
|
timestamp: int = Field(default_factory=current_timestamp_ms, description="Event timestamp")
|
||||||
|
key: str = Field(..., description="Metric key")
|
||||||
|
duration: int = Field(..., description="Duration in milliseconds")
|
||||||
|
data: Optional[Dict[str, Any]] = Field(default=None, description="Additional metric data")
|
||||||
|
|
||||||
|
|
||||||
|
class AddHistoryEvent(BaseModel):
|
||||||
|
"""Conversation history entry added event."""
|
||||||
|
|
||||||
|
event: str = Field(default="addHistory", description="Event type")
|
||||||
|
timestamp: int = Field(default_factory=current_timestamp_ms, description="Event timestamp")
|
||||||
|
sender: Optional[str] = Field(default=None, description="Component that added history")
|
||||||
|
speaker: str = Field(..., description="Speaker identifier")
|
||||||
|
text: str = Field(..., description="Conversation text")
|
||||||
|
|
||||||
|
|
||||||
|
class DTMFEvent(BaseEvent):
|
||||||
|
"""DTMF tone detected event."""
|
||||||
|
|
||||||
|
event: str = Field(default="dtmf", description="Event type")
|
||||||
|
digit: str = Field(..., description="DTMF digit (0-9, *, #, A-D)")
|
||||||
|
|
||||||
|
|
||||||
|
# Event type mapping
|
||||||
|
EVENT_TYPES = {
|
||||||
|
"incoming": IncomingEvent,
|
||||||
|
"answer": AnswerEvent,
|
||||||
|
"reject": RejectEvent,
|
||||||
|
"ringing": RingingEvent,
|
||||||
|
"hangup": HangupEvent,
|
||||||
|
"speaking": SpeakingEvent,
|
||||||
|
"silence": SilenceEvent,
|
||||||
|
"asrFinal": AsrFinalEvent,
|
||||||
|
"asrDelta": AsrDeltaEvent,
|
||||||
|
"eou": EouEvent,
|
||||||
|
"trackStart": TrackStartEvent,
|
||||||
|
"trackEnd": TrackEndEvent,
|
||||||
|
"interruption": InterruptionEvent,
|
||||||
|
"error": ErrorEvent,
|
||||||
|
"metrics": MetricsEvent,
|
||||||
|
"addHistory": AddHistoryEvent,
|
||||||
|
"dtmf": DTMFEvent,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_event(event_type: str, **kwargs) -> BaseModel:
|
||||||
|
"""
|
||||||
|
Create an event model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_type: Type of event to create
|
||||||
|
**kwargs: Event fields
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Event model instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If event type is unknown
|
||||||
|
"""
|
||||||
|
event_class = EVENT_TYPES.get(event_type)
|
||||||
|
|
||||||
|
if not event_class:
|
||||||
|
raise ValueError(f"Unknown event type: {event_type}")
|
||||||
|
|
||||||
|
return event_class(event=event_type, **kwargs)
|
||||||
6
processors/__init__.py
Normal file
6
processors/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
"""Audio Processors Package"""
|
||||||
|
|
||||||
|
from processors.eou import EouDetector
|
||||||
|
from processors.vad import SileroVAD, VADProcessor
|
||||||
|
|
||||||
|
__all__ = ["EouDetector", "SileroVAD", "VADProcessor"]
|
||||||
80
processors/eou.py
Normal file
80
processors/eou.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
"""End-of-Utterance Detection."""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class EouDetector:
|
||||||
|
"""
|
||||||
|
End-of-utterance detector. Fires EOU only after continuous silence for
|
||||||
|
silence_threshold_ms. Short pauses between sentences do not trigger EOU
|
||||||
|
because speech resets the silence timer (one EOU per turn).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, silence_threshold_ms: int = 1000, min_speech_duration_ms: int = 250):
|
||||||
|
"""
|
||||||
|
Initialize EOU detector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
silence_threshold_ms: How long silence must last to trigger EOU (default 1000ms)
|
||||||
|
min_speech_duration_ms: Minimum speech duration to consider valid (default 250ms)
|
||||||
|
"""
|
||||||
|
self.threshold = silence_threshold_ms / 1000.0
|
||||||
|
self.min_speech = min_speech_duration_ms / 1000.0
|
||||||
|
self._silence_threshold_ms = silence_threshold_ms
|
||||||
|
self._min_speech_duration_ms = min_speech_duration_ms
|
||||||
|
|
||||||
|
# State
|
||||||
|
self.is_speaking = False
|
||||||
|
self.speech_start_time = 0.0
|
||||||
|
self.silence_start_time: Optional[float] = None
|
||||||
|
self.triggered = False
|
||||||
|
|
||||||
|
def process(self, vad_status: str) -> bool:
|
||||||
|
"""
|
||||||
|
Process VAD status and detect end of utterance.
|
||||||
|
|
||||||
|
Input: "Speech" or "Silence" (from VAD).
|
||||||
|
Output: True if EOU detected, False otherwise.
|
||||||
|
|
||||||
|
Short breaks between phrases reset the silence clock when speech
|
||||||
|
resumes, so only one EOU is emitted after the user truly stops.
|
||||||
|
"""
|
||||||
|
now = time.time()
|
||||||
|
|
||||||
|
if vad_status == "Speech":
|
||||||
|
if not self.is_speaking:
|
||||||
|
self.is_speaking = True
|
||||||
|
self.speech_start_time = now
|
||||||
|
self.triggered = False
|
||||||
|
# Any speech resets silence timer — short pause + more speech = one utterance
|
||||||
|
self.silence_start_time = None
|
||||||
|
return False
|
||||||
|
|
||||||
|
if vad_status == "Silence":
|
||||||
|
if not self.is_speaking:
|
||||||
|
return False
|
||||||
|
if self.silence_start_time is None:
|
||||||
|
self.silence_start_time = now
|
||||||
|
|
||||||
|
speech_duration = self.silence_start_time - self.speech_start_time
|
||||||
|
if speech_duration < self.min_speech:
|
||||||
|
self.is_speaking = False
|
||||||
|
self.silence_start_time = None
|
||||||
|
return False
|
||||||
|
|
||||||
|
silence_duration = now - self.silence_start_time
|
||||||
|
if silence_duration >= self.threshold and not self.triggered:
|
||||||
|
self.triggered = True
|
||||||
|
self.is_speaking = False
|
||||||
|
self.silence_start_time = None
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Reset EOU detector state."""
|
||||||
|
self.is_speaking = False
|
||||||
|
self.speech_start_time = 0.0
|
||||||
|
self.silence_start_time = None
|
||||||
|
self.triggered = False
|
||||||
168
processors/tracks.py
Normal file
168
processors/tracks.py
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
"""Audio track processing for WebRTC."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import fractions
|
||||||
|
from typing import Optional
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
# Try to import aiortc (optional for WebRTC functionality)
|
||||||
|
try:
|
||||||
|
from aiortc import AudioStreamTrack
|
||||||
|
AIORTC_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
AIORTC_AVAILABLE = False
|
||||||
|
AudioStreamTrack = object # Dummy class for type hints
|
||||||
|
|
||||||
|
# Try to import PyAV (optional for audio resampling)
|
||||||
|
try:
|
||||||
|
from av import AudioFrame, AudioResampler
|
||||||
|
AV_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
AV_AVAILABLE = False
|
||||||
|
# Create dummy classes for type hints
|
||||||
|
class AudioFrame:
|
||||||
|
pass
|
||||||
|
class AudioResampler:
|
||||||
|
pass
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class Resampled16kTrack(AudioStreamTrack if AIORTC_AVAILABLE else object):
|
||||||
|
"""
|
||||||
|
Audio track that resamples input to 16kHz mono PCM.
|
||||||
|
|
||||||
|
Wraps an existing MediaStreamTrack and converts its output
|
||||||
|
to 16kHz mono 16-bit PCM format for the pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, track, target_sample_rate: int = 16000):
|
||||||
|
"""
|
||||||
|
Initialize resampled track.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
track: Source MediaStreamTrack
|
||||||
|
target_sample_rate: Target sample rate (default: 16000)
|
||||||
|
"""
|
||||||
|
if not AIORTC_AVAILABLE:
|
||||||
|
raise RuntimeError("aiortc not available - Resampled16kTrack cannot be used")
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.track = track
|
||||||
|
self.target_sample_rate = target_sample_rate
|
||||||
|
|
||||||
|
if AV_AVAILABLE:
|
||||||
|
self.resampler = AudioResampler(
|
||||||
|
format="s16",
|
||||||
|
layout="mono",
|
||||||
|
rate=target_sample_rate
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning("PyAV not available, audio resampling disabled")
|
||||||
|
self.resampler = None
|
||||||
|
|
||||||
|
self._closed = False
|
||||||
|
|
||||||
|
async def recv(self):
|
||||||
|
"""
|
||||||
|
Receive and resample next audio frame.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Resampled AudioFrame at 16kHz mono
|
||||||
|
"""
|
||||||
|
if self._closed:
|
||||||
|
raise RuntimeError("Track is closed")
|
||||||
|
|
||||||
|
# Get frame from source track
|
||||||
|
frame = await self.track.recv()
|
||||||
|
|
||||||
|
# Resample the frame if AV is available
|
||||||
|
if AV_AVAILABLE and self.resampler:
|
||||||
|
resampled_frame = self.resampler.resample(frame)
|
||||||
|
# Ensure the frame has the correct format
|
||||||
|
resampled_frame.sample_rate = self.target_sample_rate
|
||||||
|
return resampled_frame
|
||||||
|
else:
|
||||||
|
# Return frame as-is if AV is not available
|
||||||
|
return frame
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
"""Stop the track and cleanup resources."""
|
||||||
|
self._closed = True
|
||||||
|
if hasattr(self, 'resampler') and self.resampler:
|
||||||
|
del self.resampler
|
||||||
|
logger.debug("Resampled track stopped")
|
||||||
|
|
||||||
|
|
||||||
|
class SineWaveTrack(AudioStreamTrack if AIORTC_AVAILABLE else object):
|
||||||
|
"""
|
||||||
|
Synthetic audio track that generates a sine wave.
|
||||||
|
|
||||||
|
Useful for testing without requiring real audio input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, sample_rate: int = 16000, frequency: int = 440):
|
||||||
|
"""
|
||||||
|
Initialize sine wave track.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample_rate: Audio sample rate (default: 16000)
|
||||||
|
frequency: Sine wave frequency in Hz (default: 440)
|
||||||
|
"""
|
||||||
|
if not AIORTC_AVAILABLE:
|
||||||
|
raise RuntimeError("aiortc not available - SineWaveTrack cannot be used")
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.frequency = frequency
|
||||||
|
self.counter = 0
|
||||||
|
self._stopped = False
|
||||||
|
|
||||||
|
async def recv(self):
|
||||||
|
"""
|
||||||
|
Generate next audio frame with sine wave.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AudioFrame with sine wave data
|
||||||
|
"""
|
||||||
|
if self._stopped:
|
||||||
|
raise RuntimeError("Track is stopped")
|
||||||
|
|
||||||
|
# Generate 20ms of audio
|
||||||
|
samples = int(self.sample_rate * 0.02)
|
||||||
|
pts = self.counter
|
||||||
|
time_base = fractions.Fraction(1, self.sample_rate)
|
||||||
|
|
||||||
|
# Generate sine wave
|
||||||
|
t = np.linspace(
|
||||||
|
self.counter / self.sample_rate,
|
||||||
|
(self.counter + samples) / self.sample_rate,
|
||||||
|
samples,
|
||||||
|
endpoint=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate sine wave (Int16 PCM)
|
||||||
|
data = (0.5 * np.sin(2 * np.pi * self.frequency * t) * 32767).astype(np.int16)
|
||||||
|
|
||||||
|
# Update counter
|
||||||
|
self.counter += samples
|
||||||
|
|
||||||
|
# Create AudioFrame if AV is available
|
||||||
|
if AV_AVAILABLE:
|
||||||
|
frame = AudioFrame.from_ndarray(data.reshape(1, -1), format='s16', layout='mono')
|
||||||
|
frame.pts = pts
|
||||||
|
frame.time_base = time_base
|
||||||
|
frame.sample_rate = self.sample_rate
|
||||||
|
return frame
|
||||||
|
else:
|
||||||
|
# Return simple data structure if AV is not available
|
||||||
|
return {
|
||||||
|
'data': data,
|
||||||
|
'sample_rate': self.sample_rate,
|
||||||
|
'pts': pts,
|
||||||
|
'time_base': time_base
|
||||||
|
}
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
"""Stop the track."""
|
||||||
|
self._stopped = True
|
||||||
213
processors/vad.py
Normal file
213
processors/vad.py
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
"""Voice Activity Detection using Silero VAD."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
from typing import Tuple, Optional
|
||||||
|
import numpy as np
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from processors.eou import EouDetector
|
||||||
|
|
||||||
|
# Try to import onnxruntime (optional for VAD functionality)
|
||||||
|
try:
|
||||||
|
import onnxruntime as ort
|
||||||
|
ONNX_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
ONNX_AVAILABLE = False
|
||||||
|
ort = None
|
||||||
|
logger.warning("onnxruntime not available - VAD will be disabled")
|
||||||
|
|
||||||
|
|
||||||
|
class SileroVAD:
|
||||||
|
"""
|
||||||
|
Voice Activity Detection using Silero VAD model.
|
||||||
|
|
||||||
|
Detects speech in audio chunks using the Silero VAD ONNX model.
|
||||||
|
Returns "Speech" or "Silence" for each audio chunk.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model_path: str = "data/vad/silero_vad.onnx", sample_rate: int = 16000):
|
||||||
|
"""
|
||||||
|
Initialize Silero VAD.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to Silero VAD ONNX model
|
||||||
|
sample_rate: Audio sample rate (must be 16kHz for Silero VAD)
|
||||||
|
"""
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.model_path = model_path
|
||||||
|
|
||||||
|
# Check if model exists
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
logger.warning(f"VAD model not found at {model_path}. VAD will be disabled.")
|
||||||
|
self.session = None
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if onnxruntime is available
|
||||||
|
if not ONNX_AVAILABLE:
|
||||||
|
logger.warning("onnxruntime not available - VAD will be disabled")
|
||||||
|
self.session = None
|
||||||
|
return
|
||||||
|
|
||||||
|
# Load ONNX model
|
||||||
|
try:
|
||||||
|
self.session = ort.InferenceSession(model_path)
|
||||||
|
logger.info(f"Loaded Silero VAD model from {model_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load VAD model: {e}")
|
||||||
|
self.session = None
|
||||||
|
return
|
||||||
|
|
||||||
|
# Internal state for VAD
|
||||||
|
self._reset_state()
|
||||||
|
self.buffer = np.array([], dtype=np.float32)
|
||||||
|
self.min_chunk_size = 512
|
||||||
|
self.last_label = "Silence"
|
||||||
|
self.last_probability = 0.0
|
||||||
|
|
||||||
|
def _reset_state(self):
|
||||||
|
# Silero VAD V4+ expects state shape [2, 1, 128]
|
||||||
|
self._state = np.zeros((2, 1, 128), dtype=np.float32)
|
||||||
|
self._sr = np.array([self.sample_rate], dtype=np.int64)
|
||||||
|
|
||||||
|
def process_audio(self, pcm_bytes: bytes, chunk_size_ms: int = 20) -> Tuple[str, float]:
|
||||||
|
"""
|
||||||
|
Process audio chunk and detect speech.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
|
||||||
|
chunk_size_ms: Chunk duration in milliseconds (ignored for buffering logic)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (label, probability) where label is "Speech" or "Silence"
|
||||||
|
"""
|
||||||
|
if self.session is None or not ONNX_AVAILABLE:
|
||||||
|
# If model not loaded or onnxruntime not available, assume speech
|
||||||
|
return "Speech", 1.0
|
||||||
|
|
||||||
|
# Convert bytes to numpy array of int16
|
||||||
|
audio_int16 = np.frombuffer(pcm_bytes, dtype=np.int16)
|
||||||
|
|
||||||
|
# Normalize to float32 (-1.0 to 1.0)
|
||||||
|
audio_float = audio_int16.astype(np.float32) / 32768.0
|
||||||
|
|
||||||
|
# Add to buffer
|
||||||
|
self.buffer = np.concatenate((self.buffer, audio_float))
|
||||||
|
|
||||||
|
# Process all complete chunks in the buffer
|
||||||
|
processed_any = False
|
||||||
|
while len(self.buffer) >= self.min_chunk_size:
|
||||||
|
# Slice exactly 512 samples
|
||||||
|
chunk = self.buffer[:self.min_chunk_size]
|
||||||
|
self.buffer = self.buffer[self.min_chunk_size:]
|
||||||
|
|
||||||
|
# Prepare inputs
|
||||||
|
# Input tensor shape: [batch, samples] -> [1, 512]
|
||||||
|
input_tensor = chunk.reshape(1, -1)
|
||||||
|
|
||||||
|
# Run inference
|
||||||
|
try:
|
||||||
|
ort_inputs = {
|
||||||
|
'input': input_tensor,
|
||||||
|
'state': self._state,
|
||||||
|
'sr': self._sr
|
||||||
|
}
|
||||||
|
|
||||||
|
# Outputs: probability, state
|
||||||
|
out, self._state = self.session.run(None, ort_inputs)
|
||||||
|
|
||||||
|
# Get probability
|
||||||
|
self.last_probability = float(out[0][0])
|
||||||
|
self.last_label = "Speech" if self.last_probability >= 0.5 else "Silence"
|
||||||
|
processed_any = True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"VAD inference error: {e}")
|
||||||
|
# Try to determine if it's an input name issue
|
||||||
|
try:
|
||||||
|
inputs = [x.name for x in self.session.get_inputs()]
|
||||||
|
logger.error(f"Model expects inputs: {inputs}")
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return "Speech", 1.0
|
||||||
|
|
||||||
|
return self.last_label, self.last_probability
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Reset VAD internal state."""
|
||||||
|
self._reset_state()
|
||||||
|
self.buffer = np.array([], dtype=np.float32)
|
||||||
|
self.last_label = "Silence"
|
||||||
|
self.last_probability = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class VADProcessor:
|
||||||
|
"""
|
||||||
|
High-level VAD processor with state management.
|
||||||
|
|
||||||
|
Tracks speech/silence state and emits events on transitions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, vad_model: SileroVAD, threshold: float = 0.5,
|
||||||
|
silence_threshold_ms: int = 1000, min_speech_duration_ms: int = 250):
|
||||||
|
"""
|
||||||
|
Initialize VAD processor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vad_model: Silero VAD model instance
|
||||||
|
threshold: Speech detection threshold
|
||||||
|
silence_threshold_ms: EOU silence threshold in ms (longer = one EOU across short pauses)
|
||||||
|
min_speech_duration_ms: EOU min speech duration in ms (ignore very short noises)
|
||||||
|
"""
|
||||||
|
self.vad = vad_model
|
||||||
|
self.threshold = threshold
|
||||||
|
self._eou_silence_ms = silence_threshold_ms
|
||||||
|
self._eou_min_speech_ms = min_speech_duration_ms
|
||||||
|
self.is_speaking = False
|
||||||
|
self.speech_start_time: Optional[float] = None
|
||||||
|
self.silence_start_time: Optional[float] = None
|
||||||
|
self.eou_detector = EouDetector(silence_threshold_ms, min_speech_duration_ms)
|
||||||
|
|
||||||
|
def process(self, pcm_bytes: bytes, chunk_size_ms: int = 20) -> Optional[Tuple[str, float]]:
|
||||||
|
"""
|
||||||
|
Process audio chunk and detect state changes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pcm_bytes: PCM audio data
|
||||||
|
chunk_size_ms: Chunk duration in milliseconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (event_type, probability) if state changed, None otherwise
|
||||||
|
"""
|
||||||
|
label, probability = self.vad.process_audio(pcm_bytes, chunk_size_ms)
|
||||||
|
|
||||||
|
# Check if this is speech based on threshold
|
||||||
|
is_speech = probability >= self.threshold
|
||||||
|
|
||||||
|
# Check EOU
|
||||||
|
if self.eou_detector.process("Speech" if is_speech else "Silence"):
|
||||||
|
return ("eou", probability)
|
||||||
|
|
||||||
|
# State transition: Silence -> Speech
|
||||||
|
if is_speech and not self.is_speaking:
|
||||||
|
self.is_speaking = True
|
||||||
|
self.speech_start_time = asyncio.get_event_loop().time()
|
||||||
|
self.silence_start_time = None
|
||||||
|
return ("speaking", probability)
|
||||||
|
|
||||||
|
# State transition: Speech -> Silence
|
||||||
|
elif not is_speech and self.is_speaking:
|
||||||
|
self.is_speaking = False
|
||||||
|
self.silence_start_time = asyncio.get_event_loop().time()
|
||||||
|
self.speech_start_time = None
|
||||||
|
return ("silence", probability)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Reset VAD state."""
|
||||||
|
self.vad.reset()
|
||||||
|
self.is_speaking = False
|
||||||
|
self.speech_start_time = None
|
||||||
|
self.silence_start_time = None
|
||||||
|
self.eou_detector = EouDetector(self._eou_silence_ms, self._eou_min_speech_ms)
|
||||||
134
pyproject.toml
Normal file
134
pyproject.toml
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=68.0"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "py-active-call-cc"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Python Active-Call: Real-time audio streaming with WebSocket and WebRTC"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.11"
|
||||||
|
license = {text = "MIT"}
|
||||||
|
authors = [
|
||||||
|
{name = "Your Name", email = "your.email@example.com"}
|
||||||
|
]
|
||||||
|
keywords = ["webrtc", "websocket", "audio", "voip", "real-time"]
|
||||||
|
classifiers = [
|
||||||
|
"Development Status :: 3 - Alpha",
|
||||||
|
"Intended Audience :: Developers",
|
||||||
|
"Topic :: Communications :: Telephony",
|
||||||
|
"License :: OSI Approved :: MIT License",
|
||||||
|
"Programming Language :: Python :: 3",
|
||||||
|
"Programming Language :: Python :: 3.11",
|
||||||
|
"Programming Language :: Python :: 3.12",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.urls]
|
||||||
|
Homepage = "https://github.com/yourusername/py-active-call-cc"
|
||||||
|
Documentation = "https://github.com/yourusername/py-active-call-cc/blob/main/README.md"
|
||||||
|
Repository = "https://github.com/yourusername/py-active-call-cc.git"
|
||||||
|
Issues = "https://github.com/yourusername/py-active-call-cc/issues"
|
||||||
|
|
||||||
|
[tool.setuptools.packages.find]
|
||||||
|
where = ["."]
|
||||||
|
include = ["app*"]
|
||||||
|
exclude = ["tests*", "scripts*", "reference*"]
|
||||||
|
|
||||||
|
[tool.black]
|
||||||
|
line-length = 100
|
||||||
|
target-version = ['py311']
|
||||||
|
include = '\.pyi?$'
|
||||||
|
extend-exclude = '''
|
||||||
|
/(
|
||||||
|
# directories
|
||||||
|
\.eggs
|
||||||
|
| \.git
|
||||||
|
| \.hg
|
||||||
|
| \.mypy_cache
|
||||||
|
| \.tox
|
||||||
|
| \.venv
|
||||||
|
| build
|
||||||
|
| dist
|
||||||
|
| reference
|
||||||
|
)/
|
||||||
|
'''
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
line-length = 100
|
||||||
|
target-version = "py311"
|
||||||
|
select = [
|
||||||
|
"E", # pycodestyle errors
|
||||||
|
"W", # pycodestyle warnings
|
||||||
|
"F", # pyflakes
|
||||||
|
"I", # isort
|
||||||
|
"B", # flake8-bugbear
|
||||||
|
"C4", # flake8-comprehensions
|
||||||
|
"UP", # pyupgrade
|
||||||
|
]
|
||||||
|
ignore = [
|
||||||
|
"E501", # line too long (handled by black)
|
||||||
|
"B008", # do not perform function calls in argument defaults
|
||||||
|
]
|
||||||
|
exclude = [
|
||||||
|
".bzr",
|
||||||
|
".direnv",
|
||||||
|
".eggs",
|
||||||
|
".git",
|
||||||
|
".hg",
|
||||||
|
".mypy_cache",
|
||||||
|
".nox",
|
||||||
|
".pants.d",
|
||||||
|
".ruff_cache",
|
||||||
|
".svn",
|
||||||
|
".tox",
|
||||||
|
".venv",
|
||||||
|
"__pypackages__",
|
||||||
|
"_build",
|
||||||
|
"buck-out",
|
||||||
|
"build",
|
||||||
|
"dist",
|
||||||
|
"node_modules",
|
||||||
|
"venv",
|
||||||
|
"reference",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.ruff.per-file-ignores]
|
||||||
|
"__init__.py" = ["F401"] # unused imports
|
||||||
|
|
||||||
|
[tool.mypy]
|
||||||
|
python_version = "3.11"
|
||||||
|
warn_return_any = true
|
||||||
|
warn_unused_configs = true
|
||||||
|
disallow_untyped_defs = false
|
||||||
|
disallow_incomplete_defs = false
|
||||||
|
check_untyped_defs = true
|
||||||
|
no_implicit_optional = true
|
||||||
|
warn_redundant_casts = true
|
||||||
|
warn_unused_ignores = true
|
||||||
|
warn_no_return = true
|
||||||
|
strict_equality = true
|
||||||
|
exclude = [
|
||||||
|
"venv",
|
||||||
|
"reference",
|
||||||
|
"build",
|
||||||
|
"dist",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[tool.mypy.overrides]]
|
||||||
|
module = [
|
||||||
|
"aiortc.*",
|
||||||
|
"av.*",
|
||||||
|
"onnxruntime.*",
|
||||||
|
]
|
||||||
|
ignore_missing_imports = true
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
minversion = "7.0"
|
||||||
|
addopts = "-ra -q --strict-markers --strict-config"
|
||||||
|
testpaths = ["tests"]
|
||||||
|
pythonpath = ["."]
|
||||||
|
asyncio_mode = "auto"
|
||||||
|
markers = [
|
||||||
|
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
||||||
|
"integration: marks tests as integration tests",
|
||||||
|
]
|
||||||
37
requirements.txt
Normal file
37
requirements.txt
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
# Web Framework
|
||||||
|
fastapi>=0.109.0
|
||||||
|
uvicorn[standard]>=0.27.0
|
||||||
|
websockets>=12.0
|
||||||
|
python-multipart>=0.0.6
|
||||||
|
|
||||||
|
# WebRTC (optional - for WebRTC transport)
|
||||||
|
aiortc>=1.6.0
|
||||||
|
|
||||||
|
# Audio Processing
|
||||||
|
av>=12.1.0
|
||||||
|
numpy>=1.26.3
|
||||||
|
onnxruntime>=1.16.3
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
pydantic>=2.5.3
|
||||||
|
pydantic-settings>=2.1.0
|
||||||
|
python-dotenv>=1.0.0
|
||||||
|
toml>=0.10.2
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
loguru>=0.7.2
|
||||||
|
|
||||||
|
# HTTP Client
|
||||||
|
aiohttp>=3.9.1
|
||||||
|
|
||||||
|
# AI Services - LLM
|
||||||
|
openai>=1.0.0
|
||||||
|
|
||||||
|
# AI Services - TTS
|
||||||
|
edge-tts>=6.1.0
|
||||||
|
pydub>=0.25.0 # For audio format conversion
|
||||||
|
|
||||||
|
# Microphone client dependencies
|
||||||
|
sounddevice>=0.4.6
|
||||||
|
soundfile>=0.12.1
|
||||||
|
pyaudio>=0.2.13 # More reliable audio on Windows
|
||||||
166
scripts/test_websocket.py
Normal file
166
scripts/test_websocket.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
"""WebSocket endpoint test client.
|
||||||
|
|
||||||
|
Tests the /ws endpoint with sine wave or file audio streaming.
|
||||||
|
Based on reference/py-active-call/exec/test_ws_endpoint/test_ws.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import aiohttp
|
||||||
|
import json
|
||||||
|
import struct
|
||||||
|
import math
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
SERVER_URL = "ws://localhost:8000/ws"
|
||||||
|
SAMPLE_RATE = 16000
|
||||||
|
FREQUENCY = 440 # 440Hz Sine Wave
|
||||||
|
CHUNK_DURATION_MS = 20
|
||||||
|
# 16kHz * 16-bit (2 bytes) * 20ms = 640 bytes per chunk
|
||||||
|
CHUNK_SIZE_BYTES = int(SAMPLE_RATE * 2 * (CHUNK_DURATION_MS / 1000.0))
|
||||||
|
|
||||||
|
|
||||||
|
def generate_sine_wave(duration_ms=1000):
|
||||||
|
"""Generates sine wave audio (16kHz mono PCM 16-bit)."""
|
||||||
|
num_samples = int(SAMPLE_RATE * (duration_ms / 1000.0))
|
||||||
|
audio_data = bytearray()
|
||||||
|
|
||||||
|
for x in range(num_samples):
|
||||||
|
# Generate sine wave sample
|
||||||
|
value = int(32767.0 * math.sin(2 * math.pi * FREQUENCY * x / SAMPLE_RATE))
|
||||||
|
# Pack as little-endian 16-bit integer
|
||||||
|
audio_data.extend(struct.pack('<h', value))
|
||||||
|
|
||||||
|
return audio_data
|
||||||
|
|
||||||
|
|
||||||
|
async def receive_loop(ws):
|
||||||
|
"""Listen for incoming messages from the server."""
|
||||||
|
print("👂 Listening for server responses...")
|
||||||
|
async for msg in ws:
|
||||||
|
timestamp = datetime.now().strftime("%H:%M:%S")
|
||||||
|
|
||||||
|
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||||
|
try:
|
||||||
|
data = json.loads(msg.data)
|
||||||
|
event_type = data.get('event', 'Unknown')
|
||||||
|
print(f"[{timestamp}] 📨 Event: {event_type} | {msg.data[:150]}...")
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
print(f"[{timestamp}] 📨 Text: {msg.data[:100]}...")
|
||||||
|
|
||||||
|
elif msg.type == aiohttp.WSMsgType.BINARY:
|
||||||
|
# Received audio chunk back (e.g., TTS or echo)
|
||||||
|
print(f"[{timestamp}] 🔊 Audio: {len(msg.data)} bytes", end="\r")
|
||||||
|
|
||||||
|
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
||||||
|
print(f"\n[{timestamp}] ❌ Socket Closed")
|
||||||
|
break
|
||||||
|
|
||||||
|
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||||
|
print(f"\n[{timestamp}] ⚠️ Socket Error")
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
async def send_file_loop(ws, file_path):
|
||||||
|
"""Stream a raw PCM/WAV file to the server."""
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
print(f"❌ Error: File '{file_path}' not found.")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"📂 Streaming file: {file_path} ...")
|
||||||
|
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
# Skip WAV header if present (first 44 bytes)
|
||||||
|
if file_path.endswith('.wav'):
|
||||||
|
f.read(44)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
chunk = f.read(CHUNK_SIZE_BYTES)
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Send binary frame
|
||||||
|
await ws.send_bytes(chunk)
|
||||||
|
|
||||||
|
# Sleep to simulate real-time playback
|
||||||
|
await asyncio.sleep(CHUNK_DURATION_MS / 1000.0)
|
||||||
|
|
||||||
|
print(f"\n✅ Finished streaming {file_path}")
|
||||||
|
|
||||||
|
|
||||||
|
async def send_sine_loop(ws):
|
||||||
|
"""Stream generated sine wave to the server."""
|
||||||
|
print("🎙️ Starting Audio Stream (Sine Wave)...")
|
||||||
|
|
||||||
|
# Generate 10 seconds of audio buffer
|
||||||
|
audio_buffer = generate_sine_wave(5000)
|
||||||
|
cursor = 0
|
||||||
|
|
||||||
|
while cursor < len(audio_buffer):
|
||||||
|
chunk = audio_buffer[cursor:cursor + CHUNK_SIZE_BYTES]
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
|
||||||
|
await ws.send_bytes(chunk)
|
||||||
|
cursor += len(chunk)
|
||||||
|
|
||||||
|
await asyncio.sleep(CHUNK_DURATION_MS / 1000.0)
|
||||||
|
|
||||||
|
print("\n✅ Finished streaming test audio.")
|
||||||
|
|
||||||
|
|
||||||
|
async def run_client(url, file_path=None, use_sine=False):
|
||||||
|
"""Run the WebSocket test client."""
|
||||||
|
session = aiohttp.ClientSession()
|
||||||
|
try:
|
||||||
|
print(f"🔌 Connecting to {url}...")
|
||||||
|
async with session.ws_connect(url) as ws:
|
||||||
|
print("✅ Connected!")
|
||||||
|
|
||||||
|
# Send initial invite command
|
||||||
|
init_cmd = {
|
||||||
|
"command": "invite",
|
||||||
|
"option": {
|
||||||
|
"codec": "pcm",
|
||||||
|
"samplerate": SAMPLE_RATE
|
||||||
|
}
|
||||||
|
}
|
||||||
|
await ws.send_json(init_cmd)
|
||||||
|
print("📤 Sent Invite Command")
|
||||||
|
|
||||||
|
# Select sender based on args
|
||||||
|
if use_sine:
|
||||||
|
sender_task = send_sine_loop(ws)
|
||||||
|
elif file_path:
|
||||||
|
sender_task = send_file_loop(ws, file_path)
|
||||||
|
else:
|
||||||
|
# Default to sine wave
|
||||||
|
sender_task = send_sine_loop(ws)
|
||||||
|
|
||||||
|
# Run send and receive loops in parallel
|
||||||
|
await asyncio.gather(
|
||||||
|
receive_loop(ws),
|
||||||
|
sender_task
|
||||||
|
)
|
||||||
|
|
||||||
|
except aiohttp.ClientConnectorError:
|
||||||
|
print(f"❌ Connection Failed. Is the server running at {url}?")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error: {e}")
|
||||||
|
finally:
|
||||||
|
await session.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="WebSocket Audio Test Client")
|
||||||
|
parser.add_argument("--url", default=SERVER_URL, help="WebSocket endpoint URL")
|
||||||
|
parser.add_argument("--file", help="Path to PCM/WAV file to stream")
|
||||||
|
parser.add_argument("--sine", action="store_true", help="Use sine wave generation (default)")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
asyncio.run(run_client(args.url, args.file, args.sine))
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n👋 Client stopped.")
|
||||||
47
services/__init__.py
Normal file
47
services/__init__.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
"""AI Services package.
|
||||||
|
|
||||||
|
Provides ASR, LLM, TTS, and Realtime API services for voice conversation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from services.base import (
|
||||||
|
ServiceState,
|
||||||
|
ASRResult,
|
||||||
|
LLMMessage,
|
||||||
|
TTSChunk,
|
||||||
|
BaseASRService,
|
||||||
|
BaseLLMService,
|
||||||
|
BaseTTSService,
|
||||||
|
)
|
||||||
|
from services.llm import OpenAILLMService, MockLLMService
|
||||||
|
from services.tts import EdgeTTSService, MockTTSService
|
||||||
|
from services.asr import BufferedASRService, MockASRService
|
||||||
|
from services.siliconflow_asr import SiliconFlowASRService
|
||||||
|
from services.siliconflow_tts import SiliconFlowTTSService
|
||||||
|
from services.realtime import RealtimeService, RealtimeConfig, RealtimePipeline
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Base classes
|
||||||
|
"ServiceState",
|
||||||
|
"ASRResult",
|
||||||
|
"LLMMessage",
|
||||||
|
"TTSChunk",
|
||||||
|
"BaseASRService",
|
||||||
|
"BaseLLMService",
|
||||||
|
"BaseTTSService",
|
||||||
|
# LLM
|
||||||
|
"OpenAILLMService",
|
||||||
|
"MockLLMService",
|
||||||
|
# TTS
|
||||||
|
"EdgeTTSService",
|
||||||
|
"MockTTSService",
|
||||||
|
# ASR
|
||||||
|
"BufferedASRService",
|
||||||
|
"MockASRService",
|
||||||
|
"SiliconFlowASRService",
|
||||||
|
# TTS (SiliconFlow)
|
||||||
|
"SiliconFlowTTSService",
|
||||||
|
# Realtime
|
||||||
|
"RealtimeService",
|
||||||
|
"RealtimeConfig",
|
||||||
|
"RealtimePipeline",
|
||||||
|
]
|
||||||
147
services/asr.py
Normal file
147
services/asr.py
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
"""ASR (Automatic Speech Recognition) Service implementations.
|
||||||
|
|
||||||
|
Provides speech-to-text capabilities with streaming support.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from typing import AsyncIterator, Optional
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from services.base import BaseASRService, ASRResult, ServiceState
|
||||||
|
|
||||||
|
# Try to import websockets for streaming ASR
|
||||||
|
try:
|
||||||
|
import websockets
|
||||||
|
WEBSOCKETS_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
WEBSOCKETS_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
class BufferedASRService(BaseASRService):
|
||||||
|
"""
|
||||||
|
Buffered ASR service that accumulates audio and provides
|
||||||
|
a simple text accumulator for use with EOU detection.
|
||||||
|
|
||||||
|
This is a lightweight implementation that works with the
|
||||||
|
existing VAD + EOU pattern without requiring external ASR.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sample_rate: int = 16000,
|
||||||
|
language: str = "en"
|
||||||
|
):
|
||||||
|
super().__init__(sample_rate=sample_rate, language=language)
|
||||||
|
|
||||||
|
self._audio_buffer: bytes = b""
|
||||||
|
self._current_text: str = ""
|
||||||
|
self._transcript_queue: asyncio.Queue[ASRResult] = asyncio.Queue()
|
||||||
|
|
||||||
|
async def connect(self) -> None:
|
||||||
|
"""No connection needed for buffered ASR."""
|
||||||
|
self.state = ServiceState.CONNECTED
|
||||||
|
logger.info("Buffered ASR service connected")
|
||||||
|
|
||||||
|
async def disconnect(self) -> None:
|
||||||
|
"""Clear buffers on disconnect."""
|
||||||
|
self._audio_buffer = b""
|
||||||
|
self._current_text = ""
|
||||||
|
self.state = ServiceState.DISCONNECTED
|
||||||
|
logger.info("Buffered ASR service disconnected")
|
||||||
|
|
||||||
|
async def send_audio(self, audio: bytes) -> None:
|
||||||
|
"""Buffer audio for later processing."""
|
||||||
|
self._audio_buffer += audio
|
||||||
|
|
||||||
|
async def receive_transcripts(self) -> AsyncIterator[ASRResult]:
|
||||||
|
"""Yield transcription results."""
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
result = await asyncio.wait_for(
|
||||||
|
self._transcript_queue.get(),
|
||||||
|
timeout=0.1
|
||||||
|
)
|
||||||
|
yield result
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
continue
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
|
||||||
|
def set_text(self, text: str) -> None:
|
||||||
|
"""
|
||||||
|
Set the current transcript text directly.
|
||||||
|
|
||||||
|
This allows external integration (e.g., Whisper, other ASR)
|
||||||
|
to provide transcripts.
|
||||||
|
"""
|
||||||
|
self._current_text = text
|
||||||
|
result = ASRResult(text=text, is_final=False)
|
||||||
|
asyncio.create_task(self._transcript_queue.put(result))
|
||||||
|
|
||||||
|
def get_and_clear_text(self) -> str:
|
||||||
|
"""Get accumulated text and clear buffer."""
|
||||||
|
text = self._current_text
|
||||||
|
self._current_text = ""
|
||||||
|
self._audio_buffer = b""
|
||||||
|
return text
|
||||||
|
|
||||||
|
def get_audio_buffer(self) -> bytes:
|
||||||
|
"""Get accumulated audio buffer."""
|
||||||
|
return self._audio_buffer
|
||||||
|
|
||||||
|
def clear_audio_buffer(self) -> None:
|
||||||
|
"""Clear audio buffer."""
|
||||||
|
self._audio_buffer = b""
|
||||||
|
|
||||||
|
|
||||||
|
class MockASRService(BaseASRService):
|
||||||
|
"""
|
||||||
|
Mock ASR service for testing without actual recognition.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, sample_rate: int = 16000, language: str = "en"):
|
||||||
|
super().__init__(sample_rate=sample_rate, language=language)
|
||||||
|
self._transcript_queue: asyncio.Queue[ASRResult] = asyncio.Queue()
|
||||||
|
self._mock_texts = [
|
||||||
|
"Hello, how are you?",
|
||||||
|
"That's interesting.",
|
||||||
|
"Tell me more about that.",
|
||||||
|
"I understand.",
|
||||||
|
]
|
||||||
|
self._text_index = 0
|
||||||
|
|
||||||
|
async def connect(self) -> None:
|
||||||
|
self.state = ServiceState.CONNECTED
|
||||||
|
logger.info("Mock ASR service connected")
|
||||||
|
|
||||||
|
async def disconnect(self) -> None:
|
||||||
|
self.state = ServiceState.DISCONNECTED
|
||||||
|
logger.info("Mock ASR service disconnected")
|
||||||
|
|
||||||
|
async def send_audio(self, audio: bytes) -> None:
|
||||||
|
"""Mock audio processing - generates fake transcripts periodically."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def trigger_transcript(self) -> None:
|
||||||
|
"""Manually trigger a transcript (for testing)."""
|
||||||
|
text = self._mock_texts[self._text_index % len(self._mock_texts)]
|
||||||
|
self._text_index += 1
|
||||||
|
|
||||||
|
result = ASRResult(text=text, is_final=True, confidence=0.95)
|
||||||
|
asyncio.create_task(self._transcript_queue.put(result))
|
||||||
|
|
||||||
|
async def receive_transcripts(self) -> AsyncIterator[ASRResult]:
|
||||||
|
"""Yield transcription results."""
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
result = await asyncio.wait_for(
|
||||||
|
self._transcript_queue.get(),
|
||||||
|
timeout=0.1
|
||||||
|
)
|
||||||
|
yield result
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
continue
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
244
services/base.py
Normal file
244
services/base.py
Normal file
@@ -0,0 +1,244 @@
|
|||||||
|
"""Base classes for AI services.
|
||||||
|
|
||||||
|
Defines abstract interfaces for ASR, LLM, and TTS services,
|
||||||
|
inspired by pipecat's service architecture and active-call's
|
||||||
|
StreamEngine pattern.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import AsyncIterator, Optional, List, Dict, Any
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class ServiceState(Enum):
|
||||||
|
"""Service connection state."""
|
||||||
|
DISCONNECTED = "disconnected"
|
||||||
|
CONNECTING = "connecting"
|
||||||
|
CONNECTED = "connected"
|
||||||
|
ERROR = "error"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ASRResult:
|
||||||
|
"""ASR transcription result."""
|
||||||
|
text: str
|
||||||
|
is_final: bool = False
|
||||||
|
confidence: float = 1.0
|
||||||
|
language: Optional[str] = None
|
||||||
|
start_time: Optional[float] = None
|
||||||
|
end_time: Optional[float] = None
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
status = "FINAL" if self.is_final else "PARTIAL"
|
||||||
|
return f"[{status}] {self.text}"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LLMMessage:
|
||||||
|
"""LLM conversation message."""
|
||||||
|
role: str # "system", "user", "assistant", "function"
|
||||||
|
content: str
|
||||||
|
name: Optional[str] = None # For function calls
|
||||||
|
function_call: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to API-compatible dict."""
|
||||||
|
d = {"role": self.role, "content": self.content}
|
||||||
|
if self.name:
|
||||||
|
d["name"] = self.name
|
||||||
|
if self.function_call:
|
||||||
|
d["function_call"] = self.function_call
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TTSChunk:
|
||||||
|
"""TTS audio chunk."""
|
||||||
|
audio: bytes # PCM audio data
|
||||||
|
sample_rate: int = 16000
|
||||||
|
channels: int = 1
|
||||||
|
bits_per_sample: int = 16
|
||||||
|
is_final: bool = False
|
||||||
|
text_offset: Optional[int] = None # Character offset in original text
|
||||||
|
|
||||||
|
|
||||||
|
class BaseASRService(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for ASR (Speech-to-Text) services.
|
||||||
|
|
||||||
|
Supports both streaming and non-streaming transcription.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, sample_rate: int = 16000, language: str = "en"):
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.language = language
|
||||||
|
self.state = ServiceState.DISCONNECTED
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def connect(self) -> None:
|
||||||
|
"""Establish connection to ASR service."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def disconnect(self) -> None:
|
||||||
|
"""Close connection to ASR service."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def send_audio(self, audio: bytes) -> None:
|
||||||
|
"""
|
||||||
|
Send audio chunk for transcription.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio: PCM audio data (16-bit, mono)
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def receive_transcripts(self) -> AsyncIterator[ASRResult]:
|
||||||
|
"""
|
||||||
|
Receive transcription results.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
ASRResult objects as they become available
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def transcribe(self, audio: bytes) -> ASRResult:
|
||||||
|
"""
|
||||||
|
Transcribe a complete audio buffer (non-streaming).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio: Complete PCM audio data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Final ASRResult
|
||||||
|
"""
|
||||||
|
# Default implementation using streaming
|
||||||
|
await self.send_audio(audio)
|
||||||
|
async for result in self.receive_transcripts():
|
||||||
|
if result.is_final:
|
||||||
|
return result
|
||||||
|
return ASRResult(text="", is_final=True)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseLLMService(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for LLM (Language Model) services.
|
||||||
|
|
||||||
|
Supports streaming responses for real-time conversation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model: str = "gpt-4"):
|
||||||
|
self.model = model
|
||||||
|
self.state = ServiceState.DISCONNECTED
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def connect(self) -> None:
|
||||||
|
"""Initialize LLM service connection."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def disconnect(self) -> None:
|
||||||
|
"""Close LLM service connection."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def generate(
|
||||||
|
self,
|
||||||
|
messages: List[LLMMessage],
|
||||||
|
temperature: float = 0.7,
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Generate a complete response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Conversation history
|
||||||
|
temperature: Sampling temperature
|
||||||
|
max_tokens: Maximum tokens to generate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Complete response text
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def generate_stream(
|
||||||
|
self,
|
||||||
|
messages: List[LLMMessage],
|
||||||
|
temperature: float = 0.7,
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
) -> AsyncIterator[str]:
|
||||||
|
"""
|
||||||
|
Generate response in streaming mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Conversation history
|
||||||
|
temperature: Sampling temperature
|
||||||
|
max_tokens: Maximum tokens to generate
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Text chunks as they are generated
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BaseTTSService(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for TTS (Text-to-Speech) services.
|
||||||
|
|
||||||
|
Supports streaming audio synthesis for low-latency playback.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
voice: str = "default",
|
||||||
|
sample_rate: int = 16000,
|
||||||
|
speed: float = 1.0
|
||||||
|
):
|
||||||
|
self.voice = voice
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.speed = speed
|
||||||
|
self.state = ServiceState.DISCONNECTED
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def connect(self) -> None:
|
||||||
|
"""Initialize TTS service connection."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def disconnect(self) -> None:
|
||||||
|
"""Close TTS service connection."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def synthesize(self, text: str) -> bytes:
|
||||||
|
"""
|
||||||
|
Synthesize complete audio for text (non-streaming).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to synthesize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Complete PCM audio data
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def synthesize_stream(self, text: str) -> AsyncIterator[TTSChunk]:
|
||||||
|
"""
|
||||||
|
Synthesize audio in streaming mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to synthesize
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
TTSChunk objects as audio is generated
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def cancel(self) -> None:
|
||||||
|
"""Cancel ongoing synthesis (for barge-in support)."""
|
||||||
|
pass
|
||||||
239
services/llm.py
Normal file
239
services/llm.py
Normal file
@@ -0,0 +1,239 @@
|
|||||||
|
"""LLM (Large Language Model) Service implementations.
|
||||||
|
|
||||||
|
Provides OpenAI-compatible LLM integration with streaming support
|
||||||
|
for real-time voice conversation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
from typing import AsyncIterator, Optional, List, Dict, Any
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from services.base import BaseLLMService, LLMMessage, ServiceState
|
||||||
|
|
||||||
|
# Try to import openai
|
||||||
|
try:
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
OPENAI_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
OPENAI_AVAILABLE = False
|
||||||
|
logger.warning("openai package not available - LLM service will be disabled")
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAILLMService(BaseLLMService):
|
||||||
|
"""
|
||||||
|
OpenAI-compatible LLM service.
|
||||||
|
|
||||||
|
Supports streaming responses for low-latency voice conversation.
|
||||||
|
Works with OpenAI API, Azure OpenAI, and compatible APIs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str = "gpt-4o-mini",
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
system_prompt: Optional[str] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize OpenAI LLM service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model name (e.g., "gpt-4o-mini", "gpt-4o")
|
||||||
|
api_key: OpenAI API key (defaults to OPENAI_API_KEY env var)
|
||||||
|
base_url: Custom API base URL (for Azure or compatible APIs)
|
||||||
|
system_prompt: Default system prompt for conversations
|
||||||
|
"""
|
||||||
|
super().__init__(model=model)
|
||||||
|
|
||||||
|
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||||
|
self.base_url = base_url or os.getenv("OPENAI_API_URL")
|
||||||
|
self.system_prompt = system_prompt or (
|
||||||
|
"You are a helpful, friendly voice assistant. "
|
||||||
|
"Keep your responses concise and conversational. "
|
||||||
|
"Respond naturally as if having a phone conversation."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.client: Optional[AsyncOpenAI] = None
|
||||||
|
self._cancel_event = asyncio.Event()
|
||||||
|
|
||||||
|
async def connect(self) -> None:
|
||||||
|
"""Initialize OpenAI client."""
|
||||||
|
if not OPENAI_AVAILABLE:
|
||||||
|
raise RuntimeError("openai package not installed")
|
||||||
|
|
||||||
|
if not self.api_key:
|
||||||
|
raise ValueError("OpenAI API key not provided")
|
||||||
|
|
||||||
|
self.client = AsyncOpenAI(
|
||||||
|
api_key=self.api_key,
|
||||||
|
base_url=self.base_url
|
||||||
|
)
|
||||||
|
self.state = ServiceState.CONNECTED
|
||||||
|
logger.info(f"OpenAI LLM service connected: model={self.model}")
|
||||||
|
|
||||||
|
async def disconnect(self) -> None:
|
||||||
|
"""Close OpenAI client."""
|
||||||
|
if self.client:
|
||||||
|
await self.client.close()
|
||||||
|
self.client = None
|
||||||
|
self.state = ServiceState.DISCONNECTED
|
||||||
|
logger.info("OpenAI LLM service disconnected")
|
||||||
|
|
||||||
|
def _prepare_messages(self, messages: List[LLMMessage]) -> List[Dict[str, Any]]:
|
||||||
|
"""Prepare messages list with system prompt."""
|
||||||
|
result = []
|
||||||
|
|
||||||
|
# Add system prompt if not already present
|
||||||
|
has_system = any(m.role == "system" for m in messages)
|
||||||
|
if not has_system and self.system_prompt:
|
||||||
|
result.append({"role": "system", "content": self.system_prompt})
|
||||||
|
|
||||||
|
# Add all messages
|
||||||
|
for msg in messages:
|
||||||
|
result.append(msg.to_dict())
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def generate(
|
||||||
|
self,
|
||||||
|
messages: List[LLMMessage],
|
||||||
|
temperature: float = 0.7,
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Generate a complete response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Conversation history
|
||||||
|
temperature: Sampling temperature
|
||||||
|
max_tokens: Maximum tokens to generate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Complete response text
|
||||||
|
"""
|
||||||
|
if not self.client:
|
||||||
|
raise RuntimeError("LLM service not connected")
|
||||||
|
|
||||||
|
prepared = self._prepare_messages(messages)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self.client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=prepared,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
content = response.choices[0].message.content or ""
|
||||||
|
logger.debug(f"LLM response: {content[:100]}...")
|
||||||
|
return content
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"LLM generation error: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def generate_stream(
|
||||||
|
self,
|
||||||
|
messages: List[LLMMessage],
|
||||||
|
temperature: float = 0.7,
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
) -> AsyncIterator[str]:
|
||||||
|
"""
|
||||||
|
Generate response in streaming mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Conversation history
|
||||||
|
temperature: Sampling temperature
|
||||||
|
max_tokens: Maximum tokens to generate
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Text chunks as they are generated
|
||||||
|
"""
|
||||||
|
if not self.client:
|
||||||
|
raise RuntimeError("LLM service not connected")
|
||||||
|
|
||||||
|
prepared = self._prepare_messages(messages)
|
||||||
|
self._cancel_event.clear()
|
||||||
|
|
||||||
|
try:
|
||||||
|
stream = await self.client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=prepared,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
stream=True
|
||||||
|
)
|
||||||
|
|
||||||
|
async for chunk in stream:
|
||||||
|
# Check for cancellation
|
||||||
|
if self._cancel_event.is_set():
|
||||||
|
logger.info("LLM stream cancelled")
|
||||||
|
break
|
||||||
|
|
||||||
|
if chunk.choices and chunk.choices[0].delta.content:
|
||||||
|
content = chunk.choices[0].delta.content
|
||||||
|
yield content
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("LLM stream cancelled via asyncio")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"LLM streaming error: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def cancel(self) -> None:
|
||||||
|
"""Cancel ongoing generation."""
|
||||||
|
self._cancel_event.set()
|
||||||
|
|
||||||
|
|
||||||
|
class MockLLMService(BaseLLMService):
|
||||||
|
"""
|
||||||
|
Mock LLM service for testing without API calls.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, response_delay: float = 0.5):
|
||||||
|
super().__init__(model="mock")
|
||||||
|
self.response_delay = response_delay
|
||||||
|
self.responses = [
|
||||||
|
"Hello! How can I help you today?",
|
||||||
|
"That's an interesting question. Let me think about it.",
|
||||||
|
"I understand. Is there anything else you'd like to know?",
|
||||||
|
"Great! I'm here if you need anything else.",
|
||||||
|
]
|
||||||
|
self._response_index = 0
|
||||||
|
|
||||||
|
async def connect(self) -> None:
|
||||||
|
self.state = ServiceState.CONNECTED
|
||||||
|
logger.info("Mock LLM service connected")
|
||||||
|
|
||||||
|
async def disconnect(self) -> None:
|
||||||
|
self.state = ServiceState.DISCONNECTED
|
||||||
|
logger.info("Mock LLM service disconnected")
|
||||||
|
|
||||||
|
async def generate(
|
||||||
|
self,
|
||||||
|
messages: List[LLMMessage],
|
||||||
|
temperature: float = 0.7,
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
) -> str:
|
||||||
|
await asyncio.sleep(self.response_delay)
|
||||||
|
response = self.responses[self._response_index % len(self.responses)]
|
||||||
|
self._response_index += 1
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def generate_stream(
|
||||||
|
self,
|
||||||
|
messages: List[LLMMessage],
|
||||||
|
temperature: float = 0.7,
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
) -> AsyncIterator[str]:
|
||||||
|
response = await self.generate(messages, temperature, max_tokens)
|
||||||
|
|
||||||
|
# Stream word by word
|
||||||
|
words = response.split()
|
||||||
|
for i, word in enumerate(words):
|
||||||
|
if i > 0:
|
||||||
|
yield " "
|
||||||
|
yield word
|
||||||
|
await asyncio.sleep(0.05) # Simulate streaming delay
|
||||||
548
services/realtime.py
Normal file
548
services/realtime.py
Normal file
@@ -0,0 +1,548 @@
|
|||||||
|
"""OpenAI Realtime API Service.
|
||||||
|
|
||||||
|
Provides true duplex voice conversation using OpenAI's Realtime API,
|
||||||
|
similar to active-call's RealtimeProcessor. This bypasses the need for
|
||||||
|
separate ASR/LLM/TTS services by handling everything server-side.
|
||||||
|
|
||||||
|
The Realtime API provides:
|
||||||
|
- Server-side VAD with turn detection
|
||||||
|
- Streaming speech-to-text
|
||||||
|
- Streaming LLM responses
|
||||||
|
- Streaming text-to-speech
|
||||||
|
- Function calling support
|
||||||
|
- Barge-in/interruption handling
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import base64
|
||||||
|
from typing import Optional, Dict, Any, Callable, Awaitable, List
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
try:
|
||||||
|
import websockets
|
||||||
|
WEBSOCKETS_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
WEBSOCKETS_AVAILABLE = False
|
||||||
|
logger.warning("websockets not available - Realtime API will be disabled")
|
||||||
|
|
||||||
|
|
||||||
|
class RealtimeState(Enum):
|
||||||
|
"""Realtime API connection state."""
|
||||||
|
DISCONNECTED = "disconnected"
|
||||||
|
CONNECTING = "connecting"
|
||||||
|
CONNECTED = "connected"
|
||||||
|
ERROR = "error"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RealtimeConfig:
|
||||||
|
"""Configuration for OpenAI Realtime API."""
|
||||||
|
|
||||||
|
# API Configuration
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
model: str = "gpt-4o-realtime-preview"
|
||||||
|
endpoint: Optional[str] = None # For Azure or custom endpoints
|
||||||
|
|
||||||
|
# Voice Configuration
|
||||||
|
voice: str = "alloy" # alloy, echo, shimmer, etc.
|
||||||
|
instructions: str = (
|
||||||
|
"You are a helpful, friendly voice assistant. "
|
||||||
|
"Keep your responses concise and conversational."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Turn Detection (Server-side VAD)
|
||||||
|
turn_detection: Optional[Dict[str, Any]] = field(default_factory=lambda: {
|
||||||
|
"type": "server_vad",
|
||||||
|
"threshold": 0.5,
|
||||||
|
"prefix_padding_ms": 300,
|
||||||
|
"silence_duration_ms": 500
|
||||||
|
})
|
||||||
|
|
||||||
|
# Audio Configuration
|
||||||
|
input_audio_format: str = "pcm16"
|
||||||
|
output_audio_format: str = "pcm16"
|
||||||
|
|
||||||
|
# Tools/Functions
|
||||||
|
tools: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class RealtimeService:
|
||||||
|
"""
|
||||||
|
OpenAI Realtime API service for true duplex voice conversation.
|
||||||
|
|
||||||
|
This service handles the entire voice conversation pipeline:
|
||||||
|
1. Audio input → Server-side VAD → Speech-to-text
|
||||||
|
2. Text → LLM processing → Response generation
|
||||||
|
3. Response → Text-to-speech → Audio output
|
||||||
|
|
||||||
|
Events emitted:
|
||||||
|
- on_audio: Audio output from the assistant
|
||||||
|
- on_transcript: Text transcript (user or assistant)
|
||||||
|
- on_speech_started: User started speaking
|
||||||
|
- on_speech_stopped: User stopped speaking
|
||||||
|
- on_response_started: Assistant started responding
|
||||||
|
- on_response_done: Assistant finished responding
|
||||||
|
- on_function_call: Function call requested
|
||||||
|
- on_error: Error occurred
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: Optional[RealtimeConfig] = None):
|
||||||
|
"""
|
||||||
|
Initialize Realtime API service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Realtime configuration (uses defaults if not provided)
|
||||||
|
"""
|
||||||
|
self.config = config or RealtimeConfig()
|
||||||
|
self.config.api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")
|
||||||
|
|
||||||
|
self.state = RealtimeState.DISCONNECTED
|
||||||
|
self._ws = None
|
||||||
|
self._receive_task: Optional[asyncio.Task] = None
|
||||||
|
self._cancel_event = asyncio.Event()
|
||||||
|
|
||||||
|
# Event callbacks
|
||||||
|
self._callbacks: Dict[str, List[Callable]] = {
|
||||||
|
"on_audio": [],
|
||||||
|
"on_transcript": [],
|
||||||
|
"on_speech_started": [],
|
||||||
|
"on_speech_stopped": [],
|
||||||
|
"on_response_started": [],
|
||||||
|
"on_response_done": [],
|
||||||
|
"on_function_call": [],
|
||||||
|
"on_error": [],
|
||||||
|
"on_interrupted": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.debug(f"RealtimeService initialized with model={self.config.model}")
|
||||||
|
|
||||||
|
def on(self, event: str, callback: Callable[..., Awaitable[None]]) -> None:
|
||||||
|
"""
|
||||||
|
Register event callback.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event: Event name
|
||||||
|
callback: Async callback function
|
||||||
|
"""
|
||||||
|
if event in self._callbacks:
|
||||||
|
self._callbacks[event].append(callback)
|
||||||
|
|
||||||
|
async def _emit(self, event: str, *args, **kwargs) -> None:
|
||||||
|
"""Emit event to all registered callbacks."""
|
||||||
|
for callback in self._callbacks.get(event, []):
|
||||||
|
try:
|
||||||
|
await callback(*args, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Event callback error ({event}): {e}")
|
||||||
|
|
||||||
|
async def connect(self) -> None:
|
||||||
|
"""Connect to OpenAI Realtime API."""
|
||||||
|
if not WEBSOCKETS_AVAILABLE:
|
||||||
|
raise RuntimeError("websockets package not installed")
|
||||||
|
|
||||||
|
if not self.config.api_key:
|
||||||
|
raise ValueError("OpenAI API key not provided")
|
||||||
|
|
||||||
|
self.state = RealtimeState.CONNECTING
|
||||||
|
|
||||||
|
# Build URL
|
||||||
|
if self.config.endpoint:
|
||||||
|
# Azure or custom endpoint
|
||||||
|
url = f"{self.config.endpoint}/openai/realtime?api-version=2024-10-01-preview&deployment={self.config.model}"
|
||||||
|
else:
|
||||||
|
# OpenAI endpoint
|
||||||
|
url = f"wss://api.openai.com/v1/realtime?model={self.config.model}"
|
||||||
|
|
||||||
|
# Build headers
|
||||||
|
headers = {}
|
||||||
|
if self.config.endpoint:
|
||||||
|
headers["api-key"] = self.config.api_key
|
||||||
|
else:
|
||||||
|
headers["Authorization"] = f"Bearer {self.config.api_key}"
|
||||||
|
headers["OpenAI-Beta"] = "realtime=v1"
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"Connecting to Realtime API: {url}")
|
||||||
|
self._ws = await websockets.connect(url, extra_headers=headers)
|
||||||
|
|
||||||
|
# Send session configuration
|
||||||
|
await self._configure_session()
|
||||||
|
|
||||||
|
# Start receive loop
|
||||||
|
self._receive_task = asyncio.create_task(self._receive_loop())
|
||||||
|
|
||||||
|
self.state = RealtimeState.CONNECTED
|
||||||
|
logger.info("Realtime API connected successfully")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.state = RealtimeState.ERROR
|
||||||
|
logger.error(f"Realtime API connection failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _configure_session(self) -> None:
|
||||||
|
"""Send session configuration to server."""
|
||||||
|
session_config = {
|
||||||
|
"type": "session.update",
|
||||||
|
"session": {
|
||||||
|
"modalities": ["text", "audio"],
|
||||||
|
"instructions": self.config.instructions,
|
||||||
|
"voice": self.config.voice,
|
||||||
|
"input_audio_format": self.config.input_audio_format,
|
||||||
|
"output_audio_format": self.config.output_audio_format,
|
||||||
|
"turn_detection": self.config.turn_detection,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.config.tools:
|
||||||
|
session_config["session"]["tools"] = self.config.tools
|
||||||
|
|
||||||
|
await self._send(session_config)
|
||||||
|
logger.debug("Session configuration sent")
|
||||||
|
|
||||||
|
async def _send(self, data: Dict[str, Any]) -> None:
|
||||||
|
"""Send JSON data to server."""
|
||||||
|
if self._ws:
|
||||||
|
await self._ws.send(json.dumps(data))
|
||||||
|
|
||||||
|
async def send_audio(self, audio_bytes: bytes) -> None:
|
||||||
|
"""
|
||||||
|
Send audio to the Realtime API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_bytes: PCM audio data (16-bit, mono, 24kHz by default)
|
||||||
|
"""
|
||||||
|
if self.state != RealtimeState.CONNECTED:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Encode audio as base64
|
||||||
|
audio_b64 = base64.standard_b64encode(audio_bytes).decode()
|
||||||
|
|
||||||
|
await self._send({
|
||||||
|
"type": "input_audio_buffer.append",
|
||||||
|
"audio": audio_b64
|
||||||
|
})
|
||||||
|
|
||||||
|
async def send_text(self, text: str) -> None:
|
||||||
|
"""
|
||||||
|
Send text input (bypassing audio).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: User text input
|
||||||
|
"""
|
||||||
|
if self.state != RealtimeState.CONNECTED:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create a conversation item with user text
|
||||||
|
await self._send({
|
||||||
|
"type": "conversation.item.create",
|
||||||
|
"item": {
|
||||||
|
"type": "message",
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"type": "input_text", "text": text}]
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
# Trigger response
|
||||||
|
await self._send({"type": "response.create"})
|
||||||
|
|
||||||
|
async def cancel_response(self) -> None:
|
||||||
|
"""Cancel the current response (for barge-in)."""
|
||||||
|
if self.state != RealtimeState.CONNECTED:
|
||||||
|
return
|
||||||
|
|
||||||
|
await self._send({"type": "response.cancel"})
|
||||||
|
logger.debug("Response cancelled")
|
||||||
|
|
||||||
|
async def commit_audio(self) -> None:
|
||||||
|
"""Commit the audio buffer and trigger response."""
|
||||||
|
if self.state != RealtimeState.CONNECTED:
|
||||||
|
return
|
||||||
|
|
||||||
|
await self._send({"type": "input_audio_buffer.commit"})
|
||||||
|
await self._send({"type": "response.create"})
|
||||||
|
|
||||||
|
async def clear_audio_buffer(self) -> None:
|
||||||
|
"""Clear the input audio buffer."""
|
||||||
|
if self.state != RealtimeState.CONNECTED:
|
||||||
|
return
|
||||||
|
|
||||||
|
await self._send({"type": "input_audio_buffer.clear"})
|
||||||
|
|
||||||
|
async def submit_function_result(self, call_id: str, result: str) -> None:
|
||||||
|
"""
|
||||||
|
Submit function call result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
call_id: The function call ID
|
||||||
|
result: JSON string result
|
||||||
|
"""
|
||||||
|
if self.state != RealtimeState.CONNECTED:
|
||||||
|
return
|
||||||
|
|
||||||
|
await self._send({
|
||||||
|
"type": "conversation.item.create",
|
||||||
|
"item": {
|
||||||
|
"type": "function_call_output",
|
||||||
|
"call_id": call_id,
|
||||||
|
"output": result
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
# Trigger response with the function result
|
||||||
|
await self._send({"type": "response.create"})
|
||||||
|
|
||||||
|
async def _receive_loop(self) -> None:
|
||||||
|
"""Receive and process messages from the Realtime API."""
|
||||||
|
if not self._ws:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for message in self._ws:
|
||||||
|
try:
|
||||||
|
data = json.loads(message)
|
||||||
|
await self._handle_event(data)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning(f"Invalid JSON received: {message[:100]}")
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.debug("Receive loop cancelled")
|
||||||
|
except websockets.ConnectionClosed as e:
|
||||||
|
logger.info(f"WebSocket closed: {e}")
|
||||||
|
self.state = RealtimeState.DISCONNECTED
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Receive loop error: {e}")
|
||||||
|
self.state = RealtimeState.ERROR
|
||||||
|
|
||||||
|
async def _handle_event(self, data: Dict[str, Any]) -> None:
|
||||||
|
"""Handle incoming event from Realtime API."""
|
||||||
|
event_type = data.get("type", "unknown")
|
||||||
|
|
||||||
|
# Audio delta - streaming audio output
|
||||||
|
if event_type == "response.audio.delta":
|
||||||
|
if "delta" in data:
|
||||||
|
audio_bytes = base64.standard_b64decode(data["delta"])
|
||||||
|
await self._emit("on_audio", audio_bytes)
|
||||||
|
|
||||||
|
# Audio transcript delta - streaming text
|
||||||
|
elif event_type == "response.audio_transcript.delta":
|
||||||
|
if "delta" in data:
|
||||||
|
await self._emit("on_transcript", data["delta"], "assistant", False)
|
||||||
|
|
||||||
|
# Audio transcript done
|
||||||
|
elif event_type == "response.audio_transcript.done":
|
||||||
|
if "transcript" in data:
|
||||||
|
await self._emit("on_transcript", data["transcript"], "assistant", True)
|
||||||
|
|
||||||
|
# Input audio transcript (user speech)
|
||||||
|
elif event_type == "conversation.item.input_audio_transcription.completed":
|
||||||
|
if "transcript" in data:
|
||||||
|
await self._emit("on_transcript", data["transcript"], "user", True)
|
||||||
|
|
||||||
|
# Speech started (server VAD detected speech)
|
||||||
|
elif event_type == "input_audio_buffer.speech_started":
|
||||||
|
await self._emit("on_speech_started", data.get("audio_start_ms", 0))
|
||||||
|
|
||||||
|
# Speech stopped
|
||||||
|
elif event_type == "input_audio_buffer.speech_stopped":
|
||||||
|
await self._emit("on_speech_stopped", data.get("audio_end_ms", 0))
|
||||||
|
|
||||||
|
# Response started
|
||||||
|
elif event_type == "response.created":
|
||||||
|
await self._emit("on_response_started", data.get("response", {}))
|
||||||
|
|
||||||
|
# Response done
|
||||||
|
elif event_type == "response.done":
|
||||||
|
await self._emit("on_response_done", data.get("response", {}))
|
||||||
|
|
||||||
|
# Function call
|
||||||
|
elif event_type == "response.function_call_arguments.done":
|
||||||
|
call_id = data.get("call_id")
|
||||||
|
name = data.get("name")
|
||||||
|
arguments = data.get("arguments", "{}")
|
||||||
|
await self._emit("on_function_call", call_id, name, arguments)
|
||||||
|
|
||||||
|
# Error
|
||||||
|
elif event_type == "error":
|
||||||
|
error = data.get("error", {})
|
||||||
|
logger.error(f"Realtime API error: {error}")
|
||||||
|
await self._emit("on_error", error)
|
||||||
|
|
||||||
|
# Session events
|
||||||
|
elif event_type == "session.created":
|
||||||
|
logger.info("Session created")
|
||||||
|
elif event_type == "session.updated":
|
||||||
|
logger.debug("Session updated")
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.debug(f"Unhandled event type: {event_type}")
|
||||||
|
|
||||||
|
async def disconnect(self) -> None:
|
||||||
|
"""Disconnect from Realtime API."""
|
||||||
|
self._cancel_event.set()
|
||||||
|
|
||||||
|
if self._receive_task:
|
||||||
|
self._receive_task.cancel()
|
||||||
|
try:
|
||||||
|
await self._receive_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if self._ws:
|
||||||
|
await self._ws.close()
|
||||||
|
self._ws = None
|
||||||
|
|
||||||
|
self.state = RealtimeState.DISCONNECTED
|
||||||
|
logger.info("Realtime API disconnected")
|
||||||
|
|
||||||
|
|
||||||
|
class RealtimePipeline:
|
||||||
|
"""
|
||||||
|
Pipeline adapter for RealtimeService.
|
||||||
|
|
||||||
|
Provides a compatible interface with DuplexPipeline but uses
|
||||||
|
OpenAI Realtime API for all processing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
transport,
|
||||||
|
session_id: str,
|
||||||
|
config: Optional[RealtimeConfig] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize Realtime pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
transport: Transport for sending audio/events
|
||||||
|
session_id: Session identifier
|
||||||
|
config: Realtime configuration
|
||||||
|
"""
|
||||||
|
self.transport = transport
|
||||||
|
self.session_id = session_id
|
||||||
|
|
||||||
|
self.service = RealtimeService(config)
|
||||||
|
|
||||||
|
# Register callbacks
|
||||||
|
self.service.on("on_audio", self._on_audio)
|
||||||
|
self.service.on("on_transcript", self._on_transcript)
|
||||||
|
self.service.on("on_speech_started", self._on_speech_started)
|
||||||
|
self.service.on("on_speech_stopped", self._on_speech_stopped)
|
||||||
|
self.service.on("on_response_started", self._on_response_started)
|
||||||
|
self.service.on("on_response_done", self._on_response_done)
|
||||||
|
self.service.on("on_error", self._on_error)
|
||||||
|
|
||||||
|
self._is_speaking = False
|
||||||
|
self._running = True
|
||||||
|
|
||||||
|
logger.info(f"RealtimePipeline initialized for session {session_id}")
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
"""Start the pipeline."""
|
||||||
|
await self.service.connect()
|
||||||
|
|
||||||
|
async def process_audio(self, pcm_bytes: bytes) -> None:
|
||||||
|
"""
|
||||||
|
Process incoming audio.
|
||||||
|
|
||||||
|
Note: Realtime API expects 24kHz audio by default.
|
||||||
|
You may need to resample from 16kHz.
|
||||||
|
"""
|
||||||
|
if not self._running:
|
||||||
|
return
|
||||||
|
|
||||||
|
# TODO: Resample from 16kHz to 24kHz if needed
|
||||||
|
await self.service.send_audio(pcm_bytes)
|
||||||
|
|
||||||
|
async def process_text(self, text: str) -> None:
|
||||||
|
"""Process text input."""
|
||||||
|
if not self._running:
|
||||||
|
return
|
||||||
|
|
||||||
|
await self.service.send_text(text)
|
||||||
|
|
||||||
|
async def interrupt(self) -> None:
|
||||||
|
"""Interrupt current response."""
|
||||||
|
await self.service.cancel_response()
|
||||||
|
await self.transport.send_event({
|
||||||
|
"event": "interrupt",
|
||||||
|
"trackId": self.session_id,
|
||||||
|
"timestamp": self._get_timestamp_ms()
|
||||||
|
})
|
||||||
|
|
||||||
|
async def cleanup(self) -> None:
|
||||||
|
"""Cleanup resources."""
|
||||||
|
self._running = False
|
||||||
|
await self.service.disconnect()
|
||||||
|
|
||||||
|
# Event handlers
|
||||||
|
|
||||||
|
async def _on_audio(self, audio_bytes: bytes) -> None:
|
||||||
|
"""Handle audio output."""
|
||||||
|
await self.transport.send_audio(audio_bytes)
|
||||||
|
|
||||||
|
async def _on_transcript(self, text: str, role: str, is_final: bool) -> None:
|
||||||
|
"""Handle transcript."""
|
||||||
|
logger.info(f"[{role.upper()}] {text[:50]}..." if len(text) > 50 else f"[{role.upper()}] {text}")
|
||||||
|
|
||||||
|
async def _on_speech_started(self, start_ms: int) -> None:
|
||||||
|
"""Handle user speech start."""
|
||||||
|
self._is_speaking = True
|
||||||
|
await self.transport.send_event({
|
||||||
|
"event": "speaking",
|
||||||
|
"trackId": self.session_id,
|
||||||
|
"timestamp": self._get_timestamp_ms(),
|
||||||
|
"startTime": start_ms
|
||||||
|
})
|
||||||
|
|
||||||
|
# Cancel any ongoing response (barge-in)
|
||||||
|
await self.service.cancel_response()
|
||||||
|
|
||||||
|
async def _on_speech_stopped(self, end_ms: int) -> None:
|
||||||
|
"""Handle user speech stop."""
|
||||||
|
self._is_speaking = False
|
||||||
|
await self.transport.send_event({
|
||||||
|
"event": "silence",
|
||||||
|
"trackId": self.session_id,
|
||||||
|
"timestamp": self._get_timestamp_ms(),
|
||||||
|
"duration": end_ms
|
||||||
|
})
|
||||||
|
|
||||||
|
async def _on_response_started(self, response: Dict) -> None:
|
||||||
|
"""Handle response start."""
|
||||||
|
await self.transport.send_event({
|
||||||
|
"event": "trackStart",
|
||||||
|
"trackId": self.session_id,
|
||||||
|
"timestamp": self._get_timestamp_ms()
|
||||||
|
})
|
||||||
|
|
||||||
|
async def _on_response_done(self, response: Dict) -> None:
|
||||||
|
"""Handle response complete."""
|
||||||
|
await self.transport.send_event({
|
||||||
|
"event": "trackEnd",
|
||||||
|
"trackId": self.session_id,
|
||||||
|
"timestamp": self._get_timestamp_ms()
|
||||||
|
})
|
||||||
|
|
||||||
|
async def _on_error(self, error: Dict) -> None:
|
||||||
|
"""Handle error."""
|
||||||
|
await self.transport.send_event({
|
||||||
|
"event": "error",
|
||||||
|
"trackId": self.session_id,
|
||||||
|
"timestamp": self._get_timestamp_ms(),
|
||||||
|
"sender": "realtime",
|
||||||
|
"error": str(error)
|
||||||
|
})
|
||||||
|
|
||||||
|
def _get_timestamp_ms(self) -> int:
|
||||||
|
"""Get current timestamp in milliseconds."""
|
||||||
|
import time
|
||||||
|
return int(time.time() * 1000)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_speaking(self) -> bool:
|
||||||
|
"""Check if user is speaking."""
|
||||||
|
return self._is_speaking
|
||||||
317
services/siliconflow_asr.py
Normal file
317
services/siliconflow_asr.py
Normal file
@@ -0,0 +1,317 @@
|
|||||||
|
"""SiliconFlow ASR (Automatic Speech Recognition) Service.
|
||||||
|
|
||||||
|
Uses the SiliconFlow API for speech-to-text transcription.
|
||||||
|
API: https://docs.siliconflow.cn/cn/api-reference/audio/create-audio-transcriptions
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import io
|
||||||
|
import wave
|
||||||
|
from typing import AsyncIterator, Optional, Callable, Awaitable
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
try:
|
||||||
|
import aiohttp
|
||||||
|
AIOHTTP_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
AIOHTTP_AVAILABLE = False
|
||||||
|
logger.warning("aiohttp not available - SiliconFlowASRService will not work")
|
||||||
|
|
||||||
|
from services.base import BaseASRService, ASRResult, ServiceState
|
||||||
|
|
||||||
|
|
||||||
|
class SiliconFlowASRService(BaseASRService):
|
||||||
|
"""
|
||||||
|
SiliconFlow ASR service for speech-to-text transcription.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Buffers incoming audio chunks
|
||||||
|
- Provides interim transcriptions periodically (for streaming to client)
|
||||||
|
- Final transcription on EOU
|
||||||
|
|
||||||
|
API Details:
|
||||||
|
- Endpoint: POST https://api.siliconflow.cn/v1/audio/transcriptions
|
||||||
|
- Models: FunAudioLLM/SenseVoiceSmall (default), TeleAI/TeleSpeechASR
|
||||||
|
- Input: Audio file (multipart/form-data)
|
||||||
|
- Output: {"text": "transcribed text"}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Supported models
|
||||||
|
MODELS = {
|
||||||
|
"sensevoice": "FunAudioLLM/SenseVoiceSmall",
|
||||||
|
"telespeech": "TeleAI/TeleSpeechASR",
|
||||||
|
}
|
||||||
|
|
||||||
|
API_URL = "https://api.siliconflow.cn/v1/audio/transcriptions"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str,
|
||||||
|
model: str = "FunAudioLLM/SenseVoiceSmall",
|
||||||
|
sample_rate: int = 16000,
|
||||||
|
language: str = "auto",
|
||||||
|
interim_interval_ms: int = 500, # How often to send interim results
|
||||||
|
min_audio_for_interim_ms: int = 300, # Min audio before first interim
|
||||||
|
on_transcript: Optional[Callable[[str, bool], Awaitable[None]]] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize SiliconFlow ASR service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: SiliconFlow API key
|
||||||
|
model: ASR model name or alias
|
||||||
|
sample_rate: Audio sample rate (16000 recommended)
|
||||||
|
language: Language code (auto for automatic detection)
|
||||||
|
interim_interval_ms: How often to generate interim transcriptions
|
||||||
|
min_audio_for_interim_ms: Minimum audio duration before first interim
|
||||||
|
on_transcript: Callback for transcription results (text, is_final)
|
||||||
|
"""
|
||||||
|
super().__init__(sample_rate=sample_rate, language=language)
|
||||||
|
|
||||||
|
if not AIOHTTP_AVAILABLE:
|
||||||
|
raise RuntimeError("aiohttp is required for SiliconFlowASRService")
|
||||||
|
|
||||||
|
self.api_key = api_key
|
||||||
|
self.model = self.MODELS.get(model.lower(), model)
|
||||||
|
self.interim_interval_ms = interim_interval_ms
|
||||||
|
self.min_audio_for_interim_ms = min_audio_for_interim_ms
|
||||||
|
self.on_transcript = on_transcript
|
||||||
|
|
||||||
|
# Session
|
||||||
|
self._session: Optional[aiohttp.ClientSession] = None
|
||||||
|
|
||||||
|
# Audio buffer
|
||||||
|
self._audio_buffer: bytes = b""
|
||||||
|
self._current_text: str = ""
|
||||||
|
self._last_interim_time: float = 0
|
||||||
|
|
||||||
|
# Transcript queue for async iteration
|
||||||
|
self._transcript_queue: asyncio.Queue[ASRResult] = asyncio.Queue()
|
||||||
|
|
||||||
|
# Background task for interim results
|
||||||
|
self._interim_task: Optional[asyncio.Task] = None
|
||||||
|
self._running = False
|
||||||
|
|
||||||
|
logger.info(f"SiliconFlowASRService initialized with model: {self.model}")
|
||||||
|
|
||||||
|
async def connect(self) -> None:
|
||||||
|
"""Connect to the service."""
|
||||||
|
self._session = aiohttp.ClientSession(
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {self.api_key}"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self._running = True
|
||||||
|
self.state = ServiceState.CONNECTED
|
||||||
|
logger.info("SiliconFlowASRService connected")
|
||||||
|
|
||||||
|
async def disconnect(self) -> None:
|
||||||
|
"""Disconnect and cleanup."""
|
||||||
|
self._running = False
|
||||||
|
|
||||||
|
if self._interim_task:
|
||||||
|
self._interim_task.cancel()
|
||||||
|
try:
|
||||||
|
await self._interim_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
self._interim_task = None
|
||||||
|
|
||||||
|
if self._session:
|
||||||
|
await self._session.close()
|
||||||
|
self._session = None
|
||||||
|
|
||||||
|
self._audio_buffer = b""
|
||||||
|
self._current_text = ""
|
||||||
|
self.state = ServiceState.DISCONNECTED
|
||||||
|
logger.info("SiliconFlowASRService disconnected")
|
||||||
|
|
||||||
|
async def send_audio(self, audio: bytes) -> None:
|
||||||
|
"""
|
||||||
|
Buffer incoming audio data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio: PCM audio data (16-bit, mono)
|
||||||
|
"""
|
||||||
|
self._audio_buffer += audio
|
||||||
|
|
||||||
|
async def transcribe_buffer(self, is_final: bool = False) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Transcribe current audio buffer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
is_final: Whether this is the final transcription
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Transcribed text or None if not enough audio
|
||||||
|
"""
|
||||||
|
if not self._session:
|
||||||
|
logger.warning("ASR session not connected")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check minimum audio duration
|
||||||
|
audio_duration_ms = len(self._audio_buffer) / (self.sample_rate * 2) * 1000
|
||||||
|
|
||||||
|
if not is_final and audio_duration_ms < self.min_audio_for_interim_ms:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if audio_duration_ms < 100: # Less than 100ms - too short
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Convert PCM to WAV in memory
|
||||||
|
wav_buffer = io.BytesIO()
|
||||||
|
with wave.open(wav_buffer, 'wb') as wav_file:
|
||||||
|
wav_file.setnchannels(1)
|
||||||
|
wav_file.setsampwidth(2) # 16-bit
|
||||||
|
wav_file.setframerate(self.sample_rate)
|
||||||
|
wav_file.writeframes(self._audio_buffer)
|
||||||
|
|
||||||
|
wav_buffer.seek(0)
|
||||||
|
wav_data = wav_buffer.read()
|
||||||
|
|
||||||
|
# Send to API
|
||||||
|
form_data = aiohttp.FormData()
|
||||||
|
form_data.add_field(
|
||||||
|
'file',
|
||||||
|
wav_data,
|
||||||
|
filename='audio.wav',
|
||||||
|
content_type='audio/wav'
|
||||||
|
)
|
||||||
|
form_data.add_field('model', self.model)
|
||||||
|
|
||||||
|
async with self._session.post(self.API_URL, data=form_data) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
result = await response.json()
|
||||||
|
text = result.get("text", "").strip()
|
||||||
|
|
||||||
|
if text:
|
||||||
|
self._current_text = text
|
||||||
|
|
||||||
|
# Notify via callback
|
||||||
|
if self.on_transcript:
|
||||||
|
await self.on_transcript(text, is_final)
|
||||||
|
|
||||||
|
# Queue result
|
||||||
|
await self._transcript_queue.put(
|
||||||
|
ASRResult(text=text, is_final=is_final)
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"ASR {'final' if is_final else 'interim'}: {text[:50]}...")
|
||||||
|
return text
|
||||||
|
else:
|
||||||
|
error_text = await response.text()
|
||||||
|
logger.error(f"ASR API error {response.status}: {error_text}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"ASR transcription error: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_final_transcription(self) -> str:
|
||||||
|
"""
|
||||||
|
Get final transcription and clear buffer.
|
||||||
|
|
||||||
|
Call this when EOU is detected.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Final transcribed text
|
||||||
|
"""
|
||||||
|
# Transcribe full buffer as final
|
||||||
|
text = await self.transcribe_buffer(is_final=True)
|
||||||
|
|
||||||
|
# Clear buffer
|
||||||
|
result = text or self._current_text
|
||||||
|
self._audio_buffer = b""
|
||||||
|
self._current_text = ""
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_and_clear_text(self) -> str:
|
||||||
|
"""
|
||||||
|
Get accumulated text and clear buffer.
|
||||||
|
|
||||||
|
Compatible with BufferedASRService interface.
|
||||||
|
"""
|
||||||
|
text = self._current_text
|
||||||
|
self._current_text = ""
|
||||||
|
self._audio_buffer = b""
|
||||||
|
return text
|
||||||
|
|
||||||
|
def get_audio_buffer(self) -> bytes:
|
||||||
|
"""Get current audio buffer."""
|
||||||
|
return self._audio_buffer
|
||||||
|
|
||||||
|
def get_audio_duration_ms(self) -> float:
|
||||||
|
"""Get current audio buffer duration in milliseconds."""
|
||||||
|
return len(self._audio_buffer) / (self.sample_rate * 2) * 1000
|
||||||
|
|
||||||
|
def clear_buffer(self) -> None:
|
||||||
|
"""Clear audio and text buffers."""
|
||||||
|
self._audio_buffer = b""
|
||||||
|
self._current_text = ""
|
||||||
|
|
||||||
|
async def receive_transcripts(self) -> AsyncIterator[ASRResult]:
|
||||||
|
"""
|
||||||
|
Async iterator for transcription results.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
ASRResult with text and is_final flag
|
||||||
|
"""
|
||||||
|
while self._running:
|
||||||
|
try:
|
||||||
|
result = await asyncio.wait_for(
|
||||||
|
self._transcript_queue.get(),
|
||||||
|
timeout=0.1
|
||||||
|
)
|
||||||
|
yield result
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
continue
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
|
||||||
|
async def start_interim_transcription(self) -> None:
|
||||||
|
"""
|
||||||
|
Start background task for interim transcriptions.
|
||||||
|
|
||||||
|
This periodically transcribes buffered audio for
|
||||||
|
real-time feedback to the user.
|
||||||
|
"""
|
||||||
|
if self._interim_task and not self._interim_task.done():
|
||||||
|
return
|
||||||
|
|
||||||
|
self._interim_task = asyncio.create_task(self._interim_loop())
|
||||||
|
|
||||||
|
async def stop_interim_transcription(self) -> None:
|
||||||
|
"""Stop interim transcription task."""
|
||||||
|
if self._interim_task:
|
||||||
|
self._interim_task.cancel()
|
||||||
|
try:
|
||||||
|
await self._interim_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
self._interim_task = None
|
||||||
|
|
||||||
|
async def _interim_loop(self) -> None:
|
||||||
|
"""Background loop for interim transcriptions."""
|
||||||
|
import time
|
||||||
|
|
||||||
|
while self._running:
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(self.interim_interval_ms / 1000)
|
||||||
|
|
||||||
|
# Check if we have enough new audio
|
||||||
|
current_time = time.time()
|
||||||
|
time_since_last = (current_time - self._last_interim_time) * 1000
|
||||||
|
|
||||||
|
if time_since_last >= self.interim_interval_ms:
|
||||||
|
audio_duration = self.get_audio_duration_ms()
|
||||||
|
|
||||||
|
if audio_duration >= self.min_audio_for_interim_ms:
|
||||||
|
await self.transcribe_buffer(is_final=False)
|
||||||
|
self._last_interim_time = current_time
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Interim transcription error: {e}")
|
||||||
255
services/siliconflow_tts.py
Normal file
255
services/siliconflow_tts.py
Normal file
@@ -0,0 +1,255 @@
|
|||||||
|
"""SiliconFlow TTS Service with streaming support.
|
||||||
|
|
||||||
|
Uses SiliconFlow's CosyVoice2 or MOSS-TTSD models for low-latency
|
||||||
|
text-to-speech synthesis with streaming.
|
||||||
|
|
||||||
|
API Docs: https://docs.siliconflow.cn/cn/api-reference/audio/create-speech
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
import aiohttp
|
||||||
|
from typing import AsyncIterator, Optional
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from services.base import BaseTTSService, TTSChunk, ServiceState
|
||||||
|
|
||||||
|
|
||||||
|
class SiliconFlowTTSService(BaseTTSService):
|
||||||
|
"""
|
||||||
|
SiliconFlow TTS service with streaming support.
|
||||||
|
|
||||||
|
Supports CosyVoice2-0.5B and MOSS-TTSD-v0.5 models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Available voices
|
||||||
|
VOICES = {
|
||||||
|
"alex": "FunAudioLLM/CosyVoice2-0.5B:alex",
|
||||||
|
"anna": "FunAudioLLM/CosyVoice2-0.5B:anna",
|
||||||
|
"bella": "FunAudioLLM/CosyVoice2-0.5B:bella",
|
||||||
|
"benjamin": "FunAudioLLM/CosyVoice2-0.5B:benjamin",
|
||||||
|
"charles": "FunAudioLLM/CosyVoice2-0.5B:charles",
|
||||||
|
"claire": "FunAudioLLM/CosyVoice2-0.5B:claire",
|
||||||
|
"david": "FunAudioLLM/CosyVoice2-0.5B:david",
|
||||||
|
"diana": "FunAudioLLM/CosyVoice2-0.5B:diana",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
voice: str = "anna",
|
||||||
|
model: str = "FunAudioLLM/CosyVoice2-0.5B",
|
||||||
|
sample_rate: int = 16000,
|
||||||
|
speed: float = 1.0
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize SiliconFlow TTS service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: SiliconFlow API key (defaults to SILICONFLOW_API_KEY env var)
|
||||||
|
voice: Voice name (alex, anna, bella, benjamin, charles, claire, david, diana)
|
||||||
|
model: Model name
|
||||||
|
sample_rate: Output sample rate (8000, 16000, 24000, 32000, 44100)
|
||||||
|
speed: Speech speed (0.25 to 4.0)
|
||||||
|
"""
|
||||||
|
# Resolve voice name
|
||||||
|
if voice in self.VOICES:
|
||||||
|
full_voice = self.VOICES[voice]
|
||||||
|
else:
|
||||||
|
full_voice = voice
|
||||||
|
|
||||||
|
super().__init__(voice=full_voice, sample_rate=sample_rate, speed=speed)
|
||||||
|
|
||||||
|
self.api_key = api_key or os.getenv("SILICONFLOW_API_KEY")
|
||||||
|
self.model = model
|
||||||
|
self.api_url = "https://api.siliconflow.cn/v1/audio/speech"
|
||||||
|
|
||||||
|
self._session: Optional[aiohttp.ClientSession] = None
|
||||||
|
self._cancel_event = asyncio.Event()
|
||||||
|
|
||||||
|
async def connect(self) -> None:
|
||||||
|
"""Initialize HTTP session."""
|
||||||
|
if not self.api_key:
|
||||||
|
raise ValueError("SiliconFlow API key not provided. Set SILICONFLOW_API_KEY env var.")
|
||||||
|
|
||||||
|
self._session = aiohttp.ClientSession(
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self.state = ServiceState.CONNECTED
|
||||||
|
logger.info(f"SiliconFlow TTS service ready: voice={self.voice}, model={self.model}")
|
||||||
|
|
||||||
|
async def disconnect(self) -> None:
|
||||||
|
"""Close HTTP session."""
|
||||||
|
if self._session:
|
||||||
|
await self._session.close()
|
||||||
|
self._session = None
|
||||||
|
self.state = ServiceState.DISCONNECTED
|
||||||
|
logger.info("SiliconFlow TTS service disconnected")
|
||||||
|
|
||||||
|
async def synthesize(self, text: str) -> bytes:
|
||||||
|
"""Synthesize complete audio for text."""
|
||||||
|
audio_data = b""
|
||||||
|
async for chunk in self.synthesize_stream(text):
|
||||||
|
audio_data += chunk.audio
|
||||||
|
return audio_data
|
||||||
|
|
||||||
|
async def synthesize_stream(self, text: str) -> AsyncIterator[TTSChunk]:
|
||||||
|
"""
|
||||||
|
Synthesize audio in streaming mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to synthesize
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
TTSChunk objects with PCM audio
|
||||||
|
"""
|
||||||
|
if not self._session:
|
||||||
|
raise RuntimeError("TTS service not connected")
|
||||||
|
|
||||||
|
if not text.strip():
|
||||||
|
return
|
||||||
|
|
||||||
|
self._cancel_event.clear()
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": self.model,
|
||||||
|
"input": text,
|
||||||
|
"voice": self.voice,
|
||||||
|
"response_format": "pcm",
|
||||||
|
"sample_rate": self.sample_rate,
|
||||||
|
"stream": True,
|
||||||
|
"speed": self.speed
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with self._session.post(self.api_url, json=payload) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
error_text = await response.text()
|
||||||
|
logger.error(f"SiliconFlow TTS error: {response.status} - {error_text}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Stream audio chunks
|
||||||
|
chunk_size = self.sample_rate * 2 // 10 # 100ms chunks
|
||||||
|
buffer = b""
|
||||||
|
|
||||||
|
async for chunk in response.content.iter_any():
|
||||||
|
if self._cancel_event.is_set():
|
||||||
|
logger.info("TTS synthesis cancelled")
|
||||||
|
return
|
||||||
|
|
||||||
|
buffer += chunk
|
||||||
|
|
||||||
|
# Yield complete chunks
|
||||||
|
while len(buffer) >= chunk_size:
|
||||||
|
audio_chunk = buffer[:chunk_size]
|
||||||
|
buffer = buffer[chunk_size:]
|
||||||
|
|
||||||
|
yield TTSChunk(
|
||||||
|
audio=audio_chunk,
|
||||||
|
sample_rate=self.sample_rate,
|
||||||
|
is_final=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Yield remaining buffer
|
||||||
|
if buffer:
|
||||||
|
yield TTSChunk(
|
||||||
|
audio=buffer,
|
||||||
|
sample_rate=self.sample_rate,
|
||||||
|
is_final=True
|
||||||
|
)
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("TTS synthesis cancelled via asyncio")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"TTS synthesis error: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def cancel(self) -> None:
|
||||||
|
"""Cancel ongoing synthesis."""
|
||||||
|
self._cancel_event.set()
|
||||||
|
|
||||||
|
|
||||||
|
class StreamingTTSAdapter:
|
||||||
|
"""
|
||||||
|
Adapter for streaming LLM text to TTS with sentence-level chunking.
|
||||||
|
|
||||||
|
This reduces latency by starting TTS as soon as a complete sentence
|
||||||
|
is received from the LLM, rather than waiting for the full response.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Sentence delimiters
|
||||||
|
SENTENCE_ENDS = {'.', '!', '?', '。', '!', '?', ';', '\n'}
|
||||||
|
|
||||||
|
def __init__(self, tts_service: BaseTTSService, transport, session_id: str):
|
||||||
|
self.tts_service = tts_service
|
||||||
|
self.transport = transport
|
||||||
|
self.session_id = session_id
|
||||||
|
self._buffer = ""
|
||||||
|
self._cancel_event = asyncio.Event()
|
||||||
|
self._is_speaking = False
|
||||||
|
|
||||||
|
async def process_text_chunk(self, text_chunk: str) -> None:
|
||||||
|
"""
|
||||||
|
Process a text chunk from LLM and trigger TTS when sentence is complete.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text_chunk: Text chunk from LLM streaming
|
||||||
|
"""
|
||||||
|
if self._cancel_event.is_set():
|
||||||
|
return
|
||||||
|
|
||||||
|
self._buffer += text_chunk
|
||||||
|
|
||||||
|
# Check for sentence completion
|
||||||
|
for i, char in enumerate(self._buffer):
|
||||||
|
if char in self.SENTENCE_ENDS:
|
||||||
|
# Found sentence end, synthesize up to this point
|
||||||
|
sentence = self._buffer[:i+1].strip()
|
||||||
|
self._buffer = self._buffer[i+1:]
|
||||||
|
|
||||||
|
if sentence:
|
||||||
|
await self._speak_sentence(sentence)
|
||||||
|
break
|
||||||
|
|
||||||
|
async def flush(self) -> None:
|
||||||
|
"""Flush remaining buffer."""
|
||||||
|
if self._buffer.strip() and not self._cancel_event.is_set():
|
||||||
|
await self._speak_sentence(self._buffer.strip())
|
||||||
|
self._buffer = ""
|
||||||
|
|
||||||
|
async def _speak_sentence(self, text: str) -> None:
|
||||||
|
"""Synthesize and send a sentence."""
|
||||||
|
if not text or self._cancel_event.is_set():
|
||||||
|
return
|
||||||
|
|
||||||
|
self._is_speaking = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for chunk in self.tts_service.synthesize_stream(text):
|
||||||
|
if self._cancel_event.is_set():
|
||||||
|
break
|
||||||
|
await self.transport.send_audio(chunk.audio)
|
||||||
|
await asyncio.sleep(0.01) # Prevent flooding
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"TTS speak error: {e}")
|
||||||
|
finally:
|
||||||
|
self._is_speaking = False
|
||||||
|
|
||||||
|
def cancel(self) -> None:
|
||||||
|
"""Cancel ongoing speech."""
|
||||||
|
self._cancel_event.set()
|
||||||
|
self._buffer = ""
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Reset for new turn."""
|
||||||
|
self._cancel_event.clear()
|
||||||
|
self._buffer = ""
|
||||||
|
self._is_speaking = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_speaking(self) -> bool:
|
||||||
|
return self._is_speaking
|
||||||
271
services/tts.py
Normal file
271
services/tts.py
Normal file
@@ -0,0 +1,271 @@
|
|||||||
|
"""TTS (Text-to-Speech) Service implementations.
|
||||||
|
|
||||||
|
Provides multiple TTS backend options including edge-tts (free)
|
||||||
|
and placeholder for cloud services.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import io
|
||||||
|
import asyncio
|
||||||
|
import struct
|
||||||
|
from typing import AsyncIterator, Optional
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from services.base import BaseTTSService, TTSChunk, ServiceState
|
||||||
|
|
||||||
|
# Try to import edge-tts
|
||||||
|
try:
|
||||||
|
import edge_tts
|
||||||
|
EDGE_TTS_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
EDGE_TTS_AVAILABLE = False
|
||||||
|
logger.warning("edge-tts not available - EdgeTTS service will be disabled")
|
||||||
|
|
||||||
|
|
||||||
|
class EdgeTTSService(BaseTTSService):
|
||||||
|
"""
|
||||||
|
Microsoft Edge TTS service.
|
||||||
|
|
||||||
|
Uses edge-tts library for free, high-quality speech synthesis.
|
||||||
|
Supports streaming for low-latency playback.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Voice mapping for common languages
|
||||||
|
VOICE_MAP = {
|
||||||
|
"en": "en-US-JennyNeural",
|
||||||
|
"en-US": "en-US-JennyNeural",
|
||||||
|
"en-GB": "en-GB-SoniaNeural",
|
||||||
|
"zh": "zh-CN-XiaoxiaoNeural",
|
||||||
|
"zh-CN": "zh-CN-XiaoxiaoNeural",
|
||||||
|
"zh-TW": "zh-TW-HsiaoChenNeural",
|
||||||
|
"ja": "ja-JP-NanamiNeural",
|
||||||
|
"ko": "ko-KR-SunHiNeural",
|
||||||
|
"fr": "fr-FR-DeniseNeural",
|
||||||
|
"de": "de-DE-KatjaNeural",
|
||||||
|
"es": "es-ES-ElviraNeural",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
voice: str = "en-US-JennyNeural",
|
||||||
|
sample_rate: int = 16000,
|
||||||
|
speed: float = 1.0
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize Edge TTS service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
voice: Voice name (e.g., "en-US-JennyNeural") or language code (e.g., "en")
|
||||||
|
sample_rate: Target sample rate (will be resampled)
|
||||||
|
speed: Speech speed multiplier
|
||||||
|
"""
|
||||||
|
# Resolve voice from language code if needed
|
||||||
|
if voice in self.VOICE_MAP:
|
||||||
|
voice = self.VOICE_MAP[voice]
|
||||||
|
|
||||||
|
super().__init__(voice=voice, sample_rate=sample_rate, speed=speed)
|
||||||
|
self._cancel_event = asyncio.Event()
|
||||||
|
|
||||||
|
async def connect(self) -> None:
|
||||||
|
"""Edge TTS doesn't require explicit connection."""
|
||||||
|
if not EDGE_TTS_AVAILABLE:
|
||||||
|
raise RuntimeError("edge-tts package not installed")
|
||||||
|
self.state = ServiceState.CONNECTED
|
||||||
|
logger.info(f"Edge TTS service ready: voice={self.voice}")
|
||||||
|
|
||||||
|
async def disconnect(self) -> None:
|
||||||
|
"""Edge TTS doesn't require explicit disconnection."""
|
||||||
|
self.state = ServiceState.DISCONNECTED
|
||||||
|
logger.info("Edge TTS service disconnected")
|
||||||
|
|
||||||
|
def _get_rate_string(self) -> str:
|
||||||
|
"""Convert speed to rate string for edge-tts."""
|
||||||
|
# edge-tts uses percentage format: "+0%", "-10%", "+20%"
|
||||||
|
percentage = int((self.speed - 1.0) * 100)
|
||||||
|
if percentage >= 0:
|
||||||
|
return f"+{percentage}%"
|
||||||
|
return f"{percentage}%"
|
||||||
|
|
||||||
|
async def synthesize(self, text: str) -> bytes:
|
||||||
|
"""
|
||||||
|
Synthesize complete audio for text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to synthesize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PCM audio data (16-bit, mono, 16kHz)
|
||||||
|
"""
|
||||||
|
if not EDGE_TTS_AVAILABLE:
|
||||||
|
raise RuntimeError("edge-tts not available")
|
||||||
|
|
||||||
|
# Collect all chunks
|
||||||
|
audio_data = b""
|
||||||
|
async for chunk in self.synthesize_stream(text):
|
||||||
|
audio_data += chunk.audio
|
||||||
|
|
||||||
|
return audio_data
|
||||||
|
|
||||||
|
async def synthesize_stream(self, text: str) -> AsyncIterator[TTSChunk]:
|
||||||
|
"""
|
||||||
|
Synthesize audio in streaming mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to synthesize
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
TTSChunk objects with PCM audio
|
||||||
|
"""
|
||||||
|
if not EDGE_TTS_AVAILABLE:
|
||||||
|
raise RuntimeError("edge-tts not available")
|
||||||
|
|
||||||
|
self._cancel_event.clear()
|
||||||
|
|
||||||
|
try:
|
||||||
|
communicate = edge_tts.Communicate(
|
||||||
|
text,
|
||||||
|
voice=self.voice,
|
||||||
|
rate=self._get_rate_string()
|
||||||
|
)
|
||||||
|
|
||||||
|
# edge-tts outputs MP3, we need to decode to PCM
|
||||||
|
# For now, collect MP3 chunks and yield after conversion
|
||||||
|
mp3_data = b""
|
||||||
|
|
||||||
|
async for chunk in communicate.stream():
|
||||||
|
# Check for cancellation
|
||||||
|
if self._cancel_event.is_set():
|
||||||
|
logger.info("TTS synthesis cancelled")
|
||||||
|
return
|
||||||
|
|
||||||
|
if chunk["type"] == "audio":
|
||||||
|
mp3_data += chunk["data"]
|
||||||
|
|
||||||
|
# Convert MP3 to PCM
|
||||||
|
if mp3_data:
|
||||||
|
pcm_data = await self._convert_mp3_to_pcm(mp3_data)
|
||||||
|
if pcm_data:
|
||||||
|
# Yield in chunks for streaming playback
|
||||||
|
chunk_size = self.sample_rate * 2 // 10 # 100ms chunks
|
||||||
|
for i in range(0, len(pcm_data), chunk_size):
|
||||||
|
if self._cancel_event.is_set():
|
||||||
|
return
|
||||||
|
|
||||||
|
chunk_data = pcm_data[i:i + chunk_size]
|
||||||
|
yield TTSChunk(
|
||||||
|
audio=chunk_data,
|
||||||
|
sample_rate=self.sample_rate,
|
||||||
|
is_final=(i + chunk_size >= len(pcm_data))
|
||||||
|
)
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("TTS synthesis cancelled via asyncio")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"TTS synthesis error: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _convert_mp3_to_pcm(self, mp3_data: bytes) -> bytes:
|
||||||
|
"""
|
||||||
|
Convert MP3 audio to PCM.
|
||||||
|
|
||||||
|
Uses pydub or ffmpeg for conversion.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Try using pydub (requires ffmpeg)
|
||||||
|
from pydub import AudioSegment
|
||||||
|
|
||||||
|
# Load MP3 from bytes
|
||||||
|
audio = AudioSegment.from_mp3(io.BytesIO(mp3_data))
|
||||||
|
|
||||||
|
# Convert to target format
|
||||||
|
audio = audio.set_frame_rate(self.sample_rate)
|
||||||
|
audio = audio.set_channels(1)
|
||||||
|
audio = audio.set_sample_width(2) # 16-bit
|
||||||
|
|
||||||
|
# Export as raw PCM
|
||||||
|
return audio.raw_data
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
logger.warning("pydub not available, trying fallback")
|
||||||
|
# Fallback: Use subprocess to call ffmpeg directly
|
||||||
|
return await self._ffmpeg_convert(mp3_data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Audio conversion error: {e}")
|
||||||
|
return b""
|
||||||
|
|
||||||
|
async def _ffmpeg_convert(self, mp3_data: bytes) -> bytes:
|
||||||
|
"""Convert MP3 to PCM using ffmpeg subprocess."""
|
||||||
|
try:
|
||||||
|
process = await asyncio.create_subprocess_exec(
|
||||||
|
"ffmpeg",
|
||||||
|
"-i", "pipe:0",
|
||||||
|
"-f", "s16le",
|
||||||
|
"-acodec", "pcm_s16le",
|
||||||
|
"-ar", str(self.sample_rate),
|
||||||
|
"-ac", "1",
|
||||||
|
"pipe:1",
|
||||||
|
stdin=asyncio.subprocess.PIPE,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.DEVNULL
|
||||||
|
)
|
||||||
|
|
||||||
|
stdout, _ = await process.communicate(input=mp3_data)
|
||||||
|
return stdout
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"ffmpeg conversion error: {e}")
|
||||||
|
return b""
|
||||||
|
|
||||||
|
async def cancel(self) -> None:
|
||||||
|
"""Cancel ongoing synthesis."""
|
||||||
|
self._cancel_event.set()
|
||||||
|
|
||||||
|
|
||||||
|
class MockTTSService(BaseTTSService):
|
||||||
|
"""
|
||||||
|
Mock TTS service for testing without actual synthesis.
|
||||||
|
|
||||||
|
Generates silence or simple tones.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
voice: str = "mock",
|
||||||
|
sample_rate: int = 16000,
|
||||||
|
speed: float = 1.0
|
||||||
|
):
|
||||||
|
super().__init__(voice=voice, sample_rate=sample_rate, speed=speed)
|
||||||
|
|
||||||
|
async def connect(self) -> None:
|
||||||
|
self.state = ServiceState.CONNECTED
|
||||||
|
logger.info("Mock TTS service connected")
|
||||||
|
|
||||||
|
async def disconnect(self) -> None:
|
||||||
|
self.state = ServiceState.DISCONNECTED
|
||||||
|
logger.info("Mock TTS service disconnected")
|
||||||
|
|
||||||
|
async def synthesize(self, text: str) -> bytes:
|
||||||
|
"""Generate silence based on text length."""
|
||||||
|
# Approximate: 100ms per word
|
||||||
|
word_count = len(text.split())
|
||||||
|
duration_ms = word_count * 100
|
||||||
|
samples = int(self.sample_rate * duration_ms / 1000)
|
||||||
|
|
||||||
|
# Generate silence (zeros)
|
||||||
|
return bytes(samples * 2) # 16-bit = 2 bytes per sample
|
||||||
|
|
||||||
|
async def synthesize_stream(self, text: str) -> AsyncIterator[TTSChunk]:
|
||||||
|
"""Generate silence chunks."""
|
||||||
|
audio = await self.synthesize(text)
|
||||||
|
|
||||||
|
# Yield in 100ms chunks
|
||||||
|
chunk_size = self.sample_rate * 2 // 10
|
||||||
|
for i in range(0, len(audio), chunk_size):
|
||||||
|
chunk_data = audio[i:i + chunk_size]
|
||||||
|
yield TTSChunk(
|
||||||
|
audio=chunk_data,
|
||||||
|
sample_rate=self.sample_rate,
|
||||||
|
is_final=(i + chunk_size >= len(audio))
|
||||||
|
)
|
||||||
|
await asyncio.sleep(0.05) # Simulate processing time
|
||||||
160
test_client.py
160
test_client.py
@@ -1,160 +0,0 @@
|
|||||||
"""
|
|
||||||
WebSocket Test Client
|
|
||||||
|
|
||||||
Tests the WebSocket server with sine wave audio generation.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
python test_client.py
|
|
||||||
python test_client.py --url ws://localhost:8000/ws
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import aiohttp
|
|
||||||
import json
|
|
||||||
import struct
|
|
||||||
import math
|
|
||||||
import argparse
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
# Configuration
|
|
||||||
SERVER_URL = "ws://localhost:8000/ws"
|
|
||||||
SAMPLE_RATE = 16000
|
|
||||||
FREQUENCY = 440 # 440Hz sine wave
|
|
||||||
CHUNK_DURATION_MS = 20
|
|
||||||
CHUNK_SIZE_BYTES = int(SAMPLE_RATE * 2 * (CHUNK_DURATION_MS / 1000.0)) # 640 bytes
|
|
||||||
|
|
||||||
|
|
||||||
def generate_sine_wave(duration_ms=1000):
|
|
||||||
"""
|
|
||||||
Generate sine wave audio data.
|
|
||||||
|
|
||||||
Format: 16kHz, mono, 16-bit PCM
|
|
||||||
"""
|
|
||||||
num_samples = int(SAMPLE_RATE * (duration_ms / 1000.0))
|
|
||||||
audio_data = bytearray()
|
|
||||||
|
|
||||||
for x in range(num_samples):
|
|
||||||
# Generate sine wave sample
|
|
||||||
value = int(32767.0 * math.sin(2 * math.pi * FREQUENCY * x / SAMPLE_RATE))
|
|
||||||
# Pack as little-endian 16-bit signed integer
|
|
||||||
audio_data.extend(struct.pack('<h', value))
|
|
||||||
|
|
||||||
return audio_data
|
|
||||||
|
|
||||||
|
|
||||||
async def receive_loop(ws, session_id):
|
|
||||||
"""
|
|
||||||
Listen for incoming messages from the server.
|
|
||||||
"""
|
|
||||||
print("👂 Listening for server responses...")
|
|
||||||
async for msg in ws:
|
|
||||||
timestamp = datetime.now().strftime("%H:%M:%S")
|
|
||||||
|
|
||||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
|
||||||
try:
|
|
||||||
data = json.loads(msg.data)
|
|
||||||
event_type = data.get('event', 'Unknown')
|
|
||||||
print(f"[{timestamp}] 📨 Event: {event_type}")
|
|
||||||
print(f" {json.dumps(data, indent=2)}")
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
print(f"[{timestamp}] 📨 Text: {msg.data[:100]}...")
|
|
||||||
|
|
||||||
elif msg.type == aiohttp.WSMsgType.BINARY:
|
|
||||||
# Received audio chunk back
|
|
||||||
print(f"[{timestamp}] 🔊 Audio: {len(msg.data)} bytes")
|
|
||||||
|
|
||||||
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
|
||||||
print(f"\n[{timestamp}] ❌ Connection closed")
|
|
||||||
break
|
|
||||||
|
|
||||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
|
||||||
print(f"\n[{timestamp}] ⚠️ WebSocket error")
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
async def send_audio_loop(ws):
|
|
||||||
"""
|
|
||||||
Stream sine wave audio to the server.
|
|
||||||
"""
|
|
||||||
print("🎙️ Starting audio stream (sine wave)...")
|
|
||||||
|
|
||||||
# Generate 5 seconds of audio
|
|
||||||
audio_buffer = generate_sine_wave(5000)
|
|
||||||
cursor = 0
|
|
||||||
|
|
||||||
while cursor < len(audio_buffer):
|
|
||||||
chunk = audio_buffer[cursor:cursor + CHUNK_SIZE_BYTES]
|
|
||||||
if not chunk:
|
|
||||||
break
|
|
||||||
|
|
||||||
await ws.send_bytes(chunk)
|
|
||||||
print(f"📤 Sent audio chunk: {len(chunk)} bytes", end="\r")
|
|
||||||
|
|
||||||
cursor += len(chunk)
|
|
||||||
|
|
||||||
# Sleep to simulate real-time (20ms per chunk)
|
|
||||||
await asyncio.sleep(CHUNK_DURATION_MS / 1000.0)
|
|
||||||
|
|
||||||
print("\n✅ Finished streaming audio")
|
|
||||||
|
|
||||||
|
|
||||||
async def run_client(url):
|
|
||||||
"""
|
|
||||||
Run the WebSocket test client.
|
|
||||||
"""
|
|
||||||
session = aiohttp.ClientSession()
|
|
||||||
|
|
||||||
try:
|
|
||||||
print(f"🔌 Connecting to {url}...")
|
|
||||||
|
|
||||||
async with session.ws_connect(url) as ws:
|
|
||||||
print("✅ Connected!")
|
|
||||||
print()
|
|
||||||
|
|
||||||
# Send invite command
|
|
||||||
invite_cmd = {
|
|
||||||
"command": "invite",
|
|
||||||
"option": {
|
|
||||||
"codec": "pcm",
|
|
||||||
"samplerate": SAMPLE_RATE
|
|
||||||
}
|
|
||||||
}
|
|
||||||
await ws.send_json(invite_cmd)
|
|
||||||
print("📤 Sent invite command")
|
|
||||||
print()
|
|
||||||
|
|
||||||
# Send a ping command
|
|
||||||
ping_cmd = {"command": "ping"}
|
|
||||||
await ws.send_json(ping_cmd)
|
|
||||||
print("📤 Sent ping command")
|
|
||||||
print()
|
|
||||||
|
|
||||||
# Wait a moment for responses
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
|
|
||||||
# Run audio streaming and receiving in parallel
|
|
||||||
await asyncio.gather(
|
|
||||||
receive_loop(ws, None),
|
|
||||||
send_audio_loop(ws)
|
|
||||||
)
|
|
||||||
|
|
||||||
except aiohttp.ClientConnectorError:
|
|
||||||
print(f"❌ Connection failed. Is the server running at {url}?")
|
|
||||||
print(f" Start server with: python main.py")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ Error: {e}")
|
|
||||||
|
|
||||||
finally:
|
|
||||||
await session.close()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="WebSocket Audio Test Client")
|
|
||||||
parser.add_argument("--url", default=SERVER_URL, help="WebSocket URL")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
try:
|
|
||||||
asyncio.run(run_client(args.url))
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print("\n👋 Client stopped.")
|
|
||||||
1
utils/__init__.py
Normal file
1
utils/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Utilities Package"""
|
||||||
83
utils/logging.py
Normal file
83
utils/logging.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
"""Logging configuration utilities."""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from loguru import logger
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging(
|
||||||
|
log_level: str = "INFO",
|
||||||
|
log_format: str = "text",
|
||||||
|
log_to_file: bool = True,
|
||||||
|
log_dir: str = "logs"
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Configure structured logging with loguru.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log_level: Logging level (DEBUG, INFO, WARNING, ERROR)
|
||||||
|
log_format: Format type (json or text)
|
||||||
|
log_to_file: Whether to log to file
|
||||||
|
log_dir: Directory for log files
|
||||||
|
"""
|
||||||
|
# Remove default handler
|
||||||
|
logger.remove()
|
||||||
|
|
||||||
|
# Console handler
|
||||||
|
if log_format == "json":
|
||||||
|
logger.add(
|
||||||
|
sys.stdout,
|
||||||
|
format="{message}",
|
||||||
|
level=log_level,
|
||||||
|
serialize=True,
|
||||||
|
colorize=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.add(
|
||||||
|
sys.stdout,
|
||||||
|
format="<green>{time:HH:mm:ss}</green> | <level>{level: <8}</level> | <level>{message}</level>",
|
||||||
|
level=log_level,
|
||||||
|
colorize=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# File handler
|
||||||
|
if log_to_file:
|
||||||
|
log_path = Path(log_dir)
|
||||||
|
log_path.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
if log_format == "json":
|
||||||
|
logger.add(
|
||||||
|
log_path / "active_call_{time:YYYY-MM-DD}.log",
|
||||||
|
format="{message}",
|
||||||
|
level=log_level,
|
||||||
|
rotation="1 day",
|
||||||
|
retention="7 days",
|
||||||
|
compression="zip",
|
||||||
|
serialize=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.add(
|
||||||
|
log_path / "active_call_{time:YYYY-MM-DD}.log",
|
||||||
|
format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}",
|
||||||
|
level=log_level,
|
||||||
|
rotation="1 day",
|
||||||
|
retention="7 days",
|
||||||
|
compression="zip"
|
||||||
|
)
|
||||||
|
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
def get_logger(name: str = None):
|
||||||
|
"""
|
||||||
|
Get a logger instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Logger name (optional)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Logger instance
|
||||||
|
"""
|
||||||
|
if name:
|
||||||
|
return logger.bind(name=name)
|
||||||
|
return logger
|
||||||
Reference in New Issue
Block a user