Update code for tool call
This commit is contained in:
@@ -7,7 +7,7 @@ StreamEngine pattern.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import AsyncIterator, Optional, List, Dict, Any
|
||||
from typing import AsyncIterator, Optional, List, Dict, Any, Literal
|
||||
from enum import Enum
|
||||
|
||||
|
||||
@@ -52,6 +52,15 @@ class LLMMessage:
|
||||
return d
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMStreamEvent:
|
||||
"""Structured LLM stream event."""
|
||||
|
||||
type: Literal["text_delta", "tool_call", "done"]
|
||||
text: Optional[str] = None
|
||||
tool_call: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TTSChunk:
|
||||
"""TTS audio chunk."""
|
||||
@@ -170,7 +179,7 @@ class BaseLLMService(ABC):
|
||||
messages: List[LLMMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None
|
||||
) -> AsyncIterator[str]:
|
||||
) -> AsyncIterator[LLMStreamEvent]:
|
||||
"""
|
||||
Generate response in streaming mode.
|
||||
|
||||
@@ -180,7 +189,7 @@ class BaseLLMService(ABC):
|
||||
max_tokens: Maximum tokens to generate
|
||||
|
||||
Yields:
|
||||
Text chunks as they are generated
|
||||
Stream events (text delta/tool call/done)
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@@ -6,11 +6,12 @@ for real-time voice conversation.
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import AsyncIterator, Optional, List, Dict, Any
|
||||
from loguru import logger
|
||||
|
||||
from app.backend_client import search_knowledge_context
|
||||
from services.base import BaseLLMService, LLMMessage, ServiceState
|
||||
from services.base import BaseLLMService, LLMMessage, LLMStreamEvent, ServiceState
|
||||
|
||||
# Try to import openai
|
||||
try:
|
||||
@@ -59,6 +60,7 @@ class OpenAILLMService(BaseLLMService):
|
||||
self.client: Optional[AsyncOpenAI] = None
|
||||
self._cancel_event = asyncio.Event()
|
||||
self._knowledge_config: Dict[str, Any] = knowledge_config or {}
|
||||
self._tool_schemas: List[Dict[str, Any]] = []
|
||||
|
||||
_RAG_DEFAULT_RESULTS = 5
|
||||
_RAG_MAX_RESULTS = 8
|
||||
@@ -106,6 +108,29 @@ class OpenAILLMService(BaseLLMService):
|
||||
"""Update runtime knowledge retrieval config."""
|
||||
self._knowledge_config = config or {}
|
||||
|
||||
def set_tool_schemas(self, schemas: Optional[List[Dict[str, Any]]]) -> None:
|
||||
"""Update runtime tool schemas."""
|
||||
self._tool_schemas = []
|
||||
if not isinstance(schemas, list):
|
||||
return
|
||||
for item in schemas:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
fn = item.get("function")
|
||||
if isinstance(fn, dict) and fn.get("name"):
|
||||
self._tool_schemas.append(item)
|
||||
elif item.get("name"):
|
||||
self._tool_schemas.append(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": str(item.get("name")),
|
||||
"description": str(item.get("description") or ""),
|
||||
"parameters": item.get("parameters") or {"type": "object", "properties": {}},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _coerce_int(value: Any, default: int) -> int:
|
||||
try:
|
||||
@@ -258,7 +283,7 @@ class OpenAILLMService(BaseLLMService):
|
||||
messages: List[LLMMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None
|
||||
) -> AsyncIterator[str]:
|
||||
) -> AsyncIterator[LLMStreamEvent]:
|
||||
"""
|
||||
Generate response in streaming mode.
|
||||
|
||||
@@ -268,7 +293,7 @@ class OpenAILLMService(BaseLLMService):
|
||||
max_tokens: Maximum tokens to generate
|
||||
|
||||
Yields:
|
||||
Text chunks as they are generated
|
||||
Structured stream events
|
||||
"""
|
||||
if not self.client:
|
||||
raise RuntimeError("LLM service not connected")
|
||||
@@ -276,25 +301,82 @@ class OpenAILLMService(BaseLLMService):
|
||||
rag_messages = await self._with_knowledge_context(messages)
|
||||
prepared = self._prepare_messages(rag_messages)
|
||||
self._cancel_event.clear()
|
||||
tool_accumulator: Dict[int, Dict[str, str]] = {}
|
||||
openai_tools = self._tool_schemas or None
|
||||
|
||||
try:
|
||||
stream = await self.client.chat.completions.create(
|
||||
create_args: Dict[str, Any] = dict(
|
||||
model=self.model,
|
||||
messages=prepared,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=True
|
||||
stream=True,
|
||||
)
|
||||
if openai_tools:
|
||||
create_args["tools"] = openai_tools
|
||||
create_args["tool_choice"] = "auto"
|
||||
stream = await self.client.chat.completions.create(**create_args)
|
||||
|
||||
async for chunk in stream:
|
||||
# Check for cancellation
|
||||
if self._cancel_event.is_set():
|
||||
logger.info("LLM stream cancelled")
|
||||
break
|
||||
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
content = chunk.choices[0].delta.content
|
||||
yield content
|
||||
|
||||
if not chunk.choices:
|
||||
continue
|
||||
|
||||
choice = chunk.choices[0]
|
||||
delta = getattr(choice, "delta", None)
|
||||
if delta and getattr(delta, "content", None):
|
||||
content = delta.content
|
||||
yield LLMStreamEvent(type="text_delta", text=content)
|
||||
|
||||
# OpenAI streams function calls via incremental tool_calls deltas.
|
||||
tool_calls = getattr(delta, "tool_calls", None) if delta else None
|
||||
if tool_calls:
|
||||
for tc in tool_calls:
|
||||
index = getattr(tc, "index", 0) or 0
|
||||
item = tool_accumulator.setdefault(
|
||||
int(index),
|
||||
{"id": "", "name": "", "arguments": ""},
|
||||
)
|
||||
tc_id = getattr(tc, "id", None)
|
||||
if tc_id:
|
||||
item["id"] = str(tc_id)
|
||||
fn = getattr(tc, "function", None)
|
||||
if fn:
|
||||
fn_name = getattr(fn, "name", None)
|
||||
if fn_name:
|
||||
item["name"] = str(fn_name)
|
||||
fn_args = getattr(fn, "arguments", None)
|
||||
if fn_args:
|
||||
item["arguments"] += str(fn_args)
|
||||
|
||||
finish_reason = getattr(choice, "finish_reason", None)
|
||||
if finish_reason == "tool_calls" and tool_accumulator:
|
||||
for _, payload in sorted(tool_accumulator.items(), key=lambda row: row[0]):
|
||||
call_name = payload.get("name", "").strip()
|
||||
if not call_name:
|
||||
continue
|
||||
call_id = payload.get("id", "").strip() or f"call_{uuid.uuid4().hex[:10]}"
|
||||
yield LLMStreamEvent(
|
||||
type="tool_call",
|
||||
tool_call={
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": call_name,
|
||||
"arguments": payload.get("arguments", "") or "{}",
|
||||
},
|
||||
},
|
||||
)
|
||||
yield LLMStreamEvent(type="done")
|
||||
return
|
||||
|
||||
if finish_reason in {"stop", "length", "content_filter"}:
|
||||
yield LLMStreamEvent(type="done")
|
||||
return
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("LLM stream cancelled via asyncio")
|
||||
@@ -348,13 +430,14 @@ class MockLLMService(BaseLLMService):
|
||||
messages: List[LLMMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None
|
||||
) -> AsyncIterator[str]:
|
||||
) -> AsyncIterator[LLMStreamEvent]:
|
||||
response = await self.generate(messages, temperature, max_tokens)
|
||||
|
||||
# Stream word by word
|
||||
words = response.split()
|
||||
for i, word in enumerate(words):
|
||||
if i > 0:
|
||||
yield " "
|
||||
yield word
|
||||
yield LLMStreamEvent(type="text_delta", text=" ")
|
||||
yield LLMStreamEvent(type="text_delta", text=word)
|
||||
await asyncio.sleep(0.05) # Simulate streaming delay
|
||||
yield LLMStreamEvent(type="done")
|
||||
|
||||
Reference in New Issue
Block a user