Compare commits

...

3 Commits

Author SHA1 Message Date
Xin Wang
bbeffa89ed Update workflow feature with codex 2026-02-10 08:12:46 +08:00
Xin Wang
6b4391c423 Implement KB features with codex 2026-02-10 07:35:08 +08:00
Xin Wang
ed1f7fc8b0 TTS model from select to input 2026-02-10 00:27:59 +08:00
14 changed files with 1556 additions and 55 deletions

View File

@@ -68,6 +68,14 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> dict:
}
warnings.append(f"Voice resource not found: {assistant.voice}")
if assistant.knowledge_base_id:
metadata["knowledgeBaseId"] = assistant.knowledge_base_id
metadata["knowledge"] = {
"enabled": True,
"kbId": assistant.knowledge_base_id,
"nResults": 5,
}
return {
"assistantId": assistant.id,
"sessionStartMetadata": metadata,
@@ -75,6 +83,7 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> dict:
"llmModelId": assistant.llm_model_id,
"asrModelId": assistant.asr_model_id,
"voiceId": assistant.voice,
"knowledgeBaseId": assistant.knowledge_base_id,
},
"warnings": warnings,
}

View File

@@ -2,14 +2,30 @@ from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
import uuid
from datetime import datetime
from typing import Any, Dict, List, Tuple
from ..db import get_db
from ..models import Workflow
from ..schemas import WorkflowCreate, WorkflowUpdate, WorkflowOut
from ..schemas import WorkflowCreate, WorkflowUpdate, WorkflowOut, WorkflowNode, WorkflowEdge
router = APIRouter(prefix="/workflows", tags=["Workflows"])
def _normalize_graph_payload(nodes: List[Any], edges: List[Any]) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
"""Normalize graph payload to canonical dict structures."""
parsed_nodes: List[WorkflowNode] = []
for node in nodes:
parsed_nodes.append(node if isinstance(node, WorkflowNode) else WorkflowNode.model_validate(node))
parsed_edges: List[WorkflowEdge] = []
for edge in edges:
parsed_edges.append(edge if isinstance(edge, WorkflowEdge) else WorkflowEdge.model_validate(edge))
normalized_nodes = [node.model_dump() for node in parsed_nodes]
normalized_edges = [edge.model_dump() for edge in parsed_edges]
return normalized_nodes, normalized_edges
@router.get("")
def list_workflows(
page: int = 1,
@@ -27,16 +43,17 @@ def list_workflows(
@router.post("", response_model=WorkflowOut)
def create_workflow(data: WorkflowCreate, db: Session = Depends(get_db)):
"""创建工作流"""
nodes, edges = _normalize_graph_payload(data.nodes, data.edges)
workflow = Workflow(
id=str(uuid.uuid4())[:8],
user_id=1,
name=data.name,
node_count=data.nodeCount,
node_count=data.nodeCount or len(nodes),
created_at=data.createdAt or datetime.utcnow().isoformat(),
updated_at=data.updatedAt or "",
global_prompt=data.globalPrompt,
nodes=data.nodes,
edges=data.edges,
nodes=nodes,
edges=edges,
)
db.add(workflow)
db.commit()
@@ -60,7 +77,7 @@ def update_workflow(id: str, data: WorkflowUpdate, db: Session = Depends(get_db)
if not workflow:
raise HTTPException(status_code=404, detail="Workflow not found")
update_data = data.model_dump(exclude_unset=True)
update_data = data.model_dump(exclude_unset=True, exclude={"nodes", "edges"})
field_map = {
"nodeCount": "node_count",
"globalPrompt": "global_prompt",
@@ -68,6 +85,16 @@ def update_workflow(id: str, data: WorkflowUpdate, db: Session = Depends(get_db)
for field, value in update_data.items():
setattr(workflow, field_map.get(field, field), value)
if data.nodes is not None or data.edges is not None:
existing_nodes = workflow.nodes if isinstance(workflow.nodes, list) else []
existing_edges = workflow.edges if isinstance(workflow.edges, list) else []
input_nodes = data.nodes if data.nodes is not None else existing_nodes
input_edges = data.edges if data.edges is not None else existing_edges
nodes, edges = _normalize_graph_payload(input_nodes, input_edges)
workflow.nodes = nodes
workflow.edges = edges
workflow.node_count = len(nodes)
workflow.updated_at = datetime.utcnow().isoformat()
db.commit()
db.refresh(workflow)

View File

@@ -1,7 +1,7 @@
from datetime import datetime
from enum import Enum
from typing import List, Optional
from pydantic import BaseModel
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, ConfigDict, Field, model_validator
# ============ Enums ============
@@ -410,24 +410,82 @@ class KnowledgeStats(BaseModel):
# ============ Workflow ============
class WorkflowNode(BaseModel):
name: str
type: str
model_config = ConfigDict(extra="allow")
id: Optional[str] = None
name: str = ""
type: str = "assistant"
isStart: Optional[bool] = None
metadata: dict
metadata: Dict[str, Any] = Field(default_factory=dict)
prompt: Optional[str] = None
messagePlan: Optional[dict] = None
variableExtractionPlan: Optional[dict] = None
tool: Optional[dict] = None
globalNodePlan: Optional[dict] = None
messagePlan: Optional[Dict[str, Any]] = None
variableExtractionPlan: Optional[Dict[str, Any]] = None
tool: Optional[Dict[str, Any]] = None
globalNodePlan: Optional[Dict[str, Any]] = None
assistantId: Optional[str] = None
assistant: Optional[Dict[str, Any]] = None
@model_validator(mode="before")
@classmethod
def _normalize_legacy_node(cls, data: Any) -> Any:
if not isinstance(data, dict):
return data
raw = dict(data)
node_id = raw.get("id") or raw.get("name")
if not node_id:
node_id = f"node_{abs(hash(str(raw))) % 100000}"
raw["id"] = str(node_id)
raw["name"] = str(raw.get("name") or raw["id"])
node_type = str(raw.get("type") or "assistant").lower()
if node_type == "conversation":
node_type = "assistant"
elif node_type == "human":
node_type = "human_transfer"
elif node_type not in {"start", "assistant", "tool", "human_transfer", "end"}:
node_type = "assistant"
raw["type"] = node_type
metadata = raw.get("metadata")
if not isinstance(metadata, dict):
metadata = {}
if "position" not in metadata and isinstance(raw.get("position"), dict):
metadata["position"] = raw.get("position")
raw["metadata"] = metadata
if raw.get("isStart") is None and node_type == "start":
raw["isStart"] = True
return raw
class WorkflowEdge(BaseModel):
from_: str
to: str
label: Optional[str] = None
model_config = ConfigDict(extra="allow")
class Config:
populate_by_name = True
id: Optional[str] = None
fromNodeId: str
toNodeId: str
label: Optional[str] = None
condition: Optional[Dict[str, Any]] = None
priority: int = 100
@model_validator(mode="before")
@classmethod
def _normalize_legacy_edge(cls, data: Any) -> Any:
if not isinstance(data, dict):
return data
raw = dict(data)
from_node = raw.get("fromNodeId") or raw.get("from") or raw.get("from_") or raw.get("source")
to_node = raw.get("toNodeId") or raw.get("to") or raw.get("target")
raw["fromNodeId"] = str(from_node or "")
raw["toNodeId"] = str(to_node or "")
if raw.get("id") is None:
raw["id"] = f"e_{raw['fromNodeId']}_{raw['toNodeId']}"
if raw.get("condition") is None:
if raw.get("label"):
raw["condition"] = {"type": "contains", "source": "user", "value": str(raw["label"])}
else:
raw["condition"] = {"type": "always"}
return raw
class WorkflowBase(BaseModel):
@@ -436,29 +494,85 @@ class WorkflowBase(BaseModel):
createdAt: str = ""
updatedAt: str = ""
globalPrompt: Optional[str] = None
nodes: List[dict] = []
edges: List[dict] = []
nodes: List[WorkflowNode] = Field(default_factory=list)
edges: List[WorkflowEdge] = Field(default_factory=list)
class WorkflowCreate(WorkflowBase):
pass
@model_validator(mode="after")
def _validate_graph(self) -> "WorkflowCreate":
_validate_workflow_graph(self.nodes, self.edges)
return self
class WorkflowUpdate(BaseModel):
name: Optional[str] = None
nodeCount: Optional[int] = None
nodes: Optional[List[dict]] = None
edges: Optional[List[dict]] = None
nodes: Optional[List[WorkflowNode]] = None
edges: Optional[List[WorkflowEdge]] = None
globalPrompt: Optional[str] = None
@model_validator(mode="after")
def _validate_partial_graph(self) -> "WorkflowUpdate":
if self.nodes is not None and self.edges is not None:
_validate_workflow_graph(self.nodes, self.edges)
return self
class WorkflowOut(WorkflowBase):
id: str
@model_validator(mode="before")
@classmethod
def _normalize_db_fields(cls, data: Any) -> Any:
if isinstance(data, dict):
raw = dict(data)
else:
raw = {
"id": getattr(data, "id", None),
"name": getattr(data, "name", None),
"node_count": getattr(data, "node_count", None),
"created_at": getattr(data, "created_at", None),
"updated_at": getattr(data, "updated_at", None),
"global_prompt": getattr(data, "global_prompt", None),
"nodes": getattr(data, "nodes", None),
"edges": getattr(data, "edges", None),
}
if "nodeCount" not in raw and raw.get("node_count") is not None:
raw["nodeCount"] = raw["node_count"]
if "createdAt" not in raw and raw.get("created_at") is not None:
raw["createdAt"] = raw["created_at"]
if "updatedAt" not in raw and raw.get("updated_at") is not None:
raw["updatedAt"] = raw["updated_at"]
if "globalPrompt" not in raw and raw.get("global_prompt") is not None:
raw["globalPrompt"] = raw["global_prompt"]
return raw
class Config:
from_attributes = True
def _validate_workflow_graph(nodes: List[WorkflowNode], edges: List[WorkflowEdge]) -> None:
if not nodes:
raise ValueError("Workflow must include at least one node")
node_ids = [node.id for node in nodes if node.id]
if len(node_ids) != len(set(node_ids)):
raise ValueError("Workflow node ids must be unique")
starts = [node for node in nodes if node.isStart or node.type == "start"]
if not starts:
raise ValueError("Workflow must define a start node (isStart=true or type=start)")
known = set(node_ids)
for edge in edges:
if edge.fromNodeId not in known:
raise ValueError(f"Workflow edge fromNodeId not found: {edge.fromNodeId}")
if edge.toNodeId not in known:
raise ValueError(f"Workflow edge toNodeId not found: {edge.toNodeId}")
# ============ Call Record ============
class TranscriptSegment(BaseModel):
turnIndex: int

View File

@@ -119,6 +119,14 @@ class TestAssistantAPI:
assert response.status_code == 200
assert response.json()["knowledgeBaseId"] == "non-existent-kb"
assistant_id = response.json()["id"]
runtime_resp = client.get(f"/api/assistants/{assistant_id}/runtime-config")
assert runtime_resp.status_code == 200
metadata = runtime_resp.json()["sessionStartMetadata"]
assert metadata["knowledgeBaseId"] == "non-existent-kb"
assert metadata["knowledge"]["enabled"] is True
assert metadata["knowledge"]["kbId"] == "non-existent-kb"
def test_assistant_with_model_references(self, client, sample_assistant_data):
"""Test creating assistant with model references"""
sample_assistant_data.update({

167
api/tests/test_workflows.py Normal file
View File

@@ -0,0 +1,167 @@
"""Tests for workflow graph schema and router behavior."""
class TestWorkflowAPI:
"""Workflow CRUD and graph validation test cases."""
def _minimal_nodes(self):
return [
{
"id": "start_1",
"name": "start_1",
"type": "start",
"isStart": True,
"metadata": {"position": {"x": 80, "y": 80}},
},
{
"id": "assistant_1",
"name": "assistant_1",
"type": "assistant",
"metadata": {"position": {"x": 280, "y": 80}},
"prompt": "You are the first assistant node.",
},
]
def test_create_workflow_with_canonical_graph(self, client):
payload = {
"name": "Canonical Graph",
"nodes": self._minimal_nodes(),
"edges": [
{
"id": "edge_start_assistant",
"fromNodeId": "start_1",
"toNodeId": "assistant_1",
"condition": {"type": "always"},
}
],
}
resp = client.post("/api/workflows", json=payload)
assert resp.status_code == 200
data = resp.json()
assert data["name"] == "Canonical Graph"
assert data["nodeCount"] == 2
assert data["nodes"][0]["id"] == "start_1"
assert data["edges"][0]["fromNodeId"] == "start_1"
assert data["edges"][0]["toNodeId"] == "assistant_1"
def test_create_workflow_with_legacy_graph(self, client):
payload = {
"name": "Legacy Graph",
"nodes": [
{
"name": "legacy_start",
"type": "conversation",
"isStart": True,
"metadata": {"position": {"x": 100, "y": 100}},
},
{
"name": "legacy_human",
"type": "human",
"metadata": {"position": {"x": 300, "y": 100}},
},
],
"edges": [
{
"from": "legacy_start",
"to": "legacy_human",
"label": "人工",
}
],
}
resp = client.post("/api/workflows", json=payload)
assert resp.status_code == 200
data = resp.json()
assert data["nodes"][0]["type"] == "assistant"
assert data["nodes"][1]["type"] == "human_transfer"
assert data["edges"][0]["fromNodeId"] == "legacy_start"
assert data["edges"][0]["toNodeId"] == "legacy_human"
assert data["edges"][0]["condition"]["type"] == "contains"
def test_create_workflow_without_start_node_fails(self, client):
payload = {
"name": "No Start",
"nodes": [
{"id": "node_1", "name": "node_1", "type": "assistant", "metadata": {"position": {"x": 0, "y": 0}}},
],
"edges": [],
}
resp = client.post("/api/workflows", json=payload)
assert resp.status_code == 422
def test_create_workflow_with_invalid_edge_fails(self, client):
payload = {
"name": "Bad Edge",
"nodes": self._minimal_nodes(),
"edges": [
{"id": "edge_bad", "fromNodeId": "missing", "toNodeId": "assistant_1", "condition": {"type": "always"}},
],
}
resp = client.post("/api/workflows", json=payload)
assert resp.status_code == 422
def test_update_workflow_nodes_and_edges(self, client):
create_payload = {
"name": "Before Update",
"nodes": self._minimal_nodes(),
"edges": [
{
"id": "edge_start_assistant",
"fromNodeId": "start_1",
"toNodeId": "assistant_1",
"condition": {"type": "always"},
}
],
}
create_resp = client.post("/api/workflows", json=create_payload)
assert create_resp.status_code == 200
workflow_id = create_resp.json()["id"]
update_payload = {
"name": "After Update",
"nodes": [
{
"id": "start_1",
"name": "start_1",
"type": "start",
"isStart": True,
"metadata": {"position": {"x": 50, "y": 50}},
},
{
"id": "assistant_2",
"name": "assistant_2",
"type": "assistant",
"metadata": {"position": {"x": 250, "y": 50}},
"prompt": "new prompt",
},
{
"id": "end_1",
"name": "end_1",
"type": "end",
"metadata": {"position": {"x": 450, "y": 50}},
},
],
"edges": [
{
"id": "edge_start_assistant2",
"fromNodeId": "start_1",
"toNodeId": "assistant_2",
"condition": {"type": "always"},
},
{
"id": "edge_assistant2_end",
"fromNodeId": "assistant_2",
"toNodeId": "end_1",
"condition": {"type": "contains", "source": "user", "value": "结束"},
},
],
}
update_resp = client.put(f"/api/workflows/{workflow_id}", json=update_payload)
assert update_resp.status_code == 200
updated = update_resp.json()
assert updated["name"] == "After Update"
assert updated["nodeCount"] == 3
assert len(updated["nodes"]) == 3
assert len(updated["edges"]) == 2

View File

@@ -2,7 +2,7 @@
from __future__ import annotations
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional
import aiohttp
from loguru import logger
@@ -146,3 +146,46 @@ async def finalize_history_call_record(
except Exception as exc:
logger.warning(f"Failed to finalize history call record ({call_id}): {exc}")
return False
async def search_knowledge_context(
*,
kb_id: str,
query: str,
n_results: int = 5,
) -> List[Dict[str, Any]]:
"""Search backend knowledge base and return retrieval results."""
base_url = _backend_base_url()
if not base_url:
return []
if not kb_id or not query.strip():
return []
try:
safe_n_results = max(1, int(n_results))
except (TypeError, ValueError):
safe_n_results = 5
url = f"{base_url}/api/knowledge/search"
payload: Dict[str, Any] = {
"kb_id": kb_id,
"query": query,
"nResults": safe_n_results,
}
try:
async with aiohttp.ClientSession(timeout=_timeout()) as session:
async with session.post(url, json=payload) as resp:
if resp.status == 404:
logger.warning(f"Knowledge base not found for retrieval: {kb_id}")
return []
resp.raise_for_status()
data = await resp.json()
if not isinstance(data, dict):
return []
results = data.get("results", [])
if not isinstance(results, list):
return []
return [r for r in results if isinstance(r, dict)]
except Exception as exc:
logger.warning(f"Knowledge search failed (kb_id={kb_id}): {exc}")
return []

View File

@@ -13,11 +13,12 @@ event-driven design.
import asyncio
import time
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
from loguru import logger
from app.backend_client import search_knowledge_context
from app.config import settings
from core.conversation import ConversationManager, ConversationState
from core.events import get_event_bus
@@ -26,7 +27,7 @@ from models.ws_v1 import ev
from processors.eou import EouDetector
from processors.vad import SileroVAD, VADProcessor
from services.asr import BufferedASRService
from services.base import BaseASRService, BaseLLMService, BaseTTSService
from services.base import BaseASRService, BaseLLMService, BaseTTSService, LLMMessage
from services.llm import MockLLMService, OpenAILLMService
from services.siliconflow_asr import SiliconFlowASRService
from services.siliconflow_tts import SiliconFlowTTSService
@@ -55,6 +56,9 @@ class DuplexPipeline:
_SENTENCE_TRAILING_CHARS = frozenset({"", "", "", ".", "!", "?", "", "~", "", "\n"})
_SENTENCE_CLOSERS = frozenset({'"', "'", "", "", ")", "]", "}", "", "", "", "", ""})
_MIN_SPLIT_SPOKEN_CHARS = 6
_RAG_DEFAULT_RESULTS = 5
_RAG_MAX_RESULTS = 8
_RAG_MAX_CONTEXT_CHARS = 4000
def __init__(
self,
@@ -156,6 +160,8 @@ class DuplexPipeline:
self._runtime_tts: Dict[str, Any] = {}
self._runtime_system_prompt: Optional[str] = None
self._runtime_greeting: Optional[str] = None
self._runtime_knowledge: Dict[str, Any] = {}
self._runtime_knowledge_base_id: Optional[str] = None
logger.info(f"DuplexPipeline initialized for session {session_id}")
@@ -194,6 +200,18 @@ class DuplexPipeline:
if isinstance(services.get("tts"), dict):
self._runtime_tts = services["tts"]
knowledge_base_id = metadata.get("knowledgeBaseId")
if knowledge_base_id is not None:
kb_id = str(knowledge_base_id).strip()
self._runtime_knowledge_base_id = kb_id or None
knowledge = metadata.get("knowledge")
if isinstance(knowledge, dict):
self._runtime_knowledge = knowledge
kb_id = str(knowledge.get("kbId") or knowledge.get("knowledgeBaseId") or "").strip()
if kb_id:
self._runtime_knowledge_base_id = kb_id
async def start(self) -> None:
"""Start the pipeline and connect services."""
try:
@@ -552,6 +570,103 @@ class DuplexPipeline:
await self.conversation.end_user_turn(user_text)
self._current_turn_task = asyncio.create_task(self._handle_turn(user_text))
@staticmethod
def _coerce_int(value: Any, default: int) -> int:
try:
return int(value)
except (TypeError, ValueError):
return default
def _resolve_runtime_kb_id(self) -> Optional[str]:
if self._runtime_knowledge_base_id:
return self._runtime_knowledge_base_id
kb_id = str(self._runtime_knowledge.get("kbId") or self._runtime_knowledge.get("knowledgeBaseId") or "").strip()
return kb_id or None
def _build_knowledge_prompt(self, results: List[Dict[str, Any]]) -> Optional[str]:
if not results:
return None
lines = [
"You have retrieved the following knowledge base snippets.",
"Use them only when relevant to the latest user request.",
"If snippets are insufficient, say you are not sure instead of guessing.",
"",
]
used_chars = 0
used_count = 0
for item in results:
content = str(item.get("content") or "").strip()
if not content:
continue
if used_chars >= self._RAG_MAX_CONTEXT_CHARS:
break
metadata = item.get("metadata") if isinstance(item.get("metadata"), dict) else {}
doc_id = metadata.get("document_id")
chunk_index = metadata.get("chunk_index")
distance = item.get("distance")
source_parts = []
if doc_id:
source_parts.append(f"doc={doc_id}")
if chunk_index is not None:
source_parts.append(f"chunk={chunk_index}")
source = f" ({', '.join(source_parts)})" if source_parts else ""
distance_text = ""
try:
if distance is not None:
distance_text = f", distance={float(distance):.4f}"
except (TypeError, ValueError):
distance_text = ""
remaining = self._RAG_MAX_CONTEXT_CHARS - used_chars
snippet = content[:remaining].strip()
if not snippet:
continue
used_count += 1
lines.append(f"[{used_count}{source}{distance_text}] {snippet}")
used_chars += len(snippet)
if used_count == 0:
return None
return "\n".join(lines)
async def _build_turn_messages(self, user_text: str) -> List[LLMMessage]:
messages = self.conversation.get_messages()
kb_id = self._resolve_runtime_kb_id()
if not kb_id:
return messages
knowledge_cfg = self._runtime_knowledge if isinstance(self._runtime_knowledge, dict) else {}
enabled = knowledge_cfg.get("enabled", True)
if isinstance(enabled, str):
enabled = enabled.strip().lower() not in {"false", "0", "off", "no"}
if not enabled:
return messages
n_results = self._coerce_int(knowledge_cfg.get("nResults"), self._RAG_DEFAULT_RESULTS)
n_results = max(1, min(n_results, self._RAG_MAX_RESULTS))
results = await search_knowledge_context(
kb_id=kb_id,
query=user_text,
n_results=n_results,
)
prompt = self._build_knowledge_prompt(results)
if not prompt:
return messages
logger.debug(f"RAG context injected (kb_id={kb_id}, chunks={len(results)})")
rag_system = LLMMessage(role="system", content=prompt)
if messages and messages[0].role == "system":
return [messages[0], rag_system, *messages[1:]]
return [rag_system, *messages]
async def _handle_turn(self, user_text: str) -> None:
"""
Handle a complete conversation turn.
@@ -567,7 +682,7 @@ class DuplexPipeline:
self._first_audio_sent = False
# Get AI response (streaming)
messages = self.conversation.get_messages()
messages = await self._build_turn_messages(user_text)
full_response = ""
await self.conversation.start_assistant_turn()

View File

@@ -4,8 +4,9 @@ import asyncio
import uuid
import json
import time
import re
from enum import Enum
from typing import Optional, Dict, Any
from typing import Optional, Dict, Any, List
from loguru import logger
from app.backend_client import (
@@ -16,7 +17,9 @@ from app.backend_client import (
from core.transports import BaseTransport
from core.duplex_pipeline import DuplexPipeline
from core.conversation import ConversationTurn
from core.workflow_runner import WorkflowRunner, WorkflowTransition, WorkflowNodeDef, WorkflowEdgeDef
from app.config import settings
from services.base import LLMMessage
from models.ws_v1 import (
parse_client_message,
ev,
@@ -81,6 +84,9 @@ class Session:
self._history_finalized: bool = False
self._cleanup_lock = asyncio.Lock()
self._cleaned_up = False
self.workflow_runner: Optional[WorkflowRunner] = None
self._workflow_last_user_text: str = ""
self._workflow_initial_node: Optional[WorkflowNodeDef] = None
self.pipeline.conversation.on_turn_complete(self._on_turn_complete)
@@ -223,6 +229,7 @@ class Session:
return
metadata = message.metadata or {}
metadata = self._merge_runtime_metadata(metadata, self._bootstrap_workflow(metadata))
# Create history call record early so later turn callbacks can append transcripts.
await self._start_history_bridge(metadata)
@@ -246,6 +253,26 @@ class Session:
audio=message.audio or {},
)
)
if self.workflow_runner and self._workflow_initial_node:
await self.transport.send_event(
ev(
"workflow.started",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
workflowName=self.workflow_runner.name,
nodeId=self._workflow_initial_node.id,
)
)
await self.transport.send_event(
ev(
"workflow.node.entered",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
nodeId=self._workflow_initial_node.id,
nodeName=self._workflow_initial_node.name,
nodeType=self._workflow_initial_node.node_type,
)
)
async def _handle_session_stop(self, reason: Optional[str]) -> None:
"""Handle session stop."""
@@ -334,7 +361,14 @@ class Session:
logger.info(f"Session {self.id} history bridge enabled (call_id={call_id}, source={source})")
async def _on_turn_complete(self, turn: ConversationTurn) -> None:
"""Persist completed turns to backend call transcripts."""
"""Process workflow transitions and persist completed turns to history."""
if turn.text and turn.text.strip():
role = (turn.role or "").lower()
if role == "user":
self._workflow_last_user_text = turn.text.strip()
elif role == "assistant":
await self._maybe_advance_workflow(turn.text.strip())
if not self._history_call_id:
return
if not turn.text or not turn.text.strip():
@@ -377,3 +411,235 @@ class Session:
)
if ok:
self._history_finalized = True
def _bootstrap_workflow(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
"""Parse workflow payload and return initial runtime overrides."""
payload = metadata.get("workflow")
self.workflow_runner = WorkflowRunner.from_payload(payload)
self._workflow_initial_node = None
if not self.workflow_runner:
return {}
node = self.workflow_runner.bootstrap()
if not node:
logger.warning(f"Session {self.id} workflow payload had no resolvable start node")
self.workflow_runner = None
return {}
self._workflow_initial_node = node
logger.info(
"Session {} workflow enabled: workflow={} start_node={}",
self.id,
self.workflow_runner.workflow_id,
node.id,
)
return self.workflow_runner.build_runtime_metadata(node)
async def _maybe_advance_workflow(self, assistant_text: str) -> None:
"""Attempt node transfer after assistant turn finalization."""
if not self.workflow_runner or self.ws_state == WsSessionState.STOPPED:
return
transition = await self.workflow_runner.route(
user_text=self._workflow_last_user_text,
assistant_text=assistant_text,
llm_router=self._workflow_llm_route,
)
if not transition:
return
await self._apply_workflow_transition(transition, reason="rule_match")
# Auto-advance through utility nodes when default edges are present.
max_auto_hops = 6
auto_hops = 0
while self.workflow_runner and self.ws_state != WsSessionState.STOPPED:
current = self.workflow_runner.current_node
if not current or current.node_type not in {"start", "tool"}:
break
next_default = self.workflow_runner.next_default_transition()
if not next_default:
break
auto_hops += 1
await self._apply_workflow_transition(next_default, reason="auto")
if auto_hops >= max_auto_hops:
logger.warning(
"Session {} workflow auto-advance reached hop limit (possible cycle)",
self.id,
)
break
async def _apply_workflow_transition(self, transition: WorkflowTransition, reason: str) -> None:
"""Apply graph transition and emit workflow lifecycle events."""
if not self.workflow_runner:
return
self.workflow_runner.apply_transition(transition)
node = transition.node
edge = transition.edge
await self.transport.send_event(
ev(
"workflow.edge.taken",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
edgeId=edge.id,
fromNodeId=edge.from_node_id,
toNodeId=edge.to_node_id,
reason=reason,
)
)
await self.transport.send_event(
ev(
"workflow.node.entered",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
nodeId=node.id,
nodeName=node.name,
nodeType=node.node_type,
)
)
node_runtime = self.workflow_runner.build_runtime_metadata(node)
if node_runtime:
self.pipeline.apply_runtime_overrides(node_runtime)
if node.node_type == "tool":
await self.transport.send_event(
ev(
"workflow.tool.requested",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
nodeId=node.id,
tool=node.tool or {},
)
)
return
if node.node_type == "human_transfer":
await self.transport.send_event(
ev(
"workflow.human_transfer",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
nodeId=node.id,
)
)
await self._handle_session_stop("workflow_human_transfer")
return
if node.node_type == "end":
await self.transport.send_event(
ev(
"workflow.ended",
sessionId=self.id,
workflowId=self.workflow_runner.workflow_id,
nodeId=node.id,
)
)
await self._handle_session_stop("workflow_end")
async def _workflow_llm_route(
self,
node: WorkflowNodeDef,
candidates: List[WorkflowEdgeDef],
context: Dict[str, str],
) -> Optional[str]:
"""LLM-based edge routing for condition.type == 'llm' edges."""
llm_service = self.pipeline.llm_service
if not llm_service:
return None
candidate_rows = [
{
"edgeId": edge.id,
"toNodeId": edge.to_node_id,
"label": edge.label,
"hint": edge.condition.get("prompt") if isinstance(edge.condition, dict) else None,
}
for edge in candidates
]
system_prompt = (
"You are a workflow router. Pick exactly one edge. "
"Return JSON only: {\"edgeId\":\"...\"}."
)
user_prompt = json.dumps(
{
"nodeId": node.id,
"nodeName": node.name,
"userText": context.get("userText", ""),
"assistantText": context.get("assistantText", ""),
"candidates": candidate_rows,
},
ensure_ascii=False,
)
try:
reply = await llm_service.generate(
[
LLMMessage(role="system", content=system_prompt),
LLMMessage(role="user", content=user_prompt),
],
temperature=0.0,
max_tokens=64,
)
except Exception as exc:
logger.warning(f"Session {self.id} workflow llm routing failed: {exc}")
return None
if not reply:
return None
edge_ids = {edge.id for edge in candidates}
node_ids = {edge.to_node_id for edge in candidates}
parsed = self._extract_json_obj(reply)
if isinstance(parsed, dict):
edge_id = parsed.get("edgeId") or parsed.get("id")
node_id = parsed.get("toNodeId") or parsed.get("nodeId")
if isinstance(edge_id, str) and edge_id in edge_ids:
return edge_id
if isinstance(node_id, str) and node_id in node_ids:
return node_id
token_candidates = sorted(edge_ids | node_ids, key=len, reverse=True)
lowered_reply = reply.lower()
for token in token_candidates:
if token.lower() in lowered_reply:
return token
return None
def _merge_runtime_metadata(self, base: Dict[str, Any], overrides: Dict[str, Any]) -> Dict[str, Any]:
"""Merge node-level metadata overrides into session.start metadata."""
merged = dict(base or {})
if not overrides:
return merged
for key, value in overrides.items():
if key == "services" and isinstance(value, dict):
existing = merged.get("services")
merged_services = dict(existing) if isinstance(existing, dict) else {}
merged_services.update(value)
merged["services"] = merged_services
else:
merged[key] = value
return merged
def _extract_json_obj(self, text: str) -> Optional[Dict[str, Any]]:
"""Best-effort extraction of a JSON object from freeform text."""
try:
parsed = json.loads(text)
if isinstance(parsed, dict):
return parsed
except Exception:
pass
match = re.search(r"\{.*\}", text, re.DOTALL)
if not match:
return None
try:
parsed = json.loads(match.group(0))
return parsed if isinstance(parsed, dict) else None
except Exception:
return None

View File

@@ -0,0 +1,402 @@
"""Workflow runtime helpers for session-level node routing.
MVP goals:
- Parse workflow graph payload from WS session.start metadata
- Track current node
- Evaluate edge conditions on each assistant turn completion
- Provide per-node runtime metadata overrides (prompt/greeting/services)
"""
from __future__ import annotations
from dataclasses import dataclass
import json
import re
from typing import Any, Awaitable, Callable, Dict, List, Optional
from loguru import logger
_NODE_TYPE_MAP = {
"conversation": "assistant",
"assistant": "assistant",
"human": "human_transfer",
"human_transfer": "human_transfer",
"tool": "tool",
"end": "end",
"start": "start",
}
def _normalize_node_type(raw_type: Any) -> str:
value = str(raw_type or "").strip().lower()
return _NODE_TYPE_MAP.get(value, "assistant")
def _safe_str(value: Any) -> str:
if value is None:
return ""
return str(value)
def _normalize_condition(raw: Any, label: Optional[str]) -> Dict[str, Any]:
if not isinstance(raw, dict):
if label:
return {"type": "contains", "source": "user", "value": str(label)}
return {"type": "always"}
condition = dict(raw)
condition_type = str(condition.get("type", "always")).strip().lower()
if not condition_type:
condition_type = "always"
condition["type"] = condition_type
condition["source"] = str(condition.get("source", "user")).strip().lower() or "user"
return condition
@dataclass
class WorkflowNodeDef:
id: str
name: str
node_type: str
is_start: bool
prompt: Optional[str]
message_plan: Dict[str, Any]
assistant_id: Optional[str]
assistant: Dict[str, Any]
tool: Optional[Dict[str, Any]]
raw: Dict[str, Any]
@dataclass
class WorkflowEdgeDef:
id: str
from_node_id: str
to_node_id: str
label: Optional[str]
condition: Dict[str, Any]
priority: int
order: int
raw: Dict[str, Any]
@dataclass
class WorkflowTransition:
edge: WorkflowEdgeDef
node: WorkflowNodeDef
LlmRouter = Callable[
[WorkflowNodeDef, List[WorkflowEdgeDef], Dict[str, str]],
Awaitable[Optional[str]],
]
class WorkflowRunner:
"""In-memory workflow graph for a single active session."""
def __init__(self, workflow_id: str, name: str, nodes: List[WorkflowNodeDef], edges: List[WorkflowEdgeDef]):
self.workflow_id = workflow_id
self.name = name
self._nodes: Dict[str, WorkflowNodeDef] = {node.id: node for node in nodes}
self._edges = edges
self.current_node_id: Optional[str] = None
@classmethod
def from_payload(cls, payload: Any) -> Optional["WorkflowRunner"]:
if not isinstance(payload, dict):
return None
raw_nodes = payload.get("nodes")
raw_edges = payload.get("edges")
if not isinstance(raw_nodes, list) or len(raw_nodes) == 0:
return None
nodes: List[WorkflowNodeDef] = []
for i, raw in enumerate(raw_nodes):
if not isinstance(raw, dict):
continue
node_id = _safe_str(raw.get("id") or raw.get("name") or f"node_{i + 1}").strip() or f"node_{i + 1}"
node_name = _safe_str(raw.get("name") or node_id).strip() or node_id
node_type = _normalize_node_type(raw.get("type"))
is_start = bool(raw.get("isStart")) or node_type == "start"
prompt: Optional[str] = None
if "prompt" in raw:
prompt = _safe_str(raw.get("prompt"))
message_plan = raw.get("messagePlan")
if not isinstance(message_plan, dict):
message_plan = {}
assistant_cfg = raw.get("assistant")
if not isinstance(assistant_cfg, dict):
assistant_cfg = {}
tool_cfg = raw.get("tool")
if not isinstance(tool_cfg, dict):
tool_cfg = None
assistant_id = raw.get("assistantId")
if assistant_id is not None:
assistant_id = _safe_str(assistant_id).strip() or None
nodes.append(
WorkflowNodeDef(
id=node_id,
name=node_name,
node_type=node_type,
is_start=is_start,
prompt=prompt,
message_plan=message_plan,
assistant_id=assistant_id,
assistant=assistant_cfg,
tool=tool_cfg,
raw=raw,
)
)
if not nodes:
return None
node_ids = {node.id for node in nodes}
edges: List[WorkflowEdgeDef] = []
for i, raw in enumerate(raw_edges if isinstance(raw_edges, list) else []):
if not isinstance(raw, dict):
continue
from_node_id = _safe_str(
raw.get("fromNodeId") or raw.get("from") or raw.get("from_") or raw.get("source")
).strip()
to_node_id = _safe_str(raw.get("toNodeId") or raw.get("to") or raw.get("target")).strip()
if not from_node_id or not to_node_id:
continue
if from_node_id not in node_ids or to_node_id not in node_ids:
continue
label = raw.get("label")
if label is not None:
label = _safe_str(label)
condition = _normalize_condition(raw.get("condition"), label=label)
priority = 100
try:
priority = int(raw.get("priority", 100))
except (TypeError, ValueError):
priority = 100
edge_id = _safe_str(raw.get("id") or f"e_{from_node_id}_{to_node_id}_{i + 1}").strip() or f"e_{i + 1}"
edges.append(
WorkflowEdgeDef(
id=edge_id,
from_node_id=from_node_id,
to_node_id=to_node_id,
label=label,
condition=condition,
priority=priority,
order=i,
raw=raw,
)
)
workflow_id = _safe_str(payload.get("id") or "workflow")
workflow_name = _safe_str(payload.get("name") or workflow_id)
return cls(workflow_id=workflow_id, name=workflow_name, nodes=nodes, edges=edges)
def bootstrap(self) -> Optional[WorkflowNodeDef]:
start_node = self._resolve_start_node()
if not start_node:
return None
self.current_node_id = start_node.id
return start_node
@property
def current_node(self) -> Optional[WorkflowNodeDef]:
if not self.current_node_id:
return None
return self._nodes.get(self.current_node_id)
def outgoing_edges(self, node_id: str) -> List[WorkflowEdgeDef]:
edges = [edge for edge in self._edges if edge.from_node_id == node_id]
return sorted(edges, key=lambda edge: (edge.priority, edge.order))
def next_default_transition(self) -> Optional[WorkflowTransition]:
node = self.current_node
if not node:
return None
for edge in self.outgoing_edges(node.id):
cond_type = str(edge.condition.get("type", "always")).strip().lower()
if cond_type in {"", "always", "default"}:
target = self._nodes.get(edge.to_node_id)
if target:
return WorkflowTransition(edge=edge, node=target)
return None
async def route(
self,
*,
user_text: str,
assistant_text: str,
llm_router: Optional[LlmRouter] = None,
) -> Optional[WorkflowTransition]:
node = self.current_node
if not node:
return None
outgoing = self.outgoing_edges(node.id)
if not outgoing:
return None
llm_edges: List[WorkflowEdgeDef] = []
for edge in outgoing:
cond_type = str(edge.condition.get("type", "always")).strip().lower()
if cond_type == "llm":
llm_edges.append(edge)
continue
if self._matches_condition(edge, user_text=user_text, assistant_text=assistant_text):
target = self._nodes.get(edge.to_node_id)
if target:
return WorkflowTransition(edge=edge, node=target)
if llm_edges and llm_router:
selection = await llm_router(
node,
llm_edges,
{
"userText": user_text,
"assistantText": assistant_text,
},
)
if selection:
for edge in llm_edges:
if selection in {edge.id, edge.to_node_id}:
target = self._nodes.get(edge.to_node_id)
if target:
return WorkflowTransition(edge=edge, node=target)
for edge in outgoing:
cond_type = str(edge.condition.get("type", "always")).strip().lower()
if cond_type in {"", "always", "default"}:
target = self._nodes.get(edge.to_node_id)
if target:
return WorkflowTransition(edge=edge, node=target)
return None
def apply_transition(self, transition: WorkflowTransition) -> None:
self.current_node_id = transition.node.id
def build_runtime_metadata(self, node: WorkflowNodeDef) -> Dict[str, Any]:
assistant_cfg = node.assistant if isinstance(node.assistant, dict) else {}
message_plan = node.message_plan if isinstance(node.message_plan, dict) else {}
metadata: Dict[str, Any] = {}
if node.prompt is not None:
metadata["systemPrompt"] = node.prompt
elif "systemPrompt" in assistant_cfg:
metadata["systemPrompt"] = _safe_str(assistant_cfg.get("systemPrompt"))
elif "prompt" in assistant_cfg:
metadata["systemPrompt"] = _safe_str(assistant_cfg.get("prompt"))
first_message = message_plan.get("firstMessage")
if first_message is not None:
metadata["greeting"] = _safe_str(first_message)
elif "greeting" in assistant_cfg:
metadata["greeting"] = _safe_str(assistant_cfg.get("greeting"))
elif "opener" in assistant_cfg:
metadata["greeting"] = _safe_str(assistant_cfg.get("opener"))
services = assistant_cfg.get("services")
if isinstance(services, dict):
metadata["services"] = services
if node.assistant_id:
metadata["assistantId"] = node.assistant_id
return metadata
def _resolve_start_node(self) -> Optional[WorkflowNodeDef]:
explicit_start = next((node for node in self._nodes.values() if node.is_start), None)
if not explicit_start:
explicit_start = next((node for node in self._nodes.values() if node.node_type == "start"), None)
if explicit_start:
# If a dedicated start node exists, try to move to its first default target.
if explicit_start.node_type == "start":
visited = {explicit_start.id}
current = explicit_start
for _ in range(8):
transition = self._first_default_transition_from(current.id)
if not transition:
return current
current = transition.node
if current.id in visited:
break
visited.add(current.id)
return current
return explicit_start
assistant_node = next((node for node in self._nodes.values() if node.node_type == "assistant"), None)
if assistant_node:
return assistant_node
return next(iter(self._nodes.values()), None)
def _first_default_transition_from(self, node_id: str) -> Optional[WorkflowTransition]:
for edge in self.outgoing_edges(node_id):
cond_type = str(edge.condition.get("type", "always")).strip().lower()
if cond_type in {"", "always", "default"}:
node = self._nodes.get(edge.to_node_id)
if node:
return WorkflowTransition(edge=edge, node=node)
return None
def _matches_condition(self, edge: WorkflowEdgeDef, *, user_text: str, assistant_text: str) -> bool:
condition = edge.condition or {"type": "always"}
cond_type = str(condition.get("type", "always")).strip().lower()
source = str(condition.get("source", "user")).strip().lower()
if cond_type in {"", "always", "default"}:
return True
text = assistant_text if source == "assistant" else user_text
text_lower = (text or "").lower()
if cond_type == "contains":
values: List[str] = []
if isinstance(condition.get("values"), list):
values = [_safe_str(v).strip().lower() for v in condition["values"] if _safe_str(v).strip()]
if not values:
single = _safe_str(condition.get("value") or condition.get("keyword") or edge.label).strip().lower()
if single:
values = [single]
if not values:
return False
return any(value in text_lower for value in values)
if cond_type == "equals":
expected = _safe_str(condition.get("value") or "").strip().lower()
return bool(expected) and text_lower == expected
if cond_type == "regex":
pattern = _safe_str(condition.get("value") or condition.get("pattern") or "").strip()
if not pattern:
return False
try:
return bool(re.search(pattern, text or "", re.IGNORECASE))
except re.error:
logger.warning(f"Invalid workflow regex condition: {pattern}")
return False
if cond_type == "json":
value = _safe_str(condition.get("value") or "").strip()
if not value:
return False
try:
obj = json.loads(text or "")
except Exception:
return False
return str(obj) == value
return False

View File

@@ -988,10 +988,21 @@ export const DebugDrawer: React.FC<{
isOpen: boolean;
onClose: () => void;
assistant: Assistant;
voices: Voice[];
llmModels: LLMModel[];
asrModels: ASRModel[];
}> = ({ isOpen, onClose, assistant, voices, llmModels, asrModels }) => {
voices?: Voice[];
llmModels?: LLMModel[];
asrModels?: ASRModel[];
sessionMetadataExtras?: Record<string, any>;
onProtocolEvent?: (event: Record<string, any>) => void;
}> = ({
isOpen,
onClose,
assistant,
voices = [],
llmModels = [],
asrModels = [],
sessionMetadataExtras,
onProtocolEvent,
}) => {
const TARGET_SAMPLE_RATE = 16000;
const downsampleTo16k = (input: Float32Array, inputSampleRate: number): Float32Array => {
if (inputSampleRate === TARGET_SAMPLE_RATE) return input;
@@ -1474,6 +1485,10 @@ export const DebugDrawer: React.FC<{
const warnings: string[] = [];
const services: Record<string, any> = {};
const isExternalLlm = assistant.configMode === 'dify' || assistant.configMode === 'fastgpt';
const knowledgeBaseId = String(assistant.knowledgeBaseId || '').trim();
const knowledge = knowledgeBaseId
? { enabled: true, kbId: knowledgeBaseId, nResults: 5 }
: { enabled: false };
if (isExternalLlm) {
services.llm = {
@@ -1541,6 +1556,8 @@ export const DebugDrawer: React.FC<{
sessionStartMetadata: {
systemPrompt: assistant.prompt || '',
greeting: assistant.opener || '',
knowledgeBaseId,
knowledge,
services,
history: {
assistantId: assistant.id,
@@ -1556,7 +1573,10 @@ export const DebugDrawer: React.FC<{
const fetchRuntimeMetadata = async (): Promise<Record<string, any>> => {
const localResolved = buildLocalResolvedRuntime();
setResolvedConfigView(JSON.stringify(localResolved, null, 2));
return localResolved.sessionStartMetadata;
return {
...localResolved.sessionStartMetadata,
...(sessionMetadataExtras || {}),
};
};
const closeWs = () => {
@@ -1622,6 +1642,9 @@ export const DebugDrawer: React.FC<{
}
const type = payload?.type;
if (onProtocolEvent) {
onProtocolEvent(payload);
}
if (type === 'hello.ack') {
ws.send(
JSON.stringify({

View File

@@ -5,6 +5,11 @@ import { Voice } from '../types';
import { createVoice, deleteVoice, fetchVoices, previewVoice, updateVoice } from '../services/backendApi';
const SILICONFLOW_DEFAULT_MODEL = 'FunAudioLLM/CosyVoice2-0.5B';
const SILICONFLOW_MODEL_SUGGESTIONS = [
'FunAudioLLM/CosyVoice2-0.5B',
'fishaudio/fish-speech-1.5',
'fishaudio/fish-speech-1.4',
];
const buildSiliconflowVoiceKey = (rawId: string, model: string): string => {
const id = (rawId || '').trim();
@@ -408,15 +413,18 @@ const AddVoiceModal: React.FC<{
<div className="grid grid-cols-2 gap-4">
<div className="space-y-1.5">
<label className="text-[10px] font-black text-muted-foreground uppercase tracking-widest block"> (Model)</label>
<select
className="flex h-9 w-full rounded-md border-0 bg-white/5 px-3 py-1 text-sm shadow-sm transition-colors focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-primary/50 text-foreground [&>option]:bg-card"
<Input
className="h-9 bg-white/5 font-mono text-xs"
value={sfModel}
onChange={(e) => setSfModel(e.target.value)}
>
<option value="FunAudioLLM/CosyVoice2-0.5B">FunAudioLLM/CosyVoice2-0.5B</option>
<option value="fishaudio/fish-speech-1.5">fishaudio/fish-speech-1.5</option>
<option value="fishaudio/fish-speech-1.4">fishaudio/fish-speech-1.4</option>
</select>
placeholder="例如: FunAudioLLM/CosyVoice2-0.5B"
list="siliconflow-model-options"
/>
<datalist id="siliconflow-model-options">
{SILICONFLOW_MODEL_SUGGESTIONS.map((m) => (
<option key={m} value={m} />
))}
</datalist>
</div>
<div className="space-y-1.5">
<label className="text-[10px] font-black text-muted-foreground uppercase tracking-widest block"> ID (Voice)</label>

View File

@@ -1,5 +1,5 @@
import React, { useState, useRef, useEffect } from 'react';
import React, { useState, useRef, useEffect, useMemo } from 'react';
import { useNavigate, useParams, useSearchParams } from 'react-router-dom';
import { ArrowLeft, Play, Save, Rocket, Plus, Bot, UserCheck, Wrench, Ban, Zap, X, Copy, MousePointer2 } from 'lucide-react';
import { Button, Input, Badge } from '../components/UI';
@@ -7,10 +7,21 @@ import { Assistant, WorkflowNode, WorkflowEdge, Workflow } from '../types';
import { DebugDrawer } from './Assistants';
import { createWorkflow, fetchAssistants, fetchWorkflowById, updateWorkflow } from '../services/backendApi';
const toWorkflowNodeType = (type: WorkflowNode['type']): WorkflowNode['type'] => {
if (type === 'conversation') return 'assistant';
if (type === 'human') return 'human_transfer';
return type;
};
const nodeRef = (node: WorkflowNode): string => node.id || node.name;
const edgeFromRef = (edge: WorkflowEdge): string => edge.fromNodeId || edge.from;
const edgeToRef = (edge: WorkflowEdge): string => edge.toNodeId || edge.to;
const getTemplateNodes = (templateType: string | null): WorkflowNode[] => {
if (templateType === 'lead') {
return [
{
id: 'introduction',
name: 'introduction',
type: 'conversation',
isStart: true,
@@ -19,6 +30,7 @@ const getTemplateNodes = (templateType: string | null): WorkflowNode[] => {
messagePlan: { firstMessage: "Hello, this is Morgan from GrowthPartners. Do you have a few minutes to chat?" }
},
{
id: 'need_discovery',
name: 'need_discovery',
type: 'conversation',
metadata: { position: { x: 450, y: 250 } },
@@ -31,6 +43,7 @@ const getTemplateNodes = (templateType: string | null): WorkflowNode[] => {
}
},
{
id: 'hangup_node',
name: 'hangup_node',
type: 'end',
metadata: { position: { x: 450, y: 550 } },
@@ -44,6 +57,7 @@ const getTemplateNodes = (templateType: string | null): WorkflowNode[] => {
}
return [
{
id: 'start_node',
name: 'start_node',
type: 'conversation',
isStart: true,
@@ -80,6 +94,66 @@ export const WorkflowEditorPage: React.FC = () => {
const panStart = useRef({ x: 0, y: 0 });
const selectedNode = nodes.find(n => n.name === selectedNodeName);
const selectedNodeRef = selectedNode ? nodeRef(selectedNode) : null;
const outgoingEdges = selectedNodeRef
? edges.filter((edge) => edgeFromRef(edge) === selectedNodeRef)
: [];
const resolvedEdges = useMemo(() => {
return edges
.map((edge, index) => {
const from = nodes.find((node) => nodeRef(node) === edgeFromRef(edge));
const to = nodes.find((node) => nodeRef(node) === edgeToRef(edge));
if (!from || !to) return null;
return {
key: edge.id || `${edgeFromRef(edge)}->${edgeToRef(edge)}:${index}`,
edge,
from,
to,
};
})
.filter((item): item is { key: string; edge: WorkflowEdge; from: WorkflowNode; to: WorkflowNode } => Boolean(item));
}, [edges, nodes]);
const workflowRuntimeMetadata = useMemo(() => {
return {
id: id || 'draft_workflow',
name,
nodes: nodes.map((node) => {
const assistant = node.assistantId
? assistants.find((item) => item.id === node.assistantId)
: undefined;
return {
id: nodeRef(node),
name: node.name || nodeRef(node),
type: toWorkflowNodeType(node.type),
isStart: node.isStart,
prompt: node.prompt,
messagePlan: node.messagePlan,
assistantId: node.assistantId,
assistant: assistant
? {
systemPrompt: node.prompt || assistant.prompt || '',
greeting: node.messagePlan?.firstMessage || assistant.opener || '',
}
: undefined,
tool: node.tool,
metadata: node.metadata,
};
}),
edges: edges.map((edge, index) => {
const fromNodeId = edgeFromRef(edge);
const toNodeId = edgeToRef(edge);
return {
id: edge.id || `edge_${index + 1}`,
fromNodeId,
toNodeId,
label: edge.label,
priority: edge.priority ?? 100,
condition: edge.condition || (edge.label ? { type: 'contains', source: 'user', value: edge.label } : { type: 'always' }),
};
}),
};
}, [assistants, edges, id, name, nodes]);
// Scroll Zoom handler
const handleWheel = (e: React.WheelEvent) => {
@@ -172,8 +246,10 @@ export const WorkflowEditorPage: React.FC = () => {
}, [id]);
const addNode = (type: WorkflowNode['type']) => {
const nodeId = `${type}_${Date.now()}`;
const newNode: WorkflowNode = {
name: `${type}_${Date.now()}`,
id: nodeId,
name: nodeId,
type,
metadata: { position: { x: (300 - panOffset.x) / zoom, y: (300 - panOffset.y) / zoom } },
prompt: type === 'conversation' ? '输入该节点的 Prompt...' : '',
@@ -185,7 +261,95 @@ export const WorkflowEditorPage: React.FC = () => {
const updateNodeData = (field: string, value: any) => {
if (!selectedNodeName) return;
setNodes(prev => prev.map(n => n.name === selectedNodeName ? { ...n, [field]: value } : n));
setNodes(prev => {
const currentNode = prev.find((n) => n.name === selectedNodeName);
if (!currentNode) return prev;
const oldRef = nodeRef(currentNode);
const updatedNodes = prev.map((node) => {
if (node.name !== selectedNodeName) {
if (field === 'isStart' && value === true) {
return { ...node, isStart: false };
}
return node;
}
if (field === 'isStart' && value === true) {
return { ...node, isStart: true };
}
return { ...node, [field]: value };
});
if (field === 'name') {
const renamed = updatedNodes.find((n) => n.name === value);
const newRef = renamed ? nodeRef(renamed) : String(value);
setEdges((prevEdges) =>
prevEdges.map((edge) => {
const from = edgeFromRef(edge);
const to = edgeToRef(edge);
if (from !== oldRef && to !== oldRef) return edge;
const nextFrom = from === oldRef ? newRef : from;
const nextTo = to === oldRef ? newRef : to;
return {
...edge,
fromNodeId: nextFrom,
toNodeId: nextTo,
from: nextFrom,
to: nextTo,
};
})
);
}
return updatedNodes;
});
};
const addEdgeFromSelected = () => {
if (!selectedNode) return;
const fromNodeId = nodeRef(selectedNode);
const target = nodes.find((node) => nodeRef(node) !== fromNodeId);
if (!target) return;
const toNodeId = nodeRef(target);
const edgeId = `edge_${Date.now()}`;
setEdges((prev) => [
...prev,
{
id: edgeId,
fromNodeId,
toNodeId,
from: fromNodeId,
to: toNodeId,
condition: { type: 'always' },
},
]);
};
const updateOutgoingEdge = (edgeId: string, patch: Partial<WorkflowEdge>) => {
setEdges((prev) =>
prev.map((edge, index) => {
const idForCompare = edge.id || `${edgeFromRef(edge)}->${edgeToRef(edge)}:${index}`;
if (idForCompare !== edgeId) return edge;
const next = { ...edge, ...patch };
const fromNodeId = edgeFromRef(next);
const toNodeId = edgeToRef(next);
return {
...next,
fromNodeId,
toNodeId,
from: fromNodeId,
to: toNodeId,
};
})
);
};
const removeOutgoingEdge = (edgeId: string) => {
setEdges((prev) =>
prev.filter((edge, index) => {
const idForCompare = edge.id || `${edgeFromRef(edge)}->${edgeToRef(edge)}:${index}`;
return idForCompare !== edgeId;
})
);
};
const handleSave = async () => {
@@ -286,9 +450,29 @@ export const WorkflowEditorPage: React.FC = () => {
transformOrigin: '0 0'
}}
>
<svg className="absolute inset-0 pointer-events-none overflow-visible">
{resolvedEdges.map(({ key, from, to, edge }) => {
const x1 = from.metadata.position.x + 112;
const y1 = from.metadata.position.y + 88;
const x2 = to.metadata.position.x + 112;
const y2 = to.metadata.position.y;
const midY = (y1 + y2) / 2;
const d = `M ${x1} ${y1} C ${x1} ${midY}, ${x2} ${midY}, ${x2} ${y2}`;
return (
<g key={key}>
<path d={d} stroke="rgba(148,163,184,0.55)" strokeWidth={2} fill="none" />
{(edge.label || edge.condition?.value) && (
<text x={(x1 + x2) / 2} y={midY - 6} fill="rgba(226,232,240,0.8)" fontSize={10} textAnchor="middle">
{edge.label || edge.condition?.value}
</text>
)}
</g>
);
})}
</svg>
{nodes.map(node => (
<div
key={node.name}
key={nodeRef(node)}
onMouseDown={(e) => handleNodeMouseDown(e, node.name)}
style={{ left: node.metadata.position.x, top: node.metadata.position.y }}
className={`absolute w-56 p-4 rounded-xl border bg-card/70 backdrop-blur-sm cursor-grab active:cursor-grabbing group transition-shadow ${selectedNodeName === node.name ? 'border-primary shadow-[0_0_30px_rgba(6,182,212,0.3)]' : 'border-white/10 hover:border-white/30'}`}
@@ -324,7 +508,13 @@ export const WorkflowEditorPage: React.FC = () => {
return (
<div
key={i}
className={`absolute w-3 h-2 rounded-sm ${n.type === 'conversation' ? 'bg-primary' : n.type === 'end' ? 'bg-destructive' : 'bg-white/40'}`}
className={`absolute w-3 h-2 rounded-sm ${
n.type === 'conversation' || n.type === 'assistant' || n.type === 'start'
? 'bg-primary'
: n.type === 'end'
? 'bg-destructive'
: 'bg-white/40'
}`}
style={{ left: `${20 + mx}%`, top: `${20 + my}%` }}
></div>
);
@@ -369,8 +559,23 @@ export const WorkflowEditorPage: React.FC = () => {
<Badge variant="outline" className="w-fit">{selectedNode.type.toUpperCase()}</Badge>
</div>
{selectedNode.type === 'conversation' && (
{(selectedNode.type === 'conversation' || selectedNode.type === 'assistant' || selectedNode.type === 'start') && (
<>
<div className="space-y-2">
<label className="text-[10px] text-muted-foreground uppercase font-mono tracking-widest"></label>
<select
value={selectedNode.assistantId || ''}
onChange={(e) => updateNodeData('assistantId', e.target.value || undefined)}
className="w-full h-8 bg-white/5 border border-white/10 rounded-md px-2 text-xs text-white focus:outline-none focus:ring-1 focus:ring-primary/50"
>
<option value=""> Prompt</option>
{assistants.map((assistant) => (
<option key={assistant.id} value={assistant.id}>
{assistant.name}
</option>
))}
</select>
</div>
<div className="space-y-2">
<label className="text-[10px] text-muted-foreground uppercase font-mono tracking-widest">Prompt ()</label>
<textarea
@@ -410,6 +615,73 @@ export const WorkflowEditorPage: React.FC = () => {
<span className="text-[10px] text-muted-foreground group-hover:text-primary transition-colors uppercase font-mono tracking-widest"> (Start Node)</span>
</label>
</div>
<div className="pt-4 border-t border-white/5 space-y-3">
<div className="flex items-center justify-between">
<label className="text-[10px] text-muted-foreground uppercase font-mono tracking-widest"></label>
<Button variant="outline" size="sm" className="h-7 text-[11px]" onClick={addEdgeFromSelected}>
<Plus className="w-3 h-3 mr-1" />
</Button>
</div>
{outgoingEdges.length === 0 && (
<p className="text-[11px] text-muted-foreground"></p>
)}
{outgoingEdges.map((edge, index) => {
const edgeId = edge.id || `${edgeFromRef(edge)}->${edgeToRef(edge)}:${index}`;
const keyword = edge.condition?.value || edge.label || '';
return (
<div key={edgeId} className="rounded-lg border border-white/10 p-3 space-y-2 bg-white/5">
<div className="flex items-center justify-between">
<span className="text-[10px] uppercase tracking-widest text-muted-foreground"> #{index + 1}</span>
<button
className="text-[10px] text-destructive hover:underline"
onClick={() => removeOutgoingEdge(edgeId)}
>
</button>
</div>
<div className="space-y-1">
<label className="text-[10px] text-muted-foreground"></label>
<select
value={edgeToRef(edge)}
onChange={(e) =>
updateOutgoingEdge(edgeId, {
toNodeId: e.target.value,
to: e.target.value,
})
}
className="w-full h-8 bg-black/20 border border-white/10 rounded-md px-2 text-xs text-white focus:outline-none"
>
{nodes
.filter((node) => nodeRef(node) !== selectedNodeRef)
.map((node) => (
<option key={nodeRef(node)} value={nodeRef(node)}>
{node.name}
</option>
))}
</select>
</div>
<div className="space-y-1">
<label className="text-[10px] text-muted-foreground">=always</label>
<Input
value={keyword}
onChange={(e) => {
const v = e.target.value;
updateOutgoingEdge(edgeId, {
label: v || undefined,
condition: v
? { type: 'contains', source: 'user', value: v }
: { type: 'always' },
});
}}
className="h-8 text-xs"
placeholder="例如:退款 / 投诉 / 结束"
/>
</div>
</div>
);
})}
</div>
</div>
</div>
)}
@@ -418,6 +690,15 @@ export const WorkflowEditorPage: React.FC = () => {
<DebugDrawer
isOpen={isDebugOpen}
onClose={() => setIsDebugOpen(false)}
sessionMetadataExtras={{ workflow: workflowRuntimeMetadata }}
onProtocolEvent={(event) => {
if (event?.type !== 'workflow.node.entered') return;
const incomingNodeId = String(event.nodeId || '');
const matched = nodes.find((node) => nodeRef(node) === incomingNodeId || node.name === incomingNodeId);
if (matched) {
setSelectedNodeName(matched.name);
}
}}
assistant={assistants[0] || {
id: 'debug',
name: 'Debug Assistant',
@@ -429,7 +710,10 @@ export const WorkflowEditorPage: React.FC = () => {
voice: '',
speed: 1,
hotwords: [],
}}
}}
voices={[]}
llmModels={[]}
asrModels={[]}
/>
</div>
);
@@ -438,7 +722,10 @@ export const WorkflowEditorPage: React.FC = () => {
const NodeIcon = ({ type }: { type: WorkflowNode['type'] }) => {
switch (type) {
case 'conversation': return <Bot className="h-4 w-4 text-primary" />;
case 'assistant': return <Bot className="h-4 w-4 text-primary" />;
case 'start': return <Bot className="h-4 w-4 text-cyan-300" />;
case 'human': return <UserCheck className="h-4 w-4 text-orange-400" />;
case 'human_transfer': return <UserCheck className="h-4 w-4 text-orange-400" />;
case 'tool': return <Wrench className="h-4 w-4 text-purple-400" />;
case 'end': return <Ban className="h-4 w-4 text-destructive" />;
default: return <MousePointer2 className="h-4 w-4" />;

View File

@@ -107,10 +107,19 @@ const mapTool = (raw: AnyRecord): Tool => ({
});
const mapWorkflowNode = (raw: AnyRecord): WorkflowNode => ({
name: readField(raw, ['name'], ''),
type: readField(raw, ['type'], 'conversation') as 'conversation' | 'tool' | 'human' | 'end',
id: readField(raw, ['id'], ''),
name: readField(raw, ['name'], String(readField(raw, ['id'], ''))),
type: readField(raw, ['type'], 'assistant') as WorkflowNode['type'],
isStart: readField(raw, ['isStart', 'is_start'], undefined),
metadata: readField(raw, ['metadata'], { position: { x: 200, y: 200 } }),
metadata: (() => {
const metadata = readField(raw, ['metadata'], null);
if (metadata && typeof metadata === 'object') return metadata;
const position = readField(raw, ['position'], null);
if (position && typeof position === 'object') return { position };
return { position: { x: 200, y: 200 } };
})(),
assistantId: readField(raw, ['assistantId', 'assistant_id'], undefined),
assistant: readField(raw, ['assistant'], undefined),
prompt: readField(raw, ['prompt'], ''),
messagePlan: readField(raw, ['messagePlan', 'message_plan'], undefined),
variableExtractionPlan: readField(raw, ['variableExtractionPlan', 'variable_extraction_plan'], undefined),
@@ -119,9 +128,14 @@ const mapWorkflowNode = (raw: AnyRecord): WorkflowNode => ({
});
const mapWorkflowEdge = (raw: AnyRecord): WorkflowEdge => ({
from: readField(raw, ['from', 'from_'], ''),
to: readField(raw, ['to'], ''),
id: readField(raw, ['id'], undefined),
fromNodeId: readField(raw, ['fromNodeId', 'from', 'from_', 'source'], ''),
toNodeId: readField(raw, ['toNodeId', 'to', 'target'], ''),
from: readField(raw, ['fromNodeId', 'from', 'from_', 'source'], ''),
to: readField(raw, ['toNodeId', 'to', 'target'], ''),
label: readField(raw, ['label'], undefined),
condition: readField(raw, ['condition'], undefined),
priority: Number(readField(raw, ['priority'], 100)),
});
const mapWorkflow = (raw: AnyRecord): Workflow => ({

View File

@@ -91,13 +91,26 @@ export interface Workflow {
globalPrompt?: string;
}
export type WorkflowNodeType = 'start' | 'assistant' | 'tool' | 'human_transfer' | 'end' | 'conversation' | 'human';
export interface WorkflowCondition {
type: 'always' | 'contains' | 'equals' | 'regex' | 'llm' | 'default';
source?: 'user' | 'assistant';
value?: string;
values?: string[];
prompt?: string;
}
export interface WorkflowNode {
id?: string;
name: string;
type: 'conversation' | 'tool' | 'human' | 'end';
type: WorkflowNodeType;
isStart?: boolean;
metadata: {
position: { x: number; y: number };
};
assistantId?: string;
assistant?: Record<string, any>;
prompt?: string;
messagePlan?: {
firstMessage?: string;
@@ -125,9 +138,14 @@ export interface WorkflowNode {
}
export interface WorkflowEdge {
id?: string;
fromNodeId?: string;
toNodeId?: string;
from: string;
to: string;
label?: string;
condition?: WorkflowCondition;
priority?: number;
}
export enum TabValue {