Implement knowledge base management and enhance assistant configuration

Add CRUD functionality for knowledge bases, including routes for listing, creating, updating, and deleting knowledge bases. Update the assistant model to include foreign key references to knowledge bases and modify the assistant configuration to handle external API keys securely. Refactor related services and routes to accommodate these changes, ensuring proper handling of credential resolution and configuration normalization.
This commit is contained in:
Xin Wang
2026-06-09 08:31:39 +08:00
parent 34fba494a3
commit b444ea777c
6 changed files with 304 additions and 56 deletions

View File

@@ -18,7 +18,14 @@ from db.session import init_db
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from routes import assistants, credentials, health, voice_webrtc, voice_ws
from routes import (
assistants,
credentials,
health,
knowledge_bases,
voice_webrtc,
voice_ws,
)
@asynccontextmanager
@@ -40,6 +47,7 @@ app.add_middleware(
app.include_router(health.router)
app.include_router(assistants.router)
app.include_router(credentials.router)
app.include_router(knowledge_bases.router)
app.include_router(voice_webrtc.router)
app.include_router(voice_ws.router)

View File

@@ -9,7 +9,7 @@
from datetime import datetime
from sqlalchemy import Boolean, DateTime, String, func
from sqlalchemy import JSON, Boolean, DateTime, ForeignKey, String, func
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
@@ -37,19 +37,67 @@ class ProviderCredential(Base):
)
class Assistant(Base):
__tablename__ = "assistants"
class KnowledgeBase(Base):
"""知识库注册表。本身引用一个 Embedding 凭证(用哪个向量模型)。
id: Mapped[str] = mapped_column(String(40), primary_key=True) # asst_xxx
文档/分块(pgvector)是 KB 内部实现,这里先不展开;助手侧只认 knowledge_base_id。
"""
__tablename__ = "knowledge_bases"
id: Mapped[str] = mapped_column(String(40), primary_key=True) # kb_xxx
name: Mapped[str] = mapped_column(String(128))
greeting: Mapped[str] = mapped_column(String(2048), default="")
prompt: Mapped[str] = mapped_column(String(8192), default="")
runtime_mode: Mapped[str] = mapped_column(String(16), default="pipeline")
# 模型/音色的"选项名",不是 key
model: Mapped[str] = mapped_column(String(128), default="")
asr: Mapped[str] = mapped_column(String(128), default="")
voice: Mapped[str] = mapped_column(String(128), default="")
enable_interrupt: Mapped[bool] = mapped_column(Boolean, default=True)
description: Mapped[str] = mapped_column(String(2048), default="")
# 该 KB 用哪个向量模型;凭证被删则置空
embedding_credential_id: Mapped[str | None] = mapped_column(
String(40),
ForeignKey("provider_credentials.id", ondelete="SET NULL"),
nullable=True,
)
status: Mapped[str] = mapped_column(String(16), default="active") # active|archived
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
class Assistant(Base):
"""助手(单表,无版本化)。type 为可变普通列,5 种类型共用此表。
模型/KB 以 FK 引用注册表;类型专属字段塞进 config(JSON)。
"""
__tablename__ = "assistants"
id: Mapped[str] = mapped_column(String(40), primary_key=True) # asst_xxx
name: Mapped[str] = mapped_column(String(128))
# prompt|workflow|dify|fastgpt|opencode;创建后可改
type: Mapped[str] = mapped_column(String(16), index=True, default="prompt")
runtime_mode: Mapped[str] = mapped_column(String(16), default="pipeline")
greeting: Mapped[str] = mapped_column(String(2048), default="")
enable_interrupt: Mapped[bool] = mapped_column(Boolean, default=True)
# ---- 引用"注册好的资源":凭证被删 → SET NULL(resolver 有默认/.env 兜底) ----
llm_credential_id: Mapped[str | None] = mapped_column(
String(40), ForeignKey("provider_credentials.id", ondelete="SET NULL"), nullable=True
)
asr_credential_id: Mapped[str | None] = mapped_column(
String(40), ForeignKey("provider_credentials.id", ondelete="SET NULL"), nullable=True
)
tts_credential_id: Mapped[str | None] = mapped_column(
String(40), ForeignKey("provider_credentials.id", ondelete="SET NULL"), nullable=True
)
realtime_credential_id: Mapped[str | None] = mapped_column(
String(40), ForeignKey("provider_credentials.id", ondelete="SET NULL"), nullable=True
)
# KB 引用:被引用时禁止删 KB(RESTRICT),无默认兜底
knowledge_base_id: Mapped[str | None] = mapped_column(
String(40), ForeignKey("knowledge_bases.id", ondelete="RESTRICT"), nullable=True
)
# 类型专属字段(形态各异):prompt / graph / dify|fastgpt|opencode 端点+key(打码)
config: Mapped[dict] = mapped_column(JSON, default=dict)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()

View File

@@ -1,6 +1,7 @@
"""助手 CRUD。前端「助手列表 / 创建 / 编辑」对接这里。
助手配置不含 key,所以无需打码。
模型/KB 以 FK 引用注册表;外部类型(dify/fastgpt/opencode)的 config.apiKey 是私有密钥,
读时打码、写时哨兵(复用 services/masking)。
"""
import uuid
@@ -8,24 +9,47 @@ import uuid
from db.models import Assistant
from db.session import get_session
from fastapi import APIRouter, Depends, HTTPException
from schemas import AssistantOut, AssistantUpsert
from schemas import EXTERNAL_TYPES, AssistantOut, AssistantUpsert
from services.masking import mask, resolve_incoming_key
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
router = APIRouter(prefix="/api/assistants", tags=["assistants"])
def _mask_config(type_: str, config: dict) -> dict:
"""读取返回前:外部类型的 apiKey 打码,其余原样。"""
if type_ in EXTERNAL_TYPES and config.get("apiKey"):
return {**config, "apiKey": mask(config["apiKey"])}
return config
def _merge_config(type_: str, incoming: dict, stored: dict) -> dict:
"""写入时:外部类型若回传打码占位符/空 apiKey → 保留旧 key。"""
if type_ in EXTERNAL_TYPES and "apiKey" in incoming:
incoming = {
**incoming,
"apiKey": resolve_incoming_key(
incoming.get("apiKey"), stored.get("apiKey", "")
),
}
return incoming
def _to_out(a: Assistant) -> AssistantOut:
return AssistantOut(
id=a.id,
name=a.name,
greeting=a.greeting,
prompt=a.prompt,
type=a.type, # type: ignore[arg-type]
runtime_mode=a.runtime_mode, # type: ignore[arg-type]
model=a.model,
asr=a.asr,
voice=a.voice,
greeting=a.greeting,
enable_interrupt=a.enable_interrupt,
llm_credential_id=a.llm_credential_id,
asr_credential_id=a.asr_credential_id,
tts_credential_id=a.tts_credential_id,
realtime_credential_id=a.realtime_credential_id,
knowledge_base_id=a.knowledge_base_id,
config=_mask_config(a.type, a.config or {}),
updated_at=a.updated_at.isoformat() if a.updated_at else None,
)
@@ -68,7 +92,10 @@ async def update_assistant(
a = await session.get(Assistant, assistant_id)
if not a:
raise HTTPException(404, "助手不存在")
for k, v in body.model_dump().items():
data = body.model_dump()
# 外部类型 apiKey 写时哨兵:打码占位符 → 保留旧 key(在改 a.config 前用旧值)
data["config"] = _merge_config(body.type, data["config"], a.config or {})
for k, v in data.items():
setattr(a, k, v)
await session.commit()
await session.refresh(a)

View File

@@ -0,0 +1,90 @@
"""知识库 CRUD。前端助手编辑页的"知识库"下拉对接这里。
KB 自身引用一个 Embedding 凭证(embeddingCredentialId)。被助手引用时禁止删除
(DB 层 ON DELETE RESTRICT),这里把外键冲突翻译成 409。
"""
import uuid
from db.models import KnowledgeBase
from db.session import get_session
from fastapi import APIRouter, Depends, HTTPException
from schemas import KnowledgeBaseOut, KnowledgeBaseUpsert
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
router = APIRouter(prefix="/api/knowledge-bases", tags=["knowledge-bases"])
def _to_out(kb: KnowledgeBase) -> KnowledgeBaseOut:
return KnowledgeBaseOut(
id=kb.id,
name=kb.name,
description=kb.description,
embedding_credential_id=kb.embedding_credential_id,
status=kb.status,
updated_at=kb.updated_at.isoformat() if kb.updated_at else None,
)
@router.get("", response_model=list[KnowledgeBaseOut])
async def list_knowledge_bases(session: AsyncSession = Depends(get_session)):
rows = (
await session.execute(select(KnowledgeBase).order_by(KnowledgeBase.name))
).scalars().all()
return [_to_out(kb) for kb in rows]
@router.post("", response_model=KnowledgeBaseOut)
async def create_knowledge_base(
body: KnowledgeBaseUpsert, session: AsyncSession = Depends(get_session)
):
kb = KnowledgeBase(id=f"kb_{uuid.uuid4().hex[:12]}", **body.model_dump())
session.add(kb)
await session.commit()
await session.refresh(kb)
return _to_out(kb)
@router.get("/{kb_id}", response_model=KnowledgeBaseOut)
async def get_knowledge_base(
kb_id: str, session: AsyncSession = Depends(get_session)
):
kb = await session.get(KnowledgeBase, kb_id)
if not kb:
raise HTTPException(404, "知识库不存在")
return _to_out(kb)
@router.put("/{kb_id}", response_model=KnowledgeBaseOut)
async def update_knowledge_base(
kb_id: str,
body: KnowledgeBaseUpsert,
session: AsyncSession = Depends(get_session),
):
kb = await session.get(KnowledgeBase, kb_id)
if not kb:
raise HTTPException(404, "知识库不存在")
for k, v in body.model_dump().items():
setattr(kb, k, v)
await session.commit()
await session.refresh(kb)
return _to_out(kb)
@router.delete("/{kb_id}")
async def delete_knowledge_base(
kb_id: str, session: AsyncSession = Depends(get_session)
):
kb = await session.get(KnowledgeBase, kb_id)
if not kb:
raise HTTPException(404, "知识库不存在")
try:
await session.delete(kb)
await session.commit()
except IntegrityError:
# 被助手引用(ON DELETE RESTRICT):先解绑再删
await session.rollback()
raise HTTPException(409, "知识库正被助手引用,无法删除")
return {"ok": True}

View File

@@ -7,14 +7,18 @@ JSON 用 camelCase(modelId/interfaceType/apiUrl/apiKey),Python 内部用 snake_c
from __future__ import annotations
from typing import Literal
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, model_validator
from pydantic.alias_generators import to_camel
RuntimeMode = Literal["pipeline", "realtime"]
ModelType = Literal["LLM", "ASR", "TTS", "Realtime", "Embedding"]
InterfaceType = Literal["openai", "xfyun", "dashscope", "gemini"]
AssistantType = Literal["prompt", "workflow", "dify", "fastgpt", "opencode"]
# 外部应用类型:其 config.apiKey 是该助手私有密钥,读时打码 / 写时哨兵
EXTERNAL_TYPES = {"dify", "fastgpt", "opencode"}
class CamelModel(BaseModel):
@@ -27,23 +31,86 @@ class CamelModel(BaseModel):
)
# ---------- 助手 ----------
# ---------- 各类型的 config 形态(JSON 内嵌,按 type 校验) ----------
class PromptConfig(CamelModel):
prompt: str = ""
realtime_model: str = ""
class WorkflowConfig(CamelModel):
graph: dict[str, Any] = {} # {nodes, edges, viewport};节点 data 可带 *CredentialId 覆盖
class DifyConfig(CamelModel):
api_url: str = ""
api_key: str = "" # 写时:占位符/空 → 保留旧
class FastgptConfig(CamelModel):
app_id: str = ""
api_url: str = ""
api_key: str = ""
class OpencodeConfig(CamelModel):
prompt: str = ""
api_url: str = ""
api_key: str = ""
CONFIG_BY_TYPE: dict[str, type[CamelModel]] = {
"prompt": PromptConfig,
"workflow": WorkflowConfig,
"dify": DifyConfig,
"fastgpt": FastgptConfig,
"opencode": OpencodeConfig,
}
# ---------- 助手(单表,无版本化;type 可变) ----------
class AssistantUpsert(CamelModel):
name: str
greeting: str = ""
prompt: str = ""
type: AssistantType = "prompt"
runtime_mode: RuntimeMode = "pipeline"
model: str = ""
asr: str = ""
voice: str = ""
greeting: str = ""
enable_interrupt: bool = True
# 引用注册资源(FK id;None=未选)
llm_credential_id: str | None = None
asr_credential_id: str | None = None
tts_credential_id: str | None = None
realtime_credential_id: str | None = None
knowledge_base_id: str | None = None
# 类型专属字段;校验后归一为 camelCase 存库
config: dict[str, Any] = {}
@model_validator(mode="after")
def _normalize_config(self):
model = CONFIG_BY_TYPE[self.type]
# 按当前 type 校验并裁掉无关字段,统一回 camelCase 落库
self.config = model.model_validate(self.config).model_dump(by_alias=True)
return self
class AssistantOut(AssistantUpsert):
id: str
updated_at: str | None = None
# ---------- 知识库 ----------
class KnowledgeBaseUpsert(CamelModel):
name: str
description: str = ""
embedding_credential_id: str | None = None
class KnowledgeBaseOut(KnowledgeBaseUpsert):
id: str
status: str = "active"
updated_at: str | None = None
# ---------- 模型凭证(对齐前端 ModelResource) ----------
class CredentialUpsert(CamelModel):
name: str = "" # 资源名称

View File

@@ -1,7 +1,7 @@
"""assistant_id → 运行时配置(把真 key 在服务端组装好)。
浏览器只传 assistant_id;真 key 在这里从 provider_credentials 取出注入。
取不到凭证记录时,降级用 .env 默认值(开发期零配置仍能跑)
助手按 FK(*_credential_id)引用凭证;取不到则回退该 type 默认凭证,再回退 .env
"""
import config
@@ -11,49 +11,57 @@ from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
async def _get_credential(
session: AsyncSession, type_: str, name: str = ""
async def _default_credential(
session: AsyncSession, type_: str
) -> ProviderCredential | None:
"""取某类(LLM/ASR/TTS)凭证:优先按资源名匹配,否则取该类默认"""
stmt = select(ProviderCredential).where(ProviderCredential.type == type_)
if name:
# 助手按资源名引用(如 model="DeepSeek-V3");命中则用它
named = (
await session.execute(stmt.where(ProviderCredential.name == name).limit(1))
).scalar_one_or_none()
if named:
return named
stmt = stmt.order_by(
ProviderCredential.is_default.desc(), ProviderCredential.id.asc()
).limit(1)
"""该 type 的默认凭证(is_default 优先,否则按 id 取第一条)"""
stmt = (
select(ProviderCredential)
.where(ProviderCredential.type == type_)
.order_by(ProviderCredential.is_default.desc(), ProviderCredential.id.asc())
.limit(1)
)
return (await session.execute(stmt)).scalar_one_or_none()
async def _resolve(
session: AsyncSession, cred_id: str | None, type_: str
) -> ProviderCredential | None:
"""按 FK id 取凭证;id 为空或失效 → 回退该 type 默认。"""
if cred_id:
cred = await session.get(ProviderCredential, cred_id)
if cred:
return cred
return await _default_credential(session, type_)
async def resolve_runtime_config(
session: AsyncSession, assistant_id: str
) -> AssistantConfig:
"""加载助手 + 解析凭证,产出可直接交给管线的运行时配置(含真 key)。
type 映射:LLM→大模型, ASR→语音识别, TTS→语音合成。
"""
"""加载助手 + 解析凭证,产出可直接交给管线的运行时配置(含真 key)。"""
assistant = await session.get(Assistant, assistant_id)
if assistant is None:
raise ValueError(f"助手不存在: {assistant_id}")
llm = await _get_credential(session, "LLM", assistant.model)
stt = await _get_credential(session, "ASR", assistant.asr)
tts = await _get_credential(session, "TTS")
llm = await _resolve(session, assistant.llm_credential_id, "LLM")
stt = await _resolve(session, assistant.asr_credential_id, "ASR")
tts = await _resolve(session, assistant.tts_credential_id, "TTS")
realtime = await _resolve(session, assistant.realtime_credential_id, "Realtime")
cfg = assistant.config or {}
return AssistantConfig(
name=assistant.name,
greeting=assistant.greeting,
prompt=assistant.prompt,
# 提示词/工作流类型把 prompt 放 config;外部类型由其平台编排,这里给个兜底
prompt=cfg.get("prompt") or "你是一个有帮助的助手。",
runtimeMode=assistant.runtime_mode, # type: ignore[arg-type]
enableInterrupt=assistant.enable_interrupt,
# 模型/音色:凭证的模型ID优先,否则助手里填的
model=(llm.model_id if llm else assistant.model),
asr=(stt.model_id if stt else assistant.asr),
voice=assistant.voice,
# 模型/音色:凭证的模型ID优先
model=(llm.model_id if llm else ""),
asr=(stt.model_id if stt else ""),
voice=cfg.get("voice", ""), # 音色不再是独立列,若 config 带则用,否则 .env 兜底
realtimeModel=(realtime.model_id if realtime else cfg.get("realtimeModel", "")),
# 运行时连接信息(真 key + url):凭证优先,否则 .env 兜底
llm_api_key=(llm.api_key if llm else config.LLM_API_KEY),
llm_base_url=(llm.api_url if llm else config.LLM_BASE_URL),