Init commit
This commit is contained in:
20
core/__init__.py
Normal file
20
core/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Core Components Package"""
|
||||
|
||||
from core.events import EventBus, get_event_bus
|
||||
from core.transports import BaseTransport, SocketTransport, WebRtcTransport
|
||||
from core.session import Session
|
||||
from core.conversation import ConversationManager, ConversationState, ConversationTurn
|
||||
from core.duplex_pipeline import DuplexPipeline
|
||||
|
||||
__all__ = [
|
||||
"EventBus",
|
||||
"get_event_bus",
|
||||
"BaseTransport",
|
||||
"SocketTransport",
|
||||
"WebRtcTransport",
|
||||
"Session",
|
||||
"ConversationManager",
|
||||
"ConversationState",
|
||||
"ConversationTurn",
|
||||
"DuplexPipeline",
|
||||
]
|
||||
279
core/conversation.py
Normal file
279
core/conversation.py
Normal file
@@ -0,0 +1,279 @@
|
||||
"""Conversation management for voice AI.
|
||||
|
||||
Handles conversation context, turn-taking, and message history
|
||||
for multi-turn voice conversations.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import List, Optional, Dict, Any, Callable, Awaitable
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from loguru import logger
|
||||
|
||||
from services.base import LLMMessage
|
||||
|
||||
|
||||
class ConversationState(Enum):
|
||||
"""State of the conversation."""
|
||||
IDLE = "idle" # Waiting for user input
|
||||
LISTENING = "listening" # User is speaking
|
||||
PROCESSING = "processing" # Processing user input (LLM)
|
||||
SPEAKING = "speaking" # Bot is speaking
|
||||
INTERRUPTED = "interrupted" # Bot was interrupted
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConversationTurn:
|
||||
"""A single turn in the conversation."""
|
||||
role: str # "user" or "assistant"
|
||||
text: str
|
||||
audio_duration_ms: Optional[int] = None
|
||||
timestamp: float = field(default_factory=lambda: asyncio.get_event_loop().time())
|
||||
was_interrupted: bool = False
|
||||
|
||||
|
||||
class ConversationManager:
|
||||
"""
|
||||
Manages conversation state and history.
|
||||
|
||||
Provides:
|
||||
- Message history for LLM context
|
||||
- Turn management
|
||||
- State tracking
|
||||
- Event callbacks for state changes
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
system_prompt: Optional[str] = None,
|
||||
max_history: int = 20,
|
||||
greeting: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize conversation manager.
|
||||
|
||||
Args:
|
||||
system_prompt: System prompt for LLM
|
||||
max_history: Maximum number of turns to keep
|
||||
greeting: Optional greeting message when conversation starts
|
||||
"""
|
||||
self.system_prompt = system_prompt or (
|
||||
"You are a helpful, friendly voice assistant. "
|
||||
"Keep your responses concise and conversational. "
|
||||
"Respond naturally as if having a phone conversation. "
|
||||
"If you don't understand something, ask for clarification."
|
||||
)
|
||||
self.max_history = max_history
|
||||
self.greeting = greeting
|
||||
|
||||
# State
|
||||
self.state = ConversationState.IDLE
|
||||
self.turns: List[ConversationTurn] = []
|
||||
|
||||
# Callbacks
|
||||
self._state_callbacks: List[Callable[[ConversationState, ConversationState], Awaitable[None]]] = []
|
||||
self._turn_callbacks: List[Callable[[ConversationTurn], Awaitable[None]]] = []
|
||||
|
||||
# Current turn tracking
|
||||
self._current_user_text: str = ""
|
||||
self._current_assistant_text: str = ""
|
||||
|
||||
logger.info("ConversationManager initialized")
|
||||
|
||||
def on_state_change(
|
||||
self,
|
||||
callback: Callable[[ConversationState, ConversationState], Awaitable[None]]
|
||||
) -> None:
|
||||
"""Register callback for state changes."""
|
||||
self._state_callbacks.append(callback)
|
||||
|
||||
def on_turn_complete(
|
||||
self,
|
||||
callback: Callable[[ConversationTurn], Awaitable[None]]
|
||||
) -> None:
|
||||
"""Register callback for turn completion."""
|
||||
self._turn_callbacks.append(callback)
|
||||
|
||||
async def set_state(self, new_state: ConversationState) -> None:
|
||||
"""Set conversation state and notify listeners."""
|
||||
if new_state != self.state:
|
||||
old_state = self.state
|
||||
self.state = new_state
|
||||
logger.debug(f"Conversation state: {old_state.value} -> {new_state.value}")
|
||||
|
||||
for callback in self._state_callbacks:
|
||||
try:
|
||||
await callback(old_state, new_state)
|
||||
except Exception as e:
|
||||
logger.error(f"State callback error: {e}")
|
||||
|
||||
def get_messages(self) -> List[LLMMessage]:
|
||||
"""
|
||||
Get conversation history as LLM messages.
|
||||
|
||||
Returns:
|
||||
List of LLMMessage objects including system prompt
|
||||
"""
|
||||
messages = [LLMMessage(role="system", content=self.system_prompt)]
|
||||
|
||||
# Add conversation history
|
||||
for turn in self.turns[-self.max_history:]:
|
||||
messages.append(LLMMessage(role=turn.role, content=turn.text))
|
||||
|
||||
# Add current user text if any
|
||||
if self._current_user_text:
|
||||
messages.append(LLMMessage(role="user", content=self._current_user_text))
|
||||
|
||||
return messages
|
||||
|
||||
async def start_user_turn(self) -> None:
|
||||
"""Signal that user has started speaking."""
|
||||
await self.set_state(ConversationState.LISTENING)
|
||||
self._current_user_text = ""
|
||||
|
||||
async def update_user_text(self, text: str, is_final: bool = False) -> None:
|
||||
"""
|
||||
Update current user text (from ASR).
|
||||
|
||||
Args:
|
||||
text: Transcribed text
|
||||
is_final: Whether this is the final transcript
|
||||
"""
|
||||
self._current_user_text = text
|
||||
|
||||
async def end_user_turn(self, text: str) -> None:
|
||||
"""
|
||||
End user turn and add to history.
|
||||
|
||||
Args:
|
||||
text: Final user text
|
||||
"""
|
||||
if text.strip():
|
||||
turn = ConversationTurn(role="user", text=text.strip())
|
||||
self.turns.append(turn)
|
||||
|
||||
for callback in self._turn_callbacks:
|
||||
try:
|
||||
await callback(turn)
|
||||
except Exception as e:
|
||||
logger.error(f"Turn callback error: {e}")
|
||||
|
||||
logger.info(f"User: {text[:50]}...")
|
||||
|
||||
self._current_user_text = ""
|
||||
await self.set_state(ConversationState.PROCESSING)
|
||||
|
||||
async def start_assistant_turn(self) -> None:
|
||||
"""Signal that assistant has started speaking."""
|
||||
await self.set_state(ConversationState.SPEAKING)
|
||||
self._current_assistant_text = ""
|
||||
|
||||
async def update_assistant_text(self, text: str) -> None:
|
||||
"""
|
||||
Update current assistant text (streaming).
|
||||
|
||||
Args:
|
||||
text: Text chunk from LLM
|
||||
"""
|
||||
self._current_assistant_text += text
|
||||
|
||||
async def end_assistant_turn(self, was_interrupted: bool = False) -> None:
|
||||
"""
|
||||
End assistant turn and add to history.
|
||||
|
||||
Args:
|
||||
was_interrupted: Whether the turn was interrupted by user
|
||||
"""
|
||||
text = self._current_assistant_text.strip()
|
||||
if text:
|
||||
turn = ConversationTurn(
|
||||
role="assistant",
|
||||
text=text,
|
||||
was_interrupted=was_interrupted
|
||||
)
|
||||
self.turns.append(turn)
|
||||
|
||||
for callback in self._turn_callbacks:
|
||||
try:
|
||||
await callback(turn)
|
||||
except Exception as e:
|
||||
logger.error(f"Turn callback error: {e}")
|
||||
|
||||
status = " (interrupted)" if was_interrupted else ""
|
||||
logger.info(f"Assistant{status}: {text[:50]}...")
|
||||
|
||||
self._current_assistant_text = ""
|
||||
|
||||
if was_interrupted:
|
||||
# A new user turn may already be active (LISTENING) when interrupted.
|
||||
# Avoid overriding it back to INTERRUPTED, which can stall EOU flow.
|
||||
if self.state != ConversationState.LISTENING:
|
||||
await self.set_state(ConversationState.INTERRUPTED)
|
||||
else:
|
||||
await self.set_state(ConversationState.IDLE)
|
||||
|
||||
async def add_assistant_turn(self, text: str, was_interrupted: bool = False) -> None:
|
||||
"""Append an assistant turn directly without mutating conversation state."""
|
||||
content = text.strip()
|
||||
if not content:
|
||||
return
|
||||
|
||||
turn = ConversationTurn(
|
||||
role="assistant",
|
||||
text=content,
|
||||
was_interrupted=was_interrupted,
|
||||
)
|
||||
self.turns.append(turn)
|
||||
|
||||
for callback in self._turn_callbacks:
|
||||
try:
|
||||
await callback(turn)
|
||||
except Exception as e:
|
||||
logger.error(f"Turn callback error: {e}")
|
||||
|
||||
logger.info(f"Assistant (injected): {content[:50]}...")
|
||||
|
||||
async def interrupt(self) -> None:
|
||||
"""Handle interruption (barge-in)."""
|
||||
if self.state == ConversationState.SPEAKING:
|
||||
await self.end_assistant_turn(was_interrupted=True)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset conversation history."""
|
||||
self.turns = []
|
||||
self._current_user_text = ""
|
||||
self._current_assistant_text = ""
|
||||
self.state = ConversationState.IDLE
|
||||
logger.info("Conversation reset")
|
||||
|
||||
@property
|
||||
def turn_count(self) -> int:
|
||||
"""Get number of turns in conversation."""
|
||||
return len(self.turns)
|
||||
|
||||
@property
|
||||
def last_user_text(self) -> Optional[str]:
|
||||
"""Get last user text."""
|
||||
for turn in reversed(self.turns):
|
||||
if turn.role == "user":
|
||||
return turn.text
|
||||
return None
|
||||
|
||||
@property
|
||||
def last_assistant_text(self) -> Optional[str]:
|
||||
"""Get last assistant text."""
|
||||
for turn in reversed(self.turns):
|
||||
if turn.role == "assistant":
|
||||
return turn.text
|
||||
return None
|
||||
|
||||
def get_context_summary(self) -> Dict[str, Any]:
|
||||
"""Get a summary of conversation context."""
|
||||
return {
|
||||
"state": self.state.value,
|
||||
"turn_count": self.turn_count,
|
||||
"last_user": self.last_user_text,
|
||||
"last_assistant": self.last_assistant_text,
|
||||
"current_user": self._current_user_text or None,
|
||||
"current_assistant": self._current_assistant_text or None
|
||||
}
|
||||
1507
core/duplex_pipeline.py
Normal file
1507
core/duplex_pipeline.py
Normal file
File diff suppressed because it is too large
Load Diff
134
core/events.py
Normal file
134
core/events.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""Event bus for pub/sub communication between components."""
|
||||
|
||||
import asyncio
|
||||
from typing import Callable, Dict, List, Any, Optional
|
||||
from collections import defaultdict
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class EventBus:
|
||||
"""
|
||||
Async event bus for pub/sub communication.
|
||||
|
||||
Similar to the original Rust implementation's broadcast channel.
|
||||
Components can subscribe to specific event types and receive events asynchronously.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the event bus."""
|
||||
self._subscribers: Dict[str, List[Callable]] = defaultdict(list)
|
||||
self._lock = asyncio.Lock()
|
||||
self._running = True
|
||||
|
||||
def subscribe(self, event_type: str, callback: Callable[[Dict[str, Any]], None]) -> None:
|
||||
"""
|
||||
Subscribe to an event type.
|
||||
|
||||
Args:
|
||||
event_type: Type of event to subscribe to (e.g., "speaking", "silence")
|
||||
callback: Async callback function that receives event data
|
||||
"""
|
||||
if not self._running:
|
||||
logger.warning(f"Event bus is shut down, ignoring subscription to {event_type}")
|
||||
return
|
||||
|
||||
self._subscribers[event_type].append(callback)
|
||||
logger.debug(f"Subscribed to event type: {event_type}")
|
||||
|
||||
def unsubscribe(self, event_type: str, callback: Callable[[Dict[str, Any]], None]) -> None:
|
||||
"""
|
||||
Unsubscribe from an event type.
|
||||
|
||||
Args:
|
||||
event_type: Type of event to unsubscribe from
|
||||
callback: Callback function to remove
|
||||
"""
|
||||
if callback in self._subscribers[event_type]:
|
||||
self._subscribers[event_type].remove(callback)
|
||||
logger.debug(f"Unsubscribed from event type: {event_type}")
|
||||
|
||||
async def publish(self, event_type: str, event_data: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Publish an event to all subscribers.
|
||||
|
||||
Args:
|
||||
event_type: Type of event to publish
|
||||
event_data: Event data to send to subscribers
|
||||
"""
|
||||
if not self._running:
|
||||
logger.warning(f"Event bus is shut down, ignoring event: {event_type}")
|
||||
return
|
||||
|
||||
# Get subscribers for this event type
|
||||
subscribers = self._subscribers.get(event_type, [])
|
||||
|
||||
if not subscribers:
|
||||
logger.debug(f"No subscribers for event type: {event_type}")
|
||||
return
|
||||
|
||||
# Notify all subscribers concurrently
|
||||
tasks = []
|
||||
for callback in subscribers:
|
||||
try:
|
||||
# Create task for each subscriber
|
||||
task = asyncio.create_task(self._call_subscriber(callback, event_data))
|
||||
tasks.append(task)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating task for subscriber: {e}")
|
||||
|
||||
# Wait for all subscribers to complete
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
logger.debug(f"Published event '{event_type}' to {len(tasks)} subscribers")
|
||||
|
||||
async def _call_subscriber(self, callback: Callable[[Dict[str, Any]], None], event_data: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Call a subscriber callback with error handling.
|
||||
|
||||
Args:
|
||||
callback: Subscriber callback function
|
||||
event_data: Event data to pass to callback
|
||||
"""
|
||||
try:
|
||||
# Check if callback is a coroutine function
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(event_data)
|
||||
else:
|
||||
callback(event_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in subscriber callback: {e}", exc_info=True)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the event bus and stop processing events."""
|
||||
self._running = False
|
||||
self._subscribers.clear()
|
||||
logger.info("Event bus closed")
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the event bus is running."""
|
||||
return self._running
|
||||
|
||||
|
||||
# Global event bus instance
|
||||
_event_bus: Optional[EventBus] = None
|
||||
|
||||
|
||||
def get_event_bus() -> EventBus:
|
||||
"""
|
||||
Get the global event bus instance.
|
||||
|
||||
Returns:
|
||||
EventBus instance
|
||||
"""
|
||||
global _event_bus
|
||||
if _event_bus is None:
|
||||
_event_bus = EventBus()
|
||||
return _event_bus
|
||||
|
||||
|
||||
def reset_event_bus() -> None:
|
||||
"""Reset the global event bus (mainly for testing)."""
|
||||
global _event_bus
|
||||
_event_bus = None
|
||||
648
core/session.py
Normal file
648
core/session.py
Normal file
@@ -0,0 +1,648 @@
|
||||
"""Session management for active calls."""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
import json
|
||||
import time
|
||||
import re
|
||||
from enum import Enum
|
||||
from typing import Optional, Dict, Any, List
|
||||
from loguru import logger
|
||||
|
||||
from app.backend_client import (
|
||||
create_history_call_record,
|
||||
add_history_transcript,
|
||||
finalize_history_call_record,
|
||||
)
|
||||
from core.transports import BaseTransport
|
||||
from core.duplex_pipeline import DuplexPipeline
|
||||
from core.conversation import ConversationTurn
|
||||
from core.workflow_runner import WorkflowRunner, WorkflowTransition, WorkflowNodeDef, WorkflowEdgeDef
|
||||
from app.config import settings
|
||||
from services.base import LLMMessage
|
||||
from models.ws_v1 import (
|
||||
parse_client_message,
|
||||
ev,
|
||||
HelloMessage,
|
||||
SessionStartMessage,
|
||||
SessionStopMessage,
|
||||
InputTextMessage,
|
||||
ResponseCancelMessage,
|
||||
ToolCallResultsMessage,
|
||||
)
|
||||
|
||||
|
||||
class WsSessionState(str, Enum):
|
||||
"""Protocol state machine for WS sessions."""
|
||||
|
||||
WAIT_HELLO = "wait_hello"
|
||||
WAIT_START = "wait_start"
|
||||
ACTIVE = "active"
|
||||
STOPPED = "stopped"
|
||||
|
||||
|
||||
class Session:
|
||||
"""
|
||||
Manages a single call session.
|
||||
|
||||
Handles command routing, audio processing, and session lifecycle.
|
||||
Uses full duplex voice conversation pipeline.
|
||||
"""
|
||||
|
||||
def __init__(self, session_id: str, transport: BaseTransport, use_duplex: bool = None):
|
||||
"""
|
||||
Initialize session.
|
||||
|
||||
Args:
|
||||
session_id: Unique session identifier
|
||||
transport: Transport instance for communication
|
||||
use_duplex: Whether to use duplex pipeline (defaults to settings.duplex_enabled)
|
||||
"""
|
||||
self.id = session_id
|
||||
self.transport = transport
|
||||
self.use_duplex = use_duplex if use_duplex is not None else settings.duplex_enabled
|
||||
|
||||
self.pipeline = DuplexPipeline(
|
||||
transport=transport,
|
||||
session_id=session_id,
|
||||
system_prompt=settings.duplex_system_prompt,
|
||||
greeting=settings.duplex_greeting
|
||||
)
|
||||
|
||||
# Session state
|
||||
self.created_at = None
|
||||
self.state = "created" # Legacy call state for /call/lists
|
||||
self.ws_state = WsSessionState.WAIT_HELLO
|
||||
self._pipeline_started = False
|
||||
self.protocol_version: Optional[str] = None
|
||||
self.authenticated: bool = False
|
||||
|
||||
# Track IDs
|
||||
self.current_track_id: Optional[str] = str(uuid.uuid4())
|
||||
self._history_call_id: Optional[str] = None
|
||||
self._history_turn_index: int = 0
|
||||
self._history_call_started_mono: Optional[float] = None
|
||||
self._history_finalized: bool = False
|
||||
self._cleanup_lock = asyncio.Lock()
|
||||
self._cleaned_up = False
|
||||
self.workflow_runner: Optional[WorkflowRunner] = None
|
||||
self._workflow_last_user_text: str = ""
|
||||
self._workflow_initial_node: Optional[WorkflowNodeDef] = None
|
||||
|
||||
self.pipeline.conversation.on_turn_complete(self._on_turn_complete)
|
||||
|
||||
logger.info(f"Session {self.id} created (duplex={self.use_duplex})")
|
||||
|
||||
async def handle_text(self, text_data: str) -> None:
|
||||
"""
|
||||
Handle incoming text data (WS v1 JSON control messages).
|
||||
|
||||
Args:
|
||||
text_data: JSON text data
|
||||
"""
|
||||
try:
|
||||
data = json.loads(text_data)
|
||||
message = parse_client_message(data)
|
||||
await self._handle_v1_message(message)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Session {self.id} JSON decode error: {e}")
|
||||
await self._send_error("client", f"Invalid JSON: {e}", "protocol.invalid_json")
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Session {self.id} command parse error: {e}")
|
||||
await self._send_error("client", f"Invalid message: {e}", "protocol.invalid_message")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Session {self.id} handle_text error: {e}", exc_info=True)
|
||||
await self._send_error("server", f"Internal error: {e}", "server.internal")
|
||||
|
||||
async def handle_audio(self, audio_bytes: bytes) -> None:
|
||||
"""
|
||||
Handle incoming audio data.
|
||||
|
||||
Args:
|
||||
audio_bytes: PCM audio data
|
||||
"""
|
||||
if self.ws_state != WsSessionState.ACTIVE:
|
||||
await self._send_error(
|
||||
"client",
|
||||
"Audio received before session.start",
|
||||
"protocol.order",
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
await self.pipeline.process_audio(audio_bytes)
|
||||
except Exception as e:
|
||||
logger.error(f"Session {self.id} handle_audio error: {e}", exc_info=True)
|
||||
|
||||
async def _handle_v1_message(self, message: Any) -> None:
|
||||
"""Route validated WS v1 message to handlers."""
|
||||
msg_type = message.type
|
||||
logger.info(f"Session {self.id} received message: {msg_type}")
|
||||
|
||||
if isinstance(message, HelloMessage):
|
||||
await self._handle_hello(message)
|
||||
return
|
||||
|
||||
# All messages below require hello handshake first
|
||||
if self.ws_state == WsSessionState.WAIT_HELLO:
|
||||
await self._send_error(
|
||||
"client",
|
||||
"Expected hello message first",
|
||||
"protocol.order",
|
||||
)
|
||||
return
|
||||
|
||||
if isinstance(message, SessionStartMessage):
|
||||
await self._handle_session_start(message)
|
||||
return
|
||||
|
||||
# All messages below require active session
|
||||
if self.ws_state != WsSessionState.ACTIVE:
|
||||
await self._send_error(
|
||||
"client",
|
||||
f"Message '{msg_type}' requires active session",
|
||||
"protocol.order",
|
||||
)
|
||||
return
|
||||
|
||||
if isinstance(message, InputTextMessage):
|
||||
await self.pipeline.process_text(message.text)
|
||||
elif isinstance(message, ResponseCancelMessage):
|
||||
if message.graceful:
|
||||
logger.info(f"Session {self.id} graceful response.cancel")
|
||||
else:
|
||||
await self.pipeline.interrupt()
|
||||
elif isinstance(message, ToolCallResultsMessage):
|
||||
await self.pipeline.handle_tool_call_results(message.results)
|
||||
elif isinstance(message, SessionStopMessage):
|
||||
await self._handle_session_stop(message.reason)
|
||||
else:
|
||||
await self._send_error("client", f"Unsupported message type: {msg_type}", "protocol.unsupported")
|
||||
|
||||
async def _handle_hello(self, message: HelloMessage) -> None:
|
||||
"""Handle initial hello/auth/version negotiation."""
|
||||
if self.ws_state != WsSessionState.WAIT_HELLO:
|
||||
await self._send_error("client", "Duplicate hello", "protocol.order")
|
||||
return
|
||||
|
||||
if message.version != settings.ws_protocol_version:
|
||||
await self._send_error(
|
||||
"client",
|
||||
f"Unsupported protocol version '{message.version}'",
|
||||
"protocol.version_unsupported",
|
||||
)
|
||||
await self.transport.close()
|
||||
self.ws_state = WsSessionState.STOPPED
|
||||
return
|
||||
|
||||
auth_payload = message.auth or {}
|
||||
api_key = auth_payload.get("apiKey")
|
||||
jwt = auth_payload.get("jwt")
|
||||
|
||||
if settings.ws_api_key:
|
||||
if api_key != settings.ws_api_key:
|
||||
await self._send_error("auth", "Invalid API key", "auth.invalid_api_key")
|
||||
await self.transport.close()
|
||||
self.ws_state = WsSessionState.STOPPED
|
||||
return
|
||||
elif settings.ws_require_auth and not (api_key or jwt):
|
||||
await self._send_error("auth", "Authentication required", "auth.required")
|
||||
await self.transport.close()
|
||||
self.ws_state = WsSessionState.STOPPED
|
||||
return
|
||||
|
||||
self.authenticated = True
|
||||
self.protocol_version = message.version
|
||||
self.ws_state = WsSessionState.WAIT_START
|
||||
await self.transport.send_event(
|
||||
ev(
|
||||
"hello.ack",
|
||||
sessionId=self.id,
|
||||
version=self.protocol_version,
|
||||
)
|
||||
)
|
||||
|
||||
async def _handle_session_start(self, message: SessionStartMessage) -> None:
|
||||
"""Handle explicit session start after successful hello."""
|
||||
if self.ws_state != WsSessionState.WAIT_START:
|
||||
await self._send_error("client", "Duplicate session.start", "protocol.order")
|
||||
return
|
||||
|
||||
metadata = message.metadata or {}
|
||||
metadata = self._merge_runtime_metadata(metadata, self._bootstrap_workflow(metadata))
|
||||
|
||||
# Create history call record early so later turn callbacks can append transcripts.
|
||||
await self._start_history_bridge(metadata)
|
||||
|
||||
# Apply runtime service/prompt overrides from backend if provided
|
||||
self.pipeline.apply_runtime_overrides(metadata)
|
||||
|
||||
# Start duplex pipeline
|
||||
if not self._pipeline_started:
|
||||
await self.pipeline.start()
|
||||
self._pipeline_started = True
|
||||
logger.info(f"Session {self.id} duplex pipeline started")
|
||||
|
||||
self.state = "accepted"
|
||||
self.ws_state = WsSessionState.ACTIVE
|
||||
await self.transport.send_event(
|
||||
ev(
|
||||
"session.started",
|
||||
sessionId=self.id,
|
||||
trackId=self.current_track_id,
|
||||
audio=message.audio or {},
|
||||
)
|
||||
)
|
||||
if self.workflow_runner and self._workflow_initial_node:
|
||||
await self.transport.send_event(
|
||||
ev(
|
||||
"workflow.started",
|
||||
sessionId=self.id,
|
||||
workflowId=self.workflow_runner.workflow_id,
|
||||
workflowName=self.workflow_runner.name,
|
||||
nodeId=self._workflow_initial_node.id,
|
||||
)
|
||||
)
|
||||
await self.transport.send_event(
|
||||
ev(
|
||||
"workflow.node.entered",
|
||||
sessionId=self.id,
|
||||
workflowId=self.workflow_runner.workflow_id,
|
||||
nodeId=self._workflow_initial_node.id,
|
||||
nodeName=self._workflow_initial_node.name,
|
||||
nodeType=self._workflow_initial_node.node_type,
|
||||
)
|
||||
)
|
||||
|
||||
async def _handle_session_stop(self, reason: Optional[str]) -> None:
|
||||
"""Handle session stop."""
|
||||
if self.ws_state == WsSessionState.STOPPED:
|
||||
return
|
||||
|
||||
stop_reason = reason or "client_requested"
|
||||
self.state = "hungup"
|
||||
self.ws_state = WsSessionState.STOPPED
|
||||
await self.transport.send_event(
|
||||
ev(
|
||||
"session.stopped",
|
||||
sessionId=self.id,
|
||||
reason=stop_reason,
|
||||
)
|
||||
)
|
||||
await self._finalize_history(status="connected")
|
||||
await self.transport.close()
|
||||
|
||||
async def _send_error(self, sender: str, error_message: str, code: str) -> None:
|
||||
"""
|
||||
Send error event to client.
|
||||
|
||||
Args:
|
||||
sender: Component that generated the error
|
||||
error_message: Error message
|
||||
code: Machine-readable error code
|
||||
"""
|
||||
await self.transport.send_event(
|
||||
ev(
|
||||
"error",
|
||||
sender=sender,
|
||||
code=code,
|
||||
message=error_message,
|
||||
trackId=self.current_track_id,
|
||||
)
|
||||
)
|
||||
|
||||
def _get_timestamp_ms(self) -> int:
|
||||
"""Get current timestamp in milliseconds."""
|
||||
import time
|
||||
return int(time.time() * 1000)
|
||||
|
||||
async def cleanup(self) -> None:
|
||||
"""Cleanup session resources."""
|
||||
async with self._cleanup_lock:
|
||||
if self._cleaned_up:
|
||||
return
|
||||
|
||||
self._cleaned_up = True
|
||||
logger.info(f"Session {self.id} cleaning up")
|
||||
await self._finalize_history(status="connected")
|
||||
await self.pipeline.cleanup()
|
||||
await self.transport.close()
|
||||
|
||||
async def _start_history_bridge(self, metadata: Dict[str, Any]) -> None:
|
||||
"""Initialize backend history call record for this session."""
|
||||
if self._history_call_id:
|
||||
return
|
||||
|
||||
history_meta: Dict[str, Any] = {}
|
||||
if isinstance(metadata.get("history"), dict):
|
||||
history_meta = metadata["history"]
|
||||
|
||||
raw_user_id = history_meta.get("userId", metadata.get("userId", settings.history_default_user_id))
|
||||
try:
|
||||
user_id = int(raw_user_id)
|
||||
except (TypeError, ValueError):
|
||||
user_id = settings.history_default_user_id
|
||||
|
||||
assistant_id = history_meta.get("assistantId", metadata.get("assistantId"))
|
||||
source = str(history_meta.get("source", metadata.get("source", "debug")))
|
||||
|
||||
call_id = await create_history_call_record(
|
||||
user_id=user_id,
|
||||
assistant_id=str(assistant_id) if assistant_id else None,
|
||||
source=source,
|
||||
)
|
||||
if not call_id:
|
||||
return
|
||||
|
||||
self._history_call_id = call_id
|
||||
self._history_call_started_mono = time.monotonic()
|
||||
self._history_turn_index = 0
|
||||
self._history_finalized = False
|
||||
logger.info(f"Session {self.id} history bridge enabled (call_id={call_id}, source={source})")
|
||||
|
||||
async def _on_turn_complete(self, turn: ConversationTurn) -> None:
|
||||
"""Process workflow transitions and persist completed turns to history."""
|
||||
if turn.text and turn.text.strip():
|
||||
role = (turn.role or "").lower()
|
||||
if role == "user":
|
||||
self._workflow_last_user_text = turn.text.strip()
|
||||
elif role == "assistant":
|
||||
await self._maybe_advance_workflow(turn.text.strip())
|
||||
|
||||
if not self._history_call_id:
|
||||
return
|
||||
if not turn.text or not turn.text.strip():
|
||||
return
|
||||
|
||||
role = (turn.role or "").lower()
|
||||
speaker = "human" if role == "user" else "ai"
|
||||
|
||||
end_ms = 0
|
||||
if self._history_call_started_mono is not None:
|
||||
end_ms = max(0, int((time.monotonic() - self._history_call_started_mono) * 1000))
|
||||
estimated_duration_ms = max(300, min(12000, len(turn.text.strip()) * 80))
|
||||
start_ms = max(0, end_ms - estimated_duration_ms)
|
||||
|
||||
turn_index = self._history_turn_index
|
||||
await add_history_transcript(
|
||||
call_id=self._history_call_id,
|
||||
turn_index=turn_index,
|
||||
speaker=speaker,
|
||||
content=turn.text.strip(),
|
||||
start_ms=start_ms,
|
||||
end_ms=end_ms,
|
||||
duration_ms=max(1, end_ms - start_ms),
|
||||
)
|
||||
self._history_turn_index += 1
|
||||
|
||||
async def _finalize_history(self, status: str) -> None:
|
||||
"""Finalize history call record once."""
|
||||
if not self._history_call_id or self._history_finalized:
|
||||
return
|
||||
|
||||
duration_seconds = 0
|
||||
if self._history_call_started_mono is not None:
|
||||
duration_seconds = max(0, int(time.monotonic() - self._history_call_started_mono))
|
||||
|
||||
ok = await finalize_history_call_record(
|
||||
call_id=self._history_call_id,
|
||||
status=status,
|
||||
duration_seconds=duration_seconds,
|
||||
)
|
||||
if ok:
|
||||
self._history_finalized = True
|
||||
|
||||
def _bootstrap_workflow(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Parse workflow payload and return initial runtime overrides."""
|
||||
payload = metadata.get("workflow")
|
||||
self.workflow_runner = WorkflowRunner.from_payload(payload)
|
||||
self._workflow_initial_node = None
|
||||
if not self.workflow_runner:
|
||||
return {}
|
||||
|
||||
node = self.workflow_runner.bootstrap()
|
||||
if not node:
|
||||
logger.warning(f"Session {self.id} workflow payload had no resolvable start node")
|
||||
self.workflow_runner = None
|
||||
return {}
|
||||
|
||||
self._workflow_initial_node = node
|
||||
logger.info(
|
||||
"Session {} workflow enabled: workflow={} start_node={}",
|
||||
self.id,
|
||||
self.workflow_runner.workflow_id,
|
||||
node.id,
|
||||
)
|
||||
return self.workflow_runner.build_runtime_metadata(node)
|
||||
|
||||
async def _maybe_advance_workflow(self, assistant_text: str) -> None:
|
||||
"""Attempt node transfer after assistant turn finalization."""
|
||||
if not self.workflow_runner or self.ws_state == WsSessionState.STOPPED:
|
||||
return
|
||||
|
||||
transition = await self.workflow_runner.route(
|
||||
user_text=self._workflow_last_user_text,
|
||||
assistant_text=assistant_text,
|
||||
llm_router=self._workflow_llm_route,
|
||||
)
|
||||
if not transition:
|
||||
return
|
||||
|
||||
await self._apply_workflow_transition(transition, reason="rule_match")
|
||||
|
||||
# Auto-advance through utility nodes when default edges are present.
|
||||
max_auto_hops = 6
|
||||
auto_hops = 0
|
||||
while self.workflow_runner and self.ws_state != WsSessionState.STOPPED:
|
||||
current = self.workflow_runner.current_node
|
||||
if not current or current.node_type not in {"start", "tool"}:
|
||||
break
|
||||
|
||||
next_default = self.workflow_runner.next_default_transition()
|
||||
if not next_default:
|
||||
break
|
||||
|
||||
auto_hops += 1
|
||||
await self._apply_workflow_transition(next_default, reason="auto")
|
||||
if auto_hops >= max_auto_hops:
|
||||
logger.warning(
|
||||
"Session {} workflow auto-advance reached hop limit (possible cycle)",
|
||||
self.id,
|
||||
)
|
||||
break
|
||||
|
||||
async def _apply_workflow_transition(self, transition: WorkflowTransition, reason: str) -> None:
|
||||
"""Apply graph transition and emit workflow lifecycle events."""
|
||||
if not self.workflow_runner:
|
||||
return
|
||||
|
||||
self.workflow_runner.apply_transition(transition)
|
||||
node = transition.node
|
||||
edge = transition.edge
|
||||
|
||||
await self.transport.send_event(
|
||||
ev(
|
||||
"workflow.edge.taken",
|
||||
sessionId=self.id,
|
||||
workflowId=self.workflow_runner.workflow_id,
|
||||
edgeId=edge.id,
|
||||
fromNodeId=edge.from_node_id,
|
||||
toNodeId=edge.to_node_id,
|
||||
reason=reason,
|
||||
)
|
||||
)
|
||||
await self.transport.send_event(
|
||||
ev(
|
||||
"workflow.node.entered",
|
||||
sessionId=self.id,
|
||||
workflowId=self.workflow_runner.workflow_id,
|
||||
nodeId=node.id,
|
||||
nodeName=node.name,
|
||||
nodeType=node.node_type,
|
||||
)
|
||||
)
|
||||
|
||||
node_runtime = self.workflow_runner.build_runtime_metadata(node)
|
||||
if node_runtime:
|
||||
self.pipeline.apply_runtime_overrides(node_runtime)
|
||||
|
||||
if node.node_type == "tool":
|
||||
await self.transport.send_event(
|
||||
ev(
|
||||
"workflow.tool.requested",
|
||||
sessionId=self.id,
|
||||
workflowId=self.workflow_runner.workflow_id,
|
||||
nodeId=node.id,
|
||||
tool=node.tool or {},
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
if node.node_type == "human_transfer":
|
||||
await self.transport.send_event(
|
||||
ev(
|
||||
"workflow.human_transfer",
|
||||
sessionId=self.id,
|
||||
workflowId=self.workflow_runner.workflow_id,
|
||||
nodeId=node.id,
|
||||
)
|
||||
)
|
||||
await self._handle_session_stop("workflow_human_transfer")
|
||||
return
|
||||
|
||||
if node.node_type == "end":
|
||||
await self.transport.send_event(
|
||||
ev(
|
||||
"workflow.ended",
|
||||
sessionId=self.id,
|
||||
workflowId=self.workflow_runner.workflow_id,
|
||||
nodeId=node.id,
|
||||
)
|
||||
)
|
||||
await self._handle_session_stop("workflow_end")
|
||||
|
||||
async def _workflow_llm_route(
|
||||
self,
|
||||
node: WorkflowNodeDef,
|
||||
candidates: List[WorkflowEdgeDef],
|
||||
context: Dict[str, str],
|
||||
) -> Optional[str]:
|
||||
"""LLM-based edge routing for condition.type == 'llm' edges."""
|
||||
llm_service = self.pipeline.llm_service
|
||||
if not llm_service:
|
||||
return None
|
||||
|
||||
candidate_rows = [
|
||||
{
|
||||
"edgeId": edge.id,
|
||||
"toNodeId": edge.to_node_id,
|
||||
"label": edge.label,
|
||||
"hint": edge.condition.get("prompt") if isinstance(edge.condition, dict) else None,
|
||||
}
|
||||
for edge in candidates
|
||||
]
|
||||
system_prompt = (
|
||||
"You are a workflow router. Pick exactly one edge. "
|
||||
"Return JSON only: {\"edgeId\":\"...\"}."
|
||||
)
|
||||
user_prompt = json.dumps(
|
||||
{
|
||||
"nodeId": node.id,
|
||||
"nodeName": node.name,
|
||||
"userText": context.get("userText", ""),
|
||||
"assistantText": context.get("assistantText", ""),
|
||||
"candidates": candidate_rows,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
try:
|
||||
reply = await llm_service.generate(
|
||||
[
|
||||
LLMMessage(role="system", content=system_prompt),
|
||||
LLMMessage(role="user", content=user_prompt),
|
||||
],
|
||||
temperature=0.0,
|
||||
max_tokens=64,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(f"Session {self.id} workflow llm routing failed: {exc}")
|
||||
return None
|
||||
|
||||
if not reply:
|
||||
return None
|
||||
|
||||
edge_ids = {edge.id for edge in candidates}
|
||||
node_ids = {edge.to_node_id for edge in candidates}
|
||||
|
||||
parsed = self._extract_json_obj(reply)
|
||||
if isinstance(parsed, dict):
|
||||
edge_id = parsed.get("edgeId") or parsed.get("id")
|
||||
node_id = parsed.get("toNodeId") or parsed.get("nodeId")
|
||||
if isinstance(edge_id, str) and edge_id in edge_ids:
|
||||
return edge_id
|
||||
if isinstance(node_id, str) and node_id in node_ids:
|
||||
return node_id
|
||||
|
||||
token_candidates = sorted(edge_ids | node_ids, key=len, reverse=True)
|
||||
lowered_reply = reply.lower()
|
||||
for token in token_candidates:
|
||||
if token.lower() in lowered_reply:
|
||||
return token
|
||||
return None
|
||||
|
||||
def _merge_runtime_metadata(self, base: Dict[str, Any], overrides: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Merge node-level metadata overrides into session.start metadata."""
|
||||
merged = dict(base or {})
|
||||
if not overrides:
|
||||
return merged
|
||||
for key, value in overrides.items():
|
||||
if key == "services" and isinstance(value, dict):
|
||||
existing = merged.get("services")
|
||||
merged_services = dict(existing) if isinstance(existing, dict) else {}
|
||||
merged_services.update(value)
|
||||
merged["services"] = merged_services
|
||||
else:
|
||||
merged[key] = value
|
||||
return merged
|
||||
|
||||
def _extract_json_obj(self, text: str) -> Optional[Dict[str, Any]]:
|
||||
"""Best-effort extraction of a JSON object from freeform text."""
|
||||
try:
|
||||
parsed = json.loads(text)
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
match = re.search(r"\{.*\}", text, re.DOTALL)
|
||||
if not match:
|
||||
return None
|
||||
try:
|
||||
parsed = json.loads(match.group(0))
|
||||
return parsed if isinstance(parsed, dict) else None
|
||||
except Exception:
|
||||
return None
|
||||
340
core/tool_executor.py
Normal file
340
core/tool_executor.py
Normal file
@@ -0,0 +1,340 @@
|
||||
"""Server-side tool execution helpers."""
|
||||
|
||||
import asyncio
|
||||
import ast
|
||||
import operator
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict
|
||||
|
||||
import aiohttp
|
||||
|
||||
from app.backend_client import fetch_tool_resource
|
||||
|
||||
_BIN_OPS = {
|
||||
ast.Add: operator.add,
|
||||
ast.Sub: operator.sub,
|
||||
ast.Mult: operator.mul,
|
||||
ast.Div: operator.truediv,
|
||||
ast.Mod: operator.mod,
|
||||
}
|
||||
|
||||
_UNARY_OPS = {
|
||||
ast.UAdd: operator.pos,
|
||||
ast.USub: operator.neg,
|
||||
}
|
||||
|
||||
_SAFE_EVAL_FUNCS = {
|
||||
"abs": abs,
|
||||
"round": round,
|
||||
"min": min,
|
||||
"max": max,
|
||||
"sum": sum,
|
||||
"len": len,
|
||||
}
|
||||
|
||||
|
||||
def _validate_safe_expr(node: ast.AST) -> None:
|
||||
"""Allow only a constrained subset of Python expression nodes."""
|
||||
if isinstance(node, ast.Expression):
|
||||
_validate_safe_expr(node.body)
|
||||
return
|
||||
|
||||
if isinstance(node, ast.Constant):
|
||||
return
|
||||
|
||||
if isinstance(node, (ast.List, ast.Tuple, ast.Set)):
|
||||
for elt in node.elts:
|
||||
_validate_safe_expr(elt)
|
||||
return
|
||||
|
||||
if isinstance(node, ast.Dict):
|
||||
for key in node.keys:
|
||||
if key is not None:
|
||||
_validate_safe_expr(key)
|
||||
for value in node.values:
|
||||
_validate_safe_expr(value)
|
||||
return
|
||||
|
||||
if isinstance(node, ast.BinOp):
|
||||
if type(node.op) not in _BIN_OPS:
|
||||
raise ValueError("unsupported operator")
|
||||
_validate_safe_expr(node.left)
|
||||
_validate_safe_expr(node.right)
|
||||
return
|
||||
|
||||
if isinstance(node, ast.UnaryOp):
|
||||
if type(node.op) not in _UNARY_OPS:
|
||||
raise ValueError("unsupported unary operator")
|
||||
_validate_safe_expr(node.operand)
|
||||
return
|
||||
|
||||
if isinstance(node, ast.BoolOp):
|
||||
for value in node.values:
|
||||
_validate_safe_expr(value)
|
||||
return
|
||||
|
||||
if isinstance(node, ast.Compare):
|
||||
_validate_safe_expr(node.left)
|
||||
for comp in node.comparators:
|
||||
_validate_safe_expr(comp)
|
||||
return
|
||||
|
||||
if isinstance(node, ast.Name):
|
||||
if node.id not in _SAFE_EVAL_FUNCS and node.id not in {"True", "False", "None"}:
|
||||
raise ValueError("unknown symbol")
|
||||
return
|
||||
|
||||
if isinstance(node, ast.Call):
|
||||
if not isinstance(node.func, ast.Name):
|
||||
raise ValueError("unsafe call target")
|
||||
if node.func.id not in _SAFE_EVAL_FUNCS:
|
||||
raise ValueError("function not allowed")
|
||||
for arg in node.args:
|
||||
_validate_safe_expr(arg)
|
||||
for kw in node.keywords:
|
||||
_validate_safe_expr(kw.value)
|
||||
return
|
||||
|
||||
# Explicitly reject high-risk nodes (import/attribute/subscript/comprehensions/lambda, etc.)
|
||||
raise ValueError("unsupported expression")
|
||||
|
||||
|
||||
def _safe_eval_python_expr(expression: str) -> Any:
|
||||
tree = ast.parse(expression, mode="eval")
|
||||
_validate_safe_expr(tree)
|
||||
return eval( # noqa: S307 - validated AST + empty builtins
|
||||
compile(tree, "<code_interpreter>", "eval"),
|
||||
{"__builtins__": {}},
|
||||
dict(_SAFE_EVAL_FUNCS),
|
||||
)
|
||||
|
||||
|
||||
def _json_safe(value: Any) -> Any:
|
||||
if isinstance(value, (str, int, float, bool)) or value is None:
|
||||
return value
|
||||
if isinstance(value, (list, tuple)):
|
||||
return [_json_safe(v) for v in value]
|
||||
if isinstance(value, dict):
|
||||
return {str(k): _json_safe(v) for k, v in value.items()}
|
||||
return repr(value)
|
||||
|
||||
|
||||
def _safe_eval_expr(expression: str) -> float:
|
||||
tree = ast.parse(expression, mode="eval")
|
||||
|
||||
def _eval(node: ast.AST) -> float:
|
||||
if isinstance(node, ast.Expression):
|
||||
return _eval(node.body)
|
||||
if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)):
|
||||
return float(node.value)
|
||||
if isinstance(node, ast.BinOp):
|
||||
op = _BIN_OPS.get(type(node.op))
|
||||
if not op:
|
||||
raise ValueError("unsupported operator")
|
||||
return float(op(_eval(node.left), _eval(node.right)))
|
||||
if isinstance(node, ast.UnaryOp):
|
||||
op = _UNARY_OPS.get(type(node.op))
|
||||
if not op:
|
||||
raise ValueError("unsupported unary operator")
|
||||
return float(op(_eval(node.operand)))
|
||||
raise ValueError("unsupported expression")
|
||||
|
||||
return _eval(tree)
|
||||
|
||||
|
||||
def _extract_tool_name(tool_call: Dict[str, Any]) -> str:
|
||||
function_payload = tool_call.get("function")
|
||||
if isinstance(function_payload, dict):
|
||||
return str(function_payload.get("name") or "").strip()
|
||||
return ""
|
||||
|
||||
|
||||
def _extract_tool_args(tool_call: Dict[str, Any]) -> Dict[str, Any]:
|
||||
function_payload = tool_call.get("function")
|
||||
if not isinstance(function_payload, dict):
|
||||
return {}
|
||||
raw = function_payload.get("arguments")
|
||||
if isinstance(raw, dict):
|
||||
return raw
|
||||
if not isinstance(raw, str):
|
||||
return {}
|
||||
text = raw.strip()
|
||||
if not text:
|
||||
return {}
|
||||
try:
|
||||
import json
|
||||
|
||||
parsed = json.loads(text)
|
||||
return parsed if isinstance(parsed, dict) else {}
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
async def execute_server_tool(tool_call: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Execute a server-side tool and return normalized result payload."""
|
||||
call_id = str(tool_call.get("id") or "").strip()
|
||||
tool_name = _extract_tool_name(tool_call)
|
||||
args = _extract_tool_args(tool_call)
|
||||
|
||||
if tool_name == "calculator":
|
||||
expression = str(args.get("expression") or "").strip()
|
||||
if not expression:
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"name": tool_name,
|
||||
"output": {"error": "missing expression"},
|
||||
"status": {"code": 400, "message": "bad_request"},
|
||||
}
|
||||
if len(expression) > 200:
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"name": tool_name,
|
||||
"output": {"expression": expression, "error": "expression too long"},
|
||||
"status": {"code": 422, "message": "invalid_expression"},
|
||||
}
|
||||
try:
|
||||
value = _safe_eval_expr(expression)
|
||||
if value.is_integer():
|
||||
value = int(value)
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"name": tool_name,
|
||||
"output": {"expression": expression, "result": value},
|
||||
"status": {"code": 200, "message": "ok"},
|
||||
}
|
||||
except Exception as exc:
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"name": tool_name,
|
||||
"output": {"expression": expression, "error": str(exc)},
|
||||
"status": {"code": 422, "message": "invalid_expression"},
|
||||
}
|
||||
|
||||
if tool_name == "code_interpreter":
|
||||
code = str(args.get("code") or args.get("expression") or "").strip()
|
||||
if not code:
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"name": tool_name,
|
||||
"output": {"error": "missing code"},
|
||||
"status": {"code": 400, "message": "bad_request"},
|
||||
}
|
||||
if len(code) > 500:
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"name": tool_name,
|
||||
"output": {"error": "code too long"},
|
||||
"status": {"code": 422, "message": "invalid_code"},
|
||||
}
|
||||
try:
|
||||
result = _safe_eval_python_expr(code)
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"name": tool_name,
|
||||
"output": {"code": code, "result": _json_safe(result)},
|
||||
"status": {"code": 200, "message": "ok"},
|
||||
}
|
||||
except Exception as exc:
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"name": tool_name,
|
||||
"output": {"code": code, "error": str(exc)},
|
||||
"status": {"code": 422, "message": "invalid_code"},
|
||||
}
|
||||
|
||||
if tool_name == "current_time":
|
||||
now = datetime.now().astimezone()
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"name": tool_name,
|
||||
"output": {
|
||||
"local_time": now.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"iso": now.isoformat(),
|
||||
"timezone": str(now.tzinfo or ""),
|
||||
"timestamp": int(now.timestamp()),
|
||||
},
|
||||
"status": {"code": 200, "message": "ok"},
|
||||
}
|
||||
|
||||
if tool_name and tool_name not in {"calculator", "code_interpreter", "current_time"}:
|
||||
resource = await fetch_tool_resource(tool_name)
|
||||
if resource and str(resource.get("category") or "") == "query":
|
||||
method = str(resource.get("http_method") or "GET").strip().upper()
|
||||
if method not in {"GET", "POST", "PUT", "PATCH", "DELETE"}:
|
||||
method = "GET"
|
||||
url = str(resource.get("http_url") or "").strip()
|
||||
headers = resource.get("http_headers") if isinstance(resource.get("http_headers"), dict) else {}
|
||||
timeout_ms = resource.get("http_timeout_ms")
|
||||
try:
|
||||
timeout_s = max(1.0, float(timeout_ms) / 1000.0)
|
||||
except Exception:
|
||||
timeout_s = 10.0
|
||||
|
||||
if not url:
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"name": tool_name,
|
||||
"output": {"error": "http_url not configured"},
|
||||
"status": {"code": 422, "message": "invalid_tool_config"},
|
||||
}
|
||||
|
||||
request_kwargs: Dict[str, Any] = {}
|
||||
if method in {"GET", "DELETE"}:
|
||||
request_kwargs["params"] = args
|
||||
else:
|
||||
request_kwargs["json"] = args
|
||||
|
||||
try:
|
||||
timeout = aiohttp.ClientTimeout(total=timeout_s)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.request(method, url, headers=headers, **request_kwargs) as resp:
|
||||
content_type = str(resp.headers.get("Content-Type") or "").lower()
|
||||
if "application/json" in content_type:
|
||||
body: Any = await resp.json()
|
||||
else:
|
||||
body = await resp.text()
|
||||
status_code = int(resp.status)
|
||||
if 200 <= status_code < 300:
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"name": tool_name,
|
||||
"output": {
|
||||
"method": method,
|
||||
"url": url,
|
||||
"status_code": status_code,
|
||||
"response": _json_safe(body),
|
||||
},
|
||||
"status": {"code": 200, "message": "ok"},
|
||||
}
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"name": tool_name,
|
||||
"output": {
|
||||
"method": method,
|
||||
"url": url,
|
||||
"status_code": status_code,
|
||||
"response": _json_safe(body),
|
||||
},
|
||||
"status": {"code": status_code, "message": "http_error"},
|
||||
}
|
||||
except asyncio.TimeoutError:
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"name": tool_name,
|
||||
"output": {"method": method, "url": url, "error": "request timeout"},
|
||||
"status": {"code": 504, "message": "http_timeout"},
|
||||
}
|
||||
except Exception as exc:
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"name": tool_name,
|
||||
"output": {"method": method, "url": url, "error": str(exc)},
|
||||
"status": {"code": 502, "message": "http_request_failed"},
|
||||
}
|
||||
|
||||
return {
|
||||
"tool_call_id": call_id,
|
||||
"name": tool_name or "unknown_tool",
|
||||
"output": {"message": "server tool not implemented"},
|
||||
"status": {"code": 501, "message": "not_implemented"},
|
||||
}
|
||||
247
core/transports.py
Normal file
247
core/transports.py
Normal file
@@ -0,0 +1,247 @@
|
||||
"""Transport layer for WebSocket and WebRTC communication."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from fastapi import WebSocket
|
||||
from starlette.websockets import WebSocketState
|
||||
from loguru import logger
|
||||
|
||||
# Try to import aiortc (optional for WebRTC functionality)
|
||||
try:
|
||||
from aiortc import RTCPeerConnection
|
||||
AIORTC_AVAILABLE = True
|
||||
except ImportError:
|
||||
AIORTC_AVAILABLE = False
|
||||
RTCPeerConnection = None # Type hint placeholder
|
||||
|
||||
|
||||
class BaseTransport(ABC):
|
||||
"""
|
||||
Abstract base class for transports.
|
||||
|
||||
All transports must implement send_event and send_audio methods.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def send_event(self, event: dict) -> None:
|
||||
"""
|
||||
Send a JSON event to the client.
|
||||
|
||||
Args:
|
||||
event: Event data as dictionary
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def send_audio(self, pcm_bytes: bytes) -> None:
|
||||
"""
|
||||
Send audio data to the client.
|
||||
|
||||
Args:
|
||||
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None:
|
||||
"""Close the transport and cleanup resources."""
|
||||
pass
|
||||
|
||||
|
||||
class SocketTransport(BaseTransport):
|
||||
"""
|
||||
WebSocket transport for raw audio streaming.
|
||||
|
||||
Handles mixed text/binary frames over WebSocket connection.
|
||||
Uses asyncio.Lock to prevent frame interleaving.
|
||||
"""
|
||||
|
||||
def __init__(self, websocket: WebSocket):
|
||||
"""
|
||||
Initialize WebSocket transport.
|
||||
|
||||
Args:
|
||||
websocket: FastAPI WebSocket connection
|
||||
"""
|
||||
self.ws = websocket
|
||||
self.lock = asyncio.Lock() # Prevent frame interleaving
|
||||
self._closed = False
|
||||
|
||||
def _ws_disconnected(self) -> bool:
|
||||
"""Best-effort check for websocket disconnection state."""
|
||||
return (
|
||||
self.ws.client_state == WebSocketState.DISCONNECTED
|
||||
or self.ws.application_state == WebSocketState.DISCONNECTED
|
||||
)
|
||||
|
||||
async def send_event(self, event: dict) -> None:
|
||||
"""
|
||||
Send a JSON event via WebSocket.
|
||||
|
||||
Args:
|
||||
event: Event data as dictionary
|
||||
"""
|
||||
if self._closed or self._ws_disconnected():
|
||||
logger.warning("Attempted to send event on closed transport")
|
||||
self._closed = True
|
||||
return
|
||||
|
||||
async with self.lock:
|
||||
try:
|
||||
await self.ws.send_text(json.dumps(event))
|
||||
logger.debug(f"Sent event: {event.get('event', 'unknown')}")
|
||||
except RuntimeError as e:
|
||||
self._closed = True
|
||||
if self._ws_disconnected() or "close message has been sent" in str(e):
|
||||
logger.debug(f"Skip sending event on closed websocket: {e!r}")
|
||||
return
|
||||
logger.error(f"Error sending event: {e!r}")
|
||||
except Exception as e:
|
||||
self._closed = True
|
||||
if self._ws_disconnected():
|
||||
logger.debug(f"Skip sending event on disconnected websocket: {e!r}")
|
||||
return
|
||||
logger.error(f"Error sending event: {e!r}")
|
||||
|
||||
async def send_audio(self, pcm_bytes: bytes) -> None:
|
||||
"""
|
||||
Send PCM audio data via WebSocket.
|
||||
|
||||
Args:
|
||||
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
|
||||
"""
|
||||
if self._closed or self._ws_disconnected():
|
||||
logger.warning("Attempted to send audio on closed transport")
|
||||
self._closed = True
|
||||
return
|
||||
|
||||
async with self.lock:
|
||||
try:
|
||||
await self.ws.send_bytes(pcm_bytes)
|
||||
except RuntimeError as e:
|
||||
self._closed = True
|
||||
if self._ws_disconnected() or "close message has been sent" in str(e):
|
||||
logger.debug(f"Skip sending audio on closed websocket: {e!r}")
|
||||
return
|
||||
logger.error(f"Error sending audio: {e!r}")
|
||||
except Exception as e:
|
||||
self._closed = True
|
||||
if self._ws_disconnected():
|
||||
logger.debug(f"Skip sending audio on disconnected websocket: {e!r}")
|
||||
return
|
||||
logger.error(f"Error sending audio: {e!r}")
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the WebSocket connection."""
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
self._closed = True
|
||||
if self._ws_disconnected():
|
||||
return
|
||||
|
||||
try:
|
||||
await self.ws.close()
|
||||
except RuntimeError as e:
|
||||
# Already closed by another task/path; safe to ignore.
|
||||
if "close message has been sent" in str(e):
|
||||
logger.debug(f"WebSocket already closed: {e}")
|
||||
return
|
||||
logger.error(f"Error closing WebSocket: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing WebSocket: {e}")
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
"""Check if the transport is closed."""
|
||||
return self._closed
|
||||
|
||||
|
||||
class WebRtcTransport(BaseTransport):
|
||||
"""
|
||||
WebRTC transport for WebRTC audio streaming.
|
||||
|
||||
Uses WebSocket for signaling and RTCPeerConnection for media.
|
||||
"""
|
||||
|
||||
def __init__(self, websocket: WebSocket, pc):
|
||||
"""
|
||||
Initialize WebRTC transport.
|
||||
|
||||
Args:
|
||||
websocket: FastAPI WebSocket connection for signaling
|
||||
pc: RTCPeerConnection for media transport
|
||||
"""
|
||||
if not AIORTC_AVAILABLE:
|
||||
raise RuntimeError("aiortc is not available - WebRTC transport cannot be used")
|
||||
|
||||
self.ws = websocket
|
||||
self.pc = pc
|
||||
self.outbound_track = None # MediaStreamTrack for outbound audio
|
||||
self._closed = False
|
||||
|
||||
async def send_event(self, event: dict) -> None:
|
||||
"""
|
||||
Send a JSON event via WebSocket signaling.
|
||||
|
||||
Args:
|
||||
event: Event data as dictionary
|
||||
"""
|
||||
if self._closed:
|
||||
logger.warning("Attempted to send event on closed transport")
|
||||
return
|
||||
|
||||
try:
|
||||
await self.ws.send_text(json.dumps(event))
|
||||
logger.debug(f"Sent event: {event.get('event', 'unknown')}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending event: {e}")
|
||||
self._closed = True
|
||||
|
||||
async def send_audio(self, pcm_bytes: bytes) -> None:
|
||||
"""
|
||||
Send audio data via WebRTC track.
|
||||
|
||||
Note: In WebRTC, you don't send bytes directly. You push frames
|
||||
to a MediaStreamTrack that the peer connection is reading.
|
||||
|
||||
Args:
|
||||
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
|
||||
"""
|
||||
if self._closed:
|
||||
logger.warning("Attempted to send audio on closed transport")
|
||||
return
|
||||
|
||||
# This would require a custom MediaStreamTrack implementation
|
||||
# For now, we'll log this as a placeholder
|
||||
logger.debug(f"Audio bytes queued for WebRTC track: {len(pcm_bytes)} bytes")
|
||||
|
||||
# TODO: Implement outbound audio track if needed
|
||||
# if self.outbound_track:
|
||||
# await self.outbound_track.add_frame(pcm_bytes)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the WebRTC connection."""
|
||||
self._closed = True
|
||||
try:
|
||||
await self.pc.close()
|
||||
await self.ws.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing WebRTC transport: {e}")
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
"""Check if the transport is closed."""
|
||||
return self._closed
|
||||
|
||||
def set_outbound_track(self, track):
|
||||
"""
|
||||
Set the outbound audio track for sending audio to client.
|
||||
|
||||
Args:
|
||||
track: MediaStreamTrack for outbound audio
|
||||
"""
|
||||
self.outbound_track = track
|
||||
logger.debug("Set outbound track for WebRTC transport")
|
||||
402
core/workflow_runner.py
Normal file
402
core/workflow_runner.py
Normal file
@@ -0,0 +1,402 @@
|
||||
"""Workflow runtime helpers for session-level node routing.
|
||||
|
||||
MVP goals:
|
||||
- Parse workflow graph payload from WS session.start metadata
|
||||
- Track current node
|
||||
- Evaluate edge conditions on each assistant turn completion
|
||||
- Provide per-node runtime metadata overrides (prompt/greeting/services)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
_NODE_TYPE_MAP = {
|
||||
"conversation": "assistant",
|
||||
"assistant": "assistant",
|
||||
"human": "human_transfer",
|
||||
"human_transfer": "human_transfer",
|
||||
"tool": "tool",
|
||||
"end": "end",
|
||||
"start": "start",
|
||||
}
|
||||
|
||||
|
||||
def _normalize_node_type(raw_type: Any) -> str:
|
||||
value = str(raw_type or "").strip().lower()
|
||||
return _NODE_TYPE_MAP.get(value, "assistant")
|
||||
|
||||
|
||||
def _safe_str(value: Any) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
return str(value)
|
||||
|
||||
|
||||
def _normalize_condition(raw: Any, label: Optional[str]) -> Dict[str, Any]:
|
||||
if not isinstance(raw, dict):
|
||||
if label:
|
||||
return {"type": "contains", "source": "user", "value": str(label)}
|
||||
return {"type": "always"}
|
||||
|
||||
condition = dict(raw)
|
||||
condition_type = str(condition.get("type", "always")).strip().lower()
|
||||
if not condition_type:
|
||||
condition_type = "always"
|
||||
condition["type"] = condition_type
|
||||
condition["source"] = str(condition.get("source", "user")).strip().lower() or "user"
|
||||
return condition
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkflowNodeDef:
|
||||
id: str
|
||||
name: str
|
||||
node_type: str
|
||||
is_start: bool
|
||||
prompt: Optional[str]
|
||||
message_plan: Dict[str, Any]
|
||||
assistant_id: Optional[str]
|
||||
assistant: Dict[str, Any]
|
||||
tool: Optional[Dict[str, Any]]
|
||||
raw: Dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkflowEdgeDef:
|
||||
id: str
|
||||
from_node_id: str
|
||||
to_node_id: str
|
||||
label: Optional[str]
|
||||
condition: Dict[str, Any]
|
||||
priority: int
|
||||
order: int
|
||||
raw: Dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkflowTransition:
|
||||
edge: WorkflowEdgeDef
|
||||
node: WorkflowNodeDef
|
||||
|
||||
|
||||
LlmRouter = Callable[
|
||||
[WorkflowNodeDef, List[WorkflowEdgeDef], Dict[str, str]],
|
||||
Awaitable[Optional[str]],
|
||||
]
|
||||
|
||||
|
||||
class WorkflowRunner:
|
||||
"""In-memory workflow graph for a single active session."""
|
||||
|
||||
def __init__(self, workflow_id: str, name: str, nodes: List[WorkflowNodeDef], edges: List[WorkflowEdgeDef]):
|
||||
self.workflow_id = workflow_id
|
||||
self.name = name
|
||||
self._nodes: Dict[str, WorkflowNodeDef] = {node.id: node for node in nodes}
|
||||
self._edges = edges
|
||||
self.current_node_id: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_payload(cls, payload: Any) -> Optional["WorkflowRunner"]:
|
||||
if not isinstance(payload, dict):
|
||||
return None
|
||||
|
||||
raw_nodes = payload.get("nodes")
|
||||
raw_edges = payload.get("edges")
|
||||
if not isinstance(raw_nodes, list) or len(raw_nodes) == 0:
|
||||
return None
|
||||
|
||||
nodes: List[WorkflowNodeDef] = []
|
||||
for i, raw in enumerate(raw_nodes):
|
||||
if not isinstance(raw, dict):
|
||||
continue
|
||||
|
||||
node_id = _safe_str(raw.get("id") or raw.get("name") or f"node_{i + 1}").strip() or f"node_{i + 1}"
|
||||
node_name = _safe_str(raw.get("name") or node_id).strip() or node_id
|
||||
node_type = _normalize_node_type(raw.get("type"))
|
||||
is_start = bool(raw.get("isStart")) or node_type == "start"
|
||||
|
||||
prompt: Optional[str] = None
|
||||
if "prompt" in raw:
|
||||
prompt = _safe_str(raw.get("prompt"))
|
||||
|
||||
message_plan = raw.get("messagePlan")
|
||||
if not isinstance(message_plan, dict):
|
||||
message_plan = {}
|
||||
|
||||
assistant_cfg = raw.get("assistant")
|
||||
if not isinstance(assistant_cfg, dict):
|
||||
assistant_cfg = {}
|
||||
|
||||
tool_cfg = raw.get("tool")
|
||||
if not isinstance(tool_cfg, dict):
|
||||
tool_cfg = None
|
||||
|
||||
assistant_id = raw.get("assistantId")
|
||||
if assistant_id is not None:
|
||||
assistant_id = _safe_str(assistant_id).strip() or None
|
||||
|
||||
nodes.append(
|
||||
WorkflowNodeDef(
|
||||
id=node_id,
|
||||
name=node_name,
|
||||
node_type=node_type,
|
||||
is_start=is_start,
|
||||
prompt=prompt,
|
||||
message_plan=message_plan,
|
||||
assistant_id=assistant_id,
|
||||
assistant=assistant_cfg,
|
||||
tool=tool_cfg,
|
||||
raw=raw,
|
||||
)
|
||||
)
|
||||
|
||||
if not nodes:
|
||||
return None
|
||||
|
||||
node_ids = {node.id for node in nodes}
|
||||
edges: List[WorkflowEdgeDef] = []
|
||||
for i, raw in enumerate(raw_edges if isinstance(raw_edges, list) else []):
|
||||
if not isinstance(raw, dict):
|
||||
continue
|
||||
|
||||
from_node_id = _safe_str(
|
||||
raw.get("fromNodeId") or raw.get("from") or raw.get("from_") or raw.get("source")
|
||||
).strip()
|
||||
to_node_id = _safe_str(raw.get("toNodeId") or raw.get("to") or raw.get("target")).strip()
|
||||
if not from_node_id or not to_node_id:
|
||||
continue
|
||||
if from_node_id not in node_ids or to_node_id not in node_ids:
|
||||
continue
|
||||
|
||||
label = raw.get("label")
|
||||
if label is not None:
|
||||
label = _safe_str(label)
|
||||
|
||||
condition = _normalize_condition(raw.get("condition"), label=label)
|
||||
|
||||
priority = 100
|
||||
try:
|
||||
priority = int(raw.get("priority", 100))
|
||||
except (TypeError, ValueError):
|
||||
priority = 100
|
||||
|
||||
edge_id = _safe_str(raw.get("id") or f"e_{from_node_id}_{to_node_id}_{i + 1}").strip() or f"e_{i + 1}"
|
||||
|
||||
edges.append(
|
||||
WorkflowEdgeDef(
|
||||
id=edge_id,
|
||||
from_node_id=from_node_id,
|
||||
to_node_id=to_node_id,
|
||||
label=label,
|
||||
condition=condition,
|
||||
priority=priority,
|
||||
order=i,
|
||||
raw=raw,
|
||||
)
|
||||
)
|
||||
|
||||
workflow_id = _safe_str(payload.get("id") or "workflow")
|
||||
workflow_name = _safe_str(payload.get("name") or workflow_id)
|
||||
return cls(workflow_id=workflow_id, name=workflow_name, nodes=nodes, edges=edges)
|
||||
|
||||
def bootstrap(self) -> Optional[WorkflowNodeDef]:
|
||||
start_node = self._resolve_start_node()
|
||||
if not start_node:
|
||||
return None
|
||||
self.current_node_id = start_node.id
|
||||
return start_node
|
||||
|
||||
@property
|
||||
def current_node(self) -> Optional[WorkflowNodeDef]:
|
||||
if not self.current_node_id:
|
||||
return None
|
||||
return self._nodes.get(self.current_node_id)
|
||||
|
||||
def outgoing_edges(self, node_id: str) -> List[WorkflowEdgeDef]:
|
||||
edges = [edge for edge in self._edges if edge.from_node_id == node_id]
|
||||
return sorted(edges, key=lambda edge: (edge.priority, edge.order))
|
||||
|
||||
def next_default_transition(self) -> Optional[WorkflowTransition]:
|
||||
node = self.current_node
|
||||
if not node:
|
||||
return None
|
||||
for edge in self.outgoing_edges(node.id):
|
||||
cond_type = str(edge.condition.get("type", "always")).strip().lower()
|
||||
if cond_type in {"", "always", "default"}:
|
||||
target = self._nodes.get(edge.to_node_id)
|
||||
if target:
|
||||
return WorkflowTransition(edge=edge, node=target)
|
||||
return None
|
||||
|
||||
async def route(
|
||||
self,
|
||||
*,
|
||||
user_text: str,
|
||||
assistant_text: str,
|
||||
llm_router: Optional[LlmRouter] = None,
|
||||
) -> Optional[WorkflowTransition]:
|
||||
node = self.current_node
|
||||
if not node:
|
||||
return None
|
||||
|
||||
outgoing = self.outgoing_edges(node.id)
|
||||
if not outgoing:
|
||||
return None
|
||||
|
||||
llm_edges: List[WorkflowEdgeDef] = []
|
||||
for edge in outgoing:
|
||||
cond_type = str(edge.condition.get("type", "always")).strip().lower()
|
||||
if cond_type == "llm":
|
||||
llm_edges.append(edge)
|
||||
continue
|
||||
if self._matches_condition(edge, user_text=user_text, assistant_text=assistant_text):
|
||||
target = self._nodes.get(edge.to_node_id)
|
||||
if target:
|
||||
return WorkflowTransition(edge=edge, node=target)
|
||||
|
||||
if llm_edges and llm_router:
|
||||
selection = await llm_router(
|
||||
node,
|
||||
llm_edges,
|
||||
{
|
||||
"userText": user_text,
|
||||
"assistantText": assistant_text,
|
||||
},
|
||||
)
|
||||
if selection:
|
||||
for edge in llm_edges:
|
||||
if selection in {edge.id, edge.to_node_id}:
|
||||
target = self._nodes.get(edge.to_node_id)
|
||||
if target:
|
||||
return WorkflowTransition(edge=edge, node=target)
|
||||
|
||||
for edge in outgoing:
|
||||
cond_type = str(edge.condition.get("type", "always")).strip().lower()
|
||||
if cond_type in {"", "always", "default"}:
|
||||
target = self._nodes.get(edge.to_node_id)
|
||||
if target:
|
||||
return WorkflowTransition(edge=edge, node=target)
|
||||
return None
|
||||
|
||||
def apply_transition(self, transition: WorkflowTransition) -> None:
|
||||
self.current_node_id = transition.node.id
|
||||
|
||||
def build_runtime_metadata(self, node: WorkflowNodeDef) -> Dict[str, Any]:
|
||||
assistant_cfg = node.assistant if isinstance(node.assistant, dict) else {}
|
||||
message_plan = node.message_plan if isinstance(node.message_plan, dict) else {}
|
||||
metadata: Dict[str, Any] = {}
|
||||
|
||||
if node.prompt is not None:
|
||||
metadata["systemPrompt"] = node.prompt
|
||||
elif "systemPrompt" in assistant_cfg:
|
||||
metadata["systemPrompt"] = _safe_str(assistant_cfg.get("systemPrompt"))
|
||||
elif "prompt" in assistant_cfg:
|
||||
metadata["systemPrompt"] = _safe_str(assistant_cfg.get("prompt"))
|
||||
|
||||
first_message = message_plan.get("firstMessage")
|
||||
if first_message is not None:
|
||||
metadata["greeting"] = _safe_str(first_message)
|
||||
elif "greeting" in assistant_cfg:
|
||||
metadata["greeting"] = _safe_str(assistant_cfg.get("greeting"))
|
||||
elif "opener" in assistant_cfg:
|
||||
metadata["greeting"] = _safe_str(assistant_cfg.get("opener"))
|
||||
|
||||
services = assistant_cfg.get("services")
|
||||
if isinstance(services, dict):
|
||||
metadata["services"] = services
|
||||
|
||||
if node.assistant_id:
|
||||
metadata["assistantId"] = node.assistant_id
|
||||
|
||||
return metadata
|
||||
|
||||
def _resolve_start_node(self) -> Optional[WorkflowNodeDef]:
|
||||
explicit_start = next((node for node in self._nodes.values() if node.is_start), None)
|
||||
if not explicit_start:
|
||||
explicit_start = next((node for node in self._nodes.values() if node.node_type == "start"), None)
|
||||
|
||||
if explicit_start:
|
||||
# If a dedicated start node exists, try to move to its first default target.
|
||||
if explicit_start.node_type == "start":
|
||||
visited = {explicit_start.id}
|
||||
current = explicit_start
|
||||
for _ in range(8):
|
||||
transition = self._first_default_transition_from(current.id)
|
||||
if not transition:
|
||||
return current
|
||||
current = transition.node
|
||||
if current.id in visited:
|
||||
break
|
||||
visited.add(current.id)
|
||||
return current
|
||||
return explicit_start
|
||||
|
||||
assistant_node = next((node for node in self._nodes.values() if node.node_type == "assistant"), None)
|
||||
if assistant_node:
|
||||
return assistant_node
|
||||
return next(iter(self._nodes.values()), None)
|
||||
|
||||
def _first_default_transition_from(self, node_id: str) -> Optional[WorkflowTransition]:
|
||||
for edge in self.outgoing_edges(node_id):
|
||||
cond_type = str(edge.condition.get("type", "always")).strip().lower()
|
||||
if cond_type in {"", "always", "default"}:
|
||||
node = self._nodes.get(edge.to_node_id)
|
||||
if node:
|
||||
return WorkflowTransition(edge=edge, node=node)
|
||||
return None
|
||||
|
||||
def _matches_condition(self, edge: WorkflowEdgeDef, *, user_text: str, assistant_text: str) -> bool:
|
||||
condition = edge.condition or {"type": "always"}
|
||||
cond_type = str(condition.get("type", "always")).strip().lower()
|
||||
source = str(condition.get("source", "user")).strip().lower()
|
||||
|
||||
if cond_type in {"", "always", "default"}:
|
||||
return True
|
||||
|
||||
text = assistant_text if source == "assistant" else user_text
|
||||
text_lower = (text or "").lower()
|
||||
|
||||
if cond_type == "contains":
|
||||
values: List[str] = []
|
||||
if isinstance(condition.get("values"), list):
|
||||
values = [_safe_str(v).strip().lower() for v in condition["values"] if _safe_str(v).strip()]
|
||||
if not values:
|
||||
single = _safe_str(condition.get("value") or condition.get("keyword") or edge.label).strip().lower()
|
||||
if single:
|
||||
values = [single]
|
||||
if not values:
|
||||
return False
|
||||
return any(value in text_lower for value in values)
|
||||
|
||||
if cond_type == "equals":
|
||||
expected = _safe_str(condition.get("value") or "").strip().lower()
|
||||
return bool(expected) and text_lower == expected
|
||||
|
||||
if cond_type == "regex":
|
||||
pattern = _safe_str(condition.get("value") or condition.get("pattern") or "").strip()
|
||||
if not pattern:
|
||||
return False
|
||||
try:
|
||||
return bool(re.search(pattern, text or "", re.IGNORECASE))
|
||||
except re.error:
|
||||
logger.warning(f"Invalid workflow regex condition: {pattern}")
|
||||
return False
|
||||
|
||||
if cond_type == "json":
|
||||
value = _safe_str(condition.get("value") or "").strip()
|
||||
if not value:
|
||||
return False
|
||||
try:
|
||||
obj = json.loads(text or "")
|
||||
except Exception:
|
||||
return False
|
||||
return str(obj) == value
|
||||
|
||||
return False
|
||||
Reference in New Issue
Block a user