Update workflow feature with codex
This commit is contained in:
@@ -2,14 +2,30 @@ from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from ..db import get_db
|
||||
from ..models import Workflow
|
||||
from ..schemas import WorkflowCreate, WorkflowUpdate, WorkflowOut
|
||||
from ..schemas import WorkflowCreate, WorkflowUpdate, WorkflowOut, WorkflowNode, WorkflowEdge
|
||||
|
||||
router = APIRouter(prefix="/workflows", tags=["Workflows"])
|
||||
|
||||
|
||||
def _normalize_graph_payload(nodes: List[Any], edges: List[Any]) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
||||
"""Normalize graph payload to canonical dict structures."""
|
||||
parsed_nodes: List[WorkflowNode] = []
|
||||
for node in nodes:
|
||||
parsed_nodes.append(node if isinstance(node, WorkflowNode) else WorkflowNode.model_validate(node))
|
||||
|
||||
parsed_edges: List[WorkflowEdge] = []
|
||||
for edge in edges:
|
||||
parsed_edges.append(edge if isinstance(edge, WorkflowEdge) else WorkflowEdge.model_validate(edge))
|
||||
|
||||
normalized_nodes = [node.model_dump() for node in parsed_nodes]
|
||||
normalized_edges = [edge.model_dump() for edge in parsed_edges]
|
||||
return normalized_nodes, normalized_edges
|
||||
|
||||
|
||||
@router.get("")
|
||||
def list_workflows(
|
||||
page: int = 1,
|
||||
@@ -27,16 +43,17 @@ def list_workflows(
|
||||
@router.post("", response_model=WorkflowOut)
|
||||
def create_workflow(data: WorkflowCreate, db: Session = Depends(get_db)):
|
||||
"""创建工作流"""
|
||||
nodes, edges = _normalize_graph_payload(data.nodes, data.edges)
|
||||
workflow = Workflow(
|
||||
id=str(uuid.uuid4())[:8],
|
||||
user_id=1,
|
||||
name=data.name,
|
||||
node_count=data.nodeCount,
|
||||
node_count=data.nodeCount or len(nodes),
|
||||
created_at=data.createdAt or datetime.utcnow().isoformat(),
|
||||
updated_at=data.updatedAt or "",
|
||||
global_prompt=data.globalPrompt,
|
||||
nodes=data.nodes,
|
||||
edges=data.edges,
|
||||
nodes=nodes,
|
||||
edges=edges,
|
||||
)
|
||||
db.add(workflow)
|
||||
db.commit()
|
||||
@@ -60,7 +77,7 @@ def update_workflow(id: str, data: WorkflowUpdate, db: Session = Depends(get_db)
|
||||
if not workflow:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
update_data = data.model_dump(exclude_unset=True, exclude={"nodes", "edges"})
|
||||
field_map = {
|
||||
"nodeCount": "node_count",
|
||||
"globalPrompt": "global_prompt",
|
||||
@@ -68,6 +85,16 @@ def update_workflow(id: str, data: WorkflowUpdate, db: Session = Depends(get_db)
|
||||
for field, value in update_data.items():
|
||||
setattr(workflow, field_map.get(field, field), value)
|
||||
|
||||
if data.nodes is not None or data.edges is not None:
|
||||
existing_nodes = workflow.nodes if isinstance(workflow.nodes, list) else []
|
||||
existing_edges = workflow.edges if isinstance(workflow.edges, list) else []
|
||||
input_nodes = data.nodes if data.nodes is not None else existing_nodes
|
||||
input_edges = data.edges if data.edges is not None else existing_edges
|
||||
nodes, edges = _normalize_graph_payload(input_nodes, input_edges)
|
||||
workflow.nodes = nodes
|
||||
workflow.edges = edges
|
||||
workflow.node_count = len(nodes)
|
||||
|
||||
workflow.updated_at = datetime.utcnow().isoformat()
|
||||
db.commit()
|
||||
db.refresh(workflow)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, Dict, List, Optional
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
|
||||
# ============ Enums ============
|
||||
@@ -410,24 +410,82 @@ class KnowledgeStats(BaseModel):
|
||||
|
||||
# ============ Workflow ============
|
||||
class WorkflowNode(BaseModel):
|
||||
name: str
|
||||
type: str
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
id: Optional[str] = None
|
||||
name: str = ""
|
||||
type: str = "assistant"
|
||||
isStart: Optional[bool] = None
|
||||
metadata: dict
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
prompt: Optional[str] = None
|
||||
messagePlan: Optional[dict] = None
|
||||
variableExtractionPlan: Optional[dict] = None
|
||||
tool: Optional[dict] = None
|
||||
globalNodePlan: Optional[dict] = None
|
||||
messagePlan: Optional[Dict[str, Any]] = None
|
||||
variableExtractionPlan: Optional[Dict[str, Any]] = None
|
||||
tool: Optional[Dict[str, Any]] = None
|
||||
globalNodePlan: Optional[Dict[str, Any]] = None
|
||||
assistantId: Optional[str] = None
|
||||
assistant: Optional[Dict[str, Any]] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _normalize_legacy_node(cls, data: Any) -> Any:
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
raw = dict(data)
|
||||
node_id = raw.get("id") or raw.get("name")
|
||||
if not node_id:
|
||||
node_id = f"node_{abs(hash(str(raw))) % 100000}"
|
||||
raw["id"] = str(node_id)
|
||||
raw["name"] = str(raw.get("name") or raw["id"])
|
||||
|
||||
node_type = str(raw.get("type") or "assistant").lower()
|
||||
if node_type == "conversation":
|
||||
node_type = "assistant"
|
||||
elif node_type == "human":
|
||||
node_type = "human_transfer"
|
||||
elif node_type not in {"start", "assistant", "tool", "human_transfer", "end"}:
|
||||
node_type = "assistant"
|
||||
raw["type"] = node_type
|
||||
|
||||
metadata = raw.get("metadata")
|
||||
if not isinstance(metadata, dict):
|
||||
metadata = {}
|
||||
if "position" not in metadata and isinstance(raw.get("position"), dict):
|
||||
metadata["position"] = raw.get("position")
|
||||
raw["metadata"] = metadata
|
||||
|
||||
if raw.get("isStart") is None and node_type == "start":
|
||||
raw["isStart"] = True
|
||||
return raw
|
||||
|
||||
|
||||
class WorkflowEdge(BaseModel):
|
||||
from_: str
|
||||
to: str
|
||||
label: Optional[str] = None
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
class Config:
|
||||
populate_by_name = True
|
||||
id: Optional[str] = None
|
||||
fromNodeId: str
|
||||
toNodeId: str
|
||||
label: Optional[str] = None
|
||||
condition: Optional[Dict[str, Any]] = None
|
||||
priority: int = 100
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _normalize_legacy_edge(cls, data: Any) -> Any:
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
raw = dict(data)
|
||||
from_node = raw.get("fromNodeId") or raw.get("from") or raw.get("from_") or raw.get("source")
|
||||
to_node = raw.get("toNodeId") or raw.get("to") or raw.get("target")
|
||||
raw["fromNodeId"] = str(from_node or "")
|
||||
raw["toNodeId"] = str(to_node or "")
|
||||
if raw.get("id") is None:
|
||||
raw["id"] = f"e_{raw['fromNodeId']}_{raw['toNodeId']}"
|
||||
if raw.get("condition") is None:
|
||||
if raw.get("label"):
|
||||
raw["condition"] = {"type": "contains", "source": "user", "value": str(raw["label"])}
|
||||
else:
|
||||
raw["condition"] = {"type": "always"}
|
||||
return raw
|
||||
|
||||
|
||||
class WorkflowBase(BaseModel):
|
||||
@@ -436,29 +494,85 @@ class WorkflowBase(BaseModel):
|
||||
createdAt: str = ""
|
||||
updatedAt: str = ""
|
||||
globalPrompt: Optional[str] = None
|
||||
nodes: List[dict] = []
|
||||
edges: List[dict] = []
|
||||
nodes: List[WorkflowNode] = Field(default_factory=list)
|
||||
edges: List[WorkflowEdge] = Field(default_factory=list)
|
||||
|
||||
|
||||
class WorkflowCreate(WorkflowBase):
|
||||
pass
|
||||
@model_validator(mode="after")
|
||||
def _validate_graph(self) -> "WorkflowCreate":
|
||||
_validate_workflow_graph(self.nodes, self.edges)
|
||||
return self
|
||||
|
||||
|
||||
class WorkflowUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
nodeCount: Optional[int] = None
|
||||
nodes: Optional[List[dict]] = None
|
||||
edges: Optional[List[dict]] = None
|
||||
nodes: Optional[List[WorkflowNode]] = None
|
||||
edges: Optional[List[WorkflowEdge]] = None
|
||||
globalPrompt: Optional[str] = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_partial_graph(self) -> "WorkflowUpdate":
|
||||
if self.nodes is not None and self.edges is not None:
|
||||
_validate_workflow_graph(self.nodes, self.edges)
|
||||
return self
|
||||
|
||||
|
||||
class WorkflowOut(WorkflowBase):
|
||||
id: str
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _normalize_db_fields(cls, data: Any) -> Any:
|
||||
if isinstance(data, dict):
|
||||
raw = dict(data)
|
||||
else:
|
||||
raw = {
|
||||
"id": getattr(data, "id", None),
|
||||
"name": getattr(data, "name", None),
|
||||
"node_count": getattr(data, "node_count", None),
|
||||
"created_at": getattr(data, "created_at", None),
|
||||
"updated_at": getattr(data, "updated_at", None),
|
||||
"global_prompt": getattr(data, "global_prompt", None),
|
||||
"nodes": getattr(data, "nodes", None),
|
||||
"edges": getattr(data, "edges", None),
|
||||
}
|
||||
|
||||
if "nodeCount" not in raw and raw.get("node_count") is not None:
|
||||
raw["nodeCount"] = raw["node_count"]
|
||||
if "createdAt" not in raw and raw.get("created_at") is not None:
|
||||
raw["createdAt"] = raw["created_at"]
|
||||
if "updatedAt" not in raw and raw.get("updated_at") is not None:
|
||||
raw["updatedAt"] = raw["updated_at"]
|
||||
if "globalPrompt" not in raw and raw.get("global_prompt") is not None:
|
||||
raw["globalPrompt"] = raw["global_prompt"]
|
||||
return raw
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
def _validate_workflow_graph(nodes: List[WorkflowNode], edges: List[WorkflowEdge]) -> None:
|
||||
if not nodes:
|
||||
raise ValueError("Workflow must include at least one node")
|
||||
|
||||
node_ids = [node.id for node in nodes if node.id]
|
||||
if len(node_ids) != len(set(node_ids)):
|
||||
raise ValueError("Workflow node ids must be unique")
|
||||
|
||||
starts = [node for node in nodes if node.isStart or node.type == "start"]
|
||||
if not starts:
|
||||
raise ValueError("Workflow must define a start node (isStart=true or type=start)")
|
||||
|
||||
known = set(node_ids)
|
||||
for edge in edges:
|
||||
if edge.fromNodeId not in known:
|
||||
raise ValueError(f"Workflow edge fromNodeId not found: {edge.fromNodeId}")
|
||||
if edge.toNodeId not in known:
|
||||
raise ValueError(f"Workflow edge toNodeId not found: {edge.toNodeId}")
|
||||
|
||||
|
||||
# ============ Call Record ============
|
||||
class TranscriptSegment(BaseModel):
|
||||
turnIndex: int
|
||||
|
||||
167
api/tests/test_workflows.py
Normal file
167
api/tests/test_workflows.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""Tests for workflow graph schema and router behavior."""
|
||||
|
||||
|
||||
class TestWorkflowAPI:
|
||||
"""Workflow CRUD and graph validation test cases."""
|
||||
|
||||
def _minimal_nodes(self):
|
||||
return [
|
||||
{
|
||||
"id": "start_1",
|
||||
"name": "start_1",
|
||||
"type": "start",
|
||||
"isStart": True,
|
||||
"metadata": {"position": {"x": 80, "y": 80}},
|
||||
},
|
||||
{
|
||||
"id": "assistant_1",
|
||||
"name": "assistant_1",
|
||||
"type": "assistant",
|
||||
"metadata": {"position": {"x": 280, "y": 80}},
|
||||
"prompt": "You are the first assistant node.",
|
||||
},
|
||||
]
|
||||
|
||||
def test_create_workflow_with_canonical_graph(self, client):
|
||||
payload = {
|
||||
"name": "Canonical Graph",
|
||||
"nodes": self._minimal_nodes(),
|
||||
"edges": [
|
||||
{
|
||||
"id": "edge_start_assistant",
|
||||
"fromNodeId": "start_1",
|
||||
"toNodeId": "assistant_1",
|
||||
"condition": {"type": "always"},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
resp = client.post("/api/workflows", json=payload)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["name"] == "Canonical Graph"
|
||||
assert data["nodeCount"] == 2
|
||||
assert data["nodes"][0]["id"] == "start_1"
|
||||
assert data["edges"][0]["fromNodeId"] == "start_1"
|
||||
assert data["edges"][0]["toNodeId"] == "assistant_1"
|
||||
|
||||
def test_create_workflow_with_legacy_graph(self, client):
|
||||
payload = {
|
||||
"name": "Legacy Graph",
|
||||
"nodes": [
|
||||
{
|
||||
"name": "legacy_start",
|
||||
"type": "conversation",
|
||||
"isStart": True,
|
||||
"metadata": {"position": {"x": 100, "y": 100}},
|
||||
},
|
||||
{
|
||||
"name": "legacy_human",
|
||||
"type": "human",
|
||||
"metadata": {"position": {"x": 300, "y": 100}},
|
||||
},
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"from": "legacy_start",
|
||||
"to": "legacy_human",
|
||||
"label": "人工",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
resp = client.post("/api/workflows", json=payload)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["nodes"][0]["type"] == "assistant"
|
||||
assert data["nodes"][1]["type"] == "human_transfer"
|
||||
assert data["edges"][0]["fromNodeId"] == "legacy_start"
|
||||
assert data["edges"][0]["toNodeId"] == "legacy_human"
|
||||
assert data["edges"][0]["condition"]["type"] == "contains"
|
||||
|
||||
def test_create_workflow_without_start_node_fails(self, client):
|
||||
payload = {
|
||||
"name": "No Start",
|
||||
"nodes": [
|
||||
{"id": "node_1", "name": "node_1", "type": "assistant", "metadata": {"position": {"x": 0, "y": 0}}},
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
resp = client.post("/api/workflows", json=payload)
|
||||
assert resp.status_code == 422
|
||||
|
||||
def test_create_workflow_with_invalid_edge_fails(self, client):
|
||||
payload = {
|
||||
"name": "Bad Edge",
|
||||
"nodes": self._minimal_nodes(),
|
||||
"edges": [
|
||||
{"id": "edge_bad", "fromNodeId": "missing", "toNodeId": "assistant_1", "condition": {"type": "always"}},
|
||||
],
|
||||
}
|
||||
resp = client.post("/api/workflows", json=payload)
|
||||
assert resp.status_code == 422
|
||||
|
||||
def test_update_workflow_nodes_and_edges(self, client):
|
||||
create_payload = {
|
||||
"name": "Before Update",
|
||||
"nodes": self._minimal_nodes(),
|
||||
"edges": [
|
||||
{
|
||||
"id": "edge_start_assistant",
|
||||
"fromNodeId": "start_1",
|
||||
"toNodeId": "assistant_1",
|
||||
"condition": {"type": "always"},
|
||||
}
|
||||
],
|
||||
}
|
||||
create_resp = client.post("/api/workflows", json=create_payload)
|
||||
assert create_resp.status_code == 200
|
||||
workflow_id = create_resp.json()["id"]
|
||||
|
||||
update_payload = {
|
||||
"name": "After Update",
|
||||
"nodes": [
|
||||
{
|
||||
"id": "start_1",
|
||||
"name": "start_1",
|
||||
"type": "start",
|
||||
"isStart": True,
|
||||
"metadata": {"position": {"x": 50, "y": 50}},
|
||||
},
|
||||
{
|
||||
"id": "assistant_2",
|
||||
"name": "assistant_2",
|
||||
"type": "assistant",
|
||||
"metadata": {"position": {"x": 250, "y": 50}},
|
||||
"prompt": "new prompt",
|
||||
},
|
||||
{
|
||||
"id": "end_1",
|
||||
"name": "end_1",
|
||||
"type": "end",
|
||||
"metadata": {"position": {"x": 450, "y": 50}},
|
||||
},
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "edge_start_assistant2",
|
||||
"fromNodeId": "start_1",
|
||||
"toNodeId": "assistant_2",
|
||||
"condition": {"type": "always"},
|
||||
},
|
||||
{
|
||||
"id": "edge_assistant2_end",
|
||||
"fromNodeId": "assistant_2",
|
||||
"toNodeId": "end_1",
|
||||
"condition": {"type": "contains", "source": "user", "value": "结束"},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
update_resp = client.put(f"/api/workflows/{workflow_id}", json=update_payload)
|
||||
assert update_resp.status_code == 200
|
||||
updated = update_resp.json()
|
||||
assert updated["name"] == "After Update"
|
||||
assert updated["nodeCount"] == 3
|
||||
assert len(updated["nodes"]) == 3
|
||||
assert len(updated["edges"]) == 2
|
||||
Reference in New Issue
Block a user