From bbeffa89edbc8728a7da2acb5ce3be2c103dcf5d Mon Sep 17 00:00:00 2001 From: Xin Wang Date: Tue, 10 Feb 2026 08:12:46 +0800 Subject: [PATCH] Update workflow feature with codex --- api/app/routers/workflows.py | 37 ++- api/app/schemas.py | 152 +++++++++++-- api/tests/test_workflows.py | 167 ++++++++++++++ engine/core/session.py | 270 +++++++++++++++++++++- engine/core/workflow_runner.py | 402 +++++++++++++++++++++++++++++++++ web/pages/WorkflowEditor.tsx | 301 +++++++++++++++++++++++- web/services/backendApi.ts | 24 +- web/types.ts | 20 +- 8 files changed, 1334 insertions(+), 39 deletions(-) create mode 100644 api/tests/test_workflows.py create mode 100644 engine/core/workflow_runner.py diff --git a/api/app/routers/workflows.py b/api/app/routers/workflows.py index c6f9c9f..0c945af 100644 --- a/api/app/routers/workflows.py +++ b/api/app/routers/workflows.py @@ -2,14 +2,30 @@ from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.orm import Session import uuid from datetime import datetime +from typing import Any, Dict, List, Tuple from ..db import get_db from ..models import Workflow -from ..schemas import WorkflowCreate, WorkflowUpdate, WorkflowOut +from ..schemas import WorkflowCreate, WorkflowUpdate, WorkflowOut, WorkflowNode, WorkflowEdge router = APIRouter(prefix="/workflows", tags=["Workflows"]) +def _normalize_graph_payload(nodes: List[Any], edges: List[Any]) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: + """Normalize graph payload to canonical dict structures.""" + parsed_nodes: List[WorkflowNode] = [] + for node in nodes: + parsed_nodes.append(node if isinstance(node, WorkflowNode) else WorkflowNode.model_validate(node)) + + parsed_edges: List[WorkflowEdge] = [] + for edge in edges: + parsed_edges.append(edge if isinstance(edge, WorkflowEdge) else WorkflowEdge.model_validate(edge)) + + normalized_nodes = [node.model_dump() for node in parsed_nodes] + normalized_edges = [edge.model_dump() for edge in parsed_edges] + return normalized_nodes, normalized_edges + + @router.get("") def list_workflows( page: int = 1, @@ -27,16 +43,17 @@ def list_workflows( @router.post("", response_model=WorkflowOut) def create_workflow(data: WorkflowCreate, db: Session = Depends(get_db)): """创建工作流""" + nodes, edges = _normalize_graph_payload(data.nodes, data.edges) workflow = Workflow( id=str(uuid.uuid4())[:8], user_id=1, name=data.name, - node_count=data.nodeCount, + node_count=data.nodeCount or len(nodes), created_at=data.createdAt or datetime.utcnow().isoformat(), updated_at=data.updatedAt or "", global_prompt=data.globalPrompt, - nodes=data.nodes, - edges=data.edges, + nodes=nodes, + edges=edges, ) db.add(workflow) db.commit() @@ -60,7 +77,7 @@ def update_workflow(id: str, data: WorkflowUpdate, db: Session = Depends(get_db) if not workflow: raise HTTPException(status_code=404, detail="Workflow not found") - update_data = data.model_dump(exclude_unset=True) + update_data = data.model_dump(exclude_unset=True, exclude={"nodes", "edges"}) field_map = { "nodeCount": "node_count", "globalPrompt": "global_prompt", @@ -68,6 +85,16 @@ def update_workflow(id: str, data: WorkflowUpdate, db: Session = Depends(get_db) for field, value in update_data.items(): setattr(workflow, field_map.get(field, field), value) + if data.nodes is not None or data.edges is not None: + existing_nodes = workflow.nodes if isinstance(workflow.nodes, list) else [] + existing_edges = workflow.edges if isinstance(workflow.edges, list) else [] + input_nodes = data.nodes if data.nodes is not None else existing_nodes + input_edges = data.edges if data.edges is not None else existing_edges + nodes, edges = _normalize_graph_payload(input_nodes, input_edges) + workflow.nodes = nodes + workflow.edges = edges + workflow.node_count = len(nodes) + workflow.updated_at = datetime.utcnow().isoformat() db.commit() db.refresh(workflow) diff --git a/api/app/schemas.py b/api/app/schemas.py index 0b40d82..457476b 100644 --- a/api/app/schemas.py +++ b/api/app/schemas.py @@ -1,7 +1,7 @@ from datetime import datetime from enum import Enum -from typing import List, Optional -from pydantic import BaseModel +from typing import Any, Dict, List, Optional +from pydantic import BaseModel, ConfigDict, Field, model_validator # ============ Enums ============ @@ -410,24 +410,82 @@ class KnowledgeStats(BaseModel): # ============ Workflow ============ class WorkflowNode(BaseModel): - name: str - type: str + model_config = ConfigDict(extra="allow") + + id: Optional[str] = None + name: str = "" + type: str = "assistant" isStart: Optional[bool] = None - metadata: dict + metadata: Dict[str, Any] = Field(default_factory=dict) prompt: Optional[str] = None - messagePlan: Optional[dict] = None - variableExtractionPlan: Optional[dict] = None - tool: Optional[dict] = None - globalNodePlan: Optional[dict] = None + messagePlan: Optional[Dict[str, Any]] = None + variableExtractionPlan: Optional[Dict[str, Any]] = None + tool: Optional[Dict[str, Any]] = None + globalNodePlan: Optional[Dict[str, Any]] = None + assistantId: Optional[str] = None + assistant: Optional[Dict[str, Any]] = None + + @model_validator(mode="before") + @classmethod + def _normalize_legacy_node(cls, data: Any) -> Any: + if not isinstance(data, dict): + return data + raw = dict(data) + node_id = raw.get("id") or raw.get("name") + if not node_id: + node_id = f"node_{abs(hash(str(raw))) % 100000}" + raw["id"] = str(node_id) + raw["name"] = str(raw.get("name") or raw["id"]) + + node_type = str(raw.get("type") or "assistant").lower() + if node_type == "conversation": + node_type = "assistant" + elif node_type == "human": + node_type = "human_transfer" + elif node_type not in {"start", "assistant", "tool", "human_transfer", "end"}: + node_type = "assistant" + raw["type"] = node_type + + metadata = raw.get("metadata") + if not isinstance(metadata, dict): + metadata = {} + if "position" not in metadata and isinstance(raw.get("position"), dict): + metadata["position"] = raw.get("position") + raw["metadata"] = metadata + + if raw.get("isStart") is None and node_type == "start": + raw["isStart"] = True + return raw class WorkflowEdge(BaseModel): - from_: str - to: str - label: Optional[str] = None + model_config = ConfigDict(extra="allow") - class Config: - populate_by_name = True + id: Optional[str] = None + fromNodeId: str + toNodeId: str + label: Optional[str] = None + condition: Optional[Dict[str, Any]] = None + priority: int = 100 + + @model_validator(mode="before") + @classmethod + def _normalize_legacy_edge(cls, data: Any) -> Any: + if not isinstance(data, dict): + return data + raw = dict(data) + from_node = raw.get("fromNodeId") or raw.get("from") or raw.get("from_") or raw.get("source") + to_node = raw.get("toNodeId") or raw.get("to") or raw.get("target") + raw["fromNodeId"] = str(from_node or "") + raw["toNodeId"] = str(to_node or "") + if raw.get("id") is None: + raw["id"] = f"e_{raw['fromNodeId']}_{raw['toNodeId']}" + if raw.get("condition") is None: + if raw.get("label"): + raw["condition"] = {"type": "contains", "source": "user", "value": str(raw["label"])} + else: + raw["condition"] = {"type": "always"} + return raw class WorkflowBase(BaseModel): @@ -436,29 +494,85 @@ class WorkflowBase(BaseModel): createdAt: str = "" updatedAt: str = "" globalPrompt: Optional[str] = None - nodes: List[dict] = [] - edges: List[dict] = [] + nodes: List[WorkflowNode] = Field(default_factory=list) + edges: List[WorkflowEdge] = Field(default_factory=list) class WorkflowCreate(WorkflowBase): - pass + @model_validator(mode="after") + def _validate_graph(self) -> "WorkflowCreate": + _validate_workflow_graph(self.nodes, self.edges) + return self class WorkflowUpdate(BaseModel): name: Optional[str] = None nodeCount: Optional[int] = None - nodes: Optional[List[dict]] = None - edges: Optional[List[dict]] = None + nodes: Optional[List[WorkflowNode]] = None + edges: Optional[List[WorkflowEdge]] = None globalPrompt: Optional[str] = None + @model_validator(mode="after") + def _validate_partial_graph(self) -> "WorkflowUpdate": + if self.nodes is not None and self.edges is not None: + _validate_workflow_graph(self.nodes, self.edges) + return self + class WorkflowOut(WorkflowBase): id: str + @model_validator(mode="before") + @classmethod + def _normalize_db_fields(cls, data: Any) -> Any: + if isinstance(data, dict): + raw = dict(data) + else: + raw = { + "id": getattr(data, "id", None), + "name": getattr(data, "name", None), + "node_count": getattr(data, "node_count", None), + "created_at": getattr(data, "created_at", None), + "updated_at": getattr(data, "updated_at", None), + "global_prompt": getattr(data, "global_prompt", None), + "nodes": getattr(data, "nodes", None), + "edges": getattr(data, "edges", None), + } + + if "nodeCount" not in raw and raw.get("node_count") is not None: + raw["nodeCount"] = raw["node_count"] + if "createdAt" not in raw and raw.get("created_at") is not None: + raw["createdAt"] = raw["created_at"] + if "updatedAt" not in raw and raw.get("updated_at") is not None: + raw["updatedAt"] = raw["updated_at"] + if "globalPrompt" not in raw and raw.get("global_prompt") is not None: + raw["globalPrompt"] = raw["global_prompt"] + return raw + class Config: from_attributes = True +def _validate_workflow_graph(nodes: List[WorkflowNode], edges: List[WorkflowEdge]) -> None: + if not nodes: + raise ValueError("Workflow must include at least one node") + + node_ids = [node.id for node in nodes if node.id] + if len(node_ids) != len(set(node_ids)): + raise ValueError("Workflow node ids must be unique") + + starts = [node for node in nodes if node.isStart or node.type == "start"] + if not starts: + raise ValueError("Workflow must define a start node (isStart=true or type=start)") + + known = set(node_ids) + for edge in edges: + if edge.fromNodeId not in known: + raise ValueError(f"Workflow edge fromNodeId not found: {edge.fromNodeId}") + if edge.toNodeId not in known: + raise ValueError(f"Workflow edge toNodeId not found: {edge.toNodeId}") + + # ============ Call Record ============ class TranscriptSegment(BaseModel): turnIndex: int diff --git a/api/tests/test_workflows.py b/api/tests/test_workflows.py new file mode 100644 index 0000000..3cb74a4 --- /dev/null +++ b/api/tests/test_workflows.py @@ -0,0 +1,167 @@ +"""Tests for workflow graph schema and router behavior.""" + + +class TestWorkflowAPI: + """Workflow CRUD and graph validation test cases.""" + + def _minimal_nodes(self): + return [ + { + "id": "start_1", + "name": "start_1", + "type": "start", + "isStart": True, + "metadata": {"position": {"x": 80, "y": 80}}, + }, + { + "id": "assistant_1", + "name": "assistant_1", + "type": "assistant", + "metadata": {"position": {"x": 280, "y": 80}}, + "prompt": "You are the first assistant node.", + }, + ] + + def test_create_workflow_with_canonical_graph(self, client): + payload = { + "name": "Canonical Graph", + "nodes": self._minimal_nodes(), + "edges": [ + { + "id": "edge_start_assistant", + "fromNodeId": "start_1", + "toNodeId": "assistant_1", + "condition": {"type": "always"}, + } + ], + } + + resp = client.post("/api/workflows", json=payload) + assert resp.status_code == 200 + data = resp.json() + assert data["name"] == "Canonical Graph" + assert data["nodeCount"] == 2 + assert data["nodes"][0]["id"] == "start_1" + assert data["edges"][0]["fromNodeId"] == "start_1" + assert data["edges"][0]["toNodeId"] == "assistant_1" + + def test_create_workflow_with_legacy_graph(self, client): + payload = { + "name": "Legacy Graph", + "nodes": [ + { + "name": "legacy_start", + "type": "conversation", + "isStart": True, + "metadata": {"position": {"x": 100, "y": 100}}, + }, + { + "name": "legacy_human", + "type": "human", + "metadata": {"position": {"x": 300, "y": 100}}, + }, + ], + "edges": [ + { + "from": "legacy_start", + "to": "legacy_human", + "label": "人工", + } + ], + } + + resp = client.post("/api/workflows", json=payload) + assert resp.status_code == 200 + data = resp.json() + assert data["nodes"][0]["type"] == "assistant" + assert data["nodes"][1]["type"] == "human_transfer" + assert data["edges"][0]["fromNodeId"] == "legacy_start" + assert data["edges"][0]["toNodeId"] == "legacy_human" + assert data["edges"][0]["condition"]["type"] == "contains" + + def test_create_workflow_without_start_node_fails(self, client): + payload = { + "name": "No Start", + "nodes": [ + {"id": "node_1", "name": "node_1", "type": "assistant", "metadata": {"position": {"x": 0, "y": 0}}}, + ], + "edges": [], + } + resp = client.post("/api/workflows", json=payload) + assert resp.status_code == 422 + + def test_create_workflow_with_invalid_edge_fails(self, client): + payload = { + "name": "Bad Edge", + "nodes": self._minimal_nodes(), + "edges": [ + {"id": "edge_bad", "fromNodeId": "missing", "toNodeId": "assistant_1", "condition": {"type": "always"}}, + ], + } + resp = client.post("/api/workflows", json=payload) + assert resp.status_code == 422 + + def test_update_workflow_nodes_and_edges(self, client): + create_payload = { + "name": "Before Update", + "nodes": self._minimal_nodes(), + "edges": [ + { + "id": "edge_start_assistant", + "fromNodeId": "start_1", + "toNodeId": "assistant_1", + "condition": {"type": "always"}, + } + ], + } + create_resp = client.post("/api/workflows", json=create_payload) + assert create_resp.status_code == 200 + workflow_id = create_resp.json()["id"] + + update_payload = { + "name": "After Update", + "nodes": [ + { + "id": "start_1", + "name": "start_1", + "type": "start", + "isStart": True, + "metadata": {"position": {"x": 50, "y": 50}}, + }, + { + "id": "assistant_2", + "name": "assistant_2", + "type": "assistant", + "metadata": {"position": {"x": 250, "y": 50}}, + "prompt": "new prompt", + }, + { + "id": "end_1", + "name": "end_1", + "type": "end", + "metadata": {"position": {"x": 450, "y": 50}}, + }, + ], + "edges": [ + { + "id": "edge_start_assistant2", + "fromNodeId": "start_1", + "toNodeId": "assistant_2", + "condition": {"type": "always"}, + }, + { + "id": "edge_assistant2_end", + "fromNodeId": "assistant_2", + "toNodeId": "end_1", + "condition": {"type": "contains", "source": "user", "value": "结束"}, + }, + ], + } + + update_resp = client.put(f"/api/workflows/{workflow_id}", json=update_payload) + assert update_resp.status_code == 200 + updated = update_resp.json() + assert updated["name"] == "After Update" + assert updated["nodeCount"] == 3 + assert len(updated["nodes"]) == 3 + assert len(updated["edges"]) == 2 diff --git a/engine/core/session.py b/engine/core/session.py index 6c59bd7..c4fb193 100644 --- a/engine/core/session.py +++ b/engine/core/session.py @@ -4,8 +4,9 @@ import asyncio import uuid import json import time +import re from enum import Enum -from typing import Optional, Dict, Any +from typing import Optional, Dict, Any, List from loguru import logger from app.backend_client import ( @@ -16,7 +17,9 @@ from app.backend_client import ( from core.transports import BaseTransport from core.duplex_pipeline import DuplexPipeline from core.conversation import ConversationTurn +from core.workflow_runner import WorkflowRunner, WorkflowTransition, WorkflowNodeDef, WorkflowEdgeDef from app.config import settings +from services.base import LLMMessage from models.ws_v1 import ( parse_client_message, ev, @@ -81,6 +84,9 @@ class Session: self._history_finalized: bool = False self._cleanup_lock = asyncio.Lock() self._cleaned_up = False + self.workflow_runner: Optional[WorkflowRunner] = None + self._workflow_last_user_text: str = "" + self._workflow_initial_node: Optional[WorkflowNodeDef] = None self.pipeline.conversation.on_turn_complete(self._on_turn_complete) @@ -223,6 +229,7 @@ class Session: return metadata = message.metadata or {} + metadata = self._merge_runtime_metadata(metadata, self._bootstrap_workflow(metadata)) # Create history call record early so later turn callbacks can append transcripts. await self._start_history_bridge(metadata) @@ -246,6 +253,26 @@ class Session: audio=message.audio or {}, ) ) + if self.workflow_runner and self._workflow_initial_node: + await self.transport.send_event( + ev( + "workflow.started", + sessionId=self.id, + workflowId=self.workflow_runner.workflow_id, + workflowName=self.workflow_runner.name, + nodeId=self._workflow_initial_node.id, + ) + ) + await self.transport.send_event( + ev( + "workflow.node.entered", + sessionId=self.id, + workflowId=self.workflow_runner.workflow_id, + nodeId=self._workflow_initial_node.id, + nodeName=self._workflow_initial_node.name, + nodeType=self._workflow_initial_node.node_type, + ) + ) async def _handle_session_stop(self, reason: Optional[str]) -> None: """Handle session stop.""" @@ -334,7 +361,14 @@ class Session: logger.info(f"Session {self.id} history bridge enabled (call_id={call_id}, source={source})") async def _on_turn_complete(self, turn: ConversationTurn) -> None: - """Persist completed turns to backend call transcripts.""" + """Process workflow transitions and persist completed turns to history.""" + if turn.text and turn.text.strip(): + role = (turn.role or "").lower() + if role == "user": + self._workflow_last_user_text = turn.text.strip() + elif role == "assistant": + await self._maybe_advance_workflow(turn.text.strip()) + if not self._history_call_id: return if not turn.text or not turn.text.strip(): @@ -377,3 +411,235 @@ class Session: ) if ok: self._history_finalized = True + + def _bootstrap_workflow(self, metadata: Dict[str, Any]) -> Dict[str, Any]: + """Parse workflow payload and return initial runtime overrides.""" + payload = metadata.get("workflow") + self.workflow_runner = WorkflowRunner.from_payload(payload) + self._workflow_initial_node = None + if not self.workflow_runner: + return {} + + node = self.workflow_runner.bootstrap() + if not node: + logger.warning(f"Session {self.id} workflow payload had no resolvable start node") + self.workflow_runner = None + return {} + + self._workflow_initial_node = node + logger.info( + "Session {} workflow enabled: workflow={} start_node={}", + self.id, + self.workflow_runner.workflow_id, + node.id, + ) + return self.workflow_runner.build_runtime_metadata(node) + + async def _maybe_advance_workflow(self, assistant_text: str) -> None: + """Attempt node transfer after assistant turn finalization.""" + if not self.workflow_runner or self.ws_state == WsSessionState.STOPPED: + return + + transition = await self.workflow_runner.route( + user_text=self._workflow_last_user_text, + assistant_text=assistant_text, + llm_router=self._workflow_llm_route, + ) + if not transition: + return + + await self._apply_workflow_transition(transition, reason="rule_match") + + # Auto-advance through utility nodes when default edges are present. + max_auto_hops = 6 + auto_hops = 0 + while self.workflow_runner and self.ws_state != WsSessionState.STOPPED: + current = self.workflow_runner.current_node + if not current or current.node_type not in {"start", "tool"}: + break + + next_default = self.workflow_runner.next_default_transition() + if not next_default: + break + + auto_hops += 1 + await self._apply_workflow_transition(next_default, reason="auto") + if auto_hops >= max_auto_hops: + logger.warning( + "Session {} workflow auto-advance reached hop limit (possible cycle)", + self.id, + ) + break + + async def _apply_workflow_transition(self, transition: WorkflowTransition, reason: str) -> None: + """Apply graph transition and emit workflow lifecycle events.""" + if not self.workflow_runner: + return + + self.workflow_runner.apply_transition(transition) + node = transition.node + edge = transition.edge + + await self.transport.send_event( + ev( + "workflow.edge.taken", + sessionId=self.id, + workflowId=self.workflow_runner.workflow_id, + edgeId=edge.id, + fromNodeId=edge.from_node_id, + toNodeId=edge.to_node_id, + reason=reason, + ) + ) + await self.transport.send_event( + ev( + "workflow.node.entered", + sessionId=self.id, + workflowId=self.workflow_runner.workflow_id, + nodeId=node.id, + nodeName=node.name, + nodeType=node.node_type, + ) + ) + + node_runtime = self.workflow_runner.build_runtime_metadata(node) + if node_runtime: + self.pipeline.apply_runtime_overrides(node_runtime) + + if node.node_type == "tool": + await self.transport.send_event( + ev( + "workflow.tool.requested", + sessionId=self.id, + workflowId=self.workflow_runner.workflow_id, + nodeId=node.id, + tool=node.tool or {}, + ) + ) + return + + if node.node_type == "human_transfer": + await self.transport.send_event( + ev( + "workflow.human_transfer", + sessionId=self.id, + workflowId=self.workflow_runner.workflow_id, + nodeId=node.id, + ) + ) + await self._handle_session_stop("workflow_human_transfer") + return + + if node.node_type == "end": + await self.transport.send_event( + ev( + "workflow.ended", + sessionId=self.id, + workflowId=self.workflow_runner.workflow_id, + nodeId=node.id, + ) + ) + await self._handle_session_stop("workflow_end") + + async def _workflow_llm_route( + self, + node: WorkflowNodeDef, + candidates: List[WorkflowEdgeDef], + context: Dict[str, str], + ) -> Optional[str]: + """LLM-based edge routing for condition.type == 'llm' edges.""" + llm_service = self.pipeline.llm_service + if not llm_service: + return None + + candidate_rows = [ + { + "edgeId": edge.id, + "toNodeId": edge.to_node_id, + "label": edge.label, + "hint": edge.condition.get("prompt") if isinstance(edge.condition, dict) else None, + } + for edge in candidates + ] + system_prompt = ( + "You are a workflow router. Pick exactly one edge. " + "Return JSON only: {\"edgeId\":\"...\"}." + ) + user_prompt = json.dumps( + { + "nodeId": node.id, + "nodeName": node.name, + "userText": context.get("userText", ""), + "assistantText": context.get("assistantText", ""), + "candidates": candidate_rows, + }, + ensure_ascii=False, + ) + + try: + reply = await llm_service.generate( + [ + LLMMessage(role="system", content=system_prompt), + LLMMessage(role="user", content=user_prompt), + ], + temperature=0.0, + max_tokens=64, + ) + except Exception as exc: + logger.warning(f"Session {self.id} workflow llm routing failed: {exc}") + return None + + if not reply: + return None + + edge_ids = {edge.id for edge in candidates} + node_ids = {edge.to_node_id for edge in candidates} + + parsed = self._extract_json_obj(reply) + if isinstance(parsed, dict): + edge_id = parsed.get("edgeId") or parsed.get("id") + node_id = parsed.get("toNodeId") or parsed.get("nodeId") + if isinstance(edge_id, str) and edge_id in edge_ids: + return edge_id + if isinstance(node_id, str) and node_id in node_ids: + return node_id + + token_candidates = sorted(edge_ids | node_ids, key=len, reverse=True) + lowered_reply = reply.lower() + for token in token_candidates: + if token.lower() in lowered_reply: + return token + return None + + def _merge_runtime_metadata(self, base: Dict[str, Any], overrides: Dict[str, Any]) -> Dict[str, Any]: + """Merge node-level metadata overrides into session.start metadata.""" + merged = dict(base or {}) + if not overrides: + return merged + for key, value in overrides.items(): + if key == "services" and isinstance(value, dict): + existing = merged.get("services") + merged_services = dict(existing) if isinstance(existing, dict) else {} + merged_services.update(value) + merged["services"] = merged_services + else: + merged[key] = value + return merged + + def _extract_json_obj(self, text: str) -> Optional[Dict[str, Any]]: + """Best-effort extraction of a JSON object from freeform text.""" + try: + parsed = json.loads(text) + if isinstance(parsed, dict): + return parsed + except Exception: + pass + + match = re.search(r"\{.*\}", text, re.DOTALL) + if not match: + return None + try: + parsed = json.loads(match.group(0)) + return parsed if isinstance(parsed, dict) else None + except Exception: + return None diff --git a/engine/core/workflow_runner.py b/engine/core/workflow_runner.py new file mode 100644 index 0000000..2ad7ded --- /dev/null +++ b/engine/core/workflow_runner.py @@ -0,0 +1,402 @@ +"""Workflow runtime helpers for session-level node routing. + +MVP goals: +- Parse workflow graph payload from WS session.start metadata +- Track current node +- Evaluate edge conditions on each assistant turn completion +- Provide per-node runtime metadata overrides (prompt/greeting/services) +""" + +from __future__ import annotations + +from dataclasses import dataclass +import json +import re +from typing import Any, Awaitable, Callable, Dict, List, Optional + +from loguru import logger + + +_NODE_TYPE_MAP = { + "conversation": "assistant", + "assistant": "assistant", + "human": "human_transfer", + "human_transfer": "human_transfer", + "tool": "tool", + "end": "end", + "start": "start", +} + + +def _normalize_node_type(raw_type: Any) -> str: + value = str(raw_type or "").strip().lower() + return _NODE_TYPE_MAP.get(value, "assistant") + + +def _safe_str(value: Any) -> str: + if value is None: + return "" + return str(value) + + +def _normalize_condition(raw: Any, label: Optional[str]) -> Dict[str, Any]: + if not isinstance(raw, dict): + if label: + return {"type": "contains", "source": "user", "value": str(label)} + return {"type": "always"} + + condition = dict(raw) + condition_type = str(condition.get("type", "always")).strip().lower() + if not condition_type: + condition_type = "always" + condition["type"] = condition_type + condition["source"] = str(condition.get("source", "user")).strip().lower() or "user" + return condition + + +@dataclass +class WorkflowNodeDef: + id: str + name: str + node_type: str + is_start: bool + prompt: Optional[str] + message_plan: Dict[str, Any] + assistant_id: Optional[str] + assistant: Dict[str, Any] + tool: Optional[Dict[str, Any]] + raw: Dict[str, Any] + + +@dataclass +class WorkflowEdgeDef: + id: str + from_node_id: str + to_node_id: str + label: Optional[str] + condition: Dict[str, Any] + priority: int + order: int + raw: Dict[str, Any] + + +@dataclass +class WorkflowTransition: + edge: WorkflowEdgeDef + node: WorkflowNodeDef + + +LlmRouter = Callable[ + [WorkflowNodeDef, List[WorkflowEdgeDef], Dict[str, str]], + Awaitable[Optional[str]], +] + + +class WorkflowRunner: + """In-memory workflow graph for a single active session.""" + + def __init__(self, workflow_id: str, name: str, nodes: List[WorkflowNodeDef], edges: List[WorkflowEdgeDef]): + self.workflow_id = workflow_id + self.name = name + self._nodes: Dict[str, WorkflowNodeDef] = {node.id: node for node in nodes} + self._edges = edges + self.current_node_id: Optional[str] = None + + @classmethod + def from_payload(cls, payload: Any) -> Optional["WorkflowRunner"]: + if not isinstance(payload, dict): + return None + + raw_nodes = payload.get("nodes") + raw_edges = payload.get("edges") + if not isinstance(raw_nodes, list) or len(raw_nodes) == 0: + return None + + nodes: List[WorkflowNodeDef] = [] + for i, raw in enumerate(raw_nodes): + if not isinstance(raw, dict): + continue + + node_id = _safe_str(raw.get("id") or raw.get("name") or f"node_{i + 1}").strip() or f"node_{i + 1}" + node_name = _safe_str(raw.get("name") or node_id).strip() or node_id + node_type = _normalize_node_type(raw.get("type")) + is_start = bool(raw.get("isStart")) or node_type == "start" + + prompt: Optional[str] = None + if "prompt" in raw: + prompt = _safe_str(raw.get("prompt")) + + message_plan = raw.get("messagePlan") + if not isinstance(message_plan, dict): + message_plan = {} + + assistant_cfg = raw.get("assistant") + if not isinstance(assistant_cfg, dict): + assistant_cfg = {} + + tool_cfg = raw.get("tool") + if not isinstance(tool_cfg, dict): + tool_cfg = None + + assistant_id = raw.get("assistantId") + if assistant_id is not None: + assistant_id = _safe_str(assistant_id).strip() or None + + nodes.append( + WorkflowNodeDef( + id=node_id, + name=node_name, + node_type=node_type, + is_start=is_start, + prompt=prompt, + message_plan=message_plan, + assistant_id=assistant_id, + assistant=assistant_cfg, + tool=tool_cfg, + raw=raw, + ) + ) + + if not nodes: + return None + + node_ids = {node.id for node in nodes} + edges: List[WorkflowEdgeDef] = [] + for i, raw in enumerate(raw_edges if isinstance(raw_edges, list) else []): + if not isinstance(raw, dict): + continue + + from_node_id = _safe_str( + raw.get("fromNodeId") or raw.get("from") or raw.get("from_") or raw.get("source") + ).strip() + to_node_id = _safe_str(raw.get("toNodeId") or raw.get("to") or raw.get("target")).strip() + if not from_node_id or not to_node_id: + continue + if from_node_id not in node_ids or to_node_id not in node_ids: + continue + + label = raw.get("label") + if label is not None: + label = _safe_str(label) + + condition = _normalize_condition(raw.get("condition"), label=label) + + priority = 100 + try: + priority = int(raw.get("priority", 100)) + except (TypeError, ValueError): + priority = 100 + + edge_id = _safe_str(raw.get("id") or f"e_{from_node_id}_{to_node_id}_{i + 1}").strip() or f"e_{i + 1}" + + edges.append( + WorkflowEdgeDef( + id=edge_id, + from_node_id=from_node_id, + to_node_id=to_node_id, + label=label, + condition=condition, + priority=priority, + order=i, + raw=raw, + ) + ) + + workflow_id = _safe_str(payload.get("id") or "workflow") + workflow_name = _safe_str(payload.get("name") or workflow_id) + return cls(workflow_id=workflow_id, name=workflow_name, nodes=nodes, edges=edges) + + def bootstrap(self) -> Optional[WorkflowNodeDef]: + start_node = self._resolve_start_node() + if not start_node: + return None + self.current_node_id = start_node.id + return start_node + + @property + def current_node(self) -> Optional[WorkflowNodeDef]: + if not self.current_node_id: + return None + return self._nodes.get(self.current_node_id) + + def outgoing_edges(self, node_id: str) -> List[WorkflowEdgeDef]: + edges = [edge for edge in self._edges if edge.from_node_id == node_id] + return sorted(edges, key=lambda edge: (edge.priority, edge.order)) + + def next_default_transition(self) -> Optional[WorkflowTransition]: + node = self.current_node + if not node: + return None + for edge in self.outgoing_edges(node.id): + cond_type = str(edge.condition.get("type", "always")).strip().lower() + if cond_type in {"", "always", "default"}: + target = self._nodes.get(edge.to_node_id) + if target: + return WorkflowTransition(edge=edge, node=target) + return None + + async def route( + self, + *, + user_text: str, + assistant_text: str, + llm_router: Optional[LlmRouter] = None, + ) -> Optional[WorkflowTransition]: + node = self.current_node + if not node: + return None + + outgoing = self.outgoing_edges(node.id) + if not outgoing: + return None + + llm_edges: List[WorkflowEdgeDef] = [] + for edge in outgoing: + cond_type = str(edge.condition.get("type", "always")).strip().lower() + if cond_type == "llm": + llm_edges.append(edge) + continue + if self._matches_condition(edge, user_text=user_text, assistant_text=assistant_text): + target = self._nodes.get(edge.to_node_id) + if target: + return WorkflowTransition(edge=edge, node=target) + + if llm_edges and llm_router: + selection = await llm_router( + node, + llm_edges, + { + "userText": user_text, + "assistantText": assistant_text, + }, + ) + if selection: + for edge in llm_edges: + if selection in {edge.id, edge.to_node_id}: + target = self._nodes.get(edge.to_node_id) + if target: + return WorkflowTransition(edge=edge, node=target) + + for edge in outgoing: + cond_type = str(edge.condition.get("type", "always")).strip().lower() + if cond_type in {"", "always", "default"}: + target = self._nodes.get(edge.to_node_id) + if target: + return WorkflowTransition(edge=edge, node=target) + return None + + def apply_transition(self, transition: WorkflowTransition) -> None: + self.current_node_id = transition.node.id + + def build_runtime_metadata(self, node: WorkflowNodeDef) -> Dict[str, Any]: + assistant_cfg = node.assistant if isinstance(node.assistant, dict) else {} + message_plan = node.message_plan if isinstance(node.message_plan, dict) else {} + metadata: Dict[str, Any] = {} + + if node.prompt is not None: + metadata["systemPrompt"] = node.prompt + elif "systemPrompt" in assistant_cfg: + metadata["systemPrompt"] = _safe_str(assistant_cfg.get("systemPrompt")) + elif "prompt" in assistant_cfg: + metadata["systemPrompt"] = _safe_str(assistant_cfg.get("prompt")) + + first_message = message_plan.get("firstMessage") + if first_message is not None: + metadata["greeting"] = _safe_str(first_message) + elif "greeting" in assistant_cfg: + metadata["greeting"] = _safe_str(assistant_cfg.get("greeting")) + elif "opener" in assistant_cfg: + metadata["greeting"] = _safe_str(assistant_cfg.get("opener")) + + services = assistant_cfg.get("services") + if isinstance(services, dict): + metadata["services"] = services + + if node.assistant_id: + metadata["assistantId"] = node.assistant_id + + return metadata + + def _resolve_start_node(self) -> Optional[WorkflowNodeDef]: + explicit_start = next((node for node in self._nodes.values() if node.is_start), None) + if not explicit_start: + explicit_start = next((node for node in self._nodes.values() if node.node_type == "start"), None) + + if explicit_start: + # If a dedicated start node exists, try to move to its first default target. + if explicit_start.node_type == "start": + visited = {explicit_start.id} + current = explicit_start + for _ in range(8): + transition = self._first_default_transition_from(current.id) + if not transition: + return current + current = transition.node + if current.id in visited: + break + visited.add(current.id) + return current + return explicit_start + + assistant_node = next((node for node in self._nodes.values() if node.node_type == "assistant"), None) + if assistant_node: + return assistant_node + return next(iter(self._nodes.values()), None) + + def _first_default_transition_from(self, node_id: str) -> Optional[WorkflowTransition]: + for edge in self.outgoing_edges(node_id): + cond_type = str(edge.condition.get("type", "always")).strip().lower() + if cond_type in {"", "always", "default"}: + node = self._nodes.get(edge.to_node_id) + if node: + return WorkflowTransition(edge=edge, node=node) + return None + + def _matches_condition(self, edge: WorkflowEdgeDef, *, user_text: str, assistant_text: str) -> bool: + condition = edge.condition or {"type": "always"} + cond_type = str(condition.get("type", "always")).strip().lower() + source = str(condition.get("source", "user")).strip().lower() + + if cond_type in {"", "always", "default"}: + return True + + text = assistant_text if source == "assistant" else user_text + text_lower = (text or "").lower() + + if cond_type == "contains": + values: List[str] = [] + if isinstance(condition.get("values"), list): + values = [_safe_str(v).strip().lower() for v in condition["values"] if _safe_str(v).strip()] + if not values: + single = _safe_str(condition.get("value") or condition.get("keyword") or edge.label).strip().lower() + if single: + values = [single] + if not values: + return False + return any(value in text_lower for value in values) + + if cond_type == "equals": + expected = _safe_str(condition.get("value") or "").strip().lower() + return bool(expected) and text_lower == expected + + if cond_type == "regex": + pattern = _safe_str(condition.get("value") or condition.get("pattern") or "").strip() + if not pattern: + return False + try: + return bool(re.search(pattern, text or "", re.IGNORECASE)) + except re.error: + logger.warning(f"Invalid workflow regex condition: {pattern}") + return False + + if cond_type == "json": + value = _safe_str(condition.get("value") or "").strip() + if not value: + return False + try: + obj = json.loads(text or "") + except Exception: + return False + return str(obj) == value + + return False diff --git a/web/pages/WorkflowEditor.tsx b/web/pages/WorkflowEditor.tsx index 184c80d..a1b3a4e 100644 --- a/web/pages/WorkflowEditor.tsx +++ b/web/pages/WorkflowEditor.tsx @@ -1,5 +1,5 @@ -import React, { useState, useRef, useEffect } from 'react'; +import React, { useState, useRef, useEffect, useMemo } from 'react'; import { useNavigate, useParams, useSearchParams } from 'react-router-dom'; import { ArrowLeft, Play, Save, Rocket, Plus, Bot, UserCheck, Wrench, Ban, Zap, X, Copy, MousePointer2 } from 'lucide-react'; import { Button, Input, Badge } from '../components/UI'; @@ -7,10 +7,21 @@ import { Assistant, WorkflowNode, WorkflowEdge, Workflow } from '../types'; import { DebugDrawer } from './Assistants'; import { createWorkflow, fetchAssistants, fetchWorkflowById, updateWorkflow } from '../services/backendApi'; +const toWorkflowNodeType = (type: WorkflowNode['type']): WorkflowNode['type'] => { + if (type === 'conversation') return 'assistant'; + if (type === 'human') return 'human_transfer'; + return type; +}; + +const nodeRef = (node: WorkflowNode): string => node.id || node.name; +const edgeFromRef = (edge: WorkflowEdge): string => edge.fromNodeId || edge.from; +const edgeToRef = (edge: WorkflowEdge): string => edge.toNodeId || edge.to; + const getTemplateNodes = (templateType: string | null): WorkflowNode[] => { if (templateType === 'lead') { return [ { + id: 'introduction', name: 'introduction', type: 'conversation', isStart: true, @@ -19,6 +30,7 @@ const getTemplateNodes = (templateType: string | null): WorkflowNode[] => { messagePlan: { firstMessage: "Hello, this is Morgan from GrowthPartners. Do you have a few minutes to chat?" } }, { + id: 'need_discovery', name: 'need_discovery', type: 'conversation', metadata: { position: { x: 450, y: 250 } }, @@ -31,6 +43,7 @@ const getTemplateNodes = (templateType: string | null): WorkflowNode[] => { } }, { + id: 'hangup_node', name: 'hangup_node', type: 'end', metadata: { position: { x: 450, y: 550 } }, @@ -44,6 +57,7 @@ const getTemplateNodes = (templateType: string | null): WorkflowNode[] => { } return [ { + id: 'start_node', name: 'start_node', type: 'conversation', isStart: true, @@ -80,6 +94,66 @@ export const WorkflowEditorPage: React.FC = () => { const panStart = useRef({ x: 0, y: 0 }); const selectedNode = nodes.find(n => n.name === selectedNodeName); + const selectedNodeRef = selectedNode ? nodeRef(selectedNode) : null; + const outgoingEdges = selectedNodeRef + ? edges.filter((edge) => edgeFromRef(edge) === selectedNodeRef) + : []; + const resolvedEdges = useMemo(() => { + return edges + .map((edge, index) => { + const from = nodes.find((node) => nodeRef(node) === edgeFromRef(edge)); + const to = nodes.find((node) => nodeRef(node) === edgeToRef(edge)); + if (!from || !to) return null; + return { + key: edge.id || `${edgeFromRef(edge)}->${edgeToRef(edge)}:${index}`, + edge, + from, + to, + }; + }) + .filter((item): item is { key: string; edge: WorkflowEdge; from: WorkflowNode; to: WorkflowNode } => Boolean(item)); + }, [edges, nodes]); + + const workflowRuntimeMetadata = useMemo(() => { + return { + id: id || 'draft_workflow', + name, + nodes: nodes.map((node) => { + const assistant = node.assistantId + ? assistants.find((item) => item.id === node.assistantId) + : undefined; + return { + id: nodeRef(node), + name: node.name || nodeRef(node), + type: toWorkflowNodeType(node.type), + isStart: node.isStart, + prompt: node.prompt, + messagePlan: node.messagePlan, + assistantId: node.assistantId, + assistant: assistant + ? { + systemPrompt: node.prompt || assistant.prompt || '', + greeting: node.messagePlan?.firstMessage || assistant.opener || '', + } + : undefined, + tool: node.tool, + metadata: node.metadata, + }; + }), + edges: edges.map((edge, index) => { + const fromNodeId = edgeFromRef(edge); + const toNodeId = edgeToRef(edge); + return { + id: edge.id || `edge_${index + 1}`, + fromNodeId, + toNodeId, + label: edge.label, + priority: edge.priority ?? 100, + condition: edge.condition || (edge.label ? { type: 'contains', source: 'user', value: edge.label } : { type: 'always' }), + }; + }), + }; + }, [assistants, edges, id, name, nodes]); // Scroll Zoom handler const handleWheel = (e: React.WheelEvent) => { @@ -172,8 +246,10 @@ export const WorkflowEditorPage: React.FC = () => { }, [id]); const addNode = (type: WorkflowNode['type']) => { + const nodeId = `${type}_${Date.now()}`; const newNode: WorkflowNode = { - name: `${type}_${Date.now()}`, + id: nodeId, + name: nodeId, type, metadata: { position: { x: (300 - panOffset.x) / zoom, y: (300 - panOffset.y) / zoom } }, prompt: type === 'conversation' ? '输入该节点的 Prompt...' : '', @@ -185,7 +261,95 @@ export const WorkflowEditorPage: React.FC = () => { const updateNodeData = (field: string, value: any) => { if (!selectedNodeName) return; - setNodes(prev => prev.map(n => n.name === selectedNodeName ? { ...n, [field]: value } : n)); + setNodes(prev => { + const currentNode = prev.find((n) => n.name === selectedNodeName); + if (!currentNode) return prev; + + const oldRef = nodeRef(currentNode); + const updatedNodes = prev.map((node) => { + if (node.name !== selectedNodeName) { + if (field === 'isStart' && value === true) { + return { ...node, isStart: false }; + } + return node; + } + if (field === 'isStart' && value === true) { + return { ...node, isStart: true }; + } + return { ...node, [field]: value }; + }); + + if (field === 'name') { + const renamed = updatedNodes.find((n) => n.name === value); + const newRef = renamed ? nodeRef(renamed) : String(value); + setEdges((prevEdges) => + prevEdges.map((edge) => { + const from = edgeFromRef(edge); + const to = edgeToRef(edge); + if (from !== oldRef && to !== oldRef) return edge; + const nextFrom = from === oldRef ? newRef : from; + const nextTo = to === oldRef ? newRef : to; + return { + ...edge, + fromNodeId: nextFrom, + toNodeId: nextTo, + from: nextFrom, + to: nextTo, + }; + }) + ); + } + + return updatedNodes; + }); + }; + + const addEdgeFromSelected = () => { + if (!selectedNode) return; + const fromNodeId = nodeRef(selectedNode); + const target = nodes.find((node) => nodeRef(node) !== fromNodeId); + if (!target) return; + const toNodeId = nodeRef(target); + const edgeId = `edge_${Date.now()}`; + setEdges((prev) => [ + ...prev, + { + id: edgeId, + fromNodeId, + toNodeId, + from: fromNodeId, + to: toNodeId, + condition: { type: 'always' }, + }, + ]); + }; + + const updateOutgoingEdge = (edgeId: string, patch: Partial) => { + setEdges((prev) => + prev.map((edge, index) => { + const idForCompare = edge.id || `${edgeFromRef(edge)}->${edgeToRef(edge)}:${index}`; + if (idForCompare !== edgeId) return edge; + const next = { ...edge, ...patch }; + const fromNodeId = edgeFromRef(next); + const toNodeId = edgeToRef(next); + return { + ...next, + fromNodeId, + toNodeId, + from: fromNodeId, + to: toNodeId, + }; + }) + ); + }; + + const removeOutgoingEdge = (edgeId: string) => { + setEdges((prev) => + prev.filter((edge, index) => { + const idForCompare = edge.id || `${edgeFromRef(edge)}->${edgeToRef(edge)}:${index}`; + return idForCompare !== edgeId; + }) + ); }; const handleSave = async () => { @@ -286,9 +450,29 @@ export const WorkflowEditorPage: React.FC = () => { transformOrigin: '0 0' }} > + + {resolvedEdges.map(({ key, from, to, edge }) => { + const x1 = from.metadata.position.x + 112; + const y1 = from.metadata.position.y + 88; + const x2 = to.metadata.position.x + 112; + const y2 = to.metadata.position.y; + const midY = (y1 + y2) / 2; + const d = `M ${x1} ${y1} C ${x1} ${midY}, ${x2} ${midY}, ${x2} ${y2}`; + return ( + + + {(edge.label || edge.condition?.value) && ( + + {edge.label || edge.condition?.value} + + )} + + ); + })} + {nodes.map(node => (
handleNodeMouseDown(e, node.name)} style={{ left: node.metadata.position.x, top: node.metadata.position.y }} className={`absolute w-56 p-4 rounded-xl border bg-card/70 backdrop-blur-sm cursor-grab active:cursor-grabbing group transition-shadow ${selectedNodeName === node.name ? 'border-primary shadow-[0_0_30px_rgba(6,182,212,0.3)]' : 'border-white/10 hover:border-white/30'}`} @@ -324,7 +508,13 @@ export const WorkflowEditorPage: React.FC = () => { return (
); @@ -369,8 +559,23 @@ export const WorkflowEditorPage: React.FC = () => { {selectedNode.type.toUpperCase()}
- {selectedNode.type === 'conversation' && ( + {(selectedNode.type === 'conversation' || selectedNode.type === 'assistant' || selectedNode.type === 'start') && ( <> +
+ + +