Update code for tool call

This commit is contained in:
Xin Wang
2026-02-10 16:28:20 +08:00
parent 539cf2fda2
commit 4b8da32787
7 changed files with 676 additions and 91 deletions

View File

@@ -6,11 +6,12 @@ for real-time voice conversation.
import os
import asyncio
import uuid
from typing import AsyncIterator, Optional, List, Dict, Any
from loguru import logger
from app.backend_client import search_knowledge_context
from services.base import BaseLLMService, LLMMessage, ServiceState
from services.base import BaseLLMService, LLMMessage, LLMStreamEvent, ServiceState
# Try to import openai
try:
@@ -59,6 +60,7 @@ class OpenAILLMService(BaseLLMService):
self.client: Optional[AsyncOpenAI] = None
self._cancel_event = asyncio.Event()
self._knowledge_config: Dict[str, Any] = knowledge_config or {}
self._tool_schemas: List[Dict[str, Any]] = []
_RAG_DEFAULT_RESULTS = 5
_RAG_MAX_RESULTS = 8
@@ -106,6 +108,29 @@ class OpenAILLMService(BaseLLMService):
"""Update runtime knowledge retrieval config."""
self._knowledge_config = config or {}
def set_tool_schemas(self, schemas: Optional[List[Dict[str, Any]]]) -> None:
"""Update runtime tool schemas."""
self._tool_schemas = []
if not isinstance(schemas, list):
return
for item in schemas:
if not isinstance(item, dict):
continue
fn = item.get("function")
if isinstance(fn, dict) and fn.get("name"):
self._tool_schemas.append(item)
elif item.get("name"):
self._tool_schemas.append(
{
"type": "function",
"function": {
"name": str(item.get("name")),
"description": str(item.get("description") or ""),
"parameters": item.get("parameters") or {"type": "object", "properties": {}},
},
}
)
@staticmethod
def _coerce_int(value: Any, default: int) -> int:
try:
@@ -258,7 +283,7 @@ class OpenAILLMService(BaseLLMService):
messages: List[LLMMessage],
temperature: float = 0.7,
max_tokens: Optional[int] = None
) -> AsyncIterator[str]:
) -> AsyncIterator[LLMStreamEvent]:
"""
Generate response in streaming mode.
@@ -268,7 +293,7 @@ class OpenAILLMService(BaseLLMService):
max_tokens: Maximum tokens to generate
Yields:
Text chunks as they are generated
Structured stream events
"""
if not self.client:
raise RuntimeError("LLM service not connected")
@@ -276,25 +301,82 @@ class OpenAILLMService(BaseLLMService):
rag_messages = await self._with_knowledge_context(messages)
prepared = self._prepare_messages(rag_messages)
self._cancel_event.clear()
tool_accumulator: Dict[int, Dict[str, str]] = {}
openai_tools = self._tool_schemas or None
try:
stream = await self.client.chat.completions.create(
create_args: Dict[str, Any] = dict(
model=self.model,
messages=prepared,
temperature=temperature,
max_tokens=max_tokens,
stream=True
stream=True,
)
if openai_tools:
create_args["tools"] = openai_tools
create_args["tool_choice"] = "auto"
stream = await self.client.chat.completions.create(**create_args)
async for chunk in stream:
# Check for cancellation
if self._cancel_event.is_set():
logger.info("LLM stream cancelled")
break
if chunk.choices and chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
yield content
if not chunk.choices:
continue
choice = chunk.choices[0]
delta = getattr(choice, "delta", None)
if delta and getattr(delta, "content", None):
content = delta.content
yield LLMStreamEvent(type="text_delta", text=content)
# OpenAI streams function calls via incremental tool_calls deltas.
tool_calls = getattr(delta, "tool_calls", None) if delta else None
if tool_calls:
for tc in tool_calls:
index = getattr(tc, "index", 0) or 0
item = tool_accumulator.setdefault(
int(index),
{"id": "", "name": "", "arguments": ""},
)
tc_id = getattr(tc, "id", None)
if tc_id:
item["id"] = str(tc_id)
fn = getattr(tc, "function", None)
if fn:
fn_name = getattr(fn, "name", None)
if fn_name:
item["name"] = str(fn_name)
fn_args = getattr(fn, "arguments", None)
if fn_args:
item["arguments"] += str(fn_args)
finish_reason = getattr(choice, "finish_reason", None)
if finish_reason == "tool_calls" and tool_accumulator:
for _, payload in sorted(tool_accumulator.items(), key=lambda row: row[0]):
call_name = payload.get("name", "").strip()
if not call_name:
continue
call_id = payload.get("id", "").strip() or f"call_{uuid.uuid4().hex[:10]}"
yield LLMStreamEvent(
type="tool_call",
tool_call={
"id": call_id,
"type": "function",
"function": {
"name": call_name,
"arguments": payload.get("arguments", "") or "{}",
},
},
)
yield LLMStreamEvent(type="done")
return
if finish_reason in {"stop", "length", "content_filter"}:
yield LLMStreamEvent(type="done")
return
except asyncio.CancelledError:
logger.info("LLM stream cancelled via asyncio")
@@ -348,13 +430,14 @@ class MockLLMService(BaseLLMService):
messages: List[LLMMessage],
temperature: float = 0.7,
max_tokens: Optional[int] = None
) -> AsyncIterator[str]:
) -> AsyncIterator[LLMStreamEvent]:
response = await self.generate(messages, temperature, max_tokens)
# Stream word by word
words = response.split()
for i, word in enumerate(words):
if i > 0:
yield " "
yield word
yield LLMStreamEvent(type="text_delta", text=" ")
yield LLMStreamEvent(type="text_delta", text=word)
await asyncio.sleep(0.05) # Simulate streaming delay
yield LLMStreamEvent(type="done")