Integrate eou and vad
This commit is contained in:
290
app/main.py
Normal file
290
app/main.py
Normal file
@@ -0,0 +1,290 @@
|
||||
"""FastAPI application with WebSocket and WebRTC endpoints."""
|
||||
|
||||
import uuid
|
||||
import json
|
||||
from typing import Dict, Any, Optional
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from loguru import logger
|
||||
|
||||
# Try to import aiortc (optional for WebRTC functionality)
|
||||
try:
|
||||
from aiortc import RTCPeerConnection, RTCSessionDescription
|
||||
AIORTC_AVAILABLE = True
|
||||
except ImportError:
|
||||
AIORTC_AVAILABLE = False
|
||||
logger.warning("aiortc not available - WebRTC endpoint will be disabled")
|
||||
|
||||
from app.config import settings
|
||||
from core.transports import SocketTransport, WebRtcTransport
|
||||
from core.session import Session
|
||||
from processors.tracks import Resampled16kTrack
|
||||
from core.events import get_event_bus, reset_event_bus
|
||||
|
||||
# Initialize FastAPI
|
||||
app = FastAPI(title="Python Active-Call", version="0.1.0")
|
||||
|
||||
# Configure CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.cors_origins_list,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Active sessions storage
|
||||
active_sessions: Dict[str, Session] = {}
|
||||
|
||||
# Configure logging
|
||||
logger.remove()
|
||||
logger.add(
|
||||
"../logs/active_call_{time}.log",
|
||||
rotation="1 day",
|
||||
retention="7 days",
|
||||
level=settings.log_level,
|
||||
format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}"
|
||||
)
|
||||
logger.add(
|
||||
lambda msg: print(msg, end=""),
|
||||
level=settings.log_level,
|
||||
format="{time:HH:mm:ss} | {level: <8} | {message}"
|
||||
)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint."""
|
||||
return {"status": "healthy", "sessions": len(active_sessions)}
|
||||
|
||||
|
||||
@app.get("/iceservers")
|
||||
async def get_ice_servers():
|
||||
"""Get ICE servers configuration for WebRTC."""
|
||||
return settings.ice_servers_list
|
||||
|
||||
|
||||
@app.get("/call/lists")
|
||||
async def list_calls():
|
||||
"""List all active calls."""
|
||||
return {
|
||||
"calls": [
|
||||
{
|
||||
"id": session_id,
|
||||
"state": session.state,
|
||||
"created_at": session.created_at
|
||||
}
|
||||
for session_id, session in active_sessions.items()
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@app.post("/call/kill/{session_id}")
|
||||
async def kill_call(session_id: str):
|
||||
"""Kill a specific active call."""
|
||||
if session_id not in active_sessions:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
session = active_sessions[session_id]
|
||||
await session.cleanup()
|
||||
del active_sessions[session_id]
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
"""
|
||||
WebSocket endpoint for raw audio streaming.
|
||||
|
||||
Accepts mixed text/binary frames:
|
||||
- Text frames: JSON commands
|
||||
- Binary frames: PCM audio data (16kHz, 16-bit, mono)
|
||||
"""
|
||||
await websocket.accept()
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
# Create transport and session
|
||||
transport = SocketTransport(websocket)
|
||||
session = Session(session_id, transport)
|
||||
active_sessions[session_id] = session
|
||||
|
||||
logger.info(f"WebSocket connection established: {session_id}")
|
||||
|
||||
try:
|
||||
# Receive loop
|
||||
while True:
|
||||
message = await websocket.receive()
|
||||
|
||||
# Handle binary audio data
|
||||
if "bytes" in message:
|
||||
await session.handle_audio(message["bytes"])
|
||||
|
||||
# Handle text commands
|
||||
elif "text" in message:
|
||||
await session.handle_text(message["text"])
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket disconnected: {session_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket error: {e}", exc_info=True)
|
||||
|
||||
finally:
|
||||
# Cleanup session
|
||||
if session_id in active_sessions:
|
||||
await session.cleanup()
|
||||
del active_sessions[session_id]
|
||||
|
||||
logger.info(f"Session {session_id} removed")
|
||||
|
||||
|
||||
@app.websocket("/webrtc")
|
||||
async def webrtc_endpoint(websocket: WebSocket):
|
||||
"""
|
||||
WebRTC endpoint for WebRTC audio streaming.
|
||||
|
||||
Uses WebSocket for signaling (SDP exchange) and WebRTC for media transport.
|
||||
"""
|
||||
# Check if aiortc is available
|
||||
if not AIORTC_AVAILABLE:
|
||||
await websocket.close(code=1011, reason="WebRTC not available - aiortc/av not installed")
|
||||
logger.warning("WebRTC connection attempted but aiortc is not available")
|
||||
return
|
||||
await websocket.accept()
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
# Create WebRTC peer connection
|
||||
pc = RTCPeerConnection()
|
||||
|
||||
# Create transport and session
|
||||
transport = WebRtcTransport(websocket, pc)
|
||||
session = Session(session_id, transport)
|
||||
active_sessions[session_id] = session
|
||||
|
||||
logger.info(f"WebRTC connection established: {session_id}")
|
||||
|
||||
# Track handler for incoming audio
|
||||
@pc.on("track")
|
||||
def on_track(track):
|
||||
logger.info(f"Track received: {track.kind}")
|
||||
|
||||
if track.kind == "audio":
|
||||
# Wrap track with resampler
|
||||
wrapped_track = Resampled16kTrack(track)
|
||||
|
||||
# Create task to pull audio from track
|
||||
async def pull_audio():
|
||||
try:
|
||||
while True:
|
||||
frame = await wrapped_track.recv()
|
||||
# Convert frame to bytes
|
||||
pcm_bytes = frame.to_ndarray().tobytes()
|
||||
# Feed to session
|
||||
await session.handle_audio(pcm_bytes)
|
||||
except Exception as e:
|
||||
logger.error(f"Error pulling audio from track: {e}")
|
||||
|
||||
asyncio.create_task(pull_audio())
|
||||
|
||||
@pc.on("connectionstatechange")
|
||||
async def on_connectionstatechange():
|
||||
logger.info(f"Connection state: {pc.connectionState}")
|
||||
if pc.connectionState == "failed" or pc.connectionState == "closed":
|
||||
await session.cleanup()
|
||||
|
||||
try:
|
||||
# Signaling loop
|
||||
while True:
|
||||
message = await websocket.receive()
|
||||
|
||||
if "text" not in message:
|
||||
continue
|
||||
|
||||
data = json.loads(message["text"])
|
||||
|
||||
# Handle SDP offer/answer
|
||||
if "sdp" in data and "type" in data:
|
||||
logger.info(f"Received SDP {data['type']}")
|
||||
|
||||
# Set remote description
|
||||
offer = RTCSessionDescription(sdp=data["sdp"], type=data["type"])
|
||||
await pc.setRemoteDescription(offer)
|
||||
|
||||
# Create and set local description
|
||||
if data["type"] == "offer":
|
||||
answer = await pc.createAnswer()
|
||||
await pc.setLocalDescription(answer)
|
||||
|
||||
# Send answer back
|
||||
await websocket.send_text(json.dumps({
|
||||
"event": "answer",
|
||||
"trackId": session_id,
|
||||
"timestamp": int(asyncio.get_event_loop().time() * 1000),
|
||||
"sdp": pc.localDescription.sdp
|
||||
}))
|
||||
|
||||
logger.info(f"Sent SDP answer")
|
||||
|
||||
else:
|
||||
# Handle other commands
|
||||
await session.handle_text(message["text"])
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebRTC WebSocket disconnected: {session_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"WebRTC error: {e}", exc_info=True)
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
await pc.close()
|
||||
if session_id in active_sessions:
|
||||
await session.cleanup()
|
||||
del active_sessions[session_id]
|
||||
|
||||
logger.info(f"WebRTC session {session_id} removed")
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""Run on application startup."""
|
||||
logger.info("Starting Python Active-Call server")
|
||||
logger.info(f"Server: {settings.host}:{settings.port}")
|
||||
logger.info(f"Sample rate: {settings.sample_rate} Hz")
|
||||
logger.info(f"VAD model: {settings.vad_model_path}")
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
"""Run on application shutdown."""
|
||||
logger.info("Shutting down Python Active-Call server")
|
||||
|
||||
# Cleanup all sessions
|
||||
for session_id, session in active_sessions.items():
|
||||
await session.cleanup()
|
||||
|
||||
# Close event bus
|
||||
event_bus = get_event_bus()
|
||||
await event_bus.close()
|
||||
reset_event_bus()
|
||||
|
||||
logger.info("Server shutdown complete")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
# Create logs directory
|
||||
import os
|
||||
os.makedirs("logs", exist_ok=True)
|
||||
|
||||
# Run server
|
||||
uvicorn.run(
|
||||
"app.main:app",
|
||||
host=settings.host,
|
||||
port=settings.port,
|
||||
reload=True,
|
||||
log_level=settings.log_level.lower()
|
||||
)
|
||||
Reference in New Issue
Block a user