commit 30eb4397c254373812e3926a3f218996832c53cc Author: Xin Wang Date: Tue Feb 17 10:39:23 2026 +0800 Init commit diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..f62a4c6 --- /dev/null +++ b/.env.example @@ -0,0 +1,92 @@ +# ----------------------------------------------------------------------------- +# Engine .env example (safe template) +# Notes: +# - Never commit real API keys. +# - Start with defaults below, then tune from logs. +# ----------------------------------------------------------------------------- + +# Server +HOST=0.0.0.0 +PORT=8000 +# EXTERNAL_IP=1.2.3.4 + +# Backend bridge (optional) +BACKEND_URL=http://127.0.0.1:8100 +BACKEND_TIMEOUT_SEC=10 +HISTORY_DEFAULT_USER_ID=1 + +# Audio +SAMPLE_RATE=16000 +# 20ms is recommended for VAD stability and latency. +# 100ms works but usually worsens start-of-speech accuracy. +CHUNK_SIZE_MS=20 +DEFAULT_CODEC=pcm +MAX_AUDIO_BUFFER_SECONDS=30 + +# VAD / EOU +VAD_TYPE=silero +VAD_MODEL_PATH=data/vad/silero_vad.onnx +# Higher = stricter speech detection (fewer false positives, more misses). +VAD_THRESHOLD=0.5 +# Require this much continuous speech before utterance can be valid. +VAD_MIN_SPEECH_DURATION_MS=100 +# Silence duration required to finalize one user turn. +VAD_EOU_THRESHOLD_MS=800 + +# LLM +OPENAI_API_KEY=your_openai_api_key_here +# Optional for OpenAI-compatible providers. +# OPENAI_API_URL=https://api.openai.com/v1 +LLM_MODEL=gpt-4o-mini +LLM_TEMPERATURE=0.7 + +# TTS +# edge: no API key needed +# openai_compatible: compatible with SiliconFlow-style endpoints +TTS_PROVIDER=openai_compatible +TTS_VOICE=anna +TTS_SPEED=1.0 + +# SiliconFlow (used by TTS and/or ASR when provider=openai_compatible) +SILICONFLOW_API_KEY=your_siliconflow_api_key_here +SILICONFLOW_TTS_MODEL=FunAudioLLM/CosyVoice2-0.5B +SILICONFLOW_ASR_MODEL=FunAudioLLM/SenseVoiceSmall + +# ASR +ASR_PROVIDER=openai_compatible +# Interim cadence and minimum audio before interim decode. +ASR_INTERIM_INTERVAL_MS=500 +ASR_MIN_AUDIO_MS=300 +# ASR start gate: ignore micro-noise, then commit to one turn once started. +ASR_START_MIN_SPEECH_MS=160 +# Pre-roll protects beginning phonemes. +ASR_PRE_SPEECH_MS=240 +# Tail silence protects ending phonemes. +ASR_FINAL_TAIL_MS=120 + +# Duplex behavior +DUPLEX_ENABLED=true +# DUPLEX_GREETING=Hello! How can I help you today? +DUPLEX_SYSTEM_PROMPT=You are a helpful, friendly voice assistant. Keep your responses concise and conversational. + +# Barge-in (user interrupting assistant) +# Min user speech duration needed to interrupt assistant audio. +BARGE_IN_MIN_DURATION_MS=200 +# Allowed silence during potential barge-in (ms) before reset. +BARGE_IN_SILENCE_TOLERANCE_MS=60 + +# Logging +LOG_LEVEL=INFO +# json is better for production/observability; text is easier locally. +LOG_FORMAT=json + +# WebSocket behavior +INACTIVITY_TIMEOUT_SEC=60 +HEARTBEAT_INTERVAL_SEC=50 +WS_PROTOCOL_VERSION=v1 +# WS_API_KEY=replace_with_shared_secret +WS_REQUIRE_AUTH=false + +# CORS / ICE (JSON strings) +CORS_ORIGINS=["http://localhost:3000","http://localhost:8080"] +ICE_SERVERS=[{"urls":"stun:stun.l.google.com:19302"}] diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5cd10e8 --- /dev/null +++ b/.gitignore @@ -0,0 +1,148 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +Pipfile.lock + +# poetry +poetry.lock + +# pdm +.pdm.toml + +# PEP 582 +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# IDEs +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# Project specific +recordings/ +logs/ +running/ diff --git a/README.md b/README.md new file mode 100644 index 0000000..17d9e3a --- /dev/null +++ b/README.md @@ -0,0 +1,31 @@ +# py-active-call-cc + +Python Active-Call: real-time audio streaming with WebSocket and WebRTC. + +This repo contains a Python 3.11+ codebase for building low-latency voice +pipelines (capture, stream, and process audio) using WebRTC and WebSockets. +It is currently in an early, experimental stage. + +# Usage + +启动 + +``` +uvicorn app.main:app --reload --host 0.0.0.0 --port 8000 +``` + +测试 + +``` +python examples/test_websocket.py +``` + +``` +python mic_client.py +``` + +## WS Protocol + +`/ws` uses a strict `v1` JSON control protocol with binary PCM audio frames. + +See `/Users/wx44wx/.codex/worktrees/d817/AI-VideoAssistant/engine/docs/ws_v1_schema.md`. diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 0000000..c136b14 --- /dev/null +++ b/app/__init__.py @@ -0,0 +1 @@ +"""Active-Call Application Package""" diff --git a/app/backend_client.py b/app/backend_client.py new file mode 100644 index 0000000..b750564 --- /dev/null +++ b/app/backend_client.py @@ -0,0 +1,211 @@ +"""Backend API client for assistant config and history persistence.""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +import aiohttp +from loguru import logger + +from app.config import settings + + +async def fetch_assistant_config(assistant_id: str) -> Optional[Dict[str, Any]]: + """Fetch assistant config payload from backend API. + + Expected response shape: + { + "assistant": {...}, + "voice": {...} | null + } + """ + if not settings.backend_url: + logger.warning("BACKEND_URL not set; skipping assistant config fetch") + return None + + url = f"{settings.backend_url.rstrip('/')}/api/assistants/{assistant_id}/config" + timeout = aiohttp.ClientTimeout(total=settings.backend_timeout_sec) + + try: + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.get(url) as resp: + if resp.status == 404: + logger.warning(f"Assistant config not found: {assistant_id}") + return None + resp.raise_for_status() + payload = await resp.json() + if not isinstance(payload, dict): + logger.warning("Assistant config payload is not a dict; ignoring") + return None + return payload + except Exception as exc: + logger.warning(f"Failed to fetch assistant config ({assistant_id}): {exc}") + return None + + +def _backend_base_url() -> Optional[str]: + if not settings.backend_url: + return None + return settings.backend_url.rstrip("/") + + +def _timeout() -> aiohttp.ClientTimeout: + return aiohttp.ClientTimeout(total=settings.backend_timeout_sec) + + +async def create_history_call_record( + *, + user_id: int, + assistant_id: Optional[str], + source: str = "debug", +) -> Optional[str]: + """Create a call record via backend history API and return call_id.""" + base_url = _backend_base_url() + if not base_url: + return None + + url = f"{base_url}/api/history" + payload: Dict[str, Any] = { + "user_id": user_id, + "assistant_id": assistant_id, + "source": source, + "status": "connected", + } + + try: + async with aiohttp.ClientSession(timeout=_timeout()) as session: + async with session.post(url, json=payload) as resp: + resp.raise_for_status() + data = await resp.json() + call_id = str((data or {}).get("id") or "") + return call_id or None + except Exception as exc: + logger.warning(f"Failed to create history call record: {exc}") + return None + + +async def add_history_transcript( + *, + call_id: str, + turn_index: int, + speaker: str, + content: str, + start_ms: int, + end_ms: int, + confidence: Optional[float] = None, + duration_ms: Optional[int] = None, +) -> bool: + """Append a transcript segment to backend history.""" + base_url = _backend_base_url() + if not base_url or not call_id: + return False + + url = f"{base_url}/api/history/{call_id}/transcripts" + payload: Dict[str, Any] = { + "turn_index": turn_index, + "speaker": speaker, + "content": content, + "confidence": confidence, + "start_ms": start_ms, + "end_ms": end_ms, + "duration_ms": duration_ms, + } + + try: + async with aiohttp.ClientSession(timeout=_timeout()) as session: + async with session.post(url, json=payload) as resp: + resp.raise_for_status() + return True + except Exception as exc: + logger.warning(f"Failed to append history transcript (call_id={call_id}, turn={turn_index}): {exc}") + return False + + +async def finalize_history_call_record( + *, + call_id: str, + status: str, + duration_seconds: int, +) -> bool: + """Finalize a call record with status and duration.""" + base_url = _backend_base_url() + if not base_url or not call_id: + return False + + url = f"{base_url}/api/history/{call_id}" + payload: Dict[str, Any] = { + "status": status, + "duration_seconds": duration_seconds, + } + + try: + async with aiohttp.ClientSession(timeout=_timeout()) as session: + async with session.put(url, json=payload) as resp: + resp.raise_for_status() + return True + except Exception as exc: + logger.warning(f"Failed to finalize history call record ({call_id}): {exc}") + return False + + +async def search_knowledge_context( + *, + kb_id: str, + query: str, + n_results: int = 5, +) -> List[Dict[str, Any]]: + """Search backend knowledge base and return retrieval results.""" + base_url = _backend_base_url() + if not base_url: + return [] + if not kb_id or not query.strip(): + return [] + try: + safe_n_results = max(1, int(n_results)) + except (TypeError, ValueError): + safe_n_results = 5 + + url = f"{base_url}/api/knowledge/search" + payload: Dict[str, Any] = { + "kb_id": kb_id, + "query": query, + "nResults": safe_n_results, + } + + try: + async with aiohttp.ClientSession(timeout=_timeout()) as session: + async with session.post(url, json=payload) as resp: + if resp.status == 404: + logger.warning(f"Knowledge base not found for retrieval: {kb_id}") + return [] + resp.raise_for_status() + data = await resp.json() + if not isinstance(data, dict): + return [] + results = data.get("results", []) + if not isinstance(results, list): + return [] + return [r for r in results if isinstance(r, dict)] + except Exception as exc: + logger.warning(f"Knowledge search failed (kb_id={kb_id}): {exc}") + return [] + + +async def fetch_tool_resource(tool_id: str) -> Optional[Dict[str, Any]]: + """Fetch tool resource configuration from backend API.""" + base_url = _backend_base_url() + if not base_url or not tool_id: + return None + + url = f"{base_url}/api/tools/resources/{tool_id}" + try: + async with aiohttp.ClientSession(timeout=_timeout()) as session: + async with session.get(url) as resp: + if resp.status == 404: + return None + resp.raise_for_status() + data = await resp.json() + return data if isinstance(data, dict) else None + except Exception as exc: + logger.warning(f"Failed to fetch tool resource ({tool_id}): {exc}") + return None diff --git a/app/config.py b/app/config.py new file mode 100644 index 0000000..1e3e1b3 --- /dev/null +++ b/app/config.py @@ -0,0 +1,154 @@ +"""Configuration management using Pydantic settings.""" + +from typing import List, Optional +from pydantic import Field +from pydantic_settings import BaseSettings, SettingsConfigDict +import json + + +class Settings(BaseSettings): + """Application settings loaded from environment variables.""" + + model_config = SettingsConfigDict( + env_file=".env", + env_file_encoding="utf-8", + case_sensitive=False, + extra="ignore" + ) + + # Server Configuration + host: str = Field(default="0.0.0.0", description="Server host address") + port: int = Field(default=8000, description="Server port") + external_ip: Optional[str] = Field(default=None, description="External IP for NAT traversal") + + # Audio Configuration + sample_rate: int = Field(default=16000, description="Audio sample rate in Hz") + chunk_size_ms: int = Field(default=20, description="Audio chunk duration in milliseconds") + default_codec: str = Field(default="pcm", description="Default audio codec") + max_audio_buffer_seconds: int = Field( + default=30, + description="Maximum buffered user audio duration kept in memory for current turn" + ) + + # VAD Configuration + vad_type: str = Field(default="silero", description="VAD algorithm type") + vad_model_path: str = Field(default="data/vad/silero_vad.onnx", description="Path to VAD model") + vad_threshold: float = Field(default=0.5, description="VAD detection threshold") + vad_min_speech_duration_ms: int = Field(default=100, description="Minimum speech duration in milliseconds") + vad_eou_threshold_ms: int = Field(default=800, description="End of utterance (silence) threshold in milliseconds") + + # OpenAI / LLM Configuration + openai_api_key: Optional[str] = Field(default=None, description="OpenAI API key") + openai_api_url: Optional[str] = Field(default=None, description="OpenAI API base URL (for Azure/compatible)") + llm_model: str = Field(default="gpt-4o-mini", description="LLM model name") + llm_temperature: float = Field(default=0.7, description="LLM temperature for response generation") + + # TTS Configuration + tts_provider: str = Field( + default="openai_compatible", + description="TTS provider (edge, openai_compatible; siliconflow alias supported)" + ) + tts_voice: str = Field(default="anna", description="TTS voice name") + tts_speed: float = Field(default=1.0, description="TTS speech speed multiplier") + + # SiliconFlow Configuration + siliconflow_api_key: Optional[str] = Field(default=None, description="SiliconFlow API key") + siliconflow_tts_model: str = Field(default="FunAudioLLM/CosyVoice2-0.5B", description="SiliconFlow TTS model") + + # ASR Configuration + asr_provider: str = Field( + default="openai_compatible", + description="ASR provider (openai_compatible, buffered; siliconflow alias supported)" + ) + siliconflow_asr_model: str = Field(default="FunAudioLLM/SenseVoiceSmall", description="SiliconFlow ASR model") + asr_interim_interval_ms: int = Field(default=500, description="Interval for interim ASR results in ms") + asr_min_audio_ms: int = Field(default=300, description="Minimum audio duration before first ASR result") + asr_start_min_speech_ms: int = Field( + default=160, + description="Minimum continuous speech duration before ASR capture starts" + ) + asr_pre_speech_ms: int = Field( + default=240, + description="Audio context (ms) prepended before detected speech to avoid clipping first phoneme" + ) + asr_final_tail_ms: int = Field( + default=120, + description="Silence tail (ms) appended before final ASR decode to protect utterance ending" + ) + + # Duplex Pipeline Configuration + duplex_enabled: bool = Field(default=True, description="Enable duplex voice pipeline") + duplex_greeting: Optional[str] = Field(default=None, description="Optional greeting message") + duplex_system_prompt: Optional[str] = Field( + default="You are a helpful, friendly voice assistant. Keep your responses concise and conversational.", + description="System prompt for LLM" + ) + + # Barge-in (interruption) Configuration + barge_in_min_duration_ms: int = Field( + default=200, + description="Minimum speech duration (ms) required to trigger barge-in. Lower=more sensitive." + ) + barge_in_silence_tolerance_ms: int = Field( + default=60, + description="How much silence (ms) is tolerated during potential barge-in before reset" + ) + + # Logging + log_level: str = Field(default="INFO", description="Logging level") + log_format: str = Field(default="json", description="Log format (json or text)") + + # CORS + cors_origins: str = Field( + default='["http://localhost:3000", "http://localhost:8080"]', + description="CORS allowed origins" + ) + + # ICE Servers (WebRTC) + ice_servers: str = Field( + default='[{"urls": "stun:stun.l.google.com:19302"}]', + description="ICE servers configuration" + ) + + # WebSocket heartbeat and inactivity + inactivity_timeout_sec: int = Field(default=60, description="Close connection after no message from client (seconds)") + heartbeat_interval_sec: int = Field(default=50, description="Send heartBeat event to client every N seconds") + ws_protocol_version: str = Field(default="v1", description="Public WS protocol version") + ws_api_key: Optional[str] = Field(default=None, description="Optional API key required for WS hello auth") + ws_require_auth: bool = Field(default=False, description="Require auth in hello message even when ws_api_key is not set") + + # Backend bridge configuration (for call/transcript persistence) + backend_url: Optional[str] = Field(default=None, description="Backend API base URL (e.g. http://localhost:8787)") + backend_timeout_sec: int = Field(default=10, description="Backend API request timeout in seconds") + history_default_user_id: int = Field(default=1, description="Fallback user_id for history records") + + @property + def chunk_size_bytes(self) -> int: + """Calculate chunk size in bytes based on sample rate and duration.""" + # 16-bit (2 bytes) per sample, mono channel + return int(self.sample_rate * 2 * (self.chunk_size_ms / 1000.0)) + + @property + def cors_origins_list(self) -> List[str]: + """Parse CORS origins from JSON string.""" + try: + return json.loads(self.cors_origins) + except json.JSONDecodeError: + return ["http://localhost:3000", "http://localhost:8080"] + + @property + def ice_servers_list(self) -> List[dict]: + """Parse ICE servers from JSON string.""" + try: + return json.loads(self.ice_servers) + except json.JSONDecodeError: + return [{"urls": "stun:stun.l.google.com:19302"}] + + +# Global settings instance +settings = Settings() + + +def get_settings() -> Settings: + """Get application settings instance.""" + return settings diff --git a/app/main.py b/app/main.py new file mode 100644 index 0000000..259204c --- /dev/null +++ b/app/main.py @@ -0,0 +1,396 @@ +"""FastAPI application with WebSocket and WebRTC endpoints.""" + +import asyncio +import json +import time +import uuid +from pathlib import Path +from typing import Dict, Any, Optional, List +from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse, FileResponse +from loguru import logger + +# Try to import aiortc (optional for WebRTC functionality) +try: + from aiortc import RTCPeerConnection, RTCSessionDescription + AIORTC_AVAILABLE = True +except ImportError: + AIORTC_AVAILABLE = False + logger.warning("aiortc not available - WebRTC endpoint will be disabled") + +from app.config import settings +from core.transports import SocketTransport, WebRtcTransport, BaseTransport +from core.session import Session +from processors.tracks import Resampled16kTrack +from core.events import get_event_bus, reset_event_bus +from models.ws_v1 import ev + +# Check interval for heartbeat/timeout (seconds) +_HEARTBEAT_CHECK_INTERVAL_SEC = 5 + + +async def heartbeat_and_timeout_task( + transport: BaseTransport, + session: Session, + session_id: str, + last_received_at: List[float], + last_heartbeat_at: List[float], + inactivity_timeout_sec: int, + heartbeat_interval_sec: int, +) -> None: + """ + Background task: send heartBeat every ~heartbeat_interval_sec and close + connection if no message from client for inactivity_timeout_sec. + """ + while True: + await asyncio.sleep(_HEARTBEAT_CHECK_INTERVAL_SEC) + if transport.is_closed: + break + now = time.monotonic() + if now - last_received_at[0] > inactivity_timeout_sec: + logger.info(f"Session {session_id}: {inactivity_timeout_sec}s no message, closing") + await session.cleanup() + break + if now - last_heartbeat_at[0] >= heartbeat_interval_sec: + try: + await transport.send_event({ + **ev("heartbeat"), + }) + last_heartbeat_at[0] = now + except Exception as e: + logger.debug(f"Session {session_id}: heartbeat send failed: {e}") + break + + +# Initialize FastAPI +app = FastAPI(title="Python Active-Call", version="0.1.0") +_WEB_CLIENT_PATH = Path(__file__).resolve().parent.parent / "examples" / "web_client.html" + +# Configure CORS +app.add_middleware( + CORSMiddleware, + allow_origins=settings.cors_origins_list, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Active sessions storage +active_sessions: Dict[str, Session] = {} + +# Configure logging +logger.remove() +logger.add( + "./logs/active_call_{time}.log", + rotation="1 day", + retention="7 days", + level=settings.log_level, + format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}" +) +logger.add( + lambda msg: print(msg, end=""), + level=settings.log_level, + format="{time:HH:mm:ss} | {level: <8} | {message}" +) + + +@app.get("/health") +async def health_check(): + """Health check endpoint.""" + return {"status": "healthy", "sessions": len(active_sessions)} + + +@app.get("/") +async def web_client_root(): + """Serve the web client.""" + if not _WEB_CLIENT_PATH.exists(): + raise HTTPException(status_code=404, detail="Web client not found") + return FileResponse(_WEB_CLIENT_PATH) + + +@app.get("/client") +async def web_client_alias(): + """Alias for the web client.""" + if not _WEB_CLIENT_PATH.exists(): + raise HTTPException(status_code=404, detail="Web client not found") + return FileResponse(_WEB_CLIENT_PATH) + + + + +@app.get("/iceservers") +async def get_ice_servers(): + """Get ICE servers configuration for WebRTC.""" + return settings.ice_servers_list + + +@app.get("/call/lists") +async def list_calls(): + """List all active calls.""" + return { + "calls": [ + { + "id": session_id, + "state": session.state, + "created_at": session.created_at + } + for session_id, session in active_sessions.items() + ] + } + + +@app.post("/call/kill/{session_id}") +async def kill_call(session_id: str): + """Kill a specific active call.""" + if session_id not in active_sessions: + raise HTTPException(status_code=404, detail="Session not found") + + session = active_sessions[session_id] + await session.cleanup() + del active_sessions[session_id] + + return True + + +@app.websocket("/ws") +async def websocket_endpoint(websocket: WebSocket): + """ + WebSocket endpoint for raw audio streaming. + + Accepts mixed text/binary frames: + - Text frames: JSON commands + - Binary frames: PCM audio data (16kHz, 16-bit, mono) + """ + await websocket.accept() + session_id = str(uuid.uuid4()) + + # Create transport and session + transport = SocketTransport(websocket) + session = Session(session_id, transport) + active_sessions[session_id] = session + + logger.info(f"WebSocket connection established: {session_id}") + + last_received_at: List[float] = [time.monotonic()] + last_heartbeat_at: List[float] = [0.0] + hb_task = asyncio.create_task( + heartbeat_and_timeout_task( + transport, + session, + session_id, + last_received_at, + last_heartbeat_at, + settings.inactivity_timeout_sec, + settings.heartbeat_interval_sec, + ) + ) + + try: + # Receive loop + while True: + message = await websocket.receive() + message_type = message.get("type") + + if message_type == "websocket.disconnect": + logger.info(f"WebSocket disconnected: {session_id}") + break + + last_received_at[0] = time.monotonic() + + # Handle binary audio data + if "bytes" in message: + await session.handle_audio(message["bytes"]) + + # Handle text commands + elif "text" in message: + await session.handle_text(message["text"]) + + except WebSocketDisconnect: + logger.info(f"WebSocket disconnected: {session_id}") + + except Exception as e: + logger.error(f"WebSocket error: {e}", exc_info=True) + + finally: + hb_task.cancel() + try: + await hb_task + except asyncio.CancelledError: + pass + # Cleanup session + if session_id in active_sessions: + await session.cleanup() + del active_sessions[session_id] + + logger.info(f"Session {session_id} removed") + + +@app.websocket("/webrtc") +async def webrtc_endpoint(websocket: WebSocket): + """ + WebRTC endpoint for WebRTC audio streaming. + + Uses WebSocket for signaling (SDP exchange) and WebRTC for media transport. + """ + # Check if aiortc is available + if not AIORTC_AVAILABLE: + await websocket.close(code=1011, reason="WebRTC not available - aiortc/av not installed") + logger.warning("WebRTC connection attempted but aiortc is not available") + return + await websocket.accept() + session_id = str(uuid.uuid4()) + + # Create WebRTC peer connection + pc = RTCPeerConnection() + + # Create transport and session + transport = WebRtcTransport(websocket, pc) + session = Session(session_id, transport) + active_sessions[session_id] = session + + logger.info(f"WebRTC connection established: {session_id}") + + last_received_at: List[float] = [time.monotonic()] + last_heartbeat_at: List[float] = [0.0] + hb_task = asyncio.create_task( + heartbeat_and_timeout_task( + transport, + session, + session_id, + last_received_at, + last_heartbeat_at, + settings.inactivity_timeout_sec, + settings.heartbeat_interval_sec, + ) + ) + + # Track handler for incoming audio + @pc.on("track") + def on_track(track): + logger.info(f"Track received: {track.kind}") + + if track.kind == "audio": + # Wrap track with resampler + wrapped_track = Resampled16kTrack(track) + + # Create task to pull audio from track + async def pull_audio(): + try: + while True: + frame = await wrapped_track.recv() + # Convert frame to bytes + pcm_bytes = frame.to_ndarray().tobytes() + # Feed to session + await session.handle_audio(pcm_bytes) + except Exception as e: + logger.error(f"Error pulling audio from track: {e}") + + asyncio.create_task(pull_audio()) + + @pc.on("connectionstatechange") + async def on_connectionstatechange(): + logger.info(f"Connection state: {pc.connectionState}") + if pc.connectionState == "failed" or pc.connectionState == "closed": + await session.cleanup() + + try: + # Signaling loop + while True: + message = await websocket.receive() + + if "text" not in message: + continue + + last_received_at[0] = time.monotonic() + data = json.loads(message["text"]) + + # Handle SDP offer/answer + if "sdp" in data and "type" in data: + logger.info(f"Received SDP {data['type']}") + + # Set remote description + offer = RTCSessionDescription(sdp=data["sdp"], type=data["type"]) + await pc.setRemoteDescription(offer) + + # Create and set local description + if data["type"] == "offer": + answer = await pc.createAnswer() + await pc.setLocalDescription(answer) + + # Send answer back + await websocket.send_text(json.dumps({ + "event": "answer", + "trackId": session_id, + "timestamp": int(asyncio.get_event_loop().time() * 1000), + "sdp": pc.localDescription.sdp + })) + + logger.info(f"Sent SDP answer") + + else: + # Handle other commands + await session.handle_text(message["text"]) + + except WebSocketDisconnect: + logger.info(f"WebRTC WebSocket disconnected: {session_id}") + + except Exception as e: + logger.error(f"WebRTC error: {e}", exc_info=True) + + finally: + hb_task.cancel() + try: + await hb_task + except asyncio.CancelledError: + pass + # Cleanup + await pc.close() + if session_id in active_sessions: + await session.cleanup() + del active_sessions[session_id] + + logger.info(f"WebRTC session {session_id} removed") + + +@app.on_event("startup") +async def startup_event(): + """Run on application startup.""" + logger.info("Starting Python Active-Call server") + logger.info(f"Server: {settings.host}:{settings.port}") + logger.info(f"Sample rate: {settings.sample_rate} Hz") + logger.info(f"VAD model: {settings.vad_model_path}") + + +@app.on_event("shutdown") +async def shutdown_event(): + """Run on application shutdown.""" + logger.info("Shutting down Python Active-Call server") + + # Cleanup all sessions + for session_id, session in active_sessions.items(): + await session.cleanup() + + # Close event bus + event_bus = get_event_bus() + await event_bus.close() + reset_event_bus() + + logger.info("Server shutdown complete") + + +if __name__ == "__main__": + import uvicorn + + # Create logs directory + import os + os.makedirs("logs", exist_ok=True) + + # Run server + uvicorn.run( + "app.main:app", + host=settings.host, + port=settings.port, + reload=True, + log_level=settings.log_level.lower() + ) diff --git a/core/__init__.py b/core/__init__.py new file mode 100644 index 0000000..0110686 --- /dev/null +++ b/core/__init__.py @@ -0,0 +1,20 @@ +"""Core Components Package""" + +from core.events import EventBus, get_event_bus +from core.transports import BaseTransport, SocketTransport, WebRtcTransport +from core.session import Session +from core.conversation import ConversationManager, ConversationState, ConversationTurn +from core.duplex_pipeline import DuplexPipeline + +__all__ = [ + "EventBus", + "get_event_bus", + "BaseTransport", + "SocketTransport", + "WebRtcTransport", + "Session", + "ConversationManager", + "ConversationState", + "ConversationTurn", + "DuplexPipeline", +] diff --git a/core/conversation.py b/core/conversation.py new file mode 100644 index 0000000..08b23c6 --- /dev/null +++ b/core/conversation.py @@ -0,0 +1,279 @@ +"""Conversation management for voice AI. + +Handles conversation context, turn-taking, and message history +for multi-turn voice conversations. +""" + +import asyncio +from typing import List, Optional, Dict, Any, Callable, Awaitable +from dataclasses import dataclass, field +from enum import Enum +from loguru import logger + +from services.base import LLMMessage + + +class ConversationState(Enum): + """State of the conversation.""" + IDLE = "idle" # Waiting for user input + LISTENING = "listening" # User is speaking + PROCESSING = "processing" # Processing user input (LLM) + SPEAKING = "speaking" # Bot is speaking + INTERRUPTED = "interrupted" # Bot was interrupted + + +@dataclass +class ConversationTurn: + """A single turn in the conversation.""" + role: str # "user" or "assistant" + text: str + audio_duration_ms: Optional[int] = None + timestamp: float = field(default_factory=lambda: asyncio.get_event_loop().time()) + was_interrupted: bool = False + + +class ConversationManager: + """ + Manages conversation state and history. + + Provides: + - Message history for LLM context + - Turn management + - State tracking + - Event callbacks for state changes + """ + + def __init__( + self, + system_prompt: Optional[str] = None, + max_history: int = 20, + greeting: Optional[str] = None + ): + """ + Initialize conversation manager. + + Args: + system_prompt: System prompt for LLM + max_history: Maximum number of turns to keep + greeting: Optional greeting message when conversation starts + """ + self.system_prompt = system_prompt or ( + "You are a helpful, friendly voice assistant. " + "Keep your responses concise and conversational. " + "Respond naturally as if having a phone conversation. " + "If you don't understand something, ask for clarification." + ) + self.max_history = max_history + self.greeting = greeting + + # State + self.state = ConversationState.IDLE + self.turns: List[ConversationTurn] = [] + + # Callbacks + self._state_callbacks: List[Callable[[ConversationState, ConversationState], Awaitable[None]]] = [] + self._turn_callbacks: List[Callable[[ConversationTurn], Awaitable[None]]] = [] + + # Current turn tracking + self._current_user_text: str = "" + self._current_assistant_text: str = "" + + logger.info("ConversationManager initialized") + + def on_state_change( + self, + callback: Callable[[ConversationState, ConversationState], Awaitable[None]] + ) -> None: + """Register callback for state changes.""" + self._state_callbacks.append(callback) + + def on_turn_complete( + self, + callback: Callable[[ConversationTurn], Awaitable[None]] + ) -> None: + """Register callback for turn completion.""" + self._turn_callbacks.append(callback) + + async def set_state(self, new_state: ConversationState) -> None: + """Set conversation state and notify listeners.""" + if new_state != self.state: + old_state = self.state + self.state = new_state + logger.debug(f"Conversation state: {old_state.value} -> {new_state.value}") + + for callback in self._state_callbacks: + try: + await callback(old_state, new_state) + except Exception as e: + logger.error(f"State callback error: {e}") + + def get_messages(self) -> List[LLMMessage]: + """ + Get conversation history as LLM messages. + + Returns: + List of LLMMessage objects including system prompt + """ + messages = [LLMMessage(role="system", content=self.system_prompt)] + + # Add conversation history + for turn in self.turns[-self.max_history:]: + messages.append(LLMMessage(role=turn.role, content=turn.text)) + + # Add current user text if any + if self._current_user_text: + messages.append(LLMMessage(role="user", content=self._current_user_text)) + + return messages + + async def start_user_turn(self) -> None: + """Signal that user has started speaking.""" + await self.set_state(ConversationState.LISTENING) + self._current_user_text = "" + + async def update_user_text(self, text: str, is_final: bool = False) -> None: + """ + Update current user text (from ASR). + + Args: + text: Transcribed text + is_final: Whether this is the final transcript + """ + self._current_user_text = text + + async def end_user_turn(self, text: str) -> None: + """ + End user turn and add to history. + + Args: + text: Final user text + """ + if text.strip(): + turn = ConversationTurn(role="user", text=text.strip()) + self.turns.append(turn) + + for callback in self._turn_callbacks: + try: + await callback(turn) + except Exception as e: + logger.error(f"Turn callback error: {e}") + + logger.info(f"User: {text[:50]}...") + + self._current_user_text = "" + await self.set_state(ConversationState.PROCESSING) + + async def start_assistant_turn(self) -> None: + """Signal that assistant has started speaking.""" + await self.set_state(ConversationState.SPEAKING) + self._current_assistant_text = "" + + async def update_assistant_text(self, text: str) -> None: + """ + Update current assistant text (streaming). + + Args: + text: Text chunk from LLM + """ + self._current_assistant_text += text + + async def end_assistant_turn(self, was_interrupted: bool = False) -> None: + """ + End assistant turn and add to history. + + Args: + was_interrupted: Whether the turn was interrupted by user + """ + text = self._current_assistant_text.strip() + if text: + turn = ConversationTurn( + role="assistant", + text=text, + was_interrupted=was_interrupted + ) + self.turns.append(turn) + + for callback in self._turn_callbacks: + try: + await callback(turn) + except Exception as e: + logger.error(f"Turn callback error: {e}") + + status = " (interrupted)" if was_interrupted else "" + logger.info(f"Assistant{status}: {text[:50]}...") + + self._current_assistant_text = "" + + if was_interrupted: + # A new user turn may already be active (LISTENING) when interrupted. + # Avoid overriding it back to INTERRUPTED, which can stall EOU flow. + if self.state != ConversationState.LISTENING: + await self.set_state(ConversationState.INTERRUPTED) + else: + await self.set_state(ConversationState.IDLE) + + async def add_assistant_turn(self, text: str, was_interrupted: bool = False) -> None: + """Append an assistant turn directly without mutating conversation state.""" + content = text.strip() + if not content: + return + + turn = ConversationTurn( + role="assistant", + text=content, + was_interrupted=was_interrupted, + ) + self.turns.append(turn) + + for callback in self._turn_callbacks: + try: + await callback(turn) + except Exception as e: + logger.error(f"Turn callback error: {e}") + + logger.info(f"Assistant (injected): {content[:50]}...") + + async def interrupt(self) -> None: + """Handle interruption (barge-in).""" + if self.state == ConversationState.SPEAKING: + await self.end_assistant_turn(was_interrupted=True) + + def reset(self) -> None: + """Reset conversation history.""" + self.turns = [] + self._current_user_text = "" + self._current_assistant_text = "" + self.state = ConversationState.IDLE + logger.info("Conversation reset") + + @property + def turn_count(self) -> int: + """Get number of turns in conversation.""" + return len(self.turns) + + @property + def last_user_text(self) -> Optional[str]: + """Get last user text.""" + for turn in reversed(self.turns): + if turn.role == "user": + return turn.text + return None + + @property + def last_assistant_text(self) -> Optional[str]: + """Get last assistant text.""" + for turn in reversed(self.turns): + if turn.role == "assistant": + return turn.text + return None + + def get_context_summary(self) -> Dict[str, Any]: + """Get a summary of conversation context.""" + return { + "state": self.state.value, + "turn_count": self.turn_count, + "last_user": self.last_user_text, + "last_assistant": self.last_assistant_text, + "current_user": self._current_user_text or None, + "current_assistant": self._current_assistant_text or None + } diff --git a/core/duplex_pipeline.py b/core/duplex_pipeline.py new file mode 100644 index 0000000..508ba2b --- /dev/null +++ b/core/duplex_pipeline.py @@ -0,0 +1,1507 @@ +"""Full duplex audio pipeline for AI voice conversation. + +This module implements the core duplex pipeline that orchestrates: +- VAD (Voice Activity Detection) +- EOU (End of Utterance) Detection +- ASR (Automatic Speech Recognition) - optional +- LLM (Language Model) +- TTS (Text-to-Speech) + +Inspired by pipecat's frame-based architecture and active-call's +event-driven design. +""" + +import asyncio +import json +import time +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +from loguru import logger + +from app.config import settings +from core.conversation import ConversationManager, ConversationState +from core.events import get_event_bus +from core.tool_executor import execute_server_tool +from core.transports import BaseTransport +from models.ws_v1 import ev +from processors.eou import EouDetector +from processors.vad import SileroVAD, VADProcessor +from services.asr import BufferedASRService +from services.base import BaseASRService, BaseLLMService, BaseTTSService, LLMMessage, LLMStreamEvent +from services.llm import MockLLMService, OpenAILLMService +from services.openai_compatible_asr import OpenAICompatibleASRService +from services.openai_compatible_tts import OpenAICompatibleTTSService +from services.streaming_text import extract_tts_sentence, has_spoken_content +from services.tts import EdgeTTSService, MockTTSService + + +class DuplexPipeline: + """ + Full duplex audio pipeline for AI voice conversation. + + Handles bidirectional audio flow with: + - User speech detection and transcription + - AI response generation + - Text-to-speech synthesis + - Barge-in (interruption) support + + Architecture (inspired by pipecat): + + User Audio → VAD → EOU → [ASR] → LLM → TTS → Audio Out + ↓ + Barge-in Detection → Interrupt + """ + + _SENTENCE_END_CHARS = frozenset({"。", "!", "?", ".", "!", "?", "\n"}) + _SENTENCE_TRAILING_CHARS = frozenset({"。", "!", "?", ".", "!", "?", "…", "~", "~", "\n"}) + _SENTENCE_CLOSERS = frozenset({'"', "'", "”", "’", ")", "]", "}", ")", "】", "」", "』", "》"}) + _MIN_SPLIT_SPOKEN_CHARS = 6 + _TOOL_WAIT_TIMEOUT_SECONDS = 15.0 + _SERVER_TOOL_TIMEOUT_SECONDS = 15.0 + _DEFAULT_TOOL_SCHEMAS: Dict[str, Dict[str, Any]] = { + "current_time": { + "name": "current_time", + "description": "Get current local time", + "parameters": { + "type": "object", + "properties": {}, + "required": [], + }, + }, + } + + def __init__( + self, + transport: BaseTransport, + session_id: str, + llm_service: Optional[BaseLLMService] = None, + tts_service: Optional[BaseTTSService] = None, + asr_service: Optional[BaseASRService] = None, + system_prompt: Optional[str] = None, + greeting: Optional[str] = None + ): + """ + Initialize duplex pipeline. + + Args: + transport: Transport for sending audio/events + session_id: Session identifier + llm_service: LLM service (defaults to OpenAI) + tts_service: TTS service (defaults to EdgeTTS) + asr_service: ASR service (optional) + system_prompt: System prompt for LLM + greeting: Optional greeting to speak on start + """ + self.transport = transport + self.session_id = session_id + self.event_bus = get_event_bus() + + # Initialize VAD + self.vad_model = SileroVAD( + model_path=settings.vad_model_path, + sample_rate=settings.sample_rate + ) + self.vad_processor = VADProcessor( + vad_model=self.vad_model, + threshold=settings.vad_threshold + ) + + # Initialize EOU detector + self.eou_detector = EouDetector( + silence_threshold_ms=settings.vad_eou_threshold_ms, + min_speech_duration_ms=settings.vad_min_speech_duration_ms + ) + + # Initialize services + self.llm_service = llm_service + self.tts_service = tts_service + self.asr_service = asr_service # Will be initialized in start() + + # Track last sent transcript to avoid duplicates + self._last_sent_transcript = "" + + # Conversation manager + self.conversation = ConversationManager( + system_prompt=system_prompt, + greeting=greeting + ) + + # State + self._running = True + self._is_bot_speaking = False + self._current_turn_task: Optional[asyncio.Task] = None + self._audio_buffer: bytes = b"" + max_buffer_seconds = settings.max_audio_buffer_seconds + self._max_audio_buffer_bytes = int(settings.sample_rate * 2 * max_buffer_seconds) + self._asr_start_min_speech_ms: int = settings.asr_start_min_speech_ms + self._asr_capture_active: bool = False + self._pending_speech_audio: bytes = b"" + # Keep a short rolling pre-speech window so VAD transition latency + # does not clip the first phoneme/character sent to ASR. + pre_speech_ms = settings.asr_pre_speech_ms + self._asr_pre_speech_bytes = int(settings.sample_rate * 2 * (pre_speech_ms / 1000.0)) + self._pre_speech_buffer: bytes = b"" + # Add a tiny trailing silence tail before final ASR to avoid + # clipping the last phoneme at utterance boundaries. + asr_final_tail_ms = settings.asr_final_tail_ms + self._asr_final_tail_bytes = int(settings.sample_rate * 2 * (asr_final_tail_ms / 1000.0)) + self._last_vad_status: str = "Silence" + self._process_lock = asyncio.Lock() + # Priority outbound dispatcher (lower value = higher priority). + self._outbound_q: asyncio.PriorityQueue[Tuple[int, int, str, Any]] = asyncio.PriorityQueue() + self._outbound_seq = 0 + self._outbound_task: Optional[asyncio.Task] = None + self._drop_outbound_audio = False + + # Interruption handling + self._interrupt_event = asyncio.Event() + + # Latency tracking - TTFB (Time to First Byte) + self._turn_start_time: Optional[float] = None + self._first_audio_sent: bool = False + + # Barge-in filtering - require minimum speech duration to interrupt + self._barge_in_speech_start_time: Optional[float] = None + self._barge_in_min_duration_ms: int = settings.barge_in_min_duration_ms + self._barge_in_silence_tolerance_ms: int = settings.barge_in_silence_tolerance_ms + self._barge_in_speech_frames: int = 0 # Count speech frames + self._barge_in_silence_frames: int = 0 # Count silence frames during potential barge-in + + # Runtime overrides injected from session.start metadata + self._runtime_llm: Dict[str, Any] = {} + self._runtime_asr: Dict[str, Any] = {} + self._runtime_tts: Dict[str, Any] = {} + self._runtime_output: Dict[str, Any] = {} + self._runtime_system_prompt: Optional[str] = None + self._runtime_first_turn_mode: str = "bot_first" + self._runtime_greeting: Optional[str] = None + self._runtime_generated_opener_enabled: Optional[bool] = None + self._runtime_barge_in_enabled: Optional[bool] = None + self._runtime_barge_in_min_duration_ms: Optional[int] = None + self._runtime_knowledge: Dict[str, Any] = {} + self._runtime_knowledge_base_id: Optional[str] = None + self._runtime_tools: List[Any] = [] + self._runtime_tool_executor: Dict[str, str] = {} + self._pending_tool_waiters: Dict[str, asyncio.Future] = {} + self._early_tool_results: Dict[str, Dict[str, Any]] = {} + self._completed_tool_call_ids: set[str] = set() + + logger.info(f"DuplexPipeline initialized for session {session_id}") + + def apply_runtime_overrides(self, metadata: Optional[Dict[str, Any]]) -> None: + """ + Apply runtime overrides from WS session.start metadata. + + Expected metadata shape: + { + "systemPrompt": "...", + "greeting": "...", + "services": { + "llm": {...}, + "asr": {...}, + "tts": {...} + } + } + """ + if not metadata: + return + + if "systemPrompt" in metadata: + self._runtime_system_prompt = str(metadata.get("systemPrompt") or "") + if self._runtime_system_prompt: + self.conversation.system_prompt = self._runtime_system_prompt + if "firstTurnMode" in metadata: + raw_mode = str(metadata.get("firstTurnMode") or "").strip().lower() + self._runtime_first_turn_mode = "user_first" if raw_mode == "user_first" else "bot_first" + if "greeting" in metadata: + greeting_payload = metadata.get("greeting") + if isinstance(greeting_payload, dict): + self._runtime_greeting = str(greeting_payload.get("text") or "") + generated_flag = self._coerce_bool(greeting_payload.get("generated")) + if generated_flag is not None: + self._runtime_generated_opener_enabled = generated_flag + else: + self._runtime_greeting = str(greeting_payload or "") + self.conversation.greeting = self._runtime_greeting or None + generated_opener_flag = self._coerce_bool(metadata.get("generatedOpenerEnabled")) + if generated_opener_flag is not None: + self._runtime_generated_opener_enabled = generated_opener_flag + + services = metadata.get("services") or {} + if isinstance(services, dict): + if isinstance(services.get("llm"), dict): + self._runtime_llm = services["llm"] + if isinstance(services.get("asr"), dict): + self._runtime_asr = services["asr"] + if isinstance(services.get("tts"), dict): + self._runtime_tts = services["tts"] + output = metadata.get("output") or {} + if isinstance(output, dict): + self._runtime_output = output + barge_in = metadata.get("bargeIn") + if isinstance(barge_in, dict): + barge_in_enabled = self._coerce_bool(barge_in.get("enabled")) + if barge_in_enabled is not None: + self._runtime_barge_in_enabled = barge_in_enabled + min_duration = barge_in.get("minDurationMs") + if isinstance(min_duration, (int, float, str)): + try: + self._runtime_barge_in_min_duration_ms = max(0, int(min_duration)) + except (TypeError, ValueError): + self._runtime_barge_in_min_duration_ms = None + + knowledge_base_id = metadata.get("knowledgeBaseId") + if knowledge_base_id is not None: + kb_id = str(knowledge_base_id).strip() + self._runtime_knowledge_base_id = kb_id or None + + knowledge = metadata.get("knowledge") + if isinstance(knowledge, dict): + self._runtime_knowledge = knowledge + kb_id = str(knowledge.get("kbId") or knowledge.get("knowledgeBaseId") or "").strip() + if kb_id: + self._runtime_knowledge_base_id = kb_id + + tools_payload = metadata.get("tools") + if isinstance(tools_payload, list): + self._runtime_tools = tools_payload + self._runtime_tool_executor = self._resolved_tool_executor_map() + elif "tools" in metadata: + self._runtime_tools = [] + self._runtime_tool_executor = {} + + if self.llm_service and hasattr(self.llm_service, "set_knowledge_config"): + self.llm_service.set_knowledge_config(self._resolved_knowledge_config()) + if self.llm_service and hasattr(self.llm_service, "set_tool_schemas"): + self.llm_service.set_tool_schemas(self._resolved_tool_schemas()) + + @staticmethod + def _coerce_bool(value: Any) -> Optional[bool]: + if isinstance(value, bool): + return value + if isinstance(value, (int, float)): + return bool(value) + if isinstance(value, str): + normalized = value.strip().lower() + if normalized in {"1", "true", "yes", "on", "enabled"}: + return True + if normalized in {"0", "false", "no", "off", "disabled"}: + return False + return None + + @staticmethod + def _is_openai_compatible_provider(provider: Any) -> bool: + normalized = str(provider or "").strip().lower() + return normalized in {"openai_compatible", "openai-compatible", "siliconflow"} + + def _tts_output_enabled(self) -> bool: + enabled = self._coerce_bool(self._runtime_tts.get("enabled")) + if enabled is not None: + return enabled + + output_mode = str(self._runtime_output.get("mode") or "").strip().lower() + if output_mode in {"text", "text_only", "text-only"}: + return False + + return True + + def _generated_opener_enabled(self) -> bool: + return self._runtime_generated_opener_enabled is True + + def _bot_starts_first(self) -> bool: + return self._runtime_first_turn_mode != "user_first" + + def _barge_in_enabled(self) -> bool: + if self._runtime_barge_in_enabled is not None: + return self._runtime_barge_in_enabled + return True + + def _resolved_barge_in_min_duration_ms(self) -> int: + if self._runtime_barge_in_min_duration_ms is not None: + return self._runtime_barge_in_min_duration_ms + return self._barge_in_min_duration_ms + + def _barge_in_silence_tolerance_frames(self) -> int: + """Convert silence tolerance from ms to frame count using current chunk size.""" + chunk_ms = max(1, settings.chunk_size_ms) + return max(1, int(np.ceil(self._barge_in_silence_tolerance_ms / chunk_ms))) + + async def _generate_runtime_greeting(self) -> Optional[str]: + if not self.llm_service: + return None + + prompt_hint = (self._runtime_greeting or "").strip() + system_context = (self.conversation.system_prompt or self._runtime_system_prompt or "").strip() + # Keep context concise to avoid overloading greeting generation. + if len(system_context) > 1200: + system_context = system_context[:1200] + system_prompt = ( + "你是语音通话助手的开场白生成器。" + "请只输出一句自然、简洁、友好的中文开场白。" + "不要使用引号,不要使用 markdown,不要加解释。" + ) + user_prompt = "请生成一句中文开场白(不超过25个汉字)。" + if system_context: + user_prompt += f"\n\n以下是该助手的系统提示词,请据此决定语气、角色和边界:\n{system_context}" + if prompt_hint: + user_prompt += f"\n\n额外风格提示:{prompt_hint}" + + try: + generated = await self.llm_service.generate( + [ + LLMMessage(role="system", content=system_prompt), + LLMMessage(role="user", content=user_prompt), + ], + temperature=0.7, + max_tokens=64, + ) + except Exception as exc: + logger.warning(f"Failed to generate runtime greeting: {exc}") + return None + + text = (generated or "").strip() + if not text: + return None + return text.strip().strip('"').strip("'") + + async def start(self) -> None: + """Start the pipeline and connect services.""" + try: + # Connect LLM service + if not self.llm_service: + llm_api_key = self._runtime_llm.get("apiKey") or settings.openai_api_key + llm_base_url = self._runtime_llm.get("baseUrl") or settings.openai_api_url + llm_model = self._runtime_llm.get("model") or settings.llm_model + llm_provider = (self._runtime_llm.get("provider") or "openai").lower() + + if llm_provider == "openai" and llm_api_key: + self.llm_service = OpenAILLMService( + api_key=llm_api_key, + base_url=llm_base_url, + model=llm_model, + knowledge_config=self._resolved_knowledge_config(), + ) + else: + logger.warning("No OpenAI API key - using mock LLM") + self.llm_service = MockLLMService() + + if hasattr(self.llm_service, "set_knowledge_config"): + self.llm_service.set_knowledge_config(self._resolved_knowledge_config()) + if hasattr(self.llm_service, "set_tool_schemas"): + self.llm_service.set_tool_schemas(self._resolved_tool_schemas()) + + await self.llm_service.connect() + + tts_output_enabled = self._tts_output_enabled() + + # Connect TTS service only when audio output is enabled. + if tts_output_enabled: + if not self.tts_service: + tts_provider = (self._runtime_tts.get("provider") or settings.tts_provider).lower() + tts_api_key = self._runtime_tts.get("apiKey") or settings.siliconflow_api_key + tts_voice = self._runtime_tts.get("voice") or settings.tts_voice + tts_model = self._runtime_tts.get("model") or settings.siliconflow_tts_model + tts_speed = float(self._runtime_tts.get("speed") or settings.tts_speed) + + if self._is_openai_compatible_provider(tts_provider) and tts_api_key: + self.tts_service = OpenAICompatibleTTSService( + api_key=tts_api_key, + voice=tts_voice, + model=tts_model, + sample_rate=settings.sample_rate, + speed=tts_speed + ) + logger.info("Using OpenAI-compatible TTS service (SiliconFlow implementation)") + else: + self.tts_service = EdgeTTSService( + voice=tts_voice, + sample_rate=settings.sample_rate + ) + logger.info("Using Edge TTS service") + + try: + await self.tts_service.connect() + except Exception as e: + logger.warning(f"TTS backend unavailable ({e}); falling back to MockTTS") + self.tts_service = MockTTSService( + sample_rate=settings.sample_rate + ) + await self.tts_service.connect() + else: + self.tts_service = None + logger.info("TTS output disabled by runtime metadata") + + # Connect ASR service + if not self.asr_service: + asr_provider = (self._runtime_asr.get("provider") or settings.asr_provider).lower() + asr_api_key = self._runtime_asr.get("apiKey") or settings.siliconflow_api_key + asr_model = self._runtime_asr.get("model") or settings.siliconflow_asr_model + asr_interim_interval = int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms) + asr_min_audio_ms = int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms) + + if self._is_openai_compatible_provider(asr_provider) and asr_api_key: + self.asr_service = OpenAICompatibleASRService( + api_key=asr_api_key, + model=asr_model, + sample_rate=settings.sample_rate, + interim_interval_ms=asr_interim_interval, + min_audio_for_interim_ms=asr_min_audio_ms, + on_transcript=self._on_transcript_callback + ) + logger.info("Using OpenAI-compatible ASR service (SiliconFlow implementation)") + else: + self.asr_service = BufferedASRService( + sample_rate=settings.sample_rate + ) + logger.info("Using Buffered ASR service (no real transcription)") + + await self.asr_service.connect() + + logger.info("DuplexPipeline services connected") + if not self._outbound_task or self._outbound_task.done(): + self._outbound_task = asyncio.create_task(self._outbound_loop()) + + # Resolve greeting once per session start. + # Always emit text opener event so text-only sessions can display it. + if self._bot_starts_first(): + greeting_to_speak = self.conversation.greeting + if self._generated_opener_enabled(): + generated_greeting = await self._generate_runtime_greeting() + if generated_greeting: + greeting_to_speak = generated_greeting + self.conversation.greeting = generated_greeting + if greeting_to_speak: + await self._send_event( + ev( + "assistant.response.final", + text=greeting_to_speak, + trackId=self.session_id, + ), + priority=20, + ) + await self.conversation.add_assistant_turn(greeting_to_speak) + if tts_output_enabled: + await self._speak(greeting_to_speak) + + except Exception as e: + logger.error(f"Failed to start pipeline: {e}") + raise + + async def _enqueue_outbound(self, kind: str, payload: Any, priority: int) -> None: + """Queue outbound message with priority ordering.""" + self._outbound_seq += 1 + await self._outbound_q.put((priority, self._outbound_seq, kind, payload)) + + async def _send_event(self, event: Dict[str, Any], priority: int = 20) -> None: + await self._enqueue_outbound("event", event, priority) + + async def _send_audio(self, pcm_bytes: bytes, priority: int = 50) -> None: + await self._enqueue_outbound("audio", pcm_bytes, priority) + + async def _outbound_loop(self) -> None: + """Single sender loop that enforces priority for interrupt events.""" + while True: + _priority, _seq, kind, payload = await self._outbound_q.get() + try: + if kind == "stop": + return + if kind == "audio": + if self._drop_outbound_audio: + continue + await self.transport.send_audio(payload) + elif kind == "event": + await self.transport.send_event(payload) + except Exception as e: + logger.error(f"Outbound send error ({kind}): {e}") + finally: + self._outbound_q.task_done() + + async def process_audio(self, pcm_bytes: bytes) -> None: + """ + Process incoming audio chunk. + + This is the main entry point for audio from the user. + + Args: + pcm_bytes: PCM audio data (16-bit, mono, 16kHz) + """ + if not self._running: + return + + try: + async with self._process_lock: + if pcm_bytes: + self._pre_speech_buffer += pcm_bytes + if len(self._pre_speech_buffer) > self._asr_pre_speech_bytes: + self._pre_speech_buffer = self._pre_speech_buffer[-self._asr_pre_speech_bytes:] + + # 1. Process through VAD + vad_result = self.vad_processor.process(pcm_bytes, settings.chunk_size_ms) + + vad_status = "Silence" + if vad_result: + event_type, probability = vad_result + vad_status = "Speech" if event_type == "speaking" else "Silence" + + # Emit VAD event + await self.event_bus.publish(event_type, { + "trackId": self.session_id, + "probability": probability + }) + await self._send_event( + ev( + "input.speech_started" if event_type == "speaking" else "input.speech_stopped", + trackId=self.session_id, + probability=probability, + ), + priority=30, + ) + else: + # No state change - keep previous status + vad_status = self._last_vad_status + + # Update state based on VAD + if vad_status == "Speech" and self._last_vad_status != "Speech": + await self._on_speech_start() + + self._last_vad_status = vad_status + + # 2. Check for barge-in (user speaking while bot speaking) + # Filter false interruptions by requiring minimum speech duration + if self._is_bot_speaking and self._barge_in_enabled(): + if vad_status == "Speech": + # User is speaking while bot is speaking + self._barge_in_silence_frames = 0 # Reset silence counter + + if self._barge_in_speech_start_time is None: + # Start tracking speech duration + self._barge_in_speech_start_time = time.time() + self._barge_in_speech_frames = 1 + logger.debug("Potential barge-in detected, tracking duration...") + else: + self._barge_in_speech_frames += 1 + # Check if speech duration exceeds threshold + speech_duration_ms = (time.time() - self._barge_in_speech_start_time) * 1000 + if speech_duration_ms >= self._resolved_barge_in_min_duration_ms(): + logger.info(f"Barge-in confirmed after {speech_duration_ms:.0f}ms of speech ({self._barge_in_speech_frames} frames)") + await self._handle_barge_in() + else: + # Silence frame during potential barge-in + if self._barge_in_speech_start_time is not None: + self._barge_in_silence_frames += 1 + # Allow brief silence gaps (VAD flickering) + if self._barge_in_silence_frames > self._barge_in_silence_tolerance_frames(): + # Too much silence - reset barge-in tracking + logger.debug(f"Barge-in cancelled after {self._barge_in_silence_frames} silence frames") + self._barge_in_speech_start_time = None + self._barge_in_speech_frames = 0 + self._barge_in_silence_frames = 0 + elif self._is_bot_speaking and not self._barge_in_enabled(): + self._barge_in_speech_start_time = None + self._barge_in_speech_frames = 0 + self._barge_in_silence_frames = 0 + + # 3. Buffer audio for ASR. + # Gate ASR startup by a short speech-duration threshold to reduce + # false positives from micro noises, then always close the turn + # by EOU once ASR has started. + just_started_asr = False + if vad_status == "Speech" and not self._asr_capture_active: + self._pending_speech_audio += pcm_bytes + pending_ms = (len(self._pending_speech_audio) / (settings.sample_rate * 2)) * 1000.0 + if pending_ms >= self._asr_start_min_speech_ms: + await self._start_asr_capture() + just_started_asr = True + + if self._asr_capture_active: + if not just_started_asr: + self._audio_buffer += pcm_bytes + if len(self._audio_buffer) > self._max_audio_buffer_bytes: + # Keep only the most recent audio to cap memory usage + self._audio_buffer = self._audio_buffer[-self._max_audio_buffer_bytes:] + await self.asr_service.send_audio(pcm_bytes) + + # For SiliconFlow ASR, trigger interim transcription periodically + # The service handles timing internally via start_interim_transcription() + + # 4. Check for End of Utterance - this triggers LLM response + if self.eou_detector.process(vad_status, force_eligible=self._asr_capture_active): + await self._on_end_of_utterance() + elif ( + vad_status == "Silence" + and not self.eou_detector.is_speaking + and not self._asr_capture_active + and self.conversation.state == ConversationState.LISTENING + ): + # Speech was too short to pass ASR gate; reset turn so next + # utterance can start cleanly. + self._pending_speech_audio = b"" + self._audio_buffer = b"" + self._last_sent_transcript = "" + await self.conversation.set_state(ConversationState.IDLE) + + except Exception as e: + logger.error(f"Pipeline audio processing error: {e}", exc_info=True) + + async def process_text(self, text: str) -> None: + """ + Process text input (chat command). + + Allows direct text input to bypass ASR. + + Args: + text: User text input + """ + if not self._running: + return + + logger.info(f"Processing text input: {text[:50]}...") + + # Cancel any current speaking + await self._stop_current_speech() + + # Start new turn + await self.conversation.end_user_turn(text) + self._current_turn_task = asyncio.create_task(self._handle_turn(text)) + + async def interrupt(self) -> None: + """Interrupt current bot speech (manual interrupt command).""" + await self._handle_barge_in() + + async def _on_transcript_callback(self, text: str, is_final: bool) -> None: + """ + Callback for ASR transcription results. + + Streams transcription to client for display. + + Args: + text: Transcribed text + is_final: Whether this is the final transcription + """ + # Avoid sending duplicate transcripts + if text == self._last_sent_transcript and not is_final: + return + + self._last_sent_transcript = text + + # Send transcript event to client + await self._send_event({ + **ev( + "transcript.final" if is_final else "transcript.delta", + trackId=self.session_id, + text=text, + ) + }, priority=30) + + if not is_final: + logger.info(f"[ASR] ASR interim: {text[:100]}") + logger.debug(f"Sent transcript ({'final' if is_final else 'interim'}): {text[:50]}...") + + async def _on_speech_start(self) -> None: + """Handle user starting to speak.""" + if self.conversation.state in (ConversationState.IDLE, ConversationState.INTERRUPTED): + await self.conversation.start_user_turn() + self._audio_buffer = b"" + self._last_sent_transcript = "" + self.eou_detector.reset() + self._asr_capture_active = False + self._pending_speech_audio = b"" + + # Clear ASR buffer. Interim starts only after ASR capture is activated. + if hasattr(self.asr_service, 'clear_buffer'): + self.asr_service.clear_buffer() + + logger.debug("User speech started") + + async def _start_asr_capture(self) -> None: + """Start ASR capture for the current turn after min speech gate passes.""" + if self._asr_capture_active: + return + + if hasattr(self.asr_service, 'start_interim_transcription'): + await self.asr_service.start_interim_transcription() + + # Prime ASR with a short pre-speech context window so the utterance + # start isn't lost while waiting for VAD to transition to Speech. + pre_roll = self._pre_speech_buffer + # _pre_speech_buffer already includes current speech frames; avoid + # duplicating onset audio when we append pending speech below. + if self._pending_speech_audio and len(pre_roll) > len(self._pending_speech_audio): + pre_roll = pre_roll[:-len(self._pending_speech_audio)] + elif self._pending_speech_audio: + pre_roll = b"" + capture_audio = pre_roll + self._pending_speech_audio + if capture_audio: + await self.asr_service.send_audio(capture_audio) + self._audio_buffer = capture_audio[-self._max_audio_buffer_bytes:] + + self._asr_capture_active = True + logger.debug( + f"ASR capture started after speech gate ({self._asr_start_min_speech_ms}ms), " + f"capture={len(capture_audio)} bytes" + ) + + async def _on_end_of_utterance(self) -> None: + """Handle end of user utterance.""" + if self.conversation.state not in (ConversationState.LISTENING, ConversationState.INTERRUPTED): + return + + # Add a tiny trailing silence tail to stabilize final-token decoding. + if self._asr_final_tail_bytes > 0: + final_tail = b"\x00" * self._asr_final_tail_bytes + await self.asr_service.send_audio(final_tail) + + # Stop interim transcriptions + if hasattr(self.asr_service, 'stop_interim_transcription'): + await self.asr_service.stop_interim_transcription() + + # Get final transcription from ASR service + user_text = "" + + if hasattr(self.asr_service, 'get_final_transcription'): + # SiliconFlow ASR - get final transcription + user_text = await self.asr_service.get_final_transcription() + elif hasattr(self.asr_service, 'get_and_clear_text'): + # Buffered ASR - get accumulated text + user_text = self.asr_service.get_and_clear_text() + + # Skip if no meaningful text + if not user_text or not user_text.strip(): + logger.debug("[EOU] Detected but no transcription - skipping") + # Reset for next utterance + self._audio_buffer = b"" + self._last_sent_transcript = "" + self._asr_capture_active = False + self._pending_speech_audio = b"" + # Return to idle; don't force LISTENING which causes buffering on silence + await self.conversation.set_state(ConversationState.IDLE) + return + + logger.info(f"[EOU] Detected - user said: {user_text[:100]}...") + + # For ASR backends that already emitted final via callback, + # avoid duplicating transcript.final on EOU. + if user_text != self._last_sent_transcript: + await self._send_event({ + **ev( + "transcript.final", + trackId=self.session_id, + text=user_text, + ) + }, priority=25) + + # Clear buffers + self._audio_buffer = b"" + self._last_sent_transcript = "" + self._asr_capture_active = False + self._pending_speech_audio = b"" + + # Process the turn - trigger LLM response + # Cancel any existing turn to avoid overlapping assistant responses + await self._stop_current_speech() + await self.conversation.end_user_turn(user_text) + self._current_turn_task = asyncio.create_task(self._handle_turn(user_text)) + + def _resolved_knowledge_config(self) -> Dict[str, Any]: + cfg: Dict[str, Any] = {} + if isinstance(self._runtime_knowledge, dict): + cfg.update(self._runtime_knowledge) + kb_id = self._runtime_knowledge_base_id or str( + cfg.get("kbId") or cfg.get("knowledgeBaseId") or "" + ).strip() + if kb_id: + cfg["kbId"] = kb_id + cfg.setdefault("enabled", True) + return cfg + + def _resolved_tool_schemas(self) -> List[Dict[str, Any]]: + schemas: List[Dict[str, Any]] = [] + for item in self._runtime_tools: + if isinstance(item, str): + base = self._DEFAULT_TOOL_SCHEMAS.get(item) + if base: + schemas.append( + { + "type": "function", + "function": { + "name": base["name"], + "description": base.get("description") or "", + "parameters": base.get("parameters") or {"type": "object", "properties": {}}, + }, + } + ) + continue + + if not isinstance(item, dict): + continue + + fn = item.get("function") + if isinstance(fn, dict) and fn.get("name"): + fn_name = str(fn.get("name")) + schemas.append( + { + "type": "function", + "function": { + "name": str(fn.get("name")), + "description": str(fn.get("description") or item.get("description") or ""), + "parameters": fn.get("parameters") or {"type": "object", "properties": {}}, + }, + } + ) + continue + + if item.get("name"): + schemas.append( + { + "type": "function", + "function": { + "name": str(item.get("name")), + "description": str(item.get("description") or ""), + "parameters": item.get("parameters") or {"type": "object", "properties": {}}, + }, + } + ) + return schemas + + def _resolved_tool_executor_map(self) -> Dict[str, str]: + result: Dict[str, str] = {} + for item in self._runtime_tools: + if not isinstance(item, dict): + continue + fn = item.get("function") + if isinstance(fn, dict) and fn.get("name"): + name = str(fn.get("name")) + else: + name = str(item.get("name") or "").strip() + if not name: + continue + executor = str(item.get("executor") or item.get("run_on") or "").strip().lower() + if executor in {"client", "server"}: + result[name] = executor + return result + + def _tool_name(self, tool_call: Dict[str, Any]) -> str: + fn = tool_call.get("function") + if isinstance(fn, dict): + return str(fn.get("name") or "").strip() + return "" + + def _tool_executor(self, tool_call: Dict[str, Any]) -> str: + name = self._tool_name(tool_call) + if name and name in self._runtime_tool_executor: + return self._runtime_tool_executor[name] + # Default to server execution unless explicitly marked as client. + return "server" + + async def _emit_tool_result(self, result: Dict[str, Any], source: str) -> None: + tool_name = str(result.get("name") or "unknown_tool") + call_id = str(result.get("tool_call_id") or result.get("id") or "") + status = result.get("status") if isinstance(result.get("status"), dict) else {} + status_code = int(status.get("code") or 0) if status else 0 + status_message = str(status.get("message") or "") if status else "" + logger.info( + f"[Tool] emit result source={source} name={tool_name} call_id={call_id} " + f"status={status_code} {status_message}".strip() + ) + await self._send_event( + { + **ev( + "assistant.tool_result", + trackId=self.session_id, + source=source, + result=result, + ) + }, + priority=22, + ) + + async def handle_tool_call_results(self, results: List[Dict[str, Any]]) -> None: + """Handle client tool execution results.""" + if not isinstance(results, list): + return + + for item in results: + if not isinstance(item, dict): + continue + call_id = str(item.get("tool_call_id") or item.get("id") or "").strip() + if not call_id: + continue + if call_id in self._completed_tool_call_ids: + logger.debug(f"[Tool] ignore duplicate client result call_id={call_id}") + continue + status = item.get("status") if isinstance(item.get("status"), dict) else {} + status_code = int(status.get("code") or 0) if status else 0 + status_message = str(status.get("message") or "") if status else "" + tool_name = str(item.get("name") or "unknown_tool") + logger.info( + f"[Tool] received client result name={tool_name} call_id={call_id} " + f"status={status_code} {status_message}".strip() + ) + + waiter = self._pending_tool_waiters.get(call_id) + if waiter and not waiter.done(): + waiter.set_result(item) + self._completed_tool_call_ids.add(call_id) + continue + self._early_tool_results[call_id] = item + self._completed_tool_call_ids.add(call_id) + + async def _wait_for_single_tool_result(self, call_id: str) -> Dict[str, Any]: + if call_id in self._completed_tool_call_ids and call_id not in self._early_tool_results: + return { + "tool_call_id": call_id, + "status": {"code": 208, "message": "tool_call result already handled"}, + "output": "", + } + if call_id in self._early_tool_results: + self._completed_tool_call_ids.add(call_id) + return self._early_tool_results.pop(call_id) + + loop = asyncio.get_running_loop() + future = loop.create_future() + self._pending_tool_waiters[call_id] = future + try: + return await asyncio.wait_for(future, timeout=self._TOOL_WAIT_TIMEOUT_SECONDS) + except asyncio.TimeoutError: + self._completed_tool_call_ids.add(call_id) + return { + "tool_call_id": call_id, + "status": {"code": 504, "message": "tool_call timeout"}, + "output": "", + } + finally: + self._pending_tool_waiters.pop(call_id, None) + + def _normalize_stream_event(self, item: Any) -> LLMStreamEvent: + if isinstance(item, LLMStreamEvent): + return item + if isinstance(item, str): + return LLMStreamEvent(type="text_delta", text=item) + if isinstance(item, dict): + event_type = str(item.get("type") or "") + if event_type in {"text_delta", "tool_call", "done"}: + return LLMStreamEvent( + type=event_type, # type: ignore[arg-type] + text=item.get("text"), + tool_call=item.get("tool_call"), + ) + return LLMStreamEvent(type="done") + + async def _handle_turn(self, user_text: str) -> None: + """ + Handle a complete conversation turn. + + Uses sentence-by-sentence streaming TTS for lower latency. + + Args: + user_text: User's transcribed text + """ + try: + # Start latency tracking + self._turn_start_time = time.time() + self._first_audio_sent = False + + full_response = "" + messages = self.conversation.get_messages() + max_rounds = 3 + + await self.conversation.start_assistant_turn() + self._is_bot_speaking = True + self._interrupt_event.clear() + self._drop_outbound_audio = False + + first_audio_sent = False + for _ in range(max_rounds): + if self._interrupt_event.is_set(): + break + + sentence_buffer = "" + pending_punctuation = "" + round_response = "" + tool_calls: List[Dict[str, Any]] = [] + allow_text_output = True + + async for raw_event in self.llm_service.generate_stream(messages): + if self._interrupt_event.is_set(): + break + + event = self._normalize_stream_event(raw_event) + if event.type == "tool_call": + tool_call = event.tool_call if isinstance(event.tool_call, dict) else None + if not tool_call: + continue + allow_text_output = False + executor = self._tool_executor(tool_call) + enriched_tool_call = dict(tool_call) + enriched_tool_call["executor"] = executor + tool_name = self._tool_name(enriched_tool_call) or "unknown_tool" + call_id = str(enriched_tool_call.get("id") or "").strip() + fn_payload = enriched_tool_call.get("function") + raw_args = str(fn_payload.get("arguments") or "") if isinstance(fn_payload, dict) else "" + args_preview = raw_args if len(raw_args) <= 160 else f"{raw_args[:160]}..." + logger.info( + f"[Tool] call requested name={tool_name} call_id={call_id} " + f"executor={executor} args={args_preview}" + ) + tool_calls.append(enriched_tool_call) + await self._send_event( + { + **ev( + "assistant.tool_call", + trackId=self.session_id, + tool_call=enriched_tool_call, + ) + }, + priority=22, + ) + continue + + if event.type != "text_delta": + continue + + text_chunk = event.text or "" + if not text_chunk: + continue + + if not allow_text_output: + continue + + full_response += text_chunk + round_response += text_chunk + sentence_buffer += text_chunk + await self.conversation.update_assistant_text(text_chunk) + + await self._send_event( + { + **ev( + "assistant.response.delta", + trackId=self.session_id, + text=text_chunk, + ) + }, + # Keep delta/final on the same event priority so FIFO seq + # preserves stream order (avoid late-delta after final). + priority=20, + ) + + while True: + split_result = extract_tts_sentence( + sentence_buffer, + end_chars=self._SENTENCE_END_CHARS, + trailing_chars=self._SENTENCE_TRAILING_CHARS, + closers=self._SENTENCE_CLOSERS, + min_split_spoken_chars=self._MIN_SPLIT_SPOKEN_CHARS, + hold_trailing_at_buffer_end=True, + force=False, + ) + if not split_result: + break + sentence, sentence_buffer = split_result + if not sentence: + continue + + sentence = f"{pending_punctuation}{sentence}".strip() + pending_punctuation = "" + if not sentence: + continue + + if not has_spoken_content(sentence): + pending_punctuation = sentence + continue + + if self._tts_output_enabled() and not self._interrupt_event.is_set(): + if not first_audio_sent: + await self._send_event( + { + **ev( + "output.audio.start", + trackId=self.session_id, + ) + }, + priority=10, + ) + first_audio_sent = True + + await self._speak_sentence( + sentence, + fade_in_ms=0, + fade_out_ms=8, + ) + + remaining_text = f"{pending_punctuation}{sentence_buffer}".strip() + if ( + self._tts_output_enabled() + and remaining_text + and has_spoken_content(remaining_text) + and not self._interrupt_event.is_set() + ): + if not first_audio_sent: + await self._send_event( + { + **ev( + "output.audio.start", + trackId=self.session_id, + ) + }, + priority=10, + ) + first_audio_sent = True + await self._speak_sentence( + remaining_text, + fade_in_ms=0, + fade_out_ms=8, + ) + + if not tool_calls: + break + + tool_results: List[Dict[str, Any]] = [] + for call in tool_calls: + call_id = str(call.get("id") or "").strip() + if not call_id: + continue + executor = str(call.get("executor") or "server").strip().lower() + tool_name = self._tool_name(call) or "unknown_tool" + logger.info(f"[Tool] execute start name={tool_name} call_id={call_id} executor={executor}") + if executor == "client": + result = await self._wait_for_single_tool_result(call_id) + await self._emit_tool_result(result, source="client") + tool_results.append(result) + continue + + try: + result = await asyncio.wait_for( + execute_server_tool(call), + timeout=self._SERVER_TOOL_TIMEOUT_SECONDS, + ) + except asyncio.TimeoutError: + result = { + "tool_call_id": call_id, + "name": self._tool_name(call) or "unknown_tool", + "output": {"message": "server tool timeout"}, + "status": {"code": 504, "message": "server_tool_timeout"}, + } + await self._emit_tool_result(result, source="server") + tool_results.append(result) + + messages = [ + *messages, + LLMMessage( + role="assistant", + content=round_response.strip(), + ), + LLMMessage( + role="system", + content=( + "Tool execution results are available. " + "Continue answering the user naturally using these results. " + "Do not request the same tool again in this turn.\n" + f"tool_calls={json.dumps(tool_calls, ensure_ascii=False)}\n" + f"tool_results={json.dumps(tool_results, ensure_ascii=False)}" + ), + ), + ] + + if full_response and not self._interrupt_event.is_set(): + await self._send_event( + { + **ev( + "assistant.response.final", + trackId=self.session_id, + text=full_response, + ) + }, + priority=20, + ) + + # Send track end + if first_audio_sent: + await self._send_event({ + **ev( + "output.audio.end", + trackId=self.session_id, + ) + }, priority=10) + + # End assistant turn + await self.conversation.end_assistant_turn( + was_interrupted=self._interrupt_event.is_set() + ) + + except asyncio.CancelledError: + logger.info("Turn handling cancelled") + await self.conversation.end_assistant_turn(was_interrupted=True) + except Exception as e: + logger.error(f"Turn handling error: {e}", exc_info=True) + await self.conversation.end_assistant_turn(was_interrupted=True) + finally: + self._is_bot_speaking = False + # Reset barge-in tracking when bot finishes speaking + self._barge_in_speech_start_time = None + self._barge_in_speech_frames = 0 + self._barge_in_silence_frames = 0 + + async def _speak_sentence(self, text: str, fade_in_ms: int = 0, fade_out_ms: int = 8) -> None: + """ + Synthesize and send a single sentence. + + Args: + text: Sentence to speak + fade_in_ms: Fade-in duration for sentence start chunks + fade_out_ms: Fade-out duration for sentence end chunks + """ + if not self._tts_output_enabled(): + return + + if not text.strip() or self._interrupt_event.is_set() or not self.tts_service: + return + + logger.info(f"[TTS] split sentence: {text!r}") + + try: + is_first_chunk = True + async for chunk in self.tts_service.synthesize_stream(text): + # Check interrupt at the start of each iteration + if self._interrupt_event.is_set(): + logger.debug("TTS sentence interrupted") + break + + # Track and log first audio packet latency (TTFB) + if not self._first_audio_sent and self._turn_start_time: + ttfb_ms = (time.time() - self._turn_start_time) * 1000 + self._first_audio_sent = True + logger.info(f"[TTFB] Server first audio packet latency: {ttfb_ms:.0f}ms (session {self.session_id})") + + # Send TTFB event to client + await self._send_event({ + **ev( + "metrics.ttfb", + trackId=self.session_id, + latencyMs=round(ttfb_ms), + ) + }, priority=25) + + # Double-check interrupt right before sending audio + if self._interrupt_event.is_set(): + break + + smoothed_audio = self._apply_edge_fade( + pcm_bytes=chunk.audio, + sample_rate=chunk.sample_rate, + fade_in=is_first_chunk, + fade_out=bool(chunk.is_final), + fade_in_ms=fade_in_ms, + fade_out_ms=fade_out_ms, + ) + is_first_chunk = False + + await self._send_audio(smoothed_audio, priority=50) + except asyncio.CancelledError: + logger.debug("TTS sentence cancelled") + except Exception as e: + logger.error(f"TTS sentence error: {e}") + + def _apply_edge_fade( + self, + pcm_bytes: bytes, + sample_rate: int, + fade_in: bool = False, + fade_out: bool = False, + fade_in_ms: int = 0, + fade_out_ms: int = 8, + ) -> bytes: + """Apply short edge fades to reduce click/pop at sentence boundaries.""" + if not pcm_bytes or (not fade_in and not fade_out): + return pcm_bytes + + try: + samples = np.frombuffer(pcm_bytes, dtype=" 0: + fade_in_samples = int(sample_rate * (fade_in_ms / 1000.0)) + fade_in_samples = max(1, min(fade_in_samples, samples.size)) + samples[:fade_in_samples] *= np.linspace(0.0, 1.0, fade_in_samples, endpoint=True) + if fade_out: + fade_out_samples = int(sample_rate * (fade_out_ms / 1000.0)) + fade_out_samples = max(1, min(fade_out_samples, samples.size)) + samples[-fade_out_samples:] *= np.linspace(1.0, 0.0, fade_out_samples, endpoint=True) + + return np.clip(samples, -32768, 32767).astype(" None: + """ + Synthesize and send speech. + + Args: + text: Text to speak + """ + if not self._tts_output_enabled(): + return + + if not text.strip() or not self.tts_service: + return + + try: + self._drop_outbound_audio = False + # Start latency tracking for greeting + speak_start_time = time.time() + first_audio_sent = False + + # Send track start event + await self._send_event({ + **ev( + "output.audio.start", + trackId=self.session_id, + ) + }, priority=10) + + self._is_bot_speaking = True + + # Stream TTS audio + async for chunk in self.tts_service.synthesize_stream(text): + if self._interrupt_event.is_set(): + logger.info("TTS interrupted by barge-in") + break + + # Track and log first audio packet latency (TTFB) + if not first_audio_sent: + ttfb_ms = (time.time() - speak_start_time) * 1000 + first_audio_sent = True + logger.info(f"[TTFB] Greeting first audio packet latency: {ttfb_ms:.0f}ms (session {self.session_id})") + + # Send TTFB event to client + await self._send_event({ + **ev( + "metrics.ttfb", + trackId=self.session_id, + latencyMs=round(ttfb_ms), + ) + }, priority=25) + + # Send audio to client + await self._send_audio(chunk.audio, priority=50) + + # Small delay to prevent flooding + await asyncio.sleep(0.01) + + # Send track end event + await self._send_event({ + **ev( + "output.audio.end", + trackId=self.session_id, + ) + }, priority=10) + + except asyncio.CancelledError: + logger.info("TTS cancelled") + raise + except Exception as e: + logger.error(f"TTS error: {e}") + finally: + self._is_bot_speaking = False + + async def _handle_barge_in(self) -> None: + """Handle user barge-in (interruption).""" + if not self._is_bot_speaking: + return + + logger.info("Barge-in detected - interrupting bot speech") + + # Reset barge-in tracking + self._barge_in_speech_start_time = None + self._barge_in_speech_frames = 0 + self._barge_in_silence_frames = 0 + + # IMPORTANT: Signal interruption FIRST to stop audio sending + self._interrupt_event.set() + self._is_bot_speaking = False + self._drop_outbound_audio = True + + # Send interrupt event to client IMMEDIATELY + # This must happen BEFORE canceling services, so client knows to discard in-flight audio + await self._send_event({ + **ev( + "response.interrupted", + trackId=self.session_id, + ) + }, priority=0) + + # Cancel TTS + if self.tts_service: + await self.tts_service.cancel() + + # Cancel LLM + if self.llm_service and hasattr(self.llm_service, 'cancel'): + self.llm_service.cancel() + + # Interrupt conversation only if there is no active turn task. + # When a turn task exists, it will handle end_assistant_turn() to avoid double callbacks. + if not (self._current_turn_task and not self._current_turn_task.done()): + await self.conversation.interrupt() + + # Reset for new user turn + await self.conversation.start_user_turn() + self._audio_buffer = b"" + self.eou_detector.reset() + self._asr_capture_active = False + self._pending_speech_audio = b"" + + async def _stop_current_speech(self) -> None: + """Stop any current speech task.""" + self._drop_outbound_audio = True + if self._current_turn_task and not self._current_turn_task.done(): + self._interrupt_event.set() + self._current_turn_task.cancel() + try: + await self._current_turn_task + except asyncio.CancelledError: + pass + + # Ensure underlying services are cancelled to avoid leaking work/audio + if self.tts_service: + await self.tts_service.cancel() + if self.llm_service and hasattr(self.llm_service, 'cancel'): + self.llm_service.cancel() + + self._is_bot_speaking = False + self._interrupt_event.clear() + + async def cleanup(self) -> None: + """Cleanup pipeline resources.""" + logger.info(f"Cleaning up DuplexPipeline for session {self.session_id}") + + self._running = False + await self._stop_current_speech() + if self._outbound_task and not self._outbound_task.done(): + await self._enqueue_outbound("stop", None, priority=-1000) + await self._outbound_task + self._outbound_task = None + + # Disconnect services + if self.llm_service: + await self.llm_service.disconnect() + if self.tts_service: + await self.tts_service.disconnect() + if self.asr_service: + await self.asr_service.disconnect() + + def _get_timestamp_ms(self) -> int: + """Get current timestamp in milliseconds.""" + import time + return int(time.time() * 1000) + + @property + def is_speaking(self) -> bool: + """Check if bot is currently speaking.""" + return self._is_bot_speaking + + @property + def state(self) -> ConversationState: + """Get current conversation state.""" + return self.conversation.state diff --git a/core/events.py b/core/events.py new file mode 100644 index 0000000..1762148 --- /dev/null +++ b/core/events.py @@ -0,0 +1,134 @@ +"""Event bus for pub/sub communication between components.""" + +import asyncio +from typing import Callable, Dict, List, Any, Optional +from collections import defaultdict +from loguru import logger + + +class EventBus: + """ + Async event bus for pub/sub communication. + + Similar to the original Rust implementation's broadcast channel. + Components can subscribe to specific event types and receive events asynchronously. + """ + + def __init__(self): + """Initialize the event bus.""" + self._subscribers: Dict[str, List[Callable]] = defaultdict(list) + self._lock = asyncio.Lock() + self._running = True + + def subscribe(self, event_type: str, callback: Callable[[Dict[str, Any]], None]) -> None: + """ + Subscribe to an event type. + + Args: + event_type: Type of event to subscribe to (e.g., "speaking", "silence") + callback: Async callback function that receives event data + """ + if not self._running: + logger.warning(f"Event bus is shut down, ignoring subscription to {event_type}") + return + + self._subscribers[event_type].append(callback) + logger.debug(f"Subscribed to event type: {event_type}") + + def unsubscribe(self, event_type: str, callback: Callable[[Dict[str, Any]], None]) -> None: + """ + Unsubscribe from an event type. + + Args: + event_type: Type of event to unsubscribe from + callback: Callback function to remove + """ + if callback in self._subscribers[event_type]: + self._subscribers[event_type].remove(callback) + logger.debug(f"Unsubscribed from event type: {event_type}") + + async def publish(self, event_type: str, event_data: Dict[str, Any]) -> None: + """ + Publish an event to all subscribers. + + Args: + event_type: Type of event to publish + event_data: Event data to send to subscribers + """ + if not self._running: + logger.warning(f"Event bus is shut down, ignoring event: {event_type}") + return + + # Get subscribers for this event type + subscribers = self._subscribers.get(event_type, []) + + if not subscribers: + logger.debug(f"No subscribers for event type: {event_type}") + return + + # Notify all subscribers concurrently + tasks = [] + for callback in subscribers: + try: + # Create task for each subscriber + task = asyncio.create_task(self._call_subscriber(callback, event_data)) + tasks.append(task) + except Exception as e: + logger.error(f"Error creating task for subscriber: {e}") + + # Wait for all subscribers to complete + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + + logger.debug(f"Published event '{event_type}' to {len(tasks)} subscribers") + + async def _call_subscriber(self, callback: Callable[[Dict[str, Any]], None], event_data: Dict[str, Any]) -> None: + """ + Call a subscriber callback with error handling. + + Args: + callback: Subscriber callback function + event_data: Event data to pass to callback + """ + try: + # Check if callback is a coroutine function + if asyncio.iscoroutinefunction(callback): + await callback(event_data) + else: + callback(event_data) + except Exception as e: + logger.error(f"Error in subscriber callback: {e}", exc_info=True) + + async def close(self) -> None: + """Close the event bus and stop processing events.""" + self._running = False + self._subscribers.clear() + logger.info("Event bus closed") + + @property + def is_running(self) -> bool: + """Check if the event bus is running.""" + return self._running + + +# Global event bus instance +_event_bus: Optional[EventBus] = None + + +def get_event_bus() -> EventBus: + """ + Get the global event bus instance. + + Returns: + EventBus instance + """ + global _event_bus + if _event_bus is None: + _event_bus = EventBus() + return _event_bus + + +def reset_event_bus() -> None: + """Reset the global event bus (mainly for testing).""" + global _event_bus + _event_bus = None diff --git a/core/session.py b/core/session.py new file mode 100644 index 0000000..3f8f18d --- /dev/null +++ b/core/session.py @@ -0,0 +1,648 @@ +"""Session management for active calls.""" + +import asyncio +import uuid +import json +import time +import re +from enum import Enum +from typing import Optional, Dict, Any, List +from loguru import logger + +from app.backend_client import ( + create_history_call_record, + add_history_transcript, + finalize_history_call_record, +) +from core.transports import BaseTransport +from core.duplex_pipeline import DuplexPipeline +from core.conversation import ConversationTurn +from core.workflow_runner import WorkflowRunner, WorkflowTransition, WorkflowNodeDef, WorkflowEdgeDef +from app.config import settings +from services.base import LLMMessage +from models.ws_v1 import ( + parse_client_message, + ev, + HelloMessage, + SessionStartMessage, + SessionStopMessage, + InputTextMessage, + ResponseCancelMessage, + ToolCallResultsMessage, +) + + +class WsSessionState(str, Enum): + """Protocol state machine for WS sessions.""" + + WAIT_HELLO = "wait_hello" + WAIT_START = "wait_start" + ACTIVE = "active" + STOPPED = "stopped" + + +class Session: + """ + Manages a single call session. + + Handles command routing, audio processing, and session lifecycle. + Uses full duplex voice conversation pipeline. + """ + + def __init__(self, session_id: str, transport: BaseTransport, use_duplex: bool = None): + """ + Initialize session. + + Args: + session_id: Unique session identifier + transport: Transport instance for communication + use_duplex: Whether to use duplex pipeline (defaults to settings.duplex_enabled) + """ + self.id = session_id + self.transport = transport + self.use_duplex = use_duplex if use_duplex is not None else settings.duplex_enabled + + self.pipeline = DuplexPipeline( + transport=transport, + session_id=session_id, + system_prompt=settings.duplex_system_prompt, + greeting=settings.duplex_greeting + ) + + # Session state + self.created_at = None + self.state = "created" # Legacy call state for /call/lists + self.ws_state = WsSessionState.WAIT_HELLO + self._pipeline_started = False + self.protocol_version: Optional[str] = None + self.authenticated: bool = False + + # Track IDs + self.current_track_id: Optional[str] = str(uuid.uuid4()) + self._history_call_id: Optional[str] = None + self._history_turn_index: int = 0 + self._history_call_started_mono: Optional[float] = None + self._history_finalized: bool = False + self._cleanup_lock = asyncio.Lock() + self._cleaned_up = False + self.workflow_runner: Optional[WorkflowRunner] = None + self._workflow_last_user_text: str = "" + self._workflow_initial_node: Optional[WorkflowNodeDef] = None + + self.pipeline.conversation.on_turn_complete(self._on_turn_complete) + + logger.info(f"Session {self.id} created (duplex={self.use_duplex})") + + async def handle_text(self, text_data: str) -> None: + """ + Handle incoming text data (WS v1 JSON control messages). + + Args: + text_data: JSON text data + """ + try: + data = json.loads(text_data) + message = parse_client_message(data) + await self._handle_v1_message(message) + + except json.JSONDecodeError as e: + logger.error(f"Session {self.id} JSON decode error: {e}") + await self._send_error("client", f"Invalid JSON: {e}", "protocol.invalid_json") + + except ValueError as e: + logger.error(f"Session {self.id} command parse error: {e}") + await self._send_error("client", f"Invalid message: {e}", "protocol.invalid_message") + + except Exception as e: + logger.error(f"Session {self.id} handle_text error: {e}", exc_info=True) + await self._send_error("server", f"Internal error: {e}", "server.internal") + + async def handle_audio(self, audio_bytes: bytes) -> None: + """ + Handle incoming audio data. + + Args: + audio_bytes: PCM audio data + """ + if self.ws_state != WsSessionState.ACTIVE: + await self._send_error( + "client", + "Audio received before session.start", + "protocol.order", + ) + return + + try: + await self.pipeline.process_audio(audio_bytes) + except Exception as e: + logger.error(f"Session {self.id} handle_audio error: {e}", exc_info=True) + + async def _handle_v1_message(self, message: Any) -> None: + """Route validated WS v1 message to handlers.""" + msg_type = message.type + logger.info(f"Session {self.id} received message: {msg_type}") + + if isinstance(message, HelloMessage): + await self._handle_hello(message) + return + + # All messages below require hello handshake first + if self.ws_state == WsSessionState.WAIT_HELLO: + await self._send_error( + "client", + "Expected hello message first", + "protocol.order", + ) + return + + if isinstance(message, SessionStartMessage): + await self._handle_session_start(message) + return + + # All messages below require active session + if self.ws_state != WsSessionState.ACTIVE: + await self._send_error( + "client", + f"Message '{msg_type}' requires active session", + "protocol.order", + ) + return + + if isinstance(message, InputTextMessage): + await self.pipeline.process_text(message.text) + elif isinstance(message, ResponseCancelMessage): + if message.graceful: + logger.info(f"Session {self.id} graceful response.cancel") + else: + await self.pipeline.interrupt() + elif isinstance(message, ToolCallResultsMessage): + await self.pipeline.handle_tool_call_results(message.results) + elif isinstance(message, SessionStopMessage): + await self._handle_session_stop(message.reason) + else: + await self._send_error("client", f"Unsupported message type: {msg_type}", "protocol.unsupported") + + async def _handle_hello(self, message: HelloMessage) -> None: + """Handle initial hello/auth/version negotiation.""" + if self.ws_state != WsSessionState.WAIT_HELLO: + await self._send_error("client", "Duplicate hello", "protocol.order") + return + + if message.version != settings.ws_protocol_version: + await self._send_error( + "client", + f"Unsupported protocol version '{message.version}'", + "protocol.version_unsupported", + ) + await self.transport.close() + self.ws_state = WsSessionState.STOPPED + return + + auth_payload = message.auth or {} + api_key = auth_payload.get("apiKey") + jwt = auth_payload.get("jwt") + + if settings.ws_api_key: + if api_key != settings.ws_api_key: + await self._send_error("auth", "Invalid API key", "auth.invalid_api_key") + await self.transport.close() + self.ws_state = WsSessionState.STOPPED + return + elif settings.ws_require_auth and not (api_key or jwt): + await self._send_error("auth", "Authentication required", "auth.required") + await self.transport.close() + self.ws_state = WsSessionState.STOPPED + return + + self.authenticated = True + self.protocol_version = message.version + self.ws_state = WsSessionState.WAIT_START + await self.transport.send_event( + ev( + "hello.ack", + sessionId=self.id, + version=self.protocol_version, + ) + ) + + async def _handle_session_start(self, message: SessionStartMessage) -> None: + """Handle explicit session start after successful hello.""" + if self.ws_state != WsSessionState.WAIT_START: + await self._send_error("client", "Duplicate session.start", "protocol.order") + return + + metadata = message.metadata or {} + metadata = self._merge_runtime_metadata(metadata, self._bootstrap_workflow(metadata)) + + # Create history call record early so later turn callbacks can append transcripts. + await self._start_history_bridge(metadata) + + # Apply runtime service/prompt overrides from backend if provided + self.pipeline.apply_runtime_overrides(metadata) + + # Start duplex pipeline + if not self._pipeline_started: + await self.pipeline.start() + self._pipeline_started = True + logger.info(f"Session {self.id} duplex pipeline started") + + self.state = "accepted" + self.ws_state = WsSessionState.ACTIVE + await self.transport.send_event( + ev( + "session.started", + sessionId=self.id, + trackId=self.current_track_id, + audio=message.audio or {}, + ) + ) + if self.workflow_runner and self._workflow_initial_node: + await self.transport.send_event( + ev( + "workflow.started", + sessionId=self.id, + workflowId=self.workflow_runner.workflow_id, + workflowName=self.workflow_runner.name, + nodeId=self._workflow_initial_node.id, + ) + ) + await self.transport.send_event( + ev( + "workflow.node.entered", + sessionId=self.id, + workflowId=self.workflow_runner.workflow_id, + nodeId=self._workflow_initial_node.id, + nodeName=self._workflow_initial_node.name, + nodeType=self._workflow_initial_node.node_type, + ) + ) + + async def _handle_session_stop(self, reason: Optional[str]) -> None: + """Handle session stop.""" + if self.ws_state == WsSessionState.STOPPED: + return + + stop_reason = reason or "client_requested" + self.state = "hungup" + self.ws_state = WsSessionState.STOPPED + await self.transport.send_event( + ev( + "session.stopped", + sessionId=self.id, + reason=stop_reason, + ) + ) + await self._finalize_history(status="connected") + await self.transport.close() + + async def _send_error(self, sender: str, error_message: str, code: str) -> None: + """ + Send error event to client. + + Args: + sender: Component that generated the error + error_message: Error message + code: Machine-readable error code + """ + await self.transport.send_event( + ev( + "error", + sender=sender, + code=code, + message=error_message, + trackId=self.current_track_id, + ) + ) + + def _get_timestamp_ms(self) -> int: + """Get current timestamp in milliseconds.""" + import time + return int(time.time() * 1000) + + async def cleanup(self) -> None: + """Cleanup session resources.""" + async with self._cleanup_lock: + if self._cleaned_up: + return + + self._cleaned_up = True + logger.info(f"Session {self.id} cleaning up") + await self._finalize_history(status="connected") + await self.pipeline.cleanup() + await self.transport.close() + + async def _start_history_bridge(self, metadata: Dict[str, Any]) -> None: + """Initialize backend history call record for this session.""" + if self._history_call_id: + return + + history_meta: Dict[str, Any] = {} + if isinstance(metadata.get("history"), dict): + history_meta = metadata["history"] + + raw_user_id = history_meta.get("userId", metadata.get("userId", settings.history_default_user_id)) + try: + user_id = int(raw_user_id) + except (TypeError, ValueError): + user_id = settings.history_default_user_id + + assistant_id = history_meta.get("assistantId", metadata.get("assistantId")) + source = str(history_meta.get("source", metadata.get("source", "debug"))) + + call_id = await create_history_call_record( + user_id=user_id, + assistant_id=str(assistant_id) if assistant_id else None, + source=source, + ) + if not call_id: + return + + self._history_call_id = call_id + self._history_call_started_mono = time.monotonic() + self._history_turn_index = 0 + self._history_finalized = False + logger.info(f"Session {self.id} history bridge enabled (call_id={call_id}, source={source})") + + async def _on_turn_complete(self, turn: ConversationTurn) -> None: + """Process workflow transitions and persist completed turns to history.""" + if turn.text and turn.text.strip(): + role = (turn.role or "").lower() + if role == "user": + self._workflow_last_user_text = turn.text.strip() + elif role == "assistant": + await self._maybe_advance_workflow(turn.text.strip()) + + if not self._history_call_id: + return + if not turn.text or not turn.text.strip(): + return + + role = (turn.role or "").lower() + speaker = "human" if role == "user" else "ai" + + end_ms = 0 + if self._history_call_started_mono is not None: + end_ms = max(0, int((time.monotonic() - self._history_call_started_mono) * 1000)) + estimated_duration_ms = max(300, min(12000, len(turn.text.strip()) * 80)) + start_ms = max(0, end_ms - estimated_duration_ms) + + turn_index = self._history_turn_index + await add_history_transcript( + call_id=self._history_call_id, + turn_index=turn_index, + speaker=speaker, + content=turn.text.strip(), + start_ms=start_ms, + end_ms=end_ms, + duration_ms=max(1, end_ms - start_ms), + ) + self._history_turn_index += 1 + + async def _finalize_history(self, status: str) -> None: + """Finalize history call record once.""" + if not self._history_call_id or self._history_finalized: + return + + duration_seconds = 0 + if self._history_call_started_mono is not None: + duration_seconds = max(0, int(time.monotonic() - self._history_call_started_mono)) + + ok = await finalize_history_call_record( + call_id=self._history_call_id, + status=status, + duration_seconds=duration_seconds, + ) + if ok: + self._history_finalized = True + + def _bootstrap_workflow(self, metadata: Dict[str, Any]) -> Dict[str, Any]: + """Parse workflow payload and return initial runtime overrides.""" + payload = metadata.get("workflow") + self.workflow_runner = WorkflowRunner.from_payload(payload) + self._workflow_initial_node = None + if not self.workflow_runner: + return {} + + node = self.workflow_runner.bootstrap() + if not node: + logger.warning(f"Session {self.id} workflow payload had no resolvable start node") + self.workflow_runner = None + return {} + + self._workflow_initial_node = node + logger.info( + "Session {} workflow enabled: workflow={} start_node={}", + self.id, + self.workflow_runner.workflow_id, + node.id, + ) + return self.workflow_runner.build_runtime_metadata(node) + + async def _maybe_advance_workflow(self, assistant_text: str) -> None: + """Attempt node transfer after assistant turn finalization.""" + if not self.workflow_runner or self.ws_state == WsSessionState.STOPPED: + return + + transition = await self.workflow_runner.route( + user_text=self._workflow_last_user_text, + assistant_text=assistant_text, + llm_router=self._workflow_llm_route, + ) + if not transition: + return + + await self._apply_workflow_transition(transition, reason="rule_match") + + # Auto-advance through utility nodes when default edges are present. + max_auto_hops = 6 + auto_hops = 0 + while self.workflow_runner and self.ws_state != WsSessionState.STOPPED: + current = self.workflow_runner.current_node + if not current or current.node_type not in {"start", "tool"}: + break + + next_default = self.workflow_runner.next_default_transition() + if not next_default: + break + + auto_hops += 1 + await self._apply_workflow_transition(next_default, reason="auto") + if auto_hops >= max_auto_hops: + logger.warning( + "Session {} workflow auto-advance reached hop limit (possible cycle)", + self.id, + ) + break + + async def _apply_workflow_transition(self, transition: WorkflowTransition, reason: str) -> None: + """Apply graph transition and emit workflow lifecycle events.""" + if not self.workflow_runner: + return + + self.workflow_runner.apply_transition(transition) + node = transition.node + edge = transition.edge + + await self.transport.send_event( + ev( + "workflow.edge.taken", + sessionId=self.id, + workflowId=self.workflow_runner.workflow_id, + edgeId=edge.id, + fromNodeId=edge.from_node_id, + toNodeId=edge.to_node_id, + reason=reason, + ) + ) + await self.transport.send_event( + ev( + "workflow.node.entered", + sessionId=self.id, + workflowId=self.workflow_runner.workflow_id, + nodeId=node.id, + nodeName=node.name, + nodeType=node.node_type, + ) + ) + + node_runtime = self.workflow_runner.build_runtime_metadata(node) + if node_runtime: + self.pipeline.apply_runtime_overrides(node_runtime) + + if node.node_type == "tool": + await self.transport.send_event( + ev( + "workflow.tool.requested", + sessionId=self.id, + workflowId=self.workflow_runner.workflow_id, + nodeId=node.id, + tool=node.tool or {}, + ) + ) + return + + if node.node_type == "human_transfer": + await self.transport.send_event( + ev( + "workflow.human_transfer", + sessionId=self.id, + workflowId=self.workflow_runner.workflow_id, + nodeId=node.id, + ) + ) + await self._handle_session_stop("workflow_human_transfer") + return + + if node.node_type == "end": + await self.transport.send_event( + ev( + "workflow.ended", + sessionId=self.id, + workflowId=self.workflow_runner.workflow_id, + nodeId=node.id, + ) + ) + await self._handle_session_stop("workflow_end") + + async def _workflow_llm_route( + self, + node: WorkflowNodeDef, + candidates: List[WorkflowEdgeDef], + context: Dict[str, str], + ) -> Optional[str]: + """LLM-based edge routing for condition.type == 'llm' edges.""" + llm_service = self.pipeline.llm_service + if not llm_service: + return None + + candidate_rows = [ + { + "edgeId": edge.id, + "toNodeId": edge.to_node_id, + "label": edge.label, + "hint": edge.condition.get("prompt") if isinstance(edge.condition, dict) else None, + } + for edge in candidates + ] + system_prompt = ( + "You are a workflow router. Pick exactly one edge. " + "Return JSON only: {\"edgeId\":\"...\"}." + ) + user_prompt = json.dumps( + { + "nodeId": node.id, + "nodeName": node.name, + "userText": context.get("userText", ""), + "assistantText": context.get("assistantText", ""), + "candidates": candidate_rows, + }, + ensure_ascii=False, + ) + + try: + reply = await llm_service.generate( + [ + LLMMessage(role="system", content=system_prompt), + LLMMessage(role="user", content=user_prompt), + ], + temperature=0.0, + max_tokens=64, + ) + except Exception as exc: + logger.warning(f"Session {self.id} workflow llm routing failed: {exc}") + return None + + if not reply: + return None + + edge_ids = {edge.id for edge in candidates} + node_ids = {edge.to_node_id for edge in candidates} + + parsed = self._extract_json_obj(reply) + if isinstance(parsed, dict): + edge_id = parsed.get("edgeId") or parsed.get("id") + node_id = parsed.get("toNodeId") or parsed.get("nodeId") + if isinstance(edge_id, str) and edge_id in edge_ids: + return edge_id + if isinstance(node_id, str) and node_id in node_ids: + return node_id + + token_candidates = sorted(edge_ids | node_ids, key=len, reverse=True) + lowered_reply = reply.lower() + for token in token_candidates: + if token.lower() in lowered_reply: + return token + return None + + def _merge_runtime_metadata(self, base: Dict[str, Any], overrides: Dict[str, Any]) -> Dict[str, Any]: + """Merge node-level metadata overrides into session.start metadata.""" + merged = dict(base or {}) + if not overrides: + return merged + for key, value in overrides.items(): + if key == "services" and isinstance(value, dict): + existing = merged.get("services") + merged_services = dict(existing) if isinstance(existing, dict) else {} + merged_services.update(value) + merged["services"] = merged_services + else: + merged[key] = value + return merged + + def _extract_json_obj(self, text: str) -> Optional[Dict[str, Any]]: + """Best-effort extraction of a JSON object from freeform text.""" + try: + parsed = json.loads(text) + if isinstance(parsed, dict): + return parsed + except Exception: + pass + + match = re.search(r"\{.*\}", text, re.DOTALL) + if not match: + return None + try: + parsed = json.loads(match.group(0)) + return parsed if isinstance(parsed, dict) else None + except Exception: + return None diff --git a/core/tool_executor.py b/core/tool_executor.py new file mode 100644 index 0000000..407e199 --- /dev/null +++ b/core/tool_executor.py @@ -0,0 +1,340 @@ +"""Server-side tool execution helpers.""" + +import asyncio +import ast +import operator +from datetime import datetime +from typing import Any, Dict + +import aiohttp + +from app.backend_client import fetch_tool_resource + +_BIN_OPS = { + ast.Add: operator.add, + ast.Sub: operator.sub, + ast.Mult: operator.mul, + ast.Div: operator.truediv, + ast.Mod: operator.mod, +} + +_UNARY_OPS = { + ast.UAdd: operator.pos, + ast.USub: operator.neg, +} + +_SAFE_EVAL_FUNCS = { + "abs": abs, + "round": round, + "min": min, + "max": max, + "sum": sum, + "len": len, +} + + +def _validate_safe_expr(node: ast.AST) -> None: + """Allow only a constrained subset of Python expression nodes.""" + if isinstance(node, ast.Expression): + _validate_safe_expr(node.body) + return + + if isinstance(node, ast.Constant): + return + + if isinstance(node, (ast.List, ast.Tuple, ast.Set)): + for elt in node.elts: + _validate_safe_expr(elt) + return + + if isinstance(node, ast.Dict): + for key in node.keys: + if key is not None: + _validate_safe_expr(key) + for value in node.values: + _validate_safe_expr(value) + return + + if isinstance(node, ast.BinOp): + if type(node.op) not in _BIN_OPS: + raise ValueError("unsupported operator") + _validate_safe_expr(node.left) + _validate_safe_expr(node.right) + return + + if isinstance(node, ast.UnaryOp): + if type(node.op) not in _UNARY_OPS: + raise ValueError("unsupported unary operator") + _validate_safe_expr(node.operand) + return + + if isinstance(node, ast.BoolOp): + for value in node.values: + _validate_safe_expr(value) + return + + if isinstance(node, ast.Compare): + _validate_safe_expr(node.left) + for comp in node.comparators: + _validate_safe_expr(comp) + return + + if isinstance(node, ast.Name): + if node.id not in _SAFE_EVAL_FUNCS and node.id not in {"True", "False", "None"}: + raise ValueError("unknown symbol") + return + + if isinstance(node, ast.Call): + if not isinstance(node.func, ast.Name): + raise ValueError("unsafe call target") + if node.func.id not in _SAFE_EVAL_FUNCS: + raise ValueError("function not allowed") + for arg in node.args: + _validate_safe_expr(arg) + for kw in node.keywords: + _validate_safe_expr(kw.value) + return + + # Explicitly reject high-risk nodes (import/attribute/subscript/comprehensions/lambda, etc.) + raise ValueError("unsupported expression") + + +def _safe_eval_python_expr(expression: str) -> Any: + tree = ast.parse(expression, mode="eval") + _validate_safe_expr(tree) + return eval( # noqa: S307 - validated AST + empty builtins + compile(tree, "", "eval"), + {"__builtins__": {}}, + dict(_SAFE_EVAL_FUNCS), + ) + + +def _json_safe(value: Any) -> Any: + if isinstance(value, (str, int, float, bool)) or value is None: + return value + if isinstance(value, (list, tuple)): + return [_json_safe(v) for v in value] + if isinstance(value, dict): + return {str(k): _json_safe(v) for k, v in value.items()} + return repr(value) + + +def _safe_eval_expr(expression: str) -> float: + tree = ast.parse(expression, mode="eval") + + def _eval(node: ast.AST) -> float: + if isinstance(node, ast.Expression): + return _eval(node.body) + if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)): + return float(node.value) + if isinstance(node, ast.BinOp): + op = _BIN_OPS.get(type(node.op)) + if not op: + raise ValueError("unsupported operator") + return float(op(_eval(node.left), _eval(node.right))) + if isinstance(node, ast.UnaryOp): + op = _UNARY_OPS.get(type(node.op)) + if not op: + raise ValueError("unsupported unary operator") + return float(op(_eval(node.operand))) + raise ValueError("unsupported expression") + + return _eval(tree) + + +def _extract_tool_name(tool_call: Dict[str, Any]) -> str: + function_payload = tool_call.get("function") + if isinstance(function_payload, dict): + return str(function_payload.get("name") or "").strip() + return "" + + +def _extract_tool_args(tool_call: Dict[str, Any]) -> Dict[str, Any]: + function_payload = tool_call.get("function") + if not isinstance(function_payload, dict): + return {} + raw = function_payload.get("arguments") + if isinstance(raw, dict): + return raw + if not isinstance(raw, str): + return {} + text = raw.strip() + if not text: + return {} + try: + import json + + parsed = json.loads(text) + return parsed if isinstance(parsed, dict) else {} + except Exception: + return {} + + +async def execute_server_tool(tool_call: Dict[str, Any]) -> Dict[str, Any]: + """Execute a server-side tool and return normalized result payload.""" + call_id = str(tool_call.get("id") or "").strip() + tool_name = _extract_tool_name(tool_call) + args = _extract_tool_args(tool_call) + + if tool_name == "calculator": + expression = str(args.get("expression") or "").strip() + if not expression: + return { + "tool_call_id": call_id, + "name": tool_name, + "output": {"error": "missing expression"}, + "status": {"code": 400, "message": "bad_request"}, + } + if len(expression) > 200: + return { + "tool_call_id": call_id, + "name": tool_name, + "output": {"expression": expression, "error": "expression too long"}, + "status": {"code": 422, "message": "invalid_expression"}, + } + try: + value = _safe_eval_expr(expression) + if value.is_integer(): + value = int(value) + return { + "tool_call_id": call_id, + "name": tool_name, + "output": {"expression": expression, "result": value}, + "status": {"code": 200, "message": "ok"}, + } + except Exception as exc: + return { + "tool_call_id": call_id, + "name": tool_name, + "output": {"expression": expression, "error": str(exc)}, + "status": {"code": 422, "message": "invalid_expression"}, + } + + if tool_name == "code_interpreter": + code = str(args.get("code") or args.get("expression") or "").strip() + if not code: + return { + "tool_call_id": call_id, + "name": tool_name, + "output": {"error": "missing code"}, + "status": {"code": 400, "message": "bad_request"}, + } + if len(code) > 500: + return { + "tool_call_id": call_id, + "name": tool_name, + "output": {"error": "code too long"}, + "status": {"code": 422, "message": "invalid_code"}, + } + try: + result = _safe_eval_python_expr(code) + return { + "tool_call_id": call_id, + "name": tool_name, + "output": {"code": code, "result": _json_safe(result)}, + "status": {"code": 200, "message": "ok"}, + } + except Exception as exc: + return { + "tool_call_id": call_id, + "name": tool_name, + "output": {"code": code, "error": str(exc)}, + "status": {"code": 422, "message": "invalid_code"}, + } + + if tool_name == "current_time": + now = datetime.now().astimezone() + return { + "tool_call_id": call_id, + "name": tool_name, + "output": { + "local_time": now.strftime("%Y-%m-%d %H:%M:%S"), + "iso": now.isoformat(), + "timezone": str(now.tzinfo or ""), + "timestamp": int(now.timestamp()), + }, + "status": {"code": 200, "message": "ok"}, + } + + if tool_name and tool_name not in {"calculator", "code_interpreter", "current_time"}: + resource = await fetch_tool_resource(tool_name) + if resource and str(resource.get("category") or "") == "query": + method = str(resource.get("http_method") or "GET").strip().upper() + if method not in {"GET", "POST", "PUT", "PATCH", "DELETE"}: + method = "GET" + url = str(resource.get("http_url") or "").strip() + headers = resource.get("http_headers") if isinstance(resource.get("http_headers"), dict) else {} + timeout_ms = resource.get("http_timeout_ms") + try: + timeout_s = max(1.0, float(timeout_ms) / 1000.0) + except Exception: + timeout_s = 10.0 + + if not url: + return { + "tool_call_id": call_id, + "name": tool_name, + "output": {"error": "http_url not configured"}, + "status": {"code": 422, "message": "invalid_tool_config"}, + } + + request_kwargs: Dict[str, Any] = {} + if method in {"GET", "DELETE"}: + request_kwargs["params"] = args + else: + request_kwargs["json"] = args + + try: + timeout = aiohttp.ClientTimeout(total=timeout_s) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.request(method, url, headers=headers, **request_kwargs) as resp: + content_type = str(resp.headers.get("Content-Type") or "").lower() + if "application/json" in content_type: + body: Any = await resp.json() + else: + body = await resp.text() + status_code = int(resp.status) + if 200 <= status_code < 300: + return { + "tool_call_id": call_id, + "name": tool_name, + "output": { + "method": method, + "url": url, + "status_code": status_code, + "response": _json_safe(body), + }, + "status": {"code": 200, "message": "ok"}, + } + return { + "tool_call_id": call_id, + "name": tool_name, + "output": { + "method": method, + "url": url, + "status_code": status_code, + "response": _json_safe(body), + }, + "status": {"code": status_code, "message": "http_error"}, + } + except asyncio.TimeoutError: + return { + "tool_call_id": call_id, + "name": tool_name, + "output": {"method": method, "url": url, "error": "request timeout"}, + "status": {"code": 504, "message": "http_timeout"}, + } + except Exception as exc: + return { + "tool_call_id": call_id, + "name": tool_name, + "output": {"method": method, "url": url, "error": str(exc)}, + "status": {"code": 502, "message": "http_request_failed"}, + } + + return { + "tool_call_id": call_id, + "name": tool_name or "unknown_tool", + "output": {"message": "server tool not implemented"}, + "status": {"code": 501, "message": "not_implemented"}, + } diff --git a/core/transports.py b/core/transports.py new file mode 100644 index 0000000..31e398c --- /dev/null +++ b/core/transports.py @@ -0,0 +1,247 @@ +"""Transport layer for WebSocket and WebRTC communication.""" + +import asyncio +import json +from abc import ABC, abstractmethod +from typing import Optional +from fastapi import WebSocket +from starlette.websockets import WebSocketState +from loguru import logger + +# Try to import aiortc (optional for WebRTC functionality) +try: + from aiortc import RTCPeerConnection + AIORTC_AVAILABLE = True +except ImportError: + AIORTC_AVAILABLE = False + RTCPeerConnection = None # Type hint placeholder + + +class BaseTransport(ABC): + """ + Abstract base class for transports. + + All transports must implement send_event and send_audio methods. + """ + + @abstractmethod + async def send_event(self, event: dict) -> None: + """ + Send a JSON event to the client. + + Args: + event: Event data as dictionary + """ + pass + + @abstractmethod + async def send_audio(self, pcm_bytes: bytes) -> None: + """ + Send audio data to the client. + + Args: + pcm_bytes: PCM audio data (16-bit, mono, 16kHz) + """ + pass + + @abstractmethod + async def close(self) -> None: + """Close the transport and cleanup resources.""" + pass + + +class SocketTransport(BaseTransport): + """ + WebSocket transport for raw audio streaming. + + Handles mixed text/binary frames over WebSocket connection. + Uses asyncio.Lock to prevent frame interleaving. + """ + + def __init__(self, websocket: WebSocket): + """ + Initialize WebSocket transport. + + Args: + websocket: FastAPI WebSocket connection + """ + self.ws = websocket + self.lock = asyncio.Lock() # Prevent frame interleaving + self._closed = False + + def _ws_disconnected(self) -> bool: + """Best-effort check for websocket disconnection state.""" + return ( + self.ws.client_state == WebSocketState.DISCONNECTED + or self.ws.application_state == WebSocketState.DISCONNECTED + ) + + async def send_event(self, event: dict) -> None: + """ + Send a JSON event via WebSocket. + + Args: + event: Event data as dictionary + """ + if self._closed or self._ws_disconnected(): + logger.warning("Attempted to send event on closed transport") + self._closed = True + return + + async with self.lock: + try: + await self.ws.send_text(json.dumps(event)) + logger.debug(f"Sent event: {event.get('event', 'unknown')}") + except RuntimeError as e: + self._closed = True + if self._ws_disconnected() or "close message has been sent" in str(e): + logger.debug(f"Skip sending event on closed websocket: {e!r}") + return + logger.error(f"Error sending event: {e!r}") + except Exception as e: + self._closed = True + if self._ws_disconnected(): + logger.debug(f"Skip sending event on disconnected websocket: {e!r}") + return + logger.error(f"Error sending event: {e!r}") + + async def send_audio(self, pcm_bytes: bytes) -> None: + """ + Send PCM audio data via WebSocket. + + Args: + pcm_bytes: PCM audio data (16-bit, mono, 16kHz) + """ + if self._closed or self._ws_disconnected(): + logger.warning("Attempted to send audio on closed transport") + self._closed = True + return + + async with self.lock: + try: + await self.ws.send_bytes(pcm_bytes) + except RuntimeError as e: + self._closed = True + if self._ws_disconnected() or "close message has been sent" in str(e): + logger.debug(f"Skip sending audio on closed websocket: {e!r}") + return + logger.error(f"Error sending audio: {e!r}") + except Exception as e: + self._closed = True + if self._ws_disconnected(): + logger.debug(f"Skip sending audio on disconnected websocket: {e!r}") + return + logger.error(f"Error sending audio: {e!r}") + + async def close(self) -> None: + """Close the WebSocket connection.""" + if self._closed: + return + + self._closed = True + if self._ws_disconnected(): + return + + try: + await self.ws.close() + except RuntimeError as e: + # Already closed by another task/path; safe to ignore. + if "close message has been sent" in str(e): + logger.debug(f"WebSocket already closed: {e}") + return + logger.error(f"Error closing WebSocket: {e}") + except Exception as e: + logger.error(f"Error closing WebSocket: {e}") + + @property + def is_closed(self) -> bool: + """Check if the transport is closed.""" + return self._closed + + +class WebRtcTransport(BaseTransport): + """ + WebRTC transport for WebRTC audio streaming. + + Uses WebSocket for signaling and RTCPeerConnection for media. + """ + + def __init__(self, websocket: WebSocket, pc): + """ + Initialize WebRTC transport. + + Args: + websocket: FastAPI WebSocket connection for signaling + pc: RTCPeerConnection for media transport + """ + if not AIORTC_AVAILABLE: + raise RuntimeError("aiortc is not available - WebRTC transport cannot be used") + + self.ws = websocket + self.pc = pc + self.outbound_track = None # MediaStreamTrack for outbound audio + self._closed = False + + async def send_event(self, event: dict) -> None: + """ + Send a JSON event via WebSocket signaling. + + Args: + event: Event data as dictionary + """ + if self._closed: + logger.warning("Attempted to send event on closed transport") + return + + try: + await self.ws.send_text(json.dumps(event)) + logger.debug(f"Sent event: {event.get('event', 'unknown')}") + except Exception as e: + logger.error(f"Error sending event: {e}") + self._closed = True + + async def send_audio(self, pcm_bytes: bytes) -> None: + """ + Send audio data via WebRTC track. + + Note: In WebRTC, you don't send bytes directly. You push frames + to a MediaStreamTrack that the peer connection is reading. + + Args: + pcm_bytes: PCM audio data (16-bit, mono, 16kHz) + """ + if self._closed: + logger.warning("Attempted to send audio on closed transport") + return + + # This would require a custom MediaStreamTrack implementation + # For now, we'll log this as a placeholder + logger.debug(f"Audio bytes queued for WebRTC track: {len(pcm_bytes)} bytes") + + # TODO: Implement outbound audio track if needed + # if self.outbound_track: + # await self.outbound_track.add_frame(pcm_bytes) + + async def close(self) -> None: + """Close the WebRTC connection.""" + self._closed = True + try: + await self.pc.close() + await self.ws.close() + except Exception as e: + logger.error(f"Error closing WebRTC transport: {e}") + + @property + def is_closed(self) -> bool: + """Check if the transport is closed.""" + return self._closed + + def set_outbound_track(self, track): + """ + Set the outbound audio track for sending audio to client. + + Args: + track: MediaStreamTrack for outbound audio + """ + self.outbound_track = track + logger.debug("Set outbound track for WebRTC transport") diff --git a/core/workflow_runner.py b/core/workflow_runner.py new file mode 100644 index 0000000..2ad7ded --- /dev/null +++ b/core/workflow_runner.py @@ -0,0 +1,402 @@ +"""Workflow runtime helpers for session-level node routing. + +MVP goals: +- Parse workflow graph payload from WS session.start metadata +- Track current node +- Evaluate edge conditions on each assistant turn completion +- Provide per-node runtime metadata overrides (prompt/greeting/services) +""" + +from __future__ import annotations + +from dataclasses import dataclass +import json +import re +from typing import Any, Awaitable, Callable, Dict, List, Optional + +from loguru import logger + + +_NODE_TYPE_MAP = { + "conversation": "assistant", + "assistant": "assistant", + "human": "human_transfer", + "human_transfer": "human_transfer", + "tool": "tool", + "end": "end", + "start": "start", +} + + +def _normalize_node_type(raw_type: Any) -> str: + value = str(raw_type or "").strip().lower() + return _NODE_TYPE_MAP.get(value, "assistant") + + +def _safe_str(value: Any) -> str: + if value is None: + return "" + return str(value) + + +def _normalize_condition(raw: Any, label: Optional[str]) -> Dict[str, Any]: + if not isinstance(raw, dict): + if label: + return {"type": "contains", "source": "user", "value": str(label)} + return {"type": "always"} + + condition = dict(raw) + condition_type = str(condition.get("type", "always")).strip().lower() + if not condition_type: + condition_type = "always" + condition["type"] = condition_type + condition["source"] = str(condition.get("source", "user")).strip().lower() or "user" + return condition + + +@dataclass +class WorkflowNodeDef: + id: str + name: str + node_type: str + is_start: bool + prompt: Optional[str] + message_plan: Dict[str, Any] + assistant_id: Optional[str] + assistant: Dict[str, Any] + tool: Optional[Dict[str, Any]] + raw: Dict[str, Any] + + +@dataclass +class WorkflowEdgeDef: + id: str + from_node_id: str + to_node_id: str + label: Optional[str] + condition: Dict[str, Any] + priority: int + order: int + raw: Dict[str, Any] + + +@dataclass +class WorkflowTransition: + edge: WorkflowEdgeDef + node: WorkflowNodeDef + + +LlmRouter = Callable[ + [WorkflowNodeDef, List[WorkflowEdgeDef], Dict[str, str]], + Awaitable[Optional[str]], +] + + +class WorkflowRunner: + """In-memory workflow graph for a single active session.""" + + def __init__(self, workflow_id: str, name: str, nodes: List[WorkflowNodeDef], edges: List[WorkflowEdgeDef]): + self.workflow_id = workflow_id + self.name = name + self._nodes: Dict[str, WorkflowNodeDef] = {node.id: node for node in nodes} + self._edges = edges + self.current_node_id: Optional[str] = None + + @classmethod + def from_payload(cls, payload: Any) -> Optional["WorkflowRunner"]: + if not isinstance(payload, dict): + return None + + raw_nodes = payload.get("nodes") + raw_edges = payload.get("edges") + if not isinstance(raw_nodes, list) or len(raw_nodes) == 0: + return None + + nodes: List[WorkflowNodeDef] = [] + for i, raw in enumerate(raw_nodes): + if not isinstance(raw, dict): + continue + + node_id = _safe_str(raw.get("id") or raw.get("name") or f"node_{i + 1}").strip() or f"node_{i + 1}" + node_name = _safe_str(raw.get("name") or node_id).strip() or node_id + node_type = _normalize_node_type(raw.get("type")) + is_start = bool(raw.get("isStart")) or node_type == "start" + + prompt: Optional[str] = None + if "prompt" in raw: + prompt = _safe_str(raw.get("prompt")) + + message_plan = raw.get("messagePlan") + if not isinstance(message_plan, dict): + message_plan = {} + + assistant_cfg = raw.get("assistant") + if not isinstance(assistant_cfg, dict): + assistant_cfg = {} + + tool_cfg = raw.get("tool") + if not isinstance(tool_cfg, dict): + tool_cfg = None + + assistant_id = raw.get("assistantId") + if assistant_id is not None: + assistant_id = _safe_str(assistant_id).strip() or None + + nodes.append( + WorkflowNodeDef( + id=node_id, + name=node_name, + node_type=node_type, + is_start=is_start, + prompt=prompt, + message_plan=message_plan, + assistant_id=assistant_id, + assistant=assistant_cfg, + tool=tool_cfg, + raw=raw, + ) + ) + + if not nodes: + return None + + node_ids = {node.id for node in nodes} + edges: List[WorkflowEdgeDef] = [] + for i, raw in enumerate(raw_edges if isinstance(raw_edges, list) else []): + if not isinstance(raw, dict): + continue + + from_node_id = _safe_str( + raw.get("fromNodeId") or raw.get("from") or raw.get("from_") or raw.get("source") + ).strip() + to_node_id = _safe_str(raw.get("toNodeId") or raw.get("to") or raw.get("target")).strip() + if not from_node_id or not to_node_id: + continue + if from_node_id not in node_ids or to_node_id not in node_ids: + continue + + label = raw.get("label") + if label is not None: + label = _safe_str(label) + + condition = _normalize_condition(raw.get("condition"), label=label) + + priority = 100 + try: + priority = int(raw.get("priority", 100)) + except (TypeError, ValueError): + priority = 100 + + edge_id = _safe_str(raw.get("id") or f"e_{from_node_id}_{to_node_id}_{i + 1}").strip() or f"e_{i + 1}" + + edges.append( + WorkflowEdgeDef( + id=edge_id, + from_node_id=from_node_id, + to_node_id=to_node_id, + label=label, + condition=condition, + priority=priority, + order=i, + raw=raw, + ) + ) + + workflow_id = _safe_str(payload.get("id") or "workflow") + workflow_name = _safe_str(payload.get("name") or workflow_id) + return cls(workflow_id=workflow_id, name=workflow_name, nodes=nodes, edges=edges) + + def bootstrap(self) -> Optional[WorkflowNodeDef]: + start_node = self._resolve_start_node() + if not start_node: + return None + self.current_node_id = start_node.id + return start_node + + @property + def current_node(self) -> Optional[WorkflowNodeDef]: + if not self.current_node_id: + return None + return self._nodes.get(self.current_node_id) + + def outgoing_edges(self, node_id: str) -> List[WorkflowEdgeDef]: + edges = [edge for edge in self._edges if edge.from_node_id == node_id] + return sorted(edges, key=lambda edge: (edge.priority, edge.order)) + + def next_default_transition(self) -> Optional[WorkflowTransition]: + node = self.current_node + if not node: + return None + for edge in self.outgoing_edges(node.id): + cond_type = str(edge.condition.get("type", "always")).strip().lower() + if cond_type in {"", "always", "default"}: + target = self._nodes.get(edge.to_node_id) + if target: + return WorkflowTransition(edge=edge, node=target) + return None + + async def route( + self, + *, + user_text: str, + assistant_text: str, + llm_router: Optional[LlmRouter] = None, + ) -> Optional[WorkflowTransition]: + node = self.current_node + if not node: + return None + + outgoing = self.outgoing_edges(node.id) + if not outgoing: + return None + + llm_edges: List[WorkflowEdgeDef] = [] + for edge in outgoing: + cond_type = str(edge.condition.get("type", "always")).strip().lower() + if cond_type == "llm": + llm_edges.append(edge) + continue + if self._matches_condition(edge, user_text=user_text, assistant_text=assistant_text): + target = self._nodes.get(edge.to_node_id) + if target: + return WorkflowTransition(edge=edge, node=target) + + if llm_edges and llm_router: + selection = await llm_router( + node, + llm_edges, + { + "userText": user_text, + "assistantText": assistant_text, + }, + ) + if selection: + for edge in llm_edges: + if selection in {edge.id, edge.to_node_id}: + target = self._nodes.get(edge.to_node_id) + if target: + return WorkflowTransition(edge=edge, node=target) + + for edge in outgoing: + cond_type = str(edge.condition.get("type", "always")).strip().lower() + if cond_type in {"", "always", "default"}: + target = self._nodes.get(edge.to_node_id) + if target: + return WorkflowTransition(edge=edge, node=target) + return None + + def apply_transition(self, transition: WorkflowTransition) -> None: + self.current_node_id = transition.node.id + + def build_runtime_metadata(self, node: WorkflowNodeDef) -> Dict[str, Any]: + assistant_cfg = node.assistant if isinstance(node.assistant, dict) else {} + message_plan = node.message_plan if isinstance(node.message_plan, dict) else {} + metadata: Dict[str, Any] = {} + + if node.prompt is not None: + metadata["systemPrompt"] = node.prompt + elif "systemPrompt" in assistant_cfg: + metadata["systemPrompt"] = _safe_str(assistant_cfg.get("systemPrompt")) + elif "prompt" in assistant_cfg: + metadata["systemPrompt"] = _safe_str(assistant_cfg.get("prompt")) + + first_message = message_plan.get("firstMessage") + if first_message is not None: + metadata["greeting"] = _safe_str(first_message) + elif "greeting" in assistant_cfg: + metadata["greeting"] = _safe_str(assistant_cfg.get("greeting")) + elif "opener" in assistant_cfg: + metadata["greeting"] = _safe_str(assistant_cfg.get("opener")) + + services = assistant_cfg.get("services") + if isinstance(services, dict): + metadata["services"] = services + + if node.assistant_id: + metadata["assistantId"] = node.assistant_id + + return metadata + + def _resolve_start_node(self) -> Optional[WorkflowNodeDef]: + explicit_start = next((node for node in self._nodes.values() if node.is_start), None) + if not explicit_start: + explicit_start = next((node for node in self._nodes.values() if node.node_type == "start"), None) + + if explicit_start: + # If a dedicated start node exists, try to move to its first default target. + if explicit_start.node_type == "start": + visited = {explicit_start.id} + current = explicit_start + for _ in range(8): + transition = self._first_default_transition_from(current.id) + if not transition: + return current + current = transition.node + if current.id in visited: + break + visited.add(current.id) + return current + return explicit_start + + assistant_node = next((node for node in self._nodes.values() if node.node_type == "assistant"), None) + if assistant_node: + return assistant_node + return next(iter(self._nodes.values()), None) + + def _first_default_transition_from(self, node_id: str) -> Optional[WorkflowTransition]: + for edge in self.outgoing_edges(node_id): + cond_type = str(edge.condition.get("type", "always")).strip().lower() + if cond_type in {"", "always", "default"}: + node = self._nodes.get(edge.to_node_id) + if node: + return WorkflowTransition(edge=edge, node=node) + return None + + def _matches_condition(self, edge: WorkflowEdgeDef, *, user_text: str, assistant_text: str) -> bool: + condition = edge.condition or {"type": "always"} + cond_type = str(condition.get("type", "always")).strip().lower() + source = str(condition.get("source", "user")).strip().lower() + + if cond_type in {"", "always", "default"}: + return True + + text = assistant_text if source == "assistant" else user_text + text_lower = (text or "").lower() + + if cond_type == "contains": + values: List[str] = [] + if isinstance(condition.get("values"), list): + values = [_safe_str(v).strip().lower() for v in condition["values"] if _safe_str(v).strip()] + if not values: + single = _safe_str(condition.get("value") or condition.get("keyword") or edge.label).strip().lower() + if single: + values = [single] + if not values: + return False + return any(value in text_lower for value in values) + + if cond_type == "equals": + expected = _safe_str(condition.get("value") or "").strip().lower() + return bool(expected) and text_lower == expected + + if cond_type == "regex": + pattern = _safe_str(condition.get("value") or condition.get("pattern") or "").strip() + if not pattern: + return False + try: + return bool(re.search(pattern, text or "", re.IGNORECASE)) + except re.error: + logger.warning(f"Invalid workflow regex condition: {pattern}") + return False + + if cond_type == "json": + value = _safe_str(condition.get("value") or "").strip() + if not value: + return False + try: + obj = json.loads(text or "") + except Exception: + return False + return str(obj) == value + + return False diff --git a/data/audio_examples/single_utterance_16k.wav b/data/audio_examples/single_utterance_16k.wav new file mode 100644 index 0000000..8c7bbe5 Binary files /dev/null and b/data/audio_examples/single_utterance_16k.wav differ diff --git a/data/audio_examples/three_utterances.wav b/data/audio_examples/three_utterances.wav new file mode 100644 index 0000000..c2dca2f Binary files /dev/null and b/data/audio_examples/three_utterances.wav differ diff --git a/data/audio_examples/two_utterances.wav b/data/audio_examples/two_utterances.wav new file mode 100644 index 0000000..5c66f70 Binary files /dev/null and b/data/audio_examples/two_utterances.wav differ diff --git a/data/vad/silero_vad.onnx b/data/vad/silero_vad.onnx new file mode 100644 index 0000000..b3e3a90 Binary files /dev/null and b/data/vad/silero_vad.onnx differ diff --git a/docs/duplex_interaction.svg b/docs/duplex_interaction.svg new file mode 100644 index 0000000..9ccd0bb --- /dev/null +++ b/docs/duplex_interaction.svg @@ -0,0 +1,96 @@ + + + + + + + + + + Web Client + WS JSON commands + WS binary PCM audio + + + FastAPI /ws + Session + Transport + + + DuplexPipeline + process_audio / process_text + + + ConversationManager + turns + state + + + VADProcessor + speech/silence + + + EOU Detector + end-of-utterance + + + ASR + transcripts + + + LLM (stream) + llmResponse events + + + TTS (stream) + PCM audio + + + Web Client + audio playback + UI + + + JSON / PCM + + + dispatch + + + turn mgmt + + + audio chunks + + + vad status + + + audio buffer + + + EOU -> LLM + + + text stream + + + PCM audio + + + events: trackStart/End + + + UI updates + + + barge-in detection + + + interrupt event + cancel + diff --git a/docs/proejct_todo.md b/docs/proejct_todo.md new file mode 100644 index 0000000..18a9f17 --- /dev/null +++ b/docs/proejct_todo.md @@ -0,0 +1,187 @@ +# OmniSense: 12-Week Sprint Board + Tech Stack (Python Backend) — TODO + +## Scope +- [ ] Build a realtime AI SaaS (OmniSense) focused on web-first audio + video with WebSocket + WebRTC endpoints +- [ ] Deliver assistant builder, tool execution, observability, evals, optional telephony later +- [ ] Keep scope aligned to 2-person team, self-hosted services + +--- + +## Sprint Board (12 weeks, 2-week sprints) +Team assumption: 2 engineers. Scope prioritized to web-first audio + video, with BYO-SFU adapters. + +### Sprint 1 (Weeks 1–2) — Realtime Core MVP (WebSocket + WebRTC Audio) +- Deliverables + - [ ] WebSocket transport: audio in/out streaming (1:1) + - [ ] WebRTC transport: audio in/out streaming (1:1) + - [ ] Adapter contract wired into runtime (transport-agnostic session core) + - [ ] ASR → LLM → TTS pipeline, streaming both directions + - [ ] Basic session state (start/stop, silence timeout) + - [ ] Transcript persistence +- Acceptance criteria + - [ ] < 1.5s median round-trip for short responses + - [ ] Stable streaming for 10+ minute session + +### Sprint 2 (Weeks 3–4) — Video + Realtime UX +- Deliverables + - [ ] WebRTC video capture + streaming (assistant can “see” frames) + - [ ] WebSocket video streaming for local/dev mode + - [ ] Low-latency UI: push-to-talk, live captions, speaking indicator + - [ ] Recording + transcript storage (web sessions) +- Acceptance criteria + - [ ] Video < 2.5s end-to-end latency for analysis + - [ ] Audio quality acceptable (no clipping, jitter handling) + +### Sprint 3 (Weeks 5–6) — Assistant Builder v1 +- Deliverables + - [ ] Assistant schema + versioning + - [ ] UI: Model/Voice/Transcriber/Tools/Video/Transport tabs + - [ ] “Test/Chat/Talk to Assistant” (web) +- Acceptance criteria + - [ ] Create/publish assistant and run a live web session + - [ ] All config changes tracked by version + +### Sprint 4 (Weeks 7–8) — Tooling + Structured Outputs +- Deliverables + - [ ] Tool registry + custom HTTP tools + - [ ] Tool auth secrets management + - [ ] Structured outputs (JSON extraction) +- Acceptance criteria + - [ ] Tool calls executed with retries/timeouts + - [ ] Structured JSON stored per call/session + +### Sprint 5 (Weeks 9–10) — Observability + QA + Dev Platform +- Deliverables + - [ ] Session logs + chat logs + media logs + - [ ] Evals engine + test suites + - [ ] Basic analytics dashboard + - [ ] Public WebSocket API spec + message schema + - [ ] JS/TS SDK (connect, send audio/video, receive transcripts) +- Acceptance criteria + - [ ] Reproducible test suite runs + - [ ] Log filters by assistant/time/status + - [ ] SDK demo app runs end-to-end + +### Sprint 6 (Weeks 11–12) — SaaS Hardening +- Deliverables + - [ ] Org/RBAC + API keys + rate limits + - [ ] Usage metering + credits + - [ ] Stripe billing integration + - [ ] Self-hosted DB ops (migrations, backup/restore, monitoring) +- Acceptance criteria + - [ ] Metered usage per org + - [ ] Credits decrement correctly + - [ ] Optional telephony spike documented (defer build) + - [ ] Enterprise adapter guide published (BYO-SFU) + +--- + +## Tech Stack by Service (Self-Hosted, Web-First) + +### 1) Transport Gateway (Realtime) +- [ ] WebRTC (browser) + WebSocket (lightweight/dev) protocols +- [ ] BYO-SFU adapter (enterprise) + LiveKit optional adapter + WS transport server +- [ ] Python core (FastAPI + asyncio) + Node.js mediasoup adapters when needed +- [ ] Media: Opus/VP8, jitter buffer, VAD, echo cancellation +- [ ] Storage: S3-compatible (MinIO) for recordings + +### 2) ASR Service +- [ ] Whisper (self-hosted) baseline +- [ ] gRPC/WebSocket streaming transport +- [ ] Python native service +- [ ] Optional cloud provider fallback (later) + +### 3) TTS Service +- [ ] Piper or Coqui TTS (self-hosted) +- [ ] gRPC/WebSocket streaming transport +- [ ] Python native service +- [ ] Redis cache for common phrases + +### 4) LLM Orchestrator +- [ ] Self-hosted (vLLM + open model) +- [ ] Python (FastAPI + asyncio) +- [ ] Streaming, tool calling, JSON mode +- [ ] Safety filters + prompt templates + +### 5) Assistant Config Service +- [ ] PostgreSQL +- [ ] Python (SQLAlchemy or SQLModel) +- [ ] Versioning, publish/rollback + +### 6) Session Service +- [ ] PostgreSQL + Redis +- [ ] Python +- [ ] State machine, timeouts, events + +### 7) Tool Execution Layer +- [ ] PostgreSQL +- [ ] Python +- [ ] Auth secret vault, retry policies, tool schemas + +### 8) Observability + Logs +- [ ] Postgres (metadata), ClickHouse (logs/metrics) +- [ ] OpenSearch for search +- [ ] Prometheus + Grafana metrics +- [ ] OpenTelemetry tracing + +### 9) Billing + Usage Metering +- [ ] Stripe billing +- [ ] PostgreSQL +- [ ] NATS JetStream (events) + Redis counters + +### 10) Web App (Dashboard) +- [ ] React + Next.js +- [ ] Tailwind or Radix UI +- [ ] WebRTC client + WS client; adapter-based RTC integration +- [ ] ECharts/Recharts + +### 11) Auth + RBAC +- [ ] Keycloak (self-hosted) or custom JWT +- [ ] Org/user/role tables in Postgres + +### 12) Public WebSocket API + SDK +- [ ] WS API: versioned schema, binary audio frames + JSON control messages +- [ ] SDKs: JS/TS first, optional Python/Go clients +- [ ] Docs: quickstart, auth flow, session lifecycle, examples + +--- + +## Infrastructure (Self-Hosted) +- [ ] Docker Compose → k3s (later) +- [ ] Redis Streams or NATS +- [ ] MinIO object store +- [ ] GitHub Actions + Helm or kustomize +- [ ] Self-hosted Postgres + pgbackrest backups +- [ ] Vault for secrets + +--- + +## Suggested MVP Sequence +- [ ] WebRTC demo + ASR/LLM/TTS streaming +- [ ] Assistant schema + versioning (web-first) +- [ ] Video capture + multimodal analysis +- [ ] Tool execution + structured outputs +- [ ] Logs + evals + public WS API + SDK +- [ ] Telephony (optional, later) + +--- + +## Public WebSocket API (Minimum Spec) +- [ ] Auth: API key or JWT in initial `hello` message +- [ ] Core messages: `session.start`, `session.stop`, `audio.append`, `audio.commit`, `video.append`, `transcript.delta`, `assistant.response`, `tool.call`, `tool.result`, `error` +- [ ] Binary payloads: PCM/Opus frames with metadata in control channel +- [ ] Versioning: `v1` schema with backward compatibility rules + +--- + +## Self-Hosted DB Ops Checklist +- [ ] Postgres in Docker/k3s with persistent volumes +- [ ] Migrations: `alembic` or `atlas` +- [ ] Backups: `pgbackrest` nightly + on-demand +- [ ] Monitoring: postgres_exporter + alerts + +--- + +## RTC Adapter Contract (BYO-SFU First) +- [ ] Keep RTC pluggable; LiveKit optional, not core dependency +- [ ] Define adapter interface (TypeScript sketch) \ No newline at end of file diff --git a/docs/ws_v1_schema.md b/docs/ws_v1_schema.md new file mode 100644 index 0000000..9db0900 --- /dev/null +++ b/docs/ws_v1_schema.md @@ -0,0 +1,199 @@ +# WS v1 Protocol Schema (`/ws`) + +This document defines the public WebSocket protocol for the `/ws` endpoint. + +## Transport + +- A single WebSocket connection carries: + - JSON text frames for control/events. + - Binary frames for raw PCM audio (`pcm_s16le`, mono, 16kHz by default). + +## Handshake and State Machine + +Required message order: + +1. Client sends `hello`. +2. Server replies `hello.ack`. +3. Client sends `session.start`. +4. Server replies `session.started`. +5. Client may stream binary audio and/or send `input.text`. +6. Client sends `session.stop` (or closes socket). + +If order is violated, server emits `error` with `code = "protocol.order"`. + +## Client -> Server Messages + +### `hello` + +```json +{ + "type": "hello", + "version": "v1", + "auth": { + "apiKey": "optional-api-key", + "jwt": "optional-jwt" + } +} +``` + +Rules: +- `version` must be `v1`. +- If `WS_API_KEY` is configured on server, `auth.apiKey` must match. +- If `WS_REQUIRE_AUTH=true`, either `auth.apiKey` or `auth.jwt` must be present. + +### `session.start` + +```json +{ + "type": "session.start", + "audio": { + "encoding": "pcm_s16le", + "sample_rate_hz": 16000, + "channels": 1 + }, + "metadata": { + "client": "web-debug", + "output": { + "mode": "audio" + }, + "systemPrompt": "You are concise.", + "greeting": "Hi, how can I help?", + "services": { + "llm": { + "provider": "openai", + "model": "gpt-4o-mini", + "apiKey": "sk-...", + "baseUrl": "https://api.openai.com/v1" + }, + "asr": { + "provider": "openai_compatible", + "model": "FunAudioLLM/SenseVoiceSmall", + "apiKey": "sf-...", + "interimIntervalMs": 500, + "minAudioMs": 300 + }, + "tts": { + "enabled": true, + "provider": "openai_compatible", + "model": "FunAudioLLM/CosyVoice2-0.5B", + "apiKey": "sf-...", + "voice": "anna", + "speed": 1.0 + } + } + } +} +``` + +`metadata.services` is optional. If omitted, server defaults to environment configuration. + +Text-only mode: +- Set `metadata.output.mode = "text"` OR `metadata.services.tts.enabled = false`. +- In this mode server still sends `assistant.response.delta/final`, but will not emit audio frames or `output.audio.start/end`. + +### `input.text` + +```json +{ + "type": "input.text", + "text": "What can you do?" +} +``` + +### `response.cancel` + +```json +{ + "type": "response.cancel", + "graceful": false +} +``` + +### `session.stop` + +```json +{ + "type": "session.stop", + "reason": "client_disconnect" +} +``` + +### `tool_call.results` + +Client tool execution results returned to server. + +```json +{ + "type": "tool_call.results", + "results": [ + { + "tool_call_id": "call_abc123", + "name": "weather", + "output": { "temp_c": 21, "condition": "sunny" }, + "status": { "code": 200, "message": "ok" } + } + ] +} +``` + +## Server -> Client Events + +All server events include: + +```json +{ + "type": "event.name", + "timestamp": 1730000000000 +} +``` + +Common events: + +- `hello.ack` + - Fields: `sessionId`, `version` +- `session.started` + - Fields: `sessionId`, `trackId`, `audio` +- `session.stopped` + - Fields: `sessionId`, `reason` +- `heartbeat` +- `input.speech_started` + - Fields: `trackId`, `probability` +- `input.speech_stopped` + - Fields: `trackId`, `probability` +- `transcript.delta` + - Fields: `trackId`, `text` +- `transcript.final` + - Fields: `trackId`, `text` +- `assistant.response.delta` + - Fields: `trackId`, `text` +- `assistant.response.final` + - Fields: `trackId`, `text` +- `assistant.tool_call` + - Fields: `trackId`, `tool_call` (`tool_call.executor` is `client` or `server`) +- `assistant.tool_result` + - Fields: `trackId`, `source`, `result` +- `output.audio.start` + - Fields: `trackId` +- `output.audio.end` + - Fields: `trackId` +- `response.interrupted` + - Fields: `trackId` +- `metrics.ttfb` + - Fields: `trackId`, `latencyMs` +- `error` + - Fields: `sender`, `code`, `message`, `trackId` + +## Binary Audio Frames + +After `session.started`, client may send binary PCM chunks continuously. + +Recommended format: +- 16-bit signed little-endian PCM. +- 1 channel. +- 16000 Hz. +- 20ms frames (640 bytes) preferred. + +## Compatibility + +This endpoint now enforces v1 message schema for JSON control frames. +Legacy command names (`invite`, `chat`, etc.) are no longer part of the public protocol. diff --git a/examples/mic_client.py b/examples/mic_client.py new file mode 100644 index 0000000..509aeaa --- /dev/null +++ b/examples/mic_client.py @@ -0,0 +1,601 @@ +#!/usr/bin/env python3 +""" +Microphone client for testing duplex voice conversation. + +This client captures audio from the microphone, sends it to the server, +and plays back the AI's voice response through the speakers. +It also displays the LLM's text responses in the console. + +Usage: + python examples/mic_client.py --url ws://localhost:8000/ws + python examples/mic_client.py --url ws://localhost:8000/ws --chat "Hello!" + python examples/mic_client.py --url ws://localhost:8000/ws --verbose + +Requirements: + pip install sounddevice soundfile websockets numpy +""" + +import argparse +import asyncio +import json +import sys +import time +import threading +import queue +from pathlib import Path + +try: + import numpy as np +except ImportError: + print("Please install numpy: pip install numpy") + sys.exit(1) + +try: + import sounddevice as sd +except ImportError: + print("Please install sounddevice: pip install sounddevice") + sys.exit(1) + +try: + import websockets +except ImportError: + print("Please install websockets: pip install websockets") + sys.exit(1) + + +class MicrophoneClient: + """ + Full-duplex microphone client for voice conversation. + + Features: + - Real-time microphone capture + - Real-time speaker playback + - WebSocket communication + - Text chat support + """ + + def __init__( + self, + url: str, + sample_rate: int = 16000, + chunk_duration_ms: int = 20, + input_device: int = None, + output_device: int = None + ): + """ + Initialize microphone client. + + Args: + url: WebSocket server URL + sample_rate: Audio sample rate (Hz) + chunk_duration_ms: Audio chunk duration (ms) + input_device: Input device ID (None for default) + output_device: Output device ID (None for default) + """ + self.url = url + self.sample_rate = sample_rate + self.chunk_duration_ms = chunk_duration_ms + self.chunk_samples = int(sample_rate * chunk_duration_ms / 1000) + self.input_device = input_device + self.output_device = output_device + + # WebSocket connection + self.ws = None + self.running = False + + # Audio buffers + self.audio_input_queue = queue.Queue() + self.audio_output_buffer = b"" # Continuous buffer for smooth playback + self.audio_output_lock = threading.Lock() + + # Statistics + self.bytes_sent = 0 + self.bytes_received = 0 + + # State + self.is_recording = True + self.is_playing = True + + # TTFB tracking (Time to First Byte) + self.request_start_time = None + self.first_audio_received = False + + # Interrupt handling - discard audio until next trackStart + self._discard_audio = False + self._audio_sequence = 0 # Track audio sequence to detect stale chunks + + # Verbose mode for streaming LLM responses + self.verbose = False + + async def connect(self) -> None: + """Connect to WebSocket server.""" + print(f"Connecting to {self.url}...") + self.ws = await websockets.connect(self.url) + self.running = True + print("Connected!") + + # Send invite command + await self.send_command({ + "command": "invite", + "option": { + "codec": "pcm", + "sampleRate": self.sample_rate + } + }) + + async def send_command(self, cmd: dict) -> None: + """Send JSON command to server.""" + if self.ws: + await self.ws.send(json.dumps(cmd)) + print(f"→ Command: {cmd.get('command', 'unknown')}") + + async def send_chat(self, text: str) -> None: + """Send chat message (text input).""" + # Reset TTFB tracking for new request + self.request_start_time = time.time() + self.first_audio_received = False + + await self.send_command({ + "command": "chat", + "text": text + }) + print(f"→ Chat: {text}") + + async def send_interrupt(self) -> None: + """Send interrupt command.""" + await self.send_command({ + "command": "interrupt" + }) + + async def send_hangup(self, reason: str = "User quit") -> None: + """Send hangup command.""" + await self.send_command({ + "command": "hangup", + "reason": reason + }) + + def _audio_input_callback(self, indata, frames, time, status): + """Callback for audio input (microphone).""" + if status: + print(f"Input status: {status}") + + if self.is_recording and self.running: + # Convert to 16-bit PCM + audio_data = (indata[:, 0] * 32767).astype(np.int16).tobytes() + self.audio_input_queue.put(audio_data) + + def _add_audio_to_buffer(self, audio_data: bytes): + """Add audio data to playback buffer.""" + with self.audio_output_lock: + self.audio_output_buffer += audio_data + + def _playback_thread_func(self): + """Thread function for continuous audio playback.""" + import time + + # Chunk size: 50ms of audio + chunk_samples = int(self.sample_rate * 0.05) + chunk_bytes = chunk_samples * 2 + + print(f"Audio playback thread started (device: {self.output_device or 'default'})") + + try: + # Create output stream with callback + with sd.OutputStream( + samplerate=self.sample_rate, + channels=1, + dtype='int16', + blocksize=chunk_samples, + device=self.output_device, + latency='low' + ) as stream: + while self.running: + # Get audio from buffer + with self.audio_output_lock: + if len(self.audio_output_buffer) >= chunk_bytes: + audio_data = self.audio_output_buffer[:chunk_bytes] + self.audio_output_buffer = self.audio_output_buffer[chunk_bytes:] + else: + # Not enough audio - output silence + audio_data = b'\x00' * chunk_bytes + + # Convert to numpy array and write to stream + samples = np.frombuffer(audio_data, dtype=np.int16).reshape(-1, 1) + stream.write(samples) + + except Exception as e: + print(f"Playback thread error: {e}") + import traceback + traceback.print_exc() + + async def _playback_task(self): + """Start playback thread and monitor it.""" + # Run playback in a dedicated thread for reliable timing + playback_thread = threading.Thread(target=self._playback_thread_func, daemon=True) + playback_thread.start() + + # Wait for client to stop + while self.running and playback_thread.is_alive(): + await asyncio.sleep(0.1) + + print("Audio playback stopped") + + async def audio_sender(self) -> None: + """Send audio from microphone to server.""" + while self.running: + try: + # Get audio from queue with timeout + try: + audio_data = await asyncio.get_event_loop().run_in_executor( + None, lambda: self.audio_input_queue.get(timeout=0.1) + ) + except queue.Empty: + continue + + # Send to server + if self.ws and self.is_recording: + await self.ws.send(audio_data) + self.bytes_sent += len(audio_data) + + except asyncio.CancelledError: + break + except Exception as e: + print(f"Audio sender error: {e}") + break + + async def receiver(self) -> None: + """Receive messages from server.""" + try: + while self.running: + try: + message = await asyncio.wait_for(self.ws.recv(), timeout=0.1) + + if isinstance(message, bytes): + # Audio data received + self.bytes_received += len(message) + + # Check if we should discard this audio (after interrupt) + if self._discard_audio: + duration_ms = len(message) / (self.sample_rate * 2) * 1000 + print(f"← Audio: {duration_ms:.0f}ms (DISCARDED - waiting for new track)") + continue + + if self.is_playing: + self._add_audio_to_buffer(message) + + # Calculate and display TTFB for first audio packet + if not self.first_audio_received and self.request_start_time: + client_ttfb_ms = (time.time() - self.request_start_time) * 1000 + self.first_audio_received = True + print(f"← [TTFB] Client first audio latency: {client_ttfb_ms:.0f}ms") + + # Show progress (less verbose) + with self.audio_output_lock: + buffer_ms = len(self.audio_output_buffer) / (self.sample_rate * 2) * 1000 + duration_ms = len(message) / (self.sample_rate * 2) * 1000 + print(f"← Audio: {duration_ms:.0f}ms (buffer: {buffer_ms:.0f}ms)") + + else: + # JSON event + event = json.loads(message) + await self._handle_event(event) + + except asyncio.TimeoutError: + continue + except websockets.ConnectionClosed: + print("Connection closed") + self.running = False + break + + except asyncio.CancelledError: + pass + except Exception as e: + print(f"Receiver error: {e}") + self.running = False + + async def _handle_event(self, event: dict) -> None: + """Handle incoming event.""" + event_type = event.get("event", "unknown") + + if event_type == "answer": + print("← Session ready!") + elif event_type == "speaking": + print("← User speech detected") + elif event_type == "silence": + print("← User silence detected") + elif event_type == "transcript": + # Display user speech transcription + text = event.get("text", "") + is_final = event.get("isFinal", False) + if is_final: + # Clear the interim line and print final + print(" " * 80, end="\r") # Clear previous interim text + print(f"→ You: {text}") + else: + # Interim result - show with indicator (overwrite same line) + display_text = text[:60] + "..." if len(text) > 60 else text + print(f" [listening] {display_text}".ljust(80), end="\r") + elif event_type == "ttfb": + # Server-side TTFB event + latency_ms = event.get("latencyMs", 0) + print(f"← [TTFB] Server reported latency: {latency_ms}ms") + elif event_type == "llmResponse": + # LLM text response + text = event.get("text", "") + is_final = event.get("isFinal", False) + if is_final: + # Print final LLM response + print(f"← AI: {text}") + elif self.verbose: + # Show streaming chunks only in verbose mode + display_text = text[:60] + "..." if len(text) > 60 else text + print(f" [streaming] {display_text}") + elif event_type == "trackStart": + print("← Bot started speaking") + # IMPORTANT: Accept audio again after trackStart + self._discard_audio = False + self._audio_sequence += 1 + # Reset TTFB tracking for voice responses (when no chat was sent) + if self.request_start_time is None: + self.request_start_time = time.time() + self.first_audio_received = False + # Clear any old audio in buffer + with self.audio_output_lock: + self.audio_output_buffer = b"" + elif event_type == "trackEnd": + print("← Bot finished speaking") + # Reset TTFB tracking after response completes + self.request_start_time = None + self.first_audio_received = False + elif event_type == "interrupt": + print("← Bot interrupted!") + # IMPORTANT: Discard all audio until next trackStart + self._discard_audio = True + # Clear audio buffer immediately + with self.audio_output_lock: + buffer_ms = len(self.audio_output_buffer) / (self.sample_rate * 2) * 1000 + self.audio_output_buffer = b"" + print(f" (cleared {buffer_ms:.0f}ms, discarding audio until new track)") + elif event_type == "error": + print(f"← Error: {event.get('error')}") + elif event_type == "hangup": + print(f"← Hangup: {event.get('reason')}") + self.running = False + else: + print(f"← Event: {event_type}") + + async def interactive_mode(self) -> None: + """Run interactive mode for text chat.""" + print("\n" + "=" * 50) + print("Voice Conversation Client") + print("=" * 50) + print("Speak into your microphone to talk to the AI.") + print("Or type messages to send text.") + print("") + print("Commands:") + print(" /quit - End conversation") + print(" /mute - Mute microphone") + print(" /unmute - Unmute microphone") + print(" /interrupt - Interrupt AI speech") + print(" /stats - Show statistics") + print("=" * 50 + "\n") + + while self.running: + try: + user_input = await asyncio.get_event_loop().run_in_executor( + None, input, "" + ) + + if not user_input: + continue + + # Handle commands + if user_input.startswith("/"): + cmd = user_input.lower().strip() + + if cmd == "/quit": + await self.send_hangup("User quit") + break + elif cmd == "/mute": + self.is_recording = False + print("Microphone muted") + elif cmd == "/unmute": + self.is_recording = True + print("Microphone unmuted") + elif cmd == "/interrupt": + await self.send_interrupt() + elif cmd == "/stats": + print(f"Sent: {self.bytes_sent / 1024:.1f} KB") + print(f"Received: {self.bytes_received / 1024:.1f} KB") + else: + print(f"Unknown command: {cmd}") + else: + # Send as chat message + await self.send_chat(user_input) + + except EOFError: + break + except Exception as e: + print(f"Input error: {e}") + + async def run(self, chat_message: str = None, interactive: bool = True) -> None: + """ + Run the client. + + Args: + chat_message: Optional single chat message to send + interactive: Whether to run in interactive mode + """ + try: + await self.connect() + + # Wait for answer + await asyncio.sleep(0.5) + + # Start audio input stream + print("Starting audio streams...") + + input_stream = sd.InputStream( + samplerate=self.sample_rate, + channels=1, + dtype=np.float32, + blocksize=self.chunk_samples, + device=self.input_device, + callback=self._audio_input_callback + ) + + input_stream.start() + print("Audio streams started") + + # Start background tasks + sender_task = asyncio.create_task(self.audio_sender()) + receiver_task = asyncio.create_task(self.receiver()) + playback_task = asyncio.create_task(self._playback_task()) + + if chat_message: + # Send single message and wait + await self.send_chat(chat_message) + await asyncio.sleep(15) + elif interactive: + # Run interactive mode + await self.interactive_mode() + else: + # Just wait + while self.running: + await asyncio.sleep(0.1) + + # Cleanup + self.running = False + sender_task.cancel() + receiver_task.cancel() + playback_task.cancel() + + try: + await sender_task + except asyncio.CancelledError: + pass + + try: + await receiver_task + except asyncio.CancelledError: + pass + + try: + await playback_task + except asyncio.CancelledError: + pass + + input_stream.stop() + + except ConnectionRefusedError: + print(f"Error: Could not connect to {self.url}") + print("Make sure the server is running.") + except Exception as e: + print(f"Error: {e}") + finally: + await self.close() + + async def close(self) -> None: + """Close the connection.""" + self.running = False + if self.ws: + await self.ws.close() + + print(f"\nSession ended") + print(f" Total sent: {self.bytes_sent / 1024:.1f} KB") + print(f" Total received: {self.bytes_received / 1024:.1f} KB") + + +def list_devices(): + """List available audio devices.""" + print("\nAvailable audio devices:") + print("-" * 60) + devices = sd.query_devices() + for i, device in enumerate(devices): + direction = [] + if device['max_input_channels'] > 0: + direction.append("IN") + if device['max_output_channels'] > 0: + direction.append("OUT") + direction_str = "/".join(direction) if direction else "N/A" + + default = "" + if i == sd.default.device[0]: + default += " [DEFAULT INPUT]" + if i == sd.default.device[1]: + default += " [DEFAULT OUTPUT]" + + print(f" {i:2d}: {device['name'][:40]:40s} ({direction_str}){default}") + print("-" * 60) + + +async def main(): + parser = argparse.ArgumentParser( + description="Microphone client for duplex voice conversation" + ) + parser.add_argument( + "--url", + default="ws://localhost:8000/ws", + help="WebSocket server URL" + ) + parser.add_argument( + "--chat", + help="Send a single chat message instead of using microphone" + ) + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="Audio sample rate (default: 16000)" + ) + parser.add_argument( + "--input-device", + type=int, + help="Input device ID" + ) + parser.add_argument( + "--output-device", + type=int, + help="Output device ID" + ) + parser.add_argument( + "--list-devices", + action="store_true", + help="List available audio devices and exit" + ) + parser.add_argument( + "--no-interactive", + action="store_true", + help="Disable interactive mode" + ) + parser.add_argument( + "--verbose", "-v", + action="store_true", + help="Show streaming LLM response chunks" + ) + + args = parser.parse_args() + + if args.list_devices: + list_devices() + return + + client = MicrophoneClient( + url=args.url, + sample_rate=args.sample_rate, + input_device=args.input_device, + output_device=args.output_device + ) + client.verbose = args.verbose + + await client.run( + chat_message=args.chat, + interactive=not args.no_interactive + ) + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + print("\nInterrupted by user") diff --git a/examples/simple_client.py b/examples/simple_client.py new file mode 100644 index 0000000..4280f93 --- /dev/null +++ b/examples/simple_client.py @@ -0,0 +1,285 @@ +#!/usr/bin/env python3 +""" +Simple WebSocket client for testing voice conversation. +Uses PyAudio for more reliable audio playback on Windows. + +Usage: + python examples/simple_client.py + python examples/simple_client.py --text "Hello" +""" + +import argparse +import asyncio +import json +import sys +import time +import wave +import io + +try: + import numpy as np +except ImportError: + print("pip install numpy") + sys.exit(1) + +try: + import websockets +except ImportError: + print("pip install websockets") + sys.exit(1) + +# Try PyAudio first (more reliable on Windows) +try: + import pyaudio + PYAUDIO_AVAILABLE = True +except ImportError: + PYAUDIO_AVAILABLE = False + print("PyAudio not available, trying sounddevice...") + +try: + import sounddevice as sd + SD_AVAILABLE = True +except ImportError: + SD_AVAILABLE = False + +if not PYAUDIO_AVAILABLE and not SD_AVAILABLE: + print("Please install pyaudio or sounddevice:") + print(" pip install pyaudio") + print(" or: pip install sounddevice") + sys.exit(1) + + +class SimpleVoiceClient: + """Simple voice client with reliable audio playback.""" + + def __init__(self, url: str, sample_rate: int = 16000): + self.url = url + self.sample_rate = sample_rate + self.ws = None + self.running = False + + # Audio buffer + self.audio_buffer = b"" + + # PyAudio setup + if PYAUDIO_AVAILABLE: + self.pa = pyaudio.PyAudio() + self.stream = None + + # Stats + self.bytes_received = 0 + + # TTFB tracking (Time to First Byte) + self.request_start_time = None + self.first_audio_received = False + + # Interrupt handling - discard audio until next trackStart + self._discard_audio = False + + async def connect(self): + """Connect to server.""" + print(f"Connecting to {self.url}...") + self.ws = await websockets.connect(self.url) + self.running = True + print("Connected!") + + # Send invite + await self.ws.send(json.dumps({ + "command": "invite", + "option": {"codec": "pcm", "sampleRate": self.sample_rate} + })) + print("-> invite") + + async def send_chat(self, text: str): + """Send chat message.""" + # Reset TTFB tracking for new request + self.request_start_time = time.time() + self.first_audio_received = False + + await self.ws.send(json.dumps({"command": "chat", "text": text})) + print(f"-> chat: {text}") + + def play_audio(self, audio_data: bytes): + """Play audio data immediately.""" + if len(audio_data) == 0: + return + + if PYAUDIO_AVAILABLE: + # Use PyAudio - more reliable on Windows + if self.stream is None: + self.stream = self.pa.open( + format=pyaudio.paInt16, + channels=1, + rate=self.sample_rate, + output=True, + frames_per_buffer=1024 + ) + self.stream.write(audio_data) + elif SD_AVAILABLE: + # Use sounddevice + samples = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32767.0 + sd.play(samples, self.sample_rate, blocking=True) + + async def receive_loop(self): + """Receive and play audio.""" + print("\nWaiting for response...") + + while self.running: + try: + msg = await asyncio.wait_for(self.ws.recv(), timeout=0.1) + + if isinstance(msg, bytes): + # Audio data + self.bytes_received += len(msg) + duration_ms = len(msg) / (self.sample_rate * 2) * 1000 + + # Check if we should discard this audio (after interrupt) + if self._discard_audio: + print(f"<- audio: {len(msg)} bytes ({duration_ms:.0f}ms) [DISCARDED]") + continue + + # Calculate and display TTFB for first audio packet + if not self.first_audio_received and self.request_start_time: + client_ttfb_ms = (time.time() - self.request_start_time) * 1000 + self.first_audio_received = True + print(f"<- [TTFB] Client first audio latency: {client_ttfb_ms:.0f}ms") + + print(f"<- audio: {len(msg)} bytes ({duration_ms:.0f}ms)") + + # Play immediately in executor to not block + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, self.play_audio, msg) + else: + # JSON event + event = json.loads(msg) + etype = event.get("event", "?") + + if etype == "transcript": + # User speech transcription + text = event.get("text", "") + is_final = event.get("isFinal", False) + if is_final: + print(f"<- You said: {text}") + else: + print(f"<- [listening] {text}", end="\r") + elif etype == "ttfb": + # Server-side TTFB event + latency_ms = event.get("latencyMs", 0) + print(f"<- [TTFB] Server reported latency: {latency_ms}ms") + elif etype == "trackStart": + # New track starting - accept audio again + self._discard_audio = False + print(f"<- {etype}") + elif etype == "interrupt": + # Interrupt - discard audio until next trackStart + self._discard_audio = True + print(f"<- {etype} (discarding audio until new track)") + elif etype == "hangup": + print(f"<- {etype}") + self.running = False + break + else: + print(f"<- {etype}") + + except asyncio.TimeoutError: + continue + except websockets.ConnectionClosed: + print("Connection closed") + self.running = False + break + + async def run(self, text: str = None): + """Run the client.""" + try: + await self.connect() + await asyncio.sleep(0.5) + + # Start receiver + recv_task = asyncio.create_task(self.receive_loop()) + + if text: + await self.send_chat(text) + # Wait for response + await asyncio.sleep(30) + else: + # Interactive mode + print("\nType a message and press Enter (or 'quit' to exit):") + while self.running: + try: + user_input = await asyncio.get_event_loop().run_in_executor( + None, input, "> " + ) + if user_input.lower() == 'quit': + break + if user_input.strip(): + await self.send_chat(user_input) + except EOFError: + break + + self.running = False + recv_task.cancel() + try: + await recv_task + except asyncio.CancelledError: + pass + + finally: + await self.close() + + async def close(self): + """Close connections.""" + self.running = False + + if PYAUDIO_AVAILABLE: + if self.stream: + self.stream.stop_stream() + self.stream.close() + self.pa.terminate() + + if self.ws: + await self.ws.close() + + print(f"\nTotal audio received: {self.bytes_received / 1024:.1f} KB") + + +def list_audio_devices(): + """List available audio devices.""" + print("\n=== Audio Devices ===") + + if PYAUDIO_AVAILABLE: + pa = pyaudio.PyAudio() + print("\nPyAudio devices:") + for i in range(pa.get_device_count()): + info = pa.get_device_info_by_index(i) + if info['maxOutputChannels'] > 0: + default = " [DEFAULT]" if i == pa.get_default_output_device_info()['index'] else "" + print(f" {i}: {info['name']}{default}") + pa.terminate() + + if SD_AVAILABLE: + print("\nSounddevice devices:") + for i, d in enumerate(sd.query_devices()): + if d['max_output_channels'] > 0: + default = " [DEFAULT]" if i == sd.default.device[1] else "" + print(f" {i}: {d['name']}{default}") + + +async def main(): + parser = argparse.ArgumentParser(description="Simple voice client") + parser.add_argument("--url", default="ws://localhost:8000/ws") + parser.add_argument("--text", help="Send text and play response") + parser.add_argument("--list-devices", action="store_true") + parser.add_argument("--sample-rate", type=int, default=16000) + + args = parser.parse_args() + + if args.list_devices: + list_audio_devices() + return + + client = SimpleVoiceClient(args.url, args.sample_rate) + await client.run(args.text) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/test_websocket.py b/examples/test_websocket.py new file mode 100644 index 0000000..0d2675d --- /dev/null +++ b/examples/test_websocket.py @@ -0,0 +1,176 @@ +"""WebSocket endpoint test client. + +Tests the /ws endpoint with sine wave or file audio streaming. +Based on reference/py-active-call/exec/test_ws_endpoint/test_ws.py +""" + +import asyncio +import aiohttp +import json +import struct +import math +import argparse +import os +from datetime import datetime + +# Configuration +SERVER_URL = "ws://localhost:8000/ws" +SAMPLE_RATE = 16000 +FREQUENCY = 440 # 440Hz Sine Wave +CHUNK_DURATION_MS = 20 +# 16kHz * 16-bit (2 bytes) * 20ms = 640 bytes per chunk +CHUNK_SIZE_BYTES = int(SAMPLE_RATE * 2 * (CHUNK_DURATION_MS / 1000.0)) + + +def generate_sine_wave(duration_ms=1000): + """Generates sine wave audio (16kHz mono PCM 16-bit).""" + num_samples = int(SAMPLE_RATE * (duration_ms / 1000.0)) + audio_data = bytearray() + + for x in range(num_samples): + # Generate sine wave sample + value = int(32767.0 * math.sin(2 * math.pi * FREQUENCY * x / SAMPLE_RATE)) + # Pack as little-endian 16-bit integer + audio_data.extend(struct.pack(' None: + """Connect to WebSocket server.""" + self.log_event("→", f"Connecting to {self.url}...") + self.ws = await websockets.connect(self.url) + self.running = True + self.log_event("←", "Connected!") + + # Send invite command + await self.send_command({ + "command": "invite", + "option": { + "codec": "pcm", + "sampleRate": self.sample_rate + } + }) + + async def send_command(self, cmd: dict) -> None: + """Send JSON command to server.""" + if self.ws: + await self.ws.send(json.dumps(cmd)) + self.log_event("→", f"Command: {cmd.get('command', 'unknown')}") + + async def send_hangup(self, reason: str = "Session complete") -> None: + """Send hangup command.""" + await self.send_command({ + "command": "hangup", + "reason": reason + }) + + def load_wav_file(self) -> tuple[np.ndarray, int]: + """ + Load and prepare WAV file for sending. + + Returns: + Tuple of (audio_data as int16 numpy array, original sample rate) + """ + if not self.input_file.exists(): + raise FileNotFoundError(f"Input file not found: {self.input_file}") + + # Load audio file + audio_data, file_sample_rate = sf.read(self.input_file) + self.log_event("→", f"Loaded: {self.input_file}") + self.log_event("→", f" Original sample rate: {file_sample_rate} Hz") + self.log_event("→", f" Duration: {len(audio_data) / file_sample_rate:.2f}s") + + # Convert stereo to mono if needed + if len(audio_data.shape) > 1: + audio_data = audio_data.mean(axis=1) + self.log_event("→", " Converted stereo to mono") + + # Resample if needed + if file_sample_rate != self.sample_rate: + # Simple resampling using numpy + duration = len(audio_data) / file_sample_rate + num_samples = int(duration * self.sample_rate) + indices = np.linspace(0, len(audio_data) - 1, num_samples) + audio_data = np.interp(indices, np.arange(len(audio_data)), audio_data) + self.log_event("→", f" Resampled to {self.sample_rate} Hz") + + # Convert to int16 + if audio_data.dtype != np.int16: + # Normalize to [-1, 1] if needed + max_val = np.max(np.abs(audio_data)) + if max_val > 1.0: + audio_data = audio_data / max_val + audio_data = (audio_data * 32767).astype(np.int16) + + self.log_event("→", f" Prepared: {len(audio_data)} samples ({len(audio_data)/self.sample_rate:.2f}s)") + + return audio_data, file_sample_rate + + async def audio_sender(self, audio_data: np.ndarray) -> None: + """Send audio data to server in chunks.""" + total_samples = len(audio_data) + chunk_size = self.chunk_samples + sent_samples = 0 + + self.send_start_time = time.time() + self.log_event("→", f"Starting audio transmission ({total_samples} samples)...") + + while sent_samples < total_samples and self.running: + # Get next chunk + end_sample = min(sent_samples + chunk_size, total_samples) + chunk = audio_data[sent_samples:end_sample] + chunk_bytes = chunk.tobytes() + + # Send to server + if self.ws: + await self.ws.send(chunk_bytes) + self.bytes_sent += len(chunk_bytes) + + sent_samples = end_sample + + # Progress logging (every 500ms worth of audio) + if self.verbose and sent_samples % (self.sample_rate // 2) == 0: + progress = (sent_samples / total_samples) * 100 + print(f" Sending: {progress:.0f}%", end="\r") + + # Delay to simulate real-time streaming + # Server expects audio at real-time pace for VAD/ASR to work properly + await asyncio.sleep(self.chunk_duration_ms / 1000) + + self.send_completed = True + elapsed = time.time() - self.send_start_time + self.log_event("→", f"Audio transmission complete ({elapsed:.2f}s, {self.bytes_sent/1024:.1f} KB)") + + async def receiver(self) -> None: + """Receive messages from server.""" + try: + while self.running: + try: + message = await asyncio.wait_for(self.ws.recv(), timeout=0.1) + + if isinstance(message, bytes): + # Audio data received + self.bytes_received += len(message) + self.received_audio.extend(message) + + # Calculate TTFB on first audio of each response + if self.waiting_for_first_audio and self.response_start_time is not None: + ttfb_ms = (time.time() - self.response_start_time) * 1000 + self.ttfb_ms = ttfb_ms + self.ttfb_list.append(ttfb_ms) + self.waiting_for_first_audio = False + self.log_event("←", f"[TTFB] First audio latency: {ttfb_ms:.0f}ms") + + # Log progress + duration_ms = len(message) / (self.sample_rate * 2) * 1000 + total_ms = len(self.received_audio) / (self.sample_rate * 2) * 1000 + if self.verbose: + print(f"← Audio: +{duration_ms:.0f}ms (total: {total_ms:.0f}ms)", end="\r") + + else: + # JSON event + event = json.loads(message) + await self._handle_event(event) + + except asyncio.TimeoutError: + continue + except websockets.ConnectionClosed: + self.log_event("←", "Connection closed") + self.running = False + break + + except asyncio.CancelledError: + pass + except Exception as e: + self.log_event("!", f"Receiver error: {e}") + self.running = False + + async def _handle_event(self, event: dict) -> None: + """Handle incoming event.""" + event_type = event.get("event", "unknown") + + if event_type == "answer": + self.log_event("←", "Session ready!") + elif event_type == "speaking": + self.log_event("←", "Speech detected") + elif event_type == "silence": + self.log_event("←", "Silence detected") + elif event_type == "transcript": + # ASR transcript (interim = asrDelta-style, final = asrFinal-style) + text = event.get("text", "") + is_final = event.get("isFinal", False) + if is_final: + # Clear interim line and print final + print(" " * 80, end="\r") + self.log_event("←", f"→ You: {text}") + else: + # Interim result - show with indicator (overwrite same line, as in mic_client) + display_text = text[:60] + "..." if len(text) > 60 else text + print(f" [listening] {display_text}".ljust(80), end="\r") + elif event_type == "ttfb": + latency_ms = event.get("latencyMs", 0) + self.log_event("←", f"[TTFB] Server latency: {latency_ms}ms") + elif event_type == "llmResponse": + text = event.get("text", "") + is_final = event.get("isFinal", False) + if is_final: + self.log_event("←", f"LLM Response (final): {text[:100]}{'...' if len(text) > 100 else ''}") + elif self.verbose: + # Show streaming chunks only in verbose mode + self.log_event("←", f"LLM: {text}") + elif event_type == "trackStart": + self.track_started = True + self.response_start_time = time.time() + self.waiting_for_first_audio = True + self.log_event("←", "Bot started speaking") + elif event_type == "trackEnd": + self.track_ended = True + self.log_event("←", "Bot finished speaking") + elif event_type == "interrupt": + self.log_event("←", "Bot interrupted!") + elif event_type == "error": + self.log_event("!", f"Error: {event.get('error')}") + elif event_type == "hangup": + self.log_event("←", f"Hangup: {event.get('reason')}") + self.running = False + else: + self.log_event("←", f"Event: {event_type}") + + def save_output_wav(self) -> None: + """Save received audio to output WAV file.""" + if not self.received_audio: + self.log_event("!", "No audio received to save") + return + + # Convert bytes to numpy array + audio_data = np.frombuffer(bytes(self.received_audio), dtype=np.int16) + + # Ensure output directory exists + self.output_file.parent.mkdir(parents=True, exist_ok=True) + + # Save using wave module for compatibility + with wave.open(str(self.output_file), 'wb') as wav_file: + wav_file.setnchannels(1) + wav_file.setsampwidth(2) # 16-bit + wav_file.setframerate(self.sample_rate) + wav_file.writeframes(audio_data.tobytes()) + + duration = len(audio_data) / self.sample_rate + self.log_event("→", f"Saved output: {self.output_file}") + self.log_event("→", f" Duration: {duration:.2f}s ({len(audio_data)} samples)") + self.log_event("→", f" Size: {len(self.received_audio)/1024:.1f} KB") + + async def run(self) -> None: + """Run the WAV file test.""" + try: + # Load input WAV file + audio_data, _ = self.load_wav_file() + + # Connect to server + await self.connect() + + # Wait for answer + await asyncio.sleep(0.5) + + # Start receiver task + receiver_task = asyncio.create_task(self.receiver()) + + # Send audio + await self.audio_sender(audio_data) + + # Wait for response + self.log_event("→", f"Waiting {self.wait_time}s for response...") + + wait_start = time.time() + while self.running and (time.time() - wait_start) < self.wait_time: + # Check if track has ended (response complete) + if self.track_ended and self.send_completed: + # Give a little extra time for any remaining audio + await asyncio.sleep(1.0) + break + await asyncio.sleep(0.1) + + # Cleanup + self.running = False + receiver_task.cancel() + + try: + await receiver_task + except asyncio.CancelledError: + pass + + # Save output + self.save_output_wav() + + # Print summary + self._print_summary() + + except FileNotFoundError as e: + print(f"Error: {e}") + sys.exit(1) + except ConnectionRefusedError: + print(f"Error: Could not connect to {self.url}") + print("Make sure the server is running.") + sys.exit(1) + except Exception as e: + print(f"Error: {e}") + import traceback + traceback.print_exc() + sys.exit(1) + finally: + await self.close() + + def _print_summary(self): + """Print session summary.""" + print("\n" + "=" * 50) + print("Session Summary") + print("=" * 50) + print(f" Input file: {self.input_file}") + print(f" Output file: {self.output_file}") + print(f" Bytes sent: {self.bytes_sent / 1024:.1f} KB") + print(f" Bytes received: {self.bytes_received / 1024:.1f} KB") + if self.ttfb_list: + if len(self.ttfb_list) == 1: + print(f" TTFB: {self.ttfb_list[0]:.0f} ms") + else: + print(f" TTFB (per response): {', '.join(f'{t:.0f}ms' for t in self.ttfb_list)}") + if self.received_audio: + duration = len(self.received_audio) / (self.sample_rate * 2) + print(f" Response duration: {duration:.2f}s") + print("=" * 50) + + async def close(self) -> None: + """Close the connection.""" + self.running = False + if self.ws: + try: + await self.ws.close() + except: + pass + + +async def main(): + parser = argparse.ArgumentParser( + description="WAV file client for testing duplex voice conversation" + ) + parser.add_argument( + "--input", "-i", + required=True, + help="Input WAV file path" + ) + parser.add_argument( + "--output", "-o", + required=True, + help="Output WAV file path for response" + ) + parser.add_argument( + "--url", + default="ws://localhost:8000/ws", + help="WebSocket server URL (default: ws://localhost:8000/ws)" + ) + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="Target sample rate for audio (default: 16000)" + ) + parser.add_argument( + "--chunk-duration", + type=int, + default=20, + help="Chunk duration in ms for sending (default: 20)" + ) + parser.add_argument( + "--wait-time", "-w", + type=float, + default=15.0, + help="Time to wait for response after sending (default: 15.0)" + ) + parser.add_argument( + "--verbose", "-v", + action="store_true", + help="Enable verbose output" + ) + + args = parser.parse_args() + + client = WavFileClient( + url=args.url, + input_file=args.input, + output_file=args.output, + sample_rate=args.sample_rate, + chunk_duration_ms=args.chunk_duration, + wait_time=args.wait_time, + verbose=args.verbose + ) + + await client.run() + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + print("\nInterrupted by user") diff --git a/examples/web_client.html b/examples/web_client.html new file mode 100644 index 0000000..aaeb636 --- /dev/null +++ b/examples/web_client.html @@ -0,0 +1,766 @@ + + + + + + Duplex Voice Web Client + + + +
+
+

Duplex Voice Client

+
Browser client for the WebSocket duplex pipeline. Device selection + event logging.
+
+ +
+
+

Connection

+
+ + +
+
+ + +
+
+
+
+
Disconnected
+
Waiting for connection
+
+
+ +

Devices

+
+
+ + +
+
+ + +
+
+
+ + + +
+ +

Chat

+
+ +
+ + +
+
+
+ +
+
+

Chat History

+
+
+
+

Event Log

+
+
+
+
+ +
+ Output device selection requires HTTPS + a browser that supports setSinkId. + Audio is sent as 16-bit PCM @ 16 kHz, matching examples/mic_client.py. +
+ + + + + + diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..924d5fd --- /dev/null +++ b/models/__init__.py @@ -0,0 +1 @@ +"""Data Models Package""" diff --git a/models/commands.py b/models/commands.py new file mode 100644 index 0000000..5bcf47e --- /dev/null +++ b/models/commands.py @@ -0,0 +1,143 @@ +"""Protocol command models matching the original active-call API.""" + +from typing import Optional, Dict, Any +from pydantic import BaseModel, Field + + +class InviteCommand(BaseModel): + """Invite command to initiate a call.""" + + command: str = Field(default="invite", description="Command type") + option: Optional[Dict[str, Any]] = Field(default=None, description="Call configuration options") + + +class AcceptCommand(BaseModel): + """Accept command to accept an incoming call.""" + + command: str = Field(default="accept", description="Command type") + option: Optional[Dict[str, Any]] = Field(default=None, description="Call configuration options") + + +class RejectCommand(BaseModel): + """Reject command to reject an incoming call.""" + + command: str = Field(default="reject", description="Command type") + reason: str = Field(default="", description="Reason for rejection") + code: Optional[int] = Field(default=None, description="SIP response code") + + +class RingingCommand(BaseModel): + """Ringing command to send ringing response.""" + + command: str = Field(default="ringing", description="Command type") + recorder: Optional[Dict[str, Any]] = Field(default=None, description="Call recording configuration") + early_media: bool = Field(default=False, description="Enable early media") + ringtone: Optional[str] = Field(default=None, description="Custom ringtone URL") + + +class TTSCommand(BaseModel): + """TTS command to convert text to speech.""" + + command: str = Field(default="tts", description="Command type") + text: str = Field(..., description="Text to synthesize") + speaker: Optional[str] = Field(default=None, description="Speaker voice name") + play_id: Optional[str] = Field(default=None, description="Unique identifier for this TTS session") + auto_hangup: bool = Field(default=False, description="Auto hangup after TTS completion") + streaming: bool = Field(default=False, description="Streaming text input") + end_of_stream: bool = Field(default=False, description="End of streaming input") + wait_input_timeout: Optional[int] = Field(default=None, description="Max time to wait for input (seconds)") + option: Optional[Dict[str, Any]] = Field(default=None, description="TTS provider specific options") + + +class PlayCommand(BaseModel): + """Play command to play audio from URL.""" + + command: str = Field(default="play", description="Command type") + url: str = Field(..., description="URL of audio file to play") + auto_hangup: bool = Field(default=False, description="Auto hangup after playback") + wait_input_timeout: Optional[int] = Field(default=None, description="Max time to wait for input (seconds)") + + +class InterruptCommand(BaseModel): + """Interrupt command to interrupt current playback.""" + + command: str = Field(default="interrupt", description="Command type") + graceful: bool = Field(default=False, description="Wait for current TTS to complete") + + +class PauseCommand(BaseModel): + """Pause command to pause current playback.""" + + command: str = Field(default="pause", description="Command type") + + +class ResumeCommand(BaseModel): + """Resume command to resume paused playback.""" + + command: str = Field(default="resume", description="Command type") + + +class HangupCommand(BaseModel): + """Hangup command to end the call.""" + + command: str = Field(default="hangup", description="Command type") + reason: Optional[str] = Field(default=None, description="Reason for hangup") + initiator: Optional[str] = Field(default=None, description="Who initiated the hangup") + + +class HistoryCommand(BaseModel): + """History command to add conversation history.""" + + command: str = Field(default="history", description="Command type") + speaker: str = Field(..., description="Speaker identifier") + text: str = Field(..., description="Conversation text") + + +class ChatCommand(BaseModel): + """Chat command for text-based conversation.""" + + command: str = Field(default="chat", description="Command type") + text: str = Field(..., description="Chat text message") + + +# Command type mapping +COMMAND_TYPES = { + "invite": InviteCommand, + "accept": AcceptCommand, + "reject": RejectCommand, + "ringing": RingingCommand, + "tts": TTSCommand, + "play": PlayCommand, + "interrupt": InterruptCommand, + "pause": PauseCommand, + "resume": ResumeCommand, + "hangup": HangupCommand, + "history": HistoryCommand, + "chat": ChatCommand, +} + + +def parse_command(data: Dict[str, Any]) -> BaseModel: + """ + Parse a command from JSON data. + + Args: + data: JSON data as dictionary + + Returns: + Parsed command model + + Raises: + ValueError: If command type is unknown + """ + command_type = data.get("command") + + if not command_type: + raise ValueError("Missing 'command' field") + + command_class = COMMAND_TYPES.get(command_type) + + if not command_class: + raise ValueError(f"Unknown command type: {command_type}") + + return command_class(**data) diff --git a/models/config.py b/models/config.py new file mode 100644 index 0000000..009411e --- /dev/null +++ b/models/config.py @@ -0,0 +1,126 @@ +"""Configuration models for call options.""" + +from typing import Optional, Dict, Any, List +from pydantic import BaseModel, Field + + +class VADOption(BaseModel): + """Voice Activity Detection configuration.""" + + type: str = Field(default="silero", description="VAD algorithm type (silero, webrtc)") + samplerate: int = Field(default=16000, description="Audio sample rate for VAD") + speech_padding: int = Field(default=250, description="Speech padding in milliseconds") + silence_padding: int = Field(default=100, description="Silence padding in milliseconds") + ratio: float = Field(default=0.5, description="Voice detection ratio threshold") + voice_threshold: float = Field(default=0.5, description="Voice energy threshold") + max_buffer_duration_secs: int = Field(default=50, description="Maximum buffer duration in seconds") + silence_timeout: Optional[int] = Field(default=None, description="Silence timeout in milliseconds") + endpoint: Optional[str] = Field(default=None, description="Custom VAD service endpoint") + secret_key: Optional[str] = Field(default=None, description="VAD service secret key") + secret_id: Optional[str] = Field(default=None, description="VAD service secret ID") + + +class ASROption(BaseModel): + """Automatic Speech Recognition configuration.""" + + provider: str = Field(..., description="ASR provider (tencent, aliyun, openai, etc.)") + language: Optional[str] = Field(default=None, description="Language code (zh-CN, en-US)") + app_id: Optional[str] = Field(default=None, description="Application ID") + secret_id: Optional[str] = Field(default=None, description="Secret ID for authentication") + secret_key: Optional[str] = Field(default=None, description="Secret key for authentication") + model_type: Optional[str] = Field(default=None, description="ASR model type (16k_zh, 8k_en)") + buffer_size: Optional[int] = Field(default=None, description="Audio buffer size in bytes") + samplerate: Optional[int] = Field(default=None, description="Audio sample rate") + endpoint: Optional[str] = Field(default=None, description="Custom ASR service endpoint") + extra: Optional[Dict[str, Any]] = Field(default=None, description="Additional parameters") + start_when_answer: bool = Field(default=False, description="Start ASR when call is answered") + + +class TTSOption(BaseModel): + """Text-to-Speech configuration.""" + + samplerate: Optional[int] = Field(default=None, description="TTS output sample rate") + provider: str = Field(default="msedge", description="TTS provider (tencent, aliyun, deepgram, msedge)") + speed: float = Field(default=1.0, description="Speech speed multiplier") + app_id: Optional[str] = Field(default=None, description="Application ID") + secret_id: Optional[str] = Field(default=None, description="Secret ID for authentication") + secret_key: Optional[str] = Field(default=None, description="Secret key for authentication") + volume: Optional[int] = Field(default=None, description="Speech volume level (1-10)") + speaker: Optional[str] = Field(default=None, description="Voice speaker name") + codec: Optional[str] = Field(default=None, description="Audio codec") + subtitle: bool = Field(default=False, description="Enable subtitle generation") + emotion: Optional[str] = Field(default=None, description="Speech emotion") + endpoint: Optional[str] = Field(default=None, description="Custom TTS service endpoint") + extra: Optional[Dict[str, Any]] = Field(default=None, description="Additional parameters") + max_concurrent_tasks: Optional[int] = Field(default=None, description="Max concurrent tasks") + + +class RecorderOption(BaseModel): + """Call recording configuration.""" + + recorder_file: str = Field(..., description="Path to recording file") + samplerate: int = Field(default=16000, description="Recording sample rate") + ptime: int = Field(default=200, description="Packet time in milliseconds") + + +class MediaPassOption(BaseModel): + """Media pass-through configuration for external audio processing.""" + + url: str = Field(..., description="WebSocket URL for media streaming") + input_sample_rate: int = Field(default=16000, description="Sample rate of audio received from WebSocket") + output_sample_rate: int = Field(default=16000, description="Sample rate of audio sent to WebSocket") + packet_size: int = Field(default=2560, description="Packet size in bytes") + ptime: Optional[int] = Field(default=None, description="Buffered playback period in milliseconds") + + +class SipOption(BaseModel): + """SIP protocol configuration.""" + + username: Optional[str] = Field(default=None, description="SIP username") + password: Optional[str] = Field(default=None, description="SIP password") + realm: Optional[str] = Field(default=None, description="SIP realm/domain") + headers: Optional[Dict[str, str]] = Field(default=None, description="Additional SIP headers") + + +class HandlerRule(BaseModel): + """Handler routing rule.""" + + caller: Optional[str] = Field(default=None, description="Caller pattern (regex)") + callee: Optional[str] = Field(default=None, description="Callee pattern (regex)") + playbook: Optional[str] = Field(default=None, description="Playbook file path") + webhook: Optional[str] = Field(default=None, description="Webhook URL") + + +class CallOption(BaseModel): + """Comprehensive call configuration options.""" + + # Basic options + denoise: bool = Field(default=False, description="Enable noise reduction") + offer: Optional[str] = Field(default=None, description="SDP offer string") + callee: Optional[str] = Field(default=None, description="Callee SIP URI or phone number") + caller: Optional[str] = Field(default=None, description="Caller SIP URI or phone number") + + # Audio codec + codec: str = Field(default="pcm", description="Audio codec (pcm, pcma, pcmu, g722)") + + # Component configurations + recorder: Optional[RecorderOption] = Field(default=None, description="Call recording config") + asr: Optional[ASROption] = Field(default=None, description="ASR configuration") + vad: Optional[VADOption] = Field(default=None, description="VAD configuration") + tts: Optional[TTSOption] = Field(default=None, description="TTS configuration") + media_pass: Optional[MediaPassOption] = Field(default=None, description="Media pass-through config") + sip: Optional[SipOption] = Field(default=None, description="SIP configuration") + + # Timeouts and networking + handshake_timeout: Optional[int] = Field(default=None, description="Handshake timeout in seconds") + enable_ipv6: bool = Field(default=False, description="Enable IPv6 support") + inactivity_timeout: Optional[int] = Field(default=None, description="Inactivity timeout in seconds") + + # EOU configuration + eou: Optional[Dict[str, Any]] = Field(default=None, description="End of utterance detection config") + + # Extra parameters + extra: Optional[Dict[str, Any]] = Field(default=None, description="Additional custom parameters") + + class Config: + populate_by_name = True diff --git a/models/events.py b/models/events.py new file mode 100644 index 0000000..031b8be --- /dev/null +++ b/models/events.py @@ -0,0 +1,231 @@ +"""Protocol event models matching the original active-call API.""" + +from typing import Optional, Dict, Any +from pydantic import BaseModel, Field +from datetime import datetime + + +def current_timestamp_ms() -> int: + """Get current timestamp in milliseconds.""" + return int(datetime.now().timestamp() * 1000) + + +# Base Event Model +class BaseEvent(BaseModel): + """Base event model.""" + + event: str = Field(..., description="Event type") + track_id: str = Field(..., description="Unique track identifier") + timestamp: int = Field(default_factory=current_timestamp_ms, description="Event timestamp in milliseconds") + + +# Lifecycle Events +class IncomingEvent(BaseEvent): + """Incoming call event (SIP only).""" + + event: str = Field(default="incoming", description="Event type") + caller: Optional[str] = Field(default=None, description="Caller's SIP URI") + callee: Optional[str] = Field(default=None, description="Callee's SIP URI") + sdp: Optional[str] = Field(default=None, description="SDP offer from caller") + + +class AnswerEvent(BaseEvent): + """Call answered event.""" + + event: str = Field(default="answer", description="Event type") + sdp: Optional[str] = Field(default=None, description="SDP answer from server") + + +class RejectEvent(BaseEvent): + """Call rejected event.""" + + event: str = Field(default="reject", description="Event type") + reason: Optional[str] = Field(default=None, description="Rejection reason") + code: Optional[int] = Field(default=None, description="SIP response code") + + +class RingingEvent(BaseEvent): + """Call ringing event.""" + + event: str = Field(default="ringing", description="Event type") + early_media: bool = Field(default=False, description="Early media available") + + +class HangupEvent(BaseModel): + """Call hangup event.""" + + event: str = Field(default="hangup", description="Event type") + timestamp: int = Field(default_factory=current_timestamp_ms, description="Event timestamp") + reason: Optional[str] = Field(default=None, description="Hangup reason") + initiator: Optional[str] = Field(default=None, description="Who initiated hangup") + start_time: Optional[str] = Field(default=None, description="Call start time (ISO 8601)") + hangup_time: Optional[str] = Field(default=None, description="Hangup time (ISO 8601)") + answer_time: Optional[str] = Field(default=None, description="Answer time (ISO 8601)") + ringing_time: Optional[str] = Field(default=None, description="Ringing time (ISO 8601)") + from_: Optional[Dict[str, Any]] = Field(default=None, alias="from", description="Caller info") + to: Optional[Dict[str, Any]] = Field(default=None, description="Callee info") + extra: Optional[Dict[str, Any]] = Field(default=None, description="Additional metadata") + + class Config: + populate_by_name = True + + +# VAD Events +class SpeakingEvent(BaseEvent): + """Speech detected event.""" + + event: str = Field(default="speaking", description="Event type") + start_time: int = Field(default_factory=current_timestamp_ms, description="Speech start time") + + +class SilenceEvent(BaseEvent): + """Silence detected event.""" + + event: str = Field(default="silence", description="Event type") + start_time: int = Field(default_factory=current_timestamp_ms, description="Silence start time") + duration: int = Field(default=0, description="Silence duration in milliseconds") + + +# AI/ASR Events +class AsrFinalEvent(BaseEvent): + """ASR final transcription event.""" + + event: str = Field(default="asrFinal", description="Event type") + index: int = Field(..., description="ASR result sequence number") + start_time: Optional[int] = Field(default=None, description="Speech start time") + end_time: Optional[int] = Field(default=None, description="Speech end time") + text: str = Field(..., description="Transcribed text") + + +class AsrDeltaEvent(BaseEvent): + """ASR partial transcription event (streaming).""" + + event: str = Field(default="asrDelta", description="Event type") + index: int = Field(..., description="ASR result sequence number") + start_time: Optional[int] = Field(default=None, description="Speech start time") + end_time: Optional[int] = Field(default=None, description="Speech end time") + text: str = Field(..., description="Partial transcribed text") + + +class EouEvent(BaseEvent): + """End of utterance detection event.""" + + event: str = Field(default="eou", description="Event type") + completed: bool = Field(default=True, description="Whether utterance was completed") + + +# Audio Track Events +class TrackStartEvent(BaseEvent): + """Audio track start event.""" + + event: str = Field(default="trackStart", description="Event type") + play_id: Optional[str] = Field(default=None, description="Play ID from TTS/Play command") + + +class TrackEndEvent(BaseEvent): + """Audio track end event.""" + + event: str = Field(default="trackEnd", description="Event type") + duration: int = Field(..., description="Track duration in milliseconds") + ssrc: int = Field(..., description="RTP SSRC identifier") + play_id: Optional[str] = Field(default=None, description="Play ID from TTS/Play command") + + +class InterruptionEvent(BaseEvent): + """Playback interruption event.""" + + event: str = Field(default="interruption", description="Event type") + play_id: Optional[str] = Field(default=None, description="Play ID that was interrupted") + subtitle: Optional[str] = Field(default=None, description="TTS text being played") + position: Optional[int] = Field(default=None, description="Word index position") + total_duration: Optional[int] = Field(default=None, description="Total TTS duration") + current: Optional[int] = Field(default=None, description="Elapsed time when interrupted") + + +# System Events +class ErrorEvent(BaseEvent): + """Error event.""" + + event: str = Field(default="error", description="Event type") + sender: str = Field(..., description="Component that generated the error") + error: str = Field(..., description="Error message") + code: Optional[int] = Field(default=None, description="Error code") + + +class MetricsEvent(BaseModel): + """Performance metrics event.""" + + event: str = Field(default="metrics", description="Event type") + timestamp: int = Field(default_factory=current_timestamp_ms, description="Event timestamp") + key: str = Field(..., description="Metric key") + duration: int = Field(..., description="Duration in milliseconds") + data: Optional[Dict[str, Any]] = Field(default=None, description="Additional metric data") + + +class AddHistoryEvent(BaseModel): + """Conversation history entry added event.""" + + event: str = Field(default="addHistory", description="Event type") + timestamp: int = Field(default_factory=current_timestamp_ms, description="Event timestamp") + sender: Optional[str] = Field(default=None, description="Component that added history") + speaker: str = Field(..., description="Speaker identifier") + text: str = Field(..., description="Conversation text") + + +class DTMFEvent(BaseEvent): + """DTMF tone detected event.""" + + event: str = Field(default="dtmf", description="Event type") + digit: str = Field(..., description="DTMF digit (0-9, *, #, A-D)") + + +class HeartBeatEvent(BaseModel): + """Server-to-client heartbeat to keep connection alive.""" + + event: str = Field(default="heartBeat", description="Event type") + timestamp: int = Field(default_factory=current_timestamp_ms, description="Event timestamp in milliseconds") + + +# Event type mapping +EVENT_TYPES = { + "incoming": IncomingEvent, + "answer": AnswerEvent, + "reject": RejectEvent, + "ringing": RingingEvent, + "hangup": HangupEvent, + "speaking": SpeakingEvent, + "silence": SilenceEvent, + "asrFinal": AsrFinalEvent, + "asrDelta": AsrDeltaEvent, + "eou": EouEvent, + "trackStart": TrackStartEvent, + "trackEnd": TrackEndEvent, + "interruption": InterruptionEvent, + "error": ErrorEvent, + "metrics": MetricsEvent, + "addHistory": AddHistoryEvent, + "dtmf": DTMFEvent, + "heartBeat": HeartBeatEvent, +} + + +def create_event(event_type: str, **kwargs) -> BaseModel: + """ + Create an event model. + + Args: + event_type: Type of event to create + **kwargs: Event fields + + Returns: + Event model instance + + Raises: + ValueError: If event type is unknown + """ + event_class = EVENT_TYPES.get(event_type) + + if not event_class: + raise ValueError(f"Unknown event type: {event_type}") + + return event_class(event=event_type, **kwargs) diff --git a/models/ws_v1.py b/models/ws_v1.py new file mode 100644 index 0000000..b8f5524 --- /dev/null +++ b/models/ws_v1.py @@ -0,0 +1,73 @@ +"""WS v1 protocol message models and helpers.""" + +from typing import Optional, Dict, Any, Literal +from pydantic import BaseModel, Field + + +def now_ms() -> int: + """Current unix timestamp in milliseconds.""" + import time + + return int(time.time() * 1000) + + +# Client -> Server messages +class HelloMessage(BaseModel): + type: Literal["hello"] + version: str = Field(..., description="Protocol version, currently v1") + auth: Optional[Dict[str, str]] = Field(default=None, description="Auth payload, e.g. {'apiKey': '...'}") + + +class SessionStartMessage(BaseModel): + type: Literal["session.start"] + audio: Optional[Dict[str, Any]] = Field(default=None, description="Optional audio format metadata") + metadata: Optional[Dict[str, Any]] = Field(default=None, description="Optional session metadata") + + +class SessionStopMessage(BaseModel): + type: Literal["session.stop"] + reason: Optional[str] = None + + +class InputTextMessage(BaseModel): + type: Literal["input.text"] + text: str + + +class ResponseCancelMessage(BaseModel): + type: Literal["response.cancel"] + graceful: bool = False + + +class ToolCallResultsMessage(BaseModel): + type: Literal["tool_call.results"] + results: list[Dict[str, Any]] = Field(default_factory=list) + + +CLIENT_MESSAGE_TYPES = { + "hello": HelloMessage, + "session.start": SessionStartMessage, + "session.stop": SessionStopMessage, + "input.text": InputTextMessage, + "response.cancel": ResponseCancelMessage, + "tool_call.results": ToolCallResultsMessage, +} + + +def parse_client_message(data: Dict[str, Any]) -> BaseModel: + """Parse and validate a WS v1 client message.""" + msg_type = data.get("type") + if not msg_type: + raise ValueError("Missing 'type' field") + msg_class = CLIENT_MESSAGE_TYPES.get(msg_type) + if not msg_class: + raise ValueError(f"Unknown client message type: {msg_type}") + return msg_class(**data) + + +# Server -> Client event helpers +def ev(event_type: str, **payload: Any) -> Dict[str, Any]: + """Create a WS v1 server event payload.""" + base = {"type": event_type, "timestamp": now_ms()} + base.update(payload) + return base diff --git a/processors/__init__.py b/processors/__init__.py new file mode 100644 index 0000000..1952777 --- /dev/null +++ b/processors/__init__.py @@ -0,0 +1,6 @@ +"""Audio Processors Package""" + +from processors.eou import EouDetector +from processors.vad import SileroVAD, VADProcessor + +__all__ = ["EouDetector", "SileroVAD", "VADProcessor"] diff --git a/processors/eou.py b/processors/eou.py new file mode 100644 index 0000000..22d104f --- /dev/null +++ b/processors/eou.py @@ -0,0 +1,80 @@ +"""End-of-Utterance Detection.""" + +import time +from typing import Optional + + +class EouDetector: + """ + End-of-utterance detector. Fires EOU only after continuous silence for + silence_threshold_ms. Short pauses between sentences do not trigger EOU + because speech resets the silence timer (one EOU per turn). + """ + + def __init__(self, silence_threshold_ms: int = 1000, min_speech_duration_ms: int = 250): + """ + Initialize EOU detector. + + Args: + silence_threshold_ms: How long silence must last to trigger EOU (default 1000ms) + min_speech_duration_ms: Minimum speech duration to consider valid (default 250ms) + """ + self.threshold = silence_threshold_ms / 1000.0 + self.min_speech = min_speech_duration_ms / 1000.0 + self._silence_threshold_ms = silence_threshold_ms + self._min_speech_duration_ms = min_speech_duration_ms + + # State + self.is_speaking = False + self.speech_start_time = 0.0 + self.silence_start_time: Optional[float] = None + self.triggered = False + + def process(self, vad_status: str, force_eligible: bool = False) -> bool: + """ + Process VAD status and detect end of utterance. + + Input: "Speech" or "Silence" (from VAD). + Output: True if EOU detected, False otherwise. + + Short breaks between phrases reset the silence clock when speech + resumes, so only one EOU is emitted after the user truly stops. + """ + now = time.time() + + if vad_status == "Speech": + if not self.is_speaking: + self.is_speaking = True + self.speech_start_time = now + self.triggered = False + # Any speech resets silence timer — short pause + more speech = one utterance + self.silence_start_time = None + return False + + if vad_status == "Silence": + if not self.is_speaking: + return False + if self.silence_start_time is None: + self.silence_start_time = now + + speech_duration = self.silence_start_time - self.speech_start_time + if speech_duration < self.min_speech and not force_eligible: + self.is_speaking = False + self.silence_start_time = None + return False + + silence_duration = now - self.silence_start_time + if silence_duration >= self.threshold and not self.triggered: + self.triggered = True + self.is_speaking = False + self.silence_start_time = None + return True + + return False + + def reset(self) -> None: + """Reset EOU detector state.""" + self.is_speaking = False + self.speech_start_time = 0.0 + self.silence_start_time = None + self.triggered = False diff --git a/processors/tracks.py b/processors/tracks.py new file mode 100644 index 0000000..71f3cbd --- /dev/null +++ b/processors/tracks.py @@ -0,0 +1,168 @@ +"""Audio track processing for WebRTC.""" + +import asyncio +import fractions +from typing import Optional +from loguru import logger + +# Try to import aiortc (optional for WebRTC functionality) +try: + from aiortc import AudioStreamTrack + AIORTC_AVAILABLE = True +except ImportError: + AIORTC_AVAILABLE = False + AudioStreamTrack = object # Dummy class for type hints + +# Try to import PyAV (optional for audio resampling) +try: + from av import AudioFrame, AudioResampler + AV_AVAILABLE = True +except ImportError: + AV_AVAILABLE = False + # Create dummy classes for type hints + class AudioFrame: + pass + class AudioResampler: + pass + +import numpy as np + + +class Resampled16kTrack(AudioStreamTrack if AIORTC_AVAILABLE else object): + """ + Audio track that resamples input to 16kHz mono PCM. + + Wraps an existing MediaStreamTrack and converts its output + to 16kHz mono 16-bit PCM format for the pipeline. + """ + + def __init__(self, track, target_sample_rate: int = 16000): + """ + Initialize resampled track. + + Args: + track: Source MediaStreamTrack + target_sample_rate: Target sample rate (default: 16000) + """ + if not AIORTC_AVAILABLE: + raise RuntimeError("aiortc not available - Resampled16kTrack cannot be used") + + super().__init__() + self.track = track + self.target_sample_rate = target_sample_rate + + if AV_AVAILABLE: + self.resampler = AudioResampler( + format="s16", + layout="mono", + rate=target_sample_rate + ) + else: + logger.warning("PyAV not available, audio resampling disabled") + self.resampler = None + + self._closed = False + + async def recv(self): + """ + Receive and resample next audio frame. + + Returns: + Resampled AudioFrame at 16kHz mono + """ + if self._closed: + raise RuntimeError("Track is closed") + + # Get frame from source track + frame = await self.track.recv() + + # Resample the frame if AV is available + if AV_AVAILABLE and self.resampler: + resampled_frame = self.resampler.resample(frame) + # Ensure the frame has the correct format + resampled_frame.sample_rate = self.target_sample_rate + return resampled_frame + else: + # Return frame as-is if AV is not available + return frame + + async def stop(self) -> None: + """Stop the track and cleanup resources.""" + self._closed = True + if hasattr(self, 'resampler') and self.resampler: + del self.resampler + logger.debug("Resampled track stopped") + + +class SineWaveTrack(AudioStreamTrack if AIORTC_AVAILABLE else object): + """ + Synthetic audio track that generates a sine wave. + + Useful for testing without requiring real audio input. + """ + + def __init__(self, sample_rate: int = 16000, frequency: int = 440): + """ + Initialize sine wave track. + + Args: + sample_rate: Audio sample rate (default: 16000) + frequency: Sine wave frequency in Hz (default: 440) + """ + if not AIORTC_AVAILABLE: + raise RuntimeError("aiortc not available - SineWaveTrack cannot be used") + + super().__init__() + self.sample_rate = sample_rate + self.frequency = frequency + self.counter = 0 + self._stopped = False + + async def recv(self): + """ + Generate next audio frame with sine wave. + + Returns: + AudioFrame with sine wave data + """ + if self._stopped: + raise RuntimeError("Track is stopped") + + # Generate 20ms of audio + samples = int(self.sample_rate * 0.02) + pts = self.counter + time_base = fractions.Fraction(1, self.sample_rate) + + # Generate sine wave + t = np.linspace( + self.counter / self.sample_rate, + (self.counter + samples) / self.sample_rate, + samples, + endpoint=False + ) + + # Generate sine wave (Int16 PCM) + data = (0.5 * np.sin(2 * np.pi * self.frequency * t) * 32767).astype(np.int16) + + # Update counter + self.counter += samples + + # Create AudioFrame if AV is available + if AV_AVAILABLE: + frame = AudioFrame.from_ndarray(data.reshape(1, -1), format='s16', layout='mono') + frame.pts = pts + frame.time_base = time_base + frame.sample_rate = self.sample_rate + return frame + else: + # Return simple data structure if AV is not available + return { + 'data': data, + 'sample_rate': self.sample_rate, + 'pts': pts, + 'time_base': time_base + } + + def stop(self) -> None: + """Stop the track.""" + self._stopped = True diff --git a/processors/vad.py b/processors/vad.py new file mode 100644 index 0000000..cad6e8b --- /dev/null +++ b/processors/vad.py @@ -0,0 +1,221 @@ +"""Voice Activity Detection using Silero VAD.""" + +import asyncio +import os +from typing import Tuple, Optional +import numpy as np +from loguru import logger + + +# Try to import onnxruntime (optional for VAD functionality) +try: + import onnxruntime as ort + ONNX_AVAILABLE = True +except ImportError: + ONNX_AVAILABLE = False + ort = None + logger.warning("onnxruntime not available - VAD will be disabled") + + +class SileroVAD: + """ + Voice Activity Detection using Silero VAD model. + + Detects speech in audio chunks using the Silero VAD ONNX model. + Returns "Speech" or "Silence" for each audio chunk. + """ + + def __init__(self, model_path: str = "data/vad/silero_vad.onnx", sample_rate: int = 16000): + """ + Initialize Silero VAD. + + Args: + model_path: Path to Silero VAD ONNX model + sample_rate: Audio sample rate (must be 16kHz for Silero VAD) + """ + self.sample_rate = sample_rate + self.model_path = model_path + + # Check if model exists + if not os.path.exists(model_path): + logger.warning(f"VAD model not found at {model_path}. VAD will be disabled.") + self.session = None + return + + # Check if onnxruntime is available + if not ONNX_AVAILABLE: + logger.warning("onnxruntime not available - VAD will be disabled") + self.session = None + return + + # Load ONNX model + try: + self.session = ort.InferenceSession(model_path) + logger.info(f"Loaded Silero VAD model from {model_path}") + except Exception as e: + logger.error(f"Failed to load VAD model: {e}") + self.session = None + return + + # Internal state for VAD + self._reset_state() + self.buffer = np.array([], dtype=np.float32) + self.min_chunk_size = 512 + self.last_label = "Silence" + self.last_probability = 0.0 + self._energy_noise_floor = 1e-4 + + def _reset_state(self): + # Silero VAD V4+ expects state shape [2, 1, 128] + self._state = np.zeros((2, 1, 128), dtype=np.float32) + self._sr = np.array([self.sample_rate], dtype=np.int64) + + def process_audio(self, pcm_bytes: bytes, chunk_size_ms: int = 20) -> Tuple[str, float]: + """ + Process audio chunk and detect speech. + + Args: + pcm_bytes: PCM audio data (16-bit, mono, 16kHz) + chunk_size_ms: Chunk duration in milliseconds (ignored for buffering logic) + + Returns: + Tuple of (label, probability) where label is "Speech" or "Silence" + """ + if self.session is None or not ONNX_AVAILABLE: + # Fallback energy-based VAD with adaptive noise floor. + if not pcm_bytes: + return "Silence", 0.0 + audio_int16 = np.frombuffer(pcm_bytes, dtype=np.int16) + if audio_int16.size == 0: + return "Silence", 0.0 + audio_float = audio_int16.astype(np.float32) / 32768.0 + rms = float(np.sqrt(np.mean(audio_float * audio_float))) + + # Update adaptive noise floor (slowly rises, faster to fall) + if rms < self._energy_noise_floor: + self._energy_noise_floor = 0.95 * self._energy_noise_floor + 0.05 * rms + else: + self._energy_noise_floor = 0.995 * self._energy_noise_floor + 0.005 * rms + + # Compute SNR-like ratio and map to probability + denom = max(self._energy_noise_floor, 1e-6) + snr = max(0.0, (rms - denom) / denom) + probability = min(1.0, snr / 3.0) # ~3x above noise => strong speech + label = "Speech" if probability >= 0.5 else "Silence" + return label, probability + + # Convert bytes to numpy array of int16 + audio_int16 = np.frombuffer(pcm_bytes, dtype=np.int16) + + # Normalize to float32 (-1.0 to 1.0) + audio_float = audio_int16.astype(np.float32) / 32768.0 + + # Add to buffer + self.buffer = np.concatenate((self.buffer, audio_float)) + + # Process all complete chunks in the buffer + processed_any = False + while len(self.buffer) >= self.min_chunk_size: + # Slice exactly 512 samples + chunk = self.buffer[:self.min_chunk_size] + self.buffer = self.buffer[self.min_chunk_size:] + + # Prepare inputs + # Input tensor shape: [batch, samples] -> [1, 512] + input_tensor = chunk.reshape(1, -1) + + # Run inference + try: + ort_inputs = { + 'input': input_tensor, + 'state': self._state, + 'sr': self._sr + } + + # Outputs: probability, state + out, self._state = self.session.run(None, ort_inputs) + + # Get probability + self.last_probability = float(out[0][0]) + self.last_label = "Speech" if self.last_probability >= 0.5 else "Silence" + processed_any = True + + except Exception as e: + logger.error(f"VAD inference error: {e}") + # Try to determine if it's an input name issue + try: + inputs = [x.name for x in self.session.get_inputs()] + logger.error(f"Model expects inputs: {inputs}") + except: + pass + return "Speech", 1.0 + + return self.last_label, self.last_probability + + def reset(self) -> None: + """Reset VAD internal state.""" + self._reset_state() + self.buffer = np.array([], dtype=np.float32) + self.last_label = "Silence" + self.last_probability = 0.0 + + +class VADProcessor: + """ + High-level VAD processor with state management. + + Tracks speech/silence state and emits events on transitions. + """ + + def __init__(self, vad_model: SileroVAD, threshold: float = 0.5): + """ + Initialize VAD processor. + + Args: + vad_model: Silero VAD model instance + threshold: Speech detection threshold + """ + self.vad = vad_model + self.threshold = threshold + self.is_speaking = False + self.speech_start_time: Optional[float] = None + self.silence_start_time: Optional[float] = None + + def process(self, pcm_bytes: bytes, chunk_size_ms: int = 20) -> Optional[Tuple[str, float]]: + """ + Process audio chunk and detect state changes. + + Args: + pcm_bytes: PCM audio data + chunk_size_ms: Chunk duration in milliseconds + + Returns: + Tuple of (event_type, probability) if state changed, None otherwise + """ + label, probability = self.vad.process_audio(pcm_bytes, chunk_size_ms) + + # Check if this is speech based on threshold + is_speech = probability >= self.threshold + + # State transition: Silence -> Speech + if is_speech and not self.is_speaking: + self.is_speaking = True + self.speech_start_time = asyncio.get_event_loop().time() + self.silence_start_time = None + return ("speaking", probability) + + # State transition: Speech -> Silence + elif not is_speech and self.is_speaking: + self.is_speaking = False + self.silence_start_time = asyncio.get_event_loop().time() + self.speech_start_time = None + return ("silence", probability) + + return None + + def reset(self) -> None: + """Reset VAD state.""" + self.vad.reset() + self.is_speaking = False + self.speech_start_time = None + self.silence_start_time = None diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..8786905 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,134 @@ +[build-system] +requires = ["setuptools>=68.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "py-active-call-cc" +version = "0.1.0" +description = "Python Active-Call: Real-time audio streaming with WebSocket and WebRTC" +readme = "README.md" +requires-python = ">=3.11" +license = {text = "MIT"} +authors = [ + {name = "Your Name", email = "your.email@example.com"} +] +keywords = ["webrtc", "websocket", "audio", "voip", "real-time"] +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Topic :: Communications :: Telephony", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", +] + +[project.urls] +Homepage = "https://github.com/yourusername/py-active-call-cc" +Documentation = "https://github.com/yourusername/py-active-call-cc/blob/main/README.md" +Repository = "https://github.com/yourusername/py-active-call-cc.git" +Issues = "https://github.com/yourusername/py-active-call-cc/issues" + +[tool.setuptools.packages.find] +where = ["."] +include = ["app*"] +exclude = ["tests*", "scripts*", "reference*"] + +[tool.black] +line-length = 100 +target-version = ['py311'] +include = '\.pyi?$' +extend-exclude = ''' +/( + # directories + \.eggs + | \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | build + | dist + | reference +)/ +''' + +[tool.ruff] +line-length = 100 +target-version = "py311" +select = [ + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # pyflakes + "I", # isort + "B", # flake8-bugbear + "C4", # flake8-comprehensions + "UP", # pyupgrade +] +ignore = [ + "E501", # line too long (handled by black) + "B008", # do not perform function calls in argument defaults +] +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".hg", + ".mypy_cache", + ".nox", + ".pants.d", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "venv", + "reference", +] + +[tool.ruff.per-file-ignores] +"__init__.py" = ["F401"] # unused imports + +[tool.mypy] +python_version = "3.11" +warn_return_any = true +warn_unused_configs = true +disallow_untyped_defs = false +disallow_incomplete_defs = false +check_untyped_defs = true +no_implicit_optional = true +warn_redundant_casts = true +warn_unused_ignores = true +warn_no_return = true +strict_equality = true +exclude = [ + "venv", + "reference", + "build", + "dist", +] + +[[tool.mypy.overrides]] +module = [ + "aiortc.*", + "av.*", + "onnxruntime.*", +] +ignore_missing_imports = true + +[tool.pytest.ini_options] +minversion = "7.0" +addopts = "-ra -q --strict-markers --strict-config" +testpaths = ["tests"] +pythonpath = ["."] +asyncio_mode = "auto" +markers = [ + "slow: marks tests as slow (deselect with '-m \"not slow\"')", + "integration: marks tests as integration tests", +] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..3d38414 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,37 @@ +# Web Framework +fastapi>=0.109.0 +uvicorn[standard]>=0.27.0 +websockets>=12.0 +python-multipart>=0.0.6 + +# WebRTC (optional - for WebRTC transport) +aiortc>=1.6.0 + +# Audio Processing +av>=12.1.0 +numpy>=1.26.3 +onnxruntime>=1.16.3 + +# Configuration +pydantic>=2.5.3 +pydantic-settings>=2.1.0 +python-dotenv>=1.0.0 +toml>=0.10.2 + +# Logging +loguru>=0.7.2 + +# HTTP Client +aiohttp>=3.9.1 + +# AI Services - LLM +openai>=1.0.0 + +# AI Services - TTS +edge-tts>=6.1.0 +pydub>=0.25.0 # For audio format conversion + +# Microphone client dependencies +sounddevice>=0.4.6 +soundfile>=0.12.1 +pyaudio>=0.2.13 # More reliable audio on Windows diff --git a/scripts/README.md b/scripts/README.md new file mode 100644 index 0000000..8b6f7a0 --- /dev/null +++ b/scripts/README.md @@ -0,0 +1 @@ +# Development Script \ No newline at end of file diff --git a/scripts/generate_test_audio/generate_test_audio.py b/scripts/generate_test_audio/generate_test_audio.py new file mode 100644 index 0000000..9b37f5f --- /dev/null +++ b/scripts/generate_test_audio/generate_test_audio.py @@ -0,0 +1,311 @@ +#!/usr/bin/env python3 +""" +Generate test audio file with utterances using SiliconFlow TTS API. + +Creates a 16kHz mono WAV file with real speech segments separated by +configurable silence (for VAD/testing). + +Usage: + python generate_test_audio.py [OPTIONS] + +Options: + -o, --output PATH Output WAV path (default: data/audio_examples/two_utterances_16k.wav) + -u, --utterance TEXT Utterance text; repeat for multiple (ignored if -j is set) + -j, --json PATH JSON file: array of strings or {"utterances": [...]} + --silence-ms MS Silence in ms between utterances (default: 500) + --lead-silence-ms MS Silence in ms at start (default: 200) + --trail-silence-ms MS Silence in ms at end (default: 300) + +Examples: + # Default utterances and output + python generate_test_audio.py + + # Custom output path + python generate_test_audio.py -o out.wav + + # Utterances from command line + python generate_test_audio.py -u "Hello" -u "World" -o test.wav + + # Utterancgenerate_test_audio.py -j utterances.json -o test.wav + + # Custom silence (1s between utterances) + python generate_test_audio.py -u "One" -u "Two" --silence-ms 1000 -o test.wav + +Requires SILICONFLOW_API_KEY in .env. +""" + +import wave +import struct +import argparse +import asyncio +import aiohttp +import json +import os +from pathlib import Path +from dotenv import load_dotenv + + +# Load .env file from project root +project_root = Path(__file__).parent.parent.parent +load_dotenv(project_root / ".env") + + +# SiliconFlow TTS Configuration +SILICONFLOW_API_URL = "https://api.siliconflow.cn/v1/audio/speech" +SILICONFLOW_MODEL = "FunAudioLLM/CosyVoice2-0.5B" + +# Available voices +VOICES = { + "alex": "FunAudioLLM/CosyVoice2-0.5B:alex", + "anna": "FunAudioLLM/CosyVoice2-0.5B:anna", + "bella": "FunAudioLLM/CosyVoice2-0.5B:bella", + "benjamin": "FunAudioLLM/CosyVoice2-0.5B:benjamin", + "charles": "FunAudioLLM/CosyVoice2-0.5B:charles", + "claire": "FunAudioLLM/CosyVoice2-0.5B:claire", + "david": "FunAudioLLM/CosyVoice2-0.5B:david", + "diana": "FunAudioLLM/CosyVoice2-0.5B:diana", +} + + +def generate_silence(duration_ms: int, sample_rate: int = 16000) -> bytes: + """Generate silence as PCM bytes.""" + num_samples = int(sample_rate * (duration_ms / 1000.0)) + return b'\x00\x00' * num_samples + + +async def synthesize_speech( + text: str, + api_key: str, + voice: str = "anna", + sample_rate: int = 16000, + speed: float = 1.0 +) -> bytes: + """ + Synthesize speech using SiliconFlow TTS API. + + Args: + text: Text to synthesize + api_key: SiliconFlow API key + voice: Voice name (alex, anna, bella, benjamin, charles, claire, david, diana) + sample_rate: Output sample rate (8000, 16000, 24000, 32000, 44100) + speed: Speech speed (0.25 to 4.0) + + Returns: + PCM audio bytes (16-bit signed, little-endian) + """ + # Resolve voice name + full_voice = VOICES.get(voice, voice) + + payload = { + "model": SILICONFLOW_MODEL, + "input": text, + "voice": full_voice, + "response_format": "pcm", + "sample_rate": sample_rate, + "stream": False, + "speed": speed + } + + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json" + } + + async with aiohttp.ClientSession() as session: + async with session.post(SILICONFLOW_API_URL, json=payload, headers=headers) as response: + if response.status != 200: + error_text = await response.text() + raise RuntimeError(f"SiliconFlow TTS error: {response.status} - {error_text}") + + return await response.read() + + +async def generate_test_audio( + output_path: str, + utterances: list[str], + silence_ms: int = 500, + lead_silence_ms: int = 200, + trail_silence_ms: int = 300, + voice: str = "anna", + sample_rate: int = 16000, + speed: float = 1.0 +): + """ + Generate test audio with multiple utterances separated by silence. + + Args: + output_path: Path to save the WAV file + utterances: List of text strings for each utterance + silence_ms: Silence duration between utterances (milliseconds) + lead_silence_ms: Silence at the beginning (milliseconds) + trail_silence_ms: Silence at the end (milliseconds) + voice: TTS voice to use + sample_rate: Audio sample rate + speed: TTS speech speed + """ + api_key = os.getenv("SILICONFLOW_API_KEY") + if not api_key: + raise ValueError( + "SILICONFLOW_API_KEY not found in environment.\n" + "Please set it in your .env file:\n" + " SILICONFLOW_API_KEY=your-api-key-here" + ) + + print(f"Using SiliconFlow TTS API") + print(f" Voice: {voice}") + print(f" Sample rate: {sample_rate}Hz") + print(f" Speed: {speed}x") + print() + + segments = [] + + # Lead-in silence + if lead_silence_ms > 0: + segments.append(generate_silence(lead_silence_ms, sample_rate)) + print(f" [silence: {lead_silence_ms}ms]") + + # Generate each utterance with silence between + for i, text in enumerate(utterances): + print(f" Synthesizing utterance {i + 1}: \"{text}\"") + audio = await synthesize_speech( + text=text, + api_key=api_key, + voice=voice, + sample_rate=sample_rate, + speed=speed + ) + segments.append(audio) + + # Add silence between utterances (not after the last one) + if i < len(utterances) - 1: + segments.append(generate_silence(silence_ms, sample_rate)) + print(f" [silence: {silence_ms}ms]") + + # Trail silence + if trail_silence_ms > 0: + segments.append(generate_silence(trail_silence_ms, sample_rate)) + print(f" [silence: {trail_silence_ms}ms]") + + # Concatenate all segments + audio_data = b''.join(segments) + + # Write WAV file + with wave.open(output_path, 'wb') as wf: + wf.setnchannels(1) # Mono + wf.setsampwidth(2) # 16-bit + wf.setframerate(sample_rate) + wf.writeframes(audio_data) + + duration_sec = len(audio_data) / (sample_rate * 2) + print() + print(f"Generated: {output_path}") + print(f" Duration: {duration_sec:.2f}s") + print(f" Sample rate: {sample_rate}Hz") + print(f" Format: 16-bit mono PCM WAV") + print(f" Size: {len(audio_data):,} bytes") + + +def load_utterances_from_json(path: Path) -> list[str]: + """ + Load utterances from a JSON file. + + Accepts either: + - A JSON array: ["utterance 1", "utterance 2"] + - A JSON object with "utterances" key: {"utterances": ["a", "b"]} + """ + with open(path, encoding="utf-8") as f: + data = json.load(f) + if isinstance(data, list): + return [str(s) for s in data] + if isinstance(data, dict) and "utterances" in data: + return [str(s) for s in data["utterances"]] + raise ValueError( + f"JSON file must be an array of strings or an object with 'utterances' key. " + f"Got: {type(data).__name__}" + ) + + +def parse_args(): + """Parse command-line arguments.""" + script_dir = Path(__file__).parent + default_output = script_dir.parent / "data" / "audio_examples" / "two_utterances_16k.wav" + + parser = argparse.ArgumentParser(description="Generate test audio with SiliconFlow TTS (utterances + silence).") + parser.add_argument( + "-o", "--output", + type=Path, + default=default_output, + help=f"Output WAV file path (default: {default_output})" + ) + parser.add_argument( + "-u", "--utterance", + action="append", + dest="utterances", + metavar="TEXT", + help="Utterance text (repeat for multiple). Ignored if --json is set." + ) + parser.add_argument( + "-j", "--json", + type=Path, + metavar="PATH", + help="JSON file with utterances: array of strings or object with 'utterances' key" + ) + parser.add_argument( + "--silence-ms", + type=int, + default=500, + metavar="MS", + help="Silence in ms between utterances (default: 500)" + ) + parser.add_argument( + "--lead-silence-ms", + type=int, + default=200, + metavar="MS", + help="Silence in ms at start of file (default: 200)" + ) + parser.add_argument( + "--trail-silence-ms", + type=int, + default=300, + metavar="MS", + help="Silence in ms at end of file (default: 300)" + ) + return parser.parse_args() + + +async def main(): + """Main entry point.""" + args = parse_args() + output_path = args.output + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Resolve utterances: JSON file > -u args > defaults + if args.json is not None: + if not args.json.is_file(): + raise FileNotFoundError(f"Utterances JSON file not found: {args.json}") + utterances = load_utterances_from_json(args.json) + if not utterances: + raise ValueError(f"JSON file has no utterances: {args.json}") + elif args.utterances: + utterances = args.utterances + else: + utterances = [ + "Hello, how are you doing today?", + "I'm doing great, thank you for asking!" + ] + + await generate_test_audio( + output_path=str(output_path), + utterances=utterances, + silence_ms=args.silence_ms, + lead_silence_ms=args.lead_silence_ms, + trail_silence_ms=args.trail_silence_ms, + voice="anna", + sample_rate=16000, + speed=1.0 + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/services/__init__.py b/services/__init__.py new file mode 100644 index 0000000..0bab6b3 --- /dev/null +++ b/services/__init__.py @@ -0,0 +1,51 @@ +"""AI Services package. + +Provides ASR, LLM, TTS, and Realtime API services for voice conversation. +""" + +from services.base import ( + ServiceState, + ASRResult, + LLMMessage, + TTSChunk, + BaseASRService, + BaseLLMService, + BaseTTSService, +) +from services.llm import OpenAILLMService, MockLLMService +from services.tts import EdgeTTSService, MockTTSService +from services.asr import BufferedASRService, MockASRService +from services.openai_compatible_asr import OpenAICompatibleASRService, SiliconFlowASRService +from services.openai_compatible_tts import OpenAICompatibleTTSService, SiliconFlowTTSService +from services.streaming_tts_adapter import StreamingTTSAdapter +from services.realtime import RealtimeService, RealtimeConfig, RealtimePipeline + +__all__ = [ + # Base classes + "ServiceState", + "ASRResult", + "LLMMessage", + "TTSChunk", + "BaseASRService", + "BaseLLMService", + "BaseTTSService", + # LLM + "OpenAILLMService", + "MockLLMService", + # TTS + "EdgeTTSService", + "MockTTSService", + # ASR + "BufferedASRService", + "MockASRService", + "OpenAICompatibleASRService", + "SiliconFlowASRService", + # TTS (SiliconFlow) + "OpenAICompatibleTTSService", + "SiliconFlowTTSService", + "StreamingTTSAdapter", + # Realtime + "RealtimeService", + "RealtimeConfig", + "RealtimePipeline", +] diff --git a/services/asr.py b/services/asr.py new file mode 100644 index 0000000..51ab584 --- /dev/null +++ b/services/asr.py @@ -0,0 +1,147 @@ +"""ASR (Automatic Speech Recognition) Service implementations. + +Provides speech-to-text capabilities with streaming support. +""" + +import os +import asyncio +import json +from typing import AsyncIterator, Optional +from loguru import logger + +from services.base import BaseASRService, ASRResult, ServiceState + +# Try to import websockets for streaming ASR +try: + import websockets + WEBSOCKETS_AVAILABLE = True +except ImportError: + WEBSOCKETS_AVAILABLE = False + + +class BufferedASRService(BaseASRService): + """ + Buffered ASR service that accumulates audio and provides + a simple text accumulator for use with EOU detection. + + This is a lightweight implementation that works with the + existing VAD + EOU pattern without requiring external ASR. + """ + + def __init__( + self, + sample_rate: int = 16000, + language: str = "en" + ): + super().__init__(sample_rate=sample_rate, language=language) + + self._audio_buffer: bytes = b"" + self._current_text: str = "" + self._transcript_queue: asyncio.Queue[ASRResult] = asyncio.Queue() + + async def connect(self) -> None: + """No connection needed for buffered ASR.""" + self.state = ServiceState.CONNECTED + logger.info("Buffered ASR service connected") + + async def disconnect(self) -> None: + """Clear buffers on disconnect.""" + self._audio_buffer = b"" + self._current_text = "" + self.state = ServiceState.DISCONNECTED + logger.info("Buffered ASR service disconnected") + + async def send_audio(self, audio: bytes) -> None: + """Buffer audio for later processing.""" + self._audio_buffer += audio + + async def receive_transcripts(self) -> AsyncIterator[ASRResult]: + """Yield transcription results.""" + while True: + try: + result = await asyncio.wait_for( + self._transcript_queue.get(), + timeout=0.1 + ) + yield result + except asyncio.TimeoutError: + continue + except asyncio.CancelledError: + break + + def set_text(self, text: str) -> None: + """ + Set the current transcript text directly. + + This allows external integration (e.g., Whisper, other ASR) + to provide transcripts. + """ + self._current_text = text + result = ASRResult(text=text, is_final=False) + asyncio.create_task(self._transcript_queue.put(result)) + + def get_and_clear_text(self) -> str: + """Get accumulated text and clear buffer.""" + text = self._current_text + self._current_text = "" + self._audio_buffer = b"" + return text + + def get_audio_buffer(self) -> bytes: + """Get accumulated audio buffer.""" + return self._audio_buffer + + def clear_audio_buffer(self) -> None: + """Clear audio buffer.""" + self._audio_buffer = b"" + + +class MockASRService(BaseASRService): + """ + Mock ASR service for testing without actual recognition. + """ + + def __init__(self, sample_rate: int = 16000, language: str = "en"): + super().__init__(sample_rate=sample_rate, language=language) + self._transcript_queue: asyncio.Queue[ASRResult] = asyncio.Queue() + self._mock_texts = [ + "Hello, how are you?", + "That's interesting.", + "Tell me more about that.", + "I understand.", + ] + self._text_index = 0 + + async def connect(self) -> None: + self.state = ServiceState.CONNECTED + logger.info("Mock ASR service connected") + + async def disconnect(self) -> None: + self.state = ServiceState.DISCONNECTED + logger.info("Mock ASR service disconnected") + + async def send_audio(self, audio: bytes) -> None: + """Mock audio processing - generates fake transcripts periodically.""" + pass + + def trigger_transcript(self) -> None: + """Manually trigger a transcript (for testing).""" + text = self._mock_texts[self._text_index % len(self._mock_texts)] + self._text_index += 1 + + result = ASRResult(text=text, is_final=True, confidence=0.95) + asyncio.create_task(self._transcript_queue.put(result)) + + async def receive_transcripts(self) -> AsyncIterator[ASRResult]: + """Yield transcription results.""" + while True: + try: + result = await asyncio.wait_for( + self._transcript_queue.get(), + timeout=0.1 + ) + yield result + except asyncio.TimeoutError: + continue + except asyncio.CancelledError: + break diff --git a/services/base.py b/services/base.py new file mode 100644 index 0000000..7238416 --- /dev/null +++ b/services/base.py @@ -0,0 +1,253 @@ +"""Base classes for AI services. + +Defines abstract interfaces for ASR, LLM, and TTS services, +inspired by pipecat's service architecture and active-call's +StreamEngine pattern. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import AsyncIterator, Optional, List, Dict, Any, Literal +from enum import Enum + + +class ServiceState(Enum): + """Service connection state.""" + DISCONNECTED = "disconnected" + CONNECTING = "connecting" + CONNECTED = "connected" + ERROR = "error" + + +@dataclass +class ASRResult: + """ASR transcription result.""" + text: str + is_final: bool = False + confidence: float = 1.0 + language: Optional[str] = None + start_time: Optional[float] = None + end_time: Optional[float] = None + + def __str__(self) -> str: + status = "FINAL" if self.is_final else "PARTIAL" + return f"[{status}] {self.text}" + + +@dataclass +class LLMMessage: + """LLM conversation message.""" + role: str # "system", "user", "assistant", "function" + content: str + name: Optional[str] = None # For function calls + function_call: Optional[Dict[str, Any]] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to API-compatible dict.""" + d = {"role": self.role, "content": self.content} + if self.name: + d["name"] = self.name + if self.function_call: + d["function_call"] = self.function_call + return d + + +@dataclass +class LLMStreamEvent: + """Structured LLM stream event.""" + + type: Literal["text_delta", "tool_call", "done"] + text: Optional[str] = None + tool_call: Optional[Dict[str, Any]] = None + + +@dataclass +class TTSChunk: + """TTS audio chunk.""" + audio: bytes # PCM audio data + sample_rate: int = 16000 + channels: int = 1 + bits_per_sample: int = 16 + is_final: bool = False + text_offset: Optional[int] = None # Character offset in original text + + +class BaseASRService(ABC): + """ + Abstract base class for ASR (Speech-to-Text) services. + + Supports both streaming and non-streaming transcription. + """ + + def __init__(self, sample_rate: int = 16000, language: str = "en"): + self.sample_rate = sample_rate + self.language = language + self.state = ServiceState.DISCONNECTED + + @abstractmethod + async def connect(self) -> None: + """Establish connection to ASR service.""" + pass + + @abstractmethod + async def disconnect(self) -> None: + """Close connection to ASR service.""" + pass + + @abstractmethod + async def send_audio(self, audio: bytes) -> None: + """ + Send audio chunk for transcription. + + Args: + audio: PCM audio data (16-bit, mono) + """ + pass + + @abstractmethod + async def receive_transcripts(self) -> AsyncIterator[ASRResult]: + """ + Receive transcription results. + + Yields: + ASRResult objects as they become available + """ + pass + + async def transcribe(self, audio: bytes) -> ASRResult: + """ + Transcribe a complete audio buffer (non-streaming). + + Args: + audio: Complete PCM audio data + + Returns: + Final ASRResult + """ + # Default implementation using streaming + await self.send_audio(audio) + async for result in self.receive_transcripts(): + if result.is_final: + return result + return ASRResult(text="", is_final=True) + + +class BaseLLMService(ABC): + """ + Abstract base class for LLM (Language Model) services. + + Supports streaming responses for real-time conversation. + """ + + def __init__(self, model: str = "gpt-4"): + self.model = model + self.state = ServiceState.DISCONNECTED + + @abstractmethod + async def connect(self) -> None: + """Initialize LLM service connection.""" + pass + + @abstractmethod + async def disconnect(self) -> None: + """Close LLM service connection.""" + pass + + @abstractmethod + async def generate( + self, + messages: List[LLMMessage], + temperature: float = 0.7, + max_tokens: Optional[int] = None + ) -> str: + """ + Generate a complete response. + + Args: + messages: Conversation history + temperature: Sampling temperature + max_tokens: Maximum tokens to generate + + Returns: + Complete response text + """ + pass + + @abstractmethod + async def generate_stream( + self, + messages: List[LLMMessage], + temperature: float = 0.7, + max_tokens: Optional[int] = None + ) -> AsyncIterator[LLMStreamEvent]: + """ + Generate response in streaming mode. + + Args: + messages: Conversation history + temperature: Sampling temperature + max_tokens: Maximum tokens to generate + + Yields: + Stream events (text delta/tool call/done) + """ + pass + + +class BaseTTSService(ABC): + """ + Abstract base class for TTS (Text-to-Speech) services. + + Supports streaming audio synthesis for low-latency playback. + """ + + def __init__( + self, + voice: str = "default", + sample_rate: int = 16000, + speed: float = 1.0 + ): + self.voice = voice + self.sample_rate = sample_rate + self.speed = speed + self.state = ServiceState.DISCONNECTED + + @abstractmethod + async def connect(self) -> None: + """Initialize TTS service connection.""" + pass + + @abstractmethod + async def disconnect(self) -> None: + """Close TTS service connection.""" + pass + + @abstractmethod + async def synthesize(self, text: str) -> bytes: + """ + Synthesize complete audio for text (non-streaming). + + Args: + text: Text to synthesize + + Returns: + Complete PCM audio data + """ + pass + + @abstractmethod + async def synthesize_stream(self, text: str) -> AsyncIterator[TTSChunk]: + """ + Synthesize audio in streaming mode. + + Args: + text: Text to synthesize + + Yields: + TTSChunk objects as audio is generated + """ + pass + + async def cancel(self) -> None: + """Cancel ongoing synthesis (for barge-in support).""" + pass diff --git a/services/llm.py b/services/llm.py new file mode 100644 index 0000000..a25ff26 --- /dev/null +++ b/services/llm.py @@ -0,0 +1,443 @@ +"""LLM (Large Language Model) Service implementations. + +Provides OpenAI-compatible LLM integration with streaming support +for real-time voice conversation. +""" + +import os +import asyncio +import uuid +from typing import AsyncIterator, Optional, List, Dict, Any +from loguru import logger + +from app.backend_client import search_knowledge_context +from services.base import BaseLLMService, LLMMessage, LLMStreamEvent, ServiceState + +# Try to import openai +try: + from openai import AsyncOpenAI + OPENAI_AVAILABLE = True +except ImportError: + OPENAI_AVAILABLE = False + logger.warning("openai package not available - LLM service will be disabled") + + +class OpenAILLMService(BaseLLMService): + """ + OpenAI-compatible LLM service. + + Supports streaming responses for low-latency voice conversation. + Works with OpenAI API, Azure OpenAI, and compatible APIs. + """ + + def __init__( + self, + model: str = "gpt-4o-mini", + api_key: Optional[str] = None, + base_url: Optional[str] = None, + system_prompt: Optional[str] = None, + knowledge_config: Optional[Dict[str, Any]] = None, + ): + """ + Initialize OpenAI LLM service. + + Args: + model: Model name (e.g., "gpt-4o-mini", "gpt-4o") + api_key: OpenAI API key (defaults to OPENAI_API_KEY env var) + base_url: Custom API base URL (for Azure or compatible APIs) + system_prompt: Default system prompt for conversations + """ + super().__init__(model=model) + + self.api_key = api_key or os.getenv("OPENAI_API_KEY") + self.base_url = base_url or os.getenv("OPENAI_API_URL") + self.system_prompt = system_prompt or ( + "You are a helpful, friendly voice assistant. " + "Keep your responses concise and conversational. " + "Respond naturally as if having a phone conversation." + ) + + self.client: Optional[AsyncOpenAI] = None + self._cancel_event = asyncio.Event() + self._knowledge_config: Dict[str, Any] = knowledge_config or {} + self._tool_schemas: List[Dict[str, Any]] = [] + + _RAG_DEFAULT_RESULTS = 5 + _RAG_MAX_RESULTS = 8 + _RAG_MAX_CONTEXT_CHARS = 4000 + + async def connect(self) -> None: + """Initialize OpenAI client.""" + if not OPENAI_AVAILABLE: + raise RuntimeError("openai package not installed") + + if not self.api_key: + raise ValueError("OpenAI API key not provided") + + self.client = AsyncOpenAI( + api_key=self.api_key, + base_url=self.base_url + ) + self.state = ServiceState.CONNECTED + logger.info(f"OpenAI LLM service connected: model={self.model}") + + async def disconnect(self) -> None: + """Close OpenAI client.""" + if self.client: + await self.client.close() + self.client = None + self.state = ServiceState.DISCONNECTED + logger.info("OpenAI LLM service disconnected") + + def _prepare_messages(self, messages: List[LLMMessage]) -> List[Dict[str, Any]]: + """Prepare messages list with system prompt.""" + result = [] + + # Add system prompt if not already present + has_system = any(m.role == "system" for m in messages) + if not has_system and self.system_prompt: + result.append({"role": "system", "content": self.system_prompt}) + + # Add all messages + for msg in messages: + result.append(msg.to_dict()) + + return result + + def set_knowledge_config(self, config: Optional[Dict[str, Any]]) -> None: + """Update runtime knowledge retrieval config.""" + self._knowledge_config = config or {} + + def set_tool_schemas(self, schemas: Optional[List[Dict[str, Any]]]) -> None: + """Update runtime tool schemas.""" + self._tool_schemas = [] + if not isinstance(schemas, list): + return + for item in schemas: + if not isinstance(item, dict): + continue + fn = item.get("function") + if isinstance(fn, dict) and fn.get("name"): + self._tool_schemas.append(item) + elif item.get("name"): + self._tool_schemas.append( + { + "type": "function", + "function": { + "name": str(item.get("name")), + "description": str(item.get("description") or ""), + "parameters": item.get("parameters") or {"type": "object", "properties": {}}, + }, + } + ) + + @staticmethod + def _coerce_int(value: Any, default: int) -> int: + try: + return int(value) + except (TypeError, ValueError): + return default + + def _resolve_kb_id(self) -> Optional[str]: + cfg = self._knowledge_config if isinstance(self._knowledge_config, dict) else {} + kb_id = str( + cfg.get("kbId") + or cfg.get("knowledgeBaseId") + or cfg.get("knowledge_base_id") + or "" + ).strip() + return kb_id or None + + def _build_knowledge_prompt(self, results: List[Dict[str, Any]]) -> Optional[str]: + if not results: + return None + + lines = [ + "You have retrieved the following knowledge base snippets.", + "Use them only when relevant to the latest user request.", + "If snippets are insufficient, say you are not sure instead of guessing.", + "", + ] + + used_chars = 0 + used_count = 0 + for item in results: + content = str(item.get("content") or "").strip() + if not content: + continue + if used_chars >= self._RAG_MAX_CONTEXT_CHARS: + break + + metadata = item.get("metadata") if isinstance(item.get("metadata"), dict) else {} + doc_id = metadata.get("document_id") + chunk_index = metadata.get("chunk_index") + distance = item.get("distance") + + source_parts = [] + if doc_id: + source_parts.append(f"doc={doc_id}") + if chunk_index is not None: + source_parts.append(f"chunk={chunk_index}") + source = f" ({', '.join(source_parts)})" if source_parts else "" + + distance_text = "" + try: + if distance is not None: + distance_text = f", distance={float(distance):.4f}" + except (TypeError, ValueError): + distance_text = "" + + remaining = self._RAG_MAX_CONTEXT_CHARS - used_chars + snippet = content[:remaining].strip() + if not snippet: + continue + + used_count += 1 + lines.append(f"[{used_count}{source}{distance_text}] {snippet}") + used_chars += len(snippet) + + if used_count == 0: + return None + + return "\n".join(lines) + + async def _with_knowledge_context(self, messages: List[LLMMessage]) -> List[LLMMessage]: + cfg = self._knowledge_config if isinstance(self._knowledge_config, dict) else {} + enabled = cfg.get("enabled", True) + if isinstance(enabled, str): + enabled = enabled.strip().lower() not in {"false", "0", "off", "no"} + if not enabled: + return messages + + kb_id = self._resolve_kb_id() + if not kb_id: + return messages + + latest_user = "" + for msg in reversed(messages): + if msg.role == "user": + latest_user = (msg.content or "").strip() + break + if not latest_user: + return messages + + n_results = self._coerce_int(cfg.get("nResults"), self._RAG_DEFAULT_RESULTS) + n_results = max(1, min(n_results, self._RAG_MAX_RESULTS)) + + results = await search_knowledge_context( + kb_id=kb_id, + query=latest_user, + n_results=n_results, + ) + prompt = self._build_knowledge_prompt(results) + if not prompt: + return messages + + logger.debug(f"RAG context injected (kb_id={kb_id}, chunks={len(results)})") + rag_system = LLMMessage(role="system", content=prompt) + if messages and messages[0].role == "system": + return [messages[0], rag_system, *messages[1:]] + return [rag_system, *messages] + + async def generate( + self, + messages: List[LLMMessage], + temperature: float = 0.7, + max_tokens: Optional[int] = None + ) -> str: + """ + Generate a complete response. + + Args: + messages: Conversation history + temperature: Sampling temperature + max_tokens: Maximum tokens to generate + + Returns: + Complete response text + """ + if not self.client: + raise RuntimeError("LLM service not connected") + + rag_messages = await self._with_knowledge_context(messages) + prepared = self._prepare_messages(rag_messages) + + try: + response = await self.client.chat.completions.create( + model=self.model, + messages=prepared, + temperature=temperature, + max_tokens=max_tokens + ) + + content = response.choices[0].message.content or "" + logger.debug(f"LLM response: {content[:100]}...") + return content + + except Exception as e: + logger.error(f"LLM generation error: {e}") + raise + + async def generate_stream( + self, + messages: List[LLMMessage], + temperature: float = 0.7, + max_tokens: Optional[int] = None + ) -> AsyncIterator[LLMStreamEvent]: + """ + Generate response in streaming mode. + + Args: + messages: Conversation history + temperature: Sampling temperature + max_tokens: Maximum tokens to generate + + Yields: + Structured stream events + """ + if not self.client: + raise RuntimeError("LLM service not connected") + + rag_messages = await self._with_knowledge_context(messages) + prepared = self._prepare_messages(rag_messages) + self._cancel_event.clear() + tool_accumulator: Dict[int, Dict[str, str]] = {} + openai_tools = self._tool_schemas or None + + try: + create_args: Dict[str, Any] = dict( + model=self.model, + messages=prepared, + temperature=temperature, + max_tokens=max_tokens, + stream=True, + ) + if openai_tools: + create_args["tools"] = openai_tools + create_args["tool_choice"] = "auto" + stream = await self.client.chat.completions.create(**create_args) + + async for chunk in stream: + # Check for cancellation + if self._cancel_event.is_set(): + logger.info("LLM stream cancelled") + break + + if not chunk.choices: + continue + + choice = chunk.choices[0] + delta = getattr(choice, "delta", None) + if delta and getattr(delta, "content", None): + content = delta.content + yield LLMStreamEvent(type="text_delta", text=content) + + # OpenAI streams function calls via incremental tool_calls deltas. + tool_calls = getattr(delta, "tool_calls", None) if delta else None + if tool_calls: + for tc in tool_calls: + index = getattr(tc, "index", 0) or 0 + item = tool_accumulator.setdefault( + int(index), + {"id": "", "name": "", "arguments": ""}, + ) + tc_id = getattr(tc, "id", None) + if tc_id: + item["id"] = str(tc_id) + fn = getattr(tc, "function", None) + if fn: + fn_name = getattr(fn, "name", None) + if fn_name: + item["name"] = str(fn_name) + fn_args = getattr(fn, "arguments", None) + if fn_args: + item["arguments"] += str(fn_args) + + finish_reason = getattr(choice, "finish_reason", None) + if finish_reason == "tool_calls" and tool_accumulator: + for _, payload in sorted(tool_accumulator.items(), key=lambda row: row[0]): + call_name = payload.get("name", "").strip() + if not call_name: + continue + call_id = payload.get("id", "").strip() or f"call_{uuid.uuid4().hex[:10]}" + yield LLMStreamEvent( + type="tool_call", + tool_call={ + "id": call_id, + "type": "function", + "function": { + "name": call_name, + "arguments": payload.get("arguments", "") or "{}", + }, + }, + ) + yield LLMStreamEvent(type="done") + return + + if finish_reason in {"stop", "length", "content_filter"}: + yield LLMStreamEvent(type="done") + return + + except asyncio.CancelledError: + logger.info("LLM stream cancelled via asyncio") + raise + except Exception as e: + logger.error(f"LLM streaming error: {e}") + raise + + def cancel(self) -> None: + """Cancel ongoing generation.""" + self._cancel_event.set() + + +class MockLLMService(BaseLLMService): + """ + Mock LLM service for testing without API calls. + """ + + def __init__(self, response_delay: float = 0.5): + super().__init__(model="mock") + self.response_delay = response_delay + self.responses = [ + "Hello! How can I help you today?", + "That's an interesting question. Let me think about it.", + "I understand. Is there anything else you'd like to know?", + "Great! I'm here if you need anything else.", + ] + self._response_index = 0 + + async def connect(self) -> None: + self.state = ServiceState.CONNECTED + logger.info("Mock LLM service connected") + + async def disconnect(self) -> None: + self.state = ServiceState.DISCONNECTED + logger.info("Mock LLM service disconnected") + + async def generate( + self, + messages: List[LLMMessage], + temperature: float = 0.7, + max_tokens: Optional[int] = None + ) -> str: + await asyncio.sleep(self.response_delay) + response = self.responses[self._response_index % len(self.responses)] + self._response_index += 1 + return response + + async def generate_stream( + self, + messages: List[LLMMessage], + temperature: float = 0.7, + max_tokens: Optional[int] = None + ) -> AsyncIterator[LLMStreamEvent]: + response = await self.generate(messages, temperature, max_tokens) + + # Stream word by word + words = response.split() + for i, word in enumerate(words): + if i > 0: + yield LLMStreamEvent(type="text_delta", text=" ") + yield LLMStreamEvent(type="text_delta", text=word) + await asyncio.sleep(0.05) # Simulate streaming delay + yield LLMStreamEvent(type="done") diff --git a/services/openai_compatible_asr.py b/services/openai_compatible_asr.py new file mode 100644 index 0000000..daf7c04 --- /dev/null +++ b/services/openai_compatible_asr.py @@ -0,0 +1,321 @@ +"""OpenAI-compatible ASR (Automatic Speech Recognition) Service. + +Uses the SiliconFlow API for speech-to-text transcription. +API: https://docs.siliconflow.cn/cn/api-reference/audio/create-audio-transcriptions +""" + +import asyncio +import io +import wave +from typing import AsyncIterator, Optional, Callable, Awaitable +from loguru import logger + +try: + import aiohttp + AIOHTTP_AVAILABLE = True +except ImportError: + AIOHTTP_AVAILABLE = False + logger.warning("aiohttp not available - OpenAICompatibleASRService will not work") + +from services.base import BaseASRService, ASRResult, ServiceState + + +class OpenAICompatibleASRService(BaseASRService): + """ + OpenAI-compatible ASR service for speech-to-text transcription. + + Features: + - Buffers incoming audio chunks + - Provides interim transcriptions periodically (for streaming to client) + - Final transcription on EOU + + API Details: + - Endpoint: POST https://api.siliconflow.cn/v1/audio/transcriptions + - Models: FunAudioLLM/SenseVoiceSmall (default), TeleAI/TeleSpeechASR + - Input: Audio file (multipart/form-data) + - Output: {"text": "transcribed text"} + """ + + # Supported models + MODELS = { + "sensevoice": "FunAudioLLM/SenseVoiceSmall", + "telespeech": "TeleAI/TeleSpeechASR", + } + + API_URL = "https://api.siliconflow.cn/v1/audio/transcriptions" + + def __init__( + self, + api_key: str, + model: str = "FunAudioLLM/SenseVoiceSmall", + sample_rate: int = 16000, + language: str = "auto", + interim_interval_ms: int = 500, # How often to send interim results + min_audio_for_interim_ms: int = 300, # Min audio before first interim + on_transcript: Optional[Callable[[str, bool], Awaitable[None]]] = None + ): + """ + Initialize OpenAI-compatible ASR service. + + Args: + api_key: Provider API key + model: ASR model name or alias + sample_rate: Audio sample rate (16000 recommended) + language: Language code (auto for automatic detection) + interim_interval_ms: How often to generate interim transcriptions + min_audio_for_interim_ms: Minimum audio duration before first interim + on_transcript: Callback for transcription results (text, is_final) + """ + super().__init__(sample_rate=sample_rate, language=language) + + if not AIOHTTP_AVAILABLE: + raise RuntimeError("aiohttp is required for OpenAICompatibleASRService") + + self.api_key = api_key + self.model = self.MODELS.get(model.lower(), model) + self.interim_interval_ms = interim_interval_ms + self.min_audio_for_interim_ms = min_audio_for_interim_ms + self.on_transcript = on_transcript + + # Session + self._session: Optional[aiohttp.ClientSession] = None + + # Audio buffer + self._audio_buffer: bytes = b"" + self._current_text: str = "" + self._last_interim_time: float = 0 + + # Transcript queue for async iteration + self._transcript_queue: asyncio.Queue[ASRResult] = asyncio.Queue() + + # Background task for interim results + self._interim_task: Optional[asyncio.Task] = None + self._running = False + + logger.info(f"OpenAICompatibleASRService initialized with model: {self.model}") + + async def connect(self) -> None: + """Connect to the service.""" + self._session = aiohttp.ClientSession( + headers={ + "Authorization": f"Bearer {self.api_key}" + } + ) + self._running = True + self.state = ServiceState.CONNECTED + logger.info("OpenAICompatibleASRService connected") + + async def disconnect(self) -> None: + """Disconnect and cleanup.""" + self._running = False + + if self._interim_task: + self._interim_task.cancel() + try: + await self._interim_task + except asyncio.CancelledError: + pass + self._interim_task = None + + if self._session: + await self._session.close() + self._session = None + + self._audio_buffer = b"" + self._current_text = "" + self.state = ServiceState.DISCONNECTED + logger.info("OpenAICompatibleASRService disconnected") + + async def send_audio(self, audio: bytes) -> None: + """ + Buffer incoming audio data. + + Args: + audio: PCM audio data (16-bit, mono) + """ + self._audio_buffer += audio + + async def transcribe_buffer(self, is_final: bool = False) -> Optional[str]: + """ + Transcribe current audio buffer. + + Args: + is_final: Whether this is the final transcription + + Returns: + Transcribed text or None if not enough audio + """ + if not self._session: + logger.warning("ASR session not connected") + return None + + # Check minimum audio duration + audio_duration_ms = len(self._audio_buffer) / (self.sample_rate * 2) * 1000 + + if not is_final and audio_duration_ms < self.min_audio_for_interim_ms: + return None + + if audio_duration_ms < 100: # Less than 100ms - too short + return None + + try: + # Convert PCM to WAV in memory + wav_buffer = io.BytesIO() + with wave.open(wav_buffer, 'wb') as wav_file: + wav_file.setnchannels(1) + wav_file.setsampwidth(2) # 16-bit + wav_file.setframerate(self.sample_rate) + wav_file.writeframes(self._audio_buffer) + + wav_buffer.seek(0) + wav_data = wav_buffer.read() + + # Send to API + form_data = aiohttp.FormData() + form_data.add_field( + 'file', + wav_data, + filename='audio.wav', + content_type='audio/wav' + ) + form_data.add_field('model', self.model) + + async with self._session.post(self.API_URL, data=form_data) as response: + if response.status == 200: + result = await response.json() + text = result.get("text", "").strip() + + if text: + self._current_text = text + + # Notify via callback + if self.on_transcript: + await self.on_transcript(text, is_final) + + # Queue result + await self._transcript_queue.put( + ASRResult(text=text, is_final=is_final) + ) + + logger.debug(f"ASR {'final' if is_final else 'interim'}: {text[:50]}...") + return text + else: + error_text = await response.text() + logger.error(f"ASR API error {response.status}: {error_text}") + return None + + except Exception as e: + logger.error(f"ASR transcription error: {e}") + return None + + async def get_final_transcription(self) -> str: + """ + Get final transcription and clear buffer. + + Call this when EOU is detected. + + Returns: + Final transcribed text + """ + # Transcribe full buffer as final + text = await self.transcribe_buffer(is_final=True) + + # Clear buffer + result = text or self._current_text + self._audio_buffer = b"" + self._current_text = "" + + return result + + def get_and_clear_text(self) -> str: + """ + Get accumulated text and clear buffer. + + Compatible with BufferedASRService interface. + """ + text = self._current_text + self._current_text = "" + self._audio_buffer = b"" + return text + + def get_audio_buffer(self) -> bytes: + """Get current audio buffer.""" + return self._audio_buffer + + def get_audio_duration_ms(self) -> float: + """Get current audio buffer duration in milliseconds.""" + return len(self._audio_buffer) / (self.sample_rate * 2) * 1000 + + def clear_buffer(self) -> None: + """Clear audio and text buffers.""" + self._audio_buffer = b"" + self._current_text = "" + + async def receive_transcripts(self) -> AsyncIterator[ASRResult]: + """ + Async iterator for transcription results. + + Yields: + ASRResult with text and is_final flag + """ + while self._running: + try: + result = await asyncio.wait_for( + self._transcript_queue.get(), + timeout=0.1 + ) + yield result + except asyncio.TimeoutError: + continue + except asyncio.CancelledError: + break + + async def start_interim_transcription(self) -> None: + """ + Start background task for interim transcriptions. + + This periodically transcribes buffered audio for + real-time feedback to the user. + """ + if self._interim_task and not self._interim_task.done(): + return + + self._interim_task = asyncio.create_task(self._interim_loop()) + + async def stop_interim_transcription(self) -> None: + """Stop interim transcription task.""" + if self._interim_task: + self._interim_task.cancel() + try: + await self._interim_task + except asyncio.CancelledError: + pass + self._interim_task = None + + async def _interim_loop(self) -> None: + """Background loop for interim transcriptions.""" + import time + + while self._running: + try: + await asyncio.sleep(self.interim_interval_ms / 1000) + + # Check if we have enough new audio + current_time = time.time() + time_since_last = (current_time - self._last_interim_time) * 1000 + + if time_since_last >= self.interim_interval_ms: + audio_duration = self.get_audio_duration_ms() + + if audio_duration >= self.min_audio_for_interim_ms: + await self.transcribe_buffer(is_final=False) + self._last_interim_time = current_time + + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Interim transcription error: {e}") + + +# Backward-compatible alias +SiliconFlowASRService = OpenAICompatibleASRService diff --git a/services/openai_compatible_tts.py b/services/openai_compatible_tts.py new file mode 100644 index 0000000..4967557 --- /dev/null +++ b/services/openai_compatible_tts.py @@ -0,0 +1,324 @@ +"""OpenAI-compatible TTS Service with streaming support. + +Uses SiliconFlow's CosyVoice2 or MOSS-TTSD models for low-latency +text-to-speech synthesis with streaming. + +API Docs: https://docs.siliconflow.cn/cn/api-reference/audio/create-speech +""" + +import os +import asyncio +import aiohttp +from typing import AsyncIterator, Optional +from loguru import logger + +from services.base import BaseTTSService, TTSChunk, ServiceState +from services.streaming_tts_adapter import StreamingTTSAdapter # backward-compatible re-export + + +class OpenAICompatibleTTSService(BaseTTSService): + """ + OpenAI-compatible TTS service with streaming support. + + Supports CosyVoice2-0.5B and MOSS-TTSD-v0.5 models. + """ + + # Available voices + VOICES = { + "alex": "FunAudioLLM/CosyVoice2-0.5B:alex", + "anna": "FunAudioLLM/CosyVoice2-0.5B:anna", + "bella": "FunAudioLLM/CosyVoice2-0.5B:bella", + "benjamin": "FunAudioLLM/CosyVoice2-0.5B:benjamin", + "charles": "FunAudioLLM/CosyVoice2-0.5B:charles", + "claire": "FunAudioLLM/CosyVoice2-0.5B:claire", + "david": "FunAudioLLM/CosyVoice2-0.5B:david", + "diana": "FunAudioLLM/CosyVoice2-0.5B:diana", + } + + def __init__( + self, + api_key: Optional[str] = None, + voice: str = "anna", + model: str = "FunAudioLLM/CosyVoice2-0.5B", + sample_rate: int = 16000, + speed: float = 1.0 + ): + """ + Initialize OpenAI-compatible TTS service. + + Args: + api_key: Provider API key (defaults to SILICONFLOW_API_KEY env var) + voice: Voice name (alex, anna, bella, benjamin, charles, claire, david, diana) + model: Model name + sample_rate: Output sample rate (8000, 16000, 24000, 32000, 44100) + speed: Speech speed (0.25 to 4.0) + """ + # Resolve voice name (case-insensitive), and normalize "model:VoiceId" suffix. + resolved_voice = (voice or "").strip() + voice_lookup = resolved_voice.lower() + if voice_lookup in self.VOICES: + full_voice = self.VOICES[voice_lookup] + elif ":" in resolved_voice: + model_part, voice_part = resolved_voice.split(":", 1) + normalized_voice_part = voice_part.strip().lower() + if normalized_voice_part in self.VOICES: + full_voice = f"{(model_part or model).strip()}:{normalized_voice_part}" + else: + full_voice = resolved_voice + else: + full_voice = resolved_voice + + super().__init__(voice=full_voice, sample_rate=sample_rate, speed=speed) + + self.api_key = api_key or os.getenv("SILICONFLOW_API_KEY") + self.model = model + self.api_url = "https://api.siliconflow.cn/v1/audio/speech" + + self._session: Optional[aiohttp.ClientSession] = None + self._cancel_event = asyncio.Event() + + async def connect(self) -> None: + """Initialize HTTP session.""" + if not self.api_key: + raise ValueError("SiliconFlow API key not provided. Set SILICONFLOW_API_KEY env var.") + + self._session = aiohttp.ClientSession( + headers={ + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json" + } + ) + self.state = ServiceState.CONNECTED + logger.info(f"SiliconFlow TTS service ready: voice={self.voice}, model={self.model}") + + async def disconnect(self) -> None: + """Close HTTP session.""" + if self._session: + await self._session.close() + self._session = None + self.state = ServiceState.DISCONNECTED + logger.info("SiliconFlow TTS service disconnected") + + async def synthesize(self, text: str) -> bytes: + """Synthesize complete audio for text.""" + audio_data = b"" + async for chunk in self.synthesize_stream(text): + audio_data += chunk.audio + return audio_data + + async def synthesize_stream(self, text: str) -> AsyncIterator[TTSChunk]: + """ + Synthesize audio in streaming mode. + + Args: + text: Text to synthesize + + Yields: + TTSChunk objects with PCM audio + """ + if not self._session: + raise RuntimeError("TTS service not connected") + + if not text.strip(): + return + + self._cancel_event.clear() + + payload = { + "model": self.model, + "input": text, + "voice": self.voice, + "response_format": "pcm", + "sample_rate": self.sample_rate, + "stream": True, + "speed": self.speed + } + + try: + async with self._session.post(self.api_url, json=payload) as response: + if response.status != 200: + error_text = await response.text() + logger.error(f"SiliconFlow TTS error: {response.status} - {error_text}") + return + + # Stream audio chunks + chunk_size = self.sample_rate * 2 // 10 # 100ms chunks + buffer = b"" + pending_chunk = None + + async for chunk in response.content.iter_any(): + if self._cancel_event.is_set(): + logger.info("TTS synthesis cancelled") + return + + buffer += chunk + + # Yield complete chunks + while len(buffer) >= chunk_size: + audio_chunk = buffer[:chunk_size] + buffer = buffer[chunk_size:] + + # Keep one full chunk buffered so we can always tag the true + # last full chunk as final when stream length is an exact multiple. + if pending_chunk is not None: + yield TTSChunk( + audio=pending_chunk, + sample_rate=self.sample_rate, + is_final=False + ) + pending_chunk = audio_chunk + + # Flush pending chunk(s) and remaining tail. + if pending_chunk is not None: + if buffer: + yield TTSChunk( + audio=pending_chunk, + sample_rate=self.sample_rate, + is_final=False + ) + pending_chunk = None + else: + yield TTSChunk( + audio=pending_chunk, + sample_rate=self.sample_rate, + is_final=True + ) + pending_chunk = None + + if buffer: + yield TTSChunk( + audio=buffer, + sample_rate=self.sample_rate, + is_final=True + ) + + except asyncio.CancelledError: + logger.info("TTS synthesis cancelled via asyncio") + raise + except Exception as e: + logger.error(f"TTS synthesis error: {e}") + raise + + async def cancel(self) -> None: + """Cancel ongoing synthesis.""" + self._cancel_event.set() + + +class StreamingTTSAdapter: + """ + Adapter for streaming LLM text to TTS with sentence-level chunking. + + This reduces latency by starting TTS as soon as a complete sentence + is received from the LLM, rather than waiting for the full response. + """ + + # Sentence delimiters + SENTENCE_ENDS = {',', '。', '!', '?', '.', '!', '?', '\n'} + + def __init__(self, tts_service: BaseTTSService, transport, session_id: str): + self.tts_service = tts_service + self.transport = transport + self.session_id = session_id + self._buffer = "" + self._cancel_event = asyncio.Event() + self._is_speaking = False + + def _is_non_sentence_period(self, text: str, idx: int) -> bool: + """Check whether '.' should NOT be treated as a sentence delimiter.""" + if text[idx] != ".": + return False + + # Decimal/version segment: 1.2, v1.2.3 + if idx > 0 and idx < len(text) - 1 and text[idx - 1].isdigit() and text[idx + 1].isdigit(): + return True + + # Number abbreviations: No.1 / No. 1 + left_start = idx - 1 + while left_start >= 0 and text[left_start].isalpha(): + left_start -= 1 + left_token = text[left_start + 1:idx].lower() + if left_token == "no": + j = idx + 1 + while j < len(text) and text[j].isspace(): + j += 1 + if j < len(text) and text[j].isdigit(): + return True + + return False + + async def process_text_chunk(self, text_chunk: str) -> None: + """ + Process a text chunk from LLM and trigger TTS when sentence is complete. + + Args: + text_chunk: Text chunk from LLM streaming + """ + if self._cancel_event.is_set(): + return + + self._buffer += text_chunk + + # Check for sentence completion + while True: + split_idx = -1 + for i, char in enumerate(self._buffer): + if char == "." and self._is_non_sentence_period(self._buffer, i): + continue + if char in self.SENTENCE_ENDS: + split_idx = i + break + if split_idx < 0: + break + + end_idx = split_idx + 1 + while end_idx < len(self._buffer) and self._buffer[end_idx] in self.SENTENCE_ENDS: + end_idx += 1 + + sentence = self._buffer[:end_idx].strip() + self._buffer = self._buffer[end_idx:] + + if sentence and any(ch.isalnum() for ch in sentence): + await self._speak_sentence(sentence) + + async def flush(self) -> None: + """Flush remaining buffer.""" + if self._buffer.strip() and not self._cancel_event.is_set(): + await self._speak_sentence(self._buffer.strip()) + self._buffer = "" + + async def _speak_sentence(self, text: str) -> None: + """Synthesize and send a sentence.""" + if not text or self._cancel_event.is_set(): + return + + self._is_speaking = True + + try: + async for chunk in self.tts_service.synthesize_stream(text): + if self._cancel_event.is_set(): + break + await self.transport.send_audio(chunk.audio) + await asyncio.sleep(0.01) # Prevent flooding + except Exception as e: + logger.error(f"TTS speak error: {e}") + finally: + self._is_speaking = False + + def cancel(self) -> None: + """Cancel ongoing speech.""" + self._cancel_event.set() + self._buffer = "" + + def reset(self) -> None: + """Reset for new turn.""" + self._cancel_event.clear() + self._buffer = "" + self._is_speaking = False + + @property + def is_speaking(self) -> bool: + return self._is_speaking + + +# Backward-compatible alias +SiliconFlowTTSService = OpenAICompatibleTTSService diff --git a/services/realtime.py b/services/realtime.py new file mode 100644 index 0000000..3fd95c1 --- /dev/null +++ b/services/realtime.py @@ -0,0 +1,548 @@ +"""OpenAI Realtime API Service. + +Provides true duplex voice conversation using OpenAI's Realtime API, +similar to active-call's RealtimeProcessor. This bypasses the need for +separate ASR/LLM/TTS services by handling everything server-side. + +The Realtime API provides: +- Server-side VAD with turn detection +- Streaming speech-to-text +- Streaming LLM responses +- Streaming text-to-speech +- Function calling support +- Barge-in/interruption handling +""" + +import os +import asyncio +import json +import base64 +from typing import Optional, Dict, Any, Callable, Awaitable, List +from dataclasses import dataclass, field +from enum import Enum +from loguru import logger + +try: + import websockets + WEBSOCKETS_AVAILABLE = True +except ImportError: + WEBSOCKETS_AVAILABLE = False + logger.warning("websockets not available - Realtime API will be disabled") + + +class RealtimeState(Enum): + """Realtime API connection state.""" + DISCONNECTED = "disconnected" + CONNECTING = "connecting" + CONNECTED = "connected" + ERROR = "error" + + +@dataclass +class RealtimeConfig: + """Configuration for OpenAI Realtime API.""" + + # API Configuration + api_key: Optional[str] = None + model: str = "gpt-4o-realtime-preview" + endpoint: Optional[str] = None # For Azure or custom endpoints + + # Voice Configuration + voice: str = "alloy" # alloy, echo, shimmer, etc. + instructions: str = ( + "You are a helpful, friendly voice assistant. " + "Keep your responses concise and conversational." + ) + + # Turn Detection (Server-side VAD) + turn_detection: Optional[Dict[str, Any]] = field(default_factory=lambda: { + "type": "server_vad", + "threshold": 0.5, + "prefix_padding_ms": 300, + "silence_duration_ms": 500 + }) + + # Audio Configuration + input_audio_format: str = "pcm16" + output_audio_format: str = "pcm16" + + # Tools/Functions + tools: List[Dict[str, Any]] = field(default_factory=list) + + +class RealtimeService: + """ + OpenAI Realtime API service for true duplex voice conversation. + + This service handles the entire voice conversation pipeline: + 1. Audio input → Server-side VAD → Speech-to-text + 2. Text → LLM processing → Response generation + 3. Response → Text-to-speech → Audio output + + Events emitted: + - on_audio: Audio output from the assistant + - on_transcript: Text transcript (user or assistant) + - on_speech_started: User started speaking + - on_speech_stopped: User stopped speaking + - on_response_started: Assistant started responding + - on_response_done: Assistant finished responding + - on_function_call: Function call requested + - on_error: Error occurred + """ + + def __init__(self, config: Optional[RealtimeConfig] = None): + """ + Initialize Realtime API service. + + Args: + config: Realtime configuration (uses defaults if not provided) + """ + self.config = config or RealtimeConfig() + self.config.api_key = self.config.api_key or os.getenv("OPENAI_API_KEY") + + self.state = RealtimeState.DISCONNECTED + self._ws = None + self._receive_task: Optional[asyncio.Task] = None + self._cancel_event = asyncio.Event() + + # Event callbacks + self._callbacks: Dict[str, List[Callable]] = { + "on_audio": [], + "on_transcript": [], + "on_speech_started": [], + "on_speech_stopped": [], + "on_response_started": [], + "on_response_done": [], + "on_function_call": [], + "on_error": [], + "on_interrupted": [], + } + + logger.debug(f"RealtimeService initialized with model={self.config.model}") + + def on(self, event: str, callback: Callable[..., Awaitable[None]]) -> None: + """ + Register event callback. + + Args: + event: Event name + callback: Async callback function + """ + if event in self._callbacks: + self._callbacks[event].append(callback) + + async def _emit(self, event: str, *args, **kwargs) -> None: + """Emit event to all registered callbacks.""" + for callback in self._callbacks.get(event, []): + try: + await callback(*args, **kwargs) + except Exception as e: + logger.error(f"Event callback error ({event}): {e}") + + async def connect(self) -> None: + """Connect to OpenAI Realtime API.""" + if not WEBSOCKETS_AVAILABLE: + raise RuntimeError("websockets package not installed") + + if not self.config.api_key: + raise ValueError("OpenAI API key not provided") + + self.state = RealtimeState.CONNECTING + + # Build URL + if self.config.endpoint: + # Azure or custom endpoint + url = f"{self.config.endpoint}/openai/realtime?api-version=2024-10-01-preview&deployment={self.config.model}" + else: + # OpenAI endpoint + url = f"wss://api.openai.com/v1/realtime?model={self.config.model}" + + # Build headers + headers = {} + if self.config.endpoint: + headers["api-key"] = self.config.api_key + else: + headers["Authorization"] = f"Bearer {self.config.api_key}" + headers["OpenAI-Beta"] = "realtime=v1" + + try: + logger.info(f"Connecting to Realtime API: {url}") + self._ws = await websockets.connect(url, extra_headers=headers) + + # Send session configuration + await self._configure_session() + + # Start receive loop + self._receive_task = asyncio.create_task(self._receive_loop()) + + self.state = RealtimeState.CONNECTED + logger.info("Realtime API connected successfully") + + except Exception as e: + self.state = RealtimeState.ERROR + logger.error(f"Realtime API connection failed: {e}") + raise + + async def _configure_session(self) -> None: + """Send session configuration to server.""" + session_config = { + "type": "session.update", + "session": { + "modalities": ["text", "audio"], + "instructions": self.config.instructions, + "voice": self.config.voice, + "input_audio_format": self.config.input_audio_format, + "output_audio_format": self.config.output_audio_format, + "turn_detection": self.config.turn_detection, + } + } + + if self.config.tools: + session_config["session"]["tools"] = self.config.tools + + await self._send(session_config) + logger.debug("Session configuration sent") + + async def _send(self, data: Dict[str, Any]) -> None: + """Send JSON data to server.""" + if self._ws: + await self._ws.send(json.dumps(data)) + + async def send_audio(self, audio_bytes: bytes) -> None: + """ + Send audio to the Realtime API. + + Args: + audio_bytes: PCM audio data (16-bit, mono, 24kHz by default) + """ + if self.state != RealtimeState.CONNECTED: + return + + # Encode audio as base64 + audio_b64 = base64.standard_b64encode(audio_bytes).decode() + + await self._send({ + "type": "input_audio_buffer.append", + "audio": audio_b64 + }) + + async def send_text(self, text: str) -> None: + """ + Send text input (bypassing audio). + + Args: + text: User text input + """ + if self.state != RealtimeState.CONNECTED: + return + + # Create a conversation item with user text + await self._send({ + "type": "conversation.item.create", + "item": { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": text}] + } + }) + + # Trigger response + await self._send({"type": "response.create"}) + + async def cancel_response(self) -> None: + """Cancel the current response (for barge-in).""" + if self.state != RealtimeState.CONNECTED: + return + + await self._send({"type": "response.cancel"}) + logger.debug("Response cancelled") + + async def commit_audio(self) -> None: + """Commit the audio buffer and trigger response.""" + if self.state != RealtimeState.CONNECTED: + return + + await self._send({"type": "input_audio_buffer.commit"}) + await self._send({"type": "response.create"}) + + async def clear_audio_buffer(self) -> None: + """Clear the input audio buffer.""" + if self.state != RealtimeState.CONNECTED: + return + + await self._send({"type": "input_audio_buffer.clear"}) + + async def submit_function_result(self, call_id: str, result: str) -> None: + """ + Submit function call result. + + Args: + call_id: The function call ID + result: JSON string result + """ + if self.state != RealtimeState.CONNECTED: + return + + await self._send({ + "type": "conversation.item.create", + "item": { + "type": "function_call_output", + "call_id": call_id, + "output": result + } + }) + + # Trigger response with the function result + await self._send({"type": "response.create"}) + + async def _receive_loop(self) -> None: + """Receive and process messages from the Realtime API.""" + if not self._ws: + return + + try: + async for message in self._ws: + try: + data = json.loads(message) + await self._handle_event(data) + except json.JSONDecodeError: + logger.warning(f"Invalid JSON received: {message[:100]}") + + except asyncio.CancelledError: + logger.debug("Receive loop cancelled") + except websockets.ConnectionClosed as e: + logger.info(f"WebSocket closed: {e}") + self.state = RealtimeState.DISCONNECTED + except Exception as e: + logger.error(f"Receive loop error: {e}") + self.state = RealtimeState.ERROR + + async def _handle_event(self, data: Dict[str, Any]) -> None: + """Handle incoming event from Realtime API.""" + event_type = data.get("type", "unknown") + + # Audio delta - streaming audio output + if event_type == "response.audio.delta": + if "delta" in data: + audio_bytes = base64.standard_b64decode(data["delta"]) + await self._emit("on_audio", audio_bytes) + + # Audio transcript delta - streaming text + elif event_type == "response.audio_transcript.delta": + if "delta" in data: + await self._emit("on_transcript", data["delta"], "assistant", False) + + # Audio transcript done + elif event_type == "response.audio_transcript.done": + if "transcript" in data: + await self._emit("on_transcript", data["transcript"], "assistant", True) + + # Input audio transcript (user speech) + elif event_type == "conversation.item.input_audio_transcription.completed": + if "transcript" in data: + await self._emit("on_transcript", data["transcript"], "user", True) + + # Speech started (server VAD detected speech) + elif event_type == "input_audio_buffer.speech_started": + await self._emit("on_speech_started", data.get("audio_start_ms", 0)) + + # Speech stopped + elif event_type == "input_audio_buffer.speech_stopped": + await self._emit("on_speech_stopped", data.get("audio_end_ms", 0)) + + # Response started + elif event_type == "response.created": + await self._emit("on_response_started", data.get("response", {})) + + # Response done + elif event_type == "response.done": + await self._emit("on_response_done", data.get("response", {})) + + # Function call + elif event_type == "response.function_call_arguments.done": + call_id = data.get("call_id") + name = data.get("name") + arguments = data.get("arguments", "{}") + await self._emit("on_function_call", call_id, name, arguments) + + # Error + elif event_type == "error": + error = data.get("error", {}) + logger.error(f"Realtime API error: {error}") + await self._emit("on_error", error) + + # Session events + elif event_type == "session.created": + logger.info("Session created") + elif event_type == "session.updated": + logger.debug("Session updated") + + else: + logger.debug(f"Unhandled event type: {event_type}") + + async def disconnect(self) -> None: + """Disconnect from Realtime API.""" + self._cancel_event.set() + + if self._receive_task: + self._receive_task.cancel() + try: + await self._receive_task + except asyncio.CancelledError: + pass + + if self._ws: + await self._ws.close() + self._ws = None + + self.state = RealtimeState.DISCONNECTED + logger.info("Realtime API disconnected") + + +class RealtimePipeline: + """ + Pipeline adapter for RealtimeService. + + Provides a compatible interface with DuplexPipeline but uses + OpenAI Realtime API for all processing. + """ + + def __init__( + self, + transport, + session_id: str, + config: Optional[RealtimeConfig] = None + ): + """ + Initialize Realtime pipeline. + + Args: + transport: Transport for sending audio/events + session_id: Session identifier + config: Realtime configuration + """ + self.transport = transport + self.session_id = session_id + + self.service = RealtimeService(config) + + # Register callbacks + self.service.on("on_audio", self._on_audio) + self.service.on("on_transcript", self._on_transcript) + self.service.on("on_speech_started", self._on_speech_started) + self.service.on("on_speech_stopped", self._on_speech_stopped) + self.service.on("on_response_started", self._on_response_started) + self.service.on("on_response_done", self._on_response_done) + self.service.on("on_error", self._on_error) + + self._is_speaking = False + self._running = True + + logger.info(f"RealtimePipeline initialized for session {session_id}") + + async def start(self) -> None: + """Start the pipeline.""" + await self.service.connect() + + async def process_audio(self, pcm_bytes: bytes) -> None: + """ + Process incoming audio. + + Note: Realtime API expects 24kHz audio by default. + You may need to resample from 16kHz. + """ + if not self._running: + return + + # TODO: Resample from 16kHz to 24kHz if needed + await self.service.send_audio(pcm_bytes) + + async def process_text(self, text: str) -> None: + """Process text input.""" + if not self._running: + return + + await self.service.send_text(text) + + async def interrupt(self) -> None: + """Interrupt current response.""" + await self.service.cancel_response() + await self.transport.send_event({ + "event": "interrupt", + "trackId": self.session_id, + "timestamp": self._get_timestamp_ms() + }) + + async def cleanup(self) -> None: + """Cleanup resources.""" + self._running = False + await self.service.disconnect() + + # Event handlers + + async def _on_audio(self, audio_bytes: bytes) -> None: + """Handle audio output.""" + await self.transport.send_audio(audio_bytes) + + async def _on_transcript(self, text: str, role: str, is_final: bool) -> None: + """Handle transcript.""" + logger.info(f"[{role.upper()}] {text[:50]}..." if len(text) > 50 else f"[{role.upper()}] {text}") + + async def _on_speech_started(self, start_ms: int) -> None: + """Handle user speech start.""" + self._is_speaking = True + await self.transport.send_event({ + "event": "speaking", + "trackId": self.session_id, + "timestamp": self._get_timestamp_ms(), + "startTime": start_ms + }) + + # Cancel any ongoing response (barge-in) + await self.service.cancel_response() + + async def _on_speech_stopped(self, end_ms: int) -> None: + """Handle user speech stop.""" + self._is_speaking = False + await self.transport.send_event({ + "event": "silence", + "trackId": self.session_id, + "timestamp": self._get_timestamp_ms(), + "duration": end_ms + }) + + async def _on_response_started(self, response: Dict) -> None: + """Handle response start.""" + await self.transport.send_event({ + "event": "trackStart", + "trackId": self.session_id, + "timestamp": self._get_timestamp_ms() + }) + + async def _on_response_done(self, response: Dict) -> None: + """Handle response complete.""" + await self.transport.send_event({ + "event": "trackEnd", + "trackId": self.session_id, + "timestamp": self._get_timestamp_ms() + }) + + async def _on_error(self, error: Dict) -> None: + """Handle error.""" + await self.transport.send_event({ + "event": "error", + "trackId": self.session_id, + "timestamp": self._get_timestamp_ms(), + "sender": "realtime", + "error": str(error) + }) + + def _get_timestamp_ms(self) -> int: + """Get current timestamp in milliseconds.""" + import time + return int(time.time() * 1000) + + @property + def is_speaking(self) -> bool: + """Check if user is speaking.""" + return self._is_speaking diff --git a/services/siliconflow_asr.py b/services/siliconflow_asr.py new file mode 100644 index 0000000..2cb95dc --- /dev/null +++ b/services/siliconflow_asr.py @@ -0,0 +1,8 @@ +"""Backward-compatible imports for legacy siliconflow_asr module.""" + +from services.openai_compatible_asr import OpenAICompatibleASRService + +# Backward-compatible alias +SiliconFlowASRService = OpenAICompatibleASRService + +__all__ = ["OpenAICompatibleASRService", "SiliconFlowASRService"] diff --git a/services/siliconflow_tts.py b/services/siliconflow_tts.py new file mode 100644 index 0000000..3cdf32a --- /dev/null +++ b/services/siliconflow_tts.py @@ -0,0 +1,8 @@ +"""Backward-compatible imports for legacy siliconflow_tts module.""" + +from services.openai_compatible_tts import OpenAICompatibleTTSService, StreamingTTSAdapter + +# Backward-compatible alias +SiliconFlowTTSService = OpenAICompatibleTTSService + +__all__ = ["OpenAICompatibleTTSService", "SiliconFlowTTSService", "StreamingTTSAdapter"] diff --git a/services/streaming_text.py b/services/streaming_text.py new file mode 100644 index 0000000..d5c123f --- /dev/null +++ b/services/streaming_text.py @@ -0,0 +1,86 @@ +"""Shared text chunking helpers for streaming TTS.""" + +from typing import Optional + + +def is_non_sentence_period(text: str, idx: int) -> bool: + """Check whether '.' should NOT be treated as a sentence delimiter.""" + if idx < 0 or idx >= len(text) or text[idx] != ".": + return False + + # Decimal/version segment: 1.2, v1.2.3 + if idx > 0 and idx < len(text) - 1 and text[idx - 1].isdigit() and text[idx + 1].isdigit(): + return True + + # Number abbreviations: No.1 / No. 1 + left_start = idx - 1 + while left_start >= 0 and text[left_start].isalpha(): + left_start -= 1 + left_token = text[left_start + 1:idx].lower() + if left_token == "no": + j = idx + 1 + while j < len(text) and text[j].isspace(): + j += 1 + if j < len(text) and text[j].isdigit(): + return True + + return False + + +def has_spoken_content(text: str) -> bool: + """Check whether text contains pronounceable content (not punctuation-only).""" + return any(char.isalnum() for char in text) + + +def extract_tts_sentence( + text_buffer: str, + *, + end_chars: frozenset[str], + trailing_chars: frozenset[str], + closers: frozenset[str], + min_split_spoken_chars: int = 0, + hold_trailing_at_buffer_end: bool = False, + force: bool = False, +) -> Optional[tuple[str, str]]: + """Extract one TTS sentence from text buffer.""" + if not text_buffer: + return None + + search_start = 0 + while True: + split_idx = -1 + for idx in range(search_start, len(text_buffer)): + char = text_buffer[idx] + if char == "." and is_non_sentence_period(text_buffer, idx): + continue + if char in end_chars: + split_idx = idx + break + + if split_idx == -1: + return None + + end_idx = split_idx + 1 + while end_idx < len(text_buffer) and text_buffer[end_idx] in trailing_chars: + end_idx += 1 + + while end_idx < len(text_buffer) and text_buffer[end_idx] in closers: + end_idx += 1 + + if hold_trailing_at_buffer_end and not force and end_idx >= len(text_buffer): + return None + + sentence = text_buffer[:end_idx].strip() + spoken_chars = sum(1 for ch in sentence if ch.isalnum()) + + if ( + not force + and min_split_spoken_chars > 0 + and 0 < spoken_chars < min_split_spoken_chars + and end_idx < len(text_buffer) + ): + search_start = end_idx + continue + + remainder = text_buffer[end_idx:] + return sentence, remainder diff --git a/services/streaming_tts_adapter.py b/services/streaming_tts_adapter.py new file mode 100644 index 0000000..d4cb745 --- /dev/null +++ b/services/streaming_tts_adapter.py @@ -0,0 +1,95 @@ +"""Backend-agnostic streaming adapter from LLM text to TTS audio.""" + +import asyncio + +from loguru import logger + +from services.base import BaseTTSService +from services.streaming_text import extract_tts_sentence, has_spoken_content + + +class StreamingTTSAdapter: + """ + Adapter for streaming LLM text to TTS with sentence-level chunking. + + This reduces latency by starting TTS as soon as a complete sentence + is received from the LLM, rather than waiting for the full response. + """ + + SENTENCE_ENDS = {"。", "!", "?", ".", "!", "?", "\n"} + SENTENCE_CLOSERS = frozenset() + + def __init__(self, tts_service: BaseTTSService, transport, session_id: str): + self.tts_service = tts_service + self.transport = transport + self.session_id = session_id + self._buffer = "" + self._cancel_event = asyncio.Event() + self._is_speaking = False + + async def process_text_chunk(self, text_chunk: str) -> None: + """ + Process a text chunk from LLM and trigger TTS when sentence is complete. + + Args: + text_chunk: Text chunk from LLM streaming + """ + if self._cancel_event.is_set(): + return + + self._buffer += text_chunk + + # Check for sentence completion + while True: + split_result = extract_tts_sentence( + self._buffer, + end_chars=frozenset(self.SENTENCE_ENDS), + trailing_chars=frozenset(self.SENTENCE_ENDS), + closers=self.SENTENCE_CLOSERS, + force=False, + ) + if not split_result: + break + + sentence, self._buffer = split_result + if sentence and has_spoken_content(sentence): + await self._speak_sentence(sentence) + + async def flush(self) -> None: + """Flush remaining buffer.""" + if self._buffer.strip() and not self._cancel_event.is_set(): + await self._speak_sentence(self._buffer.strip()) + self._buffer = "" + + async def _speak_sentence(self, text: str) -> None: + """Synthesize and send a sentence.""" + if not text or self._cancel_event.is_set(): + return + + self._is_speaking = True + + try: + async for chunk in self.tts_service.synthesize_stream(text): + if self._cancel_event.is_set(): + break + await self.transport.send_audio(chunk.audio) + await asyncio.sleep(0.01) # Prevent flooding + except Exception as e: + logger.error(f"TTS speak error: {e}") + finally: + self._is_speaking = False + + def cancel(self) -> None: + """Cancel ongoing speech.""" + self._cancel_event.set() + self._buffer = "" + + def reset(self) -> None: + """Reset for new turn.""" + self._cancel_event.clear() + self._buffer = "" + self._is_speaking = False + + @property + def is_speaking(self) -> bool: + return self._is_speaking diff --git a/services/tts.py b/services/tts.py new file mode 100644 index 0000000..e838f08 --- /dev/null +++ b/services/tts.py @@ -0,0 +1,271 @@ +"""TTS (Text-to-Speech) Service implementations. + +Provides multiple TTS backend options including edge-tts (free) +and placeholder for cloud services. +""" + +import os +import io +import asyncio +import struct +from typing import AsyncIterator, Optional +from loguru import logger + +from services.base import BaseTTSService, TTSChunk, ServiceState + +# Try to import edge-tts +try: + import edge_tts + EDGE_TTS_AVAILABLE = True +except ImportError: + EDGE_TTS_AVAILABLE = False + logger.warning("edge-tts not available - EdgeTTS service will be disabled") + + +class EdgeTTSService(BaseTTSService): + """ + Microsoft Edge TTS service. + + Uses edge-tts library for free, high-quality speech synthesis. + Supports streaming for low-latency playback. + """ + + # Voice mapping for common languages + VOICE_MAP = { + "en": "en-US-JennyNeural", + "en-US": "en-US-JennyNeural", + "en-GB": "en-GB-SoniaNeural", + "zh": "zh-CN-XiaoxiaoNeural", + "zh-CN": "zh-CN-XiaoxiaoNeural", + "zh-TW": "zh-TW-HsiaoChenNeural", + "ja": "ja-JP-NanamiNeural", + "ko": "ko-KR-SunHiNeural", + "fr": "fr-FR-DeniseNeural", + "de": "de-DE-KatjaNeural", + "es": "es-ES-ElviraNeural", + } + + def __init__( + self, + voice: str = "en-US-JennyNeural", + sample_rate: int = 16000, + speed: float = 1.0 + ): + """ + Initialize Edge TTS service. + + Args: + voice: Voice name (e.g., "en-US-JennyNeural") or language code (e.g., "en") + sample_rate: Target sample rate (will be resampled) + speed: Speech speed multiplier + """ + # Resolve voice from language code if needed + if voice in self.VOICE_MAP: + voice = self.VOICE_MAP[voice] + + super().__init__(voice=voice, sample_rate=sample_rate, speed=speed) + self._cancel_event = asyncio.Event() + + async def connect(self) -> None: + """Edge TTS doesn't require explicit connection.""" + if not EDGE_TTS_AVAILABLE: + raise RuntimeError("edge-tts package not installed") + self.state = ServiceState.CONNECTED + logger.info(f"Edge TTS service ready: voice={self.voice}") + + async def disconnect(self) -> None: + """Edge TTS doesn't require explicit disconnection.""" + self.state = ServiceState.DISCONNECTED + logger.info("Edge TTS service disconnected") + + def _get_rate_string(self) -> str: + """Convert speed to rate string for edge-tts.""" + # edge-tts uses percentage format: "+0%", "-10%", "+20%" + percentage = int((self.speed - 1.0) * 100) + if percentage >= 0: + return f"+{percentage}%" + return f"{percentage}%" + + async def synthesize(self, text: str) -> bytes: + """ + Synthesize complete audio for text. + + Args: + text: Text to synthesize + + Returns: + PCM audio data (16-bit, mono, 16kHz) + """ + if not EDGE_TTS_AVAILABLE: + raise RuntimeError("edge-tts not available") + + # Collect all chunks + audio_data = b"" + async for chunk in self.synthesize_stream(text): + audio_data += chunk.audio + + return audio_data + + async def synthesize_stream(self, text: str) -> AsyncIterator[TTSChunk]: + """ + Synthesize audio in streaming mode. + + Args: + text: Text to synthesize + + Yields: + TTSChunk objects with PCM audio + """ + if not EDGE_TTS_AVAILABLE: + raise RuntimeError("edge-tts not available") + + self._cancel_event.clear() + + try: + communicate = edge_tts.Communicate( + text, + voice=self.voice, + rate=self._get_rate_string() + ) + + # edge-tts outputs MP3, we need to decode to PCM + # For now, collect MP3 chunks and yield after conversion + mp3_data = b"" + + async for chunk in communicate.stream(): + # Check for cancellation + if self._cancel_event.is_set(): + logger.info("TTS synthesis cancelled") + return + + if chunk["type"] == "audio": + mp3_data += chunk["data"] + + # Convert MP3 to PCM + if mp3_data: + pcm_data = await self._convert_mp3_to_pcm(mp3_data) + if pcm_data: + # Yield in chunks for streaming playback + chunk_size = self.sample_rate * 2 // 10 # 100ms chunks + for i in range(0, len(pcm_data), chunk_size): + if self._cancel_event.is_set(): + return + + chunk_data = pcm_data[i:i + chunk_size] + yield TTSChunk( + audio=chunk_data, + sample_rate=self.sample_rate, + is_final=(i + chunk_size >= len(pcm_data)) + ) + + except asyncio.CancelledError: + logger.info("TTS synthesis cancelled via asyncio") + raise + except Exception as e: + logger.error(f"TTS synthesis error: {e}") + raise + + async def _convert_mp3_to_pcm(self, mp3_data: bytes) -> bytes: + """ + Convert MP3 audio to PCM. + + Uses pydub or ffmpeg for conversion. + """ + try: + # Try using pydub (requires ffmpeg) + from pydub import AudioSegment + + # Load MP3 from bytes + audio = AudioSegment.from_mp3(io.BytesIO(mp3_data)) + + # Convert to target format + audio = audio.set_frame_rate(self.sample_rate) + audio = audio.set_channels(1) + audio = audio.set_sample_width(2) # 16-bit + + # Export as raw PCM + return audio.raw_data + + except ImportError: + logger.warning("pydub not available, trying fallback") + # Fallback: Use subprocess to call ffmpeg directly + return await self._ffmpeg_convert(mp3_data) + except Exception as e: + logger.error(f"Audio conversion error: {e}") + return b"" + + async def _ffmpeg_convert(self, mp3_data: bytes) -> bytes: + """Convert MP3 to PCM using ffmpeg subprocess.""" + try: + process = await asyncio.create_subprocess_exec( + "ffmpeg", + "-i", "pipe:0", + "-f", "s16le", + "-acodec", "pcm_s16le", + "-ar", str(self.sample_rate), + "-ac", "1", + "pipe:1", + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.DEVNULL + ) + + stdout, _ = await process.communicate(input=mp3_data) + return stdout + + except Exception as e: + logger.error(f"ffmpeg conversion error: {e}") + return b"" + + async def cancel(self) -> None: + """Cancel ongoing synthesis.""" + self._cancel_event.set() + + +class MockTTSService(BaseTTSService): + """ + Mock TTS service for testing without actual synthesis. + + Generates silence or simple tones. + """ + + def __init__( + self, + voice: str = "mock", + sample_rate: int = 16000, + speed: float = 1.0 + ): + super().__init__(voice=voice, sample_rate=sample_rate, speed=speed) + + async def connect(self) -> None: + self.state = ServiceState.CONNECTED + logger.info("Mock TTS service connected") + + async def disconnect(self) -> None: + self.state = ServiceState.DISCONNECTED + logger.info("Mock TTS service disconnected") + + async def synthesize(self, text: str) -> bytes: + """Generate silence based on text length.""" + # Approximate: 100ms per word + word_count = len(text.split()) + duration_ms = word_count * 100 + samples = int(self.sample_rate * duration_ms / 1000) + + # Generate silence (zeros) + return bytes(samples * 2) # 16-bit = 2 bytes per sample + + async def synthesize_stream(self, text: str) -> AsyncIterator[TTSChunk]: + """Generate silence chunks.""" + audio = await self.synthesize(text) + + # Yield in 100ms chunks + chunk_size = self.sample_rate * 2 // 10 + for i in range(0, len(audio), chunk_size): + chunk_data = audio[i:i + chunk_size] + yield TTSChunk( + audio=chunk_data, + sample_rate=self.sample_rate, + is_final=(i + chunk_size >= len(audio)) + ) + await asyncio.sleep(0.05) # Simulate processing time diff --git a/tests/test_tool_call_flow.py b/tests/test_tool_call_flow.py new file mode 100644 index 0000000..e5f241b --- /dev/null +++ b/tests/test_tool_call_flow.py @@ -0,0 +1,331 @@ +import asyncio +from typing import Any, Dict, List + +import pytest + +from core.duplex_pipeline import DuplexPipeline +from models.ws_v1 import ToolCallResultsMessage, parse_client_message +from services.base import LLMStreamEvent + + +class _DummySileroVAD: + def __init__(self, *args, **kwargs): + pass + + def process_audio(self, _pcm: bytes) -> float: + return 0.0 + + +class _DummyVADProcessor: + def __init__(self, *args, **kwargs): + pass + + def process(self, _speech_prob: float): + return "Silence", 0.0 + + +class _DummyEouDetector: + def __init__(self, *args, **kwargs): + pass + + def process(self, _vad_status: str) -> bool: + return False + + def reset(self) -> None: + return None + + +class _FakeTransport: + async def send_event(self, _event: Dict[str, Any]) -> None: + return None + + async def send_audio(self, _audio: bytes) -> None: + return None + + +class _FakeTTS: + async def synthesize_stream(self, _text: str): + if False: + yield None + + +class _FakeASR: + async def connect(self) -> None: + return None + + +class _FakeLLM: + def __init__(self, rounds: List[List[LLMStreamEvent]]): + self._rounds = rounds + self._call_index = 0 + + async def generate_stream(self, _messages, temperature=0.7, max_tokens=None): + idx = self._call_index + self._call_index += 1 + events = self._rounds[idx] if idx < len(self._rounds) else [LLMStreamEvent(type="done")] + for event in events: + yield event + + +def _build_pipeline(monkeypatch, llm_rounds: List[List[LLMStreamEvent]]) -> tuple[DuplexPipeline, List[Dict[str, Any]]]: + monkeypatch.setattr("core.duplex_pipeline.SileroVAD", _DummySileroVAD) + monkeypatch.setattr("core.duplex_pipeline.VADProcessor", _DummyVADProcessor) + monkeypatch.setattr("core.duplex_pipeline.EouDetector", _DummyEouDetector) + + pipeline = DuplexPipeline( + transport=_FakeTransport(), + session_id="s_test", + llm_service=_FakeLLM(llm_rounds), + tts_service=_FakeTTS(), + asr_service=_FakeASR(), + ) + events: List[Dict[str, Any]] = [] + + async def _capture_event(event: Dict[str, Any], priority: int = 20): + events.append(event) + + async def _noop_speak(_text: str, fade_in_ms: int = 0, fade_out_ms: int = 8): + return None + + monkeypatch.setattr(pipeline, "_send_event", _capture_event) + monkeypatch.setattr(pipeline, "_speak_sentence", _noop_speak) + return pipeline, events + + +@pytest.mark.asyncio +async def test_ws_message_parses_tool_call_results(): + msg = parse_client_message( + { + "type": "tool_call.results", + "results": [{"tool_call_id": "call_1", "status": {"code": 200, "message": "ok"}}], + } + ) + assert isinstance(msg, ToolCallResultsMessage) + assert msg.results[0]["tool_call_id"] == "call_1" + + +@pytest.mark.asyncio +async def test_turn_without_tool_keeps_streaming(monkeypatch): + pipeline, events = _build_pipeline( + monkeypatch, + [ + [ + LLMStreamEvent(type="text_delta", text="hello "), + LLMStreamEvent(type="text_delta", text="world."), + LLMStreamEvent(type="done"), + ] + ], + ) + + await pipeline._handle_turn("hi") + + event_types = [e.get("type") for e in events] + assert "assistant.response.delta" in event_types + assert "assistant.response.final" in event_types + assert "assistant.tool_call" not in event_types + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "metadata", + [ + {"output": {"mode": "text"}}, + {"services": {"tts": {"enabled": False}}}, + ], +) +async def test_text_output_mode_skips_audio_events(monkeypatch, metadata): + pipeline, events = _build_pipeline( + monkeypatch, + [ + [ + LLMStreamEvent(type="text_delta", text="hello "), + LLMStreamEvent(type="text_delta", text="world."), + LLMStreamEvent(type="done"), + ] + ], + ) + pipeline.apply_runtime_overrides(metadata) + + await pipeline._handle_turn("hi") + + event_types = [e.get("type") for e in events] + assert "assistant.response.delta" in event_types + assert "assistant.response.final" in event_types + assert "output.audio.start" not in event_types + assert "output.audio.end" not in event_types + + +@pytest.mark.asyncio +async def test_turn_with_tool_call_then_results(monkeypatch): + pipeline, events = _build_pipeline( + monkeypatch, + [ + [ + LLMStreamEvent(type="text_delta", text="let me check."), + LLMStreamEvent( + type="tool_call", + tool_call={ + "id": "call_ok", + "executor": "client", + "type": "function", + "function": {"name": "weather", "arguments": "{\"city\":\"hz\"}"}, + }, + ), + LLMStreamEvent(type="done"), + ], + [ + LLMStreamEvent(type="text_delta", text="it's sunny."), + LLMStreamEvent(type="done"), + ], + ], + ) + + task = asyncio.create_task(pipeline._handle_turn("weather?")) + for _ in range(200): + if any(e.get("type") == "assistant.tool_call" for e in events): + break + await asyncio.sleep(0.005) + + await pipeline.handle_tool_call_results( + [ + { + "tool_call_id": "call_ok", + "name": "weather", + "output": {"temp": 21}, + "status": {"code": 200, "message": "ok"}, + } + ] + ) + await task + + assert any(e.get("type") == "assistant.tool_call" for e in events) + finals = [e for e in events if e.get("type") == "assistant.response.final"] + assert finals + assert "it's sunny" in finals[-1].get("text", "") + + +@pytest.mark.asyncio +async def test_turn_with_tool_call_timeout(monkeypatch): + pipeline, events = _build_pipeline( + monkeypatch, + [ + [ + LLMStreamEvent( + type="tool_call", + tool_call={ + "id": "call_timeout", + "executor": "client", + "type": "function", + "function": {"name": "search", "arguments": "{\"query\":\"x\"}"}, + }, + ), + LLMStreamEvent(type="done"), + ], + [ + LLMStreamEvent(type="text_delta", text="fallback answer."), + LLMStreamEvent(type="done"), + ], + ], + ) + pipeline._TOOL_WAIT_TIMEOUT_SECONDS = 0.01 + + await pipeline._handle_turn("query") + + finals = [e for e in events if e.get("type") == "assistant.response.final"] + assert finals + assert "fallback answer" in finals[-1].get("text", "") + + +@pytest.mark.asyncio +async def test_duplicate_tool_results_are_ignored(monkeypatch): + pipeline, _events = _build_pipeline(monkeypatch, [[LLMStreamEvent(type="done")]]) + + await pipeline.handle_tool_call_results( + [{"tool_call_id": "call_dup", "output": {"value": 1}, "status": {"code": 200, "message": "ok"}}] + ) + await pipeline.handle_tool_call_results( + [{"tool_call_id": "call_dup", "output": {"value": 2}, "status": {"code": 200, "message": "ok"}}] + ) + result = await pipeline._wait_for_single_tool_result("call_dup") + + assert result.get("output", {}).get("value") == 1 + + +@pytest.mark.asyncio +async def test_server_calculator_emits_tool_result(monkeypatch): + pipeline, events = _build_pipeline( + monkeypatch, + [ + [ + LLMStreamEvent( + type="tool_call", + tool_call={ + "id": "call_calc", + "executor": "server", + "type": "function", + "function": {"name": "calculator", "arguments": "{\"expression\":\"1+2\"}"}, + }, + ), + LLMStreamEvent(type="done"), + ], + [ + LLMStreamEvent(type="text_delta", text="done."), + LLMStreamEvent(type="done"), + ], + ], + ) + + await pipeline._handle_turn("calc") + + tool_results = [e for e in events if e.get("type") == "assistant.tool_result"] + assert tool_results + payload = tool_results[-1].get("result", {}) + assert payload.get("status", {}).get("code") == 200 + assert payload.get("output", {}).get("result") == 3 + + +@pytest.mark.asyncio +async def test_server_tool_timeout_emits_504_and_continues(monkeypatch): + async def _slow_execute(_call): + await asyncio.sleep(0.05) + return { + "tool_call_id": "call_slow", + "name": "weather", + "output": {"ok": True}, + "status": {"code": 200, "message": "ok"}, + } + + monkeypatch.setattr("core.duplex_pipeline.execute_server_tool", _slow_execute) + + pipeline, events = _build_pipeline( + monkeypatch, + [ + [ + LLMStreamEvent( + type="tool_call", + tool_call={ + "id": "call_slow", + "executor": "server", + "type": "function", + "function": {"name": "weather", "arguments": "{\"city\":\"hz\"}"}, + }, + ), + LLMStreamEvent(type="done"), + ], + [ + LLMStreamEvent(type="text_delta", text="timeout fallback."), + LLMStreamEvent(type="done"), + ], + ], + ) + pipeline._SERVER_TOOL_TIMEOUT_SECONDS = 0.01 + + await pipeline._handle_turn("weather?") + + tool_results = [e for e in events if e.get("type") == "assistant.tool_result"] + assert tool_results + payload = tool_results[-1].get("result", {}) + assert payload.get("status", {}).get("code") == 504 + finals = [e for e in events if e.get("type") == "assistant.response.final"] + assert finals + assert "timeout fallback" in finals[-1].get("text", "") diff --git a/tests/test_tool_executor.py b/tests/test_tool_executor.py new file mode 100644 index 0000000..17345c7 --- /dev/null +++ b/tests/test_tool_executor.py @@ -0,0 +1,57 @@ +import pytest + +from core.tool_executor import execute_server_tool + + +@pytest.mark.asyncio +async def test_code_interpreter_simple_expression(): + result = await execute_server_tool( + { + "id": "call_ci_ok", + "function": { + "name": "code_interpreter", + "arguments": '{"code":"sum([1, 2, 3]) + 4"}', + }, + } + ) + assert result["status"]["code"] == 200 + assert result["output"]["result"] == 10 + + +@pytest.mark.asyncio +async def test_code_interpreter_blocks_import_and_io(): + result = await execute_server_tool( + { + "id": "call_ci_bad", + "function": { + "name": "code_interpreter", + "arguments": '{"code":"__import__(\\"os\\").system(\\"ls\\")"}', + }, + } + ) + assert result["status"]["code"] == 422 + assert result["status"]["message"] == "invalid_code" + + +@pytest.mark.asyncio +async def test_current_time_uses_local_system_clock(monkeypatch): + async def _should_not_be_called(_tool_id): + raise AssertionError("fetch_tool_resource should not be called for current_time") + + monkeypatch.setattr("core.tool_executor.fetch_tool_resource", _should_not_be_called) + + result = await execute_server_tool( + { + "id": "call_time_ok", + "function": { + "name": "current_time", + "arguments": "{}", + }, + } + ) + + assert result["status"]["code"] == 200 + assert result["status"]["message"] == "ok" + assert "local_time" in result["output"] + assert "iso" in result["output"] + assert "timestamp" in result["output"] diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..48a989f --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1 @@ +"""Utilities Package""" diff --git a/utils/logging.py b/utils/logging.py new file mode 100644 index 0000000..28b3a8f --- /dev/null +++ b/utils/logging.py @@ -0,0 +1,83 @@ +"""Logging configuration utilities.""" + +import sys +from loguru import logger +from pathlib import Path + + +def setup_logging( + log_level: str = "INFO", + log_format: str = "text", + log_to_file: bool = True, + log_dir: str = "logs" +): + """ + Configure structured logging with loguru. + + Args: + log_level: Logging level (DEBUG, INFO, WARNING, ERROR) + log_format: Format type (json or text) + log_to_file: Whether to log to file + log_dir: Directory for log files + """ + # Remove default handler + logger.remove() + + # Console handler + if log_format == "json": + logger.add( + sys.stdout, + format="{message}", + level=log_level, + serialize=True, + colorize=False + ) + else: + logger.add( + sys.stdout, + format="{time:HH:mm:ss} | {level: <8} | {message}", + level=log_level, + colorize=True + ) + + # File handler + if log_to_file: + log_path = Path(log_dir) + log_path.mkdir(exist_ok=True) + + if log_format == "json": + logger.add( + log_path / "active_call_{time:YYYY-MM-DD}.log", + format="{message}", + level=log_level, + rotation="1 day", + retention="7 days", + compression="zip", + serialize=True + ) + else: + logger.add( + log_path / "active_call_{time:YYYY-MM-DD}.log", + format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}", + level=log_level, + rotation="1 day", + retention="7 days", + compression="zip" + ) + + return logger + + +def get_logger(name: str = None): + """ + Get a logger instance. + + Args: + name: Logger name (optional) + + Returns: + Logger instance + """ + if name: + return logger.bind(name=name) + return logger