Update code for tool call

This commit is contained in:
Xin Wang
2026-02-10 16:28:20 +08:00
parent 539cf2fda2
commit 4b8da32787
7 changed files with 676 additions and 91 deletions

View File

@@ -7,7 +7,7 @@ StreamEngine pattern.
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import AsyncIterator, Optional, List, Dict, Any
from typing import AsyncIterator, Optional, List, Dict, Any, Literal
from enum import Enum
@@ -52,6 +52,15 @@ class LLMMessage:
return d
@dataclass
class LLMStreamEvent:
"""Structured LLM stream event."""
type: Literal["text_delta", "tool_call", "done"]
text: Optional[str] = None
tool_call: Optional[Dict[str, Any]] = None
@dataclass
class TTSChunk:
"""TTS audio chunk."""
@@ -170,7 +179,7 @@ class BaseLLMService(ABC):
messages: List[LLMMessage],
temperature: float = 0.7,
max_tokens: Optional[int] = None
) -> AsyncIterator[str]:
) -> AsyncIterator[LLMStreamEvent]:
"""
Generate response in streaming mode.
@@ -180,7 +189,7 @@ class BaseLLMService(ABC):
max_tokens: Maximum tokens to generate
Yields:
Text chunks as they are generated
Stream events (text delta/tool call/done)
"""
pass

View File

@@ -6,11 +6,12 @@ for real-time voice conversation.
import os
import asyncio
import uuid
from typing import AsyncIterator, Optional, List, Dict, Any
from loguru import logger
from app.backend_client import search_knowledge_context
from services.base import BaseLLMService, LLMMessage, ServiceState
from services.base import BaseLLMService, LLMMessage, LLMStreamEvent, ServiceState
# Try to import openai
try:
@@ -59,6 +60,7 @@ class OpenAILLMService(BaseLLMService):
self.client: Optional[AsyncOpenAI] = None
self._cancel_event = asyncio.Event()
self._knowledge_config: Dict[str, Any] = knowledge_config or {}
self._tool_schemas: List[Dict[str, Any]] = []
_RAG_DEFAULT_RESULTS = 5
_RAG_MAX_RESULTS = 8
@@ -106,6 +108,29 @@ class OpenAILLMService(BaseLLMService):
"""Update runtime knowledge retrieval config."""
self._knowledge_config = config or {}
def set_tool_schemas(self, schemas: Optional[List[Dict[str, Any]]]) -> None:
"""Update runtime tool schemas."""
self._tool_schemas = []
if not isinstance(schemas, list):
return
for item in schemas:
if not isinstance(item, dict):
continue
fn = item.get("function")
if isinstance(fn, dict) and fn.get("name"):
self._tool_schemas.append(item)
elif item.get("name"):
self._tool_schemas.append(
{
"type": "function",
"function": {
"name": str(item.get("name")),
"description": str(item.get("description") or ""),
"parameters": item.get("parameters") or {"type": "object", "properties": {}},
},
}
)
@staticmethod
def _coerce_int(value: Any, default: int) -> int:
try:
@@ -258,7 +283,7 @@ class OpenAILLMService(BaseLLMService):
messages: List[LLMMessage],
temperature: float = 0.7,
max_tokens: Optional[int] = None
) -> AsyncIterator[str]:
) -> AsyncIterator[LLMStreamEvent]:
"""
Generate response in streaming mode.
@@ -268,7 +293,7 @@ class OpenAILLMService(BaseLLMService):
max_tokens: Maximum tokens to generate
Yields:
Text chunks as they are generated
Structured stream events
"""
if not self.client:
raise RuntimeError("LLM service not connected")
@@ -276,25 +301,82 @@ class OpenAILLMService(BaseLLMService):
rag_messages = await self._with_knowledge_context(messages)
prepared = self._prepare_messages(rag_messages)
self._cancel_event.clear()
tool_accumulator: Dict[int, Dict[str, str]] = {}
openai_tools = self._tool_schemas or None
try:
stream = await self.client.chat.completions.create(
create_args: Dict[str, Any] = dict(
model=self.model,
messages=prepared,
temperature=temperature,
max_tokens=max_tokens,
stream=True
stream=True,
)
if openai_tools:
create_args["tools"] = openai_tools
create_args["tool_choice"] = "auto"
stream = await self.client.chat.completions.create(**create_args)
async for chunk in stream:
# Check for cancellation
if self._cancel_event.is_set():
logger.info("LLM stream cancelled")
break
if chunk.choices and chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
yield content
if not chunk.choices:
continue
choice = chunk.choices[0]
delta = getattr(choice, "delta", None)
if delta and getattr(delta, "content", None):
content = delta.content
yield LLMStreamEvent(type="text_delta", text=content)
# OpenAI streams function calls via incremental tool_calls deltas.
tool_calls = getattr(delta, "tool_calls", None) if delta else None
if tool_calls:
for tc in tool_calls:
index = getattr(tc, "index", 0) or 0
item = tool_accumulator.setdefault(
int(index),
{"id": "", "name": "", "arguments": ""},
)
tc_id = getattr(tc, "id", None)
if tc_id:
item["id"] = str(tc_id)
fn = getattr(tc, "function", None)
if fn:
fn_name = getattr(fn, "name", None)
if fn_name:
item["name"] = str(fn_name)
fn_args = getattr(fn, "arguments", None)
if fn_args:
item["arguments"] += str(fn_args)
finish_reason = getattr(choice, "finish_reason", None)
if finish_reason == "tool_calls" and tool_accumulator:
for _, payload in sorted(tool_accumulator.items(), key=lambda row: row[0]):
call_name = payload.get("name", "").strip()
if not call_name:
continue
call_id = payload.get("id", "").strip() or f"call_{uuid.uuid4().hex[:10]}"
yield LLMStreamEvent(
type="tool_call",
tool_call={
"id": call_id,
"type": "function",
"function": {
"name": call_name,
"arguments": payload.get("arguments", "") or "{}",
},
},
)
yield LLMStreamEvent(type="done")
return
if finish_reason in {"stop", "length", "content_filter"}:
yield LLMStreamEvent(type="done")
return
except asyncio.CancelledError:
logger.info("LLM stream cancelled via asyncio")
@@ -348,13 +430,14 @@ class MockLLMService(BaseLLMService):
messages: List[LLMMessage],
temperature: float = 0.7,
max_tokens: Optional[int] = None
) -> AsyncIterator[str]:
) -> AsyncIterator[LLMStreamEvent]:
response = await self.generate(messages, temperature, max_tokens)
# Stream word by word
words = response.split()
for i, word in enumerate(words):
if i > 0:
yield " "
yield word
yield LLMStreamEvent(type="text_delta", text=" ")
yield LLMStreamEvent(type="text_delta", text=word)
await asyncio.sleep(0.05) # Simulate streaming delay
yield LLMStreamEvent(type="done")