"""Workflow runtime helpers for session-level node routing. MVP goals: - Parse workflow graph payload from WS session.start metadata - Track current node - Evaluate edge conditions on each assistant turn completion - Provide per-node runtime metadata overrides (prompt/greeting/services) """ from __future__ import annotations from dataclasses import dataclass import json import re from typing import Any, Awaitable, Callable, Dict, List, Optional from loguru import logger _NODE_TYPE_MAP = { "conversation": "assistant", "assistant": "assistant", "human": "human_transfer", "human_transfer": "human_transfer", "tool": "tool", "end": "end", "start": "start", } def _normalize_node_type(raw_type: Any) -> str: value = str(raw_type or "").strip().lower() return _NODE_TYPE_MAP.get(value, "assistant") def _safe_str(value: Any) -> str: if value is None: return "" return str(value) def _normalize_condition(raw: Any, label: Optional[str]) -> Dict[str, Any]: if not isinstance(raw, dict): if label: return {"type": "contains", "source": "user", "value": str(label)} return {"type": "always"} condition = dict(raw) condition_type = str(condition.get("type", "always")).strip().lower() if not condition_type: condition_type = "always" condition["type"] = condition_type condition["source"] = str(condition.get("source", "user")).strip().lower() or "user" return condition @dataclass class WorkflowNodeDef: id: str name: str node_type: str is_start: bool prompt: Optional[str] message_plan: Dict[str, Any] assistant_id: Optional[str] assistant: Dict[str, Any] tool: Optional[Dict[str, Any]] raw: Dict[str, Any] @dataclass class WorkflowEdgeDef: id: str from_node_id: str to_node_id: str label: Optional[str] condition: Dict[str, Any] priority: int order: int raw: Dict[str, Any] @dataclass class WorkflowTransition: edge: WorkflowEdgeDef node: WorkflowNodeDef LlmRouter = Callable[ [WorkflowNodeDef, List[WorkflowEdgeDef], Dict[str, str]], Awaitable[Optional[str]], ] class WorkflowRunner: """In-memory workflow graph for a single active session.""" def __init__(self, workflow_id: str, name: str, nodes: List[WorkflowNodeDef], edges: List[WorkflowEdgeDef]): self.workflow_id = workflow_id self.name = name self._nodes: Dict[str, WorkflowNodeDef] = {node.id: node for node in nodes} self._edges = edges self.current_node_id: Optional[str] = None @classmethod def from_payload(cls, payload: Any) -> Optional["WorkflowRunner"]: if not isinstance(payload, dict): return None raw_nodes = payload.get("nodes") raw_edges = payload.get("edges") if not isinstance(raw_nodes, list) or len(raw_nodes) == 0: return None nodes: List[WorkflowNodeDef] = [] for i, raw in enumerate(raw_nodes): if not isinstance(raw, dict): continue node_id = _safe_str(raw.get("id") or raw.get("name") or f"node_{i + 1}").strip() or f"node_{i + 1}" node_name = _safe_str(raw.get("name") or node_id).strip() or node_id node_type = _normalize_node_type(raw.get("type")) is_start = bool(raw.get("isStart")) or node_type == "start" prompt: Optional[str] = None if "prompt" in raw: prompt = _safe_str(raw.get("prompt")) message_plan = raw.get("messagePlan") if not isinstance(message_plan, dict): message_plan = {} assistant_cfg = raw.get("assistant") if not isinstance(assistant_cfg, dict): assistant_cfg = {} tool_cfg = raw.get("tool") if not isinstance(tool_cfg, dict): tool_cfg = None assistant_id = raw.get("assistantId") if assistant_id is not None: assistant_id = _safe_str(assistant_id).strip() or None nodes.append( WorkflowNodeDef( id=node_id, name=node_name, node_type=node_type, is_start=is_start, prompt=prompt, message_plan=message_plan, assistant_id=assistant_id, assistant=assistant_cfg, tool=tool_cfg, raw=raw, ) ) if not nodes: return None node_ids = {node.id for node in nodes} edges: List[WorkflowEdgeDef] = [] for i, raw in enumerate(raw_edges if isinstance(raw_edges, list) else []): if not isinstance(raw, dict): continue from_node_id = _safe_str( raw.get("fromNodeId") or raw.get("from") or raw.get("from_") or raw.get("source") ).strip() to_node_id = _safe_str(raw.get("toNodeId") or raw.get("to") or raw.get("target")).strip() if not from_node_id or not to_node_id: continue if from_node_id not in node_ids or to_node_id not in node_ids: continue label = raw.get("label") if label is not None: label = _safe_str(label) condition = _normalize_condition(raw.get("condition"), label=label) priority = 100 try: priority = int(raw.get("priority", 100)) except (TypeError, ValueError): priority = 100 edge_id = _safe_str(raw.get("id") or f"e_{from_node_id}_{to_node_id}_{i + 1}").strip() or f"e_{i + 1}" edges.append( WorkflowEdgeDef( id=edge_id, from_node_id=from_node_id, to_node_id=to_node_id, label=label, condition=condition, priority=priority, order=i, raw=raw, ) ) workflow_id = _safe_str(payload.get("id") or "workflow") workflow_name = _safe_str(payload.get("name") or workflow_id) return cls(workflow_id=workflow_id, name=workflow_name, nodes=nodes, edges=edges) def bootstrap(self) -> Optional[WorkflowNodeDef]: start_node = self._resolve_start_node() if not start_node: return None self.current_node_id = start_node.id return start_node @property def current_node(self) -> Optional[WorkflowNodeDef]: if not self.current_node_id: return None return self._nodes.get(self.current_node_id) def outgoing_edges(self, node_id: str) -> List[WorkflowEdgeDef]: edges = [edge for edge in self._edges if edge.from_node_id == node_id] return sorted(edges, key=lambda edge: (edge.priority, edge.order)) def next_default_transition(self) -> Optional[WorkflowTransition]: node = self.current_node if not node: return None for edge in self.outgoing_edges(node.id): cond_type = str(edge.condition.get("type", "always")).strip().lower() if cond_type in {"", "always", "default"}: target = self._nodes.get(edge.to_node_id) if target: return WorkflowTransition(edge=edge, node=target) return None async def route( self, *, user_text: str, assistant_text: str, llm_router: Optional[LlmRouter] = None, ) -> Optional[WorkflowTransition]: node = self.current_node if not node: return None outgoing = self.outgoing_edges(node.id) if not outgoing: return None llm_edges: List[WorkflowEdgeDef] = [] for edge in outgoing: cond_type = str(edge.condition.get("type", "always")).strip().lower() if cond_type == "llm": llm_edges.append(edge) continue if self._matches_condition(edge, user_text=user_text, assistant_text=assistant_text): target = self._nodes.get(edge.to_node_id) if target: return WorkflowTransition(edge=edge, node=target) if llm_edges and llm_router: selection = await llm_router( node, llm_edges, { "userText": user_text, "assistantText": assistant_text, }, ) if selection: for edge in llm_edges: if selection in {edge.id, edge.to_node_id}: target = self._nodes.get(edge.to_node_id) if target: return WorkflowTransition(edge=edge, node=target) for edge in outgoing: cond_type = str(edge.condition.get("type", "always")).strip().lower() if cond_type in {"", "always", "default"}: target = self._nodes.get(edge.to_node_id) if target: return WorkflowTransition(edge=edge, node=target) return None def apply_transition(self, transition: WorkflowTransition) -> None: self.current_node_id = transition.node.id def build_runtime_metadata(self, node: WorkflowNodeDef) -> Dict[str, Any]: assistant_cfg = node.assistant if isinstance(node.assistant, dict) else {} message_plan = node.message_plan if isinstance(node.message_plan, dict) else {} metadata: Dict[str, Any] = {} if node.prompt is not None: metadata["systemPrompt"] = node.prompt elif "systemPrompt" in assistant_cfg: metadata["systemPrompt"] = _safe_str(assistant_cfg.get("systemPrompt")) elif "prompt" in assistant_cfg: metadata["systemPrompt"] = _safe_str(assistant_cfg.get("prompt")) first_message = message_plan.get("firstMessage") if first_message is not None: metadata["greeting"] = _safe_str(first_message) elif "greeting" in assistant_cfg: metadata["greeting"] = _safe_str(assistant_cfg.get("greeting")) elif "opener" in assistant_cfg: metadata["greeting"] = _safe_str(assistant_cfg.get("opener")) services = assistant_cfg.get("services") if isinstance(services, dict): metadata["services"] = services if node.assistant_id: metadata["assistantId"] = node.assistant_id return metadata def _resolve_start_node(self) -> Optional[WorkflowNodeDef]: explicit_start = next((node for node in self._nodes.values() if node.is_start), None) if not explicit_start: explicit_start = next((node for node in self._nodes.values() if node.node_type == "start"), None) if explicit_start: # If a dedicated start node exists, try to move to its first default target. if explicit_start.node_type == "start": visited = {explicit_start.id} current = explicit_start for _ in range(8): transition = self._first_default_transition_from(current.id) if not transition: return current current = transition.node if current.id in visited: break visited.add(current.id) return current return explicit_start assistant_node = next((node for node in self._nodes.values() if node.node_type == "assistant"), None) if assistant_node: return assistant_node return next(iter(self._nodes.values()), None) def _first_default_transition_from(self, node_id: str) -> Optional[WorkflowTransition]: for edge in self.outgoing_edges(node_id): cond_type = str(edge.condition.get("type", "always")).strip().lower() if cond_type in {"", "always", "default"}: node = self._nodes.get(edge.to_node_id) if node: return WorkflowTransition(edge=edge, node=node) return None def _matches_condition(self, edge: WorkflowEdgeDef, *, user_text: str, assistant_text: str) -> bool: condition = edge.condition or {"type": "always"} cond_type = str(condition.get("type", "always")).strip().lower() source = str(condition.get("source", "user")).strip().lower() if cond_type in {"", "always", "default"}: return True text = assistant_text if source == "assistant" else user_text text_lower = (text or "").lower() if cond_type == "contains": values: List[str] = [] if isinstance(condition.get("values"), list): values = [_safe_str(v).strip().lower() for v in condition["values"] if _safe_str(v).strip()] if not values: single = _safe_str(condition.get("value") or condition.get("keyword") or edge.label).strip().lower() if single: values = [single] if not values: return False return any(value in text_lower for value in values) if cond_type == "equals": expected = _safe_str(condition.get("value") or "").strip().lower() return bool(expected) and text_lower == expected if cond_type == "regex": pattern = _safe_str(condition.get("value") or condition.get("pattern") or "").strip() if not pattern: return False try: return bool(re.search(pattern, text or "", re.IGNORECASE)) except re.error: logger.warning(f"Invalid workflow regex condition: {pattern}") return False if cond_type == "json": value = _safe_str(condition.get("value") or "").strip() if not value: return False try: obj = json.loads(text or "") except Exception: return False return str(obj) == value return False