Use decoupled way for backend client

This commit is contained in:
Xin Wang
2026-02-25 17:05:40 +08:00
parent 1cd2da1042
commit 08319a4cc7
15 changed files with 1203 additions and 228 deletions

View File

@@ -15,7 +15,7 @@ import asyncio
import json
import time
import uuid
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
import numpy as np
from loguru import logger
@@ -86,7 +86,16 @@ class DuplexPipeline:
tts_service: Optional[BaseTTSService] = None,
asr_service: Optional[BaseASRService] = None,
system_prompt: Optional[str] = None,
greeting: Optional[str] = None
greeting: Optional[str] = None,
knowledge_searcher: Optional[
Callable[..., Awaitable[List[Dict[str, Any]]]]
] = None,
tool_resource_resolver: Optional[
Callable[[str], Awaitable[Optional[Dict[str, Any]]]]
] = None,
server_tool_executor: Optional[
Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]]]
] = None,
):
"""
Initialize duplex pipeline.
@@ -127,6 +136,9 @@ class DuplexPipeline:
self.llm_service = llm_service
self.tts_service = tts_service
self.asr_service = asr_service # Will be initialized in start()
self._knowledge_searcher = knowledge_searcher
self._tool_resource_resolver = tool_resource_resolver
self._server_tool_executor = server_tool_executor
# Track last sent transcript to avoid duplicates
self._last_sent_transcript = ""
@@ -215,6 +227,18 @@ class DuplexPipeline:
self._pending_llm_delta: str = ""
self._last_llm_delta_emit_ms: float = 0.0
if self._server_tool_executor is None:
if self._tool_resource_resolver:
async def _executor(call: Dict[str, Any]) -> Dict[str, Any]:
return await execute_server_tool(
call,
tool_resource_fetcher=self._tool_resource_resolver,
)
self._server_tool_executor = _executor
else:
self._server_tool_executor = execute_server_tool
logger.info(f"DuplexPipeline initialized for session {session_id}")
def set_event_sequence_provider(self, provider: Callable[[], int]) -> None:
@@ -559,6 +583,7 @@ class DuplexPipeline:
base_url=llm_base_url,
model=llm_model,
knowledge_config=self._resolved_knowledge_config(),
knowledge_searcher=self._knowledge_searcher,
)
else:
logger.warning("LLM provider unsupported or API key missing - using mock LLM")
@@ -1491,7 +1516,7 @@ class DuplexPipeline:
try:
result = await asyncio.wait_for(
execute_server_tool(call),
self._server_tool_executor(call),
timeout=self._SERVER_TOOL_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError: