Implement KB features with codex
This commit is contained in:
@@ -68,6 +68,14 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> dict:
|
|||||||
}
|
}
|
||||||
warnings.append(f"Voice resource not found: {assistant.voice}")
|
warnings.append(f"Voice resource not found: {assistant.voice}")
|
||||||
|
|
||||||
|
if assistant.knowledge_base_id:
|
||||||
|
metadata["knowledgeBaseId"] = assistant.knowledge_base_id
|
||||||
|
metadata["knowledge"] = {
|
||||||
|
"enabled": True,
|
||||||
|
"kbId": assistant.knowledge_base_id,
|
||||||
|
"nResults": 5,
|
||||||
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"assistantId": assistant.id,
|
"assistantId": assistant.id,
|
||||||
"sessionStartMetadata": metadata,
|
"sessionStartMetadata": metadata,
|
||||||
@@ -75,6 +83,7 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> dict:
|
|||||||
"llmModelId": assistant.llm_model_id,
|
"llmModelId": assistant.llm_model_id,
|
||||||
"asrModelId": assistant.asr_model_id,
|
"asrModelId": assistant.asr_model_id,
|
||||||
"voiceId": assistant.voice,
|
"voiceId": assistant.voice,
|
||||||
|
"knowledgeBaseId": assistant.knowledge_base_id,
|
||||||
},
|
},
|
||||||
"warnings": warnings,
|
"warnings": warnings,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -119,6 +119,14 @@ class TestAssistantAPI:
|
|||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["knowledgeBaseId"] == "non-existent-kb"
|
assert response.json()["knowledgeBaseId"] == "non-existent-kb"
|
||||||
|
|
||||||
|
assistant_id = response.json()["id"]
|
||||||
|
runtime_resp = client.get(f"/api/assistants/{assistant_id}/runtime-config")
|
||||||
|
assert runtime_resp.status_code == 200
|
||||||
|
metadata = runtime_resp.json()["sessionStartMetadata"]
|
||||||
|
assert metadata["knowledgeBaseId"] == "non-existent-kb"
|
||||||
|
assert metadata["knowledge"]["enabled"] is True
|
||||||
|
assert metadata["knowledge"]["kbId"] == "non-existent-kb"
|
||||||
|
|
||||||
def test_assistant_with_model_references(self, client, sample_assistant_data):
|
def test_assistant_with_model_references(self, client, sample_assistant_data):
|
||||||
"""Test creating assistant with model references"""
|
"""Test creating assistant with model references"""
|
||||||
sample_assistant_data.update({
|
sample_assistant_data.update({
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@@ -146,3 +146,46 @@ async def finalize_history_call_record(
|
|||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning(f"Failed to finalize history call record ({call_id}): {exc}")
|
logger.warning(f"Failed to finalize history call record ({call_id}): {exc}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def search_knowledge_context(
|
||||||
|
*,
|
||||||
|
kb_id: str,
|
||||||
|
query: str,
|
||||||
|
n_results: int = 5,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Search backend knowledge base and return retrieval results."""
|
||||||
|
base_url = _backend_base_url()
|
||||||
|
if not base_url:
|
||||||
|
return []
|
||||||
|
if not kb_id or not query.strip():
|
||||||
|
return []
|
||||||
|
try:
|
||||||
|
safe_n_results = max(1, int(n_results))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
safe_n_results = 5
|
||||||
|
|
||||||
|
url = f"{base_url}/api/knowledge/search"
|
||||||
|
payload: Dict[str, Any] = {
|
||||||
|
"kb_id": kb_id,
|
||||||
|
"query": query,
|
||||||
|
"nResults": safe_n_results,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession(timeout=_timeout()) as session:
|
||||||
|
async with session.post(url, json=payload) as resp:
|
||||||
|
if resp.status == 404:
|
||||||
|
logger.warning(f"Knowledge base not found for retrieval: {kb_id}")
|
||||||
|
return []
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = await resp.json()
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
return []
|
||||||
|
results = data.get("results", [])
|
||||||
|
if not isinstance(results, list):
|
||||||
|
return []
|
||||||
|
return [r for r in results if isinstance(r, dict)]
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(f"Knowledge search failed (kb_id={kb_id}): {exc}")
|
||||||
|
return []
|
||||||
|
|||||||
@@ -13,11 +13,12 @@ event-driven design.
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from typing import Any, Dict, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from app.backend_client import search_knowledge_context
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from core.conversation import ConversationManager, ConversationState
|
from core.conversation import ConversationManager, ConversationState
|
||||||
from core.events import get_event_bus
|
from core.events import get_event_bus
|
||||||
@@ -26,7 +27,7 @@ from models.ws_v1 import ev
|
|||||||
from processors.eou import EouDetector
|
from processors.eou import EouDetector
|
||||||
from processors.vad import SileroVAD, VADProcessor
|
from processors.vad import SileroVAD, VADProcessor
|
||||||
from services.asr import BufferedASRService
|
from services.asr import BufferedASRService
|
||||||
from services.base import BaseASRService, BaseLLMService, BaseTTSService
|
from services.base import BaseASRService, BaseLLMService, BaseTTSService, LLMMessage
|
||||||
from services.llm import MockLLMService, OpenAILLMService
|
from services.llm import MockLLMService, OpenAILLMService
|
||||||
from services.siliconflow_asr import SiliconFlowASRService
|
from services.siliconflow_asr import SiliconFlowASRService
|
||||||
from services.siliconflow_tts import SiliconFlowTTSService
|
from services.siliconflow_tts import SiliconFlowTTSService
|
||||||
@@ -55,6 +56,9 @@ class DuplexPipeline:
|
|||||||
_SENTENCE_TRAILING_CHARS = frozenset({"。", "!", "?", ".", "!", "?", "…", "~", "~", "\n"})
|
_SENTENCE_TRAILING_CHARS = frozenset({"。", "!", "?", ".", "!", "?", "…", "~", "~", "\n"})
|
||||||
_SENTENCE_CLOSERS = frozenset({'"', "'", "”", "’", ")", "]", "}", ")", "】", "」", "』", "》"})
|
_SENTENCE_CLOSERS = frozenset({'"', "'", "”", "’", ")", "]", "}", ")", "】", "」", "』", "》"})
|
||||||
_MIN_SPLIT_SPOKEN_CHARS = 6
|
_MIN_SPLIT_SPOKEN_CHARS = 6
|
||||||
|
_RAG_DEFAULT_RESULTS = 5
|
||||||
|
_RAG_MAX_RESULTS = 8
|
||||||
|
_RAG_MAX_CONTEXT_CHARS = 4000
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -156,6 +160,8 @@ class DuplexPipeline:
|
|||||||
self._runtime_tts: Dict[str, Any] = {}
|
self._runtime_tts: Dict[str, Any] = {}
|
||||||
self._runtime_system_prompt: Optional[str] = None
|
self._runtime_system_prompt: Optional[str] = None
|
||||||
self._runtime_greeting: Optional[str] = None
|
self._runtime_greeting: Optional[str] = None
|
||||||
|
self._runtime_knowledge: Dict[str, Any] = {}
|
||||||
|
self._runtime_knowledge_base_id: Optional[str] = None
|
||||||
|
|
||||||
logger.info(f"DuplexPipeline initialized for session {session_id}")
|
logger.info(f"DuplexPipeline initialized for session {session_id}")
|
||||||
|
|
||||||
@@ -194,6 +200,18 @@ class DuplexPipeline:
|
|||||||
if isinstance(services.get("tts"), dict):
|
if isinstance(services.get("tts"), dict):
|
||||||
self._runtime_tts = services["tts"]
|
self._runtime_tts = services["tts"]
|
||||||
|
|
||||||
|
knowledge_base_id = metadata.get("knowledgeBaseId")
|
||||||
|
if knowledge_base_id is not None:
|
||||||
|
kb_id = str(knowledge_base_id).strip()
|
||||||
|
self._runtime_knowledge_base_id = kb_id or None
|
||||||
|
|
||||||
|
knowledge = metadata.get("knowledge")
|
||||||
|
if isinstance(knowledge, dict):
|
||||||
|
self._runtime_knowledge = knowledge
|
||||||
|
kb_id = str(knowledge.get("kbId") or knowledge.get("knowledgeBaseId") or "").strip()
|
||||||
|
if kb_id:
|
||||||
|
self._runtime_knowledge_base_id = kb_id
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Start the pipeline and connect services."""
|
"""Start the pipeline and connect services."""
|
||||||
try:
|
try:
|
||||||
@@ -552,6 +570,103 @@ class DuplexPipeline:
|
|||||||
await self.conversation.end_user_turn(user_text)
|
await self.conversation.end_user_turn(user_text)
|
||||||
self._current_turn_task = asyncio.create_task(self._handle_turn(user_text))
|
self._current_turn_task = asyncio.create_task(self._handle_turn(user_text))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _coerce_int(value: Any, default: int) -> int:
|
||||||
|
try:
|
||||||
|
return int(value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return default
|
||||||
|
|
||||||
|
def _resolve_runtime_kb_id(self) -> Optional[str]:
|
||||||
|
if self._runtime_knowledge_base_id:
|
||||||
|
return self._runtime_knowledge_base_id
|
||||||
|
kb_id = str(self._runtime_knowledge.get("kbId") or self._runtime_knowledge.get("knowledgeBaseId") or "").strip()
|
||||||
|
return kb_id or None
|
||||||
|
|
||||||
|
def _build_knowledge_prompt(self, results: List[Dict[str, Any]]) -> Optional[str]:
|
||||||
|
if not results:
|
||||||
|
return None
|
||||||
|
|
||||||
|
lines = [
|
||||||
|
"You have retrieved the following knowledge base snippets.",
|
||||||
|
"Use them only when relevant to the latest user request.",
|
||||||
|
"If snippets are insufficient, say you are not sure instead of guessing.",
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
|
||||||
|
used_chars = 0
|
||||||
|
used_count = 0
|
||||||
|
for item in results:
|
||||||
|
content = str(item.get("content") or "").strip()
|
||||||
|
if not content:
|
||||||
|
continue
|
||||||
|
if used_chars >= self._RAG_MAX_CONTEXT_CHARS:
|
||||||
|
break
|
||||||
|
|
||||||
|
metadata = item.get("metadata") if isinstance(item.get("metadata"), dict) else {}
|
||||||
|
doc_id = metadata.get("document_id")
|
||||||
|
chunk_index = metadata.get("chunk_index")
|
||||||
|
distance = item.get("distance")
|
||||||
|
|
||||||
|
source_parts = []
|
||||||
|
if doc_id:
|
||||||
|
source_parts.append(f"doc={doc_id}")
|
||||||
|
if chunk_index is not None:
|
||||||
|
source_parts.append(f"chunk={chunk_index}")
|
||||||
|
source = f" ({', '.join(source_parts)})" if source_parts else ""
|
||||||
|
|
||||||
|
distance_text = ""
|
||||||
|
try:
|
||||||
|
if distance is not None:
|
||||||
|
distance_text = f", distance={float(distance):.4f}"
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
distance_text = ""
|
||||||
|
|
||||||
|
remaining = self._RAG_MAX_CONTEXT_CHARS - used_chars
|
||||||
|
snippet = content[:remaining].strip()
|
||||||
|
if not snippet:
|
||||||
|
continue
|
||||||
|
|
||||||
|
used_count += 1
|
||||||
|
lines.append(f"[{used_count}{source}{distance_text}] {snippet}")
|
||||||
|
used_chars += len(snippet)
|
||||||
|
|
||||||
|
if used_count == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
async def _build_turn_messages(self, user_text: str) -> List[LLMMessage]:
|
||||||
|
messages = self.conversation.get_messages()
|
||||||
|
kb_id = self._resolve_runtime_kb_id()
|
||||||
|
if not kb_id:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
knowledge_cfg = self._runtime_knowledge if isinstance(self._runtime_knowledge, dict) else {}
|
||||||
|
enabled = knowledge_cfg.get("enabled", True)
|
||||||
|
if isinstance(enabled, str):
|
||||||
|
enabled = enabled.strip().lower() not in {"false", "0", "off", "no"}
|
||||||
|
if not enabled:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
n_results = self._coerce_int(knowledge_cfg.get("nResults"), self._RAG_DEFAULT_RESULTS)
|
||||||
|
n_results = max(1, min(n_results, self._RAG_MAX_RESULTS))
|
||||||
|
|
||||||
|
results = await search_knowledge_context(
|
||||||
|
kb_id=kb_id,
|
||||||
|
query=user_text,
|
||||||
|
n_results=n_results,
|
||||||
|
)
|
||||||
|
prompt = self._build_knowledge_prompt(results)
|
||||||
|
if not prompt:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
logger.debug(f"RAG context injected (kb_id={kb_id}, chunks={len(results)})")
|
||||||
|
rag_system = LLMMessage(role="system", content=prompt)
|
||||||
|
if messages and messages[0].role == "system":
|
||||||
|
return [messages[0], rag_system, *messages[1:]]
|
||||||
|
return [rag_system, *messages]
|
||||||
|
|
||||||
async def _handle_turn(self, user_text: str) -> None:
|
async def _handle_turn(self, user_text: str) -> None:
|
||||||
"""
|
"""
|
||||||
Handle a complete conversation turn.
|
Handle a complete conversation turn.
|
||||||
@@ -567,7 +682,7 @@ class DuplexPipeline:
|
|||||||
self._first_audio_sent = False
|
self._first_audio_sent = False
|
||||||
|
|
||||||
# Get AI response (streaming)
|
# Get AI response (streaming)
|
||||||
messages = self.conversation.get_messages()
|
messages = await self._build_turn_messages(user_text)
|
||||||
full_response = ""
|
full_response = ""
|
||||||
|
|
||||||
await self.conversation.start_assistant_turn()
|
await self.conversation.start_assistant_turn()
|
||||||
|
|||||||
@@ -988,10 +988,21 @@ export const DebugDrawer: React.FC<{
|
|||||||
isOpen: boolean;
|
isOpen: boolean;
|
||||||
onClose: () => void;
|
onClose: () => void;
|
||||||
assistant: Assistant;
|
assistant: Assistant;
|
||||||
voices: Voice[];
|
voices?: Voice[];
|
||||||
llmModels: LLMModel[];
|
llmModels?: LLMModel[];
|
||||||
asrModels: ASRModel[];
|
asrModels?: ASRModel[];
|
||||||
}> = ({ isOpen, onClose, assistant, voices, llmModels, asrModels }) => {
|
sessionMetadataExtras?: Record<string, any>;
|
||||||
|
onProtocolEvent?: (event: Record<string, any>) => void;
|
||||||
|
}> = ({
|
||||||
|
isOpen,
|
||||||
|
onClose,
|
||||||
|
assistant,
|
||||||
|
voices = [],
|
||||||
|
llmModels = [],
|
||||||
|
asrModels = [],
|
||||||
|
sessionMetadataExtras,
|
||||||
|
onProtocolEvent,
|
||||||
|
}) => {
|
||||||
const TARGET_SAMPLE_RATE = 16000;
|
const TARGET_SAMPLE_RATE = 16000;
|
||||||
const downsampleTo16k = (input: Float32Array, inputSampleRate: number): Float32Array => {
|
const downsampleTo16k = (input: Float32Array, inputSampleRate: number): Float32Array => {
|
||||||
if (inputSampleRate === TARGET_SAMPLE_RATE) return input;
|
if (inputSampleRate === TARGET_SAMPLE_RATE) return input;
|
||||||
@@ -1474,6 +1485,10 @@ export const DebugDrawer: React.FC<{
|
|||||||
const warnings: string[] = [];
|
const warnings: string[] = [];
|
||||||
const services: Record<string, any> = {};
|
const services: Record<string, any> = {};
|
||||||
const isExternalLlm = assistant.configMode === 'dify' || assistant.configMode === 'fastgpt';
|
const isExternalLlm = assistant.configMode === 'dify' || assistant.configMode === 'fastgpt';
|
||||||
|
const knowledgeBaseId = String(assistant.knowledgeBaseId || '').trim();
|
||||||
|
const knowledge = knowledgeBaseId
|
||||||
|
? { enabled: true, kbId: knowledgeBaseId, nResults: 5 }
|
||||||
|
: { enabled: false };
|
||||||
|
|
||||||
if (isExternalLlm) {
|
if (isExternalLlm) {
|
||||||
services.llm = {
|
services.llm = {
|
||||||
@@ -1541,6 +1556,8 @@ export const DebugDrawer: React.FC<{
|
|||||||
sessionStartMetadata: {
|
sessionStartMetadata: {
|
||||||
systemPrompt: assistant.prompt || '',
|
systemPrompt: assistant.prompt || '',
|
||||||
greeting: assistant.opener || '',
|
greeting: assistant.opener || '',
|
||||||
|
knowledgeBaseId,
|
||||||
|
knowledge,
|
||||||
services,
|
services,
|
||||||
history: {
|
history: {
|
||||||
assistantId: assistant.id,
|
assistantId: assistant.id,
|
||||||
@@ -1556,7 +1573,10 @@ export const DebugDrawer: React.FC<{
|
|||||||
const fetchRuntimeMetadata = async (): Promise<Record<string, any>> => {
|
const fetchRuntimeMetadata = async (): Promise<Record<string, any>> => {
|
||||||
const localResolved = buildLocalResolvedRuntime();
|
const localResolved = buildLocalResolvedRuntime();
|
||||||
setResolvedConfigView(JSON.stringify(localResolved, null, 2));
|
setResolvedConfigView(JSON.stringify(localResolved, null, 2));
|
||||||
return localResolved.sessionStartMetadata;
|
return {
|
||||||
|
...localResolved.sessionStartMetadata,
|
||||||
|
...(sessionMetadataExtras || {}),
|
||||||
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
const closeWs = () => {
|
const closeWs = () => {
|
||||||
@@ -1622,6 +1642,9 @@ export const DebugDrawer: React.FC<{
|
|||||||
}
|
}
|
||||||
|
|
||||||
const type = payload?.type;
|
const type = payload?.type;
|
||||||
|
if (onProtocolEvent) {
|
||||||
|
onProtocolEvent(payload);
|
||||||
|
}
|
||||||
if (type === 'hello.ack') {
|
if (type === 'hello.ack') {
|
||||||
ws.send(
|
ws.send(
|
||||||
JSON.stringify({
|
JSON.stringify({
|
||||||
|
|||||||
Reference in New Issue
Block a user