Files
AI-VideoAssistant/engine/core/transports.py
2026-02-06 14:01:34 +08:00

208 lines
5.9 KiB
Python

"""Transport layer for WebSocket and WebRTC communication."""
import asyncio
import json
from abc import ABC, abstractmethod
from typing import Optional
from fastapi import WebSocket
from loguru import logger
# Try to import aiortc (optional for WebRTC functionality)
try:
from aiortc import RTCPeerConnection
AIORTC_AVAILABLE = True
except ImportError:
AIORTC_AVAILABLE = False
RTCPeerConnection = None # Type hint placeholder
class BaseTransport(ABC):
"""
Abstract base class for transports.
All transports must implement send_event and send_audio methods.
"""
@abstractmethod
async def send_event(self, event: dict) -> None:
"""
Send a JSON event to the client.
Args:
event: Event data as dictionary
"""
pass
@abstractmethod
async def send_audio(self, pcm_bytes: bytes) -> None:
"""
Send audio data to the client.
Args:
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
"""
pass
@abstractmethod
async def close(self) -> None:
"""Close the transport and cleanup resources."""
pass
class SocketTransport(BaseTransport):
"""
WebSocket transport for raw audio streaming.
Handles mixed text/binary frames over WebSocket connection.
Uses asyncio.Lock to prevent frame interleaving.
"""
def __init__(self, websocket: WebSocket):
"""
Initialize WebSocket transport.
Args:
websocket: FastAPI WebSocket connection
"""
self.ws = websocket
self.lock = asyncio.Lock() # Prevent frame interleaving
self._closed = False
async def send_event(self, event: dict) -> None:
"""
Send a JSON event via WebSocket.
Args:
event: Event data as dictionary
"""
if self._closed:
logger.warning("Attempted to send event on closed transport")
return
async with self.lock:
try:
await self.ws.send_text(json.dumps(event))
logger.debug(f"Sent event: {event.get('event', 'unknown')}")
except Exception as e:
logger.error(f"Error sending event: {e}")
self._closed = True
async def send_audio(self, pcm_bytes: bytes) -> None:
"""
Send PCM audio data via WebSocket.
Args:
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
"""
if self._closed:
logger.warning("Attempted to send audio on closed transport")
return
async with self.lock:
try:
await self.ws.send_bytes(pcm_bytes)
except Exception as e:
logger.error(f"Error sending audio: {e}")
self._closed = True
async def close(self) -> None:
"""Close the WebSocket connection."""
self._closed = True
try:
await self.ws.close()
except Exception as e:
logger.error(f"Error closing WebSocket: {e}")
@property
def is_closed(self) -> bool:
"""Check if the transport is closed."""
return self._closed
class WebRtcTransport(BaseTransport):
"""
WebRTC transport for WebRTC audio streaming.
Uses WebSocket for signaling and RTCPeerConnection for media.
"""
def __init__(self, websocket: WebSocket, pc):
"""
Initialize WebRTC transport.
Args:
websocket: FastAPI WebSocket connection for signaling
pc: RTCPeerConnection for media transport
"""
if not AIORTC_AVAILABLE:
raise RuntimeError("aiortc is not available - WebRTC transport cannot be used")
self.ws = websocket
self.pc = pc
self.outbound_track = None # MediaStreamTrack for outbound audio
self._closed = False
async def send_event(self, event: dict) -> None:
"""
Send a JSON event via WebSocket signaling.
Args:
event: Event data as dictionary
"""
if self._closed:
logger.warning("Attempted to send event on closed transport")
return
try:
await self.ws.send_text(json.dumps(event))
logger.debug(f"Sent event: {event.get('event', 'unknown')}")
except Exception as e:
logger.error(f"Error sending event: {e}")
self._closed = True
async def send_audio(self, pcm_bytes: bytes) -> None:
"""
Send audio data via WebRTC track.
Note: In WebRTC, you don't send bytes directly. You push frames
to a MediaStreamTrack that the peer connection is reading.
Args:
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
"""
if self._closed:
logger.warning("Attempted to send audio on closed transport")
return
# This would require a custom MediaStreamTrack implementation
# For now, we'll log this as a placeholder
logger.debug(f"Audio bytes queued for WebRTC track: {len(pcm_bytes)} bytes")
# TODO: Implement outbound audio track if needed
# if self.outbound_track:
# await self.outbound_track.add_frame(pcm_bytes)
async def close(self) -> None:
"""Close the WebRTC connection."""
self._closed = True
try:
await self.pc.close()
await self.ws.close()
except Exception as e:
logger.error(f"Error closing WebRTC transport: {e}")
@property
def is_closed(self) -> bool:
"""Check if the transport is closed."""
return self._closed
def set_outbound_track(self, track):
"""
Set the outbound audio track for sending audio to client.
Args:
track: MediaStreamTrack for outbound audio
"""
self.outbound_track = track
logger.debug("Set outbound track for WebRTC transport")