Update workflow feature with codex
This commit is contained in:
@@ -4,8 +4,9 @@ import asyncio
|
||||
import uuid
|
||||
import json
|
||||
import time
|
||||
import re
|
||||
from enum import Enum
|
||||
from typing import Optional, Dict, Any
|
||||
from typing import Optional, Dict, Any, List
|
||||
from loguru import logger
|
||||
|
||||
from app.backend_client import (
|
||||
@@ -16,7 +17,9 @@ from app.backend_client import (
|
||||
from core.transports import BaseTransport
|
||||
from core.duplex_pipeline import DuplexPipeline
|
||||
from core.conversation import ConversationTurn
|
||||
from core.workflow_runner import WorkflowRunner, WorkflowTransition, WorkflowNodeDef, WorkflowEdgeDef
|
||||
from app.config import settings
|
||||
from services.base import LLMMessage
|
||||
from models.ws_v1 import (
|
||||
parse_client_message,
|
||||
ev,
|
||||
@@ -81,6 +84,9 @@ class Session:
|
||||
self._history_finalized: bool = False
|
||||
self._cleanup_lock = asyncio.Lock()
|
||||
self._cleaned_up = False
|
||||
self.workflow_runner: Optional[WorkflowRunner] = None
|
||||
self._workflow_last_user_text: str = ""
|
||||
self._workflow_initial_node: Optional[WorkflowNodeDef] = None
|
||||
|
||||
self.pipeline.conversation.on_turn_complete(self._on_turn_complete)
|
||||
|
||||
@@ -223,6 +229,7 @@ class Session:
|
||||
return
|
||||
|
||||
metadata = message.metadata or {}
|
||||
metadata = self._merge_runtime_metadata(metadata, self._bootstrap_workflow(metadata))
|
||||
|
||||
# Create history call record early so later turn callbacks can append transcripts.
|
||||
await self._start_history_bridge(metadata)
|
||||
@@ -246,6 +253,26 @@ class Session:
|
||||
audio=message.audio or {},
|
||||
)
|
||||
)
|
||||
if self.workflow_runner and self._workflow_initial_node:
|
||||
await self.transport.send_event(
|
||||
ev(
|
||||
"workflow.started",
|
||||
sessionId=self.id,
|
||||
workflowId=self.workflow_runner.workflow_id,
|
||||
workflowName=self.workflow_runner.name,
|
||||
nodeId=self._workflow_initial_node.id,
|
||||
)
|
||||
)
|
||||
await self.transport.send_event(
|
||||
ev(
|
||||
"workflow.node.entered",
|
||||
sessionId=self.id,
|
||||
workflowId=self.workflow_runner.workflow_id,
|
||||
nodeId=self._workflow_initial_node.id,
|
||||
nodeName=self._workflow_initial_node.name,
|
||||
nodeType=self._workflow_initial_node.node_type,
|
||||
)
|
||||
)
|
||||
|
||||
async def _handle_session_stop(self, reason: Optional[str]) -> None:
|
||||
"""Handle session stop."""
|
||||
@@ -334,7 +361,14 @@ class Session:
|
||||
logger.info(f"Session {self.id} history bridge enabled (call_id={call_id}, source={source})")
|
||||
|
||||
async def _on_turn_complete(self, turn: ConversationTurn) -> None:
|
||||
"""Persist completed turns to backend call transcripts."""
|
||||
"""Process workflow transitions and persist completed turns to history."""
|
||||
if turn.text and turn.text.strip():
|
||||
role = (turn.role or "").lower()
|
||||
if role == "user":
|
||||
self._workflow_last_user_text = turn.text.strip()
|
||||
elif role == "assistant":
|
||||
await self._maybe_advance_workflow(turn.text.strip())
|
||||
|
||||
if not self._history_call_id:
|
||||
return
|
||||
if not turn.text or not turn.text.strip():
|
||||
@@ -377,3 +411,235 @@ class Session:
|
||||
)
|
||||
if ok:
|
||||
self._history_finalized = True
|
||||
|
||||
def _bootstrap_workflow(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Parse workflow payload and return initial runtime overrides."""
|
||||
payload = metadata.get("workflow")
|
||||
self.workflow_runner = WorkflowRunner.from_payload(payload)
|
||||
self._workflow_initial_node = None
|
||||
if not self.workflow_runner:
|
||||
return {}
|
||||
|
||||
node = self.workflow_runner.bootstrap()
|
||||
if not node:
|
||||
logger.warning(f"Session {self.id} workflow payload had no resolvable start node")
|
||||
self.workflow_runner = None
|
||||
return {}
|
||||
|
||||
self._workflow_initial_node = node
|
||||
logger.info(
|
||||
"Session {} workflow enabled: workflow={} start_node={}",
|
||||
self.id,
|
||||
self.workflow_runner.workflow_id,
|
||||
node.id,
|
||||
)
|
||||
return self.workflow_runner.build_runtime_metadata(node)
|
||||
|
||||
async def _maybe_advance_workflow(self, assistant_text: str) -> None:
|
||||
"""Attempt node transfer after assistant turn finalization."""
|
||||
if not self.workflow_runner or self.ws_state == WsSessionState.STOPPED:
|
||||
return
|
||||
|
||||
transition = await self.workflow_runner.route(
|
||||
user_text=self._workflow_last_user_text,
|
||||
assistant_text=assistant_text,
|
||||
llm_router=self._workflow_llm_route,
|
||||
)
|
||||
if not transition:
|
||||
return
|
||||
|
||||
await self._apply_workflow_transition(transition, reason="rule_match")
|
||||
|
||||
# Auto-advance through utility nodes when default edges are present.
|
||||
max_auto_hops = 6
|
||||
auto_hops = 0
|
||||
while self.workflow_runner and self.ws_state != WsSessionState.STOPPED:
|
||||
current = self.workflow_runner.current_node
|
||||
if not current or current.node_type not in {"start", "tool"}:
|
||||
break
|
||||
|
||||
next_default = self.workflow_runner.next_default_transition()
|
||||
if not next_default:
|
||||
break
|
||||
|
||||
auto_hops += 1
|
||||
await self._apply_workflow_transition(next_default, reason="auto")
|
||||
if auto_hops >= max_auto_hops:
|
||||
logger.warning(
|
||||
"Session {} workflow auto-advance reached hop limit (possible cycle)",
|
||||
self.id,
|
||||
)
|
||||
break
|
||||
|
||||
async def _apply_workflow_transition(self, transition: WorkflowTransition, reason: str) -> None:
|
||||
"""Apply graph transition and emit workflow lifecycle events."""
|
||||
if not self.workflow_runner:
|
||||
return
|
||||
|
||||
self.workflow_runner.apply_transition(transition)
|
||||
node = transition.node
|
||||
edge = transition.edge
|
||||
|
||||
await self.transport.send_event(
|
||||
ev(
|
||||
"workflow.edge.taken",
|
||||
sessionId=self.id,
|
||||
workflowId=self.workflow_runner.workflow_id,
|
||||
edgeId=edge.id,
|
||||
fromNodeId=edge.from_node_id,
|
||||
toNodeId=edge.to_node_id,
|
||||
reason=reason,
|
||||
)
|
||||
)
|
||||
await self.transport.send_event(
|
||||
ev(
|
||||
"workflow.node.entered",
|
||||
sessionId=self.id,
|
||||
workflowId=self.workflow_runner.workflow_id,
|
||||
nodeId=node.id,
|
||||
nodeName=node.name,
|
||||
nodeType=node.node_type,
|
||||
)
|
||||
)
|
||||
|
||||
node_runtime = self.workflow_runner.build_runtime_metadata(node)
|
||||
if node_runtime:
|
||||
self.pipeline.apply_runtime_overrides(node_runtime)
|
||||
|
||||
if node.node_type == "tool":
|
||||
await self.transport.send_event(
|
||||
ev(
|
||||
"workflow.tool.requested",
|
||||
sessionId=self.id,
|
||||
workflowId=self.workflow_runner.workflow_id,
|
||||
nodeId=node.id,
|
||||
tool=node.tool or {},
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
if node.node_type == "human_transfer":
|
||||
await self.transport.send_event(
|
||||
ev(
|
||||
"workflow.human_transfer",
|
||||
sessionId=self.id,
|
||||
workflowId=self.workflow_runner.workflow_id,
|
||||
nodeId=node.id,
|
||||
)
|
||||
)
|
||||
await self._handle_session_stop("workflow_human_transfer")
|
||||
return
|
||||
|
||||
if node.node_type == "end":
|
||||
await self.transport.send_event(
|
||||
ev(
|
||||
"workflow.ended",
|
||||
sessionId=self.id,
|
||||
workflowId=self.workflow_runner.workflow_id,
|
||||
nodeId=node.id,
|
||||
)
|
||||
)
|
||||
await self._handle_session_stop("workflow_end")
|
||||
|
||||
async def _workflow_llm_route(
|
||||
self,
|
||||
node: WorkflowNodeDef,
|
||||
candidates: List[WorkflowEdgeDef],
|
||||
context: Dict[str, str],
|
||||
) -> Optional[str]:
|
||||
"""LLM-based edge routing for condition.type == 'llm' edges."""
|
||||
llm_service = self.pipeline.llm_service
|
||||
if not llm_service:
|
||||
return None
|
||||
|
||||
candidate_rows = [
|
||||
{
|
||||
"edgeId": edge.id,
|
||||
"toNodeId": edge.to_node_id,
|
||||
"label": edge.label,
|
||||
"hint": edge.condition.get("prompt") if isinstance(edge.condition, dict) else None,
|
||||
}
|
||||
for edge in candidates
|
||||
]
|
||||
system_prompt = (
|
||||
"You are a workflow router. Pick exactly one edge. "
|
||||
"Return JSON only: {\"edgeId\":\"...\"}."
|
||||
)
|
||||
user_prompt = json.dumps(
|
||||
{
|
||||
"nodeId": node.id,
|
||||
"nodeName": node.name,
|
||||
"userText": context.get("userText", ""),
|
||||
"assistantText": context.get("assistantText", ""),
|
||||
"candidates": candidate_rows,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
try:
|
||||
reply = await llm_service.generate(
|
||||
[
|
||||
LLMMessage(role="system", content=system_prompt),
|
||||
LLMMessage(role="user", content=user_prompt),
|
||||
],
|
||||
temperature=0.0,
|
||||
max_tokens=64,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(f"Session {self.id} workflow llm routing failed: {exc}")
|
||||
return None
|
||||
|
||||
if not reply:
|
||||
return None
|
||||
|
||||
edge_ids = {edge.id for edge in candidates}
|
||||
node_ids = {edge.to_node_id for edge in candidates}
|
||||
|
||||
parsed = self._extract_json_obj(reply)
|
||||
if isinstance(parsed, dict):
|
||||
edge_id = parsed.get("edgeId") or parsed.get("id")
|
||||
node_id = parsed.get("toNodeId") or parsed.get("nodeId")
|
||||
if isinstance(edge_id, str) and edge_id in edge_ids:
|
||||
return edge_id
|
||||
if isinstance(node_id, str) and node_id in node_ids:
|
||||
return node_id
|
||||
|
||||
token_candidates = sorted(edge_ids | node_ids, key=len, reverse=True)
|
||||
lowered_reply = reply.lower()
|
||||
for token in token_candidates:
|
||||
if token.lower() in lowered_reply:
|
||||
return token
|
||||
return None
|
||||
|
||||
def _merge_runtime_metadata(self, base: Dict[str, Any], overrides: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Merge node-level metadata overrides into session.start metadata."""
|
||||
merged = dict(base or {})
|
||||
if not overrides:
|
||||
return merged
|
||||
for key, value in overrides.items():
|
||||
if key == "services" and isinstance(value, dict):
|
||||
existing = merged.get("services")
|
||||
merged_services = dict(existing) if isinstance(existing, dict) else {}
|
||||
merged_services.update(value)
|
||||
merged["services"] = merged_services
|
||||
else:
|
||||
merged[key] = value
|
||||
return merged
|
||||
|
||||
def _extract_json_obj(self, text: str) -> Optional[Dict[str, Any]]:
|
||||
"""Best-effort extraction of a JSON object from freeform text."""
|
||||
try:
|
||||
parsed = json.loads(text)
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
match = re.search(r"\{.*\}", text, re.DOTALL)
|
||||
if not match:
|
||||
return None
|
||||
try:
|
||||
parsed = json.loads(match.group(0))
|
||||
return parsed if isinstance(parsed, dict) else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user