Update workflow feature with codex

This commit is contained in:
Xin Wang
2026-02-10 08:12:46 +08:00
parent 6b4391c423
commit bbeffa89ed
8 changed files with 1334 additions and 39 deletions

View File

@@ -2,14 +2,30 @@ from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
import uuid
from datetime import datetime
from typing import Any, Dict, List, Tuple
from ..db import get_db
from ..models import Workflow
from ..schemas import WorkflowCreate, WorkflowUpdate, WorkflowOut
from ..schemas import WorkflowCreate, WorkflowUpdate, WorkflowOut, WorkflowNode, WorkflowEdge
router = APIRouter(prefix="/workflows", tags=["Workflows"])
def _normalize_graph_payload(nodes: List[Any], edges: List[Any]) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
"""Normalize graph payload to canonical dict structures."""
parsed_nodes: List[WorkflowNode] = []
for node in nodes:
parsed_nodes.append(node if isinstance(node, WorkflowNode) else WorkflowNode.model_validate(node))
parsed_edges: List[WorkflowEdge] = []
for edge in edges:
parsed_edges.append(edge if isinstance(edge, WorkflowEdge) else WorkflowEdge.model_validate(edge))
normalized_nodes = [node.model_dump() for node in parsed_nodes]
normalized_edges = [edge.model_dump() for edge in parsed_edges]
return normalized_nodes, normalized_edges
@router.get("")
def list_workflows(
page: int = 1,
@@ -27,16 +43,17 @@ def list_workflows(
@router.post("", response_model=WorkflowOut)
def create_workflow(data: WorkflowCreate, db: Session = Depends(get_db)):
"""创建工作流"""
nodes, edges = _normalize_graph_payload(data.nodes, data.edges)
workflow = Workflow(
id=str(uuid.uuid4())[:8],
user_id=1,
name=data.name,
node_count=data.nodeCount,
node_count=data.nodeCount or len(nodes),
created_at=data.createdAt or datetime.utcnow().isoformat(),
updated_at=data.updatedAt or "",
global_prompt=data.globalPrompt,
nodes=data.nodes,
edges=data.edges,
nodes=nodes,
edges=edges,
)
db.add(workflow)
db.commit()
@@ -60,7 +77,7 @@ def update_workflow(id: str, data: WorkflowUpdate, db: Session = Depends(get_db)
if not workflow:
raise HTTPException(status_code=404, detail="Workflow not found")
update_data = data.model_dump(exclude_unset=True)
update_data = data.model_dump(exclude_unset=True, exclude={"nodes", "edges"})
field_map = {
"nodeCount": "node_count",
"globalPrompt": "global_prompt",
@@ -68,6 +85,16 @@ def update_workflow(id: str, data: WorkflowUpdate, db: Session = Depends(get_db)
for field, value in update_data.items():
setattr(workflow, field_map.get(field, field), value)
if data.nodes is not None or data.edges is not None:
existing_nodes = workflow.nodes if isinstance(workflow.nodes, list) else []
existing_edges = workflow.edges if isinstance(workflow.edges, list) else []
input_nodes = data.nodes if data.nodes is not None else existing_nodes
input_edges = data.edges if data.edges is not None else existing_edges
nodes, edges = _normalize_graph_payload(input_nodes, input_edges)
workflow.nodes = nodes
workflow.edges = edges
workflow.node_count = len(nodes)
workflow.updated_at = datetime.utcnow().isoformat()
db.commit()
db.refresh(workflow)

View File

@@ -1,7 +1,7 @@
from datetime import datetime
from enum import Enum
from typing import List, Optional
from pydantic import BaseModel
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, ConfigDict, Field, model_validator
# ============ Enums ============
@@ -410,24 +410,82 @@ class KnowledgeStats(BaseModel):
# ============ Workflow ============
class WorkflowNode(BaseModel):
name: str
type: str
model_config = ConfigDict(extra="allow")
id: Optional[str] = None
name: str = ""
type: str = "assistant"
isStart: Optional[bool] = None
metadata: dict
metadata: Dict[str, Any] = Field(default_factory=dict)
prompt: Optional[str] = None
messagePlan: Optional[dict] = None
variableExtractionPlan: Optional[dict] = None
tool: Optional[dict] = None
globalNodePlan: Optional[dict] = None
messagePlan: Optional[Dict[str, Any]] = None
variableExtractionPlan: Optional[Dict[str, Any]] = None
tool: Optional[Dict[str, Any]] = None
globalNodePlan: Optional[Dict[str, Any]] = None
assistantId: Optional[str] = None
assistant: Optional[Dict[str, Any]] = None
@model_validator(mode="before")
@classmethod
def _normalize_legacy_node(cls, data: Any) -> Any:
if not isinstance(data, dict):
return data
raw = dict(data)
node_id = raw.get("id") or raw.get("name")
if not node_id:
node_id = f"node_{abs(hash(str(raw))) % 100000}"
raw["id"] = str(node_id)
raw["name"] = str(raw.get("name") or raw["id"])
node_type = str(raw.get("type") or "assistant").lower()
if node_type == "conversation":
node_type = "assistant"
elif node_type == "human":
node_type = "human_transfer"
elif node_type not in {"start", "assistant", "tool", "human_transfer", "end"}:
node_type = "assistant"
raw["type"] = node_type
metadata = raw.get("metadata")
if not isinstance(metadata, dict):
metadata = {}
if "position" not in metadata and isinstance(raw.get("position"), dict):
metadata["position"] = raw.get("position")
raw["metadata"] = metadata
if raw.get("isStart") is None and node_type == "start":
raw["isStart"] = True
return raw
class WorkflowEdge(BaseModel):
from_: str
to: str
label: Optional[str] = None
model_config = ConfigDict(extra="allow")
class Config:
populate_by_name = True
id: Optional[str] = None
fromNodeId: str
toNodeId: str
label: Optional[str] = None
condition: Optional[Dict[str, Any]] = None
priority: int = 100
@model_validator(mode="before")
@classmethod
def _normalize_legacy_edge(cls, data: Any) -> Any:
if not isinstance(data, dict):
return data
raw = dict(data)
from_node = raw.get("fromNodeId") or raw.get("from") or raw.get("from_") or raw.get("source")
to_node = raw.get("toNodeId") or raw.get("to") or raw.get("target")
raw["fromNodeId"] = str(from_node or "")
raw["toNodeId"] = str(to_node or "")
if raw.get("id") is None:
raw["id"] = f"e_{raw['fromNodeId']}_{raw['toNodeId']}"
if raw.get("condition") is None:
if raw.get("label"):
raw["condition"] = {"type": "contains", "source": "user", "value": str(raw["label"])}
else:
raw["condition"] = {"type": "always"}
return raw
class WorkflowBase(BaseModel):
@@ -436,29 +494,85 @@ class WorkflowBase(BaseModel):
createdAt: str = ""
updatedAt: str = ""
globalPrompt: Optional[str] = None
nodes: List[dict] = []
edges: List[dict] = []
nodes: List[WorkflowNode] = Field(default_factory=list)
edges: List[WorkflowEdge] = Field(default_factory=list)
class WorkflowCreate(WorkflowBase):
pass
@model_validator(mode="after")
def _validate_graph(self) -> "WorkflowCreate":
_validate_workflow_graph(self.nodes, self.edges)
return self
class WorkflowUpdate(BaseModel):
name: Optional[str] = None
nodeCount: Optional[int] = None
nodes: Optional[List[dict]] = None
edges: Optional[List[dict]] = None
nodes: Optional[List[WorkflowNode]] = None
edges: Optional[List[WorkflowEdge]] = None
globalPrompt: Optional[str] = None
@model_validator(mode="after")
def _validate_partial_graph(self) -> "WorkflowUpdate":
if self.nodes is not None and self.edges is not None:
_validate_workflow_graph(self.nodes, self.edges)
return self
class WorkflowOut(WorkflowBase):
id: str
@model_validator(mode="before")
@classmethod
def _normalize_db_fields(cls, data: Any) -> Any:
if isinstance(data, dict):
raw = dict(data)
else:
raw = {
"id": getattr(data, "id", None),
"name": getattr(data, "name", None),
"node_count": getattr(data, "node_count", None),
"created_at": getattr(data, "created_at", None),
"updated_at": getattr(data, "updated_at", None),
"global_prompt": getattr(data, "global_prompt", None),
"nodes": getattr(data, "nodes", None),
"edges": getattr(data, "edges", None),
}
if "nodeCount" not in raw and raw.get("node_count") is not None:
raw["nodeCount"] = raw["node_count"]
if "createdAt" not in raw and raw.get("created_at") is not None:
raw["createdAt"] = raw["created_at"]
if "updatedAt" not in raw and raw.get("updated_at") is not None:
raw["updatedAt"] = raw["updated_at"]
if "globalPrompt" not in raw and raw.get("global_prompt") is not None:
raw["globalPrompt"] = raw["global_prompt"]
return raw
class Config:
from_attributes = True
def _validate_workflow_graph(nodes: List[WorkflowNode], edges: List[WorkflowEdge]) -> None:
if not nodes:
raise ValueError("Workflow must include at least one node")
node_ids = [node.id for node in nodes if node.id]
if len(node_ids) != len(set(node_ids)):
raise ValueError("Workflow node ids must be unique")
starts = [node for node in nodes if node.isStart or node.type == "start"]
if not starts:
raise ValueError("Workflow must define a start node (isStart=true or type=start)")
known = set(node_ids)
for edge in edges:
if edge.fromNodeId not in known:
raise ValueError(f"Workflow edge fromNodeId not found: {edge.fromNodeId}")
if edge.toNodeId not in known:
raise ValueError(f"Workflow edge toNodeId not found: {edge.toNodeId}")
# ============ Call Record ============
class TranscriptSegment(BaseModel):
turnIndex: int

167
api/tests/test_workflows.py Normal file
View File

@@ -0,0 +1,167 @@
"""Tests for workflow graph schema and router behavior."""
class TestWorkflowAPI:
"""Workflow CRUD and graph validation test cases."""
def _minimal_nodes(self):
return [
{
"id": "start_1",
"name": "start_1",
"type": "start",
"isStart": True,
"metadata": {"position": {"x": 80, "y": 80}},
},
{
"id": "assistant_1",
"name": "assistant_1",
"type": "assistant",
"metadata": {"position": {"x": 280, "y": 80}},
"prompt": "You are the first assistant node.",
},
]
def test_create_workflow_with_canonical_graph(self, client):
payload = {
"name": "Canonical Graph",
"nodes": self._minimal_nodes(),
"edges": [
{
"id": "edge_start_assistant",
"fromNodeId": "start_1",
"toNodeId": "assistant_1",
"condition": {"type": "always"},
}
],
}
resp = client.post("/api/workflows", json=payload)
assert resp.status_code == 200
data = resp.json()
assert data["name"] == "Canonical Graph"
assert data["nodeCount"] == 2
assert data["nodes"][0]["id"] == "start_1"
assert data["edges"][0]["fromNodeId"] == "start_1"
assert data["edges"][0]["toNodeId"] == "assistant_1"
def test_create_workflow_with_legacy_graph(self, client):
payload = {
"name": "Legacy Graph",
"nodes": [
{
"name": "legacy_start",
"type": "conversation",
"isStart": True,
"metadata": {"position": {"x": 100, "y": 100}},
},
{
"name": "legacy_human",
"type": "human",
"metadata": {"position": {"x": 300, "y": 100}},
},
],
"edges": [
{
"from": "legacy_start",
"to": "legacy_human",
"label": "人工",
}
],
}
resp = client.post("/api/workflows", json=payload)
assert resp.status_code == 200
data = resp.json()
assert data["nodes"][0]["type"] == "assistant"
assert data["nodes"][1]["type"] == "human_transfer"
assert data["edges"][0]["fromNodeId"] == "legacy_start"
assert data["edges"][0]["toNodeId"] == "legacy_human"
assert data["edges"][0]["condition"]["type"] == "contains"
def test_create_workflow_without_start_node_fails(self, client):
payload = {
"name": "No Start",
"nodes": [
{"id": "node_1", "name": "node_1", "type": "assistant", "metadata": {"position": {"x": 0, "y": 0}}},
],
"edges": [],
}
resp = client.post("/api/workflows", json=payload)
assert resp.status_code == 422
def test_create_workflow_with_invalid_edge_fails(self, client):
payload = {
"name": "Bad Edge",
"nodes": self._minimal_nodes(),
"edges": [
{"id": "edge_bad", "fromNodeId": "missing", "toNodeId": "assistant_1", "condition": {"type": "always"}},
],
}
resp = client.post("/api/workflows", json=payload)
assert resp.status_code == 422
def test_update_workflow_nodes_and_edges(self, client):
create_payload = {
"name": "Before Update",
"nodes": self._minimal_nodes(),
"edges": [
{
"id": "edge_start_assistant",
"fromNodeId": "start_1",
"toNodeId": "assistant_1",
"condition": {"type": "always"},
}
],
}
create_resp = client.post("/api/workflows", json=create_payload)
assert create_resp.status_code == 200
workflow_id = create_resp.json()["id"]
update_payload = {
"name": "After Update",
"nodes": [
{
"id": "start_1",
"name": "start_1",
"type": "start",
"isStart": True,
"metadata": {"position": {"x": 50, "y": 50}},
},
{
"id": "assistant_2",
"name": "assistant_2",
"type": "assistant",
"metadata": {"position": {"x": 250, "y": 50}},
"prompt": "new prompt",
},
{
"id": "end_1",
"name": "end_1",
"type": "end",
"metadata": {"position": {"x": 450, "y": 50}},
},
],
"edges": [
{
"id": "edge_start_assistant2",
"fromNodeId": "start_1",
"toNodeId": "assistant_2",
"condition": {"type": "always"},
},
{
"id": "edge_assistant2_end",
"fromNodeId": "assistant_2",
"toNodeId": "end_1",
"condition": {"type": "contains", "source": "user", "value": "结束"},
},
],
}
update_resp = client.put(f"/api/workflows/{workflow_id}", json=update_payload)
assert update_resp.status_code == 200
updated = update_resp.json()
assert updated["name"] == "After Update"
assert updated["nodeCount"] == 3
assert len(updated["nodes"]) == 3
assert len(updated["edges"]) == 2