Add heartbeat

This commit is contained in:
Xin Wang
2026-02-04 23:16:30 +08:00
parent 77d54d284f
commit b72e09f263
3 changed files with 95 additions and 3 deletions

View File

@@ -84,6 +84,10 @@ class Settings(BaseSettings):
description="ICE servers configuration"
)
# WebSocket heartbeat and inactivity
inactivity_timeout_sec: int = Field(default=60, description="Close connection after no message from client (seconds)")
heartbeat_interval_sec: int = Field(default=50, description="Send heartBeat event to client every N seconds")
@property
def chunk_size_bytes(self) -> int:
"""Calculate chunk size in bytes based on sample rate and duration."""

View File

@@ -1,8 +1,10 @@
"""FastAPI application with WebSocket and WebRTC endpoints."""
import uuid
import asyncio
import json
from typing import Dict, Any, Optional
import time
import uuid
from typing import Dict, Any, Optional, List
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
@@ -17,11 +19,49 @@ except ImportError:
logger.warning("aiortc not available - WebRTC endpoint will be disabled")
from app.config import settings
from core.transports import SocketTransport, WebRtcTransport
from core.transports import SocketTransport, WebRtcTransport, BaseTransport
from core.session import Session
from processors.tracks import Resampled16kTrack
from core.events import get_event_bus, reset_event_bus
# Check interval for heartbeat/timeout (seconds)
_HEARTBEAT_CHECK_INTERVAL_SEC = 5
async def heartbeat_and_timeout_task(
transport: BaseTransport,
session: Session,
session_id: str,
last_received_at: List[float],
last_heartbeat_at: List[float],
inactivity_timeout_sec: int,
heartbeat_interval_sec: int,
) -> None:
"""
Background task: send heartBeat every ~heartbeat_interval_sec and close
connection if no message from client for inactivity_timeout_sec.
"""
while True:
await asyncio.sleep(_HEARTBEAT_CHECK_INTERVAL_SEC)
if transport.is_closed:
break
now = time.monotonic()
if now - last_received_at[0] > inactivity_timeout_sec:
logger.info(f"Session {session_id}: {inactivity_timeout_sec}s no message, closing")
await session.cleanup()
break
if now - last_heartbeat_at[0] >= heartbeat_interval_sec:
try:
await transport.send_event({
"event": "heartBeat",
"timestamp": int(time.time() * 1000),
})
last_heartbeat_at[0] = now
except Exception as e:
logger.debug(f"Session {session_id}: heartbeat send failed: {e}")
break
# Initialize FastAPI
app = FastAPI(title="Python Active-Call", version="0.1.0")
@@ -112,10 +152,25 @@ async def websocket_endpoint(websocket: WebSocket):
logger.info(f"WebSocket connection established: {session_id}")
last_received_at: List[float] = [time.monotonic()]
last_heartbeat_at: List[float] = [0.0]
hb_task = asyncio.create_task(
heartbeat_and_timeout_task(
transport,
session,
session_id,
last_received_at,
last_heartbeat_at,
settings.inactivity_timeout_sec,
settings.heartbeat_interval_sec,
)
)
try:
# Receive loop
while True:
message = await websocket.receive()
last_received_at[0] = time.monotonic()
# Handle binary audio data
if "bytes" in message:
@@ -132,6 +187,11 @@ async def websocket_endpoint(websocket: WebSocket):
logger.error(f"WebSocket error: {e}", exc_info=True)
finally:
hb_task.cancel()
try:
await hb_task
except asyncio.CancelledError:
pass
# Cleanup session
if session_id in active_sessions:
await session.cleanup()
@@ -165,6 +225,20 @@ async def webrtc_endpoint(websocket: WebSocket):
logger.info(f"WebRTC connection established: {session_id}")
last_received_at: List[float] = [time.monotonic()]
last_heartbeat_at: List[float] = [0.0]
hb_task = asyncio.create_task(
heartbeat_and_timeout_task(
transport,
session,
session_id,
last_received_at,
last_heartbeat_at,
settings.inactivity_timeout_sec,
settings.heartbeat_interval_sec,
)
)
# Track handler for incoming audio
@pc.on("track")
def on_track(track):
@@ -202,6 +276,7 @@ async def webrtc_endpoint(websocket: WebSocket):
if "text" not in message:
continue
last_received_at[0] = time.monotonic()
data = json.loads(message["text"])
# Handle SDP offer/answer
@@ -238,6 +313,11 @@ async def webrtc_endpoint(websocket: WebSocket):
logger.error(f"WebRTC error: {e}", exc_info=True)
finally:
hb_task.cancel()
try:
await hb_task
except asyncio.CancelledError:
pass
# Cleanup
await pc.close()
if session_id in active_sessions:

View File

@@ -179,6 +179,13 @@ class DTMFEvent(BaseEvent):
digit: str = Field(..., description="DTMF digit (0-9, *, #, A-D)")
class HeartBeatEvent(BaseModel):
"""Server-to-client heartbeat to keep connection alive."""
event: str = Field(default="heartBeat", description="Event type")
timestamp: int = Field(default_factory=current_timestamp_ms, description="Event timestamp in milliseconds")
# Event type mapping
EVENT_TYPES = {
"incoming": IncomingEvent,
@@ -198,6 +205,7 @@ EVENT_TYPES = {
"metrics": MetricsEvent,
"addHistory": AddHistoryEvent,
"dtmf": DTMFEvent,
"heartBeat": HeartBeatEvent,
}