Add heartbeat

This commit is contained in:
Xin Wang
2026-02-04 23:16:30 +08:00
parent 77d54d284f
commit b72e09f263
3 changed files with 95 additions and 3 deletions

View File

@@ -84,6 +84,10 @@ class Settings(BaseSettings):
description="ICE servers configuration" description="ICE servers configuration"
) )
# WebSocket heartbeat and inactivity
inactivity_timeout_sec: int = Field(default=60, description="Close connection after no message from client (seconds)")
heartbeat_interval_sec: int = Field(default=50, description="Send heartBeat event to client every N seconds")
@property @property
def chunk_size_bytes(self) -> int: def chunk_size_bytes(self) -> int:
"""Calculate chunk size in bytes based on sample rate and duration.""" """Calculate chunk size in bytes based on sample rate and duration."""

View File

@@ -1,8 +1,10 @@
"""FastAPI application with WebSocket and WebRTC endpoints.""" """FastAPI application with WebSocket and WebRTC endpoints."""
import uuid import asyncio
import json import json
from typing import Dict, Any, Optional import time
import uuid
from typing import Dict, Any, Optional, List
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
@@ -17,11 +19,49 @@ except ImportError:
logger.warning("aiortc not available - WebRTC endpoint will be disabled") logger.warning("aiortc not available - WebRTC endpoint will be disabled")
from app.config import settings from app.config import settings
from core.transports import SocketTransport, WebRtcTransport from core.transports import SocketTransport, WebRtcTransport, BaseTransport
from core.session import Session from core.session import Session
from processors.tracks import Resampled16kTrack from processors.tracks import Resampled16kTrack
from core.events import get_event_bus, reset_event_bus from core.events import get_event_bus, reset_event_bus
# Check interval for heartbeat/timeout (seconds)
_HEARTBEAT_CHECK_INTERVAL_SEC = 5
async def heartbeat_and_timeout_task(
transport: BaseTransport,
session: Session,
session_id: str,
last_received_at: List[float],
last_heartbeat_at: List[float],
inactivity_timeout_sec: int,
heartbeat_interval_sec: int,
) -> None:
"""
Background task: send heartBeat every ~heartbeat_interval_sec and close
connection if no message from client for inactivity_timeout_sec.
"""
while True:
await asyncio.sleep(_HEARTBEAT_CHECK_INTERVAL_SEC)
if transport.is_closed:
break
now = time.monotonic()
if now - last_received_at[0] > inactivity_timeout_sec:
logger.info(f"Session {session_id}: {inactivity_timeout_sec}s no message, closing")
await session.cleanup()
break
if now - last_heartbeat_at[0] >= heartbeat_interval_sec:
try:
await transport.send_event({
"event": "heartBeat",
"timestamp": int(time.time() * 1000),
})
last_heartbeat_at[0] = now
except Exception as e:
logger.debug(f"Session {session_id}: heartbeat send failed: {e}")
break
# Initialize FastAPI # Initialize FastAPI
app = FastAPI(title="Python Active-Call", version="0.1.0") app = FastAPI(title="Python Active-Call", version="0.1.0")
@@ -112,10 +152,25 @@ async def websocket_endpoint(websocket: WebSocket):
logger.info(f"WebSocket connection established: {session_id}") logger.info(f"WebSocket connection established: {session_id}")
last_received_at: List[float] = [time.monotonic()]
last_heartbeat_at: List[float] = [0.0]
hb_task = asyncio.create_task(
heartbeat_and_timeout_task(
transport,
session,
session_id,
last_received_at,
last_heartbeat_at,
settings.inactivity_timeout_sec,
settings.heartbeat_interval_sec,
)
)
try: try:
# Receive loop # Receive loop
while True: while True:
message = await websocket.receive() message = await websocket.receive()
last_received_at[0] = time.monotonic()
# Handle binary audio data # Handle binary audio data
if "bytes" in message: if "bytes" in message:
@@ -132,6 +187,11 @@ async def websocket_endpoint(websocket: WebSocket):
logger.error(f"WebSocket error: {e}", exc_info=True) logger.error(f"WebSocket error: {e}", exc_info=True)
finally: finally:
hb_task.cancel()
try:
await hb_task
except asyncio.CancelledError:
pass
# Cleanup session # Cleanup session
if session_id in active_sessions: if session_id in active_sessions:
await session.cleanup() await session.cleanup()
@@ -165,6 +225,20 @@ async def webrtc_endpoint(websocket: WebSocket):
logger.info(f"WebRTC connection established: {session_id}") logger.info(f"WebRTC connection established: {session_id}")
last_received_at: List[float] = [time.monotonic()]
last_heartbeat_at: List[float] = [0.0]
hb_task = asyncio.create_task(
heartbeat_and_timeout_task(
transport,
session,
session_id,
last_received_at,
last_heartbeat_at,
settings.inactivity_timeout_sec,
settings.heartbeat_interval_sec,
)
)
# Track handler for incoming audio # Track handler for incoming audio
@pc.on("track") @pc.on("track")
def on_track(track): def on_track(track):
@@ -202,6 +276,7 @@ async def webrtc_endpoint(websocket: WebSocket):
if "text" not in message: if "text" not in message:
continue continue
last_received_at[0] = time.monotonic()
data = json.loads(message["text"]) data = json.loads(message["text"])
# Handle SDP offer/answer # Handle SDP offer/answer
@@ -238,6 +313,11 @@ async def webrtc_endpoint(websocket: WebSocket):
logger.error(f"WebRTC error: {e}", exc_info=True) logger.error(f"WebRTC error: {e}", exc_info=True)
finally: finally:
hb_task.cancel()
try:
await hb_task
except asyncio.CancelledError:
pass
# Cleanup # Cleanup
await pc.close() await pc.close()
if session_id in active_sessions: if session_id in active_sessions:

View File

@@ -179,6 +179,13 @@ class DTMFEvent(BaseEvent):
digit: str = Field(..., description="DTMF digit (0-9, *, #, A-D)") digit: str = Field(..., description="DTMF digit (0-9, *, #, A-D)")
class HeartBeatEvent(BaseModel):
"""Server-to-client heartbeat to keep connection alive."""
event: str = Field(default="heartBeat", description="Event type")
timestamp: int = Field(default_factory=current_timestamp_ms, description="Event timestamp in milliseconds")
# Event type mapping # Event type mapping
EVENT_TYPES = { EVENT_TYPES = {
"incoming": IncomingEvent, "incoming": IncomingEvent,
@@ -198,6 +205,7 @@ EVENT_TYPES = {
"metrics": MetricsEvent, "metrics": MetricsEvent,
"addHistory": AddHistoryEvent, "addHistory": AddHistoryEvent,
"dtmf": DTMFEvent, "dtmf": DTMFEvent,
"heartBeat": HeartBeatEvent,
} }