403 lines
14 KiB
Python
403 lines
14 KiB
Python
"""Workflow runtime helpers for session-level node routing.
|
|
|
|
MVP goals:
|
|
- Parse workflow graph payload from WS session.start metadata
|
|
- Track current node
|
|
- Evaluate edge conditions on each assistant turn completion
|
|
- Provide per-node runtime metadata overrides (prompt/greeting/services)
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
import json
|
|
import re
|
|
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
|
|
|
from loguru import logger
|
|
|
|
|
|
_NODE_TYPE_MAP = {
|
|
"conversation": "assistant",
|
|
"assistant": "assistant",
|
|
"human": "human_transfer",
|
|
"human_transfer": "human_transfer",
|
|
"tool": "tool",
|
|
"end": "end",
|
|
"start": "start",
|
|
}
|
|
|
|
|
|
def _normalize_node_type(raw_type: Any) -> str:
|
|
value = str(raw_type or "").strip().lower()
|
|
return _NODE_TYPE_MAP.get(value, "assistant")
|
|
|
|
|
|
def _safe_str(value: Any) -> str:
|
|
if value is None:
|
|
return ""
|
|
return str(value)
|
|
|
|
|
|
def _normalize_condition(raw: Any, label: Optional[str]) -> Dict[str, Any]:
|
|
if not isinstance(raw, dict):
|
|
if label:
|
|
return {"type": "contains", "source": "user", "value": str(label)}
|
|
return {"type": "always"}
|
|
|
|
condition = dict(raw)
|
|
condition_type = str(condition.get("type", "always")).strip().lower()
|
|
if not condition_type:
|
|
condition_type = "always"
|
|
condition["type"] = condition_type
|
|
condition["source"] = str(condition.get("source", "user")).strip().lower() or "user"
|
|
return condition
|
|
|
|
|
|
@dataclass
|
|
class WorkflowNodeDef:
|
|
id: str
|
|
name: str
|
|
node_type: str
|
|
is_start: bool
|
|
prompt: Optional[str]
|
|
message_plan: Dict[str, Any]
|
|
assistant_id: Optional[str]
|
|
assistant: Dict[str, Any]
|
|
tool: Optional[Dict[str, Any]]
|
|
raw: Dict[str, Any]
|
|
|
|
|
|
@dataclass
|
|
class WorkflowEdgeDef:
|
|
id: str
|
|
from_node_id: str
|
|
to_node_id: str
|
|
label: Optional[str]
|
|
condition: Dict[str, Any]
|
|
priority: int
|
|
order: int
|
|
raw: Dict[str, Any]
|
|
|
|
|
|
@dataclass
|
|
class WorkflowTransition:
|
|
edge: WorkflowEdgeDef
|
|
node: WorkflowNodeDef
|
|
|
|
|
|
LlmRouter = Callable[
|
|
[WorkflowNodeDef, List[WorkflowEdgeDef], Dict[str, str]],
|
|
Awaitable[Optional[str]],
|
|
]
|
|
|
|
|
|
class WorkflowRunner:
|
|
"""In-memory workflow graph for a single active session."""
|
|
|
|
def __init__(self, workflow_id: str, name: str, nodes: List[WorkflowNodeDef], edges: List[WorkflowEdgeDef]):
|
|
self.workflow_id = workflow_id
|
|
self.name = name
|
|
self._nodes: Dict[str, WorkflowNodeDef] = {node.id: node for node in nodes}
|
|
self._edges = edges
|
|
self.current_node_id: Optional[str] = None
|
|
|
|
@classmethod
|
|
def from_payload(cls, payload: Any) -> Optional["WorkflowRunner"]:
|
|
if not isinstance(payload, dict):
|
|
return None
|
|
|
|
raw_nodes = payload.get("nodes")
|
|
raw_edges = payload.get("edges")
|
|
if not isinstance(raw_nodes, list) or len(raw_nodes) == 0:
|
|
return None
|
|
|
|
nodes: List[WorkflowNodeDef] = []
|
|
for i, raw in enumerate(raw_nodes):
|
|
if not isinstance(raw, dict):
|
|
continue
|
|
|
|
node_id = _safe_str(raw.get("id") or raw.get("name") or f"node_{i + 1}").strip() or f"node_{i + 1}"
|
|
node_name = _safe_str(raw.get("name") or node_id).strip() or node_id
|
|
node_type = _normalize_node_type(raw.get("type"))
|
|
is_start = bool(raw.get("isStart")) or node_type == "start"
|
|
|
|
prompt: Optional[str] = None
|
|
if "prompt" in raw:
|
|
prompt = _safe_str(raw.get("prompt"))
|
|
|
|
message_plan = raw.get("messagePlan")
|
|
if not isinstance(message_plan, dict):
|
|
message_plan = {}
|
|
|
|
assistant_cfg = raw.get("assistant")
|
|
if not isinstance(assistant_cfg, dict):
|
|
assistant_cfg = {}
|
|
|
|
tool_cfg = raw.get("tool")
|
|
if not isinstance(tool_cfg, dict):
|
|
tool_cfg = None
|
|
|
|
assistant_id = raw.get("assistantId")
|
|
if assistant_id is not None:
|
|
assistant_id = _safe_str(assistant_id).strip() or None
|
|
|
|
nodes.append(
|
|
WorkflowNodeDef(
|
|
id=node_id,
|
|
name=node_name,
|
|
node_type=node_type,
|
|
is_start=is_start,
|
|
prompt=prompt,
|
|
message_plan=message_plan,
|
|
assistant_id=assistant_id,
|
|
assistant=assistant_cfg,
|
|
tool=tool_cfg,
|
|
raw=raw,
|
|
)
|
|
)
|
|
|
|
if not nodes:
|
|
return None
|
|
|
|
node_ids = {node.id for node in nodes}
|
|
edges: List[WorkflowEdgeDef] = []
|
|
for i, raw in enumerate(raw_edges if isinstance(raw_edges, list) else []):
|
|
if not isinstance(raw, dict):
|
|
continue
|
|
|
|
from_node_id = _safe_str(
|
|
raw.get("fromNodeId") or raw.get("from") or raw.get("from_") or raw.get("source")
|
|
).strip()
|
|
to_node_id = _safe_str(raw.get("toNodeId") or raw.get("to") or raw.get("target")).strip()
|
|
if not from_node_id or not to_node_id:
|
|
continue
|
|
if from_node_id not in node_ids or to_node_id not in node_ids:
|
|
continue
|
|
|
|
label = raw.get("label")
|
|
if label is not None:
|
|
label = _safe_str(label)
|
|
|
|
condition = _normalize_condition(raw.get("condition"), label=label)
|
|
|
|
priority = 100
|
|
try:
|
|
priority = int(raw.get("priority", 100))
|
|
except (TypeError, ValueError):
|
|
priority = 100
|
|
|
|
edge_id = _safe_str(raw.get("id") or f"e_{from_node_id}_{to_node_id}_{i + 1}").strip() or f"e_{i + 1}"
|
|
|
|
edges.append(
|
|
WorkflowEdgeDef(
|
|
id=edge_id,
|
|
from_node_id=from_node_id,
|
|
to_node_id=to_node_id,
|
|
label=label,
|
|
condition=condition,
|
|
priority=priority,
|
|
order=i,
|
|
raw=raw,
|
|
)
|
|
)
|
|
|
|
workflow_id = _safe_str(payload.get("id") or "workflow")
|
|
workflow_name = _safe_str(payload.get("name") or workflow_id)
|
|
return cls(workflow_id=workflow_id, name=workflow_name, nodes=nodes, edges=edges)
|
|
|
|
def bootstrap(self) -> Optional[WorkflowNodeDef]:
|
|
start_node = self._resolve_start_node()
|
|
if not start_node:
|
|
return None
|
|
self.current_node_id = start_node.id
|
|
return start_node
|
|
|
|
@property
|
|
def current_node(self) -> Optional[WorkflowNodeDef]:
|
|
if not self.current_node_id:
|
|
return None
|
|
return self._nodes.get(self.current_node_id)
|
|
|
|
def outgoing_edges(self, node_id: str) -> List[WorkflowEdgeDef]:
|
|
edges = [edge for edge in self._edges if edge.from_node_id == node_id]
|
|
return sorted(edges, key=lambda edge: (edge.priority, edge.order))
|
|
|
|
def next_default_transition(self) -> Optional[WorkflowTransition]:
|
|
node = self.current_node
|
|
if not node:
|
|
return None
|
|
for edge in self.outgoing_edges(node.id):
|
|
cond_type = str(edge.condition.get("type", "always")).strip().lower()
|
|
if cond_type in {"", "always", "default"}:
|
|
target = self._nodes.get(edge.to_node_id)
|
|
if target:
|
|
return WorkflowTransition(edge=edge, node=target)
|
|
return None
|
|
|
|
async def route(
|
|
self,
|
|
*,
|
|
user_text: str,
|
|
assistant_text: str,
|
|
llm_router: Optional[LlmRouter] = None,
|
|
) -> Optional[WorkflowTransition]:
|
|
node = self.current_node
|
|
if not node:
|
|
return None
|
|
|
|
outgoing = self.outgoing_edges(node.id)
|
|
if not outgoing:
|
|
return None
|
|
|
|
llm_edges: List[WorkflowEdgeDef] = []
|
|
for edge in outgoing:
|
|
cond_type = str(edge.condition.get("type", "always")).strip().lower()
|
|
if cond_type == "llm":
|
|
llm_edges.append(edge)
|
|
continue
|
|
if self._matches_condition(edge, user_text=user_text, assistant_text=assistant_text):
|
|
target = self._nodes.get(edge.to_node_id)
|
|
if target:
|
|
return WorkflowTransition(edge=edge, node=target)
|
|
|
|
if llm_edges and llm_router:
|
|
selection = await llm_router(
|
|
node,
|
|
llm_edges,
|
|
{
|
|
"userText": user_text,
|
|
"assistantText": assistant_text,
|
|
},
|
|
)
|
|
if selection:
|
|
for edge in llm_edges:
|
|
if selection in {edge.id, edge.to_node_id}:
|
|
target = self._nodes.get(edge.to_node_id)
|
|
if target:
|
|
return WorkflowTransition(edge=edge, node=target)
|
|
|
|
for edge in outgoing:
|
|
cond_type = str(edge.condition.get("type", "always")).strip().lower()
|
|
if cond_type in {"", "always", "default"}:
|
|
target = self._nodes.get(edge.to_node_id)
|
|
if target:
|
|
return WorkflowTransition(edge=edge, node=target)
|
|
return None
|
|
|
|
def apply_transition(self, transition: WorkflowTransition) -> None:
|
|
self.current_node_id = transition.node.id
|
|
|
|
def build_runtime_metadata(self, node: WorkflowNodeDef) -> Dict[str, Any]:
|
|
assistant_cfg = node.assistant if isinstance(node.assistant, dict) else {}
|
|
message_plan = node.message_plan if isinstance(node.message_plan, dict) else {}
|
|
metadata: Dict[str, Any] = {}
|
|
|
|
if node.prompt is not None:
|
|
metadata["systemPrompt"] = node.prompt
|
|
elif "systemPrompt" in assistant_cfg:
|
|
metadata["systemPrompt"] = _safe_str(assistant_cfg.get("systemPrompt"))
|
|
elif "prompt" in assistant_cfg:
|
|
metadata["systemPrompt"] = _safe_str(assistant_cfg.get("prompt"))
|
|
|
|
first_message = message_plan.get("firstMessage")
|
|
if first_message is not None:
|
|
metadata["greeting"] = _safe_str(first_message)
|
|
elif "greeting" in assistant_cfg:
|
|
metadata["greeting"] = _safe_str(assistant_cfg.get("greeting"))
|
|
elif "opener" in assistant_cfg:
|
|
metadata["greeting"] = _safe_str(assistant_cfg.get("opener"))
|
|
|
|
services = assistant_cfg.get("services")
|
|
if isinstance(services, dict):
|
|
metadata["services"] = services
|
|
|
|
if node.assistant_id:
|
|
metadata["assistantId"] = node.assistant_id
|
|
|
|
return metadata
|
|
|
|
def _resolve_start_node(self) -> Optional[WorkflowNodeDef]:
|
|
explicit_start = next((node for node in self._nodes.values() if node.is_start), None)
|
|
if not explicit_start:
|
|
explicit_start = next((node for node in self._nodes.values() if node.node_type == "start"), None)
|
|
|
|
if explicit_start:
|
|
# If a dedicated start node exists, try to move to its first default target.
|
|
if explicit_start.node_type == "start":
|
|
visited = {explicit_start.id}
|
|
current = explicit_start
|
|
for _ in range(8):
|
|
transition = self._first_default_transition_from(current.id)
|
|
if not transition:
|
|
return current
|
|
current = transition.node
|
|
if current.id in visited:
|
|
break
|
|
visited.add(current.id)
|
|
return current
|
|
return explicit_start
|
|
|
|
assistant_node = next((node for node in self._nodes.values() if node.node_type == "assistant"), None)
|
|
if assistant_node:
|
|
return assistant_node
|
|
return next(iter(self._nodes.values()), None)
|
|
|
|
def _first_default_transition_from(self, node_id: str) -> Optional[WorkflowTransition]:
|
|
for edge in self.outgoing_edges(node_id):
|
|
cond_type = str(edge.condition.get("type", "always")).strip().lower()
|
|
if cond_type in {"", "always", "default"}:
|
|
node = self._nodes.get(edge.to_node_id)
|
|
if node:
|
|
return WorkflowTransition(edge=edge, node=node)
|
|
return None
|
|
|
|
def _matches_condition(self, edge: WorkflowEdgeDef, *, user_text: str, assistant_text: str) -> bool:
|
|
condition = edge.condition or {"type": "always"}
|
|
cond_type = str(condition.get("type", "always")).strip().lower()
|
|
source = str(condition.get("source", "user")).strip().lower()
|
|
|
|
if cond_type in {"", "always", "default"}:
|
|
return True
|
|
|
|
text = assistant_text if source == "assistant" else user_text
|
|
text_lower = (text or "").lower()
|
|
|
|
if cond_type == "contains":
|
|
values: List[str] = []
|
|
if isinstance(condition.get("values"), list):
|
|
values = [_safe_str(v).strip().lower() for v in condition["values"] if _safe_str(v).strip()]
|
|
if not values:
|
|
single = _safe_str(condition.get("value") or condition.get("keyword") or edge.label).strip().lower()
|
|
if single:
|
|
values = [single]
|
|
if not values:
|
|
return False
|
|
return any(value in text_lower for value in values)
|
|
|
|
if cond_type == "equals":
|
|
expected = _safe_str(condition.get("value") or "").strip().lower()
|
|
return bool(expected) and text_lower == expected
|
|
|
|
if cond_type == "regex":
|
|
pattern = _safe_str(condition.get("value") or condition.get("pattern") or "").strip()
|
|
if not pattern:
|
|
return False
|
|
try:
|
|
return bool(re.search(pattern, text or "", re.IGNORECASE))
|
|
except re.error:
|
|
logger.warning(f"Invalid workflow regex condition: {pattern}")
|
|
return False
|
|
|
|
if cond_type == "json":
|
|
value = _safe_str(condition.get("value") or "").strip()
|
|
if not value:
|
|
return False
|
|
try:
|
|
obj = json.loads(text or "")
|
|
except Exception:
|
|
return False
|
|
return str(obj) == value
|
|
|
|
return False
|