Implement KB features with codex
This commit is contained in:
@@ -68,6 +68,14 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> dict:
|
||||
}
|
||||
warnings.append(f"Voice resource not found: {assistant.voice}")
|
||||
|
||||
if assistant.knowledge_base_id:
|
||||
metadata["knowledgeBaseId"] = assistant.knowledge_base_id
|
||||
metadata["knowledge"] = {
|
||||
"enabled": True,
|
||||
"kbId": assistant.knowledge_base_id,
|
||||
"nResults": 5,
|
||||
}
|
||||
|
||||
return {
|
||||
"assistantId": assistant.id,
|
||||
"sessionStartMetadata": metadata,
|
||||
@@ -75,6 +83,7 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> dict:
|
||||
"llmModelId": assistant.llm_model_id,
|
||||
"asrModelId": assistant.asr_model_id,
|
||||
"voiceId": assistant.voice,
|
||||
"knowledgeBaseId": assistant.knowledge_base_id,
|
||||
},
|
||||
"warnings": warnings,
|
||||
}
|
||||
|
||||
@@ -119,6 +119,14 @@ class TestAssistantAPI:
|
||||
assert response.status_code == 200
|
||||
assert response.json()["knowledgeBaseId"] == "non-existent-kb"
|
||||
|
||||
assistant_id = response.json()["id"]
|
||||
runtime_resp = client.get(f"/api/assistants/{assistant_id}/runtime-config")
|
||||
assert runtime_resp.status_code == 200
|
||||
metadata = runtime_resp.json()["sessionStartMetadata"]
|
||||
assert metadata["knowledgeBaseId"] == "non-existent-kb"
|
||||
assert metadata["knowledge"]["enabled"] is True
|
||||
assert metadata["knowledge"]["kbId"] == "non-existent-kb"
|
||||
|
||||
def test_assistant_with_model_references(self, client, sample_assistant_data):
|
||||
"""Test creating assistant with model references"""
|
||||
sample_assistant_data.update({
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
@@ -146,3 +146,46 @@ async def finalize_history_call_record(
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to finalize history call record ({call_id}): {exc}")
|
||||
return False
|
||||
|
||||
|
||||
async def search_knowledge_context(
|
||||
*,
|
||||
kb_id: str,
|
||||
query: str,
|
||||
n_results: int = 5,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search backend knowledge base and return retrieval results."""
|
||||
base_url = _backend_base_url()
|
||||
if not base_url:
|
||||
return []
|
||||
if not kb_id or not query.strip():
|
||||
return []
|
||||
try:
|
||||
safe_n_results = max(1, int(n_results))
|
||||
except (TypeError, ValueError):
|
||||
safe_n_results = 5
|
||||
|
||||
url = f"{base_url}/api/knowledge/search"
|
||||
payload: Dict[str, Any] = {
|
||||
"kb_id": kb_id,
|
||||
"query": query,
|
||||
"nResults": safe_n_results,
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=_timeout()) as session:
|
||||
async with session.post(url, json=payload) as resp:
|
||||
if resp.status == 404:
|
||||
logger.warning(f"Knowledge base not found for retrieval: {kb_id}")
|
||||
return []
|
||||
resp.raise_for_status()
|
||||
data = await resp.json()
|
||||
if not isinstance(data, dict):
|
||||
return []
|
||||
results = data.get("results", [])
|
||||
if not isinstance(results, list):
|
||||
return []
|
||||
return [r for r in results if isinstance(r, dict)]
|
||||
except Exception as exc:
|
||||
logger.warning(f"Knowledge search failed (kb_id={kb_id}): {exc}")
|
||||
return []
|
||||
|
||||
@@ -13,11 +13,12 @@ event-driven design.
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from app.backend_client import search_knowledge_context
|
||||
from app.config import settings
|
||||
from core.conversation import ConversationManager, ConversationState
|
||||
from core.events import get_event_bus
|
||||
@@ -26,7 +27,7 @@ from models.ws_v1 import ev
|
||||
from processors.eou import EouDetector
|
||||
from processors.vad import SileroVAD, VADProcessor
|
||||
from services.asr import BufferedASRService
|
||||
from services.base import BaseASRService, BaseLLMService, BaseTTSService
|
||||
from services.base import BaseASRService, BaseLLMService, BaseTTSService, LLMMessage
|
||||
from services.llm import MockLLMService, OpenAILLMService
|
||||
from services.siliconflow_asr import SiliconFlowASRService
|
||||
from services.siliconflow_tts import SiliconFlowTTSService
|
||||
@@ -55,6 +56,9 @@ class DuplexPipeline:
|
||||
_SENTENCE_TRAILING_CHARS = frozenset({"。", "!", "?", ".", "!", "?", "…", "~", "~", "\n"})
|
||||
_SENTENCE_CLOSERS = frozenset({'"', "'", "”", "’", ")", "]", "}", ")", "】", "」", "』", "》"})
|
||||
_MIN_SPLIT_SPOKEN_CHARS = 6
|
||||
_RAG_DEFAULT_RESULTS = 5
|
||||
_RAG_MAX_RESULTS = 8
|
||||
_RAG_MAX_CONTEXT_CHARS = 4000
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -156,6 +160,8 @@ class DuplexPipeline:
|
||||
self._runtime_tts: Dict[str, Any] = {}
|
||||
self._runtime_system_prompt: Optional[str] = None
|
||||
self._runtime_greeting: Optional[str] = None
|
||||
self._runtime_knowledge: Dict[str, Any] = {}
|
||||
self._runtime_knowledge_base_id: Optional[str] = None
|
||||
|
||||
logger.info(f"DuplexPipeline initialized for session {session_id}")
|
||||
|
||||
@@ -194,6 +200,18 @@ class DuplexPipeline:
|
||||
if isinstance(services.get("tts"), dict):
|
||||
self._runtime_tts = services["tts"]
|
||||
|
||||
knowledge_base_id = metadata.get("knowledgeBaseId")
|
||||
if knowledge_base_id is not None:
|
||||
kb_id = str(knowledge_base_id).strip()
|
||||
self._runtime_knowledge_base_id = kb_id or None
|
||||
|
||||
knowledge = metadata.get("knowledge")
|
||||
if isinstance(knowledge, dict):
|
||||
self._runtime_knowledge = knowledge
|
||||
kb_id = str(knowledge.get("kbId") or knowledge.get("knowledgeBaseId") or "").strip()
|
||||
if kb_id:
|
||||
self._runtime_knowledge_base_id = kb_id
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the pipeline and connect services."""
|
||||
try:
|
||||
@@ -552,6 +570,103 @@ class DuplexPipeline:
|
||||
await self.conversation.end_user_turn(user_text)
|
||||
self._current_turn_task = asyncio.create_task(self._handle_turn(user_text))
|
||||
|
||||
@staticmethod
|
||||
def _coerce_int(value: Any, default: int) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
def _resolve_runtime_kb_id(self) -> Optional[str]:
|
||||
if self._runtime_knowledge_base_id:
|
||||
return self._runtime_knowledge_base_id
|
||||
kb_id = str(self._runtime_knowledge.get("kbId") or self._runtime_knowledge.get("knowledgeBaseId") or "").strip()
|
||||
return kb_id or None
|
||||
|
||||
def _build_knowledge_prompt(self, results: List[Dict[str, Any]]) -> Optional[str]:
|
||||
if not results:
|
||||
return None
|
||||
|
||||
lines = [
|
||||
"You have retrieved the following knowledge base snippets.",
|
||||
"Use them only when relevant to the latest user request.",
|
||||
"If snippets are insufficient, say you are not sure instead of guessing.",
|
||||
"",
|
||||
]
|
||||
|
||||
used_chars = 0
|
||||
used_count = 0
|
||||
for item in results:
|
||||
content = str(item.get("content") or "").strip()
|
||||
if not content:
|
||||
continue
|
||||
if used_chars >= self._RAG_MAX_CONTEXT_CHARS:
|
||||
break
|
||||
|
||||
metadata = item.get("metadata") if isinstance(item.get("metadata"), dict) else {}
|
||||
doc_id = metadata.get("document_id")
|
||||
chunk_index = metadata.get("chunk_index")
|
||||
distance = item.get("distance")
|
||||
|
||||
source_parts = []
|
||||
if doc_id:
|
||||
source_parts.append(f"doc={doc_id}")
|
||||
if chunk_index is not None:
|
||||
source_parts.append(f"chunk={chunk_index}")
|
||||
source = f" ({', '.join(source_parts)})" if source_parts else ""
|
||||
|
||||
distance_text = ""
|
||||
try:
|
||||
if distance is not None:
|
||||
distance_text = f", distance={float(distance):.4f}"
|
||||
except (TypeError, ValueError):
|
||||
distance_text = ""
|
||||
|
||||
remaining = self._RAG_MAX_CONTEXT_CHARS - used_chars
|
||||
snippet = content[:remaining].strip()
|
||||
if not snippet:
|
||||
continue
|
||||
|
||||
used_count += 1
|
||||
lines.append(f"[{used_count}{source}{distance_text}] {snippet}")
|
||||
used_chars += len(snippet)
|
||||
|
||||
if used_count == 0:
|
||||
return None
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
async def _build_turn_messages(self, user_text: str) -> List[LLMMessage]:
|
||||
messages = self.conversation.get_messages()
|
||||
kb_id = self._resolve_runtime_kb_id()
|
||||
if not kb_id:
|
||||
return messages
|
||||
|
||||
knowledge_cfg = self._runtime_knowledge if isinstance(self._runtime_knowledge, dict) else {}
|
||||
enabled = knowledge_cfg.get("enabled", True)
|
||||
if isinstance(enabled, str):
|
||||
enabled = enabled.strip().lower() not in {"false", "0", "off", "no"}
|
||||
if not enabled:
|
||||
return messages
|
||||
|
||||
n_results = self._coerce_int(knowledge_cfg.get("nResults"), self._RAG_DEFAULT_RESULTS)
|
||||
n_results = max(1, min(n_results, self._RAG_MAX_RESULTS))
|
||||
|
||||
results = await search_knowledge_context(
|
||||
kb_id=kb_id,
|
||||
query=user_text,
|
||||
n_results=n_results,
|
||||
)
|
||||
prompt = self._build_knowledge_prompt(results)
|
||||
if not prompt:
|
||||
return messages
|
||||
|
||||
logger.debug(f"RAG context injected (kb_id={kb_id}, chunks={len(results)})")
|
||||
rag_system = LLMMessage(role="system", content=prompt)
|
||||
if messages and messages[0].role == "system":
|
||||
return [messages[0], rag_system, *messages[1:]]
|
||||
return [rag_system, *messages]
|
||||
|
||||
async def _handle_turn(self, user_text: str) -> None:
|
||||
"""
|
||||
Handle a complete conversation turn.
|
||||
@@ -567,7 +682,7 @@ class DuplexPipeline:
|
||||
self._first_audio_sent = False
|
||||
|
||||
# Get AI response (streaming)
|
||||
messages = self.conversation.get_messages()
|
||||
messages = await self._build_turn_messages(user_text)
|
||||
full_response = ""
|
||||
|
||||
await self.conversation.start_assistant_turn()
|
||||
|
||||
@@ -988,10 +988,21 @@ export const DebugDrawer: React.FC<{
|
||||
isOpen: boolean;
|
||||
onClose: () => void;
|
||||
assistant: Assistant;
|
||||
voices: Voice[];
|
||||
llmModels: LLMModel[];
|
||||
asrModels: ASRModel[];
|
||||
}> = ({ isOpen, onClose, assistant, voices, llmModels, asrModels }) => {
|
||||
voices?: Voice[];
|
||||
llmModels?: LLMModel[];
|
||||
asrModels?: ASRModel[];
|
||||
sessionMetadataExtras?: Record<string, any>;
|
||||
onProtocolEvent?: (event: Record<string, any>) => void;
|
||||
}> = ({
|
||||
isOpen,
|
||||
onClose,
|
||||
assistant,
|
||||
voices = [],
|
||||
llmModels = [],
|
||||
asrModels = [],
|
||||
sessionMetadataExtras,
|
||||
onProtocolEvent,
|
||||
}) => {
|
||||
const TARGET_SAMPLE_RATE = 16000;
|
||||
const downsampleTo16k = (input: Float32Array, inputSampleRate: number): Float32Array => {
|
||||
if (inputSampleRate === TARGET_SAMPLE_RATE) return input;
|
||||
@@ -1474,6 +1485,10 @@ export const DebugDrawer: React.FC<{
|
||||
const warnings: string[] = [];
|
||||
const services: Record<string, any> = {};
|
||||
const isExternalLlm = assistant.configMode === 'dify' || assistant.configMode === 'fastgpt';
|
||||
const knowledgeBaseId = String(assistant.knowledgeBaseId || '').trim();
|
||||
const knowledge = knowledgeBaseId
|
||||
? { enabled: true, kbId: knowledgeBaseId, nResults: 5 }
|
||||
: { enabled: false };
|
||||
|
||||
if (isExternalLlm) {
|
||||
services.llm = {
|
||||
@@ -1541,6 +1556,8 @@ export const DebugDrawer: React.FC<{
|
||||
sessionStartMetadata: {
|
||||
systemPrompt: assistant.prompt || '',
|
||||
greeting: assistant.opener || '',
|
||||
knowledgeBaseId,
|
||||
knowledge,
|
||||
services,
|
||||
history: {
|
||||
assistantId: assistant.id,
|
||||
@@ -1556,7 +1573,10 @@ export const DebugDrawer: React.FC<{
|
||||
const fetchRuntimeMetadata = async (): Promise<Record<string, any>> => {
|
||||
const localResolved = buildLocalResolvedRuntime();
|
||||
setResolvedConfigView(JSON.stringify(localResolved, null, 2));
|
||||
return localResolved.sessionStartMetadata;
|
||||
return {
|
||||
...localResolved.sessionStartMetadata,
|
||||
...(sessionMetadataExtras || {}),
|
||||
};
|
||||
};
|
||||
|
||||
const closeWs = () => {
|
||||
@@ -1622,6 +1642,9 @@ export const DebugDrawer: React.FC<{
|
||||
}
|
||||
|
||||
const type = payload?.type;
|
||||
if (onProtocolEvent) {
|
||||
onProtocolEvent(payload);
|
||||
}
|
||||
if (type === 'hello.ack') {
|
||||
ws.send(
|
||||
JSON.stringify({
|
||||
|
||||
Reference in New Issue
Block a user