Files
AI-VideoAssistant/engine/runtime/transports.py
Xin Wang 7e0b777923 Refactor project structure and enhance backend integration
- Expanded package inclusion in `pyproject.toml` to support new modules.
- Introduced new `adapters` and `protocol` packages for better organization.
- Added backend adapter implementations for control plane integration.
- Updated main application imports to reflect new package structure.
- Removed deprecated core components and adjusted documentation accordingly.
- Enhanced architecture documentation to clarify the new runtime and integration layers.
2026-03-06 09:51:56 +08:00

248 lines
7.6 KiB
Python

"""Transport layer for WebSocket and WebRTC communication."""
import asyncio
import json
from abc import ABC, abstractmethod
from typing import Optional
from fastapi import WebSocket
from starlette.websockets import WebSocketState
from loguru import logger
# Try to import aiortc (optional for WebRTC functionality)
try:
from aiortc import RTCPeerConnection
AIORTC_AVAILABLE = True
except ImportError:
AIORTC_AVAILABLE = False
RTCPeerConnection = None # Type hint placeholder
class BaseTransport(ABC):
"""
Abstract base class for transports.
All transports must implement send_event and send_audio methods.
"""
@abstractmethod
async def send_event(self, event: dict) -> None:
"""
Send a JSON event to the client.
Args:
event: Event data as dictionary
"""
pass
@abstractmethod
async def send_audio(self, pcm_bytes: bytes) -> None:
"""
Send audio data to the client.
Args:
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
"""
pass
@abstractmethod
async def close(self) -> None:
"""Close the transport and cleanup resources."""
pass
class SocketTransport(BaseTransport):
"""
WebSocket transport for raw audio streaming.
Handles mixed text/binary frames over WebSocket connection.
Uses asyncio.Lock to prevent frame interleaving.
"""
def __init__(self, websocket: WebSocket):
"""
Initialize WebSocket transport.
Args:
websocket: FastAPI WebSocket connection
"""
self.ws = websocket
self.lock = asyncio.Lock() # Prevent frame interleaving
self._closed = False
def _ws_disconnected(self) -> bool:
"""Best-effort check for websocket disconnection state."""
return (
self.ws.client_state == WebSocketState.DISCONNECTED
or self.ws.application_state == WebSocketState.DISCONNECTED
)
async def send_event(self, event: dict) -> None:
"""
Send a JSON event via WebSocket.
Args:
event: Event data as dictionary
"""
if self._closed or self._ws_disconnected():
logger.warning("Attempted to send event on closed transport")
self._closed = True
return
async with self.lock:
try:
await self.ws.send_text(json.dumps(event))
logger.debug(f"Sent event: {event.get('event', 'unknown')}")
except RuntimeError as e:
self._closed = True
if self._ws_disconnected() or "close message has been sent" in str(e):
logger.debug(f"Skip sending event on closed websocket: {e!r}")
return
logger.error(f"Error sending event: {e!r}")
except Exception as e:
self._closed = True
if self._ws_disconnected():
logger.debug(f"Skip sending event on disconnected websocket: {e!r}")
return
logger.error(f"Error sending event: {e!r}")
async def send_audio(self, pcm_bytes: bytes) -> None:
"""
Send PCM audio data via WebSocket.
Args:
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
"""
if self._closed or self._ws_disconnected():
logger.warning("Attempted to send audio on closed transport")
self._closed = True
return
async with self.lock:
try:
await self.ws.send_bytes(pcm_bytes)
except RuntimeError as e:
self._closed = True
if self._ws_disconnected() or "close message has been sent" in str(e):
logger.debug(f"Skip sending audio on closed websocket: {e!r}")
return
logger.error(f"Error sending audio: {e!r}")
except Exception as e:
self._closed = True
if self._ws_disconnected():
logger.debug(f"Skip sending audio on disconnected websocket: {e!r}")
return
logger.error(f"Error sending audio: {e!r}")
async def close(self) -> None:
"""Close the WebSocket connection."""
if self._closed:
return
self._closed = True
if self._ws_disconnected():
return
try:
await self.ws.close()
except RuntimeError as e:
# Already closed by another task/path; safe to ignore.
if "close message has been sent" in str(e):
logger.debug(f"WebSocket already closed: {e}")
return
logger.error(f"Error closing WebSocket: {e}")
except Exception as e:
logger.error(f"Error closing WebSocket: {e}")
@property
def is_closed(self) -> bool:
"""Check if the transport is closed."""
return self._closed
class WebRtcTransport(BaseTransport):
"""
WebRTC transport for WebRTC audio streaming.
Uses WebSocket for signaling and RTCPeerConnection for media.
"""
def __init__(self, websocket: WebSocket, pc):
"""
Initialize WebRTC transport.
Args:
websocket: FastAPI WebSocket connection for signaling
pc: RTCPeerConnection for media transport
"""
if not AIORTC_AVAILABLE:
raise RuntimeError("aiortc is not available - WebRTC transport cannot be used")
self.ws = websocket
self.pc = pc
self.outbound_track = None # MediaStreamTrack for outbound audio
self._closed = False
async def send_event(self, event: dict) -> None:
"""
Send a JSON event via WebSocket signaling.
Args:
event: Event data as dictionary
"""
if self._closed:
logger.warning("Attempted to send event on closed transport")
return
try:
await self.ws.send_text(json.dumps(event))
logger.debug(f"Sent event: {event.get('event', 'unknown')}")
except Exception as e:
logger.error(f"Error sending event: {e}")
self._closed = True
async def send_audio(self, pcm_bytes: bytes) -> None:
"""
Send audio data via WebRTC track.
Note: In WebRTC, you don't send bytes directly. You push frames
to a MediaStreamTrack that the peer connection is reading.
Args:
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
"""
if self._closed:
logger.warning("Attempted to send audio on closed transport")
return
# This would require a custom MediaStreamTrack implementation
# For now, we'll log this as a placeholder
logger.debug(f"Audio bytes queued for WebRTC track: {len(pcm_bytes)} bytes")
# TODO: Implement outbound audio track if needed
# if self.outbound_track:
# await self.outbound_track.add_frame(pcm_bytes)
async def close(self) -> None:
"""Close the WebRTC connection."""
self._closed = True
try:
await self.pc.close()
await self.ws.close()
except Exception as e:
logger.error(f"Error closing WebRTC transport: {e}")
@property
def is_closed(self) -> bool:
"""Check if the transport is closed."""
return self._closed
def set_outbound_track(self, track):
"""
Set the outbound audio track for sending audio to client.
Args:
track: MediaStreamTrack for outbound audio
"""
self.outbound_track = track
logger.debug("Set outbound track for WebRTC transport")