Enhance DuplexPipeline to support follow-up context for manual opener tool calls

- Introduced logic to trigger a follow-up turn when the manual opener greeting is empty.
- Updated `_execute_manual_opener_tool_calls` to return structured tool call and result data.
- Added `_build_manual_opener_follow_up_context` method to construct context for follow-up turns.
- Modified `_handle_turn` to accept system context for improved conversation management.
- Enhanced tests to validate the new follow-up behavior and ensure proper context handling.
This commit is contained in:
Xin Wang
2026-03-02 14:27:44 +08:00
parent fb017f9952
commit 3aa9e0f432
2 changed files with 115 additions and 4 deletions

View File

@@ -983,10 +983,11 @@ class DuplexPipeline:
logger.info("Initial generated opener started with tool-calling path")
return
manual_opener_execution: Dict[str, List[Dict[str, Any]]] = {"toolCalls": [], "toolResults": []}
if not self._generated_opener_enabled() and self._resolved_manual_opener_tool_calls():
self._start_turn()
self._start_response()
await self._execute_manual_opener_tool_calls()
manual_opener_execution = await self._execute_manual_opener_tool_calls()
greeting_to_speak = self.conversation.greeting
if self._generated_opener_enabled():
@@ -996,6 +997,16 @@ class DuplexPipeline:
self.conversation.greeting = generated_greeting
if not greeting_to_speak:
if (
not self._generated_opener_enabled()
and manual_opener_execution.get("toolCalls")
and not (self._current_turn_task and not self._current_turn_task.done())
):
follow_up_context = self._build_manual_opener_follow_up_context(manual_opener_execution)
self._current_turn_task = asyncio.create_task(
self._handle_turn("", system_context=follow_up_context)
)
logger.info("Initial manual opener follow-up started")
return
if not self._current_turn_id:
@@ -1840,10 +1851,23 @@ class DuplexPipeline:
merged.update(args)
return merged
async def _execute_manual_opener_tool_calls(self) -> None:
def _build_manual_opener_follow_up_context(self, payload: Dict[str, List[Dict[str, Any]]]) -> str:
tool_calls = payload.get("toolCalls") if isinstance(payload.get("toolCalls"), list) else []
tool_results = payload.get("toolResults") if isinstance(payload.get("toolResults"), list) else []
return (
"Initial opener tool calls already executed. Continue with a natural assistant follow-up. "
"If tool results include user selections or values, use them in your response. "
"Never expose internal tool ids or raw payloads.\n"
f"opener_tool_calls={json.dumps(tool_calls, ensure_ascii=False)}\n"
f"opener_tool_results={json.dumps(tool_results, ensure_ascii=False)}"
)
async def _execute_manual_opener_tool_calls(self) -> Dict[str, List[Dict[str, Any]]]:
calls = self._resolved_manual_opener_tool_calls()
tool_calls_for_context: List[Dict[str, Any]] = []
tool_results_for_context: List[Dict[str, Any]] = []
if not calls:
return
return {"toolCalls": tool_calls_for_context, "toolResults": tool_results_for_context}
for call in calls:
tool_name = str(call.get("toolName") or "").strip()
@@ -1864,6 +1888,17 @@ class DuplexPipeline:
},
}
executor = self._tool_executor(tool_call)
tool_calls_for_context.append(
{
"tool_call_id": call_id,
"tool_name": tool_name,
"tool_id": tool_id,
"tool_display_name": tool_display_name,
"arguments": merged_tool_arguments,
"wait_for_response": wait_for_response,
"executor": executor,
}
)
await self._send_event(
{
@@ -1893,6 +1928,7 @@ class DuplexPipeline:
if wait_for_response:
result = await self._wait_for_single_tool_result(call_id)
await self._emit_tool_result(result, source="client")
tool_results_for_context.append(result if isinstance(result, dict) else {"tool_call_id": call_id})
continue
call_for_executor = dict(tool_call)
@@ -1917,6 +1953,9 @@ class DuplexPipeline:
"status": {"code": 504, "message": "server_tool_timeout"},
}
await self._emit_tool_result(result, source="server")
tool_results_for_context.append(result if isinstance(result, dict) else {"tool_call_id": call_id})
return {"toolCalls": tool_calls_for_context, "toolResults": tool_results_for_context}
def _normalize_tool_result(self, result: Dict[str, Any]) -> Dict[str, Any]:
status = result.get("status") if isinstance(result.get("status"), dict) else {}
@@ -2048,7 +2087,7 @@ class DuplexPipeline:
)
return LLMStreamEvent(type="done")
async def _handle_turn(self, user_text: str) -> None:
async def _handle_turn(self, user_text: str, system_context: Optional[str] = None) -> None:
"""
Handle a complete conversation turn.
@@ -2071,6 +2110,8 @@ class DuplexPipeline:
full_response = ""
messages = self.conversation.get_messages()
if system_context and system_context.strip():
messages = [*messages, LLMMessage(role="system", content=system_context.strip())]
max_rounds = 3
await self.conversation.start_assistant_turn()