#!/usr/bin/env python3 """初始化数据库""" import argparse import os import sys from contextlib import contextmanager # 添加路径 sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from app.db import Base, engine, DATABASE_URL from app.id_generator import short_id from app.models import Voice, Assistant, KnowledgeBase, Workflow, LLMModel, ASRModel, KnowledgeDocument VOICE_MODEL = "FunAudioLLM/CosyVoice2-0.5B" SEED_VOICE_IDS = { "alex": short_id("tts"), "david": short_id("tts"), "bella": short_id("tts"), "claire": short_id("tts"), } SEED_LLM_IDS = { "deepseek_chat": short_id("llm"), "glm_4": short_id("llm"), "embedding_3_small": short_id("llm"), } SEED_ASR_IDS = { "sensevoice_small": short_id("asr"), "telespeech_asr": short_id("asr"), } SEED_ASSISTANT_IDS = { "default": short_id("ast"), "customer_service": short_id("ast"), "english_tutor": short_id("ast"), } def ensure_db_dir(): """确保 SQLite 数据目录存在。""" if not DATABASE_URL.startswith("sqlite:///"): return db_path = DATABASE_URL.replace("sqlite:///", "") data_dir = os.path.dirname(db_path) if data_dir: os.makedirs(data_dir, exist_ok=True) @contextmanager def db_session(): """统一管理 DB session 生命周期。""" from app.db import SessionLocal db = SessionLocal() try: yield db finally: db.close() def seed_if_empty(db, model_cls, records, success_msg: str): """当目标表为空时写入默认记录。""" if db.query(model_cls).count() != 0: return if isinstance(records, list): db.add_all(records) else: db.add(records) db.commit() print(success_msg) def init_db(): """创建所有表""" ensure_db_dir() print("📦 创建数据库表...") Base.metadata.drop_all(bind=engine) # 删除旧表 Base.metadata.create_all(bind=engine) print("✅ 数据库表创建完成") def rebuild_vector_store(reset_doc_status: bool = True): """重建知识库向量集合(按 DB 中的 KB 列表重建 collection 壳)。""" from app.vector_store import vector_store with db_session() as db: print("🧹 重建向量库集合...") kb_list = db.query(KnowledgeBase).all() # 删除现有 KB 集合 try: collections = vector_store.client.list_collections() except Exception as exc: raise RuntimeError(f"无法读取向量集合列表: {exc}") from exc for col in collections: name = getattr(col, "name", None) or str(col) if name.startswith("kb_"): try: vector_store.client.delete_collection(name=name) print(f" - removed {name}") except Exception as exc: print(f" - skip remove {name}: {exc}") # 按 DB 重建 KB 集合 for kb in kb_list: vector_store.create_collection(kb.id, kb.embedding_model) print(f" + created kb_{kb.id} ({kb.embedding_model})") if reset_doc_status: kb.chunk_count = 0 docs = db.query(KnowledgeDocument).filter(KnowledgeDocument.kb_id == kb.id).all() kb.doc_count = 0 for doc in docs: doc.chunk_count = 0 doc.status = "pending" doc.error_message = None doc.processed_at = None db.commit() print("✅ 向量库重建完成(仅重建集合壳,文档需重新索引)") def init_default_data(): with db_session() as db: # 检查是否已有数据 # OpenAI Compatible (SiliconFlow API) CosyVoice 2.0 预设声音 (8个) # 参考: https://docs.siliconflow.cn/cn/api-reference/audio/create-speech voices = [ # 男声 (Male Voices) Voice( id=SEED_VOICE_IDS["alex"], name="Alex", vendor="OpenAI Compatible", gender="Male", language="en", description="Steady male voice.", model=VOICE_MODEL, voice_key=f"{VOICE_MODEL}:alex", is_system=True, ), Voice( id=SEED_VOICE_IDS["david"], name="David", vendor="OpenAI Compatible", gender="Male", language="en", description="Cheerful male voice.", model=VOICE_MODEL, voice_key=f"{VOICE_MODEL}:david", is_system=True, ), # 女声 (Female Voices) Voice( id=SEED_VOICE_IDS["bella"], name="Bella", vendor="OpenAI Compatible", gender="Female", language="en", description="Passionate female voice.", model=VOICE_MODEL, voice_key=f"{VOICE_MODEL}:bella", is_system=True, ), Voice( id=SEED_VOICE_IDS["claire"], name="Claire", vendor="OpenAI Compatible", gender="Female", language="en", description="Gentle female voice.", model=VOICE_MODEL, voice_key=f"{VOICE_MODEL}:claire", is_system=True, ), ] seed_if_empty(db, Voice, voices, "✅ 默认声音数据已初始化 (OpenAI Compatible CosyVoice 2.0)") def init_default_tools(recreate: bool = False): """初始化默认工具,或按需重建工具表数据。""" from app.routers.tools import _seed_default_tools_if_empty, recreate_tool_resources with db_session() as db: if recreate: recreate_tool_resources(db) print("✅ 工具库已重建") else: _seed_default_tools_if_empty(db) print("✅ 默认工具已初始化") def init_default_assistants(): """初始化默认助手""" with db_session() as db: assistants = [ Assistant( id=SEED_ASSISTANT_IDS["default"], user_id=1, name="AI 助手", call_count=0, opener="你好!我是AI助手,有什么可以帮你的吗?", prompt="你是一个友好的AI助手,请用简洁清晰的语言回答用户的问题。", language="zh", voice_output_enabled=True, voice=SEED_VOICE_IDS["bella"], speed=1.0, hotwords=[], tools=["current_time"], interruption_sensitivity=500, config_mode="platform", llm_model_id=SEED_LLM_IDS["deepseek_chat"], asr_model_id=SEED_ASR_IDS["sensevoice_small"], ), Assistant( id=SEED_ASSISTANT_IDS["customer_service"], user_id=1, name="客服助手", call_count=0, opener="您好,欢迎致电客服中心,请问有什么可以帮您?", prompt="你是一个专业的客服人员,耐心解答客户问题,提供优质的服务体验。", language="zh", voice_output_enabled=True, voice=SEED_VOICE_IDS["claire"], speed=1.0, hotwords=["客服", "投诉", "咨询"], tools=["current_time"], interruption_sensitivity=600, config_mode="platform", ), Assistant( id=SEED_ASSISTANT_IDS["english_tutor"], user_id=1, name="英语导师", call_count=0, opener="Hello! I'm your English learning companion. How can I help you today?", prompt="You are a friendly English tutor. Help users practice English conversation and explain grammar points clearly.", language="en", voice_output_enabled=True, voice=SEED_VOICE_IDS["alex"], speed=1.0, hotwords=["grammar", "vocabulary", "practice"], tools=["current_time"], interruption_sensitivity=400, config_mode="platform", ), ] seed_if_empty(db, Assistant, assistants, "✅ 默认助手数据已初始化") def init_default_workflows(): """初始化默认工作流""" from datetime import datetime with db_session() as db: now = datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S") workflows = [ Workflow( id="simple_conversation", user_id=1, name="简单对话", node_count=2, created_at=now, updated_at=now, global_prompt="处理简单的对话流程,用户问什么答什么。", nodes=[ {"id": "1", "type": "start", "position": {"x": 100, "y": 100}, "data": {"label": "开始"}}, {"id": "2", "type": "ai_reply", "position": {"x": 300, "y": 100}, "data": {"label": "AI回复"}}, ], edges=[{"source": "1", "target": "2", "id": "e1-2"}], ), Workflow( id="voice_input_flow", user_id=1, name="语音输入流程", node_count=4, created_at=now, updated_at=now, global_prompt="处理语音输入的完整流程。", nodes=[ {"id": "1", "type": "start", "position": {"x": 100, "y": 100}, "data": {"label": "开始"}}, {"id": "2", "type": "asr", "position": {"x": 250, "y": 100}, "data": {"label": "语音识别"}}, {"id": "3", "type": "llm", "position": {"x": 400, "y": 100}, "data": {"label": "LLM处理"}}, {"id": "4", "type": "tts", "position": {"x": 550, "y": 100}, "data": {"label": "语音合成"}}, ], edges=[ {"source": "1", "target": "2", "id": "e1-2"}, {"source": "2", "target": "3", "id": "e2-3"}, {"source": "3", "target": "4", "id": "e3-4"}, ], ), ] seed_if_empty(db, Workflow, workflows, "✅ 默认工作流数据已初始化") def init_default_knowledge_bases(): """初始化默认知识库""" with db_session() as db: kb = KnowledgeBase( id="default_kb", user_id=1, name="默认知识库", description="系统默认知识库,用于存储常见问题解答。", embedding_model="text-embedding-3-small", chunk_size=500, chunk_overlap=50, doc_count=0, chunk_count=0, status="active", ) seed_if_empty(db, KnowledgeBase, kb, "✅ 默认知识库已初始化") def init_default_llm_models(): """初始化默认LLM模型""" with db_session() as db: llm_models = [ LLMModel( id=SEED_LLM_IDS["deepseek_chat"], user_id=1, name="DeepSeek Chat", vendor="OpenAI Compatible", type="text", base_url="https://api.deepseek.com", api_key="YOUR_API_KEY", # 用户需替换 model_name="deepseek-chat", temperature=0.7, context_length=4096, enabled=True, ), LLMModel( id=SEED_LLM_IDS["glm_4"], user_id=1, name="GLM-4", vendor="ZhipuAI", type="text", base_url="https://open.bigmodel.cn/api/paas/v4", api_key="YOUR_API_KEY", model_name="glm-4", temperature=0.7, context_length=8192, enabled=True, ), LLMModel( id=SEED_LLM_IDS["embedding_3_small"], user_id=1, name="Embedding 3 Small", vendor="OpenAI Compatible", type="embedding", base_url="https://api.openai.com/v1", api_key="YOUR_API_KEY", model_name="text-embedding-3-small", enabled=True, ), ] seed_if_empty(db, LLMModel, llm_models, "✅ 默认LLM模型已初始化") def init_default_asr_models(): """初始化默认ASR模型""" with db_session() as db: asr_models = [ ASRModel( id=SEED_ASR_IDS["sensevoice_small"], user_id=1, name="FunAudioLLM/SenseVoiceSmall", vendor="OpenAI Compatible", language="Multi-lingual", base_url="https://api.siliconflow.cn/v1", api_key="YOUR_API_KEY", model_name="FunAudioLLM/SenseVoiceSmall", hotwords=[], enable_punctuation=True, enable_normalization=True, enabled=True, ), ASRModel( id=SEED_ASR_IDS["telespeech_asr"], user_id=1, name="TeleAI/TeleSpeechASR", vendor="OpenAI Compatible", language="Multi-lingual", base_url="https://api.siliconflow.cn/v1", api_key="YOUR_API_KEY", model_name="TeleAI/TeleSpeechASR", hotwords=[], enable_punctuation=True, enable_normalization=True, enabled=True, ), ] seed_if_empty(db, ASRModel, asr_models, "✅ 默认ASR模型已初始化") if __name__ == "__main__": parser = argparse.ArgumentParser(description="初始化/重建 AI VideoAssistant 数据与向量库") parser.add_argument( "--rebuild-db", action="store_true", help="重建数据库(drop + create tables)", ) parser.add_argument( "--rebuild-vector-store", action="store_true", help="重建向量库 KB 集合(清空后按 DB 的 knowledge_bases 重建 collection)", ) parser.add_argument( "--skip-seed", action="store_true", help="跳过默认数据初始化", ) parser.add_argument( "--recreate-tool-db", action="store_true", help="重建工具库数据(清空 tool_resources 后按内置默认工具重建)", ) args = parser.parse_args() # 无参数时保持旧行为:重建 DB + 初始化默认数据 # 仅当完全未指定任何选项时才自动触发 rebuild-db。 if ( not args.rebuild_db and not args.rebuild_vector_store and not args.skip_seed and not args.recreate_tool_db ): args.rebuild_db = True ensure_db_dir() if args.rebuild_db: init_db() else: print("ℹ️ 跳过数据库结构变更(未指定 --rebuild-db)") if not args.skip_seed or args.recreate_tool_db: print("ℹ️ 当前将执行非破坏性流程(仅工具/默认数据初始化)") if args.recreate_tool_db: init_default_tools(recreate=True) if not args.skip_seed: init_default_data() if not args.recreate_tool_db: init_default_tools(recreate=False) init_default_assistants() init_default_workflows() init_default_knowledge_bases() init_default_llm_models() init_default_asr_models() print("✅ 默认数据初始化完成") if args.rebuild_vector_store: rebuild_vector_store(reset_doc_status=True) print("🎉 初始化脚本执行完成!")