Move rag to llm service
This commit is contained in:
@@ -9,6 +9,7 @@ import asyncio
|
||||
from typing import AsyncIterator, Optional, List, Dict, Any
|
||||
from loguru import logger
|
||||
|
||||
from app.backend_client import search_knowledge_context
|
||||
from services.base import BaseLLMService, LLMMessage, ServiceState
|
||||
|
||||
# Try to import openai
|
||||
@@ -33,7 +34,8 @@ class OpenAILLMService(BaseLLMService):
|
||||
model: str = "gpt-4o-mini",
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
system_prompt: Optional[str] = None
|
||||
system_prompt: Optional[str] = None,
|
||||
knowledge_config: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize OpenAI LLM service.
|
||||
@@ -56,6 +58,11 @@ class OpenAILLMService(BaseLLMService):
|
||||
|
||||
self.client: Optional[AsyncOpenAI] = None
|
||||
self._cancel_event = asyncio.Event()
|
||||
self._knowledge_config: Dict[str, Any] = knowledge_config or {}
|
||||
|
||||
_RAG_DEFAULT_RESULTS = 5
|
||||
_RAG_MAX_RESULTS = 8
|
||||
_RAG_MAX_CONTEXT_CHARS = 4000
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Initialize OpenAI client."""
|
||||
@@ -94,6 +101,118 @@ class OpenAILLMService(BaseLLMService):
|
||||
result.append(msg.to_dict())
|
||||
|
||||
return result
|
||||
|
||||
def set_knowledge_config(self, config: Optional[Dict[str, Any]]) -> None:
|
||||
"""Update runtime knowledge retrieval config."""
|
||||
self._knowledge_config = config or {}
|
||||
|
||||
@staticmethod
|
||||
def _coerce_int(value: Any, default: int) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
def _resolve_kb_id(self) -> Optional[str]:
|
||||
cfg = self._knowledge_config if isinstance(self._knowledge_config, dict) else {}
|
||||
kb_id = str(
|
||||
cfg.get("kbId")
|
||||
or cfg.get("knowledgeBaseId")
|
||||
or cfg.get("knowledge_base_id")
|
||||
or ""
|
||||
).strip()
|
||||
return kb_id or None
|
||||
|
||||
def _build_knowledge_prompt(self, results: List[Dict[str, Any]]) -> Optional[str]:
|
||||
if not results:
|
||||
return None
|
||||
|
||||
lines = [
|
||||
"You have retrieved the following knowledge base snippets.",
|
||||
"Use them only when relevant to the latest user request.",
|
||||
"If snippets are insufficient, say you are not sure instead of guessing.",
|
||||
"",
|
||||
]
|
||||
|
||||
used_chars = 0
|
||||
used_count = 0
|
||||
for item in results:
|
||||
content = str(item.get("content") or "").strip()
|
||||
if not content:
|
||||
continue
|
||||
if used_chars >= self._RAG_MAX_CONTEXT_CHARS:
|
||||
break
|
||||
|
||||
metadata = item.get("metadata") if isinstance(item.get("metadata"), dict) else {}
|
||||
doc_id = metadata.get("document_id")
|
||||
chunk_index = metadata.get("chunk_index")
|
||||
distance = item.get("distance")
|
||||
|
||||
source_parts = []
|
||||
if doc_id:
|
||||
source_parts.append(f"doc={doc_id}")
|
||||
if chunk_index is not None:
|
||||
source_parts.append(f"chunk={chunk_index}")
|
||||
source = f" ({', '.join(source_parts)})" if source_parts else ""
|
||||
|
||||
distance_text = ""
|
||||
try:
|
||||
if distance is not None:
|
||||
distance_text = f", distance={float(distance):.4f}"
|
||||
except (TypeError, ValueError):
|
||||
distance_text = ""
|
||||
|
||||
remaining = self._RAG_MAX_CONTEXT_CHARS - used_chars
|
||||
snippet = content[:remaining].strip()
|
||||
if not snippet:
|
||||
continue
|
||||
|
||||
used_count += 1
|
||||
lines.append(f"[{used_count}{source}{distance_text}] {snippet}")
|
||||
used_chars += len(snippet)
|
||||
|
||||
if used_count == 0:
|
||||
return None
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
async def _with_knowledge_context(self, messages: List[LLMMessage]) -> List[LLMMessage]:
|
||||
cfg = self._knowledge_config if isinstance(self._knowledge_config, dict) else {}
|
||||
enabled = cfg.get("enabled", True)
|
||||
if isinstance(enabled, str):
|
||||
enabled = enabled.strip().lower() not in {"false", "0", "off", "no"}
|
||||
if not enabled:
|
||||
return messages
|
||||
|
||||
kb_id = self._resolve_kb_id()
|
||||
if not kb_id:
|
||||
return messages
|
||||
|
||||
latest_user = ""
|
||||
for msg in reversed(messages):
|
||||
if msg.role == "user":
|
||||
latest_user = (msg.content or "").strip()
|
||||
break
|
||||
if not latest_user:
|
||||
return messages
|
||||
|
||||
n_results = self._coerce_int(cfg.get("nResults"), self._RAG_DEFAULT_RESULTS)
|
||||
n_results = max(1, min(n_results, self._RAG_MAX_RESULTS))
|
||||
|
||||
results = await search_knowledge_context(
|
||||
kb_id=kb_id,
|
||||
query=latest_user,
|
||||
n_results=n_results,
|
||||
)
|
||||
prompt = self._build_knowledge_prompt(results)
|
||||
if not prompt:
|
||||
return messages
|
||||
|
||||
logger.debug(f"RAG context injected (kb_id={kb_id}, chunks={len(results)})")
|
||||
rag_system = LLMMessage(role="system", content=prompt)
|
||||
if messages and messages[0].role == "system":
|
||||
return [messages[0], rag_system, *messages[1:]]
|
||||
return [rag_system, *messages]
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
@@ -115,7 +234,8 @@ class OpenAILLMService(BaseLLMService):
|
||||
if not self.client:
|
||||
raise RuntimeError("LLM service not connected")
|
||||
|
||||
prepared = self._prepare_messages(messages)
|
||||
rag_messages = await self._with_knowledge_context(messages)
|
||||
prepared = self._prepare_messages(rag_messages)
|
||||
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
@@ -153,7 +273,8 @@ class OpenAILLMService(BaseLLMService):
|
||||
if not self.client:
|
||||
raise RuntimeError("LLM service not connected")
|
||||
|
||||
prepared = self._prepare_messages(messages)
|
||||
rag_messages = await self._with_knowledge_context(messages)
|
||||
prepared = self._prepare_messages(rag_messages)
|
||||
self._cancel_event.clear()
|
||||
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user