Move rag to llm service
This commit is contained in:
@@ -13,12 +13,11 @@ event-driven design.
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from app.backend_client import search_knowledge_context
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from core.conversation import ConversationManager, ConversationState
|
from core.conversation import ConversationManager, ConversationState
|
||||||
from core.events import get_event_bus
|
from core.events import get_event_bus
|
||||||
@@ -27,7 +26,7 @@ from models.ws_v1 import ev
|
|||||||
from processors.eou import EouDetector
|
from processors.eou import EouDetector
|
||||||
from processors.vad import SileroVAD, VADProcessor
|
from processors.vad import SileroVAD, VADProcessor
|
||||||
from services.asr import BufferedASRService
|
from services.asr import BufferedASRService
|
||||||
from services.base import BaseASRService, BaseLLMService, BaseTTSService, LLMMessage
|
from services.base import BaseASRService, BaseLLMService, BaseTTSService
|
||||||
from services.llm import MockLLMService, OpenAILLMService
|
from services.llm import MockLLMService, OpenAILLMService
|
||||||
from services.siliconflow_asr import SiliconFlowASRService
|
from services.siliconflow_asr import SiliconFlowASRService
|
||||||
from services.siliconflow_tts import SiliconFlowTTSService
|
from services.siliconflow_tts import SiliconFlowTTSService
|
||||||
@@ -56,9 +55,6 @@ class DuplexPipeline:
|
|||||||
_SENTENCE_TRAILING_CHARS = frozenset({"。", "!", "?", ".", "!", "?", "…", "~", "~", "\n"})
|
_SENTENCE_TRAILING_CHARS = frozenset({"。", "!", "?", ".", "!", "?", "…", "~", "~", "\n"})
|
||||||
_SENTENCE_CLOSERS = frozenset({'"', "'", "”", "’", ")", "]", "}", ")", "】", "」", "』", "》"})
|
_SENTENCE_CLOSERS = frozenset({'"', "'", "”", "’", ")", "]", "}", ")", "】", "」", "』", "》"})
|
||||||
_MIN_SPLIT_SPOKEN_CHARS = 6
|
_MIN_SPLIT_SPOKEN_CHARS = 6
|
||||||
_RAG_DEFAULT_RESULTS = 5
|
|
||||||
_RAG_MAX_RESULTS = 8
|
|
||||||
_RAG_MAX_CONTEXT_CHARS = 4000
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -212,6 +208,9 @@ class DuplexPipeline:
|
|||||||
if kb_id:
|
if kb_id:
|
||||||
self._runtime_knowledge_base_id = kb_id
|
self._runtime_knowledge_base_id = kb_id
|
||||||
|
|
||||||
|
if self.llm_service and hasattr(self.llm_service, "set_knowledge_config"):
|
||||||
|
self.llm_service.set_knowledge_config(self._resolved_knowledge_config())
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Start the pipeline and connect services."""
|
"""Start the pipeline and connect services."""
|
||||||
try:
|
try:
|
||||||
@@ -226,12 +225,16 @@ class DuplexPipeline:
|
|||||||
self.llm_service = OpenAILLMService(
|
self.llm_service = OpenAILLMService(
|
||||||
api_key=llm_api_key,
|
api_key=llm_api_key,
|
||||||
base_url=llm_base_url,
|
base_url=llm_base_url,
|
||||||
model=llm_model
|
model=llm_model,
|
||||||
|
knowledge_config=self._resolved_knowledge_config(),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning("No OpenAI API key - using mock LLM")
|
logger.warning("No OpenAI API key - using mock LLM")
|
||||||
self.llm_service = MockLLMService()
|
self.llm_service = MockLLMService()
|
||||||
|
|
||||||
|
if hasattr(self.llm_service, "set_knowledge_config"):
|
||||||
|
self.llm_service.set_knowledge_config(self._resolved_knowledge_config())
|
||||||
|
|
||||||
await self.llm_service.connect()
|
await self.llm_service.connect()
|
||||||
|
|
||||||
# Connect TTS service
|
# Connect TTS service
|
||||||
@@ -570,102 +573,17 @@ class DuplexPipeline:
|
|||||||
await self.conversation.end_user_turn(user_text)
|
await self.conversation.end_user_turn(user_text)
|
||||||
self._current_turn_task = asyncio.create_task(self._handle_turn(user_text))
|
self._current_turn_task = asyncio.create_task(self._handle_turn(user_text))
|
||||||
|
|
||||||
@staticmethod
|
def _resolved_knowledge_config(self) -> Dict[str, Any]:
|
||||||
def _coerce_int(value: Any, default: int) -> int:
|
cfg: Dict[str, Any] = {}
|
||||||
try:
|
if isinstance(self._runtime_knowledge, dict):
|
||||||
return int(value)
|
cfg.update(self._runtime_knowledge)
|
||||||
except (TypeError, ValueError):
|
kb_id = self._runtime_knowledge_base_id or str(
|
||||||
return default
|
cfg.get("kbId") or cfg.get("knowledgeBaseId") or ""
|
||||||
|
).strip()
|
||||||
def _resolve_runtime_kb_id(self) -> Optional[str]:
|
if kb_id:
|
||||||
if self._runtime_knowledge_base_id:
|
cfg["kbId"] = kb_id
|
||||||
return self._runtime_knowledge_base_id
|
cfg.setdefault("enabled", True)
|
||||||
kb_id = str(self._runtime_knowledge.get("kbId") or self._runtime_knowledge.get("knowledgeBaseId") or "").strip()
|
return cfg
|
||||||
return kb_id or None
|
|
||||||
|
|
||||||
def _build_knowledge_prompt(self, results: List[Dict[str, Any]]) -> Optional[str]:
|
|
||||||
if not results:
|
|
||||||
return None
|
|
||||||
|
|
||||||
lines = [
|
|
||||||
"You have retrieved the following knowledge base snippets.",
|
|
||||||
"Use them only when relevant to the latest user request.",
|
|
||||||
"If snippets are insufficient, say you are not sure instead of guessing.",
|
|
||||||
"",
|
|
||||||
]
|
|
||||||
|
|
||||||
used_chars = 0
|
|
||||||
used_count = 0
|
|
||||||
for item in results:
|
|
||||||
content = str(item.get("content") or "").strip()
|
|
||||||
if not content:
|
|
||||||
continue
|
|
||||||
if used_chars >= self._RAG_MAX_CONTEXT_CHARS:
|
|
||||||
break
|
|
||||||
|
|
||||||
metadata = item.get("metadata") if isinstance(item.get("metadata"), dict) else {}
|
|
||||||
doc_id = metadata.get("document_id")
|
|
||||||
chunk_index = metadata.get("chunk_index")
|
|
||||||
distance = item.get("distance")
|
|
||||||
|
|
||||||
source_parts = []
|
|
||||||
if doc_id:
|
|
||||||
source_parts.append(f"doc={doc_id}")
|
|
||||||
if chunk_index is not None:
|
|
||||||
source_parts.append(f"chunk={chunk_index}")
|
|
||||||
source = f" ({', '.join(source_parts)})" if source_parts else ""
|
|
||||||
|
|
||||||
distance_text = ""
|
|
||||||
try:
|
|
||||||
if distance is not None:
|
|
||||||
distance_text = f", distance={float(distance):.4f}"
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
distance_text = ""
|
|
||||||
|
|
||||||
remaining = self._RAG_MAX_CONTEXT_CHARS - used_chars
|
|
||||||
snippet = content[:remaining].strip()
|
|
||||||
if not snippet:
|
|
||||||
continue
|
|
||||||
|
|
||||||
used_count += 1
|
|
||||||
lines.append(f"[{used_count}{source}{distance_text}] {snippet}")
|
|
||||||
used_chars += len(snippet)
|
|
||||||
|
|
||||||
if used_count == 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return "\n".join(lines)
|
|
||||||
|
|
||||||
async def _build_turn_messages(self, user_text: str) -> List[LLMMessage]:
|
|
||||||
messages = self.conversation.get_messages()
|
|
||||||
kb_id = self._resolve_runtime_kb_id()
|
|
||||||
if not kb_id:
|
|
||||||
return messages
|
|
||||||
|
|
||||||
knowledge_cfg = self._runtime_knowledge if isinstance(self._runtime_knowledge, dict) else {}
|
|
||||||
enabled = knowledge_cfg.get("enabled", True)
|
|
||||||
if isinstance(enabled, str):
|
|
||||||
enabled = enabled.strip().lower() not in {"false", "0", "off", "no"}
|
|
||||||
if not enabled:
|
|
||||||
return messages
|
|
||||||
|
|
||||||
n_results = self._coerce_int(knowledge_cfg.get("nResults"), self._RAG_DEFAULT_RESULTS)
|
|
||||||
n_results = max(1, min(n_results, self._RAG_MAX_RESULTS))
|
|
||||||
|
|
||||||
results = await search_knowledge_context(
|
|
||||||
kb_id=kb_id,
|
|
||||||
query=user_text,
|
|
||||||
n_results=n_results,
|
|
||||||
)
|
|
||||||
prompt = self._build_knowledge_prompt(results)
|
|
||||||
if not prompt:
|
|
||||||
return messages
|
|
||||||
|
|
||||||
logger.debug(f"RAG context injected (kb_id={kb_id}, chunks={len(results)})")
|
|
||||||
rag_system = LLMMessage(role="system", content=prompt)
|
|
||||||
if messages and messages[0].role == "system":
|
|
||||||
return [messages[0], rag_system, *messages[1:]]
|
|
||||||
return [rag_system, *messages]
|
|
||||||
|
|
||||||
async def _handle_turn(self, user_text: str) -> None:
|
async def _handle_turn(self, user_text: str) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -682,7 +600,7 @@ class DuplexPipeline:
|
|||||||
self._first_audio_sent = False
|
self._first_audio_sent = False
|
||||||
|
|
||||||
# Get AI response (streaming)
|
# Get AI response (streaming)
|
||||||
messages = await self._build_turn_messages(user_text)
|
messages = self.conversation.get_messages()
|
||||||
full_response = ""
|
full_response = ""
|
||||||
|
|
||||||
await self.conversation.start_assistant_turn()
|
await self.conversation.start_assistant_turn()
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import asyncio
|
|||||||
from typing import AsyncIterator, Optional, List, Dict, Any
|
from typing import AsyncIterator, Optional, List, Dict, Any
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from app.backend_client import search_knowledge_context
|
||||||
from services.base import BaseLLMService, LLMMessage, ServiceState
|
from services.base import BaseLLMService, LLMMessage, ServiceState
|
||||||
|
|
||||||
# Try to import openai
|
# Try to import openai
|
||||||
@@ -33,7 +34,8 @@ class OpenAILLMService(BaseLLMService):
|
|||||||
model: str = "gpt-4o-mini",
|
model: str = "gpt-4o-mini",
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
base_url: Optional[str] = None,
|
base_url: Optional[str] = None,
|
||||||
system_prompt: Optional[str] = None
|
system_prompt: Optional[str] = None,
|
||||||
|
knowledge_config: Optional[Dict[str, Any]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize OpenAI LLM service.
|
Initialize OpenAI LLM service.
|
||||||
@@ -56,6 +58,11 @@ class OpenAILLMService(BaseLLMService):
|
|||||||
|
|
||||||
self.client: Optional[AsyncOpenAI] = None
|
self.client: Optional[AsyncOpenAI] = None
|
||||||
self._cancel_event = asyncio.Event()
|
self._cancel_event = asyncio.Event()
|
||||||
|
self._knowledge_config: Dict[str, Any] = knowledge_config or {}
|
||||||
|
|
||||||
|
_RAG_DEFAULT_RESULTS = 5
|
||||||
|
_RAG_MAX_RESULTS = 8
|
||||||
|
_RAG_MAX_CONTEXT_CHARS = 4000
|
||||||
|
|
||||||
async def connect(self) -> None:
|
async def connect(self) -> None:
|
||||||
"""Initialize OpenAI client."""
|
"""Initialize OpenAI client."""
|
||||||
@@ -95,6 +102,118 @@ class OpenAILLMService(BaseLLMService):
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def set_knowledge_config(self, config: Optional[Dict[str, Any]]) -> None:
|
||||||
|
"""Update runtime knowledge retrieval config."""
|
||||||
|
self._knowledge_config = config or {}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _coerce_int(value: Any, default: int) -> int:
|
||||||
|
try:
|
||||||
|
return int(value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return default
|
||||||
|
|
||||||
|
def _resolve_kb_id(self) -> Optional[str]:
|
||||||
|
cfg = self._knowledge_config if isinstance(self._knowledge_config, dict) else {}
|
||||||
|
kb_id = str(
|
||||||
|
cfg.get("kbId")
|
||||||
|
or cfg.get("knowledgeBaseId")
|
||||||
|
or cfg.get("knowledge_base_id")
|
||||||
|
or ""
|
||||||
|
).strip()
|
||||||
|
return kb_id or None
|
||||||
|
|
||||||
|
def _build_knowledge_prompt(self, results: List[Dict[str, Any]]) -> Optional[str]:
|
||||||
|
if not results:
|
||||||
|
return None
|
||||||
|
|
||||||
|
lines = [
|
||||||
|
"You have retrieved the following knowledge base snippets.",
|
||||||
|
"Use them only when relevant to the latest user request.",
|
||||||
|
"If snippets are insufficient, say you are not sure instead of guessing.",
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
|
||||||
|
used_chars = 0
|
||||||
|
used_count = 0
|
||||||
|
for item in results:
|
||||||
|
content = str(item.get("content") or "").strip()
|
||||||
|
if not content:
|
||||||
|
continue
|
||||||
|
if used_chars >= self._RAG_MAX_CONTEXT_CHARS:
|
||||||
|
break
|
||||||
|
|
||||||
|
metadata = item.get("metadata") if isinstance(item.get("metadata"), dict) else {}
|
||||||
|
doc_id = metadata.get("document_id")
|
||||||
|
chunk_index = metadata.get("chunk_index")
|
||||||
|
distance = item.get("distance")
|
||||||
|
|
||||||
|
source_parts = []
|
||||||
|
if doc_id:
|
||||||
|
source_parts.append(f"doc={doc_id}")
|
||||||
|
if chunk_index is not None:
|
||||||
|
source_parts.append(f"chunk={chunk_index}")
|
||||||
|
source = f" ({', '.join(source_parts)})" if source_parts else ""
|
||||||
|
|
||||||
|
distance_text = ""
|
||||||
|
try:
|
||||||
|
if distance is not None:
|
||||||
|
distance_text = f", distance={float(distance):.4f}"
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
distance_text = ""
|
||||||
|
|
||||||
|
remaining = self._RAG_MAX_CONTEXT_CHARS - used_chars
|
||||||
|
snippet = content[:remaining].strip()
|
||||||
|
if not snippet:
|
||||||
|
continue
|
||||||
|
|
||||||
|
used_count += 1
|
||||||
|
lines.append(f"[{used_count}{source}{distance_text}] {snippet}")
|
||||||
|
used_chars += len(snippet)
|
||||||
|
|
||||||
|
if used_count == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
async def _with_knowledge_context(self, messages: List[LLMMessage]) -> List[LLMMessage]:
|
||||||
|
cfg = self._knowledge_config if isinstance(self._knowledge_config, dict) else {}
|
||||||
|
enabled = cfg.get("enabled", True)
|
||||||
|
if isinstance(enabled, str):
|
||||||
|
enabled = enabled.strip().lower() not in {"false", "0", "off", "no"}
|
||||||
|
if not enabled:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
kb_id = self._resolve_kb_id()
|
||||||
|
if not kb_id:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
latest_user = ""
|
||||||
|
for msg in reversed(messages):
|
||||||
|
if msg.role == "user":
|
||||||
|
latest_user = (msg.content or "").strip()
|
||||||
|
break
|
||||||
|
if not latest_user:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
n_results = self._coerce_int(cfg.get("nResults"), self._RAG_DEFAULT_RESULTS)
|
||||||
|
n_results = max(1, min(n_results, self._RAG_MAX_RESULTS))
|
||||||
|
|
||||||
|
results = await search_knowledge_context(
|
||||||
|
kb_id=kb_id,
|
||||||
|
query=latest_user,
|
||||||
|
n_results=n_results,
|
||||||
|
)
|
||||||
|
prompt = self._build_knowledge_prompt(results)
|
||||||
|
if not prompt:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
logger.debug(f"RAG context injected (kb_id={kb_id}, chunks={len(results)})")
|
||||||
|
rag_system = LLMMessage(role="system", content=prompt)
|
||||||
|
if messages and messages[0].role == "system":
|
||||||
|
return [messages[0], rag_system, *messages[1:]]
|
||||||
|
return [rag_system, *messages]
|
||||||
|
|
||||||
async def generate(
|
async def generate(
|
||||||
self,
|
self,
|
||||||
messages: List[LLMMessage],
|
messages: List[LLMMessage],
|
||||||
@@ -115,7 +234,8 @@ class OpenAILLMService(BaseLLMService):
|
|||||||
if not self.client:
|
if not self.client:
|
||||||
raise RuntimeError("LLM service not connected")
|
raise RuntimeError("LLM service not connected")
|
||||||
|
|
||||||
prepared = self._prepare_messages(messages)
|
rag_messages = await self._with_knowledge_context(messages)
|
||||||
|
prepared = self._prepare_messages(rag_messages)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await self.client.chat.completions.create(
|
response = await self.client.chat.completions.create(
|
||||||
@@ -153,7 +273,8 @@ class OpenAILLMService(BaseLLMService):
|
|||||||
if not self.client:
|
if not self.client:
|
||||||
raise RuntimeError("LLM service not connected")
|
raise RuntimeError("LLM service not connected")
|
||||||
|
|
||||||
prepared = self._prepare_messages(messages)
|
rag_messages = await self._with_knowledge_context(messages)
|
||||||
|
prepared = self._prepare_messages(rag_messages)
|
||||||
self._cancel_event.clear()
|
self._cancel_event.clear()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
Reference in New Issue
Block a user