"""Transport layer for WebSocket and WebRTC communication.""" import asyncio import json from abc import ABC, abstractmethod from typing import Optional from fastapi import WebSocket from starlette.websockets import WebSocketState from loguru import logger # Try to import aiortc (optional for WebRTC functionality) try: from aiortc import RTCPeerConnection AIORTC_AVAILABLE = True except ImportError: AIORTC_AVAILABLE = False RTCPeerConnection = None # Type hint placeholder class BaseTransport(ABC): """ Abstract base class for transports. All transports must implement send_event and send_audio methods. """ @abstractmethod async def send_event(self, event: dict) -> None: """ Send a JSON event to the client. Args: event: Event data as dictionary """ pass @abstractmethod async def send_audio(self, pcm_bytes: bytes) -> None: """ Send audio data to the client. Args: pcm_bytes: PCM audio data (16-bit, mono, 16kHz) """ pass @abstractmethod async def close(self) -> None: """Close the transport and cleanup resources.""" pass class SocketTransport(BaseTransport): """ WebSocket transport for raw audio streaming. Handles mixed text/binary frames over WebSocket connection. Uses asyncio.Lock to prevent frame interleaving. """ def __init__(self, websocket: WebSocket): """ Initialize WebSocket transport. Args: websocket: FastAPI WebSocket connection """ self.ws = websocket self.lock = asyncio.Lock() # Prevent frame interleaving self._closed = False def _ws_disconnected(self) -> bool: """Best-effort check for websocket disconnection state.""" return ( self.ws.client_state == WebSocketState.DISCONNECTED or self.ws.application_state == WebSocketState.DISCONNECTED ) async def send_event(self, event: dict) -> None: """ Send a JSON event via WebSocket. Args: event: Event data as dictionary """ if self._closed or self._ws_disconnected(): logger.warning("Attempted to send event on closed transport") self._closed = True return async with self.lock: try: await self.ws.send_text(json.dumps(event)) logger.debug(f"Sent event: {event.get('event', 'unknown')}") except RuntimeError as e: self._closed = True if self._ws_disconnected() or "close message has been sent" in str(e): logger.debug(f"Skip sending event on closed websocket: {e!r}") return logger.error(f"Error sending event: {e!r}") except Exception as e: self._closed = True if self._ws_disconnected(): logger.debug(f"Skip sending event on disconnected websocket: {e!r}") return logger.error(f"Error sending event: {e!r}") async def send_audio(self, pcm_bytes: bytes) -> None: """ Send PCM audio data via WebSocket. Args: pcm_bytes: PCM audio data (16-bit, mono, 16kHz) """ if self._closed or self._ws_disconnected(): logger.warning("Attempted to send audio on closed transport") self._closed = True return async with self.lock: try: await self.ws.send_bytes(pcm_bytes) except RuntimeError as e: self._closed = True if self._ws_disconnected() or "close message has been sent" in str(e): logger.debug(f"Skip sending audio on closed websocket: {e!r}") return logger.error(f"Error sending audio: {e!r}") except Exception as e: self._closed = True if self._ws_disconnected(): logger.debug(f"Skip sending audio on disconnected websocket: {e!r}") return logger.error(f"Error sending audio: {e!r}") async def close(self) -> None: """Close the WebSocket connection.""" if self._closed: return self._closed = True if self._ws_disconnected(): return try: await self.ws.close() except RuntimeError as e: # Already closed by another task/path; safe to ignore. if "close message has been sent" in str(e): logger.debug(f"WebSocket already closed: {e}") return logger.error(f"Error closing WebSocket: {e}") except Exception as e: logger.error(f"Error closing WebSocket: {e}") @property def is_closed(self) -> bool: """Check if the transport is closed.""" return self._closed class WebRtcTransport(BaseTransport): """ WebRTC transport for WebRTC audio streaming. Uses WebSocket for signaling and RTCPeerConnection for media. """ def __init__(self, websocket: WebSocket, pc): """ Initialize WebRTC transport. Args: websocket: FastAPI WebSocket connection for signaling pc: RTCPeerConnection for media transport """ if not AIORTC_AVAILABLE: raise RuntimeError("aiortc is not available - WebRTC transport cannot be used") self.ws = websocket self.pc = pc self.outbound_track = None # MediaStreamTrack for outbound audio self._closed = False async def send_event(self, event: dict) -> None: """ Send a JSON event via WebSocket signaling. Args: event: Event data as dictionary """ if self._closed: logger.warning("Attempted to send event on closed transport") return try: await self.ws.send_text(json.dumps(event)) logger.debug(f"Sent event: {event.get('event', 'unknown')}") except Exception as e: logger.error(f"Error sending event: {e}") self._closed = True async def send_audio(self, pcm_bytes: bytes) -> None: """ Send audio data via WebRTC track. Note: In WebRTC, you don't send bytes directly. You push frames to a MediaStreamTrack that the peer connection is reading. Args: pcm_bytes: PCM audio data (16-bit, mono, 16kHz) """ if self._closed: logger.warning("Attempted to send audio on closed transport") return # This would require a custom MediaStreamTrack implementation # For now, we'll log this as a placeholder logger.debug(f"Audio bytes queued for WebRTC track: {len(pcm_bytes)} bytes") # TODO: Implement outbound audio track if needed # if self.outbound_track: # await self.outbound_track.add_frame(pcm_bytes) async def close(self) -> None: """Close the WebRTC connection.""" self._closed = True try: await self.pc.close() await self.ws.close() except Exception as e: logger.error(f"Error closing WebRTC transport: {e}") @property def is_closed(self) -> bool: """Check if the transport is closed.""" return self._closed def set_outbound_track(self, track): """ Set the outbound audio track for sending audio to client. Args: track: MediaStreamTrack for outbound audio """ self.outbound_track = track logger.debug("Set outbound track for WebRTC transport")