- Removed legacy agent profile settings from the .env.example and README, streamlining the configuration process. - Introduced a new local YAML configuration adapter for assistant settings, allowing for easier management of assistant profiles. - Updated backend integration documentation to clarify the behavior of assistant config sourcing based on backend URL settings. - Adjusted various service implementations to directly utilize API keys from the new configuration structure. - Enhanced test coverage for the new local YAML adapter and its integration with backend services.
450 lines
16 KiB
Python
450 lines
16 KiB
Python
"""LLM (Large Language Model) Service implementations.
|
|
|
|
Provides OpenAI-compatible LLM integration with streaming support
|
|
for real-time voice conversation.
|
|
"""
|
|
|
|
import os
|
|
import asyncio
|
|
import uuid
|
|
from typing import AsyncIterator, Optional, List, Dict, Any, Callable, Awaitable
|
|
from loguru import logger
|
|
|
|
from app.backend_adapters import build_backend_adapter_from_settings
|
|
from services.base import BaseLLMService, LLMMessage, LLMStreamEvent, ServiceState
|
|
|
|
# Try to import openai
|
|
try:
|
|
from openai import AsyncOpenAI
|
|
OPENAI_AVAILABLE = True
|
|
except ImportError:
|
|
OPENAI_AVAILABLE = False
|
|
logger.warning("openai package not available - LLM service will be disabled")
|
|
|
|
|
|
class OpenAILLMService(BaseLLMService):
|
|
"""
|
|
OpenAI-compatible LLM service.
|
|
|
|
Supports streaming responses for low-latency voice conversation.
|
|
Works with OpenAI API, Azure OpenAI, and compatible APIs.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model: str = "gpt-4o-mini",
|
|
api_key: Optional[str] = None,
|
|
base_url: Optional[str] = None,
|
|
system_prompt: Optional[str] = None,
|
|
knowledge_config: Optional[Dict[str, Any]] = None,
|
|
knowledge_searcher: Optional[Callable[..., Awaitable[List[Dict[str, Any]]]]] = None,
|
|
):
|
|
"""
|
|
Initialize OpenAI LLM service.
|
|
|
|
Args:
|
|
model: Model name (e.g., "gpt-4o-mini", "gpt-4o")
|
|
api_key: Provider API key
|
|
base_url: Custom API base URL (for Azure or compatible APIs)
|
|
system_prompt: Default system prompt for conversations
|
|
"""
|
|
super().__init__(model=model)
|
|
|
|
self.api_key = api_key
|
|
self.base_url = base_url or os.getenv("LLM_API_URL") or os.getenv("OPENAI_API_URL")
|
|
self.system_prompt = system_prompt or (
|
|
"You are a helpful, friendly voice assistant. "
|
|
"Keep your responses concise and conversational. "
|
|
"Respond naturally as if having a phone conversation."
|
|
)
|
|
|
|
self.client: Optional[AsyncOpenAI] = None
|
|
self._cancel_event = asyncio.Event()
|
|
self._knowledge_config: Dict[str, Any] = knowledge_config or {}
|
|
if knowledge_searcher is None:
|
|
adapter = build_backend_adapter_from_settings()
|
|
self._knowledge_searcher = adapter.search_knowledge_context
|
|
else:
|
|
self._knowledge_searcher = knowledge_searcher
|
|
self._tool_schemas: List[Dict[str, Any]] = []
|
|
|
|
_RAG_DEFAULT_RESULTS = 5
|
|
_RAG_MAX_RESULTS = 8
|
|
_RAG_MAX_CONTEXT_CHARS = 4000
|
|
|
|
async def connect(self) -> None:
|
|
"""Initialize OpenAI client."""
|
|
if not OPENAI_AVAILABLE:
|
|
raise RuntimeError("openai package not installed")
|
|
|
|
if not self.api_key:
|
|
raise ValueError("OpenAI API key not provided")
|
|
|
|
self.client = AsyncOpenAI(
|
|
api_key=self.api_key,
|
|
base_url=self.base_url
|
|
)
|
|
self.state = ServiceState.CONNECTED
|
|
logger.info(f"OpenAI LLM service connected: model={self.model}")
|
|
|
|
async def disconnect(self) -> None:
|
|
"""Close OpenAI client."""
|
|
if self.client:
|
|
await self.client.close()
|
|
self.client = None
|
|
self.state = ServiceState.DISCONNECTED
|
|
logger.info("OpenAI LLM service disconnected")
|
|
|
|
def _prepare_messages(self, messages: List[LLMMessage]) -> List[Dict[str, Any]]:
|
|
"""Prepare messages list with system prompt."""
|
|
result = []
|
|
|
|
# Add system prompt if not already present
|
|
has_system = any(m.role == "system" for m in messages)
|
|
if not has_system and self.system_prompt:
|
|
result.append({"role": "system", "content": self.system_prompt})
|
|
|
|
# Add all messages
|
|
for msg in messages:
|
|
result.append(msg.to_dict())
|
|
|
|
return result
|
|
|
|
def set_knowledge_config(self, config: Optional[Dict[str, Any]]) -> None:
|
|
"""Update runtime knowledge retrieval config."""
|
|
self._knowledge_config = config or {}
|
|
|
|
def set_tool_schemas(self, schemas: Optional[List[Dict[str, Any]]]) -> None:
|
|
"""Update runtime tool schemas."""
|
|
self._tool_schemas = []
|
|
if not isinstance(schemas, list):
|
|
return
|
|
for item in schemas:
|
|
if not isinstance(item, dict):
|
|
continue
|
|
fn = item.get("function")
|
|
if isinstance(fn, dict) and fn.get("name"):
|
|
self._tool_schemas.append(item)
|
|
elif item.get("name"):
|
|
self._tool_schemas.append(
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": str(item.get("name")),
|
|
"description": str(item.get("description") or ""),
|
|
"parameters": item.get("parameters") or {"type": "object", "properties": {}},
|
|
},
|
|
}
|
|
)
|
|
|
|
@staticmethod
|
|
def _coerce_int(value: Any, default: int) -> int:
|
|
try:
|
|
return int(value)
|
|
except (TypeError, ValueError):
|
|
return default
|
|
|
|
def _resolve_kb_id(self) -> Optional[str]:
|
|
cfg = self._knowledge_config if isinstance(self._knowledge_config, dict) else {}
|
|
kb_id = str(
|
|
cfg.get("kbId")
|
|
or cfg.get("knowledgeBaseId")
|
|
or cfg.get("knowledge_base_id")
|
|
or ""
|
|
).strip()
|
|
return kb_id or None
|
|
|
|
def _build_knowledge_prompt(self, results: List[Dict[str, Any]]) -> Optional[str]:
|
|
if not results:
|
|
return None
|
|
|
|
lines = [
|
|
"You have retrieved the following knowledge base snippets.",
|
|
"Use them only when relevant to the latest user request.",
|
|
"If snippets are insufficient, say you are not sure instead of guessing.",
|
|
"",
|
|
]
|
|
|
|
used_chars = 0
|
|
used_count = 0
|
|
for item in results:
|
|
content = str(item.get("content") or "").strip()
|
|
if not content:
|
|
continue
|
|
if used_chars >= self._RAG_MAX_CONTEXT_CHARS:
|
|
break
|
|
|
|
metadata = item.get("metadata") if isinstance(item.get("metadata"), dict) else {}
|
|
doc_id = metadata.get("document_id")
|
|
chunk_index = metadata.get("chunk_index")
|
|
distance = item.get("distance")
|
|
|
|
source_parts = []
|
|
if doc_id:
|
|
source_parts.append(f"doc={doc_id}")
|
|
if chunk_index is not None:
|
|
source_parts.append(f"chunk={chunk_index}")
|
|
source = f" ({', '.join(source_parts)})" if source_parts else ""
|
|
|
|
distance_text = ""
|
|
try:
|
|
if distance is not None:
|
|
distance_text = f", distance={float(distance):.4f}"
|
|
except (TypeError, ValueError):
|
|
distance_text = ""
|
|
|
|
remaining = self._RAG_MAX_CONTEXT_CHARS - used_chars
|
|
snippet = content[:remaining].strip()
|
|
if not snippet:
|
|
continue
|
|
|
|
used_count += 1
|
|
lines.append(f"[{used_count}{source}{distance_text}] {snippet}")
|
|
used_chars += len(snippet)
|
|
|
|
if used_count == 0:
|
|
return None
|
|
|
|
return "\n".join(lines)
|
|
|
|
async def _with_knowledge_context(self, messages: List[LLMMessage]) -> List[LLMMessage]:
|
|
cfg = self._knowledge_config if isinstance(self._knowledge_config, dict) else {}
|
|
enabled = cfg.get("enabled", True)
|
|
if isinstance(enabled, str):
|
|
enabled = enabled.strip().lower() not in {"false", "0", "off", "no"}
|
|
if not enabled:
|
|
return messages
|
|
|
|
kb_id = self._resolve_kb_id()
|
|
if not kb_id:
|
|
return messages
|
|
|
|
latest_user = ""
|
|
for msg in reversed(messages):
|
|
if msg.role == "user":
|
|
latest_user = (msg.content or "").strip()
|
|
break
|
|
if not latest_user:
|
|
return messages
|
|
|
|
n_results = self._coerce_int(cfg.get("nResults"), self._RAG_DEFAULT_RESULTS)
|
|
n_results = max(1, min(n_results, self._RAG_MAX_RESULTS))
|
|
|
|
results = await self._knowledge_searcher(
|
|
kb_id=kb_id,
|
|
query=latest_user,
|
|
n_results=n_results,
|
|
)
|
|
prompt = self._build_knowledge_prompt(results)
|
|
if not prompt:
|
|
return messages
|
|
|
|
logger.debug(f"RAG context injected (kb_id={kb_id}, chunks={len(results)})")
|
|
rag_system = LLMMessage(role="system", content=prompt)
|
|
if messages and messages[0].role == "system":
|
|
return [messages[0], rag_system, *messages[1:]]
|
|
return [rag_system, *messages]
|
|
|
|
async def generate(
|
|
self,
|
|
messages: List[LLMMessage],
|
|
temperature: float = 0.7,
|
|
max_tokens: Optional[int] = None
|
|
) -> str:
|
|
"""
|
|
Generate a complete response.
|
|
|
|
Args:
|
|
messages: Conversation history
|
|
temperature: Sampling temperature
|
|
max_tokens: Maximum tokens to generate
|
|
|
|
Returns:
|
|
Complete response text
|
|
"""
|
|
if not self.client:
|
|
raise RuntimeError("LLM service not connected")
|
|
|
|
rag_messages = await self._with_knowledge_context(messages)
|
|
prepared = self._prepare_messages(rag_messages)
|
|
|
|
try:
|
|
response = await self.client.chat.completions.create(
|
|
model=self.model,
|
|
messages=prepared,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens
|
|
)
|
|
|
|
content = response.choices[0].message.content or ""
|
|
logger.debug(f"LLM response: {content[:100]}...")
|
|
return content
|
|
|
|
except Exception as e:
|
|
logger.error(f"LLM generation error: {e}")
|
|
raise
|
|
|
|
async def generate_stream(
|
|
self,
|
|
messages: List[LLMMessage],
|
|
temperature: float = 0.7,
|
|
max_tokens: Optional[int] = None
|
|
) -> AsyncIterator[LLMStreamEvent]:
|
|
"""
|
|
Generate response in streaming mode.
|
|
|
|
Args:
|
|
messages: Conversation history
|
|
temperature: Sampling temperature
|
|
max_tokens: Maximum tokens to generate
|
|
|
|
Yields:
|
|
Structured stream events
|
|
"""
|
|
if not self.client:
|
|
raise RuntimeError("LLM service not connected")
|
|
|
|
rag_messages = await self._with_knowledge_context(messages)
|
|
prepared = self._prepare_messages(rag_messages)
|
|
self._cancel_event.clear()
|
|
tool_accumulator: Dict[int, Dict[str, str]] = {}
|
|
openai_tools = self._tool_schemas or None
|
|
|
|
try:
|
|
create_args: Dict[str, Any] = dict(
|
|
model=self.model,
|
|
messages=prepared,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
stream=True,
|
|
)
|
|
if openai_tools:
|
|
create_args["tools"] = openai_tools
|
|
create_args["tool_choice"] = "auto"
|
|
stream = await self.client.chat.completions.create(**create_args)
|
|
|
|
async for chunk in stream:
|
|
# Check for cancellation
|
|
if self._cancel_event.is_set():
|
|
logger.info("LLM stream cancelled")
|
|
break
|
|
|
|
if not chunk.choices:
|
|
continue
|
|
|
|
choice = chunk.choices[0]
|
|
delta = getattr(choice, "delta", None)
|
|
if delta and getattr(delta, "content", None):
|
|
content = delta.content
|
|
yield LLMStreamEvent(type="text_delta", text=content)
|
|
|
|
# OpenAI streams function calls via incremental tool_calls deltas.
|
|
tool_calls = getattr(delta, "tool_calls", None) if delta else None
|
|
if tool_calls:
|
|
for tc in tool_calls:
|
|
index = getattr(tc, "index", 0) or 0
|
|
item = tool_accumulator.setdefault(
|
|
int(index),
|
|
{"id": "", "name": "", "arguments": ""},
|
|
)
|
|
tc_id = getattr(tc, "id", None)
|
|
if tc_id:
|
|
item["id"] = str(tc_id)
|
|
fn = getattr(tc, "function", None)
|
|
if fn:
|
|
fn_name = getattr(fn, "name", None)
|
|
if fn_name:
|
|
item["name"] = str(fn_name)
|
|
fn_args = getattr(fn, "arguments", None)
|
|
if fn_args:
|
|
item["arguments"] += str(fn_args)
|
|
|
|
finish_reason = getattr(choice, "finish_reason", None)
|
|
if finish_reason == "tool_calls" and tool_accumulator:
|
|
for _, payload in sorted(tool_accumulator.items(), key=lambda row: row[0]):
|
|
call_name = payload.get("name", "").strip()
|
|
if not call_name:
|
|
continue
|
|
call_id = payload.get("id", "").strip() or f"call_{uuid.uuid4().hex[:10]}"
|
|
yield LLMStreamEvent(
|
|
type="tool_call",
|
|
tool_call={
|
|
"id": call_id,
|
|
"type": "function",
|
|
"function": {
|
|
"name": call_name,
|
|
"arguments": payload.get("arguments", "") or "{}",
|
|
},
|
|
},
|
|
)
|
|
yield LLMStreamEvent(type="done")
|
|
return
|
|
|
|
if finish_reason in {"stop", "length", "content_filter"}:
|
|
yield LLMStreamEvent(type="done")
|
|
return
|
|
|
|
except asyncio.CancelledError:
|
|
logger.info("LLM stream cancelled via asyncio")
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"LLM streaming error: {e}")
|
|
raise
|
|
|
|
def cancel(self) -> None:
|
|
"""Cancel ongoing generation."""
|
|
self._cancel_event.set()
|
|
|
|
|
|
class MockLLMService(BaseLLMService):
|
|
"""
|
|
Mock LLM service for testing without API calls.
|
|
"""
|
|
|
|
def __init__(self, response_delay: float = 0.5):
|
|
super().__init__(model="mock")
|
|
self.response_delay = response_delay
|
|
self.responses = [
|
|
"Hello! How can I help you today?",
|
|
"That's an interesting question. Let me think about it.",
|
|
"I understand. Is there anything else you'd like to know?",
|
|
"Great! I'm here if you need anything else.",
|
|
]
|
|
self._response_index = 0
|
|
|
|
async def connect(self) -> None:
|
|
self.state = ServiceState.CONNECTED
|
|
logger.info("Mock LLM service connected")
|
|
|
|
async def disconnect(self) -> None:
|
|
self.state = ServiceState.DISCONNECTED
|
|
logger.info("Mock LLM service disconnected")
|
|
|
|
async def generate(
|
|
self,
|
|
messages: List[LLMMessage],
|
|
temperature: float = 0.7,
|
|
max_tokens: Optional[int] = None
|
|
) -> str:
|
|
await asyncio.sleep(self.response_delay)
|
|
response = self.responses[self._response_index % len(self.responses)]
|
|
self._response_index += 1
|
|
return response
|
|
|
|
async def generate_stream(
|
|
self,
|
|
messages: List[LLMMessage],
|
|
temperature: float = 0.7,
|
|
max_tokens: Optional[int] = None
|
|
) -> AsyncIterator[LLMStreamEvent]:
|
|
response = await self.generate(messages, temperature, max_tokens)
|
|
|
|
# Stream word by word
|
|
words = response.split()
|
|
for i, word in enumerate(words):
|
|
if i > 0:
|
|
yield LLMStreamEvent(type="text_delta", text=" ")
|
|
yield LLMStreamEvent(type="text_delta", text=word)
|
|
await asyncio.sleep(0.05) # Simulate streaming delay
|
|
yield LLMStreamEvent(type="done")
|