Init commit

This commit is contained in:
Xin Wang
2026-02-17 10:39:23 +08:00
commit 30eb4397c2
56 changed files with 11983 additions and 0 deletions

92
.env.example Normal file
View File

@@ -0,0 +1,92 @@
# -----------------------------------------------------------------------------
# Engine .env example (safe template)
# Notes:
# - Never commit real API keys.
# - Start with defaults below, then tune from logs.
# -----------------------------------------------------------------------------
# Server
HOST=0.0.0.0
PORT=8000
# EXTERNAL_IP=1.2.3.4
# Backend bridge (optional)
BACKEND_URL=http://127.0.0.1:8100
BACKEND_TIMEOUT_SEC=10
HISTORY_DEFAULT_USER_ID=1
# Audio
SAMPLE_RATE=16000
# 20ms is recommended for VAD stability and latency.
# 100ms works but usually worsens start-of-speech accuracy.
CHUNK_SIZE_MS=20
DEFAULT_CODEC=pcm
MAX_AUDIO_BUFFER_SECONDS=30
# VAD / EOU
VAD_TYPE=silero
VAD_MODEL_PATH=data/vad/silero_vad.onnx
# Higher = stricter speech detection (fewer false positives, more misses).
VAD_THRESHOLD=0.5
# Require this much continuous speech before utterance can be valid.
VAD_MIN_SPEECH_DURATION_MS=100
# Silence duration required to finalize one user turn.
VAD_EOU_THRESHOLD_MS=800
# LLM
OPENAI_API_KEY=your_openai_api_key_here
# Optional for OpenAI-compatible providers.
# OPENAI_API_URL=https://api.openai.com/v1
LLM_MODEL=gpt-4o-mini
LLM_TEMPERATURE=0.7
# TTS
# edge: no API key needed
# openai_compatible: compatible with SiliconFlow-style endpoints
TTS_PROVIDER=openai_compatible
TTS_VOICE=anna
TTS_SPEED=1.0
# SiliconFlow (used by TTS and/or ASR when provider=openai_compatible)
SILICONFLOW_API_KEY=your_siliconflow_api_key_here
SILICONFLOW_TTS_MODEL=FunAudioLLM/CosyVoice2-0.5B
SILICONFLOW_ASR_MODEL=FunAudioLLM/SenseVoiceSmall
# ASR
ASR_PROVIDER=openai_compatible
# Interim cadence and minimum audio before interim decode.
ASR_INTERIM_INTERVAL_MS=500
ASR_MIN_AUDIO_MS=300
# ASR start gate: ignore micro-noise, then commit to one turn once started.
ASR_START_MIN_SPEECH_MS=160
# Pre-roll protects beginning phonemes.
ASR_PRE_SPEECH_MS=240
# Tail silence protects ending phonemes.
ASR_FINAL_TAIL_MS=120
# Duplex behavior
DUPLEX_ENABLED=true
# DUPLEX_GREETING=Hello! How can I help you today?
DUPLEX_SYSTEM_PROMPT=You are a helpful, friendly voice assistant. Keep your responses concise and conversational.
# Barge-in (user interrupting assistant)
# Min user speech duration needed to interrupt assistant audio.
BARGE_IN_MIN_DURATION_MS=200
# Allowed silence during potential barge-in (ms) before reset.
BARGE_IN_SILENCE_TOLERANCE_MS=60
# Logging
LOG_LEVEL=INFO
# json is better for production/observability; text is easier locally.
LOG_FORMAT=json
# WebSocket behavior
INACTIVITY_TIMEOUT_SEC=60
HEARTBEAT_INTERVAL_SEC=50
WS_PROTOCOL_VERSION=v1
# WS_API_KEY=replace_with_shared_secret
WS_REQUIRE_AUTH=false
# CORS / ICE (JSON strings)
CORS_ORIGINS=["http://localhost:3000","http://localhost:8080"]
ICE_SERVERS=[{"urls":"stun:stun.l.google.com:19302"}]

148
.gitignore vendored Normal file
View File

@@ -0,0 +1,148 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
Pipfile.lock
# poetry
poetry.lock
# pdm
.pdm.toml
# PEP 582
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# IDEs
.vscode/
.idea/
*.swp
*.swo
*~
# Project specific
recordings/
logs/
running/

31
README.md Normal file
View File

@@ -0,0 +1,31 @@
# py-active-call-cc
Python Active-Call: real-time audio streaming with WebSocket and WebRTC.
This repo contains a Python 3.11+ codebase for building low-latency voice
pipelines (capture, stream, and process audio) using WebRTC and WebSockets.
It is currently in an early, experimental stage.
# Usage
启动
```
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
```
测试
```
python examples/test_websocket.py
```
```
python mic_client.py
```
## WS Protocol
`/ws` uses a strict `v1` JSON control protocol with binary PCM audio frames.
See `/Users/wx44wx/.codex/worktrees/d817/AI-VideoAssistant/engine/docs/ws_v1_schema.md`.

1
app/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Active-Call Application Package"""

211
app/backend_client.py Normal file
View File

@@ -0,0 +1,211 @@
"""Backend API client for assistant config and history persistence."""
from __future__ import annotations
from typing import Any, Dict, List, Optional
import aiohttp
from loguru import logger
from app.config import settings
async def fetch_assistant_config(assistant_id: str) -> Optional[Dict[str, Any]]:
"""Fetch assistant config payload from backend API.
Expected response shape:
{
"assistant": {...},
"voice": {...} | null
}
"""
if not settings.backend_url:
logger.warning("BACKEND_URL not set; skipping assistant config fetch")
return None
url = f"{settings.backend_url.rstrip('/')}/api/assistants/{assistant_id}/config"
timeout = aiohttp.ClientTimeout(total=settings.backend_timeout_sec)
try:
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(url) as resp:
if resp.status == 404:
logger.warning(f"Assistant config not found: {assistant_id}")
return None
resp.raise_for_status()
payload = await resp.json()
if not isinstance(payload, dict):
logger.warning("Assistant config payload is not a dict; ignoring")
return None
return payload
except Exception as exc:
logger.warning(f"Failed to fetch assistant config ({assistant_id}): {exc}")
return None
def _backend_base_url() -> Optional[str]:
if not settings.backend_url:
return None
return settings.backend_url.rstrip("/")
def _timeout() -> aiohttp.ClientTimeout:
return aiohttp.ClientTimeout(total=settings.backend_timeout_sec)
async def create_history_call_record(
*,
user_id: int,
assistant_id: Optional[str],
source: str = "debug",
) -> Optional[str]:
"""Create a call record via backend history API and return call_id."""
base_url = _backend_base_url()
if not base_url:
return None
url = f"{base_url}/api/history"
payload: Dict[str, Any] = {
"user_id": user_id,
"assistant_id": assistant_id,
"source": source,
"status": "connected",
}
try:
async with aiohttp.ClientSession(timeout=_timeout()) as session:
async with session.post(url, json=payload) as resp:
resp.raise_for_status()
data = await resp.json()
call_id = str((data or {}).get("id") or "")
return call_id or None
except Exception as exc:
logger.warning(f"Failed to create history call record: {exc}")
return None
async def add_history_transcript(
*,
call_id: str,
turn_index: int,
speaker: str,
content: str,
start_ms: int,
end_ms: int,
confidence: Optional[float] = None,
duration_ms: Optional[int] = None,
) -> bool:
"""Append a transcript segment to backend history."""
base_url = _backend_base_url()
if not base_url or not call_id:
return False
url = f"{base_url}/api/history/{call_id}/transcripts"
payload: Dict[str, Any] = {
"turn_index": turn_index,
"speaker": speaker,
"content": content,
"confidence": confidence,
"start_ms": start_ms,
"end_ms": end_ms,
"duration_ms": duration_ms,
}
try:
async with aiohttp.ClientSession(timeout=_timeout()) as session:
async with session.post(url, json=payload) as resp:
resp.raise_for_status()
return True
except Exception as exc:
logger.warning(f"Failed to append history transcript (call_id={call_id}, turn={turn_index}): {exc}")
return False
async def finalize_history_call_record(
*,
call_id: str,
status: str,
duration_seconds: int,
) -> bool:
"""Finalize a call record with status and duration."""
base_url = _backend_base_url()
if not base_url or not call_id:
return False
url = f"{base_url}/api/history/{call_id}"
payload: Dict[str, Any] = {
"status": status,
"duration_seconds": duration_seconds,
}
try:
async with aiohttp.ClientSession(timeout=_timeout()) as session:
async with session.put(url, json=payload) as resp:
resp.raise_for_status()
return True
except Exception as exc:
logger.warning(f"Failed to finalize history call record ({call_id}): {exc}")
return False
async def search_knowledge_context(
*,
kb_id: str,
query: str,
n_results: int = 5,
) -> List[Dict[str, Any]]:
"""Search backend knowledge base and return retrieval results."""
base_url = _backend_base_url()
if not base_url:
return []
if not kb_id or not query.strip():
return []
try:
safe_n_results = max(1, int(n_results))
except (TypeError, ValueError):
safe_n_results = 5
url = f"{base_url}/api/knowledge/search"
payload: Dict[str, Any] = {
"kb_id": kb_id,
"query": query,
"nResults": safe_n_results,
}
try:
async with aiohttp.ClientSession(timeout=_timeout()) as session:
async with session.post(url, json=payload) as resp:
if resp.status == 404:
logger.warning(f"Knowledge base not found for retrieval: {kb_id}")
return []
resp.raise_for_status()
data = await resp.json()
if not isinstance(data, dict):
return []
results = data.get("results", [])
if not isinstance(results, list):
return []
return [r for r in results if isinstance(r, dict)]
except Exception as exc:
logger.warning(f"Knowledge search failed (kb_id={kb_id}): {exc}")
return []
async def fetch_tool_resource(tool_id: str) -> Optional[Dict[str, Any]]:
"""Fetch tool resource configuration from backend API."""
base_url = _backend_base_url()
if not base_url or not tool_id:
return None
url = f"{base_url}/api/tools/resources/{tool_id}"
try:
async with aiohttp.ClientSession(timeout=_timeout()) as session:
async with session.get(url) as resp:
if resp.status == 404:
return None
resp.raise_for_status()
data = await resp.json()
return data if isinstance(data, dict) else None
except Exception as exc:
logger.warning(f"Failed to fetch tool resource ({tool_id}): {exc}")
return None

154
app/config.py Normal file
View File

@@ -0,0 +1,154 @@
"""Configuration management using Pydantic settings."""
from typing import List, Optional
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
import json
class Settings(BaseSettings):
"""Application settings loaded from environment variables."""
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
case_sensitive=False,
extra="ignore"
)
# Server Configuration
host: str = Field(default="0.0.0.0", description="Server host address")
port: int = Field(default=8000, description="Server port")
external_ip: Optional[str] = Field(default=None, description="External IP for NAT traversal")
# Audio Configuration
sample_rate: int = Field(default=16000, description="Audio sample rate in Hz")
chunk_size_ms: int = Field(default=20, description="Audio chunk duration in milliseconds")
default_codec: str = Field(default="pcm", description="Default audio codec")
max_audio_buffer_seconds: int = Field(
default=30,
description="Maximum buffered user audio duration kept in memory for current turn"
)
# VAD Configuration
vad_type: str = Field(default="silero", description="VAD algorithm type")
vad_model_path: str = Field(default="data/vad/silero_vad.onnx", description="Path to VAD model")
vad_threshold: float = Field(default=0.5, description="VAD detection threshold")
vad_min_speech_duration_ms: int = Field(default=100, description="Minimum speech duration in milliseconds")
vad_eou_threshold_ms: int = Field(default=800, description="End of utterance (silence) threshold in milliseconds")
# OpenAI / LLM Configuration
openai_api_key: Optional[str] = Field(default=None, description="OpenAI API key")
openai_api_url: Optional[str] = Field(default=None, description="OpenAI API base URL (for Azure/compatible)")
llm_model: str = Field(default="gpt-4o-mini", description="LLM model name")
llm_temperature: float = Field(default=0.7, description="LLM temperature for response generation")
# TTS Configuration
tts_provider: str = Field(
default="openai_compatible",
description="TTS provider (edge, openai_compatible; siliconflow alias supported)"
)
tts_voice: str = Field(default="anna", description="TTS voice name")
tts_speed: float = Field(default=1.0, description="TTS speech speed multiplier")
# SiliconFlow Configuration
siliconflow_api_key: Optional[str] = Field(default=None, description="SiliconFlow API key")
siliconflow_tts_model: str = Field(default="FunAudioLLM/CosyVoice2-0.5B", description="SiliconFlow TTS model")
# ASR Configuration
asr_provider: str = Field(
default="openai_compatible",
description="ASR provider (openai_compatible, buffered; siliconflow alias supported)"
)
siliconflow_asr_model: str = Field(default="FunAudioLLM/SenseVoiceSmall", description="SiliconFlow ASR model")
asr_interim_interval_ms: int = Field(default=500, description="Interval for interim ASR results in ms")
asr_min_audio_ms: int = Field(default=300, description="Minimum audio duration before first ASR result")
asr_start_min_speech_ms: int = Field(
default=160,
description="Minimum continuous speech duration before ASR capture starts"
)
asr_pre_speech_ms: int = Field(
default=240,
description="Audio context (ms) prepended before detected speech to avoid clipping first phoneme"
)
asr_final_tail_ms: int = Field(
default=120,
description="Silence tail (ms) appended before final ASR decode to protect utterance ending"
)
# Duplex Pipeline Configuration
duplex_enabled: bool = Field(default=True, description="Enable duplex voice pipeline")
duplex_greeting: Optional[str] = Field(default=None, description="Optional greeting message")
duplex_system_prompt: Optional[str] = Field(
default="You are a helpful, friendly voice assistant. Keep your responses concise and conversational.",
description="System prompt for LLM"
)
# Barge-in (interruption) Configuration
barge_in_min_duration_ms: int = Field(
default=200,
description="Minimum speech duration (ms) required to trigger barge-in. Lower=more sensitive."
)
barge_in_silence_tolerance_ms: int = Field(
default=60,
description="How much silence (ms) is tolerated during potential barge-in before reset"
)
# Logging
log_level: str = Field(default="INFO", description="Logging level")
log_format: str = Field(default="json", description="Log format (json or text)")
# CORS
cors_origins: str = Field(
default='["http://localhost:3000", "http://localhost:8080"]',
description="CORS allowed origins"
)
# ICE Servers (WebRTC)
ice_servers: str = Field(
default='[{"urls": "stun:stun.l.google.com:19302"}]',
description="ICE servers configuration"
)
# WebSocket heartbeat and inactivity
inactivity_timeout_sec: int = Field(default=60, description="Close connection after no message from client (seconds)")
heartbeat_interval_sec: int = Field(default=50, description="Send heartBeat event to client every N seconds")
ws_protocol_version: str = Field(default="v1", description="Public WS protocol version")
ws_api_key: Optional[str] = Field(default=None, description="Optional API key required for WS hello auth")
ws_require_auth: bool = Field(default=False, description="Require auth in hello message even when ws_api_key is not set")
# Backend bridge configuration (for call/transcript persistence)
backend_url: Optional[str] = Field(default=None, description="Backend API base URL (e.g. http://localhost:8787)")
backend_timeout_sec: int = Field(default=10, description="Backend API request timeout in seconds")
history_default_user_id: int = Field(default=1, description="Fallback user_id for history records")
@property
def chunk_size_bytes(self) -> int:
"""Calculate chunk size in bytes based on sample rate and duration."""
# 16-bit (2 bytes) per sample, mono channel
return int(self.sample_rate * 2 * (self.chunk_size_ms / 1000.0))
@property
def cors_origins_list(self) -> List[str]:
"""Parse CORS origins from JSON string."""
try:
return json.loads(self.cors_origins)
except json.JSONDecodeError:
return ["http://localhost:3000", "http://localhost:8080"]
@property
def ice_servers_list(self) -> List[dict]:
"""Parse ICE servers from JSON string."""
try:
return json.loads(self.ice_servers)
except json.JSONDecodeError:
return [{"urls": "stun:stun.l.google.com:19302"}]
# Global settings instance
settings = Settings()
def get_settings() -> Settings:
"""Get application settings instance."""
return settings

396
app/main.py Normal file
View File

@@ -0,0 +1,396 @@
"""FastAPI application with WebSocket and WebRTC endpoints."""
import asyncio
import json
import time
import uuid
from pathlib import Path
from typing import Dict, Any, Optional, List
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, FileResponse
from loguru import logger
# Try to import aiortc (optional for WebRTC functionality)
try:
from aiortc import RTCPeerConnection, RTCSessionDescription
AIORTC_AVAILABLE = True
except ImportError:
AIORTC_AVAILABLE = False
logger.warning("aiortc not available - WebRTC endpoint will be disabled")
from app.config import settings
from core.transports import SocketTransport, WebRtcTransport, BaseTransport
from core.session import Session
from processors.tracks import Resampled16kTrack
from core.events import get_event_bus, reset_event_bus
from models.ws_v1 import ev
# Check interval for heartbeat/timeout (seconds)
_HEARTBEAT_CHECK_INTERVAL_SEC = 5
async def heartbeat_and_timeout_task(
transport: BaseTransport,
session: Session,
session_id: str,
last_received_at: List[float],
last_heartbeat_at: List[float],
inactivity_timeout_sec: int,
heartbeat_interval_sec: int,
) -> None:
"""
Background task: send heartBeat every ~heartbeat_interval_sec and close
connection if no message from client for inactivity_timeout_sec.
"""
while True:
await asyncio.sleep(_HEARTBEAT_CHECK_INTERVAL_SEC)
if transport.is_closed:
break
now = time.monotonic()
if now - last_received_at[0] > inactivity_timeout_sec:
logger.info(f"Session {session_id}: {inactivity_timeout_sec}s no message, closing")
await session.cleanup()
break
if now - last_heartbeat_at[0] >= heartbeat_interval_sec:
try:
await transport.send_event({
**ev("heartbeat"),
})
last_heartbeat_at[0] = now
except Exception as e:
logger.debug(f"Session {session_id}: heartbeat send failed: {e}")
break
# Initialize FastAPI
app = FastAPI(title="Python Active-Call", version="0.1.0")
_WEB_CLIENT_PATH = Path(__file__).resolve().parent.parent / "examples" / "web_client.html"
# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins_list,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Active sessions storage
active_sessions: Dict[str, Session] = {}
# Configure logging
logger.remove()
logger.add(
"./logs/active_call_{time}.log",
rotation="1 day",
retention="7 days",
level=settings.log_level,
format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}"
)
logger.add(
lambda msg: print(msg, end=""),
level=settings.log_level,
format="{time:HH:mm:ss} | {level: <8} | {message}"
)
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {"status": "healthy", "sessions": len(active_sessions)}
@app.get("/")
async def web_client_root():
"""Serve the web client."""
if not _WEB_CLIENT_PATH.exists():
raise HTTPException(status_code=404, detail="Web client not found")
return FileResponse(_WEB_CLIENT_PATH)
@app.get("/client")
async def web_client_alias():
"""Alias for the web client."""
if not _WEB_CLIENT_PATH.exists():
raise HTTPException(status_code=404, detail="Web client not found")
return FileResponse(_WEB_CLIENT_PATH)
@app.get("/iceservers")
async def get_ice_servers():
"""Get ICE servers configuration for WebRTC."""
return settings.ice_servers_list
@app.get("/call/lists")
async def list_calls():
"""List all active calls."""
return {
"calls": [
{
"id": session_id,
"state": session.state,
"created_at": session.created_at
}
for session_id, session in active_sessions.items()
]
}
@app.post("/call/kill/{session_id}")
async def kill_call(session_id: str):
"""Kill a specific active call."""
if session_id not in active_sessions:
raise HTTPException(status_code=404, detail="Session not found")
session = active_sessions[session_id]
await session.cleanup()
del active_sessions[session_id]
return True
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
"""
WebSocket endpoint for raw audio streaming.
Accepts mixed text/binary frames:
- Text frames: JSON commands
- Binary frames: PCM audio data (16kHz, 16-bit, mono)
"""
await websocket.accept()
session_id = str(uuid.uuid4())
# Create transport and session
transport = SocketTransport(websocket)
session = Session(session_id, transport)
active_sessions[session_id] = session
logger.info(f"WebSocket connection established: {session_id}")
last_received_at: List[float] = [time.monotonic()]
last_heartbeat_at: List[float] = [0.0]
hb_task = asyncio.create_task(
heartbeat_and_timeout_task(
transport,
session,
session_id,
last_received_at,
last_heartbeat_at,
settings.inactivity_timeout_sec,
settings.heartbeat_interval_sec,
)
)
try:
# Receive loop
while True:
message = await websocket.receive()
message_type = message.get("type")
if message_type == "websocket.disconnect":
logger.info(f"WebSocket disconnected: {session_id}")
break
last_received_at[0] = time.monotonic()
# Handle binary audio data
if "bytes" in message:
await session.handle_audio(message["bytes"])
# Handle text commands
elif "text" in message:
await session.handle_text(message["text"])
except WebSocketDisconnect:
logger.info(f"WebSocket disconnected: {session_id}")
except Exception as e:
logger.error(f"WebSocket error: {e}", exc_info=True)
finally:
hb_task.cancel()
try:
await hb_task
except asyncio.CancelledError:
pass
# Cleanup session
if session_id in active_sessions:
await session.cleanup()
del active_sessions[session_id]
logger.info(f"Session {session_id} removed")
@app.websocket("/webrtc")
async def webrtc_endpoint(websocket: WebSocket):
"""
WebRTC endpoint for WebRTC audio streaming.
Uses WebSocket for signaling (SDP exchange) and WebRTC for media transport.
"""
# Check if aiortc is available
if not AIORTC_AVAILABLE:
await websocket.close(code=1011, reason="WebRTC not available - aiortc/av not installed")
logger.warning("WebRTC connection attempted but aiortc is not available")
return
await websocket.accept()
session_id = str(uuid.uuid4())
# Create WebRTC peer connection
pc = RTCPeerConnection()
# Create transport and session
transport = WebRtcTransport(websocket, pc)
session = Session(session_id, transport)
active_sessions[session_id] = session
logger.info(f"WebRTC connection established: {session_id}")
last_received_at: List[float] = [time.monotonic()]
last_heartbeat_at: List[float] = [0.0]
hb_task = asyncio.create_task(
heartbeat_and_timeout_task(
transport,
session,
session_id,
last_received_at,
last_heartbeat_at,
settings.inactivity_timeout_sec,
settings.heartbeat_interval_sec,
)
)
# Track handler for incoming audio
@pc.on("track")
def on_track(track):
logger.info(f"Track received: {track.kind}")
if track.kind == "audio":
# Wrap track with resampler
wrapped_track = Resampled16kTrack(track)
# Create task to pull audio from track
async def pull_audio():
try:
while True:
frame = await wrapped_track.recv()
# Convert frame to bytes
pcm_bytes = frame.to_ndarray().tobytes()
# Feed to session
await session.handle_audio(pcm_bytes)
except Exception as e:
logger.error(f"Error pulling audio from track: {e}")
asyncio.create_task(pull_audio())
@pc.on("connectionstatechange")
async def on_connectionstatechange():
logger.info(f"Connection state: {pc.connectionState}")
if pc.connectionState == "failed" or pc.connectionState == "closed":
await session.cleanup()
try:
# Signaling loop
while True:
message = await websocket.receive()
if "text" not in message:
continue
last_received_at[0] = time.monotonic()
data = json.loads(message["text"])
# Handle SDP offer/answer
if "sdp" in data and "type" in data:
logger.info(f"Received SDP {data['type']}")
# Set remote description
offer = RTCSessionDescription(sdp=data["sdp"], type=data["type"])
await pc.setRemoteDescription(offer)
# Create and set local description
if data["type"] == "offer":
answer = await pc.createAnswer()
await pc.setLocalDescription(answer)
# Send answer back
await websocket.send_text(json.dumps({
"event": "answer",
"trackId": session_id,
"timestamp": int(asyncio.get_event_loop().time() * 1000),
"sdp": pc.localDescription.sdp
}))
logger.info(f"Sent SDP answer")
else:
# Handle other commands
await session.handle_text(message["text"])
except WebSocketDisconnect:
logger.info(f"WebRTC WebSocket disconnected: {session_id}")
except Exception as e:
logger.error(f"WebRTC error: {e}", exc_info=True)
finally:
hb_task.cancel()
try:
await hb_task
except asyncio.CancelledError:
pass
# Cleanup
await pc.close()
if session_id in active_sessions:
await session.cleanup()
del active_sessions[session_id]
logger.info(f"WebRTC session {session_id} removed")
@app.on_event("startup")
async def startup_event():
"""Run on application startup."""
logger.info("Starting Python Active-Call server")
logger.info(f"Server: {settings.host}:{settings.port}")
logger.info(f"Sample rate: {settings.sample_rate} Hz")
logger.info(f"VAD model: {settings.vad_model_path}")
@app.on_event("shutdown")
async def shutdown_event():
"""Run on application shutdown."""
logger.info("Shutting down Python Active-Call server")
# Cleanup all sessions
for session_id, session in active_sessions.items():
await session.cleanup()
# Close event bus
event_bus = get_event_bus()
await event_bus.close()
reset_event_bus()
logger.info("Server shutdown complete")
if __name__ == "__main__":
import uvicorn
# Create logs directory
import os
os.makedirs("logs", exist_ok=True)
# Run server
uvicorn.run(
"app.main:app",
host=settings.host,
port=settings.port,
reload=True,
log_level=settings.log_level.lower()
)

20
core/__init__.py Normal file
View File

@@ -0,0 +1,20 @@
"""Core Components Package"""
from core.events import EventBus, get_event_bus
from core.transports import BaseTransport, SocketTransport, WebRtcTransport
from core.session import Session
from core.conversation import ConversationManager, ConversationState, ConversationTurn
from core.duplex_pipeline import DuplexPipeline
__all__ = [
"EventBus",
"get_event_bus",
"BaseTransport",
"SocketTransport",
"WebRtcTransport",
"Session",
"ConversationManager",
"ConversationState",
"ConversationTurn",
"DuplexPipeline",
]

279
core/conversation.py Normal file
View File

@@ -0,0 +1,279 @@
"""Conversation management for voice AI.
Handles conversation context, turn-taking, and message history
for multi-turn voice conversations.
"""
import asyncio
from typing import List, Optional, Dict, Any, Callable, Awaitable
from dataclasses import dataclass, field
from enum import Enum
from loguru import logger
from services.base import LLMMessage
class ConversationState(Enum):
"""State of the conversation."""
IDLE = "idle" # Waiting for user input
LISTENING = "listening" # User is speaking
PROCESSING = "processing" # Processing user input (LLM)
SPEAKING = "speaking" # Bot is speaking
INTERRUPTED = "interrupted" # Bot was interrupted
@dataclass
class ConversationTurn:
"""A single turn in the conversation."""
role: str # "user" or "assistant"
text: str
audio_duration_ms: Optional[int] = None
timestamp: float = field(default_factory=lambda: asyncio.get_event_loop().time())
was_interrupted: bool = False
class ConversationManager:
"""
Manages conversation state and history.
Provides:
- Message history for LLM context
- Turn management
- State tracking
- Event callbacks for state changes
"""
def __init__(
self,
system_prompt: Optional[str] = None,
max_history: int = 20,
greeting: Optional[str] = None
):
"""
Initialize conversation manager.
Args:
system_prompt: System prompt for LLM
max_history: Maximum number of turns to keep
greeting: Optional greeting message when conversation starts
"""
self.system_prompt = system_prompt or (
"You are a helpful, friendly voice assistant. "
"Keep your responses concise and conversational. "
"Respond naturally as if having a phone conversation. "
"If you don't understand something, ask for clarification."
)
self.max_history = max_history
self.greeting = greeting
# State
self.state = ConversationState.IDLE
self.turns: List[ConversationTurn] = []
# Callbacks
self._state_callbacks: List[Callable[[ConversationState, ConversationState], Awaitable[None]]] = []
self._turn_callbacks: List[Callable[[ConversationTurn], Awaitable[None]]] = []
# Current turn tracking
self._current_user_text: str = ""
self._current_assistant_text: str = ""
logger.info("ConversationManager initialized")
def on_state_change(
self,
callback: Callable[[ConversationState, ConversationState], Awaitable[None]]
) -> None:
"""Register callback for state changes."""
self._state_callbacks.append(callback)
def on_turn_complete(
self,
callback: Callable[[ConversationTurn], Awaitable[None]]
) -> None:
"""Register callback for turn completion."""
self._turn_callbacks.append(callback)
async def set_state(self, new_state: ConversationState) -> None:
"""Set conversation state and notify listeners."""
if new_state != self.state:
old_state = self.state
self.state = new_state
logger.debug(f"Conversation state: {old_state.value} -> {new_state.value}")
for callback in self._state_callbacks:
try:
await callback(old_state, new_state)
except Exception as e:
logger.error(f"State callback error: {e}")
def get_messages(self) -> List[LLMMessage]:
"""
Get conversation history as LLM messages.
Returns:
List of LLMMessage objects including system prompt
"""
messages = [LLMMessage(role="system", content=self.system_prompt)]
# Add conversation history
for turn in self.turns[-self.max_history:]:
messages.append(LLMMessage(role=turn.role, content=turn.text))
# Add current user text if any
if self._current_user_text:
messages.append(LLMMessage(role="user", content=self._current_user_text))
return messages
async def start_user_turn(self) -> None:
"""Signal that user has started speaking."""
await self.set_state(ConversationState.LISTENING)
self._current_user_text = ""
async def update_user_text(self, text: str, is_final: bool = False) -> None:
"""
Update current user text (from ASR).
Args:
text: Transcribed text
is_final: Whether this is the final transcript
"""
self._current_user_text = text
async def end_user_turn(self, text: str) -> None:
"""
End user turn and add to history.
Args:
text: Final user text
"""
if text.strip():
turn = ConversationTurn(role="user", text=text.strip())
self.turns.append(turn)
for callback in self._turn_callbacks:
try:
await callback(turn)
except Exception as e:
logger.error(f"Turn callback error: {e}")
logger.info(f"User: {text[:50]}...")
self._current_user_text = ""
await self.set_state(ConversationState.PROCESSING)
async def start_assistant_turn(self) -> None:
"""Signal that assistant has started speaking."""
await self.set_state(ConversationState.SPEAKING)
self._current_assistant_text = ""
async def update_assistant_text(self, text: str) -> None:
"""
Update current assistant text (streaming).
Args:
text: Text chunk from LLM
"""
self._current_assistant_text += text
async def end_assistant_turn(self, was_interrupted: bool = False) -> None:
"""
End assistant turn and add to history.
Args:
was_interrupted: Whether the turn was interrupted by user
"""
text = self._current_assistant_text.strip()
if text:
turn = ConversationTurn(
role="assistant",
text=text,
was_interrupted=was_interrupted
)
self.turns.append(turn)
for callback in self._turn_callbacks:
try:
await callback(turn)
except Exception as e:
logger.error(f"Turn callback error: {e}")
status = " (interrupted)" if was_interrupted else ""
logger.info(f"Assistant{status}: {text[:50]}...")
self._current_assistant_text = ""
if was_interrupted:
# A new user turn may already be active (LISTENING) when interrupted.
# Avoid overriding it back to INTERRUPTED, which can stall EOU flow.
if self.state != ConversationState.LISTENING:
await self.set_state(ConversationState.INTERRUPTED)
else:
await self.set_state(ConversationState.IDLE)
async def add_assistant_turn(self, text: str, was_interrupted: bool = False) -> None:
"""Append an assistant turn directly without mutating conversation state."""
content = text.strip()
if not content:
return
turn = ConversationTurn(
role="assistant",
text=content,
was_interrupted=was_interrupted,
)
self.turns.append(turn)
for callback in self._turn_callbacks:
try:
await callback(turn)
except Exception as e:
logger.error(f"Turn callback error: {e}")
logger.info(f"Assistant (injected): {content[:50]}...")
async def interrupt(self) -> None:
"""Handle interruption (barge-in)."""
if self.state == ConversationState.SPEAKING:
await self.end_assistant_turn(was_interrupted=True)
def reset(self) -> None:
"""Reset conversation history."""
self.turns = []
self._current_user_text = ""
self._current_assistant_text = ""
self.state = ConversationState.IDLE
logger.info("Conversation reset")
@property
def turn_count(self) -> int:
"""Get number of turns in conversation."""
return len(self.turns)
@property
def last_user_text(self) -> Optional[str]:
"""Get last user text."""
for turn in reversed(self.turns):
if turn.role == "user":
return turn.text
return None
@property
def last_assistant_text(self) -> Optional[str]:
"""Get last assistant text."""
for turn in reversed(self.turns):
if turn.role == "assistant":
return turn.text
return None
def get_context_summary(self) -> Dict[str, Any]:
"""Get a summary of conversation context."""
return {
"state": self.state.value,
"turn_count": self.turn_count,
"last_user": self.last_user_text,
"last_assistant": self.last_assistant_text,
"current_user": self._current_user_text or None,
"current_assistant": self._current_assistant_text or None
}

1507
core/duplex_pipeline.py Normal file

File diff suppressed because it is too large Load Diff

134
core/events.py Normal file
View File

@@ -0,0 +1,134 @@
"""Event bus for pub/sub communication between components."""
import asyncio
from typing import Callable, Dict, List, Any, Optional
from collections import defaultdict
from loguru import logger
class EventBus:
"""
Async event bus for pub/sub communication.
Similar to the original Rust implementation's broadcast channel.
Components can subscribe to specific event types and receive events asynchronously.
"""
def __init__(self):
"""Initialize the event bus."""
self._subscribers: Dict[str, List[Callable]] = defaultdict(list)
self._lock = asyncio.Lock()
self._running = True
def subscribe(self, event_type: str, callback: Callable[[Dict[str, Any]], None]) -> None:
"""
Subscribe to an event type.
Args:
event_type: Type of event to subscribe to (e.g., "speaking", "silence")
callback: Async callback function that receives event data
"""
if not self._running:
logger.warning(f"Event bus is shut down, ignoring subscription to {event_type}")
return
self._subscribers[event_type].append(callback)
logger.debug(f"Subscribed to event type: {event_type}")
def unsubscribe(self, event_type: str, callback: Callable[[Dict[str, Any]], None]) -> None:
"""
Unsubscribe from an event type.
Args:
event_type: Type of event to unsubscribe from
callback: Callback function to remove
"""
if callback in self._subscribers[event_type]:
self._subscribers[event_type].remove(callback)
logger.debug(f"Unsubscribed from event type: {event_type}")
async def publish(self, event_type: str, event_data: Dict[str, Any]) -> None:
"""
Publish an event to all subscribers.
Args:
event_type: Type of event to publish
event_data: Event data to send to subscribers
"""
if not self._running:
logger.warning(f"Event bus is shut down, ignoring event: {event_type}")
return
# Get subscribers for this event type
subscribers = self._subscribers.get(event_type, [])
if not subscribers:
logger.debug(f"No subscribers for event type: {event_type}")
return
# Notify all subscribers concurrently
tasks = []
for callback in subscribers:
try:
# Create task for each subscriber
task = asyncio.create_task(self._call_subscriber(callback, event_data))
tasks.append(task)
except Exception as e:
logger.error(f"Error creating task for subscriber: {e}")
# Wait for all subscribers to complete
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
logger.debug(f"Published event '{event_type}' to {len(tasks)} subscribers")
async def _call_subscriber(self, callback: Callable[[Dict[str, Any]], None], event_data: Dict[str, Any]) -> None:
"""
Call a subscriber callback with error handling.
Args:
callback: Subscriber callback function
event_data: Event data to pass to callback
"""
try:
# Check if callback is a coroutine function
if asyncio.iscoroutinefunction(callback):
await callback(event_data)
else:
callback(event_data)
except Exception as e:
logger.error(f"Error in subscriber callback: {e}", exc_info=True)
async def close(self) -> None:
"""Close the event bus and stop processing events."""
self._running = False
self._subscribers.clear()
logger.info("Event bus closed")
@property
def is_running(self) -> bool:
"""Check if the event bus is running."""
return self._running
# Global event bus instance
_event_bus: Optional[EventBus] = None
def get_event_bus() -> EventBus:
"""
Get the global event bus instance.
Returns:
EventBus instance
"""
global _event_bus
if _event_bus is None:
_event_bus = EventBus()
return _event_bus
def reset_event_bus() -> None:
"""Reset the global event bus (mainly for testing)."""
global _event_bus
_event_bus = None

648
core/session.py Normal file
View File

@@ -0,0 +1,648 @@
"""Session management for active calls."""
import asyncio
import uuid
import json
import time
import re
from enum import Enum
from typing import Optional, Dict, Any, List
from loguru import logger
from app.backend_client import (
create_history_call_record,
add_history_transcript,
finalize_history_call_record,
)
from core.transports import BaseTransport
from core.duplex_pipeline import DuplexPipeline
from core.conversation import ConversationTurn
from core.workflow_runner import WorkflowRunner, WorkflowTransition, WorkflowNodeDef, WorkflowEdgeDef
from app.config import settings
from services.base import LLMMessage
from models.ws_v1 import (
parse_client_message,
ev,
HelloMessage,
SessionStartMessage,
SessionStopMessage,
InputTextMessage,
ResponseCancelMessage,
ToolCallResultsMessage,
)
class WsSessionState(str, Enum):
"""Protocol state machine for WS sessions."""
WAIT_HELLO = "wait_hello"
WAIT_START = "wait_start"
ACTIVE = "active"
STOPPED = "stopped"
class Session:
"""
Manages a single call session.
Handles command routing, audio processing, and session lifecycle.
Uses full duplex voice conversation pipeline.
"""
def __init__(self, session_id: str, transport: BaseTransport, use_duplex: bool = None):
"""
Initialize session.
Args:
session_id: Unique session identifier
transport: Transport instance for communication
use_duplex: Whether to use duplex pipeline (defaults to settings.duplex_enabled)
"""
self.id = session_id
self.transport = transport
self.use_duplex = use_duplex if use_duplex is not None else settings.duplex_enabled
self.pipeline = DuplexPipeline(
transport=transport,
session_id=session_id,
system_prompt=settings.duplex_system_prompt,
greeting=settings.duplex_greeting
)
# Session state
self.created_at = None
self.state = "created" # Legacy call state for /call/lists
self.ws_state = WsSessionState.WAIT_HELLO
self._pipeline_started = False
self.protocol_version: Optional[str] = None
self.authenticated: bool = False
# Track IDs
self.current_track_id: Optional[str] = str(uuid.uuid4())
self._history_call_id: Optional[str] = None
self._history_turn_index: int = 0
self._history_call_started_mono: Optional[float] = None
self._history_finalized: bool = False
self._cleanup_lock = asyncio.Lock()
self._cleaned_up = False
self.workflow_runner: Optional[WorkflowRunner] = None
self._workflow_last_user_text: str = ""
self._workflow_initial_node: Optional[WorkflowNodeDef] = None
self.pipeline.conversation.on_turn_complete(self._on_turn_complete)
logger.info(f"Session {self.id} created (duplex={self.use_duplex})")
async def handle_text(self, text_data: str) -> None:
"""
Handle incoming text data (WS v1 JSON control messages).
Args:
text_data: JSON text data
"""
try:
data = json.loads(text_data)
message = parse_client_message(data)
await self._handle_v1_message(message)
except json.JSONDecodeError as e:
logger.error(f"Session {self.id} JSON decode error: {e}")
await self._send_error("client", f"Invalid JSON: {e}", "protocol.invalid_json")
except ValueError as e:
logger.error(f"Session {self.id} command parse error: {e}")
await self._send_error("client", f"Invalid message: {e}", "protocol.invalid_message")
except Exception as e:
logger.error(f"Session {self.id} handle_text error: {e}", exc_info=True)
await self._send_error("server", f"Internal error: {e}", "server.internal")
async def handle_audio(self, audio_bytes: bytes) -> None:
"""
Handle incoming audio data.
Args:
audio_bytes: PCM audio data
"""
if self.ws_state != WsSessionState.ACTIVE:
await self._send_error(
"client",
"Audio received before session.start",
"protocol.order",
)
return
try:
await self.pipeline.process_audio(audio_bytes)
except Exception as e:
logger.error(f"Session {self.id} handle_audio error: {e}", exc_info=True)
async def _handle_v1_message(self, message: Any) -> None:
"""Route validated WS v1 message to handlers."""
msg_type = message.type
logger.info(f"Session {self.id} received message: {msg_type}")
if isinstance(message, HelloMessage):
await self._handle_hello(message)
return
# All messages below require hello handshake first
if self.ws_state == WsSessionState.WAIT_HELLO:
await self._send_error(
"client",
"Expected hello message first",
"protocol.order",
)
return
if isinstance(message, SessionStartMessage):
await self._handle_session_start(message)
return
# All messages below require active session
if self.ws_state != WsSessionState.ACTIVE:
await self._send_error(
"client",
f"Message '{msg_type}' requires active session",
"protocol.order",
)
return
if isinstance(message, InputTextMessage):
await self.pipeline.process_text(message.text)
elif isinstance(message, ResponseCancelMessage):
if message.graceful:
logger.info(f"Session {self.id} graceful response.cancel")
else:
await self.pipeline.interrupt()
elif isinstance(message, ToolCallResultsMessage):
await self.pipeline.handle_tool_call_results(message.results)
elif isinstance(message, SessionStopMessage):
await self._handle_session_stop(message.reason)
else:
await self._send_error("client", f"Unsupported message type: {msg_type}", "protocol.unsupported")
async def _handle_hello(self, message: HelloMessage) -> None:
"""Handle initial hello/auth/version negotiation."""
if self.ws_state != WsSessionState.WAIT_HELLO:
await self._send_error("client", "Duplicate hello", "protocol.order")
return
if message.version != settings.ws_protocol_version:
await self._send_error(
"client",
f"Unsupported protocol version '{message.version}'",
"protocol.version_unsupported",
)
await self.transport.close()
self.ws_state = WsSessionState.STOPPED
return
auth_payload = message.auth or {}
api_key = auth_payload.get("apiKey")
jwt = auth_payload.get("jwt")
if settings.ws_api_key:
if api_key != settings.ws_api_key:
await self._send_error("auth", "Invalid API key", "auth.invalid_api_key")
await self.transport.close()
self.ws_state = WsSessionState.STOPPED
return
elif settings.ws_require_auth and not (api_key or jwt):
await self._send_error("auth", "Authentication required", "auth.required")
await self.transport.close()
self.ws_state = WsSessionState.STOPPED
return
self.authenticated = True
self.protocol_version = message.version
self.ws_state = WsSessionState.WAIT_START
await self.transport.send_event(
ev(
"hello.ack",
sessionId=self.id,
version=self.protocol_version,
)
)
async def _handle_session_start(self, message: SessionStartMessage) -> None:
"""Handle explicit session start after successful hello."""
if self.ws_state != WsSessionState.WAIT_START:
await self._send_error("client", "Duplicate session.start", "protocol.order")
return
metadata = message.metadata or {}
metadata = self._merge_runtime_metadata(metadata, self._bootstrap_workflow(metadata))
# Create history call record early so later turn callbacks can append transcripts.
await self._start_history_bridge(metadata)
# Apply runtime service/prompt overrides from backend if provided
self.pipeline.apply_runtime_overrides(metadata)
# Start duplex pipeline
if not self._pipeline_started:
await self.pipeline.start()
self._pipeline_started = True
logger.info(f"Session {self.id} duplex pipeline started")
self.state = "accepted"
self.ws_state = WsSessionState.ACTIVE
await self.transport.send_event(
ev(
"session.started",
sessionId=self.id,
trackId=self.current_track_id,
audio=message.audio or {},
)
)
if self.workflow_runner and self._workflow_initial_node:
await self.transport.send_event(
ev(
"workflow.started",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
workflowName=self.workflow_runner.name,
nodeId=self._workflow_initial_node.id,
)
)
await self.transport.send_event(
ev(
"workflow.node.entered",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
nodeId=self._workflow_initial_node.id,
nodeName=self._workflow_initial_node.name,
nodeType=self._workflow_initial_node.node_type,
)
)
async def _handle_session_stop(self, reason: Optional[str]) -> None:
"""Handle session stop."""
if self.ws_state == WsSessionState.STOPPED:
return
stop_reason = reason or "client_requested"
self.state = "hungup"
self.ws_state = WsSessionState.STOPPED
await self.transport.send_event(
ev(
"session.stopped",
sessionId=self.id,
reason=stop_reason,
)
)
await self._finalize_history(status="connected")
await self.transport.close()
async def _send_error(self, sender: str, error_message: str, code: str) -> None:
"""
Send error event to client.
Args:
sender: Component that generated the error
error_message: Error message
code: Machine-readable error code
"""
await self.transport.send_event(
ev(
"error",
sender=sender,
code=code,
message=error_message,
trackId=self.current_track_id,
)
)
def _get_timestamp_ms(self) -> int:
"""Get current timestamp in milliseconds."""
import time
return int(time.time() * 1000)
async def cleanup(self) -> None:
"""Cleanup session resources."""
async with self._cleanup_lock:
if self._cleaned_up:
return
self._cleaned_up = True
logger.info(f"Session {self.id} cleaning up")
await self._finalize_history(status="connected")
await self.pipeline.cleanup()
await self.transport.close()
async def _start_history_bridge(self, metadata: Dict[str, Any]) -> None:
"""Initialize backend history call record for this session."""
if self._history_call_id:
return
history_meta: Dict[str, Any] = {}
if isinstance(metadata.get("history"), dict):
history_meta = metadata["history"]
raw_user_id = history_meta.get("userId", metadata.get("userId", settings.history_default_user_id))
try:
user_id = int(raw_user_id)
except (TypeError, ValueError):
user_id = settings.history_default_user_id
assistant_id = history_meta.get("assistantId", metadata.get("assistantId"))
source = str(history_meta.get("source", metadata.get("source", "debug")))
call_id = await create_history_call_record(
user_id=user_id,
assistant_id=str(assistant_id) if assistant_id else None,
source=source,
)
if not call_id:
return
self._history_call_id = call_id
self._history_call_started_mono = time.monotonic()
self._history_turn_index = 0
self._history_finalized = False
logger.info(f"Session {self.id} history bridge enabled (call_id={call_id}, source={source})")
async def _on_turn_complete(self, turn: ConversationTurn) -> None:
"""Process workflow transitions and persist completed turns to history."""
if turn.text and turn.text.strip():
role = (turn.role or "").lower()
if role == "user":
self._workflow_last_user_text = turn.text.strip()
elif role == "assistant":
await self._maybe_advance_workflow(turn.text.strip())
if not self._history_call_id:
return
if not turn.text or not turn.text.strip():
return
role = (turn.role or "").lower()
speaker = "human" if role == "user" else "ai"
end_ms = 0
if self._history_call_started_mono is not None:
end_ms = max(0, int((time.monotonic() - self._history_call_started_mono) * 1000))
estimated_duration_ms = max(300, min(12000, len(turn.text.strip()) * 80))
start_ms = max(0, end_ms - estimated_duration_ms)
turn_index = self._history_turn_index
await add_history_transcript(
call_id=self._history_call_id,
turn_index=turn_index,
speaker=speaker,
content=turn.text.strip(),
start_ms=start_ms,
end_ms=end_ms,
duration_ms=max(1, end_ms - start_ms),
)
self._history_turn_index += 1
async def _finalize_history(self, status: str) -> None:
"""Finalize history call record once."""
if not self._history_call_id or self._history_finalized:
return
duration_seconds = 0
if self._history_call_started_mono is not None:
duration_seconds = max(0, int(time.monotonic() - self._history_call_started_mono))
ok = await finalize_history_call_record(
call_id=self._history_call_id,
status=status,
duration_seconds=duration_seconds,
)
if ok:
self._history_finalized = True
def _bootstrap_workflow(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
"""Parse workflow payload and return initial runtime overrides."""
payload = metadata.get("workflow")
self.workflow_runner = WorkflowRunner.from_payload(payload)
self._workflow_initial_node = None
if not self.workflow_runner:
return {}
node = self.workflow_runner.bootstrap()
if not node:
logger.warning(f"Session {self.id} workflow payload had no resolvable start node")
self.workflow_runner = None
return {}
self._workflow_initial_node = node
logger.info(
"Session {} workflow enabled: workflow={} start_node={}",
self.id,
self.workflow_runner.workflow_id,
node.id,
)
return self.workflow_runner.build_runtime_metadata(node)
async def _maybe_advance_workflow(self, assistant_text: str) -> None:
"""Attempt node transfer after assistant turn finalization."""
if not self.workflow_runner or self.ws_state == WsSessionState.STOPPED:
return
transition = await self.workflow_runner.route(
user_text=self._workflow_last_user_text,
assistant_text=assistant_text,
llm_router=self._workflow_llm_route,
)
if not transition:
return
await self._apply_workflow_transition(transition, reason="rule_match")
# Auto-advance through utility nodes when default edges are present.
max_auto_hops = 6
auto_hops = 0
while self.workflow_runner and self.ws_state != WsSessionState.STOPPED:
current = self.workflow_runner.current_node
if not current or current.node_type not in {"start", "tool"}:
break
next_default = self.workflow_runner.next_default_transition()
if not next_default:
break
auto_hops += 1
await self._apply_workflow_transition(next_default, reason="auto")
if auto_hops >= max_auto_hops:
logger.warning(
"Session {} workflow auto-advance reached hop limit (possible cycle)",
self.id,
)
break
async def _apply_workflow_transition(self, transition: WorkflowTransition, reason: str) -> None:
"""Apply graph transition and emit workflow lifecycle events."""
if not self.workflow_runner:
return
self.workflow_runner.apply_transition(transition)
node = transition.node
edge = transition.edge
await self.transport.send_event(
ev(
"workflow.edge.taken",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
edgeId=edge.id,
fromNodeId=edge.from_node_id,
toNodeId=edge.to_node_id,
reason=reason,
)
)
await self.transport.send_event(
ev(
"workflow.node.entered",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
nodeId=node.id,
nodeName=node.name,
nodeType=node.node_type,
)
)
node_runtime = self.workflow_runner.build_runtime_metadata(node)
if node_runtime:
self.pipeline.apply_runtime_overrides(node_runtime)
if node.node_type == "tool":
await self.transport.send_event(
ev(
"workflow.tool.requested",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
nodeId=node.id,
tool=node.tool or {},
)
)
return
if node.node_type == "human_transfer":
await self.transport.send_event(
ev(
"workflow.human_transfer",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
nodeId=node.id,
)
)
await self._handle_session_stop("workflow_human_transfer")
return
if node.node_type == "end":
await self.transport.send_event(
ev(
"workflow.ended",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
nodeId=node.id,
)
)
await self._handle_session_stop("workflow_end")
async def _workflow_llm_route(
self,
node: WorkflowNodeDef,
candidates: List[WorkflowEdgeDef],
context: Dict[str, str],
) -> Optional[str]:
"""LLM-based edge routing for condition.type == 'llm' edges."""
llm_service = self.pipeline.llm_service
if not llm_service:
return None
candidate_rows = [
{
"edgeId": edge.id,
"toNodeId": edge.to_node_id,
"label": edge.label,
"hint": edge.condition.get("prompt") if isinstance(edge.condition, dict) else None,
}
for edge in candidates
]
system_prompt = (
"You are a workflow router. Pick exactly one edge. "
"Return JSON only: {\"edgeId\":\"...\"}."
)
user_prompt = json.dumps(
{
"nodeId": node.id,
"nodeName": node.name,
"userText": context.get("userText", ""),
"assistantText": context.get("assistantText", ""),
"candidates": candidate_rows,
},
ensure_ascii=False,
)
try:
reply = await llm_service.generate(
[
LLMMessage(role="system", content=system_prompt),
LLMMessage(role="user", content=user_prompt),
],
temperature=0.0,
max_tokens=64,
)
except Exception as exc:
logger.warning(f"Session {self.id} workflow llm routing failed: {exc}")
return None
if not reply:
return None
edge_ids = {edge.id for edge in candidates}
node_ids = {edge.to_node_id for edge in candidates}
parsed = self._extract_json_obj(reply)
if isinstance(parsed, dict):
edge_id = parsed.get("edgeId") or parsed.get("id")
node_id = parsed.get("toNodeId") or parsed.get("nodeId")
if isinstance(edge_id, str) and edge_id in edge_ids:
return edge_id
if isinstance(node_id, str) and node_id in node_ids:
return node_id
token_candidates = sorted(edge_ids | node_ids, key=len, reverse=True)
lowered_reply = reply.lower()
for token in token_candidates:
if token.lower() in lowered_reply:
return token
return None
def _merge_runtime_metadata(self, base: Dict[str, Any], overrides: Dict[str, Any]) -> Dict[str, Any]:
"""Merge node-level metadata overrides into session.start metadata."""
merged = dict(base or {})
if not overrides:
return merged
for key, value in overrides.items():
if key == "services" and isinstance(value, dict):
existing = merged.get("services")
merged_services = dict(existing) if isinstance(existing, dict) else {}
merged_services.update(value)
merged["services"] = merged_services
else:
merged[key] = value
return merged
def _extract_json_obj(self, text: str) -> Optional[Dict[str, Any]]:
"""Best-effort extraction of a JSON object from freeform text."""
try:
parsed = json.loads(text)
if isinstance(parsed, dict):
return parsed
except Exception:
pass
match = re.search(r"\{.*\}", text, re.DOTALL)
if not match:
return None
try:
parsed = json.loads(match.group(0))
return parsed if isinstance(parsed, dict) else None
except Exception:
return None

340
core/tool_executor.py Normal file
View File

@@ -0,0 +1,340 @@
"""Server-side tool execution helpers."""
import asyncio
import ast
import operator
from datetime import datetime
from typing import Any, Dict
import aiohttp
from app.backend_client import fetch_tool_resource
_BIN_OPS = {
ast.Add: operator.add,
ast.Sub: operator.sub,
ast.Mult: operator.mul,
ast.Div: operator.truediv,
ast.Mod: operator.mod,
}
_UNARY_OPS = {
ast.UAdd: operator.pos,
ast.USub: operator.neg,
}
_SAFE_EVAL_FUNCS = {
"abs": abs,
"round": round,
"min": min,
"max": max,
"sum": sum,
"len": len,
}
def _validate_safe_expr(node: ast.AST) -> None:
"""Allow only a constrained subset of Python expression nodes."""
if isinstance(node, ast.Expression):
_validate_safe_expr(node.body)
return
if isinstance(node, ast.Constant):
return
if isinstance(node, (ast.List, ast.Tuple, ast.Set)):
for elt in node.elts:
_validate_safe_expr(elt)
return
if isinstance(node, ast.Dict):
for key in node.keys:
if key is not None:
_validate_safe_expr(key)
for value in node.values:
_validate_safe_expr(value)
return
if isinstance(node, ast.BinOp):
if type(node.op) not in _BIN_OPS:
raise ValueError("unsupported operator")
_validate_safe_expr(node.left)
_validate_safe_expr(node.right)
return
if isinstance(node, ast.UnaryOp):
if type(node.op) not in _UNARY_OPS:
raise ValueError("unsupported unary operator")
_validate_safe_expr(node.operand)
return
if isinstance(node, ast.BoolOp):
for value in node.values:
_validate_safe_expr(value)
return
if isinstance(node, ast.Compare):
_validate_safe_expr(node.left)
for comp in node.comparators:
_validate_safe_expr(comp)
return
if isinstance(node, ast.Name):
if node.id not in _SAFE_EVAL_FUNCS and node.id not in {"True", "False", "None"}:
raise ValueError("unknown symbol")
return
if isinstance(node, ast.Call):
if not isinstance(node.func, ast.Name):
raise ValueError("unsafe call target")
if node.func.id not in _SAFE_EVAL_FUNCS:
raise ValueError("function not allowed")
for arg in node.args:
_validate_safe_expr(arg)
for kw in node.keywords:
_validate_safe_expr(kw.value)
return
# Explicitly reject high-risk nodes (import/attribute/subscript/comprehensions/lambda, etc.)
raise ValueError("unsupported expression")
def _safe_eval_python_expr(expression: str) -> Any:
tree = ast.parse(expression, mode="eval")
_validate_safe_expr(tree)
return eval( # noqa: S307 - validated AST + empty builtins
compile(tree, "<code_interpreter>", "eval"),
{"__builtins__": {}},
dict(_SAFE_EVAL_FUNCS),
)
def _json_safe(value: Any) -> Any:
if isinstance(value, (str, int, float, bool)) or value is None:
return value
if isinstance(value, (list, tuple)):
return [_json_safe(v) for v in value]
if isinstance(value, dict):
return {str(k): _json_safe(v) for k, v in value.items()}
return repr(value)
def _safe_eval_expr(expression: str) -> float:
tree = ast.parse(expression, mode="eval")
def _eval(node: ast.AST) -> float:
if isinstance(node, ast.Expression):
return _eval(node.body)
if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)):
return float(node.value)
if isinstance(node, ast.BinOp):
op = _BIN_OPS.get(type(node.op))
if not op:
raise ValueError("unsupported operator")
return float(op(_eval(node.left), _eval(node.right)))
if isinstance(node, ast.UnaryOp):
op = _UNARY_OPS.get(type(node.op))
if not op:
raise ValueError("unsupported unary operator")
return float(op(_eval(node.operand)))
raise ValueError("unsupported expression")
return _eval(tree)
def _extract_tool_name(tool_call: Dict[str, Any]) -> str:
function_payload = tool_call.get("function")
if isinstance(function_payload, dict):
return str(function_payload.get("name") or "").strip()
return ""
def _extract_tool_args(tool_call: Dict[str, Any]) -> Dict[str, Any]:
function_payload = tool_call.get("function")
if not isinstance(function_payload, dict):
return {}
raw = function_payload.get("arguments")
if isinstance(raw, dict):
return raw
if not isinstance(raw, str):
return {}
text = raw.strip()
if not text:
return {}
try:
import json
parsed = json.loads(text)
return parsed if isinstance(parsed, dict) else {}
except Exception:
return {}
async def execute_server_tool(tool_call: Dict[str, Any]) -> Dict[str, Any]:
"""Execute a server-side tool and return normalized result payload."""
call_id = str(tool_call.get("id") or "").strip()
tool_name = _extract_tool_name(tool_call)
args = _extract_tool_args(tool_call)
if tool_name == "calculator":
expression = str(args.get("expression") or "").strip()
if not expression:
return {
"tool_call_id": call_id,
"name": tool_name,
"output": {"error": "missing expression"},
"status": {"code": 400, "message": "bad_request"},
}
if len(expression) > 200:
return {
"tool_call_id": call_id,
"name": tool_name,
"output": {"expression": expression, "error": "expression too long"},
"status": {"code": 422, "message": "invalid_expression"},
}
try:
value = _safe_eval_expr(expression)
if value.is_integer():
value = int(value)
return {
"tool_call_id": call_id,
"name": tool_name,
"output": {"expression": expression, "result": value},
"status": {"code": 200, "message": "ok"},
}
except Exception as exc:
return {
"tool_call_id": call_id,
"name": tool_name,
"output": {"expression": expression, "error": str(exc)},
"status": {"code": 422, "message": "invalid_expression"},
}
if tool_name == "code_interpreter":
code = str(args.get("code") or args.get("expression") or "").strip()
if not code:
return {
"tool_call_id": call_id,
"name": tool_name,
"output": {"error": "missing code"},
"status": {"code": 400, "message": "bad_request"},
}
if len(code) > 500:
return {
"tool_call_id": call_id,
"name": tool_name,
"output": {"error": "code too long"},
"status": {"code": 422, "message": "invalid_code"},
}
try:
result = _safe_eval_python_expr(code)
return {
"tool_call_id": call_id,
"name": tool_name,
"output": {"code": code, "result": _json_safe(result)},
"status": {"code": 200, "message": "ok"},
}
except Exception as exc:
return {
"tool_call_id": call_id,
"name": tool_name,
"output": {"code": code, "error": str(exc)},
"status": {"code": 422, "message": "invalid_code"},
}
if tool_name == "current_time":
now = datetime.now().astimezone()
return {
"tool_call_id": call_id,
"name": tool_name,
"output": {
"local_time": now.strftime("%Y-%m-%d %H:%M:%S"),
"iso": now.isoformat(),
"timezone": str(now.tzinfo or ""),
"timestamp": int(now.timestamp()),
},
"status": {"code": 200, "message": "ok"},
}
if tool_name and tool_name not in {"calculator", "code_interpreter", "current_time"}:
resource = await fetch_tool_resource(tool_name)
if resource and str(resource.get("category") or "") == "query":
method = str(resource.get("http_method") or "GET").strip().upper()
if method not in {"GET", "POST", "PUT", "PATCH", "DELETE"}:
method = "GET"
url = str(resource.get("http_url") or "").strip()
headers = resource.get("http_headers") if isinstance(resource.get("http_headers"), dict) else {}
timeout_ms = resource.get("http_timeout_ms")
try:
timeout_s = max(1.0, float(timeout_ms) / 1000.0)
except Exception:
timeout_s = 10.0
if not url:
return {
"tool_call_id": call_id,
"name": tool_name,
"output": {"error": "http_url not configured"},
"status": {"code": 422, "message": "invalid_tool_config"},
}
request_kwargs: Dict[str, Any] = {}
if method in {"GET", "DELETE"}:
request_kwargs["params"] = args
else:
request_kwargs["json"] = args
try:
timeout = aiohttp.ClientTimeout(total=timeout_s)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.request(method, url, headers=headers, **request_kwargs) as resp:
content_type = str(resp.headers.get("Content-Type") or "").lower()
if "application/json" in content_type:
body: Any = await resp.json()
else:
body = await resp.text()
status_code = int(resp.status)
if 200 <= status_code < 300:
return {
"tool_call_id": call_id,
"name": tool_name,
"output": {
"method": method,
"url": url,
"status_code": status_code,
"response": _json_safe(body),
},
"status": {"code": 200, "message": "ok"},
}
return {
"tool_call_id": call_id,
"name": tool_name,
"output": {
"method": method,
"url": url,
"status_code": status_code,
"response": _json_safe(body),
},
"status": {"code": status_code, "message": "http_error"},
}
except asyncio.TimeoutError:
return {
"tool_call_id": call_id,
"name": tool_name,
"output": {"method": method, "url": url, "error": "request timeout"},
"status": {"code": 504, "message": "http_timeout"},
}
except Exception as exc:
return {
"tool_call_id": call_id,
"name": tool_name,
"output": {"method": method, "url": url, "error": str(exc)},
"status": {"code": 502, "message": "http_request_failed"},
}
return {
"tool_call_id": call_id,
"name": tool_name or "unknown_tool",
"output": {"message": "server tool not implemented"},
"status": {"code": 501, "message": "not_implemented"},
}

247
core/transports.py Normal file
View File

@@ -0,0 +1,247 @@
"""Transport layer for WebSocket and WebRTC communication."""
import asyncio
import json
from abc import ABC, abstractmethod
from typing import Optional
from fastapi import WebSocket
from starlette.websockets import WebSocketState
from loguru import logger
# Try to import aiortc (optional for WebRTC functionality)
try:
from aiortc import RTCPeerConnection
AIORTC_AVAILABLE = True
except ImportError:
AIORTC_AVAILABLE = False
RTCPeerConnection = None # Type hint placeholder
class BaseTransport(ABC):
"""
Abstract base class for transports.
All transports must implement send_event and send_audio methods.
"""
@abstractmethod
async def send_event(self, event: dict) -> None:
"""
Send a JSON event to the client.
Args:
event: Event data as dictionary
"""
pass
@abstractmethod
async def send_audio(self, pcm_bytes: bytes) -> None:
"""
Send audio data to the client.
Args:
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
"""
pass
@abstractmethod
async def close(self) -> None:
"""Close the transport and cleanup resources."""
pass
class SocketTransport(BaseTransport):
"""
WebSocket transport for raw audio streaming.
Handles mixed text/binary frames over WebSocket connection.
Uses asyncio.Lock to prevent frame interleaving.
"""
def __init__(self, websocket: WebSocket):
"""
Initialize WebSocket transport.
Args:
websocket: FastAPI WebSocket connection
"""
self.ws = websocket
self.lock = asyncio.Lock() # Prevent frame interleaving
self._closed = False
def _ws_disconnected(self) -> bool:
"""Best-effort check for websocket disconnection state."""
return (
self.ws.client_state == WebSocketState.DISCONNECTED
or self.ws.application_state == WebSocketState.DISCONNECTED
)
async def send_event(self, event: dict) -> None:
"""
Send a JSON event via WebSocket.
Args:
event: Event data as dictionary
"""
if self._closed or self._ws_disconnected():
logger.warning("Attempted to send event on closed transport")
self._closed = True
return
async with self.lock:
try:
await self.ws.send_text(json.dumps(event))
logger.debug(f"Sent event: {event.get('event', 'unknown')}")
except RuntimeError as e:
self._closed = True
if self._ws_disconnected() or "close message has been sent" in str(e):
logger.debug(f"Skip sending event on closed websocket: {e!r}")
return
logger.error(f"Error sending event: {e!r}")
except Exception as e:
self._closed = True
if self._ws_disconnected():
logger.debug(f"Skip sending event on disconnected websocket: {e!r}")
return
logger.error(f"Error sending event: {e!r}")
async def send_audio(self, pcm_bytes: bytes) -> None:
"""
Send PCM audio data via WebSocket.
Args:
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
"""
if self._closed or self._ws_disconnected():
logger.warning("Attempted to send audio on closed transport")
self._closed = True
return
async with self.lock:
try:
await self.ws.send_bytes(pcm_bytes)
except RuntimeError as e:
self._closed = True
if self._ws_disconnected() or "close message has been sent" in str(e):
logger.debug(f"Skip sending audio on closed websocket: {e!r}")
return
logger.error(f"Error sending audio: {e!r}")
except Exception as e:
self._closed = True
if self._ws_disconnected():
logger.debug(f"Skip sending audio on disconnected websocket: {e!r}")
return
logger.error(f"Error sending audio: {e!r}")
async def close(self) -> None:
"""Close the WebSocket connection."""
if self._closed:
return
self._closed = True
if self._ws_disconnected():
return
try:
await self.ws.close()
except RuntimeError as e:
# Already closed by another task/path; safe to ignore.
if "close message has been sent" in str(e):
logger.debug(f"WebSocket already closed: {e}")
return
logger.error(f"Error closing WebSocket: {e}")
except Exception as e:
logger.error(f"Error closing WebSocket: {e}")
@property
def is_closed(self) -> bool:
"""Check if the transport is closed."""
return self._closed
class WebRtcTransport(BaseTransport):
"""
WebRTC transport for WebRTC audio streaming.
Uses WebSocket for signaling and RTCPeerConnection for media.
"""
def __init__(self, websocket: WebSocket, pc):
"""
Initialize WebRTC transport.
Args:
websocket: FastAPI WebSocket connection for signaling
pc: RTCPeerConnection for media transport
"""
if not AIORTC_AVAILABLE:
raise RuntimeError("aiortc is not available - WebRTC transport cannot be used")
self.ws = websocket
self.pc = pc
self.outbound_track = None # MediaStreamTrack for outbound audio
self._closed = False
async def send_event(self, event: dict) -> None:
"""
Send a JSON event via WebSocket signaling.
Args:
event: Event data as dictionary
"""
if self._closed:
logger.warning("Attempted to send event on closed transport")
return
try:
await self.ws.send_text(json.dumps(event))
logger.debug(f"Sent event: {event.get('event', 'unknown')}")
except Exception as e:
logger.error(f"Error sending event: {e}")
self._closed = True
async def send_audio(self, pcm_bytes: bytes) -> None:
"""
Send audio data via WebRTC track.
Note: In WebRTC, you don't send bytes directly. You push frames
to a MediaStreamTrack that the peer connection is reading.
Args:
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
"""
if self._closed:
logger.warning("Attempted to send audio on closed transport")
return
# This would require a custom MediaStreamTrack implementation
# For now, we'll log this as a placeholder
logger.debug(f"Audio bytes queued for WebRTC track: {len(pcm_bytes)} bytes")
# TODO: Implement outbound audio track if needed
# if self.outbound_track:
# await self.outbound_track.add_frame(pcm_bytes)
async def close(self) -> None:
"""Close the WebRTC connection."""
self._closed = True
try:
await self.pc.close()
await self.ws.close()
except Exception as e:
logger.error(f"Error closing WebRTC transport: {e}")
@property
def is_closed(self) -> bool:
"""Check if the transport is closed."""
return self._closed
def set_outbound_track(self, track):
"""
Set the outbound audio track for sending audio to client.
Args:
track: MediaStreamTrack for outbound audio
"""
self.outbound_track = track
logger.debug("Set outbound track for WebRTC transport")

402
core/workflow_runner.py Normal file
View File

@@ -0,0 +1,402 @@
"""Workflow runtime helpers for session-level node routing.
MVP goals:
- Parse workflow graph payload from WS session.start metadata
- Track current node
- Evaluate edge conditions on each assistant turn completion
- Provide per-node runtime metadata overrides (prompt/greeting/services)
"""
from __future__ import annotations
from dataclasses import dataclass
import json
import re
from typing import Any, Awaitable, Callable, Dict, List, Optional
from loguru import logger
_NODE_TYPE_MAP = {
"conversation": "assistant",
"assistant": "assistant",
"human": "human_transfer",
"human_transfer": "human_transfer",
"tool": "tool",
"end": "end",
"start": "start",
}
def _normalize_node_type(raw_type: Any) -> str:
value = str(raw_type or "").strip().lower()
return _NODE_TYPE_MAP.get(value, "assistant")
def _safe_str(value: Any) -> str:
if value is None:
return ""
return str(value)
def _normalize_condition(raw: Any, label: Optional[str]) -> Dict[str, Any]:
if not isinstance(raw, dict):
if label:
return {"type": "contains", "source": "user", "value": str(label)}
return {"type": "always"}
condition = dict(raw)
condition_type = str(condition.get("type", "always")).strip().lower()
if not condition_type:
condition_type = "always"
condition["type"] = condition_type
condition["source"] = str(condition.get("source", "user")).strip().lower() or "user"
return condition
@dataclass
class WorkflowNodeDef:
id: str
name: str
node_type: str
is_start: bool
prompt: Optional[str]
message_plan: Dict[str, Any]
assistant_id: Optional[str]
assistant: Dict[str, Any]
tool: Optional[Dict[str, Any]]
raw: Dict[str, Any]
@dataclass
class WorkflowEdgeDef:
id: str
from_node_id: str
to_node_id: str
label: Optional[str]
condition: Dict[str, Any]
priority: int
order: int
raw: Dict[str, Any]
@dataclass
class WorkflowTransition:
edge: WorkflowEdgeDef
node: WorkflowNodeDef
LlmRouter = Callable[
[WorkflowNodeDef, List[WorkflowEdgeDef], Dict[str, str]],
Awaitable[Optional[str]],
]
class WorkflowRunner:
"""In-memory workflow graph for a single active session."""
def __init__(self, workflow_id: str, name: str, nodes: List[WorkflowNodeDef], edges: List[WorkflowEdgeDef]):
self.workflow_id = workflow_id
self.name = name
self._nodes: Dict[str, WorkflowNodeDef] = {node.id: node for node in nodes}
self._edges = edges
self.current_node_id: Optional[str] = None
@classmethod
def from_payload(cls, payload: Any) -> Optional["WorkflowRunner"]:
if not isinstance(payload, dict):
return None
raw_nodes = payload.get("nodes")
raw_edges = payload.get("edges")
if not isinstance(raw_nodes, list) or len(raw_nodes) == 0:
return None
nodes: List[WorkflowNodeDef] = []
for i, raw in enumerate(raw_nodes):
if not isinstance(raw, dict):
continue
node_id = _safe_str(raw.get("id") or raw.get("name") or f"node_{i + 1}").strip() or f"node_{i + 1}"
node_name = _safe_str(raw.get("name") or node_id).strip() or node_id
node_type = _normalize_node_type(raw.get("type"))
is_start = bool(raw.get("isStart")) or node_type == "start"
prompt: Optional[str] = None
if "prompt" in raw:
prompt = _safe_str(raw.get("prompt"))
message_plan = raw.get("messagePlan")
if not isinstance(message_plan, dict):
message_plan = {}
assistant_cfg = raw.get("assistant")
if not isinstance(assistant_cfg, dict):
assistant_cfg = {}
tool_cfg = raw.get("tool")
if not isinstance(tool_cfg, dict):
tool_cfg = None
assistant_id = raw.get("assistantId")
if assistant_id is not None:
assistant_id = _safe_str(assistant_id).strip() or None
nodes.append(
WorkflowNodeDef(
id=node_id,
name=node_name,
node_type=node_type,
is_start=is_start,
prompt=prompt,
message_plan=message_plan,
assistant_id=assistant_id,
assistant=assistant_cfg,
tool=tool_cfg,
raw=raw,
)
)
if not nodes:
return None
node_ids = {node.id for node in nodes}
edges: List[WorkflowEdgeDef] = []
for i, raw in enumerate(raw_edges if isinstance(raw_edges, list) else []):
if not isinstance(raw, dict):
continue
from_node_id = _safe_str(
raw.get("fromNodeId") or raw.get("from") or raw.get("from_") or raw.get("source")
).strip()
to_node_id = _safe_str(raw.get("toNodeId") or raw.get("to") or raw.get("target")).strip()
if not from_node_id or not to_node_id:
continue
if from_node_id not in node_ids or to_node_id not in node_ids:
continue
label = raw.get("label")
if label is not None:
label = _safe_str(label)
condition = _normalize_condition(raw.get("condition"), label=label)
priority = 100
try:
priority = int(raw.get("priority", 100))
except (TypeError, ValueError):
priority = 100
edge_id = _safe_str(raw.get("id") or f"e_{from_node_id}_{to_node_id}_{i + 1}").strip() or f"e_{i + 1}"
edges.append(
WorkflowEdgeDef(
id=edge_id,
from_node_id=from_node_id,
to_node_id=to_node_id,
label=label,
condition=condition,
priority=priority,
order=i,
raw=raw,
)
)
workflow_id = _safe_str(payload.get("id") or "workflow")
workflow_name = _safe_str(payload.get("name") or workflow_id)
return cls(workflow_id=workflow_id, name=workflow_name, nodes=nodes, edges=edges)
def bootstrap(self) -> Optional[WorkflowNodeDef]:
start_node = self._resolve_start_node()
if not start_node:
return None
self.current_node_id = start_node.id
return start_node
@property
def current_node(self) -> Optional[WorkflowNodeDef]:
if not self.current_node_id:
return None
return self._nodes.get(self.current_node_id)
def outgoing_edges(self, node_id: str) -> List[WorkflowEdgeDef]:
edges = [edge for edge in self._edges if edge.from_node_id == node_id]
return sorted(edges, key=lambda edge: (edge.priority, edge.order))
def next_default_transition(self) -> Optional[WorkflowTransition]:
node = self.current_node
if not node:
return None
for edge in self.outgoing_edges(node.id):
cond_type = str(edge.condition.get("type", "always")).strip().lower()
if cond_type in {"", "always", "default"}:
target = self._nodes.get(edge.to_node_id)
if target:
return WorkflowTransition(edge=edge, node=target)
return None
async def route(
self,
*,
user_text: str,
assistant_text: str,
llm_router: Optional[LlmRouter] = None,
) -> Optional[WorkflowTransition]:
node = self.current_node
if not node:
return None
outgoing = self.outgoing_edges(node.id)
if not outgoing:
return None
llm_edges: List[WorkflowEdgeDef] = []
for edge in outgoing:
cond_type = str(edge.condition.get("type", "always")).strip().lower()
if cond_type == "llm":
llm_edges.append(edge)
continue
if self._matches_condition(edge, user_text=user_text, assistant_text=assistant_text):
target = self._nodes.get(edge.to_node_id)
if target:
return WorkflowTransition(edge=edge, node=target)
if llm_edges and llm_router:
selection = await llm_router(
node,
llm_edges,
{
"userText": user_text,
"assistantText": assistant_text,
},
)
if selection:
for edge in llm_edges:
if selection in {edge.id, edge.to_node_id}:
target = self._nodes.get(edge.to_node_id)
if target:
return WorkflowTransition(edge=edge, node=target)
for edge in outgoing:
cond_type = str(edge.condition.get("type", "always")).strip().lower()
if cond_type in {"", "always", "default"}:
target = self._nodes.get(edge.to_node_id)
if target:
return WorkflowTransition(edge=edge, node=target)
return None
def apply_transition(self, transition: WorkflowTransition) -> None:
self.current_node_id = transition.node.id
def build_runtime_metadata(self, node: WorkflowNodeDef) -> Dict[str, Any]:
assistant_cfg = node.assistant if isinstance(node.assistant, dict) else {}
message_plan = node.message_plan if isinstance(node.message_plan, dict) else {}
metadata: Dict[str, Any] = {}
if node.prompt is not None:
metadata["systemPrompt"] = node.prompt
elif "systemPrompt" in assistant_cfg:
metadata["systemPrompt"] = _safe_str(assistant_cfg.get("systemPrompt"))
elif "prompt" in assistant_cfg:
metadata["systemPrompt"] = _safe_str(assistant_cfg.get("prompt"))
first_message = message_plan.get("firstMessage")
if first_message is not None:
metadata["greeting"] = _safe_str(first_message)
elif "greeting" in assistant_cfg:
metadata["greeting"] = _safe_str(assistant_cfg.get("greeting"))
elif "opener" in assistant_cfg:
metadata["greeting"] = _safe_str(assistant_cfg.get("opener"))
services = assistant_cfg.get("services")
if isinstance(services, dict):
metadata["services"] = services
if node.assistant_id:
metadata["assistantId"] = node.assistant_id
return metadata
def _resolve_start_node(self) -> Optional[WorkflowNodeDef]:
explicit_start = next((node for node in self._nodes.values() if node.is_start), None)
if not explicit_start:
explicit_start = next((node for node in self._nodes.values() if node.node_type == "start"), None)
if explicit_start:
# If a dedicated start node exists, try to move to its first default target.
if explicit_start.node_type == "start":
visited = {explicit_start.id}
current = explicit_start
for _ in range(8):
transition = self._first_default_transition_from(current.id)
if not transition:
return current
current = transition.node
if current.id in visited:
break
visited.add(current.id)
return current
return explicit_start
assistant_node = next((node for node in self._nodes.values() if node.node_type == "assistant"), None)
if assistant_node:
return assistant_node
return next(iter(self._nodes.values()), None)
def _first_default_transition_from(self, node_id: str) -> Optional[WorkflowTransition]:
for edge in self.outgoing_edges(node_id):
cond_type = str(edge.condition.get("type", "always")).strip().lower()
if cond_type in {"", "always", "default"}:
node = self._nodes.get(edge.to_node_id)
if node:
return WorkflowTransition(edge=edge, node=node)
return None
def _matches_condition(self, edge: WorkflowEdgeDef, *, user_text: str, assistant_text: str) -> bool:
condition = edge.condition or {"type": "always"}
cond_type = str(condition.get("type", "always")).strip().lower()
source = str(condition.get("source", "user")).strip().lower()
if cond_type in {"", "always", "default"}:
return True
text = assistant_text if source == "assistant" else user_text
text_lower = (text or "").lower()
if cond_type == "contains":
values: List[str] = []
if isinstance(condition.get("values"), list):
values = [_safe_str(v).strip().lower() for v in condition["values"] if _safe_str(v).strip()]
if not values:
single = _safe_str(condition.get("value") or condition.get("keyword") or edge.label).strip().lower()
if single:
values = [single]
if not values:
return False
return any(value in text_lower for value in values)
if cond_type == "equals":
expected = _safe_str(condition.get("value") or "").strip().lower()
return bool(expected) and text_lower == expected
if cond_type == "regex":
pattern = _safe_str(condition.get("value") or condition.get("pattern") or "").strip()
if not pattern:
return False
try:
return bool(re.search(pattern, text or "", re.IGNORECASE))
except re.error:
logger.warning(f"Invalid workflow regex condition: {pattern}")
return False
if cond_type == "json":
value = _safe_str(condition.get("value") or "").strip()
if not value:
return False
try:
obj = json.loads(text or "")
except Exception:
return False
return str(obj) == value
return False

Binary file not shown.

Binary file not shown.

Binary file not shown.

BIN
data/vad/silero_vad.onnx Normal file

Binary file not shown.

View File

@@ -0,0 +1,96 @@
<svg width="1200" height="620" viewBox="0 0 1200 620" xmlns="http://www.w3.org/2000/svg">
<defs>
<style>
.box { fill:#11131a; stroke:#3a3f4b; stroke-width:1.2; rx:10; ry:10; }
.title { font: 600 14px 'Arial'; fill:#f2f3f7; }
.text { font: 12px 'Arial'; fill:#c8ccd8; }
.arrow { stroke:#7aa2ff; stroke-width:1.6; marker-end:url(#arrow); fill:none; }
.arrow2 { stroke:#2dd4bf; stroke-width:1.6; marker-end:url(#arrow); fill:none; }
.arrow3 { stroke:#ff6b6b; stroke-width:1.6; marker-end:url(#arrow); fill:none; }
.label { font: 11px 'Arial'; fill:#9aa3b2; }
</style>
<marker id="arrow" markerWidth="8" markerHeight="8" refX="7" refY="4" orient="auto">
<path d="M0,0 L8,4 L0,8 Z" fill="#7aa2ff"/>
</marker>
</defs>
<rect x="40" y="40" width="250" height="120" class="box"/>
<text x="60" y="70" class="title">Web Client</text>
<text x="60" y="95" class="text">WS JSON commands</text>
<text x="60" y="115" class="text">WS binary PCM audio</text>
<rect x="350" y="40" width="250" height="120" class="box"/>
<text x="370" y="70" class="title">FastAPI /ws</text>
<text x="370" y="95" class="text">Session + Transport</text>
<rect x="660" y="40" width="250" height="120" class="box"/>
<text x="680" y="70" class="title">DuplexPipeline</text>
<text x="680" y="95" class="text">process_audio / process_text</text>
<rect x="920" y="40" width="240" height="120" class="box"/>
<text x="940" y="70" class="title">ConversationManager</text>
<text x="940" y="95" class="text">turns + state</text>
<rect x="660" y="200" width="180" height="100" class="box"/>
<text x="680" y="230" class="title">VADProcessor</text>
<text x="680" y="255" class="text">speech/silence</text>
<rect x="860" y="200" width="180" height="100" class="box"/>
<text x="880" y="230" class="title">EOU Detector</text>
<text x="880" y="255" class="text">end-of-utterance</text>
<rect x="1060" y="200" width="120" height="100" class="box"/>
<text x="1075" y="230" class="title">ASR</text>
<text x="1075" y="255" class="text">transcripts</text>
<rect x="920" y="350" width="240" height="110" class="box"/>
<text x="940" y="380" class="title">LLM (stream)</text>
<text x="940" y="405" class="text">llmResponse events</text>
<rect x="660" y="350" width="220" height="110" class="box"/>
<text x="680" y="380" class="title">TTS (stream)</text>
<text x="680" y="405" class="text">PCM audio</text>
<rect x="40" y="350" width="250" height="110" class="box"/>
<text x="60" y="380" class="title">Web Client</text>
<text x="60" y="405" class="text">audio playback + UI</text>
<path d="M290 80 L350 80" class="arrow"/>
<text x="300" y="70" class="label">JSON / PCM</text>
<path d="M600 80 L660 80" class="arrow"/>
<text x="615" y="70" class="label">dispatch</text>
<path d="M910 80 L920 80" class="arrow"/>
<text x="880" y="70" class="label">turn mgmt</text>
<path d="M750 160 L750 200" class="arrow"/>
<text x="705" y="190" class="label">audio chunks</text>
<path d="M840 250 L860 250" class="arrow"/>
<text x="835" y="240" class="label">vad status</text>
<path d="M1040 250 L1060 250" class="arrow"/>
<text x="1010" y="240" class="label">audio buffer</text>
<path d="M950 300 L950 350" class="arrow2"/>
<text x="930" y="340" class="label">EOU -> LLM</text>
<path d="M880 405 L920 405" class="arrow2"/>
<text x="870" y="395" class="label">text stream</text>
<path d="M660 405 L290 405" class="arrow2"/>
<text x="430" y="395" class="label">PCM audio</text>
<path d="M660 450 L350 450" class="arrow"/>
<text x="420" y="440" class="label">events: trackStart/End</text>
<path d="M350 450 L290 450" class="arrow"/>
<text x="315" y="440" class="label">UI updates</text>
<path d="M750 200 L750 160" class="arrow3"/>
<text x="700" y="145" class="label">barge-in detection</text>
<path d="M760 170 L920 170" class="arrow3"/>
<text x="820" y="160" class="label">interrupt event + cancel</text>
</svg>

After

Width:  |  Height:  |  Size: 3.9 KiB

187
docs/proejct_todo.md Normal file
View File

@@ -0,0 +1,187 @@
# OmniSense: 12-Week Sprint Board + Tech Stack (Python Backend) — TODO
## Scope
- [ ] Build a realtime AI SaaS (OmniSense) focused on web-first audio + video with WebSocket + WebRTC endpoints
- [ ] Deliver assistant builder, tool execution, observability, evals, optional telephony later
- [ ] Keep scope aligned to 2-person team, self-hosted services
---
## Sprint Board (12 weeks, 2-week sprints)
Team assumption: 2 engineers. Scope prioritized to web-first audio + video, with BYO-SFU adapters.
### Sprint 1 (Weeks 12) — Realtime Core MVP (WebSocket + WebRTC Audio)
- Deliverables
- [ ] WebSocket transport: audio in/out streaming (1:1)
- [ ] WebRTC transport: audio in/out streaming (1:1)
- [ ] Adapter contract wired into runtime (transport-agnostic session core)
- [ ] ASR → LLM → TTS pipeline, streaming both directions
- [ ] Basic session state (start/stop, silence timeout)
- [ ] Transcript persistence
- Acceptance criteria
- [ ] < 1.5s median round-trip for short responses
- [ ] Stable streaming for 10+ minute session
### Sprint 2 (Weeks 34) — Video + Realtime UX
- Deliverables
- [ ] WebRTC video capture + streaming (assistant can “see” frames)
- [ ] WebSocket video streaming for local/dev mode
- [ ] Low-latency UI: push-to-talk, live captions, speaking indicator
- [ ] Recording + transcript storage (web sessions)
- Acceptance criteria
- [ ] Video < 2.5s end-to-end latency for analysis
- [ ] Audio quality acceptable (no clipping, jitter handling)
### Sprint 3 (Weeks 56) — Assistant Builder v1
- Deliverables
- [ ] Assistant schema + versioning
- [ ] UI: Model/Voice/Transcriber/Tools/Video/Transport tabs
- [ ] “Test/Chat/Talk to Assistant” (web)
- Acceptance criteria
- [ ] Create/publish assistant and run a live web session
- [ ] All config changes tracked by version
### Sprint 4 (Weeks 78) — Tooling + Structured Outputs
- Deliverables
- [ ] Tool registry + custom HTTP tools
- [ ] Tool auth secrets management
- [ ] Structured outputs (JSON extraction)
- Acceptance criteria
- [ ] Tool calls executed with retries/timeouts
- [ ] Structured JSON stored per call/session
### Sprint 5 (Weeks 910) — Observability + QA + Dev Platform
- Deliverables
- [ ] Session logs + chat logs + media logs
- [ ] Evals engine + test suites
- [ ] Basic analytics dashboard
- [ ] Public WebSocket API spec + message schema
- [ ] JS/TS SDK (connect, send audio/video, receive transcripts)
- Acceptance criteria
- [ ] Reproducible test suite runs
- [ ] Log filters by assistant/time/status
- [ ] SDK demo app runs end-to-end
### Sprint 6 (Weeks 1112) — SaaS Hardening
- Deliverables
- [ ] Org/RBAC + API keys + rate limits
- [ ] Usage metering + credits
- [ ] Stripe billing integration
- [ ] Self-hosted DB ops (migrations, backup/restore, monitoring)
- Acceptance criteria
- [ ] Metered usage per org
- [ ] Credits decrement correctly
- [ ] Optional telephony spike documented (defer build)
- [ ] Enterprise adapter guide published (BYO-SFU)
---
## Tech Stack by Service (Self-Hosted, Web-First)
### 1) Transport Gateway (Realtime)
- [ ] WebRTC (browser) + WebSocket (lightweight/dev) protocols
- [ ] BYO-SFU adapter (enterprise) + LiveKit optional adapter + WS transport server
- [ ] Python core (FastAPI + asyncio) + Node.js mediasoup adapters when needed
- [ ] Media: Opus/VP8, jitter buffer, VAD, echo cancellation
- [ ] Storage: S3-compatible (MinIO) for recordings
### 2) ASR Service
- [ ] Whisper (self-hosted) baseline
- [ ] gRPC/WebSocket streaming transport
- [ ] Python native service
- [ ] Optional cloud provider fallback (later)
### 3) TTS Service
- [ ] Piper or Coqui TTS (self-hosted)
- [ ] gRPC/WebSocket streaming transport
- [ ] Python native service
- [ ] Redis cache for common phrases
### 4) LLM Orchestrator
- [ ] Self-hosted (vLLM + open model)
- [ ] Python (FastAPI + asyncio)
- [ ] Streaming, tool calling, JSON mode
- [ ] Safety filters + prompt templates
### 5) Assistant Config Service
- [ ] PostgreSQL
- [ ] Python (SQLAlchemy or SQLModel)
- [ ] Versioning, publish/rollback
### 6) Session Service
- [ ] PostgreSQL + Redis
- [ ] Python
- [ ] State machine, timeouts, events
### 7) Tool Execution Layer
- [ ] PostgreSQL
- [ ] Python
- [ ] Auth secret vault, retry policies, tool schemas
### 8) Observability + Logs
- [ ] Postgres (metadata), ClickHouse (logs/metrics)
- [ ] OpenSearch for search
- [ ] Prometheus + Grafana metrics
- [ ] OpenTelemetry tracing
### 9) Billing + Usage Metering
- [ ] Stripe billing
- [ ] PostgreSQL
- [ ] NATS JetStream (events) + Redis counters
### 10) Web App (Dashboard)
- [ ] React + Next.js
- [ ] Tailwind or Radix UI
- [ ] WebRTC client + WS client; adapter-based RTC integration
- [ ] ECharts/Recharts
### 11) Auth + RBAC
- [ ] Keycloak (self-hosted) or custom JWT
- [ ] Org/user/role tables in Postgres
### 12) Public WebSocket API + SDK
- [ ] WS API: versioned schema, binary audio frames + JSON control messages
- [ ] SDKs: JS/TS first, optional Python/Go clients
- [ ] Docs: quickstart, auth flow, session lifecycle, examples
---
## Infrastructure (Self-Hosted)
- [ ] Docker Compose → k3s (later)
- [ ] Redis Streams or NATS
- [ ] MinIO object store
- [ ] GitHub Actions + Helm or kustomize
- [ ] Self-hosted Postgres + pgbackrest backups
- [ ] Vault for secrets
---
## Suggested MVP Sequence
- [ ] WebRTC demo + ASR/LLM/TTS streaming
- [ ] Assistant schema + versioning (web-first)
- [ ] Video capture + multimodal analysis
- [ ] Tool execution + structured outputs
- [ ] Logs + evals + public WS API + SDK
- [ ] Telephony (optional, later)
---
## Public WebSocket API (Minimum Spec)
- [ ] Auth: API key or JWT in initial `hello` message
- [ ] Core messages: `session.start`, `session.stop`, `audio.append`, `audio.commit`, `video.append`, `transcript.delta`, `assistant.response`, `tool.call`, `tool.result`, `error`
- [ ] Binary payloads: PCM/Opus frames with metadata in control channel
- [ ] Versioning: `v1` schema with backward compatibility rules
---
## Self-Hosted DB Ops Checklist
- [ ] Postgres in Docker/k3s with persistent volumes
- [ ] Migrations: `alembic` or `atlas`
- [ ] Backups: `pgbackrest` nightly + on-demand
- [ ] Monitoring: postgres_exporter + alerts
---
## RTC Adapter Contract (BYO-SFU First)
- [ ] Keep RTC pluggable; LiveKit optional, not core dependency
- [ ] Define adapter interface (TypeScript sketch)

199
docs/ws_v1_schema.md Normal file
View File

@@ -0,0 +1,199 @@
# WS v1 Protocol Schema (`/ws`)
This document defines the public WebSocket protocol for the `/ws` endpoint.
## Transport
- A single WebSocket connection carries:
- JSON text frames for control/events.
- Binary frames for raw PCM audio (`pcm_s16le`, mono, 16kHz by default).
## Handshake and State Machine
Required message order:
1. Client sends `hello`.
2. Server replies `hello.ack`.
3. Client sends `session.start`.
4. Server replies `session.started`.
5. Client may stream binary audio and/or send `input.text`.
6. Client sends `session.stop` (or closes socket).
If order is violated, server emits `error` with `code = "protocol.order"`.
## Client -> Server Messages
### `hello`
```json
{
"type": "hello",
"version": "v1",
"auth": {
"apiKey": "optional-api-key",
"jwt": "optional-jwt"
}
}
```
Rules:
- `version` must be `v1`.
- If `WS_API_KEY` is configured on server, `auth.apiKey` must match.
- If `WS_REQUIRE_AUTH=true`, either `auth.apiKey` or `auth.jwt` must be present.
### `session.start`
```json
{
"type": "session.start",
"audio": {
"encoding": "pcm_s16le",
"sample_rate_hz": 16000,
"channels": 1
},
"metadata": {
"client": "web-debug",
"output": {
"mode": "audio"
},
"systemPrompt": "You are concise.",
"greeting": "Hi, how can I help?",
"services": {
"llm": {
"provider": "openai",
"model": "gpt-4o-mini",
"apiKey": "sk-...",
"baseUrl": "https://api.openai.com/v1"
},
"asr": {
"provider": "openai_compatible",
"model": "FunAudioLLM/SenseVoiceSmall",
"apiKey": "sf-...",
"interimIntervalMs": 500,
"minAudioMs": 300
},
"tts": {
"enabled": true,
"provider": "openai_compatible",
"model": "FunAudioLLM/CosyVoice2-0.5B",
"apiKey": "sf-...",
"voice": "anna",
"speed": 1.0
}
}
}
}
```
`metadata.services` is optional. If omitted, server defaults to environment configuration.
Text-only mode:
- Set `metadata.output.mode = "text"` OR `metadata.services.tts.enabled = false`.
- In this mode server still sends `assistant.response.delta/final`, but will not emit audio frames or `output.audio.start/end`.
### `input.text`
```json
{
"type": "input.text",
"text": "What can you do?"
}
```
### `response.cancel`
```json
{
"type": "response.cancel",
"graceful": false
}
```
### `session.stop`
```json
{
"type": "session.stop",
"reason": "client_disconnect"
}
```
### `tool_call.results`
Client tool execution results returned to server.
```json
{
"type": "tool_call.results",
"results": [
{
"tool_call_id": "call_abc123",
"name": "weather",
"output": { "temp_c": 21, "condition": "sunny" },
"status": { "code": 200, "message": "ok" }
}
]
}
```
## Server -> Client Events
All server events include:
```json
{
"type": "event.name",
"timestamp": 1730000000000
}
```
Common events:
- `hello.ack`
- Fields: `sessionId`, `version`
- `session.started`
- Fields: `sessionId`, `trackId`, `audio`
- `session.stopped`
- Fields: `sessionId`, `reason`
- `heartbeat`
- `input.speech_started`
- Fields: `trackId`, `probability`
- `input.speech_stopped`
- Fields: `trackId`, `probability`
- `transcript.delta`
- Fields: `trackId`, `text`
- `transcript.final`
- Fields: `trackId`, `text`
- `assistant.response.delta`
- Fields: `trackId`, `text`
- `assistant.response.final`
- Fields: `trackId`, `text`
- `assistant.tool_call`
- Fields: `trackId`, `tool_call` (`tool_call.executor` is `client` or `server`)
- `assistant.tool_result`
- Fields: `trackId`, `source`, `result`
- `output.audio.start`
- Fields: `trackId`
- `output.audio.end`
- Fields: `trackId`
- `response.interrupted`
- Fields: `trackId`
- `metrics.ttfb`
- Fields: `trackId`, `latencyMs`
- `error`
- Fields: `sender`, `code`, `message`, `trackId`
## Binary Audio Frames
After `session.started`, client may send binary PCM chunks continuously.
Recommended format:
- 16-bit signed little-endian PCM.
- 1 channel.
- 16000 Hz.
- 20ms frames (640 bytes) preferred.
## Compatibility
This endpoint now enforces v1 message schema for JSON control frames.
Legacy command names (`invite`, `chat`, etc.) are no longer part of the public protocol.

601
examples/mic_client.py Normal file
View File

@@ -0,0 +1,601 @@
#!/usr/bin/env python3
"""
Microphone client for testing duplex voice conversation.
This client captures audio from the microphone, sends it to the server,
and plays back the AI's voice response through the speakers.
It also displays the LLM's text responses in the console.
Usage:
python examples/mic_client.py --url ws://localhost:8000/ws
python examples/mic_client.py --url ws://localhost:8000/ws --chat "Hello!"
python examples/mic_client.py --url ws://localhost:8000/ws --verbose
Requirements:
pip install sounddevice soundfile websockets numpy
"""
import argparse
import asyncio
import json
import sys
import time
import threading
import queue
from pathlib import Path
try:
import numpy as np
except ImportError:
print("Please install numpy: pip install numpy")
sys.exit(1)
try:
import sounddevice as sd
except ImportError:
print("Please install sounddevice: pip install sounddevice")
sys.exit(1)
try:
import websockets
except ImportError:
print("Please install websockets: pip install websockets")
sys.exit(1)
class MicrophoneClient:
"""
Full-duplex microphone client for voice conversation.
Features:
- Real-time microphone capture
- Real-time speaker playback
- WebSocket communication
- Text chat support
"""
def __init__(
self,
url: str,
sample_rate: int = 16000,
chunk_duration_ms: int = 20,
input_device: int = None,
output_device: int = None
):
"""
Initialize microphone client.
Args:
url: WebSocket server URL
sample_rate: Audio sample rate (Hz)
chunk_duration_ms: Audio chunk duration (ms)
input_device: Input device ID (None for default)
output_device: Output device ID (None for default)
"""
self.url = url
self.sample_rate = sample_rate
self.chunk_duration_ms = chunk_duration_ms
self.chunk_samples = int(sample_rate * chunk_duration_ms / 1000)
self.input_device = input_device
self.output_device = output_device
# WebSocket connection
self.ws = None
self.running = False
# Audio buffers
self.audio_input_queue = queue.Queue()
self.audio_output_buffer = b"" # Continuous buffer for smooth playback
self.audio_output_lock = threading.Lock()
# Statistics
self.bytes_sent = 0
self.bytes_received = 0
# State
self.is_recording = True
self.is_playing = True
# TTFB tracking (Time to First Byte)
self.request_start_time = None
self.first_audio_received = False
# Interrupt handling - discard audio until next trackStart
self._discard_audio = False
self._audio_sequence = 0 # Track audio sequence to detect stale chunks
# Verbose mode for streaming LLM responses
self.verbose = False
async def connect(self) -> None:
"""Connect to WebSocket server."""
print(f"Connecting to {self.url}...")
self.ws = await websockets.connect(self.url)
self.running = True
print("Connected!")
# Send invite command
await self.send_command({
"command": "invite",
"option": {
"codec": "pcm",
"sampleRate": self.sample_rate
}
})
async def send_command(self, cmd: dict) -> None:
"""Send JSON command to server."""
if self.ws:
await self.ws.send(json.dumps(cmd))
print(f"→ Command: {cmd.get('command', 'unknown')}")
async def send_chat(self, text: str) -> None:
"""Send chat message (text input)."""
# Reset TTFB tracking for new request
self.request_start_time = time.time()
self.first_audio_received = False
await self.send_command({
"command": "chat",
"text": text
})
print(f"→ Chat: {text}")
async def send_interrupt(self) -> None:
"""Send interrupt command."""
await self.send_command({
"command": "interrupt"
})
async def send_hangup(self, reason: str = "User quit") -> None:
"""Send hangup command."""
await self.send_command({
"command": "hangup",
"reason": reason
})
def _audio_input_callback(self, indata, frames, time, status):
"""Callback for audio input (microphone)."""
if status:
print(f"Input status: {status}")
if self.is_recording and self.running:
# Convert to 16-bit PCM
audio_data = (indata[:, 0] * 32767).astype(np.int16).tobytes()
self.audio_input_queue.put(audio_data)
def _add_audio_to_buffer(self, audio_data: bytes):
"""Add audio data to playback buffer."""
with self.audio_output_lock:
self.audio_output_buffer += audio_data
def _playback_thread_func(self):
"""Thread function for continuous audio playback."""
import time
# Chunk size: 50ms of audio
chunk_samples = int(self.sample_rate * 0.05)
chunk_bytes = chunk_samples * 2
print(f"Audio playback thread started (device: {self.output_device or 'default'})")
try:
# Create output stream with callback
with sd.OutputStream(
samplerate=self.sample_rate,
channels=1,
dtype='int16',
blocksize=chunk_samples,
device=self.output_device,
latency='low'
) as stream:
while self.running:
# Get audio from buffer
with self.audio_output_lock:
if len(self.audio_output_buffer) >= chunk_bytes:
audio_data = self.audio_output_buffer[:chunk_bytes]
self.audio_output_buffer = self.audio_output_buffer[chunk_bytes:]
else:
# Not enough audio - output silence
audio_data = b'\x00' * chunk_bytes
# Convert to numpy array and write to stream
samples = np.frombuffer(audio_data, dtype=np.int16).reshape(-1, 1)
stream.write(samples)
except Exception as e:
print(f"Playback thread error: {e}")
import traceback
traceback.print_exc()
async def _playback_task(self):
"""Start playback thread and monitor it."""
# Run playback in a dedicated thread for reliable timing
playback_thread = threading.Thread(target=self._playback_thread_func, daemon=True)
playback_thread.start()
# Wait for client to stop
while self.running and playback_thread.is_alive():
await asyncio.sleep(0.1)
print("Audio playback stopped")
async def audio_sender(self) -> None:
"""Send audio from microphone to server."""
while self.running:
try:
# Get audio from queue with timeout
try:
audio_data = await asyncio.get_event_loop().run_in_executor(
None, lambda: self.audio_input_queue.get(timeout=0.1)
)
except queue.Empty:
continue
# Send to server
if self.ws and self.is_recording:
await self.ws.send(audio_data)
self.bytes_sent += len(audio_data)
except asyncio.CancelledError:
break
except Exception as e:
print(f"Audio sender error: {e}")
break
async def receiver(self) -> None:
"""Receive messages from server."""
try:
while self.running:
try:
message = await asyncio.wait_for(self.ws.recv(), timeout=0.1)
if isinstance(message, bytes):
# Audio data received
self.bytes_received += len(message)
# Check if we should discard this audio (after interrupt)
if self._discard_audio:
duration_ms = len(message) / (self.sample_rate * 2) * 1000
print(f"← Audio: {duration_ms:.0f}ms (DISCARDED - waiting for new track)")
continue
if self.is_playing:
self._add_audio_to_buffer(message)
# Calculate and display TTFB for first audio packet
if not self.first_audio_received and self.request_start_time:
client_ttfb_ms = (time.time() - self.request_start_time) * 1000
self.first_audio_received = True
print(f"← [TTFB] Client first audio latency: {client_ttfb_ms:.0f}ms")
# Show progress (less verbose)
with self.audio_output_lock:
buffer_ms = len(self.audio_output_buffer) / (self.sample_rate * 2) * 1000
duration_ms = len(message) / (self.sample_rate * 2) * 1000
print(f"← Audio: {duration_ms:.0f}ms (buffer: {buffer_ms:.0f}ms)")
else:
# JSON event
event = json.loads(message)
await self._handle_event(event)
except asyncio.TimeoutError:
continue
except websockets.ConnectionClosed:
print("Connection closed")
self.running = False
break
except asyncio.CancelledError:
pass
except Exception as e:
print(f"Receiver error: {e}")
self.running = False
async def _handle_event(self, event: dict) -> None:
"""Handle incoming event."""
event_type = event.get("event", "unknown")
if event_type == "answer":
print("← Session ready!")
elif event_type == "speaking":
print("← User speech detected")
elif event_type == "silence":
print("← User silence detected")
elif event_type == "transcript":
# Display user speech transcription
text = event.get("text", "")
is_final = event.get("isFinal", False)
if is_final:
# Clear the interim line and print final
print(" " * 80, end="\r") # Clear previous interim text
print(f"→ You: {text}")
else:
# Interim result - show with indicator (overwrite same line)
display_text = text[:60] + "..." if len(text) > 60 else text
print(f" [listening] {display_text}".ljust(80), end="\r")
elif event_type == "ttfb":
# Server-side TTFB event
latency_ms = event.get("latencyMs", 0)
print(f"← [TTFB] Server reported latency: {latency_ms}ms")
elif event_type == "llmResponse":
# LLM text response
text = event.get("text", "")
is_final = event.get("isFinal", False)
if is_final:
# Print final LLM response
print(f"← AI: {text}")
elif self.verbose:
# Show streaming chunks only in verbose mode
display_text = text[:60] + "..." if len(text) > 60 else text
print(f" [streaming] {display_text}")
elif event_type == "trackStart":
print("← Bot started speaking")
# IMPORTANT: Accept audio again after trackStart
self._discard_audio = False
self._audio_sequence += 1
# Reset TTFB tracking for voice responses (when no chat was sent)
if self.request_start_time is None:
self.request_start_time = time.time()
self.first_audio_received = False
# Clear any old audio in buffer
with self.audio_output_lock:
self.audio_output_buffer = b""
elif event_type == "trackEnd":
print("← Bot finished speaking")
# Reset TTFB tracking after response completes
self.request_start_time = None
self.first_audio_received = False
elif event_type == "interrupt":
print("← Bot interrupted!")
# IMPORTANT: Discard all audio until next trackStart
self._discard_audio = True
# Clear audio buffer immediately
with self.audio_output_lock:
buffer_ms = len(self.audio_output_buffer) / (self.sample_rate * 2) * 1000
self.audio_output_buffer = b""
print(f" (cleared {buffer_ms:.0f}ms, discarding audio until new track)")
elif event_type == "error":
print(f"← Error: {event.get('error')}")
elif event_type == "hangup":
print(f"← Hangup: {event.get('reason')}")
self.running = False
else:
print(f"← Event: {event_type}")
async def interactive_mode(self) -> None:
"""Run interactive mode for text chat."""
print("\n" + "=" * 50)
print("Voice Conversation Client")
print("=" * 50)
print("Speak into your microphone to talk to the AI.")
print("Or type messages to send text.")
print("")
print("Commands:")
print(" /quit - End conversation")
print(" /mute - Mute microphone")
print(" /unmute - Unmute microphone")
print(" /interrupt - Interrupt AI speech")
print(" /stats - Show statistics")
print("=" * 50 + "\n")
while self.running:
try:
user_input = await asyncio.get_event_loop().run_in_executor(
None, input, ""
)
if not user_input:
continue
# Handle commands
if user_input.startswith("/"):
cmd = user_input.lower().strip()
if cmd == "/quit":
await self.send_hangup("User quit")
break
elif cmd == "/mute":
self.is_recording = False
print("Microphone muted")
elif cmd == "/unmute":
self.is_recording = True
print("Microphone unmuted")
elif cmd == "/interrupt":
await self.send_interrupt()
elif cmd == "/stats":
print(f"Sent: {self.bytes_sent / 1024:.1f} KB")
print(f"Received: {self.bytes_received / 1024:.1f} KB")
else:
print(f"Unknown command: {cmd}")
else:
# Send as chat message
await self.send_chat(user_input)
except EOFError:
break
except Exception as e:
print(f"Input error: {e}")
async def run(self, chat_message: str = None, interactive: bool = True) -> None:
"""
Run the client.
Args:
chat_message: Optional single chat message to send
interactive: Whether to run in interactive mode
"""
try:
await self.connect()
# Wait for answer
await asyncio.sleep(0.5)
# Start audio input stream
print("Starting audio streams...")
input_stream = sd.InputStream(
samplerate=self.sample_rate,
channels=1,
dtype=np.float32,
blocksize=self.chunk_samples,
device=self.input_device,
callback=self._audio_input_callback
)
input_stream.start()
print("Audio streams started")
# Start background tasks
sender_task = asyncio.create_task(self.audio_sender())
receiver_task = asyncio.create_task(self.receiver())
playback_task = asyncio.create_task(self._playback_task())
if chat_message:
# Send single message and wait
await self.send_chat(chat_message)
await asyncio.sleep(15)
elif interactive:
# Run interactive mode
await self.interactive_mode()
else:
# Just wait
while self.running:
await asyncio.sleep(0.1)
# Cleanup
self.running = False
sender_task.cancel()
receiver_task.cancel()
playback_task.cancel()
try:
await sender_task
except asyncio.CancelledError:
pass
try:
await receiver_task
except asyncio.CancelledError:
pass
try:
await playback_task
except asyncio.CancelledError:
pass
input_stream.stop()
except ConnectionRefusedError:
print(f"Error: Could not connect to {self.url}")
print("Make sure the server is running.")
except Exception as e:
print(f"Error: {e}")
finally:
await self.close()
async def close(self) -> None:
"""Close the connection."""
self.running = False
if self.ws:
await self.ws.close()
print(f"\nSession ended")
print(f" Total sent: {self.bytes_sent / 1024:.1f} KB")
print(f" Total received: {self.bytes_received / 1024:.1f} KB")
def list_devices():
"""List available audio devices."""
print("\nAvailable audio devices:")
print("-" * 60)
devices = sd.query_devices()
for i, device in enumerate(devices):
direction = []
if device['max_input_channels'] > 0:
direction.append("IN")
if device['max_output_channels'] > 0:
direction.append("OUT")
direction_str = "/".join(direction) if direction else "N/A"
default = ""
if i == sd.default.device[0]:
default += " [DEFAULT INPUT]"
if i == sd.default.device[1]:
default += " [DEFAULT OUTPUT]"
print(f" {i:2d}: {device['name'][:40]:40s} ({direction_str}){default}")
print("-" * 60)
async def main():
parser = argparse.ArgumentParser(
description="Microphone client for duplex voice conversation"
)
parser.add_argument(
"--url",
default="ws://localhost:8000/ws",
help="WebSocket server URL"
)
parser.add_argument(
"--chat",
help="Send a single chat message instead of using microphone"
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="Audio sample rate (default: 16000)"
)
parser.add_argument(
"--input-device",
type=int,
help="Input device ID"
)
parser.add_argument(
"--output-device",
type=int,
help="Output device ID"
)
parser.add_argument(
"--list-devices",
action="store_true",
help="List available audio devices and exit"
)
parser.add_argument(
"--no-interactive",
action="store_true",
help="Disable interactive mode"
)
parser.add_argument(
"--verbose", "-v",
action="store_true",
help="Show streaming LLM response chunks"
)
args = parser.parse_args()
if args.list_devices:
list_devices()
return
client = MicrophoneClient(
url=args.url,
sample_rate=args.sample_rate,
input_device=args.input_device,
output_device=args.output_device
)
client.verbose = args.verbose
await client.run(
chat_message=args.chat,
interactive=not args.no_interactive
)
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
print("\nInterrupted by user")

285
examples/simple_client.py Normal file
View File

@@ -0,0 +1,285 @@
#!/usr/bin/env python3
"""
Simple WebSocket client for testing voice conversation.
Uses PyAudio for more reliable audio playback on Windows.
Usage:
python examples/simple_client.py
python examples/simple_client.py --text "Hello"
"""
import argparse
import asyncio
import json
import sys
import time
import wave
import io
try:
import numpy as np
except ImportError:
print("pip install numpy")
sys.exit(1)
try:
import websockets
except ImportError:
print("pip install websockets")
sys.exit(1)
# Try PyAudio first (more reliable on Windows)
try:
import pyaudio
PYAUDIO_AVAILABLE = True
except ImportError:
PYAUDIO_AVAILABLE = False
print("PyAudio not available, trying sounddevice...")
try:
import sounddevice as sd
SD_AVAILABLE = True
except ImportError:
SD_AVAILABLE = False
if not PYAUDIO_AVAILABLE and not SD_AVAILABLE:
print("Please install pyaudio or sounddevice:")
print(" pip install pyaudio")
print(" or: pip install sounddevice")
sys.exit(1)
class SimpleVoiceClient:
"""Simple voice client with reliable audio playback."""
def __init__(self, url: str, sample_rate: int = 16000):
self.url = url
self.sample_rate = sample_rate
self.ws = None
self.running = False
# Audio buffer
self.audio_buffer = b""
# PyAudio setup
if PYAUDIO_AVAILABLE:
self.pa = pyaudio.PyAudio()
self.stream = None
# Stats
self.bytes_received = 0
# TTFB tracking (Time to First Byte)
self.request_start_time = None
self.first_audio_received = False
# Interrupt handling - discard audio until next trackStart
self._discard_audio = False
async def connect(self):
"""Connect to server."""
print(f"Connecting to {self.url}...")
self.ws = await websockets.connect(self.url)
self.running = True
print("Connected!")
# Send invite
await self.ws.send(json.dumps({
"command": "invite",
"option": {"codec": "pcm", "sampleRate": self.sample_rate}
}))
print("-> invite")
async def send_chat(self, text: str):
"""Send chat message."""
# Reset TTFB tracking for new request
self.request_start_time = time.time()
self.first_audio_received = False
await self.ws.send(json.dumps({"command": "chat", "text": text}))
print(f"-> chat: {text}")
def play_audio(self, audio_data: bytes):
"""Play audio data immediately."""
if len(audio_data) == 0:
return
if PYAUDIO_AVAILABLE:
# Use PyAudio - more reliable on Windows
if self.stream is None:
self.stream = self.pa.open(
format=pyaudio.paInt16,
channels=1,
rate=self.sample_rate,
output=True,
frames_per_buffer=1024
)
self.stream.write(audio_data)
elif SD_AVAILABLE:
# Use sounddevice
samples = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32767.0
sd.play(samples, self.sample_rate, blocking=True)
async def receive_loop(self):
"""Receive and play audio."""
print("\nWaiting for response...")
while self.running:
try:
msg = await asyncio.wait_for(self.ws.recv(), timeout=0.1)
if isinstance(msg, bytes):
# Audio data
self.bytes_received += len(msg)
duration_ms = len(msg) / (self.sample_rate * 2) * 1000
# Check if we should discard this audio (after interrupt)
if self._discard_audio:
print(f"<- audio: {len(msg)} bytes ({duration_ms:.0f}ms) [DISCARDED]")
continue
# Calculate and display TTFB for first audio packet
if not self.first_audio_received and self.request_start_time:
client_ttfb_ms = (time.time() - self.request_start_time) * 1000
self.first_audio_received = True
print(f"<- [TTFB] Client first audio latency: {client_ttfb_ms:.0f}ms")
print(f"<- audio: {len(msg)} bytes ({duration_ms:.0f}ms)")
# Play immediately in executor to not block
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, self.play_audio, msg)
else:
# JSON event
event = json.loads(msg)
etype = event.get("event", "?")
if etype == "transcript":
# User speech transcription
text = event.get("text", "")
is_final = event.get("isFinal", False)
if is_final:
print(f"<- You said: {text}")
else:
print(f"<- [listening] {text}", end="\r")
elif etype == "ttfb":
# Server-side TTFB event
latency_ms = event.get("latencyMs", 0)
print(f"<- [TTFB] Server reported latency: {latency_ms}ms")
elif etype == "trackStart":
# New track starting - accept audio again
self._discard_audio = False
print(f"<- {etype}")
elif etype == "interrupt":
# Interrupt - discard audio until next trackStart
self._discard_audio = True
print(f"<- {etype} (discarding audio until new track)")
elif etype == "hangup":
print(f"<- {etype}")
self.running = False
break
else:
print(f"<- {etype}")
except asyncio.TimeoutError:
continue
except websockets.ConnectionClosed:
print("Connection closed")
self.running = False
break
async def run(self, text: str = None):
"""Run the client."""
try:
await self.connect()
await asyncio.sleep(0.5)
# Start receiver
recv_task = asyncio.create_task(self.receive_loop())
if text:
await self.send_chat(text)
# Wait for response
await asyncio.sleep(30)
else:
# Interactive mode
print("\nType a message and press Enter (or 'quit' to exit):")
while self.running:
try:
user_input = await asyncio.get_event_loop().run_in_executor(
None, input, "> "
)
if user_input.lower() == 'quit':
break
if user_input.strip():
await self.send_chat(user_input)
except EOFError:
break
self.running = False
recv_task.cancel()
try:
await recv_task
except asyncio.CancelledError:
pass
finally:
await self.close()
async def close(self):
"""Close connections."""
self.running = False
if PYAUDIO_AVAILABLE:
if self.stream:
self.stream.stop_stream()
self.stream.close()
self.pa.terminate()
if self.ws:
await self.ws.close()
print(f"\nTotal audio received: {self.bytes_received / 1024:.1f} KB")
def list_audio_devices():
"""List available audio devices."""
print("\n=== Audio Devices ===")
if PYAUDIO_AVAILABLE:
pa = pyaudio.PyAudio()
print("\nPyAudio devices:")
for i in range(pa.get_device_count()):
info = pa.get_device_info_by_index(i)
if info['maxOutputChannels'] > 0:
default = " [DEFAULT]" if i == pa.get_default_output_device_info()['index'] else ""
print(f" {i}: {info['name']}{default}")
pa.terminate()
if SD_AVAILABLE:
print("\nSounddevice devices:")
for i, d in enumerate(sd.query_devices()):
if d['max_output_channels'] > 0:
default = " [DEFAULT]" if i == sd.default.device[1] else ""
print(f" {i}: {d['name']}{default}")
async def main():
parser = argparse.ArgumentParser(description="Simple voice client")
parser.add_argument("--url", default="ws://localhost:8000/ws")
parser.add_argument("--text", help="Send text and play response")
parser.add_argument("--list-devices", action="store_true")
parser.add_argument("--sample-rate", type=int, default=16000)
args = parser.parse_args()
if args.list_devices:
list_audio_devices()
return
client = SimpleVoiceClient(args.url, args.sample_rate)
await client.run(args.text)
if __name__ == "__main__":
asyncio.run(main())

176
examples/test_websocket.py Normal file
View File

@@ -0,0 +1,176 @@
"""WebSocket endpoint test client.
Tests the /ws endpoint with sine wave or file audio streaming.
Based on reference/py-active-call/exec/test_ws_endpoint/test_ws.py
"""
import asyncio
import aiohttp
import json
import struct
import math
import argparse
import os
from datetime import datetime
# Configuration
SERVER_URL = "ws://localhost:8000/ws"
SAMPLE_RATE = 16000
FREQUENCY = 440 # 440Hz Sine Wave
CHUNK_DURATION_MS = 20
# 16kHz * 16-bit (2 bytes) * 20ms = 640 bytes per chunk
CHUNK_SIZE_BYTES = int(SAMPLE_RATE * 2 * (CHUNK_DURATION_MS / 1000.0))
def generate_sine_wave(duration_ms=1000):
"""Generates sine wave audio (16kHz mono PCM 16-bit)."""
num_samples = int(SAMPLE_RATE * (duration_ms / 1000.0))
audio_data = bytearray()
for x in range(num_samples):
# Generate sine wave sample
value = int(32767.0 * math.sin(2 * math.pi * FREQUENCY * x / SAMPLE_RATE))
# Pack as little-endian 16-bit integer
audio_data.extend(struct.pack('<h', value))
return audio_data
async def receive_loop(ws, ready_event: asyncio.Event):
"""Listen for incoming messages from the server."""
print("👂 Listening for server responses...")
async for msg in ws:
timestamp = datetime.now().strftime("%H:%M:%S")
if msg.type == aiohttp.WSMsgType.TEXT:
try:
data = json.loads(msg.data)
event_type = data.get('type', 'Unknown')
print(f"[{timestamp}] 📨 Event: {event_type} | {msg.data[:150]}...")
if event_type == "session.started":
ready_event.set()
except json.JSONDecodeError:
print(f"[{timestamp}] 📨 Text: {msg.data[:100]}...")
elif msg.type == aiohttp.WSMsgType.BINARY:
# Received audio chunk back (e.g., TTS or echo)
print(f"[{timestamp}] 🔊 Audio: {len(msg.data)} bytes", end="\r")
elif msg.type == aiohttp.WSMsgType.CLOSED:
print(f"\n[{timestamp}] ❌ Socket Closed")
break
elif msg.type == aiohttp.WSMsgType.ERROR:
print(f"\n[{timestamp}] ⚠️ Socket Error")
break
async def send_file_loop(ws, file_path):
"""Stream a raw PCM/WAV file to the server."""
if not os.path.exists(file_path):
print(f"❌ Error: File '{file_path}' not found.")
return
print(f"📂 Streaming file: {file_path} ...")
with open(file_path, "rb") as f:
# Skip WAV header if present (first 44 bytes)
if file_path.endswith('.wav'):
f.read(44)
while True:
chunk = f.read(CHUNK_SIZE_BYTES)
if not chunk:
break
# Send binary frame
await ws.send_bytes(chunk)
# Sleep to simulate real-time playback
await asyncio.sleep(CHUNK_DURATION_MS / 1000.0)
print(f"\n✅ Finished streaming {file_path}")
async def send_sine_loop(ws):
"""Stream generated sine wave to the server."""
print("🎙️ Starting Audio Stream (Sine Wave)...")
# Generate 10 seconds of audio buffer
audio_buffer = generate_sine_wave(5000)
cursor = 0
while cursor < len(audio_buffer):
chunk = audio_buffer[cursor:cursor + CHUNK_SIZE_BYTES]
if not chunk:
break
await ws.send_bytes(chunk)
cursor += len(chunk)
await asyncio.sleep(CHUNK_DURATION_MS / 1000.0)
print("\n✅ Finished streaming test audio.")
async def run_client(url, file_path=None, use_sine=False):
"""Run the WebSocket test client."""
session = aiohttp.ClientSession()
try:
print(f"🔌 Connecting to {url}...")
async with session.ws_connect(url) as ws:
print("✅ Connected!")
session_ready = asyncio.Event()
recv_task = asyncio.create_task(receive_loop(ws, session_ready))
# Send v1 hello + session.start handshake
await ws.send_json({"type": "hello", "version": "v1"})
await ws.send_json({
"type": "session.start",
"audio": {
"encoding": "pcm_s16le",
"sample_rate_hz": SAMPLE_RATE,
"channels": 1
}
})
print("📤 Sent v1 hello/session.start")
await asyncio.wait_for(session_ready.wait(), timeout=8)
# Select sender based on args
if use_sine:
await send_sine_loop(ws)
elif file_path:
await send_file_loop(ws, file_path)
else:
# Default to sine wave
await send_sine_loop(ws)
await ws.send_json({"type": "session.stop", "reason": "test_complete"})
await asyncio.sleep(1)
recv_task.cancel()
try:
await recv_task
except asyncio.CancelledError:
pass
except aiohttp.ClientConnectorError:
print(f"❌ Connection Failed. Is the server running at {url}?")
except asyncio.TimeoutError:
print("❌ Timeout waiting for session.started")
except Exception as e:
print(f"❌ Error: {e}")
finally:
await session.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="WebSocket Audio Test Client")
parser.add_argument("--url", default=SERVER_URL, help="WebSocket endpoint URL")
parser.add_argument("--file", help="Path to PCM/WAV file to stream")
parser.add_argument("--sine", action="store_true", help="Use sine wave generation (default)")
args = parser.parse_args()
try:
asyncio.run(run_client(args.url, args.file, args.sine))
except KeyboardInterrupt:
print("\n👋 Client stopped.")

504
examples/wav_client.py Normal file
View File

@@ -0,0 +1,504 @@
#!/usr/bin/env python3
"""
WAV file client for testing duplex voice conversation.
This client reads audio from a WAV file, sends it to the server,
and saves the AI's voice response to an output WAV file.
Usage:
python examples/wav_client.py --input input.wav --output response.wav
python examples/wav_client.py --input input.wav --output response.wav --url ws://localhost:8000/ws
python examples/wav_client.py --input input.wav --output response.wav --wait-time 10
python wav_client.py --input ../data/audio_examples/two_utterances.wav -o response.wav
Requirements:
pip install soundfile websockets numpy
"""
import argparse
import asyncio
import json
import sys
import time
import wave
from pathlib import Path
try:
import numpy as np
except ImportError:
print("Please install numpy: pip install numpy")
sys.exit(1)
try:
import soundfile as sf
except ImportError:
print("Please install soundfile: pip install soundfile")
sys.exit(1)
try:
import websockets
except ImportError:
print("Please install websockets: pip install websockets")
sys.exit(1)
class WavFileClient:
"""
WAV file client for voice conversation testing.
Features:
- Read audio from WAV file
- Send audio to WebSocket server
- Receive and save response audio
- Event logging
"""
def __init__(
self,
url: str,
input_file: str,
output_file: str,
sample_rate: int = 16000,
chunk_duration_ms: int = 20,
wait_time: float = 15.0,
verbose: bool = False
):
"""
Initialize WAV file client.
Args:
url: WebSocket server URL
input_file: Input WAV file path
output_file: Output WAV file path
sample_rate: Audio sample rate (Hz)
chunk_duration_ms: Audio chunk duration (ms) for sending
wait_time: Time to wait for response after sending (seconds)
verbose: Enable verbose output
"""
self.url = url
self.input_file = Path(input_file)
self.output_file = Path(output_file)
self.sample_rate = sample_rate
self.chunk_duration_ms = chunk_duration_ms
self.chunk_samples = int(sample_rate * chunk_duration_ms / 1000)
self.wait_time = wait_time
self.verbose = verbose
# WebSocket connection
self.ws = None
self.running = False
# Audio buffers
self.received_audio = bytearray()
# Statistics
self.bytes_sent = 0
self.bytes_received = 0
# TTFB tracking (per response)
self.send_start_time = None
self.response_start_time = None # set on each trackStart
self.waiting_for_first_audio = False
self.ttfb_ms = None # last TTFB for summary
self.ttfb_list = [] # TTFB for each response
# State tracking
self.track_started = False
self.track_ended = False
self.send_completed = False
# Events log
self.events_log = []
def log_event(self, direction: str, message: str):
"""Log an event with timestamp."""
timestamp = time.time()
self.events_log.append({
"timestamp": timestamp,
"direction": direction,
"message": message
})
# Handle encoding errors on Windows
try:
print(f"{direction} {message}")
except UnicodeEncodeError:
# Replace problematic characters for console output
safe_message = message.encode('ascii', errors='replace').decode('ascii')
print(f"{direction} {safe_message}")
async def connect(self) -> None:
"""Connect to WebSocket server."""
self.log_event("", f"Connecting to {self.url}...")
self.ws = await websockets.connect(self.url)
self.running = True
self.log_event("", "Connected!")
# Send invite command
await self.send_command({
"command": "invite",
"option": {
"codec": "pcm",
"sampleRate": self.sample_rate
}
})
async def send_command(self, cmd: dict) -> None:
"""Send JSON command to server."""
if self.ws:
await self.ws.send(json.dumps(cmd))
self.log_event("", f"Command: {cmd.get('command', 'unknown')}")
async def send_hangup(self, reason: str = "Session complete") -> None:
"""Send hangup command."""
await self.send_command({
"command": "hangup",
"reason": reason
})
def load_wav_file(self) -> tuple[np.ndarray, int]:
"""
Load and prepare WAV file for sending.
Returns:
Tuple of (audio_data as int16 numpy array, original sample rate)
"""
if not self.input_file.exists():
raise FileNotFoundError(f"Input file not found: {self.input_file}")
# Load audio file
audio_data, file_sample_rate = sf.read(self.input_file)
self.log_event("", f"Loaded: {self.input_file}")
self.log_event("", f" Original sample rate: {file_sample_rate} Hz")
self.log_event("", f" Duration: {len(audio_data) / file_sample_rate:.2f}s")
# Convert stereo to mono if needed
if len(audio_data.shape) > 1:
audio_data = audio_data.mean(axis=1)
self.log_event("", " Converted stereo to mono")
# Resample if needed
if file_sample_rate != self.sample_rate:
# Simple resampling using numpy
duration = len(audio_data) / file_sample_rate
num_samples = int(duration * self.sample_rate)
indices = np.linspace(0, len(audio_data) - 1, num_samples)
audio_data = np.interp(indices, np.arange(len(audio_data)), audio_data)
self.log_event("", f" Resampled to {self.sample_rate} Hz")
# Convert to int16
if audio_data.dtype != np.int16:
# Normalize to [-1, 1] if needed
max_val = np.max(np.abs(audio_data))
if max_val > 1.0:
audio_data = audio_data / max_val
audio_data = (audio_data * 32767).astype(np.int16)
self.log_event("", f" Prepared: {len(audio_data)} samples ({len(audio_data)/self.sample_rate:.2f}s)")
return audio_data, file_sample_rate
async def audio_sender(self, audio_data: np.ndarray) -> None:
"""Send audio data to server in chunks."""
total_samples = len(audio_data)
chunk_size = self.chunk_samples
sent_samples = 0
self.send_start_time = time.time()
self.log_event("", f"Starting audio transmission ({total_samples} samples)...")
while sent_samples < total_samples and self.running:
# Get next chunk
end_sample = min(sent_samples + chunk_size, total_samples)
chunk = audio_data[sent_samples:end_sample]
chunk_bytes = chunk.tobytes()
# Send to server
if self.ws:
await self.ws.send(chunk_bytes)
self.bytes_sent += len(chunk_bytes)
sent_samples = end_sample
# Progress logging (every 500ms worth of audio)
if self.verbose and sent_samples % (self.sample_rate // 2) == 0:
progress = (sent_samples / total_samples) * 100
print(f" Sending: {progress:.0f}%", end="\r")
# Delay to simulate real-time streaming
# Server expects audio at real-time pace for VAD/ASR to work properly
await asyncio.sleep(self.chunk_duration_ms / 1000)
self.send_completed = True
elapsed = time.time() - self.send_start_time
self.log_event("", f"Audio transmission complete ({elapsed:.2f}s, {self.bytes_sent/1024:.1f} KB)")
async def receiver(self) -> None:
"""Receive messages from server."""
try:
while self.running:
try:
message = await asyncio.wait_for(self.ws.recv(), timeout=0.1)
if isinstance(message, bytes):
# Audio data received
self.bytes_received += len(message)
self.received_audio.extend(message)
# Calculate TTFB on first audio of each response
if self.waiting_for_first_audio and self.response_start_time is not None:
ttfb_ms = (time.time() - self.response_start_time) * 1000
self.ttfb_ms = ttfb_ms
self.ttfb_list.append(ttfb_ms)
self.waiting_for_first_audio = False
self.log_event("", f"[TTFB] First audio latency: {ttfb_ms:.0f}ms")
# Log progress
duration_ms = len(message) / (self.sample_rate * 2) * 1000
total_ms = len(self.received_audio) / (self.sample_rate * 2) * 1000
if self.verbose:
print(f"← Audio: +{duration_ms:.0f}ms (total: {total_ms:.0f}ms)", end="\r")
else:
# JSON event
event = json.loads(message)
await self._handle_event(event)
except asyncio.TimeoutError:
continue
except websockets.ConnectionClosed:
self.log_event("", "Connection closed")
self.running = False
break
except asyncio.CancelledError:
pass
except Exception as e:
self.log_event("!", f"Receiver error: {e}")
self.running = False
async def _handle_event(self, event: dict) -> None:
"""Handle incoming event."""
event_type = event.get("event", "unknown")
if event_type == "answer":
self.log_event("", "Session ready!")
elif event_type == "speaking":
self.log_event("", "Speech detected")
elif event_type == "silence":
self.log_event("", "Silence detected")
elif event_type == "transcript":
# ASR transcript (interim = asrDelta-style, final = asrFinal-style)
text = event.get("text", "")
is_final = event.get("isFinal", False)
if is_final:
# Clear interim line and print final
print(" " * 80, end="\r")
self.log_event("", f"→ You: {text}")
else:
# Interim result - show with indicator (overwrite same line, as in mic_client)
display_text = text[:60] + "..." if len(text) > 60 else text
print(f" [listening] {display_text}".ljust(80), end="\r")
elif event_type == "ttfb":
latency_ms = event.get("latencyMs", 0)
self.log_event("", f"[TTFB] Server latency: {latency_ms}ms")
elif event_type == "llmResponse":
text = event.get("text", "")
is_final = event.get("isFinal", False)
if is_final:
self.log_event("", f"LLM Response (final): {text[:100]}{'...' if len(text) > 100 else ''}")
elif self.verbose:
# Show streaming chunks only in verbose mode
self.log_event("", f"LLM: {text}")
elif event_type == "trackStart":
self.track_started = True
self.response_start_time = time.time()
self.waiting_for_first_audio = True
self.log_event("", "Bot started speaking")
elif event_type == "trackEnd":
self.track_ended = True
self.log_event("", "Bot finished speaking")
elif event_type == "interrupt":
self.log_event("", "Bot interrupted!")
elif event_type == "error":
self.log_event("!", f"Error: {event.get('error')}")
elif event_type == "hangup":
self.log_event("", f"Hangup: {event.get('reason')}")
self.running = False
else:
self.log_event("", f"Event: {event_type}")
def save_output_wav(self) -> None:
"""Save received audio to output WAV file."""
if not self.received_audio:
self.log_event("!", "No audio received to save")
return
# Convert bytes to numpy array
audio_data = np.frombuffer(bytes(self.received_audio), dtype=np.int16)
# Ensure output directory exists
self.output_file.parent.mkdir(parents=True, exist_ok=True)
# Save using wave module for compatibility
with wave.open(str(self.output_file), 'wb') as wav_file:
wav_file.setnchannels(1)
wav_file.setsampwidth(2) # 16-bit
wav_file.setframerate(self.sample_rate)
wav_file.writeframes(audio_data.tobytes())
duration = len(audio_data) / self.sample_rate
self.log_event("", f"Saved output: {self.output_file}")
self.log_event("", f" Duration: {duration:.2f}s ({len(audio_data)} samples)")
self.log_event("", f" Size: {len(self.received_audio)/1024:.1f} KB")
async def run(self) -> None:
"""Run the WAV file test."""
try:
# Load input WAV file
audio_data, _ = self.load_wav_file()
# Connect to server
await self.connect()
# Wait for answer
await asyncio.sleep(0.5)
# Start receiver task
receiver_task = asyncio.create_task(self.receiver())
# Send audio
await self.audio_sender(audio_data)
# Wait for response
self.log_event("", f"Waiting {self.wait_time}s for response...")
wait_start = time.time()
while self.running and (time.time() - wait_start) < self.wait_time:
# Check if track has ended (response complete)
if self.track_ended and self.send_completed:
# Give a little extra time for any remaining audio
await asyncio.sleep(1.0)
break
await asyncio.sleep(0.1)
# Cleanup
self.running = False
receiver_task.cancel()
try:
await receiver_task
except asyncio.CancelledError:
pass
# Save output
self.save_output_wav()
# Print summary
self._print_summary()
except FileNotFoundError as e:
print(f"Error: {e}")
sys.exit(1)
except ConnectionRefusedError:
print(f"Error: Could not connect to {self.url}")
print("Make sure the server is running.")
sys.exit(1)
except Exception as e:
print(f"Error: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
finally:
await self.close()
def _print_summary(self):
"""Print session summary."""
print("\n" + "=" * 50)
print("Session Summary")
print("=" * 50)
print(f" Input file: {self.input_file}")
print(f" Output file: {self.output_file}")
print(f" Bytes sent: {self.bytes_sent / 1024:.1f} KB")
print(f" Bytes received: {self.bytes_received / 1024:.1f} KB")
if self.ttfb_list:
if len(self.ttfb_list) == 1:
print(f" TTFB: {self.ttfb_list[0]:.0f} ms")
else:
print(f" TTFB (per response): {', '.join(f'{t:.0f}ms' for t in self.ttfb_list)}")
if self.received_audio:
duration = len(self.received_audio) / (self.sample_rate * 2)
print(f" Response duration: {duration:.2f}s")
print("=" * 50)
async def close(self) -> None:
"""Close the connection."""
self.running = False
if self.ws:
try:
await self.ws.close()
except:
pass
async def main():
parser = argparse.ArgumentParser(
description="WAV file client for testing duplex voice conversation"
)
parser.add_argument(
"--input", "-i",
required=True,
help="Input WAV file path"
)
parser.add_argument(
"--output", "-o",
required=True,
help="Output WAV file path for response"
)
parser.add_argument(
"--url",
default="ws://localhost:8000/ws",
help="WebSocket server URL (default: ws://localhost:8000/ws)"
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="Target sample rate for audio (default: 16000)"
)
parser.add_argument(
"--chunk-duration",
type=int,
default=20,
help="Chunk duration in ms for sending (default: 20)"
)
parser.add_argument(
"--wait-time", "-w",
type=float,
default=15.0,
help="Time to wait for response after sending (default: 15.0)"
)
parser.add_argument(
"--verbose", "-v",
action="store_true",
help="Enable verbose output"
)
args = parser.parse_args()
client = WavFileClient(
url=args.url,
input_file=args.input,
output_file=args.output,
sample_rate=args.sample_rate,
chunk_duration_ms=args.chunk_duration,
wait_time=args.wait_time,
verbose=args.verbose
)
await client.run()
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
print("\nInterrupted by user")

766
examples/web_client.html Normal file
View File

@@ -0,0 +1,766 @@
<!doctype html>
<html lang="en">
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<title>Duplex Voice Web Client</title>
<style>
@import url("https://fonts.googleapis.com/css2?family=Fraunces:opsz,wght@9..144,300;9..144,500;9..144,700&family=Recursive:wght@300;400;600;700&display=swap");
:root {
--bg: #0b0b0f;
--panel: #14141c;
--panel-2: #101018;
--ink: #f2f3f7;
--muted: #a7acba;
--accent: #ff6b6b;
--accent-2: #ffd166;
--good: #2dd4bf;
--bad: #f87171;
--grid: rgba(255, 255, 255, 0.06);
--shadow: 0 20px 60px rgba(0, 0, 0, 0.45);
}
* {
box-sizing: border-box;
}
html,
body {
height: 100%;
margin: 0;
color: var(--ink);
background: radial-gradient(1200px 600px at 20% -10%, #1d1d2a 0%, transparent 60%),
radial-gradient(800px 800px at 110% 10%, #20203a 0%, transparent 50%),
var(--bg);
font-family: "Recursive", ui-sans-serif, system-ui, -apple-system, "Segoe UI", sans-serif;
}
.noise {
position: fixed;
inset: 0;
background-image: url("data:image/svg+xml;utf8,<svg xmlns='http://www.w3.org/2000/svg' width='120' height='120' viewBox='0 0 120 120'><filter id='n'><feTurbulence type='fractalNoise' baseFrequency='0.9' numOctaves='2' stitchTiles='stitch'/></filter><rect width='120' height='120' filter='url(%23n)' opacity='0.06'/></svg>");
pointer-events: none;
mix-blend-mode: soft-light;
}
header {
padding: 32px 28px 18px;
border-bottom: 1px solid var(--grid);
}
h1 {
font-family: "Fraunces", serif;
font-weight: 600;
margin: 0 0 6px;
letter-spacing: 0.4px;
}
.subtitle {
color: var(--muted);
font-size: 0.95rem;
}
main {
display: grid;
grid-template-columns: 1.1fr 1.4fr;
gap: 24px;
padding: 24px 28px 40px;
}
.panel {
background: linear-gradient(180deg, rgba(255, 255, 255, 0.02), transparent),
var(--panel);
border: 1px solid var(--grid);
border-radius: 16px;
padding: 20px;
box-shadow: var(--shadow);
}
.panel h2 {
margin: 0 0 12px;
font-size: 1.05rem;
font-weight: 600;
}
.stack {
display: grid;
gap: 12px;
}
label {
display: block;
font-size: 0.85rem;
color: var(--muted);
margin-bottom: 6px;
}
input,
select,
button,
textarea {
font-family: inherit;
}
input,
select,
textarea {
width: 100%;
padding: 10px 12px;
border-radius: 10px;
border: 1px solid var(--grid);
background: var(--panel-2);
color: var(--ink);
outline: none;
}
textarea {
min-height: 80px;
resize: vertical;
}
.row {
display: grid;
grid-template-columns: 1fr 1fr;
gap: 12px;
}
.btn-row {
display: flex;
flex-wrap: wrap;
gap: 10px;
}
button {
border: none;
border-radius: 999px;
padding: 10px 16px;
font-weight: 600;
background: var(--ink);
color: #111;
cursor: pointer;
transition: transform 0.2s ease, box-shadow 0.2s ease;
}
button.secondary {
background: transparent;
color: var(--ink);
border: 1px solid var(--grid);
}
button.accent {
background: linear-gradient(120deg, var(--accent), #f97316);
color: #0b0b0f;
}
button.good {
background: linear-gradient(120deg, var(--good), #22c55e);
color: #07261f;
}
button.bad {
background: linear-gradient(120deg, var(--bad), #f97316);
color: #2a0b0b;
}
button:active {
transform: translateY(1px) scale(0.99);
}
.status {
display: flex;
align-items: center;
gap: 12px;
padding: 12px;
background: rgba(255, 255, 255, 0.03);
border-radius: 12px;
border: 1px dashed var(--grid);
font-size: 0.9rem;
}
.dot {
width: 10px;
height: 10px;
border-radius: 999px;
background: var(--bad);
box-shadow: 0 0 12px rgba(248, 113, 113, 0.5);
}
.dot.on {
background: var(--good);
box-shadow: 0 0 12px rgba(45, 212, 191, 0.7);
}
.log {
height: 320px;
overflow: auto;
padding: 12px;
background: #0d0d14;
border-radius: 12px;
border: 1px solid var(--grid);
font-size: 0.85rem;
line-height: 1.4;
}
.chat {
height: 260px;
overflow: auto;
padding: 12px;
background: #0d0d14;
border-radius: 12px;
border: 1px solid var(--grid);
font-size: 0.9rem;
line-height: 1.45;
}
.chat-entry {
padding: 8px 10px;
margin-bottom: 8px;
border-radius: 10px;
background: rgba(255, 255, 255, 0.04);
border: 1px solid rgba(255, 255, 255, 0.06);
}
.chat-entry.user {
border-left: 3px solid var(--accent-2);
}
.chat-entry.ai {
border-left: 3px solid var(--good);
}
.chat-entry.interim {
opacity: 0.7;
font-style: italic;
}
.log-entry {
padding: 6px 8px;
border-bottom: 1px dashed rgba(255, 255, 255, 0.06);
}
.log-entry:last-child {
border-bottom: none;
}
.tag {
display: inline-flex;
align-items: center;
gap: 6px;
padding: 2px 8px;
border-radius: 999px;
font-size: 0.7rem;
text-transform: uppercase;
letter-spacing: 0.6px;
background: rgba(255, 255, 255, 0.08);
color: var(--muted);
}
.tag.event {
background: rgba(255, 107, 107, 0.18);
color: #ffc1c1;
}
.tag.audio {
background: rgba(45, 212, 191, 0.2);
color: #c5f9f0;
}
.tag.sys {
background: rgba(255, 209, 102, 0.2);
color: #ffefb0;
}
.muted {
color: var(--muted);
}
footer {
padding: 0 28px 28px;
color: var(--muted);
font-size: 0.8rem;
}
@media (max-width: 1100px) {
main {
grid-template-columns: 1fr;
}
.log {
height: 360px;
}
.chat {
height: 260px;
}
}
</style>
</head>
<body>
<div class="noise"></div>
<header>
<h1>Duplex Voice Client</h1>
<div class="subtitle">Browser client for the WebSocket duplex pipeline. Device selection + event logging.</div>
</header>
<main>
<section class="panel stack">
<h2>Connection</h2>
<div>
<label for="wsUrl">WebSocket URL</label>
<input id="wsUrl" value="ws://localhost:8000/ws" />
</div>
<div class="btn-row">
<button class="accent" id="connectBtn">Connect</button>
<button class="secondary" id="disconnectBtn">Disconnect</button>
</div>
<div class="status">
<div id="statusDot" class="dot"></div>
<div>
<div id="statusText">Disconnected</div>
<div class="muted" id="statusSub">Waiting for connection</div>
</div>
</div>
<h2>Devices</h2>
<div class="row">
<div>
<label for="inputSelect">Input (Mic)</label>
<select id="inputSelect"></select>
</div>
<div>
<label for="outputSelect">Output (Speaker)</label>
<select id="outputSelect"></select>
</div>
</div>
<div class="btn-row">
<button class="secondary" id="refreshDevicesBtn">Refresh Devices</button>
<button class="good" id="startMicBtn">Start Mic</button>
<button class="secondary" id="stopMicBtn">Stop Mic</button>
</div>
<h2>Chat</h2>
<div class="stack">
<textarea id="chatInput" placeholder="Type a message, press Send"></textarea>
<div class="btn-row">
<button class="accent" id="sendChatBtn">Send Chat</button>
<button class="secondary" id="clearLogBtn">Clear Log</button>
</div>
</div>
</section>
<section class="stack">
<div class="panel stack">
<h2>Chat History</h2>
<div class="chat" id="chatHistory"></div>
</div>
<div class="panel stack">
<h2>Event Log</h2>
<div class="log" id="log"></div>
</div>
</section>
</main>
<footer>
Output device selection requires HTTPS + a browser that supports <code>setSinkId</code>.
Audio is sent as 16-bit PCM @ 16 kHz, matching <code>examples/mic_client.py</code>.
</footer>
<audio id="audioOut" autoplay></audio>
<script>
const wsUrl = document.getElementById("wsUrl");
const connectBtn = document.getElementById("connectBtn");
const disconnectBtn = document.getElementById("disconnectBtn");
const inputSelect = document.getElementById("inputSelect");
const outputSelect = document.getElementById("outputSelect");
const startMicBtn = document.getElementById("startMicBtn");
const stopMicBtn = document.getElementById("stopMicBtn");
const refreshDevicesBtn = document.getElementById("refreshDevicesBtn");
const sendChatBtn = document.getElementById("sendChatBtn");
const clearLogBtn = document.getElementById("clearLogBtn");
const chatInput = document.getElementById("chatInput");
const logEl = document.getElementById("log");
const chatHistory = document.getElementById("chatHistory");
const statusDot = document.getElementById("statusDot");
const statusText = document.getElementById("statusText");
const statusSub = document.getElementById("statusSub");
const audioOut = document.getElementById("audioOut");
let ws = null;
let audioCtx = null;
let micStream = null;
let processor = null;
let micSource = null;
let playbackDest = null;
let playbackTime = 0;
let discardAudio = false;
let playbackSources = [];
let interimUserEl = null;
let interimAiEl = null;
let interimUserText = "";
let interimAiText = "";
const targetSampleRate = 16000;
const playbackStopRampSec = 0.008;
function logLine(type, text, data) {
const time = new Date().toLocaleTimeString();
const entry = document.createElement("div");
entry.className = "log-entry";
const tag = document.createElement("span");
tag.className = `tag ${type}`;
tag.textContent = type.toUpperCase();
const msg = document.createElement("span");
msg.style.marginLeft = "10px";
msg.textContent = `[${time}] ${text}`;
entry.appendChild(tag);
entry.appendChild(msg);
if (data) {
const pre = document.createElement("div");
pre.className = "muted";
pre.textContent = JSON.stringify(data);
pre.style.marginTop = "4px";
entry.appendChild(pre);
}
logEl.appendChild(entry);
logEl.scrollTop = logEl.scrollHeight;
}
function addChat(role, text) {
const entry = document.createElement("div");
entry.className = `chat-entry ${role === "AI" ? "ai" : "user"}`;
entry.textContent = `${role}: ${text}`;
chatHistory.appendChild(entry);
chatHistory.scrollTop = chatHistory.scrollHeight;
}
function setInterim(role, text) {
const isAi = role === "AI";
let el = isAi ? interimAiEl : interimUserEl;
if (!text) {
if (el) el.remove();
if (isAi) interimAiEl = null;
else interimUserEl = null;
if (isAi) interimAiText = "";
else interimUserText = "";
return;
}
if (!el) {
el = document.createElement("div");
el.className = `chat-entry ${isAi ? "ai" : "user"} interim`;
chatHistory.appendChild(el);
if (isAi) interimAiEl = el;
else interimUserEl = el;
}
el.textContent = `${role} (interim): ${text}`;
chatHistory.scrollTop = chatHistory.scrollHeight;
}
function stopPlayback() {
discardAudio = true;
const now = audioCtx ? audioCtx.currentTime : 0;
playbackTime = now;
playbackSources.forEach((node) => {
try {
if (audioCtx && node.gainNode && node.source) {
node.gainNode.gain.cancelScheduledValues(now);
node.gainNode.gain.setValueAtTime(node.gainNode.gain.value || 1, now);
node.gainNode.gain.linearRampToValueAtTime(0, now + playbackStopRampSec);
node.source.stop(now + playbackStopRampSec + 0.002);
} else if (node.source) {
node.source.stop();
}
} catch (err) {}
});
playbackSources = [];
}
function setStatus(connected, detail) {
statusDot.classList.toggle("on", connected);
statusText.textContent = connected ? "Connected" : "Disconnected";
statusSub.textContent = detail || "";
}
async function ensureAudioContext() {
if (audioCtx) return;
audioCtx = new (window.AudioContext || window.webkitAudioContext)();
playbackDest = audioCtx.createMediaStreamDestination();
audioOut.srcObject = playbackDest.stream;
try {
await audioOut.play();
} catch (err) {
logLine("sys", "Audio playback blocked (user gesture needed)", { err: String(err) });
}
if (outputSelect.value) {
await setOutputDevice(outputSelect.value);
}
}
function downsampleBuffer(buffer, inRate, outRate) {
if (outRate === inRate) return buffer;
const ratio = inRate / outRate;
const newLength = Math.round(buffer.length / ratio);
const result = new Float32Array(newLength);
let offsetResult = 0;
let offsetBuffer = 0;
while (offsetResult < result.length) {
const nextOffsetBuffer = Math.round((offsetResult + 1) * ratio);
let accum = 0;
let count = 0;
for (let i = offsetBuffer; i < nextOffsetBuffer && i < buffer.length; i++) {
accum += buffer[i];
count++;
}
result[offsetResult] = accum / count;
offsetResult++;
offsetBuffer = nextOffsetBuffer;
}
return result;
}
function floatTo16BitPCM(float32) {
const out = new Int16Array(float32.length);
for (let i = 0; i < float32.length; i++) {
const s = Math.max(-1, Math.min(1, float32[i]));
out[i] = s < 0 ? s * 0x8000 : s * 0x7fff;
}
return out;
}
function schedulePlayback(int16Data) {
if (!audioCtx || !playbackDest) return;
if (discardAudio) return;
const float32 = new Float32Array(int16Data.length);
for (let i = 0; i < int16Data.length; i++) {
float32[i] = int16Data[i] / 32768;
}
const buffer = audioCtx.createBuffer(1, float32.length, targetSampleRate);
buffer.copyToChannel(float32, 0);
const source = audioCtx.createBufferSource();
const gainNode = audioCtx.createGain();
source.buffer = buffer;
source.connect(gainNode);
gainNode.connect(playbackDest);
const startTime = Math.max(audioCtx.currentTime + 0.02, playbackTime);
gainNode.gain.setValueAtTime(1, startTime);
source.start(startTime);
playbackTime = startTime + buffer.duration;
const playbackNode = { source, gainNode };
playbackSources.push(playbackNode);
source.onended = () => {
playbackSources = playbackSources.filter((s) => s !== playbackNode);
};
}
async function connect() {
if (ws && ws.readyState === WebSocket.OPEN) return;
ws = new WebSocket(wsUrl.value.trim());
ws.binaryType = "arraybuffer";
ws.onopen = () => {
setStatus(true, "Session open");
logLine("sys", "WebSocket connected");
ensureAudioContext();
sendCommand({ type: "hello", version: "v1" });
};
ws.onclose = () => {
setStatus(false, "Connection closed");
logLine("sys", "WebSocket closed");
ws = null;
};
ws.onerror = (err) => {
logLine("sys", "WebSocket error", { err: String(err) });
};
ws.onmessage = (msg) => {
if (typeof msg.data === "string") {
const event = JSON.parse(msg.data);
handleEvent(event);
} else {
const audioBuf = msg.data;
const int16 = new Int16Array(audioBuf);
schedulePlayback(int16);
logLine("audio", `Audio ${Math.round((int16.length / targetSampleRate) * 1000)}ms`);
}
};
}
function disconnect() {
if (ws && ws.readyState === WebSocket.OPEN) {
sendCommand({ type: "session.stop", reason: "client_disconnect" });
ws.close();
}
ws = null;
setStatus(false, "Disconnected");
}
function sendCommand(cmd) {
if (!ws || ws.readyState !== WebSocket.OPEN) {
logLine("sys", "Not connected");
return;
}
ws.send(JSON.stringify(cmd));
logLine("sys", `${cmd.type}`, cmd);
}
function handleEvent(event) {
const type = event.type || "unknown";
logLine("event", type, event);
if (type === "hello.ack") {
sendCommand({
type: "session.start",
audio: { encoding: "pcm_s16le", sample_rate_hz: targetSampleRate, channels: 1 },
});
}
if (type === "transcript.final") {
if (event.text) {
setInterim("You", "");
addChat("You", event.text);
}
}
if (type === "transcript.delta" && event.text) {
setInterim("You", event.text);
}
if (type === "assistant.response.final") {
if (event.text) {
setInterim("AI", "");
addChat("AI", event.text);
}
}
if (type === "assistant.response.delta" && event.text) {
interimAiText += event.text;
setInterim("AI", interimAiText);
}
if (type === "output.audio.start") {
// New bot audio: stop any previous playback to avoid overlap
stopPlayback();
discardAudio = false;
interimAiText = "";
}
if (type === "input.speech_started") {
// User started speaking: clear any in-flight audio to avoid overlap
stopPlayback();
}
if (type === "response.interrupted") {
stopPlayback();
}
}
async function startMic() {
if (!ws || ws.readyState !== WebSocket.OPEN) {
logLine("sys", "Connect before starting mic");
return;
}
await ensureAudioContext();
const deviceId = inputSelect.value || undefined;
micStream = await navigator.mediaDevices.getUserMedia({
audio: deviceId ? { deviceId: { exact: deviceId } } : true,
});
micSource = audioCtx.createMediaStreamSource(micStream);
processor = audioCtx.createScriptProcessor(2048, 1, 1);
processor.onaudioprocess = (e) => {
if (!ws || ws.readyState !== WebSocket.OPEN) return;
const input = e.inputBuffer.getChannelData(0);
const downsampled = downsampleBuffer(input, audioCtx.sampleRate, targetSampleRate);
const pcm16 = floatTo16BitPCM(downsampled);
ws.send(pcm16.buffer);
};
micSource.connect(processor);
processor.connect(audioCtx.destination);
logLine("sys", "Microphone started");
}
function stopMic() {
if (processor) {
processor.disconnect();
processor = null;
}
if (micSource) {
micSource.disconnect();
micSource = null;
}
if (micStream) {
micStream.getTracks().forEach((t) => t.stop());
micStream = null;
}
logLine("sys", "Microphone stopped");
}
async function refreshDevices() {
const devices = await navigator.mediaDevices.enumerateDevices();
inputSelect.innerHTML = "";
outputSelect.innerHTML = "";
devices.forEach((d) => {
if (d.kind === "audioinput") {
const opt = document.createElement("option");
opt.value = d.deviceId;
opt.textContent = d.label || `Mic ${inputSelect.length + 1}`;
inputSelect.appendChild(opt);
}
if (d.kind === "audiooutput") {
const opt = document.createElement("option");
opt.value = d.deviceId;
opt.textContent = d.label || `Output ${outputSelect.length + 1}`;
outputSelect.appendChild(opt);
}
});
}
async function requestDeviceAccess() {
// Needed to reveal device labels in most browsers
try {
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
stream.getTracks().forEach((t) => t.stop());
logLine("sys", "Microphone permission granted");
} catch (err) {
logLine("sys", "Microphone permission denied", { err: String(err) });
}
}
async function setOutputDevice(deviceId) {
if (!audioOut.setSinkId) {
logLine("sys", "setSinkId not supported in this browser");
return;
}
await audioOut.setSinkId(deviceId);
logLine("sys", `Output device set`, { deviceId });
}
connectBtn.addEventListener("click", connect);
disconnectBtn.addEventListener("click", disconnect);
refreshDevicesBtn.addEventListener("click", async () => {
await requestDeviceAccess();
await refreshDevices();
});
startMicBtn.addEventListener("click", startMic);
stopMicBtn.addEventListener("click", stopMic);
sendChatBtn.addEventListener("click", () => {
const text = chatInput.value.trim();
if (!text) return;
ensureAudioContext();
addChat("You", text);
sendCommand({ type: "input.text", text });
chatInput.value = "";
});
clearLogBtn.addEventListener("click", () => {
logEl.innerHTML = "";
chatHistory.innerHTML = "";
setInterim("You", "");
setInterim("AI", "");
interimUserText = "";
interimAiText = "";
});
inputSelect.addEventListener("change", () => {
if (micStream) {
stopMic();
startMic();
}
});
outputSelect.addEventListener("change", () => setOutputDevice(outputSelect.value));
navigator.mediaDevices.addEventListener("devicechange", refreshDevices);
refreshDevices().catch(() => {});
</script>
</body>
</html>

1
models/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Data Models Package"""

143
models/commands.py Normal file
View File

@@ -0,0 +1,143 @@
"""Protocol command models matching the original active-call API."""
from typing import Optional, Dict, Any
from pydantic import BaseModel, Field
class InviteCommand(BaseModel):
"""Invite command to initiate a call."""
command: str = Field(default="invite", description="Command type")
option: Optional[Dict[str, Any]] = Field(default=None, description="Call configuration options")
class AcceptCommand(BaseModel):
"""Accept command to accept an incoming call."""
command: str = Field(default="accept", description="Command type")
option: Optional[Dict[str, Any]] = Field(default=None, description="Call configuration options")
class RejectCommand(BaseModel):
"""Reject command to reject an incoming call."""
command: str = Field(default="reject", description="Command type")
reason: str = Field(default="", description="Reason for rejection")
code: Optional[int] = Field(default=None, description="SIP response code")
class RingingCommand(BaseModel):
"""Ringing command to send ringing response."""
command: str = Field(default="ringing", description="Command type")
recorder: Optional[Dict[str, Any]] = Field(default=None, description="Call recording configuration")
early_media: bool = Field(default=False, description="Enable early media")
ringtone: Optional[str] = Field(default=None, description="Custom ringtone URL")
class TTSCommand(BaseModel):
"""TTS command to convert text to speech."""
command: str = Field(default="tts", description="Command type")
text: str = Field(..., description="Text to synthesize")
speaker: Optional[str] = Field(default=None, description="Speaker voice name")
play_id: Optional[str] = Field(default=None, description="Unique identifier for this TTS session")
auto_hangup: bool = Field(default=False, description="Auto hangup after TTS completion")
streaming: bool = Field(default=False, description="Streaming text input")
end_of_stream: bool = Field(default=False, description="End of streaming input")
wait_input_timeout: Optional[int] = Field(default=None, description="Max time to wait for input (seconds)")
option: Optional[Dict[str, Any]] = Field(default=None, description="TTS provider specific options")
class PlayCommand(BaseModel):
"""Play command to play audio from URL."""
command: str = Field(default="play", description="Command type")
url: str = Field(..., description="URL of audio file to play")
auto_hangup: bool = Field(default=False, description="Auto hangup after playback")
wait_input_timeout: Optional[int] = Field(default=None, description="Max time to wait for input (seconds)")
class InterruptCommand(BaseModel):
"""Interrupt command to interrupt current playback."""
command: str = Field(default="interrupt", description="Command type")
graceful: bool = Field(default=False, description="Wait for current TTS to complete")
class PauseCommand(BaseModel):
"""Pause command to pause current playback."""
command: str = Field(default="pause", description="Command type")
class ResumeCommand(BaseModel):
"""Resume command to resume paused playback."""
command: str = Field(default="resume", description="Command type")
class HangupCommand(BaseModel):
"""Hangup command to end the call."""
command: str = Field(default="hangup", description="Command type")
reason: Optional[str] = Field(default=None, description="Reason for hangup")
initiator: Optional[str] = Field(default=None, description="Who initiated the hangup")
class HistoryCommand(BaseModel):
"""History command to add conversation history."""
command: str = Field(default="history", description="Command type")
speaker: str = Field(..., description="Speaker identifier")
text: str = Field(..., description="Conversation text")
class ChatCommand(BaseModel):
"""Chat command for text-based conversation."""
command: str = Field(default="chat", description="Command type")
text: str = Field(..., description="Chat text message")
# Command type mapping
COMMAND_TYPES = {
"invite": InviteCommand,
"accept": AcceptCommand,
"reject": RejectCommand,
"ringing": RingingCommand,
"tts": TTSCommand,
"play": PlayCommand,
"interrupt": InterruptCommand,
"pause": PauseCommand,
"resume": ResumeCommand,
"hangup": HangupCommand,
"history": HistoryCommand,
"chat": ChatCommand,
}
def parse_command(data: Dict[str, Any]) -> BaseModel:
"""
Parse a command from JSON data.
Args:
data: JSON data as dictionary
Returns:
Parsed command model
Raises:
ValueError: If command type is unknown
"""
command_type = data.get("command")
if not command_type:
raise ValueError("Missing 'command' field")
command_class = COMMAND_TYPES.get(command_type)
if not command_class:
raise ValueError(f"Unknown command type: {command_type}")
return command_class(**data)

126
models/config.py Normal file
View File

@@ -0,0 +1,126 @@
"""Configuration models for call options."""
from typing import Optional, Dict, Any, List
from pydantic import BaseModel, Field
class VADOption(BaseModel):
"""Voice Activity Detection configuration."""
type: str = Field(default="silero", description="VAD algorithm type (silero, webrtc)")
samplerate: int = Field(default=16000, description="Audio sample rate for VAD")
speech_padding: int = Field(default=250, description="Speech padding in milliseconds")
silence_padding: int = Field(default=100, description="Silence padding in milliseconds")
ratio: float = Field(default=0.5, description="Voice detection ratio threshold")
voice_threshold: float = Field(default=0.5, description="Voice energy threshold")
max_buffer_duration_secs: int = Field(default=50, description="Maximum buffer duration in seconds")
silence_timeout: Optional[int] = Field(default=None, description="Silence timeout in milliseconds")
endpoint: Optional[str] = Field(default=None, description="Custom VAD service endpoint")
secret_key: Optional[str] = Field(default=None, description="VAD service secret key")
secret_id: Optional[str] = Field(default=None, description="VAD service secret ID")
class ASROption(BaseModel):
"""Automatic Speech Recognition configuration."""
provider: str = Field(..., description="ASR provider (tencent, aliyun, openai, etc.)")
language: Optional[str] = Field(default=None, description="Language code (zh-CN, en-US)")
app_id: Optional[str] = Field(default=None, description="Application ID")
secret_id: Optional[str] = Field(default=None, description="Secret ID for authentication")
secret_key: Optional[str] = Field(default=None, description="Secret key for authentication")
model_type: Optional[str] = Field(default=None, description="ASR model type (16k_zh, 8k_en)")
buffer_size: Optional[int] = Field(default=None, description="Audio buffer size in bytes")
samplerate: Optional[int] = Field(default=None, description="Audio sample rate")
endpoint: Optional[str] = Field(default=None, description="Custom ASR service endpoint")
extra: Optional[Dict[str, Any]] = Field(default=None, description="Additional parameters")
start_when_answer: bool = Field(default=False, description="Start ASR when call is answered")
class TTSOption(BaseModel):
"""Text-to-Speech configuration."""
samplerate: Optional[int] = Field(default=None, description="TTS output sample rate")
provider: str = Field(default="msedge", description="TTS provider (tencent, aliyun, deepgram, msedge)")
speed: float = Field(default=1.0, description="Speech speed multiplier")
app_id: Optional[str] = Field(default=None, description="Application ID")
secret_id: Optional[str] = Field(default=None, description="Secret ID for authentication")
secret_key: Optional[str] = Field(default=None, description="Secret key for authentication")
volume: Optional[int] = Field(default=None, description="Speech volume level (1-10)")
speaker: Optional[str] = Field(default=None, description="Voice speaker name")
codec: Optional[str] = Field(default=None, description="Audio codec")
subtitle: bool = Field(default=False, description="Enable subtitle generation")
emotion: Optional[str] = Field(default=None, description="Speech emotion")
endpoint: Optional[str] = Field(default=None, description="Custom TTS service endpoint")
extra: Optional[Dict[str, Any]] = Field(default=None, description="Additional parameters")
max_concurrent_tasks: Optional[int] = Field(default=None, description="Max concurrent tasks")
class RecorderOption(BaseModel):
"""Call recording configuration."""
recorder_file: str = Field(..., description="Path to recording file")
samplerate: int = Field(default=16000, description="Recording sample rate")
ptime: int = Field(default=200, description="Packet time in milliseconds")
class MediaPassOption(BaseModel):
"""Media pass-through configuration for external audio processing."""
url: str = Field(..., description="WebSocket URL for media streaming")
input_sample_rate: int = Field(default=16000, description="Sample rate of audio received from WebSocket")
output_sample_rate: int = Field(default=16000, description="Sample rate of audio sent to WebSocket")
packet_size: int = Field(default=2560, description="Packet size in bytes")
ptime: Optional[int] = Field(default=None, description="Buffered playback period in milliseconds")
class SipOption(BaseModel):
"""SIP protocol configuration."""
username: Optional[str] = Field(default=None, description="SIP username")
password: Optional[str] = Field(default=None, description="SIP password")
realm: Optional[str] = Field(default=None, description="SIP realm/domain")
headers: Optional[Dict[str, str]] = Field(default=None, description="Additional SIP headers")
class HandlerRule(BaseModel):
"""Handler routing rule."""
caller: Optional[str] = Field(default=None, description="Caller pattern (regex)")
callee: Optional[str] = Field(default=None, description="Callee pattern (regex)")
playbook: Optional[str] = Field(default=None, description="Playbook file path")
webhook: Optional[str] = Field(default=None, description="Webhook URL")
class CallOption(BaseModel):
"""Comprehensive call configuration options."""
# Basic options
denoise: bool = Field(default=False, description="Enable noise reduction")
offer: Optional[str] = Field(default=None, description="SDP offer string")
callee: Optional[str] = Field(default=None, description="Callee SIP URI or phone number")
caller: Optional[str] = Field(default=None, description="Caller SIP URI or phone number")
# Audio codec
codec: str = Field(default="pcm", description="Audio codec (pcm, pcma, pcmu, g722)")
# Component configurations
recorder: Optional[RecorderOption] = Field(default=None, description="Call recording config")
asr: Optional[ASROption] = Field(default=None, description="ASR configuration")
vad: Optional[VADOption] = Field(default=None, description="VAD configuration")
tts: Optional[TTSOption] = Field(default=None, description="TTS configuration")
media_pass: Optional[MediaPassOption] = Field(default=None, description="Media pass-through config")
sip: Optional[SipOption] = Field(default=None, description="SIP configuration")
# Timeouts and networking
handshake_timeout: Optional[int] = Field(default=None, description="Handshake timeout in seconds")
enable_ipv6: bool = Field(default=False, description="Enable IPv6 support")
inactivity_timeout: Optional[int] = Field(default=None, description="Inactivity timeout in seconds")
# EOU configuration
eou: Optional[Dict[str, Any]] = Field(default=None, description="End of utterance detection config")
# Extra parameters
extra: Optional[Dict[str, Any]] = Field(default=None, description="Additional custom parameters")
class Config:
populate_by_name = True

231
models/events.py Normal file
View File

@@ -0,0 +1,231 @@
"""Protocol event models matching the original active-call API."""
from typing import Optional, Dict, Any
from pydantic import BaseModel, Field
from datetime import datetime
def current_timestamp_ms() -> int:
"""Get current timestamp in milliseconds."""
return int(datetime.now().timestamp() * 1000)
# Base Event Model
class BaseEvent(BaseModel):
"""Base event model."""
event: str = Field(..., description="Event type")
track_id: str = Field(..., description="Unique track identifier")
timestamp: int = Field(default_factory=current_timestamp_ms, description="Event timestamp in milliseconds")
# Lifecycle Events
class IncomingEvent(BaseEvent):
"""Incoming call event (SIP only)."""
event: str = Field(default="incoming", description="Event type")
caller: Optional[str] = Field(default=None, description="Caller's SIP URI")
callee: Optional[str] = Field(default=None, description="Callee's SIP URI")
sdp: Optional[str] = Field(default=None, description="SDP offer from caller")
class AnswerEvent(BaseEvent):
"""Call answered event."""
event: str = Field(default="answer", description="Event type")
sdp: Optional[str] = Field(default=None, description="SDP answer from server")
class RejectEvent(BaseEvent):
"""Call rejected event."""
event: str = Field(default="reject", description="Event type")
reason: Optional[str] = Field(default=None, description="Rejection reason")
code: Optional[int] = Field(default=None, description="SIP response code")
class RingingEvent(BaseEvent):
"""Call ringing event."""
event: str = Field(default="ringing", description="Event type")
early_media: bool = Field(default=False, description="Early media available")
class HangupEvent(BaseModel):
"""Call hangup event."""
event: str = Field(default="hangup", description="Event type")
timestamp: int = Field(default_factory=current_timestamp_ms, description="Event timestamp")
reason: Optional[str] = Field(default=None, description="Hangup reason")
initiator: Optional[str] = Field(default=None, description="Who initiated hangup")
start_time: Optional[str] = Field(default=None, description="Call start time (ISO 8601)")
hangup_time: Optional[str] = Field(default=None, description="Hangup time (ISO 8601)")
answer_time: Optional[str] = Field(default=None, description="Answer time (ISO 8601)")
ringing_time: Optional[str] = Field(default=None, description="Ringing time (ISO 8601)")
from_: Optional[Dict[str, Any]] = Field(default=None, alias="from", description="Caller info")
to: Optional[Dict[str, Any]] = Field(default=None, description="Callee info")
extra: Optional[Dict[str, Any]] = Field(default=None, description="Additional metadata")
class Config:
populate_by_name = True
# VAD Events
class SpeakingEvent(BaseEvent):
"""Speech detected event."""
event: str = Field(default="speaking", description="Event type")
start_time: int = Field(default_factory=current_timestamp_ms, description="Speech start time")
class SilenceEvent(BaseEvent):
"""Silence detected event."""
event: str = Field(default="silence", description="Event type")
start_time: int = Field(default_factory=current_timestamp_ms, description="Silence start time")
duration: int = Field(default=0, description="Silence duration in milliseconds")
# AI/ASR Events
class AsrFinalEvent(BaseEvent):
"""ASR final transcription event."""
event: str = Field(default="asrFinal", description="Event type")
index: int = Field(..., description="ASR result sequence number")
start_time: Optional[int] = Field(default=None, description="Speech start time")
end_time: Optional[int] = Field(default=None, description="Speech end time")
text: str = Field(..., description="Transcribed text")
class AsrDeltaEvent(BaseEvent):
"""ASR partial transcription event (streaming)."""
event: str = Field(default="asrDelta", description="Event type")
index: int = Field(..., description="ASR result sequence number")
start_time: Optional[int] = Field(default=None, description="Speech start time")
end_time: Optional[int] = Field(default=None, description="Speech end time")
text: str = Field(..., description="Partial transcribed text")
class EouEvent(BaseEvent):
"""End of utterance detection event."""
event: str = Field(default="eou", description="Event type")
completed: bool = Field(default=True, description="Whether utterance was completed")
# Audio Track Events
class TrackStartEvent(BaseEvent):
"""Audio track start event."""
event: str = Field(default="trackStart", description="Event type")
play_id: Optional[str] = Field(default=None, description="Play ID from TTS/Play command")
class TrackEndEvent(BaseEvent):
"""Audio track end event."""
event: str = Field(default="trackEnd", description="Event type")
duration: int = Field(..., description="Track duration in milliseconds")
ssrc: int = Field(..., description="RTP SSRC identifier")
play_id: Optional[str] = Field(default=None, description="Play ID from TTS/Play command")
class InterruptionEvent(BaseEvent):
"""Playback interruption event."""
event: str = Field(default="interruption", description="Event type")
play_id: Optional[str] = Field(default=None, description="Play ID that was interrupted")
subtitle: Optional[str] = Field(default=None, description="TTS text being played")
position: Optional[int] = Field(default=None, description="Word index position")
total_duration: Optional[int] = Field(default=None, description="Total TTS duration")
current: Optional[int] = Field(default=None, description="Elapsed time when interrupted")
# System Events
class ErrorEvent(BaseEvent):
"""Error event."""
event: str = Field(default="error", description="Event type")
sender: str = Field(..., description="Component that generated the error")
error: str = Field(..., description="Error message")
code: Optional[int] = Field(default=None, description="Error code")
class MetricsEvent(BaseModel):
"""Performance metrics event."""
event: str = Field(default="metrics", description="Event type")
timestamp: int = Field(default_factory=current_timestamp_ms, description="Event timestamp")
key: str = Field(..., description="Metric key")
duration: int = Field(..., description="Duration in milliseconds")
data: Optional[Dict[str, Any]] = Field(default=None, description="Additional metric data")
class AddHistoryEvent(BaseModel):
"""Conversation history entry added event."""
event: str = Field(default="addHistory", description="Event type")
timestamp: int = Field(default_factory=current_timestamp_ms, description="Event timestamp")
sender: Optional[str] = Field(default=None, description="Component that added history")
speaker: str = Field(..., description="Speaker identifier")
text: str = Field(..., description="Conversation text")
class DTMFEvent(BaseEvent):
"""DTMF tone detected event."""
event: str = Field(default="dtmf", description="Event type")
digit: str = Field(..., description="DTMF digit (0-9, *, #, A-D)")
class HeartBeatEvent(BaseModel):
"""Server-to-client heartbeat to keep connection alive."""
event: str = Field(default="heartBeat", description="Event type")
timestamp: int = Field(default_factory=current_timestamp_ms, description="Event timestamp in milliseconds")
# Event type mapping
EVENT_TYPES = {
"incoming": IncomingEvent,
"answer": AnswerEvent,
"reject": RejectEvent,
"ringing": RingingEvent,
"hangup": HangupEvent,
"speaking": SpeakingEvent,
"silence": SilenceEvent,
"asrFinal": AsrFinalEvent,
"asrDelta": AsrDeltaEvent,
"eou": EouEvent,
"trackStart": TrackStartEvent,
"trackEnd": TrackEndEvent,
"interruption": InterruptionEvent,
"error": ErrorEvent,
"metrics": MetricsEvent,
"addHistory": AddHistoryEvent,
"dtmf": DTMFEvent,
"heartBeat": HeartBeatEvent,
}
def create_event(event_type: str, **kwargs) -> BaseModel:
"""
Create an event model.
Args:
event_type: Type of event to create
**kwargs: Event fields
Returns:
Event model instance
Raises:
ValueError: If event type is unknown
"""
event_class = EVENT_TYPES.get(event_type)
if not event_class:
raise ValueError(f"Unknown event type: {event_type}")
return event_class(event=event_type, **kwargs)

73
models/ws_v1.py Normal file
View File

@@ -0,0 +1,73 @@
"""WS v1 protocol message models and helpers."""
from typing import Optional, Dict, Any, Literal
from pydantic import BaseModel, Field
def now_ms() -> int:
"""Current unix timestamp in milliseconds."""
import time
return int(time.time() * 1000)
# Client -> Server messages
class HelloMessage(BaseModel):
type: Literal["hello"]
version: str = Field(..., description="Protocol version, currently v1")
auth: Optional[Dict[str, str]] = Field(default=None, description="Auth payload, e.g. {'apiKey': '...'}")
class SessionStartMessage(BaseModel):
type: Literal["session.start"]
audio: Optional[Dict[str, Any]] = Field(default=None, description="Optional audio format metadata")
metadata: Optional[Dict[str, Any]] = Field(default=None, description="Optional session metadata")
class SessionStopMessage(BaseModel):
type: Literal["session.stop"]
reason: Optional[str] = None
class InputTextMessage(BaseModel):
type: Literal["input.text"]
text: str
class ResponseCancelMessage(BaseModel):
type: Literal["response.cancel"]
graceful: bool = False
class ToolCallResultsMessage(BaseModel):
type: Literal["tool_call.results"]
results: list[Dict[str, Any]] = Field(default_factory=list)
CLIENT_MESSAGE_TYPES = {
"hello": HelloMessage,
"session.start": SessionStartMessage,
"session.stop": SessionStopMessage,
"input.text": InputTextMessage,
"response.cancel": ResponseCancelMessage,
"tool_call.results": ToolCallResultsMessage,
}
def parse_client_message(data: Dict[str, Any]) -> BaseModel:
"""Parse and validate a WS v1 client message."""
msg_type = data.get("type")
if not msg_type:
raise ValueError("Missing 'type' field")
msg_class = CLIENT_MESSAGE_TYPES.get(msg_type)
if not msg_class:
raise ValueError(f"Unknown client message type: {msg_type}")
return msg_class(**data)
# Server -> Client event helpers
def ev(event_type: str, **payload: Any) -> Dict[str, Any]:
"""Create a WS v1 server event payload."""
base = {"type": event_type, "timestamp": now_ms()}
base.update(payload)
return base

6
processors/__init__.py Normal file
View File

@@ -0,0 +1,6 @@
"""Audio Processors Package"""
from processors.eou import EouDetector
from processors.vad import SileroVAD, VADProcessor
__all__ = ["EouDetector", "SileroVAD", "VADProcessor"]

80
processors/eou.py Normal file
View File

@@ -0,0 +1,80 @@
"""End-of-Utterance Detection."""
import time
from typing import Optional
class EouDetector:
"""
End-of-utterance detector. Fires EOU only after continuous silence for
silence_threshold_ms. Short pauses between sentences do not trigger EOU
because speech resets the silence timer (one EOU per turn).
"""
def __init__(self, silence_threshold_ms: int = 1000, min_speech_duration_ms: int = 250):
"""
Initialize EOU detector.
Args:
silence_threshold_ms: How long silence must last to trigger EOU (default 1000ms)
min_speech_duration_ms: Minimum speech duration to consider valid (default 250ms)
"""
self.threshold = silence_threshold_ms / 1000.0
self.min_speech = min_speech_duration_ms / 1000.0
self._silence_threshold_ms = silence_threshold_ms
self._min_speech_duration_ms = min_speech_duration_ms
# State
self.is_speaking = False
self.speech_start_time = 0.0
self.silence_start_time: Optional[float] = None
self.triggered = False
def process(self, vad_status: str, force_eligible: bool = False) -> bool:
"""
Process VAD status and detect end of utterance.
Input: "Speech" or "Silence" (from VAD).
Output: True if EOU detected, False otherwise.
Short breaks between phrases reset the silence clock when speech
resumes, so only one EOU is emitted after the user truly stops.
"""
now = time.time()
if vad_status == "Speech":
if not self.is_speaking:
self.is_speaking = True
self.speech_start_time = now
self.triggered = False
# Any speech resets silence timer — short pause + more speech = one utterance
self.silence_start_time = None
return False
if vad_status == "Silence":
if not self.is_speaking:
return False
if self.silence_start_time is None:
self.silence_start_time = now
speech_duration = self.silence_start_time - self.speech_start_time
if speech_duration < self.min_speech and not force_eligible:
self.is_speaking = False
self.silence_start_time = None
return False
silence_duration = now - self.silence_start_time
if silence_duration >= self.threshold and not self.triggered:
self.triggered = True
self.is_speaking = False
self.silence_start_time = None
return True
return False
def reset(self) -> None:
"""Reset EOU detector state."""
self.is_speaking = False
self.speech_start_time = 0.0
self.silence_start_time = None
self.triggered = False

168
processors/tracks.py Normal file
View File

@@ -0,0 +1,168 @@
"""Audio track processing for WebRTC."""
import asyncio
import fractions
from typing import Optional
from loguru import logger
# Try to import aiortc (optional for WebRTC functionality)
try:
from aiortc import AudioStreamTrack
AIORTC_AVAILABLE = True
except ImportError:
AIORTC_AVAILABLE = False
AudioStreamTrack = object # Dummy class for type hints
# Try to import PyAV (optional for audio resampling)
try:
from av import AudioFrame, AudioResampler
AV_AVAILABLE = True
except ImportError:
AV_AVAILABLE = False
# Create dummy classes for type hints
class AudioFrame:
pass
class AudioResampler:
pass
import numpy as np
class Resampled16kTrack(AudioStreamTrack if AIORTC_AVAILABLE else object):
"""
Audio track that resamples input to 16kHz mono PCM.
Wraps an existing MediaStreamTrack and converts its output
to 16kHz mono 16-bit PCM format for the pipeline.
"""
def __init__(self, track, target_sample_rate: int = 16000):
"""
Initialize resampled track.
Args:
track: Source MediaStreamTrack
target_sample_rate: Target sample rate (default: 16000)
"""
if not AIORTC_AVAILABLE:
raise RuntimeError("aiortc not available - Resampled16kTrack cannot be used")
super().__init__()
self.track = track
self.target_sample_rate = target_sample_rate
if AV_AVAILABLE:
self.resampler = AudioResampler(
format="s16",
layout="mono",
rate=target_sample_rate
)
else:
logger.warning("PyAV not available, audio resampling disabled")
self.resampler = None
self._closed = False
async def recv(self):
"""
Receive and resample next audio frame.
Returns:
Resampled AudioFrame at 16kHz mono
"""
if self._closed:
raise RuntimeError("Track is closed")
# Get frame from source track
frame = await self.track.recv()
# Resample the frame if AV is available
if AV_AVAILABLE and self.resampler:
resampled_frame = self.resampler.resample(frame)
# Ensure the frame has the correct format
resampled_frame.sample_rate = self.target_sample_rate
return resampled_frame
else:
# Return frame as-is if AV is not available
return frame
async def stop(self) -> None:
"""Stop the track and cleanup resources."""
self._closed = True
if hasattr(self, 'resampler') and self.resampler:
del self.resampler
logger.debug("Resampled track stopped")
class SineWaveTrack(AudioStreamTrack if AIORTC_AVAILABLE else object):
"""
Synthetic audio track that generates a sine wave.
Useful for testing without requiring real audio input.
"""
def __init__(self, sample_rate: int = 16000, frequency: int = 440):
"""
Initialize sine wave track.
Args:
sample_rate: Audio sample rate (default: 16000)
frequency: Sine wave frequency in Hz (default: 440)
"""
if not AIORTC_AVAILABLE:
raise RuntimeError("aiortc not available - SineWaveTrack cannot be used")
super().__init__()
self.sample_rate = sample_rate
self.frequency = frequency
self.counter = 0
self._stopped = False
async def recv(self):
"""
Generate next audio frame with sine wave.
Returns:
AudioFrame with sine wave data
"""
if self._stopped:
raise RuntimeError("Track is stopped")
# Generate 20ms of audio
samples = int(self.sample_rate * 0.02)
pts = self.counter
time_base = fractions.Fraction(1, self.sample_rate)
# Generate sine wave
t = np.linspace(
self.counter / self.sample_rate,
(self.counter + samples) / self.sample_rate,
samples,
endpoint=False
)
# Generate sine wave (Int16 PCM)
data = (0.5 * np.sin(2 * np.pi * self.frequency * t) * 32767).astype(np.int16)
# Update counter
self.counter += samples
# Create AudioFrame if AV is available
if AV_AVAILABLE:
frame = AudioFrame.from_ndarray(data.reshape(1, -1), format='s16', layout='mono')
frame.pts = pts
frame.time_base = time_base
frame.sample_rate = self.sample_rate
return frame
else:
# Return simple data structure if AV is not available
return {
'data': data,
'sample_rate': self.sample_rate,
'pts': pts,
'time_base': time_base
}
def stop(self) -> None:
"""Stop the track."""
self._stopped = True

221
processors/vad.py Normal file
View File

@@ -0,0 +1,221 @@
"""Voice Activity Detection using Silero VAD."""
import asyncio
import os
from typing import Tuple, Optional
import numpy as np
from loguru import logger
# Try to import onnxruntime (optional for VAD functionality)
try:
import onnxruntime as ort
ONNX_AVAILABLE = True
except ImportError:
ONNX_AVAILABLE = False
ort = None
logger.warning("onnxruntime not available - VAD will be disabled")
class SileroVAD:
"""
Voice Activity Detection using Silero VAD model.
Detects speech in audio chunks using the Silero VAD ONNX model.
Returns "Speech" or "Silence" for each audio chunk.
"""
def __init__(self, model_path: str = "data/vad/silero_vad.onnx", sample_rate: int = 16000):
"""
Initialize Silero VAD.
Args:
model_path: Path to Silero VAD ONNX model
sample_rate: Audio sample rate (must be 16kHz for Silero VAD)
"""
self.sample_rate = sample_rate
self.model_path = model_path
# Check if model exists
if not os.path.exists(model_path):
logger.warning(f"VAD model not found at {model_path}. VAD will be disabled.")
self.session = None
return
# Check if onnxruntime is available
if not ONNX_AVAILABLE:
logger.warning("onnxruntime not available - VAD will be disabled")
self.session = None
return
# Load ONNX model
try:
self.session = ort.InferenceSession(model_path)
logger.info(f"Loaded Silero VAD model from {model_path}")
except Exception as e:
logger.error(f"Failed to load VAD model: {e}")
self.session = None
return
# Internal state for VAD
self._reset_state()
self.buffer = np.array([], dtype=np.float32)
self.min_chunk_size = 512
self.last_label = "Silence"
self.last_probability = 0.0
self._energy_noise_floor = 1e-4
def _reset_state(self):
# Silero VAD V4+ expects state shape [2, 1, 128]
self._state = np.zeros((2, 1, 128), dtype=np.float32)
self._sr = np.array([self.sample_rate], dtype=np.int64)
def process_audio(self, pcm_bytes: bytes, chunk_size_ms: int = 20) -> Tuple[str, float]:
"""
Process audio chunk and detect speech.
Args:
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
chunk_size_ms: Chunk duration in milliseconds (ignored for buffering logic)
Returns:
Tuple of (label, probability) where label is "Speech" or "Silence"
"""
if self.session is None or not ONNX_AVAILABLE:
# Fallback energy-based VAD with adaptive noise floor.
if not pcm_bytes:
return "Silence", 0.0
audio_int16 = np.frombuffer(pcm_bytes, dtype=np.int16)
if audio_int16.size == 0:
return "Silence", 0.0
audio_float = audio_int16.astype(np.float32) / 32768.0
rms = float(np.sqrt(np.mean(audio_float * audio_float)))
# Update adaptive noise floor (slowly rises, faster to fall)
if rms < self._energy_noise_floor:
self._energy_noise_floor = 0.95 * self._energy_noise_floor + 0.05 * rms
else:
self._energy_noise_floor = 0.995 * self._energy_noise_floor + 0.005 * rms
# Compute SNR-like ratio and map to probability
denom = max(self._energy_noise_floor, 1e-6)
snr = max(0.0, (rms - denom) / denom)
probability = min(1.0, snr / 3.0) # ~3x above noise => strong speech
label = "Speech" if probability >= 0.5 else "Silence"
return label, probability
# Convert bytes to numpy array of int16
audio_int16 = np.frombuffer(pcm_bytes, dtype=np.int16)
# Normalize to float32 (-1.0 to 1.0)
audio_float = audio_int16.astype(np.float32) / 32768.0
# Add to buffer
self.buffer = np.concatenate((self.buffer, audio_float))
# Process all complete chunks in the buffer
processed_any = False
while len(self.buffer) >= self.min_chunk_size:
# Slice exactly 512 samples
chunk = self.buffer[:self.min_chunk_size]
self.buffer = self.buffer[self.min_chunk_size:]
# Prepare inputs
# Input tensor shape: [batch, samples] -> [1, 512]
input_tensor = chunk.reshape(1, -1)
# Run inference
try:
ort_inputs = {
'input': input_tensor,
'state': self._state,
'sr': self._sr
}
# Outputs: probability, state
out, self._state = self.session.run(None, ort_inputs)
# Get probability
self.last_probability = float(out[0][0])
self.last_label = "Speech" if self.last_probability >= 0.5 else "Silence"
processed_any = True
except Exception as e:
logger.error(f"VAD inference error: {e}")
# Try to determine if it's an input name issue
try:
inputs = [x.name for x in self.session.get_inputs()]
logger.error(f"Model expects inputs: {inputs}")
except:
pass
return "Speech", 1.0
return self.last_label, self.last_probability
def reset(self) -> None:
"""Reset VAD internal state."""
self._reset_state()
self.buffer = np.array([], dtype=np.float32)
self.last_label = "Silence"
self.last_probability = 0.0
class VADProcessor:
"""
High-level VAD processor with state management.
Tracks speech/silence state and emits events on transitions.
"""
def __init__(self, vad_model: SileroVAD, threshold: float = 0.5):
"""
Initialize VAD processor.
Args:
vad_model: Silero VAD model instance
threshold: Speech detection threshold
"""
self.vad = vad_model
self.threshold = threshold
self.is_speaking = False
self.speech_start_time: Optional[float] = None
self.silence_start_time: Optional[float] = None
def process(self, pcm_bytes: bytes, chunk_size_ms: int = 20) -> Optional[Tuple[str, float]]:
"""
Process audio chunk and detect state changes.
Args:
pcm_bytes: PCM audio data
chunk_size_ms: Chunk duration in milliseconds
Returns:
Tuple of (event_type, probability) if state changed, None otherwise
"""
label, probability = self.vad.process_audio(pcm_bytes, chunk_size_ms)
# Check if this is speech based on threshold
is_speech = probability >= self.threshold
# State transition: Silence -> Speech
if is_speech and not self.is_speaking:
self.is_speaking = True
self.speech_start_time = asyncio.get_event_loop().time()
self.silence_start_time = None
return ("speaking", probability)
# State transition: Speech -> Silence
elif not is_speech and self.is_speaking:
self.is_speaking = False
self.silence_start_time = asyncio.get_event_loop().time()
self.speech_start_time = None
return ("silence", probability)
return None
def reset(self) -> None:
"""Reset VAD state."""
self.vad.reset()
self.is_speaking = False
self.speech_start_time = None
self.silence_start_time = None

134
pyproject.toml Normal file
View File

@@ -0,0 +1,134 @@
[build-system]
requires = ["setuptools>=68.0"]
build-backend = "setuptools.build_meta"
[project]
name = "py-active-call-cc"
version = "0.1.0"
description = "Python Active-Call: Real-time audio streaming with WebSocket and WebRTC"
readme = "README.md"
requires-python = ">=3.11"
license = {text = "MIT"}
authors = [
{name = "Your Name", email = "your.email@example.com"}
]
keywords = ["webrtc", "websocket", "audio", "voip", "real-time"]
classifiers = [
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"Topic :: Communications :: Telephony",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
]
[project.urls]
Homepage = "https://github.com/yourusername/py-active-call-cc"
Documentation = "https://github.com/yourusername/py-active-call-cc/blob/main/README.md"
Repository = "https://github.com/yourusername/py-active-call-cc.git"
Issues = "https://github.com/yourusername/py-active-call-cc/issues"
[tool.setuptools.packages.find]
where = ["."]
include = ["app*"]
exclude = ["tests*", "scripts*", "reference*"]
[tool.black]
line-length = 100
target-version = ['py311']
include = '\.pyi?$'
extend-exclude = '''
/(
# directories
\.eggs
| \.git
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| build
| dist
| reference
)/
'''
[tool.ruff]
line-length = 100
target-version = "py311"
select = [
"E", # pycodestyle errors
"W", # pycodestyle warnings
"F", # pyflakes
"I", # isort
"B", # flake8-bugbear
"C4", # flake8-comprehensions
"UP", # pyupgrade
]
ignore = [
"E501", # line too long (handled by black)
"B008", # do not perform function calls in argument defaults
]
exclude = [
".bzr",
".direnv",
".eggs",
".git",
".hg",
".mypy_cache",
".nox",
".pants.d",
".ruff_cache",
".svn",
".tox",
".venv",
"__pypackages__",
"_build",
"buck-out",
"build",
"dist",
"node_modules",
"venv",
"reference",
]
[tool.ruff.per-file-ignores]
"__init__.py" = ["F401"] # unused imports
[tool.mypy]
python_version = "3.11"
warn_return_any = true
warn_unused_configs = true
disallow_untyped_defs = false
disallow_incomplete_defs = false
check_untyped_defs = true
no_implicit_optional = true
warn_redundant_casts = true
warn_unused_ignores = true
warn_no_return = true
strict_equality = true
exclude = [
"venv",
"reference",
"build",
"dist",
]
[[tool.mypy.overrides]]
module = [
"aiortc.*",
"av.*",
"onnxruntime.*",
]
ignore_missing_imports = true
[tool.pytest.ini_options]
minversion = "7.0"
addopts = "-ra -q --strict-markers --strict-config"
testpaths = ["tests"]
pythonpath = ["."]
asyncio_mode = "auto"
markers = [
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
"integration: marks tests as integration tests",
]

37
requirements.txt Normal file
View File

@@ -0,0 +1,37 @@
# Web Framework
fastapi>=0.109.0
uvicorn[standard]>=0.27.0
websockets>=12.0
python-multipart>=0.0.6
# WebRTC (optional - for WebRTC transport)
aiortc>=1.6.0
# Audio Processing
av>=12.1.0
numpy>=1.26.3
onnxruntime>=1.16.3
# Configuration
pydantic>=2.5.3
pydantic-settings>=2.1.0
python-dotenv>=1.0.0
toml>=0.10.2
# Logging
loguru>=0.7.2
# HTTP Client
aiohttp>=3.9.1
# AI Services - LLM
openai>=1.0.0
# AI Services - TTS
edge-tts>=6.1.0
pydub>=0.25.0 # For audio format conversion
# Microphone client dependencies
sounddevice>=0.4.6
soundfile>=0.12.1
pyaudio>=0.2.13 # More reliable audio on Windows

1
scripts/README.md Normal file
View File

@@ -0,0 +1 @@
# Development Script

View File

@@ -0,0 +1,311 @@
#!/usr/bin/env python3
"""
Generate test audio file with utterances using SiliconFlow TTS API.
Creates a 16kHz mono WAV file with real speech segments separated by
configurable silence (for VAD/testing).
Usage:
python generate_test_audio.py [OPTIONS]
Options:
-o, --output PATH Output WAV path (default: data/audio_examples/two_utterances_16k.wav)
-u, --utterance TEXT Utterance text; repeat for multiple (ignored if -j is set)
-j, --json PATH JSON file: array of strings or {"utterances": [...]}
--silence-ms MS Silence in ms between utterances (default: 500)
--lead-silence-ms MS Silence in ms at start (default: 200)
--trail-silence-ms MS Silence in ms at end (default: 300)
Examples:
# Default utterances and output
python generate_test_audio.py
# Custom output path
python generate_test_audio.py -o out.wav
# Utterances from command line
python generate_test_audio.py -u "Hello" -u "World" -o test.wav
# Utterancgenerate_test_audio.py -j utterances.json -o test.wav
# Custom silence (1s between utterances)
python generate_test_audio.py -u "One" -u "Two" --silence-ms 1000 -o test.wav
Requires SILICONFLOW_API_KEY in .env.
"""
import wave
import struct
import argparse
import asyncio
import aiohttp
import json
import os
from pathlib import Path
from dotenv import load_dotenv
# Load .env file from project root
project_root = Path(__file__).parent.parent.parent
load_dotenv(project_root / ".env")
# SiliconFlow TTS Configuration
SILICONFLOW_API_URL = "https://api.siliconflow.cn/v1/audio/speech"
SILICONFLOW_MODEL = "FunAudioLLM/CosyVoice2-0.5B"
# Available voices
VOICES = {
"alex": "FunAudioLLM/CosyVoice2-0.5B:alex",
"anna": "FunAudioLLM/CosyVoice2-0.5B:anna",
"bella": "FunAudioLLM/CosyVoice2-0.5B:bella",
"benjamin": "FunAudioLLM/CosyVoice2-0.5B:benjamin",
"charles": "FunAudioLLM/CosyVoice2-0.5B:charles",
"claire": "FunAudioLLM/CosyVoice2-0.5B:claire",
"david": "FunAudioLLM/CosyVoice2-0.5B:david",
"diana": "FunAudioLLM/CosyVoice2-0.5B:diana",
}
def generate_silence(duration_ms: int, sample_rate: int = 16000) -> bytes:
"""Generate silence as PCM bytes."""
num_samples = int(sample_rate * (duration_ms / 1000.0))
return b'\x00\x00' * num_samples
async def synthesize_speech(
text: str,
api_key: str,
voice: str = "anna",
sample_rate: int = 16000,
speed: float = 1.0
) -> bytes:
"""
Synthesize speech using SiliconFlow TTS API.
Args:
text: Text to synthesize
api_key: SiliconFlow API key
voice: Voice name (alex, anna, bella, benjamin, charles, claire, david, diana)
sample_rate: Output sample rate (8000, 16000, 24000, 32000, 44100)
speed: Speech speed (0.25 to 4.0)
Returns:
PCM audio bytes (16-bit signed, little-endian)
"""
# Resolve voice name
full_voice = VOICES.get(voice, voice)
payload = {
"model": SILICONFLOW_MODEL,
"input": text,
"voice": full_voice,
"response_format": "pcm",
"sample_rate": sample_rate,
"stream": False,
"speed": speed
}
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
async with aiohttp.ClientSession() as session:
async with session.post(SILICONFLOW_API_URL, json=payload, headers=headers) as response:
if response.status != 200:
error_text = await response.text()
raise RuntimeError(f"SiliconFlow TTS error: {response.status} - {error_text}")
return await response.read()
async def generate_test_audio(
output_path: str,
utterances: list[str],
silence_ms: int = 500,
lead_silence_ms: int = 200,
trail_silence_ms: int = 300,
voice: str = "anna",
sample_rate: int = 16000,
speed: float = 1.0
):
"""
Generate test audio with multiple utterances separated by silence.
Args:
output_path: Path to save the WAV file
utterances: List of text strings for each utterance
silence_ms: Silence duration between utterances (milliseconds)
lead_silence_ms: Silence at the beginning (milliseconds)
trail_silence_ms: Silence at the end (milliseconds)
voice: TTS voice to use
sample_rate: Audio sample rate
speed: TTS speech speed
"""
api_key = os.getenv("SILICONFLOW_API_KEY")
if not api_key:
raise ValueError(
"SILICONFLOW_API_KEY not found in environment.\n"
"Please set it in your .env file:\n"
" SILICONFLOW_API_KEY=your-api-key-here"
)
print(f"Using SiliconFlow TTS API")
print(f" Voice: {voice}")
print(f" Sample rate: {sample_rate}Hz")
print(f" Speed: {speed}x")
print()
segments = []
# Lead-in silence
if lead_silence_ms > 0:
segments.append(generate_silence(lead_silence_ms, sample_rate))
print(f" [silence: {lead_silence_ms}ms]")
# Generate each utterance with silence between
for i, text in enumerate(utterances):
print(f" Synthesizing utterance {i + 1}: \"{text}\"")
audio = await synthesize_speech(
text=text,
api_key=api_key,
voice=voice,
sample_rate=sample_rate,
speed=speed
)
segments.append(audio)
# Add silence between utterances (not after the last one)
if i < len(utterances) - 1:
segments.append(generate_silence(silence_ms, sample_rate))
print(f" [silence: {silence_ms}ms]")
# Trail silence
if trail_silence_ms > 0:
segments.append(generate_silence(trail_silence_ms, sample_rate))
print(f" [silence: {trail_silence_ms}ms]")
# Concatenate all segments
audio_data = b''.join(segments)
# Write WAV file
with wave.open(output_path, 'wb') as wf:
wf.setnchannels(1) # Mono
wf.setsampwidth(2) # 16-bit
wf.setframerate(sample_rate)
wf.writeframes(audio_data)
duration_sec = len(audio_data) / (sample_rate * 2)
print()
print(f"Generated: {output_path}")
print(f" Duration: {duration_sec:.2f}s")
print(f" Sample rate: {sample_rate}Hz")
print(f" Format: 16-bit mono PCM WAV")
print(f" Size: {len(audio_data):,} bytes")
def load_utterances_from_json(path: Path) -> list[str]:
"""
Load utterances from a JSON file.
Accepts either:
- A JSON array: ["utterance 1", "utterance 2"]
- A JSON object with "utterances" key: {"utterances": ["a", "b"]}
"""
with open(path, encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, list):
return [str(s) for s in data]
if isinstance(data, dict) and "utterances" in data:
return [str(s) for s in data["utterances"]]
raise ValueError(
f"JSON file must be an array of strings or an object with 'utterances' key. "
f"Got: {type(data).__name__}"
)
def parse_args():
"""Parse command-line arguments."""
script_dir = Path(__file__).parent
default_output = script_dir.parent / "data" / "audio_examples" / "two_utterances_16k.wav"
parser = argparse.ArgumentParser(description="Generate test audio with SiliconFlow TTS (utterances + silence).")
parser.add_argument(
"-o", "--output",
type=Path,
default=default_output,
help=f"Output WAV file path (default: {default_output})"
)
parser.add_argument(
"-u", "--utterance",
action="append",
dest="utterances",
metavar="TEXT",
help="Utterance text (repeat for multiple). Ignored if --json is set."
)
parser.add_argument(
"-j", "--json",
type=Path,
metavar="PATH",
help="JSON file with utterances: array of strings or object with 'utterances' key"
)
parser.add_argument(
"--silence-ms",
type=int,
default=500,
metavar="MS",
help="Silence in ms between utterances (default: 500)"
)
parser.add_argument(
"--lead-silence-ms",
type=int,
default=200,
metavar="MS",
help="Silence in ms at start of file (default: 200)"
)
parser.add_argument(
"--trail-silence-ms",
type=int,
default=300,
metavar="MS",
help="Silence in ms at end of file (default: 300)"
)
return parser.parse_args()
async def main():
"""Main entry point."""
args = parse_args()
output_path = args.output
output_path.parent.mkdir(parents=True, exist_ok=True)
# Resolve utterances: JSON file > -u args > defaults
if args.json is not None:
if not args.json.is_file():
raise FileNotFoundError(f"Utterances JSON file not found: {args.json}")
utterances = load_utterances_from_json(args.json)
if not utterances:
raise ValueError(f"JSON file has no utterances: {args.json}")
elif args.utterances:
utterances = args.utterances
else:
utterances = [
"Hello, how are you doing today?",
"I'm doing great, thank you for asking!"
]
await generate_test_audio(
output_path=str(output_path),
utterances=utterances,
silence_ms=args.silence_ms,
lead_silence_ms=args.lead_silence_ms,
trail_silence_ms=args.trail_silence_ms,
voice="anna",
sample_rate=16000,
speed=1.0
)
if __name__ == "__main__":
asyncio.run(main())

51
services/__init__.py Normal file
View File

@@ -0,0 +1,51 @@
"""AI Services package.
Provides ASR, LLM, TTS, and Realtime API services for voice conversation.
"""
from services.base import (
ServiceState,
ASRResult,
LLMMessage,
TTSChunk,
BaseASRService,
BaseLLMService,
BaseTTSService,
)
from services.llm import OpenAILLMService, MockLLMService
from services.tts import EdgeTTSService, MockTTSService
from services.asr import BufferedASRService, MockASRService
from services.openai_compatible_asr import OpenAICompatibleASRService, SiliconFlowASRService
from services.openai_compatible_tts import OpenAICompatibleTTSService, SiliconFlowTTSService
from services.streaming_tts_adapter import StreamingTTSAdapter
from services.realtime import RealtimeService, RealtimeConfig, RealtimePipeline
__all__ = [
# Base classes
"ServiceState",
"ASRResult",
"LLMMessage",
"TTSChunk",
"BaseASRService",
"BaseLLMService",
"BaseTTSService",
# LLM
"OpenAILLMService",
"MockLLMService",
# TTS
"EdgeTTSService",
"MockTTSService",
# ASR
"BufferedASRService",
"MockASRService",
"OpenAICompatibleASRService",
"SiliconFlowASRService",
# TTS (SiliconFlow)
"OpenAICompatibleTTSService",
"SiliconFlowTTSService",
"StreamingTTSAdapter",
# Realtime
"RealtimeService",
"RealtimeConfig",
"RealtimePipeline",
]

147
services/asr.py Normal file
View File

@@ -0,0 +1,147 @@
"""ASR (Automatic Speech Recognition) Service implementations.
Provides speech-to-text capabilities with streaming support.
"""
import os
import asyncio
import json
from typing import AsyncIterator, Optional
from loguru import logger
from services.base import BaseASRService, ASRResult, ServiceState
# Try to import websockets for streaming ASR
try:
import websockets
WEBSOCKETS_AVAILABLE = True
except ImportError:
WEBSOCKETS_AVAILABLE = False
class BufferedASRService(BaseASRService):
"""
Buffered ASR service that accumulates audio and provides
a simple text accumulator for use with EOU detection.
This is a lightweight implementation that works with the
existing VAD + EOU pattern without requiring external ASR.
"""
def __init__(
self,
sample_rate: int = 16000,
language: str = "en"
):
super().__init__(sample_rate=sample_rate, language=language)
self._audio_buffer: bytes = b""
self._current_text: str = ""
self._transcript_queue: asyncio.Queue[ASRResult] = asyncio.Queue()
async def connect(self) -> None:
"""No connection needed for buffered ASR."""
self.state = ServiceState.CONNECTED
logger.info("Buffered ASR service connected")
async def disconnect(self) -> None:
"""Clear buffers on disconnect."""
self._audio_buffer = b""
self._current_text = ""
self.state = ServiceState.DISCONNECTED
logger.info("Buffered ASR service disconnected")
async def send_audio(self, audio: bytes) -> None:
"""Buffer audio for later processing."""
self._audio_buffer += audio
async def receive_transcripts(self) -> AsyncIterator[ASRResult]:
"""Yield transcription results."""
while True:
try:
result = await asyncio.wait_for(
self._transcript_queue.get(),
timeout=0.1
)
yield result
except asyncio.TimeoutError:
continue
except asyncio.CancelledError:
break
def set_text(self, text: str) -> None:
"""
Set the current transcript text directly.
This allows external integration (e.g., Whisper, other ASR)
to provide transcripts.
"""
self._current_text = text
result = ASRResult(text=text, is_final=False)
asyncio.create_task(self._transcript_queue.put(result))
def get_and_clear_text(self) -> str:
"""Get accumulated text and clear buffer."""
text = self._current_text
self._current_text = ""
self._audio_buffer = b""
return text
def get_audio_buffer(self) -> bytes:
"""Get accumulated audio buffer."""
return self._audio_buffer
def clear_audio_buffer(self) -> None:
"""Clear audio buffer."""
self._audio_buffer = b""
class MockASRService(BaseASRService):
"""
Mock ASR service for testing without actual recognition.
"""
def __init__(self, sample_rate: int = 16000, language: str = "en"):
super().__init__(sample_rate=sample_rate, language=language)
self._transcript_queue: asyncio.Queue[ASRResult] = asyncio.Queue()
self._mock_texts = [
"Hello, how are you?",
"That's interesting.",
"Tell me more about that.",
"I understand.",
]
self._text_index = 0
async def connect(self) -> None:
self.state = ServiceState.CONNECTED
logger.info("Mock ASR service connected")
async def disconnect(self) -> None:
self.state = ServiceState.DISCONNECTED
logger.info("Mock ASR service disconnected")
async def send_audio(self, audio: bytes) -> None:
"""Mock audio processing - generates fake transcripts periodically."""
pass
def trigger_transcript(self) -> None:
"""Manually trigger a transcript (for testing)."""
text = self._mock_texts[self._text_index % len(self._mock_texts)]
self._text_index += 1
result = ASRResult(text=text, is_final=True, confidence=0.95)
asyncio.create_task(self._transcript_queue.put(result))
async def receive_transcripts(self) -> AsyncIterator[ASRResult]:
"""Yield transcription results."""
while True:
try:
result = await asyncio.wait_for(
self._transcript_queue.get(),
timeout=0.1
)
yield result
except asyncio.TimeoutError:
continue
except asyncio.CancelledError:
break

253
services/base.py Normal file
View File

@@ -0,0 +1,253 @@
"""Base classes for AI services.
Defines abstract interfaces for ASR, LLM, and TTS services,
inspired by pipecat's service architecture and active-call's
StreamEngine pattern.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import AsyncIterator, Optional, List, Dict, Any, Literal
from enum import Enum
class ServiceState(Enum):
"""Service connection state."""
DISCONNECTED = "disconnected"
CONNECTING = "connecting"
CONNECTED = "connected"
ERROR = "error"
@dataclass
class ASRResult:
"""ASR transcription result."""
text: str
is_final: bool = False
confidence: float = 1.0
language: Optional[str] = None
start_time: Optional[float] = None
end_time: Optional[float] = None
def __str__(self) -> str:
status = "FINAL" if self.is_final else "PARTIAL"
return f"[{status}] {self.text}"
@dataclass
class LLMMessage:
"""LLM conversation message."""
role: str # "system", "user", "assistant", "function"
content: str
name: Optional[str] = None # For function calls
function_call: Optional[Dict[str, Any]] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert to API-compatible dict."""
d = {"role": self.role, "content": self.content}
if self.name:
d["name"] = self.name
if self.function_call:
d["function_call"] = self.function_call
return d
@dataclass
class LLMStreamEvent:
"""Structured LLM stream event."""
type: Literal["text_delta", "tool_call", "done"]
text: Optional[str] = None
tool_call: Optional[Dict[str, Any]] = None
@dataclass
class TTSChunk:
"""TTS audio chunk."""
audio: bytes # PCM audio data
sample_rate: int = 16000
channels: int = 1
bits_per_sample: int = 16
is_final: bool = False
text_offset: Optional[int] = None # Character offset in original text
class BaseASRService(ABC):
"""
Abstract base class for ASR (Speech-to-Text) services.
Supports both streaming and non-streaming transcription.
"""
def __init__(self, sample_rate: int = 16000, language: str = "en"):
self.sample_rate = sample_rate
self.language = language
self.state = ServiceState.DISCONNECTED
@abstractmethod
async def connect(self) -> None:
"""Establish connection to ASR service."""
pass
@abstractmethod
async def disconnect(self) -> None:
"""Close connection to ASR service."""
pass
@abstractmethod
async def send_audio(self, audio: bytes) -> None:
"""
Send audio chunk for transcription.
Args:
audio: PCM audio data (16-bit, mono)
"""
pass
@abstractmethod
async def receive_transcripts(self) -> AsyncIterator[ASRResult]:
"""
Receive transcription results.
Yields:
ASRResult objects as they become available
"""
pass
async def transcribe(self, audio: bytes) -> ASRResult:
"""
Transcribe a complete audio buffer (non-streaming).
Args:
audio: Complete PCM audio data
Returns:
Final ASRResult
"""
# Default implementation using streaming
await self.send_audio(audio)
async for result in self.receive_transcripts():
if result.is_final:
return result
return ASRResult(text="", is_final=True)
class BaseLLMService(ABC):
"""
Abstract base class for LLM (Language Model) services.
Supports streaming responses for real-time conversation.
"""
def __init__(self, model: str = "gpt-4"):
self.model = model
self.state = ServiceState.DISCONNECTED
@abstractmethod
async def connect(self) -> None:
"""Initialize LLM service connection."""
pass
@abstractmethod
async def disconnect(self) -> None:
"""Close LLM service connection."""
pass
@abstractmethod
async def generate(
self,
messages: List[LLMMessage],
temperature: float = 0.7,
max_tokens: Optional[int] = None
) -> str:
"""
Generate a complete response.
Args:
messages: Conversation history
temperature: Sampling temperature
max_tokens: Maximum tokens to generate
Returns:
Complete response text
"""
pass
@abstractmethod
async def generate_stream(
self,
messages: List[LLMMessage],
temperature: float = 0.7,
max_tokens: Optional[int] = None
) -> AsyncIterator[LLMStreamEvent]:
"""
Generate response in streaming mode.
Args:
messages: Conversation history
temperature: Sampling temperature
max_tokens: Maximum tokens to generate
Yields:
Stream events (text delta/tool call/done)
"""
pass
class BaseTTSService(ABC):
"""
Abstract base class for TTS (Text-to-Speech) services.
Supports streaming audio synthesis for low-latency playback.
"""
def __init__(
self,
voice: str = "default",
sample_rate: int = 16000,
speed: float = 1.0
):
self.voice = voice
self.sample_rate = sample_rate
self.speed = speed
self.state = ServiceState.DISCONNECTED
@abstractmethod
async def connect(self) -> None:
"""Initialize TTS service connection."""
pass
@abstractmethod
async def disconnect(self) -> None:
"""Close TTS service connection."""
pass
@abstractmethod
async def synthesize(self, text: str) -> bytes:
"""
Synthesize complete audio for text (non-streaming).
Args:
text: Text to synthesize
Returns:
Complete PCM audio data
"""
pass
@abstractmethod
async def synthesize_stream(self, text: str) -> AsyncIterator[TTSChunk]:
"""
Synthesize audio in streaming mode.
Args:
text: Text to synthesize
Yields:
TTSChunk objects as audio is generated
"""
pass
async def cancel(self) -> None:
"""Cancel ongoing synthesis (for barge-in support)."""
pass

443
services/llm.py Normal file
View File

@@ -0,0 +1,443 @@
"""LLM (Large Language Model) Service implementations.
Provides OpenAI-compatible LLM integration with streaming support
for real-time voice conversation.
"""
import os
import asyncio
import uuid
from typing import AsyncIterator, Optional, List, Dict, Any
from loguru import logger
from app.backend_client import search_knowledge_context
from services.base import BaseLLMService, LLMMessage, LLMStreamEvent, ServiceState
# Try to import openai
try:
from openai import AsyncOpenAI
OPENAI_AVAILABLE = True
except ImportError:
OPENAI_AVAILABLE = False
logger.warning("openai package not available - LLM service will be disabled")
class OpenAILLMService(BaseLLMService):
"""
OpenAI-compatible LLM service.
Supports streaming responses for low-latency voice conversation.
Works with OpenAI API, Azure OpenAI, and compatible APIs.
"""
def __init__(
self,
model: str = "gpt-4o-mini",
api_key: Optional[str] = None,
base_url: Optional[str] = None,
system_prompt: Optional[str] = None,
knowledge_config: Optional[Dict[str, Any]] = None,
):
"""
Initialize OpenAI LLM service.
Args:
model: Model name (e.g., "gpt-4o-mini", "gpt-4o")
api_key: OpenAI API key (defaults to OPENAI_API_KEY env var)
base_url: Custom API base URL (for Azure or compatible APIs)
system_prompt: Default system prompt for conversations
"""
super().__init__(model=model)
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
self.base_url = base_url or os.getenv("OPENAI_API_URL")
self.system_prompt = system_prompt or (
"You are a helpful, friendly voice assistant. "
"Keep your responses concise and conversational. "
"Respond naturally as if having a phone conversation."
)
self.client: Optional[AsyncOpenAI] = None
self._cancel_event = asyncio.Event()
self._knowledge_config: Dict[str, Any] = knowledge_config or {}
self._tool_schemas: List[Dict[str, Any]] = []
_RAG_DEFAULT_RESULTS = 5
_RAG_MAX_RESULTS = 8
_RAG_MAX_CONTEXT_CHARS = 4000
async def connect(self) -> None:
"""Initialize OpenAI client."""
if not OPENAI_AVAILABLE:
raise RuntimeError("openai package not installed")
if not self.api_key:
raise ValueError("OpenAI API key not provided")
self.client = AsyncOpenAI(
api_key=self.api_key,
base_url=self.base_url
)
self.state = ServiceState.CONNECTED
logger.info(f"OpenAI LLM service connected: model={self.model}")
async def disconnect(self) -> None:
"""Close OpenAI client."""
if self.client:
await self.client.close()
self.client = None
self.state = ServiceState.DISCONNECTED
logger.info("OpenAI LLM service disconnected")
def _prepare_messages(self, messages: List[LLMMessage]) -> List[Dict[str, Any]]:
"""Prepare messages list with system prompt."""
result = []
# Add system prompt if not already present
has_system = any(m.role == "system" for m in messages)
if not has_system and self.system_prompt:
result.append({"role": "system", "content": self.system_prompt})
# Add all messages
for msg in messages:
result.append(msg.to_dict())
return result
def set_knowledge_config(self, config: Optional[Dict[str, Any]]) -> None:
"""Update runtime knowledge retrieval config."""
self._knowledge_config = config or {}
def set_tool_schemas(self, schemas: Optional[List[Dict[str, Any]]]) -> None:
"""Update runtime tool schemas."""
self._tool_schemas = []
if not isinstance(schemas, list):
return
for item in schemas:
if not isinstance(item, dict):
continue
fn = item.get("function")
if isinstance(fn, dict) and fn.get("name"):
self._tool_schemas.append(item)
elif item.get("name"):
self._tool_schemas.append(
{
"type": "function",
"function": {
"name": str(item.get("name")),
"description": str(item.get("description") or ""),
"parameters": item.get("parameters") or {"type": "object", "properties": {}},
},
}
)
@staticmethod
def _coerce_int(value: Any, default: int) -> int:
try:
return int(value)
except (TypeError, ValueError):
return default
def _resolve_kb_id(self) -> Optional[str]:
cfg = self._knowledge_config if isinstance(self._knowledge_config, dict) else {}
kb_id = str(
cfg.get("kbId")
or cfg.get("knowledgeBaseId")
or cfg.get("knowledge_base_id")
or ""
).strip()
return kb_id or None
def _build_knowledge_prompt(self, results: List[Dict[str, Any]]) -> Optional[str]:
if not results:
return None
lines = [
"You have retrieved the following knowledge base snippets.",
"Use them only when relevant to the latest user request.",
"If snippets are insufficient, say you are not sure instead of guessing.",
"",
]
used_chars = 0
used_count = 0
for item in results:
content = str(item.get("content") or "").strip()
if not content:
continue
if used_chars >= self._RAG_MAX_CONTEXT_CHARS:
break
metadata = item.get("metadata") if isinstance(item.get("metadata"), dict) else {}
doc_id = metadata.get("document_id")
chunk_index = metadata.get("chunk_index")
distance = item.get("distance")
source_parts = []
if doc_id:
source_parts.append(f"doc={doc_id}")
if chunk_index is not None:
source_parts.append(f"chunk={chunk_index}")
source = f" ({', '.join(source_parts)})" if source_parts else ""
distance_text = ""
try:
if distance is not None:
distance_text = f", distance={float(distance):.4f}"
except (TypeError, ValueError):
distance_text = ""
remaining = self._RAG_MAX_CONTEXT_CHARS - used_chars
snippet = content[:remaining].strip()
if not snippet:
continue
used_count += 1
lines.append(f"[{used_count}{source}{distance_text}] {snippet}")
used_chars += len(snippet)
if used_count == 0:
return None
return "\n".join(lines)
async def _with_knowledge_context(self, messages: List[LLMMessage]) -> List[LLMMessage]:
cfg = self._knowledge_config if isinstance(self._knowledge_config, dict) else {}
enabled = cfg.get("enabled", True)
if isinstance(enabled, str):
enabled = enabled.strip().lower() not in {"false", "0", "off", "no"}
if not enabled:
return messages
kb_id = self._resolve_kb_id()
if not kb_id:
return messages
latest_user = ""
for msg in reversed(messages):
if msg.role == "user":
latest_user = (msg.content or "").strip()
break
if not latest_user:
return messages
n_results = self._coerce_int(cfg.get("nResults"), self._RAG_DEFAULT_RESULTS)
n_results = max(1, min(n_results, self._RAG_MAX_RESULTS))
results = await search_knowledge_context(
kb_id=kb_id,
query=latest_user,
n_results=n_results,
)
prompt = self._build_knowledge_prompt(results)
if not prompt:
return messages
logger.debug(f"RAG context injected (kb_id={kb_id}, chunks={len(results)})")
rag_system = LLMMessage(role="system", content=prompt)
if messages and messages[0].role == "system":
return [messages[0], rag_system, *messages[1:]]
return [rag_system, *messages]
async def generate(
self,
messages: List[LLMMessage],
temperature: float = 0.7,
max_tokens: Optional[int] = None
) -> str:
"""
Generate a complete response.
Args:
messages: Conversation history
temperature: Sampling temperature
max_tokens: Maximum tokens to generate
Returns:
Complete response text
"""
if not self.client:
raise RuntimeError("LLM service not connected")
rag_messages = await self._with_knowledge_context(messages)
prepared = self._prepare_messages(rag_messages)
try:
response = await self.client.chat.completions.create(
model=self.model,
messages=prepared,
temperature=temperature,
max_tokens=max_tokens
)
content = response.choices[0].message.content or ""
logger.debug(f"LLM response: {content[:100]}...")
return content
except Exception as e:
logger.error(f"LLM generation error: {e}")
raise
async def generate_stream(
self,
messages: List[LLMMessage],
temperature: float = 0.7,
max_tokens: Optional[int] = None
) -> AsyncIterator[LLMStreamEvent]:
"""
Generate response in streaming mode.
Args:
messages: Conversation history
temperature: Sampling temperature
max_tokens: Maximum tokens to generate
Yields:
Structured stream events
"""
if not self.client:
raise RuntimeError("LLM service not connected")
rag_messages = await self._with_knowledge_context(messages)
prepared = self._prepare_messages(rag_messages)
self._cancel_event.clear()
tool_accumulator: Dict[int, Dict[str, str]] = {}
openai_tools = self._tool_schemas or None
try:
create_args: Dict[str, Any] = dict(
model=self.model,
messages=prepared,
temperature=temperature,
max_tokens=max_tokens,
stream=True,
)
if openai_tools:
create_args["tools"] = openai_tools
create_args["tool_choice"] = "auto"
stream = await self.client.chat.completions.create(**create_args)
async for chunk in stream:
# Check for cancellation
if self._cancel_event.is_set():
logger.info("LLM stream cancelled")
break
if not chunk.choices:
continue
choice = chunk.choices[0]
delta = getattr(choice, "delta", None)
if delta and getattr(delta, "content", None):
content = delta.content
yield LLMStreamEvent(type="text_delta", text=content)
# OpenAI streams function calls via incremental tool_calls deltas.
tool_calls = getattr(delta, "tool_calls", None) if delta else None
if tool_calls:
for tc in tool_calls:
index = getattr(tc, "index", 0) or 0
item = tool_accumulator.setdefault(
int(index),
{"id": "", "name": "", "arguments": ""},
)
tc_id = getattr(tc, "id", None)
if tc_id:
item["id"] = str(tc_id)
fn = getattr(tc, "function", None)
if fn:
fn_name = getattr(fn, "name", None)
if fn_name:
item["name"] = str(fn_name)
fn_args = getattr(fn, "arguments", None)
if fn_args:
item["arguments"] += str(fn_args)
finish_reason = getattr(choice, "finish_reason", None)
if finish_reason == "tool_calls" and tool_accumulator:
for _, payload in sorted(tool_accumulator.items(), key=lambda row: row[0]):
call_name = payload.get("name", "").strip()
if not call_name:
continue
call_id = payload.get("id", "").strip() or f"call_{uuid.uuid4().hex[:10]}"
yield LLMStreamEvent(
type="tool_call",
tool_call={
"id": call_id,
"type": "function",
"function": {
"name": call_name,
"arguments": payload.get("arguments", "") or "{}",
},
},
)
yield LLMStreamEvent(type="done")
return
if finish_reason in {"stop", "length", "content_filter"}:
yield LLMStreamEvent(type="done")
return
except asyncio.CancelledError:
logger.info("LLM stream cancelled via asyncio")
raise
except Exception as e:
logger.error(f"LLM streaming error: {e}")
raise
def cancel(self) -> None:
"""Cancel ongoing generation."""
self._cancel_event.set()
class MockLLMService(BaseLLMService):
"""
Mock LLM service for testing without API calls.
"""
def __init__(self, response_delay: float = 0.5):
super().__init__(model="mock")
self.response_delay = response_delay
self.responses = [
"Hello! How can I help you today?",
"That's an interesting question. Let me think about it.",
"I understand. Is there anything else you'd like to know?",
"Great! I'm here if you need anything else.",
]
self._response_index = 0
async def connect(self) -> None:
self.state = ServiceState.CONNECTED
logger.info("Mock LLM service connected")
async def disconnect(self) -> None:
self.state = ServiceState.DISCONNECTED
logger.info("Mock LLM service disconnected")
async def generate(
self,
messages: List[LLMMessage],
temperature: float = 0.7,
max_tokens: Optional[int] = None
) -> str:
await asyncio.sleep(self.response_delay)
response = self.responses[self._response_index % len(self.responses)]
self._response_index += 1
return response
async def generate_stream(
self,
messages: List[LLMMessage],
temperature: float = 0.7,
max_tokens: Optional[int] = None
) -> AsyncIterator[LLMStreamEvent]:
response = await self.generate(messages, temperature, max_tokens)
# Stream word by word
words = response.split()
for i, word in enumerate(words):
if i > 0:
yield LLMStreamEvent(type="text_delta", text=" ")
yield LLMStreamEvent(type="text_delta", text=word)
await asyncio.sleep(0.05) # Simulate streaming delay
yield LLMStreamEvent(type="done")

View File

@@ -0,0 +1,321 @@
"""OpenAI-compatible ASR (Automatic Speech Recognition) Service.
Uses the SiliconFlow API for speech-to-text transcription.
API: https://docs.siliconflow.cn/cn/api-reference/audio/create-audio-transcriptions
"""
import asyncio
import io
import wave
from typing import AsyncIterator, Optional, Callable, Awaitable
from loguru import logger
try:
import aiohttp
AIOHTTP_AVAILABLE = True
except ImportError:
AIOHTTP_AVAILABLE = False
logger.warning("aiohttp not available - OpenAICompatibleASRService will not work")
from services.base import BaseASRService, ASRResult, ServiceState
class OpenAICompatibleASRService(BaseASRService):
"""
OpenAI-compatible ASR service for speech-to-text transcription.
Features:
- Buffers incoming audio chunks
- Provides interim transcriptions periodically (for streaming to client)
- Final transcription on EOU
API Details:
- Endpoint: POST https://api.siliconflow.cn/v1/audio/transcriptions
- Models: FunAudioLLM/SenseVoiceSmall (default), TeleAI/TeleSpeechASR
- Input: Audio file (multipart/form-data)
- Output: {"text": "transcribed text"}
"""
# Supported models
MODELS = {
"sensevoice": "FunAudioLLM/SenseVoiceSmall",
"telespeech": "TeleAI/TeleSpeechASR",
}
API_URL = "https://api.siliconflow.cn/v1/audio/transcriptions"
def __init__(
self,
api_key: str,
model: str = "FunAudioLLM/SenseVoiceSmall",
sample_rate: int = 16000,
language: str = "auto",
interim_interval_ms: int = 500, # How often to send interim results
min_audio_for_interim_ms: int = 300, # Min audio before first interim
on_transcript: Optional[Callable[[str, bool], Awaitable[None]]] = None
):
"""
Initialize OpenAI-compatible ASR service.
Args:
api_key: Provider API key
model: ASR model name or alias
sample_rate: Audio sample rate (16000 recommended)
language: Language code (auto for automatic detection)
interim_interval_ms: How often to generate interim transcriptions
min_audio_for_interim_ms: Minimum audio duration before first interim
on_transcript: Callback for transcription results (text, is_final)
"""
super().__init__(sample_rate=sample_rate, language=language)
if not AIOHTTP_AVAILABLE:
raise RuntimeError("aiohttp is required for OpenAICompatibleASRService")
self.api_key = api_key
self.model = self.MODELS.get(model.lower(), model)
self.interim_interval_ms = interim_interval_ms
self.min_audio_for_interim_ms = min_audio_for_interim_ms
self.on_transcript = on_transcript
# Session
self._session: Optional[aiohttp.ClientSession] = None
# Audio buffer
self._audio_buffer: bytes = b""
self._current_text: str = ""
self._last_interim_time: float = 0
# Transcript queue for async iteration
self._transcript_queue: asyncio.Queue[ASRResult] = asyncio.Queue()
# Background task for interim results
self._interim_task: Optional[asyncio.Task] = None
self._running = False
logger.info(f"OpenAICompatibleASRService initialized with model: {self.model}")
async def connect(self) -> None:
"""Connect to the service."""
self._session = aiohttp.ClientSession(
headers={
"Authorization": f"Bearer {self.api_key}"
}
)
self._running = True
self.state = ServiceState.CONNECTED
logger.info("OpenAICompatibleASRService connected")
async def disconnect(self) -> None:
"""Disconnect and cleanup."""
self._running = False
if self._interim_task:
self._interim_task.cancel()
try:
await self._interim_task
except asyncio.CancelledError:
pass
self._interim_task = None
if self._session:
await self._session.close()
self._session = None
self._audio_buffer = b""
self._current_text = ""
self.state = ServiceState.DISCONNECTED
logger.info("OpenAICompatibleASRService disconnected")
async def send_audio(self, audio: bytes) -> None:
"""
Buffer incoming audio data.
Args:
audio: PCM audio data (16-bit, mono)
"""
self._audio_buffer += audio
async def transcribe_buffer(self, is_final: bool = False) -> Optional[str]:
"""
Transcribe current audio buffer.
Args:
is_final: Whether this is the final transcription
Returns:
Transcribed text or None if not enough audio
"""
if not self._session:
logger.warning("ASR session not connected")
return None
# Check minimum audio duration
audio_duration_ms = len(self._audio_buffer) / (self.sample_rate * 2) * 1000
if not is_final and audio_duration_ms < self.min_audio_for_interim_ms:
return None
if audio_duration_ms < 100: # Less than 100ms - too short
return None
try:
# Convert PCM to WAV in memory
wav_buffer = io.BytesIO()
with wave.open(wav_buffer, 'wb') as wav_file:
wav_file.setnchannels(1)
wav_file.setsampwidth(2) # 16-bit
wav_file.setframerate(self.sample_rate)
wav_file.writeframes(self._audio_buffer)
wav_buffer.seek(0)
wav_data = wav_buffer.read()
# Send to API
form_data = aiohttp.FormData()
form_data.add_field(
'file',
wav_data,
filename='audio.wav',
content_type='audio/wav'
)
form_data.add_field('model', self.model)
async with self._session.post(self.API_URL, data=form_data) as response:
if response.status == 200:
result = await response.json()
text = result.get("text", "").strip()
if text:
self._current_text = text
# Notify via callback
if self.on_transcript:
await self.on_transcript(text, is_final)
# Queue result
await self._transcript_queue.put(
ASRResult(text=text, is_final=is_final)
)
logger.debug(f"ASR {'final' if is_final else 'interim'}: {text[:50]}...")
return text
else:
error_text = await response.text()
logger.error(f"ASR API error {response.status}: {error_text}")
return None
except Exception as e:
logger.error(f"ASR transcription error: {e}")
return None
async def get_final_transcription(self) -> str:
"""
Get final transcription and clear buffer.
Call this when EOU is detected.
Returns:
Final transcribed text
"""
# Transcribe full buffer as final
text = await self.transcribe_buffer(is_final=True)
# Clear buffer
result = text or self._current_text
self._audio_buffer = b""
self._current_text = ""
return result
def get_and_clear_text(self) -> str:
"""
Get accumulated text and clear buffer.
Compatible with BufferedASRService interface.
"""
text = self._current_text
self._current_text = ""
self._audio_buffer = b""
return text
def get_audio_buffer(self) -> bytes:
"""Get current audio buffer."""
return self._audio_buffer
def get_audio_duration_ms(self) -> float:
"""Get current audio buffer duration in milliseconds."""
return len(self._audio_buffer) / (self.sample_rate * 2) * 1000
def clear_buffer(self) -> None:
"""Clear audio and text buffers."""
self._audio_buffer = b""
self._current_text = ""
async def receive_transcripts(self) -> AsyncIterator[ASRResult]:
"""
Async iterator for transcription results.
Yields:
ASRResult with text and is_final flag
"""
while self._running:
try:
result = await asyncio.wait_for(
self._transcript_queue.get(),
timeout=0.1
)
yield result
except asyncio.TimeoutError:
continue
except asyncio.CancelledError:
break
async def start_interim_transcription(self) -> None:
"""
Start background task for interim transcriptions.
This periodically transcribes buffered audio for
real-time feedback to the user.
"""
if self._interim_task and not self._interim_task.done():
return
self._interim_task = asyncio.create_task(self._interim_loop())
async def stop_interim_transcription(self) -> None:
"""Stop interim transcription task."""
if self._interim_task:
self._interim_task.cancel()
try:
await self._interim_task
except asyncio.CancelledError:
pass
self._interim_task = None
async def _interim_loop(self) -> None:
"""Background loop for interim transcriptions."""
import time
while self._running:
try:
await asyncio.sleep(self.interim_interval_ms / 1000)
# Check if we have enough new audio
current_time = time.time()
time_since_last = (current_time - self._last_interim_time) * 1000
if time_since_last >= self.interim_interval_ms:
audio_duration = self.get_audio_duration_ms()
if audio_duration >= self.min_audio_for_interim_ms:
await self.transcribe_buffer(is_final=False)
self._last_interim_time = current_time
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Interim transcription error: {e}")
# Backward-compatible alias
SiliconFlowASRService = OpenAICompatibleASRService

View File

@@ -0,0 +1,324 @@
"""OpenAI-compatible TTS Service with streaming support.
Uses SiliconFlow's CosyVoice2 or MOSS-TTSD models for low-latency
text-to-speech synthesis with streaming.
API Docs: https://docs.siliconflow.cn/cn/api-reference/audio/create-speech
"""
import os
import asyncio
import aiohttp
from typing import AsyncIterator, Optional
from loguru import logger
from services.base import BaseTTSService, TTSChunk, ServiceState
from services.streaming_tts_adapter import StreamingTTSAdapter # backward-compatible re-export
class OpenAICompatibleTTSService(BaseTTSService):
"""
OpenAI-compatible TTS service with streaming support.
Supports CosyVoice2-0.5B and MOSS-TTSD-v0.5 models.
"""
# Available voices
VOICES = {
"alex": "FunAudioLLM/CosyVoice2-0.5B:alex",
"anna": "FunAudioLLM/CosyVoice2-0.5B:anna",
"bella": "FunAudioLLM/CosyVoice2-0.5B:bella",
"benjamin": "FunAudioLLM/CosyVoice2-0.5B:benjamin",
"charles": "FunAudioLLM/CosyVoice2-0.5B:charles",
"claire": "FunAudioLLM/CosyVoice2-0.5B:claire",
"david": "FunAudioLLM/CosyVoice2-0.5B:david",
"diana": "FunAudioLLM/CosyVoice2-0.5B:diana",
}
def __init__(
self,
api_key: Optional[str] = None,
voice: str = "anna",
model: str = "FunAudioLLM/CosyVoice2-0.5B",
sample_rate: int = 16000,
speed: float = 1.0
):
"""
Initialize OpenAI-compatible TTS service.
Args:
api_key: Provider API key (defaults to SILICONFLOW_API_KEY env var)
voice: Voice name (alex, anna, bella, benjamin, charles, claire, david, diana)
model: Model name
sample_rate: Output sample rate (8000, 16000, 24000, 32000, 44100)
speed: Speech speed (0.25 to 4.0)
"""
# Resolve voice name (case-insensitive), and normalize "model:VoiceId" suffix.
resolved_voice = (voice or "").strip()
voice_lookup = resolved_voice.lower()
if voice_lookup in self.VOICES:
full_voice = self.VOICES[voice_lookup]
elif ":" in resolved_voice:
model_part, voice_part = resolved_voice.split(":", 1)
normalized_voice_part = voice_part.strip().lower()
if normalized_voice_part in self.VOICES:
full_voice = f"{(model_part or model).strip()}:{normalized_voice_part}"
else:
full_voice = resolved_voice
else:
full_voice = resolved_voice
super().__init__(voice=full_voice, sample_rate=sample_rate, speed=speed)
self.api_key = api_key or os.getenv("SILICONFLOW_API_KEY")
self.model = model
self.api_url = "https://api.siliconflow.cn/v1/audio/speech"
self._session: Optional[aiohttp.ClientSession] = None
self._cancel_event = asyncio.Event()
async def connect(self) -> None:
"""Initialize HTTP session."""
if not self.api_key:
raise ValueError("SiliconFlow API key not provided. Set SILICONFLOW_API_KEY env var.")
self._session = aiohttp.ClientSession(
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
)
self.state = ServiceState.CONNECTED
logger.info(f"SiliconFlow TTS service ready: voice={self.voice}, model={self.model}")
async def disconnect(self) -> None:
"""Close HTTP session."""
if self._session:
await self._session.close()
self._session = None
self.state = ServiceState.DISCONNECTED
logger.info("SiliconFlow TTS service disconnected")
async def synthesize(self, text: str) -> bytes:
"""Synthesize complete audio for text."""
audio_data = b""
async for chunk in self.synthesize_stream(text):
audio_data += chunk.audio
return audio_data
async def synthesize_stream(self, text: str) -> AsyncIterator[TTSChunk]:
"""
Synthesize audio in streaming mode.
Args:
text: Text to synthesize
Yields:
TTSChunk objects with PCM audio
"""
if not self._session:
raise RuntimeError("TTS service not connected")
if not text.strip():
return
self._cancel_event.clear()
payload = {
"model": self.model,
"input": text,
"voice": self.voice,
"response_format": "pcm",
"sample_rate": self.sample_rate,
"stream": True,
"speed": self.speed
}
try:
async with self._session.post(self.api_url, json=payload) as response:
if response.status != 200:
error_text = await response.text()
logger.error(f"SiliconFlow TTS error: {response.status} - {error_text}")
return
# Stream audio chunks
chunk_size = self.sample_rate * 2 // 10 # 100ms chunks
buffer = b""
pending_chunk = None
async for chunk in response.content.iter_any():
if self._cancel_event.is_set():
logger.info("TTS synthesis cancelled")
return
buffer += chunk
# Yield complete chunks
while len(buffer) >= chunk_size:
audio_chunk = buffer[:chunk_size]
buffer = buffer[chunk_size:]
# Keep one full chunk buffered so we can always tag the true
# last full chunk as final when stream length is an exact multiple.
if pending_chunk is not None:
yield TTSChunk(
audio=pending_chunk,
sample_rate=self.sample_rate,
is_final=False
)
pending_chunk = audio_chunk
# Flush pending chunk(s) and remaining tail.
if pending_chunk is not None:
if buffer:
yield TTSChunk(
audio=pending_chunk,
sample_rate=self.sample_rate,
is_final=False
)
pending_chunk = None
else:
yield TTSChunk(
audio=pending_chunk,
sample_rate=self.sample_rate,
is_final=True
)
pending_chunk = None
if buffer:
yield TTSChunk(
audio=buffer,
sample_rate=self.sample_rate,
is_final=True
)
except asyncio.CancelledError:
logger.info("TTS synthesis cancelled via asyncio")
raise
except Exception as e:
logger.error(f"TTS synthesis error: {e}")
raise
async def cancel(self) -> None:
"""Cancel ongoing synthesis."""
self._cancel_event.set()
class StreamingTTSAdapter:
"""
Adapter for streaming LLM text to TTS with sentence-level chunking.
This reduces latency by starting TTS as soon as a complete sentence
is received from the LLM, rather than waiting for the full response.
"""
# Sentence delimiters
SENTENCE_ENDS = {'', '', '', '', '.', '!', '?', '\n'}
def __init__(self, tts_service: BaseTTSService, transport, session_id: str):
self.tts_service = tts_service
self.transport = transport
self.session_id = session_id
self._buffer = ""
self._cancel_event = asyncio.Event()
self._is_speaking = False
def _is_non_sentence_period(self, text: str, idx: int) -> bool:
"""Check whether '.' should NOT be treated as a sentence delimiter."""
if text[idx] != ".":
return False
# Decimal/version segment: 1.2, v1.2.3
if idx > 0 and idx < len(text) - 1 and text[idx - 1].isdigit() and text[idx + 1].isdigit():
return True
# Number abbreviations: No.1 / No. 1
left_start = idx - 1
while left_start >= 0 and text[left_start].isalpha():
left_start -= 1
left_token = text[left_start + 1:idx].lower()
if left_token == "no":
j = idx + 1
while j < len(text) and text[j].isspace():
j += 1
if j < len(text) and text[j].isdigit():
return True
return False
async def process_text_chunk(self, text_chunk: str) -> None:
"""
Process a text chunk from LLM and trigger TTS when sentence is complete.
Args:
text_chunk: Text chunk from LLM streaming
"""
if self._cancel_event.is_set():
return
self._buffer += text_chunk
# Check for sentence completion
while True:
split_idx = -1
for i, char in enumerate(self._buffer):
if char == "." and self._is_non_sentence_period(self._buffer, i):
continue
if char in self.SENTENCE_ENDS:
split_idx = i
break
if split_idx < 0:
break
end_idx = split_idx + 1
while end_idx < len(self._buffer) and self._buffer[end_idx] in self.SENTENCE_ENDS:
end_idx += 1
sentence = self._buffer[:end_idx].strip()
self._buffer = self._buffer[end_idx:]
if sentence and any(ch.isalnum() for ch in sentence):
await self._speak_sentence(sentence)
async def flush(self) -> None:
"""Flush remaining buffer."""
if self._buffer.strip() and not self._cancel_event.is_set():
await self._speak_sentence(self._buffer.strip())
self._buffer = ""
async def _speak_sentence(self, text: str) -> None:
"""Synthesize and send a sentence."""
if not text or self._cancel_event.is_set():
return
self._is_speaking = True
try:
async for chunk in self.tts_service.synthesize_stream(text):
if self._cancel_event.is_set():
break
await self.transport.send_audio(chunk.audio)
await asyncio.sleep(0.01) # Prevent flooding
except Exception as e:
logger.error(f"TTS speak error: {e}")
finally:
self._is_speaking = False
def cancel(self) -> None:
"""Cancel ongoing speech."""
self._cancel_event.set()
self._buffer = ""
def reset(self) -> None:
"""Reset for new turn."""
self._cancel_event.clear()
self._buffer = ""
self._is_speaking = False
@property
def is_speaking(self) -> bool:
return self._is_speaking
# Backward-compatible alias
SiliconFlowTTSService = OpenAICompatibleTTSService

548
services/realtime.py Normal file
View File

@@ -0,0 +1,548 @@
"""OpenAI Realtime API Service.
Provides true duplex voice conversation using OpenAI's Realtime API,
similar to active-call's RealtimeProcessor. This bypasses the need for
separate ASR/LLM/TTS services by handling everything server-side.
The Realtime API provides:
- Server-side VAD with turn detection
- Streaming speech-to-text
- Streaming LLM responses
- Streaming text-to-speech
- Function calling support
- Barge-in/interruption handling
"""
import os
import asyncio
import json
import base64
from typing import Optional, Dict, Any, Callable, Awaitable, List
from dataclasses import dataclass, field
from enum import Enum
from loguru import logger
try:
import websockets
WEBSOCKETS_AVAILABLE = True
except ImportError:
WEBSOCKETS_AVAILABLE = False
logger.warning("websockets not available - Realtime API will be disabled")
class RealtimeState(Enum):
"""Realtime API connection state."""
DISCONNECTED = "disconnected"
CONNECTING = "connecting"
CONNECTED = "connected"
ERROR = "error"
@dataclass
class RealtimeConfig:
"""Configuration for OpenAI Realtime API."""
# API Configuration
api_key: Optional[str] = None
model: str = "gpt-4o-realtime-preview"
endpoint: Optional[str] = None # For Azure or custom endpoints
# Voice Configuration
voice: str = "alloy" # alloy, echo, shimmer, etc.
instructions: str = (
"You are a helpful, friendly voice assistant. "
"Keep your responses concise and conversational."
)
# Turn Detection (Server-side VAD)
turn_detection: Optional[Dict[str, Any]] = field(default_factory=lambda: {
"type": "server_vad",
"threshold": 0.5,
"prefix_padding_ms": 300,
"silence_duration_ms": 500
})
# Audio Configuration
input_audio_format: str = "pcm16"
output_audio_format: str = "pcm16"
# Tools/Functions
tools: List[Dict[str, Any]] = field(default_factory=list)
class RealtimeService:
"""
OpenAI Realtime API service for true duplex voice conversation.
This service handles the entire voice conversation pipeline:
1. Audio input → Server-side VAD → Speech-to-text
2. Text → LLM processing → Response generation
3. Response → Text-to-speech → Audio output
Events emitted:
- on_audio: Audio output from the assistant
- on_transcript: Text transcript (user or assistant)
- on_speech_started: User started speaking
- on_speech_stopped: User stopped speaking
- on_response_started: Assistant started responding
- on_response_done: Assistant finished responding
- on_function_call: Function call requested
- on_error: Error occurred
"""
def __init__(self, config: Optional[RealtimeConfig] = None):
"""
Initialize Realtime API service.
Args:
config: Realtime configuration (uses defaults if not provided)
"""
self.config = config or RealtimeConfig()
self.config.api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")
self.state = RealtimeState.DISCONNECTED
self._ws = None
self._receive_task: Optional[asyncio.Task] = None
self._cancel_event = asyncio.Event()
# Event callbacks
self._callbacks: Dict[str, List[Callable]] = {
"on_audio": [],
"on_transcript": [],
"on_speech_started": [],
"on_speech_stopped": [],
"on_response_started": [],
"on_response_done": [],
"on_function_call": [],
"on_error": [],
"on_interrupted": [],
}
logger.debug(f"RealtimeService initialized with model={self.config.model}")
def on(self, event: str, callback: Callable[..., Awaitable[None]]) -> None:
"""
Register event callback.
Args:
event: Event name
callback: Async callback function
"""
if event in self._callbacks:
self._callbacks[event].append(callback)
async def _emit(self, event: str, *args, **kwargs) -> None:
"""Emit event to all registered callbacks."""
for callback in self._callbacks.get(event, []):
try:
await callback(*args, **kwargs)
except Exception as e:
logger.error(f"Event callback error ({event}): {e}")
async def connect(self) -> None:
"""Connect to OpenAI Realtime API."""
if not WEBSOCKETS_AVAILABLE:
raise RuntimeError("websockets package not installed")
if not self.config.api_key:
raise ValueError("OpenAI API key not provided")
self.state = RealtimeState.CONNECTING
# Build URL
if self.config.endpoint:
# Azure or custom endpoint
url = f"{self.config.endpoint}/openai/realtime?api-version=2024-10-01-preview&deployment={self.config.model}"
else:
# OpenAI endpoint
url = f"wss://api.openai.com/v1/realtime?model={self.config.model}"
# Build headers
headers = {}
if self.config.endpoint:
headers["api-key"] = self.config.api_key
else:
headers["Authorization"] = f"Bearer {self.config.api_key}"
headers["OpenAI-Beta"] = "realtime=v1"
try:
logger.info(f"Connecting to Realtime API: {url}")
self._ws = await websockets.connect(url, extra_headers=headers)
# Send session configuration
await self._configure_session()
# Start receive loop
self._receive_task = asyncio.create_task(self._receive_loop())
self.state = RealtimeState.CONNECTED
logger.info("Realtime API connected successfully")
except Exception as e:
self.state = RealtimeState.ERROR
logger.error(f"Realtime API connection failed: {e}")
raise
async def _configure_session(self) -> None:
"""Send session configuration to server."""
session_config = {
"type": "session.update",
"session": {
"modalities": ["text", "audio"],
"instructions": self.config.instructions,
"voice": self.config.voice,
"input_audio_format": self.config.input_audio_format,
"output_audio_format": self.config.output_audio_format,
"turn_detection": self.config.turn_detection,
}
}
if self.config.tools:
session_config["session"]["tools"] = self.config.tools
await self._send(session_config)
logger.debug("Session configuration sent")
async def _send(self, data: Dict[str, Any]) -> None:
"""Send JSON data to server."""
if self._ws:
await self._ws.send(json.dumps(data))
async def send_audio(self, audio_bytes: bytes) -> None:
"""
Send audio to the Realtime API.
Args:
audio_bytes: PCM audio data (16-bit, mono, 24kHz by default)
"""
if self.state != RealtimeState.CONNECTED:
return
# Encode audio as base64
audio_b64 = base64.standard_b64encode(audio_bytes).decode()
await self._send({
"type": "input_audio_buffer.append",
"audio": audio_b64
})
async def send_text(self, text: str) -> None:
"""
Send text input (bypassing audio).
Args:
text: User text input
"""
if self.state != RealtimeState.CONNECTED:
return
# Create a conversation item with user text
await self._send({
"type": "conversation.item.create",
"item": {
"type": "message",
"role": "user",
"content": [{"type": "input_text", "text": text}]
}
})
# Trigger response
await self._send({"type": "response.create"})
async def cancel_response(self) -> None:
"""Cancel the current response (for barge-in)."""
if self.state != RealtimeState.CONNECTED:
return
await self._send({"type": "response.cancel"})
logger.debug("Response cancelled")
async def commit_audio(self) -> None:
"""Commit the audio buffer and trigger response."""
if self.state != RealtimeState.CONNECTED:
return
await self._send({"type": "input_audio_buffer.commit"})
await self._send({"type": "response.create"})
async def clear_audio_buffer(self) -> None:
"""Clear the input audio buffer."""
if self.state != RealtimeState.CONNECTED:
return
await self._send({"type": "input_audio_buffer.clear"})
async def submit_function_result(self, call_id: str, result: str) -> None:
"""
Submit function call result.
Args:
call_id: The function call ID
result: JSON string result
"""
if self.state != RealtimeState.CONNECTED:
return
await self._send({
"type": "conversation.item.create",
"item": {
"type": "function_call_output",
"call_id": call_id,
"output": result
}
})
# Trigger response with the function result
await self._send({"type": "response.create"})
async def _receive_loop(self) -> None:
"""Receive and process messages from the Realtime API."""
if not self._ws:
return
try:
async for message in self._ws:
try:
data = json.loads(message)
await self._handle_event(data)
except json.JSONDecodeError:
logger.warning(f"Invalid JSON received: {message[:100]}")
except asyncio.CancelledError:
logger.debug("Receive loop cancelled")
except websockets.ConnectionClosed as e:
logger.info(f"WebSocket closed: {e}")
self.state = RealtimeState.DISCONNECTED
except Exception as e:
logger.error(f"Receive loop error: {e}")
self.state = RealtimeState.ERROR
async def _handle_event(self, data: Dict[str, Any]) -> None:
"""Handle incoming event from Realtime API."""
event_type = data.get("type", "unknown")
# Audio delta - streaming audio output
if event_type == "response.audio.delta":
if "delta" in data:
audio_bytes = base64.standard_b64decode(data["delta"])
await self._emit("on_audio", audio_bytes)
# Audio transcript delta - streaming text
elif event_type == "response.audio_transcript.delta":
if "delta" in data:
await self._emit("on_transcript", data["delta"], "assistant", False)
# Audio transcript done
elif event_type == "response.audio_transcript.done":
if "transcript" in data:
await self._emit("on_transcript", data["transcript"], "assistant", True)
# Input audio transcript (user speech)
elif event_type == "conversation.item.input_audio_transcription.completed":
if "transcript" in data:
await self._emit("on_transcript", data["transcript"], "user", True)
# Speech started (server VAD detected speech)
elif event_type == "input_audio_buffer.speech_started":
await self._emit("on_speech_started", data.get("audio_start_ms", 0))
# Speech stopped
elif event_type == "input_audio_buffer.speech_stopped":
await self._emit("on_speech_stopped", data.get("audio_end_ms", 0))
# Response started
elif event_type == "response.created":
await self._emit("on_response_started", data.get("response", {}))
# Response done
elif event_type == "response.done":
await self._emit("on_response_done", data.get("response", {}))
# Function call
elif event_type == "response.function_call_arguments.done":
call_id = data.get("call_id")
name = data.get("name")
arguments = data.get("arguments", "{}")
await self._emit("on_function_call", call_id, name, arguments)
# Error
elif event_type == "error":
error = data.get("error", {})
logger.error(f"Realtime API error: {error}")
await self._emit("on_error", error)
# Session events
elif event_type == "session.created":
logger.info("Session created")
elif event_type == "session.updated":
logger.debug("Session updated")
else:
logger.debug(f"Unhandled event type: {event_type}")
async def disconnect(self) -> None:
"""Disconnect from Realtime API."""
self._cancel_event.set()
if self._receive_task:
self._receive_task.cancel()
try:
await self._receive_task
except asyncio.CancelledError:
pass
if self._ws:
await self._ws.close()
self._ws = None
self.state = RealtimeState.DISCONNECTED
logger.info("Realtime API disconnected")
class RealtimePipeline:
"""
Pipeline adapter for RealtimeService.
Provides a compatible interface with DuplexPipeline but uses
OpenAI Realtime API for all processing.
"""
def __init__(
self,
transport,
session_id: str,
config: Optional[RealtimeConfig] = None
):
"""
Initialize Realtime pipeline.
Args:
transport: Transport for sending audio/events
session_id: Session identifier
config: Realtime configuration
"""
self.transport = transport
self.session_id = session_id
self.service = RealtimeService(config)
# Register callbacks
self.service.on("on_audio", self._on_audio)
self.service.on("on_transcript", self._on_transcript)
self.service.on("on_speech_started", self._on_speech_started)
self.service.on("on_speech_stopped", self._on_speech_stopped)
self.service.on("on_response_started", self._on_response_started)
self.service.on("on_response_done", self._on_response_done)
self.service.on("on_error", self._on_error)
self._is_speaking = False
self._running = True
logger.info(f"RealtimePipeline initialized for session {session_id}")
async def start(self) -> None:
"""Start the pipeline."""
await self.service.connect()
async def process_audio(self, pcm_bytes: bytes) -> None:
"""
Process incoming audio.
Note: Realtime API expects 24kHz audio by default.
You may need to resample from 16kHz.
"""
if not self._running:
return
# TODO: Resample from 16kHz to 24kHz if needed
await self.service.send_audio(pcm_bytes)
async def process_text(self, text: str) -> None:
"""Process text input."""
if not self._running:
return
await self.service.send_text(text)
async def interrupt(self) -> None:
"""Interrupt current response."""
await self.service.cancel_response()
await self.transport.send_event({
"event": "interrupt",
"trackId": self.session_id,
"timestamp": self._get_timestamp_ms()
})
async def cleanup(self) -> None:
"""Cleanup resources."""
self._running = False
await self.service.disconnect()
# Event handlers
async def _on_audio(self, audio_bytes: bytes) -> None:
"""Handle audio output."""
await self.transport.send_audio(audio_bytes)
async def _on_transcript(self, text: str, role: str, is_final: bool) -> None:
"""Handle transcript."""
logger.info(f"[{role.upper()}] {text[:50]}..." if len(text) > 50 else f"[{role.upper()}] {text}")
async def _on_speech_started(self, start_ms: int) -> None:
"""Handle user speech start."""
self._is_speaking = True
await self.transport.send_event({
"event": "speaking",
"trackId": self.session_id,
"timestamp": self._get_timestamp_ms(),
"startTime": start_ms
})
# Cancel any ongoing response (barge-in)
await self.service.cancel_response()
async def _on_speech_stopped(self, end_ms: int) -> None:
"""Handle user speech stop."""
self._is_speaking = False
await self.transport.send_event({
"event": "silence",
"trackId": self.session_id,
"timestamp": self._get_timestamp_ms(),
"duration": end_ms
})
async def _on_response_started(self, response: Dict) -> None:
"""Handle response start."""
await self.transport.send_event({
"event": "trackStart",
"trackId": self.session_id,
"timestamp": self._get_timestamp_ms()
})
async def _on_response_done(self, response: Dict) -> None:
"""Handle response complete."""
await self.transport.send_event({
"event": "trackEnd",
"trackId": self.session_id,
"timestamp": self._get_timestamp_ms()
})
async def _on_error(self, error: Dict) -> None:
"""Handle error."""
await self.transport.send_event({
"event": "error",
"trackId": self.session_id,
"timestamp": self._get_timestamp_ms(),
"sender": "realtime",
"error": str(error)
})
def _get_timestamp_ms(self) -> int:
"""Get current timestamp in milliseconds."""
import time
return int(time.time() * 1000)
@property
def is_speaking(self) -> bool:
"""Check if user is speaking."""
return self._is_speaking

View File

@@ -0,0 +1,8 @@
"""Backward-compatible imports for legacy siliconflow_asr module."""
from services.openai_compatible_asr import OpenAICompatibleASRService
# Backward-compatible alias
SiliconFlowASRService = OpenAICompatibleASRService
__all__ = ["OpenAICompatibleASRService", "SiliconFlowASRService"]

View File

@@ -0,0 +1,8 @@
"""Backward-compatible imports for legacy siliconflow_tts module."""
from services.openai_compatible_tts import OpenAICompatibleTTSService, StreamingTTSAdapter
# Backward-compatible alias
SiliconFlowTTSService = OpenAICompatibleTTSService
__all__ = ["OpenAICompatibleTTSService", "SiliconFlowTTSService", "StreamingTTSAdapter"]

View File

@@ -0,0 +1,86 @@
"""Shared text chunking helpers for streaming TTS."""
from typing import Optional
def is_non_sentence_period(text: str, idx: int) -> bool:
"""Check whether '.' should NOT be treated as a sentence delimiter."""
if idx < 0 or idx >= len(text) or text[idx] != ".":
return False
# Decimal/version segment: 1.2, v1.2.3
if idx > 0 and idx < len(text) - 1 and text[idx - 1].isdigit() and text[idx + 1].isdigit():
return True
# Number abbreviations: No.1 / No. 1
left_start = idx - 1
while left_start >= 0 and text[left_start].isalpha():
left_start -= 1
left_token = text[left_start + 1:idx].lower()
if left_token == "no":
j = idx + 1
while j < len(text) and text[j].isspace():
j += 1
if j < len(text) and text[j].isdigit():
return True
return False
def has_spoken_content(text: str) -> bool:
"""Check whether text contains pronounceable content (not punctuation-only)."""
return any(char.isalnum() for char in text)
def extract_tts_sentence(
text_buffer: str,
*,
end_chars: frozenset[str],
trailing_chars: frozenset[str],
closers: frozenset[str],
min_split_spoken_chars: int = 0,
hold_trailing_at_buffer_end: bool = False,
force: bool = False,
) -> Optional[tuple[str, str]]:
"""Extract one TTS sentence from text buffer."""
if not text_buffer:
return None
search_start = 0
while True:
split_idx = -1
for idx in range(search_start, len(text_buffer)):
char = text_buffer[idx]
if char == "." and is_non_sentence_period(text_buffer, idx):
continue
if char in end_chars:
split_idx = idx
break
if split_idx == -1:
return None
end_idx = split_idx + 1
while end_idx < len(text_buffer) and text_buffer[end_idx] in trailing_chars:
end_idx += 1
while end_idx < len(text_buffer) and text_buffer[end_idx] in closers:
end_idx += 1
if hold_trailing_at_buffer_end and not force and end_idx >= len(text_buffer):
return None
sentence = text_buffer[:end_idx].strip()
spoken_chars = sum(1 for ch in sentence if ch.isalnum())
if (
not force
and min_split_spoken_chars > 0
and 0 < spoken_chars < min_split_spoken_chars
and end_idx < len(text_buffer)
):
search_start = end_idx
continue
remainder = text_buffer[end_idx:]
return sentence, remainder

View File

@@ -0,0 +1,95 @@
"""Backend-agnostic streaming adapter from LLM text to TTS audio."""
import asyncio
from loguru import logger
from services.base import BaseTTSService
from services.streaming_text import extract_tts_sentence, has_spoken_content
class StreamingTTSAdapter:
"""
Adapter for streaming LLM text to TTS with sentence-level chunking.
This reduces latency by starting TTS as soon as a complete sentence
is received from the LLM, rather than waiting for the full response.
"""
SENTENCE_ENDS = {"", "", "", ".", "!", "?", "\n"}
SENTENCE_CLOSERS = frozenset()
def __init__(self, tts_service: BaseTTSService, transport, session_id: str):
self.tts_service = tts_service
self.transport = transport
self.session_id = session_id
self._buffer = ""
self._cancel_event = asyncio.Event()
self._is_speaking = False
async def process_text_chunk(self, text_chunk: str) -> None:
"""
Process a text chunk from LLM and trigger TTS when sentence is complete.
Args:
text_chunk: Text chunk from LLM streaming
"""
if self._cancel_event.is_set():
return
self._buffer += text_chunk
# Check for sentence completion
while True:
split_result = extract_tts_sentence(
self._buffer,
end_chars=frozenset(self.SENTENCE_ENDS),
trailing_chars=frozenset(self.SENTENCE_ENDS),
closers=self.SENTENCE_CLOSERS,
force=False,
)
if not split_result:
break
sentence, self._buffer = split_result
if sentence and has_spoken_content(sentence):
await self._speak_sentence(sentence)
async def flush(self) -> None:
"""Flush remaining buffer."""
if self._buffer.strip() and not self._cancel_event.is_set():
await self._speak_sentence(self._buffer.strip())
self._buffer = ""
async def _speak_sentence(self, text: str) -> None:
"""Synthesize and send a sentence."""
if not text or self._cancel_event.is_set():
return
self._is_speaking = True
try:
async for chunk in self.tts_service.synthesize_stream(text):
if self._cancel_event.is_set():
break
await self.transport.send_audio(chunk.audio)
await asyncio.sleep(0.01) # Prevent flooding
except Exception as e:
logger.error(f"TTS speak error: {e}")
finally:
self._is_speaking = False
def cancel(self) -> None:
"""Cancel ongoing speech."""
self._cancel_event.set()
self._buffer = ""
def reset(self) -> None:
"""Reset for new turn."""
self._cancel_event.clear()
self._buffer = ""
self._is_speaking = False
@property
def is_speaking(self) -> bool:
return self._is_speaking

271
services/tts.py Normal file
View File

@@ -0,0 +1,271 @@
"""TTS (Text-to-Speech) Service implementations.
Provides multiple TTS backend options including edge-tts (free)
and placeholder for cloud services.
"""
import os
import io
import asyncio
import struct
from typing import AsyncIterator, Optional
from loguru import logger
from services.base import BaseTTSService, TTSChunk, ServiceState
# Try to import edge-tts
try:
import edge_tts
EDGE_TTS_AVAILABLE = True
except ImportError:
EDGE_TTS_AVAILABLE = False
logger.warning("edge-tts not available - EdgeTTS service will be disabled")
class EdgeTTSService(BaseTTSService):
"""
Microsoft Edge TTS service.
Uses edge-tts library for free, high-quality speech synthesis.
Supports streaming for low-latency playback.
"""
# Voice mapping for common languages
VOICE_MAP = {
"en": "en-US-JennyNeural",
"en-US": "en-US-JennyNeural",
"en-GB": "en-GB-SoniaNeural",
"zh": "zh-CN-XiaoxiaoNeural",
"zh-CN": "zh-CN-XiaoxiaoNeural",
"zh-TW": "zh-TW-HsiaoChenNeural",
"ja": "ja-JP-NanamiNeural",
"ko": "ko-KR-SunHiNeural",
"fr": "fr-FR-DeniseNeural",
"de": "de-DE-KatjaNeural",
"es": "es-ES-ElviraNeural",
}
def __init__(
self,
voice: str = "en-US-JennyNeural",
sample_rate: int = 16000,
speed: float = 1.0
):
"""
Initialize Edge TTS service.
Args:
voice: Voice name (e.g., "en-US-JennyNeural") or language code (e.g., "en")
sample_rate: Target sample rate (will be resampled)
speed: Speech speed multiplier
"""
# Resolve voice from language code if needed
if voice in self.VOICE_MAP:
voice = self.VOICE_MAP[voice]
super().__init__(voice=voice, sample_rate=sample_rate, speed=speed)
self._cancel_event = asyncio.Event()
async def connect(self) -> None:
"""Edge TTS doesn't require explicit connection."""
if not EDGE_TTS_AVAILABLE:
raise RuntimeError("edge-tts package not installed")
self.state = ServiceState.CONNECTED
logger.info(f"Edge TTS service ready: voice={self.voice}")
async def disconnect(self) -> None:
"""Edge TTS doesn't require explicit disconnection."""
self.state = ServiceState.DISCONNECTED
logger.info("Edge TTS service disconnected")
def _get_rate_string(self) -> str:
"""Convert speed to rate string for edge-tts."""
# edge-tts uses percentage format: "+0%", "-10%", "+20%"
percentage = int((self.speed - 1.0) * 100)
if percentage >= 0:
return f"+{percentage}%"
return f"{percentage}%"
async def synthesize(self, text: str) -> bytes:
"""
Synthesize complete audio for text.
Args:
text: Text to synthesize
Returns:
PCM audio data (16-bit, mono, 16kHz)
"""
if not EDGE_TTS_AVAILABLE:
raise RuntimeError("edge-tts not available")
# Collect all chunks
audio_data = b""
async for chunk in self.synthesize_stream(text):
audio_data += chunk.audio
return audio_data
async def synthesize_stream(self, text: str) -> AsyncIterator[TTSChunk]:
"""
Synthesize audio in streaming mode.
Args:
text: Text to synthesize
Yields:
TTSChunk objects with PCM audio
"""
if not EDGE_TTS_AVAILABLE:
raise RuntimeError("edge-tts not available")
self._cancel_event.clear()
try:
communicate = edge_tts.Communicate(
text,
voice=self.voice,
rate=self._get_rate_string()
)
# edge-tts outputs MP3, we need to decode to PCM
# For now, collect MP3 chunks and yield after conversion
mp3_data = b""
async for chunk in communicate.stream():
# Check for cancellation
if self._cancel_event.is_set():
logger.info("TTS synthesis cancelled")
return
if chunk["type"] == "audio":
mp3_data += chunk["data"]
# Convert MP3 to PCM
if mp3_data:
pcm_data = await self._convert_mp3_to_pcm(mp3_data)
if pcm_data:
# Yield in chunks for streaming playback
chunk_size = self.sample_rate * 2 // 10 # 100ms chunks
for i in range(0, len(pcm_data), chunk_size):
if self._cancel_event.is_set():
return
chunk_data = pcm_data[i:i + chunk_size]
yield TTSChunk(
audio=chunk_data,
sample_rate=self.sample_rate,
is_final=(i + chunk_size >= len(pcm_data))
)
except asyncio.CancelledError:
logger.info("TTS synthesis cancelled via asyncio")
raise
except Exception as e:
logger.error(f"TTS synthesis error: {e}")
raise
async def _convert_mp3_to_pcm(self, mp3_data: bytes) -> bytes:
"""
Convert MP3 audio to PCM.
Uses pydub or ffmpeg for conversion.
"""
try:
# Try using pydub (requires ffmpeg)
from pydub import AudioSegment
# Load MP3 from bytes
audio = AudioSegment.from_mp3(io.BytesIO(mp3_data))
# Convert to target format
audio = audio.set_frame_rate(self.sample_rate)
audio = audio.set_channels(1)
audio = audio.set_sample_width(2) # 16-bit
# Export as raw PCM
return audio.raw_data
except ImportError:
logger.warning("pydub not available, trying fallback")
# Fallback: Use subprocess to call ffmpeg directly
return await self._ffmpeg_convert(mp3_data)
except Exception as e:
logger.error(f"Audio conversion error: {e}")
return b""
async def _ffmpeg_convert(self, mp3_data: bytes) -> bytes:
"""Convert MP3 to PCM using ffmpeg subprocess."""
try:
process = await asyncio.create_subprocess_exec(
"ffmpeg",
"-i", "pipe:0",
"-f", "s16le",
"-acodec", "pcm_s16le",
"-ar", str(self.sample_rate),
"-ac", "1",
"pipe:1",
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.DEVNULL
)
stdout, _ = await process.communicate(input=mp3_data)
return stdout
except Exception as e:
logger.error(f"ffmpeg conversion error: {e}")
return b""
async def cancel(self) -> None:
"""Cancel ongoing synthesis."""
self._cancel_event.set()
class MockTTSService(BaseTTSService):
"""
Mock TTS service for testing without actual synthesis.
Generates silence or simple tones.
"""
def __init__(
self,
voice: str = "mock",
sample_rate: int = 16000,
speed: float = 1.0
):
super().__init__(voice=voice, sample_rate=sample_rate, speed=speed)
async def connect(self) -> None:
self.state = ServiceState.CONNECTED
logger.info("Mock TTS service connected")
async def disconnect(self) -> None:
self.state = ServiceState.DISCONNECTED
logger.info("Mock TTS service disconnected")
async def synthesize(self, text: str) -> bytes:
"""Generate silence based on text length."""
# Approximate: 100ms per word
word_count = len(text.split())
duration_ms = word_count * 100
samples = int(self.sample_rate * duration_ms / 1000)
# Generate silence (zeros)
return bytes(samples * 2) # 16-bit = 2 bytes per sample
async def synthesize_stream(self, text: str) -> AsyncIterator[TTSChunk]:
"""Generate silence chunks."""
audio = await self.synthesize(text)
# Yield in 100ms chunks
chunk_size = self.sample_rate * 2 // 10
for i in range(0, len(audio), chunk_size):
chunk_data = audio[i:i + chunk_size]
yield TTSChunk(
audio=chunk_data,
sample_rate=self.sample_rate,
is_final=(i + chunk_size >= len(audio))
)
await asyncio.sleep(0.05) # Simulate processing time

View File

@@ -0,0 +1,331 @@
import asyncio
from typing import Any, Dict, List
import pytest
from core.duplex_pipeline import DuplexPipeline
from models.ws_v1 import ToolCallResultsMessage, parse_client_message
from services.base import LLMStreamEvent
class _DummySileroVAD:
def __init__(self, *args, **kwargs):
pass
def process_audio(self, _pcm: bytes) -> float:
return 0.0
class _DummyVADProcessor:
def __init__(self, *args, **kwargs):
pass
def process(self, _speech_prob: float):
return "Silence", 0.0
class _DummyEouDetector:
def __init__(self, *args, **kwargs):
pass
def process(self, _vad_status: str) -> bool:
return False
def reset(self) -> None:
return None
class _FakeTransport:
async def send_event(self, _event: Dict[str, Any]) -> None:
return None
async def send_audio(self, _audio: bytes) -> None:
return None
class _FakeTTS:
async def synthesize_stream(self, _text: str):
if False:
yield None
class _FakeASR:
async def connect(self) -> None:
return None
class _FakeLLM:
def __init__(self, rounds: List[List[LLMStreamEvent]]):
self._rounds = rounds
self._call_index = 0
async def generate_stream(self, _messages, temperature=0.7, max_tokens=None):
idx = self._call_index
self._call_index += 1
events = self._rounds[idx] if idx < len(self._rounds) else [LLMStreamEvent(type="done")]
for event in events:
yield event
def _build_pipeline(monkeypatch, llm_rounds: List[List[LLMStreamEvent]]) -> tuple[DuplexPipeline, List[Dict[str, Any]]]:
monkeypatch.setattr("core.duplex_pipeline.SileroVAD", _DummySileroVAD)
monkeypatch.setattr("core.duplex_pipeline.VADProcessor", _DummyVADProcessor)
monkeypatch.setattr("core.duplex_pipeline.EouDetector", _DummyEouDetector)
pipeline = DuplexPipeline(
transport=_FakeTransport(),
session_id="s_test",
llm_service=_FakeLLM(llm_rounds),
tts_service=_FakeTTS(),
asr_service=_FakeASR(),
)
events: List[Dict[str, Any]] = []
async def _capture_event(event: Dict[str, Any], priority: int = 20):
events.append(event)
async def _noop_speak(_text: str, fade_in_ms: int = 0, fade_out_ms: int = 8):
return None
monkeypatch.setattr(pipeline, "_send_event", _capture_event)
monkeypatch.setattr(pipeline, "_speak_sentence", _noop_speak)
return pipeline, events
@pytest.mark.asyncio
async def test_ws_message_parses_tool_call_results():
msg = parse_client_message(
{
"type": "tool_call.results",
"results": [{"tool_call_id": "call_1", "status": {"code": 200, "message": "ok"}}],
}
)
assert isinstance(msg, ToolCallResultsMessage)
assert msg.results[0]["tool_call_id"] == "call_1"
@pytest.mark.asyncio
async def test_turn_without_tool_keeps_streaming(monkeypatch):
pipeline, events = _build_pipeline(
monkeypatch,
[
[
LLMStreamEvent(type="text_delta", text="hello "),
LLMStreamEvent(type="text_delta", text="world."),
LLMStreamEvent(type="done"),
]
],
)
await pipeline._handle_turn("hi")
event_types = [e.get("type") for e in events]
assert "assistant.response.delta" in event_types
assert "assistant.response.final" in event_types
assert "assistant.tool_call" not in event_types
@pytest.mark.asyncio
@pytest.mark.parametrize(
"metadata",
[
{"output": {"mode": "text"}},
{"services": {"tts": {"enabled": False}}},
],
)
async def test_text_output_mode_skips_audio_events(monkeypatch, metadata):
pipeline, events = _build_pipeline(
monkeypatch,
[
[
LLMStreamEvent(type="text_delta", text="hello "),
LLMStreamEvent(type="text_delta", text="world."),
LLMStreamEvent(type="done"),
]
],
)
pipeline.apply_runtime_overrides(metadata)
await pipeline._handle_turn("hi")
event_types = [e.get("type") for e in events]
assert "assistant.response.delta" in event_types
assert "assistant.response.final" in event_types
assert "output.audio.start" not in event_types
assert "output.audio.end" not in event_types
@pytest.mark.asyncio
async def test_turn_with_tool_call_then_results(monkeypatch):
pipeline, events = _build_pipeline(
monkeypatch,
[
[
LLMStreamEvent(type="text_delta", text="let me check."),
LLMStreamEvent(
type="tool_call",
tool_call={
"id": "call_ok",
"executor": "client",
"type": "function",
"function": {"name": "weather", "arguments": "{\"city\":\"hz\"}"},
},
),
LLMStreamEvent(type="done"),
],
[
LLMStreamEvent(type="text_delta", text="it's sunny."),
LLMStreamEvent(type="done"),
],
],
)
task = asyncio.create_task(pipeline._handle_turn("weather?"))
for _ in range(200):
if any(e.get("type") == "assistant.tool_call" for e in events):
break
await asyncio.sleep(0.005)
await pipeline.handle_tool_call_results(
[
{
"tool_call_id": "call_ok",
"name": "weather",
"output": {"temp": 21},
"status": {"code": 200, "message": "ok"},
}
]
)
await task
assert any(e.get("type") == "assistant.tool_call" for e in events)
finals = [e for e in events if e.get("type") == "assistant.response.final"]
assert finals
assert "it's sunny" in finals[-1].get("text", "")
@pytest.mark.asyncio
async def test_turn_with_tool_call_timeout(monkeypatch):
pipeline, events = _build_pipeline(
monkeypatch,
[
[
LLMStreamEvent(
type="tool_call",
tool_call={
"id": "call_timeout",
"executor": "client",
"type": "function",
"function": {"name": "search", "arguments": "{\"query\":\"x\"}"},
},
),
LLMStreamEvent(type="done"),
],
[
LLMStreamEvent(type="text_delta", text="fallback answer."),
LLMStreamEvent(type="done"),
],
],
)
pipeline._TOOL_WAIT_TIMEOUT_SECONDS = 0.01
await pipeline._handle_turn("query")
finals = [e for e in events if e.get("type") == "assistant.response.final"]
assert finals
assert "fallback answer" in finals[-1].get("text", "")
@pytest.mark.asyncio
async def test_duplicate_tool_results_are_ignored(monkeypatch):
pipeline, _events = _build_pipeline(monkeypatch, [[LLMStreamEvent(type="done")]])
await pipeline.handle_tool_call_results(
[{"tool_call_id": "call_dup", "output": {"value": 1}, "status": {"code": 200, "message": "ok"}}]
)
await pipeline.handle_tool_call_results(
[{"tool_call_id": "call_dup", "output": {"value": 2}, "status": {"code": 200, "message": "ok"}}]
)
result = await pipeline._wait_for_single_tool_result("call_dup")
assert result.get("output", {}).get("value") == 1
@pytest.mark.asyncio
async def test_server_calculator_emits_tool_result(monkeypatch):
pipeline, events = _build_pipeline(
monkeypatch,
[
[
LLMStreamEvent(
type="tool_call",
tool_call={
"id": "call_calc",
"executor": "server",
"type": "function",
"function": {"name": "calculator", "arguments": "{\"expression\":\"1+2\"}"},
},
),
LLMStreamEvent(type="done"),
],
[
LLMStreamEvent(type="text_delta", text="done."),
LLMStreamEvent(type="done"),
],
],
)
await pipeline._handle_turn("calc")
tool_results = [e for e in events if e.get("type") == "assistant.tool_result"]
assert tool_results
payload = tool_results[-1].get("result", {})
assert payload.get("status", {}).get("code") == 200
assert payload.get("output", {}).get("result") == 3
@pytest.mark.asyncio
async def test_server_tool_timeout_emits_504_and_continues(monkeypatch):
async def _slow_execute(_call):
await asyncio.sleep(0.05)
return {
"tool_call_id": "call_slow",
"name": "weather",
"output": {"ok": True},
"status": {"code": 200, "message": "ok"},
}
monkeypatch.setattr("core.duplex_pipeline.execute_server_tool", _slow_execute)
pipeline, events = _build_pipeline(
monkeypatch,
[
[
LLMStreamEvent(
type="tool_call",
tool_call={
"id": "call_slow",
"executor": "server",
"type": "function",
"function": {"name": "weather", "arguments": "{\"city\":\"hz\"}"},
},
),
LLMStreamEvent(type="done"),
],
[
LLMStreamEvent(type="text_delta", text="timeout fallback."),
LLMStreamEvent(type="done"),
],
],
)
pipeline._SERVER_TOOL_TIMEOUT_SECONDS = 0.01
await pipeline._handle_turn("weather?")
tool_results = [e for e in events if e.get("type") == "assistant.tool_result"]
assert tool_results
payload = tool_results[-1].get("result", {})
assert payload.get("status", {}).get("code") == 504
finals = [e for e in events if e.get("type") == "assistant.response.final"]
assert finals
assert "timeout fallback" in finals[-1].get("text", "")

View File

@@ -0,0 +1,57 @@
import pytest
from core.tool_executor import execute_server_tool
@pytest.mark.asyncio
async def test_code_interpreter_simple_expression():
result = await execute_server_tool(
{
"id": "call_ci_ok",
"function": {
"name": "code_interpreter",
"arguments": '{"code":"sum([1, 2, 3]) + 4"}',
},
}
)
assert result["status"]["code"] == 200
assert result["output"]["result"] == 10
@pytest.mark.asyncio
async def test_code_interpreter_blocks_import_and_io():
result = await execute_server_tool(
{
"id": "call_ci_bad",
"function": {
"name": "code_interpreter",
"arguments": '{"code":"__import__(\\"os\\").system(\\"ls\\")"}',
},
}
)
assert result["status"]["code"] == 422
assert result["status"]["message"] == "invalid_code"
@pytest.mark.asyncio
async def test_current_time_uses_local_system_clock(monkeypatch):
async def _should_not_be_called(_tool_id):
raise AssertionError("fetch_tool_resource should not be called for current_time")
monkeypatch.setattr("core.tool_executor.fetch_tool_resource", _should_not_be_called)
result = await execute_server_tool(
{
"id": "call_time_ok",
"function": {
"name": "current_time",
"arguments": "{}",
},
}
)
assert result["status"]["code"] == 200
assert result["status"]["message"] == "ok"
assert "local_time" in result["output"]
assert "iso" in result["output"]
assert "timestamp" in result["output"]

1
utils/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Utilities Package"""

83
utils/logging.py Normal file
View File

@@ -0,0 +1,83 @@
"""Logging configuration utilities."""
import sys
from loguru import logger
from pathlib import Path
def setup_logging(
log_level: str = "INFO",
log_format: str = "text",
log_to_file: bool = True,
log_dir: str = "logs"
):
"""
Configure structured logging with loguru.
Args:
log_level: Logging level (DEBUG, INFO, WARNING, ERROR)
log_format: Format type (json or text)
log_to_file: Whether to log to file
log_dir: Directory for log files
"""
# Remove default handler
logger.remove()
# Console handler
if log_format == "json":
logger.add(
sys.stdout,
format="{message}",
level=log_level,
serialize=True,
colorize=False
)
else:
logger.add(
sys.stdout,
format="<green>{time:HH:mm:ss}</green> | <level>{level: <8}</level> | <level>{message}</level>",
level=log_level,
colorize=True
)
# File handler
if log_to_file:
log_path = Path(log_dir)
log_path.mkdir(exist_ok=True)
if log_format == "json":
logger.add(
log_path / "active_call_{time:YYYY-MM-DD}.log",
format="{message}",
level=log_level,
rotation="1 day",
retention="7 days",
compression="zip",
serialize=True
)
else:
logger.add(
log_path / "active_call_{time:YYYY-MM-DD}.log",
format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}",
level=log_level,
rotation="1 day",
retention="7 days",
compression="zip"
)
return logger
def get_logger(name: str = None):
"""
Get a logger instance.
Args:
name: Logger name (optional)
Returns:
Logger instance
"""
if name:
return logger.bind(name=name)
return logger