Move rag to llm service

This commit is contained in:
Xin Wang
2026-02-10 13:47:08 +08:00
parent d2aaba999b
commit 539cf2fda2
2 changed files with 146 additions and 107 deletions

View File

@@ -13,12 +13,11 @@ event-driven design.
import asyncio import asyncio
import time import time
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, Optional, Tuple
import numpy as np import numpy as np
from loguru import logger from loguru import logger
from app.backend_client import search_knowledge_context
from app.config import settings from app.config import settings
from core.conversation import ConversationManager, ConversationState from core.conversation import ConversationManager, ConversationState
from core.events import get_event_bus from core.events import get_event_bus
@@ -27,7 +26,7 @@ from models.ws_v1 import ev
from processors.eou import EouDetector from processors.eou import EouDetector
from processors.vad import SileroVAD, VADProcessor from processors.vad import SileroVAD, VADProcessor
from services.asr import BufferedASRService from services.asr import BufferedASRService
from services.base import BaseASRService, BaseLLMService, BaseTTSService, LLMMessage from services.base import BaseASRService, BaseLLMService, BaseTTSService
from services.llm import MockLLMService, OpenAILLMService from services.llm import MockLLMService, OpenAILLMService
from services.siliconflow_asr import SiliconFlowASRService from services.siliconflow_asr import SiliconFlowASRService
from services.siliconflow_tts import SiliconFlowTTSService from services.siliconflow_tts import SiliconFlowTTSService
@@ -56,9 +55,6 @@ class DuplexPipeline:
_SENTENCE_TRAILING_CHARS = frozenset({"", "", "", ".", "!", "?", "", "~", "", "\n"}) _SENTENCE_TRAILING_CHARS = frozenset({"", "", "", ".", "!", "?", "", "~", "", "\n"})
_SENTENCE_CLOSERS = frozenset({'"', "'", "", "", ")", "]", "}", "", "", "", "", ""}) _SENTENCE_CLOSERS = frozenset({'"', "'", "", "", ")", "]", "}", "", "", "", "", ""})
_MIN_SPLIT_SPOKEN_CHARS = 6 _MIN_SPLIT_SPOKEN_CHARS = 6
_RAG_DEFAULT_RESULTS = 5
_RAG_MAX_RESULTS = 8
_RAG_MAX_CONTEXT_CHARS = 4000
def __init__( def __init__(
self, self,
@@ -212,6 +208,9 @@ class DuplexPipeline:
if kb_id: if kb_id:
self._runtime_knowledge_base_id = kb_id self._runtime_knowledge_base_id = kb_id
if self.llm_service and hasattr(self.llm_service, "set_knowledge_config"):
self.llm_service.set_knowledge_config(self._resolved_knowledge_config())
async def start(self) -> None: async def start(self) -> None:
"""Start the pipeline and connect services.""" """Start the pipeline and connect services."""
try: try:
@@ -226,12 +225,16 @@ class DuplexPipeline:
self.llm_service = OpenAILLMService( self.llm_service = OpenAILLMService(
api_key=llm_api_key, api_key=llm_api_key,
base_url=llm_base_url, base_url=llm_base_url,
model=llm_model model=llm_model,
knowledge_config=self._resolved_knowledge_config(),
) )
else: else:
logger.warning("No OpenAI API key - using mock LLM") logger.warning("No OpenAI API key - using mock LLM")
self.llm_service = MockLLMService() self.llm_service = MockLLMService()
if hasattr(self.llm_service, "set_knowledge_config"):
self.llm_service.set_knowledge_config(self._resolved_knowledge_config())
await self.llm_service.connect() await self.llm_service.connect()
# Connect TTS service # Connect TTS service
@@ -570,102 +573,17 @@ class DuplexPipeline:
await self.conversation.end_user_turn(user_text) await self.conversation.end_user_turn(user_text)
self._current_turn_task = asyncio.create_task(self._handle_turn(user_text)) self._current_turn_task = asyncio.create_task(self._handle_turn(user_text))
@staticmethod def _resolved_knowledge_config(self) -> Dict[str, Any]:
def _coerce_int(value: Any, default: int) -> int: cfg: Dict[str, Any] = {}
try: if isinstance(self._runtime_knowledge, dict):
return int(value) cfg.update(self._runtime_knowledge)
except (TypeError, ValueError): kb_id = self._runtime_knowledge_base_id or str(
return default cfg.get("kbId") or cfg.get("knowledgeBaseId") or ""
).strip()
def _resolve_runtime_kb_id(self) -> Optional[str]: if kb_id:
if self._runtime_knowledge_base_id: cfg["kbId"] = kb_id
return self._runtime_knowledge_base_id cfg.setdefault("enabled", True)
kb_id = str(self._runtime_knowledge.get("kbId") or self._runtime_knowledge.get("knowledgeBaseId") or "").strip() return cfg
return kb_id or None
def _build_knowledge_prompt(self, results: List[Dict[str, Any]]) -> Optional[str]:
if not results:
return None
lines = [
"You have retrieved the following knowledge base snippets.",
"Use them only when relevant to the latest user request.",
"If snippets are insufficient, say you are not sure instead of guessing.",
"",
]
used_chars = 0
used_count = 0
for item in results:
content = str(item.get("content") or "").strip()
if not content:
continue
if used_chars >= self._RAG_MAX_CONTEXT_CHARS:
break
metadata = item.get("metadata") if isinstance(item.get("metadata"), dict) else {}
doc_id = metadata.get("document_id")
chunk_index = metadata.get("chunk_index")
distance = item.get("distance")
source_parts = []
if doc_id:
source_parts.append(f"doc={doc_id}")
if chunk_index is not None:
source_parts.append(f"chunk={chunk_index}")
source = f" ({', '.join(source_parts)})" if source_parts else ""
distance_text = ""
try:
if distance is not None:
distance_text = f", distance={float(distance):.4f}"
except (TypeError, ValueError):
distance_text = ""
remaining = self._RAG_MAX_CONTEXT_CHARS - used_chars
snippet = content[:remaining].strip()
if not snippet:
continue
used_count += 1
lines.append(f"[{used_count}{source}{distance_text}] {snippet}")
used_chars += len(snippet)
if used_count == 0:
return None
return "\n".join(lines)
async def _build_turn_messages(self, user_text: str) -> List[LLMMessage]:
messages = self.conversation.get_messages()
kb_id = self._resolve_runtime_kb_id()
if not kb_id:
return messages
knowledge_cfg = self._runtime_knowledge if isinstance(self._runtime_knowledge, dict) else {}
enabled = knowledge_cfg.get("enabled", True)
if isinstance(enabled, str):
enabled = enabled.strip().lower() not in {"false", "0", "off", "no"}
if not enabled:
return messages
n_results = self._coerce_int(knowledge_cfg.get("nResults"), self._RAG_DEFAULT_RESULTS)
n_results = max(1, min(n_results, self._RAG_MAX_RESULTS))
results = await search_knowledge_context(
kb_id=kb_id,
query=user_text,
n_results=n_results,
)
prompt = self._build_knowledge_prompt(results)
if not prompt:
return messages
logger.debug(f"RAG context injected (kb_id={kb_id}, chunks={len(results)})")
rag_system = LLMMessage(role="system", content=prompt)
if messages and messages[0].role == "system":
return [messages[0], rag_system, *messages[1:]]
return [rag_system, *messages]
async def _handle_turn(self, user_text: str) -> None: async def _handle_turn(self, user_text: str) -> None:
""" """
@@ -682,7 +600,7 @@ class DuplexPipeline:
self._first_audio_sent = False self._first_audio_sent = False
# Get AI response (streaming) # Get AI response (streaming)
messages = await self._build_turn_messages(user_text) messages = self.conversation.get_messages()
full_response = "" full_response = ""
await self.conversation.start_assistant_turn() await self.conversation.start_assistant_turn()

View File

@@ -9,6 +9,7 @@ import asyncio
from typing import AsyncIterator, Optional, List, Dict, Any from typing import AsyncIterator, Optional, List, Dict, Any
from loguru import logger from loguru import logger
from app.backend_client import search_knowledge_context
from services.base import BaseLLMService, LLMMessage, ServiceState from services.base import BaseLLMService, LLMMessage, ServiceState
# Try to import openai # Try to import openai
@@ -33,7 +34,8 @@ class OpenAILLMService(BaseLLMService):
model: str = "gpt-4o-mini", model: str = "gpt-4o-mini",
api_key: Optional[str] = None, api_key: Optional[str] = None,
base_url: Optional[str] = None, base_url: Optional[str] = None,
system_prompt: Optional[str] = None system_prompt: Optional[str] = None,
knowledge_config: Optional[Dict[str, Any]] = None,
): ):
""" """
Initialize OpenAI LLM service. Initialize OpenAI LLM service.
@@ -56,6 +58,11 @@ class OpenAILLMService(BaseLLMService):
self.client: Optional[AsyncOpenAI] = None self.client: Optional[AsyncOpenAI] = None
self._cancel_event = asyncio.Event() self._cancel_event = asyncio.Event()
self._knowledge_config: Dict[str, Any] = knowledge_config or {}
_RAG_DEFAULT_RESULTS = 5
_RAG_MAX_RESULTS = 8
_RAG_MAX_CONTEXT_CHARS = 4000
async def connect(self) -> None: async def connect(self) -> None:
"""Initialize OpenAI client.""" """Initialize OpenAI client."""
@@ -94,6 +101,118 @@ class OpenAILLMService(BaseLLMService):
result.append(msg.to_dict()) result.append(msg.to_dict())
return result return result
def set_knowledge_config(self, config: Optional[Dict[str, Any]]) -> None:
"""Update runtime knowledge retrieval config."""
self._knowledge_config = config or {}
@staticmethod
def _coerce_int(value: Any, default: int) -> int:
try:
return int(value)
except (TypeError, ValueError):
return default
def _resolve_kb_id(self) -> Optional[str]:
cfg = self._knowledge_config if isinstance(self._knowledge_config, dict) else {}
kb_id = str(
cfg.get("kbId")
or cfg.get("knowledgeBaseId")
or cfg.get("knowledge_base_id")
or ""
).strip()
return kb_id or None
def _build_knowledge_prompt(self, results: List[Dict[str, Any]]) -> Optional[str]:
if not results:
return None
lines = [
"You have retrieved the following knowledge base snippets.",
"Use them only when relevant to the latest user request.",
"If snippets are insufficient, say you are not sure instead of guessing.",
"",
]
used_chars = 0
used_count = 0
for item in results:
content = str(item.get("content") or "").strip()
if not content:
continue
if used_chars >= self._RAG_MAX_CONTEXT_CHARS:
break
metadata = item.get("metadata") if isinstance(item.get("metadata"), dict) else {}
doc_id = metadata.get("document_id")
chunk_index = metadata.get("chunk_index")
distance = item.get("distance")
source_parts = []
if doc_id:
source_parts.append(f"doc={doc_id}")
if chunk_index is not None:
source_parts.append(f"chunk={chunk_index}")
source = f" ({', '.join(source_parts)})" if source_parts else ""
distance_text = ""
try:
if distance is not None:
distance_text = f", distance={float(distance):.4f}"
except (TypeError, ValueError):
distance_text = ""
remaining = self._RAG_MAX_CONTEXT_CHARS - used_chars
snippet = content[:remaining].strip()
if not snippet:
continue
used_count += 1
lines.append(f"[{used_count}{source}{distance_text}] {snippet}")
used_chars += len(snippet)
if used_count == 0:
return None
return "\n".join(lines)
async def _with_knowledge_context(self, messages: List[LLMMessage]) -> List[LLMMessage]:
cfg = self._knowledge_config if isinstance(self._knowledge_config, dict) else {}
enabled = cfg.get("enabled", True)
if isinstance(enabled, str):
enabled = enabled.strip().lower() not in {"false", "0", "off", "no"}
if not enabled:
return messages
kb_id = self._resolve_kb_id()
if not kb_id:
return messages
latest_user = ""
for msg in reversed(messages):
if msg.role == "user":
latest_user = (msg.content or "").strip()
break
if not latest_user:
return messages
n_results = self._coerce_int(cfg.get("nResults"), self._RAG_DEFAULT_RESULTS)
n_results = max(1, min(n_results, self._RAG_MAX_RESULTS))
results = await search_knowledge_context(
kb_id=kb_id,
query=latest_user,
n_results=n_results,
)
prompt = self._build_knowledge_prompt(results)
if not prompt:
return messages
logger.debug(f"RAG context injected (kb_id={kb_id}, chunks={len(results)})")
rag_system = LLMMessage(role="system", content=prompt)
if messages and messages[0].role == "system":
return [messages[0], rag_system, *messages[1:]]
return [rag_system, *messages]
async def generate( async def generate(
self, self,
@@ -115,7 +234,8 @@ class OpenAILLMService(BaseLLMService):
if not self.client: if not self.client:
raise RuntimeError("LLM service not connected") raise RuntimeError("LLM service not connected")
prepared = self._prepare_messages(messages) rag_messages = await self._with_knowledge_context(messages)
prepared = self._prepare_messages(rag_messages)
try: try:
response = await self.client.chat.completions.create( response = await self.client.chat.completions.create(
@@ -153,7 +273,8 @@ class OpenAILLMService(BaseLLMService):
if not self.client: if not self.client:
raise RuntimeError("LLM service not connected") raise RuntimeError("LLM service not connected")
prepared = self._prepare_messages(messages) rag_messages = await self._with_knowledge_context(messages)
prepared = self._prepare_messages(rag_messages)
self._cancel_event.clear() self._cancel_event.clear()
try: try: