Files
AI-VideoAssistant/api/app/routers/workflows.py
2026-02-10 08:12:46 +08:00

113 lines
4.0 KiB
Python

from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
import uuid
from datetime import datetime
from typing import Any, Dict, List, Tuple
from ..db import get_db
from ..models import Workflow
from ..schemas import WorkflowCreate, WorkflowUpdate, WorkflowOut, WorkflowNode, WorkflowEdge
router = APIRouter(prefix="/workflows", tags=["Workflows"])
def _normalize_graph_payload(nodes: List[Any], edges: List[Any]) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
"""Normalize graph payload to canonical dict structures."""
parsed_nodes: List[WorkflowNode] = []
for node in nodes:
parsed_nodes.append(node if isinstance(node, WorkflowNode) else WorkflowNode.model_validate(node))
parsed_edges: List[WorkflowEdge] = []
for edge in edges:
parsed_edges.append(edge if isinstance(edge, WorkflowEdge) else WorkflowEdge.model_validate(edge))
normalized_nodes = [node.model_dump() for node in parsed_nodes]
normalized_edges = [edge.model_dump() for edge in parsed_edges]
return normalized_nodes, normalized_edges
@router.get("")
def list_workflows(
page: int = 1,
limit: int = 50,
db: Session = Depends(get_db)
):
"""获取工作流列表"""
query = db.query(Workflow)
total = query.count()
workflows = query.order_by(Workflow.created_at.desc()) \
.offset((page - 1) * limit).limit(limit).all()
return {"total": total, "page": page, "limit": limit, "list": workflows}
@router.post("", response_model=WorkflowOut)
def create_workflow(data: WorkflowCreate, db: Session = Depends(get_db)):
"""创建工作流"""
nodes, edges = _normalize_graph_payload(data.nodes, data.edges)
workflow = Workflow(
id=str(uuid.uuid4())[:8],
user_id=1,
name=data.name,
node_count=data.nodeCount or len(nodes),
created_at=data.createdAt or datetime.utcnow().isoformat(),
updated_at=data.updatedAt or "",
global_prompt=data.globalPrompt,
nodes=nodes,
edges=edges,
)
db.add(workflow)
db.commit()
db.refresh(workflow)
return workflow
@router.get("/{id}", response_model=WorkflowOut)
def get_workflow(id: str, db: Session = Depends(get_db)):
"""获取单个工作流"""
workflow = db.query(Workflow).filter(Workflow.id == id).first()
if not workflow:
raise HTTPException(status_code=404, detail="Workflow not found")
return workflow
@router.put("/{id}", response_model=WorkflowOut)
def update_workflow(id: str, data: WorkflowUpdate, db: Session = Depends(get_db)):
"""更新工作流"""
workflow = db.query(Workflow).filter(Workflow.id == id).first()
if not workflow:
raise HTTPException(status_code=404, detail="Workflow not found")
update_data = data.model_dump(exclude_unset=True, exclude={"nodes", "edges"})
field_map = {
"nodeCount": "node_count",
"globalPrompt": "global_prompt",
}
for field, value in update_data.items():
setattr(workflow, field_map.get(field, field), value)
if data.nodes is not None or data.edges is not None:
existing_nodes = workflow.nodes if isinstance(workflow.nodes, list) else []
existing_edges = workflow.edges if isinstance(workflow.edges, list) else []
input_nodes = data.nodes if data.nodes is not None else existing_nodes
input_edges = data.edges if data.edges is not None else existing_edges
nodes, edges = _normalize_graph_payload(input_nodes, input_edges)
workflow.nodes = nodes
workflow.edges = edges
workflow.node_count = len(nodes)
workflow.updated_at = datetime.utcnow().isoformat()
db.commit()
db.refresh(workflow)
return workflow
@router.delete("/{id}")
def delete_workflow(id: str, db: Session = Depends(get_db)):
"""删除工作流"""
workflow = db.query(Workflow).filter(Workflow.id == id).first()
if not workflow:
raise HTTPException(status_code=404, detail="Workflow not found")
db.delete(workflow)
db.commit()
return {"message": "Deleted successfully"}