From 539cf2fda231ef353d12be80ffddb8d903d6f4d4 Mon Sep 17 00:00:00 2001 From: Xin Wang Date: Tue, 10 Feb 2026 13:47:08 +0800 Subject: [PATCH] Move rag to llm service --- engine/core/duplex_pipeline.py | 126 ++++++-------------------------- engine/services/llm.py | 127 ++++++++++++++++++++++++++++++++- 2 files changed, 146 insertions(+), 107 deletions(-) diff --git a/engine/core/duplex_pipeline.py b/engine/core/duplex_pipeline.py index 6197667..3d2a1dd 100644 --- a/engine/core/duplex_pipeline.py +++ b/engine/core/duplex_pipeline.py @@ -13,12 +13,11 @@ event-driven design. import asyncio import time -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, Optional, Tuple import numpy as np from loguru import logger -from app.backend_client import search_knowledge_context from app.config import settings from core.conversation import ConversationManager, ConversationState from core.events import get_event_bus @@ -27,7 +26,7 @@ from models.ws_v1 import ev from processors.eou import EouDetector from processors.vad import SileroVAD, VADProcessor from services.asr import BufferedASRService -from services.base import BaseASRService, BaseLLMService, BaseTTSService, LLMMessage +from services.base import BaseASRService, BaseLLMService, BaseTTSService from services.llm import MockLLMService, OpenAILLMService from services.siliconflow_asr import SiliconFlowASRService from services.siliconflow_tts import SiliconFlowTTSService @@ -56,9 +55,6 @@ class DuplexPipeline: _SENTENCE_TRAILING_CHARS = frozenset({"。", "!", "?", ".", "!", "?", "…", "~", "~", "\n"}) _SENTENCE_CLOSERS = frozenset({'"', "'", "”", "’", ")", "]", "}", ")", "】", "」", "』", "》"}) _MIN_SPLIT_SPOKEN_CHARS = 6 - _RAG_DEFAULT_RESULTS = 5 - _RAG_MAX_RESULTS = 8 - _RAG_MAX_CONTEXT_CHARS = 4000 def __init__( self, @@ -212,6 +208,9 @@ class DuplexPipeline: if kb_id: self._runtime_knowledge_base_id = kb_id + if self.llm_service and hasattr(self.llm_service, "set_knowledge_config"): + self.llm_service.set_knowledge_config(self._resolved_knowledge_config()) + async def start(self) -> None: """Start the pipeline and connect services.""" try: @@ -226,12 +225,16 @@ class DuplexPipeline: self.llm_service = OpenAILLMService( api_key=llm_api_key, base_url=llm_base_url, - model=llm_model + model=llm_model, + knowledge_config=self._resolved_knowledge_config(), ) else: logger.warning("No OpenAI API key - using mock LLM") self.llm_service = MockLLMService() + if hasattr(self.llm_service, "set_knowledge_config"): + self.llm_service.set_knowledge_config(self._resolved_knowledge_config()) + await self.llm_service.connect() # Connect TTS service @@ -570,102 +573,17 @@ class DuplexPipeline: await self.conversation.end_user_turn(user_text) self._current_turn_task = asyncio.create_task(self._handle_turn(user_text)) - @staticmethod - def _coerce_int(value: Any, default: int) -> int: - try: - return int(value) - except (TypeError, ValueError): - return default - - def _resolve_runtime_kb_id(self) -> Optional[str]: - if self._runtime_knowledge_base_id: - return self._runtime_knowledge_base_id - kb_id = str(self._runtime_knowledge.get("kbId") or self._runtime_knowledge.get("knowledgeBaseId") or "").strip() - return kb_id or None - - def _build_knowledge_prompt(self, results: List[Dict[str, Any]]) -> Optional[str]: - if not results: - return None - - lines = [ - "You have retrieved the following knowledge base snippets.", - "Use them only when relevant to the latest user request.", - "If snippets are insufficient, say you are not sure instead of guessing.", - "", - ] - - used_chars = 0 - used_count = 0 - for item in results: - content = str(item.get("content") or "").strip() - if not content: - continue - if used_chars >= self._RAG_MAX_CONTEXT_CHARS: - break - - metadata = item.get("metadata") if isinstance(item.get("metadata"), dict) else {} - doc_id = metadata.get("document_id") - chunk_index = metadata.get("chunk_index") - distance = item.get("distance") - - source_parts = [] - if doc_id: - source_parts.append(f"doc={doc_id}") - if chunk_index is not None: - source_parts.append(f"chunk={chunk_index}") - source = f" ({', '.join(source_parts)})" if source_parts else "" - - distance_text = "" - try: - if distance is not None: - distance_text = f", distance={float(distance):.4f}" - except (TypeError, ValueError): - distance_text = "" - - remaining = self._RAG_MAX_CONTEXT_CHARS - used_chars - snippet = content[:remaining].strip() - if not snippet: - continue - - used_count += 1 - lines.append(f"[{used_count}{source}{distance_text}] {snippet}") - used_chars += len(snippet) - - if used_count == 0: - return None - - return "\n".join(lines) - - async def _build_turn_messages(self, user_text: str) -> List[LLMMessage]: - messages = self.conversation.get_messages() - kb_id = self._resolve_runtime_kb_id() - if not kb_id: - return messages - - knowledge_cfg = self._runtime_knowledge if isinstance(self._runtime_knowledge, dict) else {} - enabled = knowledge_cfg.get("enabled", True) - if isinstance(enabled, str): - enabled = enabled.strip().lower() not in {"false", "0", "off", "no"} - if not enabled: - return messages - - n_results = self._coerce_int(knowledge_cfg.get("nResults"), self._RAG_DEFAULT_RESULTS) - n_results = max(1, min(n_results, self._RAG_MAX_RESULTS)) - - results = await search_knowledge_context( - kb_id=kb_id, - query=user_text, - n_results=n_results, - ) - prompt = self._build_knowledge_prompt(results) - if not prompt: - return messages - - logger.debug(f"RAG context injected (kb_id={kb_id}, chunks={len(results)})") - rag_system = LLMMessage(role="system", content=prompt) - if messages and messages[0].role == "system": - return [messages[0], rag_system, *messages[1:]] - return [rag_system, *messages] + def _resolved_knowledge_config(self) -> Dict[str, Any]: + cfg: Dict[str, Any] = {} + if isinstance(self._runtime_knowledge, dict): + cfg.update(self._runtime_knowledge) + kb_id = self._runtime_knowledge_base_id or str( + cfg.get("kbId") or cfg.get("knowledgeBaseId") or "" + ).strip() + if kb_id: + cfg["kbId"] = kb_id + cfg.setdefault("enabled", True) + return cfg async def _handle_turn(self, user_text: str) -> None: """ @@ -682,7 +600,7 @@ class DuplexPipeline: self._first_audio_sent = False # Get AI response (streaming) - messages = await self._build_turn_messages(user_text) + messages = self.conversation.get_messages() full_response = "" await self.conversation.start_assistant_turn() diff --git a/engine/services/llm.py b/engine/services/llm.py index e1d99a8..6496a69 100644 --- a/engine/services/llm.py +++ b/engine/services/llm.py @@ -9,6 +9,7 @@ import asyncio from typing import AsyncIterator, Optional, List, Dict, Any from loguru import logger +from app.backend_client import search_knowledge_context from services.base import BaseLLMService, LLMMessage, ServiceState # Try to import openai @@ -33,7 +34,8 @@ class OpenAILLMService(BaseLLMService): model: str = "gpt-4o-mini", api_key: Optional[str] = None, base_url: Optional[str] = None, - system_prompt: Optional[str] = None + system_prompt: Optional[str] = None, + knowledge_config: Optional[Dict[str, Any]] = None, ): """ Initialize OpenAI LLM service. @@ -56,6 +58,11 @@ class OpenAILLMService(BaseLLMService): self.client: Optional[AsyncOpenAI] = None self._cancel_event = asyncio.Event() + self._knowledge_config: Dict[str, Any] = knowledge_config or {} + + _RAG_DEFAULT_RESULTS = 5 + _RAG_MAX_RESULTS = 8 + _RAG_MAX_CONTEXT_CHARS = 4000 async def connect(self) -> None: """Initialize OpenAI client.""" @@ -94,6 +101,118 @@ class OpenAILLMService(BaseLLMService): result.append(msg.to_dict()) return result + + def set_knowledge_config(self, config: Optional[Dict[str, Any]]) -> None: + """Update runtime knowledge retrieval config.""" + self._knowledge_config = config or {} + + @staticmethod + def _coerce_int(value: Any, default: int) -> int: + try: + return int(value) + except (TypeError, ValueError): + return default + + def _resolve_kb_id(self) -> Optional[str]: + cfg = self._knowledge_config if isinstance(self._knowledge_config, dict) else {} + kb_id = str( + cfg.get("kbId") + or cfg.get("knowledgeBaseId") + or cfg.get("knowledge_base_id") + or "" + ).strip() + return kb_id or None + + def _build_knowledge_prompt(self, results: List[Dict[str, Any]]) -> Optional[str]: + if not results: + return None + + lines = [ + "You have retrieved the following knowledge base snippets.", + "Use them only when relevant to the latest user request.", + "If snippets are insufficient, say you are not sure instead of guessing.", + "", + ] + + used_chars = 0 + used_count = 0 + for item in results: + content = str(item.get("content") or "").strip() + if not content: + continue + if used_chars >= self._RAG_MAX_CONTEXT_CHARS: + break + + metadata = item.get("metadata") if isinstance(item.get("metadata"), dict) else {} + doc_id = metadata.get("document_id") + chunk_index = metadata.get("chunk_index") + distance = item.get("distance") + + source_parts = [] + if doc_id: + source_parts.append(f"doc={doc_id}") + if chunk_index is not None: + source_parts.append(f"chunk={chunk_index}") + source = f" ({', '.join(source_parts)})" if source_parts else "" + + distance_text = "" + try: + if distance is not None: + distance_text = f", distance={float(distance):.4f}" + except (TypeError, ValueError): + distance_text = "" + + remaining = self._RAG_MAX_CONTEXT_CHARS - used_chars + snippet = content[:remaining].strip() + if not snippet: + continue + + used_count += 1 + lines.append(f"[{used_count}{source}{distance_text}] {snippet}") + used_chars += len(snippet) + + if used_count == 0: + return None + + return "\n".join(lines) + + async def _with_knowledge_context(self, messages: List[LLMMessage]) -> List[LLMMessage]: + cfg = self._knowledge_config if isinstance(self._knowledge_config, dict) else {} + enabled = cfg.get("enabled", True) + if isinstance(enabled, str): + enabled = enabled.strip().lower() not in {"false", "0", "off", "no"} + if not enabled: + return messages + + kb_id = self._resolve_kb_id() + if not kb_id: + return messages + + latest_user = "" + for msg in reversed(messages): + if msg.role == "user": + latest_user = (msg.content or "").strip() + break + if not latest_user: + return messages + + n_results = self._coerce_int(cfg.get("nResults"), self._RAG_DEFAULT_RESULTS) + n_results = max(1, min(n_results, self._RAG_MAX_RESULTS)) + + results = await search_knowledge_context( + kb_id=kb_id, + query=latest_user, + n_results=n_results, + ) + prompt = self._build_knowledge_prompt(results) + if not prompt: + return messages + + logger.debug(f"RAG context injected (kb_id={kb_id}, chunks={len(results)})") + rag_system = LLMMessage(role="system", content=prompt) + if messages and messages[0].role == "system": + return [messages[0], rag_system, *messages[1:]] + return [rag_system, *messages] async def generate( self, @@ -115,7 +234,8 @@ class OpenAILLMService(BaseLLMService): if not self.client: raise RuntimeError("LLM service not connected") - prepared = self._prepare_messages(messages) + rag_messages = await self._with_knowledge_context(messages) + prepared = self._prepare_messages(rag_messages) try: response = await self.client.chat.completions.create( @@ -153,7 +273,8 @@ class OpenAILLMService(BaseLLMService): if not self.client: raise RuntimeError("LLM service not connected") - prepared = self._prepare_messages(messages) + rag_messages = await self._with_knowledge_context(messages) + prepared = self._prepare_messages(rag_messages) self._cancel_event.clear() try: