208 lines
5.9 KiB
Python
208 lines
5.9 KiB
Python
"""Transport layer for WebSocket and WebRTC communication."""
|
|
|
|
import asyncio
|
|
import json
|
|
from abc import ABC, abstractmethod
|
|
from typing import Optional
|
|
from fastapi import WebSocket
|
|
from loguru import logger
|
|
|
|
# Try to import aiortc (optional for WebRTC functionality)
|
|
try:
|
|
from aiortc import RTCPeerConnection
|
|
AIORTC_AVAILABLE = True
|
|
except ImportError:
|
|
AIORTC_AVAILABLE = False
|
|
RTCPeerConnection = None # Type hint placeholder
|
|
|
|
|
|
class BaseTransport(ABC):
|
|
"""
|
|
Abstract base class for transports.
|
|
|
|
All transports must implement send_event and send_audio methods.
|
|
"""
|
|
|
|
@abstractmethod
|
|
async def send_event(self, event: dict) -> None:
|
|
"""
|
|
Send a JSON event to the client.
|
|
|
|
Args:
|
|
event: Event data as dictionary
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def send_audio(self, pcm_bytes: bytes) -> None:
|
|
"""
|
|
Send audio data to the client.
|
|
|
|
Args:
|
|
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def close(self) -> None:
|
|
"""Close the transport and cleanup resources."""
|
|
pass
|
|
|
|
|
|
class SocketTransport(BaseTransport):
|
|
"""
|
|
WebSocket transport for raw audio streaming.
|
|
|
|
Handles mixed text/binary frames over WebSocket connection.
|
|
Uses asyncio.Lock to prevent frame interleaving.
|
|
"""
|
|
|
|
def __init__(self, websocket: WebSocket):
|
|
"""
|
|
Initialize WebSocket transport.
|
|
|
|
Args:
|
|
websocket: FastAPI WebSocket connection
|
|
"""
|
|
self.ws = websocket
|
|
self.lock = asyncio.Lock() # Prevent frame interleaving
|
|
self._closed = False
|
|
|
|
async def send_event(self, event: dict) -> None:
|
|
"""
|
|
Send a JSON event via WebSocket.
|
|
|
|
Args:
|
|
event: Event data as dictionary
|
|
"""
|
|
if self._closed:
|
|
logger.warning("Attempted to send event on closed transport")
|
|
return
|
|
|
|
async with self.lock:
|
|
try:
|
|
await self.ws.send_text(json.dumps(event))
|
|
logger.debug(f"Sent event: {event.get('event', 'unknown')}")
|
|
except Exception as e:
|
|
logger.error(f"Error sending event: {e}")
|
|
self._closed = True
|
|
|
|
async def send_audio(self, pcm_bytes: bytes) -> None:
|
|
"""
|
|
Send PCM audio data via WebSocket.
|
|
|
|
Args:
|
|
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
|
|
"""
|
|
if self._closed:
|
|
logger.warning("Attempted to send audio on closed transport")
|
|
return
|
|
|
|
async with self.lock:
|
|
try:
|
|
await self.ws.send_bytes(pcm_bytes)
|
|
except Exception as e:
|
|
logger.error(f"Error sending audio: {e}")
|
|
self._closed = True
|
|
|
|
async def close(self) -> None:
|
|
"""Close the WebSocket connection."""
|
|
self._closed = True
|
|
try:
|
|
await self.ws.close()
|
|
except Exception as e:
|
|
logger.error(f"Error closing WebSocket: {e}")
|
|
|
|
@property
|
|
def is_closed(self) -> bool:
|
|
"""Check if the transport is closed."""
|
|
return self._closed
|
|
|
|
|
|
class WebRtcTransport(BaseTransport):
|
|
"""
|
|
WebRTC transport for WebRTC audio streaming.
|
|
|
|
Uses WebSocket for signaling and RTCPeerConnection for media.
|
|
"""
|
|
|
|
def __init__(self, websocket: WebSocket, pc):
|
|
"""
|
|
Initialize WebRTC transport.
|
|
|
|
Args:
|
|
websocket: FastAPI WebSocket connection for signaling
|
|
pc: RTCPeerConnection for media transport
|
|
"""
|
|
if not AIORTC_AVAILABLE:
|
|
raise RuntimeError("aiortc is not available - WebRTC transport cannot be used")
|
|
|
|
self.ws = websocket
|
|
self.pc = pc
|
|
self.outbound_track = None # MediaStreamTrack for outbound audio
|
|
self._closed = False
|
|
|
|
async def send_event(self, event: dict) -> None:
|
|
"""
|
|
Send a JSON event via WebSocket signaling.
|
|
|
|
Args:
|
|
event: Event data as dictionary
|
|
"""
|
|
if self._closed:
|
|
logger.warning("Attempted to send event on closed transport")
|
|
return
|
|
|
|
try:
|
|
await self.ws.send_text(json.dumps(event))
|
|
logger.debug(f"Sent event: {event.get('event', 'unknown')}")
|
|
except Exception as e:
|
|
logger.error(f"Error sending event: {e}")
|
|
self._closed = True
|
|
|
|
async def send_audio(self, pcm_bytes: bytes) -> None:
|
|
"""
|
|
Send audio data via WebRTC track.
|
|
|
|
Note: In WebRTC, you don't send bytes directly. You push frames
|
|
to a MediaStreamTrack that the peer connection is reading.
|
|
|
|
Args:
|
|
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
|
|
"""
|
|
if self._closed:
|
|
logger.warning("Attempted to send audio on closed transport")
|
|
return
|
|
|
|
# This would require a custom MediaStreamTrack implementation
|
|
# For now, we'll log this as a placeholder
|
|
logger.debug(f"Audio bytes queued for WebRTC track: {len(pcm_bytes)} bytes")
|
|
|
|
# TODO: Implement outbound audio track if needed
|
|
# if self.outbound_track:
|
|
# await self.outbound_track.add_frame(pcm_bytes)
|
|
|
|
async def close(self) -> None:
|
|
"""Close the WebRTC connection."""
|
|
self._closed = True
|
|
try:
|
|
await self.pc.close()
|
|
await self.ws.close()
|
|
except Exception as e:
|
|
logger.error(f"Error closing WebRTC transport: {e}")
|
|
|
|
@property
|
|
def is_closed(self) -> bool:
|
|
"""Check if the transport is closed."""
|
|
return self._closed
|
|
|
|
def set_outbound_track(self, track):
|
|
"""
|
|
Set the outbound audio track for sending audio to client.
|
|
|
|
Args:
|
|
track: MediaStreamTrack for outbound audio
|
|
"""
|
|
self.outbound_track = track
|
|
logger.debug("Set outbound track for WebRTC transport")
|