diff --git a/api/app/routers/assistants.py b/api/app/routers/assistants.py index 44ae3ba..0e5227e 100644 --- a/api/app/routers/assistants.py +++ b/api/app/routers/assistants.py @@ -68,6 +68,14 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> dict: } warnings.append(f"Voice resource not found: {assistant.voice}") + if assistant.knowledge_base_id: + metadata["knowledgeBaseId"] = assistant.knowledge_base_id + metadata["knowledge"] = { + "enabled": True, + "kbId": assistant.knowledge_base_id, + "nResults": 5, + } + return { "assistantId": assistant.id, "sessionStartMetadata": metadata, @@ -75,6 +83,7 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> dict: "llmModelId": assistant.llm_model_id, "asrModelId": assistant.asr_model_id, "voiceId": assistant.voice, + "knowledgeBaseId": assistant.knowledge_base_id, }, "warnings": warnings, } diff --git a/api/tests/test_assistants.py b/api/tests/test_assistants.py index 9d6b200..106ff10 100644 --- a/api/tests/test_assistants.py +++ b/api/tests/test_assistants.py @@ -119,6 +119,14 @@ class TestAssistantAPI: assert response.status_code == 200 assert response.json()["knowledgeBaseId"] == "non-existent-kb" + assistant_id = response.json()["id"] + runtime_resp = client.get(f"/api/assistants/{assistant_id}/runtime-config") + assert runtime_resp.status_code == 200 + metadata = runtime_resp.json()["sessionStartMetadata"] + assert metadata["knowledgeBaseId"] == "non-existent-kb" + assert metadata["knowledge"]["enabled"] is True + assert metadata["knowledge"]["kbId"] == "non-existent-kb" + def test_assistant_with_model_references(self, client, sample_assistant_data): """Test creating assistant with model references""" sample_assistant_data.update({ diff --git a/engine/app/backend_client.py b/engine/app/backend_client.py index 4d8aa8c..9bd4e77 100644 --- a/engine/app/backend_client.py +++ b/engine/app/backend_client.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional import aiohttp from loguru import logger @@ -146,3 +146,46 @@ async def finalize_history_call_record( except Exception as exc: logger.warning(f"Failed to finalize history call record ({call_id}): {exc}") return False + + +async def search_knowledge_context( + *, + kb_id: str, + query: str, + n_results: int = 5, +) -> List[Dict[str, Any]]: + """Search backend knowledge base and return retrieval results.""" + base_url = _backend_base_url() + if not base_url: + return [] + if not kb_id or not query.strip(): + return [] + try: + safe_n_results = max(1, int(n_results)) + except (TypeError, ValueError): + safe_n_results = 5 + + url = f"{base_url}/api/knowledge/search" + payload: Dict[str, Any] = { + "kb_id": kb_id, + "query": query, + "nResults": safe_n_results, + } + + try: + async with aiohttp.ClientSession(timeout=_timeout()) as session: + async with session.post(url, json=payload) as resp: + if resp.status == 404: + logger.warning(f"Knowledge base not found for retrieval: {kb_id}") + return [] + resp.raise_for_status() + data = await resp.json() + if not isinstance(data, dict): + return [] + results = data.get("results", []) + if not isinstance(results, list): + return [] + return [r for r in results if isinstance(r, dict)] + except Exception as exc: + logger.warning(f"Knowledge search failed (kb_id={kb_id}): {exc}") + return [] diff --git a/engine/core/duplex_pipeline.py b/engine/core/duplex_pipeline.py index 561997f..6197667 100644 --- a/engine/core/duplex_pipeline.py +++ b/engine/core/duplex_pipeline.py @@ -13,11 +13,12 @@ event-driven design. import asyncio import time -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import numpy as np from loguru import logger +from app.backend_client import search_knowledge_context from app.config import settings from core.conversation import ConversationManager, ConversationState from core.events import get_event_bus @@ -26,7 +27,7 @@ from models.ws_v1 import ev from processors.eou import EouDetector from processors.vad import SileroVAD, VADProcessor from services.asr import BufferedASRService -from services.base import BaseASRService, BaseLLMService, BaseTTSService +from services.base import BaseASRService, BaseLLMService, BaseTTSService, LLMMessage from services.llm import MockLLMService, OpenAILLMService from services.siliconflow_asr import SiliconFlowASRService from services.siliconflow_tts import SiliconFlowTTSService @@ -55,6 +56,9 @@ class DuplexPipeline: _SENTENCE_TRAILING_CHARS = frozenset({"。", "!", "?", ".", "!", "?", "…", "~", "~", "\n"}) _SENTENCE_CLOSERS = frozenset({'"', "'", "”", "’", ")", "]", "}", ")", "】", "」", "』", "》"}) _MIN_SPLIT_SPOKEN_CHARS = 6 + _RAG_DEFAULT_RESULTS = 5 + _RAG_MAX_RESULTS = 8 + _RAG_MAX_CONTEXT_CHARS = 4000 def __init__( self, @@ -156,6 +160,8 @@ class DuplexPipeline: self._runtime_tts: Dict[str, Any] = {} self._runtime_system_prompt: Optional[str] = None self._runtime_greeting: Optional[str] = None + self._runtime_knowledge: Dict[str, Any] = {} + self._runtime_knowledge_base_id: Optional[str] = None logger.info(f"DuplexPipeline initialized for session {session_id}") @@ -194,6 +200,18 @@ class DuplexPipeline: if isinstance(services.get("tts"), dict): self._runtime_tts = services["tts"] + knowledge_base_id = metadata.get("knowledgeBaseId") + if knowledge_base_id is not None: + kb_id = str(knowledge_base_id).strip() + self._runtime_knowledge_base_id = kb_id or None + + knowledge = metadata.get("knowledge") + if isinstance(knowledge, dict): + self._runtime_knowledge = knowledge + kb_id = str(knowledge.get("kbId") or knowledge.get("knowledgeBaseId") or "").strip() + if kb_id: + self._runtime_knowledge_base_id = kb_id + async def start(self) -> None: """Start the pipeline and connect services.""" try: @@ -552,6 +570,103 @@ class DuplexPipeline: await self.conversation.end_user_turn(user_text) self._current_turn_task = asyncio.create_task(self._handle_turn(user_text)) + @staticmethod + def _coerce_int(value: Any, default: int) -> int: + try: + return int(value) + except (TypeError, ValueError): + return default + + def _resolve_runtime_kb_id(self) -> Optional[str]: + if self._runtime_knowledge_base_id: + return self._runtime_knowledge_base_id + kb_id = str(self._runtime_knowledge.get("kbId") or self._runtime_knowledge.get("knowledgeBaseId") or "").strip() + return kb_id or None + + def _build_knowledge_prompt(self, results: List[Dict[str, Any]]) -> Optional[str]: + if not results: + return None + + lines = [ + "You have retrieved the following knowledge base snippets.", + "Use them only when relevant to the latest user request.", + "If snippets are insufficient, say you are not sure instead of guessing.", + "", + ] + + used_chars = 0 + used_count = 0 + for item in results: + content = str(item.get("content") or "").strip() + if not content: + continue + if used_chars >= self._RAG_MAX_CONTEXT_CHARS: + break + + metadata = item.get("metadata") if isinstance(item.get("metadata"), dict) else {} + doc_id = metadata.get("document_id") + chunk_index = metadata.get("chunk_index") + distance = item.get("distance") + + source_parts = [] + if doc_id: + source_parts.append(f"doc={doc_id}") + if chunk_index is not None: + source_parts.append(f"chunk={chunk_index}") + source = f" ({', '.join(source_parts)})" if source_parts else "" + + distance_text = "" + try: + if distance is not None: + distance_text = f", distance={float(distance):.4f}" + except (TypeError, ValueError): + distance_text = "" + + remaining = self._RAG_MAX_CONTEXT_CHARS - used_chars + snippet = content[:remaining].strip() + if not snippet: + continue + + used_count += 1 + lines.append(f"[{used_count}{source}{distance_text}] {snippet}") + used_chars += len(snippet) + + if used_count == 0: + return None + + return "\n".join(lines) + + async def _build_turn_messages(self, user_text: str) -> List[LLMMessage]: + messages = self.conversation.get_messages() + kb_id = self._resolve_runtime_kb_id() + if not kb_id: + return messages + + knowledge_cfg = self._runtime_knowledge if isinstance(self._runtime_knowledge, dict) else {} + enabled = knowledge_cfg.get("enabled", True) + if isinstance(enabled, str): + enabled = enabled.strip().lower() not in {"false", "0", "off", "no"} + if not enabled: + return messages + + n_results = self._coerce_int(knowledge_cfg.get("nResults"), self._RAG_DEFAULT_RESULTS) + n_results = max(1, min(n_results, self._RAG_MAX_RESULTS)) + + results = await search_knowledge_context( + kb_id=kb_id, + query=user_text, + n_results=n_results, + ) + prompt = self._build_knowledge_prompt(results) + if not prompt: + return messages + + logger.debug(f"RAG context injected (kb_id={kb_id}, chunks={len(results)})") + rag_system = LLMMessage(role="system", content=prompt) + if messages and messages[0].role == "system": + return [messages[0], rag_system, *messages[1:]] + return [rag_system, *messages] + async def _handle_turn(self, user_text: str) -> None: """ Handle a complete conversation turn. @@ -567,7 +682,7 @@ class DuplexPipeline: self._first_audio_sent = False # Get AI response (streaming) - messages = self.conversation.get_messages() + messages = await self._build_turn_messages(user_text) full_response = "" await self.conversation.start_assistant_turn() diff --git a/web/pages/Assistants.tsx b/web/pages/Assistants.tsx index a564c28..d9a060f 100644 --- a/web/pages/Assistants.tsx +++ b/web/pages/Assistants.tsx @@ -988,10 +988,21 @@ export const DebugDrawer: React.FC<{ isOpen: boolean; onClose: () => void; assistant: Assistant; - voices: Voice[]; - llmModels: LLMModel[]; - asrModels: ASRModel[]; -}> = ({ isOpen, onClose, assistant, voices, llmModels, asrModels }) => { + voices?: Voice[]; + llmModels?: LLMModel[]; + asrModels?: ASRModel[]; + sessionMetadataExtras?: Record; + onProtocolEvent?: (event: Record) => void; +}> = ({ + isOpen, + onClose, + assistant, + voices = [], + llmModels = [], + asrModels = [], + sessionMetadataExtras, + onProtocolEvent, +}) => { const TARGET_SAMPLE_RATE = 16000; const downsampleTo16k = (input: Float32Array, inputSampleRate: number): Float32Array => { if (inputSampleRate === TARGET_SAMPLE_RATE) return input; @@ -1474,6 +1485,10 @@ export const DebugDrawer: React.FC<{ const warnings: string[] = []; const services: Record = {}; const isExternalLlm = assistant.configMode === 'dify' || assistant.configMode === 'fastgpt'; + const knowledgeBaseId = String(assistant.knowledgeBaseId || '').trim(); + const knowledge = knowledgeBaseId + ? { enabled: true, kbId: knowledgeBaseId, nResults: 5 } + : { enabled: false }; if (isExternalLlm) { services.llm = { @@ -1541,6 +1556,8 @@ export const DebugDrawer: React.FC<{ sessionStartMetadata: { systemPrompt: assistant.prompt || '', greeting: assistant.opener || '', + knowledgeBaseId, + knowledge, services, history: { assistantId: assistant.id, @@ -1556,7 +1573,10 @@ export const DebugDrawer: React.FC<{ const fetchRuntimeMetadata = async (): Promise> => { const localResolved = buildLocalResolvedRuntime(); setResolvedConfigView(JSON.stringify(localResolved, null, 2)); - return localResolved.sessionStartMetadata; + return { + ...localResolved.sessionStartMetadata, + ...(sessionMetadataExtras || {}), + }; }; const closeWs = () => { @@ -1622,6 +1642,9 @@ export const DebugDrawer: React.FC<{ } const type = payload?.type; + if (onProtocolEvent) { + onProtocolEvent(payload); + } if (type === 'hello.ack') { ws.send( JSON.stringify({