Implement knowledge base management and enhance assistant configuration
Add CRUD functionality for knowledge bases, including routes for listing, creating, updating, and deleting knowledge bases. Update the assistant model to include foreign key references to knowledge bases and modify the assistant configuration to handle external API keys securely. Refactor related services and routes to accommodate these changes, ensuring proper handling of credential resolution and configuration normalization.
This commit is contained in:
@@ -18,7 +18,14 @@ from db.session import init_db
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from routes import assistants, credentials, health, voice_webrtc, voice_ws
|
||||
from routes import (
|
||||
assistants,
|
||||
credentials,
|
||||
health,
|
||||
knowledge_bases,
|
||||
voice_webrtc,
|
||||
voice_ws,
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@@ -40,6 +47,7 @@ app.add_middleware(
|
||||
app.include_router(health.router)
|
||||
app.include_router(assistants.router)
|
||||
app.include_router(credentials.router)
|
||||
app.include_router(knowledge_bases.router)
|
||||
app.include_router(voice_webrtc.router)
|
||||
app.include_router(voice_ws.router)
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, String, func
|
||||
from sqlalchemy import JSON, Boolean, DateTime, ForeignKey, String, func
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
|
||||
@@ -37,19 +37,67 @@ class ProviderCredential(Base):
|
||||
)
|
||||
|
||||
|
||||
class Assistant(Base):
|
||||
__tablename__ = "assistants"
|
||||
class KnowledgeBase(Base):
|
||||
"""知识库注册表。本身引用一个 Embedding 凭证(用哪个向量模型)。
|
||||
|
||||
id: Mapped[str] = mapped_column(String(40), primary_key=True) # asst_xxx
|
||||
文档/分块(pgvector)是 KB 内部实现,这里先不展开;助手侧只认 knowledge_base_id。
|
||||
"""
|
||||
|
||||
__tablename__ = "knowledge_bases"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(40), primary_key=True) # kb_xxx
|
||||
name: Mapped[str] = mapped_column(String(128))
|
||||
greeting: Mapped[str] = mapped_column(String(2048), default="")
|
||||
prompt: Mapped[str] = mapped_column(String(8192), default="")
|
||||
runtime_mode: Mapped[str] = mapped_column(String(16), default="pipeline")
|
||||
# 模型/音色的"选项名",不是 key
|
||||
model: Mapped[str] = mapped_column(String(128), default="")
|
||||
asr: Mapped[str] = mapped_column(String(128), default="")
|
||||
voice: Mapped[str] = mapped_column(String(128), default="")
|
||||
enable_interrupt: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
description: Mapped[str] = mapped_column(String(2048), default="")
|
||||
# 该 KB 用哪个向量模型;凭证被删则置空
|
||||
embedding_credential_id: Mapped[str | None] = mapped_column(
|
||||
String(40),
|
||||
ForeignKey("provider_credentials.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
status: Mapped[str] = mapped_column(String(16), default="active") # active|archived
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
|
||||
class Assistant(Base):
|
||||
"""助手(单表,无版本化)。type 为可变普通列,5 种类型共用此表。
|
||||
|
||||
模型/KB 以 FK 引用注册表;类型专属字段塞进 config(JSON)。
|
||||
"""
|
||||
|
||||
__tablename__ = "assistants"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(40), primary_key=True) # asst_xxx
|
||||
name: Mapped[str] = mapped_column(String(128))
|
||||
# prompt|workflow|dify|fastgpt|opencode;创建后可改
|
||||
type: Mapped[str] = mapped_column(String(16), index=True, default="prompt")
|
||||
runtime_mode: Mapped[str] = mapped_column(String(16), default="pipeline")
|
||||
greeting: Mapped[str] = mapped_column(String(2048), default="")
|
||||
enable_interrupt: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
|
||||
# ---- 引用"注册好的资源":凭证被删 → SET NULL(resolver 有默认/.env 兜底) ----
|
||||
llm_credential_id: Mapped[str | None] = mapped_column(
|
||||
String(40), ForeignKey("provider_credentials.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
asr_credential_id: Mapped[str | None] = mapped_column(
|
||||
String(40), ForeignKey("provider_credentials.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
tts_credential_id: Mapped[str | None] = mapped_column(
|
||||
String(40), ForeignKey("provider_credentials.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
realtime_credential_id: Mapped[str | None] = mapped_column(
|
||||
String(40), ForeignKey("provider_credentials.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
# KB 引用:被引用时禁止删 KB(RESTRICT),无默认兜底
|
||||
knowledge_base_id: Mapped[str | None] = mapped_column(
|
||||
String(40), ForeignKey("knowledge_bases.id", ondelete="RESTRICT"), nullable=True
|
||||
)
|
||||
|
||||
# 类型专属字段(形态各异):prompt / graph / dify|fastgpt|opencode 端点+key(打码)
|
||||
config: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""助手 CRUD。前端「助手列表 / 创建 / 编辑」对接这里。
|
||||
|
||||
助手配置不含 key,所以无需打码。
|
||||
模型/KB 以 FK 引用注册表;外部类型(dify/fastgpt/opencode)的 config.apiKey 是私有密钥,
|
||||
读时打码、写时哨兵(复用 services/masking)。
|
||||
"""
|
||||
|
||||
import uuid
|
||||
@@ -8,24 +9,47 @@ import uuid
|
||||
from db.models import Assistant
|
||||
from db.session import get_session
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from schemas import AssistantOut, AssistantUpsert
|
||||
from schemas import EXTERNAL_TYPES, AssistantOut, AssistantUpsert
|
||||
from services.masking import mask, resolve_incoming_key
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
router = APIRouter(prefix="/api/assistants", tags=["assistants"])
|
||||
|
||||
|
||||
def _mask_config(type_: str, config: dict) -> dict:
|
||||
"""读取返回前:外部类型的 apiKey 打码,其余原样。"""
|
||||
if type_ in EXTERNAL_TYPES and config.get("apiKey"):
|
||||
return {**config, "apiKey": mask(config["apiKey"])}
|
||||
return config
|
||||
|
||||
|
||||
def _merge_config(type_: str, incoming: dict, stored: dict) -> dict:
|
||||
"""写入时:外部类型若回传打码占位符/空 apiKey → 保留旧 key。"""
|
||||
if type_ in EXTERNAL_TYPES and "apiKey" in incoming:
|
||||
incoming = {
|
||||
**incoming,
|
||||
"apiKey": resolve_incoming_key(
|
||||
incoming.get("apiKey"), stored.get("apiKey", "")
|
||||
),
|
||||
}
|
||||
return incoming
|
||||
|
||||
|
||||
def _to_out(a: Assistant) -> AssistantOut:
|
||||
return AssistantOut(
|
||||
id=a.id,
|
||||
name=a.name,
|
||||
greeting=a.greeting,
|
||||
prompt=a.prompt,
|
||||
type=a.type, # type: ignore[arg-type]
|
||||
runtime_mode=a.runtime_mode, # type: ignore[arg-type]
|
||||
model=a.model,
|
||||
asr=a.asr,
|
||||
voice=a.voice,
|
||||
greeting=a.greeting,
|
||||
enable_interrupt=a.enable_interrupt,
|
||||
llm_credential_id=a.llm_credential_id,
|
||||
asr_credential_id=a.asr_credential_id,
|
||||
tts_credential_id=a.tts_credential_id,
|
||||
realtime_credential_id=a.realtime_credential_id,
|
||||
knowledge_base_id=a.knowledge_base_id,
|
||||
config=_mask_config(a.type, a.config or {}),
|
||||
updated_at=a.updated_at.isoformat() if a.updated_at else None,
|
||||
)
|
||||
|
||||
@@ -68,7 +92,10 @@ async def update_assistant(
|
||||
a = await session.get(Assistant, assistant_id)
|
||||
if not a:
|
||||
raise HTTPException(404, "助手不存在")
|
||||
for k, v in body.model_dump().items():
|
||||
data = body.model_dump()
|
||||
# 外部类型 apiKey 写时哨兵:打码占位符 → 保留旧 key(在改 a.config 前用旧值)
|
||||
data["config"] = _merge_config(body.type, data["config"], a.config or {})
|
||||
for k, v in data.items():
|
||||
setattr(a, k, v)
|
||||
await session.commit()
|
||||
await session.refresh(a)
|
||||
|
||||
90
backend/routes/knowledge_bases.py
Normal file
90
backend/routes/knowledge_bases.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""知识库 CRUD。前端助手编辑页的"知识库"下拉对接这里。
|
||||
|
||||
KB 自身引用一个 Embedding 凭证(embeddingCredentialId)。被助手引用时禁止删除
|
||||
(DB 层 ON DELETE RESTRICT),这里把外键冲突翻译成 409。
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
||||
from db.models import KnowledgeBase
|
||||
from db.session import get_session
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from schemas import KnowledgeBaseOut, KnowledgeBaseUpsert
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
router = APIRouter(prefix="/api/knowledge-bases", tags=["knowledge-bases"])
|
||||
|
||||
|
||||
def _to_out(kb: KnowledgeBase) -> KnowledgeBaseOut:
|
||||
return KnowledgeBaseOut(
|
||||
id=kb.id,
|
||||
name=kb.name,
|
||||
description=kb.description,
|
||||
embedding_credential_id=kb.embedding_credential_id,
|
||||
status=kb.status,
|
||||
updated_at=kb.updated_at.isoformat() if kb.updated_at else None,
|
||||
)
|
||||
|
||||
|
||||
@router.get("", response_model=list[KnowledgeBaseOut])
|
||||
async def list_knowledge_bases(session: AsyncSession = Depends(get_session)):
|
||||
rows = (
|
||||
await session.execute(select(KnowledgeBase).order_by(KnowledgeBase.name))
|
||||
).scalars().all()
|
||||
return [_to_out(kb) for kb in rows]
|
||||
|
||||
|
||||
@router.post("", response_model=KnowledgeBaseOut)
|
||||
async def create_knowledge_base(
|
||||
body: KnowledgeBaseUpsert, session: AsyncSession = Depends(get_session)
|
||||
):
|
||||
kb = KnowledgeBase(id=f"kb_{uuid.uuid4().hex[:12]}", **body.model_dump())
|
||||
session.add(kb)
|
||||
await session.commit()
|
||||
await session.refresh(kb)
|
||||
return _to_out(kb)
|
||||
|
||||
|
||||
@router.get("/{kb_id}", response_model=KnowledgeBaseOut)
|
||||
async def get_knowledge_base(
|
||||
kb_id: str, session: AsyncSession = Depends(get_session)
|
||||
):
|
||||
kb = await session.get(KnowledgeBase, kb_id)
|
||||
if not kb:
|
||||
raise HTTPException(404, "知识库不存在")
|
||||
return _to_out(kb)
|
||||
|
||||
|
||||
@router.put("/{kb_id}", response_model=KnowledgeBaseOut)
|
||||
async def update_knowledge_base(
|
||||
kb_id: str,
|
||||
body: KnowledgeBaseUpsert,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
kb = await session.get(KnowledgeBase, kb_id)
|
||||
if not kb:
|
||||
raise HTTPException(404, "知识库不存在")
|
||||
for k, v in body.model_dump().items():
|
||||
setattr(kb, k, v)
|
||||
await session.commit()
|
||||
await session.refresh(kb)
|
||||
return _to_out(kb)
|
||||
|
||||
|
||||
@router.delete("/{kb_id}")
|
||||
async def delete_knowledge_base(
|
||||
kb_id: str, session: AsyncSession = Depends(get_session)
|
||||
):
|
||||
kb = await session.get(KnowledgeBase, kb_id)
|
||||
if not kb:
|
||||
raise HTTPException(404, "知识库不存在")
|
||||
try:
|
||||
await session.delete(kb)
|
||||
await session.commit()
|
||||
except IntegrityError:
|
||||
# 被助手引用(ON DELETE RESTRICT):先解绑再删
|
||||
await session.rollback()
|
||||
raise HTTPException(409, "知识库正被助手引用,无法删除")
|
||||
return {"ok": True}
|
||||
@@ -7,14 +7,18 @@ JSON 用 camelCase(modelId/interfaceType/apiUrl/apiKey),Python 内部用 snake_c
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
from pydantic.alias_generators import to_camel
|
||||
|
||||
RuntimeMode = Literal["pipeline", "realtime"]
|
||||
ModelType = Literal["LLM", "ASR", "TTS", "Realtime", "Embedding"]
|
||||
InterfaceType = Literal["openai", "xfyun", "dashscope", "gemini"]
|
||||
AssistantType = Literal["prompt", "workflow", "dify", "fastgpt", "opencode"]
|
||||
|
||||
# 外部应用类型:其 config.apiKey 是该助手私有密钥,读时打码 / 写时哨兵
|
||||
EXTERNAL_TYPES = {"dify", "fastgpt", "opencode"}
|
||||
|
||||
|
||||
class CamelModel(BaseModel):
|
||||
@@ -27,23 +31,86 @@ class CamelModel(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
# ---------- 助手 ----------
|
||||
# ---------- 各类型的 config 形态(JSON 内嵌,按 type 校验) ----------
|
||||
class PromptConfig(CamelModel):
|
||||
prompt: str = ""
|
||||
realtime_model: str = ""
|
||||
|
||||
|
||||
class WorkflowConfig(CamelModel):
|
||||
graph: dict[str, Any] = {} # {nodes, edges, viewport};节点 data 可带 *CredentialId 覆盖
|
||||
|
||||
|
||||
class DifyConfig(CamelModel):
|
||||
api_url: str = ""
|
||||
api_key: str = "" # 写时:占位符/空 → 保留旧
|
||||
|
||||
|
||||
class FastgptConfig(CamelModel):
|
||||
app_id: str = ""
|
||||
api_url: str = ""
|
||||
api_key: str = ""
|
||||
|
||||
|
||||
class OpencodeConfig(CamelModel):
|
||||
prompt: str = ""
|
||||
api_url: str = ""
|
||||
api_key: str = ""
|
||||
|
||||
|
||||
CONFIG_BY_TYPE: dict[str, type[CamelModel]] = {
|
||||
"prompt": PromptConfig,
|
||||
"workflow": WorkflowConfig,
|
||||
"dify": DifyConfig,
|
||||
"fastgpt": FastgptConfig,
|
||||
"opencode": OpencodeConfig,
|
||||
}
|
||||
|
||||
|
||||
# ---------- 助手(单表,无版本化;type 可变) ----------
|
||||
class AssistantUpsert(CamelModel):
|
||||
name: str
|
||||
greeting: str = ""
|
||||
prompt: str = ""
|
||||
type: AssistantType = "prompt"
|
||||
runtime_mode: RuntimeMode = "pipeline"
|
||||
model: str = ""
|
||||
asr: str = ""
|
||||
voice: str = ""
|
||||
greeting: str = ""
|
||||
enable_interrupt: bool = True
|
||||
|
||||
# 引用注册资源(FK id;None=未选)
|
||||
llm_credential_id: str | None = None
|
||||
asr_credential_id: str | None = None
|
||||
tts_credential_id: str | None = None
|
||||
realtime_credential_id: str | None = None
|
||||
knowledge_base_id: str | None = None
|
||||
|
||||
# 类型专属字段;校验后归一为 camelCase 存库
|
||||
config: dict[str, Any] = {}
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _normalize_config(self):
|
||||
model = CONFIG_BY_TYPE[self.type]
|
||||
# 按当前 type 校验并裁掉无关字段,统一回 camelCase 落库
|
||||
self.config = model.model_validate(self.config).model_dump(by_alias=True)
|
||||
return self
|
||||
|
||||
|
||||
class AssistantOut(AssistantUpsert):
|
||||
id: str
|
||||
updated_at: str | None = None
|
||||
|
||||
|
||||
# ---------- 知识库 ----------
|
||||
class KnowledgeBaseUpsert(CamelModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
embedding_credential_id: str | None = None
|
||||
|
||||
|
||||
class KnowledgeBaseOut(KnowledgeBaseUpsert):
|
||||
id: str
|
||||
status: str = "active"
|
||||
updated_at: str | None = None
|
||||
|
||||
|
||||
# ---------- 模型凭证(对齐前端 ModelResource) ----------
|
||||
class CredentialUpsert(CamelModel):
|
||||
name: str = "" # 资源名称
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""assistant_id → 运行时配置(把真 key 在服务端组装好)。
|
||||
|
||||
浏览器只传 assistant_id;真 key 在这里从 provider_credentials 取出注入。
|
||||
取不到凭证记录时,降级用 .env 默认值(开发期零配置仍能跑)。
|
||||
助手按 FK(*_credential_id)引用凭证;取不到则回退该 type 默认凭证,再回退 .env。
|
||||
"""
|
||||
|
||||
import config
|
||||
@@ -11,49 +11,57 @@ from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
async def _get_credential(
|
||||
session: AsyncSession, type_: str, name: str = ""
|
||||
async def _default_credential(
|
||||
session: AsyncSession, type_: str
|
||||
) -> ProviderCredential | None:
|
||||
"""取某类(LLM/ASR/TTS)凭证:优先按资源名匹配,否则取该类默认。"""
|
||||
stmt = select(ProviderCredential).where(ProviderCredential.type == type_)
|
||||
if name:
|
||||
# 助手按资源名引用(如 model="DeepSeek-V3");命中则用它
|
||||
named = (
|
||||
await session.execute(stmt.where(ProviderCredential.name == name).limit(1))
|
||||
).scalar_one_or_none()
|
||||
if named:
|
||||
return named
|
||||
stmt = stmt.order_by(
|
||||
ProviderCredential.is_default.desc(), ProviderCredential.id.asc()
|
||||
).limit(1)
|
||||
"""该 type 的默认凭证(is_default 优先,否则按 id 取第一条)。"""
|
||||
stmt = (
|
||||
select(ProviderCredential)
|
||||
.where(ProviderCredential.type == type_)
|
||||
.order_by(ProviderCredential.is_default.desc(), ProviderCredential.id.asc())
|
||||
.limit(1)
|
||||
)
|
||||
return (await session.execute(stmt)).scalar_one_or_none()
|
||||
|
||||
|
||||
async def _resolve(
|
||||
session: AsyncSession, cred_id: str | None, type_: str
|
||||
) -> ProviderCredential | None:
|
||||
"""按 FK id 取凭证;id 为空或失效 → 回退该 type 默认。"""
|
||||
if cred_id:
|
||||
cred = await session.get(ProviderCredential, cred_id)
|
||||
if cred:
|
||||
return cred
|
||||
return await _default_credential(session, type_)
|
||||
|
||||
|
||||
async def resolve_runtime_config(
|
||||
session: AsyncSession, assistant_id: str
|
||||
) -> AssistantConfig:
|
||||
"""加载助手 + 解析凭证,产出可直接交给管线的运行时配置(含真 key)。
|
||||
|
||||
type 映射:LLM→大模型, ASR→语音识别, TTS→语音合成。
|
||||
"""
|
||||
"""加载助手 + 解析凭证,产出可直接交给管线的运行时配置(含真 key)。"""
|
||||
assistant = await session.get(Assistant, assistant_id)
|
||||
if assistant is None:
|
||||
raise ValueError(f"助手不存在: {assistant_id}")
|
||||
|
||||
llm = await _get_credential(session, "LLM", assistant.model)
|
||||
stt = await _get_credential(session, "ASR", assistant.asr)
|
||||
tts = await _get_credential(session, "TTS")
|
||||
llm = await _resolve(session, assistant.llm_credential_id, "LLM")
|
||||
stt = await _resolve(session, assistant.asr_credential_id, "ASR")
|
||||
tts = await _resolve(session, assistant.tts_credential_id, "TTS")
|
||||
realtime = await _resolve(session, assistant.realtime_credential_id, "Realtime")
|
||||
|
||||
cfg = assistant.config or {}
|
||||
|
||||
return AssistantConfig(
|
||||
name=assistant.name,
|
||||
greeting=assistant.greeting,
|
||||
prompt=assistant.prompt,
|
||||
# 提示词/工作流类型把 prompt 放 config;外部类型由其平台编排,这里给个兜底
|
||||
prompt=cfg.get("prompt") or "你是一个有帮助的助手。",
|
||||
runtimeMode=assistant.runtime_mode, # type: ignore[arg-type]
|
||||
enableInterrupt=assistant.enable_interrupt,
|
||||
# 模型/音色:凭证的模型ID优先,否则助手里填的
|
||||
model=(llm.model_id if llm else assistant.model),
|
||||
asr=(stt.model_id if stt else assistant.asr),
|
||||
voice=assistant.voice,
|
||||
# 模型/音色:凭证的模型ID优先
|
||||
model=(llm.model_id if llm else ""),
|
||||
asr=(stt.model_id if stt else ""),
|
||||
voice=cfg.get("voice", ""), # 音色不再是独立列,若 config 带则用,否则 .env 兜底
|
||||
realtimeModel=(realtime.model_id if realtime else cfg.get("realtimeModel", "")),
|
||||
# 运行时连接信息(真 key + url):凭证优先,否则 .env 兜底
|
||||
llm_api_key=(llm.api_key if llm else config.LLM_API_KEY),
|
||||
llm_base_url=(llm.api_url if llm else config.LLM_BASE_URL),
|
||||
|
||||
Reference in New Issue
Block a user