"""FastAPI application with WebSocket and WebRTC endpoints.""" import asyncio import json import time import uuid from typing import Dict, Any, Optional, List from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from loguru import logger # Try to import aiortc (optional for WebRTC functionality) try: from aiortc import RTCPeerConnection, RTCSessionDescription AIORTC_AVAILABLE = True except ImportError: AIORTC_AVAILABLE = False logger.warning("aiortc not available - WebRTC endpoint will be disabled") from app.config import settings from core.transports import SocketTransport, WebRtcTransport, BaseTransport from core.session import Session from processors.tracks import Resampled16kTrack from core.events import get_event_bus, reset_event_bus # Check interval for heartbeat/timeout (seconds) _HEARTBEAT_CHECK_INTERVAL_SEC = 5 async def heartbeat_and_timeout_task( transport: BaseTransport, session: Session, session_id: str, last_received_at: List[float], last_heartbeat_at: List[float], inactivity_timeout_sec: int, heartbeat_interval_sec: int, ) -> None: """ Background task: send heartBeat every ~heartbeat_interval_sec and close connection if no message from client for inactivity_timeout_sec. """ while True: await asyncio.sleep(_HEARTBEAT_CHECK_INTERVAL_SEC) if transport.is_closed: break now = time.monotonic() if now - last_received_at[0] > inactivity_timeout_sec: logger.info(f"Session {session_id}: {inactivity_timeout_sec}s no message, closing") await session.cleanup() break if now - last_heartbeat_at[0] >= heartbeat_interval_sec: try: await transport.send_event({ "event": "heartBeat", "timestamp": int(time.time() * 1000), }) last_heartbeat_at[0] = now except Exception as e: logger.debug(f"Session {session_id}: heartbeat send failed: {e}") break # Initialize FastAPI app = FastAPI(title="Python Active-Call", version="0.1.0") # Configure CORS app.add_middleware( CORSMiddleware, allow_origins=settings.cors_origins_list, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Active sessions storage active_sessions: Dict[str, Session] = {} # Configure logging logger.remove() logger.add( "./logs/active_call_{time}.log", rotation="1 day", retention="7 days", level=settings.log_level, format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}" ) logger.add( lambda msg: print(msg, end=""), level=settings.log_level, format="{time:HH:mm:ss} | {level: <8} | {message}" ) @app.get("/health") async def health_check(): """Health check endpoint.""" return {"status": "healthy", "sessions": len(active_sessions)} @app.get("/iceservers") async def get_ice_servers(): """Get ICE servers configuration for WebRTC.""" return settings.ice_servers_list @app.get("/call/lists") async def list_calls(): """List all active calls.""" return { "calls": [ { "id": session_id, "state": session.state, "created_at": session.created_at } for session_id, session in active_sessions.items() ] } @app.post("/call/kill/{session_id}") async def kill_call(session_id: str): """Kill a specific active call.""" if session_id not in active_sessions: raise HTTPException(status_code=404, detail="Session not found") session = active_sessions[session_id] await session.cleanup() del active_sessions[session_id] return True @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): """ WebSocket endpoint for raw audio streaming. Accepts mixed text/binary frames: - Text frames: JSON commands - Binary frames: PCM audio data (16kHz, 16-bit, mono) """ await websocket.accept() session_id = str(uuid.uuid4()) # Create transport and session transport = SocketTransport(websocket) session = Session(session_id, transport) active_sessions[session_id] = session logger.info(f"WebSocket connection established: {session_id}") last_received_at: List[float] = [time.monotonic()] last_heartbeat_at: List[float] = [0.0] hb_task = asyncio.create_task( heartbeat_and_timeout_task( transport, session, session_id, last_received_at, last_heartbeat_at, settings.inactivity_timeout_sec, settings.heartbeat_interval_sec, ) ) try: # Receive loop while True: message = await websocket.receive() last_received_at[0] = time.monotonic() # Handle binary audio data if "bytes" in message: await session.handle_audio(message["bytes"]) # Handle text commands elif "text" in message: await session.handle_text(message["text"]) except WebSocketDisconnect: logger.info(f"WebSocket disconnected: {session_id}") except Exception as e: logger.error(f"WebSocket error: {e}", exc_info=True) finally: hb_task.cancel() try: await hb_task except asyncio.CancelledError: pass # Cleanup session if session_id in active_sessions: await session.cleanup() del active_sessions[session_id] logger.info(f"Session {session_id} removed") @app.websocket("/webrtc") async def webrtc_endpoint(websocket: WebSocket): """ WebRTC endpoint for WebRTC audio streaming. Uses WebSocket for signaling (SDP exchange) and WebRTC for media transport. """ # Check if aiortc is available if not AIORTC_AVAILABLE: await websocket.close(code=1011, reason="WebRTC not available - aiortc/av not installed") logger.warning("WebRTC connection attempted but aiortc is not available") return await websocket.accept() session_id = str(uuid.uuid4()) # Create WebRTC peer connection pc = RTCPeerConnection() # Create transport and session transport = WebRtcTransport(websocket, pc) session = Session(session_id, transport) active_sessions[session_id] = session logger.info(f"WebRTC connection established: {session_id}") last_received_at: List[float] = [time.monotonic()] last_heartbeat_at: List[float] = [0.0] hb_task = asyncio.create_task( heartbeat_and_timeout_task( transport, session, session_id, last_received_at, last_heartbeat_at, settings.inactivity_timeout_sec, settings.heartbeat_interval_sec, ) ) # Track handler for incoming audio @pc.on("track") def on_track(track): logger.info(f"Track received: {track.kind}") if track.kind == "audio": # Wrap track with resampler wrapped_track = Resampled16kTrack(track) # Create task to pull audio from track async def pull_audio(): try: while True: frame = await wrapped_track.recv() # Convert frame to bytes pcm_bytes = frame.to_ndarray().tobytes() # Feed to session await session.handle_audio(pcm_bytes) except Exception as e: logger.error(f"Error pulling audio from track: {e}") asyncio.create_task(pull_audio()) @pc.on("connectionstatechange") async def on_connectionstatechange(): logger.info(f"Connection state: {pc.connectionState}") if pc.connectionState == "failed" or pc.connectionState == "closed": await session.cleanup() try: # Signaling loop while True: message = await websocket.receive() if "text" not in message: continue last_received_at[0] = time.monotonic() data = json.loads(message["text"]) # Handle SDP offer/answer if "sdp" in data and "type" in data: logger.info(f"Received SDP {data['type']}") # Set remote description offer = RTCSessionDescription(sdp=data["sdp"], type=data["type"]) await pc.setRemoteDescription(offer) # Create and set local description if data["type"] == "offer": answer = await pc.createAnswer() await pc.setLocalDescription(answer) # Send answer back await websocket.send_text(json.dumps({ "event": "answer", "trackId": session_id, "timestamp": int(asyncio.get_event_loop().time() * 1000), "sdp": pc.localDescription.sdp })) logger.info(f"Sent SDP answer") else: # Handle other commands await session.handle_text(message["text"]) except WebSocketDisconnect: logger.info(f"WebRTC WebSocket disconnected: {session_id}") except Exception as e: logger.error(f"WebRTC error: {e}", exc_info=True) finally: hb_task.cancel() try: await hb_task except asyncio.CancelledError: pass # Cleanup await pc.close() if session_id in active_sessions: await session.cleanup() del active_sessions[session_id] logger.info(f"WebRTC session {session_id} removed") @app.on_event("startup") async def startup_event(): """Run on application startup.""" logger.info("Starting Python Active-Call server") logger.info(f"Server: {settings.host}:{settings.port}") logger.info(f"Sample rate: {settings.sample_rate} Hz") logger.info(f"VAD model: {settings.vad_model_path}") @app.on_event("shutdown") async def shutdown_event(): """Run on application shutdown.""" logger.info("Shutting down Python Active-Call server") # Cleanup all sessions for session_id, session in active_sessions.items(): await session.cleanup() # Close event bus event_bus = get_event_bus() await event_bus.close() reset_event_bus() logger.info("Server shutdown complete") if __name__ == "__main__": import uvicorn # Create logs directory import os os.makedirs("logs", exist_ok=True) # Run server uvicorn.run( "app.main:app", host=settings.host, port=settings.port, reload=True, log_level=settings.log_level.lower() )