Move rag to llm service
This commit is contained in:
@@ -13,12 +13,11 @@ event-driven design.
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from app.backend_client import search_knowledge_context
|
||||
from app.config import settings
|
||||
from core.conversation import ConversationManager, ConversationState
|
||||
from core.events import get_event_bus
|
||||
@@ -27,7 +26,7 @@ from models.ws_v1 import ev
|
||||
from processors.eou import EouDetector
|
||||
from processors.vad import SileroVAD, VADProcessor
|
||||
from services.asr import BufferedASRService
|
||||
from services.base import BaseASRService, BaseLLMService, BaseTTSService, LLMMessage
|
||||
from services.base import BaseASRService, BaseLLMService, BaseTTSService
|
||||
from services.llm import MockLLMService, OpenAILLMService
|
||||
from services.siliconflow_asr import SiliconFlowASRService
|
||||
from services.siliconflow_tts import SiliconFlowTTSService
|
||||
@@ -56,9 +55,6 @@ class DuplexPipeline:
|
||||
_SENTENCE_TRAILING_CHARS = frozenset({"。", "!", "?", ".", "!", "?", "…", "~", "~", "\n"})
|
||||
_SENTENCE_CLOSERS = frozenset({'"', "'", "”", "’", ")", "]", "}", ")", "】", "」", "』", "》"})
|
||||
_MIN_SPLIT_SPOKEN_CHARS = 6
|
||||
_RAG_DEFAULT_RESULTS = 5
|
||||
_RAG_MAX_RESULTS = 8
|
||||
_RAG_MAX_CONTEXT_CHARS = 4000
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -212,6 +208,9 @@ class DuplexPipeline:
|
||||
if kb_id:
|
||||
self._runtime_knowledge_base_id = kb_id
|
||||
|
||||
if self.llm_service and hasattr(self.llm_service, "set_knowledge_config"):
|
||||
self.llm_service.set_knowledge_config(self._resolved_knowledge_config())
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the pipeline and connect services."""
|
||||
try:
|
||||
@@ -226,12 +225,16 @@ class DuplexPipeline:
|
||||
self.llm_service = OpenAILLMService(
|
||||
api_key=llm_api_key,
|
||||
base_url=llm_base_url,
|
||||
model=llm_model
|
||||
model=llm_model,
|
||||
knowledge_config=self._resolved_knowledge_config(),
|
||||
)
|
||||
else:
|
||||
logger.warning("No OpenAI API key - using mock LLM")
|
||||
self.llm_service = MockLLMService()
|
||||
|
||||
if hasattr(self.llm_service, "set_knowledge_config"):
|
||||
self.llm_service.set_knowledge_config(self._resolved_knowledge_config())
|
||||
|
||||
await self.llm_service.connect()
|
||||
|
||||
# Connect TTS service
|
||||
@@ -570,102 +573,17 @@ class DuplexPipeline:
|
||||
await self.conversation.end_user_turn(user_text)
|
||||
self._current_turn_task = asyncio.create_task(self._handle_turn(user_text))
|
||||
|
||||
@staticmethod
|
||||
def _coerce_int(value: Any, default: int) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
def _resolve_runtime_kb_id(self) -> Optional[str]:
|
||||
if self._runtime_knowledge_base_id:
|
||||
return self._runtime_knowledge_base_id
|
||||
kb_id = str(self._runtime_knowledge.get("kbId") or self._runtime_knowledge.get("knowledgeBaseId") or "").strip()
|
||||
return kb_id or None
|
||||
|
||||
def _build_knowledge_prompt(self, results: List[Dict[str, Any]]) -> Optional[str]:
|
||||
if not results:
|
||||
return None
|
||||
|
||||
lines = [
|
||||
"You have retrieved the following knowledge base snippets.",
|
||||
"Use them only when relevant to the latest user request.",
|
||||
"If snippets are insufficient, say you are not sure instead of guessing.",
|
||||
"",
|
||||
]
|
||||
|
||||
used_chars = 0
|
||||
used_count = 0
|
||||
for item in results:
|
||||
content = str(item.get("content") or "").strip()
|
||||
if not content:
|
||||
continue
|
||||
if used_chars >= self._RAG_MAX_CONTEXT_CHARS:
|
||||
break
|
||||
|
||||
metadata = item.get("metadata") if isinstance(item.get("metadata"), dict) else {}
|
||||
doc_id = metadata.get("document_id")
|
||||
chunk_index = metadata.get("chunk_index")
|
||||
distance = item.get("distance")
|
||||
|
||||
source_parts = []
|
||||
if doc_id:
|
||||
source_parts.append(f"doc={doc_id}")
|
||||
if chunk_index is not None:
|
||||
source_parts.append(f"chunk={chunk_index}")
|
||||
source = f" ({', '.join(source_parts)})" if source_parts else ""
|
||||
|
||||
distance_text = ""
|
||||
try:
|
||||
if distance is not None:
|
||||
distance_text = f", distance={float(distance):.4f}"
|
||||
except (TypeError, ValueError):
|
||||
distance_text = ""
|
||||
|
||||
remaining = self._RAG_MAX_CONTEXT_CHARS - used_chars
|
||||
snippet = content[:remaining].strip()
|
||||
if not snippet:
|
||||
continue
|
||||
|
||||
used_count += 1
|
||||
lines.append(f"[{used_count}{source}{distance_text}] {snippet}")
|
||||
used_chars += len(snippet)
|
||||
|
||||
if used_count == 0:
|
||||
return None
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
async def _build_turn_messages(self, user_text: str) -> List[LLMMessage]:
|
||||
messages = self.conversation.get_messages()
|
||||
kb_id = self._resolve_runtime_kb_id()
|
||||
if not kb_id:
|
||||
return messages
|
||||
|
||||
knowledge_cfg = self._runtime_knowledge if isinstance(self._runtime_knowledge, dict) else {}
|
||||
enabled = knowledge_cfg.get("enabled", True)
|
||||
if isinstance(enabled, str):
|
||||
enabled = enabled.strip().lower() not in {"false", "0", "off", "no"}
|
||||
if not enabled:
|
||||
return messages
|
||||
|
||||
n_results = self._coerce_int(knowledge_cfg.get("nResults"), self._RAG_DEFAULT_RESULTS)
|
||||
n_results = max(1, min(n_results, self._RAG_MAX_RESULTS))
|
||||
|
||||
results = await search_knowledge_context(
|
||||
kb_id=kb_id,
|
||||
query=user_text,
|
||||
n_results=n_results,
|
||||
)
|
||||
prompt = self._build_knowledge_prompt(results)
|
||||
if not prompt:
|
||||
return messages
|
||||
|
||||
logger.debug(f"RAG context injected (kb_id={kb_id}, chunks={len(results)})")
|
||||
rag_system = LLMMessage(role="system", content=prompt)
|
||||
if messages and messages[0].role == "system":
|
||||
return [messages[0], rag_system, *messages[1:]]
|
||||
return [rag_system, *messages]
|
||||
def _resolved_knowledge_config(self) -> Dict[str, Any]:
|
||||
cfg: Dict[str, Any] = {}
|
||||
if isinstance(self._runtime_knowledge, dict):
|
||||
cfg.update(self._runtime_knowledge)
|
||||
kb_id = self._runtime_knowledge_base_id or str(
|
||||
cfg.get("kbId") or cfg.get("knowledgeBaseId") or ""
|
||||
).strip()
|
||||
if kb_id:
|
||||
cfg["kbId"] = kb_id
|
||||
cfg.setdefault("enabled", True)
|
||||
return cfg
|
||||
|
||||
async def _handle_turn(self, user_text: str) -> None:
|
||||
"""
|
||||
@@ -682,7 +600,7 @@ class DuplexPipeline:
|
||||
self._first_audio_sent = False
|
||||
|
||||
# Get AI response (streaming)
|
||||
messages = await self._build_turn_messages(user_text)
|
||||
messages = self.conversation.get_messages()
|
||||
full_response = ""
|
||||
|
||||
await self.conversation.start_assistant_turn()
|
||||
|
||||
Reference in New Issue
Block a user