Update workflow feature with codex

This commit is contained in:
Xin Wang
2026-02-10 08:12:46 +08:00
parent 6b4391c423
commit bbeffa89ed
8 changed files with 1334 additions and 39 deletions

View File

@@ -4,8 +4,9 @@ import asyncio
import uuid
import json
import time
import re
from enum import Enum
from typing import Optional, Dict, Any
from typing import Optional, Dict, Any, List
from loguru import logger
from app.backend_client import (
@@ -16,7 +17,9 @@ from app.backend_client import (
from core.transports import BaseTransport
from core.duplex_pipeline import DuplexPipeline
from core.conversation import ConversationTurn
from core.workflow_runner import WorkflowRunner, WorkflowTransition, WorkflowNodeDef, WorkflowEdgeDef
from app.config import settings
from services.base import LLMMessage
from models.ws_v1 import (
parse_client_message,
ev,
@@ -81,6 +84,9 @@ class Session:
self._history_finalized: bool = False
self._cleanup_lock = asyncio.Lock()
self._cleaned_up = False
self.workflow_runner: Optional[WorkflowRunner] = None
self._workflow_last_user_text: str = ""
self._workflow_initial_node: Optional[WorkflowNodeDef] = None
self.pipeline.conversation.on_turn_complete(self._on_turn_complete)
@@ -223,6 +229,7 @@ class Session:
return
metadata = message.metadata or {}
metadata = self._merge_runtime_metadata(metadata, self._bootstrap_workflow(metadata))
# Create history call record early so later turn callbacks can append transcripts.
await self._start_history_bridge(metadata)
@@ -246,6 +253,26 @@ class Session:
audio=message.audio or {},
)
)
if self.workflow_runner and self._workflow_initial_node:
await self.transport.send_event(
ev(
"workflow.started",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
workflowName=self.workflow_runner.name,
nodeId=self._workflow_initial_node.id,
)
)
await self.transport.send_event(
ev(
"workflow.node.entered",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
nodeId=self._workflow_initial_node.id,
nodeName=self._workflow_initial_node.name,
nodeType=self._workflow_initial_node.node_type,
)
)
async def _handle_session_stop(self, reason: Optional[str]) -> None:
"""Handle session stop."""
@@ -334,7 +361,14 @@ class Session:
logger.info(f"Session {self.id} history bridge enabled (call_id={call_id}, source={source})")
async def _on_turn_complete(self, turn: ConversationTurn) -> None:
"""Persist completed turns to backend call transcripts."""
"""Process workflow transitions and persist completed turns to history."""
if turn.text and turn.text.strip():
role = (turn.role or "").lower()
if role == "user":
self._workflow_last_user_text = turn.text.strip()
elif role == "assistant":
await self._maybe_advance_workflow(turn.text.strip())
if not self._history_call_id:
return
if not turn.text or not turn.text.strip():
@@ -377,3 +411,235 @@ class Session:
)
if ok:
self._history_finalized = True
def _bootstrap_workflow(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
"""Parse workflow payload and return initial runtime overrides."""
payload = metadata.get("workflow")
self.workflow_runner = WorkflowRunner.from_payload(payload)
self._workflow_initial_node = None
if not self.workflow_runner:
return {}
node = self.workflow_runner.bootstrap()
if not node:
logger.warning(f"Session {self.id} workflow payload had no resolvable start node")
self.workflow_runner = None
return {}
self._workflow_initial_node = node
logger.info(
"Session {} workflow enabled: workflow={} start_node={}",
self.id,
self.workflow_runner.workflow_id,
node.id,
)
return self.workflow_runner.build_runtime_metadata(node)
async def _maybe_advance_workflow(self, assistant_text: str) -> None:
"""Attempt node transfer after assistant turn finalization."""
if not self.workflow_runner or self.ws_state == WsSessionState.STOPPED:
return
transition = await self.workflow_runner.route(
user_text=self._workflow_last_user_text,
assistant_text=assistant_text,
llm_router=self._workflow_llm_route,
)
if not transition:
return
await self._apply_workflow_transition(transition, reason="rule_match")
# Auto-advance through utility nodes when default edges are present.
max_auto_hops = 6
auto_hops = 0
while self.workflow_runner and self.ws_state != WsSessionState.STOPPED:
current = self.workflow_runner.current_node
if not current or current.node_type not in {"start", "tool"}:
break
next_default = self.workflow_runner.next_default_transition()
if not next_default:
break
auto_hops += 1
await self._apply_workflow_transition(next_default, reason="auto")
if auto_hops >= max_auto_hops:
logger.warning(
"Session {} workflow auto-advance reached hop limit (possible cycle)",
self.id,
)
break
async def _apply_workflow_transition(self, transition: WorkflowTransition, reason: str) -> None:
"""Apply graph transition and emit workflow lifecycle events."""
if not self.workflow_runner:
return
self.workflow_runner.apply_transition(transition)
node = transition.node
edge = transition.edge
await self.transport.send_event(
ev(
"workflow.edge.taken",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
edgeId=edge.id,
fromNodeId=edge.from_node_id,
toNodeId=edge.to_node_id,
reason=reason,
)
)
await self.transport.send_event(
ev(
"workflow.node.entered",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
nodeId=node.id,
nodeName=node.name,
nodeType=node.node_type,
)
)
node_runtime = self.workflow_runner.build_runtime_metadata(node)
if node_runtime:
self.pipeline.apply_runtime_overrides(node_runtime)
if node.node_type == "tool":
await self.transport.send_event(
ev(
"workflow.tool.requested",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
nodeId=node.id,
tool=node.tool or {},
)
)
return
if node.node_type == "human_transfer":
await self.transport.send_event(
ev(
"workflow.human_transfer",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
nodeId=node.id,
)
)
await self._handle_session_stop("workflow_human_transfer")
return
if node.node_type == "end":
await self.transport.send_event(
ev(
"workflow.ended",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
nodeId=node.id,
)
)
await self._handle_session_stop("workflow_end")
async def _workflow_llm_route(
self,
node: WorkflowNodeDef,
candidates: List[WorkflowEdgeDef],
context: Dict[str, str],
) -> Optional[str]:
"""LLM-based edge routing for condition.type == 'llm' edges."""
llm_service = self.pipeline.llm_service
if not llm_service:
return None
candidate_rows = [
{
"edgeId": edge.id,
"toNodeId": edge.to_node_id,
"label": edge.label,
"hint": edge.condition.get("prompt") if isinstance(edge.condition, dict) else None,
}
for edge in candidates
]
system_prompt = (
"You are a workflow router. Pick exactly one edge. "
"Return JSON only: {\"edgeId\":\"...\"}."
)
user_prompt = json.dumps(
{
"nodeId": node.id,
"nodeName": node.name,
"userText": context.get("userText", ""),
"assistantText": context.get("assistantText", ""),
"candidates": candidate_rows,
},
ensure_ascii=False,
)
try:
reply = await llm_service.generate(
[
LLMMessage(role="system", content=system_prompt),
LLMMessage(role="user", content=user_prompt),
],
temperature=0.0,
max_tokens=64,
)
except Exception as exc:
logger.warning(f"Session {self.id} workflow llm routing failed: {exc}")
return None
if not reply:
return None
edge_ids = {edge.id for edge in candidates}
node_ids = {edge.to_node_id for edge in candidates}
parsed = self._extract_json_obj(reply)
if isinstance(parsed, dict):
edge_id = parsed.get("edgeId") or parsed.get("id")
node_id = parsed.get("toNodeId") or parsed.get("nodeId")
if isinstance(edge_id, str) and edge_id in edge_ids:
return edge_id
if isinstance(node_id, str) and node_id in node_ids:
return node_id
token_candidates = sorted(edge_ids | node_ids, key=len, reverse=True)
lowered_reply = reply.lower()
for token in token_candidates:
if token.lower() in lowered_reply:
return token
return None
def _merge_runtime_metadata(self, base: Dict[str, Any], overrides: Dict[str, Any]) -> Dict[str, Any]:
"""Merge node-level metadata overrides into session.start metadata."""
merged = dict(base or {})
if not overrides:
return merged
for key, value in overrides.items():
if key == "services" and isinstance(value, dict):
existing = merged.get("services")
merged_services = dict(existing) if isinstance(existing, dict) else {}
merged_services.update(value)
merged["services"] = merged_services
else:
merged[key] = value
return merged
def _extract_json_obj(self, text: str) -> Optional[Dict[str, Any]]:
"""Best-effort extraction of a JSON object from freeform text."""
try:
parsed = json.loads(text)
if isinstance(parsed, dict):
return parsed
except Exception:
pass
match = re.search(r"\{.*\}", text, re.DOTALL)
if not match:
return None
try:
parsed = json.loads(match.group(0))
return parsed if isinstance(parsed, dict) else None
except Exception:
return None

View File

@@ -0,0 +1,402 @@
"""Workflow runtime helpers for session-level node routing.
MVP goals:
- Parse workflow graph payload from WS session.start metadata
- Track current node
- Evaluate edge conditions on each assistant turn completion
- Provide per-node runtime metadata overrides (prompt/greeting/services)
"""
from __future__ import annotations
from dataclasses import dataclass
import json
import re
from typing import Any, Awaitable, Callable, Dict, List, Optional
from loguru import logger
_NODE_TYPE_MAP = {
"conversation": "assistant",
"assistant": "assistant",
"human": "human_transfer",
"human_transfer": "human_transfer",
"tool": "tool",
"end": "end",
"start": "start",
}
def _normalize_node_type(raw_type: Any) -> str:
value = str(raw_type or "").strip().lower()
return _NODE_TYPE_MAP.get(value, "assistant")
def _safe_str(value: Any) -> str:
if value is None:
return ""
return str(value)
def _normalize_condition(raw: Any, label: Optional[str]) -> Dict[str, Any]:
if not isinstance(raw, dict):
if label:
return {"type": "contains", "source": "user", "value": str(label)}
return {"type": "always"}
condition = dict(raw)
condition_type = str(condition.get("type", "always")).strip().lower()
if not condition_type:
condition_type = "always"
condition["type"] = condition_type
condition["source"] = str(condition.get("source", "user")).strip().lower() or "user"
return condition
@dataclass
class WorkflowNodeDef:
id: str
name: str
node_type: str
is_start: bool
prompt: Optional[str]
message_plan: Dict[str, Any]
assistant_id: Optional[str]
assistant: Dict[str, Any]
tool: Optional[Dict[str, Any]]
raw: Dict[str, Any]
@dataclass
class WorkflowEdgeDef:
id: str
from_node_id: str
to_node_id: str
label: Optional[str]
condition: Dict[str, Any]
priority: int
order: int
raw: Dict[str, Any]
@dataclass
class WorkflowTransition:
edge: WorkflowEdgeDef
node: WorkflowNodeDef
LlmRouter = Callable[
[WorkflowNodeDef, List[WorkflowEdgeDef], Dict[str, str]],
Awaitable[Optional[str]],
]
class WorkflowRunner:
"""In-memory workflow graph for a single active session."""
def __init__(self, workflow_id: str, name: str, nodes: List[WorkflowNodeDef], edges: List[WorkflowEdgeDef]):
self.workflow_id = workflow_id
self.name = name
self._nodes: Dict[str, WorkflowNodeDef] = {node.id: node for node in nodes}
self._edges = edges
self.current_node_id: Optional[str] = None
@classmethod
def from_payload(cls, payload: Any) -> Optional["WorkflowRunner"]:
if not isinstance(payload, dict):
return None
raw_nodes = payload.get("nodes")
raw_edges = payload.get("edges")
if not isinstance(raw_nodes, list) or len(raw_nodes) == 0:
return None
nodes: List[WorkflowNodeDef] = []
for i, raw in enumerate(raw_nodes):
if not isinstance(raw, dict):
continue
node_id = _safe_str(raw.get("id") or raw.get("name") or f"node_{i + 1}").strip() or f"node_{i + 1}"
node_name = _safe_str(raw.get("name") or node_id).strip() or node_id
node_type = _normalize_node_type(raw.get("type"))
is_start = bool(raw.get("isStart")) or node_type == "start"
prompt: Optional[str] = None
if "prompt" in raw:
prompt = _safe_str(raw.get("prompt"))
message_plan = raw.get("messagePlan")
if not isinstance(message_plan, dict):
message_plan = {}
assistant_cfg = raw.get("assistant")
if not isinstance(assistant_cfg, dict):
assistant_cfg = {}
tool_cfg = raw.get("tool")
if not isinstance(tool_cfg, dict):
tool_cfg = None
assistant_id = raw.get("assistantId")
if assistant_id is not None:
assistant_id = _safe_str(assistant_id).strip() or None
nodes.append(
WorkflowNodeDef(
id=node_id,
name=node_name,
node_type=node_type,
is_start=is_start,
prompt=prompt,
message_plan=message_plan,
assistant_id=assistant_id,
assistant=assistant_cfg,
tool=tool_cfg,
raw=raw,
)
)
if not nodes:
return None
node_ids = {node.id for node in nodes}
edges: List[WorkflowEdgeDef] = []
for i, raw in enumerate(raw_edges if isinstance(raw_edges, list) else []):
if not isinstance(raw, dict):
continue
from_node_id = _safe_str(
raw.get("fromNodeId") or raw.get("from") or raw.get("from_") or raw.get("source")
).strip()
to_node_id = _safe_str(raw.get("toNodeId") or raw.get("to") or raw.get("target")).strip()
if not from_node_id or not to_node_id:
continue
if from_node_id not in node_ids or to_node_id not in node_ids:
continue
label = raw.get("label")
if label is not None:
label = _safe_str(label)
condition = _normalize_condition(raw.get("condition"), label=label)
priority = 100
try:
priority = int(raw.get("priority", 100))
except (TypeError, ValueError):
priority = 100
edge_id = _safe_str(raw.get("id") or f"e_{from_node_id}_{to_node_id}_{i + 1}").strip() or f"e_{i + 1}"
edges.append(
WorkflowEdgeDef(
id=edge_id,
from_node_id=from_node_id,
to_node_id=to_node_id,
label=label,
condition=condition,
priority=priority,
order=i,
raw=raw,
)
)
workflow_id = _safe_str(payload.get("id") or "workflow")
workflow_name = _safe_str(payload.get("name") or workflow_id)
return cls(workflow_id=workflow_id, name=workflow_name, nodes=nodes, edges=edges)
def bootstrap(self) -> Optional[WorkflowNodeDef]:
start_node = self._resolve_start_node()
if not start_node:
return None
self.current_node_id = start_node.id
return start_node
@property
def current_node(self) -> Optional[WorkflowNodeDef]:
if not self.current_node_id:
return None
return self._nodes.get(self.current_node_id)
def outgoing_edges(self, node_id: str) -> List[WorkflowEdgeDef]:
edges = [edge for edge in self._edges if edge.from_node_id == node_id]
return sorted(edges, key=lambda edge: (edge.priority, edge.order))
def next_default_transition(self) -> Optional[WorkflowTransition]:
node = self.current_node
if not node:
return None
for edge in self.outgoing_edges(node.id):
cond_type = str(edge.condition.get("type", "always")).strip().lower()
if cond_type in {"", "always", "default"}:
target = self._nodes.get(edge.to_node_id)
if target:
return WorkflowTransition(edge=edge, node=target)
return None
async def route(
self,
*,
user_text: str,
assistant_text: str,
llm_router: Optional[LlmRouter] = None,
) -> Optional[WorkflowTransition]:
node = self.current_node
if not node:
return None
outgoing = self.outgoing_edges(node.id)
if not outgoing:
return None
llm_edges: List[WorkflowEdgeDef] = []
for edge in outgoing:
cond_type = str(edge.condition.get("type", "always")).strip().lower()
if cond_type == "llm":
llm_edges.append(edge)
continue
if self._matches_condition(edge, user_text=user_text, assistant_text=assistant_text):
target = self._nodes.get(edge.to_node_id)
if target:
return WorkflowTransition(edge=edge, node=target)
if llm_edges and llm_router:
selection = await llm_router(
node,
llm_edges,
{
"userText": user_text,
"assistantText": assistant_text,
},
)
if selection:
for edge in llm_edges:
if selection in {edge.id, edge.to_node_id}:
target = self._nodes.get(edge.to_node_id)
if target:
return WorkflowTransition(edge=edge, node=target)
for edge in outgoing:
cond_type = str(edge.condition.get("type", "always")).strip().lower()
if cond_type in {"", "always", "default"}:
target = self._nodes.get(edge.to_node_id)
if target:
return WorkflowTransition(edge=edge, node=target)
return None
def apply_transition(self, transition: WorkflowTransition) -> None:
self.current_node_id = transition.node.id
def build_runtime_metadata(self, node: WorkflowNodeDef) -> Dict[str, Any]:
assistant_cfg = node.assistant if isinstance(node.assistant, dict) else {}
message_plan = node.message_plan if isinstance(node.message_plan, dict) else {}
metadata: Dict[str, Any] = {}
if node.prompt is not None:
metadata["systemPrompt"] = node.prompt
elif "systemPrompt" in assistant_cfg:
metadata["systemPrompt"] = _safe_str(assistant_cfg.get("systemPrompt"))
elif "prompt" in assistant_cfg:
metadata["systemPrompt"] = _safe_str(assistant_cfg.get("prompt"))
first_message = message_plan.get("firstMessage")
if first_message is not None:
metadata["greeting"] = _safe_str(first_message)
elif "greeting" in assistant_cfg:
metadata["greeting"] = _safe_str(assistant_cfg.get("greeting"))
elif "opener" in assistant_cfg:
metadata["greeting"] = _safe_str(assistant_cfg.get("opener"))
services = assistant_cfg.get("services")
if isinstance(services, dict):
metadata["services"] = services
if node.assistant_id:
metadata["assistantId"] = node.assistant_id
return metadata
def _resolve_start_node(self) -> Optional[WorkflowNodeDef]:
explicit_start = next((node for node in self._nodes.values() if node.is_start), None)
if not explicit_start:
explicit_start = next((node for node in self._nodes.values() if node.node_type == "start"), None)
if explicit_start:
# If a dedicated start node exists, try to move to its first default target.
if explicit_start.node_type == "start":
visited = {explicit_start.id}
current = explicit_start
for _ in range(8):
transition = self._first_default_transition_from(current.id)
if not transition:
return current
current = transition.node
if current.id in visited:
break
visited.add(current.id)
return current
return explicit_start
assistant_node = next((node for node in self._nodes.values() if node.node_type == "assistant"), None)
if assistant_node:
return assistant_node
return next(iter(self._nodes.values()), None)
def _first_default_transition_from(self, node_id: str) -> Optional[WorkflowTransition]:
for edge in self.outgoing_edges(node_id):
cond_type = str(edge.condition.get("type", "always")).strip().lower()
if cond_type in {"", "always", "default"}:
node = self._nodes.get(edge.to_node_id)
if node:
return WorkflowTransition(edge=edge, node=node)
return None
def _matches_condition(self, edge: WorkflowEdgeDef, *, user_text: str, assistant_text: str) -> bool:
condition = edge.condition or {"type": "always"}
cond_type = str(condition.get("type", "always")).strip().lower()
source = str(condition.get("source", "user")).strip().lower()
if cond_type in {"", "always", "default"}:
return True
text = assistant_text if source == "assistant" else user_text
text_lower = (text or "").lower()
if cond_type == "contains":
values: List[str] = []
if isinstance(condition.get("values"), list):
values = [_safe_str(v).strip().lower() for v in condition["values"] if _safe_str(v).strip()]
if not values:
single = _safe_str(condition.get("value") or condition.get("keyword") or edge.label).strip().lower()
if single:
values = [single]
if not values:
return False
return any(value in text_lower for value in values)
if cond_type == "equals":
expected = _safe_str(condition.get("value") or "").strip().lower()
return bool(expected) and text_lower == expected
if cond_type == "regex":
pattern = _safe_str(condition.get("value") or condition.get("pattern") or "").strip()
if not pattern:
return False
try:
return bool(re.search(pattern, text or "", re.IGNORECASE))
except re.error:
logger.warning(f"Invalid workflow regex condition: {pattern}")
return False
if cond_type == "json":
value = _safe_str(condition.get("value") or "").strip()
if not value:
return False
try:
obj = json.loads(text or "")
except Exception:
return False
return str(obj) == value
return False