Integrate eou and vad
This commit is contained in:
1
core/__init__.py
Normal file
1
core/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Core Components Package"""
|
||||
134
core/events.py
Normal file
134
core/events.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""Event bus for pub/sub communication between components."""
|
||||
|
||||
import asyncio
|
||||
from typing import Callable, Dict, List, Any, Optional
|
||||
from collections import defaultdict
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class EventBus:
|
||||
"""
|
||||
Async event bus for pub/sub communication.
|
||||
|
||||
Similar to the original Rust implementation's broadcast channel.
|
||||
Components can subscribe to specific event types and receive events asynchronously.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the event bus."""
|
||||
self._subscribers: Dict[str, List[Callable]] = defaultdict(list)
|
||||
self._lock = asyncio.Lock()
|
||||
self._running = True
|
||||
|
||||
def subscribe(self, event_type: str, callback: Callable[[Dict[str, Any]], None]) -> None:
|
||||
"""
|
||||
Subscribe to an event type.
|
||||
|
||||
Args:
|
||||
event_type: Type of event to subscribe to (e.g., "speaking", "silence")
|
||||
callback: Async callback function that receives event data
|
||||
"""
|
||||
if not self._running:
|
||||
logger.warning(f"Event bus is shut down, ignoring subscription to {event_type}")
|
||||
return
|
||||
|
||||
self._subscribers[event_type].append(callback)
|
||||
logger.debug(f"Subscribed to event type: {event_type}")
|
||||
|
||||
def unsubscribe(self, event_type: str, callback: Callable[[Dict[str, Any]], None]) -> None:
|
||||
"""
|
||||
Unsubscribe from an event type.
|
||||
|
||||
Args:
|
||||
event_type: Type of event to unsubscribe from
|
||||
callback: Callback function to remove
|
||||
"""
|
||||
if callback in self._subscribers[event_type]:
|
||||
self._subscribers[event_type].remove(callback)
|
||||
logger.debug(f"Unsubscribed from event type: {event_type}")
|
||||
|
||||
async def publish(self, event_type: str, event_data: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Publish an event to all subscribers.
|
||||
|
||||
Args:
|
||||
event_type: Type of event to publish
|
||||
event_data: Event data to send to subscribers
|
||||
"""
|
||||
if not self._running:
|
||||
logger.warning(f"Event bus is shut down, ignoring event: {event_type}")
|
||||
return
|
||||
|
||||
# Get subscribers for this event type
|
||||
subscribers = self._subscribers.get(event_type, [])
|
||||
|
||||
if not subscribers:
|
||||
logger.debug(f"No subscribers for event type: {event_type}")
|
||||
return
|
||||
|
||||
# Notify all subscribers concurrently
|
||||
tasks = []
|
||||
for callback in subscribers:
|
||||
try:
|
||||
# Create task for each subscriber
|
||||
task = asyncio.create_task(self._call_subscriber(callback, event_data))
|
||||
tasks.append(task)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating task for subscriber: {e}")
|
||||
|
||||
# Wait for all subscribers to complete
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
logger.debug(f"Published event '{event_type}' to {len(tasks)} subscribers")
|
||||
|
||||
async def _call_subscriber(self, callback: Callable[[Dict[str, Any]], None], event_data: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Call a subscriber callback with error handling.
|
||||
|
||||
Args:
|
||||
callback: Subscriber callback function
|
||||
event_data: Event data to pass to callback
|
||||
"""
|
||||
try:
|
||||
# Check if callback is a coroutine function
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(event_data)
|
||||
else:
|
||||
callback(event_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in subscriber callback: {e}", exc_info=True)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the event bus and stop processing events."""
|
||||
self._running = False
|
||||
self._subscribers.clear()
|
||||
logger.info("Event bus closed")
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the event bus is running."""
|
||||
return self._running
|
||||
|
||||
|
||||
# Global event bus instance
|
||||
_event_bus: Optional[EventBus] = None
|
||||
|
||||
|
||||
def get_event_bus() -> EventBus:
|
||||
"""
|
||||
Get the global event bus instance.
|
||||
|
||||
Returns:
|
||||
EventBus instance
|
||||
"""
|
||||
global _event_bus
|
||||
if _event_bus is None:
|
||||
_event_bus = EventBus()
|
||||
return _event_bus
|
||||
|
||||
|
||||
def reset_event_bus() -> None:
|
||||
"""Reset the global event bus (mainly for testing)."""
|
||||
global _event_bus
|
||||
_event_bus = None
|
||||
131
core/pipeline.py
Normal file
131
core/pipeline.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""Audio processing pipeline."""
|
||||
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
from loguru import logger
|
||||
|
||||
from core.transports import BaseTransport
|
||||
from core.events import EventBus, get_event_bus
|
||||
from processors.vad import VADProcessor, SileroVAD
|
||||
from app.config import settings
|
||||
|
||||
|
||||
class AudioPipeline:
|
||||
"""
|
||||
Audio processing pipeline.
|
||||
|
||||
Processes incoming audio through VAD and emits events.
|
||||
"""
|
||||
|
||||
def __init__(self, transport: BaseTransport, session_id: str):
|
||||
"""
|
||||
Initialize audio pipeline.
|
||||
|
||||
Args:
|
||||
transport: Transport instance for sending events/audio
|
||||
session_id: Session identifier for event tracking
|
||||
"""
|
||||
self.transport = transport
|
||||
self.session_id = session_id
|
||||
self.event_bus = get_event_bus()
|
||||
|
||||
# Initialize VAD
|
||||
self.vad_model = SileroVAD(
|
||||
model_path=settings.vad_model_path,
|
||||
sample_rate=settings.sample_rate
|
||||
)
|
||||
self.vad_processor = VADProcessor(
|
||||
vad_model=self.vad_model,
|
||||
threshold=settings.vad_threshold,
|
||||
silence_threshold_ms=settings.vad_eou_threshold_ms,
|
||||
min_speech_duration_ms=settings.vad_min_speech_duration_ms
|
||||
)
|
||||
|
||||
# State
|
||||
self.is_bot_speaking = False
|
||||
self.interrupt_signal = asyncio.Event()
|
||||
self._running = True
|
||||
|
||||
logger.info(f"Audio pipeline initialized for session {session_id}")
|
||||
|
||||
async def process_input(self, pcm_bytes: bytes) -> None:
|
||||
"""
|
||||
Process incoming audio chunk.
|
||||
|
||||
Args:
|
||||
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
|
||||
"""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
try:
|
||||
# Process through VAD
|
||||
result = self.vad_processor.process(pcm_bytes, settings.chunk_size_ms)
|
||||
|
||||
if result:
|
||||
event_type, probability = result
|
||||
|
||||
# Emit event through event bus
|
||||
await self.event_bus.publish(event_type, {
|
||||
"trackId": self.session_id,
|
||||
"probability": probability
|
||||
})
|
||||
|
||||
# Send event to client
|
||||
if event_type == "speaking":
|
||||
logger.info(f"User speaking started (session {self.session_id})")
|
||||
await self.transport.send_event({
|
||||
"event": "speaking",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms(),
|
||||
"startTime": self._get_timestamp_ms()
|
||||
})
|
||||
|
||||
elif event_type == "silence":
|
||||
logger.info(f"User speaking stopped (session {self.session_id})")
|
||||
await self.transport.send_event({
|
||||
"event": "silence",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms(),
|
||||
"startTime": self._get_timestamp_ms(),
|
||||
"duration": 0 # TODO: Calculate actual duration
|
||||
})
|
||||
|
||||
elif event_type == "eou":
|
||||
logger.info(f"EOU detected (session {self.session_id})")
|
||||
await self.transport.send_event({
|
||||
"event": "eou",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Pipeline processing error: {e}", exc_info=True)
|
||||
|
||||
async def process_text_input(self, text: str) -> None:
|
||||
"""
|
||||
Process text input (chat command).
|
||||
|
||||
Args:
|
||||
text: Text input
|
||||
"""
|
||||
logger.info(f"Processing text input: {text[:50]}...")
|
||||
# TODO: Implement text processing (LLM integration, etc.)
|
||||
# For now, just log it
|
||||
|
||||
async def interrupt(self) -> None:
|
||||
"""Interrupt current audio playback."""
|
||||
if self.is_bot_speaking:
|
||||
self.interrupt_signal.set()
|
||||
logger.info(f"Pipeline interrupted for session {self.session_id}")
|
||||
|
||||
async def cleanup(self) -> None:
|
||||
"""Cleanup pipeline resources."""
|
||||
logger.info(f"Cleaning up pipeline for session {self.session_id}")
|
||||
self._running = False
|
||||
self.interrupt_signal.set()
|
||||
|
||||
def _get_timestamp_ms(self) -> int:
|
||||
"""Get current timestamp in milliseconds."""
|
||||
import time
|
||||
return int(time.time() * 1000)
|
||||
266
core/session.py
Normal file
266
core/session.py
Normal file
@@ -0,0 +1,266 @@
|
||||
"""Session management for active calls."""
|
||||
|
||||
import uuid
|
||||
import json
|
||||
from typing import Optional, Dict, Any
|
||||
from loguru import logger
|
||||
|
||||
from core.transports import BaseTransport
|
||||
from core.pipeline import AudioPipeline
|
||||
from models.commands import parse_command, TTSCommand, ChatCommand, InterruptCommand, HangupCommand
|
||||
|
||||
|
||||
class Session:
|
||||
"""
|
||||
Manages a single call session.
|
||||
|
||||
Handles command routing, audio processing, and session lifecycle.
|
||||
"""
|
||||
|
||||
def __init__(self, session_id: str, transport: BaseTransport):
|
||||
"""
|
||||
Initialize session.
|
||||
|
||||
Args:
|
||||
session_id: Unique session identifier
|
||||
transport: Transport instance for communication
|
||||
"""
|
||||
self.id = session_id
|
||||
self.transport = transport
|
||||
self.pipeline = AudioPipeline(transport, session_id)
|
||||
|
||||
# Session state
|
||||
self.created_at = None
|
||||
self.state = "created" # created, invited, accepted, ringing, hungup
|
||||
|
||||
# Track IDs
|
||||
self.current_track_id: Optional[str] = str(uuid.uuid4())
|
||||
|
||||
logger.info(f"Session {self.id} created")
|
||||
|
||||
async def handle_text(self, text_data: str) -> None:
|
||||
"""
|
||||
Handle incoming text data (JSON commands).
|
||||
|
||||
Args:
|
||||
text_data: JSON text data
|
||||
"""
|
||||
try:
|
||||
data = json.loads(text_data)
|
||||
command = parse_command(data)
|
||||
command_type = command.command
|
||||
|
||||
logger.info(f"Session {self.id} received command: {command_type}")
|
||||
|
||||
# Route command to appropriate handler
|
||||
if command_type == "invite":
|
||||
await self._handle_invite(data)
|
||||
|
||||
elif command_type == "accept":
|
||||
await self._handle_accept(data)
|
||||
|
||||
elif command_type == "reject":
|
||||
await self._handle_reject(data)
|
||||
|
||||
elif command_type == "ringing":
|
||||
await self._handle_ringing(data)
|
||||
|
||||
elif command_type == "tts":
|
||||
await self._handle_tts(command)
|
||||
|
||||
elif command_type == "play":
|
||||
await self._handle_play(data)
|
||||
|
||||
elif command_type == "interrupt":
|
||||
await self._handle_interrupt(command)
|
||||
|
||||
elif command_type == "pause":
|
||||
await self._handle_pause()
|
||||
|
||||
elif command_type == "resume":
|
||||
await self._handle_resume()
|
||||
|
||||
elif command_type == "hangup":
|
||||
await self._handle_hangup(command)
|
||||
|
||||
elif command_type == "history":
|
||||
await self._handle_history(data)
|
||||
|
||||
elif command_type == "chat":
|
||||
await self._handle_chat(command)
|
||||
|
||||
else:
|
||||
logger.warning(f"Session {self.id} unknown command: {command_type}")
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Session {self.id} JSON decode error: {e}")
|
||||
await self._send_error("client", f"Invalid JSON: {e}")
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Session {self.id} command parse error: {e}")
|
||||
await self._send_error("client", f"Invalid command: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Session {self.id} handle_text error: {e}", exc_info=True)
|
||||
await self._send_error("server", f"Internal error: {e}")
|
||||
|
||||
async def handle_audio(self, audio_bytes: bytes) -> None:
|
||||
"""
|
||||
Handle incoming audio data.
|
||||
|
||||
Args:
|
||||
audio_bytes: PCM audio data
|
||||
"""
|
||||
try:
|
||||
await self.pipeline.process_input(audio_bytes)
|
||||
except Exception as e:
|
||||
logger.error(f"Session {self.id} handle_audio error: {e}", exc_info=True)
|
||||
|
||||
async def _handle_invite(self, data: Dict[str, Any]) -> None:
|
||||
"""Handle invite command."""
|
||||
self.state = "invited"
|
||||
option = data.get("option", {})
|
||||
|
||||
# Send answer event
|
||||
await self.transport.send_event({
|
||||
"event": "answer",
|
||||
"trackId": self.current_track_id,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
})
|
||||
|
||||
logger.info(f"Session {self.id} invited with codec: {option.get('codec', 'pcm')}")
|
||||
|
||||
async def _handle_accept(self, data: Dict[str, Any]) -> None:
|
||||
"""Handle accept command."""
|
||||
self.state = "accepted"
|
||||
logger.info(f"Session {self.id} accepted")
|
||||
|
||||
async def _handle_reject(self, data: Dict[str, Any]) -> None:
|
||||
"""Handle reject command."""
|
||||
self.state = "rejected"
|
||||
reason = data.get("reason", "Rejected")
|
||||
logger.info(f"Session {self.id} rejected: {reason}")
|
||||
|
||||
async def _handle_ringing(self, data: Dict[str, Any]) -> None:
|
||||
"""Handle ringing command."""
|
||||
self.state = "ringing"
|
||||
logger.info(f"Session {self.id} ringing")
|
||||
|
||||
async def _handle_tts(self, command: TTSCommand) -> None:
|
||||
"""Handle TTS command."""
|
||||
logger.info(f"Session {self.id} TTS: {command.text[:50]}...")
|
||||
|
||||
# Send track start event
|
||||
await self.transport.send_event({
|
||||
"event": "trackStart",
|
||||
"trackId": self.current_track_id,
|
||||
"timestamp": self._get_timestamp_ms(),
|
||||
"playId": command.play_id
|
||||
})
|
||||
|
||||
# TODO: Implement actual TTS synthesis
|
||||
# For now, just send track end event
|
||||
await self.transport.send_event({
|
||||
"event": "trackEnd",
|
||||
"trackId": self.current_track_id,
|
||||
"timestamp": self._get_timestamp_ms(),
|
||||
"duration": 1000,
|
||||
"ssrc": 0,
|
||||
"playId": command.play_id
|
||||
})
|
||||
|
||||
async def _handle_play(self, data: Dict[str, Any]) -> None:
|
||||
"""Handle play command."""
|
||||
url = data.get("url", "")
|
||||
logger.info(f"Session {self.id} play: {url}")
|
||||
|
||||
# Send track start event
|
||||
await self.transport.send_event({
|
||||
"event": "trackStart",
|
||||
"trackId": self.current_track_id,
|
||||
"timestamp": self._get_timestamp_ms(),
|
||||
"playId": url
|
||||
})
|
||||
|
||||
# TODO: Implement actual audio playback
|
||||
# For now, just send track end event
|
||||
await self.transport.send_event({
|
||||
"event": "trackEnd",
|
||||
"trackId": self.current_track_id,
|
||||
"timestamp": self._get_timestamp_ms(),
|
||||
"duration": 1000,
|
||||
"ssrc": 0,
|
||||
"playId": url
|
||||
})
|
||||
|
||||
async def _handle_interrupt(self, command: InterruptCommand) -> None:
|
||||
"""Handle interrupt command."""
|
||||
if command.graceful:
|
||||
logger.info(f"Session {self.id} graceful interrupt")
|
||||
else:
|
||||
logger.info(f"Session {self.id} immediate interrupt")
|
||||
await self.pipeline.interrupt()
|
||||
|
||||
async def _handle_pause(self) -> None:
|
||||
"""Handle pause command."""
|
||||
logger.info(f"Session {self.id} paused")
|
||||
|
||||
async def _handle_resume(self) -> None:
|
||||
"""Handle resume command."""
|
||||
logger.info(f"Session {self.id} resumed")
|
||||
|
||||
async def _handle_hangup(self, command: HangupCommand) -> None:
|
||||
"""Handle hangup command."""
|
||||
self.state = "hungup"
|
||||
reason = command.reason or "User requested"
|
||||
logger.info(f"Session {self.id} hung up: {reason}")
|
||||
|
||||
# Send hangup event
|
||||
await self.transport.send_event({
|
||||
"event": "hangup",
|
||||
"timestamp": self._get_timestamp_ms(),
|
||||
"reason": reason,
|
||||
"initiator": command.initiator or "user"
|
||||
})
|
||||
|
||||
# Close transport
|
||||
await self.transport.close()
|
||||
|
||||
async def _handle_history(self, data: Dict[str, Any]) -> None:
|
||||
"""Handle history command."""
|
||||
speaker = data.get("speaker", "unknown")
|
||||
text = data.get("text", "")
|
||||
logger.info(f"Session {self.id} history [{speaker}]: {text[:50]}...")
|
||||
|
||||
async def _handle_chat(self, command: ChatCommand) -> None:
|
||||
"""Handle chat command."""
|
||||
logger.info(f"Session {self.id} chat: {command.text[:50]}...")
|
||||
# Process text input through pipeline
|
||||
await self.pipeline.process_text_input(command.text)
|
||||
|
||||
async def _send_error(self, sender: str, error_message: str) -> None:
|
||||
"""
|
||||
Send error event to client.
|
||||
|
||||
Args:
|
||||
sender: Component that generated the error
|
||||
error_message: Error message
|
||||
"""
|
||||
await self.transport.send_event({
|
||||
"event": "error",
|
||||
"trackId": self.current_track_id,
|
||||
"timestamp": self._get_timestamp_ms(),
|
||||
"sender": sender,
|
||||
"error": error_message
|
||||
})
|
||||
|
||||
def _get_timestamp_ms(self) -> int:
|
||||
"""Get current timestamp in milliseconds."""
|
||||
import time
|
||||
return int(time.time() * 1000)
|
||||
|
||||
async def cleanup(self) -> None:
|
||||
"""Cleanup session resources."""
|
||||
logger.info(f"Session {self.id} cleaning up")
|
||||
await self.pipeline.cleanup()
|
||||
await self.transport.close()
|
||||
207
core/transports.py
Normal file
207
core/transports.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""Transport layer for WebSocket and WebRTC communication."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from fastapi import WebSocket
|
||||
from loguru import logger
|
||||
|
||||
# Try to import aiortc (optional for WebRTC functionality)
|
||||
try:
|
||||
from aiortc import RTCPeerConnection
|
||||
AIORTC_AVAILABLE = True
|
||||
except ImportError:
|
||||
AIORTC_AVAILABLE = False
|
||||
RTCPeerConnection = None # Type hint placeholder
|
||||
|
||||
|
||||
class BaseTransport(ABC):
|
||||
"""
|
||||
Abstract base class for transports.
|
||||
|
||||
All transports must implement send_event and send_audio methods.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def send_event(self, event: dict) -> None:
|
||||
"""
|
||||
Send a JSON event to the client.
|
||||
|
||||
Args:
|
||||
event: Event data as dictionary
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def send_audio(self, pcm_bytes: bytes) -> None:
|
||||
"""
|
||||
Send audio data to the client.
|
||||
|
||||
Args:
|
||||
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None:
|
||||
"""Close the transport and cleanup resources."""
|
||||
pass
|
||||
|
||||
|
||||
class SocketTransport(BaseTransport):
|
||||
"""
|
||||
WebSocket transport for raw audio streaming.
|
||||
|
||||
Handles mixed text/binary frames over WebSocket connection.
|
||||
Uses asyncio.Lock to prevent frame interleaving.
|
||||
"""
|
||||
|
||||
def __init__(self, websocket: WebSocket):
|
||||
"""
|
||||
Initialize WebSocket transport.
|
||||
|
||||
Args:
|
||||
websocket: FastAPI WebSocket connection
|
||||
"""
|
||||
self.ws = websocket
|
||||
self.lock = asyncio.Lock() # Prevent frame interleaving
|
||||
self._closed = False
|
||||
|
||||
async def send_event(self, event: dict) -> None:
|
||||
"""
|
||||
Send a JSON event via WebSocket.
|
||||
|
||||
Args:
|
||||
event: Event data as dictionary
|
||||
"""
|
||||
if self._closed:
|
||||
logger.warning("Attempted to send event on closed transport")
|
||||
return
|
||||
|
||||
async with self.lock:
|
||||
try:
|
||||
await self.ws.send_text(json.dumps(event))
|
||||
logger.debug(f"Sent event: {event.get('event', 'unknown')}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending event: {e}")
|
||||
self._closed = True
|
||||
|
||||
async def send_audio(self, pcm_bytes: bytes) -> None:
|
||||
"""
|
||||
Send PCM audio data via WebSocket.
|
||||
|
||||
Args:
|
||||
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
|
||||
"""
|
||||
if self._closed:
|
||||
logger.warning("Attempted to send audio on closed transport")
|
||||
return
|
||||
|
||||
async with self.lock:
|
||||
try:
|
||||
await self.ws.send_bytes(pcm_bytes)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending audio: {e}")
|
||||
self._closed = True
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the WebSocket connection."""
|
||||
self._closed = True
|
||||
try:
|
||||
await self.ws.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing WebSocket: {e}")
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
"""Check if the transport is closed."""
|
||||
return self._closed
|
||||
|
||||
|
||||
class WebRtcTransport(BaseTransport):
|
||||
"""
|
||||
WebRTC transport for WebRTC audio streaming.
|
||||
|
||||
Uses WebSocket for signaling and RTCPeerConnection for media.
|
||||
"""
|
||||
|
||||
def __init__(self, websocket: WebSocket, pc):
|
||||
"""
|
||||
Initialize WebRTC transport.
|
||||
|
||||
Args:
|
||||
websocket: FastAPI WebSocket connection for signaling
|
||||
pc: RTCPeerConnection for media transport
|
||||
"""
|
||||
if not AIORTC_AVAILABLE:
|
||||
raise RuntimeError("aiortc is not available - WebRTC transport cannot be used")
|
||||
|
||||
self.ws = websocket
|
||||
self.pc = pc
|
||||
self.outbound_track = None # MediaStreamTrack for outbound audio
|
||||
self._closed = False
|
||||
|
||||
async def send_event(self, event: dict) -> None:
|
||||
"""
|
||||
Send a JSON event via WebSocket signaling.
|
||||
|
||||
Args:
|
||||
event: Event data as dictionary
|
||||
"""
|
||||
if self._closed:
|
||||
logger.warning("Attempted to send event on closed transport")
|
||||
return
|
||||
|
||||
try:
|
||||
await self.ws.send_text(json.dumps(event))
|
||||
logger.debug(f"Sent event: {event.get('event', 'unknown')}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending event: {e}")
|
||||
self._closed = True
|
||||
|
||||
async def send_audio(self, pcm_bytes: bytes) -> None:
|
||||
"""
|
||||
Send audio data via WebRTC track.
|
||||
|
||||
Note: In WebRTC, you don't send bytes directly. You push frames
|
||||
to a MediaStreamTrack that the peer connection is reading.
|
||||
|
||||
Args:
|
||||
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
|
||||
"""
|
||||
if self._closed:
|
||||
logger.warning("Attempted to send audio on closed transport")
|
||||
return
|
||||
|
||||
# This would require a custom MediaStreamTrack implementation
|
||||
# For now, we'll log this as a placeholder
|
||||
logger.debug(f"Audio bytes queued for WebRTC track: {len(pcm_bytes)} bytes")
|
||||
|
||||
# TODO: Implement outbound audio track if needed
|
||||
# if self.outbound_track:
|
||||
# await self.outbound_track.add_frame(pcm_bytes)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the WebRTC connection."""
|
||||
self._closed = True
|
||||
try:
|
||||
await self.pc.close()
|
||||
await self.ws.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing WebRTC transport: {e}")
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
"""Check if the transport is closed."""
|
||||
return self._closed
|
||||
|
||||
def set_outbound_track(self, track):
|
||||
"""
|
||||
Set the outbound audio track for sending audio to client.
|
||||
|
||||
Args:
|
||||
track: MediaStreamTrack for outbound audio
|
||||
"""
|
||||
self.outbound_track = track
|
||||
logger.debug("Set outbound track for WebRTC transport")
|
||||
Reference in New Issue
Block a user