- Removed legacy agent profile settings from the .env.example and README, streamlining the configuration process. - Introduced a new local YAML configuration adapter for assistant settings, allowing for easier management of assistant profiles. - Updated backend integration documentation to clarify the behavior of assistant config sourcing based on backend URL settings. - Adjusted various service implementations to directly utilize API keys from the new configuration structure. - Enhanced test coverage for the new local YAML adapter and its integration with backend services.
412 lines
12 KiB
Python
412 lines
12 KiB
Python
"""FastAPI application with WebSocket and WebRTC endpoints."""
|
|
|
|
import asyncio
|
|
import json
|
|
import time
|
|
import uuid
|
|
from pathlib import Path
|
|
from typing import Dict, Any, Optional, List
|
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import JSONResponse, FileResponse
|
|
from loguru import logger
|
|
|
|
# Try to import aiortc (optional for WebRTC functionality)
|
|
try:
|
|
from aiortc import RTCPeerConnection, RTCSessionDescription
|
|
AIORTC_AVAILABLE = True
|
|
except ImportError:
|
|
AIORTC_AVAILABLE = False
|
|
logger.warning("aiortc not available - WebRTC endpoint will be disabled")
|
|
|
|
from app.config import settings
|
|
from app.backend_adapters import build_backend_adapter_from_settings
|
|
from core.transports import SocketTransport, WebRtcTransport, BaseTransport
|
|
from core.session import Session
|
|
from processors.tracks import Resampled16kTrack
|
|
from core.events import get_event_bus, reset_event_bus
|
|
|
|
# Check interval for heartbeat/timeout (seconds)
|
|
_HEARTBEAT_CHECK_INTERVAL_SEC = 5
|
|
|
|
|
|
async def heartbeat_and_timeout_task(
|
|
transport: BaseTransport,
|
|
session: Session,
|
|
session_id: str,
|
|
last_received_at: List[float],
|
|
last_heartbeat_at: List[float],
|
|
inactivity_timeout_sec: int,
|
|
heartbeat_interval_sec: int,
|
|
) -> None:
|
|
"""
|
|
Background task: send heartBeat every ~heartbeat_interval_sec and close
|
|
connection if no message from client for inactivity_timeout_sec.
|
|
"""
|
|
while True:
|
|
await asyncio.sleep(_HEARTBEAT_CHECK_INTERVAL_SEC)
|
|
if transport.is_closed:
|
|
break
|
|
now = time.monotonic()
|
|
if now - last_received_at[0] > inactivity_timeout_sec:
|
|
logger.info(f"Session {session_id}: {inactivity_timeout_sec}s no message, closing")
|
|
await session.cleanup()
|
|
break
|
|
if now - last_heartbeat_at[0] >= heartbeat_interval_sec:
|
|
try:
|
|
await session.send_heartbeat()
|
|
last_heartbeat_at[0] = now
|
|
except Exception as e:
|
|
logger.debug(f"Session {session_id}: heartbeat send failed: {e}")
|
|
break
|
|
|
|
|
|
# Initialize FastAPI
|
|
app = FastAPI(title="Python Active-Call", version="0.1.0")
|
|
_WEB_CLIENT_PATH = Path(__file__).resolve().parent.parent / "examples" / "web_client.html"
|
|
|
|
# Configure CORS
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=settings.cors_origins_list,
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# Active sessions storage
|
|
active_sessions: Dict[str, Session] = {}
|
|
backend_gateway = build_backend_adapter_from_settings()
|
|
|
|
# Configure logging
|
|
logger.remove()
|
|
logger.add(
|
|
"./logs/active_call_{time}.log",
|
|
rotation="1 day",
|
|
retention="7 days",
|
|
level=settings.log_level,
|
|
format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}"
|
|
)
|
|
logger.add(
|
|
lambda msg: print(msg, end=""),
|
|
level=settings.log_level,
|
|
format="{time:HH:mm:ss} | {level: <8} | {message}"
|
|
)
|
|
|
|
|
|
@app.get("/health")
|
|
async def health_check():
|
|
"""Health check endpoint."""
|
|
return {"status": "healthy", "sessions": len(active_sessions)}
|
|
|
|
|
|
@app.get("/")
|
|
async def web_client_root():
|
|
"""Serve the web client."""
|
|
if not _WEB_CLIENT_PATH.exists():
|
|
raise HTTPException(status_code=404, detail="Web client not found")
|
|
return FileResponse(_WEB_CLIENT_PATH)
|
|
|
|
|
|
@app.get("/client")
|
|
async def web_client_alias():
|
|
"""Alias for the web client."""
|
|
if not _WEB_CLIENT_PATH.exists():
|
|
raise HTTPException(status_code=404, detail="Web client not found")
|
|
return FileResponse(_WEB_CLIENT_PATH)
|
|
|
|
|
|
|
|
|
|
@app.get("/iceservers")
|
|
async def get_ice_servers():
|
|
"""Get ICE servers configuration for WebRTC."""
|
|
return settings.ice_servers_list
|
|
|
|
|
|
@app.get("/call/lists")
|
|
async def list_calls():
|
|
"""List all active calls."""
|
|
return {
|
|
"calls": [
|
|
{
|
|
"id": session_id,
|
|
"state": session.state,
|
|
"created_at": session.created_at
|
|
}
|
|
for session_id, session in active_sessions.items()
|
|
]
|
|
}
|
|
|
|
|
|
@app.post("/call/kill/{session_id}")
|
|
async def kill_call(session_id: str):
|
|
"""Kill a specific active call."""
|
|
if session_id not in active_sessions:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
session = active_sessions[session_id]
|
|
await session.cleanup()
|
|
del active_sessions[session_id]
|
|
|
|
return True
|
|
|
|
|
|
@app.websocket("/ws")
|
|
async def websocket_endpoint(websocket: WebSocket):
|
|
"""
|
|
WebSocket endpoint for raw audio streaming.
|
|
|
|
Accepts mixed text/binary frames:
|
|
- Text frames: JSON commands
|
|
- Binary frames: PCM audio data (16kHz, 16-bit, mono)
|
|
"""
|
|
await websocket.accept()
|
|
session_id = str(uuid.uuid4())
|
|
assistant_id = str(websocket.query_params.get("assistant_id") or "").strip() or None
|
|
|
|
# Create transport and session
|
|
transport = SocketTransport(websocket)
|
|
session = Session(
|
|
session_id,
|
|
transport,
|
|
backend_gateway=backend_gateway,
|
|
assistant_id=assistant_id,
|
|
)
|
|
active_sessions[session_id] = session
|
|
|
|
logger.info(f"WebSocket connection established: {session_id} assistant_id={assistant_id or '-'}")
|
|
|
|
last_received_at: List[float] = [time.monotonic()]
|
|
last_heartbeat_at: List[float] = [0.0]
|
|
hb_task = asyncio.create_task(
|
|
heartbeat_and_timeout_task(
|
|
transport,
|
|
session,
|
|
session_id,
|
|
last_received_at,
|
|
last_heartbeat_at,
|
|
settings.inactivity_timeout_sec,
|
|
settings.heartbeat_interval_sec,
|
|
)
|
|
)
|
|
|
|
try:
|
|
# Receive loop
|
|
while True:
|
|
message = await websocket.receive()
|
|
message_type = message.get("type")
|
|
|
|
if message_type == "websocket.disconnect":
|
|
logger.info(f"WebSocket disconnected: {session_id}")
|
|
break
|
|
|
|
last_received_at[0] = time.monotonic()
|
|
|
|
# Handle binary audio data
|
|
if "bytes" in message:
|
|
await session.handle_audio(message["bytes"])
|
|
|
|
# Handle text commands
|
|
elif "text" in message:
|
|
await session.handle_text(message["text"])
|
|
|
|
except WebSocketDisconnect:
|
|
logger.info(f"WebSocket disconnected: {session_id}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"WebSocket error: {e}", exc_info=True)
|
|
|
|
finally:
|
|
hb_task.cancel()
|
|
try:
|
|
await hb_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
# Cleanup session
|
|
if session_id in active_sessions:
|
|
await session.cleanup()
|
|
del active_sessions[session_id]
|
|
|
|
logger.info(f"Session {session_id} removed")
|
|
|
|
|
|
@app.websocket("/webrtc")
|
|
async def webrtc_endpoint(websocket: WebSocket):
|
|
"""
|
|
WebRTC endpoint for WebRTC audio streaming.
|
|
|
|
Uses WebSocket for signaling (SDP exchange) and WebRTC for media transport.
|
|
"""
|
|
# Check if aiortc is available
|
|
if not AIORTC_AVAILABLE:
|
|
await websocket.close(code=1011, reason="WebRTC not available - aiortc/av not installed")
|
|
logger.warning("WebRTC connection attempted but aiortc is not available")
|
|
return
|
|
await websocket.accept()
|
|
session_id = str(uuid.uuid4())
|
|
assistant_id = str(websocket.query_params.get("assistant_id") or "").strip() or None
|
|
|
|
# Create WebRTC peer connection
|
|
pc = RTCPeerConnection()
|
|
|
|
# Create transport and session
|
|
transport = WebRtcTransport(websocket, pc)
|
|
session = Session(
|
|
session_id,
|
|
transport,
|
|
backend_gateway=backend_gateway,
|
|
assistant_id=assistant_id,
|
|
)
|
|
active_sessions[session_id] = session
|
|
|
|
logger.info(f"WebRTC connection established: {session_id} assistant_id={assistant_id or '-'}")
|
|
|
|
last_received_at: List[float] = [time.monotonic()]
|
|
last_heartbeat_at: List[float] = [0.0]
|
|
hb_task = asyncio.create_task(
|
|
heartbeat_and_timeout_task(
|
|
transport,
|
|
session,
|
|
session_id,
|
|
last_received_at,
|
|
last_heartbeat_at,
|
|
settings.inactivity_timeout_sec,
|
|
settings.heartbeat_interval_sec,
|
|
)
|
|
)
|
|
|
|
# Track handler for incoming audio
|
|
@pc.on("track")
|
|
def on_track(track):
|
|
logger.info(f"Track received: {track.kind}")
|
|
|
|
if track.kind == "audio":
|
|
# Wrap track with resampler
|
|
wrapped_track = Resampled16kTrack(track)
|
|
|
|
# Create task to pull audio from track
|
|
async def pull_audio():
|
|
try:
|
|
while True:
|
|
frame = await wrapped_track.recv()
|
|
# Convert frame to bytes
|
|
pcm_bytes = frame.to_ndarray().tobytes()
|
|
# Feed to session
|
|
await session.handle_audio(pcm_bytes)
|
|
except Exception as e:
|
|
logger.error(f"Error pulling audio from track: {e}")
|
|
|
|
asyncio.create_task(pull_audio())
|
|
|
|
@pc.on("connectionstatechange")
|
|
async def on_connectionstatechange():
|
|
logger.info(f"Connection state: {pc.connectionState}")
|
|
if pc.connectionState == "failed" or pc.connectionState == "closed":
|
|
await session.cleanup()
|
|
|
|
try:
|
|
# Signaling loop
|
|
while True:
|
|
message = await websocket.receive()
|
|
|
|
if "text" not in message:
|
|
continue
|
|
|
|
last_received_at[0] = time.monotonic()
|
|
data = json.loads(message["text"])
|
|
|
|
# Handle SDP offer/answer
|
|
if "sdp" in data and "type" in data:
|
|
logger.info(f"Received SDP {data['type']}")
|
|
|
|
# Set remote description
|
|
offer = RTCSessionDescription(sdp=data["sdp"], type=data["type"])
|
|
await pc.setRemoteDescription(offer)
|
|
|
|
# Create and set local description
|
|
if data["type"] == "offer":
|
|
answer = await pc.createAnswer()
|
|
await pc.setLocalDescription(answer)
|
|
|
|
# Send answer back
|
|
await websocket.send_text(json.dumps({
|
|
"event": "answer",
|
|
"trackId": session_id,
|
|
"timestamp": int(asyncio.get_event_loop().time() * 1000),
|
|
"sdp": pc.localDescription.sdp
|
|
}))
|
|
|
|
logger.info(f"Sent SDP answer")
|
|
|
|
else:
|
|
# Handle other commands
|
|
await session.handle_text(message["text"])
|
|
|
|
except WebSocketDisconnect:
|
|
logger.info(f"WebRTC WebSocket disconnected: {session_id}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"WebRTC error: {e}", exc_info=True)
|
|
|
|
finally:
|
|
hb_task.cancel()
|
|
try:
|
|
await hb_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
# Cleanup
|
|
await pc.close()
|
|
if session_id in active_sessions:
|
|
await session.cleanup()
|
|
del active_sessions[session_id]
|
|
|
|
logger.info(f"WebRTC session {session_id} removed")
|
|
|
|
|
|
@app.on_event("startup")
|
|
async def startup_event():
|
|
"""Run on application startup."""
|
|
logger.info("Starting Python Active-Call server")
|
|
logger.info(f"Server: {settings.host}:{settings.port}")
|
|
logger.info(f"Sample rate: {settings.sample_rate} Hz")
|
|
logger.info(f"VAD model: {settings.vad_model_path}")
|
|
logger.info(
|
|
"Assistant runtime config source: backend when BACKEND_URL is set, "
|
|
"otherwise local YAML by assistant_id from ASSISTANT_LOCAL_CONFIG_DIR"
|
|
)
|
|
|
|
|
|
@app.on_event("shutdown")
|
|
async def shutdown_event():
|
|
"""Run on application shutdown."""
|
|
logger.info("Shutting down Python Active-Call server")
|
|
|
|
# Cleanup all sessions
|
|
for session_id, session in active_sessions.items():
|
|
await session.cleanup()
|
|
|
|
# Close event bus
|
|
event_bus = get_event_bus()
|
|
await event_bus.close()
|
|
reset_event_bus()
|
|
|
|
logger.info("Server shutdown complete")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
# Create logs directory
|
|
import os
|
|
os.makedirs("logs", exist_ok=True)
|
|
|
|
# Run server
|
|
uvicorn.run(
|
|
"app.main:app",
|
|
host=settings.host,
|
|
port=settings.port,
|
|
reload=True,
|
|
log_level=settings.log_level.lower()
|
|
)
|