391 lines
11 KiB
Python
391 lines
11 KiB
Python
"""FastAPI application with WebSocket and WebRTC endpoints."""
|
|
|
|
import asyncio
|
|
import json
|
|
import time
|
|
import uuid
|
|
from pathlib import Path
|
|
from typing import Dict, Any, Optional, List
|
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import JSONResponse, FileResponse
|
|
from loguru import logger
|
|
|
|
# Try to import aiortc (optional for WebRTC functionality)
|
|
try:
|
|
from aiortc import RTCPeerConnection, RTCSessionDescription
|
|
AIORTC_AVAILABLE = True
|
|
except ImportError:
|
|
AIORTC_AVAILABLE = False
|
|
logger.warning("aiortc not available - WebRTC endpoint will be disabled")
|
|
|
|
from app.config import settings
|
|
from core.transports import SocketTransport, WebRtcTransport, BaseTransport
|
|
from core.session import Session
|
|
from processors.tracks import Resampled16kTrack
|
|
from core.events import get_event_bus, reset_event_bus
|
|
|
|
# Check interval for heartbeat/timeout (seconds)
|
|
_HEARTBEAT_CHECK_INTERVAL_SEC = 5
|
|
|
|
|
|
async def heartbeat_and_timeout_task(
|
|
transport: BaseTransport,
|
|
session: Session,
|
|
session_id: str,
|
|
last_received_at: List[float],
|
|
last_heartbeat_at: List[float],
|
|
inactivity_timeout_sec: int,
|
|
heartbeat_interval_sec: int,
|
|
) -> None:
|
|
"""
|
|
Background task: send heartBeat every ~heartbeat_interval_sec and close
|
|
connection if no message from client for inactivity_timeout_sec.
|
|
"""
|
|
while True:
|
|
await asyncio.sleep(_HEARTBEAT_CHECK_INTERVAL_SEC)
|
|
if transport.is_closed:
|
|
break
|
|
now = time.monotonic()
|
|
if now - last_received_at[0] > inactivity_timeout_sec:
|
|
logger.info(f"Session {session_id}: {inactivity_timeout_sec}s no message, closing")
|
|
await session.cleanup()
|
|
break
|
|
if now - last_heartbeat_at[0] >= heartbeat_interval_sec:
|
|
try:
|
|
await transport.send_event({
|
|
"event": "heartBeat",
|
|
"timestamp": int(time.time() * 1000),
|
|
})
|
|
last_heartbeat_at[0] = now
|
|
except Exception as e:
|
|
logger.debug(f"Session {session_id}: heartbeat send failed: {e}")
|
|
break
|
|
|
|
|
|
# Initialize FastAPI
|
|
app = FastAPI(title="Python Active-Call", version="0.1.0")
|
|
_WEB_CLIENT_PATH = Path(__file__).resolve().parent.parent / "examples" / "web_client.html"
|
|
|
|
# Configure CORS
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=settings.cors_origins_list,
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# Active sessions storage
|
|
active_sessions: Dict[str, Session] = {}
|
|
|
|
# Configure logging
|
|
logger.remove()
|
|
logger.add(
|
|
"./logs/active_call_{time}.log",
|
|
rotation="1 day",
|
|
retention="7 days",
|
|
level=settings.log_level,
|
|
format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}"
|
|
)
|
|
logger.add(
|
|
lambda msg: print(msg, end=""),
|
|
level=settings.log_level,
|
|
format="{time:HH:mm:ss} | {level: <8} | {message}"
|
|
)
|
|
|
|
|
|
@app.get("/health")
|
|
async def health_check():
|
|
"""Health check endpoint."""
|
|
return {"status": "healthy", "sessions": len(active_sessions)}
|
|
|
|
|
|
@app.get("/")
|
|
async def web_client_root():
|
|
"""Serve the web client."""
|
|
if not _WEB_CLIENT_PATH.exists():
|
|
raise HTTPException(status_code=404, detail="Web client not found")
|
|
return FileResponse(_WEB_CLIENT_PATH)
|
|
|
|
|
|
@app.get("/client")
|
|
async def web_client_alias():
|
|
"""Alias for the web client."""
|
|
if not _WEB_CLIENT_PATH.exists():
|
|
raise HTTPException(status_code=404, detail="Web client not found")
|
|
return FileResponse(_WEB_CLIENT_PATH)
|
|
|
|
|
|
|
|
|
|
@app.get("/iceservers")
|
|
async def get_ice_servers():
|
|
"""Get ICE servers configuration for WebRTC."""
|
|
return settings.ice_servers_list
|
|
|
|
|
|
@app.get("/call/lists")
|
|
async def list_calls():
|
|
"""List all active calls."""
|
|
return {
|
|
"calls": [
|
|
{
|
|
"id": session_id,
|
|
"state": session.state,
|
|
"created_at": session.created_at
|
|
}
|
|
for session_id, session in active_sessions.items()
|
|
]
|
|
}
|
|
|
|
|
|
@app.post("/call/kill/{session_id}")
|
|
async def kill_call(session_id: str):
|
|
"""Kill a specific active call."""
|
|
if session_id not in active_sessions:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
session = active_sessions[session_id]
|
|
await session.cleanup()
|
|
del active_sessions[session_id]
|
|
|
|
return True
|
|
|
|
|
|
@app.websocket("/ws")
|
|
async def websocket_endpoint(websocket: WebSocket):
|
|
"""
|
|
WebSocket endpoint for raw audio streaming.
|
|
|
|
Accepts mixed text/binary frames:
|
|
- Text frames: JSON commands
|
|
- Binary frames: PCM audio data (16kHz, 16-bit, mono)
|
|
"""
|
|
await websocket.accept()
|
|
session_id = str(uuid.uuid4())
|
|
|
|
# Create transport and session
|
|
transport = SocketTransport(websocket)
|
|
session = Session(session_id, transport)
|
|
active_sessions[session_id] = session
|
|
|
|
logger.info(f"WebSocket connection established: {session_id}")
|
|
|
|
last_received_at: List[float] = [time.monotonic()]
|
|
last_heartbeat_at: List[float] = [0.0]
|
|
hb_task = asyncio.create_task(
|
|
heartbeat_and_timeout_task(
|
|
transport,
|
|
session,
|
|
session_id,
|
|
last_received_at,
|
|
last_heartbeat_at,
|
|
settings.inactivity_timeout_sec,
|
|
settings.heartbeat_interval_sec,
|
|
)
|
|
)
|
|
|
|
try:
|
|
# Receive loop
|
|
while True:
|
|
message = await websocket.receive()
|
|
last_received_at[0] = time.monotonic()
|
|
|
|
# Handle binary audio data
|
|
if "bytes" in message:
|
|
await session.handle_audio(message["bytes"])
|
|
|
|
# Handle text commands
|
|
elif "text" in message:
|
|
await session.handle_text(message["text"])
|
|
|
|
except WebSocketDisconnect:
|
|
logger.info(f"WebSocket disconnected: {session_id}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"WebSocket error: {e}", exc_info=True)
|
|
|
|
finally:
|
|
hb_task.cancel()
|
|
try:
|
|
await hb_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
# Cleanup session
|
|
if session_id in active_sessions:
|
|
await session.cleanup()
|
|
del active_sessions[session_id]
|
|
|
|
logger.info(f"Session {session_id} removed")
|
|
|
|
|
|
@app.websocket("/webrtc")
|
|
async def webrtc_endpoint(websocket: WebSocket):
|
|
"""
|
|
WebRTC endpoint for WebRTC audio streaming.
|
|
|
|
Uses WebSocket for signaling (SDP exchange) and WebRTC for media transport.
|
|
"""
|
|
# Check if aiortc is available
|
|
if not AIORTC_AVAILABLE:
|
|
await websocket.close(code=1011, reason="WebRTC not available - aiortc/av not installed")
|
|
logger.warning("WebRTC connection attempted but aiortc is not available")
|
|
return
|
|
await websocket.accept()
|
|
session_id = str(uuid.uuid4())
|
|
|
|
# Create WebRTC peer connection
|
|
pc = RTCPeerConnection()
|
|
|
|
# Create transport and session
|
|
transport = WebRtcTransport(websocket, pc)
|
|
session = Session(session_id, transport)
|
|
active_sessions[session_id] = session
|
|
|
|
logger.info(f"WebRTC connection established: {session_id}")
|
|
|
|
last_received_at: List[float] = [time.monotonic()]
|
|
last_heartbeat_at: List[float] = [0.0]
|
|
hb_task = asyncio.create_task(
|
|
heartbeat_and_timeout_task(
|
|
transport,
|
|
session,
|
|
session_id,
|
|
last_received_at,
|
|
last_heartbeat_at,
|
|
settings.inactivity_timeout_sec,
|
|
settings.heartbeat_interval_sec,
|
|
)
|
|
)
|
|
|
|
# Track handler for incoming audio
|
|
@pc.on("track")
|
|
def on_track(track):
|
|
logger.info(f"Track received: {track.kind}")
|
|
|
|
if track.kind == "audio":
|
|
# Wrap track with resampler
|
|
wrapped_track = Resampled16kTrack(track)
|
|
|
|
# Create task to pull audio from track
|
|
async def pull_audio():
|
|
try:
|
|
while True:
|
|
frame = await wrapped_track.recv()
|
|
# Convert frame to bytes
|
|
pcm_bytes = frame.to_ndarray().tobytes()
|
|
# Feed to session
|
|
await session.handle_audio(pcm_bytes)
|
|
except Exception as e:
|
|
logger.error(f"Error pulling audio from track: {e}")
|
|
|
|
asyncio.create_task(pull_audio())
|
|
|
|
@pc.on("connectionstatechange")
|
|
async def on_connectionstatechange():
|
|
logger.info(f"Connection state: {pc.connectionState}")
|
|
if pc.connectionState == "failed" or pc.connectionState == "closed":
|
|
await session.cleanup()
|
|
|
|
try:
|
|
# Signaling loop
|
|
while True:
|
|
message = await websocket.receive()
|
|
|
|
if "text" not in message:
|
|
continue
|
|
|
|
last_received_at[0] = time.monotonic()
|
|
data = json.loads(message["text"])
|
|
|
|
# Handle SDP offer/answer
|
|
if "sdp" in data and "type" in data:
|
|
logger.info(f"Received SDP {data['type']}")
|
|
|
|
# Set remote description
|
|
offer = RTCSessionDescription(sdp=data["sdp"], type=data["type"])
|
|
await pc.setRemoteDescription(offer)
|
|
|
|
# Create and set local description
|
|
if data["type"] == "offer":
|
|
answer = await pc.createAnswer()
|
|
await pc.setLocalDescription(answer)
|
|
|
|
# Send answer back
|
|
await websocket.send_text(json.dumps({
|
|
"event": "answer",
|
|
"trackId": session_id,
|
|
"timestamp": int(asyncio.get_event_loop().time() * 1000),
|
|
"sdp": pc.localDescription.sdp
|
|
}))
|
|
|
|
logger.info(f"Sent SDP answer")
|
|
|
|
else:
|
|
# Handle other commands
|
|
await session.handle_text(message["text"])
|
|
|
|
except WebSocketDisconnect:
|
|
logger.info(f"WebRTC WebSocket disconnected: {session_id}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"WebRTC error: {e}", exc_info=True)
|
|
|
|
finally:
|
|
hb_task.cancel()
|
|
try:
|
|
await hb_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
# Cleanup
|
|
await pc.close()
|
|
if session_id in active_sessions:
|
|
await session.cleanup()
|
|
del active_sessions[session_id]
|
|
|
|
logger.info(f"WebRTC session {session_id} removed")
|
|
|
|
|
|
@app.on_event("startup")
|
|
async def startup_event():
|
|
"""Run on application startup."""
|
|
logger.info("Starting Python Active-Call server")
|
|
logger.info(f"Server: {settings.host}:{settings.port}")
|
|
logger.info(f"Sample rate: {settings.sample_rate} Hz")
|
|
logger.info(f"VAD model: {settings.vad_model_path}")
|
|
|
|
|
|
@app.on_event("shutdown")
|
|
async def shutdown_event():
|
|
"""Run on application shutdown."""
|
|
logger.info("Shutting down Python Active-Call server")
|
|
|
|
# Cleanup all sessions
|
|
for session_id, session in active_sessions.items():
|
|
await session.cleanup()
|
|
|
|
# Close event bus
|
|
event_bus = get_event_bus()
|
|
await event_bus.close()
|
|
reset_event_bus()
|
|
|
|
logger.info("Server shutdown complete")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
# Create logs directory
|
|
import os
|
|
os.makedirs("logs", exist_ok=True)
|
|
|
|
# Run server
|
|
uvicorn.run(
|
|
"app.main:app",
|
|
host=settings.host,
|
|
port=settings.port,
|
|
reload=True,
|
|
log_level=settings.log_level.lower()
|
|
)
|