Refactor project structure and enhance backend integration
- Expanded package inclusion in `pyproject.toml` to support new modules. - Introduced new `adapters` and `protocol` packages for better organization. - Added backend adapter implementations for control plane integration. - Updated main application imports to reflect new package structure. - Removed deprecated core components and adjusted documentation accordingly. - Enhanced architecture documentation to clarify the new runtime and integration layers.
This commit is contained in:
1
engine/runtime/__init__.py
Normal file
1
engine/runtime/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Runtime package."""
|
||||
279
engine/runtime/conversation.py
Normal file
279
engine/runtime/conversation.py
Normal file
@@ -0,0 +1,279 @@
|
||||
"""Conversation management for voice AI.
|
||||
|
||||
Handles conversation context, turn-taking, and message history
|
||||
for multi-turn voice conversations.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import List, Optional, Dict, Any, Callable, Awaitable
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from loguru import logger
|
||||
|
||||
from providers.common.base import LLMMessage
|
||||
|
||||
|
||||
class ConversationState(Enum):
|
||||
"""State of the conversation."""
|
||||
IDLE = "idle" # Waiting for user input
|
||||
LISTENING = "listening" # User is speaking
|
||||
PROCESSING = "processing" # Processing user input (LLM)
|
||||
SPEAKING = "speaking" # Bot is speaking
|
||||
INTERRUPTED = "interrupted" # Bot was interrupted
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConversationTurn:
|
||||
"""A single turn in the conversation."""
|
||||
role: str # "user" or "assistant"
|
||||
text: str
|
||||
audio_duration_ms: Optional[int] = None
|
||||
timestamp: float = field(default_factory=lambda: asyncio.get_event_loop().time())
|
||||
was_interrupted: bool = False
|
||||
|
||||
|
||||
class ConversationManager:
|
||||
"""
|
||||
Manages conversation state and history.
|
||||
|
||||
Provides:
|
||||
- Message history for LLM context
|
||||
- Turn management
|
||||
- State tracking
|
||||
- Event callbacks for state changes
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
system_prompt: Optional[str] = None,
|
||||
max_history: int = 20,
|
||||
greeting: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize conversation manager.
|
||||
|
||||
Args:
|
||||
system_prompt: System prompt for LLM
|
||||
max_history: Maximum number of turns to keep
|
||||
greeting: Optional greeting message when conversation starts
|
||||
"""
|
||||
self.system_prompt = system_prompt or (
|
||||
"You are a helpful, friendly voice assistant. "
|
||||
"Keep your responses concise and conversational. "
|
||||
"Respond naturally as if having a phone conversation. "
|
||||
"If you don't understand something, ask for clarification."
|
||||
)
|
||||
self.max_history = max_history
|
||||
self.greeting = greeting
|
||||
|
||||
# State
|
||||
self.state = ConversationState.IDLE
|
||||
self.turns: List[ConversationTurn] = []
|
||||
|
||||
# Callbacks
|
||||
self._state_callbacks: List[Callable[[ConversationState, ConversationState], Awaitable[None]]] = []
|
||||
self._turn_callbacks: List[Callable[[ConversationTurn], Awaitable[None]]] = []
|
||||
|
||||
# Current turn tracking
|
||||
self._current_user_text: str = ""
|
||||
self._current_assistant_text: str = ""
|
||||
|
||||
logger.info("ConversationManager initialized")
|
||||
|
||||
def on_state_change(
|
||||
self,
|
||||
callback: Callable[[ConversationState, ConversationState], Awaitable[None]]
|
||||
) -> None:
|
||||
"""Register callback for state changes."""
|
||||
self._state_callbacks.append(callback)
|
||||
|
||||
def on_turn_complete(
|
||||
self,
|
||||
callback: Callable[[ConversationTurn], Awaitable[None]]
|
||||
) -> None:
|
||||
"""Register callback for turn completion."""
|
||||
self._turn_callbacks.append(callback)
|
||||
|
||||
async def set_state(self, new_state: ConversationState) -> None:
|
||||
"""Set conversation state and notify listeners."""
|
||||
if new_state != self.state:
|
||||
old_state = self.state
|
||||
self.state = new_state
|
||||
logger.debug(f"Conversation state: {old_state.value} -> {new_state.value}")
|
||||
|
||||
for callback in self._state_callbacks:
|
||||
try:
|
||||
await callback(old_state, new_state)
|
||||
except Exception as e:
|
||||
logger.error(f"State callback error: {e}")
|
||||
|
||||
def get_messages(self) -> List[LLMMessage]:
|
||||
"""
|
||||
Get conversation history as LLM messages.
|
||||
|
||||
Returns:
|
||||
List of LLMMessage objects including system prompt
|
||||
"""
|
||||
messages = [LLMMessage(role="system", content=self.system_prompt)]
|
||||
|
||||
# Add conversation history
|
||||
for turn in self.turns[-self.max_history:]:
|
||||
messages.append(LLMMessage(role=turn.role, content=turn.text))
|
||||
|
||||
# Add current user text if any
|
||||
if self._current_user_text:
|
||||
messages.append(LLMMessage(role="user", content=self._current_user_text))
|
||||
|
||||
return messages
|
||||
|
||||
async def start_user_turn(self) -> None:
|
||||
"""Signal that user has started speaking."""
|
||||
await self.set_state(ConversationState.LISTENING)
|
||||
self._current_user_text = ""
|
||||
|
||||
async def update_user_text(self, text: str, is_final: bool = False) -> None:
|
||||
"""
|
||||
Update current user text (from ASR).
|
||||
|
||||
Args:
|
||||
text: Transcribed text
|
||||
is_final: Whether this is the final transcript
|
||||
"""
|
||||
self._current_user_text = text
|
||||
|
||||
async def end_user_turn(self, text: str) -> None:
|
||||
"""
|
||||
End user turn and add to history.
|
||||
|
||||
Args:
|
||||
text: Final user text
|
||||
"""
|
||||
if text.strip():
|
||||
turn = ConversationTurn(role="user", text=text.strip())
|
||||
self.turns.append(turn)
|
||||
|
||||
for callback in self._turn_callbacks:
|
||||
try:
|
||||
await callback(turn)
|
||||
except Exception as e:
|
||||
logger.error(f"Turn callback error: {e}")
|
||||
|
||||
logger.info(f"User: {text[:50]}...")
|
||||
|
||||
self._current_user_text = ""
|
||||
await self.set_state(ConversationState.PROCESSING)
|
||||
|
||||
async def start_assistant_turn(self) -> None:
|
||||
"""Signal that assistant has started speaking."""
|
||||
await self.set_state(ConversationState.SPEAKING)
|
||||
self._current_assistant_text = ""
|
||||
|
||||
async def update_assistant_text(self, text: str) -> None:
|
||||
"""
|
||||
Update current assistant text (streaming).
|
||||
|
||||
Args:
|
||||
text: Text chunk from LLM
|
||||
"""
|
||||
self._current_assistant_text += text
|
||||
|
||||
async def end_assistant_turn(self, was_interrupted: bool = False) -> None:
|
||||
"""
|
||||
End assistant turn and add to history.
|
||||
|
||||
Args:
|
||||
was_interrupted: Whether the turn was interrupted by user
|
||||
"""
|
||||
text = self._current_assistant_text.strip()
|
||||
if text:
|
||||
turn = ConversationTurn(
|
||||
role="assistant",
|
||||
text=text,
|
||||
was_interrupted=was_interrupted
|
||||
)
|
||||
self.turns.append(turn)
|
||||
|
||||
for callback in self._turn_callbacks:
|
||||
try:
|
||||
await callback(turn)
|
||||
except Exception as e:
|
||||
logger.error(f"Turn callback error: {e}")
|
||||
|
||||
status = " (interrupted)" if was_interrupted else ""
|
||||
logger.info(f"Assistant{status}: {text[:50]}...")
|
||||
|
||||
self._current_assistant_text = ""
|
||||
|
||||
if was_interrupted:
|
||||
# A new user turn may already be active (LISTENING) when interrupted.
|
||||
# Avoid overriding it back to INTERRUPTED, which can stall EOU flow.
|
||||
if self.state != ConversationState.LISTENING:
|
||||
await self.set_state(ConversationState.INTERRUPTED)
|
||||
else:
|
||||
await self.set_state(ConversationState.IDLE)
|
||||
|
||||
async def add_assistant_turn(self, text: str, was_interrupted: bool = False) -> None:
|
||||
"""Append an assistant turn directly without mutating conversation state."""
|
||||
content = text.strip()
|
||||
if not content:
|
||||
return
|
||||
|
||||
turn = ConversationTurn(
|
||||
role="assistant",
|
||||
text=content,
|
||||
was_interrupted=was_interrupted,
|
||||
)
|
||||
self.turns.append(turn)
|
||||
|
||||
for callback in self._turn_callbacks:
|
||||
try:
|
||||
await callback(turn)
|
||||
except Exception as e:
|
||||
logger.error(f"Turn callback error: {e}")
|
||||
|
||||
logger.info(f"Assistant (injected): {content[:50]}...")
|
||||
|
||||
async def interrupt(self) -> None:
|
||||
"""Handle interruption (barge-in)."""
|
||||
if self.state == ConversationState.SPEAKING:
|
||||
await self.end_assistant_turn(was_interrupted=True)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset conversation history."""
|
||||
self.turns = []
|
||||
self._current_user_text = ""
|
||||
self._current_assistant_text = ""
|
||||
self.state = ConversationState.IDLE
|
||||
logger.info("Conversation reset")
|
||||
|
||||
@property
|
||||
def turn_count(self) -> int:
|
||||
"""Get number of turns in conversation."""
|
||||
return len(self.turns)
|
||||
|
||||
@property
|
||||
def last_user_text(self) -> Optional[str]:
|
||||
"""Get last user text."""
|
||||
for turn in reversed(self.turns):
|
||||
if turn.role == "user":
|
||||
return turn.text
|
||||
return None
|
||||
|
||||
@property
|
||||
def last_assistant_text(self) -> Optional[str]:
|
||||
"""Get last assistant text."""
|
||||
for turn in reversed(self.turns):
|
||||
if turn.role == "assistant":
|
||||
return turn.text
|
||||
return None
|
||||
|
||||
def get_context_summary(self) -> Dict[str, Any]:
|
||||
"""Get a summary of conversation context."""
|
||||
return {
|
||||
"state": self.state.value,
|
||||
"turn_count": self.turn_count,
|
||||
"last_user": self.last_user_text,
|
||||
"last_assistant": self.last_assistant_text,
|
||||
"current_user": self._current_user_text or None,
|
||||
"current_assistant": self._current_assistant_text or None
|
||||
}
|
||||
134
engine/runtime/events.py
Normal file
134
engine/runtime/events.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""Event bus for pub/sub communication between components."""
|
||||
|
||||
import asyncio
|
||||
from typing import Callable, Dict, List, Any, Optional
|
||||
from collections import defaultdict
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class EventBus:
|
||||
"""
|
||||
Async event bus for pub/sub communication.
|
||||
|
||||
Similar to the original Rust implementation's broadcast channel.
|
||||
Components can subscribe to specific event types and receive events asynchronously.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the event bus."""
|
||||
self._subscribers: Dict[str, List[Callable]] = defaultdict(list)
|
||||
self._lock = asyncio.Lock()
|
||||
self._running = True
|
||||
|
||||
def subscribe(self, event_type: str, callback: Callable[[Dict[str, Any]], None]) -> None:
|
||||
"""
|
||||
Subscribe to an event type.
|
||||
|
||||
Args:
|
||||
event_type: Type of event to subscribe to (e.g., "speaking", "silence")
|
||||
callback: Async callback function that receives event data
|
||||
"""
|
||||
if not self._running:
|
||||
logger.warning(f"Event bus is shut down, ignoring subscription to {event_type}")
|
||||
return
|
||||
|
||||
self._subscribers[event_type].append(callback)
|
||||
logger.debug(f"Subscribed to event type: {event_type}")
|
||||
|
||||
def unsubscribe(self, event_type: str, callback: Callable[[Dict[str, Any]], None]) -> None:
|
||||
"""
|
||||
Unsubscribe from an event type.
|
||||
|
||||
Args:
|
||||
event_type: Type of event to unsubscribe from
|
||||
callback: Callback function to remove
|
||||
"""
|
||||
if callback in self._subscribers[event_type]:
|
||||
self._subscribers[event_type].remove(callback)
|
||||
logger.debug(f"Unsubscribed from event type: {event_type}")
|
||||
|
||||
async def publish(self, event_type: str, event_data: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Publish an event to all subscribers.
|
||||
|
||||
Args:
|
||||
event_type: Type of event to publish
|
||||
event_data: Event data to send to subscribers
|
||||
"""
|
||||
if not self._running:
|
||||
logger.warning(f"Event bus is shut down, ignoring event: {event_type}")
|
||||
return
|
||||
|
||||
# Get subscribers for this event type
|
||||
subscribers = self._subscribers.get(event_type, [])
|
||||
|
||||
if not subscribers:
|
||||
logger.debug(f"No subscribers for event type: {event_type}")
|
||||
return
|
||||
|
||||
# Notify all subscribers concurrently
|
||||
tasks = []
|
||||
for callback in subscribers:
|
||||
try:
|
||||
# Create task for each subscriber
|
||||
task = asyncio.create_task(self._call_subscriber(callback, event_data))
|
||||
tasks.append(task)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating task for subscriber: {e}")
|
||||
|
||||
# Wait for all subscribers to complete
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
logger.debug(f"Published event '{event_type}' to {len(tasks)} subscribers")
|
||||
|
||||
async def _call_subscriber(self, callback: Callable[[Dict[str, Any]], None], event_data: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Call a subscriber callback with error handling.
|
||||
|
||||
Args:
|
||||
callback: Subscriber callback function
|
||||
event_data: Event data to pass to callback
|
||||
"""
|
||||
try:
|
||||
# Check if callback is a coroutine function
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(event_data)
|
||||
else:
|
||||
callback(event_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in subscriber callback: {e}", exc_info=True)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the event bus and stop processing events."""
|
||||
self._running = False
|
||||
self._subscribers.clear()
|
||||
logger.info("Event bus closed")
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the event bus is running."""
|
||||
return self._running
|
||||
|
||||
|
||||
# Global event bus instance
|
||||
_event_bus: Optional[EventBus] = None
|
||||
|
||||
|
||||
def get_event_bus() -> EventBus:
|
||||
"""
|
||||
Get the global event bus instance.
|
||||
|
||||
Returns:
|
||||
EventBus instance
|
||||
"""
|
||||
global _event_bus
|
||||
if _event_bus is None:
|
||||
_event_bus = EventBus()
|
||||
return _event_bus
|
||||
|
||||
|
||||
def reset_event_bus() -> None:
|
||||
"""Reset the global event bus (mainly for testing)."""
|
||||
global _event_bus
|
||||
_event_bus = None
|
||||
1
engine/runtime/history/__init__.py
Normal file
1
engine/runtime/history/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Runtime history package."""
|
||||
246
engine/runtime/history/bridge.py
Normal file
246
engine/runtime/history/bridge.py
Normal file
@@ -0,0 +1,246 @@
|
||||
"""Async history bridge for non-blocking transcript persistence."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from runtime.ports import ConversationHistoryStore
|
||||
|
||||
|
||||
@dataclass
|
||||
class _HistoryTranscriptJob:
|
||||
call_id: str
|
||||
turn_index: int
|
||||
speaker: str
|
||||
content: str
|
||||
start_ms: int
|
||||
end_ms: int
|
||||
duration_ms: int
|
||||
|
||||
|
||||
class SessionHistoryBridge:
|
||||
"""Session-scoped buffered history writer with background retries."""
|
||||
|
||||
_STOP_SENTINEL = object()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
history_writer: ConversationHistoryStore | None,
|
||||
enabled: bool,
|
||||
queue_max_size: int,
|
||||
retry_max_attempts: int,
|
||||
retry_backoff_sec: float,
|
||||
finalize_drain_timeout_sec: float,
|
||||
):
|
||||
self._history_writer = history_writer
|
||||
self._enabled = bool(enabled and history_writer is not None)
|
||||
self._queue_max_size = max(1, int(queue_max_size))
|
||||
self._retry_max_attempts = max(0, int(retry_max_attempts))
|
||||
self._retry_backoff_sec = max(0.0, float(retry_backoff_sec))
|
||||
self._finalize_drain_timeout_sec = max(0.0, float(finalize_drain_timeout_sec))
|
||||
|
||||
self._call_id: Optional[str] = None
|
||||
self._turn_index: int = 0
|
||||
self._started_mono: Optional[float] = None
|
||||
self._finalized: bool = False
|
||||
self._worker_task: Optional[asyncio.Task] = None
|
||||
self._finalize_lock = asyncio.Lock()
|
||||
self._queue: asyncio.Queue[_HistoryTranscriptJob | object] = asyncio.Queue(maxsize=self._queue_max_size)
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self._enabled
|
||||
|
||||
@property
|
||||
def call_id(self) -> Optional[str]:
|
||||
return self._call_id
|
||||
|
||||
async def start_call(
|
||||
self,
|
||||
*,
|
||||
user_id: int,
|
||||
assistant_id: Optional[str],
|
||||
source: str,
|
||||
) -> Optional[str]:
|
||||
"""Create remote call record and start background worker."""
|
||||
if not self._enabled or self._call_id:
|
||||
return self._call_id
|
||||
|
||||
call_id = await self._history_writer.create_call_record(
|
||||
user_id=user_id,
|
||||
assistant_id=assistant_id,
|
||||
source=source,
|
||||
)
|
||||
if not call_id:
|
||||
return None
|
||||
|
||||
self._call_id = str(call_id)
|
||||
self._turn_index = 0
|
||||
self._finalized = False
|
||||
self._started_mono = time.monotonic()
|
||||
self._ensure_worker()
|
||||
return self._call_id
|
||||
|
||||
def elapsed_ms(self) -> int:
|
||||
if self._started_mono is None:
|
||||
return 0
|
||||
return max(0, int((time.monotonic() - self._started_mono) * 1000))
|
||||
|
||||
def enqueue_turn(self, *, role: str, text: str) -> bool:
|
||||
"""Queue one transcript write without blocking the caller."""
|
||||
if not self._enabled or not self._call_id or self._finalized:
|
||||
return False
|
||||
|
||||
content = str(text or "").strip()
|
||||
if not content:
|
||||
return False
|
||||
|
||||
speaker = "human" if str(role or "").strip().lower() == "user" else "ai"
|
||||
end_ms = self.elapsed_ms()
|
||||
estimated_duration_ms = max(300, min(12000, len(content) * 80))
|
||||
start_ms = max(0, end_ms - estimated_duration_ms)
|
||||
|
||||
job = _HistoryTranscriptJob(
|
||||
call_id=self._call_id,
|
||||
turn_index=self._turn_index,
|
||||
speaker=speaker,
|
||||
content=content,
|
||||
start_ms=start_ms,
|
||||
end_ms=end_ms,
|
||||
duration_ms=max(1, end_ms - start_ms),
|
||||
)
|
||||
self._turn_index += 1
|
||||
self._ensure_worker()
|
||||
|
||||
try:
|
||||
self._queue.put_nowait(job)
|
||||
return True
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(
|
||||
"History queue full; dropping transcript call_id={} turn={}",
|
||||
self._call_id,
|
||||
job.turn_index,
|
||||
)
|
||||
return False
|
||||
|
||||
async def finalize(self, *, status: str) -> bool:
|
||||
"""Finalize history record once; waits briefly for queue drain."""
|
||||
if not self._enabled or not self._call_id:
|
||||
return False
|
||||
|
||||
async with self._finalize_lock:
|
||||
if self._finalized:
|
||||
return True
|
||||
|
||||
await self._drain_queue()
|
||||
ok = await self._history_writer.finalize_call_record(
|
||||
call_id=self._call_id,
|
||||
status=status,
|
||||
duration_seconds=self.duration_seconds(),
|
||||
)
|
||||
if ok:
|
||||
self._finalized = True
|
||||
await self._stop_worker()
|
||||
return ok
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Stop worker task and release queue resources."""
|
||||
await self._stop_worker()
|
||||
|
||||
def duration_seconds(self) -> int:
|
||||
if self._started_mono is None:
|
||||
return 0
|
||||
return max(0, int(time.monotonic() - self._started_mono))
|
||||
|
||||
def _ensure_worker(self) -> None:
|
||||
if self._worker_task and not self._worker_task.done():
|
||||
return
|
||||
self._worker_task = asyncio.create_task(self._worker_loop())
|
||||
|
||||
async def _drain_queue(self) -> None:
|
||||
if self._finalize_drain_timeout_sec <= 0:
|
||||
return
|
||||
try:
|
||||
await asyncio.wait_for(self._queue.join(), timeout=self._finalize_drain_timeout_sec)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("History queue drain timed out after {}s", self._finalize_drain_timeout_sec)
|
||||
|
||||
async def _stop_worker(self) -> None:
|
||||
task = self._worker_task
|
||||
if not task:
|
||||
return
|
||||
if task.done():
|
||||
self._worker_task = None
|
||||
return
|
||||
|
||||
sent = False
|
||||
try:
|
||||
self._queue.put_nowait(self._STOP_SENTINEL)
|
||||
sent = True
|
||||
except asyncio.QueueFull:
|
||||
pass
|
||||
|
||||
if not sent:
|
||||
try:
|
||||
await asyncio.wait_for(self._queue.put(self._STOP_SENTINEL), timeout=0.5)
|
||||
except asyncio.TimeoutError:
|
||||
task.cancel()
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(task, timeout=1.5)
|
||||
except asyncio.TimeoutError:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except Exception:
|
||||
pass
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
self._worker_task = None
|
||||
|
||||
async def _worker_loop(self) -> None:
|
||||
while True:
|
||||
item = await self._queue.get()
|
||||
try:
|
||||
if item is self._STOP_SENTINEL:
|
||||
return
|
||||
|
||||
assert isinstance(item, _HistoryTranscriptJob)
|
||||
await self._write_with_retry(item)
|
||||
except Exception as exc:
|
||||
logger.warning("History worker write failed unexpectedly: {}", exc)
|
||||
finally:
|
||||
self._queue.task_done()
|
||||
|
||||
async def _write_with_retry(self, job: _HistoryTranscriptJob) -> bool:
|
||||
for attempt in range(self._retry_max_attempts + 1):
|
||||
ok = await self._history_writer.add_transcript(
|
||||
call_id=job.call_id,
|
||||
turn_index=job.turn_index,
|
||||
speaker=job.speaker,
|
||||
content=job.content,
|
||||
start_ms=job.start_ms,
|
||||
end_ms=job.end_ms,
|
||||
duration_ms=job.duration_ms,
|
||||
)
|
||||
if ok:
|
||||
return True
|
||||
|
||||
if attempt >= self._retry_max_attempts:
|
||||
logger.warning(
|
||||
"History write dropped after retries call_id={} turn={}",
|
||||
job.call_id,
|
||||
job.turn_index,
|
||||
)
|
||||
return False
|
||||
|
||||
if self._retry_backoff_sec > 0:
|
||||
await asyncio.sleep(self._retry_backoff_sec * (2**attempt))
|
||||
return False
|
||||
1
engine/runtime/pipeline/__init__.py
Normal file
1
engine/runtime/pipeline/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Runtime pipeline package."""
|
||||
13
engine/runtime/pipeline/asr_flow.py
Normal file
13
engine/runtime/pipeline/asr_flow.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""ASR flow helpers extracted from the duplex pipeline.
|
||||
|
||||
This module is intentionally lightweight for phase-wise migration.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from providers.common.base import ASRResult
|
||||
|
||||
|
||||
def is_final_result(result: ASRResult) -> bool:
|
||||
"""Return whether an ASR result is final."""
|
||||
return bool(result.is_final)
|
||||
6
engine/runtime/pipeline/constants.py
Normal file
6
engine/runtime/pipeline/constants.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Shared constants for the runtime duplex pipeline."""
|
||||
|
||||
TRACK_AUDIO_IN = "audio_in"
|
||||
TRACK_AUDIO_OUT = "audio_out"
|
||||
TRACK_CONTROL = "control"
|
||||
PCM_FRAME_BYTES = 640 # 16k mono pcm_s16le, 20ms
|
||||
2764
engine/runtime/pipeline/duplex.py
Normal file
2764
engine/runtime/pipeline/duplex.py
Normal file
File diff suppressed because it is too large
Load Diff
12
engine/runtime/pipeline/events_out.py
Normal file
12
engine/runtime/pipeline/events_out.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""Output-event shaping helpers for the runtime pipeline."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
def assistant_text_delta_event(text: str, **extra: Any) -> Dict[str, Any]:
|
||||
"""Build a normalized assistant text delta payload."""
|
||||
payload: Dict[str, Any] = {"type": "assistant.text.delta", "text": str(text)}
|
||||
payload.update(extra)
|
||||
return payload
|
||||
8
engine/runtime/pipeline/interrupts.py
Normal file
8
engine/runtime/pipeline/interrupts.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""Interruption-related helpers extracted from the duplex pipeline."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def should_interrupt(min_duration_ms: int, detected_ms: int) -> bool:
|
||||
"""Decide whether interruption conditions are met."""
|
||||
return int(detected_ms) >= max(0, int(min_duration_ms))
|
||||
13
engine/runtime/pipeline/llm_flow.py
Normal file
13
engine/runtime/pipeline/llm_flow.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""LLM flow helpers extracted from the duplex pipeline.
|
||||
|
||||
This module is intentionally lightweight for phase-wise migration.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from providers.common.base import LLMStreamEvent
|
||||
|
||||
|
||||
def is_done_event(event: LLMStreamEvent) -> bool:
|
||||
"""Return whether an LLM stream event signals completion."""
|
||||
return str(event.type) == "done"
|
||||
13
engine/runtime/pipeline/tooling.py
Normal file
13
engine/runtime/pipeline/tooling.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""Tooling helpers extracted from the duplex pipeline."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def normalize_tool_name(name: Any, aliases: dict[str, str]) -> str:
|
||||
"""Normalize tool name with alias mapping."""
|
||||
normalized = str(name or "").strip()
|
||||
if not normalized:
|
||||
return ""
|
||||
return aliases.get(normalized, normalized)
|
||||
15
engine/runtime/pipeline/tts_flow.py
Normal file
15
engine/runtime/pipeline/tts_flow.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""TTS flow helpers extracted from the duplex pipeline.
|
||||
|
||||
This module is intentionally lightweight for phase-wise migration.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from providers.common.base import TTSChunk
|
||||
|
||||
|
||||
def chunk_duration_ms(chunk: TTSChunk) -> float:
|
||||
"""Estimate chunk duration in milliseconds for pcm16 mono."""
|
||||
if chunk.sample_rate <= 0:
|
||||
return 0.0
|
||||
return (len(chunk.audio) / 2.0 / float(chunk.sample_rate)) * 1000.0
|
||||
32
engine/runtime/ports/__init__.py
Normal file
32
engine/runtime/ports/__init__.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Port interfaces for runtime integration boundaries."""
|
||||
|
||||
from runtime.ports.asr import ASRBufferControl, ASRInterimControl, ASRPort, ASRServiceSpec
|
||||
from runtime.ports.control_plane import (
|
||||
AssistantRuntimeConfigProvider,
|
||||
ControlPlaneGateway,
|
||||
ConversationHistoryStore,
|
||||
KnowledgeRetriever,
|
||||
ToolCatalog,
|
||||
)
|
||||
from runtime.ports.llm import LLMCancellable, LLMPort, LLMRuntimeConfigurable, LLMServiceSpec
|
||||
from runtime.ports.service_factory import RealtimeServiceFactory
|
||||
from runtime.ports.tts import TTSPort, TTSServiceSpec
|
||||
|
||||
__all__ = [
|
||||
"ASRPort",
|
||||
"ASRServiceSpec",
|
||||
"ASRInterimControl",
|
||||
"ASRBufferControl",
|
||||
"AssistantRuntimeConfigProvider",
|
||||
"ControlPlaneGateway",
|
||||
"ConversationHistoryStore",
|
||||
"KnowledgeRetriever",
|
||||
"ToolCatalog",
|
||||
"LLMCancellable",
|
||||
"LLMPort",
|
||||
"LLMRuntimeConfigurable",
|
||||
"LLMServiceSpec",
|
||||
"RealtimeServiceFactory",
|
||||
"TTSPort",
|
||||
"TTSServiceSpec",
|
||||
]
|
||||
64
engine/runtime/ports/asr.py
Normal file
64
engine/runtime/ports/asr.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""ASR extension port contracts."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import AsyncIterator, Awaitable, Callable, Optional, Protocol
|
||||
|
||||
from providers.common.base import ASRResult
|
||||
|
||||
TranscriptCallback = Callable[[str, bool], Awaitable[None]]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ASRServiceSpec:
|
||||
"""Resolved runtime configuration for ASR service creation."""
|
||||
|
||||
provider: str
|
||||
sample_rate: int
|
||||
language: str = "auto"
|
||||
api_key: Optional[str] = None
|
||||
api_url: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
interim_interval_ms: int = 500
|
||||
min_audio_for_interim_ms: int = 300
|
||||
on_transcript: Optional[TranscriptCallback] = None
|
||||
|
||||
|
||||
class ASRPort(Protocol):
|
||||
"""Port for speech recognition providers."""
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Establish connection to ASR provider."""
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Release ASR resources."""
|
||||
|
||||
async def send_audio(self, audio: bytes) -> None:
|
||||
"""Push one PCM audio chunk for recognition."""
|
||||
|
||||
async def receive_transcripts(self) -> AsyncIterator[ASRResult]:
|
||||
"""Stream partial/final recognition results."""
|
||||
|
||||
|
||||
class ASRInterimControl(Protocol):
|
||||
"""Optional extension for explicit interim transcription control."""
|
||||
|
||||
async def start_interim_transcription(self) -> None:
|
||||
"""Start interim transcription loop if supported."""
|
||||
|
||||
async def stop_interim_transcription(self) -> None:
|
||||
"""Stop interim transcription loop if supported."""
|
||||
|
||||
|
||||
class ASRBufferControl(Protocol):
|
||||
"""Optional extension for explicit ASR buffer lifecycle control."""
|
||||
|
||||
def clear_buffer(self) -> None:
|
||||
"""Clear provider-side ASR buffer."""
|
||||
|
||||
async def get_final_transcription(self) -> str:
|
||||
"""Return final transcription for the current utterance."""
|
||||
|
||||
def get_and_clear_text(self) -> str:
|
||||
"""Return buffered text and clear internal state."""
|
||||
83
engine/runtime/ports/control_plane.py
Normal file
83
engine/runtime/ports/control_plane.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""Control-plane integration ports.
|
||||
|
||||
These interfaces define the boundary between engine runtime logic and
|
||||
control-plane capabilities (config lookup, history persistence, retrieval,
|
||||
and tool resource discovery).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Protocol
|
||||
|
||||
|
||||
class AssistantRuntimeConfigProvider(Protocol):
|
||||
"""Port for loading trusted assistant runtime configuration."""
|
||||
|
||||
async def fetch_assistant_config(self, assistant_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Fetch assistant configuration payload."""
|
||||
|
||||
|
||||
class ConversationHistoryStore(Protocol):
|
||||
"""Port for persisting call and transcript history."""
|
||||
|
||||
async def create_call_record(
|
||||
self,
|
||||
*,
|
||||
user_id: int,
|
||||
assistant_id: Optional[str],
|
||||
source: str = "debug",
|
||||
) -> Optional[str]:
|
||||
"""Create a call record and return control-plane call ID."""
|
||||
|
||||
async def add_transcript(
|
||||
self,
|
||||
*,
|
||||
call_id: str,
|
||||
turn_index: int,
|
||||
speaker: str,
|
||||
content: str,
|
||||
start_ms: int,
|
||||
end_ms: int,
|
||||
confidence: Optional[float] = None,
|
||||
duration_ms: Optional[int] = None,
|
||||
) -> bool:
|
||||
"""Append one transcript turn segment."""
|
||||
|
||||
async def finalize_call_record(
|
||||
self,
|
||||
*,
|
||||
call_id: str,
|
||||
status: str,
|
||||
duration_seconds: int,
|
||||
) -> bool:
|
||||
"""Finalize a call record."""
|
||||
|
||||
|
||||
class KnowledgeRetriever(Protocol):
|
||||
"""Port for RAG / knowledge retrieval operations."""
|
||||
|
||||
async def search_knowledge_context(
|
||||
self,
|
||||
*,
|
||||
kb_id: str,
|
||||
query: str,
|
||||
n_results: int = 5,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search a knowledge source and return ranked snippets."""
|
||||
|
||||
|
||||
class ToolCatalog(Protocol):
|
||||
"""Port for resolving tool metadata/configuration."""
|
||||
|
||||
async def fetch_tool_resource(self, tool_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Fetch tool resource configuration."""
|
||||
|
||||
|
||||
class ControlPlaneGateway(
|
||||
AssistantRuntimeConfigProvider,
|
||||
ConversationHistoryStore,
|
||||
KnowledgeRetriever,
|
||||
ToolCatalog,
|
||||
Protocol,
|
||||
):
|
||||
"""Composite control-plane gateway used by engine services."""
|
||||
67
engine/runtime/ports/llm.py
Normal file
67
engine/runtime/ports/llm.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""LLM extension port contracts."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Protocol
|
||||
|
||||
from providers.common.base import LLMMessage, LLMStreamEvent
|
||||
|
||||
KnowledgeRetrieverFn = Callable[..., Awaitable[List[Dict[str, Any]]]]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LLMServiceSpec:
|
||||
"""Resolved runtime configuration for LLM service creation."""
|
||||
|
||||
provider: str
|
||||
model: str
|
||||
api_key: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
system_prompt: Optional[str] = None
|
||||
temperature: float = 0.7
|
||||
knowledge_config: Dict[str, Any] = field(default_factory=dict)
|
||||
knowledge_searcher: Optional[KnowledgeRetrieverFn] = None
|
||||
|
||||
|
||||
class LLMPort(Protocol):
|
||||
"""Port for LLM providers."""
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Establish connection to LLM provider."""
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Release LLM resources."""
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
messages: List[LLMMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> str:
|
||||
"""Generate a complete assistant response."""
|
||||
|
||||
async def generate_stream(
|
||||
self,
|
||||
messages: List[LLMMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> AsyncIterator[LLMStreamEvent]:
|
||||
"""Generate streaming assistant response events."""
|
||||
|
||||
|
||||
class LLMCancellable(Protocol):
|
||||
"""Optional extension for interrupting in-flight LLM generation."""
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""Cancel an in-flight generation request."""
|
||||
|
||||
|
||||
class LLMRuntimeConfigurable(Protocol):
|
||||
"""Optional extension for runtime config updates."""
|
||||
|
||||
def set_knowledge_config(self, config: Optional[Dict[str, Any]]) -> None:
|
||||
"""Apply runtime knowledge retrieval settings."""
|
||||
|
||||
def set_tool_schemas(self, schemas: Optional[List[Dict[str, Any]]]) -> None:
|
||||
"""Apply runtime tool schemas used for tool calling."""
|
||||
22
engine/runtime/ports/service_factory.py
Normal file
22
engine/runtime/ports/service_factory.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""Factory port for creating runtime ASR/LLM/TTS services."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
from runtime.ports.asr import ASRPort, ASRServiceSpec
|
||||
from runtime.ports.llm import LLMPort, LLMServiceSpec
|
||||
from runtime.ports.tts import TTSPort, TTSServiceSpec
|
||||
|
||||
|
||||
class RealtimeServiceFactory(Protocol):
|
||||
"""Port for provider-specific service construction."""
|
||||
|
||||
def create_llm_service(self, spec: LLMServiceSpec) -> LLMPort:
|
||||
"""Create an LLM service instance from a resolved spec."""
|
||||
|
||||
def create_tts_service(self, spec: TTSServiceSpec) -> TTSPort:
|
||||
"""Create a TTS service instance from a resolved spec."""
|
||||
|
||||
def create_asr_service(self, spec: ASRServiceSpec) -> ASRPort:
|
||||
"""Create an ASR service instance from a resolved spec."""
|
||||
41
engine/runtime/ports/tts.py
Normal file
41
engine/runtime/ports/tts.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""TTS extension port contracts."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import AsyncIterator, Optional, Protocol
|
||||
|
||||
from providers.common.base import TTSChunk
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TTSServiceSpec:
|
||||
"""Resolved runtime configuration for TTS service creation."""
|
||||
|
||||
provider: str
|
||||
voice: str
|
||||
sample_rate: int
|
||||
speed: float = 1.0
|
||||
api_key: Optional[str] = None
|
||||
api_url: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
mode: str = "commit"
|
||||
|
||||
|
||||
class TTSPort(Protocol):
|
||||
"""Port for speech synthesis providers."""
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Establish connection to TTS provider."""
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Release TTS resources."""
|
||||
|
||||
async def synthesize(self, text: str) -> bytes:
|
||||
"""Synthesize complete PCM payload for text."""
|
||||
|
||||
async def synthesize_stream(self, text: str) -> AsyncIterator[TTSChunk]:
|
||||
"""Stream synthesized PCM chunks for text."""
|
||||
|
||||
async def cancel(self) -> None:
|
||||
"""Cancel an in-flight synthesis request."""
|
||||
1
engine/runtime/session/__init__.py
Normal file
1
engine/runtime/session/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Runtime session package."""
|
||||
10
engine/runtime/session/lifecycle.py
Normal file
10
engine/runtime/session/lifecycle.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""Lifecycle helper utilities for runtime sessions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
def utc_now_iso() -> str:
|
||||
"""Return current UTC timestamp in ISO 8601 format."""
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
1229
engine/runtime/session/manager.py
Normal file
1229
engine/runtime/session/manager.py
Normal file
File diff suppressed because it is too large
Load Diff
9
engine/runtime/session/metadata.py
Normal file
9
engine/runtime/session/metadata.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""Metadata helpers extracted from session manager."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Pattern
|
||||
|
||||
DYNAMIC_VARIABLE_KEY_RE: Pattern[str] = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]{0,63}$")
|
||||
DYNAMIC_VARIABLE_PLACEHOLDER_RE: Pattern[str] = re.compile(r"\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\}\}")
|
||||
12
engine/runtime/session/workflow_bridge.py
Normal file
12
engine/runtime/session/workflow_bridge.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""Workflow bridge helpers for runtime session orchestration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from workflow.runner import WorkflowRunner
|
||||
|
||||
|
||||
def has_active_workflow(workflow_runner: Optional[WorkflowRunner]) -> bool:
|
||||
"""Return whether a workflow runner exists and has a current node."""
|
||||
return bool(workflow_runner and workflow_runner.current_node is not None)
|
||||
247
engine/runtime/transports.py
Normal file
247
engine/runtime/transports.py
Normal file
@@ -0,0 +1,247 @@
|
||||
"""Transport layer for WebSocket and WebRTC communication."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from fastapi import WebSocket
|
||||
from starlette.websockets import WebSocketState
|
||||
from loguru import logger
|
||||
|
||||
# Try to import aiortc (optional for WebRTC functionality)
|
||||
try:
|
||||
from aiortc import RTCPeerConnection
|
||||
AIORTC_AVAILABLE = True
|
||||
except ImportError:
|
||||
AIORTC_AVAILABLE = False
|
||||
RTCPeerConnection = None # Type hint placeholder
|
||||
|
||||
|
||||
class BaseTransport(ABC):
|
||||
"""
|
||||
Abstract base class for transports.
|
||||
|
||||
All transports must implement send_event and send_audio methods.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def send_event(self, event: dict) -> None:
|
||||
"""
|
||||
Send a JSON event to the client.
|
||||
|
||||
Args:
|
||||
event: Event data as dictionary
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def send_audio(self, pcm_bytes: bytes) -> None:
|
||||
"""
|
||||
Send audio data to the client.
|
||||
|
||||
Args:
|
||||
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None:
|
||||
"""Close the transport and cleanup resources."""
|
||||
pass
|
||||
|
||||
|
||||
class SocketTransport(BaseTransport):
|
||||
"""
|
||||
WebSocket transport for raw audio streaming.
|
||||
|
||||
Handles mixed text/binary frames over WebSocket connection.
|
||||
Uses asyncio.Lock to prevent frame interleaving.
|
||||
"""
|
||||
|
||||
def __init__(self, websocket: WebSocket):
|
||||
"""
|
||||
Initialize WebSocket transport.
|
||||
|
||||
Args:
|
||||
websocket: FastAPI WebSocket connection
|
||||
"""
|
||||
self.ws = websocket
|
||||
self.lock = asyncio.Lock() # Prevent frame interleaving
|
||||
self._closed = False
|
||||
|
||||
def _ws_disconnected(self) -> bool:
|
||||
"""Best-effort check for websocket disconnection state."""
|
||||
return (
|
||||
self.ws.client_state == WebSocketState.DISCONNECTED
|
||||
or self.ws.application_state == WebSocketState.DISCONNECTED
|
||||
)
|
||||
|
||||
async def send_event(self, event: dict) -> None:
|
||||
"""
|
||||
Send a JSON event via WebSocket.
|
||||
|
||||
Args:
|
||||
event: Event data as dictionary
|
||||
"""
|
||||
if self._closed or self._ws_disconnected():
|
||||
logger.warning("Attempted to send event on closed transport")
|
||||
self._closed = True
|
||||
return
|
||||
|
||||
async with self.lock:
|
||||
try:
|
||||
await self.ws.send_text(json.dumps(event))
|
||||
logger.debug(f"Sent event: {event.get('event', 'unknown')}")
|
||||
except RuntimeError as e:
|
||||
self._closed = True
|
||||
if self._ws_disconnected() or "close message has been sent" in str(e):
|
||||
logger.debug(f"Skip sending event on closed websocket: {e!r}")
|
||||
return
|
||||
logger.error(f"Error sending event: {e!r}")
|
||||
except Exception as e:
|
||||
self._closed = True
|
||||
if self._ws_disconnected():
|
||||
logger.debug(f"Skip sending event on disconnected websocket: {e!r}")
|
||||
return
|
||||
logger.error(f"Error sending event: {e!r}")
|
||||
|
||||
async def send_audio(self, pcm_bytes: bytes) -> None:
|
||||
"""
|
||||
Send PCM audio data via WebSocket.
|
||||
|
||||
Args:
|
||||
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
|
||||
"""
|
||||
if self._closed or self._ws_disconnected():
|
||||
logger.warning("Attempted to send audio on closed transport")
|
||||
self._closed = True
|
||||
return
|
||||
|
||||
async with self.lock:
|
||||
try:
|
||||
await self.ws.send_bytes(pcm_bytes)
|
||||
except RuntimeError as e:
|
||||
self._closed = True
|
||||
if self._ws_disconnected() or "close message has been sent" in str(e):
|
||||
logger.debug(f"Skip sending audio on closed websocket: {e!r}")
|
||||
return
|
||||
logger.error(f"Error sending audio: {e!r}")
|
||||
except Exception as e:
|
||||
self._closed = True
|
||||
if self._ws_disconnected():
|
||||
logger.debug(f"Skip sending audio on disconnected websocket: {e!r}")
|
||||
return
|
||||
logger.error(f"Error sending audio: {e!r}")
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the WebSocket connection."""
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
self._closed = True
|
||||
if self._ws_disconnected():
|
||||
return
|
||||
|
||||
try:
|
||||
await self.ws.close()
|
||||
except RuntimeError as e:
|
||||
# Already closed by another task/path; safe to ignore.
|
||||
if "close message has been sent" in str(e):
|
||||
logger.debug(f"WebSocket already closed: {e}")
|
||||
return
|
||||
logger.error(f"Error closing WebSocket: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing WebSocket: {e}")
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
"""Check if the transport is closed."""
|
||||
return self._closed
|
||||
|
||||
|
||||
class WebRtcTransport(BaseTransport):
|
||||
"""
|
||||
WebRTC transport for WebRTC audio streaming.
|
||||
|
||||
Uses WebSocket for signaling and RTCPeerConnection for media.
|
||||
"""
|
||||
|
||||
def __init__(self, websocket: WebSocket, pc):
|
||||
"""
|
||||
Initialize WebRTC transport.
|
||||
|
||||
Args:
|
||||
websocket: FastAPI WebSocket connection for signaling
|
||||
pc: RTCPeerConnection for media transport
|
||||
"""
|
||||
if not AIORTC_AVAILABLE:
|
||||
raise RuntimeError("aiortc is not available - WebRTC transport cannot be used")
|
||||
|
||||
self.ws = websocket
|
||||
self.pc = pc
|
||||
self.outbound_track = None # MediaStreamTrack for outbound audio
|
||||
self._closed = False
|
||||
|
||||
async def send_event(self, event: dict) -> None:
|
||||
"""
|
||||
Send a JSON event via WebSocket signaling.
|
||||
|
||||
Args:
|
||||
event: Event data as dictionary
|
||||
"""
|
||||
if self._closed:
|
||||
logger.warning("Attempted to send event on closed transport")
|
||||
return
|
||||
|
||||
try:
|
||||
await self.ws.send_text(json.dumps(event))
|
||||
logger.debug(f"Sent event: {event.get('event', 'unknown')}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending event: {e}")
|
||||
self._closed = True
|
||||
|
||||
async def send_audio(self, pcm_bytes: bytes) -> None:
|
||||
"""
|
||||
Send audio data via WebRTC track.
|
||||
|
||||
Note: In WebRTC, you don't send bytes directly. You push frames
|
||||
to a MediaStreamTrack that the peer connection is reading.
|
||||
|
||||
Args:
|
||||
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
|
||||
"""
|
||||
if self._closed:
|
||||
logger.warning("Attempted to send audio on closed transport")
|
||||
return
|
||||
|
||||
# This would require a custom MediaStreamTrack implementation
|
||||
# For now, we'll log this as a placeholder
|
||||
logger.debug(f"Audio bytes queued for WebRTC track: {len(pcm_bytes)} bytes")
|
||||
|
||||
# TODO: Implement outbound audio track if needed
|
||||
# if self.outbound_track:
|
||||
# await self.outbound_track.add_frame(pcm_bytes)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the WebRTC connection."""
|
||||
self._closed = True
|
||||
try:
|
||||
await self.pc.close()
|
||||
await self.ws.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing WebRTC transport: {e}")
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
"""Check if the transport is closed."""
|
||||
return self._closed
|
||||
|
||||
def set_outbound_track(self, track):
|
||||
"""
|
||||
Set the outbound audio track for sending audio to client.
|
||||
|
||||
Args:
|
||||
track: MediaStreamTrack for outbound audio
|
||||
"""
|
||||
self.outbound_track = track
|
||||
logger.debug("Set outbound track for WebRTC transport")
|
||||
Reference in New Issue
Block a user