"""Transport layer for WebSocket and WebRTC communication.""" import asyncio import json from abc import ABC, abstractmethod from typing import Optional from fastapi import WebSocket from loguru import logger # Try to import aiortc (optional for WebRTC functionality) try: from aiortc import RTCPeerConnection AIORTC_AVAILABLE = True except ImportError: AIORTC_AVAILABLE = False RTCPeerConnection = None # Type hint placeholder class BaseTransport(ABC): """ Abstract base class for transports. All transports must implement send_event and send_audio methods. """ @abstractmethod async def send_event(self, event: dict) -> None: """ Send a JSON event to the client. Args: event: Event data as dictionary """ pass @abstractmethod async def send_audio(self, pcm_bytes: bytes) -> None: """ Send audio data to the client. Args: pcm_bytes: PCM audio data (16-bit, mono, 16kHz) """ pass @abstractmethod async def close(self) -> None: """Close the transport and cleanup resources.""" pass class SocketTransport(BaseTransport): """ WebSocket transport for raw audio streaming. Handles mixed text/binary frames over WebSocket connection. Uses asyncio.Lock to prevent frame interleaving. """ def __init__(self, websocket: WebSocket): """ Initialize WebSocket transport. Args: websocket: FastAPI WebSocket connection """ self.ws = websocket self.lock = asyncio.Lock() # Prevent frame interleaving self._closed = False async def send_event(self, event: dict) -> None: """ Send a JSON event via WebSocket. Args: event: Event data as dictionary """ if self._closed: logger.warning("Attempted to send event on closed transport") return async with self.lock: try: await self.ws.send_text(json.dumps(event)) logger.debug(f"Sent event: {event.get('event', 'unknown')}") except Exception as e: logger.error(f"Error sending event: {e}") self._closed = True async def send_audio(self, pcm_bytes: bytes) -> None: """ Send PCM audio data via WebSocket. Args: pcm_bytes: PCM audio data (16-bit, mono, 16kHz) """ if self._closed: logger.warning("Attempted to send audio on closed transport") return async with self.lock: try: await self.ws.send_bytes(pcm_bytes) except Exception as e: logger.error(f"Error sending audio: {e}") self._closed = True async def close(self) -> None: """Close the WebSocket connection.""" self._closed = True try: await self.ws.close() except Exception as e: logger.error(f"Error closing WebSocket: {e}") @property def is_closed(self) -> bool: """Check if the transport is closed.""" return self._closed class WebRtcTransport(BaseTransport): """ WebRTC transport for WebRTC audio streaming. Uses WebSocket for signaling and RTCPeerConnection for media. """ def __init__(self, websocket: WebSocket, pc): """ Initialize WebRTC transport. Args: websocket: FastAPI WebSocket connection for signaling pc: RTCPeerConnection for media transport """ if not AIORTC_AVAILABLE: raise RuntimeError("aiortc is not available - WebRTC transport cannot be used") self.ws = websocket self.pc = pc self.outbound_track = None # MediaStreamTrack for outbound audio self._closed = False async def send_event(self, event: dict) -> None: """ Send a JSON event via WebSocket signaling. Args: event: Event data as dictionary """ if self._closed: logger.warning("Attempted to send event on closed transport") return try: await self.ws.send_text(json.dumps(event)) logger.debug(f"Sent event: {event.get('event', 'unknown')}") except Exception as e: logger.error(f"Error sending event: {e}") self._closed = True async def send_audio(self, pcm_bytes: bytes) -> None: """ Send audio data via WebRTC track. Note: In WebRTC, you don't send bytes directly. You push frames to a MediaStreamTrack that the peer connection is reading. Args: pcm_bytes: PCM audio data (16-bit, mono, 16kHz) """ if self._closed: logger.warning("Attempted to send audio on closed transport") return # This would require a custom MediaStreamTrack implementation # For now, we'll log this as a placeholder logger.debug(f"Audio bytes queued for WebRTC track: {len(pcm_bytes)} bytes") # TODO: Implement outbound audio track if needed # if self.outbound_track: # await self.outbound_track.add_frame(pcm_bytes) async def close(self) -> None: """Close the WebRTC connection.""" self._closed = True try: await self.pc.close() await self.ws.close() except Exception as e: logger.error(f"Error closing WebRTC transport: {e}") @property def is_closed(self) -> bool: """Check if the transport is closed.""" return self._closed def set_outbound_track(self, track): """ Set the outbound audio track for sending audio to client. Args: track: MediaStreamTrack for outbound audio """ self.outbound_track = track logger.debug("Set outbound track for WebRTC transport")