Files
AI-VideoAssistant/engine/providers/llm/openai.py
Xin Wang 7e0b777923 Refactor project structure and enhance backend integration
- Expanded package inclusion in `pyproject.toml` to support new modules.
- Introduced new `adapters` and `protocol` packages for better organization.
- Added backend adapter implementations for control plane integration.
- Updated main application imports to reflect new package structure.
- Removed deprecated core components and adjusted documentation accordingly.
- Enhanced architecture documentation to clarify the new runtime and integration layers.
2026-03-06 09:51:56 +08:00

450 lines
16 KiB
Python

"""LLM (Large Language Model) Service implementations.
Provides OpenAI-compatible LLM integration with streaming support
for real-time voice conversation.
"""
import os
import asyncio
import uuid
from typing import AsyncIterator, Optional, List, Dict, Any, Callable, Awaitable
from loguru import logger
from adapters.control_plane.backend import build_backend_adapter_from_settings
from providers.common.base import BaseLLMService, LLMMessage, LLMStreamEvent, ServiceState
# Try to import openai
try:
from openai import AsyncOpenAI
OPENAI_AVAILABLE = True
except ImportError:
OPENAI_AVAILABLE = False
logger.warning("openai package not available - LLM service will be disabled")
class OpenAILLMService(BaseLLMService):
"""
OpenAI-compatible LLM service.
Supports streaming responses for low-latency voice conversation.
Works with OpenAI API, Azure OpenAI, and compatible APIs.
"""
def __init__(
self,
model: str = "gpt-4o-mini",
api_key: Optional[str] = None,
base_url: Optional[str] = None,
system_prompt: Optional[str] = None,
knowledge_config: Optional[Dict[str, Any]] = None,
knowledge_searcher: Optional[Callable[..., Awaitable[List[Dict[str, Any]]]]] = None,
):
"""
Initialize OpenAI LLM service.
Args:
model: Model name (e.g., "gpt-4o-mini", "gpt-4o")
api_key: Provider API key
base_url: Custom API base URL (for Azure or compatible APIs)
system_prompt: Default system prompt for conversations
"""
super().__init__(model=model)
self.api_key = api_key
self.base_url = base_url or os.getenv("LLM_API_URL") or os.getenv("OPENAI_API_URL")
self.system_prompt = system_prompt or (
"You are a helpful, friendly voice assistant. "
"Keep your responses concise and conversational. "
"Respond naturally as if having a phone conversation."
)
self.client: Optional[AsyncOpenAI] = None
self._cancel_event = asyncio.Event()
self._knowledge_config: Dict[str, Any] = knowledge_config or {}
if knowledge_searcher is None:
adapter = build_backend_adapter_from_settings()
self._knowledge_searcher = adapter.search_knowledge_context
else:
self._knowledge_searcher = knowledge_searcher
self._tool_schemas: List[Dict[str, Any]] = []
_RAG_DEFAULT_RESULTS = 5
_RAG_MAX_RESULTS = 8
_RAG_MAX_CONTEXT_CHARS = 4000
async def connect(self) -> None:
"""Initialize OpenAI client."""
if not OPENAI_AVAILABLE:
raise RuntimeError("openai package not installed")
if not self.api_key:
raise ValueError("OpenAI API key not provided")
self.client = AsyncOpenAI(
api_key=self.api_key,
base_url=self.base_url
)
self.state = ServiceState.CONNECTED
logger.info(f"OpenAI LLM service connected: model={self.model}")
async def disconnect(self) -> None:
"""Close OpenAI client."""
if self.client:
await self.client.close()
self.client = None
self.state = ServiceState.DISCONNECTED
logger.info("OpenAI LLM service disconnected")
def _prepare_messages(self, messages: List[LLMMessage]) -> List[Dict[str, Any]]:
"""Prepare messages list with system prompt."""
result = []
# Add system prompt if not already present
has_system = any(m.role == "system" for m in messages)
if not has_system and self.system_prompt:
result.append({"role": "system", "content": self.system_prompt})
# Add all messages
for msg in messages:
result.append(msg.to_dict())
return result
def set_knowledge_config(self, config: Optional[Dict[str, Any]]) -> None:
"""Update runtime knowledge retrieval config."""
self._knowledge_config = config or {}
def set_tool_schemas(self, schemas: Optional[List[Dict[str, Any]]]) -> None:
"""Update runtime tool schemas."""
self._tool_schemas = []
if not isinstance(schemas, list):
return
for item in schemas:
if not isinstance(item, dict):
continue
fn = item.get("function")
if isinstance(fn, dict) and fn.get("name"):
self._tool_schemas.append(item)
elif item.get("name"):
self._tool_schemas.append(
{
"type": "function",
"function": {
"name": str(item.get("name")),
"description": str(item.get("description") or ""),
"parameters": item.get("parameters") or {"type": "object", "properties": {}},
},
}
)
@staticmethod
def _coerce_int(value: Any, default: int) -> int:
try:
return int(value)
except (TypeError, ValueError):
return default
def _resolve_kb_id(self) -> Optional[str]:
cfg = self._knowledge_config if isinstance(self._knowledge_config, dict) else {}
kb_id = str(
cfg.get("kbId")
or cfg.get("knowledgeBaseId")
or cfg.get("knowledge_base_id")
or ""
).strip()
return kb_id or None
def _build_knowledge_prompt(self, results: List[Dict[str, Any]]) -> Optional[str]:
if not results:
return None
lines = [
"You have retrieved the following knowledge base snippets.",
"Use them only when relevant to the latest user request.",
"If snippets are insufficient, say you are not sure instead of guessing.",
"",
]
used_chars = 0
used_count = 0
for item in results:
content = str(item.get("content") or "").strip()
if not content:
continue
if used_chars >= self._RAG_MAX_CONTEXT_CHARS:
break
metadata = item.get("metadata") if isinstance(item.get("metadata"), dict) else {}
doc_id = metadata.get("document_id")
chunk_index = metadata.get("chunk_index")
distance = item.get("distance")
source_parts = []
if doc_id:
source_parts.append(f"doc={doc_id}")
if chunk_index is not None:
source_parts.append(f"chunk={chunk_index}")
source = f" ({', '.join(source_parts)})" if source_parts else ""
distance_text = ""
try:
if distance is not None:
distance_text = f", distance={float(distance):.4f}"
except (TypeError, ValueError):
distance_text = ""
remaining = self._RAG_MAX_CONTEXT_CHARS - used_chars
snippet = content[:remaining].strip()
if not snippet:
continue
used_count += 1
lines.append(f"[{used_count}{source}{distance_text}] {snippet}")
used_chars += len(snippet)
if used_count == 0:
return None
return "\n".join(lines)
async def _with_knowledge_context(self, messages: List[LLMMessage]) -> List[LLMMessage]:
cfg = self._knowledge_config if isinstance(self._knowledge_config, dict) else {}
enabled = cfg.get("enabled", True)
if isinstance(enabled, str):
enabled = enabled.strip().lower() not in {"false", "0", "off", "no"}
if not enabled:
return messages
kb_id = self._resolve_kb_id()
if not kb_id:
return messages
latest_user = ""
for msg in reversed(messages):
if msg.role == "user":
latest_user = (msg.content or "").strip()
break
if not latest_user:
return messages
n_results = self._coerce_int(cfg.get("nResults"), self._RAG_DEFAULT_RESULTS)
n_results = max(1, min(n_results, self._RAG_MAX_RESULTS))
results = await self._knowledge_searcher(
kb_id=kb_id,
query=latest_user,
n_results=n_results,
)
prompt = self._build_knowledge_prompt(results)
if not prompt:
return messages
logger.debug(f"RAG context injected (kb_id={kb_id}, chunks={len(results)})")
rag_system = LLMMessage(role="system", content=prompt)
if messages and messages[0].role == "system":
return [messages[0], rag_system, *messages[1:]]
return [rag_system, *messages]
async def generate(
self,
messages: List[LLMMessage],
temperature: float = 0.7,
max_tokens: Optional[int] = None
) -> str:
"""
Generate a complete response.
Args:
messages: Conversation history
temperature: Sampling temperature
max_tokens: Maximum tokens to generate
Returns:
Complete response text
"""
if not self.client:
raise RuntimeError("LLM service not connected")
rag_messages = await self._with_knowledge_context(messages)
prepared = self._prepare_messages(rag_messages)
try:
response = await self.client.chat.completions.create(
model=self.model,
messages=prepared,
temperature=temperature,
max_tokens=max_tokens
)
content = response.choices[0].message.content or ""
logger.debug(f"LLM response: {content[:100]}...")
return content
except Exception as e:
logger.error(f"LLM generation error: {e}")
raise
async def generate_stream(
self,
messages: List[LLMMessage],
temperature: float = 0.7,
max_tokens: Optional[int] = None
) -> AsyncIterator[LLMStreamEvent]:
"""
Generate response in streaming mode.
Args:
messages: Conversation history
temperature: Sampling temperature
max_tokens: Maximum tokens to generate
Yields:
Structured stream events
"""
if not self.client:
raise RuntimeError("LLM service not connected")
rag_messages = await self._with_knowledge_context(messages)
prepared = self._prepare_messages(rag_messages)
self._cancel_event.clear()
tool_accumulator: Dict[int, Dict[str, str]] = {}
openai_tools = self._tool_schemas or None
try:
create_args: Dict[str, Any] = dict(
model=self.model,
messages=prepared,
temperature=temperature,
max_tokens=max_tokens,
stream=True,
)
if openai_tools:
create_args["tools"] = openai_tools
create_args["tool_choice"] = "auto"
stream = await self.client.chat.completions.create(**create_args)
async for chunk in stream:
# Check for cancellation
if self._cancel_event.is_set():
logger.info("LLM stream cancelled")
break
if not chunk.choices:
continue
choice = chunk.choices[0]
delta = getattr(choice, "delta", None)
if delta and getattr(delta, "content", None):
content = delta.content
yield LLMStreamEvent(type="text_delta", text=content)
# OpenAI streams function calls via incremental tool_calls deltas.
tool_calls = getattr(delta, "tool_calls", None) if delta else None
if tool_calls:
for tc in tool_calls:
index = getattr(tc, "index", 0) or 0
item = tool_accumulator.setdefault(
int(index),
{"id": "", "name": "", "arguments": ""},
)
tc_id = getattr(tc, "id", None)
if tc_id:
item["id"] = str(tc_id)
fn = getattr(tc, "function", None)
if fn:
fn_name = getattr(fn, "name", None)
if fn_name:
item["name"] = str(fn_name)
fn_args = getattr(fn, "arguments", None)
if fn_args:
item["arguments"] += str(fn_args)
finish_reason = getattr(choice, "finish_reason", None)
if finish_reason == "tool_calls" and tool_accumulator:
for _, payload in sorted(tool_accumulator.items(), key=lambda row: row[0]):
call_name = payload.get("name", "").strip()
if not call_name:
continue
call_id = payload.get("id", "").strip() or f"call_{uuid.uuid4().hex[:10]}"
yield LLMStreamEvent(
type="tool_call",
tool_call={
"id": call_id,
"type": "function",
"function": {
"name": call_name,
"arguments": payload.get("arguments", "") or "{}",
},
},
)
yield LLMStreamEvent(type="done")
return
if finish_reason in {"stop", "length", "content_filter"}:
yield LLMStreamEvent(type="done")
return
except asyncio.CancelledError:
logger.info("LLM stream cancelled via asyncio")
raise
except Exception as e:
logger.error(f"LLM streaming error: {e}")
raise
def cancel(self) -> None:
"""Cancel ongoing generation."""
self._cancel_event.set()
class MockLLMService(BaseLLMService):
"""
Mock LLM service for testing without API calls.
"""
def __init__(self, response_delay: float = 0.5):
super().__init__(model="mock")
self.response_delay = response_delay
self.responses = [
"Hello! How can I help you today?",
"That's an interesting question. Let me think about it.",
"I understand. Is there anything else you'd like to know?",
"Great! I'm here if you need anything else.",
]
self._response_index = 0
async def connect(self) -> None:
self.state = ServiceState.CONNECTED
logger.info("Mock LLM service connected")
async def disconnect(self) -> None:
self.state = ServiceState.DISCONNECTED
logger.info("Mock LLM service disconnected")
async def generate(
self,
messages: List[LLMMessage],
temperature: float = 0.7,
max_tokens: Optional[int] = None
) -> str:
await asyncio.sleep(self.response_delay)
response = self.responses[self._response_index % len(self.responses)]
self._response_index += 1
return response
async def generate_stream(
self,
messages: List[LLMMessage],
temperature: float = 0.7,
max_tokens: Optional[int] = None
) -> AsyncIterator[LLMStreamEvent]:
response = await self.generate(messages, temperature, max_tokens)
# Stream word by word
words = response.split()
for i, word in enumerate(words):
if i > 0:
yield LLMStreamEvent(type="text_delta", text=" ")
yield LLMStreamEvent(type="text_delta", text=word)
await asyncio.sleep(0.05) # Simulate streaming delay
yield LLMStreamEvent(type="done")