Use decoupled way for backend client
This commit is contained in:
@@ -15,7 +15,7 @@ import asyncio
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
@@ -86,7 +86,16 @@ class DuplexPipeline:
|
||||
tts_service: Optional[BaseTTSService] = None,
|
||||
asr_service: Optional[BaseASRService] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
greeting: Optional[str] = None
|
||||
greeting: Optional[str] = None,
|
||||
knowledge_searcher: Optional[
|
||||
Callable[..., Awaitable[List[Dict[str, Any]]]]
|
||||
] = None,
|
||||
tool_resource_resolver: Optional[
|
||||
Callable[[str], Awaitable[Optional[Dict[str, Any]]]]
|
||||
] = None,
|
||||
server_tool_executor: Optional[
|
||||
Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]]]
|
||||
] = None,
|
||||
):
|
||||
"""
|
||||
Initialize duplex pipeline.
|
||||
@@ -127,6 +136,9 @@ class DuplexPipeline:
|
||||
self.llm_service = llm_service
|
||||
self.tts_service = tts_service
|
||||
self.asr_service = asr_service # Will be initialized in start()
|
||||
self._knowledge_searcher = knowledge_searcher
|
||||
self._tool_resource_resolver = tool_resource_resolver
|
||||
self._server_tool_executor = server_tool_executor
|
||||
|
||||
# Track last sent transcript to avoid duplicates
|
||||
self._last_sent_transcript = ""
|
||||
@@ -215,6 +227,18 @@ class DuplexPipeline:
|
||||
self._pending_llm_delta: str = ""
|
||||
self._last_llm_delta_emit_ms: float = 0.0
|
||||
|
||||
if self._server_tool_executor is None:
|
||||
if self._tool_resource_resolver:
|
||||
async def _executor(call: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return await execute_server_tool(
|
||||
call,
|
||||
tool_resource_fetcher=self._tool_resource_resolver,
|
||||
)
|
||||
|
||||
self._server_tool_executor = _executor
|
||||
else:
|
||||
self._server_tool_executor = execute_server_tool
|
||||
|
||||
logger.info(f"DuplexPipeline initialized for session {session_id}")
|
||||
|
||||
def set_event_sequence_provider(self, provider: Callable[[], int]) -> None:
|
||||
@@ -559,6 +583,7 @@ class DuplexPipeline:
|
||||
base_url=llm_base_url,
|
||||
model=llm_model,
|
||||
knowledge_config=self._resolved_knowledge_config(),
|
||||
knowledge_searcher=self._knowledge_searcher,
|
||||
)
|
||||
else:
|
||||
logger.warning("LLM provider unsupported or API key missing - using mock LLM")
|
||||
@@ -1491,7 +1516,7 @@ class DuplexPipeline:
|
||||
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
execute_server_tool(call),
|
||||
self._server_tool_executor(call),
|
||||
timeout=self._SERVER_TOOL_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
|
||||
Reference in New Issue
Block a user