Integrate eou and vad

This commit is contained in:
Xin Wang
2026-01-29 13:57:12 +08:00
parent 4cb267a288
commit cd90b4fb37
25 changed files with 2592 additions and 297 deletions

1
app/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Active-Call Application Package"""

81
app/config.py Normal file
View File

@@ -0,0 +1,81 @@
"""Configuration management using Pydantic settings."""
from typing import List, Optional
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
import json
class Settings(BaseSettings):
"""Application settings loaded from environment variables."""
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
case_sensitive=False,
extra="ignore"
)
# Server Configuration
host: str = Field(default="0.0.0.0", description="Server host address")
port: int = Field(default=8000, description="Server port")
external_ip: Optional[str] = Field(default=None, description="External IP for NAT traversal")
# Audio Configuration
sample_rate: int = Field(default=16000, description="Audio sample rate in Hz")
chunk_size_ms: int = Field(default=20, description="Audio chunk duration in milliseconds")
default_codec: str = Field(default="pcm", description="Default audio codec")
# VAD Configuration
vad_type: str = Field(default="silero", description="VAD algorithm type")
vad_model_path: str = Field(default="data/vad/silero_vad.onnx", description="Path to VAD model")
vad_threshold: float = Field(default=0.5, description="VAD detection threshold")
vad_min_speech_duration_ms: int = Field(default=250, description="Minimum speech duration in milliseconds")
vad_eou_threshold_ms: int = Field(default=400, description="End of utterance (silence) threshold in milliseconds")
# Logging
log_level: str = Field(default="INFO", description="Logging level")
log_format: str = Field(default="json", description="Log format (json or text)")
# CORS
cors_origins: str = Field(
default='["http://localhost:3000", "http://localhost:8080"]',
description="CORS allowed origins"
)
# ICE Servers (WebRTC)
ice_servers: str = Field(
default='[{"urls": "stun:stun.l.google.com:19302"}]',
description="ICE servers configuration"
)
@property
def chunk_size_bytes(self) -> int:
"""Calculate chunk size in bytes based on sample rate and duration."""
# 16-bit (2 bytes) per sample, mono channel
return int(self.sample_rate * 2 * (self.chunk_size_ms / 1000.0))
@property
def cors_origins_list(self) -> List[str]:
"""Parse CORS origins from JSON string."""
try:
return json.loads(self.cors_origins)
except json.JSONDecodeError:
return ["http://localhost:3000", "http://localhost:8080"]
@property
def ice_servers_list(self) -> List[dict]:
"""Parse ICE servers from JSON string."""
try:
return json.loads(self.ice_servers)
except json.JSONDecodeError:
return [{"urls": "stun:stun.l.google.com:19302"}]
# Global settings instance
settings = Settings()
def get_settings() -> Settings:
"""Get application settings instance."""
return settings

290
app/main.py Normal file
View File

@@ -0,0 +1,290 @@
"""FastAPI application with WebSocket and WebRTC endpoints."""
import uuid
import json
from typing import Dict, Any, Optional
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from loguru import logger
# Try to import aiortc (optional for WebRTC functionality)
try:
from aiortc import RTCPeerConnection, RTCSessionDescription
AIORTC_AVAILABLE = True
except ImportError:
AIORTC_AVAILABLE = False
logger.warning("aiortc not available - WebRTC endpoint will be disabled")
from app.config import settings
from core.transports import SocketTransport, WebRtcTransport
from core.session import Session
from processors.tracks import Resampled16kTrack
from core.events import get_event_bus, reset_event_bus
# Initialize FastAPI
app = FastAPI(title="Python Active-Call", version="0.1.0")
# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins_list,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Active sessions storage
active_sessions: Dict[str, Session] = {}
# Configure logging
logger.remove()
logger.add(
"../logs/active_call_{time}.log",
rotation="1 day",
retention="7 days",
level=settings.log_level,
format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}"
)
logger.add(
lambda msg: print(msg, end=""),
level=settings.log_level,
format="{time:HH:mm:ss} | {level: <8} | {message}"
)
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {"status": "healthy", "sessions": len(active_sessions)}
@app.get("/iceservers")
async def get_ice_servers():
"""Get ICE servers configuration for WebRTC."""
return settings.ice_servers_list
@app.get("/call/lists")
async def list_calls():
"""List all active calls."""
return {
"calls": [
{
"id": session_id,
"state": session.state,
"created_at": session.created_at
}
for session_id, session in active_sessions.items()
]
}
@app.post("/call/kill/{session_id}")
async def kill_call(session_id: str):
"""Kill a specific active call."""
if session_id not in active_sessions:
raise HTTPException(status_code=404, detail="Session not found")
session = active_sessions[session_id]
await session.cleanup()
del active_sessions[session_id]
return True
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
"""
WebSocket endpoint for raw audio streaming.
Accepts mixed text/binary frames:
- Text frames: JSON commands
- Binary frames: PCM audio data (16kHz, 16-bit, mono)
"""
await websocket.accept()
session_id = str(uuid.uuid4())
# Create transport and session
transport = SocketTransport(websocket)
session = Session(session_id, transport)
active_sessions[session_id] = session
logger.info(f"WebSocket connection established: {session_id}")
try:
# Receive loop
while True:
message = await websocket.receive()
# Handle binary audio data
if "bytes" in message:
await session.handle_audio(message["bytes"])
# Handle text commands
elif "text" in message:
await session.handle_text(message["text"])
except WebSocketDisconnect:
logger.info(f"WebSocket disconnected: {session_id}")
except Exception as e:
logger.error(f"WebSocket error: {e}", exc_info=True)
finally:
# Cleanup session
if session_id in active_sessions:
await session.cleanup()
del active_sessions[session_id]
logger.info(f"Session {session_id} removed")
@app.websocket("/webrtc")
async def webrtc_endpoint(websocket: WebSocket):
"""
WebRTC endpoint for WebRTC audio streaming.
Uses WebSocket for signaling (SDP exchange) and WebRTC for media transport.
"""
# Check if aiortc is available
if not AIORTC_AVAILABLE:
await websocket.close(code=1011, reason="WebRTC not available - aiortc/av not installed")
logger.warning("WebRTC connection attempted but aiortc is not available")
return
await websocket.accept()
session_id = str(uuid.uuid4())
# Create WebRTC peer connection
pc = RTCPeerConnection()
# Create transport and session
transport = WebRtcTransport(websocket, pc)
session = Session(session_id, transport)
active_sessions[session_id] = session
logger.info(f"WebRTC connection established: {session_id}")
# Track handler for incoming audio
@pc.on("track")
def on_track(track):
logger.info(f"Track received: {track.kind}")
if track.kind == "audio":
# Wrap track with resampler
wrapped_track = Resampled16kTrack(track)
# Create task to pull audio from track
async def pull_audio():
try:
while True:
frame = await wrapped_track.recv()
# Convert frame to bytes
pcm_bytes = frame.to_ndarray().tobytes()
# Feed to session
await session.handle_audio(pcm_bytes)
except Exception as e:
logger.error(f"Error pulling audio from track: {e}")
asyncio.create_task(pull_audio())
@pc.on("connectionstatechange")
async def on_connectionstatechange():
logger.info(f"Connection state: {pc.connectionState}")
if pc.connectionState == "failed" or pc.connectionState == "closed":
await session.cleanup()
try:
# Signaling loop
while True:
message = await websocket.receive()
if "text" not in message:
continue
data = json.loads(message["text"])
# Handle SDP offer/answer
if "sdp" in data and "type" in data:
logger.info(f"Received SDP {data['type']}")
# Set remote description
offer = RTCSessionDescription(sdp=data["sdp"], type=data["type"])
await pc.setRemoteDescription(offer)
# Create and set local description
if data["type"] == "offer":
answer = await pc.createAnswer()
await pc.setLocalDescription(answer)
# Send answer back
await websocket.send_text(json.dumps({
"event": "answer",
"trackId": session_id,
"timestamp": int(asyncio.get_event_loop().time() * 1000),
"sdp": pc.localDescription.sdp
}))
logger.info(f"Sent SDP answer")
else:
# Handle other commands
await session.handle_text(message["text"])
except WebSocketDisconnect:
logger.info(f"WebRTC WebSocket disconnected: {session_id}")
except Exception as e:
logger.error(f"WebRTC error: {e}", exc_info=True)
finally:
# Cleanup
await pc.close()
if session_id in active_sessions:
await session.cleanup()
del active_sessions[session_id]
logger.info(f"WebRTC session {session_id} removed")
@app.on_event("startup")
async def startup_event():
"""Run on application startup."""
logger.info("Starting Python Active-Call server")
logger.info(f"Server: {settings.host}:{settings.port}")
logger.info(f"Sample rate: {settings.sample_rate} Hz")
logger.info(f"VAD model: {settings.vad_model_path}")
@app.on_event("shutdown")
async def shutdown_event():
"""Run on application shutdown."""
logger.info("Shutting down Python Active-Call server")
# Cleanup all sessions
for session_id, session in active_sessions.items():
await session.cleanup()
# Close event bus
event_bus = get_event_bus()
await event_bus.close()
reset_event_bus()
logger.info("Server shutdown complete")
if __name__ == "__main__":
import uvicorn
# Create logs directory
import os
os.makedirs("logs", exist_ok=True)
# Run server
uvicorn.run(
"app.main:app",
host=settings.host,
port=settings.port,
reload=True,
log_level=settings.log_level.lower()
)