Files
AI-VideoAssistant/engine/app/main.py
Xin Wang 935f2fbd1f Refactor assistant configuration management and update documentation
- Removed legacy agent profile settings from the .env.example and README, streamlining the configuration process.
- Introduced a new local YAML configuration adapter for assistant settings, allowing for easier management of assistant profiles.
- Updated backend integration documentation to clarify the behavior of assistant config sourcing based on backend URL settings.
- Adjusted various service implementations to directly utilize API keys from the new configuration structure.
- Enhanced test coverage for the new local YAML adapter and its integration with backend services.
2026-03-05 21:24:15 +08:00

412 lines
12 KiB
Python

"""FastAPI application with WebSocket and WebRTC endpoints."""
import asyncio
import json
import time
import uuid
from pathlib import Path
from typing import Dict, Any, Optional, List
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, FileResponse
from loguru import logger
# Try to import aiortc (optional for WebRTC functionality)
try:
from aiortc import RTCPeerConnection, RTCSessionDescription
AIORTC_AVAILABLE = True
except ImportError:
AIORTC_AVAILABLE = False
logger.warning("aiortc not available - WebRTC endpoint will be disabled")
from app.config import settings
from app.backend_adapters import build_backend_adapter_from_settings
from core.transports import SocketTransport, WebRtcTransport, BaseTransport
from core.session import Session
from processors.tracks import Resampled16kTrack
from core.events import get_event_bus, reset_event_bus
# Check interval for heartbeat/timeout (seconds)
_HEARTBEAT_CHECK_INTERVAL_SEC = 5
async def heartbeat_and_timeout_task(
transport: BaseTransport,
session: Session,
session_id: str,
last_received_at: List[float],
last_heartbeat_at: List[float],
inactivity_timeout_sec: int,
heartbeat_interval_sec: int,
) -> None:
"""
Background task: send heartBeat every ~heartbeat_interval_sec and close
connection if no message from client for inactivity_timeout_sec.
"""
while True:
await asyncio.sleep(_HEARTBEAT_CHECK_INTERVAL_SEC)
if transport.is_closed:
break
now = time.monotonic()
if now - last_received_at[0] > inactivity_timeout_sec:
logger.info(f"Session {session_id}: {inactivity_timeout_sec}s no message, closing")
await session.cleanup()
break
if now - last_heartbeat_at[0] >= heartbeat_interval_sec:
try:
await session.send_heartbeat()
last_heartbeat_at[0] = now
except Exception as e:
logger.debug(f"Session {session_id}: heartbeat send failed: {e}")
break
# Initialize FastAPI
app = FastAPI(title="Python Active-Call", version="0.1.0")
_WEB_CLIENT_PATH = Path(__file__).resolve().parent.parent / "examples" / "web_client.html"
# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins_list,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Active sessions storage
active_sessions: Dict[str, Session] = {}
backend_gateway = build_backend_adapter_from_settings()
# Configure logging
logger.remove()
logger.add(
"./logs/active_call_{time}.log",
rotation="1 day",
retention="7 days",
level=settings.log_level,
format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}"
)
logger.add(
lambda msg: print(msg, end=""),
level=settings.log_level,
format="{time:HH:mm:ss} | {level: <8} | {message}"
)
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {"status": "healthy", "sessions": len(active_sessions)}
@app.get("/")
async def web_client_root():
"""Serve the web client."""
if not _WEB_CLIENT_PATH.exists():
raise HTTPException(status_code=404, detail="Web client not found")
return FileResponse(_WEB_CLIENT_PATH)
@app.get("/client")
async def web_client_alias():
"""Alias for the web client."""
if not _WEB_CLIENT_PATH.exists():
raise HTTPException(status_code=404, detail="Web client not found")
return FileResponse(_WEB_CLIENT_PATH)
@app.get("/iceservers")
async def get_ice_servers():
"""Get ICE servers configuration for WebRTC."""
return settings.ice_servers_list
@app.get("/call/lists")
async def list_calls():
"""List all active calls."""
return {
"calls": [
{
"id": session_id,
"state": session.state,
"created_at": session.created_at
}
for session_id, session in active_sessions.items()
]
}
@app.post("/call/kill/{session_id}")
async def kill_call(session_id: str):
"""Kill a specific active call."""
if session_id not in active_sessions:
raise HTTPException(status_code=404, detail="Session not found")
session = active_sessions[session_id]
await session.cleanup()
del active_sessions[session_id]
return True
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
"""
WebSocket endpoint for raw audio streaming.
Accepts mixed text/binary frames:
- Text frames: JSON commands
- Binary frames: PCM audio data (16kHz, 16-bit, mono)
"""
await websocket.accept()
session_id = str(uuid.uuid4())
assistant_id = str(websocket.query_params.get("assistant_id") or "").strip() or None
# Create transport and session
transport = SocketTransport(websocket)
session = Session(
session_id,
transport,
backend_gateway=backend_gateway,
assistant_id=assistant_id,
)
active_sessions[session_id] = session
logger.info(f"WebSocket connection established: {session_id} assistant_id={assistant_id or '-'}")
last_received_at: List[float] = [time.monotonic()]
last_heartbeat_at: List[float] = [0.0]
hb_task = asyncio.create_task(
heartbeat_and_timeout_task(
transport,
session,
session_id,
last_received_at,
last_heartbeat_at,
settings.inactivity_timeout_sec,
settings.heartbeat_interval_sec,
)
)
try:
# Receive loop
while True:
message = await websocket.receive()
message_type = message.get("type")
if message_type == "websocket.disconnect":
logger.info(f"WebSocket disconnected: {session_id}")
break
last_received_at[0] = time.monotonic()
# Handle binary audio data
if "bytes" in message:
await session.handle_audio(message["bytes"])
# Handle text commands
elif "text" in message:
await session.handle_text(message["text"])
except WebSocketDisconnect:
logger.info(f"WebSocket disconnected: {session_id}")
except Exception as e:
logger.error(f"WebSocket error: {e}", exc_info=True)
finally:
hb_task.cancel()
try:
await hb_task
except asyncio.CancelledError:
pass
# Cleanup session
if session_id in active_sessions:
await session.cleanup()
del active_sessions[session_id]
logger.info(f"Session {session_id} removed")
@app.websocket("/webrtc")
async def webrtc_endpoint(websocket: WebSocket):
"""
WebRTC endpoint for WebRTC audio streaming.
Uses WebSocket for signaling (SDP exchange) and WebRTC for media transport.
"""
# Check if aiortc is available
if not AIORTC_AVAILABLE:
await websocket.close(code=1011, reason="WebRTC not available - aiortc/av not installed")
logger.warning("WebRTC connection attempted but aiortc is not available")
return
await websocket.accept()
session_id = str(uuid.uuid4())
assistant_id = str(websocket.query_params.get("assistant_id") or "").strip() or None
# Create WebRTC peer connection
pc = RTCPeerConnection()
# Create transport and session
transport = WebRtcTransport(websocket, pc)
session = Session(
session_id,
transport,
backend_gateway=backend_gateway,
assistant_id=assistant_id,
)
active_sessions[session_id] = session
logger.info(f"WebRTC connection established: {session_id} assistant_id={assistant_id or '-'}")
last_received_at: List[float] = [time.monotonic()]
last_heartbeat_at: List[float] = [0.0]
hb_task = asyncio.create_task(
heartbeat_and_timeout_task(
transport,
session,
session_id,
last_received_at,
last_heartbeat_at,
settings.inactivity_timeout_sec,
settings.heartbeat_interval_sec,
)
)
# Track handler for incoming audio
@pc.on("track")
def on_track(track):
logger.info(f"Track received: {track.kind}")
if track.kind == "audio":
# Wrap track with resampler
wrapped_track = Resampled16kTrack(track)
# Create task to pull audio from track
async def pull_audio():
try:
while True:
frame = await wrapped_track.recv()
# Convert frame to bytes
pcm_bytes = frame.to_ndarray().tobytes()
# Feed to session
await session.handle_audio(pcm_bytes)
except Exception as e:
logger.error(f"Error pulling audio from track: {e}")
asyncio.create_task(pull_audio())
@pc.on("connectionstatechange")
async def on_connectionstatechange():
logger.info(f"Connection state: {pc.connectionState}")
if pc.connectionState == "failed" or pc.connectionState == "closed":
await session.cleanup()
try:
# Signaling loop
while True:
message = await websocket.receive()
if "text" not in message:
continue
last_received_at[0] = time.monotonic()
data = json.loads(message["text"])
# Handle SDP offer/answer
if "sdp" in data and "type" in data:
logger.info(f"Received SDP {data['type']}")
# Set remote description
offer = RTCSessionDescription(sdp=data["sdp"], type=data["type"])
await pc.setRemoteDescription(offer)
# Create and set local description
if data["type"] == "offer":
answer = await pc.createAnswer()
await pc.setLocalDescription(answer)
# Send answer back
await websocket.send_text(json.dumps({
"event": "answer",
"trackId": session_id,
"timestamp": int(asyncio.get_event_loop().time() * 1000),
"sdp": pc.localDescription.sdp
}))
logger.info(f"Sent SDP answer")
else:
# Handle other commands
await session.handle_text(message["text"])
except WebSocketDisconnect:
logger.info(f"WebRTC WebSocket disconnected: {session_id}")
except Exception as e:
logger.error(f"WebRTC error: {e}", exc_info=True)
finally:
hb_task.cancel()
try:
await hb_task
except asyncio.CancelledError:
pass
# Cleanup
await pc.close()
if session_id in active_sessions:
await session.cleanup()
del active_sessions[session_id]
logger.info(f"WebRTC session {session_id} removed")
@app.on_event("startup")
async def startup_event():
"""Run on application startup."""
logger.info("Starting Python Active-Call server")
logger.info(f"Server: {settings.host}:{settings.port}")
logger.info(f"Sample rate: {settings.sample_rate} Hz")
logger.info(f"VAD model: {settings.vad_model_path}")
logger.info(
"Assistant runtime config source: backend when BACKEND_URL is set, "
"otherwise local YAML by assistant_id from ASSISTANT_LOCAL_CONFIG_DIR"
)
@app.on_event("shutdown")
async def shutdown_event():
"""Run on application shutdown."""
logger.info("Shutting down Python Active-Call server")
# Cleanup all sessions
for session_id, session in active_sessions.items():
await session.cleanup()
# Close event bus
event_bus = get_event_bus()
await event_bus.close()
reset_event_bus()
logger.info("Server shutdown complete")
if __name__ == "__main__":
import uvicorn
# Create logs directory
import os
os.makedirs("logs", exist_ok=True)
# Run server
uvicorn.run(
"app.main:app",
host=settings.host,
port=settings.port,
reload=True,
log_level=settings.log_level.lower()
)