diff --git a/backend/app.py b/backend/app.py index d8212b2..1254eb3 100644 --- a/backend/app.py +++ b/backend/app.py @@ -18,7 +18,14 @@ from db.session import init_db from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware -from routes import assistants, credentials, health, voice_webrtc, voice_ws +from routes import ( + assistants, + credentials, + health, + knowledge_bases, + voice_webrtc, + voice_ws, +) @asynccontextmanager @@ -40,6 +47,7 @@ app.add_middleware( app.include_router(health.router) app.include_router(assistants.router) app.include_router(credentials.router) +app.include_router(knowledge_bases.router) app.include_router(voice_webrtc.router) app.include_router(voice_ws.router) diff --git a/backend/db/models.py b/backend/db/models.py index f9adf08..506c1ee 100644 --- a/backend/db/models.py +++ b/backend/db/models.py @@ -9,7 +9,7 @@ from datetime import datetime -from sqlalchemy import Boolean, DateTime, String, func +from sqlalchemy import JSON, Boolean, DateTime, ForeignKey, String, func from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column @@ -37,19 +37,67 @@ class ProviderCredential(Base): ) -class Assistant(Base): - __tablename__ = "assistants" +class KnowledgeBase(Base): + """知识库注册表。本身引用一个 Embedding 凭证(用哪个向量模型)。 - id: Mapped[str] = mapped_column(String(40), primary_key=True) # asst_xxx + 文档/分块(pgvector)是 KB 内部实现,这里先不展开;助手侧只认 knowledge_base_id。 + """ + + __tablename__ = "knowledge_bases" + + id: Mapped[str] = mapped_column(String(40), primary_key=True) # kb_xxx name: Mapped[str] = mapped_column(String(128)) - greeting: Mapped[str] = mapped_column(String(2048), default="") - prompt: Mapped[str] = mapped_column(String(8192), default="") - runtime_mode: Mapped[str] = mapped_column(String(16), default="pipeline") - # 模型/音色的"选项名",不是 key - model: Mapped[str] = mapped_column(String(128), default="") - asr: Mapped[str] = mapped_column(String(128), default="") - voice: Mapped[str] = mapped_column(String(128), default="") - enable_interrupt: Mapped[bool] = mapped_column(Boolean, default=True) + description: Mapped[str] = mapped_column(String(2048), default="") + # 该 KB 用哪个向量模型;凭证被删则置空 + embedding_credential_id: Mapped[str | None] = mapped_column( + String(40), + ForeignKey("provider_credentials.id", ondelete="SET NULL"), + nullable=True, + ) + status: Mapped[str] = mapped_column(String(16), default="active") # active|archived + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now()) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now(), onupdate=func.now() + ) + + +class Assistant(Base): + """助手(单表,无版本化)。type 为可变普通列,5 种类型共用此表。 + + 模型/KB 以 FK 引用注册表;类型专属字段塞进 config(JSON)。 + """ + + __tablename__ = "assistants" + + id: Mapped[str] = mapped_column(String(40), primary_key=True) # asst_xxx + name: Mapped[str] = mapped_column(String(128)) + # prompt|workflow|dify|fastgpt|opencode;创建后可改 + type: Mapped[str] = mapped_column(String(16), index=True, default="prompt") + runtime_mode: Mapped[str] = mapped_column(String(16), default="pipeline") + greeting: Mapped[str] = mapped_column(String(2048), default="") + enable_interrupt: Mapped[bool] = mapped_column(Boolean, default=True) + + # ---- 引用"注册好的资源":凭证被删 → SET NULL(resolver 有默认/.env 兜底) ---- + llm_credential_id: Mapped[str | None] = mapped_column( + String(40), ForeignKey("provider_credentials.id", ondelete="SET NULL"), nullable=True + ) + asr_credential_id: Mapped[str | None] = mapped_column( + String(40), ForeignKey("provider_credentials.id", ondelete="SET NULL"), nullable=True + ) + tts_credential_id: Mapped[str | None] = mapped_column( + String(40), ForeignKey("provider_credentials.id", ondelete="SET NULL"), nullable=True + ) + realtime_credential_id: Mapped[str | None] = mapped_column( + String(40), ForeignKey("provider_credentials.id", ondelete="SET NULL"), nullable=True + ) + # KB 引用:被引用时禁止删 KB(RESTRICT),无默认兜底 + knowledge_base_id: Mapped[str | None] = mapped_column( + String(40), ForeignKey("knowledge_bases.id", ondelete="RESTRICT"), nullable=True + ) + + # 类型专属字段(形态各异):prompt / graph / dify|fastgpt|opencode 端点+key(打码) + config: Mapped[dict] = mapped_column(JSON, default=dict) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now()) updated_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), onupdate=func.now() diff --git a/backend/routes/assistants.py b/backend/routes/assistants.py index f5db498..a570732 100644 --- a/backend/routes/assistants.py +++ b/backend/routes/assistants.py @@ -1,6 +1,7 @@ """助手 CRUD。前端「助手列表 / 创建 / 编辑」对接这里。 -助手配置不含 key,所以无需打码。 +模型/KB 以 FK 引用注册表;外部类型(dify/fastgpt/opencode)的 config.apiKey 是私有密钥, +读时打码、写时哨兵(复用 services/masking)。 """ import uuid @@ -8,24 +9,47 @@ import uuid from db.models import Assistant from db.session import get_session from fastapi import APIRouter, Depends, HTTPException -from schemas import AssistantOut, AssistantUpsert +from schemas import EXTERNAL_TYPES, AssistantOut, AssistantUpsert +from services.masking import mask, resolve_incoming_key from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession router = APIRouter(prefix="/api/assistants", tags=["assistants"]) +def _mask_config(type_: str, config: dict) -> dict: + """读取返回前:外部类型的 apiKey 打码,其余原样。""" + if type_ in EXTERNAL_TYPES and config.get("apiKey"): + return {**config, "apiKey": mask(config["apiKey"])} + return config + + +def _merge_config(type_: str, incoming: dict, stored: dict) -> dict: + """写入时:外部类型若回传打码占位符/空 apiKey → 保留旧 key。""" + if type_ in EXTERNAL_TYPES and "apiKey" in incoming: + incoming = { + **incoming, + "apiKey": resolve_incoming_key( + incoming.get("apiKey"), stored.get("apiKey", "") + ), + } + return incoming + + def _to_out(a: Assistant) -> AssistantOut: return AssistantOut( id=a.id, name=a.name, - greeting=a.greeting, - prompt=a.prompt, + type=a.type, # type: ignore[arg-type] runtime_mode=a.runtime_mode, # type: ignore[arg-type] - model=a.model, - asr=a.asr, - voice=a.voice, + greeting=a.greeting, enable_interrupt=a.enable_interrupt, + llm_credential_id=a.llm_credential_id, + asr_credential_id=a.asr_credential_id, + tts_credential_id=a.tts_credential_id, + realtime_credential_id=a.realtime_credential_id, + knowledge_base_id=a.knowledge_base_id, + config=_mask_config(a.type, a.config or {}), updated_at=a.updated_at.isoformat() if a.updated_at else None, ) @@ -68,7 +92,10 @@ async def update_assistant( a = await session.get(Assistant, assistant_id) if not a: raise HTTPException(404, "助手不存在") - for k, v in body.model_dump().items(): + data = body.model_dump() + # 外部类型 apiKey 写时哨兵:打码占位符 → 保留旧 key(在改 a.config 前用旧值) + data["config"] = _merge_config(body.type, data["config"], a.config or {}) + for k, v in data.items(): setattr(a, k, v) await session.commit() await session.refresh(a) diff --git a/backend/routes/knowledge_bases.py b/backend/routes/knowledge_bases.py new file mode 100644 index 0000000..4356d2a --- /dev/null +++ b/backend/routes/knowledge_bases.py @@ -0,0 +1,90 @@ +"""知识库 CRUD。前端助手编辑页的"知识库"下拉对接这里。 + +KB 自身引用一个 Embedding 凭证(embeddingCredentialId)。被助手引用时禁止删除 +(DB 层 ON DELETE RESTRICT),这里把外键冲突翻译成 409。 +""" + +import uuid + +from db.models import KnowledgeBase +from db.session import get_session +from fastapi import APIRouter, Depends, HTTPException +from schemas import KnowledgeBaseOut, KnowledgeBaseUpsert +from sqlalchemy import select +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession + +router = APIRouter(prefix="/api/knowledge-bases", tags=["knowledge-bases"]) + + +def _to_out(kb: KnowledgeBase) -> KnowledgeBaseOut: + return KnowledgeBaseOut( + id=kb.id, + name=kb.name, + description=kb.description, + embedding_credential_id=kb.embedding_credential_id, + status=kb.status, + updated_at=kb.updated_at.isoformat() if kb.updated_at else None, + ) + + +@router.get("", response_model=list[KnowledgeBaseOut]) +async def list_knowledge_bases(session: AsyncSession = Depends(get_session)): + rows = ( + await session.execute(select(KnowledgeBase).order_by(KnowledgeBase.name)) + ).scalars().all() + return [_to_out(kb) for kb in rows] + + +@router.post("", response_model=KnowledgeBaseOut) +async def create_knowledge_base( + body: KnowledgeBaseUpsert, session: AsyncSession = Depends(get_session) +): + kb = KnowledgeBase(id=f"kb_{uuid.uuid4().hex[:12]}", **body.model_dump()) + session.add(kb) + await session.commit() + await session.refresh(kb) + return _to_out(kb) + + +@router.get("/{kb_id}", response_model=KnowledgeBaseOut) +async def get_knowledge_base( + kb_id: str, session: AsyncSession = Depends(get_session) +): + kb = await session.get(KnowledgeBase, kb_id) + if not kb: + raise HTTPException(404, "知识库不存在") + return _to_out(kb) + + +@router.put("/{kb_id}", response_model=KnowledgeBaseOut) +async def update_knowledge_base( + kb_id: str, + body: KnowledgeBaseUpsert, + session: AsyncSession = Depends(get_session), +): + kb = await session.get(KnowledgeBase, kb_id) + if not kb: + raise HTTPException(404, "知识库不存在") + for k, v in body.model_dump().items(): + setattr(kb, k, v) + await session.commit() + await session.refresh(kb) + return _to_out(kb) + + +@router.delete("/{kb_id}") +async def delete_knowledge_base( + kb_id: str, session: AsyncSession = Depends(get_session) +): + kb = await session.get(KnowledgeBase, kb_id) + if not kb: + raise HTTPException(404, "知识库不存在") + try: + await session.delete(kb) + await session.commit() + except IntegrityError: + # 被助手引用(ON DELETE RESTRICT):先解绑再删 + await session.rollback() + raise HTTPException(409, "知识库正被助手引用,无法删除") + return {"ok": True} diff --git a/backend/schemas.py b/backend/schemas.py index 69235d6..191a5ed 100644 --- a/backend/schemas.py +++ b/backend/schemas.py @@ -7,14 +7,18 @@ JSON 用 camelCase(modelId/interfaceType/apiUrl/apiKey),Python 内部用 snake_c from __future__ import annotations -from typing import Literal +from typing import Any, Literal -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, model_validator from pydantic.alias_generators import to_camel RuntimeMode = Literal["pipeline", "realtime"] ModelType = Literal["LLM", "ASR", "TTS", "Realtime", "Embedding"] InterfaceType = Literal["openai", "xfyun", "dashscope", "gemini"] +AssistantType = Literal["prompt", "workflow", "dify", "fastgpt", "opencode"] + +# 外部应用类型:其 config.apiKey 是该助手私有密钥,读时打码 / 写时哨兵 +EXTERNAL_TYPES = {"dify", "fastgpt", "opencode"} class CamelModel(BaseModel): @@ -27,23 +31,86 @@ class CamelModel(BaseModel): ) -# ---------- 助手 ---------- +# ---------- 各类型的 config 形态(JSON 内嵌,按 type 校验) ---------- +class PromptConfig(CamelModel): + prompt: str = "" + realtime_model: str = "" + + +class WorkflowConfig(CamelModel): + graph: dict[str, Any] = {} # {nodes, edges, viewport};节点 data 可带 *CredentialId 覆盖 + + +class DifyConfig(CamelModel): + api_url: str = "" + api_key: str = "" # 写时:占位符/空 → 保留旧 + + +class FastgptConfig(CamelModel): + app_id: str = "" + api_url: str = "" + api_key: str = "" + + +class OpencodeConfig(CamelModel): + prompt: str = "" + api_url: str = "" + api_key: str = "" + + +CONFIG_BY_TYPE: dict[str, type[CamelModel]] = { + "prompt": PromptConfig, + "workflow": WorkflowConfig, + "dify": DifyConfig, + "fastgpt": FastgptConfig, + "opencode": OpencodeConfig, +} + + +# ---------- 助手(单表,无版本化;type 可变) ---------- class AssistantUpsert(CamelModel): name: str - greeting: str = "" - prompt: str = "" + type: AssistantType = "prompt" runtime_mode: RuntimeMode = "pipeline" - model: str = "" - asr: str = "" - voice: str = "" + greeting: str = "" enable_interrupt: bool = True + # 引用注册资源(FK id;None=未选) + llm_credential_id: str | None = None + asr_credential_id: str | None = None + tts_credential_id: str | None = None + realtime_credential_id: str | None = None + knowledge_base_id: str | None = None + + # 类型专属字段;校验后归一为 camelCase 存库 + config: dict[str, Any] = {} + + @model_validator(mode="after") + def _normalize_config(self): + model = CONFIG_BY_TYPE[self.type] + # 按当前 type 校验并裁掉无关字段,统一回 camelCase 落库 + self.config = model.model_validate(self.config).model_dump(by_alias=True) + return self + class AssistantOut(AssistantUpsert): id: str updated_at: str | None = None +# ---------- 知识库 ---------- +class KnowledgeBaseUpsert(CamelModel): + name: str + description: str = "" + embedding_credential_id: str | None = None + + +class KnowledgeBaseOut(KnowledgeBaseUpsert): + id: str + status: str = "active" + updated_at: str | None = None + + # ---------- 模型凭证(对齐前端 ModelResource) ---------- class CredentialUpsert(CamelModel): name: str = "" # 资源名称 diff --git a/backend/services/config_resolver.py b/backend/services/config_resolver.py index bce59c0..e905d89 100644 --- a/backend/services/config_resolver.py +++ b/backend/services/config_resolver.py @@ -1,7 +1,7 @@ """assistant_id → 运行时配置(把真 key 在服务端组装好)。 浏览器只传 assistant_id;真 key 在这里从 provider_credentials 取出注入。 -取不到凭证记录时,降级用 .env 默认值(开发期零配置仍能跑)。 +助手按 FK(*_credential_id)引用凭证;取不到则回退该 type 默认凭证,再回退 .env。 """ import config @@ -11,49 +11,57 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -async def _get_credential( - session: AsyncSession, type_: str, name: str = "" +async def _default_credential( + session: AsyncSession, type_: str ) -> ProviderCredential | None: - """取某类(LLM/ASR/TTS)凭证:优先按资源名匹配,否则取该类默认。""" - stmt = select(ProviderCredential).where(ProviderCredential.type == type_) - if name: - # 助手按资源名引用(如 model="DeepSeek-V3");命中则用它 - named = ( - await session.execute(stmt.where(ProviderCredential.name == name).limit(1)) - ).scalar_one_or_none() - if named: - return named - stmt = stmt.order_by( - ProviderCredential.is_default.desc(), ProviderCredential.id.asc() - ).limit(1) + """该 type 的默认凭证(is_default 优先,否则按 id 取第一条)。""" + stmt = ( + select(ProviderCredential) + .where(ProviderCredential.type == type_) + .order_by(ProviderCredential.is_default.desc(), ProviderCredential.id.asc()) + .limit(1) + ) return (await session.execute(stmt)).scalar_one_or_none() +async def _resolve( + session: AsyncSession, cred_id: str | None, type_: str +) -> ProviderCredential | None: + """按 FK id 取凭证;id 为空或失效 → 回退该 type 默认。""" + if cred_id: + cred = await session.get(ProviderCredential, cred_id) + if cred: + return cred + return await _default_credential(session, type_) + + async def resolve_runtime_config( session: AsyncSession, assistant_id: str ) -> AssistantConfig: - """加载助手 + 解析凭证,产出可直接交给管线的运行时配置(含真 key)。 - - type 映射:LLM→大模型, ASR→语音识别, TTS→语音合成。 - """ + """加载助手 + 解析凭证,产出可直接交给管线的运行时配置(含真 key)。""" assistant = await session.get(Assistant, assistant_id) if assistant is None: raise ValueError(f"助手不存在: {assistant_id}") - llm = await _get_credential(session, "LLM", assistant.model) - stt = await _get_credential(session, "ASR", assistant.asr) - tts = await _get_credential(session, "TTS") + llm = await _resolve(session, assistant.llm_credential_id, "LLM") + stt = await _resolve(session, assistant.asr_credential_id, "ASR") + tts = await _resolve(session, assistant.tts_credential_id, "TTS") + realtime = await _resolve(session, assistant.realtime_credential_id, "Realtime") + + cfg = assistant.config or {} return AssistantConfig( name=assistant.name, greeting=assistant.greeting, - prompt=assistant.prompt, + # 提示词/工作流类型把 prompt 放 config;外部类型由其平台编排,这里给个兜底 + prompt=cfg.get("prompt") or "你是一个有帮助的助手。", runtimeMode=assistant.runtime_mode, # type: ignore[arg-type] enableInterrupt=assistant.enable_interrupt, - # 模型/音色:凭证的模型ID优先,否则助手里填的 - model=(llm.model_id if llm else assistant.model), - asr=(stt.model_id if stt else assistant.asr), - voice=assistant.voice, + # 模型/音色:凭证的模型ID优先 + model=(llm.model_id if llm else ""), + asr=(stt.model_id if stt else ""), + voice=cfg.get("voice", ""), # 音色不再是独立列,若 config 带则用,否则 .env 兜底 + realtimeModel=(realtime.model_id if realtime else cfg.get("realtimeModel", "")), # 运行时连接信息(真 key + url):凭证优先,否则 .env 兜底 llm_api_key=(llm.api_key if llm else config.LLM_API_KEY), llm_base_url=(llm.api_url if llm else config.LLM_BASE_URL),