Integrate eou and vad
This commit is contained in:
1
app/__init__.py
Normal file
1
app/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Active-Call Application Package"""
|
||||
81
app/config.py
Normal file
81
app/config.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""Configuration management using Pydantic settings."""
|
||||
|
||||
from typing import List, Optional
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
import json
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings loaded from environment variables."""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
case_sensitive=False,
|
||||
extra="ignore"
|
||||
)
|
||||
|
||||
# Server Configuration
|
||||
host: str = Field(default="0.0.0.0", description="Server host address")
|
||||
port: int = Field(default=8000, description="Server port")
|
||||
external_ip: Optional[str] = Field(default=None, description="External IP for NAT traversal")
|
||||
|
||||
# Audio Configuration
|
||||
sample_rate: int = Field(default=16000, description="Audio sample rate in Hz")
|
||||
chunk_size_ms: int = Field(default=20, description="Audio chunk duration in milliseconds")
|
||||
default_codec: str = Field(default="pcm", description="Default audio codec")
|
||||
|
||||
# VAD Configuration
|
||||
vad_type: str = Field(default="silero", description="VAD algorithm type")
|
||||
vad_model_path: str = Field(default="data/vad/silero_vad.onnx", description="Path to VAD model")
|
||||
vad_threshold: float = Field(default=0.5, description="VAD detection threshold")
|
||||
vad_min_speech_duration_ms: int = Field(default=250, description="Minimum speech duration in milliseconds")
|
||||
vad_eou_threshold_ms: int = Field(default=400, description="End of utterance (silence) threshold in milliseconds")
|
||||
|
||||
# Logging
|
||||
log_level: str = Field(default="INFO", description="Logging level")
|
||||
log_format: str = Field(default="json", description="Log format (json or text)")
|
||||
|
||||
# CORS
|
||||
cors_origins: str = Field(
|
||||
default='["http://localhost:3000", "http://localhost:8080"]',
|
||||
description="CORS allowed origins"
|
||||
)
|
||||
|
||||
# ICE Servers (WebRTC)
|
||||
ice_servers: str = Field(
|
||||
default='[{"urls": "stun:stun.l.google.com:19302"}]',
|
||||
description="ICE servers configuration"
|
||||
)
|
||||
|
||||
@property
|
||||
def chunk_size_bytes(self) -> int:
|
||||
"""Calculate chunk size in bytes based on sample rate and duration."""
|
||||
# 16-bit (2 bytes) per sample, mono channel
|
||||
return int(self.sample_rate * 2 * (self.chunk_size_ms / 1000.0))
|
||||
|
||||
@property
|
||||
def cors_origins_list(self) -> List[str]:
|
||||
"""Parse CORS origins from JSON string."""
|
||||
try:
|
||||
return json.loads(self.cors_origins)
|
||||
except json.JSONDecodeError:
|
||||
return ["http://localhost:3000", "http://localhost:8080"]
|
||||
|
||||
@property
|
||||
def ice_servers_list(self) -> List[dict]:
|
||||
"""Parse ICE servers from JSON string."""
|
||||
try:
|
||||
return json.loads(self.ice_servers)
|
||||
except json.JSONDecodeError:
|
||||
return [{"urls": "stun:stun.l.google.com:19302"}]
|
||||
|
||||
|
||||
# Global settings instance
|
||||
settings = Settings()
|
||||
|
||||
|
||||
def get_settings() -> Settings:
|
||||
"""Get application settings instance."""
|
||||
return settings
|
||||
290
app/main.py
Normal file
290
app/main.py
Normal file
@@ -0,0 +1,290 @@
|
||||
"""FastAPI application with WebSocket and WebRTC endpoints."""
|
||||
|
||||
import uuid
|
||||
import json
|
||||
from typing import Dict, Any, Optional
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from loguru import logger
|
||||
|
||||
# Try to import aiortc (optional for WebRTC functionality)
|
||||
try:
|
||||
from aiortc import RTCPeerConnection, RTCSessionDescription
|
||||
AIORTC_AVAILABLE = True
|
||||
except ImportError:
|
||||
AIORTC_AVAILABLE = False
|
||||
logger.warning("aiortc not available - WebRTC endpoint will be disabled")
|
||||
|
||||
from app.config import settings
|
||||
from core.transports import SocketTransport, WebRtcTransport
|
||||
from core.session import Session
|
||||
from processors.tracks import Resampled16kTrack
|
||||
from core.events import get_event_bus, reset_event_bus
|
||||
|
||||
# Initialize FastAPI
|
||||
app = FastAPI(title="Python Active-Call", version="0.1.0")
|
||||
|
||||
# Configure CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.cors_origins_list,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Active sessions storage
|
||||
active_sessions: Dict[str, Session] = {}
|
||||
|
||||
# Configure logging
|
||||
logger.remove()
|
||||
logger.add(
|
||||
"../logs/active_call_{time}.log",
|
||||
rotation="1 day",
|
||||
retention="7 days",
|
||||
level=settings.log_level,
|
||||
format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}"
|
||||
)
|
||||
logger.add(
|
||||
lambda msg: print(msg, end=""),
|
||||
level=settings.log_level,
|
||||
format="{time:HH:mm:ss} | {level: <8} | {message}"
|
||||
)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint."""
|
||||
return {"status": "healthy", "sessions": len(active_sessions)}
|
||||
|
||||
|
||||
@app.get("/iceservers")
|
||||
async def get_ice_servers():
|
||||
"""Get ICE servers configuration for WebRTC."""
|
||||
return settings.ice_servers_list
|
||||
|
||||
|
||||
@app.get("/call/lists")
|
||||
async def list_calls():
|
||||
"""List all active calls."""
|
||||
return {
|
||||
"calls": [
|
||||
{
|
||||
"id": session_id,
|
||||
"state": session.state,
|
||||
"created_at": session.created_at
|
||||
}
|
||||
for session_id, session in active_sessions.items()
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@app.post("/call/kill/{session_id}")
|
||||
async def kill_call(session_id: str):
|
||||
"""Kill a specific active call."""
|
||||
if session_id not in active_sessions:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
session = active_sessions[session_id]
|
||||
await session.cleanup()
|
||||
del active_sessions[session_id]
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
"""
|
||||
WebSocket endpoint for raw audio streaming.
|
||||
|
||||
Accepts mixed text/binary frames:
|
||||
- Text frames: JSON commands
|
||||
- Binary frames: PCM audio data (16kHz, 16-bit, mono)
|
||||
"""
|
||||
await websocket.accept()
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
# Create transport and session
|
||||
transport = SocketTransport(websocket)
|
||||
session = Session(session_id, transport)
|
||||
active_sessions[session_id] = session
|
||||
|
||||
logger.info(f"WebSocket connection established: {session_id}")
|
||||
|
||||
try:
|
||||
# Receive loop
|
||||
while True:
|
||||
message = await websocket.receive()
|
||||
|
||||
# Handle binary audio data
|
||||
if "bytes" in message:
|
||||
await session.handle_audio(message["bytes"])
|
||||
|
||||
# Handle text commands
|
||||
elif "text" in message:
|
||||
await session.handle_text(message["text"])
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket disconnected: {session_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket error: {e}", exc_info=True)
|
||||
|
||||
finally:
|
||||
# Cleanup session
|
||||
if session_id in active_sessions:
|
||||
await session.cleanup()
|
||||
del active_sessions[session_id]
|
||||
|
||||
logger.info(f"Session {session_id} removed")
|
||||
|
||||
|
||||
@app.websocket("/webrtc")
|
||||
async def webrtc_endpoint(websocket: WebSocket):
|
||||
"""
|
||||
WebRTC endpoint for WebRTC audio streaming.
|
||||
|
||||
Uses WebSocket for signaling (SDP exchange) and WebRTC for media transport.
|
||||
"""
|
||||
# Check if aiortc is available
|
||||
if not AIORTC_AVAILABLE:
|
||||
await websocket.close(code=1011, reason="WebRTC not available - aiortc/av not installed")
|
||||
logger.warning("WebRTC connection attempted but aiortc is not available")
|
||||
return
|
||||
await websocket.accept()
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
# Create WebRTC peer connection
|
||||
pc = RTCPeerConnection()
|
||||
|
||||
# Create transport and session
|
||||
transport = WebRtcTransport(websocket, pc)
|
||||
session = Session(session_id, transport)
|
||||
active_sessions[session_id] = session
|
||||
|
||||
logger.info(f"WebRTC connection established: {session_id}")
|
||||
|
||||
# Track handler for incoming audio
|
||||
@pc.on("track")
|
||||
def on_track(track):
|
||||
logger.info(f"Track received: {track.kind}")
|
||||
|
||||
if track.kind == "audio":
|
||||
# Wrap track with resampler
|
||||
wrapped_track = Resampled16kTrack(track)
|
||||
|
||||
# Create task to pull audio from track
|
||||
async def pull_audio():
|
||||
try:
|
||||
while True:
|
||||
frame = await wrapped_track.recv()
|
||||
# Convert frame to bytes
|
||||
pcm_bytes = frame.to_ndarray().tobytes()
|
||||
# Feed to session
|
||||
await session.handle_audio(pcm_bytes)
|
||||
except Exception as e:
|
||||
logger.error(f"Error pulling audio from track: {e}")
|
||||
|
||||
asyncio.create_task(pull_audio())
|
||||
|
||||
@pc.on("connectionstatechange")
|
||||
async def on_connectionstatechange():
|
||||
logger.info(f"Connection state: {pc.connectionState}")
|
||||
if pc.connectionState == "failed" or pc.connectionState == "closed":
|
||||
await session.cleanup()
|
||||
|
||||
try:
|
||||
# Signaling loop
|
||||
while True:
|
||||
message = await websocket.receive()
|
||||
|
||||
if "text" not in message:
|
||||
continue
|
||||
|
||||
data = json.loads(message["text"])
|
||||
|
||||
# Handle SDP offer/answer
|
||||
if "sdp" in data and "type" in data:
|
||||
logger.info(f"Received SDP {data['type']}")
|
||||
|
||||
# Set remote description
|
||||
offer = RTCSessionDescription(sdp=data["sdp"], type=data["type"])
|
||||
await pc.setRemoteDescription(offer)
|
||||
|
||||
# Create and set local description
|
||||
if data["type"] == "offer":
|
||||
answer = await pc.createAnswer()
|
||||
await pc.setLocalDescription(answer)
|
||||
|
||||
# Send answer back
|
||||
await websocket.send_text(json.dumps({
|
||||
"event": "answer",
|
||||
"trackId": session_id,
|
||||
"timestamp": int(asyncio.get_event_loop().time() * 1000),
|
||||
"sdp": pc.localDescription.sdp
|
||||
}))
|
||||
|
||||
logger.info(f"Sent SDP answer")
|
||||
|
||||
else:
|
||||
# Handle other commands
|
||||
await session.handle_text(message["text"])
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebRTC WebSocket disconnected: {session_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"WebRTC error: {e}", exc_info=True)
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
await pc.close()
|
||||
if session_id in active_sessions:
|
||||
await session.cleanup()
|
||||
del active_sessions[session_id]
|
||||
|
||||
logger.info(f"WebRTC session {session_id} removed")
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""Run on application startup."""
|
||||
logger.info("Starting Python Active-Call server")
|
||||
logger.info(f"Server: {settings.host}:{settings.port}")
|
||||
logger.info(f"Sample rate: {settings.sample_rate} Hz")
|
||||
logger.info(f"VAD model: {settings.vad_model_path}")
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
"""Run on application shutdown."""
|
||||
logger.info("Shutting down Python Active-Call server")
|
||||
|
||||
# Cleanup all sessions
|
||||
for session_id, session in active_sessions.items():
|
||||
await session.cleanup()
|
||||
|
||||
# Close event bus
|
||||
event_bus = get_event_bus()
|
||||
await event_bus.close()
|
||||
reset_event_bus()
|
||||
|
||||
logger.info("Server shutdown complete")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
# Create logs directory
|
||||
import os
|
||||
os.makedirs("logs", exist_ok=True)
|
||||
|
||||
# Run server
|
||||
uvicorn.run(
|
||||
"app.main:app",
|
||||
host=settings.host,
|
||||
port=settings.port,
|
||||
reload=True,
|
||||
log_level=settings.log_level.lower()
|
||||
)
|
||||
1
core/__init__.py
Normal file
1
core/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Core Components Package"""
|
||||
134
core/events.py
Normal file
134
core/events.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""Event bus for pub/sub communication between components."""
|
||||
|
||||
import asyncio
|
||||
from typing import Callable, Dict, List, Any, Optional
|
||||
from collections import defaultdict
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class EventBus:
|
||||
"""
|
||||
Async event bus for pub/sub communication.
|
||||
|
||||
Similar to the original Rust implementation's broadcast channel.
|
||||
Components can subscribe to specific event types and receive events asynchronously.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the event bus."""
|
||||
self._subscribers: Dict[str, List[Callable]] = defaultdict(list)
|
||||
self._lock = asyncio.Lock()
|
||||
self._running = True
|
||||
|
||||
def subscribe(self, event_type: str, callback: Callable[[Dict[str, Any]], None]) -> None:
|
||||
"""
|
||||
Subscribe to an event type.
|
||||
|
||||
Args:
|
||||
event_type: Type of event to subscribe to (e.g., "speaking", "silence")
|
||||
callback: Async callback function that receives event data
|
||||
"""
|
||||
if not self._running:
|
||||
logger.warning(f"Event bus is shut down, ignoring subscription to {event_type}")
|
||||
return
|
||||
|
||||
self._subscribers[event_type].append(callback)
|
||||
logger.debug(f"Subscribed to event type: {event_type}")
|
||||
|
||||
def unsubscribe(self, event_type: str, callback: Callable[[Dict[str, Any]], None]) -> None:
|
||||
"""
|
||||
Unsubscribe from an event type.
|
||||
|
||||
Args:
|
||||
event_type: Type of event to unsubscribe from
|
||||
callback: Callback function to remove
|
||||
"""
|
||||
if callback in self._subscribers[event_type]:
|
||||
self._subscribers[event_type].remove(callback)
|
||||
logger.debug(f"Unsubscribed from event type: {event_type}")
|
||||
|
||||
async def publish(self, event_type: str, event_data: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Publish an event to all subscribers.
|
||||
|
||||
Args:
|
||||
event_type: Type of event to publish
|
||||
event_data: Event data to send to subscribers
|
||||
"""
|
||||
if not self._running:
|
||||
logger.warning(f"Event bus is shut down, ignoring event: {event_type}")
|
||||
return
|
||||
|
||||
# Get subscribers for this event type
|
||||
subscribers = self._subscribers.get(event_type, [])
|
||||
|
||||
if not subscribers:
|
||||
logger.debug(f"No subscribers for event type: {event_type}")
|
||||
return
|
||||
|
||||
# Notify all subscribers concurrently
|
||||
tasks = []
|
||||
for callback in subscribers:
|
||||
try:
|
||||
# Create task for each subscriber
|
||||
task = asyncio.create_task(self._call_subscriber(callback, event_data))
|
||||
tasks.append(task)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating task for subscriber: {e}")
|
||||
|
||||
# Wait for all subscribers to complete
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
logger.debug(f"Published event '{event_type}' to {len(tasks)} subscribers")
|
||||
|
||||
async def _call_subscriber(self, callback: Callable[[Dict[str, Any]], None], event_data: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Call a subscriber callback with error handling.
|
||||
|
||||
Args:
|
||||
callback: Subscriber callback function
|
||||
event_data: Event data to pass to callback
|
||||
"""
|
||||
try:
|
||||
# Check if callback is a coroutine function
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(event_data)
|
||||
else:
|
||||
callback(event_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in subscriber callback: {e}", exc_info=True)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the event bus and stop processing events."""
|
||||
self._running = False
|
||||
self._subscribers.clear()
|
||||
logger.info("Event bus closed")
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the event bus is running."""
|
||||
return self._running
|
||||
|
||||
|
||||
# Global event bus instance
|
||||
_event_bus: Optional[EventBus] = None
|
||||
|
||||
|
||||
def get_event_bus() -> EventBus:
|
||||
"""
|
||||
Get the global event bus instance.
|
||||
|
||||
Returns:
|
||||
EventBus instance
|
||||
"""
|
||||
global _event_bus
|
||||
if _event_bus is None:
|
||||
_event_bus = EventBus()
|
||||
return _event_bus
|
||||
|
||||
|
||||
def reset_event_bus() -> None:
|
||||
"""Reset the global event bus (mainly for testing)."""
|
||||
global _event_bus
|
||||
_event_bus = None
|
||||
131
core/pipeline.py
Normal file
131
core/pipeline.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""Audio processing pipeline."""
|
||||
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
from loguru import logger
|
||||
|
||||
from core.transports import BaseTransport
|
||||
from core.events import EventBus, get_event_bus
|
||||
from processors.vad import VADProcessor, SileroVAD
|
||||
from app.config import settings
|
||||
|
||||
|
||||
class AudioPipeline:
|
||||
"""
|
||||
Audio processing pipeline.
|
||||
|
||||
Processes incoming audio through VAD and emits events.
|
||||
"""
|
||||
|
||||
def __init__(self, transport: BaseTransport, session_id: str):
|
||||
"""
|
||||
Initialize audio pipeline.
|
||||
|
||||
Args:
|
||||
transport: Transport instance for sending events/audio
|
||||
session_id: Session identifier for event tracking
|
||||
"""
|
||||
self.transport = transport
|
||||
self.session_id = session_id
|
||||
self.event_bus = get_event_bus()
|
||||
|
||||
# Initialize VAD
|
||||
self.vad_model = SileroVAD(
|
||||
model_path=settings.vad_model_path,
|
||||
sample_rate=settings.sample_rate
|
||||
)
|
||||
self.vad_processor = VADProcessor(
|
||||
vad_model=self.vad_model,
|
||||
threshold=settings.vad_threshold,
|
||||
silence_threshold_ms=settings.vad_eou_threshold_ms,
|
||||
min_speech_duration_ms=settings.vad_min_speech_duration_ms
|
||||
)
|
||||
|
||||
# State
|
||||
self.is_bot_speaking = False
|
||||
self.interrupt_signal = asyncio.Event()
|
||||
self._running = True
|
||||
|
||||
logger.info(f"Audio pipeline initialized for session {session_id}")
|
||||
|
||||
async def process_input(self, pcm_bytes: bytes) -> None:
|
||||
"""
|
||||
Process incoming audio chunk.
|
||||
|
||||
Args:
|
||||
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
|
||||
"""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
try:
|
||||
# Process through VAD
|
||||
result = self.vad_processor.process(pcm_bytes, settings.chunk_size_ms)
|
||||
|
||||
if result:
|
||||
event_type, probability = result
|
||||
|
||||
# Emit event through event bus
|
||||
await self.event_bus.publish(event_type, {
|
||||
"trackId": self.session_id,
|
||||
"probability": probability
|
||||
})
|
||||
|
||||
# Send event to client
|
||||
if event_type == "speaking":
|
||||
logger.info(f"User speaking started (session {self.session_id})")
|
||||
await self.transport.send_event({
|
||||
"event": "speaking",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms(),
|
||||
"startTime": self._get_timestamp_ms()
|
||||
})
|
||||
|
||||
elif event_type == "silence":
|
||||
logger.info(f"User speaking stopped (session {self.session_id})")
|
||||
await self.transport.send_event({
|
||||
"event": "silence",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms(),
|
||||
"startTime": self._get_timestamp_ms(),
|
||||
"duration": 0 # TODO: Calculate actual duration
|
||||
})
|
||||
|
||||
elif event_type == "eou":
|
||||
logger.info(f"EOU detected (session {self.session_id})")
|
||||
await self.transport.send_event({
|
||||
"event": "eou",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Pipeline processing error: {e}", exc_info=True)
|
||||
|
||||
async def process_text_input(self, text: str) -> None:
|
||||
"""
|
||||
Process text input (chat command).
|
||||
|
||||
Args:
|
||||
text: Text input
|
||||
"""
|
||||
logger.info(f"Processing text input: {text[:50]}...")
|
||||
# TODO: Implement text processing (LLM integration, etc.)
|
||||
# For now, just log it
|
||||
|
||||
async def interrupt(self) -> None:
|
||||
"""Interrupt current audio playback."""
|
||||
if self.is_bot_speaking:
|
||||
self.interrupt_signal.set()
|
||||
logger.info(f"Pipeline interrupted for session {self.session_id}")
|
||||
|
||||
async def cleanup(self) -> None:
|
||||
"""Cleanup pipeline resources."""
|
||||
logger.info(f"Cleaning up pipeline for session {self.session_id}")
|
||||
self._running = False
|
||||
self.interrupt_signal.set()
|
||||
|
||||
def _get_timestamp_ms(self) -> int:
|
||||
"""Get current timestamp in milliseconds."""
|
||||
import time
|
||||
return int(time.time() * 1000)
|
||||
266
core/session.py
Normal file
266
core/session.py
Normal file
@@ -0,0 +1,266 @@
|
||||
"""Session management for active calls."""
|
||||
|
||||
import uuid
|
||||
import json
|
||||
from typing import Optional, Dict, Any
|
||||
from loguru import logger
|
||||
|
||||
from core.transports import BaseTransport
|
||||
from core.pipeline import AudioPipeline
|
||||
from models.commands import parse_command, TTSCommand, ChatCommand, InterruptCommand, HangupCommand
|
||||
|
||||
|
||||
class Session:
|
||||
"""
|
||||
Manages a single call session.
|
||||
|
||||
Handles command routing, audio processing, and session lifecycle.
|
||||
"""
|
||||
|
||||
def __init__(self, session_id: str, transport: BaseTransport):
|
||||
"""
|
||||
Initialize session.
|
||||
|
||||
Args:
|
||||
session_id: Unique session identifier
|
||||
transport: Transport instance for communication
|
||||
"""
|
||||
self.id = session_id
|
||||
self.transport = transport
|
||||
self.pipeline = AudioPipeline(transport, session_id)
|
||||
|
||||
# Session state
|
||||
self.created_at = None
|
||||
self.state = "created" # created, invited, accepted, ringing, hungup
|
||||
|
||||
# Track IDs
|
||||
self.current_track_id: Optional[str] = str(uuid.uuid4())
|
||||
|
||||
logger.info(f"Session {self.id} created")
|
||||
|
||||
async def handle_text(self, text_data: str) -> None:
|
||||
"""
|
||||
Handle incoming text data (JSON commands).
|
||||
|
||||
Args:
|
||||
text_data: JSON text data
|
||||
"""
|
||||
try:
|
||||
data = json.loads(text_data)
|
||||
command = parse_command(data)
|
||||
command_type = command.command
|
||||
|
||||
logger.info(f"Session {self.id} received command: {command_type}")
|
||||
|
||||
# Route command to appropriate handler
|
||||
if command_type == "invite":
|
||||
await self._handle_invite(data)
|
||||
|
||||
elif command_type == "accept":
|
||||
await self._handle_accept(data)
|
||||
|
||||
elif command_type == "reject":
|
||||
await self._handle_reject(data)
|
||||
|
||||
elif command_type == "ringing":
|
||||
await self._handle_ringing(data)
|
||||
|
||||
elif command_type == "tts":
|
||||
await self._handle_tts(command)
|
||||
|
||||
elif command_type == "play":
|
||||
await self._handle_play(data)
|
||||
|
||||
elif command_type == "interrupt":
|
||||
await self._handle_interrupt(command)
|
||||
|
||||
elif command_type == "pause":
|
||||
await self._handle_pause()
|
||||
|
||||
elif command_type == "resume":
|
||||
await self._handle_resume()
|
||||
|
||||
elif command_type == "hangup":
|
||||
await self._handle_hangup(command)
|
||||
|
||||
elif command_type == "history":
|
||||
await self._handle_history(data)
|
||||
|
||||
elif command_type == "chat":
|
||||
await self._handle_chat(command)
|
||||
|
||||
else:
|
||||
logger.warning(f"Session {self.id} unknown command: {command_type}")
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Session {self.id} JSON decode error: {e}")
|
||||
await self._send_error("client", f"Invalid JSON: {e}")
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Session {self.id} command parse error: {e}")
|
||||
await self._send_error("client", f"Invalid command: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Session {self.id} handle_text error: {e}", exc_info=True)
|
||||
await self._send_error("server", f"Internal error: {e}")
|
||||
|
||||
async def handle_audio(self, audio_bytes: bytes) -> None:
|
||||
"""
|
||||
Handle incoming audio data.
|
||||
|
||||
Args:
|
||||
audio_bytes: PCM audio data
|
||||
"""
|
||||
try:
|
||||
await self.pipeline.process_input(audio_bytes)
|
||||
except Exception as e:
|
||||
logger.error(f"Session {self.id} handle_audio error: {e}", exc_info=True)
|
||||
|
||||
async def _handle_invite(self, data: Dict[str, Any]) -> None:
|
||||
"""Handle invite command."""
|
||||
self.state = "invited"
|
||||
option = data.get("option", {})
|
||||
|
||||
# Send answer event
|
||||
await self.transport.send_event({
|
||||
"event": "answer",
|
||||
"trackId": self.current_track_id,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
})
|
||||
|
||||
logger.info(f"Session {self.id} invited with codec: {option.get('codec', 'pcm')}")
|
||||
|
||||
async def _handle_accept(self, data: Dict[str, Any]) -> None:
|
||||
"""Handle accept command."""
|
||||
self.state = "accepted"
|
||||
logger.info(f"Session {self.id} accepted")
|
||||
|
||||
async def _handle_reject(self, data: Dict[str, Any]) -> None:
|
||||
"""Handle reject command."""
|
||||
self.state = "rejected"
|
||||
reason = data.get("reason", "Rejected")
|
||||
logger.info(f"Session {self.id} rejected: {reason}")
|
||||
|
||||
async def _handle_ringing(self, data: Dict[str, Any]) -> None:
|
||||
"""Handle ringing command."""
|
||||
self.state = "ringing"
|
||||
logger.info(f"Session {self.id} ringing")
|
||||
|
||||
async def _handle_tts(self, command: TTSCommand) -> None:
|
||||
"""Handle TTS command."""
|
||||
logger.info(f"Session {self.id} TTS: {command.text[:50]}...")
|
||||
|
||||
# Send track start event
|
||||
await self.transport.send_event({
|
||||
"event": "trackStart",
|
||||
"trackId": self.current_track_id,
|
||||
"timestamp": self._get_timestamp_ms(),
|
||||
"playId": command.play_id
|
||||
})
|
||||
|
||||
# TODO: Implement actual TTS synthesis
|
||||
# For now, just send track end event
|
||||
await self.transport.send_event({
|
||||
"event": "trackEnd",
|
||||
"trackId": self.current_track_id,
|
||||
"timestamp": self._get_timestamp_ms(),
|
||||
"duration": 1000,
|
||||
"ssrc": 0,
|
||||
"playId": command.play_id
|
||||
})
|
||||
|
||||
async def _handle_play(self, data: Dict[str, Any]) -> None:
|
||||
"""Handle play command."""
|
||||
url = data.get("url", "")
|
||||
logger.info(f"Session {self.id} play: {url}")
|
||||
|
||||
# Send track start event
|
||||
await self.transport.send_event({
|
||||
"event": "trackStart",
|
||||
"trackId": self.current_track_id,
|
||||
"timestamp": self._get_timestamp_ms(),
|
||||
"playId": url
|
||||
})
|
||||
|
||||
# TODO: Implement actual audio playback
|
||||
# For now, just send track end event
|
||||
await self.transport.send_event({
|
||||
"event": "trackEnd",
|
||||
"trackId": self.current_track_id,
|
||||
"timestamp": self._get_timestamp_ms(),
|
||||
"duration": 1000,
|
||||
"ssrc": 0,
|
||||
"playId": url
|
||||
})
|
||||
|
||||
async def _handle_interrupt(self, command: InterruptCommand) -> None:
|
||||
"""Handle interrupt command."""
|
||||
if command.graceful:
|
||||
logger.info(f"Session {self.id} graceful interrupt")
|
||||
else:
|
||||
logger.info(f"Session {self.id} immediate interrupt")
|
||||
await self.pipeline.interrupt()
|
||||
|
||||
async def _handle_pause(self) -> None:
|
||||
"""Handle pause command."""
|
||||
logger.info(f"Session {self.id} paused")
|
||||
|
||||
async def _handle_resume(self) -> None:
|
||||
"""Handle resume command."""
|
||||
logger.info(f"Session {self.id} resumed")
|
||||
|
||||
async def _handle_hangup(self, command: HangupCommand) -> None:
|
||||
"""Handle hangup command."""
|
||||
self.state = "hungup"
|
||||
reason = command.reason or "User requested"
|
||||
logger.info(f"Session {self.id} hung up: {reason}")
|
||||
|
||||
# Send hangup event
|
||||
await self.transport.send_event({
|
||||
"event": "hangup",
|
||||
"timestamp": self._get_timestamp_ms(),
|
||||
"reason": reason,
|
||||
"initiator": command.initiator or "user"
|
||||
})
|
||||
|
||||
# Close transport
|
||||
await self.transport.close()
|
||||
|
||||
async def _handle_history(self, data: Dict[str, Any]) -> None:
|
||||
"""Handle history command."""
|
||||
speaker = data.get("speaker", "unknown")
|
||||
text = data.get("text", "")
|
||||
logger.info(f"Session {self.id} history [{speaker}]: {text[:50]}...")
|
||||
|
||||
async def _handle_chat(self, command: ChatCommand) -> None:
|
||||
"""Handle chat command."""
|
||||
logger.info(f"Session {self.id} chat: {command.text[:50]}...")
|
||||
# Process text input through pipeline
|
||||
await self.pipeline.process_text_input(command.text)
|
||||
|
||||
async def _send_error(self, sender: str, error_message: str) -> None:
|
||||
"""
|
||||
Send error event to client.
|
||||
|
||||
Args:
|
||||
sender: Component that generated the error
|
||||
error_message: Error message
|
||||
"""
|
||||
await self.transport.send_event({
|
||||
"event": "error",
|
||||
"trackId": self.current_track_id,
|
||||
"timestamp": self._get_timestamp_ms(),
|
||||
"sender": sender,
|
||||
"error": error_message
|
||||
})
|
||||
|
||||
def _get_timestamp_ms(self) -> int:
|
||||
"""Get current timestamp in milliseconds."""
|
||||
import time
|
||||
return int(time.time() * 1000)
|
||||
|
||||
async def cleanup(self) -> None:
|
||||
"""Cleanup session resources."""
|
||||
logger.info(f"Session {self.id} cleaning up")
|
||||
await self.pipeline.cleanup()
|
||||
await self.transport.close()
|
||||
207
core/transports.py
Normal file
207
core/transports.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""Transport layer for WebSocket and WebRTC communication."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from fastapi import WebSocket
|
||||
from loguru import logger
|
||||
|
||||
# Try to import aiortc (optional for WebRTC functionality)
|
||||
try:
|
||||
from aiortc import RTCPeerConnection
|
||||
AIORTC_AVAILABLE = True
|
||||
except ImportError:
|
||||
AIORTC_AVAILABLE = False
|
||||
RTCPeerConnection = None # Type hint placeholder
|
||||
|
||||
|
||||
class BaseTransport(ABC):
|
||||
"""
|
||||
Abstract base class for transports.
|
||||
|
||||
All transports must implement send_event and send_audio methods.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def send_event(self, event: dict) -> None:
|
||||
"""
|
||||
Send a JSON event to the client.
|
||||
|
||||
Args:
|
||||
event: Event data as dictionary
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def send_audio(self, pcm_bytes: bytes) -> None:
|
||||
"""
|
||||
Send audio data to the client.
|
||||
|
||||
Args:
|
||||
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None:
|
||||
"""Close the transport and cleanup resources."""
|
||||
pass
|
||||
|
||||
|
||||
class SocketTransport(BaseTransport):
|
||||
"""
|
||||
WebSocket transport for raw audio streaming.
|
||||
|
||||
Handles mixed text/binary frames over WebSocket connection.
|
||||
Uses asyncio.Lock to prevent frame interleaving.
|
||||
"""
|
||||
|
||||
def __init__(self, websocket: WebSocket):
|
||||
"""
|
||||
Initialize WebSocket transport.
|
||||
|
||||
Args:
|
||||
websocket: FastAPI WebSocket connection
|
||||
"""
|
||||
self.ws = websocket
|
||||
self.lock = asyncio.Lock() # Prevent frame interleaving
|
||||
self._closed = False
|
||||
|
||||
async def send_event(self, event: dict) -> None:
|
||||
"""
|
||||
Send a JSON event via WebSocket.
|
||||
|
||||
Args:
|
||||
event: Event data as dictionary
|
||||
"""
|
||||
if self._closed:
|
||||
logger.warning("Attempted to send event on closed transport")
|
||||
return
|
||||
|
||||
async with self.lock:
|
||||
try:
|
||||
await self.ws.send_text(json.dumps(event))
|
||||
logger.debug(f"Sent event: {event.get('event', 'unknown')}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending event: {e}")
|
||||
self._closed = True
|
||||
|
||||
async def send_audio(self, pcm_bytes: bytes) -> None:
|
||||
"""
|
||||
Send PCM audio data via WebSocket.
|
||||
|
||||
Args:
|
||||
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
|
||||
"""
|
||||
if self._closed:
|
||||
logger.warning("Attempted to send audio on closed transport")
|
||||
return
|
||||
|
||||
async with self.lock:
|
||||
try:
|
||||
await self.ws.send_bytes(pcm_bytes)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending audio: {e}")
|
||||
self._closed = True
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the WebSocket connection."""
|
||||
self._closed = True
|
||||
try:
|
||||
await self.ws.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing WebSocket: {e}")
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
"""Check if the transport is closed."""
|
||||
return self._closed
|
||||
|
||||
|
||||
class WebRtcTransport(BaseTransport):
|
||||
"""
|
||||
WebRTC transport for WebRTC audio streaming.
|
||||
|
||||
Uses WebSocket for signaling and RTCPeerConnection for media.
|
||||
"""
|
||||
|
||||
def __init__(self, websocket: WebSocket, pc):
|
||||
"""
|
||||
Initialize WebRTC transport.
|
||||
|
||||
Args:
|
||||
websocket: FastAPI WebSocket connection for signaling
|
||||
pc: RTCPeerConnection for media transport
|
||||
"""
|
||||
if not AIORTC_AVAILABLE:
|
||||
raise RuntimeError("aiortc is not available - WebRTC transport cannot be used")
|
||||
|
||||
self.ws = websocket
|
||||
self.pc = pc
|
||||
self.outbound_track = None # MediaStreamTrack for outbound audio
|
||||
self._closed = False
|
||||
|
||||
async def send_event(self, event: dict) -> None:
|
||||
"""
|
||||
Send a JSON event via WebSocket signaling.
|
||||
|
||||
Args:
|
||||
event: Event data as dictionary
|
||||
"""
|
||||
if self._closed:
|
||||
logger.warning("Attempted to send event on closed transport")
|
||||
return
|
||||
|
||||
try:
|
||||
await self.ws.send_text(json.dumps(event))
|
||||
logger.debug(f"Sent event: {event.get('event', 'unknown')}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending event: {e}")
|
||||
self._closed = True
|
||||
|
||||
async def send_audio(self, pcm_bytes: bytes) -> None:
|
||||
"""
|
||||
Send audio data via WebRTC track.
|
||||
|
||||
Note: In WebRTC, you don't send bytes directly. You push frames
|
||||
to a MediaStreamTrack that the peer connection is reading.
|
||||
|
||||
Args:
|
||||
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
|
||||
"""
|
||||
if self._closed:
|
||||
logger.warning("Attempted to send audio on closed transport")
|
||||
return
|
||||
|
||||
# This would require a custom MediaStreamTrack implementation
|
||||
# For now, we'll log this as a placeholder
|
||||
logger.debug(f"Audio bytes queued for WebRTC track: {len(pcm_bytes)} bytes")
|
||||
|
||||
# TODO: Implement outbound audio track if needed
|
||||
# if self.outbound_track:
|
||||
# await self.outbound_track.add_frame(pcm_bytes)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the WebRTC connection."""
|
||||
self._closed = True
|
||||
try:
|
||||
await self.pc.close()
|
||||
await self.ws.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing WebRTC transport: {e}")
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
"""Check if the transport is closed."""
|
||||
return self._closed
|
||||
|
||||
def set_outbound_track(self, track):
|
||||
"""
|
||||
Set the outbound audio track for sending audio to client.
|
||||
|
||||
Args:
|
||||
track: MediaStreamTrack for outbound audio
|
||||
"""
|
||||
self.outbound_track = track
|
||||
logger.debug("Set outbound track for WebRTC transport")
|
||||
BIN
data/vad/silero_vad.onnx
Normal file
BIN
data/vad/silero_vad.onnx
Normal file
Binary file not shown.
137
examples/mic_client.py
Normal file
137
examples/mic_client.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
Microphone WebSocket Client
|
||||
|
||||
Connects to the backend WebSocket endpoint and streams audio from the microphone.
|
||||
Used to test VAD and EOU detection.
|
||||
|
||||
Dependencies:
|
||||
pip install pyaudio aiohttp
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import pyaudio
|
||||
import json
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
# Configuration
|
||||
SERVER_URL = "ws://localhost:8000/ws"
|
||||
SAMPLE_RATE = 16000
|
||||
CHANNELS = 1
|
||||
CHUNK_DURATION_MS = 20
|
||||
CHUNK_SIZE = int(SAMPLE_RATE * (CHUNK_DURATION_MS / 1000.0)) # 320 samples for 20ms
|
||||
FORMAT = pyaudio.paInt16
|
||||
|
||||
async def send_audio_loop(ws, stream):
|
||||
"""Read from microphone and send to WebSocket."""
|
||||
print("🎙️ Microphone streaming started...")
|
||||
try:
|
||||
while True:
|
||||
# Read non-blocking? PyAudio read is blocking, so run in executor or use specialized async lib.
|
||||
# For simplicity in this script, we'll just read. It might block the event loop slightly
|
||||
# but for 20ms chunks it's usually acceptable for a test script.
|
||||
# To be proper async, we should run_in_executor.
|
||||
data = await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: stream.read(CHUNK_SIZE, exception_on_overflow=False)
|
||||
)
|
||||
|
||||
await ws.send_bytes(data)
|
||||
# No sleep needed here as microphone dictates the timing
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error in send loop: {e}")
|
||||
|
||||
async def receive_loop(ws):
|
||||
"""Listen for VAD/EOU events."""
|
||||
print("👂 Listening for server events...")
|
||||
async for msg in ws:
|
||||
timestamp = datetime.now().strftime("%H:%M:%S.%f")[:-3]
|
||||
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
try:
|
||||
data = json.loads(msg.data)
|
||||
event = data.get('event')
|
||||
|
||||
# Highlight VAD/EOU events
|
||||
if event == 'speaking':
|
||||
print(f"[{timestamp}] 🗣️ SPEAKING STARTED")
|
||||
elif event == 'silence':
|
||||
print(f"[{timestamp}] 🤫 SILENCE DETECTED")
|
||||
elif event == 'eou':
|
||||
print(f"[{timestamp}] ✅ END OF UTTERANCE (EOU)")
|
||||
elif event == 'error':
|
||||
print(f"[{timestamp}] ❌ ERROR: {data.get('error')}")
|
||||
else:
|
||||
print(f"[{timestamp}] 📩 {event}: {str(data)[:100]}")
|
||||
|
||||
except json.JSONDecodeError:
|
||||
print(f"[{timestamp}] 📄 Text: {msg.data}")
|
||||
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
||||
print("❌ Connection closed")
|
||||
break
|
||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||
print("❌ Connection error")
|
||||
break
|
||||
|
||||
async def main():
|
||||
p = pyaudio.PyAudio()
|
||||
|
||||
# Check for input devices
|
||||
info = p.get_host_api_info_by_index(0)
|
||||
numdevices = info.get('deviceCount')
|
||||
if numdevices == 0:
|
||||
print("❌ No audio input devices found")
|
||||
return
|
||||
|
||||
# Open microphone stream
|
||||
try:
|
||||
stream = p.open(format=FORMAT,
|
||||
channels=CHANNELS,
|
||||
rate=SAMPLE_RATE,
|
||||
input=True,
|
||||
frames_per_buffer=CHUNK_SIZE)
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to open microphone: {e}")
|
||||
return
|
||||
|
||||
session = aiohttp.ClientSession()
|
||||
|
||||
try:
|
||||
print(f"🔌 Connecting to {SERVER_URL}...")
|
||||
async with session.ws_connect(SERVER_URL) as ws:
|
||||
print("✅ Connected!")
|
||||
|
||||
# 1. Send Invite
|
||||
invite_msg = {
|
||||
"command": "invite",
|
||||
"option": {
|
||||
"codec": "pcm",
|
||||
"samplerate": SAMPLE_RATE
|
||||
}
|
||||
}
|
||||
await ws.send_json(invite_msg)
|
||||
print("📤 Sent Invite")
|
||||
|
||||
# 2. Run loops
|
||||
await asyncio.gather(
|
||||
receive_loop(ws),
|
||||
send_audio_loop(ws, stream)
|
||||
)
|
||||
|
||||
except aiohttp.ClientConnectorError:
|
||||
print(f"❌ Failed to connect to {SERVER_URL}. Is the server running?")
|
||||
except KeyboardInterrupt:
|
||||
print("\n👋 Stopping...")
|
||||
finally:
|
||||
stream.stop_stream()
|
||||
stream.close()
|
||||
p.terminate()
|
||||
await session.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
137
main.py
137
main.py
@@ -1,137 +0,0 @@
|
||||
"""
|
||||
Step 1: Minimal WebSocket Echo Server
|
||||
|
||||
This is the simplest possible WebSocket audio server.
|
||||
It accepts connections and echoes back events.
|
||||
|
||||
What you'll learn:
|
||||
- How to create a FastAPI WebSocket endpoint
|
||||
- How to handle mixed text/binary frames
|
||||
- Basic event sending
|
||||
|
||||
Test with:
|
||||
python main.py
|
||||
python test_client.py
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from fastapi import FastAPI, WebSocket
|
||||
from loguru import logger
|
||||
|
||||
# Configure logging
|
||||
logger.remove()
|
||||
logger.add(lambda msg: print(msg, end=""), level="INFO", format="<green>{time:HH:mm:ss}</green> | {level} | {message}")
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(title="Voice Gateway - Step 1")
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint."""
|
||||
return {"status": "healthy", "step": "1_minimal_echo"}
|
||||
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
"""
|
||||
WebSocket endpoint for audio streaming.
|
||||
|
||||
This is a minimal echo server that:
|
||||
1. Accepts WebSocket connections
|
||||
2. Sends a welcome event
|
||||
3. Receives text commands and binary audio
|
||||
4. Echoes speaking events back
|
||||
"""
|
||||
await websocket.accept()
|
||||
|
||||
# Generate unique session ID
|
||||
session_id = str(uuid.uuid4())
|
||||
logger.info(f"[{session_id}] Client connected")
|
||||
|
||||
try:
|
||||
# Send welcome event (answer)
|
||||
await websocket.send_json({
|
||||
"event": "answer",
|
||||
"trackId": session_id,
|
||||
"timestamp": _get_timestamp_ms()
|
||||
})
|
||||
logger.info(f"[{session_id}] Sent answer event")
|
||||
|
||||
# Message receive loop
|
||||
while True:
|
||||
message = await websocket.receive()
|
||||
|
||||
# Handle binary audio data
|
||||
if "bytes" in message:
|
||||
audio_bytes = message["bytes"]
|
||||
logger.info(f"[{session_id}] Received audio: {len(audio_bytes)} bytes")
|
||||
|
||||
# Send speaking event (echo back)
|
||||
await websocket.send_json({
|
||||
"event": "speaking",
|
||||
"trackId": session_id,
|
||||
"timestamp": _get_timestamp_ms(),
|
||||
"startTime": _get_timestamp_ms()
|
||||
})
|
||||
|
||||
# Handle text commands
|
||||
elif "text" in message:
|
||||
text_data = message["text"]
|
||||
logger.info(f"[{session_id}] Received text: {text_data[:100]}...")
|
||||
|
||||
try:
|
||||
data = json.loads(text_data)
|
||||
command = data.get("command", "unknown")
|
||||
logger.info(f"[{session_id}] Command: {command}")
|
||||
|
||||
# Handle basic commands
|
||||
if command == "invite":
|
||||
await websocket.send_json({
|
||||
"event": "answer",
|
||||
"trackId": session_id,
|
||||
"timestamp": _get_timestamp_ms()
|
||||
})
|
||||
logger.info(f"[{session_id}] Responded to invite")
|
||||
|
||||
elif command == "hangup":
|
||||
logger.info(f"[{session_id}] Hangup requested")
|
||||
break
|
||||
|
||||
elif command == "ping":
|
||||
await websocket.send_json({
|
||||
"event": "pong",
|
||||
"timestamp": _get_timestamp_ms()
|
||||
})
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"[{session_id}] Invalid JSON: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{session_id}] Error: {e}")
|
||||
|
||||
finally:
|
||||
logger.info(f"[{session_id}] Connection closed")
|
||||
|
||||
|
||||
def _get_timestamp_ms() -> int:
|
||||
"""Get current timestamp in milliseconds."""
|
||||
import time
|
||||
return int(time.time() * 1000)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
logger.info("🚀 Starting Step 1: Minimal WebSocket Echo Server")
|
||||
logger.info("📡 Server: ws://localhost:8000/ws")
|
||||
logger.info("🩺 Health: http://localhost:8000/health")
|
||||
|
||||
uvicorn.run(
|
||||
app,
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
log_level="info"
|
||||
)
|
||||
1
models/__init__.py
Normal file
1
models/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Data Models Package"""
|
||||
143
models/commands.py
Normal file
143
models/commands.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""Protocol command models matching the original active-call API."""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class InviteCommand(BaseModel):
|
||||
"""Invite command to initiate a call."""
|
||||
|
||||
command: str = Field(default="invite", description="Command type")
|
||||
option: Optional[Dict[str, Any]] = Field(default=None, description="Call configuration options")
|
||||
|
||||
|
||||
class AcceptCommand(BaseModel):
|
||||
"""Accept command to accept an incoming call."""
|
||||
|
||||
command: str = Field(default="accept", description="Command type")
|
||||
option: Optional[Dict[str, Any]] = Field(default=None, description="Call configuration options")
|
||||
|
||||
|
||||
class RejectCommand(BaseModel):
|
||||
"""Reject command to reject an incoming call."""
|
||||
|
||||
command: str = Field(default="reject", description="Command type")
|
||||
reason: str = Field(default="", description="Reason for rejection")
|
||||
code: Optional[int] = Field(default=None, description="SIP response code")
|
||||
|
||||
|
||||
class RingingCommand(BaseModel):
|
||||
"""Ringing command to send ringing response."""
|
||||
|
||||
command: str = Field(default="ringing", description="Command type")
|
||||
recorder: Optional[Dict[str, Any]] = Field(default=None, description="Call recording configuration")
|
||||
early_media: bool = Field(default=False, description="Enable early media")
|
||||
ringtone: Optional[str] = Field(default=None, description="Custom ringtone URL")
|
||||
|
||||
|
||||
class TTSCommand(BaseModel):
|
||||
"""TTS command to convert text to speech."""
|
||||
|
||||
command: str = Field(default="tts", description="Command type")
|
||||
text: str = Field(..., description="Text to synthesize")
|
||||
speaker: Optional[str] = Field(default=None, description="Speaker voice name")
|
||||
play_id: Optional[str] = Field(default=None, description="Unique identifier for this TTS session")
|
||||
auto_hangup: bool = Field(default=False, description="Auto hangup after TTS completion")
|
||||
streaming: bool = Field(default=False, description="Streaming text input")
|
||||
end_of_stream: bool = Field(default=False, description="End of streaming input")
|
||||
wait_input_timeout: Optional[int] = Field(default=None, description="Max time to wait for input (seconds)")
|
||||
option: Optional[Dict[str, Any]] = Field(default=None, description="TTS provider specific options")
|
||||
|
||||
|
||||
class PlayCommand(BaseModel):
|
||||
"""Play command to play audio from URL."""
|
||||
|
||||
command: str = Field(default="play", description="Command type")
|
||||
url: str = Field(..., description="URL of audio file to play")
|
||||
auto_hangup: bool = Field(default=False, description="Auto hangup after playback")
|
||||
wait_input_timeout: Optional[int] = Field(default=None, description="Max time to wait for input (seconds)")
|
||||
|
||||
|
||||
class InterruptCommand(BaseModel):
|
||||
"""Interrupt command to interrupt current playback."""
|
||||
|
||||
command: str = Field(default="interrupt", description="Command type")
|
||||
graceful: bool = Field(default=False, description="Wait for current TTS to complete")
|
||||
|
||||
|
||||
class PauseCommand(BaseModel):
|
||||
"""Pause command to pause current playback."""
|
||||
|
||||
command: str = Field(default="pause", description="Command type")
|
||||
|
||||
|
||||
class ResumeCommand(BaseModel):
|
||||
"""Resume command to resume paused playback."""
|
||||
|
||||
command: str = Field(default="resume", description="Command type")
|
||||
|
||||
|
||||
class HangupCommand(BaseModel):
|
||||
"""Hangup command to end the call."""
|
||||
|
||||
command: str = Field(default="hangup", description="Command type")
|
||||
reason: Optional[str] = Field(default=None, description="Reason for hangup")
|
||||
initiator: Optional[str] = Field(default=None, description="Who initiated the hangup")
|
||||
|
||||
|
||||
class HistoryCommand(BaseModel):
|
||||
"""History command to add conversation history."""
|
||||
|
||||
command: str = Field(default="history", description="Command type")
|
||||
speaker: str = Field(..., description="Speaker identifier")
|
||||
text: str = Field(..., description="Conversation text")
|
||||
|
||||
|
||||
class ChatCommand(BaseModel):
|
||||
"""Chat command for text-based conversation."""
|
||||
|
||||
command: str = Field(default="chat", description="Command type")
|
||||
text: str = Field(..., description="Chat text message")
|
||||
|
||||
|
||||
# Command type mapping
|
||||
COMMAND_TYPES = {
|
||||
"invite": InviteCommand,
|
||||
"accept": AcceptCommand,
|
||||
"reject": RejectCommand,
|
||||
"ringing": RingingCommand,
|
||||
"tts": TTSCommand,
|
||||
"play": PlayCommand,
|
||||
"interrupt": InterruptCommand,
|
||||
"pause": PauseCommand,
|
||||
"resume": ResumeCommand,
|
||||
"hangup": HangupCommand,
|
||||
"history": HistoryCommand,
|
||||
"chat": ChatCommand,
|
||||
}
|
||||
|
||||
|
||||
def parse_command(data: Dict[str, Any]) -> BaseModel:
|
||||
"""
|
||||
Parse a command from JSON data.
|
||||
|
||||
Args:
|
||||
data: JSON data as dictionary
|
||||
|
||||
Returns:
|
||||
Parsed command model
|
||||
|
||||
Raises:
|
||||
ValueError: If command type is unknown
|
||||
"""
|
||||
command_type = data.get("command")
|
||||
|
||||
if not command_type:
|
||||
raise ValueError("Missing 'command' field")
|
||||
|
||||
command_class = COMMAND_TYPES.get(command_type)
|
||||
|
||||
if not command_class:
|
||||
raise ValueError(f"Unknown command type: {command_type}")
|
||||
|
||||
return command_class(**data)
|
||||
126
models/config.py
Normal file
126
models/config.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""Configuration models for call options."""
|
||||
|
||||
from typing import Optional, Dict, Any, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class VADOption(BaseModel):
|
||||
"""Voice Activity Detection configuration."""
|
||||
|
||||
type: str = Field(default="silero", description="VAD algorithm type (silero, webrtc)")
|
||||
samplerate: int = Field(default=16000, description="Audio sample rate for VAD")
|
||||
speech_padding: int = Field(default=250, description="Speech padding in milliseconds")
|
||||
silence_padding: int = Field(default=100, description="Silence padding in milliseconds")
|
||||
ratio: float = Field(default=0.5, description="Voice detection ratio threshold")
|
||||
voice_threshold: float = Field(default=0.5, description="Voice energy threshold")
|
||||
max_buffer_duration_secs: int = Field(default=50, description="Maximum buffer duration in seconds")
|
||||
silence_timeout: Optional[int] = Field(default=None, description="Silence timeout in milliseconds")
|
||||
endpoint: Optional[str] = Field(default=None, description="Custom VAD service endpoint")
|
||||
secret_key: Optional[str] = Field(default=None, description="VAD service secret key")
|
||||
secret_id: Optional[str] = Field(default=None, description="VAD service secret ID")
|
||||
|
||||
|
||||
class ASROption(BaseModel):
|
||||
"""Automatic Speech Recognition configuration."""
|
||||
|
||||
provider: str = Field(..., description="ASR provider (tencent, aliyun, openai, etc.)")
|
||||
language: Optional[str] = Field(default=None, description="Language code (zh-CN, en-US)")
|
||||
app_id: Optional[str] = Field(default=None, description="Application ID")
|
||||
secret_id: Optional[str] = Field(default=None, description="Secret ID for authentication")
|
||||
secret_key: Optional[str] = Field(default=None, description="Secret key for authentication")
|
||||
model_type: Optional[str] = Field(default=None, description="ASR model type (16k_zh, 8k_en)")
|
||||
buffer_size: Optional[int] = Field(default=None, description="Audio buffer size in bytes")
|
||||
samplerate: Optional[int] = Field(default=None, description="Audio sample rate")
|
||||
endpoint: Optional[str] = Field(default=None, description="Custom ASR service endpoint")
|
||||
extra: Optional[Dict[str, Any]] = Field(default=None, description="Additional parameters")
|
||||
start_when_answer: bool = Field(default=False, description="Start ASR when call is answered")
|
||||
|
||||
|
||||
class TTSOption(BaseModel):
|
||||
"""Text-to-Speech configuration."""
|
||||
|
||||
samplerate: Optional[int] = Field(default=None, description="TTS output sample rate")
|
||||
provider: str = Field(default="msedge", description="TTS provider (tencent, aliyun, deepgram, msedge)")
|
||||
speed: float = Field(default=1.0, description="Speech speed multiplier")
|
||||
app_id: Optional[str] = Field(default=None, description="Application ID")
|
||||
secret_id: Optional[str] = Field(default=None, description="Secret ID for authentication")
|
||||
secret_key: Optional[str] = Field(default=None, description="Secret key for authentication")
|
||||
volume: Optional[int] = Field(default=None, description="Speech volume level (1-10)")
|
||||
speaker: Optional[str] = Field(default=None, description="Voice speaker name")
|
||||
codec: Optional[str] = Field(default=None, description="Audio codec")
|
||||
subtitle: bool = Field(default=False, description="Enable subtitle generation")
|
||||
emotion: Optional[str] = Field(default=None, description="Speech emotion")
|
||||
endpoint: Optional[str] = Field(default=None, description="Custom TTS service endpoint")
|
||||
extra: Optional[Dict[str, Any]] = Field(default=None, description="Additional parameters")
|
||||
max_concurrent_tasks: Optional[int] = Field(default=None, description="Max concurrent tasks")
|
||||
|
||||
|
||||
class RecorderOption(BaseModel):
|
||||
"""Call recording configuration."""
|
||||
|
||||
recorder_file: str = Field(..., description="Path to recording file")
|
||||
samplerate: int = Field(default=16000, description="Recording sample rate")
|
||||
ptime: int = Field(default=200, description="Packet time in milliseconds")
|
||||
|
||||
|
||||
class MediaPassOption(BaseModel):
|
||||
"""Media pass-through configuration for external audio processing."""
|
||||
|
||||
url: str = Field(..., description="WebSocket URL for media streaming")
|
||||
input_sample_rate: int = Field(default=16000, description="Sample rate of audio received from WebSocket")
|
||||
output_sample_rate: int = Field(default=16000, description="Sample rate of audio sent to WebSocket")
|
||||
packet_size: int = Field(default=2560, description="Packet size in bytes")
|
||||
ptime: Optional[int] = Field(default=None, description="Buffered playback period in milliseconds")
|
||||
|
||||
|
||||
class SipOption(BaseModel):
|
||||
"""SIP protocol configuration."""
|
||||
|
||||
username: Optional[str] = Field(default=None, description="SIP username")
|
||||
password: Optional[str] = Field(default=None, description="SIP password")
|
||||
realm: Optional[str] = Field(default=None, description="SIP realm/domain")
|
||||
headers: Optional[Dict[str, str]] = Field(default=None, description="Additional SIP headers")
|
||||
|
||||
|
||||
class HandlerRule(BaseModel):
|
||||
"""Handler routing rule."""
|
||||
|
||||
caller: Optional[str] = Field(default=None, description="Caller pattern (regex)")
|
||||
callee: Optional[str] = Field(default=None, description="Callee pattern (regex)")
|
||||
playbook: Optional[str] = Field(default=None, description="Playbook file path")
|
||||
webhook: Optional[str] = Field(default=None, description="Webhook URL")
|
||||
|
||||
|
||||
class CallOption(BaseModel):
|
||||
"""Comprehensive call configuration options."""
|
||||
|
||||
# Basic options
|
||||
denoise: bool = Field(default=False, description="Enable noise reduction")
|
||||
offer: Optional[str] = Field(default=None, description="SDP offer string")
|
||||
callee: Optional[str] = Field(default=None, description="Callee SIP URI or phone number")
|
||||
caller: Optional[str] = Field(default=None, description="Caller SIP URI or phone number")
|
||||
|
||||
# Audio codec
|
||||
codec: str = Field(default="pcm", description="Audio codec (pcm, pcma, pcmu, g722)")
|
||||
|
||||
# Component configurations
|
||||
recorder: Optional[RecorderOption] = Field(default=None, description="Call recording config")
|
||||
asr: Optional[ASROption] = Field(default=None, description="ASR configuration")
|
||||
vad: Optional[VADOption] = Field(default=None, description="VAD configuration")
|
||||
tts: Optional[TTSOption] = Field(default=None, description="TTS configuration")
|
||||
media_pass: Optional[MediaPassOption] = Field(default=None, description="Media pass-through config")
|
||||
sip: Optional[SipOption] = Field(default=None, description="SIP configuration")
|
||||
|
||||
# Timeouts and networking
|
||||
handshake_timeout: Optional[int] = Field(default=None, description="Handshake timeout in seconds")
|
||||
enable_ipv6: bool = Field(default=False, description="Enable IPv6 support")
|
||||
inactivity_timeout: Optional[int] = Field(default=None, description="Inactivity timeout in seconds")
|
||||
|
||||
# EOU configuration
|
||||
eou: Optional[Dict[str, Any]] = Field(default=None, description="End of utterance detection config")
|
||||
|
||||
# Extra parameters
|
||||
extra: Optional[Dict[str, Any]] = Field(default=None, description="Additional custom parameters")
|
||||
|
||||
class Config:
|
||||
populate_by_name = True
|
||||
223
models/events.py
Normal file
223
models/events.py
Normal file
@@ -0,0 +1,223 @@
|
||||
"""Protocol event models matching the original active-call API."""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def current_timestamp_ms() -> int:
|
||||
"""Get current timestamp in milliseconds."""
|
||||
return int(datetime.now().timestamp() * 1000)
|
||||
|
||||
|
||||
# Base Event Model
|
||||
class BaseEvent(BaseModel):
|
||||
"""Base event model."""
|
||||
|
||||
event: str = Field(..., description="Event type")
|
||||
track_id: str = Field(..., description="Unique track identifier")
|
||||
timestamp: int = Field(default_factory=current_timestamp_ms, description="Event timestamp in milliseconds")
|
||||
|
||||
|
||||
# Lifecycle Events
|
||||
class IncomingEvent(BaseEvent):
|
||||
"""Incoming call event (SIP only)."""
|
||||
|
||||
event: str = Field(default="incoming", description="Event type")
|
||||
caller: Optional[str] = Field(default=None, description="Caller's SIP URI")
|
||||
callee: Optional[str] = Field(default=None, description="Callee's SIP URI")
|
||||
sdp: Optional[str] = Field(default=None, description="SDP offer from caller")
|
||||
|
||||
|
||||
class AnswerEvent(BaseEvent):
|
||||
"""Call answered event."""
|
||||
|
||||
event: str = Field(default="answer", description="Event type")
|
||||
sdp: Optional[str] = Field(default=None, description="SDP answer from server")
|
||||
|
||||
|
||||
class RejectEvent(BaseEvent):
|
||||
"""Call rejected event."""
|
||||
|
||||
event: str = Field(default="reject", description="Event type")
|
||||
reason: Optional[str] = Field(default=None, description="Rejection reason")
|
||||
code: Optional[int] = Field(default=None, description="SIP response code")
|
||||
|
||||
|
||||
class RingingEvent(BaseEvent):
|
||||
"""Call ringing event."""
|
||||
|
||||
event: str = Field(default="ringing", description="Event type")
|
||||
early_media: bool = Field(default=False, description="Early media available")
|
||||
|
||||
|
||||
class HangupEvent(BaseModel):
|
||||
"""Call hangup event."""
|
||||
|
||||
event: str = Field(default="hangup", description="Event type")
|
||||
timestamp: int = Field(default_factory=current_timestamp_ms, description="Event timestamp")
|
||||
reason: Optional[str] = Field(default=None, description="Hangup reason")
|
||||
initiator: Optional[str] = Field(default=None, description="Who initiated hangup")
|
||||
start_time: Optional[str] = Field(default=None, description="Call start time (ISO 8601)")
|
||||
hangup_time: Optional[str] = Field(default=None, description="Hangup time (ISO 8601)")
|
||||
answer_time: Optional[str] = Field(default=None, description="Answer time (ISO 8601)")
|
||||
ringing_time: Optional[str] = Field(default=None, description="Ringing time (ISO 8601)")
|
||||
from_: Optional[Dict[str, Any]] = Field(default=None, alias="from", description="Caller info")
|
||||
to: Optional[Dict[str, Any]] = Field(default=None, description="Callee info")
|
||||
extra: Optional[Dict[str, Any]] = Field(default=None, description="Additional metadata")
|
||||
|
||||
class Config:
|
||||
populate_by_name = True
|
||||
|
||||
|
||||
# VAD Events
|
||||
class SpeakingEvent(BaseEvent):
|
||||
"""Speech detected event."""
|
||||
|
||||
event: str = Field(default="speaking", description="Event type")
|
||||
start_time: int = Field(default_factory=current_timestamp_ms, description="Speech start time")
|
||||
|
||||
|
||||
class SilenceEvent(BaseEvent):
|
||||
"""Silence detected event."""
|
||||
|
||||
event: str = Field(default="silence", description="Event type")
|
||||
start_time: int = Field(default_factory=current_timestamp_ms, description="Silence start time")
|
||||
duration: int = Field(default=0, description="Silence duration in milliseconds")
|
||||
|
||||
|
||||
# AI/ASR Events
|
||||
class AsrFinalEvent(BaseEvent):
|
||||
"""ASR final transcription event."""
|
||||
|
||||
event: str = Field(default="asrFinal", description="Event type")
|
||||
index: int = Field(..., description="ASR result sequence number")
|
||||
start_time: Optional[int] = Field(default=None, description="Speech start time")
|
||||
end_time: Optional[int] = Field(default=None, description="Speech end time")
|
||||
text: str = Field(..., description="Transcribed text")
|
||||
|
||||
|
||||
class AsrDeltaEvent(BaseEvent):
|
||||
"""ASR partial transcription event (streaming)."""
|
||||
|
||||
event: str = Field(default="asrDelta", description="Event type")
|
||||
index: int = Field(..., description="ASR result sequence number")
|
||||
start_time: Optional[int] = Field(default=None, description="Speech start time")
|
||||
end_time: Optional[int] = Field(default=None, description="Speech end time")
|
||||
text: str = Field(..., description="Partial transcribed text")
|
||||
|
||||
|
||||
class EouEvent(BaseEvent):
|
||||
"""End of utterance detection event."""
|
||||
|
||||
event: str = Field(default="eou", description="Event type")
|
||||
completed: bool = Field(default=True, description="Whether utterance was completed")
|
||||
|
||||
|
||||
# Audio Track Events
|
||||
class TrackStartEvent(BaseEvent):
|
||||
"""Audio track start event."""
|
||||
|
||||
event: str = Field(default="trackStart", description="Event type")
|
||||
play_id: Optional[str] = Field(default=None, description="Play ID from TTS/Play command")
|
||||
|
||||
|
||||
class TrackEndEvent(BaseEvent):
|
||||
"""Audio track end event."""
|
||||
|
||||
event: str = Field(default="trackEnd", description="Event type")
|
||||
duration: int = Field(..., description="Track duration in milliseconds")
|
||||
ssrc: int = Field(..., description="RTP SSRC identifier")
|
||||
play_id: Optional[str] = Field(default=None, description="Play ID from TTS/Play command")
|
||||
|
||||
|
||||
class InterruptionEvent(BaseEvent):
|
||||
"""Playback interruption event."""
|
||||
|
||||
event: str = Field(default="interruption", description="Event type")
|
||||
play_id: Optional[str] = Field(default=None, description="Play ID that was interrupted")
|
||||
subtitle: Optional[str] = Field(default=None, description="TTS text being played")
|
||||
position: Optional[int] = Field(default=None, description="Word index position")
|
||||
total_duration: Optional[int] = Field(default=None, description="Total TTS duration")
|
||||
current: Optional[int] = Field(default=None, description="Elapsed time when interrupted")
|
||||
|
||||
|
||||
# System Events
|
||||
class ErrorEvent(BaseEvent):
|
||||
"""Error event."""
|
||||
|
||||
event: str = Field(default="error", description="Event type")
|
||||
sender: str = Field(..., description="Component that generated the error")
|
||||
error: str = Field(..., description="Error message")
|
||||
code: Optional[int] = Field(default=None, description="Error code")
|
||||
|
||||
|
||||
class MetricsEvent(BaseModel):
|
||||
"""Performance metrics event."""
|
||||
|
||||
event: str = Field(default="metrics", description="Event type")
|
||||
timestamp: int = Field(default_factory=current_timestamp_ms, description="Event timestamp")
|
||||
key: str = Field(..., description="Metric key")
|
||||
duration: int = Field(..., description="Duration in milliseconds")
|
||||
data: Optional[Dict[str, Any]] = Field(default=None, description="Additional metric data")
|
||||
|
||||
|
||||
class AddHistoryEvent(BaseModel):
|
||||
"""Conversation history entry added event."""
|
||||
|
||||
event: str = Field(default="addHistory", description="Event type")
|
||||
timestamp: int = Field(default_factory=current_timestamp_ms, description="Event timestamp")
|
||||
sender: Optional[str] = Field(default=None, description="Component that added history")
|
||||
speaker: str = Field(..., description="Speaker identifier")
|
||||
text: str = Field(..., description="Conversation text")
|
||||
|
||||
|
||||
class DTMFEvent(BaseEvent):
|
||||
"""DTMF tone detected event."""
|
||||
|
||||
event: str = Field(default="dtmf", description="Event type")
|
||||
digit: str = Field(..., description="DTMF digit (0-9, *, #, A-D)")
|
||||
|
||||
|
||||
# Event type mapping
|
||||
EVENT_TYPES = {
|
||||
"incoming": IncomingEvent,
|
||||
"answer": AnswerEvent,
|
||||
"reject": RejectEvent,
|
||||
"ringing": RingingEvent,
|
||||
"hangup": HangupEvent,
|
||||
"speaking": SpeakingEvent,
|
||||
"silence": SilenceEvent,
|
||||
"asrFinal": AsrFinalEvent,
|
||||
"asrDelta": AsrDeltaEvent,
|
||||
"eou": EouEvent,
|
||||
"trackStart": TrackStartEvent,
|
||||
"trackEnd": TrackEndEvent,
|
||||
"interruption": InterruptionEvent,
|
||||
"error": ErrorEvent,
|
||||
"metrics": MetricsEvent,
|
||||
"addHistory": AddHistoryEvent,
|
||||
"dtmf": DTMFEvent,
|
||||
}
|
||||
|
||||
|
||||
def create_event(event_type: str, **kwargs) -> BaseModel:
|
||||
"""
|
||||
Create an event model.
|
||||
|
||||
Args:
|
||||
event_type: Type of event to create
|
||||
**kwargs: Event fields
|
||||
|
||||
Returns:
|
||||
Event model instance
|
||||
|
||||
Raises:
|
||||
ValueError: If event type is unknown
|
||||
"""
|
||||
event_class = EVENT_TYPES.get(event_type)
|
||||
|
||||
if not event_class:
|
||||
raise ValueError(f"Unknown event type: {event_type}")
|
||||
|
||||
return event_class(event=event_type, **kwargs)
|
||||
6
processors/__init__.py
Normal file
6
processors/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Audio Processors Package"""
|
||||
|
||||
from processors.eou import EouDetector
|
||||
from processors.vad import SileroVAD, VADProcessor
|
||||
|
||||
__all__ = ["EouDetector", "SileroVAD", "VADProcessor"]
|
||||
80
processors/eou.py
Normal file
80
processors/eou.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""End-of-Utterance Detection."""
|
||||
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class EouDetector:
|
||||
"""
|
||||
End-of-utterance detector. Fires EOU only after continuous silence for
|
||||
silence_threshold_ms. Short pauses between sentences do not trigger EOU
|
||||
because speech resets the silence timer (one EOU per turn).
|
||||
"""
|
||||
|
||||
def __init__(self, silence_threshold_ms: int = 1000, min_speech_duration_ms: int = 250):
|
||||
"""
|
||||
Initialize EOU detector.
|
||||
|
||||
Args:
|
||||
silence_threshold_ms: How long silence must last to trigger EOU (default 1000ms)
|
||||
min_speech_duration_ms: Minimum speech duration to consider valid (default 250ms)
|
||||
"""
|
||||
self.threshold = silence_threshold_ms / 1000.0
|
||||
self.min_speech = min_speech_duration_ms / 1000.0
|
||||
self._silence_threshold_ms = silence_threshold_ms
|
||||
self._min_speech_duration_ms = min_speech_duration_ms
|
||||
|
||||
# State
|
||||
self.is_speaking = False
|
||||
self.speech_start_time = 0.0
|
||||
self.silence_start_time: Optional[float] = None
|
||||
self.triggered = False
|
||||
|
||||
def process(self, vad_status: str) -> bool:
|
||||
"""
|
||||
Process VAD status and detect end of utterance.
|
||||
|
||||
Input: "Speech" or "Silence" (from VAD).
|
||||
Output: True if EOU detected, False otherwise.
|
||||
|
||||
Short breaks between phrases reset the silence clock when speech
|
||||
resumes, so only one EOU is emitted after the user truly stops.
|
||||
"""
|
||||
now = time.time()
|
||||
|
||||
if vad_status == "Speech":
|
||||
if not self.is_speaking:
|
||||
self.is_speaking = True
|
||||
self.speech_start_time = now
|
||||
self.triggered = False
|
||||
# Any speech resets silence timer — short pause + more speech = one utterance
|
||||
self.silence_start_time = None
|
||||
return False
|
||||
|
||||
if vad_status == "Silence":
|
||||
if not self.is_speaking:
|
||||
return False
|
||||
if self.silence_start_time is None:
|
||||
self.silence_start_time = now
|
||||
|
||||
speech_duration = self.silence_start_time - self.speech_start_time
|
||||
if speech_duration < self.min_speech:
|
||||
self.is_speaking = False
|
||||
self.silence_start_time = None
|
||||
return False
|
||||
|
||||
silence_duration = now - self.silence_start_time
|
||||
if silence_duration >= self.threshold and not self.triggered:
|
||||
self.triggered = True
|
||||
self.is_speaking = False
|
||||
self.silence_start_time = None
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset EOU detector state."""
|
||||
self.is_speaking = False
|
||||
self.speech_start_time = 0.0
|
||||
self.silence_start_time = None
|
||||
self.triggered = False
|
||||
168
processors/tracks.py
Normal file
168
processors/tracks.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""Audio track processing for WebRTC."""
|
||||
|
||||
import asyncio
|
||||
import fractions
|
||||
from typing import Optional
|
||||
from loguru import logger
|
||||
|
||||
# Try to import aiortc (optional for WebRTC functionality)
|
||||
try:
|
||||
from aiortc import AudioStreamTrack
|
||||
AIORTC_AVAILABLE = True
|
||||
except ImportError:
|
||||
AIORTC_AVAILABLE = False
|
||||
AudioStreamTrack = object # Dummy class for type hints
|
||||
|
||||
# Try to import PyAV (optional for audio resampling)
|
||||
try:
|
||||
from av import AudioFrame, AudioResampler
|
||||
AV_AVAILABLE = True
|
||||
except ImportError:
|
||||
AV_AVAILABLE = False
|
||||
# Create dummy classes for type hints
|
||||
class AudioFrame:
|
||||
pass
|
||||
class AudioResampler:
|
||||
pass
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Resampled16kTrack(AudioStreamTrack if AIORTC_AVAILABLE else object):
|
||||
"""
|
||||
Audio track that resamples input to 16kHz mono PCM.
|
||||
|
||||
Wraps an existing MediaStreamTrack and converts its output
|
||||
to 16kHz mono 16-bit PCM format for the pipeline.
|
||||
"""
|
||||
|
||||
def __init__(self, track, target_sample_rate: int = 16000):
|
||||
"""
|
||||
Initialize resampled track.
|
||||
|
||||
Args:
|
||||
track: Source MediaStreamTrack
|
||||
target_sample_rate: Target sample rate (default: 16000)
|
||||
"""
|
||||
if not AIORTC_AVAILABLE:
|
||||
raise RuntimeError("aiortc not available - Resampled16kTrack cannot be used")
|
||||
|
||||
super().__init__()
|
||||
self.track = track
|
||||
self.target_sample_rate = target_sample_rate
|
||||
|
||||
if AV_AVAILABLE:
|
||||
self.resampler = AudioResampler(
|
||||
format="s16",
|
||||
layout="mono",
|
||||
rate=target_sample_rate
|
||||
)
|
||||
else:
|
||||
logger.warning("PyAV not available, audio resampling disabled")
|
||||
self.resampler = None
|
||||
|
||||
self._closed = False
|
||||
|
||||
async def recv(self):
|
||||
"""
|
||||
Receive and resample next audio frame.
|
||||
|
||||
Returns:
|
||||
Resampled AudioFrame at 16kHz mono
|
||||
"""
|
||||
if self._closed:
|
||||
raise RuntimeError("Track is closed")
|
||||
|
||||
# Get frame from source track
|
||||
frame = await self.track.recv()
|
||||
|
||||
# Resample the frame if AV is available
|
||||
if AV_AVAILABLE and self.resampler:
|
||||
resampled_frame = self.resampler.resample(frame)
|
||||
# Ensure the frame has the correct format
|
||||
resampled_frame.sample_rate = self.target_sample_rate
|
||||
return resampled_frame
|
||||
else:
|
||||
# Return frame as-is if AV is not available
|
||||
return frame
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the track and cleanup resources."""
|
||||
self._closed = True
|
||||
if hasattr(self, 'resampler') and self.resampler:
|
||||
del self.resampler
|
||||
logger.debug("Resampled track stopped")
|
||||
|
||||
|
||||
class SineWaveTrack(AudioStreamTrack if AIORTC_AVAILABLE else object):
|
||||
"""
|
||||
Synthetic audio track that generates a sine wave.
|
||||
|
||||
Useful for testing without requiring real audio input.
|
||||
"""
|
||||
|
||||
def __init__(self, sample_rate: int = 16000, frequency: int = 440):
|
||||
"""
|
||||
Initialize sine wave track.
|
||||
|
||||
Args:
|
||||
sample_rate: Audio sample rate (default: 16000)
|
||||
frequency: Sine wave frequency in Hz (default: 440)
|
||||
"""
|
||||
if not AIORTC_AVAILABLE:
|
||||
raise RuntimeError("aiortc not available - SineWaveTrack cannot be used")
|
||||
|
||||
super().__init__()
|
||||
self.sample_rate = sample_rate
|
||||
self.frequency = frequency
|
||||
self.counter = 0
|
||||
self._stopped = False
|
||||
|
||||
async def recv(self):
|
||||
"""
|
||||
Generate next audio frame with sine wave.
|
||||
|
||||
Returns:
|
||||
AudioFrame with sine wave data
|
||||
"""
|
||||
if self._stopped:
|
||||
raise RuntimeError("Track is stopped")
|
||||
|
||||
# Generate 20ms of audio
|
||||
samples = int(self.sample_rate * 0.02)
|
||||
pts = self.counter
|
||||
time_base = fractions.Fraction(1, self.sample_rate)
|
||||
|
||||
# Generate sine wave
|
||||
t = np.linspace(
|
||||
self.counter / self.sample_rate,
|
||||
(self.counter + samples) / self.sample_rate,
|
||||
samples,
|
||||
endpoint=False
|
||||
)
|
||||
|
||||
# Generate sine wave (Int16 PCM)
|
||||
data = (0.5 * np.sin(2 * np.pi * self.frequency * t) * 32767).astype(np.int16)
|
||||
|
||||
# Update counter
|
||||
self.counter += samples
|
||||
|
||||
# Create AudioFrame if AV is available
|
||||
if AV_AVAILABLE:
|
||||
frame = AudioFrame.from_ndarray(data.reshape(1, -1), format='s16', layout='mono')
|
||||
frame.pts = pts
|
||||
frame.time_base = time_base
|
||||
frame.sample_rate = self.sample_rate
|
||||
return frame
|
||||
else:
|
||||
# Return simple data structure if AV is not available
|
||||
return {
|
||||
'data': data,
|
||||
'sample_rate': self.sample_rate,
|
||||
'pts': pts,
|
||||
'time_base': time_base
|
||||
}
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the track."""
|
||||
self._stopped = True
|
||||
213
processors/vad.py
Normal file
213
processors/vad.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""Voice Activity Detection using Silero VAD."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Tuple, Optional
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from processors.eou import EouDetector
|
||||
|
||||
# Try to import onnxruntime (optional for VAD functionality)
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
ONNX_AVAILABLE = True
|
||||
except ImportError:
|
||||
ONNX_AVAILABLE = False
|
||||
ort = None
|
||||
logger.warning("onnxruntime not available - VAD will be disabled")
|
||||
|
||||
|
||||
class SileroVAD:
|
||||
"""
|
||||
Voice Activity Detection using Silero VAD model.
|
||||
|
||||
Detects speech in audio chunks using the Silero VAD ONNX model.
|
||||
Returns "Speech" or "Silence" for each audio chunk.
|
||||
"""
|
||||
|
||||
def __init__(self, model_path: str = "data/vad/silero_vad.onnx", sample_rate: int = 16000):
|
||||
"""
|
||||
Initialize Silero VAD.
|
||||
|
||||
Args:
|
||||
model_path: Path to Silero VAD ONNX model
|
||||
sample_rate: Audio sample rate (must be 16kHz for Silero VAD)
|
||||
"""
|
||||
self.sample_rate = sample_rate
|
||||
self.model_path = model_path
|
||||
|
||||
# Check if model exists
|
||||
if not os.path.exists(model_path):
|
||||
logger.warning(f"VAD model not found at {model_path}. VAD will be disabled.")
|
||||
self.session = None
|
||||
return
|
||||
|
||||
# Check if onnxruntime is available
|
||||
if not ONNX_AVAILABLE:
|
||||
logger.warning("onnxruntime not available - VAD will be disabled")
|
||||
self.session = None
|
||||
return
|
||||
|
||||
# Load ONNX model
|
||||
try:
|
||||
self.session = ort.InferenceSession(model_path)
|
||||
logger.info(f"Loaded Silero VAD model from {model_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load VAD model: {e}")
|
||||
self.session = None
|
||||
return
|
||||
|
||||
# Internal state for VAD
|
||||
self._reset_state()
|
||||
self.buffer = np.array([], dtype=np.float32)
|
||||
self.min_chunk_size = 512
|
||||
self.last_label = "Silence"
|
||||
self.last_probability = 0.0
|
||||
|
||||
def _reset_state(self):
|
||||
# Silero VAD V4+ expects state shape [2, 1, 128]
|
||||
self._state = np.zeros((2, 1, 128), dtype=np.float32)
|
||||
self._sr = np.array([self.sample_rate], dtype=np.int64)
|
||||
|
||||
def process_audio(self, pcm_bytes: bytes, chunk_size_ms: int = 20) -> Tuple[str, float]:
|
||||
"""
|
||||
Process audio chunk and detect speech.
|
||||
|
||||
Args:
|
||||
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
|
||||
chunk_size_ms: Chunk duration in milliseconds (ignored for buffering logic)
|
||||
|
||||
Returns:
|
||||
Tuple of (label, probability) where label is "Speech" or "Silence"
|
||||
"""
|
||||
if self.session is None or not ONNX_AVAILABLE:
|
||||
# If model not loaded or onnxruntime not available, assume speech
|
||||
return "Speech", 1.0
|
||||
|
||||
# Convert bytes to numpy array of int16
|
||||
audio_int16 = np.frombuffer(pcm_bytes, dtype=np.int16)
|
||||
|
||||
# Normalize to float32 (-1.0 to 1.0)
|
||||
audio_float = audio_int16.astype(np.float32) / 32768.0
|
||||
|
||||
# Add to buffer
|
||||
self.buffer = np.concatenate((self.buffer, audio_float))
|
||||
|
||||
# Process all complete chunks in the buffer
|
||||
processed_any = False
|
||||
while len(self.buffer) >= self.min_chunk_size:
|
||||
# Slice exactly 512 samples
|
||||
chunk = self.buffer[:self.min_chunk_size]
|
||||
self.buffer = self.buffer[self.min_chunk_size:]
|
||||
|
||||
# Prepare inputs
|
||||
# Input tensor shape: [batch, samples] -> [1, 512]
|
||||
input_tensor = chunk.reshape(1, -1)
|
||||
|
||||
# Run inference
|
||||
try:
|
||||
ort_inputs = {
|
||||
'input': input_tensor,
|
||||
'state': self._state,
|
||||
'sr': self._sr
|
||||
}
|
||||
|
||||
# Outputs: probability, state
|
||||
out, self._state = self.session.run(None, ort_inputs)
|
||||
|
||||
# Get probability
|
||||
self.last_probability = float(out[0][0])
|
||||
self.last_label = "Speech" if self.last_probability >= 0.5 else "Silence"
|
||||
processed_any = True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"VAD inference error: {e}")
|
||||
# Try to determine if it's an input name issue
|
||||
try:
|
||||
inputs = [x.name for x in self.session.get_inputs()]
|
||||
logger.error(f"Model expects inputs: {inputs}")
|
||||
except:
|
||||
pass
|
||||
return "Speech", 1.0
|
||||
|
||||
return self.last_label, self.last_probability
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset VAD internal state."""
|
||||
self._reset_state()
|
||||
self.buffer = np.array([], dtype=np.float32)
|
||||
self.last_label = "Silence"
|
||||
self.last_probability = 0.0
|
||||
|
||||
|
||||
class VADProcessor:
|
||||
"""
|
||||
High-level VAD processor with state management.
|
||||
|
||||
Tracks speech/silence state and emits events on transitions.
|
||||
"""
|
||||
|
||||
def __init__(self, vad_model: SileroVAD, threshold: float = 0.5,
|
||||
silence_threshold_ms: int = 1000, min_speech_duration_ms: int = 250):
|
||||
"""
|
||||
Initialize VAD processor.
|
||||
|
||||
Args:
|
||||
vad_model: Silero VAD model instance
|
||||
threshold: Speech detection threshold
|
||||
silence_threshold_ms: EOU silence threshold in ms (longer = one EOU across short pauses)
|
||||
min_speech_duration_ms: EOU min speech duration in ms (ignore very short noises)
|
||||
"""
|
||||
self.vad = vad_model
|
||||
self.threshold = threshold
|
||||
self._eou_silence_ms = silence_threshold_ms
|
||||
self._eou_min_speech_ms = min_speech_duration_ms
|
||||
self.is_speaking = False
|
||||
self.speech_start_time: Optional[float] = None
|
||||
self.silence_start_time: Optional[float] = None
|
||||
self.eou_detector = EouDetector(silence_threshold_ms, min_speech_duration_ms)
|
||||
|
||||
def process(self, pcm_bytes: bytes, chunk_size_ms: int = 20) -> Optional[Tuple[str, float]]:
|
||||
"""
|
||||
Process audio chunk and detect state changes.
|
||||
|
||||
Args:
|
||||
pcm_bytes: PCM audio data
|
||||
chunk_size_ms: Chunk duration in milliseconds
|
||||
|
||||
Returns:
|
||||
Tuple of (event_type, probability) if state changed, None otherwise
|
||||
"""
|
||||
label, probability = self.vad.process_audio(pcm_bytes, chunk_size_ms)
|
||||
|
||||
# Check if this is speech based on threshold
|
||||
is_speech = probability >= self.threshold
|
||||
|
||||
# Check EOU
|
||||
if self.eou_detector.process("Speech" if is_speech else "Silence"):
|
||||
return ("eou", probability)
|
||||
|
||||
# State transition: Silence -> Speech
|
||||
if is_speech and not self.is_speaking:
|
||||
self.is_speaking = True
|
||||
self.speech_start_time = asyncio.get_event_loop().time()
|
||||
self.silence_start_time = None
|
||||
return ("speaking", probability)
|
||||
|
||||
# State transition: Speech -> Silence
|
||||
elif not is_speech and self.is_speaking:
|
||||
self.is_speaking = False
|
||||
self.silence_start_time = asyncio.get_event_loop().time()
|
||||
self.speech_start_time = None
|
||||
return ("silence", probability)
|
||||
|
||||
return None
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset VAD state."""
|
||||
self.vad.reset()
|
||||
self.is_speaking = False
|
||||
self.speech_start_time = None
|
||||
self.silence_start_time = None
|
||||
self.eou_detector = EouDetector(self._eou_silence_ms, self._eou_min_speech_ms)
|
||||
134
pyproject.toml
Normal file
134
pyproject.toml
Normal file
@@ -0,0 +1,134 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=68.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "py-active-call-cc"
|
||||
version = "0.1.0"
|
||||
description = "Python Active-Call: Real-time audio streaming with WebSocket and WebRTC"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.11"
|
||||
license = {text = "MIT"}
|
||||
authors = [
|
||||
{name = "Your Name", email = "your.email@example.com"}
|
||||
]
|
||||
keywords = ["webrtc", "websocket", "audio", "voip", "real-time"]
|
||||
classifiers = [
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Intended Audience :: Developers",
|
||||
"Topic :: Communications :: Telephony",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/yourusername/py-active-call-cc"
|
||||
Documentation = "https://github.com/yourusername/py-active-call-cc/blob/main/README.md"
|
||||
Repository = "https://github.com/yourusername/py-active-call-cc.git"
|
||||
Issues = "https://github.com/yourusername/py-active-call-cc/issues"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["."]
|
||||
include = ["app*"]
|
||||
exclude = ["tests*", "scripts*", "reference*"]
|
||||
|
||||
[tool.black]
|
||||
line-length = 100
|
||||
target-version = ['py311']
|
||||
include = '\.pyi?$'
|
||||
extend-exclude = '''
|
||||
/(
|
||||
# directories
|
||||
\.eggs
|
||||
| \.git
|
||||
| \.hg
|
||||
| \.mypy_cache
|
||||
| \.tox
|
||||
| \.venv
|
||||
| build
|
||||
| dist
|
||||
| reference
|
||||
)/
|
||||
'''
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 100
|
||||
target-version = "py311"
|
||||
select = [
|
||||
"E", # pycodestyle errors
|
||||
"W", # pycodestyle warnings
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
"B", # flake8-bugbear
|
||||
"C4", # flake8-comprehensions
|
||||
"UP", # pyupgrade
|
||||
]
|
||||
ignore = [
|
||||
"E501", # line too long (handled by black)
|
||||
"B008", # do not perform function calls in argument defaults
|
||||
]
|
||||
exclude = [
|
||||
".bzr",
|
||||
".direnv",
|
||||
".eggs",
|
||||
".git",
|
||||
".hg",
|
||||
".mypy_cache",
|
||||
".nox",
|
||||
".pants.d",
|
||||
".ruff_cache",
|
||||
".svn",
|
||||
".tox",
|
||||
".venv",
|
||||
"__pypackages__",
|
||||
"_build",
|
||||
"buck-out",
|
||||
"build",
|
||||
"dist",
|
||||
"node_modules",
|
||||
"venv",
|
||||
"reference",
|
||||
]
|
||||
|
||||
[tool.ruff.per-file-ignores]
|
||||
"__init__.py" = ["F401"] # unused imports
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.11"
|
||||
warn_return_any = true
|
||||
warn_unused_configs = true
|
||||
disallow_untyped_defs = false
|
||||
disallow_incomplete_defs = false
|
||||
check_untyped_defs = true
|
||||
no_implicit_optional = true
|
||||
warn_redundant_casts = true
|
||||
warn_unused_ignores = true
|
||||
warn_no_return = true
|
||||
strict_equality = true
|
||||
exclude = [
|
||||
"venv",
|
||||
"reference",
|
||||
"build",
|
||||
"dist",
|
||||
]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = [
|
||||
"aiortc.*",
|
||||
"av.*",
|
||||
"onnxruntime.*",
|
||||
]
|
||||
ignore_missing_imports = true
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
minversion = "7.0"
|
||||
addopts = "-ra -q --strict-markers --strict-config"
|
||||
testpaths = ["tests"]
|
||||
pythonpath = ["."]
|
||||
asyncio_mode = "auto"
|
||||
markers = [
|
||||
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
||||
"integration: marks tests as integration tests",
|
||||
]
|
||||
0
requirements.txt
Normal file
0
requirements.txt
Normal file
166
scripts/test_websocket.py
Normal file
166
scripts/test_websocket.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""WebSocket endpoint test client.
|
||||
|
||||
Tests the /ws endpoint with sine wave or file audio streaming.
|
||||
Based on reference/py-active-call/exec/test_ws_endpoint/test_ws.py
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import json
|
||||
import struct
|
||||
import math
|
||||
import argparse
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
# Configuration
|
||||
SERVER_URL = "ws://localhost:8000/ws"
|
||||
SAMPLE_RATE = 16000
|
||||
FREQUENCY = 440 # 440Hz Sine Wave
|
||||
CHUNK_DURATION_MS = 20
|
||||
# 16kHz * 16-bit (2 bytes) * 20ms = 640 bytes per chunk
|
||||
CHUNK_SIZE_BYTES = int(SAMPLE_RATE * 2 * (CHUNK_DURATION_MS / 1000.0))
|
||||
|
||||
|
||||
def generate_sine_wave(duration_ms=1000):
|
||||
"""Generates sine wave audio (16kHz mono PCM 16-bit)."""
|
||||
num_samples = int(SAMPLE_RATE * (duration_ms / 1000.0))
|
||||
audio_data = bytearray()
|
||||
|
||||
for x in range(num_samples):
|
||||
# Generate sine wave sample
|
||||
value = int(32767.0 * math.sin(2 * math.pi * FREQUENCY * x / SAMPLE_RATE))
|
||||
# Pack as little-endian 16-bit integer
|
||||
audio_data.extend(struct.pack('<h', value))
|
||||
|
||||
return audio_data
|
||||
|
||||
|
||||
async def receive_loop(ws):
|
||||
"""Listen for incoming messages from the server."""
|
||||
print("👂 Listening for server responses...")
|
||||
async for msg in ws:
|
||||
timestamp = datetime.now().strftime("%H:%M:%S")
|
||||
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
try:
|
||||
data = json.loads(msg.data)
|
||||
event_type = data.get('event', 'Unknown')
|
||||
print(f"[{timestamp}] 📨 Event: {event_type} | {msg.data[:150]}...")
|
||||
except json.JSONDecodeError:
|
||||
print(f"[{timestamp}] 📨 Text: {msg.data[:100]}...")
|
||||
|
||||
elif msg.type == aiohttp.WSMsgType.BINARY:
|
||||
# Received audio chunk back (e.g., TTS or echo)
|
||||
print(f"[{timestamp}] 🔊 Audio: {len(msg.data)} bytes", end="\r")
|
||||
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
||||
print(f"\n[{timestamp}] ❌ Socket Closed")
|
||||
break
|
||||
|
||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||
print(f"\n[{timestamp}] ⚠️ Socket Error")
|
||||
break
|
||||
|
||||
|
||||
async def send_file_loop(ws, file_path):
|
||||
"""Stream a raw PCM/WAV file to the server."""
|
||||
if not os.path.exists(file_path):
|
||||
print(f"❌ Error: File '{file_path}' not found.")
|
||||
return
|
||||
|
||||
print(f"📂 Streaming file: {file_path} ...")
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
# Skip WAV header if present (first 44 bytes)
|
||||
if file_path.endswith('.wav'):
|
||||
f.read(44)
|
||||
|
||||
while True:
|
||||
chunk = f.read(CHUNK_SIZE_BYTES)
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
# Send binary frame
|
||||
await ws.send_bytes(chunk)
|
||||
|
||||
# Sleep to simulate real-time playback
|
||||
await asyncio.sleep(CHUNK_DURATION_MS / 1000.0)
|
||||
|
||||
print(f"\n✅ Finished streaming {file_path}")
|
||||
|
||||
|
||||
async def send_sine_loop(ws):
|
||||
"""Stream generated sine wave to the server."""
|
||||
print("🎙️ Starting Audio Stream (Sine Wave)...")
|
||||
|
||||
# Generate 10 seconds of audio buffer
|
||||
audio_buffer = generate_sine_wave(5000)
|
||||
cursor = 0
|
||||
|
||||
while cursor < len(audio_buffer):
|
||||
chunk = audio_buffer[cursor:cursor + CHUNK_SIZE_BYTES]
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
await ws.send_bytes(chunk)
|
||||
cursor += len(chunk)
|
||||
|
||||
await asyncio.sleep(CHUNK_DURATION_MS / 1000.0)
|
||||
|
||||
print("\n✅ Finished streaming test audio.")
|
||||
|
||||
|
||||
async def run_client(url, file_path=None, use_sine=False):
|
||||
"""Run the WebSocket test client."""
|
||||
session = aiohttp.ClientSession()
|
||||
try:
|
||||
print(f"🔌 Connecting to {url}...")
|
||||
async with session.ws_connect(url) as ws:
|
||||
print("✅ Connected!")
|
||||
|
||||
# Send initial invite command
|
||||
init_cmd = {
|
||||
"command": "invite",
|
||||
"option": {
|
||||
"codec": "pcm",
|
||||
"samplerate": SAMPLE_RATE
|
||||
}
|
||||
}
|
||||
await ws.send_json(init_cmd)
|
||||
print("📤 Sent Invite Command")
|
||||
|
||||
# Select sender based on args
|
||||
if use_sine:
|
||||
sender_task = send_sine_loop(ws)
|
||||
elif file_path:
|
||||
sender_task = send_file_loop(ws, file_path)
|
||||
else:
|
||||
# Default to sine wave
|
||||
sender_task = send_sine_loop(ws)
|
||||
|
||||
# Run send and receive loops in parallel
|
||||
await asyncio.gather(
|
||||
receive_loop(ws),
|
||||
sender_task
|
||||
)
|
||||
|
||||
except aiohttp.ClientConnectorError:
|
||||
print(f"❌ Connection Failed. Is the server running at {url}?")
|
||||
except Exception as e:
|
||||
print(f"❌ Error: {e}")
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="WebSocket Audio Test Client")
|
||||
parser.add_argument("--url", default=SERVER_URL, help="WebSocket endpoint URL")
|
||||
parser.add_argument("--file", help="Path to PCM/WAV file to stream")
|
||||
parser.add_argument("--sine", action="store_true", help="Use sine wave generation (default)")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
asyncio.run(run_client(args.url, args.file, args.sine))
|
||||
except KeyboardInterrupt:
|
||||
print("\n👋 Client stopped.")
|
||||
160
test_client.py
160
test_client.py
@@ -1,160 +0,0 @@
|
||||
"""
|
||||
WebSocket Test Client
|
||||
|
||||
Tests the WebSocket server with sine wave audio generation.
|
||||
|
||||
Usage:
|
||||
python test_client.py
|
||||
python test_client.py --url ws://localhost:8000/ws
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import json
|
||||
import struct
|
||||
import math
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
|
||||
# Configuration
|
||||
SERVER_URL = "ws://localhost:8000/ws"
|
||||
SAMPLE_RATE = 16000
|
||||
FREQUENCY = 440 # 440Hz sine wave
|
||||
CHUNK_DURATION_MS = 20
|
||||
CHUNK_SIZE_BYTES = int(SAMPLE_RATE * 2 * (CHUNK_DURATION_MS / 1000.0)) # 640 bytes
|
||||
|
||||
|
||||
def generate_sine_wave(duration_ms=1000):
|
||||
"""
|
||||
Generate sine wave audio data.
|
||||
|
||||
Format: 16kHz, mono, 16-bit PCM
|
||||
"""
|
||||
num_samples = int(SAMPLE_RATE * (duration_ms / 1000.0))
|
||||
audio_data = bytearray()
|
||||
|
||||
for x in range(num_samples):
|
||||
# Generate sine wave sample
|
||||
value = int(32767.0 * math.sin(2 * math.pi * FREQUENCY * x / SAMPLE_RATE))
|
||||
# Pack as little-endian 16-bit signed integer
|
||||
audio_data.extend(struct.pack('<h', value))
|
||||
|
||||
return audio_data
|
||||
|
||||
|
||||
async def receive_loop(ws, session_id):
|
||||
"""
|
||||
Listen for incoming messages from the server.
|
||||
"""
|
||||
print("👂 Listening for server responses...")
|
||||
async for msg in ws:
|
||||
timestamp = datetime.now().strftime("%H:%M:%S")
|
||||
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
try:
|
||||
data = json.loads(msg.data)
|
||||
event_type = data.get('event', 'Unknown')
|
||||
print(f"[{timestamp}] 📨 Event: {event_type}")
|
||||
print(f" {json.dumps(data, indent=2)}")
|
||||
except json.JSONDecodeError:
|
||||
print(f"[{timestamp}] 📨 Text: {msg.data[:100]}...")
|
||||
|
||||
elif msg.type == aiohttp.WSMsgType.BINARY:
|
||||
# Received audio chunk back
|
||||
print(f"[{timestamp}] 🔊 Audio: {len(msg.data)} bytes")
|
||||
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
||||
print(f"\n[{timestamp}] ❌ Connection closed")
|
||||
break
|
||||
|
||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||
print(f"\n[{timestamp}] ⚠️ WebSocket error")
|
||||
break
|
||||
|
||||
|
||||
async def send_audio_loop(ws):
|
||||
"""
|
||||
Stream sine wave audio to the server.
|
||||
"""
|
||||
print("🎙️ Starting audio stream (sine wave)...")
|
||||
|
||||
# Generate 5 seconds of audio
|
||||
audio_buffer = generate_sine_wave(5000)
|
||||
cursor = 0
|
||||
|
||||
while cursor < len(audio_buffer):
|
||||
chunk = audio_buffer[cursor:cursor + CHUNK_SIZE_BYTES]
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
await ws.send_bytes(chunk)
|
||||
print(f"📤 Sent audio chunk: {len(chunk)} bytes", end="\r")
|
||||
|
||||
cursor += len(chunk)
|
||||
|
||||
# Sleep to simulate real-time (20ms per chunk)
|
||||
await asyncio.sleep(CHUNK_DURATION_MS / 1000.0)
|
||||
|
||||
print("\n✅ Finished streaming audio")
|
||||
|
||||
|
||||
async def run_client(url):
|
||||
"""
|
||||
Run the WebSocket test client.
|
||||
"""
|
||||
session = aiohttp.ClientSession()
|
||||
|
||||
try:
|
||||
print(f"🔌 Connecting to {url}...")
|
||||
|
||||
async with session.ws_connect(url) as ws:
|
||||
print("✅ Connected!")
|
||||
print()
|
||||
|
||||
# Send invite command
|
||||
invite_cmd = {
|
||||
"command": "invite",
|
||||
"option": {
|
||||
"codec": "pcm",
|
||||
"samplerate": SAMPLE_RATE
|
||||
}
|
||||
}
|
||||
await ws.send_json(invite_cmd)
|
||||
print("📤 Sent invite command")
|
||||
print()
|
||||
|
||||
# Send a ping command
|
||||
ping_cmd = {"command": "ping"}
|
||||
await ws.send_json(ping_cmd)
|
||||
print("📤 Sent ping command")
|
||||
print()
|
||||
|
||||
# Wait a moment for responses
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Run audio streaming and receiving in parallel
|
||||
await asyncio.gather(
|
||||
receive_loop(ws, None),
|
||||
send_audio_loop(ws)
|
||||
)
|
||||
|
||||
except aiohttp.ClientConnectorError:
|
||||
print(f"❌ Connection failed. Is the server running at {url}?")
|
||||
print(f" Start server with: python main.py")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error: {e}")
|
||||
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="WebSocket Audio Test Client")
|
||||
parser.add_argument("--url", default=SERVER_URL, help="WebSocket URL")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
asyncio.run(run_client(args.url))
|
||||
except KeyboardInterrupt:
|
||||
print("\n👋 Client stopped.")
|
||||
1
utils/__init__.py
Normal file
1
utils/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Utilities Package"""
|
||||
83
utils/logging.py
Normal file
83
utils/logging.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""Logging configuration utilities."""
|
||||
|
||||
import sys
|
||||
from loguru import logger
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def setup_logging(
|
||||
log_level: str = "INFO",
|
||||
log_format: str = "text",
|
||||
log_to_file: bool = True,
|
||||
log_dir: str = "logs"
|
||||
):
|
||||
"""
|
||||
Configure structured logging with loguru.
|
||||
|
||||
Args:
|
||||
log_level: Logging level (DEBUG, INFO, WARNING, ERROR)
|
||||
log_format: Format type (json or text)
|
||||
log_to_file: Whether to log to file
|
||||
log_dir: Directory for log files
|
||||
"""
|
||||
# Remove default handler
|
||||
logger.remove()
|
||||
|
||||
# Console handler
|
||||
if log_format == "json":
|
||||
logger.add(
|
||||
sys.stdout,
|
||||
format="{message}",
|
||||
level=log_level,
|
||||
serialize=True,
|
||||
colorize=False
|
||||
)
|
||||
else:
|
||||
logger.add(
|
||||
sys.stdout,
|
||||
format="<green>{time:HH:mm:ss}</green> | <level>{level: <8}</level> | <level>{message}</level>",
|
||||
level=log_level,
|
||||
colorize=True
|
||||
)
|
||||
|
||||
# File handler
|
||||
if log_to_file:
|
||||
log_path = Path(log_dir)
|
||||
log_path.mkdir(exist_ok=True)
|
||||
|
||||
if log_format == "json":
|
||||
logger.add(
|
||||
log_path / "active_call_{time:YYYY-MM-DD}.log",
|
||||
format="{message}",
|
||||
level=log_level,
|
||||
rotation="1 day",
|
||||
retention="7 days",
|
||||
compression="zip",
|
||||
serialize=True
|
||||
)
|
||||
else:
|
||||
logger.add(
|
||||
log_path / "active_call_{time:YYYY-MM-DD}.log",
|
||||
format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}",
|
||||
level=log_level,
|
||||
rotation="1 day",
|
||||
retention="7 days",
|
||||
compression="zip"
|
||||
)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
def get_logger(name: str = None):
|
||||
"""
|
||||
Get a logger instance.
|
||||
|
||||
Args:
|
||||
name: Logger name (optional)
|
||||
|
||||
Returns:
|
||||
Logger instance
|
||||
"""
|
||||
if name:
|
||||
return logger.bind(name=name)
|
||||
return logger
|
||||
Reference in New Issue
Block a user