Implement KB features with codex

This commit is contained in:
Xin Wang
2026-02-10 07:35:08 +08:00
parent ed1f7fc8b0
commit 6b4391c423
5 changed files with 207 additions and 9 deletions

View File

@@ -68,6 +68,14 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> dict:
} }
warnings.append(f"Voice resource not found: {assistant.voice}") warnings.append(f"Voice resource not found: {assistant.voice}")
if assistant.knowledge_base_id:
metadata["knowledgeBaseId"] = assistant.knowledge_base_id
metadata["knowledge"] = {
"enabled": True,
"kbId": assistant.knowledge_base_id,
"nResults": 5,
}
return { return {
"assistantId": assistant.id, "assistantId": assistant.id,
"sessionStartMetadata": metadata, "sessionStartMetadata": metadata,
@@ -75,6 +83,7 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> dict:
"llmModelId": assistant.llm_model_id, "llmModelId": assistant.llm_model_id,
"asrModelId": assistant.asr_model_id, "asrModelId": assistant.asr_model_id,
"voiceId": assistant.voice, "voiceId": assistant.voice,
"knowledgeBaseId": assistant.knowledge_base_id,
}, },
"warnings": warnings, "warnings": warnings,
} }

View File

@@ -119,6 +119,14 @@ class TestAssistantAPI:
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["knowledgeBaseId"] == "non-existent-kb" assert response.json()["knowledgeBaseId"] == "non-existent-kb"
assistant_id = response.json()["id"]
runtime_resp = client.get(f"/api/assistants/{assistant_id}/runtime-config")
assert runtime_resp.status_code == 200
metadata = runtime_resp.json()["sessionStartMetadata"]
assert metadata["knowledgeBaseId"] == "non-existent-kb"
assert metadata["knowledge"]["enabled"] is True
assert metadata["knowledge"]["kbId"] == "non-existent-kb"
def test_assistant_with_model_references(self, client, sample_assistant_data): def test_assistant_with_model_references(self, client, sample_assistant_data):
"""Test creating assistant with model references""" """Test creating assistant with model references"""
sample_assistant_data.update({ sample_assistant_data.update({

View File

@@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, Optional from typing import Any, Dict, List, Optional
import aiohttp import aiohttp
from loguru import logger from loguru import logger
@@ -146,3 +146,46 @@ async def finalize_history_call_record(
except Exception as exc: except Exception as exc:
logger.warning(f"Failed to finalize history call record ({call_id}): {exc}") logger.warning(f"Failed to finalize history call record ({call_id}): {exc}")
return False return False
async def search_knowledge_context(
*,
kb_id: str,
query: str,
n_results: int = 5,
) -> List[Dict[str, Any]]:
"""Search backend knowledge base and return retrieval results."""
base_url = _backend_base_url()
if not base_url:
return []
if not kb_id or not query.strip():
return []
try:
safe_n_results = max(1, int(n_results))
except (TypeError, ValueError):
safe_n_results = 5
url = f"{base_url}/api/knowledge/search"
payload: Dict[str, Any] = {
"kb_id": kb_id,
"query": query,
"nResults": safe_n_results,
}
try:
async with aiohttp.ClientSession(timeout=_timeout()) as session:
async with session.post(url, json=payload) as resp:
if resp.status == 404:
logger.warning(f"Knowledge base not found for retrieval: {kb_id}")
return []
resp.raise_for_status()
data = await resp.json()
if not isinstance(data, dict):
return []
results = data.get("results", [])
if not isinstance(results, list):
return []
return [r for r in results if isinstance(r, dict)]
except Exception as exc:
logger.warning(f"Knowledge search failed (kb_id={kb_id}): {exc}")
return []

View File

@@ -13,11 +13,12 @@ event-driven design.
import asyncio import asyncio
import time import time
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import numpy as np import numpy as np
from loguru import logger from loguru import logger
from app.backend_client import search_knowledge_context
from app.config import settings from app.config import settings
from core.conversation import ConversationManager, ConversationState from core.conversation import ConversationManager, ConversationState
from core.events import get_event_bus from core.events import get_event_bus
@@ -26,7 +27,7 @@ from models.ws_v1 import ev
from processors.eou import EouDetector from processors.eou import EouDetector
from processors.vad import SileroVAD, VADProcessor from processors.vad import SileroVAD, VADProcessor
from services.asr import BufferedASRService from services.asr import BufferedASRService
from services.base import BaseASRService, BaseLLMService, BaseTTSService from services.base import BaseASRService, BaseLLMService, BaseTTSService, LLMMessage
from services.llm import MockLLMService, OpenAILLMService from services.llm import MockLLMService, OpenAILLMService
from services.siliconflow_asr import SiliconFlowASRService from services.siliconflow_asr import SiliconFlowASRService
from services.siliconflow_tts import SiliconFlowTTSService from services.siliconflow_tts import SiliconFlowTTSService
@@ -55,6 +56,9 @@ class DuplexPipeline:
_SENTENCE_TRAILING_CHARS = frozenset({"", "", "", ".", "!", "?", "", "~", "", "\n"}) _SENTENCE_TRAILING_CHARS = frozenset({"", "", "", ".", "!", "?", "", "~", "", "\n"})
_SENTENCE_CLOSERS = frozenset({'"', "'", "", "", ")", "]", "}", "", "", "", "", ""}) _SENTENCE_CLOSERS = frozenset({'"', "'", "", "", ")", "]", "}", "", "", "", "", ""})
_MIN_SPLIT_SPOKEN_CHARS = 6 _MIN_SPLIT_SPOKEN_CHARS = 6
_RAG_DEFAULT_RESULTS = 5
_RAG_MAX_RESULTS = 8
_RAG_MAX_CONTEXT_CHARS = 4000
def __init__( def __init__(
self, self,
@@ -156,6 +160,8 @@ class DuplexPipeline:
self._runtime_tts: Dict[str, Any] = {} self._runtime_tts: Dict[str, Any] = {}
self._runtime_system_prompt: Optional[str] = None self._runtime_system_prompt: Optional[str] = None
self._runtime_greeting: Optional[str] = None self._runtime_greeting: Optional[str] = None
self._runtime_knowledge: Dict[str, Any] = {}
self._runtime_knowledge_base_id: Optional[str] = None
logger.info(f"DuplexPipeline initialized for session {session_id}") logger.info(f"DuplexPipeline initialized for session {session_id}")
@@ -194,6 +200,18 @@ class DuplexPipeline:
if isinstance(services.get("tts"), dict): if isinstance(services.get("tts"), dict):
self._runtime_tts = services["tts"] self._runtime_tts = services["tts"]
knowledge_base_id = metadata.get("knowledgeBaseId")
if knowledge_base_id is not None:
kb_id = str(knowledge_base_id).strip()
self._runtime_knowledge_base_id = kb_id or None
knowledge = metadata.get("knowledge")
if isinstance(knowledge, dict):
self._runtime_knowledge = knowledge
kb_id = str(knowledge.get("kbId") or knowledge.get("knowledgeBaseId") or "").strip()
if kb_id:
self._runtime_knowledge_base_id = kb_id
async def start(self) -> None: async def start(self) -> None:
"""Start the pipeline and connect services.""" """Start the pipeline and connect services."""
try: try:
@@ -552,6 +570,103 @@ class DuplexPipeline:
await self.conversation.end_user_turn(user_text) await self.conversation.end_user_turn(user_text)
self._current_turn_task = asyncio.create_task(self._handle_turn(user_text)) self._current_turn_task = asyncio.create_task(self._handle_turn(user_text))
@staticmethod
def _coerce_int(value: Any, default: int) -> int:
try:
return int(value)
except (TypeError, ValueError):
return default
def _resolve_runtime_kb_id(self) -> Optional[str]:
if self._runtime_knowledge_base_id:
return self._runtime_knowledge_base_id
kb_id = str(self._runtime_knowledge.get("kbId") or self._runtime_knowledge.get("knowledgeBaseId") or "").strip()
return kb_id or None
def _build_knowledge_prompt(self, results: List[Dict[str, Any]]) -> Optional[str]:
if not results:
return None
lines = [
"You have retrieved the following knowledge base snippets.",
"Use them only when relevant to the latest user request.",
"If snippets are insufficient, say you are not sure instead of guessing.",
"",
]
used_chars = 0
used_count = 0
for item in results:
content = str(item.get("content") or "").strip()
if not content:
continue
if used_chars >= self._RAG_MAX_CONTEXT_CHARS:
break
metadata = item.get("metadata") if isinstance(item.get("metadata"), dict) else {}
doc_id = metadata.get("document_id")
chunk_index = metadata.get("chunk_index")
distance = item.get("distance")
source_parts = []
if doc_id:
source_parts.append(f"doc={doc_id}")
if chunk_index is not None:
source_parts.append(f"chunk={chunk_index}")
source = f" ({', '.join(source_parts)})" if source_parts else ""
distance_text = ""
try:
if distance is not None:
distance_text = f", distance={float(distance):.4f}"
except (TypeError, ValueError):
distance_text = ""
remaining = self._RAG_MAX_CONTEXT_CHARS - used_chars
snippet = content[:remaining].strip()
if not snippet:
continue
used_count += 1
lines.append(f"[{used_count}{source}{distance_text}] {snippet}")
used_chars += len(snippet)
if used_count == 0:
return None
return "\n".join(lines)
async def _build_turn_messages(self, user_text: str) -> List[LLMMessage]:
messages = self.conversation.get_messages()
kb_id = self._resolve_runtime_kb_id()
if not kb_id:
return messages
knowledge_cfg = self._runtime_knowledge if isinstance(self._runtime_knowledge, dict) else {}
enabled = knowledge_cfg.get("enabled", True)
if isinstance(enabled, str):
enabled = enabled.strip().lower() not in {"false", "0", "off", "no"}
if not enabled:
return messages
n_results = self._coerce_int(knowledge_cfg.get("nResults"), self._RAG_DEFAULT_RESULTS)
n_results = max(1, min(n_results, self._RAG_MAX_RESULTS))
results = await search_knowledge_context(
kb_id=kb_id,
query=user_text,
n_results=n_results,
)
prompt = self._build_knowledge_prompt(results)
if not prompt:
return messages
logger.debug(f"RAG context injected (kb_id={kb_id}, chunks={len(results)})")
rag_system = LLMMessage(role="system", content=prompt)
if messages and messages[0].role == "system":
return [messages[0], rag_system, *messages[1:]]
return [rag_system, *messages]
async def _handle_turn(self, user_text: str) -> None: async def _handle_turn(self, user_text: str) -> None:
""" """
Handle a complete conversation turn. Handle a complete conversation turn.
@@ -567,7 +682,7 @@ class DuplexPipeline:
self._first_audio_sent = False self._first_audio_sent = False
# Get AI response (streaming) # Get AI response (streaming)
messages = self.conversation.get_messages() messages = await self._build_turn_messages(user_text)
full_response = "" full_response = ""
await self.conversation.start_assistant_turn() await self.conversation.start_assistant_turn()

View File

@@ -988,10 +988,21 @@ export const DebugDrawer: React.FC<{
isOpen: boolean; isOpen: boolean;
onClose: () => void; onClose: () => void;
assistant: Assistant; assistant: Assistant;
voices: Voice[]; voices?: Voice[];
llmModels: LLMModel[]; llmModels?: LLMModel[];
asrModels: ASRModel[]; asrModels?: ASRModel[];
}> = ({ isOpen, onClose, assistant, voices, llmModels, asrModels }) => { sessionMetadataExtras?: Record<string, any>;
onProtocolEvent?: (event: Record<string, any>) => void;
}> = ({
isOpen,
onClose,
assistant,
voices = [],
llmModels = [],
asrModels = [],
sessionMetadataExtras,
onProtocolEvent,
}) => {
const TARGET_SAMPLE_RATE = 16000; const TARGET_SAMPLE_RATE = 16000;
const downsampleTo16k = (input: Float32Array, inputSampleRate: number): Float32Array => { const downsampleTo16k = (input: Float32Array, inputSampleRate: number): Float32Array => {
if (inputSampleRate === TARGET_SAMPLE_RATE) return input; if (inputSampleRate === TARGET_SAMPLE_RATE) return input;
@@ -1474,6 +1485,10 @@ export const DebugDrawer: React.FC<{
const warnings: string[] = []; const warnings: string[] = [];
const services: Record<string, any> = {}; const services: Record<string, any> = {};
const isExternalLlm = assistant.configMode === 'dify' || assistant.configMode === 'fastgpt'; const isExternalLlm = assistant.configMode === 'dify' || assistant.configMode === 'fastgpt';
const knowledgeBaseId = String(assistant.knowledgeBaseId || '').trim();
const knowledge = knowledgeBaseId
? { enabled: true, kbId: knowledgeBaseId, nResults: 5 }
: { enabled: false };
if (isExternalLlm) { if (isExternalLlm) {
services.llm = { services.llm = {
@@ -1541,6 +1556,8 @@ export const DebugDrawer: React.FC<{
sessionStartMetadata: { sessionStartMetadata: {
systemPrompt: assistant.prompt || '', systemPrompt: assistant.prompt || '',
greeting: assistant.opener || '', greeting: assistant.opener || '',
knowledgeBaseId,
knowledge,
services, services,
history: { history: {
assistantId: assistant.id, assistantId: assistant.id,
@@ -1556,7 +1573,10 @@ export const DebugDrawer: React.FC<{
const fetchRuntimeMetadata = async (): Promise<Record<string, any>> => { const fetchRuntimeMetadata = async (): Promise<Record<string, any>> => {
const localResolved = buildLocalResolvedRuntime(); const localResolved = buildLocalResolvedRuntime();
setResolvedConfigView(JSON.stringify(localResolved, null, 2)); setResolvedConfigView(JSON.stringify(localResolved, null, 2));
return localResolved.sessionStartMetadata; return {
...localResolved.sessionStartMetadata,
...(sessionMetadataExtras || {}),
};
}; };
const closeWs = () => { const closeWs = () => {
@@ -1622,6 +1642,9 @@ export const DebugDrawer: React.FC<{
} }
const type = payload?.type; const type = payload?.type;
if (onProtocolEvent) {
onProtocolEvent(payload);
}
if (type === 'hello.ack') { if (type === 'hello.ack') {
ws.send( ws.send(
JSON.stringify({ JSON.stringify({