Init commit

This commit is contained in:
Xin Wang
2026-02-17 10:39:23 +08:00
commit 30eb4397c2
56 changed files with 11983 additions and 0 deletions

20
core/__init__.py Normal file
View File

@@ -0,0 +1,20 @@
"""Core Components Package"""
from core.events import EventBus, get_event_bus
from core.transports import BaseTransport, SocketTransport, WebRtcTransport
from core.session import Session
from core.conversation import ConversationManager, ConversationState, ConversationTurn
from core.duplex_pipeline import DuplexPipeline
__all__ = [
"EventBus",
"get_event_bus",
"BaseTransport",
"SocketTransport",
"WebRtcTransport",
"Session",
"ConversationManager",
"ConversationState",
"ConversationTurn",
"DuplexPipeline",
]

279
core/conversation.py Normal file
View File

@@ -0,0 +1,279 @@
"""Conversation management for voice AI.
Handles conversation context, turn-taking, and message history
for multi-turn voice conversations.
"""
import asyncio
from typing import List, Optional, Dict, Any, Callable, Awaitable
from dataclasses import dataclass, field
from enum import Enum
from loguru import logger
from services.base import LLMMessage
class ConversationState(Enum):
"""State of the conversation."""
IDLE = "idle" # Waiting for user input
LISTENING = "listening" # User is speaking
PROCESSING = "processing" # Processing user input (LLM)
SPEAKING = "speaking" # Bot is speaking
INTERRUPTED = "interrupted" # Bot was interrupted
@dataclass
class ConversationTurn:
"""A single turn in the conversation."""
role: str # "user" or "assistant"
text: str
audio_duration_ms: Optional[int] = None
timestamp: float = field(default_factory=lambda: asyncio.get_event_loop().time())
was_interrupted: bool = False
class ConversationManager:
"""
Manages conversation state and history.
Provides:
- Message history for LLM context
- Turn management
- State tracking
- Event callbacks for state changes
"""
def __init__(
self,
system_prompt: Optional[str] = None,
max_history: int = 20,
greeting: Optional[str] = None
):
"""
Initialize conversation manager.
Args:
system_prompt: System prompt for LLM
max_history: Maximum number of turns to keep
greeting: Optional greeting message when conversation starts
"""
self.system_prompt = system_prompt or (
"You are a helpful, friendly voice assistant. "
"Keep your responses concise and conversational. "
"Respond naturally as if having a phone conversation. "
"If you don't understand something, ask for clarification."
)
self.max_history = max_history
self.greeting = greeting
# State
self.state = ConversationState.IDLE
self.turns: List[ConversationTurn] = []
# Callbacks
self._state_callbacks: List[Callable[[ConversationState, ConversationState], Awaitable[None]]] = []
self._turn_callbacks: List[Callable[[ConversationTurn], Awaitable[None]]] = []
# Current turn tracking
self._current_user_text: str = ""
self._current_assistant_text: str = ""
logger.info("ConversationManager initialized")
def on_state_change(
self,
callback: Callable[[ConversationState, ConversationState], Awaitable[None]]
) -> None:
"""Register callback for state changes."""
self._state_callbacks.append(callback)
def on_turn_complete(
self,
callback: Callable[[ConversationTurn], Awaitable[None]]
) -> None:
"""Register callback for turn completion."""
self._turn_callbacks.append(callback)
async def set_state(self, new_state: ConversationState) -> None:
"""Set conversation state and notify listeners."""
if new_state != self.state:
old_state = self.state
self.state = new_state
logger.debug(f"Conversation state: {old_state.value} -> {new_state.value}")
for callback in self._state_callbacks:
try:
await callback(old_state, new_state)
except Exception as e:
logger.error(f"State callback error: {e}")
def get_messages(self) -> List[LLMMessage]:
"""
Get conversation history as LLM messages.
Returns:
List of LLMMessage objects including system prompt
"""
messages = [LLMMessage(role="system", content=self.system_prompt)]
# Add conversation history
for turn in self.turns[-self.max_history:]:
messages.append(LLMMessage(role=turn.role, content=turn.text))
# Add current user text if any
if self._current_user_text:
messages.append(LLMMessage(role="user", content=self._current_user_text))
return messages
async def start_user_turn(self) -> None:
"""Signal that user has started speaking."""
await self.set_state(ConversationState.LISTENING)
self._current_user_text = ""
async def update_user_text(self, text: str, is_final: bool = False) -> None:
"""
Update current user text (from ASR).
Args:
text: Transcribed text
is_final: Whether this is the final transcript
"""
self._current_user_text = text
async def end_user_turn(self, text: str) -> None:
"""
End user turn and add to history.
Args:
text: Final user text
"""
if text.strip():
turn = ConversationTurn(role="user", text=text.strip())
self.turns.append(turn)
for callback in self._turn_callbacks:
try:
await callback(turn)
except Exception as e:
logger.error(f"Turn callback error: {e}")
logger.info(f"User: {text[:50]}...")
self._current_user_text = ""
await self.set_state(ConversationState.PROCESSING)
async def start_assistant_turn(self) -> None:
"""Signal that assistant has started speaking."""
await self.set_state(ConversationState.SPEAKING)
self._current_assistant_text = ""
async def update_assistant_text(self, text: str) -> None:
"""
Update current assistant text (streaming).
Args:
text: Text chunk from LLM
"""
self._current_assistant_text += text
async def end_assistant_turn(self, was_interrupted: bool = False) -> None:
"""
End assistant turn and add to history.
Args:
was_interrupted: Whether the turn was interrupted by user
"""
text = self._current_assistant_text.strip()
if text:
turn = ConversationTurn(
role="assistant",
text=text,
was_interrupted=was_interrupted
)
self.turns.append(turn)
for callback in self._turn_callbacks:
try:
await callback(turn)
except Exception as e:
logger.error(f"Turn callback error: {e}")
status = " (interrupted)" if was_interrupted else ""
logger.info(f"Assistant{status}: {text[:50]}...")
self._current_assistant_text = ""
if was_interrupted:
# A new user turn may already be active (LISTENING) when interrupted.
# Avoid overriding it back to INTERRUPTED, which can stall EOU flow.
if self.state != ConversationState.LISTENING:
await self.set_state(ConversationState.INTERRUPTED)
else:
await self.set_state(ConversationState.IDLE)
async def add_assistant_turn(self, text: str, was_interrupted: bool = False) -> None:
"""Append an assistant turn directly without mutating conversation state."""
content = text.strip()
if not content:
return
turn = ConversationTurn(
role="assistant",
text=content,
was_interrupted=was_interrupted,
)
self.turns.append(turn)
for callback in self._turn_callbacks:
try:
await callback(turn)
except Exception as e:
logger.error(f"Turn callback error: {e}")
logger.info(f"Assistant (injected): {content[:50]}...")
async def interrupt(self) -> None:
"""Handle interruption (barge-in)."""
if self.state == ConversationState.SPEAKING:
await self.end_assistant_turn(was_interrupted=True)
def reset(self) -> None:
"""Reset conversation history."""
self.turns = []
self._current_user_text = ""
self._current_assistant_text = ""
self.state = ConversationState.IDLE
logger.info("Conversation reset")
@property
def turn_count(self) -> int:
"""Get number of turns in conversation."""
return len(self.turns)
@property
def last_user_text(self) -> Optional[str]:
"""Get last user text."""
for turn in reversed(self.turns):
if turn.role == "user":
return turn.text
return None
@property
def last_assistant_text(self) -> Optional[str]:
"""Get last assistant text."""
for turn in reversed(self.turns):
if turn.role == "assistant":
return turn.text
return None
def get_context_summary(self) -> Dict[str, Any]:
"""Get a summary of conversation context."""
return {
"state": self.state.value,
"turn_count": self.turn_count,
"last_user": self.last_user_text,
"last_assistant": self.last_assistant_text,
"current_user": self._current_user_text or None,
"current_assistant": self._current_assistant_text or None
}

1507
core/duplex_pipeline.py Normal file

File diff suppressed because it is too large Load Diff

134
core/events.py Normal file
View File

@@ -0,0 +1,134 @@
"""Event bus for pub/sub communication between components."""
import asyncio
from typing import Callable, Dict, List, Any, Optional
from collections import defaultdict
from loguru import logger
class EventBus:
"""
Async event bus for pub/sub communication.
Similar to the original Rust implementation's broadcast channel.
Components can subscribe to specific event types and receive events asynchronously.
"""
def __init__(self):
"""Initialize the event bus."""
self._subscribers: Dict[str, List[Callable]] = defaultdict(list)
self._lock = asyncio.Lock()
self._running = True
def subscribe(self, event_type: str, callback: Callable[[Dict[str, Any]], None]) -> None:
"""
Subscribe to an event type.
Args:
event_type: Type of event to subscribe to (e.g., "speaking", "silence")
callback: Async callback function that receives event data
"""
if not self._running:
logger.warning(f"Event bus is shut down, ignoring subscription to {event_type}")
return
self._subscribers[event_type].append(callback)
logger.debug(f"Subscribed to event type: {event_type}")
def unsubscribe(self, event_type: str, callback: Callable[[Dict[str, Any]], None]) -> None:
"""
Unsubscribe from an event type.
Args:
event_type: Type of event to unsubscribe from
callback: Callback function to remove
"""
if callback in self._subscribers[event_type]:
self._subscribers[event_type].remove(callback)
logger.debug(f"Unsubscribed from event type: {event_type}")
async def publish(self, event_type: str, event_data: Dict[str, Any]) -> None:
"""
Publish an event to all subscribers.
Args:
event_type: Type of event to publish
event_data: Event data to send to subscribers
"""
if not self._running:
logger.warning(f"Event bus is shut down, ignoring event: {event_type}")
return
# Get subscribers for this event type
subscribers = self._subscribers.get(event_type, [])
if not subscribers:
logger.debug(f"No subscribers for event type: {event_type}")
return
# Notify all subscribers concurrently
tasks = []
for callback in subscribers:
try:
# Create task for each subscriber
task = asyncio.create_task(self._call_subscriber(callback, event_data))
tasks.append(task)
except Exception as e:
logger.error(f"Error creating task for subscriber: {e}")
# Wait for all subscribers to complete
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
logger.debug(f"Published event '{event_type}' to {len(tasks)} subscribers")
async def _call_subscriber(self, callback: Callable[[Dict[str, Any]], None], event_data: Dict[str, Any]) -> None:
"""
Call a subscriber callback with error handling.
Args:
callback: Subscriber callback function
event_data: Event data to pass to callback
"""
try:
# Check if callback is a coroutine function
if asyncio.iscoroutinefunction(callback):
await callback(event_data)
else:
callback(event_data)
except Exception as e:
logger.error(f"Error in subscriber callback: {e}", exc_info=True)
async def close(self) -> None:
"""Close the event bus and stop processing events."""
self._running = False
self._subscribers.clear()
logger.info("Event bus closed")
@property
def is_running(self) -> bool:
"""Check if the event bus is running."""
return self._running
# Global event bus instance
_event_bus: Optional[EventBus] = None
def get_event_bus() -> EventBus:
"""
Get the global event bus instance.
Returns:
EventBus instance
"""
global _event_bus
if _event_bus is None:
_event_bus = EventBus()
return _event_bus
def reset_event_bus() -> None:
"""Reset the global event bus (mainly for testing)."""
global _event_bus
_event_bus = None

648
core/session.py Normal file
View File

@@ -0,0 +1,648 @@
"""Session management for active calls."""
import asyncio
import uuid
import json
import time
import re
from enum import Enum
from typing import Optional, Dict, Any, List
from loguru import logger
from app.backend_client import (
create_history_call_record,
add_history_transcript,
finalize_history_call_record,
)
from core.transports import BaseTransport
from core.duplex_pipeline import DuplexPipeline
from core.conversation import ConversationTurn
from core.workflow_runner import WorkflowRunner, WorkflowTransition, WorkflowNodeDef, WorkflowEdgeDef
from app.config import settings
from services.base import LLMMessage
from models.ws_v1 import (
parse_client_message,
ev,
HelloMessage,
SessionStartMessage,
SessionStopMessage,
InputTextMessage,
ResponseCancelMessage,
ToolCallResultsMessage,
)
class WsSessionState(str, Enum):
"""Protocol state machine for WS sessions."""
WAIT_HELLO = "wait_hello"
WAIT_START = "wait_start"
ACTIVE = "active"
STOPPED = "stopped"
class Session:
"""
Manages a single call session.
Handles command routing, audio processing, and session lifecycle.
Uses full duplex voice conversation pipeline.
"""
def __init__(self, session_id: str, transport: BaseTransport, use_duplex: bool = None):
"""
Initialize session.
Args:
session_id: Unique session identifier
transport: Transport instance for communication
use_duplex: Whether to use duplex pipeline (defaults to settings.duplex_enabled)
"""
self.id = session_id
self.transport = transport
self.use_duplex = use_duplex if use_duplex is not None else settings.duplex_enabled
self.pipeline = DuplexPipeline(
transport=transport,
session_id=session_id,
system_prompt=settings.duplex_system_prompt,
greeting=settings.duplex_greeting
)
# Session state
self.created_at = None
self.state = "created" # Legacy call state for /call/lists
self.ws_state = WsSessionState.WAIT_HELLO
self._pipeline_started = False
self.protocol_version: Optional[str] = None
self.authenticated: bool = False
# Track IDs
self.current_track_id: Optional[str] = str(uuid.uuid4())
self._history_call_id: Optional[str] = None
self._history_turn_index: int = 0
self._history_call_started_mono: Optional[float] = None
self._history_finalized: bool = False
self._cleanup_lock = asyncio.Lock()
self._cleaned_up = False
self.workflow_runner: Optional[WorkflowRunner] = None
self._workflow_last_user_text: str = ""
self._workflow_initial_node: Optional[WorkflowNodeDef] = None
self.pipeline.conversation.on_turn_complete(self._on_turn_complete)
logger.info(f"Session {self.id} created (duplex={self.use_duplex})")
async def handle_text(self, text_data: str) -> None:
"""
Handle incoming text data (WS v1 JSON control messages).
Args:
text_data: JSON text data
"""
try:
data = json.loads(text_data)
message = parse_client_message(data)
await self._handle_v1_message(message)
except json.JSONDecodeError as e:
logger.error(f"Session {self.id} JSON decode error: {e}")
await self._send_error("client", f"Invalid JSON: {e}", "protocol.invalid_json")
except ValueError as e:
logger.error(f"Session {self.id} command parse error: {e}")
await self._send_error("client", f"Invalid message: {e}", "protocol.invalid_message")
except Exception as e:
logger.error(f"Session {self.id} handle_text error: {e}", exc_info=True)
await self._send_error("server", f"Internal error: {e}", "server.internal")
async def handle_audio(self, audio_bytes: bytes) -> None:
"""
Handle incoming audio data.
Args:
audio_bytes: PCM audio data
"""
if self.ws_state != WsSessionState.ACTIVE:
await self._send_error(
"client",
"Audio received before session.start",
"protocol.order",
)
return
try:
await self.pipeline.process_audio(audio_bytes)
except Exception as e:
logger.error(f"Session {self.id} handle_audio error: {e}", exc_info=True)
async def _handle_v1_message(self, message: Any) -> None:
"""Route validated WS v1 message to handlers."""
msg_type = message.type
logger.info(f"Session {self.id} received message: {msg_type}")
if isinstance(message, HelloMessage):
await self._handle_hello(message)
return
# All messages below require hello handshake first
if self.ws_state == WsSessionState.WAIT_HELLO:
await self._send_error(
"client",
"Expected hello message first",
"protocol.order",
)
return
if isinstance(message, SessionStartMessage):
await self._handle_session_start(message)
return
# All messages below require active session
if self.ws_state != WsSessionState.ACTIVE:
await self._send_error(
"client",
f"Message '{msg_type}' requires active session",
"protocol.order",
)
return
if isinstance(message, InputTextMessage):
await self.pipeline.process_text(message.text)
elif isinstance(message, ResponseCancelMessage):
if message.graceful:
logger.info(f"Session {self.id} graceful response.cancel")
else:
await self.pipeline.interrupt()
elif isinstance(message, ToolCallResultsMessage):
await self.pipeline.handle_tool_call_results(message.results)
elif isinstance(message, SessionStopMessage):
await self._handle_session_stop(message.reason)
else:
await self._send_error("client", f"Unsupported message type: {msg_type}", "protocol.unsupported")
async def _handle_hello(self, message: HelloMessage) -> None:
"""Handle initial hello/auth/version negotiation."""
if self.ws_state != WsSessionState.WAIT_HELLO:
await self._send_error("client", "Duplicate hello", "protocol.order")
return
if message.version != settings.ws_protocol_version:
await self._send_error(
"client",
f"Unsupported protocol version '{message.version}'",
"protocol.version_unsupported",
)
await self.transport.close()
self.ws_state = WsSessionState.STOPPED
return
auth_payload = message.auth or {}
api_key = auth_payload.get("apiKey")
jwt = auth_payload.get("jwt")
if settings.ws_api_key:
if api_key != settings.ws_api_key:
await self._send_error("auth", "Invalid API key", "auth.invalid_api_key")
await self.transport.close()
self.ws_state = WsSessionState.STOPPED
return
elif settings.ws_require_auth and not (api_key or jwt):
await self._send_error("auth", "Authentication required", "auth.required")
await self.transport.close()
self.ws_state = WsSessionState.STOPPED
return
self.authenticated = True
self.protocol_version = message.version
self.ws_state = WsSessionState.WAIT_START
await self.transport.send_event(
ev(
"hello.ack",
sessionId=self.id,
version=self.protocol_version,
)
)
async def _handle_session_start(self, message: SessionStartMessage) -> None:
"""Handle explicit session start after successful hello."""
if self.ws_state != WsSessionState.WAIT_START:
await self._send_error("client", "Duplicate session.start", "protocol.order")
return
metadata = message.metadata or {}
metadata = self._merge_runtime_metadata(metadata, self._bootstrap_workflow(metadata))
# Create history call record early so later turn callbacks can append transcripts.
await self._start_history_bridge(metadata)
# Apply runtime service/prompt overrides from backend if provided
self.pipeline.apply_runtime_overrides(metadata)
# Start duplex pipeline
if not self._pipeline_started:
await self.pipeline.start()
self._pipeline_started = True
logger.info(f"Session {self.id} duplex pipeline started")
self.state = "accepted"
self.ws_state = WsSessionState.ACTIVE
await self.transport.send_event(
ev(
"session.started",
sessionId=self.id,
trackId=self.current_track_id,
audio=message.audio or {},
)
)
if self.workflow_runner and self._workflow_initial_node:
await self.transport.send_event(
ev(
"workflow.started",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
workflowName=self.workflow_runner.name,
nodeId=self._workflow_initial_node.id,
)
)
await self.transport.send_event(
ev(
"workflow.node.entered",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
nodeId=self._workflow_initial_node.id,
nodeName=self._workflow_initial_node.name,
nodeType=self._workflow_initial_node.node_type,
)
)
async def _handle_session_stop(self, reason: Optional[str]) -> None:
"""Handle session stop."""
if self.ws_state == WsSessionState.STOPPED:
return
stop_reason = reason or "client_requested"
self.state = "hungup"
self.ws_state = WsSessionState.STOPPED
await self.transport.send_event(
ev(
"session.stopped",
sessionId=self.id,
reason=stop_reason,
)
)
await self._finalize_history(status="connected")
await self.transport.close()
async def _send_error(self, sender: str, error_message: str, code: str) -> None:
"""
Send error event to client.
Args:
sender: Component that generated the error
error_message: Error message
code: Machine-readable error code
"""
await self.transport.send_event(
ev(
"error",
sender=sender,
code=code,
message=error_message,
trackId=self.current_track_id,
)
)
def _get_timestamp_ms(self) -> int:
"""Get current timestamp in milliseconds."""
import time
return int(time.time() * 1000)
async def cleanup(self) -> None:
"""Cleanup session resources."""
async with self._cleanup_lock:
if self._cleaned_up:
return
self._cleaned_up = True
logger.info(f"Session {self.id} cleaning up")
await self._finalize_history(status="connected")
await self.pipeline.cleanup()
await self.transport.close()
async def _start_history_bridge(self, metadata: Dict[str, Any]) -> None:
"""Initialize backend history call record for this session."""
if self._history_call_id:
return
history_meta: Dict[str, Any] = {}
if isinstance(metadata.get("history"), dict):
history_meta = metadata["history"]
raw_user_id = history_meta.get("userId", metadata.get("userId", settings.history_default_user_id))
try:
user_id = int(raw_user_id)
except (TypeError, ValueError):
user_id = settings.history_default_user_id
assistant_id = history_meta.get("assistantId", metadata.get("assistantId"))
source = str(history_meta.get("source", metadata.get("source", "debug")))
call_id = await create_history_call_record(
user_id=user_id,
assistant_id=str(assistant_id) if assistant_id else None,
source=source,
)
if not call_id:
return
self._history_call_id = call_id
self._history_call_started_mono = time.monotonic()
self._history_turn_index = 0
self._history_finalized = False
logger.info(f"Session {self.id} history bridge enabled (call_id={call_id}, source={source})")
async def _on_turn_complete(self, turn: ConversationTurn) -> None:
"""Process workflow transitions and persist completed turns to history."""
if turn.text and turn.text.strip():
role = (turn.role or "").lower()
if role == "user":
self._workflow_last_user_text = turn.text.strip()
elif role == "assistant":
await self._maybe_advance_workflow(turn.text.strip())
if not self._history_call_id:
return
if not turn.text or not turn.text.strip():
return
role = (turn.role or "").lower()
speaker = "human" if role == "user" else "ai"
end_ms = 0
if self._history_call_started_mono is not None:
end_ms = max(0, int((time.monotonic() - self._history_call_started_mono) * 1000))
estimated_duration_ms = max(300, min(12000, len(turn.text.strip()) * 80))
start_ms = max(0, end_ms - estimated_duration_ms)
turn_index = self._history_turn_index
await add_history_transcript(
call_id=self._history_call_id,
turn_index=turn_index,
speaker=speaker,
content=turn.text.strip(),
start_ms=start_ms,
end_ms=end_ms,
duration_ms=max(1, end_ms - start_ms),
)
self._history_turn_index += 1
async def _finalize_history(self, status: str) -> None:
"""Finalize history call record once."""
if not self._history_call_id or self._history_finalized:
return
duration_seconds = 0
if self._history_call_started_mono is not None:
duration_seconds = max(0, int(time.monotonic() - self._history_call_started_mono))
ok = await finalize_history_call_record(
call_id=self._history_call_id,
status=status,
duration_seconds=duration_seconds,
)
if ok:
self._history_finalized = True
def _bootstrap_workflow(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
"""Parse workflow payload and return initial runtime overrides."""
payload = metadata.get("workflow")
self.workflow_runner = WorkflowRunner.from_payload(payload)
self._workflow_initial_node = None
if not self.workflow_runner:
return {}
node = self.workflow_runner.bootstrap()
if not node:
logger.warning(f"Session {self.id} workflow payload had no resolvable start node")
self.workflow_runner = None
return {}
self._workflow_initial_node = node
logger.info(
"Session {} workflow enabled: workflow={} start_node={}",
self.id,
self.workflow_runner.workflow_id,
node.id,
)
return self.workflow_runner.build_runtime_metadata(node)
async def _maybe_advance_workflow(self, assistant_text: str) -> None:
"""Attempt node transfer after assistant turn finalization."""
if not self.workflow_runner or self.ws_state == WsSessionState.STOPPED:
return
transition = await self.workflow_runner.route(
user_text=self._workflow_last_user_text,
assistant_text=assistant_text,
llm_router=self._workflow_llm_route,
)
if not transition:
return
await self._apply_workflow_transition(transition, reason="rule_match")
# Auto-advance through utility nodes when default edges are present.
max_auto_hops = 6
auto_hops = 0
while self.workflow_runner and self.ws_state != WsSessionState.STOPPED:
current = self.workflow_runner.current_node
if not current or current.node_type not in {"start", "tool"}:
break
next_default = self.workflow_runner.next_default_transition()
if not next_default:
break
auto_hops += 1
await self._apply_workflow_transition(next_default, reason="auto")
if auto_hops >= max_auto_hops:
logger.warning(
"Session {} workflow auto-advance reached hop limit (possible cycle)",
self.id,
)
break
async def _apply_workflow_transition(self, transition: WorkflowTransition, reason: str) -> None:
"""Apply graph transition and emit workflow lifecycle events."""
if not self.workflow_runner:
return
self.workflow_runner.apply_transition(transition)
node = transition.node
edge = transition.edge
await self.transport.send_event(
ev(
"workflow.edge.taken",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
edgeId=edge.id,
fromNodeId=edge.from_node_id,
toNodeId=edge.to_node_id,
reason=reason,
)
)
await self.transport.send_event(
ev(
"workflow.node.entered",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
nodeId=node.id,
nodeName=node.name,
nodeType=node.node_type,
)
)
node_runtime = self.workflow_runner.build_runtime_metadata(node)
if node_runtime:
self.pipeline.apply_runtime_overrides(node_runtime)
if node.node_type == "tool":
await self.transport.send_event(
ev(
"workflow.tool.requested",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
nodeId=node.id,
tool=node.tool or {},
)
)
return
if node.node_type == "human_transfer":
await self.transport.send_event(
ev(
"workflow.human_transfer",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
nodeId=node.id,
)
)
await self._handle_session_stop("workflow_human_transfer")
return
if node.node_type == "end":
await self.transport.send_event(
ev(
"workflow.ended",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
nodeId=node.id,
)
)
await self._handle_session_stop("workflow_end")
async def _workflow_llm_route(
self,
node: WorkflowNodeDef,
candidates: List[WorkflowEdgeDef],
context: Dict[str, str],
) -> Optional[str]:
"""LLM-based edge routing for condition.type == 'llm' edges."""
llm_service = self.pipeline.llm_service
if not llm_service:
return None
candidate_rows = [
{
"edgeId": edge.id,
"toNodeId": edge.to_node_id,
"label": edge.label,
"hint": edge.condition.get("prompt") if isinstance(edge.condition, dict) else None,
}
for edge in candidates
]
system_prompt = (
"You are a workflow router. Pick exactly one edge. "
"Return JSON only: {\"edgeId\":\"...\"}."
)
user_prompt = json.dumps(
{
"nodeId": node.id,
"nodeName": node.name,
"userText": context.get("userText", ""),
"assistantText": context.get("assistantText", ""),
"candidates": candidate_rows,
},
ensure_ascii=False,
)
try:
reply = await llm_service.generate(
[
LLMMessage(role="system", content=system_prompt),
LLMMessage(role="user", content=user_prompt),
],
temperature=0.0,
max_tokens=64,
)
except Exception as exc:
logger.warning(f"Session {self.id} workflow llm routing failed: {exc}")
return None
if not reply:
return None
edge_ids = {edge.id for edge in candidates}
node_ids = {edge.to_node_id for edge in candidates}
parsed = self._extract_json_obj(reply)
if isinstance(parsed, dict):
edge_id = parsed.get("edgeId") or parsed.get("id")
node_id = parsed.get("toNodeId") or parsed.get("nodeId")
if isinstance(edge_id, str) and edge_id in edge_ids:
return edge_id
if isinstance(node_id, str) and node_id in node_ids:
return node_id
token_candidates = sorted(edge_ids | node_ids, key=len, reverse=True)
lowered_reply = reply.lower()
for token in token_candidates:
if token.lower() in lowered_reply:
return token
return None
def _merge_runtime_metadata(self, base: Dict[str, Any], overrides: Dict[str, Any]) -> Dict[str, Any]:
"""Merge node-level metadata overrides into session.start metadata."""
merged = dict(base or {})
if not overrides:
return merged
for key, value in overrides.items():
if key == "services" and isinstance(value, dict):
existing = merged.get("services")
merged_services = dict(existing) if isinstance(existing, dict) else {}
merged_services.update(value)
merged["services"] = merged_services
else:
merged[key] = value
return merged
def _extract_json_obj(self, text: str) -> Optional[Dict[str, Any]]:
"""Best-effort extraction of a JSON object from freeform text."""
try:
parsed = json.loads(text)
if isinstance(parsed, dict):
return parsed
except Exception:
pass
match = re.search(r"\{.*\}", text, re.DOTALL)
if not match:
return None
try:
parsed = json.loads(match.group(0))
return parsed if isinstance(parsed, dict) else None
except Exception:
return None

340
core/tool_executor.py Normal file
View File

@@ -0,0 +1,340 @@
"""Server-side tool execution helpers."""
import asyncio
import ast
import operator
from datetime import datetime
from typing import Any, Dict
import aiohttp
from app.backend_client import fetch_tool_resource
_BIN_OPS = {
ast.Add: operator.add,
ast.Sub: operator.sub,
ast.Mult: operator.mul,
ast.Div: operator.truediv,
ast.Mod: operator.mod,
}
_UNARY_OPS = {
ast.UAdd: operator.pos,
ast.USub: operator.neg,
}
_SAFE_EVAL_FUNCS = {
"abs": abs,
"round": round,
"min": min,
"max": max,
"sum": sum,
"len": len,
}
def _validate_safe_expr(node: ast.AST) -> None:
"""Allow only a constrained subset of Python expression nodes."""
if isinstance(node, ast.Expression):
_validate_safe_expr(node.body)
return
if isinstance(node, ast.Constant):
return
if isinstance(node, (ast.List, ast.Tuple, ast.Set)):
for elt in node.elts:
_validate_safe_expr(elt)
return
if isinstance(node, ast.Dict):
for key in node.keys:
if key is not None:
_validate_safe_expr(key)
for value in node.values:
_validate_safe_expr(value)
return
if isinstance(node, ast.BinOp):
if type(node.op) not in _BIN_OPS:
raise ValueError("unsupported operator")
_validate_safe_expr(node.left)
_validate_safe_expr(node.right)
return
if isinstance(node, ast.UnaryOp):
if type(node.op) not in _UNARY_OPS:
raise ValueError("unsupported unary operator")
_validate_safe_expr(node.operand)
return
if isinstance(node, ast.BoolOp):
for value in node.values:
_validate_safe_expr(value)
return
if isinstance(node, ast.Compare):
_validate_safe_expr(node.left)
for comp in node.comparators:
_validate_safe_expr(comp)
return
if isinstance(node, ast.Name):
if node.id not in _SAFE_EVAL_FUNCS and node.id not in {"True", "False", "None"}:
raise ValueError("unknown symbol")
return
if isinstance(node, ast.Call):
if not isinstance(node.func, ast.Name):
raise ValueError("unsafe call target")
if node.func.id not in _SAFE_EVAL_FUNCS:
raise ValueError("function not allowed")
for arg in node.args:
_validate_safe_expr(arg)
for kw in node.keywords:
_validate_safe_expr(kw.value)
return
# Explicitly reject high-risk nodes (import/attribute/subscript/comprehensions/lambda, etc.)
raise ValueError("unsupported expression")
def _safe_eval_python_expr(expression: str) -> Any:
tree = ast.parse(expression, mode="eval")
_validate_safe_expr(tree)
return eval( # noqa: S307 - validated AST + empty builtins
compile(tree, "<code_interpreter>", "eval"),
{"__builtins__": {}},
dict(_SAFE_EVAL_FUNCS),
)
def _json_safe(value: Any) -> Any:
if isinstance(value, (str, int, float, bool)) or value is None:
return value
if isinstance(value, (list, tuple)):
return [_json_safe(v) for v in value]
if isinstance(value, dict):
return {str(k): _json_safe(v) for k, v in value.items()}
return repr(value)
def _safe_eval_expr(expression: str) -> float:
tree = ast.parse(expression, mode="eval")
def _eval(node: ast.AST) -> float:
if isinstance(node, ast.Expression):
return _eval(node.body)
if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)):
return float(node.value)
if isinstance(node, ast.BinOp):
op = _BIN_OPS.get(type(node.op))
if not op:
raise ValueError("unsupported operator")
return float(op(_eval(node.left), _eval(node.right)))
if isinstance(node, ast.UnaryOp):
op = _UNARY_OPS.get(type(node.op))
if not op:
raise ValueError("unsupported unary operator")
return float(op(_eval(node.operand)))
raise ValueError("unsupported expression")
return _eval(tree)
def _extract_tool_name(tool_call: Dict[str, Any]) -> str:
function_payload = tool_call.get("function")
if isinstance(function_payload, dict):
return str(function_payload.get("name") or "").strip()
return ""
def _extract_tool_args(tool_call: Dict[str, Any]) -> Dict[str, Any]:
function_payload = tool_call.get("function")
if not isinstance(function_payload, dict):
return {}
raw = function_payload.get("arguments")
if isinstance(raw, dict):
return raw
if not isinstance(raw, str):
return {}
text = raw.strip()
if not text:
return {}
try:
import json
parsed = json.loads(text)
return parsed if isinstance(parsed, dict) else {}
except Exception:
return {}
async def execute_server_tool(tool_call: Dict[str, Any]) -> Dict[str, Any]:
"""Execute a server-side tool and return normalized result payload."""
call_id = str(tool_call.get("id") or "").strip()
tool_name = _extract_tool_name(tool_call)
args = _extract_tool_args(tool_call)
if tool_name == "calculator":
expression = str(args.get("expression") or "").strip()
if not expression:
return {
"tool_call_id": call_id,
"name": tool_name,
"output": {"error": "missing expression"},
"status": {"code": 400, "message": "bad_request"},
}
if len(expression) > 200:
return {
"tool_call_id": call_id,
"name": tool_name,
"output": {"expression": expression, "error": "expression too long"},
"status": {"code": 422, "message": "invalid_expression"},
}
try:
value = _safe_eval_expr(expression)
if value.is_integer():
value = int(value)
return {
"tool_call_id": call_id,
"name": tool_name,
"output": {"expression": expression, "result": value},
"status": {"code": 200, "message": "ok"},
}
except Exception as exc:
return {
"tool_call_id": call_id,
"name": tool_name,
"output": {"expression": expression, "error": str(exc)},
"status": {"code": 422, "message": "invalid_expression"},
}
if tool_name == "code_interpreter":
code = str(args.get("code") or args.get("expression") or "").strip()
if not code:
return {
"tool_call_id": call_id,
"name": tool_name,
"output": {"error": "missing code"},
"status": {"code": 400, "message": "bad_request"},
}
if len(code) > 500:
return {
"tool_call_id": call_id,
"name": tool_name,
"output": {"error": "code too long"},
"status": {"code": 422, "message": "invalid_code"},
}
try:
result = _safe_eval_python_expr(code)
return {
"tool_call_id": call_id,
"name": tool_name,
"output": {"code": code, "result": _json_safe(result)},
"status": {"code": 200, "message": "ok"},
}
except Exception as exc:
return {
"tool_call_id": call_id,
"name": tool_name,
"output": {"code": code, "error": str(exc)},
"status": {"code": 422, "message": "invalid_code"},
}
if tool_name == "current_time":
now = datetime.now().astimezone()
return {
"tool_call_id": call_id,
"name": tool_name,
"output": {
"local_time": now.strftime("%Y-%m-%d %H:%M:%S"),
"iso": now.isoformat(),
"timezone": str(now.tzinfo or ""),
"timestamp": int(now.timestamp()),
},
"status": {"code": 200, "message": "ok"},
}
if tool_name and tool_name not in {"calculator", "code_interpreter", "current_time"}:
resource = await fetch_tool_resource(tool_name)
if resource and str(resource.get("category") or "") == "query":
method = str(resource.get("http_method") or "GET").strip().upper()
if method not in {"GET", "POST", "PUT", "PATCH", "DELETE"}:
method = "GET"
url = str(resource.get("http_url") or "").strip()
headers = resource.get("http_headers") if isinstance(resource.get("http_headers"), dict) else {}
timeout_ms = resource.get("http_timeout_ms")
try:
timeout_s = max(1.0, float(timeout_ms) / 1000.0)
except Exception:
timeout_s = 10.0
if not url:
return {
"tool_call_id": call_id,
"name": tool_name,
"output": {"error": "http_url not configured"},
"status": {"code": 422, "message": "invalid_tool_config"},
}
request_kwargs: Dict[str, Any] = {}
if method in {"GET", "DELETE"}:
request_kwargs["params"] = args
else:
request_kwargs["json"] = args
try:
timeout = aiohttp.ClientTimeout(total=timeout_s)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.request(method, url, headers=headers, **request_kwargs) as resp:
content_type = str(resp.headers.get("Content-Type") or "").lower()
if "application/json" in content_type:
body: Any = await resp.json()
else:
body = await resp.text()
status_code = int(resp.status)
if 200 <= status_code < 300:
return {
"tool_call_id": call_id,
"name": tool_name,
"output": {
"method": method,
"url": url,
"status_code": status_code,
"response": _json_safe(body),
},
"status": {"code": 200, "message": "ok"},
}
return {
"tool_call_id": call_id,
"name": tool_name,
"output": {
"method": method,
"url": url,
"status_code": status_code,
"response": _json_safe(body),
},
"status": {"code": status_code, "message": "http_error"},
}
except asyncio.TimeoutError:
return {
"tool_call_id": call_id,
"name": tool_name,
"output": {"method": method, "url": url, "error": "request timeout"},
"status": {"code": 504, "message": "http_timeout"},
}
except Exception as exc:
return {
"tool_call_id": call_id,
"name": tool_name,
"output": {"method": method, "url": url, "error": str(exc)},
"status": {"code": 502, "message": "http_request_failed"},
}
return {
"tool_call_id": call_id,
"name": tool_name or "unknown_tool",
"output": {"message": "server tool not implemented"},
"status": {"code": 501, "message": "not_implemented"},
}

247
core/transports.py Normal file
View File

@@ -0,0 +1,247 @@
"""Transport layer for WebSocket and WebRTC communication."""
import asyncio
import json
from abc import ABC, abstractmethod
from typing import Optional
from fastapi import WebSocket
from starlette.websockets import WebSocketState
from loguru import logger
# Try to import aiortc (optional for WebRTC functionality)
try:
from aiortc import RTCPeerConnection
AIORTC_AVAILABLE = True
except ImportError:
AIORTC_AVAILABLE = False
RTCPeerConnection = None # Type hint placeholder
class BaseTransport(ABC):
"""
Abstract base class for transports.
All transports must implement send_event and send_audio methods.
"""
@abstractmethod
async def send_event(self, event: dict) -> None:
"""
Send a JSON event to the client.
Args:
event: Event data as dictionary
"""
pass
@abstractmethod
async def send_audio(self, pcm_bytes: bytes) -> None:
"""
Send audio data to the client.
Args:
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
"""
pass
@abstractmethod
async def close(self) -> None:
"""Close the transport and cleanup resources."""
pass
class SocketTransport(BaseTransport):
"""
WebSocket transport for raw audio streaming.
Handles mixed text/binary frames over WebSocket connection.
Uses asyncio.Lock to prevent frame interleaving.
"""
def __init__(self, websocket: WebSocket):
"""
Initialize WebSocket transport.
Args:
websocket: FastAPI WebSocket connection
"""
self.ws = websocket
self.lock = asyncio.Lock() # Prevent frame interleaving
self._closed = False
def _ws_disconnected(self) -> bool:
"""Best-effort check for websocket disconnection state."""
return (
self.ws.client_state == WebSocketState.DISCONNECTED
or self.ws.application_state == WebSocketState.DISCONNECTED
)
async def send_event(self, event: dict) -> None:
"""
Send a JSON event via WebSocket.
Args:
event: Event data as dictionary
"""
if self._closed or self._ws_disconnected():
logger.warning("Attempted to send event on closed transport")
self._closed = True
return
async with self.lock:
try:
await self.ws.send_text(json.dumps(event))
logger.debug(f"Sent event: {event.get('event', 'unknown')}")
except RuntimeError as e:
self._closed = True
if self._ws_disconnected() or "close message has been sent" in str(e):
logger.debug(f"Skip sending event on closed websocket: {e!r}")
return
logger.error(f"Error sending event: {e!r}")
except Exception as e:
self._closed = True
if self._ws_disconnected():
logger.debug(f"Skip sending event on disconnected websocket: {e!r}")
return
logger.error(f"Error sending event: {e!r}")
async def send_audio(self, pcm_bytes: bytes) -> None:
"""
Send PCM audio data via WebSocket.
Args:
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
"""
if self._closed or self._ws_disconnected():
logger.warning("Attempted to send audio on closed transport")
self._closed = True
return
async with self.lock:
try:
await self.ws.send_bytes(pcm_bytes)
except RuntimeError as e:
self._closed = True
if self._ws_disconnected() or "close message has been sent" in str(e):
logger.debug(f"Skip sending audio on closed websocket: {e!r}")
return
logger.error(f"Error sending audio: {e!r}")
except Exception as e:
self._closed = True
if self._ws_disconnected():
logger.debug(f"Skip sending audio on disconnected websocket: {e!r}")
return
logger.error(f"Error sending audio: {e!r}")
async def close(self) -> None:
"""Close the WebSocket connection."""
if self._closed:
return
self._closed = True
if self._ws_disconnected():
return
try:
await self.ws.close()
except RuntimeError as e:
# Already closed by another task/path; safe to ignore.
if "close message has been sent" in str(e):
logger.debug(f"WebSocket already closed: {e}")
return
logger.error(f"Error closing WebSocket: {e}")
except Exception as e:
logger.error(f"Error closing WebSocket: {e}")
@property
def is_closed(self) -> bool:
"""Check if the transport is closed."""
return self._closed
class WebRtcTransport(BaseTransport):
"""
WebRTC transport for WebRTC audio streaming.
Uses WebSocket for signaling and RTCPeerConnection for media.
"""
def __init__(self, websocket: WebSocket, pc):
"""
Initialize WebRTC transport.
Args:
websocket: FastAPI WebSocket connection for signaling
pc: RTCPeerConnection for media transport
"""
if not AIORTC_AVAILABLE:
raise RuntimeError("aiortc is not available - WebRTC transport cannot be used")
self.ws = websocket
self.pc = pc
self.outbound_track = None # MediaStreamTrack for outbound audio
self._closed = False
async def send_event(self, event: dict) -> None:
"""
Send a JSON event via WebSocket signaling.
Args:
event: Event data as dictionary
"""
if self._closed:
logger.warning("Attempted to send event on closed transport")
return
try:
await self.ws.send_text(json.dumps(event))
logger.debug(f"Sent event: {event.get('event', 'unknown')}")
except Exception as e:
logger.error(f"Error sending event: {e}")
self._closed = True
async def send_audio(self, pcm_bytes: bytes) -> None:
"""
Send audio data via WebRTC track.
Note: In WebRTC, you don't send bytes directly. You push frames
to a MediaStreamTrack that the peer connection is reading.
Args:
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
"""
if self._closed:
logger.warning("Attempted to send audio on closed transport")
return
# This would require a custom MediaStreamTrack implementation
# For now, we'll log this as a placeholder
logger.debug(f"Audio bytes queued for WebRTC track: {len(pcm_bytes)} bytes")
# TODO: Implement outbound audio track if needed
# if self.outbound_track:
# await self.outbound_track.add_frame(pcm_bytes)
async def close(self) -> None:
"""Close the WebRTC connection."""
self._closed = True
try:
await self.pc.close()
await self.ws.close()
except Exception as e:
logger.error(f"Error closing WebRTC transport: {e}")
@property
def is_closed(self) -> bool:
"""Check if the transport is closed."""
return self._closed
def set_outbound_track(self, track):
"""
Set the outbound audio track for sending audio to client.
Args:
track: MediaStreamTrack for outbound audio
"""
self.outbound_track = track
logger.debug("Set outbound track for WebRTC transport")

402
core/workflow_runner.py Normal file
View File

@@ -0,0 +1,402 @@
"""Workflow runtime helpers for session-level node routing.
MVP goals:
- Parse workflow graph payload from WS session.start metadata
- Track current node
- Evaluate edge conditions on each assistant turn completion
- Provide per-node runtime metadata overrides (prompt/greeting/services)
"""
from __future__ import annotations
from dataclasses import dataclass
import json
import re
from typing import Any, Awaitable, Callable, Dict, List, Optional
from loguru import logger
_NODE_TYPE_MAP = {
"conversation": "assistant",
"assistant": "assistant",
"human": "human_transfer",
"human_transfer": "human_transfer",
"tool": "tool",
"end": "end",
"start": "start",
}
def _normalize_node_type(raw_type: Any) -> str:
value = str(raw_type or "").strip().lower()
return _NODE_TYPE_MAP.get(value, "assistant")
def _safe_str(value: Any) -> str:
if value is None:
return ""
return str(value)
def _normalize_condition(raw: Any, label: Optional[str]) -> Dict[str, Any]:
if not isinstance(raw, dict):
if label:
return {"type": "contains", "source": "user", "value": str(label)}
return {"type": "always"}
condition = dict(raw)
condition_type = str(condition.get("type", "always")).strip().lower()
if not condition_type:
condition_type = "always"
condition["type"] = condition_type
condition["source"] = str(condition.get("source", "user")).strip().lower() or "user"
return condition
@dataclass
class WorkflowNodeDef:
id: str
name: str
node_type: str
is_start: bool
prompt: Optional[str]
message_plan: Dict[str, Any]
assistant_id: Optional[str]
assistant: Dict[str, Any]
tool: Optional[Dict[str, Any]]
raw: Dict[str, Any]
@dataclass
class WorkflowEdgeDef:
id: str
from_node_id: str
to_node_id: str
label: Optional[str]
condition: Dict[str, Any]
priority: int
order: int
raw: Dict[str, Any]
@dataclass
class WorkflowTransition:
edge: WorkflowEdgeDef
node: WorkflowNodeDef
LlmRouter = Callable[
[WorkflowNodeDef, List[WorkflowEdgeDef], Dict[str, str]],
Awaitable[Optional[str]],
]
class WorkflowRunner:
"""In-memory workflow graph for a single active session."""
def __init__(self, workflow_id: str, name: str, nodes: List[WorkflowNodeDef], edges: List[WorkflowEdgeDef]):
self.workflow_id = workflow_id
self.name = name
self._nodes: Dict[str, WorkflowNodeDef] = {node.id: node for node in nodes}
self._edges = edges
self.current_node_id: Optional[str] = None
@classmethod
def from_payload(cls, payload: Any) -> Optional["WorkflowRunner"]:
if not isinstance(payload, dict):
return None
raw_nodes = payload.get("nodes")
raw_edges = payload.get("edges")
if not isinstance(raw_nodes, list) or len(raw_nodes) == 0:
return None
nodes: List[WorkflowNodeDef] = []
for i, raw in enumerate(raw_nodes):
if not isinstance(raw, dict):
continue
node_id = _safe_str(raw.get("id") or raw.get("name") or f"node_{i + 1}").strip() or f"node_{i + 1}"
node_name = _safe_str(raw.get("name") or node_id).strip() or node_id
node_type = _normalize_node_type(raw.get("type"))
is_start = bool(raw.get("isStart")) or node_type == "start"
prompt: Optional[str] = None
if "prompt" in raw:
prompt = _safe_str(raw.get("prompt"))
message_plan = raw.get("messagePlan")
if not isinstance(message_plan, dict):
message_plan = {}
assistant_cfg = raw.get("assistant")
if not isinstance(assistant_cfg, dict):
assistant_cfg = {}
tool_cfg = raw.get("tool")
if not isinstance(tool_cfg, dict):
tool_cfg = None
assistant_id = raw.get("assistantId")
if assistant_id is not None:
assistant_id = _safe_str(assistant_id).strip() or None
nodes.append(
WorkflowNodeDef(
id=node_id,
name=node_name,
node_type=node_type,
is_start=is_start,
prompt=prompt,
message_plan=message_plan,
assistant_id=assistant_id,
assistant=assistant_cfg,
tool=tool_cfg,
raw=raw,
)
)
if not nodes:
return None
node_ids = {node.id for node in nodes}
edges: List[WorkflowEdgeDef] = []
for i, raw in enumerate(raw_edges if isinstance(raw_edges, list) else []):
if not isinstance(raw, dict):
continue
from_node_id = _safe_str(
raw.get("fromNodeId") or raw.get("from") or raw.get("from_") or raw.get("source")
).strip()
to_node_id = _safe_str(raw.get("toNodeId") or raw.get("to") or raw.get("target")).strip()
if not from_node_id or not to_node_id:
continue
if from_node_id not in node_ids or to_node_id not in node_ids:
continue
label = raw.get("label")
if label is not None:
label = _safe_str(label)
condition = _normalize_condition(raw.get("condition"), label=label)
priority = 100
try:
priority = int(raw.get("priority", 100))
except (TypeError, ValueError):
priority = 100
edge_id = _safe_str(raw.get("id") or f"e_{from_node_id}_{to_node_id}_{i + 1}").strip() or f"e_{i + 1}"
edges.append(
WorkflowEdgeDef(
id=edge_id,
from_node_id=from_node_id,
to_node_id=to_node_id,
label=label,
condition=condition,
priority=priority,
order=i,
raw=raw,
)
)
workflow_id = _safe_str(payload.get("id") or "workflow")
workflow_name = _safe_str(payload.get("name") or workflow_id)
return cls(workflow_id=workflow_id, name=workflow_name, nodes=nodes, edges=edges)
def bootstrap(self) -> Optional[WorkflowNodeDef]:
start_node = self._resolve_start_node()
if not start_node:
return None
self.current_node_id = start_node.id
return start_node
@property
def current_node(self) -> Optional[WorkflowNodeDef]:
if not self.current_node_id:
return None
return self._nodes.get(self.current_node_id)
def outgoing_edges(self, node_id: str) -> List[WorkflowEdgeDef]:
edges = [edge for edge in self._edges if edge.from_node_id == node_id]
return sorted(edges, key=lambda edge: (edge.priority, edge.order))
def next_default_transition(self) -> Optional[WorkflowTransition]:
node = self.current_node
if not node:
return None
for edge in self.outgoing_edges(node.id):
cond_type = str(edge.condition.get("type", "always")).strip().lower()
if cond_type in {"", "always", "default"}:
target = self._nodes.get(edge.to_node_id)
if target:
return WorkflowTransition(edge=edge, node=target)
return None
async def route(
self,
*,
user_text: str,
assistant_text: str,
llm_router: Optional[LlmRouter] = None,
) -> Optional[WorkflowTransition]:
node = self.current_node
if not node:
return None
outgoing = self.outgoing_edges(node.id)
if not outgoing:
return None
llm_edges: List[WorkflowEdgeDef] = []
for edge in outgoing:
cond_type = str(edge.condition.get("type", "always")).strip().lower()
if cond_type == "llm":
llm_edges.append(edge)
continue
if self._matches_condition(edge, user_text=user_text, assistant_text=assistant_text):
target = self._nodes.get(edge.to_node_id)
if target:
return WorkflowTransition(edge=edge, node=target)
if llm_edges and llm_router:
selection = await llm_router(
node,
llm_edges,
{
"userText": user_text,
"assistantText": assistant_text,
},
)
if selection:
for edge in llm_edges:
if selection in {edge.id, edge.to_node_id}:
target = self._nodes.get(edge.to_node_id)
if target:
return WorkflowTransition(edge=edge, node=target)
for edge in outgoing:
cond_type = str(edge.condition.get("type", "always")).strip().lower()
if cond_type in {"", "always", "default"}:
target = self._nodes.get(edge.to_node_id)
if target:
return WorkflowTransition(edge=edge, node=target)
return None
def apply_transition(self, transition: WorkflowTransition) -> None:
self.current_node_id = transition.node.id
def build_runtime_metadata(self, node: WorkflowNodeDef) -> Dict[str, Any]:
assistant_cfg = node.assistant if isinstance(node.assistant, dict) else {}
message_plan = node.message_plan if isinstance(node.message_plan, dict) else {}
metadata: Dict[str, Any] = {}
if node.prompt is not None:
metadata["systemPrompt"] = node.prompt
elif "systemPrompt" in assistant_cfg:
metadata["systemPrompt"] = _safe_str(assistant_cfg.get("systemPrompt"))
elif "prompt" in assistant_cfg:
metadata["systemPrompt"] = _safe_str(assistant_cfg.get("prompt"))
first_message = message_plan.get("firstMessage")
if first_message is not None:
metadata["greeting"] = _safe_str(first_message)
elif "greeting" in assistant_cfg:
metadata["greeting"] = _safe_str(assistant_cfg.get("greeting"))
elif "opener" in assistant_cfg:
metadata["greeting"] = _safe_str(assistant_cfg.get("opener"))
services = assistant_cfg.get("services")
if isinstance(services, dict):
metadata["services"] = services
if node.assistant_id:
metadata["assistantId"] = node.assistant_id
return metadata
def _resolve_start_node(self) -> Optional[WorkflowNodeDef]:
explicit_start = next((node for node in self._nodes.values() if node.is_start), None)
if not explicit_start:
explicit_start = next((node for node in self._nodes.values() if node.node_type == "start"), None)
if explicit_start:
# If a dedicated start node exists, try to move to its first default target.
if explicit_start.node_type == "start":
visited = {explicit_start.id}
current = explicit_start
for _ in range(8):
transition = self._first_default_transition_from(current.id)
if not transition:
return current
current = transition.node
if current.id in visited:
break
visited.add(current.id)
return current
return explicit_start
assistant_node = next((node for node in self._nodes.values() if node.node_type == "assistant"), None)
if assistant_node:
return assistant_node
return next(iter(self._nodes.values()), None)
def _first_default_transition_from(self, node_id: str) -> Optional[WorkflowTransition]:
for edge in self.outgoing_edges(node_id):
cond_type = str(edge.condition.get("type", "always")).strip().lower()
if cond_type in {"", "always", "default"}:
node = self._nodes.get(edge.to_node_id)
if node:
return WorkflowTransition(edge=edge, node=node)
return None
def _matches_condition(self, edge: WorkflowEdgeDef, *, user_text: str, assistant_text: str) -> bool:
condition = edge.condition or {"type": "always"}
cond_type = str(condition.get("type", "always")).strip().lower()
source = str(condition.get("source", "user")).strip().lower()
if cond_type in {"", "always", "default"}:
return True
text = assistant_text if source == "assistant" else user_text
text_lower = (text or "").lower()
if cond_type == "contains":
values: List[str] = []
if isinstance(condition.get("values"), list):
values = [_safe_str(v).strip().lower() for v in condition["values"] if _safe_str(v).strip()]
if not values:
single = _safe_str(condition.get("value") or condition.get("keyword") or edge.label).strip().lower()
if single:
values = [single]
if not values:
return False
return any(value in text_lower for value in values)
if cond_type == "equals":
expected = _safe_str(condition.get("value") or "").strip().lower()
return bool(expected) and text_lower == expected
if cond_type == "regex":
pattern = _safe_str(condition.get("value") or condition.get("pattern") or "").strip()
if not pattern:
return False
try:
return bool(re.search(pattern, text or "", re.IGNORECASE))
except re.error:
logger.warning(f"Invalid workflow regex condition: {pattern}")
return False
if cond_type == "json":
value = _safe_str(condition.get("value") or "").strip()
if not value:
return False
try:
obj = json.loads(text or "")
except Exception:
return False
return str(obj) == value
return False