diff --git a/engine/app/main.py b/engine/app/main.py index 593d534..259204c 100644 --- a/engine/app/main.py +++ b/engine/app/main.py @@ -190,6 +190,12 @@ async def websocket_endpoint(websocket: WebSocket): # Receive loop while True: message = await websocket.receive() + message_type = message.get("type") + + if message_type == "websocket.disconnect": + logger.info(f"WebSocket disconnected: {session_id}") + break + last_received_at[0] = time.monotonic() # Handle binary audio data diff --git a/engine/core/session.py b/engine/core/session.py index 460cd6d..6c59bd7 100644 --- a/engine/core/session.py +++ b/engine/core/session.py @@ -1,5 +1,6 @@ """Session management for active calls.""" +import asyncio import uuid import json import time @@ -78,6 +79,8 @@ class Session: self._history_turn_index: int = 0 self._history_call_started_mono: Optional[float] = None self._history_finalized: bool = False + self._cleanup_lock = asyncio.Lock() + self._cleaned_up = False self.pipeline.conversation.on_turn_complete(self._on_turn_complete) @@ -288,10 +291,15 @@ class Session: async def cleanup(self) -> None: """Cleanup session resources.""" - logger.info(f"Session {self.id} cleaning up") - await self._finalize_history(status="connected") - await self.pipeline.cleanup() - await self.transport.close() + async with self._cleanup_lock: + if self._cleaned_up: + return + + self._cleaned_up = True + logger.info(f"Session {self.id} cleaning up") + await self._finalize_history(status="connected") + await self.pipeline.cleanup() + await self.transport.close() async def _start_history_bridge(self, metadata: Dict[str, Any]) -> None: """Initialize backend history call record for this session.""" diff --git a/engine/core/transports.py b/engine/core/transports.py index 6945225..3df04ba 100644 --- a/engine/core/transports.py +++ b/engine/core/transports.py @@ -5,6 +5,7 @@ import json from abc import ABC, abstractmethod from typing import Optional from fastapi import WebSocket +from starlette.websockets import WebSocketState from loguru import logger # Try to import aiortc (optional for WebRTC functionality) @@ -107,9 +108,24 @@ class SocketTransport(BaseTransport): async def close(self) -> None: """Close the WebSocket connection.""" + if self._closed: + return + self._closed = True + if ( + self.ws.client_state == WebSocketState.DISCONNECTED + or self.ws.application_state == WebSocketState.DISCONNECTED + ): + return + try: await self.ws.close() + except RuntimeError as e: + # Already closed by another task/path; safe to ignore. + if "close message has been sent" in str(e): + logger.debug(f"WebSocket already closed: {e}") + return + logger.error(f"Error closing WebSocket: {e}") except Exception as e: logger.error(f"Error closing WebSocket: {e}")