Init commit
This commit is contained in:
92
.env.example
Normal file
92
.env.example
Normal file
@@ -0,0 +1,92 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
# Engine .env example (safe template)
|
||||
# Notes:
|
||||
# - Never commit real API keys.
|
||||
# - Start with defaults below, then tune from logs.
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Server
|
||||
HOST=0.0.0.0
|
||||
PORT=8000
|
||||
# EXTERNAL_IP=1.2.3.4
|
||||
|
||||
# Backend bridge (optional)
|
||||
BACKEND_URL=http://127.0.0.1:8100
|
||||
BACKEND_TIMEOUT_SEC=10
|
||||
HISTORY_DEFAULT_USER_ID=1
|
||||
|
||||
# Audio
|
||||
SAMPLE_RATE=16000
|
||||
# 20ms is recommended for VAD stability and latency.
|
||||
# 100ms works but usually worsens start-of-speech accuracy.
|
||||
CHUNK_SIZE_MS=20
|
||||
DEFAULT_CODEC=pcm
|
||||
MAX_AUDIO_BUFFER_SECONDS=30
|
||||
|
||||
# VAD / EOU
|
||||
VAD_TYPE=silero
|
||||
VAD_MODEL_PATH=data/vad/silero_vad.onnx
|
||||
# Higher = stricter speech detection (fewer false positives, more misses).
|
||||
VAD_THRESHOLD=0.5
|
||||
# Require this much continuous speech before utterance can be valid.
|
||||
VAD_MIN_SPEECH_DURATION_MS=100
|
||||
# Silence duration required to finalize one user turn.
|
||||
VAD_EOU_THRESHOLD_MS=800
|
||||
|
||||
# LLM
|
||||
OPENAI_API_KEY=your_openai_api_key_here
|
||||
# Optional for OpenAI-compatible providers.
|
||||
# OPENAI_API_URL=https://api.openai.com/v1
|
||||
LLM_MODEL=gpt-4o-mini
|
||||
LLM_TEMPERATURE=0.7
|
||||
|
||||
# TTS
|
||||
# edge: no API key needed
|
||||
# openai_compatible: compatible with SiliconFlow-style endpoints
|
||||
TTS_PROVIDER=openai_compatible
|
||||
TTS_VOICE=anna
|
||||
TTS_SPEED=1.0
|
||||
|
||||
# SiliconFlow (used by TTS and/or ASR when provider=openai_compatible)
|
||||
SILICONFLOW_API_KEY=your_siliconflow_api_key_here
|
||||
SILICONFLOW_TTS_MODEL=FunAudioLLM/CosyVoice2-0.5B
|
||||
SILICONFLOW_ASR_MODEL=FunAudioLLM/SenseVoiceSmall
|
||||
|
||||
# ASR
|
||||
ASR_PROVIDER=openai_compatible
|
||||
# Interim cadence and minimum audio before interim decode.
|
||||
ASR_INTERIM_INTERVAL_MS=500
|
||||
ASR_MIN_AUDIO_MS=300
|
||||
# ASR start gate: ignore micro-noise, then commit to one turn once started.
|
||||
ASR_START_MIN_SPEECH_MS=160
|
||||
# Pre-roll protects beginning phonemes.
|
||||
ASR_PRE_SPEECH_MS=240
|
||||
# Tail silence protects ending phonemes.
|
||||
ASR_FINAL_TAIL_MS=120
|
||||
|
||||
# Duplex behavior
|
||||
DUPLEX_ENABLED=true
|
||||
# DUPLEX_GREETING=Hello! How can I help you today?
|
||||
DUPLEX_SYSTEM_PROMPT=You are a helpful, friendly voice assistant. Keep your responses concise and conversational.
|
||||
|
||||
# Barge-in (user interrupting assistant)
|
||||
# Min user speech duration needed to interrupt assistant audio.
|
||||
BARGE_IN_MIN_DURATION_MS=200
|
||||
# Allowed silence during potential barge-in (ms) before reset.
|
||||
BARGE_IN_SILENCE_TOLERANCE_MS=60
|
||||
|
||||
# Logging
|
||||
LOG_LEVEL=INFO
|
||||
# json is better for production/observability; text is easier locally.
|
||||
LOG_FORMAT=json
|
||||
|
||||
# WebSocket behavior
|
||||
INACTIVITY_TIMEOUT_SEC=60
|
||||
HEARTBEAT_INTERVAL_SEC=50
|
||||
WS_PROTOCOL_VERSION=v1
|
||||
# WS_API_KEY=replace_with_shared_secret
|
||||
WS_REQUIRE_AUTH=false
|
||||
|
||||
# CORS / ICE (JSON strings)
|
||||
CORS_ORIGINS=["http://localhost:3000","http://localhost:8080"]
|
||||
ICE_SERVERS=[{"urls":"stun:stun.l.google.com:19302"}]
|
||||
148
.gitignore
vendored
Normal file
148
.gitignore
vendored
Normal file
@@ -0,0 +1,148 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
Pipfile.lock
|
||||
|
||||
# poetry
|
||||
poetry.lock
|
||||
|
||||
# pdm
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# IDEs
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# Project specific
|
||||
recordings/
|
||||
logs/
|
||||
running/
|
||||
31
README.md
Normal file
31
README.md
Normal file
@@ -0,0 +1,31 @@
|
||||
# py-active-call-cc
|
||||
|
||||
Python Active-Call: real-time audio streaming with WebSocket and WebRTC.
|
||||
|
||||
This repo contains a Python 3.11+ codebase for building low-latency voice
|
||||
pipelines (capture, stream, and process audio) using WebRTC and WebSockets.
|
||||
It is currently in an early, experimental stage.
|
||||
|
||||
# Usage
|
||||
|
||||
启动
|
||||
|
||||
```
|
||||
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
测试
|
||||
|
||||
```
|
||||
python examples/test_websocket.py
|
||||
```
|
||||
|
||||
```
|
||||
python mic_client.py
|
||||
```
|
||||
|
||||
## WS Protocol
|
||||
|
||||
`/ws` uses a strict `v1` JSON control protocol with binary PCM audio frames.
|
||||
|
||||
See `/Users/wx44wx/.codex/worktrees/d817/AI-VideoAssistant/engine/docs/ws_v1_schema.md`.
|
||||
1
app/__init__.py
Normal file
1
app/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Active-Call Application Package"""
|
||||
211
app/backend_client.py
Normal file
211
app/backend_client.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""Backend API client for assistant config and history persistence."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
|
||||
from app.config import settings
|
||||
|
||||
|
||||
async def fetch_assistant_config(assistant_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Fetch assistant config payload from backend API.
|
||||
|
||||
Expected response shape:
|
||||
{
|
||||
"assistant": {...},
|
||||
"voice": {...} | null
|
||||
}
|
||||
"""
|
||||
if not settings.backend_url:
|
||||
logger.warning("BACKEND_URL not set; skipping assistant config fetch")
|
||||
return None
|
||||
|
||||
url = f"{settings.backend_url.rstrip('/')}/api/assistants/{assistant_id}/config"
|
||||
timeout = aiohttp.ClientTimeout(total=settings.backend_timeout_sec)
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.get(url) as resp:
|
||||
if resp.status == 404:
|
||||
logger.warning(f"Assistant config not found: {assistant_id}")
|
||||
return None
|
||||
resp.raise_for_status()
|
||||
payload = await resp.json()
|
||||
if not isinstance(payload, dict):
|
||||
logger.warning("Assistant config payload is not a dict; ignoring")
|
||||
return None
|
||||
return payload
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to fetch assistant config ({assistant_id}): {exc}")
|
||||
return None
|
||||
|
||||
|
||||
def _backend_base_url() -> Optional[str]:
|
||||
if not settings.backend_url:
|
||||
return None
|
||||
return settings.backend_url.rstrip("/")
|
||||
|
||||
|
||||
def _timeout() -> aiohttp.ClientTimeout:
|
||||
return aiohttp.ClientTimeout(total=settings.backend_timeout_sec)
|
||||
|
||||
|
||||
async def create_history_call_record(
|
||||
*,
|
||||
user_id: int,
|
||||
assistant_id: Optional[str],
|
||||
source: str = "debug",
|
||||
) -> Optional[str]:
|
||||
"""Create a call record via backend history API and return call_id."""
|
||||
base_url = _backend_base_url()
|
||||
if not base_url:
|
||||
return None
|
||||
|
||||
url = f"{base_url}/api/history"
|
||||
payload: Dict[str, Any] = {
|
||||
"user_id": user_id,
|
||||
"assistant_id": assistant_id,
|
||||
"source": source,
|
||||
"status": "connected",
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=_timeout()) as session:
|
||||
async with session.post(url, json=payload) as resp:
|
||||
resp.raise_for_status()
|
||||
data = await resp.json()
|
||||
call_id = str((data or {}).get("id") or "")
|
||||
return call_id or None
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to create history call record: {exc}")
|
||||
return None
|
||||
|
||||
|
||||
async def add_history_transcript(
|
||||
*,
|
||||
call_id: str,
|
||||
turn_index: int,
|
||||
speaker: str,
|
||||
content: str,
|
||||
start_ms: int,
|
||||
end_ms: int,
|
||||
confidence: Optional[float] = None,
|
||||
duration_ms: Optional[int] = None,
|
||||
) -> bool:
|
||||
"""Append a transcript segment to backend history."""
|
||||
base_url = _backend_base_url()
|
||||
if not base_url or not call_id:
|
||||
return False
|
||||
|
||||
url = f"{base_url}/api/history/{call_id}/transcripts"
|
||||
payload: Dict[str, Any] = {
|
||||
"turn_index": turn_index,
|
||||
"speaker": speaker,
|
||||
"content": content,
|
||||
"confidence": confidence,
|
||||
"start_ms": start_ms,
|
||||
"end_ms": end_ms,
|
||||
"duration_ms": duration_ms,
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=_timeout()) as session:
|
||||
async with session.post(url, json=payload) as resp:
|
||||
resp.raise_for_status()
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to append history transcript (call_id={call_id}, turn={turn_index}): {exc}")
|
||||
return False
|
||||
|
||||
|
||||
async def finalize_history_call_record(
|
||||
*,
|
||||
call_id: str,
|
||||
status: str,
|
||||
duration_seconds: int,
|
||||
) -> bool:
|
||||
"""Finalize a call record with status and duration."""
|
||||
base_url = _backend_base_url()
|
||||
if not base_url or not call_id:
|
||||
return False
|
||||
|
||||
url = f"{base_url}/api/history/{call_id}"
|
||||
payload: Dict[str, Any] = {
|
||||
"status": status,
|
||||
"duration_seconds": duration_seconds,
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=_timeout()) as session:
|
||||
async with session.put(url, json=payload) as resp:
|
||||
resp.raise_for_status()
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to finalize history call record ({call_id}): {exc}")
|
||||
return False
|
||||
|
||||
|
||||
async def search_knowledge_context(
|
||||
*,
|
||||
kb_id: str,
|
||||
query: str,
|
||||
n_results: int = 5,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search backend knowledge base and return retrieval results."""
|
||||
base_url = _backend_base_url()
|
||||
if not base_url:
|
||||
return []
|
||||
if not kb_id or not query.strip():
|
||||
return []
|
||||
try:
|
||||
safe_n_results = max(1, int(n_results))
|
||||
except (TypeError, ValueError):
|
||||
safe_n_results = 5
|
||||
|
||||
url = f"{base_url}/api/knowledge/search"
|
||||
payload: Dict[str, Any] = {
|
||||
"kb_id": kb_id,
|
||||
"query": query,
|
||||
"nResults": safe_n_results,
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=_timeout()) as session:
|
||||
async with session.post(url, json=payload) as resp:
|
||||
if resp.status == 404:
|
||||
logger.warning(f"Knowledge base not found for retrieval: {kb_id}")
|
||||
return []
|
||||
resp.raise_for_status()
|
||||
data = await resp.json()
|
||||
if not isinstance(data, dict):
|
||||
return []
|
||||
results = data.get("results", [])
|
||||
if not isinstance(results, list):
|
||||
return []
|
||||
return [r for r in results if isinstance(r, dict)]
|
||||
except Exception as exc:
|
||||
logger.warning(f"Knowledge search failed (kb_id={kb_id}): {exc}")
|
||||
return []
|
||||
|
||||
|
||||
async def fetch_tool_resource(tool_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Fetch tool resource configuration from backend API."""
|
||||
base_url = _backend_base_url()
|
||||
if not base_url or not tool_id:
|
||||
return None
|
||||
|
||||
url = f"{base_url}/api/tools/resources/{tool_id}"
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=_timeout()) as session:
|
||||
async with session.get(url) as resp:
|
||||
if resp.status == 404:
|
||||
return None
|
||||
resp.raise_for_status()
|
||||
data = await resp.json()
|
||||
return data if isinstance(data, dict) else None
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to fetch tool resource ({tool_id}): {exc}")
|
||||
return None
|
||||
154
app/config.py
Normal file
154
app/config.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""Configuration management using Pydantic settings."""
|
||||
|
||||
from typing import List, Optional
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
import json
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings loaded from environment variables."""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
case_sensitive=False,
|
||||
extra="ignore"
|
||||
)
|
||||
|
||||
# Server Configuration
|
||||
host: str = Field(default="0.0.0.0", description="Server host address")
|
||||
port: int = Field(default=8000, description="Server port")
|
||||
external_ip: Optional[str] = Field(default=None, description="External IP for NAT traversal")
|
||||
|
||||
# Audio Configuration
|
||||
sample_rate: int = Field(default=16000, description="Audio sample rate in Hz")
|
||||
chunk_size_ms: int = Field(default=20, description="Audio chunk duration in milliseconds")
|
||||
default_codec: str = Field(default="pcm", description="Default audio codec")
|
||||
max_audio_buffer_seconds: int = Field(
|
||||
default=30,
|
||||
description="Maximum buffered user audio duration kept in memory for current turn"
|
||||
)
|
||||
|
||||
# VAD Configuration
|
||||
vad_type: str = Field(default="silero", description="VAD algorithm type")
|
||||
vad_model_path: str = Field(default="data/vad/silero_vad.onnx", description="Path to VAD model")
|
||||
vad_threshold: float = Field(default=0.5, description="VAD detection threshold")
|
||||
vad_min_speech_duration_ms: int = Field(default=100, description="Minimum speech duration in milliseconds")
|
||||
vad_eou_threshold_ms: int = Field(default=800, description="End of utterance (silence) threshold in milliseconds")
|
||||
|
||||
# OpenAI / LLM Configuration
|
||||
openai_api_key: Optional[str] = Field(default=None, description="OpenAI API key")
|
||||
openai_api_url: Optional[str] = Field(default=None, description="OpenAI API base URL (for Azure/compatible)")
|
||||
llm_model: str = Field(default="gpt-4o-mini", description="LLM model name")
|
||||
llm_temperature: float = Field(default=0.7, description="LLM temperature for response generation")
|
||||
|
||||
# TTS Configuration
|
||||
tts_provider: str = Field(
|
||||
default="openai_compatible",
|
||||
description="TTS provider (edge, openai_compatible; siliconflow alias supported)"
|
||||
)
|
||||
tts_voice: str = Field(default="anna", description="TTS voice name")
|
||||
tts_speed: float = Field(default=1.0, description="TTS speech speed multiplier")
|
||||
|
||||
# SiliconFlow Configuration
|
||||
siliconflow_api_key: Optional[str] = Field(default=None, description="SiliconFlow API key")
|
||||
siliconflow_tts_model: str = Field(default="FunAudioLLM/CosyVoice2-0.5B", description="SiliconFlow TTS model")
|
||||
|
||||
# ASR Configuration
|
||||
asr_provider: str = Field(
|
||||
default="openai_compatible",
|
||||
description="ASR provider (openai_compatible, buffered; siliconflow alias supported)"
|
||||
)
|
||||
siliconflow_asr_model: str = Field(default="FunAudioLLM/SenseVoiceSmall", description="SiliconFlow ASR model")
|
||||
asr_interim_interval_ms: int = Field(default=500, description="Interval for interim ASR results in ms")
|
||||
asr_min_audio_ms: int = Field(default=300, description="Minimum audio duration before first ASR result")
|
||||
asr_start_min_speech_ms: int = Field(
|
||||
default=160,
|
||||
description="Minimum continuous speech duration before ASR capture starts"
|
||||
)
|
||||
asr_pre_speech_ms: int = Field(
|
||||
default=240,
|
||||
description="Audio context (ms) prepended before detected speech to avoid clipping first phoneme"
|
||||
)
|
||||
asr_final_tail_ms: int = Field(
|
||||
default=120,
|
||||
description="Silence tail (ms) appended before final ASR decode to protect utterance ending"
|
||||
)
|
||||
|
||||
# Duplex Pipeline Configuration
|
||||
duplex_enabled: bool = Field(default=True, description="Enable duplex voice pipeline")
|
||||
duplex_greeting: Optional[str] = Field(default=None, description="Optional greeting message")
|
||||
duplex_system_prompt: Optional[str] = Field(
|
||||
default="You are a helpful, friendly voice assistant. Keep your responses concise and conversational.",
|
||||
description="System prompt for LLM"
|
||||
)
|
||||
|
||||
# Barge-in (interruption) Configuration
|
||||
barge_in_min_duration_ms: int = Field(
|
||||
default=200,
|
||||
description="Minimum speech duration (ms) required to trigger barge-in. Lower=more sensitive."
|
||||
)
|
||||
barge_in_silence_tolerance_ms: int = Field(
|
||||
default=60,
|
||||
description="How much silence (ms) is tolerated during potential barge-in before reset"
|
||||
)
|
||||
|
||||
# Logging
|
||||
log_level: str = Field(default="INFO", description="Logging level")
|
||||
log_format: str = Field(default="json", description="Log format (json or text)")
|
||||
|
||||
# CORS
|
||||
cors_origins: str = Field(
|
||||
default='["http://localhost:3000", "http://localhost:8080"]',
|
||||
description="CORS allowed origins"
|
||||
)
|
||||
|
||||
# ICE Servers (WebRTC)
|
||||
ice_servers: str = Field(
|
||||
default='[{"urls": "stun:stun.l.google.com:19302"}]',
|
||||
description="ICE servers configuration"
|
||||
)
|
||||
|
||||
# WebSocket heartbeat and inactivity
|
||||
inactivity_timeout_sec: int = Field(default=60, description="Close connection after no message from client (seconds)")
|
||||
heartbeat_interval_sec: int = Field(default=50, description="Send heartBeat event to client every N seconds")
|
||||
ws_protocol_version: str = Field(default="v1", description="Public WS protocol version")
|
||||
ws_api_key: Optional[str] = Field(default=None, description="Optional API key required for WS hello auth")
|
||||
ws_require_auth: bool = Field(default=False, description="Require auth in hello message even when ws_api_key is not set")
|
||||
|
||||
# Backend bridge configuration (for call/transcript persistence)
|
||||
backend_url: Optional[str] = Field(default=None, description="Backend API base URL (e.g. http://localhost:8787)")
|
||||
backend_timeout_sec: int = Field(default=10, description="Backend API request timeout in seconds")
|
||||
history_default_user_id: int = Field(default=1, description="Fallback user_id for history records")
|
||||
|
||||
@property
|
||||
def chunk_size_bytes(self) -> int:
|
||||
"""Calculate chunk size in bytes based on sample rate and duration."""
|
||||
# 16-bit (2 bytes) per sample, mono channel
|
||||
return int(self.sample_rate * 2 * (self.chunk_size_ms / 1000.0))
|
||||
|
||||
@property
|
||||
def cors_origins_list(self) -> List[str]:
|
||||
"""Parse CORS origins from JSON string."""
|
||||
try:
|
||||
return json.loads(self.cors_origins)
|
||||
except json.JSONDecodeError:
|
||||
return ["http://localhost:3000", "http://localhost:8080"]
|
||||
|
||||
@property
|
||||
def ice_servers_list(self) -> List[dict]:
|
||||
"""Parse ICE servers from JSON string."""
|
||||
try:
|
||||
return json.loads(self.ice_servers)
|
||||
except json.JSONDecodeError:
|
||||
return [{"urls": "stun:stun.l.google.com:19302"}]
|
||||
|
||||
|
||||
# Global settings instance
|
||||
settings = Settings()
|
||||
|
||||
|
||||
def get_settings() -> Settings:
|
||||
"""Get application settings instance."""
|
||||
return settings
|
||||
396
app/main.py
Normal file
396
app/main.py
Normal file
@@ -0,0 +1,396 @@
|
||||
"""FastAPI application with WebSocket and WebRTC endpoints."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, List
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, FileResponse
|
||||
from loguru import logger
|
||||
|
||||
# Try to import aiortc (optional for WebRTC functionality)
|
||||
try:
|
||||
from aiortc import RTCPeerConnection, RTCSessionDescription
|
||||
AIORTC_AVAILABLE = True
|
||||
except ImportError:
|
||||
AIORTC_AVAILABLE = False
|
||||
logger.warning("aiortc not available - WebRTC endpoint will be disabled")
|
||||
|
||||
from app.config import settings
|
||||
from core.transports import SocketTransport, WebRtcTransport, BaseTransport
|
||||
from core.session import Session
|
||||
from processors.tracks import Resampled16kTrack
|
||||
from core.events import get_event_bus, reset_event_bus
|
||||
from models.ws_v1 import ev
|
||||
|
||||
# Check interval for heartbeat/timeout (seconds)
|
||||
_HEARTBEAT_CHECK_INTERVAL_SEC = 5
|
||||
|
||||
|
||||
async def heartbeat_and_timeout_task(
|
||||
transport: BaseTransport,
|
||||
session: Session,
|
||||
session_id: str,
|
||||
last_received_at: List[float],
|
||||
last_heartbeat_at: List[float],
|
||||
inactivity_timeout_sec: int,
|
||||
heartbeat_interval_sec: int,
|
||||
) -> None:
|
||||
"""
|
||||
Background task: send heartBeat every ~heartbeat_interval_sec and close
|
||||
connection if no message from client for inactivity_timeout_sec.
|
||||
"""
|
||||
while True:
|
||||
await asyncio.sleep(_HEARTBEAT_CHECK_INTERVAL_SEC)
|
||||
if transport.is_closed:
|
||||
break
|
||||
now = time.monotonic()
|
||||
if now - last_received_at[0] > inactivity_timeout_sec:
|
||||
logger.info(f"Session {session_id}: {inactivity_timeout_sec}s no message, closing")
|
||||
await session.cleanup()
|
||||
break
|
||||
if now - last_heartbeat_at[0] >= heartbeat_interval_sec:
|
||||
try:
|
||||
await transport.send_event({
|
||||
**ev("heartbeat"),
|
||||
})
|
||||
last_heartbeat_at[0] = now
|
||||
except Exception as e:
|
||||
logger.debug(f"Session {session_id}: heartbeat send failed: {e}")
|
||||
break
|
||||
|
||||
|
||||
# Initialize FastAPI
|
||||
app = FastAPI(title="Python Active-Call", version="0.1.0")
|
||||
_WEB_CLIENT_PATH = Path(__file__).resolve().parent.parent / "examples" / "web_client.html"
|
||||
|
||||
# Configure CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.cors_origins_list,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Active sessions storage
|
||||
active_sessions: Dict[str, Session] = {}
|
||||
|
||||
# Configure logging
|
||||
logger.remove()
|
||||
logger.add(
|
||||
"./logs/active_call_{time}.log",
|
||||
rotation="1 day",
|
||||
retention="7 days",
|
||||
level=settings.log_level,
|
||||
format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}"
|
||||
)
|
||||
logger.add(
|
||||
lambda msg: print(msg, end=""),
|
||||
level=settings.log_level,
|
||||
format="{time:HH:mm:ss} | {level: <8} | {message}"
|
||||
)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint."""
|
||||
return {"status": "healthy", "sessions": len(active_sessions)}
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def web_client_root():
|
||||
"""Serve the web client."""
|
||||
if not _WEB_CLIENT_PATH.exists():
|
||||
raise HTTPException(status_code=404, detail="Web client not found")
|
||||
return FileResponse(_WEB_CLIENT_PATH)
|
||||
|
||||
|
||||
@app.get("/client")
|
||||
async def web_client_alias():
|
||||
"""Alias for the web client."""
|
||||
if not _WEB_CLIENT_PATH.exists():
|
||||
raise HTTPException(status_code=404, detail="Web client not found")
|
||||
return FileResponse(_WEB_CLIENT_PATH)
|
||||
|
||||
|
||||
|
||||
|
||||
@app.get("/iceservers")
|
||||
async def get_ice_servers():
|
||||
"""Get ICE servers configuration for WebRTC."""
|
||||
return settings.ice_servers_list
|
||||
|
||||
|
||||
@app.get("/call/lists")
|
||||
async def list_calls():
|
||||
"""List all active calls."""
|
||||
return {
|
||||
"calls": [
|
||||
{
|
||||
"id": session_id,
|
||||
"state": session.state,
|
||||
"created_at": session.created_at
|
||||
}
|
||||
for session_id, session in active_sessions.items()
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@app.post("/call/kill/{session_id}")
|
||||
async def kill_call(session_id: str):
|
||||
"""Kill a specific active call."""
|
||||
if session_id not in active_sessions:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
session = active_sessions[session_id]
|
||||
await session.cleanup()
|
||||
del active_sessions[session_id]
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
"""
|
||||
WebSocket endpoint for raw audio streaming.
|
||||
|
||||
Accepts mixed text/binary frames:
|
||||
- Text frames: JSON commands
|
||||
- Binary frames: PCM audio data (16kHz, 16-bit, mono)
|
||||
"""
|
||||
await websocket.accept()
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
# Create transport and session
|
||||
transport = SocketTransport(websocket)
|
||||
session = Session(session_id, transport)
|
||||
active_sessions[session_id] = session
|
||||
|
||||
logger.info(f"WebSocket connection established: {session_id}")
|
||||
|
||||
last_received_at: List[float] = [time.monotonic()]
|
||||
last_heartbeat_at: List[float] = [0.0]
|
||||
hb_task = asyncio.create_task(
|
||||
heartbeat_and_timeout_task(
|
||||
transport,
|
||||
session,
|
||||
session_id,
|
||||
last_received_at,
|
||||
last_heartbeat_at,
|
||||
settings.inactivity_timeout_sec,
|
||||
settings.heartbeat_interval_sec,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
# Receive loop
|
||||
while True:
|
||||
message = await websocket.receive()
|
||||
message_type = message.get("type")
|
||||
|
||||
if message_type == "websocket.disconnect":
|
||||
logger.info(f"WebSocket disconnected: {session_id}")
|
||||
break
|
||||
|
||||
last_received_at[0] = time.monotonic()
|
||||
|
||||
# Handle binary audio data
|
||||
if "bytes" in message:
|
||||
await session.handle_audio(message["bytes"])
|
||||
|
||||
# Handle text commands
|
||||
elif "text" in message:
|
||||
await session.handle_text(message["text"])
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket disconnected: {session_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket error: {e}", exc_info=True)
|
||||
|
||||
finally:
|
||||
hb_task.cancel()
|
||||
try:
|
||||
await hb_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
# Cleanup session
|
||||
if session_id in active_sessions:
|
||||
await session.cleanup()
|
||||
del active_sessions[session_id]
|
||||
|
||||
logger.info(f"Session {session_id} removed")
|
||||
|
||||
|
||||
@app.websocket("/webrtc")
|
||||
async def webrtc_endpoint(websocket: WebSocket):
|
||||
"""
|
||||
WebRTC endpoint for WebRTC audio streaming.
|
||||
|
||||
Uses WebSocket for signaling (SDP exchange) and WebRTC for media transport.
|
||||
"""
|
||||
# Check if aiortc is available
|
||||
if not AIORTC_AVAILABLE:
|
||||
await websocket.close(code=1011, reason="WebRTC not available - aiortc/av not installed")
|
||||
logger.warning("WebRTC connection attempted but aiortc is not available")
|
||||
return
|
||||
await websocket.accept()
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
# Create WebRTC peer connection
|
||||
pc = RTCPeerConnection()
|
||||
|
||||
# Create transport and session
|
||||
transport = WebRtcTransport(websocket, pc)
|
||||
session = Session(session_id, transport)
|
||||
active_sessions[session_id] = session
|
||||
|
||||
logger.info(f"WebRTC connection established: {session_id}")
|
||||
|
||||
last_received_at: List[float] = [time.monotonic()]
|
||||
last_heartbeat_at: List[float] = [0.0]
|
||||
hb_task = asyncio.create_task(
|
||||
heartbeat_and_timeout_task(
|
||||
transport,
|
||||
session,
|
||||
session_id,
|
||||
last_received_at,
|
||||
last_heartbeat_at,
|
||||
settings.inactivity_timeout_sec,
|
||||
settings.heartbeat_interval_sec,
|
||||
)
|
||||
)
|
||||
|
||||
# Track handler for incoming audio
|
||||
@pc.on("track")
|
||||
def on_track(track):
|
||||
logger.info(f"Track received: {track.kind}")
|
||||
|
||||
if track.kind == "audio":
|
||||
# Wrap track with resampler
|
||||
wrapped_track = Resampled16kTrack(track)
|
||||
|
||||
# Create task to pull audio from track
|
||||
async def pull_audio():
|
||||
try:
|
||||
while True:
|
||||
frame = await wrapped_track.recv()
|
||||
# Convert frame to bytes
|
||||
pcm_bytes = frame.to_ndarray().tobytes()
|
||||
# Feed to session
|
||||
await session.handle_audio(pcm_bytes)
|
||||
except Exception as e:
|
||||
logger.error(f"Error pulling audio from track: {e}")
|
||||
|
||||
asyncio.create_task(pull_audio())
|
||||
|
||||
@pc.on("connectionstatechange")
|
||||
async def on_connectionstatechange():
|
||||
logger.info(f"Connection state: {pc.connectionState}")
|
||||
if pc.connectionState == "failed" or pc.connectionState == "closed":
|
||||
await session.cleanup()
|
||||
|
||||
try:
|
||||
# Signaling loop
|
||||
while True:
|
||||
message = await websocket.receive()
|
||||
|
||||
if "text" not in message:
|
||||
continue
|
||||
|
||||
last_received_at[0] = time.monotonic()
|
||||
data = json.loads(message["text"])
|
||||
|
||||
# Handle SDP offer/answer
|
||||
if "sdp" in data and "type" in data:
|
||||
logger.info(f"Received SDP {data['type']}")
|
||||
|
||||
# Set remote description
|
||||
offer = RTCSessionDescription(sdp=data["sdp"], type=data["type"])
|
||||
await pc.setRemoteDescription(offer)
|
||||
|
||||
# Create and set local description
|
||||
if data["type"] == "offer":
|
||||
answer = await pc.createAnswer()
|
||||
await pc.setLocalDescription(answer)
|
||||
|
||||
# Send answer back
|
||||
await websocket.send_text(json.dumps({
|
||||
"event": "answer",
|
||||
"trackId": session_id,
|
||||
"timestamp": int(asyncio.get_event_loop().time() * 1000),
|
||||
"sdp": pc.localDescription.sdp
|
||||
}))
|
||||
|
||||
logger.info(f"Sent SDP answer")
|
||||
|
||||
else:
|
||||
# Handle other commands
|
||||
await session.handle_text(message["text"])
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebRTC WebSocket disconnected: {session_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"WebRTC error: {e}", exc_info=True)
|
||||
|
||||
finally:
|
||||
hb_task.cancel()
|
||||
try:
|
||||
await hb_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
# Cleanup
|
||||
await pc.close()
|
||||
if session_id in active_sessions:
|
||||
await session.cleanup()
|
||||
del active_sessions[session_id]
|
||||
|
||||
logger.info(f"WebRTC session {session_id} removed")
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""Run on application startup."""
|
||||
logger.info("Starting Python Active-Call server")
|
||||
logger.info(f"Server: {settings.host}:{settings.port}")
|
||||
logger.info(f"Sample rate: {settings.sample_rate} Hz")
|
||||
logger.info(f"VAD model: {settings.vad_model_path}")
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
"""Run on application shutdown."""
|
||||
logger.info("Shutting down Python Active-Call server")
|
||||
|
||||
# Cleanup all sessions
|
||||
for session_id, session in active_sessions.items():
|
||||
await session.cleanup()
|
||||
|
||||
# Close event bus
|
||||
event_bus = get_event_bus()
|
||||
await event_bus.close()
|
||||
reset_event_bus()
|
||||
|
||||
logger.info("Server shutdown complete")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
# Create logs directory
|
||||
import os
|
||||
os.makedirs("logs", exist_ok=True)
|
||||
|
||||
# Run server
|
||||
uvicorn.run(
|
||||
"app.main:app",
|
||||
host=settings.host,
|
||||
port=settings.port,
|
||||
reload=True,
|
||||
log_level=settings.log_level.lower()
|
||||
)
|
||||
20
core/__init__.py
Normal file
20
core/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Core Components Package"""
|
||||
|
||||
from core.events import EventBus, get_event_bus
|
||||
from core.transports import BaseTransport, SocketTransport, WebRtcTransport
|
||||
from core.session import Session
|
||||
from core.conversation import ConversationManager, ConversationState, ConversationTurn
|
||||
from core.duplex_pipeline import DuplexPipeline
|
||||
|
||||
__all__ = [
|
||||
"EventBus",
|
||||
"get_event_bus",
|
||||
"BaseTransport",
|
||||
"SocketTransport",
|
||||
"WebRtcTransport",
|
||||
"Session",
|
||||
"ConversationManager",
|
||||
"ConversationState",
|
||||
"ConversationTurn",
|
||||
"DuplexPipeline",
|
||||
]
|
||||
279
core/conversation.py
Normal file
279
core/conversation.py
Normal file
@@ -0,0 +1,279 @@
|
||||
"""Conversation management for voice AI.
|
||||
|
||||
Handles conversation context, turn-taking, and message history
|
||||
for multi-turn voice conversations.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import List, Optional, Dict, Any, Callable, Awaitable
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from loguru import logger
|
||||
|
||||
from services.base import LLMMessage
|
||||
|
||||
|
||||
class ConversationState(Enum):
|
||||
"""State of the conversation."""
|
||||
IDLE = "idle" # Waiting for user input
|
||||
LISTENING = "listening" # User is speaking
|
||||
PROCESSING = "processing" # Processing user input (LLM)
|
||||
SPEAKING = "speaking" # Bot is speaking
|
||||
INTERRUPTED = "interrupted" # Bot was interrupted
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConversationTurn:
|
||||
"""A single turn in the conversation."""
|
||||
role: str # "user" or "assistant"
|
||||
text: str
|
||||
audio_duration_ms: Optional[int] = None
|
||||
timestamp: float = field(default_factory=lambda: asyncio.get_event_loop().time())
|
||||
was_interrupted: bool = False
|
||||
|
||||
|
||||
class ConversationManager:
|
||||
"""
|
||||
Manages conversation state and history.
|
||||
|
||||
Provides:
|
||||
- Message history for LLM context
|
||||
- Turn management
|
||||
- State tracking
|
||||
- Event callbacks for state changes
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
system_prompt: Optional[str] = None,
|
||||
max_history: int = 20,
|
||||
greeting: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize conversation manager.
|
||||
|
||||
Args:
|
||||
system_prompt: System prompt for LLM
|
||||
max_history: Maximum number of turns to keep
|
||||
greeting: Optional greeting message when conversation starts
|
||||
"""
|
||||
self.system_prompt = system_prompt or (
|
||||
"You are a helpful, friendly voice assistant. "
|
||||
"Keep your responses concise and conversational. "
|
||||
"Respond naturally as if having a phone conversation. "
|
||||
"If you don't understand something, ask for clarification."
|
||||
)
|
||||
self.max_history = max_history
|
||||
self.greeting = greeting
|
||||
|
||||
# State
|
||||
self.state = ConversationState.IDLE
|
||||
self.turns: List[ConversationTurn] = []
|
||||
|
||||
# Callbacks
|
||||
self._state_callbacks: List[Callable[[ConversationState, ConversationState], Awaitable[None]]] = []
|
||||
self._turn_callbacks: List[Callable[[ConversationTurn], Awaitable[None]]] = []
|
||||
|
||||
# Current turn tracking
|
||||
self._current_user_text: str = ""
|
||||
self._current_assistant_text: str = ""
|
||||
|
||||
logger.info("ConversationManager initialized")
|
||||
|
||||
def on_state_change(
|
||||
self,
|
||||
callback: Callable[[ConversationState, ConversationState], Awaitable[None]]
|
||||
) -> None:
|
||||
"""Register callback for state changes."""
|
||||
self._state_callbacks.append(callback)
|
||||
|
||||
def on_turn_complete(
|
||||
self,
|
||||
callback: Callable[[ConversationTurn], Awaitable[None]]
|
||||
) -> None:
|
||||
"""Register callback for turn completion."""
|
||||
self._turn_callbacks.append(callback)
|
||||
|
||||
async def set_state(self, new_state: ConversationState) -> None:
|
||||
"""Set conversation state and notify listeners."""
|
||||
if new_state != self.state:
|
||||
old_state = self.state
|
||||
self.state = new_state
|
||||
logger.debug(f"Conversation state: {old_state.value} -> {new_state.value}")
|
||||
|
||||
for callback in self._state_callbacks:
|
||||
try:
|
||||
await callback(old_state, new_state)
|
||||
except Exception as e:
|
||||
logger.error(f"State callback error: {e}")
|
||||
|
||||
def get_messages(self) -> List[LLMMessage]:
|
||||
"""
|
||||
Get conversation history as LLM messages.
|
||||
|
||||
Returns:
|
||||
List of LLMMessage objects including system prompt
|
||||
"""
|
||||
messages = [LLMMessage(role="system", content=self.system_prompt)]
|
||||
|
||||
# Add conversation history
|
||||
for turn in self.turns[-self.max_history:]:
|
||||
messages.append(LLMMessage(role=turn.role, content=turn.text))
|
||||
|
||||
# Add current user text if any
|
||||
if self._current_user_text:
|
||||
messages.append(LLMMessage(role="user", content=self._current_user_text))
|
||||
|
||||
return messages
|
||||
|
||||
async def start_user_turn(self) -> None:
|
||||
"""Signal that user has started speaking."""
|
||||
await self.set_state(ConversationState.LISTENING)
|
||||
self._current_user_text = ""
|
||||
|
||||
async def update_user_text(self, text: str, is_final: bool = False) -> None:
|
||||
"""
|
||||
Update current user text (from ASR).
|
||||
|
||||
Args:
|
||||
text: Transcribed text
|
||||
is_final: Whether this is the final transcript
|
||||
"""
|
||||
self._current_user_text = text
|
||||
|
||||
async def end_user_turn(self, text: str) -> None:
|
||||
"""
|
||||
End user turn and add to history.
|
||||
|
||||
Args:
|
||||
text: Final user text
|
||||
"""
|
||||
if text.strip():
|
||||
turn = ConversationTurn(role="user", text=text.strip())
|
||||
self.turns.append(turn)
|
||||
|
||||
for callback in self._turn_callbacks:
|
||||
try:
|
||||
await callback(turn)
|
||||
except Exception as e:
|
||||
logger.error(f"Turn callback error: {e}")
|
||||
|
||||
logger.info(f"User: {text[:50]}...")
|
||||
|
||||
self._current_user_text = ""
|
||||
await self.set_state(ConversationState.PROCESSING)
|
||||
|
||||
async def start_assistant_turn(self) -> None:
|
||||
"""Signal that assistant has started speaking."""
|
||||
await self.set_state(ConversationState.SPEAKING)
|
||||
self._current_assistant_text = ""
|
||||
|
||||
async def update_assistant_text(self, text: str) -> None:
|
||||
"""
|
||||
Update current assistant text (streaming).
|
||||
|
||||
Args:
|
||||
text: Text chunk from LLM
|
||||
"""
|
||||
self._current_assistant_text += text
|
||||
|
||||
async def end_assistant_turn(self, was_interrupted: bool = False) -> None:
|
||||
"""
|
||||
End assistant turn and add to history.
|
||||
|
||||
Args:
|
||||
was_interrupted: Whether the turn was interrupted by user
|
||||
"""
|
||||
text = self._current_assistant_text.strip()
|
||||
if text:
|
||||
turn = ConversationTurn(
|
||||
role="assistant",
|
||||
text=text,
|
||||
was_interrupted=was_interrupted
|
||||
)
|
||||
self.turns.append(turn)
|
||||
|
||||
for callback in self._turn_callbacks:
|
||||
try:
|
||||
await callback(turn)
|
||||
except Exception as e:
|
||||
logger.error(f"Turn callback error: {e}")
|
||||
|
||||
status = " (interrupted)" if was_interrupted else ""
|
||||
logger.info(f"Assistant{status}: {text[:50]}...")
|
||||
|
||||
self._current_assistant_text = ""
|
||||
|
||||
if was_interrupted:
|
||||
# A new user turn may already be active (LISTENING) when interrupted.
|
||||
# Avoid overriding it back to INTERRUPTED, which can stall EOU flow.
|
||||
if self.state != ConversationState.LISTENING:
|
||||
await self.set_state(ConversationState.INTERRUPTED)
|
||||
else:
|
||||
await self.set_state(ConversationState.IDLE)
|
||||
|
||||
async def add_assistant_turn(self, text: str, was_interrupted: bool = False) -> None:
|
||||
"""Append an assistant turn directly without mutating conversation state."""
|
||||
content = text.strip()
|
||||
if not content:
|
||||
return
|
||||
|
||||
turn = ConversationTurn(
|
||||
role="assistant",
|
||||
text=content,
|
||||
was_interrupted=was_interrupted,
|
||||
)
|
||||
self.turns.append(turn)
|
||||
|
||||
for callback in self._turn_callbacks:
|
||||
try:
|
||||
await callback(turn)
|
||||
except Exception as e:
|
||||
logger.error(f"Turn callback error: {e}")
|
||||
|
||||
logger.info(f"Assistant (injected): {content[:50]}...")
|
||||
|
||||
async def interrupt(self) -> None:
|
||||
"""Handle interruption (barge-in)."""
|
||||
if self.state == ConversationState.SPEAKING:
|
||||
await self.end_assistant_turn(was_interrupted=True)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset conversation history."""
|
||||
self.turns = []
|
||||
self._current_user_text = ""
|
||||
self._current_assistant_text = ""
|
||||
self.state = ConversationState.IDLE
|
||||
logger.info("Conversation reset")
|
||||
|
||||
@property
|
||||
def turn_count(self) -> int:
|
||||
"""Get number of turns in conversation."""
|
||||
return len(self.turns)
|
||||
|
||||
@property
|
||||
def last_user_text(self) -> Optional[str]:
|
||||
"""Get last user text."""
|
||||
for turn in reversed(self.turns):
|
||||
if turn.role == "user":
|
||||
return turn.text
|
||||
return None
|
||||
|
||||
@property
|
||||
def last_assistant_text(self) -> Optional[str]:
|
||||
"""Get last assistant text."""
|
||||
for turn in reversed(self.turns):
|
||||
if turn.role == "assistant":
|
||||
return turn.text
|
||||
return None
|
||||
|
||||
def get_context_summary(self) -> Dict[str, Any]:
|
||||
"""Get a summary of conversation context."""
|
||||
return {
|
||||
"state": self.state.value,
|
||||
"turn_count": self.turn_count,
|
||||
"last_user": self.last_user_text,
|
||||
"last_assistant": self.last_assistant_text,
|
||||
"current_user": self._current_user_text or None,
|
||||
"current_assistant": self._current_assistant_text or None
|
||||
}
|
||||
1507
core/duplex_pipeline.py
Normal file
1507
core/duplex_pipeline.py
Normal file
File diff suppressed because it is too large
Load Diff
134
core/events.py
Normal file
134
core/events.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""Event bus for pub/sub communication between components."""
|
||||
|
||||
import asyncio
|
||||
from typing import Callable, Dict, List, Any, Optional
|
||||
from collections import defaultdict
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class EventBus:
|
||||
"""
|
||||
Async event bus for pub/sub communication.
|
||||
|
||||
Similar to the original Rust implementation's broadcast channel.
|
||||
Components can subscribe to specific event types and receive events asynchronously.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the event bus."""
|
||||
self._subscribers: Dict[str, List[Callable]] = defaultdict(list)
|
||||
self._lock = asyncio.Lock()
|
||||
self._running = True
|
||||
|
||||
def subscribe(self, event_type: str, callback: Callable[[Dict[str, Any]], None]) -> None:
|
||||
"""
|
||||
Subscribe to an event type.
|
||||
|
||||
Args:
|
||||
event_type: Type of event to subscribe to (e.g., "speaking", "silence")
|
||||
callback: Async callback function that receives event data
|
||||
"""
|
||||
if not self._running:
|
||||
logger.warning(f"Event bus is shut down, ignoring subscription to {event_type}")
|
||||
return
|
||||
|
||||
self._subscribers[event_type].append(callback)
|
||||
logger.debug(f"Subscribed to event type: {event_type}")
|
||||
|
||||
def unsubscribe(self, event_type: str, callback: Callable[[Dict[str, Any]], None]) -> None:
|
||||
"""
|
||||
Unsubscribe from an event type.
|
||||
|
||||
Args:
|
||||
event_type: Type of event to unsubscribe from
|
||||
callback: Callback function to remove
|
||||
"""
|
||||
if callback in self._subscribers[event_type]:
|
||||
self._subscribers[event_type].remove(callback)
|
||||
logger.debug(f"Unsubscribed from event type: {event_type}")
|
||||
|
||||
async def publish(self, event_type: str, event_data: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Publish an event to all subscribers.
|
||||
|
||||
Args:
|
||||
event_type: Type of event to publish
|
||||
event_data: Event data to send to subscribers
|
||||
"""
|
||||
if not self._running:
|
||||
logger.warning(f"Event bus is shut down, ignoring event: {event_type}")
|
||||
return
|
||||
|
||||
# Get subscribers for this event type
|
||||
subscribers = self._subscribers.get(event_type, [])
|
||||
|
||||
if not subscribers:
|
||||
logger.debug(f"No subscribers for event type: {event_type}")
|
||||
return
|
||||
|
||||
# Notify all subscribers concurrently
|
||||
tasks = []
|
||||
for callback in subscribers:
|
||||
try:
|
||||
# Create task for each subscriber
|
||||
task = asyncio.create_task(self._call_subscriber(callback, event_data))
|
||||
tasks.append(task)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating task for subscriber: {e}")
|
||||
|
||||
# Wait for all subscribers to complete
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
logger.debug(f"Published event '{event_type}' to {len(tasks)} subscribers")
|
||||
|
||||
async def _call_subscriber(self, callback: Callable[[Dict[str, Any]], None], event_data: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Call a subscriber callback with error handling.
|
||||
|
||||
Args:
|
||||
callback: Subscriber callback function
|
||||
event_data: Event data to pass to callback
|
||||
"""
|
||||
try:
|
||||
# Check if callback is a coroutine function
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(event_data)
|
||||
else:
|
||||
callback(event_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in subscriber callback: {e}", exc_info=True)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the event bus and stop processing events."""
|
||||
self._running = False
|
||||
self._subscribers.clear()
|
||||
logger.info("Event bus closed")
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the event bus is running."""
|
||||
return self._running
|
||||
|
||||
|
||||
# Global event bus instance
|
||||
_event_bus: Optional[EventBus] = None
|
||||
|
||||
|
||||
def get_event_bus() -> EventBus:
|
||||
"""
|
||||
Get the global event bus instance.
|
||||
|
||||
Returns:
|
||||
EventBus instance
|
||||
"""
|
||||
global _event_bus
|
||||
if _event_bus is None:
|
||||
_event_bus = EventBus()
|
||||
return _event_bus
|
||||
|
||||
|
||||
def reset_event_bus() -> None:
|
||||
"""Reset the global event bus (mainly for testing)."""
|
||||
global _event_bus
|
||||
_event_bus = None
|
||||
648
core/session.py
Normal file
648
core/session.py
Normal file
@@ -0,0 +1,648 @@
|
||||
"""Session management for active calls."""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
import json
|
||||
import time
|
||||
import re
|
||||
from enum import Enum
|
||||
from typing import Optional, Dict, Any, List
|
||||
from loguru import logger
|
||||
|
||||
from app.backend_client import (
|
||||
create_history_call_record,
|
||||
add_history_transcript,
|
||||
finalize_history_call_record,
|
||||
)
|
||||
from core.transports import BaseTransport
|
||||
from core.duplex_pipeline import DuplexPipeline
|
||||
from core.conversation import ConversationTurn
|
||||
from core.workflow_runner import WorkflowRunner, WorkflowTransition, WorkflowNodeDef, WorkflowEdgeDef
|
||||
from app.config import settings
|
||||
from services.base import LLMMessage
|
||||
from models.ws_v1 import (
|
||||
parse_client_message,
|
||||
ev,
|
||||
HelloMessage,
|
||||
SessionStartMessage,
|
||||
SessionStopMessage,
|
||||
InputTextMessage,
|
||||
ResponseCancelMessage,
|
||||
ToolCallResultsMessage,
|
||||
)
|
||||
|
||||
|
||||
class WsSessionState(str, Enum):
|
||||
"""Protocol state machine for WS sessions."""
|
||||
|
||||
WAIT_HELLO = "wait_hello"
|
||||
WAIT_START = "wait_start"
|
||||
ACTIVE = "active"
|
||||
STOPPED = "stopped"
|
||||
|
||||
|
||||
class Session:
|
||||
"""
|
||||
Manages a single call session.
|
||||
|
||||
Handles command routing, audio processing, and session lifecycle.
|
||||
Uses full duplex voice conversation pipeline.
|
||||
"""
|
||||
|
||||
def __init__(self, session_id: str, transport: BaseTransport, use_duplex: bool = None):
|
||||
"""
|
||||
Initialize session.
|
||||
|
||||
Args:
|
||||
session_id: Unique session identifier
|
||||
transport: Transport instance for communication
|
||||
use_duplex: Whether to use duplex pipeline (defaults to settings.duplex_enabled)
|
||||
"""
|
||||
self.id = session_id
|
||||
self.transport = transport
|
||||
self.use_duplex = use_duplex if use_duplex is not None else settings.duplex_enabled
|
||||
|
||||
self.pipeline = DuplexPipeline(
|
||||
transport=transport,
|
||||
session_id=session_id,
|
||||
system_prompt=settings.duplex_system_prompt,
|
||||
greeting=settings.duplex_greeting
|
||||
)
|
||||
|
||||
# Session state
|
||||
self.created_at = None
|
||||
self.state = "created" # Legacy call state for /call/lists
|
||||
self.ws_state = WsSessionState.WAIT_HELLO
|
||||
self._pipeline_started = False
|
||||
self.protocol_version: Optional[str] = None
|
||||
self.authenticated: bool = False
|
||||
|
||||
# Track IDs
|
||||
self.current_track_id: Optional[str] = str(uuid.uuid4())
|
||||
self._history_call_id: Optional[str] = None
|
||||
self._history_turn_index: int = 0
|
||||
self._history_call_started_mono: Optional[float] = None
|
||||
self._history_finalized: bool = False
|
||||
self._cleanup_lock = asyncio.Lock()
|
||||
self._cleaned_up = False
|
||||
self.workflow_runner: Optional[WorkflowRunner] = None
|
||||
self._workflow_last_user_text: str = ""
|
||||
self._workflow_initial_node: Optional[WorkflowNodeDef] = None
|
||||
|
||||
self.pipeline.conversation.on_turn_complete(self._on_turn_complete)
|
||||
|
||||
logger.info(f"Session {self.id} created (duplex={self.use_duplex})")
|
||||
|
||||
async def handle_text(self, text_data: str) -> None:
|
||||
"""
|
||||
Handle incoming text data (WS v1 JSON control messages).
|
||||
|
||||
Args:
|
||||
text_data: JSON text data
|
||||
"""
|
||||
try:
|
||||
data = json.loads(text_data)
|
||||
message = parse_client_message(data)
|
||||
await self._handle_v1_message(message)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Session {self.id} JSON decode error: {e}")
|
||||
await self._send_error("client", f"Invalid JSON: {e}", "protocol.invalid_json")
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Session {self.id} command parse error: {e}")
|
||||
await self._send_error("client", f"Invalid message: {e}", "protocol.invalid_message")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Session {self.id} handle_text error: {e}", exc_info=True)
|
||||
await self._send_error("server", f"Internal error: {e}", "server.internal")
|
||||
|
||||
async def handle_audio(self, audio_bytes: bytes) -> None:
|
||||
"""
|
||||
Handle incoming audio data.
|
||||
|
||||
Args:
|
||||
audio_bytes: PCM audio data
|
||||
"""
|
||||
if self.ws_state != WsSessionState.ACTIVE:
|
||||
await self._send_error(
|
||||
"client",
|
||||
"Audio received before session.start",
|
||||
"protocol.order",
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
await self.pipeline.process_audio(audio_bytes)
|
||||
except Exception as e:
|
||||
logger.error(f"Session {self.id} handle_audio error: {e}", exc_info=True)
|
||||
|
||||
async def _handle_v1_message(self, message: Any) -> None:
|
||||
"""Route validated WS v1 message to handlers."""
|
||||
msg_type = message.type
|
||||
logger.info(f"Session {self.id} received message: {msg_type}")
|
||||
|
||||
if isinstance(message, HelloMessage):
|
||||
await self._handle_hello(message)
|
||||
return
|
||||
|
||||
# All messages below require hello handshake first
|
||||
if self.ws_state == WsSessionState.WAIT_HELLO:
|
||||
await self._send_error(
|
||||
"client",
|
||||
"Expected hello message first",
|
||||
"protocol.order",
|
||||
)
|
||||
return
|
||||
|
||||
if isinstance(message, SessionStartMessage):
|
||||
await self._handle_session_start(message)
|
||||
return
|
||||
|
||||
# All messages below require active session
|
||||
if self.ws_state != WsSessionState.ACTIVE:
|
||||
await self._send_error(
|
||||
"client",
|
||||
f"Message '{msg_type}' requires active session",
|
||||
"protocol.order",
|
||||
)
|
||||
return
|
||||
|
||||
if isinstance(message, InputTextMessage):
|
||||
await self.pipeline.process_text(message.text)
|
||||
elif isinstance(message, ResponseCancelMessage):
|
||||
if message.graceful:
|
||||
logger.info(f"Session {self.id} graceful response.cancel")
|
||||
else:
|
||||
await self.pipeline.interrupt()
|
||||
elif isinstance(message, ToolCallResultsMessage):
|
||||
await self.pipeline.handle_tool_call_results(message.results)
|
||||
elif isinstance(message, SessionStopMessage):
|
||||
await self._handle_session_stop(message.reason)
|
||||
else:
|
||||
await self._send_error("client", f"Unsupported message type: {msg_type}", "protocol.unsupported")
|
||||
|
||||
async def _handle_hello(self, message: HelloMessage) -> None:
|
||||
"""Handle initial hello/auth/version negotiation."""
|
||||
if self.ws_state != WsSessionState.WAIT_HELLO:
|
||||
await self._send_error("client", "Duplicate hello", "protocol.order")
|
||||
return
|
||||
|
||||
if message.version != settings.ws_protocol_version:
|
||||
await self._send_error(
|
||||
"client",
|
||||
f"Unsupported protocol version '{message.version}'",
|
||||
"protocol.version_unsupported",
|
||||
)
|
||||
await self.transport.close()
|
||||
self.ws_state = WsSessionState.STOPPED
|
||||
return
|
||||
|
||||
auth_payload = message.auth or {}
|
||||
api_key = auth_payload.get("apiKey")
|
||||
jwt = auth_payload.get("jwt")
|
||||
|
||||
if settings.ws_api_key:
|
||||
if api_key != settings.ws_api_key:
|
||||
await self._send_error("auth", "Invalid API key", "auth.invalid_api_key")
|
||||
await self.transport.close()
|
||||
self.ws_state = WsSessionState.STOPPED
|
||||
return
|
||||
elif settings.ws_require_auth and not (api_key or jwt):
|
||||
await self._send_error("auth", "Authentication required", "auth.required")
|
||||
await self.transport.close()
|
||||
self.ws_state = WsSessionState.STOPPED
|
||||
return
|
||||
|
||||
self.authenticated = True
|
||||
self.protocol_version = message.version
|
||||
self.ws_state = WsSessionState.WAIT_START
|
||||
await self.transport.send_event(
|
||||
ev(
|
||||
"hello.ack",
|
||||
sessionId=self.id,
|
||||
version=self.protocol_version,
|
||||
)
|
||||
)
|
||||
|
||||
async def _handle_session_start(self, message: SessionStartMessage) -> None:
|
||||
"""Handle explicit session start after successful hello."""
|
||||
if self.ws_state != WsSessionState.WAIT_START:
|
||||
await self._send_error("client", "Duplicate session.start", "protocol.order")
|
||||
return
|
||||
|
||||
metadata = message.metadata or {}
|
||||
metadata = self._merge_runtime_metadata(metadata, self._bootstrap_workflow(metadata))
|
||||
|
||||
# Create history call record early so later turn callbacks can append transcripts.
|
||||
await self._start_history_bridge(metadata)
|
||||
|
||||
# Apply runtime service/prompt overrides from backend if provided
|
||||
self.pipeline.apply_runtime_overrides(metadata)
|
||||
|
||||
# Start duplex pipeline
|
||||
if not self._pipeline_started:
|
||||
await self.pipeline.start()
|
||||
self._pipeline_started = True
|
||||
logger.info(f"Session {self.id} duplex pipeline started")
|
||||
|
||||
self.state = "accepted"
|
||||
self.ws_state = WsSessionState.ACTIVE
|
||||
await self.transport.send_event(
|
||||
ev(
|
||||
"session.started",
|
||||
sessionId=self.id,
|
||||
trackId=self.current_track_id,
|
||||
audio=message.audio or {},
|
||||
)
|
||||
)
|
||||
if self.workflow_runner and self._workflow_initial_node:
|
||||
await self.transport.send_event(
|
||||
ev(
|
||||
"workflow.started",
|
||||
sessionId=self.id,
|
||||
workflowId=self.workflow_runner.workflow_id,
|
||||
workflowName=self.workflow_runner.name,
|
||||
nodeId=self._workflow_initial_node.id,
|
||||
)
|
||||
)
|
||||
await self.transport.send_event(
|
||||
ev(
|
||||
"workflow.node.entered",
|
||||
sessionId=self.id,
|
||||
workflowId=self.workflow_runner.workflow_id,
|
||||
nodeId=self._workflow_initial_node.id,
|
||||
nodeName=self._workflow_initial_node.name,
|
||||
nodeType=self._workflow_initial_node.node_type,
|
||||
)
|
||||
)
|
||||
|
||||
async def _handle_session_stop(self, reason: Optional[str]) -> None:
|
||||
"""Handle session stop."""
|
||||
if self.ws_state == WsSessionState.STOPPED:
|
||||
return
|
||||
|
||||
stop_reason = reason or "client_requested"
|
||||
self.state = "hungup"
|
||||
self.ws_state = WsSessionState.STOPPED
|
||||
await self.transport.send_event(
|
||||
ev(
|
||||
"session.stopped",
|
||||
sessionId=self.id,
|
||||
reason=stop_reason,
|
||||
)
|
||||
)
|
||||
await self._finalize_history(status="connected")
|
||||
await self.transport.close()
|
||||
|
||||
async def _send_error(self, sender: str, error_message: str, code: str) -> None:
|
||||
"""
|
||||
Send error event to client.
|
||||
|
||||
Args:
|
||||
sender: Component that generated the error
|
||||
error_message: Error message
|
||||
code: Machine-readable error code
|
||||
"""
|
||||
await self.transport.send_event(
|
||||
ev(
|
||||
"error",
|
||||
sender=sender,
|
||||
code=code,
|
||||
message=error_message,
|
||||
trackId=self.current_track_id,
|
||||
)
|
||||
)
|
||||
|
||||
def _get_timestamp_ms(self) -> int:
|
||||
"""Get current timestamp in milliseconds."""
|
||||
import time
|
||||
return int(time.time() * 1000)
|
||||
|
||||
async def cleanup(self) -> None:
|
||||
"""Cleanup session resources."""
|
||||
async with self._cleanup_lock:
|
||||
if self._cleaned_up:
|
||||
return
|
||||
|
||||
self._cleaned_up = True
|
||||
logger.info(f"Session {self.id} cleaning up")
|
||||
await self._finalize_history(status="connected")
|
||||
await self.pipeline.cleanup()
|
||||
await self.transport.close()
|
||||
|
||||
async def _start_history_bridge(self, metadata: Dict[str, Any]) -> None:
|
||||
"""Initialize backend history call record for this session."""
|
||||
if self._history_call_id:
|
||||
return
|
||||
|
||||
history_meta: Dict[str, Any] = {}
|
||||
if isinstance(metadata.get("history"), dict):
|
||||
history_meta = metadata["history"]
|
||||
|
||||
raw_user_id = history_meta.get("userId", metadata.get("userId", settings.history_default_user_id))
|
||||
try:
|
||||
user_id = int(raw_user_id)
|
||||
except (TypeError, ValueError):
|
||||
user_id = settings.history_default_user_id
|
||||
|
||||
assistant_id = history_meta.get("assistantId", metadata.get("assistantId"))
|
||||
source = str(history_meta.get("source", metadata.get("source", "debug")))
|
||||
|
||||
call_id = await create_history_call_record(
|
||||
user_id=user_id,
|
||||
assistant_id=str(assistant_id) if assistant_id else None,
|
||||
source=source,
|
||||
)
|
||||
if not call_id:
|
||||
return
|
||||
|
||||
self._history_call_id = call_id
|
||||
self._history_call_started_mono = time.monotonic()
|
||||
self._history_turn_index = 0
|
||||
self._history_finalized = False
|
||||
logger.info(f"Session {self.id} history bridge enabled (call_id={call_id}, source={source})")
|
||||
|
||||
async def _on_turn_complete(self, turn: ConversationTurn) -> None:
|
||||
"""Process workflow transitions and persist completed turns to history."""
|
||||
if turn.text and turn.text.strip():
|
||||
role = (turn.role or "").lower()
|
||||
if role == "user":
|
||||
self._workflow_last_user_text = turn.text.strip()
|
||||
elif role == "assistant":
|
||||
await self._maybe_advance_workflow(turn.text.strip())
|
||||
|
||||
if not self._history_call_id:
|
||||
return
|
||||
if not turn.text or not turn.text.strip():
|
||||
return
|
||||
|
||||
role = (turn.role or "").lower()
|
||||
speaker = "human" if role == "user" else "ai"
|
||||
|
||||
end_ms = 0
|
||||
if self._history_call_started_mono is not None:
|
||||
end_ms = max(0, int((time.monotonic() - self._history_call_started_mono) * 1000))
|
||||
estimated_duration_ms = max(300, min(12000, len(turn.text.strip()) * 80))
|
||||
start_ms = max(0, end_ms - estimated_duration_ms)
|
||||
|
||||
turn_index = self._history_turn_index
|
||||
await add_history_transcript(
|
||||
call_id=self._history_call_id,
|
||||
turn_index=turn_index,
|
||||
speaker=speaker,
|
||||
content=turn.text.strip(),
|
||||
start_ms=start_ms,
|
||||
end_ms=end_ms,
|
||||
duration_ms=max(1, end_ms - start_ms),
|
||||
)
|
||||
self._history_turn_index += 1
|
||||
|
||||
async def _finalize_history(self, status: str) -> None:
|
||||
"""Finalize history call record once."""
|
||||
if not self._history_call_id or self._history_finalized:
|
||||
return
|
||||
|
||||
duration_seconds = 0
|
||||
if self._history_call_started_mono is not None:
|
||||
duration_seconds = max(0, int(time.monotonic() - self._history_call_started_mono))
|
||||
|
||||
ok = await finalize_history_call_record(
|
||||
call_id=self._history_call_id,
|
||||
status=status,
|
||||
duration_seconds=duration_seconds,
|
||||
)
|
||||
if ok:
|
||||
self._history_finalized = True
|
||||
|
||||
def _bootstrap_workflow(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Parse workflow payload and return initial runtime overrides."""
|
||||
payload = metadata.get("workflow")
|
||||
self.workflow_runner = WorkflowRunner.from_payload(payload)
|
||||
self._workflow_initial_node = None
|
||||
if not self.workflow_runner:
|
||||
return {}
|
||||
|
||||
node = self.workflow_runner.bootstrap()
|
||||
if not node:
|
||||
logger.warning(f"Session {self.id} workflow payload had no resolvable start node")
|
||||
self.workflow_runner = None
|
||||
return {}
|
||||
|
||||
self._workflow_initial_node = node
|
||||
logger.info(
|
||||
"Session {} workflow enabled: workflow={} start_node={}",
|
||||
self.id,
|
||||
self.workflow_runner.workflow_id,
|
||||
node.id,
|
||||
)
|
||||
return self.workflow_runner.build_runtime_metadata(node)
|
||||
|
||||
async def _maybe_advance_workflow(self, assistant_text: str) -> None:
|
||||
"""Attempt node transfer after assistant turn finalization."""
|
||||
if not self.workflow_runner or self.ws_state == WsSessionState.STOPPED:
|
||||
return
|
||||
|
||||
transition = await self.workflow_runner.route(
|
||||
user_text=self._workflow_last_user_text,
|
||||
assistant_text=assistant_text,
|
||||
llm_router=self._workflow_llm_route,
|
||||
)
|
||||
if not transition:
|
||||
return
|
||||
|
||||
await self._apply_workflow_transition(transition, reason="rule_match")
|
||||
|
||||
# Auto-advance through utility nodes when default edges are present.
|
||||
max_auto_hops = 6
|
||||
auto_hops = 0
|
||||
while self.workflow_runner and self.ws_state != WsSessionState.STOPPED:
|
||||
current = self.workflow_runner.current_node
|
||||
if not current or current.node_type not in {"start", "tool"}:
|
||||
break
|
||||
|
||||
next_default = self.workflow_runner.next_default_transition()
|
||||
if not next_default:
|
||||
break
|
||||
|
||||
auto_hops += 1
|
||||
await self._apply_workflow_transition(next_default, reason="auto")
|
||||
if auto_hops >= max_auto_hops:
|
||||
logger.warning(
|
||||
"Session {} workflow auto-advance reached hop limit (possible cycle)",
|
||||
self.id,
|
||||
)
|
||||
break
|
||||
|
||||
async def _apply_workflow_transition(self, transition: WorkflowTransition, reason: str) -> None:
|
||||
"""Apply graph transition and emit workflow lifecycle events."""
|
||||
if not self.workflow_runner:
|
||||
return
|
||||
|
||||
self.workflow_runner.apply_transition(transition)
|
||||
node = transition.node
|
||||
edge = transition.edge
|
||||
|
||||
await self.transport.send_event(
|
||||
ev(
|
||||
"workflow.edge.taken",
|
||||
sessionId=self.id,
|
||||
workflowId=self.workflow_runner.workflow_id,
|
||||
edgeId=edge.id,
|
||||
fromNodeId=edge.from_node_id,
|
||||
toNodeId=edge.to_node_id,
|
||||
reason=reason,
|
||||
)
|
||||
)
|
||||
await self.transport.send_event(
|
||||
ev(
|
||||
"workflow.node.entered",
|
||||
sessionId=self.id,
|
||||
workflowId=self.workflow_runner.workflow_id,
|
||||
nodeId=node.id,
|
||||
nodeName=node.name,
|
||||
nodeType=node.node_type,
|
||||
)
|
||||
)
|
||||
|
||||
node_runtime = self.workflow_runner.build_runtime_metadata(node)
|
||||
if node_runtime:
|
||||
self.pipeline.apply_runtime_overrides(node_runtime)
|
||||
|
||||
if node.node_type == "tool":
|
||||
await self.transport.send_event(
|
||||
ev(
|
||||
"workflow.tool.requested",
|
||||
sessionId=self.id,
|
||||
workflowId=self.workflow_runner.workflow_id,
|
||||
nodeId=node.id,
|
||||
tool=node.tool or {},
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
if node.node_type == "human_transfer":
|
||||
await self.transport.send_event(
|
||||
ev(
|
||||
"workflow.human_transfer",
|
||||
sessionId=self.id,
|
||||
workflowId=self.workflow_runner.workflow_id,
|
||||
nodeId=node.id,
|
||||
)
|
||||
)
|
||||
await self._handle_session_stop("workflow_human_transfer")
|
||||
return
|
||||
|
||||
if node.node_type == "end":
|
||||
await self.transport.send_event(
|
||||
ev(
|
||||
"workflow.ended",
|
||||
sessionId=self.id,
|
||||
workflowId=self.workflow_runner.workflow_id,
|
||||
nodeId=node.id,
|
||||
)
|
||||
)
|
||||
await self._handle_session_stop("workflow_end")
|
||||
|
||||
async def _workflow_llm_route(
|
||||
self,
|
||||
node: WorkflowNodeDef,
|
||||
candidates: List[WorkflowEdgeDef],
|
||||
context: Dict[str, str],
|
||||
) -> Optional[str]:
|
||||
"""LLM-based edge routing for condition.type == 'llm' edges."""
|
||||
llm_service = self.pipeline.llm_service
|
||||
if not llm_service:
|
||||
return None
|
||||
|
||||
candidate_rows = [
|
||||
{
|
||||
"edgeId": edge.id,
|
||||
"toNodeId": edge.to_node_id,
|
||||
"label": edge.label,
|
||||
"hint": edge.condition.get("prompt") if isinstance(edge.condition, dict) else None,
|
||||
}
|
||||
for edge in candidates
|
||||
]
|
||||
system_prompt = (
|
||||
"You are a workflow router. Pick exactly one edge. "
|
||||
"Return JSON only: {\"edgeId\":\"...\"}."
|
||||
)
|
||||
user_prompt = json.dumps(
|
||||
{
|
||||
"nodeId": node.id,
|
||||
"nodeName": node.name,
|
||||
"userText": context.get("userText", ""),
|
||||
"assistantText": context.get("assistantText", ""),
|
||||
"candidates": candidate_rows,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
try:
|
||||
reply = await llm_service.generate(
|
||||
[
|
||||
LLMMessage(role="system", content=system_prompt),
|
||||
LLMMessage(role="user", content=user_prompt),
|
||||
],
|
||||
temperature=0.0,
|
||||
max_tokens=64,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(f"Session {self.id} workflow llm routing failed: {exc}")
|
||||
return None
|
||||
|
||||
if not reply:
|
||||
return None
|
||||
|
||||
edge_ids = {edge.id for edge in candidates}
|
||||
node_ids = {edge.to_node_id for edge in candidates}
|
||||
|
||||
parsed = self._extract_json_obj(reply)
|
||||
if isinstance(parsed, dict):
|
||||
edge_id = parsed.get("edgeId") or parsed.get("id")
|
||||
node_id = parsed.get("toNodeId") or parsed.get("nodeId")
|
||||
if isinstance(edge_id, str) and edge_id in edge_ids:
|
||||
return edge_id
|
||||
if isinstance(node_id, str) and node_id in node_ids:
|
||||
return node_id
|
||||
|
||||
token_candidates = sorted(edge_ids | node_ids, key=len, reverse=True)
|
||||
lowered_reply = reply.lower()
|
||||
for token in token_candidates:
|
||||
if token.lower() in lowered_reply:
|
||||
return token
|
||||
return None
|
||||
|
||||
def _merge_runtime_metadata(self, base: Dict[str, Any], overrides: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Merge node-level metadata overrides into session.start metadata."""
|
||||
merged = dict(base or {})
|
||||
if not overrides:
|
||||
return merged
|
||||
for key, value in overrides.items():
|
||||
if key == "services" and isinstance(value, dict):
|
||||
existing = merged.get("services")
|
||||
merged_services = dict(existing) if isinstance(existing, dict) else {}
|
||||
merged_services.update(value)
|
||||
merged["services"] = merged_services
|
||||
else:
|
||||
merged[key] = value
|
||||
return merged
|
||||
|
||||
def _extract_json_obj(self, text: str) -> Optional[Dict[str, Any]]:
|
||||
"""Best-effort extraction of a JSON object from freeform text."""
|
||||
try:
|
||||
parsed = json.loads(text)
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
match = re.search(r"\{.*\}", text, re.DOTALL)
|
||||
if not match:
|
||||
return None
|
||||
try:
|
||||
parsed = json.loads(match.group(0))
|
||||
return parsed if isinstance(parsed, dict) else None
|
||||
except Exception:
|
||||
return None
|
||||
340
core/tool_executor.py
Normal file
340
core/tool_executor.py
Normal file
@@ -0,0 +1,340 @@
|
||||
"""Server-side tool execution helpers."""
|
||||
|
||||
import asyncio
|
||||
import ast
|
||||
import operator
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict
|
||||
|
||||
import aiohttp
|
||||
|
||||
from app.backend_client import fetch_tool_resource
|
||||
|
||||
_BIN_OPS = {
|
||||
ast.Add: operator.add,
|
||||
ast.Sub: operator.sub,
|
||||
ast.Mult: operator.mul,
|
||||
ast.Div: operator.truediv,
|
||||
ast.Mod: operator.mod,
|
||||
}
|
||||
|
||||
_UNARY_OPS = {
|
||||
ast.UAdd: operator.pos,
|
||||
ast.USub: operator.neg,
|
||||
}
|
||||
|
||||
_SAFE_EVAL_FUNCS = {
|
||||
"abs": abs,
|
||||
"round": round,
|
||||
"min": min,
|
||||
"max": max,
|
||||
"sum": sum,
|
||||
"len": len,
|
||||
}
|
||||
|
||||
|
||||
def _validate_safe_expr(node: ast.AST) -> None:
|
||||
"""Allow only a constrained subset of Python expression nodes."""
|
||||
if isinstance(node, ast.Expression):
|
||||
_validate_safe_expr(node.body)
|
||||
return
|
||||
|
||||
if isinstance(node, ast.Constant):
|
||||
return
|
||||
|
||||
if isinstance(node, (ast.List, ast.Tuple, ast.Set)):
|
||||
for elt in node.elts:
|
||||
_validate_safe_expr(elt)
|
||||
return
|
||||
|
||||
if isinstance(node, ast.Dict):
|
||||
for key in node.keys:
|
||||
if key is not None:
|
||||
_validate_safe_expr(key)
|
||||
for value in node.values:
|
||||
_validate_safe_expr(value)
|
||||
return
|
||||
|
||||
if isinstance(node, ast.BinOp):
|
||||
if type(node.op) not in _BIN_OPS:
|
||||
raise ValueError("unsupported operator")
|
||||
_validate_safe_expr(node.left)
|
||||
_validate_safe_expr(node.right)
|
||||
return
|
||||
|
||||
if isinstance(node, ast.UnaryOp):
|
||||
if type(node.op) not in _UNARY_OPS:
|
||||
raise ValueError("unsupported unary operator")
|
||||
_validate_safe_expr(node.operand)
|
||||
return
|
||||
|
||||
if isinstance(node, ast.BoolOp):
|
||||
for value in node.values:
|
||||
_validate_safe_expr(value)
|
||||
return
|
||||
|
||||
if isinstance(node, ast.Compare):
|
||||
_validate_safe_expr(node.left)
|
||||
for comp in node.comparators:
|
||||
_validate_safe_expr(comp)
|
||||
return
|
||||
|
||||
if isinstance(node, ast.Name):
|
||||
if node.id not in _SAFE_EVAL_FUNCS and node.id not in {"True", "False", "None"}:
|
||||
raise ValueError("unknown symbol")
|
||||
return
|
||||
|
||||
if isinstance(node, ast.Call):
|
||||
if not isinstance(node.func, ast.Name):
|
||||
raise ValueError("unsafe call target")
|
||||
if node.func.id not in _SAFE_EVAL_FUNCS:
|
||||
raise ValueError("function not allowed")
|
||||
for arg in node.args:
|
||||
_validate_safe_expr(arg)
|
||||
for kw in node.keywords:
|
||||
_validate_safe_expr(kw.value)
|
||||
return
|
||||
|
||||
# Explicitly reject high-risk nodes (import/attribute/subscript/comprehensions/lambda, etc.)
|
||||
raise ValueError("unsupported expression")
|
||||
|
||||
|
||||
def _safe_eval_python_expr(expression: str) -> Any:
|
||||
tree = ast.parse(expression, mode="eval")
|
||||
_validate_safe_expr(tree)
|
||||
return eval( # noqa: S307 - validated AST + empty builtins
|
||||
compile(tree, "<code_interpreter>", "eval"),
|
||||
{"__builtins__": {}},
|
||||
dict(_SAFE_EVAL_FUNCS),
|
||||
)
|
||||
|
||||
|
||||
def _json_safe(value: Any) -> Any:
|
||||
if isinstance(value, (str, int, float, bool)) or value is None:
|
||||
return value
|
||||
if isinstance(value, (list, tuple)):
|
||||
return [_json_safe(v) for v in value]
|
||||
if isinstance(value, dict):
|
||||
return {str(k): _json_safe(v) for k, v in value.items()}
|
||||
return repr(value)
|
||||
|
||||
|
||||
def _safe_eval_expr(expression: str) -> float:
|
||||
tree = ast.parse(expression, mode="eval")
|
||||
|
||||
def _eval(node: ast.AST) -> float:
|
||||
if isinstance(node, ast.Expression):
|
||||
return _eval(node.body)
|
||||
if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)):
|
||||
return float(node.value)
|
||||
if isinstance(node, ast.BinOp):
|
||||
op = _BIN_OPS.get(type(node.op))
|
||||
if not op:
|
||||
raise ValueError("unsupported operator")
|
||||
return float(op(_eval(node.left), _eval(node.right)))
|
||||
if isinstance(node, ast.UnaryOp):
|
||||
op = _UNARY_OPS.get(type(node.op))
|
||||
if not op:
|
||||
raise ValueError("unsupported unary operator")
|
||||
return float(op(_eval(node.operand)))
|
||||
raise ValueError("unsupported expression")
|
||||
|
||||
return _eval(tree)
|
||||
|
||||
|
||||
def _extract_tool_name(tool_call: Dict[str, Any]) -> str:
|
||||
function_payload = tool_call.get("function")
|
||||
if isinstance(function_payload, dict):
|
||||
return str(function_payload.get("name") or "").strip()
|
||||
return ""
|
||||
|
||||
|
||||
def _extract_tool_args(tool_call: Dict[str, Any]) -> Dict[str, Any]:
|
||||
function_payload = tool_call.get("function")
|
||||
if not isinstance(function_payload, dict):
|
||||
return {}
|
||||
raw = function_payload.get("arguments")
|
||||
if isinstance(raw, dict):
|
||||
return raw
|
||||
if not isinstance(raw, str):
|
||||
return {}
|
||||
text = raw.strip()
|
||||
if not text:
|
||||
return {}
|
||||
try:
|
||||
import json
|
||||
|
||||
parsed = json.loads(text)
|
||||
return parsed if isinstance(parsed, dict) else {}
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
async def execute_server_tool(tool_call: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Execute a server-side tool and return normalized result payload."""
|
||||
call_id = str(tool_call.get("id") or "").strip()
|
||||
tool_name = _extract_tool_name(tool_call)
|
||||
args = _extract_tool_args(tool_call)
|
||||
|
||||
if tool_name == "calculator":
|
||||
expression = str(args.get("expression") or "").strip()
|
||||
if not expression:
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"name": tool_name,
|
||||
"output": {"error": "missing expression"},
|
||||
"status": {"code": 400, "message": "bad_request"},
|
||||
}
|
||||
if len(expression) > 200:
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"name": tool_name,
|
||||
"output": {"expression": expression, "error": "expression too long"},
|
||||
"status": {"code": 422, "message": "invalid_expression"},
|
||||
}
|
||||
try:
|
||||
value = _safe_eval_expr(expression)
|
||||
if value.is_integer():
|
||||
value = int(value)
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"name": tool_name,
|
||||
"output": {"expression": expression, "result": value},
|
||||
"status": {"code": 200, "message": "ok"},
|
||||
}
|
||||
except Exception as exc:
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"name": tool_name,
|
||||
"output": {"expression": expression, "error": str(exc)},
|
||||
"status": {"code": 422, "message": "invalid_expression"},
|
||||
}
|
||||
|
||||
if tool_name == "code_interpreter":
|
||||
code = str(args.get("code") or args.get("expression") or "").strip()
|
||||
if not code:
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"name": tool_name,
|
||||
"output": {"error": "missing code"},
|
||||
"status": {"code": 400, "message": "bad_request"},
|
||||
}
|
||||
if len(code) > 500:
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"name": tool_name,
|
||||
"output": {"error": "code too long"},
|
||||
"status": {"code": 422, "message": "invalid_code"},
|
||||
}
|
||||
try:
|
||||
result = _safe_eval_python_expr(code)
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"name": tool_name,
|
||||
"output": {"code": code, "result": _json_safe(result)},
|
||||
"status": {"code": 200, "message": "ok"},
|
||||
}
|
||||
except Exception as exc:
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"name": tool_name,
|
||||
"output": {"code": code, "error": str(exc)},
|
||||
"status": {"code": 422, "message": "invalid_code"},
|
||||
}
|
||||
|
||||
if tool_name == "current_time":
|
||||
now = datetime.now().astimezone()
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"name": tool_name,
|
||||
"output": {
|
||||
"local_time": now.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"iso": now.isoformat(),
|
||||
"timezone": str(now.tzinfo or ""),
|
||||
"timestamp": int(now.timestamp()),
|
||||
},
|
||||
"status": {"code": 200, "message": "ok"},
|
||||
}
|
||||
|
||||
if tool_name and tool_name not in {"calculator", "code_interpreter", "current_time"}:
|
||||
resource = await fetch_tool_resource(tool_name)
|
||||
if resource and str(resource.get("category") or "") == "query":
|
||||
method = str(resource.get("http_method") or "GET").strip().upper()
|
||||
if method not in {"GET", "POST", "PUT", "PATCH", "DELETE"}:
|
||||
method = "GET"
|
||||
url = str(resource.get("http_url") or "").strip()
|
||||
headers = resource.get("http_headers") if isinstance(resource.get("http_headers"), dict) else {}
|
||||
timeout_ms = resource.get("http_timeout_ms")
|
||||
try:
|
||||
timeout_s = max(1.0, float(timeout_ms) / 1000.0)
|
||||
except Exception:
|
||||
timeout_s = 10.0
|
||||
|
||||
if not url:
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"name": tool_name,
|
||||
"output": {"error": "http_url not configured"},
|
||||
"status": {"code": 422, "message": "invalid_tool_config"},
|
||||
}
|
||||
|
||||
request_kwargs: Dict[str, Any] = {}
|
||||
if method in {"GET", "DELETE"}:
|
||||
request_kwargs["params"] = args
|
||||
else:
|
||||
request_kwargs["json"] = args
|
||||
|
||||
try:
|
||||
timeout = aiohttp.ClientTimeout(total=timeout_s)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.request(method, url, headers=headers, **request_kwargs) as resp:
|
||||
content_type = str(resp.headers.get("Content-Type") or "").lower()
|
||||
if "application/json" in content_type:
|
||||
body: Any = await resp.json()
|
||||
else:
|
||||
body = await resp.text()
|
||||
status_code = int(resp.status)
|
||||
if 200 <= status_code < 300:
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"name": tool_name,
|
||||
"output": {
|
||||
"method": method,
|
||||
"url": url,
|
||||
"status_code": status_code,
|
||||
"response": _json_safe(body),
|
||||
},
|
||||
"status": {"code": 200, "message": "ok"},
|
||||
}
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"name": tool_name,
|
||||
"output": {
|
||||
"method": method,
|
||||
"url": url,
|
||||
"status_code": status_code,
|
||||
"response": _json_safe(body),
|
||||
},
|
||||
"status": {"code": status_code, "message": "http_error"},
|
||||
}
|
||||
except asyncio.TimeoutError:
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"name": tool_name,
|
||||
"output": {"method": method, "url": url, "error": "request timeout"},
|
||||
"status": {"code": 504, "message": "http_timeout"},
|
||||
}
|
||||
except Exception as exc:
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"name": tool_name,
|
||||
"output": {"method": method, "url": url, "error": str(exc)},
|
||||
"status": {"code": 502, "message": "http_request_failed"},
|
||||
}
|
||||
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"name": tool_name or "unknown_tool",
|
||||
"output": {"message": "server tool not implemented"},
|
||||
"status": {"code": 501, "message": "not_implemented"},
|
||||
}
|
||||
247
core/transports.py
Normal file
247
core/transports.py
Normal file
@@ -0,0 +1,247 @@
|
||||
"""Transport layer for WebSocket and WebRTC communication."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from fastapi import WebSocket
|
||||
from starlette.websockets import WebSocketState
|
||||
from loguru import logger
|
||||
|
||||
# Try to import aiortc (optional for WebRTC functionality)
|
||||
try:
|
||||
from aiortc import RTCPeerConnection
|
||||
AIORTC_AVAILABLE = True
|
||||
except ImportError:
|
||||
AIORTC_AVAILABLE = False
|
||||
RTCPeerConnection = None # Type hint placeholder
|
||||
|
||||
|
||||
class BaseTransport(ABC):
|
||||
"""
|
||||
Abstract base class for transports.
|
||||
|
||||
All transports must implement send_event and send_audio methods.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def send_event(self, event: dict) -> None:
|
||||
"""
|
||||
Send a JSON event to the client.
|
||||
|
||||
Args:
|
||||
event: Event data as dictionary
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def send_audio(self, pcm_bytes: bytes) -> None:
|
||||
"""
|
||||
Send audio data to the client.
|
||||
|
||||
Args:
|
||||
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None:
|
||||
"""Close the transport and cleanup resources."""
|
||||
pass
|
||||
|
||||
|
||||
class SocketTransport(BaseTransport):
|
||||
"""
|
||||
WebSocket transport for raw audio streaming.
|
||||
|
||||
Handles mixed text/binary frames over WebSocket connection.
|
||||
Uses asyncio.Lock to prevent frame interleaving.
|
||||
"""
|
||||
|
||||
def __init__(self, websocket: WebSocket):
|
||||
"""
|
||||
Initialize WebSocket transport.
|
||||
|
||||
Args:
|
||||
websocket: FastAPI WebSocket connection
|
||||
"""
|
||||
self.ws = websocket
|
||||
self.lock = asyncio.Lock() # Prevent frame interleaving
|
||||
self._closed = False
|
||||
|
||||
def _ws_disconnected(self) -> bool:
|
||||
"""Best-effort check for websocket disconnection state."""
|
||||
return (
|
||||
self.ws.client_state == WebSocketState.DISCONNECTED
|
||||
or self.ws.application_state == WebSocketState.DISCONNECTED
|
||||
)
|
||||
|
||||
async def send_event(self, event: dict) -> None:
|
||||
"""
|
||||
Send a JSON event via WebSocket.
|
||||
|
||||
Args:
|
||||
event: Event data as dictionary
|
||||
"""
|
||||
if self._closed or self._ws_disconnected():
|
||||
logger.warning("Attempted to send event on closed transport")
|
||||
self._closed = True
|
||||
return
|
||||
|
||||
async with self.lock:
|
||||
try:
|
||||
await self.ws.send_text(json.dumps(event))
|
||||
logger.debug(f"Sent event: {event.get('event', 'unknown')}")
|
||||
except RuntimeError as e:
|
||||
self._closed = True
|
||||
if self._ws_disconnected() or "close message has been sent" in str(e):
|
||||
logger.debug(f"Skip sending event on closed websocket: {e!r}")
|
||||
return
|
||||
logger.error(f"Error sending event: {e!r}")
|
||||
except Exception as e:
|
||||
self._closed = True
|
||||
if self._ws_disconnected():
|
||||
logger.debug(f"Skip sending event on disconnected websocket: {e!r}")
|
||||
return
|
||||
logger.error(f"Error sending event: {e!r}")
|
||||
|
||||
async def send_audio(self, pcm_bytes: bytes) -> None:
|
||||
"""
|
||||
Send PCM audio data via WebSocket.
|
||||
|
||||
Args:
|
||||
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
|
||||
"""
|
||||
if self._closed or self._ws_disconnected():
|
||||
logger.warning("Attempted to send audio on closed transport")
|
||||
self._closed = True
|
||||
return
|
||||
|
||||
async with self.lock:
|
||||
try:
|
||||
await self.ws.send_bytes(pcm_bytes)
|
||||
except RuntimeError as e:
|
||||
self._closed = True
|
||||
if self._ws_disconnected() or "close message has been sent" in str(e):
|
||||
logger.debug(f"Skip sending audio on closed websocket: {e!r}")
|
||||
return
|
||||
logger.error(f"Error sending audio: {e!r}")
|
||||
except Exception as e:
|
||||
self._closed = True
|
||||
if self._ws_disconnected():
|
||||
logger.debug(f"Skip sending audio on disconnected websocket: {e!r}")
|
||||
return
|
||||
logger.error(f"Error sending audio: {e!r}")
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the WebSocket connection."""
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
self._closed = True
|
||||
if self._ws_disconnected():
|
||||
return
|
||||
|
||||
try:
|
||||
await self.ws.close()
|
||||
except RuntimeError as e:
|
||||
# Already closed by another task/path; safe to ignore.
|
||||
if "close message has been sent" in str(e):
|
||||
logger.debug(f"WebSocket already closed: {e}")
|
||||
return
|
||||
logger.error(f"Error closing WebSocket: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing WebSocket: {e}")
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
"""Check if the transport is closed."""
|
||||
return self._closed
|
||||
|
||||
|
||||
class WebRtcTransport(BaseTransport):
|
||||
"""
|
||||
WebRTC transport for WebRTC audio streaming.
|
||||
|
||||
Uses WebSocket for signaling and RTCPeerConnection for media.
|
||||
"""
|
||||
|
||||
def __init__(self, websocket: WebSocket, pc):
|
||||
"""
|
||||
Initialize WebRTC transport.
|
||||
|
||||
Args:
|
||||
websocket: FastAPI WebSocket connection for signaling
|
||||
pc: RTCPeerConnection for media transport
|
||||
"""
|
||||
if not AIORTC_AVAILABLE:
|
||||
raise RuntimeError("aiortc is not available - WebRTC transport cannot be used")
|
||||
|
||||
self.ws = websocket
|
||||
self.pc = pc
|
||||
self.outbound_track = None # MediaStreamTrack for outbound audio
|
||||
self._closed = False
|
||||
|
||||
async def send_event(self, event: dict) -> None:
|
||||
"""
|
||||
Send a JSON event via WebSocket signaling.
|
||||
|
||||
Args:
|
||||
event: Event data as dictionary
|
||||
"""
|
||||
if self._closed:
|
||||
logger.warning("Attempted to send event on closed transport")
|
||||
return
|
||||
|
||||
try:
|
||||
await self.ws.send_text(json.dumps(event))
|
||||
logger.debug(f"Sent event: {event.get('event', 'unknown')}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending event: {e}")
|
||||
self._closed = True
|
||||
|
||||
async def send_audio(self, pcm_bytes: bytes) -> None:
|
||||
"""
|
||||
Send audio data via WebRTC track.
|
||||
|
||||
Note: In WebRTC, you don't send bytes directly. You push frames
|
||||
to a MediaStreamTrack that the peer connection is reading.
|
||||
|
||||
Args:
|
||||
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
|
||||
"""
|
||||
if self._closed:
|
||||
logger.warning("Attempted to send audio on closed transport")
|
||||
return
|
||||
|
||||
# This would require a custom MediaStreamTrack implementation
|
||||
# For now, we'll log this as a placeholder
|
||||
logger.debug(f"Audio bytes queued for WebRTC track: {len(pcm_bytes)} bytes")
|
||||
|
||||
# TODO: Implement outbound audio track if needed
|
||||
# if self.outbound_track:
|
||||
# await self.outbound_track.add_frame(pcm_bytes)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the WebRTC connection."""
|
||||
self._closed = True
|
||||
try:
|
||||
await self.pc.close()
|
||||
await self.ws.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing WebRTC transport: {e}")
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
"""Check if the transport is closed."""
|
||||
return self._closed
|
||||
|
||||
def set_outbound_track(self, track):
|
||||
"""
|
||||
Set the outbound audio track for sending audio to client.
|
||||
|
||||
Args:
|
||||
track: MediaStreamTrack for outbound audio
|
||||
"""
|
||||
self.outbound_track = track
|
||||
logger.debug("Set outbound track for WebRTC transport")
|
||||
402
core/workflow_runner.py
Normal file
402
core/workflow_runner.py
Normal file
@@ -0,0 +1,402 @@
|
||||
"""Workflow runtime helpers for session-level node routing.
|
||||
|
||||
MVP goals:
|
||||
- Parse workflow graph payload from WS session.start metadata
|
||||
- Track current node
|
||||
- Evaluate edge conditions on each assistant turn completion
|
||||
- Provide per-node runtime metadata overrides (prompt/greeting/services)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
_NODE_TYPE_MAP = {
|
||||
"conversation": "assistant",
|
||||
"assistant": "assistant",
|
||||
"human": "human_transfer",
|
||||
"human_transfer": "human_transfer",
|
||||
"tool": "tool",
|
||||
"end": "end",
|
||||
"start": "start",
|
||||
}
|
||||
|
||||
|
||||
def _normalize_node_type(raw_type: Any) -> str:
|
||||
value = str(raw_type or "").strip().lower()
|
||||
return _NODE_TYPE_MAP.get(value, "assistant")
|
||||
|
||||
|
||||
def _safe_str(value: Any) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
return str(value)
|
||||
|
||||
|
||||
def _normalize_condition(raw: Any, label: Optional[str]) -> Dict[str, Any]:
|
||||
if not isinstance(raw, dict):
|
||||
if label:
|
||||
return {"type": "contains", "source": "user", "value": str(label)}
|
||||
return {"type": "always"}
|
||||
|
||||
condition = dict(raw)
|
||||
condition_type = str(condition.get("type", "always")).strip().lower()
|
||||
if not condition_type:
|
||||
condition_type = "always"
|
||||
condition["type"] = condition_type
|
||||
condition["source"] = str(condition.get("source", "user")).strip().lower() or "user"
|
||||
return condition
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkflowNodeDef:
|
||||
id: str
|
||||
name: str
|
||||
node_type: str
|
||||
is_start: bool
|
||||
prompt: Optional[str]
|
||||
message_plan: Dict[str, Any]
|
||||
assistant_id: Optional[str]
|
||||
assistant: Dict[str, Any]
|
||||
tool: Optional[Dict[str, Any]]
|
||||
raw: Dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkflowEdgeDef:
|
||||
id: str
|
||||
from_node_id: str
|
||||
to_node_id: str
|
||||
label: Optional[str]
|
||||
condition: Dict[str, Any]
|
||||
priority: int
|
||||
order: int
|
||||
raw: Dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkflowTransition:
|
||||
edge: WorkflowEdgeDef
|
||||
node: WorkflowNodeDef
|
||||
|
||||
|
||||
LlmRouter = Callable[
|
||||
[WorkflowNodeDef, List[WorkflowEdgeDef], Dict[str, str]],
|
||||
Awaitable[Optional[str]],
|
||||
]
|
||||
|
||||
|
||||
class WorkflowRunner:
|
||||
"""In-memory workflow graph for a single active session."""
|
||||
|
||||
def __init__(self, workflow_id: str, name: str, nodes: List[WorkflowNodeDef], edges: List[WorkflowEdgeDef]):
|
||||
self.workflow_id = workflow_id
|
||||
self.name = name
|
||||
self._nodes: Dict[str, WorkflowNodeDef] = {node.id: node for node in nodes}
|
||||
self._edges = edges
|
||||
self.current_node_id: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_payload(cls, payload: Any) -> Optional["WorkflowRunner"]:
|
||||
if not isinstance(payload, dict):
|
||||
return None
|
||||
|
||||
raw_nodes = payload.get("nodes")
|
||||
raw_edges = payload.get("edges")
|
||||
if not isinstance(raw_nodes, list) or len(raw_nodes) == 0:
|
||||
return None
|
||||
|
||||
nodes: List[WorkflowNodeDef] = []
|
||||
for i, raw in enumerate(raw_nodes):
|
||||
if not isinstance(raw, dict):
|
||||
continue
|
||||
|
||||
node_id = _safe_str(raw.get("id") or raw.get("name") or f"node_{i + 1}").strip() or f"node_{i + 1}"
|
||||
node_name = _safe_str(raw.get("name") or node_id).strip() or node_id
|
||||
node_type = _normalize_node_type(raw.get("type"))
|
||||
is_start = bool(raw.get("isStart")) or node_type == "start"
|
||||
|
||||
prompt: Optional[str] = None
|
||||
if "prompt" in raw:
|
||||
prompt = _safe_str(raw.get("prompt"))
|
||||
|
||||
message_plan = raw.get("messagePlan")
|
||||
if not isinstance(message_plan, dict):
|
||||
message_plan = {}
|
||||
|
||||
assistant_cfg = raw.get("assistant")
|
||||
if not isinstance(assistant_cfg, dict):
|
||||
assistant_cfg = {}
|
||||
|
||||
tool_cfg = raw.get("tool")
|
||||
if not isinstance(tool_cfg, dict):
|
||||
tool_cfg = None
|
||||
|
||||
assistant_id = raw.get("assistantId")
|
||||
if assistant_id is not None:
|
||||
assistant_id = _safe_str(assistant_id).strip() or None
|
||||
|
||||
nodes.append(
|
||||
WorkflowNodeDef(
|
||||
id=node_id,
|
||||
name=node_name,
|
||||
node_type=node_type,
|
||||
is_start=is_start,
|
||||
prompt=prompt,
|
||||
message_plan=message_plan,
|
||||
assistant_id=assistant_id,
|
||||
assistant=assistant_cfg,
|
||||
tool=tool_cfg,
|
||||
raw=raw,
|
||||
)
|
||||
)
|
||||
|
||||
if not nodes:
|
||||
return None
|
||||
|
||||
node_ids = {node.id for node in nodes}
|
||||
edges: List[WorkflowEdgeDef] = []
|
||||
for i, raw in enumerate(raw_edges if isinstance(raw_edges, list) else []):
|
||||
if not isinstance(raw, dict):
|
||||
continue
|
||||
|
||||
from_node_id = _safe_str(
|
||||
raw.get("fromNodeId") or raw.get("from") or raw.get("from_") or raw.get("source")
|
||||
).strip()
|
||||
to_node_id = _safe_str(raw.get("toNodeId") or raw.get("to") or raw.get("target")).strip()
|
||||
if not from_node_id or not to_node_id:
|
||||
continue
|
||||
if from_node_id not in node_ids or to_node_id not in node_ids:
|
||||
continue
|
||||
|
||||
label = raw.get("label")
|
||||
if label is not None:
|
||||
label = _safe_str(label)
|
||||
|
||||
condition = _normalize_condition(raw.get("condition"), label=label)
|
||||
|
||||
priority = 100
|
||||
try:
|
||||
priority = int(raw.get("priority", 100))
|
||||
except (TypeError, ValueError):
|
||||
priority = 100
|
||||
|
||||
edge_id = _safe_str(raw.get("id") or f"e_{from_node_id}_{to_node_id}_{i + 1}").strip() or f"e_{i + 1}"
|
||||
|
||||
edges.append(
|
||||
WorkflowEdgeDef(
|
||||
id=edge_id,
|
||||
from_node_id=from_node_id,
|
||||
to_node_id=to_node_id,
|
||||
label=label,
|
||||
condition=condition,
|
||||
priority=priority,
|
||||
order=i,
|
||||
raw=raw,
|
||||
)
|
||||
)
|
||||
|
||||
workflow_id = _safe_str(payload.get("id") or "workflow")
|
||||
workflow_name = _safe_str(payload.get("name") or workflow_id)
|
||||
return cls(workflow_id=workflow_id, name=workflow_name, nodes=nodes, edges=edges)
|
||||
|
||||
def bootstrap(self) -> Optional[WorkflowNodeDef]:
|
||||
start_node = self._resolve_start_node()
|
||||
if not start_node:
|
||||
return None
|
||||
self.current_node_id = start_node.id
|
||||
return start_node
|
||||
|
||||
@property
|
||||
def current_node(self) -> Optional[WorkflowNodeDef]:
|
||||
if not self.current_node_id:
|
||||
return None
|
||||
return self._nodes.get(self.current_node_id)
|
||||
|
||||
def outgoing_edges(self, node_id: str) -> List[WorkflowEdgeDef]:
|
||||
edges = [edge for edge in self._edges if edge.from_node_id == node_id]
|
||||
return sorted(edges, key=lambda edge: (edge.priority, edge.order))
|
||||
|
||||
def next_default_transition(self) -> Optional[WorkflowTransition]:
|
||||
node = self.current_node
|
||||
if not node:
|
||||
return None
|
||||
for edge in self.outgoing_edges(node.id):
|
||||
cond_type = str(edge.condition.get("type", "always")).strip().lower()
|
||||
if cond_type in {"", "always", "default"}:
|
||||
target = self._nodes.get(edge.to_node_id)
|
||||
if target:
|
||||
return WorkflowTransition(edge=edge, node=target)
|
||||
return None
|
||||
|
||||
async def route(
|
||||
self,
|
||||
*,
|
||||
user_text: str,
|
||||
assistant_text: str,
|
||||
llm_router: Optional[LlmRouter] = None,
|
||||
) -> Optional[WorkflowTransition]:
|
||||
node = self.current_node
|
||||
if not node:
|
||||
return None
|
||||
|
||||
outgoing = self.outgoing_edges(node.id)
|
||||
if not outgoing:
|
||||
return None
|
||||
|
||||
llm_edges: List[WorkflowEdgeDef] = []
|
||||
for edge in outgoing:
|
||||
cond_type = str(edge.condition.get("type", "always")).strip().lower()
|
||||
if cond_type == "llm":
|
||||
llm_edges.append(edge)
|
||||
continue
|
||||
if self._matches_condition(edge, user_text=user_text, assistant_text=assistant_text):
|
||||
target = self._nodes.get(edge.to_node_id)
|
||||
if target:
|
||||
return WorkflowTransition(edge=edge, node=target)
|
||||
|
||||
if llm_edges and llm_router:
|
||||
selection = await llm_router(
|
||||
node,
|
||||
llm_edges,
|
||||
{
|
||||
"userText": user_text,
|
||||
"assistantText": assistant_text,
|
||||
},
|
||||
)
|
||||
if selection:
|
||||
for edge in llm_edges:
|
||||
if selection in {edge.id, edge.to_node_id}:
|
||||
target = self._nodes.get(edge.to_node_id)
|
||||
if target:
|
||||
return WorkflowTransition(edge=edge, node=target)
|
||||
|
||||
for edge in outgoing:
|
||||
cond_type = str(edge.condition.get("type", "always")).strip().lower()
|
||||
if cond_type in {"", "always", "default"}:
|
||||
target = self._nodes.get(edge.to_node_id)
|
||||
if target:
|
||||
return WorkflowTransition(edge=edge, node=target)
|
||||
return None
|
||||
|
||||
def apply_transition(self, transition: WorkflowTransition) -> None:
|
||||
self.current_node_id = transition.node.id
|
||||
|
||||
def build_runtime_metadata(self, node: WorkflowNodeDef) -> Dict[str, Any]:
|
||||
assistant_cfg = node.assistant if isinstance(node.assistant, dict) else {}
|
||||
message_plan = node.message_plan if isinstance(node.message_plan, dict) else {}
|
||||
metadata: Dict[str, Any] = {}
|
||||
|
||||
if node.prompt is not None:
|
||||
metadata["systemPrompt"] = node.prompt
|
||||
elif "systemPrompt" in assistant_cfg:
|
||||
metadata["systemPrompt"] = _safe_str(assistant_cfg.get("systemPrompt"))
|
||||
elif "prompt" in assistant_cfg:
|
||||
metadata["systemPrompt"] = _safe_str(assistant_cfg.get("prompt"))
|
||||
|
||||
first_message = message_plan.get("firstMessage")
|
||||
if first_message is not None:
|
||||
metadata["greeting"] = _safe_str(first_message)
|
||||
elif "greeting" in assistant_cfg:
|
||||
metadata["greeting"] = _safe_str(assistant_cfg.get("greeting"))
|
||||
elif "opener" in assistant_cfg:
|
||||
metadata["greeting"] = _safe_str(assistant_cfg.get("opener"))
|
||||
|
||||
services = assistant_cfg.get("services")
|
||||
if isinstance(services, dict):
|
||||
metadata["services"] = services
|
||||
|
||||
if node.assistant_id:
|
||||
metadata["assistantId"] = node.assistant_id
|
||||
|
||||
return metadata
|
||||
|
||||
def _resolve_start_node(self) -> Optional[WorkflowNodeDef]:
|
||||
explicit_start = next((node for node in self._nodes.values() if node.is_start), None)
|
||||
if not explicit_start:
|
||||
explicit_start = next((node for node in self._nodes.values() if node.node_type == "start"), None)
|
||||
|
||||
if explicit_start:
|
||||
# If a dedicated start node exists, try to move to its first default target.
|
||||
if explicit_start.node_type == "start":
|
||||
visited = {explicit_start.id}
|
||||
current = explicit_start
|
||||
for _ in range(8):
|
||||
transition = self._first_default_transition_from(current.id)
|
||||
if not transition:
|
||||
return current
|
||||
current = transition.node
|
||||
if current.id in visited:
|
||||
break
|
||||
visited.add(current.id)
|
||||
return current
|
||||
return explicit_start
|
||||
|
||||
assistant_node = next((node for node in self._nodes.values() if node.node_type == "assistant"), None)
|
||||
if assistant_node:
|
||||
return assistant_node
|
||||
return next(iter(self._nodes.values()), None)
|
||||
|
||||
def _first_default_transition_from(self, node_id: str) -> Optional[WorkflowTransition]:
|
||||
for edge in self.outgoing_edges(node_id):
|
||||
cond_type = str(edge.condition.get("type", "always")).strip().lower()
|
||||
if cond_type in {"", "always", "default"}:
|
||||
node = self._nodes.get(edge.to_node_id)
|
||||
if node:
|
||||
return WorkflowTransition(edge=edge, node=node)
|
||||
return None
|
||||
|
||||
def _matches_condition(self, edge: WorkflowEdgeDef, *, user_text: str, assistant_text: str) -> bool:
|
||||
condition = edge.condition or {"type": "always"}
|
||||
cond_type = str(condition.get("type", "always")).strip().lower()
|
||||
source = str(condition.get("source", "user")).strip().lower()
|
||||
|
||||
if cond_type in {"", "always", "default"}:
|
||||
return True
|
||||
|
||||
text = assistant_text if source == "assistant" else user_text
|
||||
text_lower = (text or "").lower()
|
||||
|
||||
if cond_type == "contains":
|
||||
values: List[str] = []
|
||||
if isinstance(condition.get("values"), list):
|
||||
values = [_safe_str(v).strip().lower() for v in condition["values"] if _safe_str(v).strip()]
|
||||
if not values:
|
||||
single = _safe_str(condition.get("value") or condition.get("keyword") or edge.label).strip().lower()
|
||||
if single:
|
||||
values = [single]
|
||||
if not values:
|
||||
return False
|
||||
return any(value in text_lower for value in values)
|
||||
|
||||
if cond_type == "equals":
|
||||
expected = _safe_str(condition.get("value") or "").strip().lower()
|
||||
return bool(expected) and text_lower == expected
|
||||
|
||||
if cond_type == "regex":
|
||||
pattern = _safe_str(condition.get("value") or condition.get("pattern") or "").strip()
|
||||
if not pattern:
|
||||
return False
|
||||
try:
|
||||
return bool(re.search(pattern, text or "", re.IGNORECASE))
|
||||
except re.error:
|
||||
logger.warning(f"Invalid workflow regex condition: {pattern}")
|
||||
return False
|
||||
|
||||
if cond_type == "json":
|
||||
value = _safe_str(condition.get("value") or "").strip()
|
||||
if not value:
|
||||
return False
|
||||
try:
|
||||
obj = json.loads(text or "")
|
||||
except Exception:
|
||||
return False
|
||||
return str(obj) == value
|
||||
|
||||
return False
|
||||
BIN
data/audio_examples/single_utterance_16k.wav
Normal file
BIN
data/audio_examples/single_utterance_16k.wav
Normal file
Binary file not shown.
BIN
data/audio_examples/three_utterances.wav
Normal file
BIN
data/audio_examples/three_utterances.wav
Normal file
Binary file not shown.
BIN
data/audio_examples/two_utterances.wav
Normal file
BIN
data/audio_examples/two_utterances.wav
Normal file
Binary file not shown.
BIN
data/vad/silero_vad.onnx
Normal file
BIN
data/vad/silero_vad.onnx
Normal file
Binary file not shown.
96
docs/duplex_interaction.svg
Normal file
96
docs/duplex_interaction.svg
Normal file
@@ -0,0 +1,96 @@
|
||||
<svg width="1200" height="620" viewBox="0 0 1200 620" xmlns="http://www.w3.org/2000/svg">
|
||||
<defs>
|
||||
<style>
|
||||
.box { fill:#11131a; stroke:#3a3f4b; stroke-width:1.2; rx:10; ry:10; }
|
||||
.title { font: 600 14px 'Arial'; fill:#f2f3f7; }
|
||||
.text { font: 12px 'Arial'; fill:#c8ccd8; }
|
||||
.arrow { stroke:#7aa2ff; stroke-width:1.6; marker-end:url(#arrow); fill:none; }
|
||||
.arrow2 { stroke:#2dd4bf; stroke-width:1.6; marker-end:url(#arrow); fill:none; }
|
||||
.arrow3 { stroke:#ff6b6b; stroke-width:1.6; marker-end:url(#arrow); fill:none; }
|
||||
.label { font: 11px 'Arial'; fill:#9aa3b2; }
|
||||
</style>
|
||||
<marker id="arrow" markerWidth="8" markerHeight="8" refX="7" refY="4" orient="auto">
|
||||
<path d="M0,0 L8,4 L0,8 Z" fill="#7aa2ff"/>
|
||||
</marker>
|
||||
</defs>
|
||||
|
||||
<rect x="40" y="40" width="250" height="120" class="box"/>
|
||||
<text x="60" y="70" class="title">Web Client</text>
|
||||
<text x="60" y="95" class="text">WS JSON commands</text>
|
||||
<text x="60" y="115" class="text">WS binary PCM audio</text>
|
||||
|
||||
<rect x="350" y="40" width="250" height="120" class="box"/>
|
||||
<text x="370" y="70" class="title">FastAPI /ws</text>
|
||||
<text x="370" y="95" class="text">Session + Transport</text>
|
||||
|
||||
<rect x="660" y="40" width="250" height="120" class="box"/>
|
||||
<text x="680" y="70" class="title">DuplexPipeline</text>
|
||||
<text x="680" y="95" class="text">process_audio / process_text</text>
|
||||
|
||||
<rect x="920" y="40" width="240" height="120" class="box"/>
|
||||
<text x="940" y="70" class="title">ConversationManager</text>
|
||||
<text x="940" y="95" class="text">turns + state</text>
|
||||
|
||||
<rect x="660" y="200" width="180" height="100" class="box"/>
|
||||
<text x="680" y="230" class="title">VADProcessor</text>
|
||||
<text x="680" y="255" class="text">speech/silence</text>
|
||||
|
||||
<rect x="860" y="200" width="180" height="100" class="box"/>
|
||||
<text x="880" y="230" class="title">EOU Detector</text>
|
||||
<text x="880" y="255" class="text">end-of-utterance</text>
|
||||
|
||||
<rect x="1060" y="200" width="120" height="100" class="box"/>
|
||||
<text x="1075" y="230" class="title">ASR</text>
|
||||
<text x="1075" y="255" class="text">transcripts</text>
|
||||
|
||||
<rect x="920" y="350" width="240" height="110" class="box"/>
|
||||
<text x="940" y="380" class="title">LLM (stream)</text>
|
||||
<text x="940" y="405" class="text">llmResponse events</text>
|
||||
|
||||
<rect x="660" y="350" width="220" height="110" class="box"/>
|
||||
<text x="680" y="380" class="title">TTS (stream)</text>
|
||||
<text x="680" y="405" class="text">PCM audio</text>
|
||||
|
||||
<rect x="40" y="350" width="250" height="110" class="box"/>
|
||||
<text x="60" y="380" class="title">Web Client</text>
|
||||
<text x="60" y="405" class="text">audio playback + UI</text>
|
||||
|
||||
<path d="M290 80 L350 80" class="arrow"/>
|
||||
<text x="300" y="70" class="label">JSON / PCM</text>
|
||||
|
||||
<path d="M600 80 L660 80" class="arrow"/>
|
||||
<text x="615" y="70" class="label">dispatch</text>
|
||||
|
||||
<path d="M910 80 L920 80" class="arrow"/>
|
||||
<text x="880" y="70" class="label">turn mgmt</text>
|
||||
|
||||
<path d="M750 160 L750 200" class="arrow"/>
|
||||
<text x="705" y="190" class="label">audio chunks</text>
|
||||
|
||||
<path d="M840 250 L860 250" class="arrow"/>
|
||||
<text x="835" y="240" class="label">vad status</text>
|
||||
|
||||
<path d="M1040 250 L1060 250" class="arrow"/>
|
||||
<text x="1010" y="240" class="label">audio buffer</text>
|
||||
|
||||
<path d="M950 300 L950 350" class="arrow2"/>
|
||||
<text x="930" y="340" class="label">EOU -> LLM</text>
|
||||
|
||||
<path d="M880 405 L920 405" class="arrow2"/>
|
||||
<text x="870" y="395" class="label">text stream</text>
|
||||
|
||||
<path d="M660 405 L290 405" class="arrow2"/>
|
||||
<text x="430" y="395" class="label">PCM audio</text>
|
||||
|
||||
<path d="M660 450 L350 450" class="arrow"/>
|
||||
<text x="420" y="440" class="label">events: trackStart/End</text>
|
||||
|
||||
<path d="M350 450 L290 450" class="arrow"/>
|
||||
<text x="315" y="440" class="label">UI updates</text>
|
||||
|
||||
<path d="M750 200 L750 160" class="arrow3"/>
|
||||
<text x="700" y="145" class="label">barge-in detection</text>
|
||||
|
||||
<path d="M760 170 L920 170" class="arrow3"/>
|
||||
<text x="820" y="160" class="label">interrupt event + cancel</text>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 3.9 KiB |
187
docs/proejct_todo.md
Normal file
187
docs/proejct_todo.md
Normal file
@@ -0,0 +1,187 @@
|
||||
# OmniSense: 12-Week Sprint Board + Tech Stack (Python Backend) — TODO
|
||||
|
||||
## Scope
|
||||
- [ ] Build a realtime AI SaaS (OmniSense) focused on web-first audio + video with WebSocket + WebRTC endpoints
|
||||
- [ ] Deliver assistant builder, tool execution, observability, evals, optional telephony later
|
||||
- [ ] Keep scope aligned to 2-person team, self-hosted services
|
||||
|
||||
---
|
||||
|
||||
## Sprint Board (12 weeks, 2-week sprints)
|
||||
Team assumption: 2 engineers. Scope prioritized to web-first audio + video, with BYO-SFU adapters.
|
||||
|
||||
### Sprint 1 (Weeks 1–2) — Realtime Core MVP (WebSocket + WebRTC Audio)
|
||||
- Deliverables
|
||||
- [ ] WebSocket transport: audio in/out streaming (1:1)
|
||||
- [ ] WebRTC transport: audio in/out streaming (1:1)
|
||||
- [ ] Adapter contract wired into runtime (transport-agnostic session core)
|
||||
- [ ] ASR → LLM → TTS pipeline, streaming both directions
|
||||
- [ ] Basic session state (start/stop, silence timeout)
|
||||
- [ ] Transcript persistence
|
||||
- Acceptance criteria
|
||||
- [ ] < 1.5s median round-trip for short responses
|
||||
- [ ] Stable streaming for 10+ minute session
|
||||
|
||||
### Sprint 2 (Weeks 3–4) — Video + Realtime UX
|
||||
- Deliverables
|
||||
- [ ] WebRTC video capture + streaming (assistant can “see” frames)
|
||||
- [ ] WebSocket video streaming for local/dev mode
|
||||
- [ ] Low-latency UI: push-to-talk, live captions, speaking indicator
|
||||
- [ ] Recording + transcript storage (web sessions)
|
||||
- Acceptance criteria
|
||||
- [ ] Video < 2.5s end-to-end latency for analysis
|
||||
- [ ] Audio quality acceptable (no clipping, jitter handling)
|
||||
|
||||
### Sprint 3 (Weeks 5–6) — Assistant Builder v1
|
||||
- Deliverables
|
||||
- [ ] Assistant schema + versioning
|
||||
- [ ] UI: Model/Voice/Transcriber/Tools/Video/Transport tabs
|
||||
- [ ] “Test/Chat/Talk to Assistant” (web)
|
||||
- Acceptance criteria
|
||||
- [ ] Create/publish assistant and run a live web session
|
||||
- [ ] All config changes tracked by version
|
||||
|
||||
### Sprint 4 (Weeks 7–8) — Tooling + Structured Outputs
|
||||
- Deliverables
|
||||
- [ ] Tool registry + custom HTTP tools
|
||||
- [ ] Tool auth secrets management
|
||||
- [ ] Structured outputs (JSON extraction)
|
||||
- Acceptance criteria
|
||||
- [ ] Tool calls executed with retries/timeouts
|
||||
- [ ] Structured JSON stored per call/session
|
||||
|
||||
### Sprint 5 (Weeks 9–10) — Observability + QA + Dev Platform
|
||||
- Deliverables
|
||||
- [ ] Session logs + chat logs + media logs
|
||||
- [ ] Evals engine + test suites
|
||||
- [ ] Basic analytics dashboard
|
||||
- [ ] Public WebSocket API spec + message schema
|
||||
- [ ] JS/TS SDK (connect, send audio/video, receive transcripts)
|
||||
- Acceptance criteria
|
||||
- [ ] Reproducible test suite runs
|
||||
- [ ] Log filters by assistant/time/status
|
||||
- [ ] SDK demo app runs end-to-end
|
||||
|
||||
### Sprint 6 (Weeks 11–12) — SaaS Hardening
|
||||
- Deliverables
|
||||
- [ ] Org/RBAC + API keys + rate limits
|
||||
- [ ] Usage metering + credits
|
||||
- [ ] Stripe billing integration
|
||||
- [ ] Self-hosted DB ops (migrations, backup/restore, monitoring)
|
||||
- Acceptance criteria
|
||||
- [ ] Metered usage per org
|
||||
- [ ] Credits decrement correctly
|
||||
- [ ] Optional telephony spike documented (defer build)
|
||||
- [ ] Enterprise adapter guide published (BYO-SFU)
|
||||
|
||||
---
|
||||
|
||||
## Tech Stack by Service (Self-Hosted, Web-First)
|
||||
|
||||
### 1) Transport Gateway (Realtime)
|
||||
- [ ] WebRTC (browser) + WebSocket (lightweight/dev) protocols
|
||||
- [ ] BYO-SFU adapter (enterprise) + LiveKit optional adapter + WS transport server
|
||||
- [ ] Python core (FastAPI + asyncio) + Node.js mediasoup adapters when needed
|
||||
- [ ] Media: Opus/VP8, jitter buffer, VAD, echo cancellation
|
||||
- [ ] Storage: S3-compatible (MinIO) for recordings
|
||||
|
||||
### 2) ASR Service
|
||||
- [ ] Whisper (self-hosted) baseline
|
||||
- [ ] gRPC/WebSocket streaming transport
|
||||
- [ ] Python native service
|
||||
- [ ] Optional cloud provider fallback (later)
|
||||
|
||||
### 3) TTS Service
|
||||
- [ ] Piper or Coqui TTS (self-hosted)
|
||||
- [ ] gRPC/WebSocket streaming transport
|
||||
- [ ] Python native service
|
||||
- [ ] Redis cache for common phrases
|
||||
|
||||
### 4) LLM Orchestrator
|
||||
- [ ] Self-hosted (vLLM + open model)
|
||||
- [ ] Python (FastAPI + asyncio)
|
||||
- [ ] Streaming, tool calling, JSON mode
|
||||
- [ ] Safety filters + prompt templates
|
||||
|
||||
### 5) Assistant Config Service
|
||||
- [ ] PostgreSQL
|
||||
- [ ] Python (SQLAlchemy or SQLModel)
|
||||
- [ ] Versioning, publish/rollback
|
||||
|
||||
### 6) Session Service
|
||||
- [ ] PostgreSQL + Redis
|
||||
- [ ] Python
|
||||
- [ ] State machine, timeouts, events
|
||||
|
||||
### 7) Tool Execution Layer
|
||||
- [ ] PostgreSQL
|
||||
- [ ] Python
|
||||
- [ ] Auth secret vault, retry policies, tool schemas
|
||||
|
||||
### 8) Observability + Logs
|
||||
- [ ] Postgres (metadata), ClickHouse (logs/metrics)
|
||||
- [ ] OpenSearch for search
|
||||
- [ ] Prometheus + Grafana metrics
|
||||
- [ ] OpenTelemetry tracing
|
||||
|
||||
### 9) Billing + Usage Metering
|
||||
- [ ] Stripe billing
|
||||
- [ ] PostgreSQL
|
||||
- [ ] NATS JetStream (events) + Redis counters
|
||||
|
||||
### 10) Web App (Dashboard)
|
||||
- [ ] React + Next.js
|
||||
- [ ] Tailwind or Radix UI
|
||||
- [ ] WebRTC client + WS client; adapter-based RTC integration
|
||||
- [ ] ECharts/Recharts
|
||||
|
||||
### 11) Auth + RBAC
|
||||
- [ ] Keycloak (self-hosted) or custom JWT
|
||||
- [ ] Org/user/role tables in Postgres
|
||||
|
||||
### 12) Public WebSocket API + SDK
|
||||
- [ ] WS API: versioned schema, binary audio frames + JSON control messages
|
||||
- [ ] SDKs: JS/TS first, optional Python/Go clients
|
||||
- [ ] Docs: quickstart, auth flow, session lifecycle, examples
|
||||
|
||||
---
|
||||
|
||||
## Infrastructure (Self-Hosted)
|
||||
- [ ] Docker Compose → k3s (later)
|
||||
- [ ] Redis Streams or NATS
|
||||
- [ ] MinIO object store
|
||||
- [ ] GitHub Actions + Helm or kustomize
|
||||
- [ ] Self-hosted Postgres + pgbackrest backups
|
||||
- [ ] Vault for secrets
|
||||
|
||||
---
|
||||
|
||||
## Suggested MVP Sequence
|
||||
- [ ] WebRTC demo + ASR/LLM/TTS streaming
|
||||
- [ ] Assistant schema + versioning (web-first)
|
||||
- [ ] Video capture + multimodal analysis
|
||||
- [ ] Tool execution + structured outputs
|
||||
- [ ] Logs + evals + public WS API + SDK
|
||||
- [ ] Telephony (optional, later)
|
||||
|
||||
---
|
||||
|
||||
## Public WebSocket API (Minimum Spec)
|
||||
- [ ] Auth: API key or JWT in initial `hello` message
|
||||
- [ ] Core messages: `session.start`, `session.stop`, `audio.append`, `audio.commit`, `video.append`, `transcript.delta`, `assistant.response`, `tool.call`, `tool.result`, `error`
|
||||
- [ ] Binary payloads: PCM/Opus frames with metadata in control channel
|
||||
- [ ] Versioning: `v1` schema with backward compatibility rules
|
||||
|
||||
---
|
||||
|
||||
## Self-Hosted DB Ops Checklist
|
||||
- [ ] Postgres in Docker/k3s with persistent volumes
|
||||
- [ ] Migrations: `alembic` or `atlas`
|
||||
- [ ] Backups: `pgbackrest` nightly + on-demand
|
||||
- [ ] Monitoring: postgres_exporter + alerts
|
||||
|
||||
---
|
||||
|
||||
## RTC Adapter Contract (BYO-SFU First)
|
||||
- [ ] Keep RTC pluggable; LiveKit optional, not core dependency
|
||||
- [ ] Define adapter interface (TypeScript sketch)
|
||||
199
docs/ws_v1_schema.md
Normal file
199
docs/ws_v1_schema.md
Normal file
@@ -0,0 +1,199 @@
|
||||
# WS v1 Protocol Schema (`/ws`)
|
||||
|
||||
This document defines the public WebSocket protocol for the `/ws` endpoint.
|
||||
|
||||
## Transport
|
||||
|
||||
- A single WebSocket connection carries:
|
||||
- JSON text frames for control/events.
|
||||
- Binary frames for raw PCM audio (`pcm_s16le`, mono, 16kHz by default).
|
||||
|
||||
## Handshake and State Machine
|
||||
|
||||
Required message order:
|
||||
|
||||
1. Client sends `hello`.
|
||||
2. Server replies `hello.ack`.
|
||||
3. Client sends `session.start`.
|
||||
4. Server replies `session.started`.
|
||||
5. Client may stream binary audio and/or send `input.text`.
|
||||
6. Client sends `session.stop` (or closes socket).
|
||||
|
||||
If order is violated, server emits `error` with `code = "protocol.order"`.
|
||||
|
||||
## Client -> Server Messages
|
||||
|
||||
### `hello`
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "hello",
|
||||
"version": "v1",
|
||||
"auth": {
|
||||
"apiKey": "optional-api-key",
|
||||
"jwt": "optional-jwt"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Rules:
|
||||
- `version` must be `v1`.
|
||||
- If `WS_API_KEY` is configured on server, `auth.apiKey` must match.
|
||||
- If `WS_REQUIRE_AUTH=true`, either `auth.apiKey` or `auth.jwt` must be present.
|
||||
|
||||
### `session.start`
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "session.start",
|
||||
"audio": {
|
||||
"encoding": "pcm_s16le",
|
||||
"sample_rate_hz": 16000,
|
||||
"channels": 1
|
||||
},
|
||||
"metadata": {
|
||||
"client": "web-debug",
|
||||
"output": {
|
||||
"mode": "audio"
|
||||
},
|
||||
"systemPrompt": "You are concise.",
|
||||
"greeting": "Hi, how can I help?",
|
||||
"services": {
|
||||
"llm": {
|
||||
"provider": "openai",
|
||||
"model": "gpt-4o-mini",
|
||||
"apiKey": "sk-...",
|
||||
"baseUrl": "https://api.openai.com/v1"
|
||||
},
|
||||
"asr": {
|
||||
"provider": "openai_compatible",
|
||||
"model": "FunAudioLLM/SenseVoiceSmall",
|
||||
"apiKey": "sf-...",
|
||||
"interimIntervalMs": 500,
|
||||
"minAudioMs": 300
|
||||
},
|
||||
"tts": {
|
||||
"enabled": true,
|
||||
"provider": "openai_compatible",
|
||||
"model": "FunAudioLLM/CosyVoice2-0.5B",
|
||||
"apiKey": "sf-...",
|
||||
"voice": "anna",
|
||||
"speed": 1.0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
`metadata.services` is optional. If omitted, server defaults to environment configuration.
|
||||
|
||||
Text-only mode:
|
||||
- Set `metadata.output.mode = "text"` OR `metadata.services.tts.enabled = false`.
|
||||
- In this mode server still sends `assistant.response.delta/final`, but will not emit audio frames or `output.audio.start/end`.
|
||||
|
||||
### `input.text`
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "input.text",
|
||||
"text": "What can you do?"
|
||||
}
|
||||
```
|
||||
|
||||
### `response.cancel`
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "response.cancel",
|
||||
"graceful": false
|
||||
}
|
||||
```
|
||||
|
||||
### `session.stop`
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "session.stop",
|
||||
"reason": "client_disconnect"
|
||||
}
|
||||
```
|
||||
|
||||
### `tool_call.results`
|
||||
|
||||
Client tool execution results returned to server.
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "tool_call.results",
|
||||
"results": [
|
||||
{
|
||||
"tool_call_id": "call_abc123",
|
||||
"name": "weather",
|
||||
"output": { "temp_c": 21, "condition": "sunny" },
|
||||
"status": { "code": 200, "message": "ok" }
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Server -> Client Events
|
||||
|
||||
All server events include:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "event.name",
|
||||
"timestamp": 1730000000000
|
||||
}
|
||||
```
|
||||
|
||||
Common events:
|
||||
|
||||
- `hello.ack`
|
||||
- Fields: `sessionId`, `version`
|
||||
- `session.started`
|
||||
- Fields: `sessionId`, `trackId`, `audio`
|
||||
- `session.stopped`
|
||||
- Fields: `sessionId`, `reason`
|
||||
- `heartbeat`
|
||||
- `input.speech_started`
|
||||
- Fields: `trackId`, `probability`
|
||||
- `input.speech_stopped`
|
||||
- Fields: `trackId`, `probability`
|
||||
- `transcript.delta`
|
||||
- Fields: `trackId`, `text`
|
||||
- `transcript.final`
|
||||
- Fields: `trackId`, `text`
|
||||
- `assistant.response.delta`
|
||||
- Fields: `trackId`, `text`
|
||||
- `assistant.response.final`
|
||||
- Fields: `trackId`, `text`
|
||||
- `assistant.tool_call`
|
||||
- Fields: `trackId`, `tool_call` (`tool_call.executor` is `client` or `server`)
|
||||
- `assistant.tool_result`
|
||||
- Fields: `trackId`, `source`, `result`
|
||||
- `output.audio.start`
|
||||
- Fields: `trackId`
|
||||
- `output.audio.end`
|
||||
- Fields: `trackId`
|
||||
- `response.interrupted`
|
||||
- Fields: `trackId`
|
||||
- `metrics.ttfb`
|
||||
- Fields: `trackId`, `latencyMs`
|
||||
- `error`
|
||||
- Fields: `sender`, `code`, `message`, `trackId`
|
||||
|
||||
## Binary Audio Frames
|
||||
|
||||
After `session.started`, client may send binary PCM chunks continuously.
|
||||
|
||||
Recommended format:
|
||||
- 16-bit signed little-endian PCM.
|
||||
- 1 channel.
|
||||
- 16000 Hz.
|
||||
- 20ms frames (640 bytes) preferred.
|
||||
|
||||
## Compatibility
|
||||
|
||||
This endpoint now enforces v1 message schema for JSON control frames.
|
||||
Legacy command names (`invite`, `chat`, etc.) are no longer part of the public protocol.
|
||||
601
examples/mic_client.py
Normal file
601
examples/mic_client.py
Normal file
@@ -0,0 +1,601 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Microphone client for testing duplex voice conversation.
|
||||
|
||||
This client captures audio from the microphone, sends it to the server,
|
||||
and plays back the AI's voice response through the speakers.
|
||||
It also displays the LLM's text responses in the console.
|
||||
|
||||
Usage:
|
||||
python examples/mic_client.py --url ws://localhost:8000/ws
|
||||
python examples/mic_client.py --url ws://localhost:8000/ws --chat "Hello!"
|
||||
python examples/mic_client.py --url ws://localhost:8000/ws --verbose
|
||||
|
||||
Requirements:
|
||||
pip install sounddevice soundfile websockets numpy
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
import threading
|
||||
import queue
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
print("Please install numpy: pip install numpy")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
import sounddevice as sd
|
||||
except ImportError:
|
||||
print("Please install sounddevice: pip install sounddevice")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
import websockets
|
||||
except ImportError:
|
||||
print("Please install websockets: pip install websockets")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
class MicrophoneClient:
|
||||
"""
|
||||
Full-duplex microphone client for voice conversation.
|
||||
|
||||
Features:
|
||||
- Real-time microphone capture
|
||||
- Real-time speaker playback
|
||||
- WebSocket communication
|
||||
- Text chat support
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
sample_rate: int = 16000,
|
||||
chunk_duration_ms: int = 20,
|
||||
input_device: int = None,
|
||||
output_device: int = None
|
||||
):
|
||||
"""
|
||||
Initialize microphone client.
|
||||
|
||||
Args:
|
||||
url: WebSocket server URL
|
||||
sample_rate: Audio sample rate (Hz)
|
||||
chunk_duration_ms: Audio chunk duration (ms)
|
||||
input_device: Input device ID (None for default)
|
||||
output_device: Output device ID (None for default)
|
||||
"""
|
||||
self.url = url
|
||||
self.sample_rate = sample_rate
|
||||
self.chunk_duration_ms = chunk_duration_ms
|
||||
self.chunk_samples = int(sample_rate * chunk_duration_ms / 1000)
|
||||
self.input_device = input_device
|
||||
self.output_device = output_device
|
||||
|
||||
# WebSocket connection
|
||||
self.ws = None
|
||||
self.running = False
|
||||
|
||||
# Audio buffers
|
||||
self.audio_input_queue = queue.Queue()
|
||||
self.audio_output_buffer = b"" # Continuous buffer for smooth playback
|
||||
self.audio_output_lock = threading.Lock()
|
||||
|
||||
# Statistics
|
||||
self.bytes_sent = 0
|
||||
self.bytes_received = 0
|
||||
|
||||
# State
|
||||
self.is_recording = True
|
||||
self.is_playing = True
|
||||
|
||||
# TTFB tracking (Time to First Byte)
|
||||
self.request_start_time = None
|
||||
self.first_audio_received = False
|
||||
|
||||
# Interrupt handling - discard audio until next trackStart
|
||||
self._discard_audio = False
|
||||
self._audio_sequence = 0 # Track audio sequence to detect stale chunks
|
||||
|
||||
# Verbose mode for streaming LLM responses
|
||||
self.verbose = False
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Connect to WebSocket server."""
|
||||
print(f"Connecting to {self.url}...")
|
||||
self.ws = await websockets.connect(self.url)
|
||||
self.running = True
|
||||
print("Connected!")
|
||||
|
||||
# Send invite command
|
||||
await self.send_command({
|
||||
"command": "invite",
|
||||
"option": {
|
||||
"codec": "pcm",
|
||||
"sampleRate": self.sample_rate
|
||||
}
|
||||
})
|
||||
|
||||
async def send_command(self, cmd: dict) -> None:
|
||||
"""Send JSON command to server."""
|
||||
if self.ws:
|
||||
await self.ws.send(json.dumps(cmd))
|
||||
print(f"→ Command: {cmd.get('command', 'unknown')}")
|
||||
|
||||
async def send_chat(self, text: str) -> None:
|
||||
"""Send chat message (text input)."""
|
||||
# Reset TTFB tracking for new request
|
||||
self.request_start_time = time.time()
|
||||
self.first_audio_received = False
|
||||
|
||||
await self.send_command({
|
||||
"command": "chat",
|
||||
"text": text
|
||||
})
|
||||
print(f"→ Chat: {text}")
|
||||
|
||||
async def send_interrupt(self) -> None:
|
||||
"""Send interrupt command."""
|
||||
await self.send_command({
|
||||
"command": "interrupt"
|
||||
})
|
||||
|
||||
async def send_hangup(self, reason: str = "User quit") -> None:
|
||||
"""Send hangup command."""
|
||||
await self.send_command({
|
||||
"command": "hangup",
|
||||
"reason": reason
|
||||
})
|
||||
|
||||
def _audio_input_callback(self, indata, frames, time, status):
|
||||
"""Callback for audio input (microphone)."""
|
||||
if status:
|
||||
print(f"Input status: {status}")
|
||||
|
||||
if self.is_recording and self.running:
|
||||
# Convert to 16-bit PCM
|
||||
audio_data = (indata[:, 0] * 32767).astype(np.int16).tobytes()
|
||||
self.audio_input_queue.put(audio_data)
|
||||
|
||||
def _add_audio_to_buffer(self, audio_data: bytes):
|
||||
"""Add audio data to playback buffer."""
|
||||
with self.audio_output_lock:
|
||||
self.audio_output_buffer += audio_data
|
||||
|
||||
def _playback_thread_func(self):
|
||||
"""Thread function for continuous audio playback."""
|
||||
import time
|
||||
|
||||
# Chunk size: 50ms of audio
|
||||
chunk_samples = int(self.sample_rate * 0.05)
|
||||
chunk_bytes = chunk_samples * 2
|
||||
|
||||
print(f"Audio playback thread started (device: {self.output_device or 'default'})")
|
||||
|
||||
try:
|
||||
# Create output stream with callback
|
||||
with sd.OutputStream(
|
||||
samplerate=self.sample_rate,
|
||||
channels=1,
|
||||
dtype='int16',
|
||||
blocksize=chunk_samples,
|
||||
device=self.output_device,
|
||||
latency='low'
|
||||
) as stream:
|
||||
while self.running:
|
||||
# Get audio from buffer
|
||||
with self.audio_output_lock:
|
||||
if len(self.audio_output_buffer) >= chunk_bytes:
|
||||
audio_data = self.audio_output_buffer[:chunk_bytes]
|
||||
self.audio_output_buffer = self.audio_output_buffer[chunk_bytes:]
|
||||
else:
|
||||
# Not enough audio - output silence
|
||||
audio_data = b'\x00' * chunk_bytes
|
||||
|
||||
# Convert to numpy array and write to stream
|
||||
samples = np.frombuffer(audio_data, dtype=np.int16).reshape(-1, 1)
|
||||
stream.write(samples)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Playback thread error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
async def _playback_task(self):
|
||||
"""Start playback thread and monitor it."""
|
||||
# Run playback in a dedicated thread for reliable timing
|
||||
playback_thread = threading.Thread(target=self._playback_thread_func, daemon=True)
|
||||
playback_thread.start()
|
||||
|
||||
# Wait for client to stop
|
||||
while self.running and playback_thread.is_alive():
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
print("Audio playback stopped")
|
||||
|
||||
async def audio_sender(self) -> None:
|
||||
"""Send audio from microphone to server."""
|
||||
while self.running:
|
||||
try:
|
||||
# Get audio from queue with timeout
|
||||
try:
|
||||
audio_data = await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: self.audio_input_queue.get(timeout=0.1)
|
||||
)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
# Send to server
|
||||
if self.ws and self.is_recording:
|
||||
await self.ws.send(audio_data)
|
||||
self.bytes_sent += len(audio_data)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Audio sender error: {e}")
|
||||
break
|
||||
|
||||
async def receiver(self) -> None:
|
||||
"""Receive messages from server."""
|
||||
try:
|
||||
while self.running:
|
||||
try:
|
||||
message = await asyncio.wait_for(self.ws.recv(), timeout=0.1)
|
||||
|
||||
if isinstance(message, bytes):
|
||||
# Audio data received
|
||||
self.bytes_received += len(message)
|
||||
|
||||
# Check if we should discard this audio (after interrupt)
|
||||
if self._discard_audio:
|
||||
duration_ms = len(message) / (self.sample_rate * 2) * 1000
|
||||
print(f"← Audio: {duration_ms:.0f}ms (DISCARDED - waiting for new track)")
|
||||
continue
|
||||
|
||||
if self.is_playing:
|
||||
self._add_audio_to_buffer(message)
|
||||
|
||||
# Calculate and display TTFB for first audio packet
|
||||
if not self.first_audio_received and self.request_start_time:
|
||||
client_ttfb_ms = (time.time() - self.request_start_time) * 1000
|
||||
self.first_audio_received = True
|
||||
print(f"← [TTFB] Client first audio latency: {client_ttfb_ms:.0f}ms")
|
||||
|
||||
# Show progress (less verbose)
|
||||
with self.audio_output_lock:
|
||||
buffer_ms = len(self.audio_output_buffer) / (self.sample_rate * 2) * 1000
|
||||
duration_ms = len(message) / (self.sample_rate * 2) * 1000
|
||||
print(f"← Audio: {duration_ms:.0f}ms (buffer: {buffer_ms:.0f}ms)")
|
||||
|
||||
else:
|
||||
# JSON event
|
||||
event = json.loads(message)
|
||||
await self._handle_event(event)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except websockets.ConnectionClosed:
|
||||
print("Connection closed")
|
||||
self.running = False
|
||||
break
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"Receiver error: {e}")
|
||||
self.running = False
|
||||
|
||||
async def _handle_event(self, event: dict) -> None:
|
||||
"""Handle incoming event."""
|
||||
event_type = event.get("event", "unknown")
|
||||
|
||||
if event_type == "answer":
|
||||
print("← Session ready!")
|
||||
elif event_type == "speaking":
|
||||
print("← User speech detected")
|
||||
elif event_type == "silence":
|
||||
print("← User silence detected")
|
||||
elif event_type == "transcript":
|
||||
# Display user speech transcription
|
||||
text = event.get("text", "")
|
||||
is_final = event.get("isFinal", False)
|
||||
if is_final:
|
||||
# Clear the interim line and print final
|
||||
print(" " * 80, end="\r") # Clear previous interim text
|
||||
print(f"→ You: {text}")
|
||||
else:
|
||||
# Interim result - show with indicator (overwrite same line)
|
||||
display_text = text[:60] + "..." if len(text) > 60 else text
|
||||
print(f" [listening] {display_text}".ljust(80), end="\r")
|
||||
elif event_type == "ttfb":
|
||||
# Server-side TTFB event
|
||||
latency_ms = event.get("latencyMs", 0)
|
||||
print(f"← [TTFB] Server reported latency: {latency_ms}ms")
|
||||
elif event_type == "llmResponse":
|
||||
# LLM text response
|
||||
text = event.get("text", "")
|
||||
is_final = event.get("isFinal", False)
|
||||
if is_final:
|
||||
# Print final LLM response
|
||||
print(f"← AI: {text}")
|
||||
elif self.verbose:
|
||||
# Show streaming chunks only in verbose mode
|
||||
display_text = text[:60] + "..." if len(text) > 60 else text
|
||||
print(f" [streaming] {display_text}")
|
||||
elif event_type == "trackStart":
|
||||
print("← Bot started speaking")
|
||||
# IMPORTANT: Accept audio again after trackStart
|
||||
self._discard_audio = False
|
||||
self._audio_sequence += 1
|
||||
# Reset TTFB tracking for voice responses (when no chat was sent)
|
||||
if self.request_start_time is None:
|
||||
self.request_start_time = time.time()
|
||||
self.first_audio_received = False
|
||||
# Clear any old audio in buffer
|
||||
with self.audio_output_lock:
|
||||
self.audio_output_buffer = b""
|
||||
elif event_type == "trackEnd":
|
||||
print("← Bot finished speaking")
|
||||
# Reset TTFB tracking after response completes
|
||||
self.request_start_time = None
|
||||
self.first_audio_received = False
|
||||
elif event_type == "interrupt":
|
||||
print("← Bot interrupted!")
|
||||
# IMPORTANT: Discard all audio until next trackStart
|
||||
self._discard_audio = True
|
||||
# Clear audio buffer immediately
|
||||
with self.audio_output_lock:
|
||||
buffer_ms = len(self.audio_output_buffer) / (self.sample_rate * 2) * 1000
|
||||
self.audio_output_buffer = b""
|
||||
print(f" (cleared {buffer_ms:.0f}ms, discarding audio until new track)")
|
||||
elif event_type == "error":
|
||||
print(f"← Error: {event.get('error')}")
|
||||
elif event_type == "hangup":
|
||||
print(f"← Hangup: {event.get('reason')}")
|
||||
self.running = False
|
||||
else:
|
||||
print(f"← Event: {event_type}")
|
||||
|
||||
async def interactive_mode(self) -> None:
|
||||
"""Run interactive mode for text chat."""
|
||||
print("\n" + "=" * 50)
|
||||
print("Voice Conversation Client")
|
||||
print("=" * 50)
|
||||
print("Speak into your microphone to talk to the AI.")
|
||||
print("Or type messages to send text.")
|
||||
print("")
|
||||
print("Commands:")
|
||||
print(" /quit - End conversation")
|
||||
print(" /mute - Mute microphone")
|
||||
print(" /unmute - Unmute microphone")
|
||||
print(" /interrupt - Interrupt AI speech")
|
||||
print(" /stats - Show statistics")
|
||||
print("=" * 50 + "\n")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
user_input = await asyncio.get_event_loop().run_in_executor(
|
||||
None, input, ""
|
||||
)
|
||||
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
# Handle commands
|
||||
if user_input.startswith("/"):
|
||||
cmd = user_input.lower().strip()
|
||||
|
||||
if cmd == "/quit":
|
||||
await self.send_hangup("User quit")
|
||||
break
|
||||
elif cmd == "/mute":
|
||||
self.is_recording = False
|
||||
print("Microphone muted")
|
||||
elif cmd == "/unmute":
|
||||
self.is_recording = True
|
||||
print("Microphone unmuted")
|
||||
elif cmd == "/interrupt":
|
||||
await self.send_interrupt()
|
||||
elif cmd == "/stats":
|
||||
print(f"Sent: {self.bytes_sent / 1024:.1f} KB")
|
||||
print(f"Received: {self.bytes_received / 1024:.1f} KB")
|
||||
else:
|
||||
print(f"Unknown command: {cmd}")
|
||||
else:
|
||||
# Send as chat message
|
||||
await self.send_chat(user_input)
|
||||
|
||||
except EOFError:
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Input error: {e}")
|
||||
|
||||
async def run(self, chat_message: str = None, interactive: bool = True) -> None:
|
||||
"""
|
||||
Run the client.
|
||||
|
||||
Args:
|
||||
chat_message: Optional single chat message to send
|
||||
interactive: Whether to run in interactive mode
|
||||
"""
|
||||
try:
|
||||
await self.connect()
|
||||
|
||||
# Wait for answer
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Start audio input stream
|
||||
print("Starting audio streams...")
|
||||
|
||||
input_stream = sd.InputStream(
|
||||
samplerate=self.sample_rate,
|
||||
channels=1,
|
||||
dtype=np.float32,
|
||||
blocksize=self.chunk_samples,
|
||||
device=self.input_device,
|
||||
callback=self._audio_input_callback
|
||||
)
|
||||
|
||||
input_stream.start()
|
||||
print("Audio streams started")
|
||||
|
||||
# Start background tasks
|
||||
sender_task = asyncio.create_task(self.audio_sender())
|
||||
receiver_task = asyncio.create_task(self.receiver())
|
||||
playback_task = asyncio.create_task(self._playback_task())
|
||||
|
||||
if chat_message:
|
||||
# Send single message and wait
|
||||
await self.send_chat(chat_message)
|
||||
await asyncio.sleep(15)
|
||||
elif interactive:
|
||||
# Run interactive mode
|
||||
await self.interactive_mode()
|
||||
else:
|
||||
# Just wait
|
||||
while self.running:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Cleanup
|
||||
self.running = False
|
||||
sender_task.cancel()
|
||||
receiver_task.cancel()
|
||||
playback_task.cancel()
|
||||
|
||||
try:
|
||||
await sender_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
try:
|
||||
await receiver_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
try:
|
||||
await playback_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
input_stream.stop()
|
||||
|
||||
except ConnectionRefusedError:
|
||||
print(f"Error: Could not connect to {self.url}")
|
||||
print("Make sure the server is running.")
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
finally:
|
||||
await self.close()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the connection."""
|
||||
self.running = False
|
||||
if self.ws:
|
||||
await self.ws.close()
|
||||
|
||||
print(f"\nSession ended")
|
||||
print(f" Total sent: {self.bytes_sent / 1024:.1f} KB")
|
||||
print(f" Total received: {self.bytes_received / 1024:.1f} KB")
|
||||
|
||||
|
||||
def list_devices():
|
||||
"""List available audio devices."""
|
||||
print("\nAvailable audio devices:")
|
||||
print("-" * 60)
|
||||
devices = sd.query_devices()
|
||||
for i, device in enumerate(devices):
|
||||
direction = []
|
||||
if device['max_input_channels'] > 0:
|
||||
direction.append("IN")
|
||||
if device['max_output_channels'] > 0:
|
||||
direction.append("OUT")
|
||||
direction_str = "/".join(direction) if direction else "N/A"
|
||||
|
||||
default = ""
|
||||
if i == sd.default.device[0]:
|
||||
default += " [DEFAULT INPUT]"
|
||||
if i == sd.default.device[1]:
|
||||
default += " [DEFAULT OUTPUT]"
|
||||
|
||||
print(f" {i:2d}: {device['name'][:40]:40s} ({direction_str}){default}")
|
||||
print("-" * 60)
|
||||
|
||||
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Microphone client for duplex voice conversation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--url",
|
||||
default="ws://localhost:8000/ws",
|
||||
help="WebSocket server URL"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chat",
|
||||
help="Send a single chat message instead of using microphone"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="Audio sample rate (default: 16000)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-device",
|
||||
type=int,
|
||||
help="Input device ID"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-device",
|
||||
type=int,
|
||||
help="Output device ID"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--list-devices",
|
||||
action="store_true",
|
||||
help="List available audio devices and exit"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-interactive",
|
||||
action="store_true",
|
||||
help="Disable interactive mode"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose", "-v",
|
||||
action="store_true",
|
||||
help="Show streaming LLM response chunks"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.list_devices:
|
||||
list_devices()
|
||||
return
|
||||
|
||||
client = MicrophoneClient(
|
||||
url=args.url,
|
||||
sample_rate=args.sample_rate,
|
||||
input_device=args.input_device,
|
||||
output_device=args.output_device
|
||||
)
|
||||
client.verbose = args.verbose
|
||||
|
||||
await client.run(
|
||||
chat_message=args.chat,
|
||||
interactive=not args.no_interactive
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
print("\nInterrupted by user")
|
||||
285
examples/simple_client.py
Normal file
285
examples/simple_client.py
Normal file
@@ -0,0 +1,285 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Simple WebSocket client for testing voice conversation.
|
||||
Uses PyAudio for more reliable audio playback on Windows.
|
||||
|
||||
Usage:
|
||||
python examples/simple_client.py
|
||||
python examples/simple_client.py --text "Hello"
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
import wave
|
||||
import io
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
print("pip install numpy")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
import websockets
|
||||
except ImportError:
|
||||
print("pip install websockets")
|
||||
sys.exit(1)
|
||||
|
||||
# Try PyAudio first (more reliable on Windows)
|
||||
try:
|
||||
import pyaudio
|
||||
PYAUDIO_AVAILABLE = True
|
||||
except ImportError:
|
||||
PYAUDIO_AVAILABLE = False
|
||||
print("PyAudio not available, trying sounddevice...")
|
||||
|
||||
try:
|
||||
import sounddevice as sd
|
||||
SD_AVAILABLE = True
|
||||
except ImportError:
|
||||
SD_AVAILABLE = False
|
||||
|
||||
if not PYAUDIO_AVAILABLE and not SD_AVAILABLE:
|
||||
print("Please install pyaudio or sounddevice:")
|
||||
print(" pip install pyaudio")
|
||||
print(" or: pip install sounddevice")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
class SimpleVoiceClient:
|
||||
"""Simple voice client with reliable audio playback."""
|
||||
|
||||
def __init__(self, url: str, sample_rate: int = 16000):
|
||||
self.url = url
|
||||
self.sample_rate = sample_rate
|
||||
self.ws = None
|
||||
self.running = False
|
||||
|
||||
# Audio buffer
|
||||
self.audio_buffer = b""
|
||||
|
||||
# PyAudio setup
|
||||
if PYAUDIO_AVAILABLE:
|
||||
self.pa = pyaudio.PyAudio()
|
||||
self.stream = None
|
||||
|
||||
# Stats
|
||||
self.bytes_received = 0
|
||||
|
||||
# TTFB tracking (Time to First Byte)
|
||||
self.request_start_time = None
|
||||
self.first_audio_received = False
|
||||
|
||||
# Interrupt handling - discard audio until next trackStart
|
||||
self._discard_audio = False
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to server."""
|
||||
print(f"Connecting to {self.url}...")
|
||||
self.ws = await websockets.connect(self.url)
|
||||
self.running = True
|
||||
print("Connected!")
|
||||
|
||||
# Send invite
|
||||
await self.ws.send(json.dumps({
|
||||
"command": "invite",
|
||||
"option": {"codec": "pcm", "sampleRate": self.sample_rate}
|
||||
}))
|
||||
print("-> invite")
|
||||
|
||||
async def send_chat(self, text: str):
|
||||
"""Send chat message."""
|
||||
# Reset TTFB tracking for new request
|
||||
self.request_start_time = time.time()
|
||||
self.first_audio_received = False
|
||||
|
||||
await self.ws.send(json.dumps({"command": "chat", "text": text}))
|
||||
print(f"-> chat: {text}")
|
||||
|
||||
def play_audio(self, audio_data: bytes):
|
||||
"""Play audio data immediately."""
|
||||
if len(audio_data) == 0:
|
||||
return
|
||||
|
||||
if PYAUDIO_AVAILABLE:
|
||||
# Use PyAudio - more reliable on Windows
|
||||
if self.stream is None:
|
||||
self.stream = self.pa.open(
|
||||
format=pyaudio.paInt16,
|
||||
channels=1,
|
||||
rate=self.sample_rate,
|
||||
output=True,
|
||||
frames_per_buffer=1024
|
||||
)
|
||||
self.stream.write(audio_data)
|
||||
elif SD_AVAILABLE:
|
||||
# Use sounddevice
|
||||
samples = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32767.0
|
||||
sd.play(samples, self.sample_rate, blocking=True)
|
||||
|
||||
async def receive_loop(self):
|
||||
"""Receive and play audio."""
|
||||
print("\nWaiting for response...")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
msg = await asyncio.wait_for(self.ws.recv(), timeout=0.1)
|
||||
|
||||
if isinstance(msg, bytes):
|
||||
# Audio data
|
||||
self.bytes_received += len(msg)
|
||||
duration_ms = len(msg) / (self.sample_rate * 2) * 1000
|
||||
|
||||
# Check if we should discard this audio (after interrupt)
|
||||
if self._discard_audio:
|
||||
print(f"<- audio: {len(msg)} bytes ({duration_ms:.0f}ms) [DISCARDED]")
|
||||
continue
|
||||
|
||||
# Calculate and display TTFB for first audio packet
|
||||
if not self.first_audio_received and self.request_start_time:
|
||||
client_ttfb_ms = (time.time() - self.request_start_time) * 1000
|
||||
self.first_audio_received = True
|
||||
print(f"<- [TTFB] Client first audio latency: {client_ttfb_ms:.0f}ms")
|
||||
|
||||
print(f"<- audio: {len(msg)} bytes ({duration_ms:.0f}ms)")
|
||||
|
||||
# Play immediately in executor to not block
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, self.play_audio, msg)
|
||||
else:
|
||||
# JSON event
|
||||
event = json.loads(msg)
|
||||
etype = event.get("event", "?")
|
||||
|
||||
if etype == "transcript":
|
||||
# User speech transcription
|
||||
text = event.get("text", "")
|
||||
is_final = event.get("isFinal", False)
|
||||
if is_final:
|
||||
print(f"<- You said: {text}")
|
||||
else:
|
||||
print(f"<- [listening] {text}", end="\r")
|
||||
elif etype == "ttfb":
|
||||
# Server-side TTFB event
|
||||
latency_ms = event.get("latencyMs", 0)
|
||||
print(f"<- [TTFB] Server reported latency: {latency_ms}ms")
|
||||
elif etype == "trackStart":
|
||||
# New track starting - accept audio again
|
||||
self._discard_audio = False
|
||||
print(f"<- {etype}")
|
||||
elif etype == "interrupt":
|
||||
# Interrupt - discard audio until next trackStart
|
||||
self._discard_audio = True
|
||||
print(f"<- {etype} (discarding audio until new track)")
|
||||
elif etype == "hangup":
|
||||
print(f"<- {etype}")
|
||||
self.running = False
|
||||
break
|
||||
else:
|
||||
print(f"<- {etype}")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except websockets.ConnectionClosed:
|
||||
print("Connection closed")
|
||||
self.running = False
|
||||
break
|
||||
|
||||
async def run(self, text: str = None):
|
||||
"""Run the client."""
|
||||
try:
|
||||
await self.connect()
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Start receiver
|
||||
recv_task = asyncio.create_task(self.receive_loop())
|
||||
|
||||
if text:
|
||||
await self.send_chat(text)
|
||||
# Wait for response
|
||||
await asyncio.sleep(30)
|
||||
else:
|
||||
# Interactive mode
|
||||
print("\nType a message and press Enter (or 'quit' to exit):")
|
||||
while self.running:
|
||||
try:
|
||||
user_input = await asyncio.get_event_loop().run_in_executor(
|
||||
None, input, "> "
|
||||
)
|
||||
if user_input.lower() == 'quit':
|
||||
break
|
||||
if user_input.strip():
|
||||
await self.send_chat(user_input)
|
||||
except EOFError:
|
||||
break
|
||||
|
||||
self.running = False
|
||||
recv_task.cancel()
|
||||
try:
|
||||
await recv_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
finally:
|
||||
await self.close()
|
||||
|
||||
async def close(self):
|
||||
"""Close connections."""
|
||||
self.running = False
|
||||
|
||||
if PYAUDIO_AVAILABLE:
|
||||
if self.stream:
|
||||
self.stream.stop_stream()
|
||||
self.stream.close()
|
||||
self.pa.terminate()
|
||||
|
||||
if self.ws:
|
||||
await self.ws.close()
|
||||
|
||||
print(f"\nTotal audio received: {self.bytes_received / 1024:.1f} KB")
|
||||
|
||||
|
||||
def list_audio_devices():
|
||||
"""List available audio devices."""
|
||||
print("\n=== Audio Devices ===")
|
||||
|
||||
if PYAUDIO_AVAILABLE:
|
||||
pa = pyaudio.PyAudio()
|
||||
print("\nPyAudio devices:")
|
||||
for i in range(pa.get_device_count()):
|
||||
info = pa.get_device_info_by_index(i)
|
||||
if info['maxOutputChannels'] > 0:
|
||||
default = " [DEFAULT]" if i == pa.get_default_output_device_info()['index'] else ""
|
||||
print(f" {i}: {info['name']}{default}")
|
||||
pa.terminate()
|
||||
|
||||
if SD_AVAILABLE:
|
||||
print("\nSounddevice devices:")
|
||||
for i, d in enumerate(sd.query_devices()):
|
||||
if d['max_output_channels'] > 0:
|
||||
default = " [DEFAULT]" if i == sd.default.device[1] else ""
|
||||
print(f" {i}: {d['name']}{default}")
|
||||
|
||||
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser(description="Simple voice client")
|
||||
parser.add_argument("--url", default="ws://localhost:8000/ws")
|
||||
parser.add_argument("--text", help="Send text and play response")
|
||||
parser.add_argument("--list-devices", action="store_true")
|
||||
parser.add_argument("--sample-rate", type=int, default=16000)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.list_devices:
|
||||
list_audio_devices()
|
||||
return
|
||||
|
||||
client = SimpleVoiceClient(args.url, args.sample_rate)
|
||||
await client.run(args.text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
176
examples/test_websocket.py
Normal file
176
examples/test_websocket.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""WebSocket endpoint test client.
|
||||
|
||||
Tests the /ws endpoint with sine wave or file audio streaming.
|
||||
Based on reference/py-active-call/exec/test_ws_endpoint/test_ws.py
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import json
|
||||
import struct
|
||||
import math
|
||||
import argparse
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
# Configuration
|
||||
SERVER_URL = "ws://localhost:8000/ws"
|
||||
SAMPLE_RATE = 16000
|
||||
FREQUENCY = 440 # 440Hz Sine Wave
|
||||
CHUNK_DURATION_MS = 20
|
||||
# 16kHz * 16-bit (2 bytes) * 20ms = 640 bytes per chunk
|
||||
CHUNK_SIZE_BYTES = int(SAMPLE_RATE * 2 * (CHUNK_DURATION_MS / 1000.0))
|
||||
|
||||
|
||||
def generate_sine_wave(duration_ms=1000):
|
||||
"""Generates sine wave audio (16kHz mono PCM 16-bit)."""
|
||||
num_samples = int(SAMPLE_RATE * (duration_ms / 1000.0))
|
||||
audio_data = bytearray()
|
||||
|
||||
for x in range(num_samples):
|
||||
# Generate sine wave sample
|
||||
value = int(32767.0 * math.sin(2 * math.pi * FREQUENCY * x / SAMPLE_RATE))
|
||||
# Pack as little-endian 16-bit integer
|
||||
audio_data.extend(struct.pack('<h', value))
|
||||
|
||||
return audio_data
|
||||
|
||||
|
||||
async def receive_loop(ws, ready_event: asyncio.Event):
|
||||
"""Listen for incoming messages from the server."""
|
||||
print("👂 Listening for server responses...")
|
||||
async for msg in ws:
|
||||
timestamp = datetime.now().strftime("%H:%M:%S")
|
||||
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
try:
|
||||
data = json.loads(msg.data)
|
||||
event_type = data.get('type', 'Unknown')
|
||||
print(f"[{timestamp}] 📨 Event: {event_type} | {msg.data[:150]}...")
|
||||
if event_type == "session.started":
|
||||
ready_event.set()
|
||||
except json.JSONDecodeError:
|
||||
print(f"[{timestamp}] 📨 Text: {msg.data[:100]}...")
|
||||
|
||||
elif msg.type == aiohttp.WSMsgType.BINARY:
|
||||
# Received audio chunk back (e.g., TTS or echo)
|
||||
print(f"[{timestamp}] 🔊 Audio: {len(msg.data)} bytes", end="\r")
|
||||
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
||||
print(f"\n[{timestamp}] ❌ Socket Closed")
|
||||
break
|
||||
|
||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||
print(f"\n[{timestamp}] ⚠️ Socket Error")
|
||||
break
|
||||
|
||||
|
||||
async def send_file_loop(ws, file_path):
|
||||
"""Stream a raw PCM/WAV file to the server."""
|
||||
if not os.path.exists(file_path):
|
||||
print(f"❌ Error: File '{file_path}' not found.")
|
||||
return
|
||||
|
||||
print(f"📂 Streaming file: {file_path} ...")
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
# Skip WAV header if present (first 44 bytes)
|
||||
if file_path.endswith('.wav'):
|
||||
f.read(44)
|
||||
|
||||
while True:
|
||||
chunk = f.read(CHUNK_SIZE_BYTES)
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
# Send binary frame
|
||||
await ws.send_bytes(chunk)
|
||||
|
||||
# Sleep to simulate real-time playback
|
||||
await asyncio.sleep(CHUNK_DURATION_MS / 1000.0)
|
||||
|
||||
print(f"\n✅ Finished streaming {file_path}")
|
||||
|
||||
|
||||
async def send_sine_loop(ws):
|
||||
"""Stream generated sine wave to the server."""
|
||||
print("🎙️ Starting Audio Stream (Sine Wave)...")
|
||||
|
||||
# Generate 10 seconds of audio buffer
|
||||
audio_buffer = generate_sine_wave(5000)
|
||||
cursor = 0
|
||||
|
||||
while cursor < len(audio_buffer):
|
||||
chunk = audio_buffer[cursor:cursor + CHUNK_SIZE_BYTES]
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
await ws.send_bytes(chunk)
|
||||
cursor += len(chunk)
|
||||
|
||||
await asyncio.sleep(CHUNK_DURATION_MS / 1000.0)
|
||||
|
||||
print("\n✅ Finished streaming test audio.")
|
||||
|
||||
|
||||
async def run_client(url, file_path=None, use_sine=False):
|
||||
"""Run the WebSocket test client."""
|
||||
session = aiohttp.ClientSession()
|
||||
try:
|
||||
print(f"🔌 Connecting to {url}...")
|
||||
async with session.ws_connect(url) as ws:
|
||||
print("✅ Connected!")
|
||||
session_ready = asyncio.Event()
|
||||
recv_task = asyncio.create_task(receive_loop(ws, session_ready))
|
||||
|
||||
# Send v1 hello + session.start handshake
|
||||
await ws.send_json({"type": "hello", "version": "v1"})
|
||||
await ws.send_json({
|
||||
"type": "session.start",
|
||||
"audio": {
|
||||
"encoding": "pcm_s16le",
|
||||
"sample_rate_hz": SAMPLE_RATE,
|
||||
"channels": 1
|
||||
}
|
||||
})
|
||||
print("📤 Sent v1 hello/session.start")
|
||||
await asyncio.wait_for(session_ready.wait(), timeout=8)
|
||||
|
||||
# Select sender based on args
|
||||
if use_sine:
|
||||
await send_sine_loop(ws)
|
||||
elif file_path:
|
||||
await send_file_loop(ws, file_path)
|
||||
else:
|
||||
# Default to sine wave
|
||||
await send_sine_loop(ws)
|
||||
|
||||
await ws.send_json({"type": "session.stop", "reason": "test_complete"})
|
||||
await asyncio.sleep(1)
|
||||
recv_task.cancel()
|
||||
try:
|
||||
await recv_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
except aiohttp.ClientConnectorError:
|
||||
print(f"❌ Connection Failed. Is the server running at {url}?")
|
||||
except asyncio.TimeoutError:
|
||||
print("❌ Timeout waiting for session.started")
|
||||
except Exception as e:
|
||||
print(f"❌ Error: {e}")
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="WebSocket Audio Test Client")
|
||||
parser.add_argument("--url", default=SERVER_URL, help="WebSocket endpoint URL")
|
||||
parser.add_argument("--file", help="Path to PCM/WAV file to stream")
|
||||
parser.add_argument("--sine", action="store_true", help="Use sine wave generation (default)")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
asyncio.run(run_client(args.url, args.file, args.sine))
|
||||
except KeyboardInterrupt:
|
||||
print("\n👋 Client stopped.")
|
||||
504
examples/wav_client.py
Normal file
504
examples/wav_client.py
Normal file
@@ -0,0 +1,504 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
WAV file client for testing duplex voice conversation.
|
||||
|
||||
This client reads audio from a WAV file, sends it to the server,
|
||||
and saves the AI's voice response to an output WAV file.
|
||||
|
||||
Usage:
|
||||
python examples/wav_client.py --input input.wav --output response.wav
|
||||
python examples/wav_client.py --input input.wav --output response.wav --url ws://localhost:8000/ws
|
||||
python examples/wav_client.py --input input.wav --output response.wav --wait-time 10
|
||||
python wav_client.py --input ../data/audio_examples/two_utterances.wav -o response.wav
|
||||
Requirements:
|
||||
pip install soundfile websockets numpy
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
import wave
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
print("Please install numpy: pip install numpy")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
import soundfile as sf
|
||||
except ImportError:
|
||||
print("Please install soundfile: pip install soundfile")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
import websockets
|
||||
except ImportError:
|
||||
print("Please install websockets: pip install websockets")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
class WavFileClient:
|
||||
"""
|
||||
WAV file client for voice conversation testing.
|
||||
|
||||
Features:
|
||||
- Read audio from WAV file
|
||||
- Send audio to WebSocket server
|
||||
- Receive and save response audio
|
||||
- Event logging
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
input_file: str,
|
||||
output_file: str,
|
||||
sample_rate: int = 16000,
|
||||
chunk_duration_ms: int = 20,
|
||||
wait_time: float = 15.0,
|
||||
verbose: bool = False
|
||||
):
|
||||
"""
|
||||
Initialize WAV file client.
|
||||
|
||||
Args:
|
||||
url: WebSocket server URL
|
||||
input_file: Input WAV file path
|
||||
output_file: Output WAV file path
|
||||
sample_rate: Audio sample rate (Hz)
|
||||
chunk_duration_ms: Audio chunk duration (ms) for sending
|
||||
wait_time: Time to wait for response after sending (seconds)
|
||||
verbose: Enable verbose output
|
||||
"""
|
||||
self.url = url
|
||||
self.input_file = Path(input_file)
|
||||
self.output_file = Path(output_file)
|
||||
self.sample_rate = sample_rate
|
||||
self.chunk_duration_ms = chunk_duration_ms
|
||||
self.chunk_samples = int(sample_rate * chunk_duration_ms / 1000)
|
||||
self.wait_time = wait_time
|
||||
self.verbose = verbose
|
||||
|
||||
# WebSocket connection
|
||||
self.ws = None
|
||||
self.running = False
|
||||
|
||||
# Audio buffers
|
||||
self.received_audio = bytearray()
|
||||
|
||||
# Statistics
|
||||
self.bytes_sent = 0
|
||||
self.bytes_received = 0
|
||||
|
||||
# TTFB tracking (per response)
|
||||
self.send_start_time = None
|
||||
self.response_start_time = None # set on each trackStart
|
||||
self.waiting_for_first_audio = False
|
||||
self.ttfb_ms = None # last TTFB for summary
|
||||
self.ttfb_list = [] # TTFB for each response
|
||||
|
||||
# State tracking
|
||||
self.track_started = False
|
||||
self.track_ended = False
|
||||
self.send_completed = False
|
||||
|
||||
# Events log
|
||||
self.events_log = []
|
||||
|
||||
def log_event(self, direction: str, message: str):
|
||||
"""Log an event with timestamp."""
|
||||
timestamp = time.time()
|
||||
self.events_log.append({
|
||||
"timestamp": timestamp,
|
||||
"direction": direction,
|
||||
"message": message
|
||||
})
|
||||
# Handle encoding errors on Windows
|
||||
try:
|
||||
print(f"{direction} {message}")
|
||||
except UnicodeEncodeError:
|
||||
# Replace problematic characters for console output
|
||||
safe_message = message.encode('ascii', errors='replace').decode('ascii')
|
||||
print(f"{direction} {safe_message}")
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Connect to WebSocket server."""
|
||||
self.log_event("→", f"Connecting to {self.url}...")
|
||||
self.ws = await websockets.connect(self.url)
|
||||
self.running = True
|
||||
self.log_event("←", "Connected!")
|
||||
|
||||
# Send invite command
|
||||
await self.send_command({
|
||||
"command": "invite",
|
||||
"option": {
|
||||
"codec": "pcm",
|
||||
"sampleRate": self.sample_rate
|
||||
}
|
||||
})
|
||||
|
||||
async def send_command(self, cmd: dict) -> None:
|
||||
"""Send JSON command to server."""
|
||||
if self.ws:
|
||||
await self.ws.send(json.dumps(cmd))
|
||||
self.log_event("→", f"Command: {cmd.get('command', 'unknown')}")
|
||||
|
||||
async def send_hangup(self, reason: str = "Session complete") -> None:
|
||||
"""Send hangup command."""
|
||||
await self.send_command({
|
||||
"command": "hangup",
|
||||
"reason": reason
|
||||
})
|
||||
|
||||
def load_wav_file(self) -> tuple[np.ndarray, int]:
|
||||
"""
|
||||
Load and prepare WAV file for sending.
|
||||
|
||||
Returns:
|
||||
Tuple of (audio_data as int16 numpy array, original sample rate)
|
||||
"""
|
||||
if not self.input_file.exists():
|
||||
raise FileNotFoundError(f"Input file not found: {self.input_file}")
|
||||
|
||||
# Load audio file
|
||||
audio_data, file_sample_rate = sf.read(self.input_file)
|
||||
self.log_event("→", f"Loaded: {self.input_file}")
|
||||
self.log_event("→", f" Original sample rate: {file_sample_rate} Hz")
|
||||
self.log_event("→", f" Duration: {len(audio_data) / file_sample_rate:.2f}s")
|
||||
|
||||
# Convert stereo to mono if needed
|
||||
if len(audio_data.shape) > 1:
|
||||
audio_data = audio_data.mean(axis=1)
|
||||
self.log_event("→", " Converted stereo to mono")
|
||||
|
||||
# Resample if needed
|
||||
if file_sample_rate != self.sample_rate:
|
||||
# Simple resampling using numpy
|
||||
duration = len(audio_data) / file_sample_rate
|
||||
num_samples = int(duration * self.sample_rate)
|
||||
indices = np.linspace(0, len(audio_data) - 1, num_samples)
|
||||
audio_data = np.interp(indices, np.arange(len(audio_data)), audio_data)
|
||||
self.log_event("→", f" Resampled to {self.sample_rate} Hz")
|
||||
|
||||
# Convert to int16
|
||||
if audio_data.dtype != np.int16:
|
||||
# Normalize to [-1, 1] if needed
|
||||
max_val = np.max(np.abs(audio_data))
|
||||
if max_val > 1.0:
|
||||
audio_data = audio_data / max_val
|
||||
audio_data = (audio_data * 32767).astype(np.int16)
|
||||
|
||||
self.log_event("→", f" Prepared: {len(audio_data)} samples ({len(audio_data)/self.sample_rate:.2f}s)")
|
||||
|
||||
return audio_data, file_sample_rate
|
||||
|
||||
async def audio_sender(self, audio_data: np.ndarray) -> None:
|
||||
"""Send audio data to server in chunks."""
|
||||
total_samples = len(audio_data)
|
||||
chunk_size = self.chunk_samples
|
||||
sent_samples = 0
|
||||
|
||||
self.send_start_time = time.time()
|
||||
self.log_event("→", f"Starting audio transmission ({total_samples} samples)...")
|
||||
|
||||
while sent_samples < total_samples and self.running:
|
||||
# Get next chunk
|
||||
end_sample = min(sent_samples + chunk_size, total_samples)
|
||||
chunk = audio_data[sent_samples:end_sample]
|
||||
chunk_bytes = chunk.tobytes()
|
||||
|
||||
# Send to server
|
||||
if self.ws:
|
||||
await self.ws.send(chunk_bytes)
|
||||
self.bytes_sent += len(chunk_bytes)
|
||||
|
||||
sent_samples = end_sample
|
||||
|
||||
# Progress logging (every 500ms worth of audio)
|
||||
if self.verbose and sent_samples % (self.sample_rate // 2) == 0:
|
||||
progress = (sent_samples / total_samples) * 100
|
||||
print(f" Sending: {progress:.0f}%", end="\r")
|
||||
|
||||
# Delay to simulate real-time streaming
|
||||
# Server expects audio at real-time pace for VAD/ASR to work properly
|
||||
await asyncio.sleep(self.chunk_duration_ms / 1000)
|
||||
|
||||
self.send_completed = True
|
||||
elapsed = time.time() - self.send_start_time
|
||||
self.log_event("→", f"Audio transmission complete ({elapsed:.2f}s, {self.bytes_sent/1024:.1f} KB)")
|
||||
|
||||
async def receiver(self) -> None:
|
||||
"""Receive messages from server."""
|
||||
try:
|
||||
while self.running:
|
||||
try:
|
||||
message = await asyncio.wait_for(self.ws.recv(), timeout=0.1)
|
||||
|
||||
if isinstance(message, bytes):
|
||||
# Audio data received
|
||||
self.bytes_received += len(message)
|
||||
self.received_audio.extend(message)
|
||||
|
||||
# Calculate TTFB on first audio of each response
|
||||
if self.waiting_for_first_audio and self.response_start_time is not None:
|
||||
ttfb_ms = (time.time() - self.response_start_time) * 1000
|
||||
self.ttfb_ms = ttfb_ms
|
||||
self.ttfb_list.append(ttfb_ms)
|
||||
self.waiting_for_first_audio = False
|
||||
self.log_event("←", f"[TTFB] First audio latency: {ttfb_ms:.0f}ms")
|
||||
|
||||
# Log progress
|
||||
duration_ms = len(message) / (self.sample_rate * 2) * 1000
|
||||
total_ms = len(self.received_audio) / (self.sample_rate * 2) * 1000
|
||||
if self.verbose:
|
||||
print(f"← Audio: +{duration_ms:.0f}ms (total: {total_ms:.0f}ms)", end="\r")
|
||||
|
||||
else:
|
||||
# JSON event
|
||||
event = json.loads(message)
|
||||
await self._handle_event(event)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except websockets.ConnectionClosed:
|
||||
self.log_event("←", "Connection closed")
|
||||
self.running = False
|
||||
break
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
self.log_event("!", f"Receiver error: {e}")
|
||||
self.running = False
|
||||
|
||||
async def _handle_event(self, event: dict) -> None:
|
||||
"""Handle incoming event."""
|
||||
event_type = event.get("event", "unknown")
|
||||
|
||||
if event_type == "answer":
|
||||
self.log_event("←", "Session ready!")
|
||||
elif event_type == "speaking":
|
||||
self.log_event("←", "Speech detected")
|
||||
elif event_type == "silence":
|
||||
self.log_event("←", "Silence detected")
|
||||
elif event_type == "transcript":
|
||||
# ASR transcript (interim = asrDelta-style, final = asrFinal-style)
|
||||
text = event.get("text", "")
|
||||
is_final = event.get("isFinal", False)
|
||||
if is_final:
|
||||
# Clear interim line and print final
|
||||
print(" " * 80, end="\r")
|
||||
self.log_event("←", f"→ You: {text}")
|
||||
else:
|
||||
# Interim result - show with indicator (overwrite same line, as in mic_client)
|
||||
display_text = text[:60] + "..." if len(text) > 60 else text
|
||||
print(f" [listening] {display_text}".ljust(80), end="\r")
|
||||
elif event_type == "ttfb":
|
||||
latency_ms = event.get("latencyMs", 0)
|
||||
self.log_event("←", f"[TTFB] Server latency: {latency_ms}ms")
|
||||
elif event_type == "llmResponse":
|
||||
text = event.get("text", "")
|
||||
is_final = event.get("isFinal", False)
|
||||
if is_final:
|
||||
self.log_event("←", f"LLM Response (final): {text[:100]}{'...' if len(text) > 100 else ''}")
|
||||
elif self.verbose:
|
||||
# Show streaming chunks only in verbose mode
|
||||
self.log_event("←", f"LLM: {text}")
|
||||
elif event_type == "trackStart":
|
||||
self.track_started = True
|
||||
self.response_start_time = time.time()
|
||||
self.waiting_for_first_audio = True
|
||||
self.log_event("←", "Bot started speaking")
|
||||
elif event_type == "trackEnd":
|
||||
self.track_ended = True
|
||||
self.log_event("←", "Bot finished speaking")
|
||||
elif event_type == "interrupt":
|
||||
self.log_event("←", "Bot interrupted!")
|
||||
elif event_type == "error":
|
||||
self.log_event("!", f"Error: {event.get('error')}")
|
||||
elif event_type == "hangup":
|
||||
self.log_event("←", f"Hangup: {event.get('reason')}")
|
||||
self.running = False
|
||||
else:
|
||||
self.log_event("←", f"Event: {event_type}")
|
||||
|
||||
def save_output_wav(self) -> None:
|
||||
"""Save received audio to output WAV file."""
|
||||
if not self.received_audio:
|
||||
self.log_event("!", "No audio received to save")
|
||||
return
|
||||
|
||||
# Convert bytes to numpy array
|
||||
audio_data = np.frombuffer(bytes(self.received_audio), dtype=np.int16)
|
||||
|
||||
# Ensure output directory exists
|
||||
self.output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save using wave module for compatibility
|
||||
with wave.open(str(self.output_file), 'wb') as wav_file:
|
||||
wav_file.setnchannels(1)
|
||||
wav_file.setsampwidth(2) # 16-bit
|
||||
wav_file.setframerate(self.sample_rate)
|
||||
wav_file.writeframes(audio_data.tobytes())
|
||||
|
||||
duration = len(audio_data) / self.sample_rate
|
||||
self.log_event("→", f"Saved output: {self.output_file}")
|
||||
self.log_event("→", f" Duration: {duration:.2f}s ({len(audio_data)} samples)")
|
||||
self.log_event("→", f" Size: {len(self.received_audio)/1024:.1f} KB")
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Run the WAV file test."""
|
||||
try:
|
||||
# Load input WAV file
|
||||
audio_data, _ = self.load_wav_file()
|
||||
|
||||
# Connect to server
|
||||
await self.connect()
|
||||
|
||||
# Wait for answer
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Start receiver task
|
||||
receiver_task = asyncio.create_task(self.receiver())
|
||||
|
||||
# Send audio
|
||||
await self.audio_sender(audio_data)
|
||||
|
||||
# Wait for response
|
||||
self.log_event("→", f"Waiting {self.wait_time}s for response...")
|
||||
|
||||
wait_start = time.time()
|
||||
while self.running and (time.time() - wait_start) < self.wait_time:
|
||||
# Check if track has ended (response complete)
|
||||
if self.track_ended and self.send_completed:
|
||||
# Give a little extra time for any remaining audio
|
||||
await asyncio.sleep(1.0)
|
||||
break
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Cleanup
|
||||
self.running = False
|
||||
receiver_task.cancel()
|
||||
|
||||
try:
|
||||
await receiver_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Save output
|
||||
self.save_output_wav()
|
||||
|
||||
# Print summary
|
||||
self._print_summary()
|
||||
|
||||
except FileNotFoundError as e:
|
||||
print(f"Error: {e}")
|
||||
sys.exit(1)
|
||||
except ConnectionRefusedError:
|
||||
print(f"Error: Could not connect to {self.url}")
|
||||
print("Make sure the server is running.")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
finally:
|
||||
await self.close()
|
||||
|
||||
def _print_summary(self):
|
||||
"""Print session summary."""
|
||||
print("\n" + "=" * 50)
|
||||
print("Session Summary")
|
||||
print("=" * 50)
|
||||
print(f" Input file: {self.input_file}")
|
||||
print(f" Output file: {self.output_file}")
|
||||
print(f" Bytes sent: {self.bytes_sent / 1024:.1f} KB")
|
||||
print(f" Bytes received: {self.bytes_received / 1024:.1f} KB")
|
||||
if self.ttfb_list:
|
||||
if len(self.ttfb_list) == 1:
|
||||
print(f" TTFB: {self.ttfb_list[0]:.0f} ms")
|
||||
else:
|
||||
print(f" TTFB (per response): {', '.join(f'{t:.0f}ms' for t in self.ttfb_list)}")
|
||||
if self.received_audio:
|
||||
duration = len(self.received_audio) / (self.sample_rate * 2)
|
||||
print(f" Response duration: {duration:.2f}s")
|
||||
print("=" * 50)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the connection."""
|
||||
self.running = False
|
||||
if self.ws:
|
||||
try:
|
||||
await self.ws.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="WAV file client for testing duplex voice conversation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input", "-i",
|
||||
required=True,
|
||||
help="Input WAV file path"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output", "-o",
|
||||
required=True,
|
||||
help="Output WAV file path for response"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--url",
|
||||
default="ws://localhost:8000/ws",
|
||||
help="WebSocket server URL (default: ws://localhost:8000/ws)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="Target sample rate for audio (default: 16000)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chunk-duration",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Chunk duration in ms for sending (default: 20)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--wait-time", "-w",
|
||||
type=float,
|
||||
default=15.0,
|
||||
help="Time to wait for response after sending (default: 15.0)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose", "-v",
|
||||
action="store_true",
|
||||
help="Enable verbose output"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
client = WavFileClient(
|
||||
url=args.url,
|
||||
input_file=args.input,
|
||||
output_file=args.output,
|
||||
sample_rate=args.sample_rate,
|
||||
chunk_duration_ms=args.chunk_duration,
|
||||
wait_time=args.wait_time,
|
||||
verbose=args.verbose
|
||||
)
|
||||
|
||||
await client.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
print("\nInterrupted by user")
|
||||
766
examples/web_client.html
Normal file
766
examples/web_client.html
Normal file
@@ -0,0 +1,766 @@
|
||||
<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
<title>Duplex Voice Web Client</title>
|
||||
<style>
|
||||
@import url("https://fonts.googleapis.com/css2?family=Fraunces:opsz,wght@9..144,300;9..144,500;9..144,700&family=Recursive:wght@300;400;600;700&display=swap");
|
||||
|
||||
:root {
|
||||
--bg: #0b0b0f;
|
||||
--panel: #14141c;
|
||||
--panel-2: #101018;
|
||||
--ink: #f2f3f7;
|
||||
--muted: #a7acba;
|
||||
--accent: #ff6b6b;
|
||||
--accent-2: #ffd166;
|
||||
--good: #2dd4bf;
|
||||
--bad: #f87171;
|
||||
--grid: rgba(255, 255, 255, 0.06);
|
||||
--shadow: 0 20px 60px rgba(0, 0, 0, 0.45);
|
||||
}
|
||||
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
html,
|
||||
body {
|
||||
height: 100%;
|
||||
margin: 0;
|
||||
color: var(--ink);
|
||||
background: radial-gradient(1200px 600px at 20% -10%, #1d1d2a 0%, transparent 60%),
|
||||
radial-gradient(800px 800px at 110% 10%, #20203a 0%, transparent 50%),
|
||||
var(--bg);
|
||||
font-family: "Recursive", ui-sans-serif, system-ui, -apple-system, "Segoe UI", sans-serif;
|
||||
}
|
||||
|
||||
.noise {
|
||||
position: fixed;
|
||||
inset: 0;
|
||||
background-image: url("data:image/svg+xml;utf8,<svg xmlns='http://www.w3.org/2000/svg' width='120' height='120' viewBox='0 0 120 120'><filter id='n'><feTurbulence type='fractalNoise' baseFrequency='0.9' numOctaves='2' stitchTiles='stitch'/></filter><rect width='120' height='120' filter='url(%23n)' opacity='0.06'/></svg>");
|
||||
pointer-events: none;
|
||||
mix-blend-mode: soft-light;
|
||||
}
|
||||
|
||||
header {
|
||||
padding: 32px 28px 18px;
|
||||
border-bottom: 1px solid var(--grid);
|
||||
}
|
||||
|
||||
h1 {
|
||||
font-family: "Fraunces", serif;
|
||||
font-weight: 600;
|
||||
margin: 0 0 6px;
|
||||
letter-spacing: 0.4px;
|
||||
}
|
||||
|
||||
.subtitle {
|
||||
color: var(--muted);
|
||||
font-size: 0.95rem;
|
||||
}
|
||||
|
||||
main {
|
||||
display: grid;
|
||||
grid-template-columns: 1.1fr 1.4fr;
|
||||
gap: 24px;
|
||||
padding: 24px 28px 40px;
|
||||
}
|
||||
|
||||
.panel {
|
||||
background: linear-gradient(180deg, rgba(255, 255, 255, 0.02), transparent),
|
||||
var(--panel);
|
||||
border: 1px solid var(--grid);
|
||||
border-radius: 16px;
|
||||
padding: 20px;
|
||||
box-shadow: var(--shadow);
|
||||
}
|
||||
|
||||
.panel h2 {
|
||||
margin: 0 0 12px;
|
||||
font-size: 1.05rem;
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.stack {
|
||||
display: grid;
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
label {
|
||||
display: block;
|
||||
font-size: 0.85rem;
|
||||
color: var(--muted);
|
||||
margin-bottom: 6px;
|
||||
}
|
||||
|
||||
input,
|
||||
select,
|
||||
button,
|
||||
textarea {
|
||||
font-family: inherit;
|
||||
}
|
||||
|
||||
input,
|
||||
select,
|
||||
textarea {
|
||||
width: 100%;
|
||||
padding: 10px 12px;
|
||||
border-radius: 10px;
|
||||
border: 1px solid var(--grid);
|
||||
background: var(--panel-2);
|
||||
color: var(--ink);
|
||||
outline: none;
|
||||
}
|
||||
|
||||
textarea {
|
||||
min-height: 80px;
|
||||
resize: vertical;
|
||||
}
|
||||
|
||||
.row {
|
||||
display: grid;
|
||||
grid-template-columns: 1fr 1fr;
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.btn-row {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: 10px;
|
||||
}
|
||||
|
||||
button {
|
||||
border: none;
|
||||
border-radius: 999px;
|
||||
padding: 10px 16px;
|
||||
font-weight: 600;
|
||||
background: var(--ink);
|
||||
color: #111;
|
||||
cursor: pointer;
|
||||
transition: transform 0.2s ease, box-shadow 0.2s ease;
|
||||
}
|
||||
|
||||
button.secondary {
|
||||
background: transparent;
|
||||
color: var(--ink);
|
||||
border: 1px solid var(--grid);
|
||||
}
|
||||
|
||||
button.accent {
|
||||
background: linear-gradient(120deg, var(--accent), #f97316);
|
||||
color: #0b0b0f;
|
||||
}
|
||||
|
||||
button.good {
|
||||
background: linear-gradient(120deg, var(--good), #22c55e);
|
||||
color: #07261f;
|
||||
}
|
||||
|
||||
button.bad {
|
||||
background: linear-gradient(120deg, var(--bad), #f97316);
|
||||
color: #2a0b0b;
|
||||
}
|
||||
|
||||
button:active {
|
||||
transform: translateY(1px) scale(0.99);
|
||||
}
|
||||
|
||||
.status {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 12px;
|
||||
padding: 12px;
|
||||
background: rgba(255, 255, 255, 0.03);
|
||||
border-radius: 12px;
|
||||
border: 1px dashed var(--grid);
|
||||
font-size: 0.9rem;
|
||||
}
|
||||
|
||||
.dot {
|
||||
width: 10px;
|
||||
height: 10px;
|
||||
border-radius: 999px;
|
||||
background: var(--bad);
|
||||
box-shadow: 0 0 12px rgba(248, 113, 113, 0.5);
|
||||
}
|
||||
|
||||
.dot.on {
|
||||
background: var(--good);
|
||||
box-shadow: 0 0 12px rgba(45, 212, 191, 0.7);
|
||||
}
|
||||
|
||||
.log {
|
||||
height: 320px;
|
||||
overflow: auto;
|
||||
padding: 12px;
|
||||
background: #0d0d14;
|
||||
border-radius: 12px;
|
||||
border: 1px solid var(--grid);
|
||||
font-size: 0.85rem;
|
||||
line-height: 1.4;
|
||||
}
|
||||
|
||||
.chat {
|
||||
height: 260px;
|
||||
overflow: auto;
|
||||
padding: 12px;
|
||||
background: #0d0d14;
|
||||
border-radius: 12px;
|
||||
border: 1px solid var(--grid);
|
||||
font-size: 0.9rem;
|
||||
line-height: 1.45;
|
||||
}
|
||||
|
||||
.chat-entry {
|
||||
padding: 8px 10px;
|
||||
margin-bottom: 8px;
|
||||
border-radius: 10px;
|
||||
background: rgba(255, 255, 255, 0.04);
|
||||
border: 1px solid rgba(255, 255, 255, 0.06);
|
||||
}
|
||||
|
||||
.chat-entry.user {
|
||||
border-left: 3px solid var(--accent-2);
|
||||
}
|
||||
|
||||
.chat-entry.ai {
|
||||
border-left: 3px solid var(--good);
|
||||
}
|
||||
|
||||
.chat-entry.interim {
|
||||
opacity: 0.7;
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.log-entry {
|
||||
padding: 6px 8px;
|
||||
border-bottom: 1px dashed rgba(255, 255, 255, 0.06);
|
||||
}
|
||||
|
||||
.log-entry:last-child {
|
||||
border-bottom: none;
|
||||
}
|
||||
|
||||
.tag {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
padding: 2px 8px;
|
||||
border-radius: 999px;
|
||||
font-size: 0.7rem;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.6px;
|
||||
background: rgba(255, 255, 255, 0.08);
|
||||
color: var(--muted);
|
||||
}
|
||||
|
||||
.tag.event {
|
||||
background: rgba(255, 107, 107, 0.18);
|
||||
color: #ffc1c1;
|
||||
}
|
||||
|
||||
.tag.audio {
|
||||
background: rgba(45, 212, 191, 0.2);
|
||||
color: #c5f9f0;
|
||||
}
|
||||
|
||||
.tag.sys {
|
||||
background: rgba(255, 209, 102, 0.2);
|
||||
color: #ffefb0;
|
||||
}
|
||||
|
||||
.muted {
|
||||
color: var(--muted);
|
||||
}
|
||||
|
||||
footer {
|
||||
padding: 0 28px 28px;
|
||||
color: var(--muted);
|
||||
font-size: 0.8rem;
|
||||
}
|
||||
|
||||
@media (max-width: 1100px) {
|
||||
main {
|
||||
grid-template-columns: 1fr;
|
||||
}
|
||||
.log {
|
||||
height: 360px;
|
||||
}
|
||||
.chat {
|
||||
height: 260px;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="noise"></div>
|
||||
<header>
|
||||
<h1>Duplex Voice Client</h1>
|
||||
<div class="subtitle">Browser client for the WebSocket duplex pipeline. Device selection + event logging.</div>
|
||||
</header>
|
||||
|
||||
<main>
|
||||
<section class="panel stack">
|
||||
<h2>Connection</h2>
|
||||
<div>
|
||||
<label for="wsUrl">WebSocket URL</label>
|
||||
<input id="wsUrl" value="ws://localhost:8000/ws" />
|
||||
</div>
|
||||
<div class="btn-row">
|
||||
<button class="accent" id="connectBtn">Connect</button>
|
||||
<button class="secondary" id="disconnectBtn">Disconnect</button>
|
||||
</div>
|
||||
<div class="status">
|
||||
<div id="statusDot" class="dot"></div>
|
||||
<div>
|
||||
<div id="statusText">Disconnected</div>
|
||||
<div class="muted" id="statusSub">Waiting for connection</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<h2>Devices</h2>
|
||||
<div class="row">
|
||||
<div>
|
||||
<label for="inputSelect">Input (Mic)</label>
|
||||
<select id="inputSelect"></select>
|
||||
</div>
|
||||
<div>
|
||||
<label for="outputSelect">Output (Speaker)</label>
|
||||
<select id="outputSelect"></select>
|
||||
</div>
|
||||
</div>
|
||||
<div class="btn-row">
|
||||
<button class="secondary" id="refreshDevicesBtn">Refresh Devices</button>
|
||||
<button class="good" id="startMicBtn">Start Mic</button>
|
||||
<button class="secondary" id="stopMicBtn">Stop Mic</button>
|
||||
</div>
|
||||
|
||||
<h2>Chat</h2>
|
||||
<div class="stack">
|
||||
<textarea id="chatInput" placeholder="Type a message, press Send"></textarea>
|
||||
<div class="btn-row">
|
||||
<button class="accent" id="sendChatBtn">Send Chat</button>
|
||||
<button class="secondary" id="clearLogBtn">Clear Log</button>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<section class="stack">
|
||||
<div class="panel stack">
|
||||
<h2>Chat History</h2>
|
||||
<div class="chat" id="chatHistory"></div>
|
||||
</div>
|
||||
<div class="panel stack">
|
||||
<h2>Event Log</h2>
|
||||
<div class="log" id="log"></div>
|
||||
</div>
|
||||
</section>
|
||||
</main>
|
||||
|
||||
<footer>
|
||||
Output device selection requires HTTPS + a browser that supports <code>setSinkId</code>.
|
||||
Audio is sent as 16-bit PCM @ 16 kHz, matching <code>examples/mic_client.py</code>.
|
||||
</footer>
|
||||
|
||||
<audio id="audioOut" autoplay></audio>
|
||||
|
||||
<script>
|
||||
const wsUrl = document.getElementById("wsUrl");
|
||||
const connectBtn = document.getElementById("connectBtn");
|
||||
const disconnectBtn = document.getElementById("disconnectBtn");
|
||||
const inputSelect = document.getElementById("inputSelect");
|
||||
const outputSelect = document.getElementById("outputSelect");
|
||||
const startMicBtn = document.getElementById("startMicBtn");
|
||||
const stopMicBtn = document.getElementById("stopMicBtn");
|
||||
const refreshDevicesBtn = document.getElementById("refreshDevicesBtn");
|
||||
const sendChatBtn = document.getElementById("sendChatBtn");
|
||||
const clearLogBtn = document.getElementById("clearLogBtn");
|
||||
const chatInput = document.getElementById("chatInput");
|
||||
const logEl = document.getElementById("log");
|
||||
const chatHistory = document.getElementById("chatHistory");
|
||||
const statusDot = document.getElementById("statusDot");
|
||||
const statusText = document.getElementById("statusText");
|
||||
const statusSub = document.getElementById("statusSub");
|
||||
const audioOut = document.getElementById("audioOut");
|
||||
|
||||
let ws = null;
|
||||
let audioCtx = null;
|
||||
let micStream = null;
|
||||
let processor = null;
|
||||
let micSource = null;
|
||||
let playbackDest = null;
|
||||
let playbackTime = 0;
|
||||
let discardAudio = false;
|
||||
let playbackSources = [];
|
||||
let interimUserEl = null;
|
||||
let interimAiEl = null;
|
||||
let interimUserText = "";
|
||||
let interimAiText = "";
|
||||
|
||||
const targetSampleRate = 16000;
|
||||
const playbackStopRampSec = 0.008;
|
||||
|
||||
function logLine(type, text, data) {
|
||||
const time = new Date().toLocaleTimeString();
|
||||
const entry = document.createElement("div");
|
||||
entry.className = "log-entry";
|
||||
const tag = document.createElement("span");
|
||||
tag.className = `tag ${type}`;
|
||||
tag.textContent = type.toUpperCase();
|
||||
const msg = document.createElement("span");
|
||||
msg.style.marginLeft = "10px";
|
||||
msg.textContent = `[${time}] ${text}`;
|
||||
entry.appendChild(tag);
|
||||
entry.appendChild(msg);
|
||||
if (data) {
|
||||
const pre = document.createElement("div");
|
||||
pre.className = "muted";
|
||||
pre.textContent = JSON.stringify(data);
|
||||
pre.style.marginTop = "4px";
|
||||
entry.appendChild(pre);
|
||||
}
|
||||
logEl.appendChild(entry);
|
||||
logEl.scrollTop = logEl.scrollHeight;
|
||||
}
|
||||
|
||||
function addChat(role, text) {
|
||||
const entry = document.createElement("div");
|
||||
entry.className = `chat-entry ${role === "AI" ? "ai" : "user"}`;
|
||||
entry.textContent = `${role}: ${text}`;
|
||||
chatHistory.appendChild(entry);
|
||||
chatHistory.scrollTop = chatHistory.scrollHeight;
|
||||
}
|
||||
|
||||
function setInterim(role, text) {
|
||||
const isAi = role === "AI";
|
||||
let el = isAi ? interimAiEl : interimUserEl;
|
||||
if (!text) {
|
||||
if (el) el.remove();
|
||||
if (isAi) interimAiEl = null;
|
||||
else interimUserEl = null;
|
||||
if (isAi) interimAiText = "";
|
||||
else interimUserText = "";
|
||||
return;
|
||||
}
|
||||
if (!el) {
|
||||
el = document.createElement("div");
|
||||
el.className = `chat-entry ${isAi ? "ai" : "user"} interim`;
|
||||
chatHistory.appendChild(el);
|
||||
if (isAi) interimAiEl = el;
|
||||
else interimUserEl = el;
|
||||
}
|
||||
el.textContent = `${role} (interim): ${text}`;
|
||||
chatHistory.scrollTop = chatHistory.scrollHeight;
|
||||
}
|
||||
|
||||
function stopPlayback() {
|
||||
discardAudio = true;
|
||||
const now = audioCtx ? audioCtx.currentTime : 0;
|
||||
playbackTime = now;
|
||||
playbackSources.forEach((node) => {
|
||||
try {
|
||||
if (audioCtx && node.gainNode && node.source) {
|
||||
node.gainNode.gain.cancelScheduledValues(now);
|
||||
node.gainNode.gain.setValueAtTime(node.gainNode.gain.value || 1, now);
|
||||
node.gainNode.gain.linearRampToValueAtTime(0, now + playbackStopRampSec);
|
||||
node.source.stop(now + playbackStopRampSec + 0.002);
|
||||
} else if (node.source) {
|
||||
node.source.stop();
|
||||
}
|
||||
} catch (err) {}
|
||||
});
|
||||
playbackSources = [];
|
||||
}
|
||||
|
||||
function setStatus(connected, detail) {
|
||||
statusDot.classList.toggle("on", connected);
|
||||
statusText.textContent = connected ? "Connected" : "Disconnected";
|
||||
statusSub.textContent = detail || "";
|
||||
}
|
||||
|
||||
async function ensureAudioContext() {
|
||||
if (audioCtx) return;
|
||||
audioCtx = new (window.AudioContext || window.webkitAudioContext)();
|
||||
playbackDest = audioCtx.createMediaStreamDestination();
|
||||
audioOut.srcObject = playbackDest.stream;
|
||||
try {
|
||||
await audioOut.play();
|
||||
} catch (err) {
|
||||
logLine("sys", "Audio playback blocked (user gesture needed)", { err: String(err) });
|
||||
}
|
||||
if (outputSelect.value) {
|
||||
await setOutputDevice(outputSelect.value);
|
||||
}
|
||||
}
|
||||
|
||||
function downsampleBuffer(buffer, inRate, outRate) {
|
||||
if (outRate === inRate) return buffer;
|
||||
const ratio = inRate / outRate;
|
||||
const newLength = Math.round(buffer.length / ratio);
|
||||
const result = new Float32Array(newLength);
|
||||
let offsetResult = 0;
|
||||
let offsetBuffer = 0;
|
||||
while (offsetResult < result.length) {
|
||||
const nextOffsetBuffer = Math.round((offsetResult + 1) * ratio);
|
||||
let accum = 0;
|
||||
let count = 0;
|
||||
for (let i = offsetBuffer; i < nextOffsetBuffer && i < buffer.length; i++) {
|
||||
accum += buffer[i];
|
||||
count++;
|
||||
}
|
||||
result[offsetResult] = accum / count;
|
||||
offsetResult++;
|
||||
offsetBuffer = nextOffsetBuffer;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
function floatTo16BitPCM(float32) {
|
||||
const out = new Int16Array(float32.length);
|
||||
for (let i = 0; i < float32.length; i++) {
|
||||
const s = Math.max(-1, Math.min(1, float32[i]));
|
||||
out[i] = s < 0 ? s * 0x8000 : s * 0x7fff;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
function schedulePlayback(int16Data) {
|
||||
if (!audioCtx || !playbackDest) return;
|
||||
if (discardAudio) return;
|
||||
const float32 = new Float32Array(int16Data.length);
|
||||
for (let i = 0; i < int16Data.length; i++) {
|
||||
float32[i] = int16Data[i] / 32768;
|
||||
}
|
||||
const buffer = audioCtx.createBuffer(1, float32.length, targetSampleRate);
|
||||
buffer.copyToChannel(float32, 0);
|
||||
const source = audioCtx.createBufferSource();
|
||||
const gainNode = audioCtx.createGain();
|
||||
source.buffer = buffer;
|
||||
source.connect(gainNode);
|
||||
gainNode.connect(playbackDest);
|
||||
const startTime = Math.max(audioCtx.currentTime + 0.02, playbackTime);
|
||||
gainNode.gain.setValueAtTime(1, startTime);
|
||||
source.start(startTime);
|
||||
playbackTime = startTime + buffer.duration;
|
||||
const playbackNode = { source, gainNode };
|
||||
playbackSources.push(playbackNode);
|
||||
source.onended = () => {
|
||||
playbackSources = playbackSources.filter((s) => s !== playbackNode);
|
||||
};
|
||||
}
|
||||
|
||||
async function connect() {
|
||||
if (ws && ws.readyState === WebSocket.OPEN) return;
|
||||
ws = new WebSocket(wsUrl.value.trim());
|
||||
ws.binaryType = "arraybuffer";
|
||||
|
||||
ws.onopen = () => {
|
||||
setStatus(true, "Session open");
|
||||
logLine("sys", "WebSocket connected");
|
||||
ensureAudioContext();
|
||||
sendCommand({ type: "hello", version: "v1" });
|
||||
};
|
||||
|
||||
ws.onclose = () => {
|
||||
setStatus(false, "Connection closed");
|
||||
logLine("sys", "WebSocket closed");
|
||||
ws = null;
|
||||
};
|
||||
|
||||
ws.onerror = (err) => {
|
||||
logLine("sys", "WebSocket error", { err: String(err) });
|
||||
};
|
||||
|
||||
ws.onmessage = (msg) => {
|
||||
if (typeof msg.data === "string") {
|
||||
const event = JSON.parse(msg.data);
|
||||
handleEvent(event);
|
||||
} else {
|
||||
const audioBuf = msg.data;
|
||||
const int16 = new Int16Array(audioBuf);
|
||||
schedulePlayback(int16);
|
||||
logLine("audio", `Audio ${Math.round((int16.length / targetSampleRate) * 1000)}ms`);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
function disconnect() {
|
||||
if (ws && ws.readyState === WebSocket.OPEN) {
|
||||
sendCommand({ type: "session.stop", reason: "client_disconnect" });
|
||||
ws.close();
|
||||
}
|
||||
ws = null;
|
||||
setStatus(false, "Disconnected");
|
||||
}
|
||||
|
||||
function sendCommand(cmd) {
|
||||
if (!ws || ws.readyState !== WebSocket.OPEN) {
|
||||
logLine("sys", "Not connected");
|
||||
return;
|
||||
}
|
||||
ws.send(JSON.stringify(cmd));
|
||||
logLine("sys", `→ ${cmd.type}`, cmd);
|
||||
}
|
||||
|
||||
function handleEvent(event) {
|
||||
const type = event.type || "unknown";
|
||||
logLine("event", type, event);
|
||||
if (type === "hello.ack") {
|
||||
sendCommand({
|
||||
type: "session.start",
|
||||
audio: { encoding: "pcm_s16le", sample_rate_hz: targetSampleRate, channels: 1 },
|
||||
});
|
||||
}
|
||||
if (type === "transcript.final") {
|
||||
if (event.text) {
|
||||
setInterim("You", "");
|
||||
addChat("You", event.text);
|
||||
}
|
||||
}
|
||||
if (type === "transcript.delta" && event.text) {
|
||||
setInterim("You", event.text);
|
||||
}
|
||||
if (type === "assistant.response.final") {
|
||||
if (event.text) {
|
||||
setInterim("AI", "");
|
||||
addChat("AI", event.text);
|
||||
}
|
||||
}
|
||||
if (type === "assistant.response.delta" && event.text) {
|
||||
interimAiText += event.text;
|
||||
setInterim("AI", interimAiText);
|
||||
}
|
||||
if (type === "output.audio.start") {
|
||||
// New bot audio: stop any previous playback to avoid overlap
|
||||
stopPlayback();
|
||||
discardAudio = false;
|
||||
interimAiText = "";
|
||||
}
|
||||
if (type === "input.speech_started") {
|
||||
// User started speaking: clear any in-flight audio to avoid overlap
|
||||
stopPlayback();
|
||||
}
|
||||
if (type === "response.interrupted") {
|
||||
stopPlayback();
|
||||
}
|
||||
}
|
||||
|
||||
async function startMic() {
|
||||
if (!ws || ws.readyState !== WebSocket.OPEN) {
|
||||
logLine("sys", "Connect before starting mic");
|
||||
return;
|
||||
}
|
||||
await ensureAudioContext();
|
||||
const deviceId = inputSelect.value || undefined;
|
||||
micStream = await navigator.mediaDevices.getUserMedia({
|
||||
audio: deviceId ? { deviceId: { exact: deviceId } } : true,
|
||||
});
|
||||
micSource = audioCtx.createMediaStreamSource(micStream);
|
||||
processor = audioCtx.createScriptProcessor(2048, 1, 1);
|
||||
processor.onaudioprocess = (e) => {
|
||||
if (!ws || ws.readyState !== WebSocket.OPEN) return;
|
||||
const input = e.inputBuffer.getChannelData(0);
|
||||
const downsampled = downsampleBuffer(input, audioCtx.sampleRate, targetSampleRate);
|
||||
const pcm16 = floatTo16BitPCM(downsampled);
|
||||
ws.send(pcm16.buffer);
|
||||
};
|
||||
micSource.connect(processor);
|
||||
processor.connect(audioCtx.destination);
|
||||
logLine("sys", "Microphone started");
|
||||
}
|
||||
|
||||
function stopMic() {
|
||||
if (processor) {
|
||||
processor.disconnect();
|
||||
processor = null;
|
||||
}
|
||||
if (micSource) {
|
||||
micSource.disconnect();
|
||||
micSource = null;
|
||||
}
|
||||
if (micStream) {
|
||||
micStream.getTracks().forEach((t) => t.stop());
|
||||
micStream = null;
|
||||
}
|
||||
logLine("sys", "Microphone stopped");
|
||||
}
|
||||
|
||||
async function refreshDevices() {
|
||||
const devices = await navigator.mediaDevices.enumerateDevices();
|
||||
inputSelect.innerHTML = "";
|
||||
outputSelect.innerHTML = "";
|
||||
devices.forEach((d) => {
|
||||
if (d.kind === "audioinput") {
|
||||
const opt = document.createElement("option");
|
||||
opt.value = d.deviceId;
|
||||
opt.textContent = d.label || `Mic ${inputSelect.length + 1}`;
|
||||
inputSelect.appendChild(opt);
|
||||
}
|
||||
if (d.kind === "audiooutput") {
|
||||
const opt = document.createElement("option");
|
||||
opt.value = d.deviceId;
|
||||
opt.textContent = d.label || `Output ${outputSelect.length + 1}`;
|
||||
outputSelect.appendChild(opt);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async function requestDeviceAccess() {
|
||||
// Needed to reveal device labels in most browsers
|
||||
try {
|
||||
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
|
||||
stream.getTracks().forEach((t) => t.stop());
|
||||
logLine("sys", "Microphone permission granted");
|
||||
} catch (err) {
|
||||
logLine("sys", "Microphone permission denied", { err: String(err) });
|
||||
}
|
||||
}
|
||||
|
||||
async function setOutputDevice(deviceId) {
|
||||
if (!audioOut.setSinkId) {
|
||||
logLine("sys", "setSinkId not supported in this browser");
|
||||
return;
|
||||
}
|
||||
await audioOut.setSinkId(deviceId);
|
||||
logLine("sys", `Output device set`, { deviceId });
|
||||
}
|
||||
|
||||
connectBtn.addEventListener("click", connect);
|
||||
disconnectBtn.addEventListener("click", disconnect);
|
||||
refreshDevicesBtn.addEventListener("click", async () => {
|
||||
await requestDeviceAccess();
|
||||
await refreshDevices();
|
||||
});
|
||||
startMicBtn.addEventListener("click", startMic);
|
||||
stopMicBtn.addEventListener("click", stopMic);
|
||||
sendChatBtn.addEventListener("click", () => {
|
||||
const text = chatInput.value.trim();
|
||||
if (!text) return;
|
||||
ensureAudioContext();
|
||||
addChat("You", text);
|
||||
sendCommand({ type: "input.text", text });
|
||||
chatInput.value = "";
|
||||
});
|
||||
clearLogBtn.addEventListener("click", () => {
|
||||
logEl.innerHTML = "";
|
||||
chatHistory.innerHTML = "";
|
||||
setInterim("You", "");
|
||||
setInterim("AI", "");
|
||||
interimUserText = "";
|
||||
interimAiText = "";
|
||||
});
|
||||
inputSelect.addEventListener("change", () => {
|
||||
if (micStream) {
|
||||
stopMic();
|
||||
startMic();
|
||||
}
|
||||
});
|
||||
outputSelect.addEventListener("change", () => setOutputDevice(outputSelect.value));
|
||||
|
||||
navigator.mediaDevices.addEventListener("devicechange", refreshDevices);
|
||||
refreshDevices().catch(() => {});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
1
models/__init__.py
Normal file
1
models/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Data Models Package"""
|
||||
143
models/commands.py
Normal file
143
models/commands.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""Protocol command models matching the original active-call API."""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class InviteCommand(BaseModel):
|
||||
"""Invite command to initiate a call."""
|
||||
|
||||
command: str = Field(default="invite", description="Command type")
|
||||
option: Optional[Dict[str, Any]] = Field(default=None, description="Call configuration options")
|
||||
|
||||
|
||||
class AcceptCommand(BaseModel):
|
||||
"""Accept command to accept an incoming call."""
|
||||
|
||||
command: str = Field(default="accept", description="Command type")
|
||||
option: Optional[Dict[str, Any]] = Field(default=None, description="Call configuration options")
|
||||
|
||||
|
||||
class RejectCommand(BaseModel):
|
||||
"""Reject command to reject an incoming call."""
|
||||
|
||||
command: str = Field(default="reject", description="Command type")
|
||||
reason: str = Field(default="", description="Reason for rejection")
|
||||
code: Optional[int] = Field(default=None, description="SIP response code")
|
||||
|
||||
|
||||
class RingingCommand(BaseModel):
|
||||
"""Ringing command to send ringing response."""
|
||||
|
||||
command: str = Field(default="ringing", description="Command type")
|
||||
recorder: Optional[Dict[str, Any]] = Field(default=None, description="Call recording configuration")
|
||||
early_media: bool = Field(default=False, description="Enable early media")
|
||||
ringtone: Optional[str] = Field(default=None, description="Custom ringtone URL")
|
||||
|
||||
|
||||
class TTSCommand(BaseModel):
|
||||
"""TTS command to convert text to speech."""
|
||||
|
||||
command: str = Field(default="tts", description="Command type")
|
||||
text: str = Field(..., description="Text to synthesize")
|
||||
speaker: Optional[str] = Field(default=None, description="Speaker voice name")
|
||||
play_id: Optional[str] = Field(default=None, description="Unique identifier for this TTS session")
|
||||
auto_hangup: bool = Field(default=False, description="Auto hangup after TTS completion")
|
||||
streaming: bool = Field(default=False, description="Streaming text input")
|
||||
end_of_stream: bool = Field(default=False, description="End of streaming input")
|
||||
wait_input_timeout: Optional[int] = Field(default=None, description="Max time to wait for input (seconds)")
|
||||
option: Optional[Dict[str, Any]] = Field(default=None, description="TTS provider specific options")
|
||||
|
||||
|
||||
class PlayCommand(BaseModel):
|
||||
"""Play command to play audio from URL."""
|
||||
|
||||
command: str = Field(default="play", description="Command type")
|
||||
url: str = Field(..., description="URL of audio file to play")
|
||||
auto_hangup: bool = Field(default=False, description="Auto hangup after playback")
|
||||
wait_input_timeout: Optional[int] = Field(default=None, description="Max time to wait for input (seconds)")
|
||||
|
||||
|
||||
class InterruptCommand(BaseModel):
|
||||
"""Interrupt command to interrupt current playback."""
|
||||
|
||||
command: str = Field(default="interrupt", description="Command type")
|
||||
graceful: bool = Field(default=False, description="Wait for current TTS to complete")
|
||||
|
||||
|
||||
class PauseCommand(BaseModel):
|
||||
"""Pause command to pause current playback."""
|
||||
|
||||
command: str = Field(default="pause", description="Command type")
|
||||
|
||||
|
||||
class ResumeCommand(BaseModel):
|
||||
"""Resume command to resume paused playback."""
|
||||
|
||||
command: str = Field(default="resume", description="Command type")
|
||||
|
||||
|
||||
class HangupCommand(BaseModel):
|
||||
"""Hangup command to end the call."""
|
||||
|
||||
command: str = Field(default="hangup", description="Command type")
|
||||
reason: Optional[str] = Field(default=None, description="Reason for hangup")
|
||||
initiator: Optional[str] = Field(default=None, description="Who initiated the hangup")
|
||||
|
||||
|
||||
class HistoryCommand(BaseModel):
|
||||
"""History command to add conversation history."""
|
||||
|
||||
command: str = Field(default="history", description="Command type")
|
||||
speaker: str = Field(..., description="Speaker identifier")
|
||||
text: str = Field(..., description="Conversation text")
|
||||
|
||||
|
||||
class ChatCommand(BaseModel):
|
||||
"""Chat command for text-based conversation."""
|
||||
|
||||
command: str = Field(default="chat", description="Command type")
|
||||
text: str = Field(..., description="Chat text message")
|
||||
|
||||
|
||||
# Command type mapping
|
||||
COMMAND_TYPES = {
|
||||
"invite": InviteCommand,
|
||||
"accept": AcceptCommand,
|
||||
"reject": RejectCommand,
|
||||
"ringing": RingingCommand,
|
||||
"tts": TTSCommand,
|
||||
"play": PlayCommand,
|
||||
"interrupt": InterruptCommand,
|
||||
"pause": PauseCommand,
|
||||
"resume": ResumeCommand,
|
||||
"hangup": HangupCommand,
|
||||
"history": HistoryCommand,
|
||||
"chat": ChatCommand,
|
||||
}
|
||||
|
||||
|
||||
def parse_command(data: Dict[str, Any]) -> BaseModel:
|
||||
"""
|
||||
Parse a command from JSON data.
|
||||
|
||||
Args:
|
||||
data: JSON data as dictionary
|
||||
|
||||
Returns:
|
||||
Parsed command model
|
||||
|
||||
Raises:
|
||||
ValueError: If command type is unknown
|
||||
"""
|
||||
command_type = data.get("command")
|
||||
|
||||
if not command_type:
|
||||
raise ValueError("Missing 'command' field")
|
||||
|
||||
command_class = COMMAND_TYPES.get(command_type)
|
||||
|
||||
if not command_class:
|
||||
raise ValueError(f"Unknown command type: {command_type}")
|
||||
|
||||
return command_class(**data)
|
||||
126
models/config.py
Normal file
126
models/config.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""Configuration models for call options."""
|
||||
|
||||
from typing import Optional, Dict, Any, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class VADOption(BaseModel):
|
||||
"""Voice Activity Detection configuration."""
|
||||
|
||||
type: str = Field(default="silero", description="VAD algorithm type (silero, webrtc)")
|
||||
samplerate: int = Field(default=16000, description="Audio sample rate for VAD")
|
||||
speech_padding: int = Field(default=250, description="Speech padding in milliseconds")
|
||||
silence_padding: int = Field(default=100, description="Silence padding in milliseconds")
|
||||
ratio: float = Field(default=0.5, description="Voice detection ratio threshold")
|
||||
voice_threshold: float = Field(default=0.5, description="Voice energy threshold")
|
||||
max_buffer_duration_secs: int = Field(default=50, description="Maximum buffer duration in seconds")
|
||||
silence_timeout: Optional[int] = Field(default=None, description="Silence timeout in milliseconds")
|
||||
endpoint: Optional[str] = Field(default=None, description="Custom VAD service endpoint")
|
||||
secret_key: Optional[str] = Field(default=None, description="VAD service secret key")
|
||||
secret_id: Optional[str] = Field(default=None, description="VAD service secret ID")
|
||||
|
||||
|
||||
class ASROption(BaseModel):
|
||||
"""Automatic Speech Recognition configuration."""
|
||||
|
||||
provider: str = Field(..., description="ASR provider (tencent, aliyun, openai, etc.)")
|
||||
language: Optional[str] = Field(default=None, description="Language code (zh-CN, en-US)")
|
||||
app_id: Optional[str] = Field(default=None, description="Application ID")
|
||||
secret_id: Optional[str] = Field(default=None, description="Secret ID for authentication")
|
||||
secret_key: Optional[str] = Field(default=None, description="Secret key for authentication")
|
||||
model_type: Optional[str] = Field(default=None, description="ASR model type (16k_zh, 8k_en)")
|
||||
buffer_size: Optional[int] = Field(default=None, description="Audio buffer size in bytes")
|
||||
samplerate: Optional[int] = Field(default=None, description="Audio sample rate")
|
||||
endpoint: Optional[str] = Field(default=None, description="Custom ASR service endpoint")
|
||||
extra: Optional[Dict[str, Any]] = Field(default=None, description="Additional parameters")
|
||||
start_when_answer: bool = Field(default=False, description="Start ASR when call is answered")
|
||||
|
||||
|
||||
class TTSOption(BaseModel):
|
||||
"""Text-to-Speech configuration."""
|
||||
|
||||
samplerate: Optional[int] = Field(default=None, description="TTS output sample rate")
|
||||
provider: str = Field(default="msedge", description="TTS provider (tencent, aliyun, deepgram, msedge)")
|
||||
speed: float = Field(default=1.0, description="Speech speed multiplier")
|
||||
app_id: Optional[str] = Field(default=None, description="Application ID")
|
||||
secret_id: Optional[str] = Field(default=None, description="Secret ID for authentication")
|
||||
secret_key: Optional[str] = Field(default=None, description="Secret key for authentication")
|
||||
volume: Optional[int] = Field(default=None, description="Speech volume level (1-10)")
|
||||
speaker: Optional[str] = Field(default=None, description="Voice speaker name")
|
||||
codec: Optional[str] = Field(default=None, description="Audio codec")
|
||||
subtitle: bool = Field(default=False, description="Enable subtitle generation")
|
||||
emotion: Optional[str] = Field(default=None, description="Speech emotion")
|
||||
endpoint: Optional[str] = Field(default=None, description="Custom TTS service endpoint")
|
||||
extra: Optional[Dict[str, Any]] = Field(default=None, description="Additional parameters")
|
||||
max_concurrent_tasks: Optional[int] = Field(default=None, description="Max concurrent tasks")
|
||||
|
||||
|
||||
class RecorderOption(BaseModel):
|
||||
"""Call recording configuration."""
|
||||
|
||||
recorder_file: str = Field(..., description="Path to recording file")
|
||||
samplerate: int = Field(default=16000, description="Recording sample rate")
|
||||
ptime: int = Field(default=200, description="Packet time in milliseconds")
|
||||
|
||||
|
||||
class MediaPassOption(BaseModel):
|
||||
"""Media pass-through configuration for external audio processing."""
|
||||
|
||||
url: str = Field(..., description="WebSocket URL for media streaming")
|
||||
input_sample_rate: int = Field(default=16000, description="Sample rate of audio received from WebSocket")
|
||||
output_sample_rate: int = Field(default=16000, description="Sample rate of audio sent to WebSocket")
|
||||
packet_size: int = Field(default=2560, description="Packet size in bytes")
|
||||
ptime: Optional[int] = Field(default=None, description="Buffered playback period in milliseconds")
|
||||
|
||||
|
||||
class SipOption(BaseModel):
|
||||
"""SIP protocol configuration."""
|
||||
|
||||
username: Optional[str] = Field(default=None, description="SIP username")
|
||||
password: Optional[str] = Field(default=None, description="SIP password")
|
||||
realm: Optional[str] = Field(default=None, description="SIP realm/domain")
|
||||
headers: Optional[Dict[str, str]] = Field(default=None, description="Additional SIP headers")
|
||||
|
||||
|
||||
class HandlerRule(BaseModel):
|
||||
"""Handler routing rule."""
|
||||
|
||||
caller: Optional[str] = Field(default=None, description="Caller pattern (regex)")
|
||||
callee: Optional[str] = Field(default=None, description="Callee pattern (regex)")
|
||||
playbook: Optional[str] = Field(default=None, description="Playbook file path")
|
||||
webhook: Optional[str] = Field(default=None, description="Webhook URL")
|
||||
|
||||
|
||||
class CallOption(BaseModel):
|
||||
"""Comprehensive call configuration options."""
|
||||
|
||||
# Basic options
|
||||
denoise: bool = Field(default=False, description="Enable noise reduction")
|
||||
offer: Optional[str] = Field(default=None, description="SDP offer string")
|
||||
callee: Optional[str] = Field(default=None, description="Callee SIP URI or phone number")
|
||||
caller: Optional[str] = Field(default=None, description="Caller SIP URI or phone number")
|
||||
|
||||
# Audio codec
|
||||
codec: str = Field(default="pcm", description="Audio codec (pcm, pcma, pcmu, g722)")
|
||||
|
||||
# Component configurations
|
||||
recorder: Optional[RecorderOption] = Field(default=None, description="Call recording config")
|
||||
asr: Optional[ASROption] = Field(default=None, description="ASR configuration")
|
||||
vad: Optional[VADOption] = Field(default=None, description="VAD configuration")
|
||||
tts: Optional[TTSOption] = Field(default=None, description="TTS configuration")
|
||||
media_pass: Optional[MediaPassOption] = Field(default=None, description="Media pass-through config")
|
||||
sip: Optional[SipOption] = Field(default=None, description="SIP configuration")
|
||||
|
||||
# Timeouts and networking
|
||||
handshake_timeout: Optional[int] = Field(default=None, description="Handshake timeout in seconds")
|
||||
enable_ipv6: bool = Field(default=False, description="Enable IPv6 support")
|
||||
inactivity_timeout: Optional[int] = Field(default=None, description="Inactivity timeout in seconds")
|
||||
|
||||
# EOU configuration
|
||||
eou: Optional[Dict[str, Any]] = Field(default=None, description="End of utterance detection config")
|
||||
|
||||
# Extra parameters
|
||||
extra: Optional[Dict[str, Any]] = Field(default=None, description="Additional custom parameters")
|
||||
|
||||
class Config:
|
||||
populate_by_name = True
|
||||
231
models/events.py
Normal file
231
models/events.py
Normal file
@@ -0,0 +1,231 @@
|
||||
"""Protocol event models matching the original active-call API."""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def current_timestamp_ms() -> int:
|
||||
"""Get current timestamp in milliseconds."""
|
||||
return int(datetime.now().timestamp() * 1000)
|
||||
|
||||
|
||||
# Base Event Model
|
||||
class BaseEvent(BaseModel):
|
||||
"""Base event model."""
|
||||
|
||||
event: str = Field(..., description="Event type")
|
||||
track_id: str = Field(..., description="Unique track identifier")
|
||||
timestamp: int = Field(default_factory=current_timestamp_ms, description="Event timestamp in milliseconds")
|
||||
|
||||
|
||||
# Lifecycle Events
|
||||
class IncomingEvent(BaseEvent):
|
||||
"""Incoming call event (SIP only)."""
|
||||
|
||||
event: str = Field(default="incoming", description="Event type")
|
||||
caller: Optional[str] = Field(default=None, description="Caller's SIP URI")
|
||||
callee: Optional[str] = Field(default=None, description="Callee's SIP URI")
|
||||
sdp: Optional[str] = Field(default=None, description="SDP offer from caller")
|
||||
|
||||
|
||||
class AnswerEvent(BaseEvent):
|
||||
"""Call answered event."""
|
||||
|
||||
event: str = Field(default="answer", description="Event type")
|
||||
sdp: Optional[str] = Field(default=None, description="SDP answer from server")
|
||||
|
||||
|
||||
class RejectEvent(BaseEvent):
|
||||
"""Call rejected event."""
|
||||
|
||||
event: str = Field(default="reject", description="Event type")
|
||||
reason: Optional[str] = Field(default=None, description="Rejection reason")
|
||||
code: Optional[int] = Field(default=None, description="SIP response code")
|
||||
|
||||
|
||||
class RingingEvent(BaseEvent):
|
||||
"""Call ringing event."""
|
||||
|
||||
event: str = Field(default="ringing", description="Event type")
|
||||
early_media: bool = Field(default=False, description="Early media available")
|
||||
|
||||
|
||||
class HangupEvent(BaseModel):
|
||||
"""Call hangup event."""
|
||||
|
||||
event: str = Field(default="hangup", description="Event type")
|
||||
timestamp: int = Field(default_factory=current_timestamp_ms, description="Event timestamp")
|
||||
reason: Optional[str] = Field(default=None, description="Hangup reason")
|
||||
initiator: Optional[str] = Field(default=None, description="Who initiated hangup")
|
||||
start_time: Optional[str] = Field(default=None, description="Call start time (ISO 8601)")
|
||||
hangup_time: Optional[str] = Field(default=None, description="Hangup time (ISO 8601)")
|
||||
answer_time: Optional[str] = Field(default=None, description="Answer time (ISO 8601)")
|
||||
ringing_time: Optional[str] = Field(default=None, description="Ringing time (ISO 8601)")
|
||||
from_: Optional[Dict[str, Any]] = Field(default=None, alias="from", description="Caller info")
|
||||
to: Optional[Dict[str, Any]] = Field(default=None, description="Callee info")
|
||||
extra: Optional[Dict[str, Any]] = Field(default=None, description="Additional metadata")
|
||||
|
||||
class Config:
|
||||
populate_by_name = True
|
||||
|
||||
|
||||
# VAD Events
|
||||
class SpeakingEvent(BaseEvent):
|
||||
"""Speech detected event."""
|
||||
|
||||
event: str = Field(default="speaking", description="Event type")
|
||||
start_time: int = Field(default_factory=current_timestamp_ms, description="Speech start time")
|
||||
|
||||
|
||||
class SilenceEvent(BaseEvent):
|
||||
"""Silence detected event."""
|
||||
|
||||
event: str = Field(default="silence", description="Event type")
|
||||
start_time: int = Field(default_factory=current_timestamp_ms, description="Silence start time")
|
||||
duration: int = Field(default=0, description="Silence duration in milliseconds")
|
||||
|
||||
|
||||
# AI/ASR Events
|
||||
class AsrFinalEvent(BaseEvent):
|
||||
"""ASR final transcription event."""
|
||||
|
||||
event: str = Field(default="asrFinal", description="Event type")
|
||||
index: int = Field(..., description="ASR result sequence number")
|
||||
start_time: Optional[int] = Field(default=None, description="Speech start time")
|
||||
end_time: Optional[int] = Field(default=None, description="Speech end time")
|
||||
text: str = Field(..., description="Transcribed text")
|
||||
|
||||
|
||||
class AsrDeltaEvent(BaseEvent):
|
||||
"""ASR partial transcription event (streaming)."""
|
||||
|
||||
event: str = Field(default="asrDelta", description="Event type")
|
||||
index: int = Field(..., description="ASR result sequence number")
|
||||
start_time: Optional[int] = Field(default=None, description="Speech start time")
|
||||
end_time: Optional[int] = Field(default=None, description="Speech end time")
|
||||
text: str = Field(..., description="Partial transcribed text")
|
||||
|
||||
|
||||
class EouEvent(BaseEvent):
|
||||
"""End of utterance detection event."""
|
||||
|
||||
event: str = Field(default="eou", description="Event type")
|
||||
completed: bool = Field(default=True, description="Whether utterance was completed")
|
||||
|
||||
|
||||
# Audio Track Events
|
||||
class TrackStartEvent(BaseEvent):
|
||||
"""Audio track start event."""
|
||||
|
||||
event: str = Field(default="trackStart", description="Event type")
|
||||
play_id: Optional[str] = Field(default=None, description="Play ID from TTS/Play command")
|
||||
|
||||
|
||||
class TrackEndEvent(BaseEvent):
|
||||
"""Audio track end event."""
|
||||
|
||||
event: str = Field(default="trackEnd", description="Event type")
|
||||
duration: int = Field(..., description="Track duration in milliseconds")
|
||||
ssrc: int = Field(..., description="RTP SSRC identifier")
|
||||
play_id: Optional[str] = Field(default=None, description="Play ID from TTS/Play command")
|
||||
|
||||
|
||||
class InterruptionEvent(BaseEvent):
|
||||
"""Playback interruption event."""
|
||||
|
||||
event: str = Field(default="interruption", description="Event type")
|
||||
play_id: Optional[str] = Field(default=None, description="Play ID that was interrupted")
|
||||
subtitle: Optional[str] = Field(default=None, description="TTS text being played")
|
||||
position: Optional[int] = Field(default=None, description="Word index position")
|
||||
total_duration: Optional[int] = Field(default=None, description="Total TTS duration")
|
||||
current: Optional[int] = Field(default=None, description="Elapsed time when interrupted")
|
||||
|
||||
|
||||
# System Events
|
||||
class ErrorEvent(BaseEvent):
|
||||
"""Error event."""
|
||||
|
||||
event: str = Field(default="error", description="Event type")
|
||||
sender: str = Field(..., description="Component that generated the error")
|
||||
error: str = Field(..., description="Error message")
|
||||
code: Optional[int] = Field(default=None, description="Error code")
|
||||
|
||||
|
||||
class MetricsEvent(BaseModel):
|
||||
"""Performance metrics event."""
|
||||
|
||||
event: str = Field(default="metrics", description="Event type")
|
||||
timestamp: int = Field(default_factory=current_timestamp_ms, description="Event timestamp")
|
||||
key: str = Field(..., description="Metric key")
|
||||
duration: int = Field(..., description="Duration in milliseconds")
|
||||
data: Optional[Dict[str, Any]] = Field(default=None, description="Additional metric data")
|
||||
|
||||
|
||||
class AddHistoryEvent(BaseModel):
|
||||
"""Conversation history entry added event."""
|
||||
|
||||
event: str = Field(default="addHistory", description="Event type")
|
||||
timestamp: int = Field(default_factory=current_timestamp_ms, description="Event timestamp")
|
||||
sender: Optional[str] = Field(default=None, description="Component that added history")
|
||||
speaker: str = Field(..., description="Speaker identifier")
|
||||
text: str = Field(..., description="Conversation text")
|
||||
|
||||
|
||||
class DTMFEvent(BaseEvent):
|
||||
"""DTMF tone detected event."""
|
||||
|
||||
event: str = Field(default="dtmf", description="Event type")
|
||||
digit: str = Field(..., description="DTMF digit (0-9, *, #, A-D)")
|
||||
|
||||
|
||||
class HeartBeatEvent(BaseModel):
|
||||
"""Server-to-client heartbeat to keep connection alive."""
|
||||
|
||||
event: str = Field(default="heartBeat", description="Event type")
|
||||
timestamp: int = Field(default_factory=current_timestamp_ms, description="Event timestamp in milliseconds")
|
||||
|
||||
|
||||
# Event type mapping
|
||||
EVENT_TYPES = {
|
||||
"incoming": IncomingEvent,
|
||||
"answer": AnswerEvent,
|
||||
"reject": RejectEvent,
|
||||
"ringing": RingingEvent,
|
||||
"hangup": HangupEvent,
|
||||
"speaking": SpeakingEvent,
|
||||
"silence": SilenceEvent,
|
||||
"asrFinal": AsrFinalEvent,
|
||||
"asrDelta": AsrDeltaEvent,
|
||||
"eou": EouEvent,
|
||||
"trackStart": TrackStartEvent,
|
||||
"trackEnd": TrackEndEvent,
|
||||
"interruption": InterruptionEvent,
|
||||
"error": ErrorEvent,
|
||||
"metrics": MetricsEvent,
|
||||
"addHistory": AddHistoryEvent,
|
||||
"dtmf": DTMFEvent,
|
||||
"heartBeat": HeartBeatEvent,
|
||||
}
|
||||
|
||||
|
||||
def create_event(event_type: str, **kwargs) -> BaseModel:
|
||||
"""
|
||||
Create an event model.
|
||||
|
||||
Args:
|
||||
event_type: Type of event to create
|
||||
**kwargs: Event fields
|
||||
|
||||
Returns:
|
||||
Event model instance
|
||||
|
||||
Raises:
|
||||
ValueError: If event type is unknown
|
||||
"""
|
||||
event_class = EVENT_TYPES.get(event_type)
|
||||
|
||||
if not event_class:
|
||||
raise ValueError(f"Unknown event type: {event_type}")
|
||||
|
||||
return event_class(event=event_type, **kwargs)
|
||||
73
models/ws_v1.py
Normal file
73
models/ws_v1.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""WS v1 protocol message models and helpers."""
|
||||
|
||||
from typing import Optional, Dict, Any, Literal
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
def now_ms() -> int:
|
||||
"""Current unix timestamp in milliseconds."""
|
||||
import time
|
||||
|
||||
return int(time.time() * 1000)
|
||||
|
||||
|
||||
# Client -> Server messages
|
||||
class HelloMessage(BaseModel):
|
||||
type: Literal["hello"]
|
||||
version: str = Field(..., description="Protocol version, currently v1")
|
||||
auth: Optional[Dict[str, str]] = Field(default=None, description="Auth payload, e.g. {'apiKey': '...'}")
|
||||
|
||||
|
||||
class SessionStartMessage(BaseModel):
|
||||
type: Literal["session.start"]
|
||||
audio: Optional[Dict[str, Any]] = Field(default=None, description="Optional audio format metadata")
|
||||
metadata: Optional[Dict[str, Any]] = Field(default=None, description="Optional session metadata")
|
||||
|
||||
|
||||
class SessionStopMessage(BaseModel):
|
||||
type: Literal["session.stop"]
|
||||
reason: Optional[str] = None
|
||||
|
||||
|
||||
class InputTextMessage(BaseModel):
|
||||
type: Literal["input.text"]
|
||||
text: str
|
||||
|
||||
|
||||
class ResponseCancelMessage(BaseModel):
|
||||
type: Literal["response.cancel"]
|
||||
graceful: bool = False
|
||||
|
||||
|
||||
class ToolCallResultsMessage(BaseModel):
|
||||
type: Literal["tool_call.results"]
|
||||
results: list[Dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
CLIENT_MESSAGE_TYPES = {
|
||||
"hello": HelloMessage,
|
||||
"session.start": SessionStartMessage,
|
||||
"session.stop": SessionStopMessage,
|
||||
"input.text": InputTextMessage,
|
||||
"response.cancel": ResponseCancelMessage,
|
||||
"tool_call.results": ToolCallResultsMessage,
|
||||
}
|
||||
|
||||
|
||||
def parse_client_message(data: Dict[str, Any]) -> BaseModel:
|
||||
"""Parse and validate a WS v1 client message."""
|
||||
msg_type = data.get("type")
|
||||
if not msg_type:
|
||||
raise ValueError("Missing 'type' field")
|
||||
msg_class = CLIENT_MESSAGE_TYPES.get(msg_type)
|
||||
if not msg_class:
|
||||
raise ValueError(f"Unknown client message type: {msg_type}")
|
||||
return msg_class(**data)
|
||||
|
||||
|
||||
# Server -> Client event helpers
|
||||
def ev(event_type: str, **payload: Any) -> Dict[str, Any]:
|
||||
"""Create a WS v1 server event payload."""
|
||||
base = {"type": event_type, "timestamp": now_ms()}
|
||||
base.update(payload)
|
||||
return base
|
||||
6
processors/__init__.py
Normal file
6
processors/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Audio Processors Package"""
|
||||
|
||||
from processors.eou import EouDetector
|
||||
from processors.vad import SileroVAD, VADProcessor
|
||||
|
||||
__all__ = ["EouDetector", "SileroVAD", "VADProcessor"]
|
||||
80
processors/eou.py
Normal file
80
processors/eou.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""End-of-Utterance Detection."""
|
||||
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class EouDetector:
|
||||
"""
|
||||
End-of-utterance detector. Fires EOU only after continuous silence for
|
||||
silence_threshold_ms. Short pauses between sentences do not trigger EOU
|
||||
because speech resets the silence timer (one EOU per turn).
|
||||
"""
|
||||
|
||||
def __init__(self, silence_threshold_ms: int = 1000, min_speech_duration_ms: int = 250):
|
||||
"""
|
||||
Initialize EOU detector.
|
||||
|
||||
Args:
|
||||
silence_threshold_ms: How long silence must last to trigger EOU (default 1000ms)
|
||||
min_speech_duration_ms: Minimum speech duration to consider valid (default 250ms)
|
||||
"""
|
||||
self.threshold = silence_threshold_ms / 1000.0
|
||||
self.min_speech = min_speech_duration_ms / 1000.0
|
||||
self._silence_threshold_ms = silence_threshold_ms
|
||||
self._min_speech_duration_ms = min_speech_duration_ms
|
||||
|
||||
# State
|
||||
self.is_speaking = False
|
||||
self.speech_start_time = 0.0
|
||||
self.silence_start_time: Optional[float] = None
|
||||
self.triggered = False
|
||||
|
||||
def process(self, vad_status: str, force_eligible: bool = False) -> bool:
|
||||
"""
|
||||
Process VAD status and detect end of utterance.
|
||||
|
||||
Input: "Speech" or "Silence" (from VAD).
|
||||
Output: True if EOU detected, False otherwise.
|
||||
|
||||
Short breaks between phrases reset the silence clock when speech
|
||||
resumes, so only one EOU is emitted after the user truly stops.
|
||||
"""
|
||||
now = time.time()
|
||||
|
||||
if vad_status == "Speech":
|
||||
if not self.is_speaking:
|
||||
self.is_speaking = True
|
||||
self.speech_start_time = now
|
||||
self.triggered = False
|
||||
# Any speech resets silence timer — short pause + more speech = one utterance
|
||||
self.silence_start_time = None
|
||||
return False
|
||||
|
||||
if vad_status == "Silence":
|
||||
if not self.is_speaking:
|
||||
return False
|
||||
if self.silence_start_time is None:
|
||||
self.silence_start_time = now
|
||||
|
||||
speech_duration = self.silence_start_time - self.speech_start_time
|
||||
if speech_duration < self.min_speech and not force_eligible:
|
||||
self.is_speaking = False
|
||||
self.silence_start_time = None
|
||||
return False
|
||||
|
||||
silence_duration = now - self.silence_start_time
|
||||
if silence_duration >= self.threshold and not self.triggered:
|
||||
self.triggered = True
|
||||
self.is_speaking = False
|
||||
self.silence_start_time = None
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset EOU detector state."""
|
||||
self.is_speaking = False
|
||||
self.speech_start_time = 0.0
|
||||
self.silence_start_time = None
|
||||
self.triggered = False
|
||||
168
processors/tracks.py
Normal file
168
processors/tracks.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""Audio track processing for WebRTC."""
|
||||
|
||||
import asyncio
|
||||
import fractions
|
||||
from typing import Optional
|
||||
from loguru import logger
|
||||
|
||||
# Try to import aiortc (optional for WebRTC functionality)
|
||||
try:
|
||||
from aiortc import AudioStreamTrack
|
||||
AIORTC_AVAILABLE = True
|
||||
except ImportError:
|
||||
AIORTC_AVAILABLE = False
|
||||
AudioStreamTrack = object # Dummy class for type hints
|
||||
|
||||
# Try to import PyAV (optional for audio resampling)
|
||||
try:
|
||||
from av import AudioFrame, AudioResampler
|
||||
AV_AVAILABLE = True
|
||||
except ImportError:
|
||||
AV_AVAILABLE = False
|
||||
# Create dummy classes for type hints
|
||||
class AudioFrame:
|
||||
pass
|
||||
class AudioResampler:
|
||||
pass
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Resampled16kTrack(AudioStreamTrack if AIORTC_AVAILABLE else object):
|
||||
"""
|
||||
Audio track that resamples input to 16kHz mono PCM.
|
||||
|
||||
Wraps an existing MediaStreamTrack and converts its output
|
||||
to 16kHz mono 16-bit PCM format for the pipeline.
|
||||
"""
|
||||
|
||||
def __init__(self, track, target_sample_rate: int = 16000):
|
||||
"""
|
||||
Initialize resampled track.
|
||||
|
||||
Args:
|
||||
track: Source MediaStreamTrack
|
||||
target_sample_rate: Target sample rate (default: 16000)
|
||||
"""
|
||||
if not AIORTC_AVAILABLE:
|
||||
raise RuntimeError("aiortc not available - Resampled16kTrack cannot be used")
|
||||
|
||||
super().__init__()
|
||||
self.track = track
|
||||
self.target_sample_rate = target_sample_rate
|
||||
|
||||
if AV_AVAILABLE:
|
||||
self.resampler = AudioResampler(
|
||||
format="s16",
|
||||
layout="mono",
|
||||
rate=target_sample_rate
|
||||
)
|
||||
else:
|
||||
logger.warning("PyAV not available, audio resampling disabled")
|
||||
self.resampler = None
|
||||
|
||||
self._closed = False
|
||||
|
||||
async def recv(self):
|
||||
"""
|
||||
Receive and resample next audio frame.
|
||||
|
||||
Returns:
|
||||
Resampled AudioFrame at 16kHz mono
|
||||
"""
|
||||
if self._closed:
|
||||
raise RuntimeError("Track is closed")
|
||||
|
||||
# Get frame from source track
|
||||
frame = await self.track.recv()
|
||||
|
||||
# Resample the frame if AV is available
|
||||
if AV_AVAILABLE and self.resampler:
|
||||
resampled_frame = self.resampler.resample(frame)
|
||||
# Ensure the frame has the correct format
|
||||
resampled_frame.sample_rate = self.target_sample_rate
|
||||
return resampled_frame
|
||||
else:
|
||||
# Return frame as-is if AV is not available
|
||||
return frame
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the track and cleanup resources."""
|
||||
self._closed = True
|
||||
if hasattr(self, 'resampler') and self.resampler:
|
||||
del self.resampler
|
||||
logger.debug("Resampled track stopped")
|
||||
|
||||
|
||||
class SineWaveTrack(AudioStreamTrack if AIORTC_AVAILABLE else object):
|
||||
"""
|
||||
Synthetic audio track that generates a sine wave.
|
||||
|
||||
Useful for testing without requiring real audio input.
|
||||
"""
|
||||
|
||||
def __init__(self, sample_rate: int = 16000, frequency: int = 440):
|
||||
"""
|
||||
Initialize sine wave track.
|
||||
|
||||
Args:
|
||||
sample_rate: Audio sample rate (default: 16000)
|
||||
frequency: Sine wave frequency in Hz (default: 440)
|
||||
"""
|
||||
if not AIORTC_AVAILABLE:
|
||||
raise RuntimeError("aiortc not available - SineWaveTrack cannot be used")
|
||||
|
||||
super().__init__()
|
||||
self.sample_rate = sample_rate
|
||||
self.frequency = frequency
|
||||
self.counter = 0
|
||||
self._stopped = False
|
||||
|
||||
async def recv(self):
|
||||
"""
|
||||
Generate next audio frame with sine wave.
|
||||
|
||||
Returns:
|
||||
AudioFrame with sine wave data
|
||||
"""
|
||||
if self._stopped:
|
||||
raise RuntimeError("Track is stopped")
|
||||
|
||||
# Generate 20ms of audio
|
||||
samples = int(self.sample_rate * 0.02)
|
||||
pts = self.counter
|
||||
time_base = fractions.Fraction(1, self.sample_rate)
|
||||
|
||||
# Generate sine wave
|
||||
t = np.linspace(
|
||||
self.counter / self.sample_rate,
|
||||
(self.counter + samples) / self.sample_rate,
|
||||
samples,
|
||||
endpoint=False
|
||||
)
|
||||
|
||||
# Generate sine wave (Int16 PCM)
|
||||
data = (0.5 * np.sin(2 * np.pi * self.frequency * t) * 32767).astype(np.int16)
|
||||
|
||||
# Update counter
|
||||
self.counter += samples
|
||||
|
||||
# Create AudioFrame if AV is available
|
||||
if AV_AVAILABLE:
|
||||
frame = AudioFrame.from_ndarray(data.reshape(1, -1), format='s16', layout='mono')
|
||||
frame.pts = pts
|
||||
frame.time_base = time_base
|
||||
frame.sample_rate = self.sample_rate
|
||||
return frame
|
||||
else:
|
||||
# Return simple data structure if AV is not available
|
||||
return {
|
||||
'data': data,
|
||||
'sample_rate': self.sample_rate,
|
||||
'pts': pts,
|
||||
'time_base': time_base
|
||||
}
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the track."""
|
||||
self._stopped = True
|
||||
221
processors/vad.py
Normal file
221
processors/vad.py
Normal file
@@ -0,0 +1,221 @@
|
||||
"""Voice Activity Detection using Silero VAD."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Tuple, Optional
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
|
||||
# Try to import onnxruntime (optional for VAD functionality)
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
ONNX_AVAILABLE = True
|
||||
except ImportError:
|
||||
ONNX_AVAILABLE = False
|
||||
ort = None
|
||||
logger.warning("onnxruntime not available - VAD will be disabled")
|
||||
|
||||
|
||||
class SileroVAD:
|
||||
"""
|
||||
Voice Activity Detection using Silero VAD model.
|
||||
|
||||
Detects speech in audio chunks using the Silero VAD ONNX model.
|
||||
Returns "Speech" or "Silence" for each audio chunk.
|
||||
"""
|
||||
|
||||
def __init__(self, model_path: str = "data/vad/silero_vad.onnx", sample_rate: int = 16000):
|
||||
"""
|
||||
Initialize Silero VAD.
|
||||
|
||||
Args:
|
||||
model_path: Path to Silero VAD ONNX model
|
||||
sample_rate: Audio sample rate (must be 16kHz for Silero VAD)
|
||||
"""
|
||||
self.sample_rate = sample_rate
|
||||
self.model_path = model_path
|
||||
|
||||
# Check if model exists
|
||||
if not os.path.exists(model_path):
|
||||
logger.warning(f"VAD model not found at {model_path}. VAD will be disabled.")
|
||||
self.session = None
|
||||
return
|
||||
|
||||
# Check if onnxruntime is available
|
||||
if not ONNX_AVAILABLE:
|
||||
logger.warning("onnxruntime not available - VAD will be disabled")
|
||||
self.session = None
|
||||
return
|
||||
|
||||
# Load ONNX model
|
||||
try:
|
||||
self.session = ort.InferenceSession(model_path)
|
||||
logger.info(f"Loaded Silero VAD model from {model_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load VAD model: {e}")
|
||||
self.session = None
|
||||
return
|
||||
|
||||
# Internal state for VAD
|
||||
self._reset_state()
|
||||
self.buffer = np.array([], dtype=np.float32)
|
||||
self.min_chunk_size = 512
|
||||
self.last_label = "Silence"
|
||||
self.last_probability = 0.0
|
||||
self._energy_noise_floor = 1e-4
|
||||
|
||||
def _reset_state(self):
|
||||
# Silero VAD V4+ expects state shape [2, 1, 128]
|
||||
self._state = np.zeros((2, 1, 128), dtype=np.float32)
|
||||
self._sr = np.array([self.sample_rate], dtype=np.int64)
|
||||
|
||||
def process_audio(self, pcm_bytes: bytes, chunk_size_ms: int = 20) -> Tuple[str, float]:
|
||||
"""
|
||||
Process audio chunk and detect speech.
|
||||
|
||||
Args:
|
||||
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
|
||||
chunk_size_ms: Chunk duration in milliseconds (ignored for buffering logic)
|
||||
|
||||
Returns:
|
||||
Tuple of (label, probability) where label is "Speech" or "Silence"
|
||||
"""
|
||||
if self.session is None or not ONNX_AVAILABLE:
|
||||
# Fallback energy-based VAD with adaptive noise floor.
|
||||
if not pcm_bytes:
|
||||
return "Silence", 0.0
|
||||
audio_int16 = np.frombuffer(pcm_bytes, dtype=np.int16)
|
||||
if audio_int16.size == 0:
|
||||
return "Silence", 0.0
|
||||
audio_float = audio_int16.astype(np.float32) / 32768.0
|
||||
rms = float(np.sqrt(np.mean(audio_float * audio_float)))
|
||||
|
||||
# Update adaptive noise floor (slowly rises, faster to fall)
|
||||
if rms < self._energy_noise_floor:
|
||||
self._energy_noise_floor = 0.95 * self._energy_noise_floor + 0.05 * rms
|
||||
else:
|
||||
self._energy_noise_floor = 0.995 * self._energy_noise_floor + 0.005 * rms
|
||||
|
||||
# Compute SNR-like ratio and map to probability
|
||||
denom = max(self._energy_noise_floor, 1e-6)
|
||||
snr = max(0.0, (rms - denom) / denom)
|
||||
probability = min(1.0, snr / 3.0) # ~3x above noise => strong speech
|
||||
label = "Speech" if probability >= 0.5 else "Silence"
|
||||
return label, probability
|
||||
|
||||
# Convert bytes to numpy array of int16
|
||||
audio_int16 = np.frombuffer(pcm_bytes, dtype=np.int16)
|
||||
|
||||
# Normalize to float32 (-1.0 to 1.0)
|
||||
audio_float = audio_int16.astype(np.float32) / 32768.0
|
||||
|
||||
# Add to buffer
|
||||
self.buffer = np.concatenate((self.buffer, audio_float))
|
||||
|
||||
# Process all complete chunks in the buffer
|
||||
processed_any = False
|
||||
while len(self.buffer) >= self.min_chunk_size:
|
||||
# Slice exactly 512 samples
|
||||
chunk = self.buffer[:self.min_chunk_size]
|
||||
self.buffer = self.buffer[self.min_chunk_size:]
|
||||
|
||||
# Prepare inputs
|
||||
# Input tensor shape: [batch, samples] -> [1, 512]
|
||||
input_tensor = chunk.reshape(1, -1)
|
||||
|
||||
# Run inference
|
||||
try:
|
||||
ort_inputs = {
|
||||
'input': input_tensor,
|
||||
'state': self._state,
|
||||
'sr': self._sr
|
||||
}
|
||||
|
||||
# Outputs: probability, state
|
||||
out, self._state = self.session.run(None, ort_inputs)
|
||||
|
||||
# Get probability
|
||||
self.last_probability = float(out[0][0])
|
||||
self.last_label = "Speech" if self.last_probability >= 0.5 else "Silence"
|
||||
processed_any = True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"VAD inference error: {e}")
|
||||
# Try to determine if it's an input name issue
|
||||
try:
|
||||
inputs = [x.name for x in self.session.get_inputs()]
|
||||
logger.error(f"Model expects inputs: {inputs}")
|
||||
except:
|
||||
pass
|
||||
return "Speech", 1.0
|
||||
|
||||
return self.last_label, self.last_probability
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset VAD internal state."""
|
||||
self._reset_state()
|
||||
self.buffer = np.array([], dtype=np.float32)
|
||||
self.last_label = "Silence"
|
||||
self.last_probability = 0.0
|
||||
|
||||
|
||||
class VADProcessor:
|
||||
"""
|
||||
High-level VAD processor with state management.
|
||||
|
||||
Tracks speech/silence state and emits events on transitions.
|
||||
"""
|
||||
|
||||
def __init__(self, vad_model: SileroVAD, threshold: float = 0.5):
|
||||
"""
|
||||
Initialize VAD processor.
|
||||
|
||||
Args:
|
||||
vad_model: Silero VAD model instance
|
||||
threshold: Speech detection threshold
|
||||
"""
|
||||
self.vad = vad_model
|
||||
self.threshold = threshold
|
||||
self.is_speaking = False
|
||||
self.speech_start_time: Optional[float] = None
|
||||
self.silence_start_time: Optional[float] = None
|
||||
|
||||
def process(self, pcm_bytes: bytes, chunk_size_ms: int = 20) -> Optional[Tuple[str, float]]:
|
||||
"""
|
||||
Process audio chunk and detect state changes.
|
||||
|
||||
Args:
|
||||
pcm_bytes: PCM audio data
|
||||
chunk_size_ms: Chunk duration in milliseconds
|
||||
|
||||
Returns:
|
||||
Tuple of (event_type, probability) if state changed, None otherwise
|
||||
"""
|
||||
label, probability = self.vad.process_audio(pcm_bytes, chunk_size_ms)
|
||||
|
||||
# Check if this is speech based on threshold
|
||||
is_speech = probability >= self.threshold
|
||||
|
||||
# State transition: Silence -> Speech
|
||||
if is_speech and not self.is_speaking:
|
||||
self.is_speaking = True
|
||||
self.speech_start_time = asyncio.get_event_loop().time()
|
||||
self.silence_start_time = None
|
||||
return ("speaking", probability)
|
||||
|
||||
# State transition: Speech -> Silence
|
||||
elif not is_speech and self.is_speaking:
|
||||
self.is_speaking = False
|
||||
self.silence_start_time = asyncio.get_event_loop().time()
|
||||
self.speech_start_time = None
|
||||
return ("silence", probability)
|
||||
|
||||
return None
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset VAD state."""
|
||||
self.vad.reset()
|
||||
self.is_speaking = False
|
||||
self.speech_start_time = None
|
||||
self.silence_start_time = None
|
||||
134
pyproject.toml
Normal file
134
pyproject.toml
Normal file
@@ -0,0 +1,134 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=68.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "py-active-call-cc"
|
||||
version = "0.1.0"
|
||||
description = "Python Active-Call: Real-time audio streaming with WebSocket and WebRTC"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.11"
|
||||
license = {text = "MIT"}
|
||||
authors = [
|
||||
{name = "Your Name", email = "your.email@example.com"}
|
||||
]
|
||||
keywords = ["webrtc", "websocket", "audio", "voip", "real-time"]
|
||||
classifiers = [
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Intended Audience :: Developers",
|
||||
"Topic :: Communications :: Telephony",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/yourusername/py-active-call-cc"
|
||||
Documentation = "https://github.com/yourusername/py-active-call-cc/blob/main/README.md"
|
||||
Repository = "https://github.com/yourusername/py-active-call-cc.git"
|
||||
Issues = "https://github.com/yourusername/py-active-call-cc/issues"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["."]
|
||||
include = ["app*"]
|
||||
exclude = ["tests*", "scripts*", "reference*"]
|
||||
|
||||
[tool.black]
|
||||
line-length = 100
|
||||
target-version = ['py311']
|
||||
include = '\.pyi?$'
|
||||
extend-exclude = '''
|
||||
/(
|
||||
# directories
|
||||
\.eggs
|
||||
| \.git
|
||||
| \.hg
|
||||
| \.mypy_cache
|
||||
| \.tox
|
||||
| \.venv
|
||||
| build
|
||||
| dist
|
||||
| reference
|
||||
)/
|
||||
'''
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 100
|
||||
target-version = "py311"
|
||||
select = [
|
||||
"E", # pycodestyle errors
|
||||
"W", # pycodestyle warnings
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
"B", # flake8-bugbear
|
||||
"C4", # flake8-comprehensions
|
||||
"UP", # pyupgrade
|
||||
]
|
||||
ignore = [
|
||||
"E501", # line too long (handled by black)
|
||||
"B008", # do not perform function calls in argument defaults
|
||||
]
|
||||
exclude = [
|
||||
".bzr",
|
||||
".direnv",
|
||||
".eggs",
|
||||
".git",
|
||||
".hg",
|
||||
".mypy_cache",
|
||||
".nox",
|
||||
".pants.d",
|
||||
".ruff_cache",
|
||||
".svn",
|
||||
".tox",
|
||||
".venv",
|
||||
"__pypackages__",
|
||||
"_build",
|
||||
"buck-out",
|
||||
"build",
|
||||
"dist",
|
||||
"node_modules",
|
||||
"venv",
|
||||
"reference",
|
||||
]
|
||||
|
||||
[tool.ruff.per-file-ignores]
|
||||
"__init__.py" = ["F401"] # unused imports
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.11"
|
||||
warn_return_any = true
|
||||
warn_unused_configs = true
|
||||
disallow_untyped_defs = false
|
||||
disallow_incomplete_defs = false
|
||||
check_untyped_defs = true
|
||||
no_implicit_optional = true
|
||||
warn_redundant_casts = true
|
||||
warn_unused_ignores = true
|
||||
warn_no_return = true
|
||||
strict_equality = true
|
||||
exclude = [
|
||||
"venv",
|
||||
"reference",
|
||||
"build",
|
||||
"dist",
|
||||
]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = [
|
||||
"aiortc.*",
|
||||
"av.*",
|
||||
"onnxruntime.*",
|
||||
]
|
||||
ignore_missing_imports = true
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
minversion = "7.0"
|
||||
addopts = "-ra -q --strict-markers --strict-config"
|
||||
testpaths = ["tests"]
|
||||
pythonpath = ["."]
|
||||
asyncio_mode = "auto"
|
||||
markers = [
|
||||
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
||||
"integration: marks tests as integration tests",
|
||||
]
|
||||
37
requirements.txt
Normal file
37
requirements.txt
Normal file
@@ -0,0 +1,37 @@
|
||||
# Web Framework
|
||||
fastapi>=0.109.0
|
||||
uvicorn[standard]>=0.27.0
|
||||
websockets>=12.0
|
||||
python-multipart>=0.0.6
|
||||
|
||||
# WebRTC (optional - for WebRTC transport)
|
||||
aiortc>=1.6.0
|
||||
|
||||
# Audio Processing
|
||||
av>=12.1.0
|
||||
numpy>=1.26.3
|
||||
onnxruntime>=1.16.3
|
||||
|
||||
# Configuration
|
||||
pydantic>=2.5.3
|
||||
pydantic-settings>=2.1.0
|
||||
python-dotenv>=1.0.0
|
||||
toml>=0.10.2
|
||||
|
||||
# Logging
|
||||
loguru>=0.7.2
|
||||
|
||||
# HTTP Client
|
||||
aiohttp>=3.9.1
|
||||
|
||||
# AI Services - LLM
|
||||
openai>=1.0.0
|
||||
|
||||
# AI Services - TTS
|
||||
edge-tts>=6.1.0
|
||||
pydub>=0.25.0 # For audio format conversion
|
||||
|
||||
# Microphone client dependencies
|
||||
sounddevice>=0.4.6
|
||||
soundfile>=0.12.1
|
||||
pyaudio>=0.2.13 # More reliable audio on Windows
|
||||
1
scripts/README.md
Normal file
1
scripts/README.md
Normal file
@@ -0,0 +1 @@
|
||||
# Development Script
|
||||
311
scripts/generate_test_audio/generate_test_audio.py
Normal file
311
scripts/generate_test_audio/generate_test_audio.py
Normal file
@@ -0,0 +1,311 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Generate test audio file with utterances using SiliconFlow TTS API.
|
||||
|
||||
Creates a 16kHz mono WAV file with real speech segments separated by
|
||||
configurable silence (for VAD/testing).
|
||||
|
||||
Usage:
|
||||
python generate_test_audio.py [OPTIONS]
|
||||
|
||||
Options:
|
||||
-o, --output PATH Output WAV path (default: data/audio_examples/two_utterances_16k.wav)
|
||||
-u, --utterance TEXT Utterance text; repeat for multiple (ignored if -j is set)
|
||||
-j, --json PATH JSON file: array of strings or {"utterances": [...]}
|
||||
--silence-ms MS Silence in ms between utterances (default: 500)
|
||||
--lead-silence-ms MS Silence in ms at start (default: 200)
|
||||
--trail-silence-ms MS Silence in ms at end (default: 300)
|
||||
|
||||
Examples:
|
||||
# Default utterances and output
|
||||
python generate_test_audio.py
|
||||
|
||||
# Custom output path
|
||||
python generate_test_audio.py -o out.wav
|
||||
|
||||
# Utterances from command line
|
||||
python generate_test_audio.py -u "Hello" -u "World" -o test.wav
|
||||
|
||||
# Utterancgenerate_test_audio.py -j utterances.json -o test.wav
|
||||
|
||||
# Custom silence (1s between utterances)
|
||||
python generate_test_audio.py -u "One" -u "Two" --silence-ms 1000 -o test.wav
|
||||
|
||||
Requires SILICONFLOW_API_KEY in .env.
|
||||
"""
|
||||
|
||||
import wave
|
||||
import struct
|
||||
import argparse
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
||||
# Load .env file from project root
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
load_dotenv(project_root / ".env")
|
||||
|
||||
|
||||
# SiliconFlow TTS Configuration
|
||||
SILICONFLOW_API_URL = "https://api.siliconflow.cn/v1/audio/speech"
|
||||
SILICONFLOW_MODEL = "FunAudioLLM/CosyVoice2-0.5B"
|
||||
|
||||
# Available voices
|
||||
VOICES = {
|
||||
"alex": "FunAudioLLM/CosyVoice2-0.5B:alex",
|
||||
"anna": "FunAudioLLM/CosyVoice2-0.5B:anna",
|
||||
"bella": "FunAudioLLM/CosyVoice2-0.5B:bella",
|
||||
"benjamin": "FunAudioLLM/CosyVoice2-0.5B:benjamin",
|
||||
"charles": "FunAudioLLM/CosyVoice2-0.5B:charles",
|
||||
"claire": "FunAudioLLM/CosyVoice2-0.5B:claire",
|
||||
"david": "FunAudioLLM/CosyVoice2-0.5B:david",
|
||||
"diana": "FunAudioLLM/CosyVoice2-0.5B:diana",
|
||||
}
|
||||
|
||||
|
||||
def generate_silence(duration_ms: int, sample_rate: int = 16000) -> bytes:
|
||||
"""Generate silence as PCM bytes."""
|
||||
num_samples = int(sample_rate * (duration_ms / 1000.0))
|
||||
return b'\x00\x00' * num_samples
|
||||
|
||||
|
||||
async def synthesize_speech(
|
||||
text: str,
|
||||
api_key: str,
|
||||
voice: str = "anna",
|
||||
sample_rate: int = 16000,
|
||||
speed: float = 1.0
|
||||
) -> bytes:
|
||||
"""
|
||||
Synthesize speech using SiliconFlow TTS API.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize
|
||||
api_key: SiliconFlow API key
|
||||
voice: Voice name (alex, anna, bella, benjamin, charles, claire, david, diana)
|
||||
sample_rate: Output sample rate (8000, 16000, 24000, 32000, 44100)
|
||||
speed: Speech speed (0.25 to 4.0)
|
||||
|
||||
Returns:
|
||||
PCM audio bytes (16-bit signed, little-endian)
|
||||
"""
|
||||
# Resolve voice name
|
||||
full_voice = VOICES.get(voice, voice)
|
||||
|
||||
payload = {
|
||||
"model": SILICONFLOW_MODEL,
|
||||
"input": text,
|
||||
"voice": full_voice,
|
||||
"response_format": "pcm",
|
||||
"sample_rate": sample_rate,
|
||||
"stream": False,
|
||||
"speed": speed
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(SILICONFLOW_API_URL, json=payload, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise RuntimeError(f"SiliconFlow TTS error: {response.status} - {error_text}")
|
||||
|
||||
return await response.read()
|
||||
|
||||
|
||||
async def generate_test_audio(
|
||||
output_path: str,
|
||||
utterances: list[str],
|
||||
silence_ms: int = 500,
|
||||
lead_silence_ms: int = 200,
|
||||
trail_silence_ms: int = 300,
|
||||
voice: str = "anna",
|
||||
sample_rate: int = 16000,
|
||||
speed: float = 1.0
|
||||
):
|
||||
"""
|
||||
Generate test audio with multiple utterances separated by silence.
|
||||
|
||||
Args:
|
||||
output_path: Path to save the WAV file
|
||||
utterances: List of text strings for each utterance
|
||||
silence_ms: Silence duration between utterances (milliseconds)
|
||||
lead_silence_ms: Silence at the beginning (milliseconds)
|
||||
trail_silence_ms: Silence at the end (milliseconds)
|
||||
voice: TTS voice to use
|
||||
sample_rate: Audio sample rate
|
||||
speed: TTS speech speed
|
||||
"""
|
||||
api_key = os.getenv("SILICONFLOW_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"SILICONFLOW_API_KEY not found in environment.\n"
|
||||
"Please set it in your .env file:\n"
|
||||
" SILICONFLOW_API_KEY=your-api-key-here"
|
||||
)
|
||||
|
||||
print(f"Using SiliconFlow TTS API")
|
||||
print(f" Voice: {voice}")
|
||||
print(f" Sample rate: {sample_rate}Hz")
|
||||
print(f" Speed: {speed}x")
|
||||
print()
|
||||
|
||||
segments = []
|
||||
|
||||
# Lead-in silence
|
||||
if lead_silence_ms > 0:
|
||||
segments.append(generate_silence(lead_silence_ms, sample_rate))
|
||||
print(f" [silence: {lead_silence_ms}ms]")
|
||||
|
||||
# Generate each utterance with silence between
|
||||
for i, text in enumerate(utterances):
|
||||
print(f" Synthesizing utterance {i + 1}: \"{text}\"")
|
||||
audio = await synthesize_speech(
|
||||
text=text,
|
||||
api_key=api_key,
|
||||
voice=voice,
|
||||
sample_rate=sample_rate,
|
||||
speed=speed
|
||||
)
|
||||
segments.append(audio)
|
||||
|
||||
# Add silence between utterances (not after the last one)
|
||||
if i < len(utterances) - 1:
|
||||
segments.append(generate_silence(silence_ms, sample_rate))
|
||||
print(f" [silence: {silence_ms}ms]")
|
||||
|
||||
# Trail silence
|
||||
if trail_silence_ms > 0:
|
||||
segments.append(generate_silence(trail_silence_ms, sample_rate))
|
||||
print(f" [silence: {trail_silence_ms}ms]")
|
||||
|
||||
# Concatenate all segments
|
||||
audio_data = b''.join(segments)
|
||||
|
||||
# Write WAV file
|
||||
with wave.open(output_path, 'wb') as wf:
|
||||
wf.setnchannels(1) # Mono
|
||||
wf.setsampwidth(2) # 16-bit
|
||||
wf.setframerate(sample_rate)
|
||||
wf.writeframes(audio_data)
|
||||
|
||||
duration_sec = len(audio_data) / (sample_rate * 2)
|
||||
print()
|
||||
print(f"Generated: {output_path}")
|
||||
print(f" Duration: {duration_sec:.2f}s")
|
||||
print(f" Sample rate: {sample_rate}Hz")
|
||||
print(f" Format: 16-bit mono PCM WAV")
|
||||
print(f" Size: {len(audio_data):,} bytes")
|
||||
|
||||
|
||||
def load_utterances_from_json(path: Path) -> list[str]:
|
||||
"""
|
||||
Load utterances from a JSON file.
|
||||
|
||||
Accepts either:
|
||||
- A JSON array: ["utterance 1", "utterance 2"]
|
||||
- A JSON object with "utterances" key: {"utterances": ["a", "b"]}
|
||||
"""
|
||||
with open(path, encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
if isinstance(data, list):
|
||||
return [str(s) for s in data]
|
||||
if isinstance(data, dict) and "utterances" in data:
|
||||
return [str(s) for s in data["utterances"]]
|
||||
raise ValueError(
|
||||
f"JSON file must be an array of strings or an object with 'utterances' key. "
|
||||
f"Got: {type(data).__name__}"
|
||||
)
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command-line arguments."""
|
||||
script_dir = Path(__file__).parent
|
||||
default_output = script_dir.parent / "data" / "audio_examples" / "two_utterances_16k.wav"
|
||||
|
||||
parser = argparse.ArgumentParser(description="Generate test audio with SiliconFlow TTS (utterances + silence).")
|
||||
parser.add_argument(
|
||||
"-o", "--output",
|
||||
type=Path,
|
||||
default=default_output,
|
||||
help=f"Output WAV file path (default: {default_output})"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-u", "--utterance",
|
||||
action="append",
|
||||
dest="utterances",
|
||||
metavar="TEXT",
|
||||
help="Utterance text (repeat for multiple). Ignored if --json is set."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-j", "--json",
|
||||
type=Path,
|
||||
metavar="PATH",
|
||||
help="JSON file with utterances: array of strings or object with 'utterances' key"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--silence-ms",
|
||||
type=int,
|
||||
default=500,
|
||||
metavar="MS",
|
||||
help="Silence in ms between utterances (default: 500)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lead-silence-ms",
|
||||
type=int,
|
||||
default=200,
|
||||
metavar="MS",
|
||||
help="Silence in ms at start of file (default: 200)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trail-silence-ms",
|
||||
type=int,
|
||||
default=300,
|
||||
metavar="MS",
|
||||
help="Silence in ms at end of file (default: 300)"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main entry point."""
|
||||
args = parse_args()
|
||||
output_path = args.output
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Resolve utterances: JSON file > -u args > defaults
|
||||
if args.json is not None:
|
||||
if not args.json.is_file():
|
||||
raise FileNotFoundError(f"Utterances JSON file not found: {args.json}")
|
||||
utterances = load_utterances_from_json(args.json)
|
||||
if not utterances:
|
||||
raise ValueError(f"JSON file has no utterances: {args.json}")
|
||||
elif args.utterances:
|
||||
utterances = args.utterances
|
||||
else:
|
||||
utterances = [
|
||||
"Hello, how are you doing today?",
|
||||
"I'm doing great, thank you for asking!"
|
||||
]
|
||||
|
||||
await generate_test_audio(
|
||||
output_path=str(output_path),
|
||||
utterances=utterances,
|
||||
silence_ms=args.silence_ms,
|
||||
lead_silence_ms=args.lead_silence_ms,
|
||||
trail_silence_ms=args.trail_silence_ms,
|
||||
voice="anna",
|
||||
sample_rate=16000,
|
||||
speed=1.0
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
51
services/__init__.py
Normal file
51
services/__init__.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""AI Services package.
|
||||
|
||||
Provides ASR, LLM, TTS, and Realtime API services for voice conversation.
|
||||
"""
|
||||
|
||||
from services.base import (
|
||||
ServiceState,
|
||||
ASRResult,
|
||||
LLMMessage,
|
||||
TTSChunk,
|
||||
BaseASRService,
|
||||
BaseLLMService,
|
||||
BaseTTSService,
|
||||
)
|
||||
from services.llm import OpenAILLMService, MockLLMService
|
||||
from services.tts import EdgeTTSService, MockTTSService
|
||||
from services.asr import BufferedASRService, MockASRService
|
||||
from services.openai_compatible_asr import OpenAICompatibleASRService, SiliconFlowASRService
|
||||
from services.openai_compatible_tts import OpenAICompatibleTTSService, SiliconFlowTTSService
|
||||
from services.streaming_tts_adapter import StreamingTTSAdapter
|
||||
from services.realtime import RealtimeService, RealtimeConfig, RealtimePipeline
|
||||
|
||||
__all__ = [
|
||||
# Base classes
|
||||
"ServiceState",
|
||||
"ASRResult",
|
||||
"LLMMessage",
|
||||
"TTSChunk",
|
||||
"BaseASRService",
|
||||
"BaseLLMService",
|
||||
"BaseTTSService",
|
||||
# LLM
|
||||
"OpenAILLMService",
|
||||
"MockLLMService",
|
||||
# TTS
|
||||
"EdgeTTSService",
|
||||
"MockTTSService",
|
||||
# ASR
|
||||
"BufferedASRService",
|
||||
"MockASRService",
|
||||
"OpenAICompatibleASRService",
|
||||
"SiliconFlowASRService",
|
||||
# TTS (SiliconFlow)
|
||||
"OpenAICompatibleTTSService",
|
||||
"SiliconFlowTTSService",
|
||||
"StreamingTTSAdapter",
|
||||
# Realtime
|
||||
"RealtimeService",
|
||||
"RealtimeConfig",
|
||||
"RealtimePipeline",
|
||||
]
|
||||
147
services/asr.py
Normal file
147
services/asr.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""ASR (Automatic Speech Recognition) Service implementations.
|
||||
|
||||
Provides speech-to-text capabilities with streaming support.
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
from typing import AsyncIterator, Optional
|
||||
from loguru import logger
|
||||
|
||||
from services.base import BaseASRService, ASRResult, ServiceState
|
||||
|
||||
# Try to import websockets for streaming ASR
|
||||
try:
|
||||
import websockets
|
||||
WEBSOCKETS_AVAILABLE = True
|
||||
except ImportError:
|
||||
WEBSOCKETS_AVAILABLE = False
|
||||
|
||||
|
||||
class BufferedASRService(BaseASRService):
|
||||
"""
|
||||
Buffered ASR service that accumulates audio and provides
|
||||
a simple text accumulator for use with EOU detection.
|
||||
|
||||
This is a lightweight implementation that works with the
|
||||
existing VAD + EOU pattern without requiring external ASR.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sample_rate: int = 16000,
|
||||
language: str = "en"
|
||||
):
|
||||
super().__init__(sample_rate=sample_rate, language=language)
|
||||
|
||||
self._audio_buffer: bytes = b""
|
||||
self._current_text: str = ""
|
||||
self._transcript_queue: asyncio.Queue[ASRResult] = asyncio.Queue()
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""No connection needed for buffered ASR."""
|
||||
self.state = ServiceState.CONNECTED
|
||||
logger.info("Buffered ASR service connected")
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Clear buffers on disconnect."""
|
||||
self._audio_buffer = b""
|
||||
self._current_text = ""
|
||||
self.state = ServiceState.DISCONNECTED
|
||||
logger.info("Buffered ASR service disconnected")
|
||||
|
||||
async def send_audio(self, audio: bytes) -> None:
|
||||
"""Buffer audio for later processing."""
|
||||
self._audio_buffer += audio
|
||||
|
||||
async def receive_transcripts(self) -> AsyncIterator[ASRResult]:
|
||||
"""Yield transcription results."""
|
||||
while True:
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
self._transcript_queue.get(),
|
||||
timeout=0.1
|
||||
)
|
||||
yield result
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
def set_text(self, text: str) -> None:
|
||||
"""
|
||||
Set the current transcript text directly.
|
||||
|
||||
This allows external integration (e.g., Whisper, other ASR)
|
||||
to provide transcripts.
|
||||
"""
|
||||
self._current_text = text
|
||||
result = ASRResult(text=text, is_final=False)
|
||||
asyncio.create_task(self._transcript_queue.put(result))
|
||||
|
||||
def get_and_clear_text(self) -> str:
|
||||
"""Get accumulated text and clear buffer."""
|
||||
text = self._current_text
|
||||
self._current_text = ""
|
||||
self._audio_buffer = b""
|
||||
return text
|
||||
|
||||
def get_audio_buffer(self) -> bytes:
|
||||
"""Get accumulated audio buffer."""
|
||||
return self._audio_buffer
|
||||
|
||||
def clear_audio_buffer(self) -> None:
|
||||
"""Clear audio buffer."""
|
||||
self._audio_buffer = b""
|
||||
|
||||
|
||||
class MockASRService(BaseASRService):
|
||||
"""
|
||||
Mock ASR service for testing without actual recognition.
|
||||
"""
|
||||
|
||||
def __init__(self, sample_rate: int = 16000, language: str = "en"):
|
||||
super().__init__(sample_rate=sample_rate, language=language)
|
||||
self._transcript_queue: asyncio.Queue[ASRResult] = asyncio.Queue()
|
||||
self._mock_texts = [
|
||||
"Hello, how are you?",
|
||||
"That's interesting.",
|
||||
"Tell me more about that.",
|
||||
"I understand.",
|
||||
]
|
||||
self._text_index = 0
|
||||
|
||||
async def connect(self) -> None:
|
||||
self.state = ServiceState.CONNECTED
|
||||
logger.info("Mock ASR service connected")
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
self.state = ServiceState.DISCONNECTED
|
||||
logger.info("Mock ASR service disconnected")
|
||||
|
||||
async def send_audio(self, audio: bytes) -> None:
|
||||
"""Mock audio processing - generates fake transcripts periodically."""
|
||||
pass
|
||||
|
||||
def trigger_transcript(self) -> None:
|
||||
"""Manually trigger a transcript (for testing)."""
|
||||
text = self._mock_texts[self._text_index % len(self._mock_texts)]
|
||||
self._text_index += 1
|
||||
|
||||
result = ASRResult(text=text, is_final=True, confidence=0.95)
|
||||
asyncio.create_task(self._transcript_queue.put(result))
|
||||
|
||||
async def receive_transcripts(self) -> AsyncIterator[ASRResult]:
|
||||
"""Yield transcription results."""
|
||||
while True:
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
self._transcript_queue.get(),
|
||||
timeout=0.1
|
||||
)
|
||||
yield result
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
253
services/base.py
Normal file
253
services/base.py
Normal file
@@ -0,0 +1,253 @@
|
||||
"""Base classes for AI services.
|
||||
|
||||
Defines abstract interfaces for ASR, LLM, and TTS services,
|
||||
inspired by pipecat's service architecture and active-call's
|
||||
StreamEngine pattern.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import AsyncIterator, Optional, List, Dict, Any, Literal
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ServiceState(Enum):
|
||||
"""Service connection state."""
|
||||
DISCONNECTED = "disconnected"
|
||||
CONNECTING = "connecting"
|
||||
CONNECTED = "connected"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ASRResult:
|
||||
"""ASR transcription result."""
|
||||
text: str
|
||||
is_final: bool = False
|
||||
confidence: float = 1.0
|
||||
language: Optional[str] = None
|
||||
start_time: Optional[float] = None
|
||||
end_time: Optional[float] = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
status = "FINAL" if self.is_final else "PARTIAL"
|
||||
return f"[{status}] {self.text}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMMessage:
|
||||
"""LLM conversation message."""
|
||||
role: str # "system", "user", "assistant", "function"
|
||||
content: str
|
||||
name: Optional[str] = None # For function calls
|
||||
function_call: Optional[Dict[str, Any]] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to API-compatible dict."""
|
||||
d = {"role": self.role, "content": self.content}
|
||||
if self.name:
|
||||
d["name"] = self.name
|
||||
if self.function_call:
|
||||
d["function_call"] = self.function_call
|
||||
return d
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMStreamEvent:
|
||||
"""Structured LLM stream event."""
|
||||
|
||||
type: Literal["text_delta", "tool_call", "done"]
|
||||
text: Optional[str] = None
|
||||
tool_call: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TTSChunk:
|
||||
"""TTS audio chunk."""
|
||||
audio: bytes # PCM audio data
|
||||
sample_rate: int = 16000
|
||||
channels: int = 1
|
||||
bits_per_sample: int = 16
|
||||
is_final: bool = False
|
||||
text_offset: Optional[int] = None # Character offset in original text
|
||||
|
||||
|
||||
class BaseASRService(ABC):
|
||||
"""
|
||||
Abstract base class for ASR (Speech-to-Text) services.
|
||||
|
||||
Supports both streaming and non-streaming transcription.
|
||||
"""
|
||||
|
||||
def __init__(self, sample_rate: int = 16000, language: str = "en"):
|
||||
self.sample_rate = sample_rate
|
||||
self.language = language
|
||||
self.state = ServiceState.DISCONNECTED
|
||||
|
||||
@abstractmethod
|
||||
async def connect(self) -> None:
|
||||
"""Establish connection to ASR service."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def disconnect(self) -> None:
|
||||
"""Close connection to ASR service."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def send_audio(self, audio: bytes) -> None:
|
||||
"""
|
||||
Send audio chunk for transcription.
|
||||
|
||||
Args:
|
||||
audio: PCM audio data (16-bit, mono)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def receive_transcripts(self) -> AsyncIterator[ASRResult]:
|
||||
"""
|
||||
Receive transcription results.
|
||||
|
||||
Yields:
|
||||
ASRResult objects as they become available
|
||||
"""
|
||||
pass
|
||||
|
||||
async def transcribe(self, audio: bytes) -> ASRResult:
|
||||
"""
|
||||
Transcribe a complete audio buffer (non-streaming).
|
||||
|
||||
Args:
|
||||
audio: Complete PCM audio data
|
||||
|
||||
Returns:
|
||||
Final ASRResult
|
||||
"""
|
||||
# Default implementation using streaming
|
||||
await self.send_audio(audio)
|
||||
async for result in self.receive_transcripts():
|
||||
if result.is_final:
|
||||
return result
|
||||
return ASRResult(text="", is_final=True)
|
||||
|
||||
|
||||
class BaseLLMService(ABC):
|
||||
"""
|
||||
Abstract base class for LLM (Language Model) services.
|
||||
|
||||
Supports streaming responses for real-time conversation.
|
||||
"""
|
||||
|
||||
def __init__(self, model: str = "gpt-4"):
|
||||
self.model = model
|
||||
self.state = ServiceState.DISCONNECTED
|
||||
|
||||
@abstractmethod
|
||||
async def connect(self) -> None:
|
||||
"""Initialize LLM service connection."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def disconnect(self) -> None:
|
||||
"""Close LLM service connection."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def generate(
|
||||
self,
|
||||
messages: List[LLMMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None
|
||||
) -> str:
|
||||
"""
|
||||
Generate a complete response.
|
||||
|
||||
Args:
|
||||
messages: Conversation history
|
||||
temperature: Sampling temperature
|
||||
max_tokens: Maximum tokens to generate
|
||||
|
||||
Returns:
|
||||
Complete response text
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def generate_stream(
|
||||
self,
|
||||
messages: List[LLMMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None
|
||||
) -> AsyncIterator[LLMStreamEvent]:
|
||||
"""
|
||||
Generate response in streaming mode.
|
||||
|
||||
Args:
|
||||
messages: Conversation history
|
||||
temperature: Sampling temperature
|
||||
max_tokens: Maximum tokens to generate
|
||||
|
||||
Yields:
|
||||
Stream events (text delta/tool call/done)
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class BaseTTSService(ABC):
|
||||
"""
|
||||
Abstract base class for TTS (Text-to-Speech) services.
|
||||
|
||||
Supports streaming audio synthesis for low-latency playback.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
voice: str = "default",
|
||||
sample_rate: int = 16000,
|
||||
speed: float = 1.0
|
||||
):
|
||||
self.voice = voice
|
||||
self.sample_rate = sample_rate
|
||||
self.speed = speed
|
||||
self.state = ServiceState.DISCONNECTED
|
||||
|
||||
@abstractmethod
|
||||
async def connect(self) -> None:
|
||||
"""Initialize TTS service connection."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def disconnect(self) -> None:
|
||||
"""Close TTS service connection."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def synthesize(self, text: str) -> bytes:
|
||||
"""
|
||||
Synthesize complete audio for text (non-streaming).
|
||||
|
||||
Args:
|
||||
text: Text to synthesize
|
||||
|
||||
Returns:
|
||||
Complete PCM audio data
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def synthesize_stream(self, text: str) -> AsyncIterator[TTSChunk]:
|
||||
"""
|
||||
Synthesize audio in streaming mode.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize
|
||||
|
||||
Yields:
|
||||
TTSChunk objects as audio is generated
|
||||
"""
|
||||
pass
|
||||
|
||||
async def cancel(self) -> None:
|
||||
"""Cancel ongoing synthesis (for barge-in support)."""
|
||||
pass
|
||||
443
services/llm.py
Normal file
443
services/llm.py
Normal file
@@ -0,0 +1,443 @@
|
||||
"""LLM (Large Language Model) Service implementations.
|
||||
|
||||
Provides OpenAI-compatible LLM integration with streaming support
|
||||
for real-time voice conversation.
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import AsyncIterator, Optional, List, Dict, Any
|
||||
from loguru import logger
|
||||
|
||||
from app.backend_client import search_knowledge_context
|
||||
from services.base import BaseLLMService, LLMMessage, LLMStreamEvent, ServiceState
|
||||
|
||||
# Try to import openai
|
||||
try:
|
||||
from openai import AsyncOpenAI
|
||||
OPENAI_AVAILABLE = True
|
||||
except ImportError:
|
||||
OPENAI_AVAILABLE = False
|
||||
logger.warning("openai package not available - LLM service will be disabled")
|
||||
|
||||
|
||||
class OpenAILLMService(BaseLLMService):
|
||||
"""
|
||||
OpenAI-compatible LLM service.
|
||||
|
||||
Supports streaming responses for low-latency voice conversation.
|
||||
Works with OpenAI API, Azure OpenAI, and compatible APIs.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "gpt-4o-mini",
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
knowledge_config: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize OpenAI LLM service.
|
||||
|
||||
Args:
|
||||
model: Model name (e.g., "gpt-4o-mini", "gpt-4o")
|
||||
api_key: OpenAI API key (defaults to OPENAI_API_KEY env var)
|
||||
base_url: Custom API base URL (for Azure or compatible APIs)
|
||||
system_prompt: Default system prompt for conversations
|
||||
"""
|
||||
super().__init__(model=model)
|
||||
|
||||
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||
self.base_url = base_url or os.getenv("OPENAI_API_URL")
|
||||
self.system_prompt = system_prompt or (
|
||||
"You are a helpful, friendly voice assistant. "
|
||||
"Keep your responses concise and conversational. "
|
||||
"Respond naturally as if having a phone conversation."
|
||||
)
|
||||
|
||||
self.client: Optional[AsyncOpenAI] = None
|
||||
self._cancel_event = asyncio.Event()
|
||||
self._knowledge_config: Dict[str, Any] = knowledge_config or {}
|
||||
self._tool_schemas: List[Dict[str, Any]] = []
|
||||
|
||||
_RAG_DEFAULT_RESULTS = 5
|
||||
_RAG_MAX_RESULTS = 8
|
||||
_RAG_MAX_CONTEXT_CHARS = 4000
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Initialize OpenAI client."""
|
||||
if not OPENAI_AVAILABLE:
|
||||
raise RuntimeError("openai package not installed")
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("OpenAI API key not provided")
|
||||
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url
|
||||
)
|
||||
self.state = ServiceState.CONNECTED
|
||||
logger.info(f"OpenAI LLM service connected: model={self.model}")
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Close OpenAI client."""
|
||||
if self.client:
|
||||
await self.client.close()
|
||||
self.client = None
|
||||
self.state = ServiceState.DISCONNECTED
|
||||
logger.info("OpenAI LLM service disconnected")
|
||||
|
||||
def _prepare_messages(self, messages: List[LLMMessage]) -> List[Dict[str, Any]]:
|
||||
"""Prepare messages list with system prompt."""
|
||||
result = []
|
||||
|
||||
# Add system prompt if not already present
|
||||
has_system = any(m.role == "system" for m in messages)
|
||||
if not has_system and self.system_prompt:
|
||||
result.append({"role": "system", "content": self.system_prompt})
|
||||
|
||||
# Add all messages
|
||||
for msg in messages:
|
||||
result.append(msg.to_dict())
|
||||
|
||||
return result
|
||||
|
||||
def set_knowledge_config(self, config: Optional[Dict[str, Any]]) -> None:
|
||||
"""Update runtime knowledge retrieval config."""
|
||||
self._knowledge_config = config or {}
|
||||
|
||||
def set_tool_schemas(self, schemas: Optional[List[Dict[str, Any]]]) -> None:
|
||||
"""Update runtime tool schemas."""
|
||||
self._tool_schemas = []
|
||||
if not isinstance(schemas, list):
|
||||
return
|
||||
for item in schemas:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
fn = item.get("function")
|
||||
if isinstance(fn, dict) and fn.get("name"):
|
||||
self._tool_schemas.append(item)
|
||||
elif item.get("name"):
|
||||
self._tool_schemas.append(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": str(item.get("name")),
|
||||
"description": str(item.get("description") or ""),
|
||||
"parameters": item.get("parameters") or {"type": "object", "properties": {}},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _coerce_int(value: Any, default: int) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
def _resolve_kb_id(self) -> Optional[str]:
|
||||
cfg = self._knowledge_config if isinstance(self._knowledge_config, dict) else {}
|
||||
kb_id = str(
|
||||
cfg.get("kbId")
|
||||
or cfg.get("knowledgeBaseId")
|
||||
or cfg.get("knowledge_base_id")
|
||||
or ""
|
||||
).strip()
|
||||
return kb_id or None
|
||||
|
||||
def _build_knowledge_prompt(self, results: List[Dict[str, Any]]) -> Optional[str]:
|
||||
if not results:
|
||||
return None
|
||||
|
||||
lines = [
|
||||
"You have retrieved the following knowledge base snippets.",
|
||||
"Use them only when relevant to the latest user request.",
|
||||
"If snippets are insufficient, say you are not sure instead of guessing.",
|
||||
"",
|
||||
]
|
||||
|
||||
used_chars = 0
|
||||
used_count = 0
|
||||
for item in results:
|
||||
content = str(item.get("content") or "").strip()
|
||||
if not content:
|
||||
continue
|
||||
if used_chars >= self._RAG_MAX_CONTEXT_CHARS:
|
||||
break
|
||||
|
||||
metadata = item.get("metadata") if isinstance(item.get("metadata"), dict) else {}
|
||||
doc_id = metadata.get("document_id")
|
||||
chunk_index = metadata.get("chunk_index")
|
||||
distance = item.get("distance")
|
||||
|
||||
source_parts = []
|
||||
if doc_id:
|
||||
source_parts.append(f"doc={doc_id}")
|
||||
if chunk_index is not None:
|
||||
source_parts.append(f"chunk={chunk_index}")
|
||||
source = f" ({', '.join(source_parts)})" if source_parts else ""
|
||||
|
||||
distance_text = ""
|
||||
try:
|
||||
if distance is not None:
|
||||
distance_text = f", distance={float(distance):.4f}"
|
||||
except (TypeError, ValueError):
|
||||
distance_text = ""
|
||||
|
||||
remaining = self._RAG_MAX_CONTEXT_CHARS - used_chars
|
||||
snippet = content[:remaining].strip()
|
||||
if not snippet:
|
||||
continue
|
||||
|
||||
used_count += 1
|
||||
lines.append(f"[{used_count}{source}{distance_text}] {snippet}")
|
||||
used_chars += len(snippet)
|
||||
|
||||
if used_count == 0:
|
||||
return None
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
async def _with_knowledge_context(self, messages: List[LLMMessage]) -> List[LLMMessage]:
|
||||
cfg = self._knowledge_config if isinstance(self._knowledge_config, dict) else {}
|
||||
enabled = cfg.get("enabled", True)
|
||||
if isinstance(enabled, str):
|
||||
enabled = enabled.strip().lower() not in {"false", "0", "off", "no"}
|
||||
if not enabled:
|
||||
return messages
|
||||
|
||||
kb_id = self._resolve_kb_id()
|
||||
if not kb_id:
|
||||
return messages
|
||||
|
||||
latest_user = ""
|
||||
for msg in reversed(messages):
|
||||
if msg.role == "user":
|
||||
latest_user = (msg.content or "").strip()
|
||||
break
|
||||
if not latest_user:
|
||||
return messages
|
||||
|
||||
n_results = self._coerce_int(cfg.get("nResults"), self._RAG_DEFAULT_RESULTS)
|
||||
n_results = max(1, min(n_results, self._RAG_MAX_RESULTS))
|
||||
|
||||
results = await search_knowledge_context(
|
||||
kb_id=kb_id,
|
||||
query=latest_user,
|
||||
n_results=n_results,
|
||||
)
|
||||
prompt = self._build_knowledge_prompt(results)
|
||||
if not prompt:
|
||||
return messages
|
||||
|
||||
logger.debug(f"RAG context injected (kb_id={kb_id}, chunks={len(results)})")
|
||||
rag_system = LLMMessage(role="system", content=prompt)
|
||||
if messages and messages[0].role == "system":
|
||||
return [messages[0], rag_system, *messages[1:]]
|
||||
return [rag_system, *messages]
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
messages: List[LLMMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None
|
||||
) -> str:
|
||||
"""
|
||||
Generate a complete response.
|
||||
|
||||
Args:
|
||||
messages: Conversation history
|
||||
temperature: Sampling temperature
|
||||
max_tokens: Maximum tokens to generate
|
||||
|
||||
Returns:
|
||||
Complete response text
|
||||
"""
|
||||
if not self.client:
|
||||
raise RuntimeError("LLM service not connected")
|
||||
|
||||
rag_messages = await self._with_knowledge_context(messages)
|
||||
prepared = self._prepare_messages(rag_messages)
|
||||
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=prepared,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content or ""
|
||||
logger.debug(f"LLM response: {content[:100]}...")
|
||||
return content
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM generation error: {e}")
|
||||
raise
|
||||
|
||||
async def generate_stream(
|
||||
self,
|
||||
messages: List[LLMMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None
|
||||
) -> AsyncIterator[LLMStreamEvent]:
|
||||
"""
|
||||
Generate response in streaming mode.
|
||||
|
||||
Args:
|
||||
messages: Conversation history
|
||||
temperature: Sampling temperature
|
||||
max_tokens: Maximum tokens to generate
|
||||
|
||||
Yields:
|
||||
Structured stream events
|
||||
"""
|
||||
if not self.client:
|
||||
raise RuntimeError("LLM service not connected")
|
||||
|
||||
rag_messages = await self._with_knowledge_context(messages)
|
||||
prepared = self._prepare_messages(rag_messages)
|
||||
self._cancel_event.clear()
|
||||
tool_accumulator: Dict[int, Dict[str, str]] = {}
|
||||
openai_tools = self._tool_schemas or None
|
||||
|
||||
try:
|
||||
create_args: Dict[str, Any] = dict(
|
||||
model=self.model,
|
||||
messages=prepared,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=True,
|
||||
)
|
||||
if openai_tools:
|
||||
create_args["tools"] = openai_tools
|
||||
create_args["tool_choice"] = "auto"
|
||||
stream = await self.client.chat.completions.create(**create_args)
|
||||
|
||||
async for chunk in stream:
|
||||
# Check for cancellation
|
||||
if self._cancel_event.is_set():
|
||||
logger.info("LLM stream cancelled")
|
||||
break
|
||||
|
||||
if not chunk.choices:
|
||||
continue
|
||||
|
||||
choice = chunk.choices[0]
|
||||
delta = getattr(choice, "delta", None)
|
||||
if delta and getattr(delta, "content", None):
|
||||
content = delta.content
|
||||
yield LLMStreamEvent(type="text_delta", text=content)
|
||||
|
||||
# OpenAI streams function calls via incremental tool_calls deltas.
|
||||
tool_calls = getattr(delta, "tool_calls", None) if delta else None
|
||||
if tool_calls:
|
||||
for tc in tool_calls:
|
||||
index = getattr(tc, "index", 0) or 0
|
||||
item = tool_accumulator.setdefault(
|
||||
int(index),
|
||||
{"id": "", "name": "", "arguments": ""},
|
||||
)
|
||||
tc_id = getattr(tc, "id", None)
|
||||
if tc_id:
|
||||
item["id"] = str(tc_id)
|
||||
fn = getattr(tc, "function", None)
|
||||
if fn:
|
||||
fn_name = getattr(fn, "name", None)
|
||||
if fn_name:
|
||||
item["name"] = str(fn_name)
|
||||
fn_args = getattr(fn, "arguments", None)
|
||||
if fn_args:
|
||||
item["arguments"] += str(fn_args)
|
||||
|
||||
finish_reason = getattr(choice, "finish_reason", None)
|
||||
if finish_reason == "tool_calls" and tool_accumulator:
|
||||
for _, payload in sorted(tool_accumulator.items(), key=lambda row: row[0]):
|
||||
call_name = payload.get("name", "").strip()
|
||||
if not call_name:
|
||||
continue
|
||||
call_id = payload.get("id", "").strip() or f"call_{uuid.uuid4().hex[:10]}"
|
||||
yield LLMStreamEvent(
|
||||
type="tool_call",
|
||||
tool_call={
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": call_name,
|
||||
"arguments": payload.get("arguments", "") or "{}",
|
||||
},
|
||||
},
|
||||
)
|
||||
yield LLMStreamEvent(type="done")
|
||||
return
|
||||
|
||||
if finish_reason in {"stop", "length", "content_filter"}:
|
||||
yield LLMStreamEvent(type="done")
|
||||
return
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("LLM stream cancelled via asyncio")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"LLM streaming error: {e}")
|
||||
raise
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""Cancel ongoing generation."""
|
||||
self._cancel_event.set()
|
||||
|
||||
|
||||
class MockLLMService(BaseLLMService):
|
||||
"""
|
||||
Mock LLM service for testing without API calls.
|
||||
"""
|
||||
|
||||
def __init__(self, response_delay: float = 0.5):
|
||||
super().__init__(model="mock")
|
||||
self.response_delay = response_delay
|
||||
self.responses = [
|
||||
"Hello! How can I help you today?",
|
||||
"That's an interesting question. Let me think about it.",
|
||||
"I understand. Is there anything else you'd like to know?",
|
||||
"Great! I'm here if you need anything else.",
|
||||
]
|
||||
self._response_index = 0
|
||||
|
||||
async def connect(self) -> None:
|
||||
self.state = ServiceState.CONNECTED
|
||||
logger.info("Mock LLM service connected")
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
self.state = ServiceState.DISCONNECTED
|
||||
logger.info("Mock LLM service disconnected")
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
messages: List[LLMMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None
|
||||
) -> str:
|
||||
await asyncio.sleep(self.response_delay)
|
||||
response = self.responses[self._response_index % len(self.responses)]
|
||||
self._response_index += 1
|
||||
return response
|
||||
|
||||
async def generate_stream(
|
||||
self,
|
||||
messages: List[LLMMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None
|
||||
) -> AsyncIterator[LLMStreamEvent]:
|
||||
response = await self.generate(messages, temperature, max_tokens)
|
||||
|
||||
# Stream word by word
|
||||
words = response.split()
|
||||
for i, word in enumerate(words):
|
||||
if i > 0:
|
||||
yield LLMStreamEvent(type="text_delta", text=" ")
|
||||
yield LLMStreamEvent(type="text_delta", text=word)
|
||||
await asyncio.sleep(0.05) # Simulate streaming delay
|
||||
yield LLMStreamEvent(type="done")
|
||||
321
services/openai_compatible_asr.py
Normal file
321
services/openai_compatible_asr.py
Normal file
@@ -0,0 +1,321 @@
|
||||
"""OpenAI-compatible ASR (Automatic Speech Recognition) Service.
|
||||
|
||||
Uses the SiliconFlow API for speech-to-text transcription.
|
||||
API: https://docs.siliconflow.cn/cn/api-reference/audio/create-audio-transcriptions
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import wave
|
||||
from typing import AsyncIterator, Optional, Callable, Awaitable
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
AIOHTTP_AVAILABLE = True
|
||||
except ImportError:
|
||||
AIOHTTP_AVAILABLE = False
|
||||
logger.warning("aiohttp not available - OpenAICompatibleASRService will not work")
|
||||
|
||||
from services.base import BaseASRService, ASRResult, ServiceState
|
||||
|
||||
|
||||
class OpenAICompatibleASRService(BaseASRService):
|
||||
"""
|
||||
OpenAI-compatible ASR service for speech-to-text transcription.
|
||||
|
||||
Features:
|
||||
- Buffers incoming audio chunks
|
||||
- Provides interim transcriptions periodically (for streaming to client)
|
||||
- Final transcription on EOU
|
||||
|
||||
API Details:
|
||||
- Endpoint: POST https://api.siliconflow.cn/v1/audio/transcriptions
|
||||
- Models: FunAudioLLM/SenseVoiceSmall (default), TeleAI/TeleSpeechASR
|
||||
- Input: Audio file (multipart/form-data)
|
||||
- Output: {"text": "transcribed text"}
|
||||
"""
|
||||
|
||||
# Supported models
|
||||
MODELS = {
|
||||
"sensevoice": "FunAudioLLM/SenseVoiceSmall",
|
||||
"telespeech": "TeleAI/TeleSpeechASR",
|
||||
}
|
||||
|
||||
API_URL = "https://api.siliconflow.cn/v1/audio/transcriptions"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
model: str = "FunAudioLLM/SenseVoiceSmall",
|
||||
sample_rate: int = 16000,
|
||||
language: str = "auto",
|
||||
interim_interval_ms: int = 500, # How often to send interim results
|
||||
min_audio_for_interim_ms: int = 300, # Min audio before first interim
|
||||
on_transcript: Optional[Callable[[str, bool], Awaitable[None]]] = None
|
||||
):
|
||||
"""
|
||||
Initialize OpenAI-compatible ASR service.
|
||||
|
||||
Args:
|
||||
api_key: Provider API key
|
||||
model: ASR model name or alias
|
||||
sample_rate: Audio sample rate (16000 recommended)
|
||||
language: Language code (auto for automatic detection)
|
||||
interim_interval_ms: How often to generate interim transcriptions
|
||||
min_audio_for_interim_ms: Minimum audio duration before first interim
|
||||
on_transcript: Callback for transcription results (text, is_final)
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, language=language)
|
||||
|
||||
if not AIOHTTP_AVAILABLE:
|
||||
raise RuntimeError("aiohttp is required for OpenAICompatibleASRService")
|
||||
|
||||
self.api_key = api_key
|
||||
self.model = self.MODELS.get(model.lower(), model)
|
||||
self.interim_interval_ms = interim_interval_ms
|
||||
self.min_audio_for_interim_ms = min_audio_for_interim_ms
|
||||
self.on_transcript = on_transcript
|
||||
|
||||
# Session
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
# Audio buffer
|
||||
self._audio_buffer: bytes = b""
|
||||
self._current_text: str = ""
|
||||
self._last_interim_time: float = 0
|
||||
|
||||
# Transcript queue for async iteration
|
||||
self._transcript_queue: asyncio.Queue[ASRResult] = asyncio.Queue()
|
||||
|
||||
# Background task for interim results
|
||||
self._interim_task: Optional[asyncio.Task] = None
|
||||
self._running = False
|
||||
|
||||
logger.info(f"OpenAICompatibleASRService initialized with model: {self.model}")
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Connect to the service."""
|
||||
self._session = aiohttp.ClientSession(
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
)
|
||||
self._running = True
|
||||
self.state = ServiceState.CONNECTED
|
||||
logger.info("OpenAICompatibleASRService connected")
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect and cleanup."""
|
||||
self._running = False
|
||||
|
||||
if self._interim_task:
|
||||
self._interim_task.cancel()
|
||||
try:
|
||||
await self._interim_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._interim_task = None
|
||||
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
|
||||
self._audio_buffer = b""
|
||||
self._current_text = ""
|
||||
self.state = ServiceState.DISCONNECTED
|
||||
logger.info("OpenAICompatibleASRService disconnected")
|
||||
|
||||
async def send_audio(self, audio: bytes) -> None:
|
||||
"""
|
||||
Buffer incoming audio data.
|
||||
|
||||
Args:
|
||||
audio: PCM audio data (16-bit, mono)
|
||||
"""
|
||||
self._audio_buffer += audio
|
||||
|
||||
async def transcribe_buffer(self, is_final: bool = False) -> Optional[str]:
|
||||
"""
|
||||
Transcribe current audio buffer.
|
||||
|
||||
Args:
|
||||
is_final: Whether this is the final transcription
|
||||
|
||||
Returns:
|
||||
Transcribed text or None if not enough audio
|
||||
"""
|
||||
if not self._session:
|
||||
logger.warning("ASR session not connected")
|
||||
return None
|
||||
|
||||
# Check minimum audio duration
|
||||
audio_duration_ms = len(self._audio_buffer) / (self.sample_rate * 2) * 1000
|
||||
|
||||
if not is_final and audio_duration_ms < self.min_audio_for_interim_ms:
|
||||
return None
|
||||
|
||||
if audio_duration_ms < 100: # Less than 100ms - too short
|
||||
return None
|
||||
|
||||
try:
|
||||
# Convert PCM to WAV in memory
|
||||
wav_buffer = io.BytesIO()
|
||||
with wave.open(wav_buffer, 'wb') as wav_file:
|
||||
wav_file.setnchannels(1)
|
||||
wav_file.setsampwidth(2) # 16-bit
|
||||
wav_file.setframerate(self.sample_rate)
|
||||
wav_file.writeframes(self._audio_buffer)
|
||||
|
||||
wav_buffer.seek(0)
|
||||
wav_data = wav_buffer.read()
|
||||
|
||||
# Send to API
|
||||
form_data = aiohttp.FormData()
|
||||
form_data.add_field(
|
||||
'file',
|
||||
wav_data,
|
||||
filename='audio.wav',
|
||||
content_type='audio/wav'
|
||||
)
|
||||
form_data.add_field('model', self.model)
|
||||
|
||||
async with self._session.post(self.API_URL, data=form_data) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
text = result.get("text", "").strip()
|
||||
|
||||
if text:
|
||||
self._current_text = text
|
||||
|
||||
# Notify via callback
|
||||
if self.on_transcript:
|
||||
await self.on_transcript(text, is_final)
|
||||
|
||||
# Queue result
|
||||
await self._transcript_queue.put(
|
||||
ASRResult(text=text, is_final=is_final)
|
||||
)
|
||||
|
||||
logger.debug(f"ASR {'final' if is_final else 'interim'}: {text[:50]}...")
|
||||
return text
|
||||
else:
|
||||
error_text = await response.text()
|
||||
logger.error(f"ASR API error {response.status}: {error_text}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ASR transcription error: {e}")
|
||||
return None
|
||||
|
||||
async def get_final_transcription(self) -> str:
|
||||
"""
|
||||
Get final transcription and clear buffer.
|
||||
|
||||
Call this when EOU is detected.
|
||||
|
||||
Returns:
|
||||
Final transcribed text
|
||||
"""
|
||||
# Transcribe full buffer as final
|
||||
text = await self.transcribe_buffer(is_final=True)
|
||||
|
||||
# Clear buffer
|
||||
result = text or self._current_text
|
||||
self._audio_buffer = b""
|
||||
self._current_text = ""
|
||||
|
||||
return result
|
||||
|
||||
def get_and_clear_text(self) -> str:
|
||||
"""
|
||||
Get accumulated text and clear buffer.
|
||||
|
||||
Compatible with BufferedASRService interface.
|
||||
"""
|
||||
text = self._current_text
|
||||
self._current_text = ""
|
||||
self._audio_buffer = b""
|
||||
return text
|
||||
|
||||
def get_audio_buffer(self) -> bytes:
|
||||
"""Get current audio buffer."""
|
||||
return self._audio_buffer
|
||||
|
||||
def get_audio_duration_ms(self) -> float:
|
||||
"""Get current audio buffer duration in milliseconds."""
|
||||
return len(self._audio_buffer) / (self.sample_rate * 2) * 1000
|
||||
|
||||
def clear_buffer(self) -> None:
|
||||
"""Clear audio and text buffers."""
|
||||
self._audio_buffer = b""
|
||||
self._current_text = ""
|
||||
|
||||
async def receive_transcripts(self) -> AsyncIterator[ASRResult]:
|
||||
"""
|
||||
Async iterator for transcription results.
|
||||
|
||||
Yields:
|
||||
ASRResult with text and is_final flag
|
||||
"""
|
||||
while self._running:
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
self._transcript_queue.get(),
|
||||
timeout=0.1
|
||||
)
|
||||
yield result
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
async def start_interim_transcription(self) -> None:
|
||||
"""
|
||||
Start background task for interim transcriptions.
|
||||
|
||||
This periodically transcribes buffered audio for
|
||||
real-time feedback to the user.
|
||||
"""
|
||||
if self._interim_task and not self._interim_task.done():
|
||||
return
|
||||
|
||||
self._interim_task = asyncio.create_task(self._interim_loop())
|
||||
|
||||
async def stop_interim_transcription(self) -> None:
|
||||
"""Stop interim transcription task."""
|
||||
if self._interim_task:
|
||||
self._interim_task.cancel()
|
||||
try:
|
||||
await self._interim_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._interim_task = None
|
||||
|
||||
async def _interim_loop(self) -> None:
|
||||
"""Background loop for interim transcriptions."""
|
||||
import time
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
await asyncio.sleep(self.interim_interval_ms / 1000)
|
||||
|
||||
# Check if we have enough new audio
|
||||
current_time = time.time()
|
||||
time_since_last = (current_time - self._last_interim_time) * 1000
|
||||
|
||||
if time_since_last >= self.interim_interval_ms:
|
||||
audio_duration = self.get_audio_duration_ms()
|
||||
|
||||
if audio_duration >= self.min_audio_for_interim_ms:
|
||||
await self.transcribe_buffer(is_final=False)
|
||||
self._last_interim_time = current_time
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Interim transcription error: {e}")
|
||||
|
||||
|
||||
# Backward-compatible alias
|
||||
SiliconFlowASRService = OpenAICompatibleASRService
|
||||
324
services/openai_compatible_tts.py
Normal file
324
services/openai_compatible_tts.py
Normal file
@@ -0,0 +1,324 @@
|
||||
"""OpenAI-compatible TTS Service with streaming support.
|
||||
|
||||
Uses SiliconFlow's CosyVoice2 or MOSS-TTSD models for low-latency
|
||||
text-to-speech synthesis with streaming.
|
||||
|
||||
API Docs: https://docs.siliconflow.cn/cn/api-reference/audio/create-speech
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import aiohttp
|
||||
from typing import AsyncIterator, Optional
|
||||
from loguru import logger
|
||||
|
||||
from services.base import BaseTTSService, TTSChunk, ServiceState
|
||||
from services.streaming_tts_adapter import StreamingTTSAdapter # backward-compatible re-export
|
||||
|
||||
|
||||
class OpenAICompatibleTTSService(BaseTTSService):
|
||||
"""
|
||||
OpenAI-compatible TTS service with streaming support.
|
||||
|
||||
Supports CosyVoice2-0.5B and MOSS-TTSD-v0.5 models.
|
||||
"""
|
||||
|
||||
# Available voices
|
||||
VOICES = {
|
||||
"alex": "FunAudioLLM/CosyVoice2-0.5B:alex",
|
||||
"anna": "FunAudioLLM/CosyVoice2-0.5B:anna",
|
||||
"bella": "FunAudioLLM/CosyVoice2-0.5B:bella",
|
||||
"benjamin": "FunAudioLLM/CosyVoice2-0.5B:benjamin",
|
||||
"charles": "FunAudioLLM/CosyVoice2-0.5B:charles",
|
||||
"claire": "FunAudioLLM/CosyVoice2-0.5B:claire",
|
||||
"david": "FunAudioLLM/CosyVoice2-0.5B:david",
|
||||
"diana": "FunAudioLLM/CosyVoice2-0.5B:diana",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
voice: str = "anna",
|
||||
model: str = "FunAudioLLM/CosyVoice2-0.5B",
|
||||
sample_rate: int = 16000,
|
||||
speed: float = 1.0
|
||||
):
|
||||
"""
|
||||
Initialize OpenAI-compatible TTS service.
|
||||
|
||||
Args:
|
||||
api_key: Provider API key (defaults to SILICONFLOW_API_KEY env var)
|
||||
voice: Voice name (alex, anna, bella, benjamin, charles, claire, david, diana)
|
||||
model: Model name
|
||||
sample_rate: Output sample rate (8000, 16000, 24000, 32000, 44100)
|
||||
speed: Speech speed (0.25 to 4.0)
|
||||
"""
|
||||
# Resolve voice name (case-insensitive), and normalize "model:VoiceId" suffix.
|
||||
resolved_voice = (voice or "").strip()
|
||||
voice_lookup = resolved_voice.lower()
|
||||
if voice_lookup in self.VOICES:
|
||||
full_voice = self.VOICES[voice_lookup]
|
||||
elif ":" in resolved_voice:
|
||||
model_part, voice_part = resolved_voice.split(":", 1)
|
||||
normalized_voice_part = voice_part.strip().lower()
|
||||
if normalized_voice_part in self.VOICES:
|
||||
full_voice = f"{(model_part or model).strip()}:{normalized_voice_part}"
|
||||
else:
|
||||
full_voice = resolved_voice
|
||||
else:
|
||||
full_voice = resolved_voice
|
||||
|
||||
super().__init__(voice=full_voice, sample_rate=sample_rate, speed=speed)
|
||||
|
||||
self.api_key = api_key or os.getenv("SILICONFLOW_API_KEY")
|
||||
self.model = model
|
||||
self.api_url = "https://api.siliconflow.cn/v1/audio/speech"
|
||||
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
self._cancel_event = asyncio.Event()
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Initialize HTTP session."""
|
||||
if not self.api_key:
|
||||
raise ValueError("SiliconFlow API key not provided. Set SILICONFLOW_API_KEY env var.")
|
||||
|
||||
self._session = aiohttp.ClientSession(
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
)
|
||||
self.state = ServiceState.CONNECTED
|
||||
logger.info(f"SiliconFlow TTS service ready: voice={self.voice}, model={self.model}")
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Close HTTP session."""
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
self.state = ServiceState.DISCONNECTED
|
||||
logger.info("SiliconFlow TTS service disconnected")
|
||||
|
||||
async def synthesize(self, text: str) -> bytes:
|
||||
"""Synthesize complete audio for text."""
|
||||
audio_data = b""
|
||||
async for chunk in self.synthesize_stream(text):
|
||||
audio_data += chunk.audio
|
||||
return audio_data
|
||||
|
||||
async def synthesize_stream(self, text: str) -> AsyncIterator[TTSChunk]:
|
||||
"""
|
||||
Synthesize audio in streaming mode.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize
|
||||
|
||||
Yields:
|
||||
TTSChunk objects with PCM audio
|
||||
"""
|
||||
if not self._session:
|
||||
raise RuntimeError("TTS service not connected")
|
||||
|
||||
if not text.strip():
|
||||
return
|
||||
|
||||
self._cancel_event.clear()
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"input": text,
|
||||
"voice": self.voice,
|
||||
"response_format": "pcm",
|
||||
"sample_rate": self.sample_rate,
|
||||
"stream": True,
|
||||
"speed": self.speed
|
||||
}
|
||||
|
||||
try:
|
||||
async with self._session.post(self.api_url, json=payload) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"SiliconFlow TTS error: {response.status} - {error_text}")
|
||||
return
|
||||
|
||||
# Stream audio chunks
|
||||
chunk_size = self.sample_rate * 2 // 10 # 100ms chunks
|
||||
buffer = b""
|
||||
pending_chunk = None
|
||||
|
||||
async for chunk in response.content.iter_any():
|
||||
if self._cancel_event.is_set():
|
||||
logger.info("TTS synthesis cancelled")
|
||||
return
|
||||
|
||||
buffer += chunk
|
||||
|
||||
# Yield complete chunks
|
||||
while len(buffer) >= chunk_size:
|
||||
audio_chunk = buffer[:chunk_size]
|
||||
buffer = buffer[chunk_size:]
|
||||
|
||||
# Keep one full chunk buffered so we can always tag the true
|
||||
# last full chunk as final when stream length is an exact multiple.
|
||||
if pending_chunk is not None:
|
||||
yield TTSChunk(
|
||||
audio=pending_chunk,
|
||||
sample_rate=self.sample_rate,
|
||||
is_final=False
|
||||
)
|
||||
pending_chunk = audio_chunk
|
||||
|
||||
# Flush pending chunk(s) and remaining tail.
|
||||
if pending_chunk is not None:
|
||||
if buffer:
|
||||
yield TTSChunk(
|
||||
audio=pending_chunk,
|
||||
sample_rate=self.sample_rate,
|
||||
is_final=False
|
||||
)
|
||||
pending_chunk = None
|
||||
else:
|
||||
yield TTSChunk(
|
||||
audio=pending_chunk,
|
||||
sample_rate=self.sample_rate,
|
||||
is_final=True
|
||||
)
|
||||
pending_chunk = None
|
||||
|
||||
if buffer:
|
||||
yield TTSChunk(
|
||||
audio=buffer,
|
||||
sample_rate=self.sample_rate,
|
||||
is_final=True
|
||||
)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("TTS synthesis cancelled via asyncio")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"TTS synthesis error: {e}")
|
||||
raise
|
||||
|
||||
async def cancel(self) -> None:
|
||||
"""Cancel ongoing synthesis."""
|
||||
self._cancel_event.set()
|
||||
|
||||
|
||||
class StreamingTTSAdapter:
|
||||
"""
|
||||
Adapter for streaming LLM text to TTS with sentence-level chunking.
|
||||
|
||||
This reduces latency by starting TTS as soon as a complete sentence
|
||||
is received from the LLM, rather than waiting for the full response.
|
||||
"""
|
||||
|
||||
# Sentence delimiters
|
||||
SENTENCE_ENDS = {',', '。', '!', '?', '.', '!', '?', '\n'}
|
||||
|
||||
def __init__(self, tts_service: BaseTTSService, transport, session_id: str):
|
||||
self.tts_service = tts_service
|
||||
self.transport = transport
|
||||
self.session_id = session_id
|
||||
self._buffer = ""
|
||||
self._cancel_event = asyncio.Event()
|
||||
self._is_speaking = False
|
||||
|
||||
def _is_non_sentence_period(self, text: str, idx: int) -> bool:
|
||||
"""Check whether '.' should NOT be treated as a sentence delimiter."""
|
||||
if text[idx] != ".":
|
||||
return False
|
||||
|
||||
# Decimal/version segment: 1.2, v1.2.3
|
||||
if idx > 0 and idx < len(text) - 1 and text[idx - 1].isdigit() and text[idx + 1].isdigit():
|
||||
return True
|
||||
|
||||
# Number abbreviations: No.1 / No. 1
|
||||
left_start = idx - 1
|
||||
while left_start >= 0 and text[left_start].isalpha():
|
||||
left_start -= 1
|
||||
left_token = text[left_start + 1:idx].lower()
|
||||
if left_token == "no":
|
||||
j = idx + 1
|
||||
while j < len(text) and text[j].isspace():
|
||||
j += 1
|
||||
if j < len(text) and text[j].isdigit():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def process_text_chunk(self, text_chunk: str) -> None:
|
||||
"""
|
||||
Process a text chunk from LLM and trigger TTS when sentence is complete.
|
||||
|
||||
Args:
|
||||
text_chunk: Text chunk from LLM streaming
|
||||
"""
|
||||
if self._cancel_event.is_set():
|
||||
return
|
||||
|
||||
self._buffer += text_chunk
|
||||
|
||||
# Check for sentence completion
|
||||
while True:
|
||||
split_idx = -1
|
||||
for i, char in enumerate(self._buffer):
|
||||
if char == "." and self._is_non_sentence_period(self._buffer, i):
|
||||
continue
|
||||
if char in self.SENTENCE_ENDS:
|
||||
split_idx = i
|
||||
break
|
||||
if split_idx < 0:
|
||||
break
|
||||
|
||||
end_idx = split_idx + 1
|
||||
while end_idx < len(self._buffer) and self._buffer[end_idx] in self.SENTENCE_ENDS:
|
||||
end_idx += 1
|
||||
|
||||
sentence = self._buffer[:end_idx].strip()
|
||||
self._buffer = self._buffer[end_idx:]
|
||||
|
||||
if sentence and any(ch.isalnum() for ch in sentence):
|
||||
await self._speak_sentence(sentence)
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Flush remaining buffer."""
|
||||
if self._buffer.strip() and not self._cancel_event.is_set():
|
||||
await self._speak_sentence(self._buffer.strip())
|
||||
self._buffer = ""
|
||||
|
||||
async def _speak_sentence(self, text: str) -> None:
|
||||
"""Synthesize and send a sentence."""
|
||||
if not text or self._cancel_event.is_set():
|
||||
return
|
||||
|
||||
self._is_speaking = True
|
||||
|
||||
try:
|
||||
async for chunk in self.tts_service.synthesize_stream(text):
|
||||
if self._cancel_event.is_set():
|
||||
break
|
||||
await self.transport.send_audio(chunk.audio)
|
||||
await asyncio.sleep(0.01) # Prevent flooding
|
||||
except Exception as e:
|
||||
logger.error(f"TTS speak error: {e}")
|
||||
finally:
|
||||
self._is_speaking = False
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""Cancel ongoing speech."""
|
||||
self._cancel_event.set()
|
||||
self._buffer = ""
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset for new turn."""
|
||||
self._cancel_event.clear()
|
||||
self._buffer = ""
|
||||
self._is_speaking = False
|
||||
|
||||
@property
|
||||
def is_speaking(self) -> bool:
|
||||
return self._is_speaking
|
||||
|
||||
|
||||
# Backward-compatible alias
|
||||
SiliconFlowTTSService = OpenAICompatibleTTSService
|
||||
548
services/realtime.py
Normal file
548
services/realtime.py
Normal file
@@ -0,0 +1,548 @@
|
||||
"""OpenAI Realtime API Service.
|
||||
|
||||
Provides true duplex voice conversation using OpenAI's Realtime API,
|
||||
similar to active-call's RealtimeProcessor. This bypasses the need for
|
||||
separate ASR/LLM/TTS services by handling everything server-side.
|
||||
|
||||
The Realtime API provides:
|
||||
- Server-side VAD with turn detection
|
||||
- Streaming speech-to-text
|
||||
- Streaming LLM responses
|
||||
- Streaming text-to-speech
|
||||
- Function calling support
|
||||
- Barge-in/interruption handling
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
import base64
|
||||
from typing import Optional, Dict, Any, Callable, Awaitable, List
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
import websockets
|
||||
WEBSOCKETS_AVAILABLE = True
|
||||
except ImportError:
|
||||
WEBSOCKETS_AVAILABLE = False
|
||||
logger.warning("websockets not available - Realtime API will be disabled")
|
||||
|
||||
|
||||
class RealtimeState(Enum):
|
||||
"""Realtime API connection state."""
|
||||
DISCONNECTED = "disconnected"
|
||||
CONNECTING = "connecting"
|
||||
CONNECTED = "connected"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RealtimeConfig:
|
||||
"""Configuration for OpenAI Realtime API."""
|
||||
|
||||
# API Configuration
|
||||
api_key: Optional[str] = None
|
||||
model: str = "gpt-4o-realtime-preview"
|
||||
endpoint: Optional[str] = None # For Azure or custom endpoints
|
||||
|
||||
# Voice Configuration
|
||||
voice: str = "alloy" # alloy, echo, shimmer, etc.
|
||||
instructions: str = (
|
||||
"You are a helpful, friendly voice assistant. "
|
||||
"Keep your responses concise and conversational."
|
||||
)
|
||||
|
||||
# Turn Detection (Server-side VAD)
|
||||
turn_detection: Optional[Dict[str, Any]] = field(default_factory=lambda: {
|
||||
"type": "server_vad",
|
||||
"threshold": 0.5,
|
||||
"prefix_padding_ms": 300,
|
||||
"silence_duration_ms": 500
|
||||
})
|
||||
|
||||
# Audio Configuration
|
||||
input_audio_format: str = "pcm16"
|
||||
output_audio_format: str = "pcm16"
|
||||
|
||||
# Tools/Functions
|
||||
tools: List[Dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
class RealtimeService:
|
||||
"""
|
||||
OpenAI Realtime API service for true duplex voice conversation.
|
||||
|
||||
This service handles the entire voice conversation pipeline:
|
||||
1. Audio input → Server-side VAD → Speech-to-text
|
||||
2. Text → LLM processing → Response generation
|
||||
3. Response → Text-to-speech → Audio output
|
||||
|
||||
Events emitted:
|
||||
- on_audio: Audio output from the assistant
|
||||
- on_transcript: Text transcript (user or assistant)
|
||||
- on_speech_started: User started speaking
|
||||
- on_speech_stopped: User stopped speaking
|
||||
- on_response_started: Assistant started responding
|
||||
- on_response_done: Assistant finished responding
|
||||
- on_function_call: Function call requested
|
||||
- on_error: Error occurred
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[RealtimeConfig] = None):
|
||||
"""
|
||||
Initialize Realtime API service.
|
||||
|
||||
Args:
|
||||
config: Realtime configuration (uses defaults if not provided)
|
||||
"""
|
||||
self.config = config or RealtimeConfig()
|
||||
self.config.api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")
|
||||
|
||||
self.state = RealtimeState.DISCONNECTED
|
||||
self._ws = None
|
||||
self._receive_task: Optional[asyncio.Task] = None
|
||||
self._cancel_event = asyncio.Event()
|
||||
|
||||
# Event callbacks
|
||||
self._callbacks: Dict[str, List[Callable]] = {
|
||||
"on_audio": [],
|
||||
"on_transcript": [],
|
||||
"on_speech_started": [],
|
||||
"on_speech_stopped": [],
|
||||
"on_response_started": [],
|
||||
"on_response_done": [],
|
||||
"on_function_call": [],
|
||||
"on_error": [],
|
||||
"on_interrupted": [],
|
||||
}
|
||||
|
||||
logger.debug(f"RealtimeService initialized with model={self.config.model}")
|
||||
|
||||
def on(self, event: str, callback: Callable[..., Awaitable[None]]) -> None:
|
||||
"""
|
||||
Register event callback.
|
||||
|
||||
Args:
|
||||
event: Event name
|
||||
callback: Async callback function
|
||||
"""
|
||||
if event in self._callbacks:
|
||||
self._callbacks[event].append(callback)
|
||||
|
||||
async def _emit(self, event: str, *args, **kwargs) -> None:
|
||||
"""Emit event to all registered callbacks."""
|
||||
for callback in self._callbacks.get(event, []):
|
||||
try:
|
||||
await callback(*args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(f"Event callback error ({event}): {e}")
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Connect to OpenAI Realtime API."""
|
||||
if not WEBSOCKETS_AVAILABLE:
|
||||
raise RuntimeError("websockets package not installed")
|
||||
|
||||
if not self.config.api_key:
|
||||
raise ValueError("OpenAI API key not provided")
|
||||
|
||||
self.state = RealtimeState.CONNECTING
|
||||
|
||||
# Build URL
|
||||
if self.config.endpoint:
|
||||
# Azure or custom endpoint
|
||||
url = f"{self.config.endpoint}/openai/realtime?api-version=2024-10-01-preview&deployment={self.config.model}"
|
||||
else:
|
||||
# OpenAI endpoint
|
||||
url = f"wss://api.openai.com/v1/realtime?model={self.config.model}"
|
||||
|
||||
# Build headers
|
||||
headers = {}
|
||||
if self.config.endpoint:
|
||||
headers["api-key"] = self.config.api_key
|
||||
else:
|
||||
headers["Authorization"] = f"Bearer {self.config.api_key}"
|
||||
headers["OpenAI-Beta"] = "realtime=v1"
|
||||
|
||||
try:
|
||||
logger.info(f"Connecting to Realtime API: {url}")
|
||||
self._ws = await websockets.connect(url, extra_headers=headers)
|
||||
|
||||
# Send session configuration
|
||||
await self._configure_session()
|
||||
|
||||
# Start receive loop
|
||||
self._receive_task = asyncio.create_task(self._receive_loop())
|
||||
|
||||
self.state = RealtimeState.CONNECTED
|
||||
logger.info("Realtime API connected successfully")
|
||||
|
||||
except Exception as e:
|
||||
self.state = RealtimeState.ERROR
|
||||
logger.error(f"Realtime API connection failed: {e}")
|
||||
raise
|
||||
|
||||
async def _configure_session(self) -> None:
|
||||
"""Send session configuration to server."""
|
||||
session_config = {
|
||||
"type": "session.update",
|
||||
"session": {
|
||||
"modalities": ["text", "audio"],
|
||||
"instructions": self.config.instructions,
|
||||
"voice": self.config.voice,
|
||||
"input_audio_format": self.config.input_audio_format,
|
||||
"output_audio_format": self.config.output_audio_format,
|
||||
"turn_detection": self.config.turn_detection,
|
||||
}
|
||||
}
|
||||
|
||||
if self.config.tools:
|
||||
session_config["session"]["tools"] = self.config.tools
|
||||
|
||||
await self._send(session_config)
|
||||
logger.debug("Session configuration sent")
|
||||
|
||||
async def _send(self, data: Dict[str, Any]) -> None:
|
||||
"""Send JSON data to server."""
|
||||
if self._ws:
|
||||
await self._ws.send(json.dumps(data))
|
||||
|
||||
async def send_audio(self, audio_bytes: bytes) -> None:
|
||||
"""
|
||||
Send audio to the Realtime API.
|
||||
|
||||
Args:
|
||||
audio_bytes: PCM audio data (16-bit, mono, 24kHz by default)
|
||||
"""
|
||||
if self.state != RealtimeState.CONNECTED:
|
||||
return
|
||||
|
||||
# Encode audio as base64
|
||||
audio_b64 = base64.standard_b64encode(audio_bytes).decode()
|
||||
|
||||
await self._send({
|
||||
"type": "input_audio_buffer.append",
|
||||
"audio": audio_b64
|
||||
})
|
||||
|
||||
async def send_text(self, text: str) -> None:
|
||||
"""
|
||||
Send text input (bypassing audio).
|
||||
|
||||
Args:
|
||||
text: User text input
|
||||
"""
|
||||
if self.state != RealtimeState.CONNECTED:
|
||||
return
|
||||
|
||||
# Create a conversation item with user text
|
||||
await self._send({
|
||||
"type": "conversation.item.create",
|
||||
"item": {
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [{"type": "input_text", "text": text}]
|
||||
}
|
||||
})
|
||||
|
||||
# Trigger response
|
||||
await self._send({"type": "response.create"})
|
||||
|
||||
async def cancel_response(self) -> None:
|
||||
"""Cancel the current response (for barge-in)."""
|
||||
if self.state != RealtimeState.CONNECTED:
|
||||
return
|
||||
|
||||
await self._send({"type": "response.cancel"})
|
||||
logger.debug("Response cancelled")
|
||||
|
||||
async def commit_audio(self) -> None:
|
||||
"""Commit the audio buffer and trigger response."""
|
||||
if self.state != RealtimeState.CONNECTED:
|
||||
return
|
||||
|
||||
await self._send({"type": "input_audio_buffer.commit"})
|
||||
await self._send({"type": "response.create"})
|
||||
|
||||
async def clear_audio_buffer(self) -> None:
|
||||
"""Clear the input audio buffer."""
|
||||
if self.state != RealtimeState.CONNECTED:
|
||||
return
|
||||
|
||||
await self._send({"type": "input_audio_buffer.clear"})
|
||||
|
||||
async def submit_function_result(self, call_id: str, result: str) -> None:
|
||||
"""
|
||||
Submit function call result.
|
||||
|
||||
Args:
|
||||
call_id: The function call ID
|
||||
result: JSON string result
|
||||
"""
|
||||
if self.state != RealtimeState.CONNECTED:
|
||||
return
|
||||
|
||||
await self._send({
|
||||
"type": "conversation.item.create",
|
||||
"item": {
|
||||
"type": "function_call_output",
|
||||
"call_id": call_id,
|
||||
"output": result
|
||||
}
|
||||
})
|
||||
|
||||
# Trigger response with the function result
|
||||
await self._send({"type": "response.create"})
|
||||
|
||||
async def _receive_loop(self) -> None:
|
||||
"""Receive and process messages from the Realtime API."""
|
||||
if not self._ws:
|
||||
return
|
||||
|
||||
try:
|
||||
async for message in self._ws:
|
||||
try:
|
||||
data = json.loads(message)
|
||||
await self._handle_event(data)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Invalid JSON received: {message[:100]}")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("Receive loop cancelled")
|
||||
except websockets.ConnectionClosed as e:
|
||||
logger.info(f"WebSocket closed: {e}")
|
||||
self.state = RealtimeState.DISCONNECTED
|
||||
except Exception as e:
|
||||
logger.error(f"Receive loop error: {e}")
|
||||
self.state = RealtimeState.ERROR
|
||||
|
||||
async def _handle_event(self, data: Dict[str, Any]) -> None:
|
||||
"""Handle incoming event from Realtime API."""
|
||||
event_type = data.get("type", "unknown")
|
||||
|
||||
# Audio delta - streaming audio output
|
||||
if event_type == "response.audio.delta":
|
||||
if "delta" in data:
|
||||
audio_bytes = base64.standard_b64decode(data["delta"])
|
||||
await self._emit("on_audio", audio_bytes)
|
||||
|
||||
# Audio transcript delta - streaming text
|
||||
elif event_type == "response.audio_transcript.delta":
|
||||
if "delta" in data:
|
||||
await self._emit("on_transcript", data["delta"], "assistant", False)
|
||||
|
||||
# Audio transcript done
|
||||
elif event_type == "response.audio_transcript.done":
|
||||
if "transcript" in data:
|
||||
await self._emit("on_transcript", data["transcript"], "assistant", True)
|
||||
|
||||
# Input audio transcript (user speech)
|
||||
elif event_type == "conversation.item.input_audio_transcription.completed":
|
||||
if "transcript" in data:
|
||||
await self._emit("on_transcript", data["transcript"], "user", True)
|
||||
|
||||
# Speech started (server VAD detected speech)
|
||||
elif event_type == "input_audio_buffer.speech_started":
|
||||
await self._emit("on_speech_started", data.get("audio_start_ms", 0))
|
||||
|
||||
# Speech stopped
|
||||
elif event_type == "input_audio_buffer.speech_stopped":
|
||||
await self._emit("on_speech_stopped", data.get("audio_end_ms", 0))
|
||||
|
||||
# Response started
|
||||
elif event_type == "response.created":
|
||||
await self._emit("on_response_started", data.get("response", {}))
|
||||
|
||||
# Response done
|
||||
elif event_type == "response.done":
|
||||
await self._emit("on_response_done", data.get("response", {}))
|
||||
|
||||
# Function call
|
||||
elif event_type == "response.function_call_arguments.done":
|
||||
call_id = data.get("call_id")
|
||||
name = data.get("name")
|
||||
arguments = data.get("arguments", "{}")
|
||||
await self._emit("on_function_call", call_id, name, arguments)
|
||||
|
||||
# Error
|
||||
elif event_type == "error":
|
||||
error = data.get("error", {})
|
||||
logger.error(f"Realtime API error: {error}")
|
||||
await self._emit("on_error", error)
|
||||
|
||||
# Session events
|
||||
elif event_type == "session.created":
|
||||
logger.info("Session created")
|
||||
elif event_type == "session.updated":
|
||||
logger.debug("Session updated")
|
||||
|
||||
else:
|
||||
logger.debug(f"Unhandled event type: {event_type}")
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from Realtime API."""
|
||||
self._cancel_event.set()
|
||||
|
||||
if self._receive_task:
|
||||
self._receive_task.cancel()
|
||||
try:
|
||||
await self._receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
|
||||
self.state = RealtimeState.DISCONNECTED
|
||||
logger.info("Realtime API disconnected")
|
||||
|
||||
|
||||
class RealtimePipeline:
|
||||
"""
|
||||
Pipeline adapter for RealtimeService.
|
||||
|
||||
Provides a compatible interface with DuplexPipeline but uses
|
||||
OpenAI Realtime API for all processing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transport,
|
||||
session_id: str,
|
||||
config: Optional[RealtimeConfig] = None
|
||||
):
|
||||
"""
|
||||
Initialize Realtime pipeline.
|
||||
|
||||
Args:
|
||||
transport: Transport for sending audio/events
|
||||
session_id: Session identifier
|
||||
config: Realtime configuration
|
||||
"""
|
||||
self.transport = transport
|
||||
self.session_id = session_id
|
||||
|
||||
self.service = RealtimeService(config)
|
||||
|
||||
# Register callbacks
|
||||
self.service.on("on_audio", self._on_audio)
|
||||
self.service.on("on_transcript", self._on_transcript)
|
||||
self.service.on("on_speech_started", self._on_speech_started)
|
||||
self.service.on("on_speech_stopped", self._on_speech_stopped)
|
||||
self.service.on("on_response_started", self._on_response_started)
|
||||
self.service.on("on_response_done", self._on_response_done)
|
||||
self.service.on("on_error", self._on_error)
|
||||
|
||||
self._is_speaking = False
|
||||
self._running = True
|
||||
|
||||
logger.info(f"RealtimePipeline initialized for session {session_id}")
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the pipeline."""
|
||||
await self.service.connect()
|
||||
|
||||
async def process_audio(self, pcm_bytes: bytes) -> None:
|
||||
"""
|
||||
Process incoming audio.
|
||||
|
||||
Note: Realtime API expects 24kHz audio by default.
|
||||
You may need to resample from 16kHz.
|
||||
"""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
# TODO: Resample from 16kHz to 24kHz if needed
|
||||
await self.service.send_audio(pcm_bytes)
|
||||
|
||||
async def process_text(self, text: str) -> None:
|
||||
"""Process text input."""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
await self.service.send_text(text)
|
||||
|
||||
async def interrupt(self) -> None:
|
||||
"""Interrupt current response."""
|
||||
await self.service.cancel_response()
|
||||
await self.transport.send_event({
|
||||
"event": "interrupt",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
})
|
||||
|
||||
async def cleanup(self) -> None:
|
||||
"""Cleanup resources."""
|
||||
self._running = False
|
||||
await self.service.disconnect()
|
||||
|
||||
# Event handlers
|
||||
|
||||
async def _on_audio(self, audio_bytes: bytes) -> None:
|
||||
"""Handle audio output."""
|
||||
await self.transport.send_audio(audio_bytes)
|
||||
|
||||
async def _on_transcript(self, text: str, role: str, is_final: bool) -> None:
|
||||
"""Handle transcript."""
|
||||
logger.info(f"[{role.upper()}] {text[:50]}..." if len(text) > 50 else f"[{role.upper()}] {text}")
|
||||
|
||||
async def _on_speech_started(self, start_ms: int) -> None:
|
||||
"""Handle user speech start."""
|
||||
self._is_speaking = True
|
||||
await self.transport.send_event({
|
||||
"event": "speaking",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms(),
|
||||
"startTime": start_ms
|
||||
})
|
||||
|
||||
# Cancel any ongoing response (barge-in)
|
||||
await self.service.cancel_response()
|
||||
|
||||
async def _on_speech_stopped(self, end_ms: int) -> None:
|
||||
"""Handle user speech stop."""
|
||||
self._is_speaking = False
|
||||
await self.transport.send_event({
|
||||
"event": "silence",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms(),
|
||||
"duration": end_ms
|
||||
})
|
||||
|
||||
async def _on_response_started(self, response: Dict) -> None:
|
||||
"""Handle response start."""
|
||||
await self.transport.send_event({
|
||||
"event": "trackStart",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
})
|
||||
|
||||
async def _on_response_done(self, response: Dict) -> None:
|
||||
"""Handle response complete."""
|
||||
await self.transport.send_event({
|
||||
"event": "trackEnd",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
})
|
||||
|
||||
async def _on_error(self, error: Dict) -> None:
|
||||
"""Handle error."""
|
||||
await self.transport.send_event({
|
||||
"event": "error",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms(),
|
||||
"sender": "realtime",
|
||||
"error": str(error)
|
||||
})
|
||||
|
||||
def _get_timestamp_ms(self) -> int:
|
||||
"""Get current timestamp in milliseconds."""
|
||||
import time
|
||||
return int(time.time() * 1000)
|
||||
|
||||
@property
|
||||
def is_speaking(self) -> bool:
|
||||
"""Check if user is speaking."""
|
||||
return self._is_speaking
|
||||
8
services/siliconflow_asr.py
Normal file
8
services/siliconflow_asr.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""Backward-compatible imports for legacy siliconflow_asr module."""
|
||||
|
||||
from services.openai_compatible_asr import OpenAICompatibleASRService
|
||||
|
||||
# Backward-compatible alias
|
||||
SiliconFlowASRService = OpenAICompatibleASRService
|
||||
|
||||
__all__ = ["OpenAICompatibleASRService", "SiliconFlowASRService"]
|
||||
8
services/siliconflow_tts.py
Normal file
8
services/siliconflow_tts.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""Backward-compatible imports for legacy siliconflow_tts module."""
|
||||
|
||||
from services.openai_compatible_tts import OpenAICompatibleTTSService, StreamingTTSAdapter
|
||||
|
||||
# Backward-compatible alias
|
||||
SiliconFlowTTSService = OpenAICompatibleTTSService
|
||||
|
||||
__all__ = ["OpenAICompatibleTTSService", "SiliconFlowTTSService", "StreamingTTSAdapter"]
|
||||
86
services/streaming_text.py
Normal file
86
services/streaming_text.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""Shared text chunking helpers for streaming TTS."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def is_non_sentence_period(text: str, idx: int) -> bool:
|
||||
"""Check whether '.' should NOT be treated as a sentence delimiter."""
|
||||
if idx < 0 or idx >= len(text) or text[idx] != ".":
|
||||
return False
|
||||
|
||||
# Decimal/version segment: 1.2, v1.2.3
|
||||
if idx > 0 and idx < len(text) - 1 and text[idx - 1].isdigit() and text[idx + 1].isdigit():
|
||||
return True
|
||||
|
||||
# Number abbreviations: No.1 / No. 1
|
||||
left_start = idx - 1
|
||||
while left_start >= 0 and text[left_start].isalpha():
|
||||
left_start -= 1
|
||||
left_token = text[left_start + 1:idx].lower()
|
||||
if left_token == "no":
|
||||
j = idx + 1
|
||||
while j < len(text) and text[j].isspace():
|
||||
j += 1
|
||||
if j < len(text) and text[j].isdigit():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def has_spoken_content(text: str) -> bool:
|
||||
"""Check whether text contains pronounceable content (not punctuation-only)."""
|
||||
return any(char.isalnum() for char in text)
|
||||
|
||||
|
||||
def extract_tts_sentence(
|
||||
text_buffer: str,
|
||||
*,
|
||||
end_chars: frozenset[str],
|
||||
trailing_chars: frozenset[str],
|
||||
closers: frozenset[str],
|
||||
min_split_spoken_chars: int = 0,
|
||||
hold_trailing_at_buffer_end: bool = False,
|
||||
force: bool = False,
|
||||
) -> Optional[tuple[str, str]]:
|
||||
"""Extract one TTS sentence from text buffer."""
|
||||
if not text_buffer:
|
||||
return None
|
||||
|
||||
search_start = 0
|
||||
while True:
|
||||
split_idx = -1
|
||||
for idx in range(search_start, len(text_buffer)):
|
||||
char = text_buffer[idx]
|
||||
if char == "." and is_non_sentence_period(text_buffer, idx):
|
||||
continue
|
||||
if char in end_chars:
|
||||
split_idx = idx
|
||||
break
|
||||
|
||||
if split_idx == -1:
|
||||
return None
|
||||
|
||||
end_idx = split_idx + 1
|
||||
while end_idx < len(text_buffer) and text_buffer[end_idx] in trailing_chars:
|
||||
end_idx += 1
|
||||
|
||||
while end_idx < len(text_buffer) and text_buffer[end_idx] in closers:
|
||||
end_idx += 1
|
||||
|
||||
if hold_trailing_at_buffer_end and not force and end_idx >= len(text_buffer):
|
||||
return None
|
||||
|
||||
sentence = text_buffer[:end_idx].strip()
|
||||
spoken_chars = sum(1 for ch in sentence if ch.isalnum())
|
||||
|
||||
if (
|
||||
not force
|
||||
and min_split_spoken_chars > 0
|
||||
and 0 < spoken_chars < min_split_spoken_chars
|
||||
and end_idx < len(text_buffer)
|
||||
):
|
||||
search_start = end_idx
|
||||
continue
|
||||
|
||||
remainder = text_buffer[end_idx:]
|
||||
return sentence, remainder
|
||||
95
services/streaming_tts_adapter.py
Normal file
95
services/streaming_tts_adapter.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""Backend-agnostic streaming adapter from LLM text to TTS audio."""
|
||||
|
||||
import asyncio
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from services.base import BaseTTSService
|
||||
from services.streaming_text import extract_tts_sentence, has_spoken_content
|
||||
|
||||
|
||||
class StreamingTTSAdapter:
|
||||
"""
|
||||
Adapter for streaming LLM text to TTS with sentence-level chunking.
|
||||
|
||||
This reduces latency by starting TTS as soon as a complete sentence
|
||||
is received from the LLM, rather than waiting for the full response.
|
||||
"""
|
||||
|
||||
SENTENCE_ENDS = {"。", "!", "?", ".", "!", "?", "\n"}
|
||||
SENTENCE_CLOSERS = frozenset()
|
||||
|
||||
def __init__(self, tts_service: BaseTTSService, transport, session_id: str):
|
||||
self.tts_service = tts_service
|
||||
self.transport = transport
|
||||
self.session_id = session_id
|
||||
self._buffer = ""
|
||||
self._cancel_event = asyncio.Event()
|
||||
self._is_speaking = False
|
||||
|
||||
async def process_text_chunk(self, text_chunk: str) -> None:
|
||||
"""
|
||||
Process a text chunk from LLM and trigger TTS when sentence is complete.
|
||||
|
||||
Args:
|
||||
text_chunk: Text chunk from LLM streaming
|
||||
"""
|
||||
if self._cancel_event.is_set():
|
||||
return
|
||||
|
||||
self._buffer += text_chunk
|
||||
|
||||
# Check for sentence completion
|
||||
while True:
|
||||
split_result = extract_tts_sentence(
|
||||
self._buffer,
|
||||
end_chars=frozenset(self.SENTENCE_ENDS),
|
||||
trailing_chars=frozenset(self.SENTENCE_ENDS),
|
||||
closers=self.SENTENCE_CLOSERS,
|
||||
force=False,
|
||||
)
|
||||
if not split_result:
|
||||
break
|
||||
|
||||
sentence, self._buffer = split_result
|
||||
if sentence and has_spoken_content(sentence):
|
||||
await self._speak_sentence(sentence)
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Flush remaining buffer."""
|
||||
if self._buffer.strip() and not self._cancel_event.is_set():
|
||||
await self._speak_sentence(self._buffer.strip())
|
||||
self._buffer = ""
|
||||
|
||||
async def _speak_sentence(self, text: str) -> None:
|
||||
"""Synthesize and send a sentence."""
|
||||
if not text or self._cancel_event.is_set():
|
||||
return
|
||||
|
||||
self._is_speaking = True
|
||||
|
||||
try:
|
||||
async for chunk in self.tts_service.synthesize_stream(text):
|
||||
if self._cancel_event.is_set():
|
||||
break
|
||||
await self.transport.send_audio(chunk.audio)
|
||||
await asyncio.sleep(0.01) # Prevent flooding
|
||||
except Exception as e:
|
||||
logger.error(f"TTS speak error: {e}")
|
||||
finally:
|
||||
self._is_speaking = False
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""Cancel ongoing speech."""
|
||||
self._cancel_event.set()
|
||||
self._buffer = ""
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset for new turn."""
|
||||
self._cancel_event.clear()
|
||||
self._buffer = ""
|
||||
self._is_speaking = False
|
||||
|
||||
@property
|
||||
def is_speaking(self) -> bool:
|
||||
return self._is_speaking
|
||||
271
services/tts.py
Normal file
271
services/tts.py
Normal file
@@ -0,0 +1,271 @@
|
||||
"""TTS (Text-to-Speech) Service implementations.
|
||||
|
||||
Provides multiple TTS backend options including edge-tts (free)
|
||||
and placeholder for cloud services.
|
||||
"""
|
||||
|
||||
import os
|
||||
import io
|
||||
import asyncio
|
||||
import struct
|
||||
from typing import AsyncIterator, Optional
|
||||
from loguru import logger
|
||||
|
||||
from services.base import BaseTTSService, TTSChunk, ServiceState
|
||||
|
||||
# Try to import edge-tts
|
||||
try:
|
||||
import edge_tts
|
||||
EDGE_TTS_AVAILABLE = True
|
||||
except ImportError:
|
||||
EDGE_TTS_AVAILABLE = False
|
||||
logger.warning("edge-tts not available - EdgeTTS service will be disabled")
|
||||
|
||||
|
||||
class EdgeTTSService(BaseTTSService):
|
||||
"""
|
||||
Microsoft Edge TTS service.
|
||||
|
||||
Uses edge-tts library for free, high-quality speech synthesis.
|
||||
Supports streaming for low-latency playback.
|
||||
"""
|
||||
|
||||
# Voice mapping for common languages
|
||||
VOICE_MAP = {
|
||||
"en": "en-US-JennyNeural",
|
||||
"en-US": "en-US-JennyNeural",
|
||||
"en-GB": "en-GB-SoniaNeural",
|
||||
"zh": "zh-CN-XiaoxiaoNeural",
|
||||
"zh-CN": "zh-CN-XiaoxiaoNeural",
|
||||
"zh-TW": "zh-TW-HsiaoChenNeural",
|
||||
"ja": "ja-JP-NanamiNeural",
|
||||
"ko": "ko-KR-SunHiNeural",
|
||||
"fr": "fr-FR-DeniseNeural",
|
||||
"de": "de-DE-KatjaNeural",
|
||||
"es": "es-ES-ElviraNeural",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
voice: str = "en-US-JennyNeural",
|
||||
sample_rate: int = 16000,
|
||||
speed: float = 1.0
|
||||
):
|
||||
"""
|
||||
Initialize Edge TTS service.
|
||||
|
||||
Args:
|
||||
voice: Voice name (e.g., "en-US-JennyNeural") or language code (e.g., "en")
|
||||
sample_rate: Target sample rate (will be resampled)
|
||||
speed: Speech speed multiplier
|
||||
"""
|
||||
# Resolve voice from language code if needed
|
||||
if voice in self.VOICE_MAP:
|
||||
voice = self.VOICE_MAP[voice]
|
||||
|
||||
super().__init__(voice=voice, sample_rate=sample_rate, speed=speed)
|
||||
self._cancel_event = asyncio.Event()
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Edge TTS doesn't require explicit connection."""
|
||||
if not EDGE_TTS_AVAILABLE:
|
||||
raise RuntimeError("edge-tts package not installed")
|
||||
self.state = ServiceState.CONNECTED
|
||||
logger.info(f"Edge TTS service ready: voice={self.voice}")
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Edge TTS doesn't require explicit disconnection."""
|
||||
self.state = ServiceState.DISCONNECTED
|
||||
logger.info("Edge TTS service disconnected")
|
||||
|
||||
def _get_rate_string(self) -> str:
|
||||
"""Convert speed to rate string for edge-tts."""
|
||||
# edge-tts uses percentage format: "+0%", "-10%", "+20%"
|
||||
percentage = int((self.speed - 1.0) * 100)
|
||||
if percentage >= 0:
|
||||
return f"+{percentage}%"
|
||||
return f"{percentage}%"
|
||||
|
||||
async def synthesize(self, text: str) -> bytes:
|
||||
"""
|
||||
Synthesize complete audio for text.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize
|
||||
|
||||
Returns:
|
||||
PCM audio data (16-bit, mono, 16kHz)
|
||||
"""
|
||||
if not EDGE_TTS_AVAILABLE:
|
||||
raise RuntimeError("edge-tts not available")
|
||||
|
||||
# Collect all chunks
|
||||
audio_data = b""
|
||||
async for chunk in self.synthesize_stream(text):
|
||||
audio_data += chunk.audio
|
||||
|
||||
return audio_data
|
||||
|
||||
async def synthesize_stream(self, text: str) -> AsyncIterator[TTSChunk]:
|
||||
"""
|
||||
Synthesize audio in streaming mode.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize
|
||||
|
||||
Yields:
|
||||
TTSChunk objects with PCM audio
|
||||
"""
|
||||
if not EDGE_TTS_AVAILABLE:
|
||||
raise RuntimeError("edge-tts not available")
|
||||
|
||||
self._cancel_event.clear()
|
||||
|
||||
try:
|
||||
communicate = edge_tts.Communicate(
|
||||
text,
|
||||
voice=self.voice,
|
||||
rate=self._get_rate_string()
|
||||
)
|
||||
|
||||
# edge-tts outputs MP3, we need to decode to PCM
|
||||
# For now, collect MP3 chunks and yield after conversion
|
||||
mp3_data = b""
|
||||
|
||||
async for chunk in communicate.stream():
|
||||
# Check for cancellation
|
||||
if self._cancel_event.is_set():
|
||||
logger.info("TTS synthesis cancelled")
|
||||
return
|
||||
|
||||
if chunk["type"] == "audio":
|
||||
mp3_data += chunk["data"]
|
||||
|
||||
# Convert MP3 to PCM
|
||||
if mp3_data:
|
||||
pcm_data = await self._convert_mp3_to_pcm(mp3_data)
|
||||
if pcm_data:
|
||||
# Yield in chunks for streaming playback
|
||||
chunk_size = self.sample_rate * 2 // 10 # 100ms chunks
|
||||
for i in range(0, len(pcm_data), chunk_size):
|
||||
if self._cancel_event.is_set():
|
||||
return
|
||||
|
||||
chunk_data = pcm_data[i:i + chunk_size]
|
||||
yield TTSChunk(
|
||||
audio=chunk_data,
|
||||
sample_rate=self.sample_rate,
|
||||
is_final=(i + chunk_size >= len(pcm_data))
|
||||
)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("TTS synthesis cancelled via asyncio")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"TTS synthesis error: {e}")
|
||||
raise
|
||||
|
||||
async def _convert_mp3_to_pcm(self, mp3_data: bytes) -> bytes:
|
||||
"""
|
||||
Convert MP3 audio to PCM.
|
||||
|
||||
Uses pydub or ffmpeg for conversion.
|
||||
"""
|
||||
try:
|
||||
# Try using pydub (requires ffmpeg)
|
||||
from pydub import AudioSegment
|
||||
|
||||
# Load MP3 from bytes
|
||||
audio = AudioSegment.from_mp3(io.BytesIO(mp3_data))
|
||||
|
||||
# Convert to target format
|
||||
audio = audio.set_frame_rate(self.sample_rate)
|
||||
audio = audio.set_channels(1)
|
||||
audio = audio.set_sample_width(2) # 16-bit
|
||||
|
||||
# Export as raw PCM
|
||||
return audio.raw_data
|
||||
|
||||
except ImportError:
|
||||
logger.warning("pydub not available, trying fallback")
|
||||
# Fallback: Use subprocess to call ffmpeg directly
|
||||
return await self._ffmpeg_convert(mp3_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Audio conversion error: {e}")
|
||||
return b""
|
||||
|
||||
async def _ffmpeg_convert(self, mp3_data: bytes) -> bytes:
|
||||
"""Convert MP3 to PCM using ffmpeg subprocess."""
|
||||
try:
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
"ffmpeg",
|
||||
"-i", "pipe:0",
|
||||
"-f", "s16le",
|
||||
"-acodec", "pcm_s16le",
|
||||
"-ar", str(self.sample_rate),
|
||||
"-ac", "1",
|
||||
"pipe:1",
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.DEVNULL
|
||||
)
|
||||
|
||||
stdout, _ = await process.communicate(input=mp3_data)
|
||||
return stdout
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ffmpeg conversion error: {e}")
|
||||
return b""
|
||||
|
||||
async def cancel(self) -> None:
|
||||
"""Cancel ongoing synthesis."""
|
||||
self._cancel_event.set()
|
||||
|
||||
|
||||
class MockTTSService(BaseTTSService):
|
||||
"""
|
||||
Mock TTS service for testing without actual synthesis.
|
||||
|
||||
Generates silence or simple tones.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
voice: str = "mock",
|
||||
sample_rate: int = 16000,
|
||||
speed: float = 1.0
|
||||
):
|
||||
super().__init__(voice=voice, sample_rate=sample_rate, speed=speed)
|
||||
|
||||
async def connect(self) -> None:
|
||||
self.state = ServiceState.CONNECTED
|
||||
logger.info("Mock TTS service connected")
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
self.state = ServiceState.DISCONNECTED
|
||||
logger.info("Mock TTS service disconnected")
|
||||
|
||||
async def synthesize(self, text: str) -> bytes:
|
||||
"""Generate silence based on text length."""
|
||||
# Approximate: 100ms per word
|
||||
word_count = len(text.split())
|
||||
duration_ms = word_count * 100
|
||||
samples = int(self.sample_rate * duration_ms / 1000)
|
||||
|
||||
# Generate silence (zeros)
|
||||
return bytes(samples * 2) # 16-bit = 2 bytes per sample
|
||||
|
||||
async def synthesize_stream(self, text: str) -> AsyncIterator[TTSChunk]:
|
||||
"""Generate silence chunks."""
|
||||
audio = await self.synthesize(text)
|
||||
|
||||
# Yield in 100ms chunks
|
||||
chunk_size = self.sample_rate * 2 // 10
|
||||
for i in range(0, len(audio), chunk_size):
|
||||
chunk_data = audio[i:i + chunk_size]
|
||||
yield TTSChunk(
|
||||
audio=chunk_data,
|
||||
sample_rate=self.sample_rate,
|
||||
is_final=(i + chunk_size >= len(audio))
|
||||
)
|
||||
await asyncio.sleep(0.05) # Simulate processing time
|
||||
331
tests/test_tool_call_flow.py
Normal file
331
tests/test_tool_call_flow.py
Normal file
@@ -0,0 +1,331 @@
|
||||
import asyncio
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytest
|
||||
|
||||
from core.duplex_pipeline import DuplexPipeline
|
||||
from models.ws_v1 import ToolCallResultsMessage, parse_client_message
|
||||
from services.base import LLMStreamEvent
|
||||
|
||||
|
||||
class _DummySileroVAD:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def process_audio(self, _pcm: bytes) -> float:
|
||||
return 0.0
|
||||
|
||||
|
||||
class _DummyVADProcessor:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def process(self, _speech_prob: float):
|
||||
return "Silence", 0.0
|
||||
|
||||
|
||||
class _DummyEouDetector:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def process(self, _vad_status: str) -> bool:
|
||||
return False
|
||||
|
||||
def reset(self) -> None:
|
||||
return None
|
||||
|
||||
|
||||
class _FakeTransport:
|
||||
async def send_event(self, _event: Dict[str, Any]) -> None:
|
||||
return None
|
||||
|
||||
async def send_audio(self, _audio: bytes) -> None:
|
||||
return None
|
||||
|
||||
|
||||
class _FakeTTS:
|
||||
async def synthesize_stream(self, _text: str):
|
||||
if False:
|
||||
yield None
|
||||
|
||||
|
||||
class _FakeASR:
|
||||
async def connect(self) -> None:
|
||||
return None
|
||||
|
||||
|
||||
class _FakeLLM:
|
||||
def __init__(self, rounds: List[List[LLMStreamEvent]]):
|
||||
self._rounds = rounds
|
||||
self._call_index = 0
|
||||
|
||||
async def generate_stream(self, _messages, temperature=0.7, max_tokens=None):
|
||||
idx = self._call_index
|
||||
self._call_index += 1
|
||||
events = self._rounds[idx] if idx < len(self._rounds) else [LLMStreamEvent(type="done")]
|
||||
for event in events:
|
||||
yield event
|
||||
|
||||
|
||||
def _build_pipeline(monkeypatch, llm_rounds: List[List[LLMStreamEvent]]) -> tuple[DuplexPipeline, List[Dict[str, Any]]]:
|
||||
monkeypatch.setattr("core.duplex_pipeline.SileroVAD", _DummySileroVAD)
|
||||
monkeypatch.setattr("core.duplex_pipeline.VADProcessor", _DummyVADProcessor)
|
||||
monkeypatch.setattr("core.duplex_pipeline.EouDetector", _DummyEouDetector)
|
||||
|
||||
pipeline = DuplexPipeline(
|
||||
transport=_FakeTransport(),
|
||||
session_id="s_test",
|
||||
llm_service=_FakeLLM(llm_rounds),
|
||||
tts_service=_FakeTTS(),
|
||||
asr_service=_FakeASR(),
|
||||
)
|
||||
events: List[Dict[str, Any]] = []
|
||||
|
||||
async def _capture_event(event: Dict[str, Any], priority: int = 20):
|
||||
events.append(event)
|
||||
|
||||
async def _noop_speak(_text: str, fade_in_ms: int = 0, fade_out_ms: int = 8):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(pipeline, "_send_event", _capture_event)
|
||||
monkeypatch.setattr(pipeline, "_speak_sentence", _noop_speak)
|
||||
return pipeline, events
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ws_message_parses_tool_call_results():
|
||||
msg = parse_client_message(
|
||||
{
|
||||
"type": "tool_call.results",
|
||||
"results": [{"tool_call_id": "call_1", "status": {"code": 200, "message": "ok"}}],
|
||||
}
|
||||
)
|
||||
assert isinstance(msg, ToolCallResultsMessage)
|
||||
assert msg.results[0]["tool_call_id"] == "call_1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_turn_without_tool_keeps_streaming(monkeypatch):
|
||||
pipeline, events = _build_pipeline(
|
||||
monkeypatch,
|
||||
[
|
||||
[
|
||||
LLMStreamEvent(type="text_delta", text="hello "),
|
||||
LLMStreamEvent(type="text_delta", text="world."),
|
||||
LLMStreamEvent(type="done"),
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
await pipeline._handle_turn("hi")
|
||||
|
||||
event_types = [e.get("type") for e in events]
|
||||
assert "assistant.response.delta" in event_types
|
||||
assert "assistant.response.final" in event_types
|
||||
assert "assistant.tool_call" not in event_types
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"metadata",
|
||||
[
|
||||
{"output": {"mode": "text"}},
|
||||
{"services": {"tts": {"enabled": False}}},
|
||||
],
|
||||
)
|
||||
async def test_text_output_mode_skips_audio_events(monkeypatch, metadata):
|
||||
pipeline, events = _build_pipeline(
|
||||
monkeypatch,
|
||||
[
|
||||
[
|
||||
LLMStreamEvent(type="text_delta", text="hello "),
|
||||
LLMStreamEvent(type="text_delta", text="world."),
|
||||
LLMStreamEvent(type="done"),
|
||||
]
|
||||
],
|
||||
)
|
||||
pipeline.apply_runtime_overrides(metadata)
|
||||
|
||||
await pipeline._handle_turn("hi")
|
||||
|
||||
event_types = [e.get("type") for e in events]
|
||||
assert "assistant.response.delta" in event_types
|
||||
assert "assistant.response.final" in event_types
|
||||
assert "output.audio.start" not in event_types
|
||||
assert "output.audio.end" not in event_types
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_turn_with_tool_call_then_results(monkeypatch):
|
||||
pipeline, events = _build_pipeline(
|
||||
monkeypatch,
|
||||
[
|
||||
[
|
||||
LLMStreamEvent(type="text_delta", text="let me check."),
|
||||
LLMStreamEvent(
|
||||
type="tool_call",
|
||||
tool_call={
|
||||
"id": "call_ok",
|
||||
"executor": "client",
|
||||
"type": "function",
|
||||
"function": {"name": "weather", "arguments": "{\"city\":\"hz\"}"},
|
||||
},
|
||||
),
|
||||
LLMStreamEvent(type="done"),
|
||||
],
|
||||
[
|
||||
LLMStreamEvent(type="text_delta", text="it's sunny."),
|
||||
LLMStreamEvent(type="done"),
|
||||
],
|
||||
],
|
||||
)
|
||||
|
||||
task = asyncio.create_task(pipeline._handle_turn("weather?"))
|
||||
for _ in range(200):
|
||||
if any(e.get("type") == "assistant.tool_call" for e in events):
|
||||
break
|
||||
await asyncio.sleep(0.005)
|
||||
|
||||
await pipeline.handle_tool_call_results(
|
||||
[
|
||||
{
|
||||
"tool_call_id": "call_ok",
|
||||
"name": "weather",
|
||||
"output": {"temp": 21},
|
||||
"status": {"code": 200, "message": "ok"},
|
||||
}
|
||||
]
|
||||
)
|
||||
await task
|
||||
|
||||
assert any(e.get("type") == "assistant.tool_call" for e in events)
|
||||
finals = [e for e in events if e.get("type") == "assistant.response.final"]
|
||||
assert finals
|
||||
assert "it's sunny" in finals[-1].get("text", "")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_turn_with_tool_call_timeout(monkeypatch):
|
||||
pipeline, events = _build_pipeline(
|
||||
monkeypatch,
|
||||
[
|
||||
[
|
||||
LLMStreamEvent(
|
||||
type="tool_call",
|
||||
tool_call={
|
||||
"id": "call_timeout",
|
||||
"executor": "client",
|
||||
"type": "function",
|
||||
"function": {"name": "search", "arguments": "{\"query\":\"x\"}"},
|
||||
},
|
||||
),
|
||||
LLMStreamEvent(type="done"),
|
||||
],
|
||||
[
|
||||
LLMStreamEvent(type="text_delta", text="fallback answer."),
|
||||
LLMStreamEvent(type="done"),
|
||||
],
|
||||
],
|
||||
)
|
||||
pipeline._TOOL_WAIT_TIMEOUT_SECONDS = 0.01
|
||||
|
||||
await pipeline._handle_turn("query")
|
||||
|
||||
finals = [e for e in events if e.get("type") == "assistant.response.final"]
|
||||
assert finals
|
||||
assert "fallback answer" in finals[-1].get("text", "")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_duplicate_tool_results_are_ignored(monkeypatch):
|
||||
pipeline, _events = _build_pipeline(monkeypatch, [[LLMStreamEvent(type="done")]])
|
||||
|
||||
await pipeline.handle_tool_call_results(
|
||||
[{"tool_call_id": "call_dup", "output": {"value": 1}, "status": {"code": 200, "message": "ok"}}]
|
||||
)
|
||||
await pipeline.handle_tool_call_results(
|
||||
[{"tool_call_id": "call_dup", "output": {"value": 2}, "status": {"code": 200, "message": "ok"}}]
|
||||
)
|
||||
result = await pipeline._wait_for_single_tool_result("call_dup")
|
||||
|
||||
assert result.get("output", {}).get("value") == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_server_calculator_emits_tool_result(monkeypatch):
|
||||
pipeline, events = _build_pipeline(
|
||||
monkeypatch,
|
||||
[
|
||||
[
|
||||
LLMStreamEvent(
|
||||
type="tool_call",
|
||||
tool_call={
|
||||
"id": "call_calc",
|
||||
"executor": "server",
|
||||
"type": "function",
|
||||
"function": {"name": "calculator", "arguments": "{\"expression\":\"1+2\"}"},
|
||||
},
|
||||
),
|
||||
LLMStreamEvent(type="done"),
|
||||
],
|
||||
[
|
||||
LLMStreamEvent(type="text_delta", text="done."),
|
||||
LLMStreamEvent(type="done"),
|
||||
],
|
||||
],
|
||||
)
|
||||
|
||||
await pipeline._handle_turn("calc")
|
||||
|
||||
tool_results = [e for e in events if e.get("type") == "assistant.tool_result"]
|
||||
assert tool_results
|
||||
payload = tool_results[-1].get("result", {})
|
||||
assert payload.get("status", {}).get("code") == 200
|
||||
assert payload.get("output", {}).get("result") == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_server_tool_timeout_emits_504_and_continues(monkeypatch):
|
||||
async def _slow_execute(_call):
|
||||
await asyncio.sleep(0.05)
|
||||
return {
|
||||
"tool_call_id": "call_slow",
|
||||
"name": "weather",
|
||||
"output": {"ok": True},
|
||||
"status": {"code": 200, "message": "ok"},
|
||||
}
|
||||
|
||||
monkeypatch.setattr("core.duplex_pipeline.execute_server_tool", _slow_execute)
|
||||
|
||||
pipeline, events = _build_pipeline(
|
||||
monkeypatch,
|
||||
[
|
||||
[
|
||||
LLMStreamEvent(
|
||||
type="tool_call",
|
||||
tool_call={
|
||||
"id": "call_slow",
|
||||
"executor": "server",
|
||||
"type": "function",
|
||||
"function": {"name": "weather", "arguments": "{\"city\":\"hz\"}"},
|
||||
},
|
||||
),
|
||||
LLMStreamEvent(type="done"),
|
||||
],
|
||||
[
|
||||
LLMStreamEvent(type="text_delta", text="timeout fallback."),
|
||||
LLMStreamEvent(type="done"),
|
||||
],
|
||||
],
|
||||
)
|
||||
pipeline._SERVER_TOOL_TIMEOUT_SECONDS = 0.01
|
||||
|
||||
await pipeline._handle_turn("weather?")
|
||||
|
||||
tool_results = [e for e in events if e.get("type") == "assistant.tool_result"]
|
||||
assert tool_results
|
||||
payload = tool_results[-1].get("result", {})
|
||||
assert payload.get("status", {}).get("code") == 504
|
||||
finals = [e for e in events if e.get("type") == "assistant.response.final"]
|
||||
assert finals
|
||||
assert "timeout fallback" in finals[-1].get("text", "")
|
||||
57
tests/test_tool_executor.py
Normal file
57
tests/test_tool_executor.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import pytest
|
||||
|
||||
from core.tool_executor import execute_server_tool
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_code_interpreter_simple_expression():
|
||||
result = await execute_server_tool(
|
||||
{
|
||||
"id": "call_ci_ok",
|
||||
"function": {
|
||||
"name": "code_interpreter",
|
||||
"arguments": '{"code":"sum([1, 2, 3]) + 4"}',
|
||||
},
|
||||
}
|
||||
)
|
||||
assert result["status"]["code"] == 200
|
||||
assert result["output"]["result"] == 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_code_interpreter_blocks_import_and_io():
|
||||
result = await execute_server_tool(
|
||||
{
|
||||
"id": "call_ci_bad",
|
||||
"function": {
|
||||
"name": "code_interpreter",
|
||||
"arguments": '{"code":"__import__(\\"os\\").system(\\"ls\\")"}',
|
||||
},
|
||||
}
|
||||
)
|
||||
assert result["status"]["code"] == 422
|
||||
assert result["status"]["message"] == "invalid_code"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_current_time_uses_local_system_clock(monkeypatch):
|
||||
async def _should_not_be_called(_tool_id):
|
||||
raise AssertionError("fetch_tool_resource should not be called for current_time")
|
||||
|
||||
monkeypatch.setattr("core.tool_executor.fetch_tool_resource", _should_not_be_called)
|
||||
|
||||
result = await execute_server_tool(
|
||||
{
|
||||
"id": "call_time_ok",
|
||||
"function": {
|
||||
"name": "current_time",
|
||||
"arguments": "{}",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
assert result["status"]["code"] == 200
|
||||
assert result["status"]["message"] == "ok"
|
||||
assert "local_time" in result["output"]
|
||||
assert "iso" in result["output"]
|
||||
assert "timestamp" in result["output"]
|
||||
1
utils/__init__.py
Normal file
1
utils/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Utilities Package"""
|
||||
83
utils/logging.py
Normal file
83
utils/logging.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""Logging configuration utilities."""
|
||||
|
||||
import sys
|
||||
from loguru import logger
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def setup_logging(
|
||||
log_level: str = "INFO",
|
||||
log_format: str = "text",
|
||||
log_to_file: bool = True,
|
||||
log_dir: str = "logs"
|
||||
):
|
||||
"""
|
||||
Configure structured logging with loguru.
|
||||
|
||||
Args:
|
||||
log_level: Logging level (DEBUG, INFO, WARNING, ERROR)
|
||||
log_format: Format type (json or text)
|
||||
log_to_file: Whether to log to file
|
||||
log_dir: Directory for log files
|
||||
"""
|
||||
# Remove default handler
|
||||
logger.remove()
|
||||
|
||||
# Console handler
|
||||
if log_format == "json":
|
||||
logger.add(
|
||||
sys.stdout,
|
||||
format="{message}",
|
||||
level=log_level,
|
||||
serialize=True,
|
||||
colorize=False
|
||||
)
|
||||
else:
|
||||
logger.add(
|
||||
sys.stdout,
|
||||
format="<green>{time:HH:mm:ss}</green> | <level>{level: <8}</level> | <level>{message}</level>",
|
||||
level=log_level,
|
||||
colorize=True
|
||||
)
|
||||
|
||||
# File handler
|
||||
if log_to_file:
|
||||
log_path = Path(log_dir)
|
||||
log_path.mkdir(exist_ok=True)
|
||||
|
||||
if log_format == "json":
|
||||
logger.add(
|
||||
log_path / "active_call_{time:YYYY-MM-DD}.log",
|
||||
format="{message}",
|
||||
level=log_level,
|
||||
rotation="1 day",
|
||||
retention="7 days",
|
||||
compression="zip",
|
||||
serialize=True
|
||||
)
|
||||
else:
|
||||
logger.add(
|
||||
log_path / "active_call_{time:YYYY-MM-DD}.log",
|
||||
format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}",
|
||||
level=log_level,
|
||||
rotation="1 day",
|
||||
retention="7 days",
|
||||
compression="zip"
|
||||
)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
def get_logger(name: str = None):
|
||||
"""
|
||||
Get a logger instance.
|
||||
|
||||
Args:
|
||||
name: Logger name (optional)
|
||||
|
||||
Returns:
|
||||
Logger instance
|
||||
"""
|
||||
if name:
|
||||
return logger.bind(name=name)
|
||||
return logger
|
||||
Reference in New Issue
Block a user