diff --git a/api/init_db.py b/api/init_db.py index 8d8c137..fa8ff5d 100644 --- a/api/init_db.py +++ b/api/init_db.py @@ -3,6 +3,7 @@ import argparse import os import sys +from contextlib import contextmanager from sqlalchemy import inspect, text # 添加路径 @@ -12,11 +13,43 @@ from app.db import Base, engine, DATABASE_URL from app.models import Voice, Assistant, KnowledgeBase, Workflow, LLMModel, ASRModel, KnowledgeDocument +def ensure_db_dir(): + """确保 SQLite 数据目录存在。""" + if not DATABASE_URL.startswith("sqlite:///"): + return + db_path = DATABASE_URL.replace("sqlite:///", "") + data_dir = os.path.dirname(db_path) + if data_dir: + os.makedirs(data_dir, exist_ok=True) + + +@contextmanager +def db_session(): + """统一管理 DB session 生命周期。""" + from app.db import SessionLocal + + db = SessionLocal() + try: + yield db + finally: + db.close() + + +def seed_if_empty(db, model_cls, records, success_msg: str): + """当目标表为空时写入默认记录。""" + if db.query(model_cls).count() != 0: + return + if isinstance(records, list): + db.add_all(records) + else: + db.add(records) + db.commit() + print(success_msg) + + def init_db(): """创建所有表""" - # 确保 data 目录存在 - data_dir = os.path.dirname(DATABASE_URL.replace("sqlite:///", "")) - os.makedirs(data_dir, exist_ok=True) + ensure_db_dir() print("📦 创建数据库表...") Base.metadata.drop_all(bind=engine) # 删除旧表 @@ -59,11 +92,9 @@ def migrate_db_schema(): def rebuild_vector_store(reset_doc_status: bool = True): """重建知识库向量集合(按 DB 中的 KB 列表重建 collection 壳)。""" - from app.db import SessionLocal from app.vector_store import vector_store - db = SessionLocal() - try: + with db_session() as db: print("🧹 重建向量库集合...") kb_list = db.query(KnowledgeBase).all() @@ -99,360 +130,295 @@ def rebuild_vector_store(reset_doc_status: bool = True): db.commit() print("✅ 向量库重建完成(仅重建集合壳,文档需重新索引)") - finally: - db.close() def init_default_data(): - from sqlalchemy.orm import Session - from app.db import SessionLocal - from app.models import Voice - - db = SessionLocal() - try: + with db_session() as db: # 检查是否已有数据 - if db.query(Voice).count() == 0: - # SiliconFlow CosyVoice 2.0 预设声音 (8个) - # 参考: https://docs.siliconflow.cn/cn/api-reference/audio/create-speech - voices = [ - # 男声 (Male Voices) - Voice(id="alex", name="Alex", vendor="SiliconFlow", gender="Male", language="en", - description="Steady male voice.", is_system=True), - Voice(id="david", name="David", vendor="SiliconFlow", gender="Male", language="en", - description="Cheerful male voice.", is_system=True), - # 女声 (Female Voices) - Voice(id="bella", name="Bella", vendor="SiliconFlow", gender="Female", language="en", - description="Passionate female voice.", is_system=True), - Voice(id="claire", name="Claire", vendor="SiliconFlow", gender="Female", language="en", - description="Gentle female voice.", is_system=True), - ] - for v in voices: - db.add(v) - db.commit() - print("✅ 默认声音数据已初始化 (SiliconFlow CosyVoice 2.0)") - finally: - db.close() + # SiliconFlow CosyVoice 2.0 预设声音 (8个) + # 参考: https://docs.siliconflow.cn/cn/api-reference/audio/create-speech + voices = [ + # 男声 (Male Voices) + Voice(id="alex", name="Alex", vendor="SiliconFlow", gender="Male", language="en", + description="Steady male voice.", is_system=True), + Voice(id="david", name="David", vendor="SiliconFlow", gender="Male", language="en", + description="Cheerful male voice.", is_system=True), + # 女声 (Female Voices) + Voice(id="bella", name="Bella", vendor="SiliconFlow", gender="Female", language="en", + description="Passionate female voice.", is_system=True), + Voice(id="claire", name="Claire", vendor="SiliconFlow", gender="Female", language="en", + description="Gentle female voice.", is_system=True), + ] + seed_if_empty(db, Voice, voices, "✅ 默认声音数据已初始化 (SiliconFlow CosyVoice 2.0)") def init_default_tools(recreate: bool = False): """初始化默认工具,或按需重建工具表数据。""" - from app.db import SessionLocal from app.routers.tools import _seed_default_tools_if_empty, recreate_tool_resources - db = SessionLocal() - try: + with db_session() as db: if recreate: recreate_tool_resources(db) print("✅ 工具库已重建") else: _seed_default_tools_if_empty(db) print("✅ 默认工具已初始化") - finally: - db.close() def init_default_assistants(): """初始化默认助手""" - from sqlalchemy.orm import Session - from app.db import SessionLocal - - db = SessionLocal() - try: - if db.query(Assistant).count() == 0: - assistants = [ - Assistant( - id="default", - user_id=1, - name="AI 助手", - call_count=0, - opener="你好!我是AI助手,有什么可以帮你的吗?", - prompt="你是一个友好的AI助手,请用简洁清晰的语言回答用户的问题。", - language="zh", - voice_output_enabled=True, - voice="anna", - speed=1.0, - hotwords=[], - tools=["calculator", "current_time"], - interruption_sensitivity=500, - config_mode="platform", - llm_model_id="deepseek-chat", - asr_model_id="paraformer-v2", - ), - Assistant( - id="customer_service", - user_id=1, - name="客服助手", - call_count=0, - opener="您好,欢迎致电客服中心,请问有什么可以帮您?", - prompt="你是一个专业的客服人员,耐心解答客户问题,提供优质的服务体验。", - language="zh", - voice_output_enabled=True, - voice="bella", - speed=1.0, - hotwords=["客服", "投诉", "咨询"], - tools=["current_time"], - interruption_sensitivity=600, - config_mode="platform", - ), - Assistant( - id="english_tutor", - user_id=1, - name="英语导师", - call_count=0, - opener="Hello! I'm your English learning companion. How can I help you today?", - prompt="You are a friendly English tutor. Help users practice English conversation and explain grammar points clearly.", - language="en", - voice_output_enabled=True, - voice="alex", - speed=1.0, - hotwords=["grammar", "vocabulary", "practice"], - tools=["calculator"], - interruption_sensitivity=400, - config_mode="platform", - ), - ] - for a in assistants: - db.add(a) - db.commit() - print("✅ 默认助手数据已初始化") - finally: - db.close() + with db_session() as db: + assistants = [ + Assistant( + id="default", + user_id=1, + name="AI 助手", + call_count=0, + opener="你好!我是AI助手,有什么可以帮你的吗?", + prompt="你是一个友好的AI助手,请用简洁清晰的语言回答用户的问题。", + language="zh", + voice_output_enabled=True, + voice="anna", + speed=1.0, + hotwords=[], + tools=["calculator", "current_time"], + interruption_sensitivity=500, + config_mode="platform", + llm_model_id="deepseek-chat", + asr_model_id="paraformer-v2", + ), + Assistant( + id="customer_service", + user_id=1, + name="客服助手", + call_count=0, + opener="您好,欢迎致电客服中心,请问有什么可以帮您?", + prompt="你是一个专业的客服人员,耐心解答客户问题,提供优质的服务体验。", + language="zh", + voice_output_enabled=True, + voice="bella", + speed=1.0, + hotwords=["客服", "投诉", "咨询"], + tools=["current_time"], + interruption_sensitivity=600, + config_mode="platform", + ), + Assistant( + id="english_tutor", + user_id=1, + name="英语导师", + call_count=0, + opener="Hello! I'm your English learning companion. How can I help you today?", + prompt="You are a friendly English tutor. Help users practice English conversation and explain grammar points clearly.", + language="en", + voice_output_enabled=True, + voice="alex", + speed=1.0, + hotwords=["grammar", "vocabulary", "practice"], + tools=["calculator"], + interruption_sensitivity=400, + config_mode="platform", + ), + ] + seed_if_empty(db, Assistant, assistants, "✅ 默认助手数据已初始化") def init_default_workflows(): """初始化默认工作流""" - from sqlalchemy.orm import Session - from app.db import SessionLocal from datetime import datetime - db = SessionLocal() - try: - if db.query(Workflow).count() == 0: - now = datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S") - workflows = [ - Workflow( - id="simple_conversation", - user_id=1, - name="简单对话", - node_count=2, - created_at=now, - updated_at=now, - global_prompt="处理简单的对话流程,用户问什么答什么。", - nodes=[ - {"id": "1", "type": "start", "position": {"x": 100, "y": 100}, "data": {"label": "开始"}}, - {"id": "2", "type": "ai_reply", "position": {"x": 300, "y": 100}, "data": {"label": "AI回复"}}, - ], - edges=[{"source": "1", "target": "2", "id": "e1-2"}], - ), - Workflow( - id="voice_input_flow", - user_id=1, - name="语音输入流程", - node_count=4, - created_at=now, - updated_at=now, - global_prompt="处理语音输入的完整流程。", - nodes=[ - {"id": "1", "type": "start", "position": {"x": 100, "y": 100}, "data": {"label": "开始"}}, - {"id": "2", "type": "asr", "position": {"x": 250, "y": 100}, "data": {"label": "语音识别"}}, - {"id": "3", "type": "llm", "position": {"x": 400, "y": 100}, "data": {"label": "LLM处理"}}, - {"id": "4", "type": "tts", "position": {"x": 550, "y": 100}, "data": {"label": "语音合成"}}, - ], - edges=[ - {"source": "1", "target": "2", "id": "e1-2"}, - {"source": "2", "target": "3", "id": "e2-3"}, - {"source": "3", "target": "4", "id": "e3-4"}, - ], - ), - ] - for w in workflows: - db.add(w) - db.commit() - print("✅ 默认工作流数据已初始化") - finally: - db.close() + with db_session() as db: + now = datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S") + workflows = [ + Workflow( + id="simple_conversation", + user_id=1, + name="简单对话", + node_count=2, + created_at=now, + updated_at=now, + global_prompt="处理简单的对话流程,用户问什么答什么。", + nodes=[ + {"id": "1", "type": "start", "position": {"x": 100, "y": 100}, "data": {"label": "开始"}}, + {"id": "2", "type": "ai_reply", "position": {"x": 300, "y": 100}, "data": {"label": "AI回复"}}, + ], + edges=[{"source": "1", "target": "2", "id": "e1-2"}], + ), + Workflow( + id="voice_input_flow", + user_id=1, + name="语音输入流程", + node_count=4, + created_at=now, + updated_at=now, + global_prompt="处理语音输入的完整流程。", + nodes=[ + {"id": "1", "type": "start", "position": {"x": 100, "y": 100}, "data": {"label": "开始"}}, + {"id": "2", "type": "asr", "position": {"x": 250, "y": 100}, "data": {"label": "语音识别"}}, + {"id": "3", "type": "llm", "position": {"x": 400, "y": 100}, "data": {"label": "LLM处理"}}, + {"id": "4", "type": "tts", "position": {"x": 550, "y": 100}, "data": {"label": "语音合成"}}, + ], + edges=[ + {"source": "1", "target": "2", "id": "e1-2"}, + {"source": "2", "target": "3", "id": "e2-3"}, + {"source": "3", "target": "4", "id": "e3-4"}, + ], + ), + ] + seed_if_empty(db, Workflow, workflows, "✅ 默认工作流数据已初始化") def init_default_knowledge_bases(): """初始化默认知识库""" - from sqlalchemy.orm import Session - from app.db import SessionLocal - - db = SessionLocal() - try: - if db.query(KnowledgeBase).count() == 0: - kb = KnowledgeBase( - id="default_kb", - user_id=1, - name="默认知识库", - description="系统默认知识库,用于存储常见问题解答。", - embedding_model="text-embedding-3-small", - chunk_size=500, - chunk_overlap=50, - doc_count=0, - chunk_count=0, - status="active", - ) - db.add(kb) - db.commit() - print("✅ 默认知识库已初始化") - finally: - db.close() + with db_session() as db: + kb = KnowledgeBase( + id="default_kb", + user_id=1, + name="默认知识库", + description="系统默认知识库,用于存储常见问题解答。", + embedding_model="text-embedding-3-small", + chunk_size=500, + chunk_overlap=50, + doc_count=0, + chunk_count=0, + status="active", + ) + seed_if_empty(db, KnowledgeBase, kb, "✅ 默认知识库已初始化") def init_default_llm_models(): """初始化默认LLM模型""" - from sqlalchemy.orm import Session - from app.db import SessionLocal - - db = SessionLocal() - try: - if db.query(LLMModel).count() == 0: - llm_models = [ - LLMModel( - id="deepseek-chat", - user_id=1, - name="DeepSeek Chat", - vendor="SiliconFlow", - type="text", - base_url="https://api.deepseek.com", - api_key="YOUR_API_KEY", # 用户需替换 - model_name="deepseek-chat", - temperature=0.7, - context_length=4096, - enabled=True, - ), - LLMModel( - id="deepseek-reasoner", - user_id=1, - name="DeepSeek Reasoner", - vendor="SiliconFlow", - type="text", - base_url="https://api.deepseek.com", - api_key="YOUR_API_KEY", - model_name="deepseek-reasoner", - temperature=0.7, - context_length=4096, - enabled=True, - ), - LLMModel( - id="gpt-4o", - user_id=1, - name="GPT-4o", - vendor="OpenAI", - type="text", - base_url="https://api.openai.com/v1", - api_key="YOUR_API_KEY", - model_name="gpt-4o", - temperature=0.7, - context_length=16384, - enabled=True, - ), - LLMModel( - id="glm-4", - user_id=1, - name="GLM-4", - vendor="ZhipuAI", - type="text", - base_url="https://open.bigmodel.cn/api/paas/v4", - api_key="YOUR_API_KEY", - model_name="glm-4", - temperature=0.7, - context_length=8192, - enabled=True, - ), - LLMModel( - id="text-embedding-3-small", - user_id=1, - name="Embedding 3 Small", - vendor="OpenAI", - type="embedding", - base_url="https://api.openai.com/v1", - api_key="YOUR_API_KEY", - model_name="text-embedding-3-small", - enabled=True, - ), - ] - for m in llm_models: - db.add(m) - db.commit() - print("✅ 默认LLM模型已初始化") - finally: - db.close() + with db_session() as db: + llm_models = [ + LLMModel( + id="deepseek-chat", + user_id=1, + name="DeepSeek Chat", + vendor="SiliconFlow", + type="text", + base_url="https://api.deepseek.com", + api_key="YOUR_API_KEY", # 用户需替换 + model_name="deepseek-chat", + temperature=0.7, + context_length=4096, + enabled=True, + ), + LLMModel( + id="deepseek-reasoner", + user_id=1, + name="DeepSeek Reasoner", + vendor="SiliconFlow", + type="text", + base_url="https://api.deepseek.com", + api_key="YOUR_API_KEY", + model_name="deepseek-reasoner", + temperature=0.7, + context_length=4096, + enabled=True, + ), + LLMModel( + id="gpt-4o", + user_id=1, + name="GPT-4o", + vendor="OpenAI", + type="text", + base_url="https://api.openai.com/v1", + api_key="YOUR_API_KEY", + model_name="gpt-4o", + temperature=0.7, + context_length=16384, + enabled=True, + ), + LLMModel( + id="glm-4", + user_id=1, + name="GLM-4", + vendor="ZhipuAI", + type="text", + base_url="https://open.bigmodel.cn/api/paas/v4", + api_key="YOUR_API_KEY", + model_name="glm-4", + temperature=0.7, + context_length=8192, + enabled=True, + ), + LLMModel( + id="text-embedding-3-small", + user_id=1, + name="Embedding 3 Small", + vendor="OpenAI", + type="embedding", + base_url="https://api.openai.com/v1", + api_key="YOUR_API_KEY", + model_name="text-embedding-3-small", + enabled=True, + ), + ] + seed_if_empty(db, LLMModel, llm_models, "✅ 默认LLM模型已初始化") def init_default_asr_models(): """初始化默认ASR模型""" - from sqlalchemy.orm import Session - from app.db import SessionLocal - - db = SessionLocal() - try: - if db.query(ASRModel).count() == 0: - asr_models = [ - ASRModel( - id="paraformer-v2", - user_id=1, - name="Paraformer V2", - vendor="SiliconFlow", - language="zh", - base_url="https://api.siliconflow.cn/v1", - api_key="YOUR_API_KEY", - model_name="paraformer-v2", - hotwords=["人工智能", "机器学习"], - enable_punctuation=True, - enable_normalization=True, - enabled=True, - ), - ASRModel( - id="paraformer-en", - user_id=1, - name="Paraformer English", - vendor="SiliconFlow", - language="en", - base_url="https://api.siliconflow.cn/v1", - api_key="YOUR_API_KEY", - model_name="paraformer-en", - hotwords=[], - enable_punctuation=True, - enable_normalization=True, - enabled=True, - ), - ASRModel( - id="whisper-1", - user_id=1, - name="Whisper", - vendor="OpenAI", - language="Multi-lingual", - base_url="https://api.openai.com/v1", - api_key="YOUR_API_KEY", - model_name="whisper-1", - hotwords=[], - enable_punctuation=True, - enable_normalization=True, - enabled=True, - ), - ASRModel( - id="sensevoice", - user_id=1, - name="SenseVoice", - vendor="SiliconFlow", - language="Multi-lingual", - base_url="https://api.siliconflow.cn/v1", - api_key="YOUR_API_KEY", - model_name="sensevoice", - hotwords=[], - enable_punctuation=True, - enable_normalization=True, - enabled=True, - ), - ] - for m in asr_models: - db.add(m) - db.commit() - print("✅ 默认ASR模型已初始化") - finally: - db.close() + with db_session() as db: + asr_models = [ + ASRModel( + id="paraformer-v2", + user_id=1, + name="Paraformer V2", + vendor="SiliconFlow", + language="zh", + base_url="https://api.siliconflow.cn/v1", + api_key="YOUR_API_KEY", + model_name="paraformer-v2", + hotwords=["人工智能", "机器学习"], + enable_punctuation=True, + enable_normalization=True, + enabled=True, + ), + ASRModel( + id="paraformer-en", + user_id=1, + name="Paraformer English", + vendor="SiliconFlow", + language="en", + base_url="https://api.siliconflow.cn/v1", + api_key="YOUR_API_KEY", + model_name="paraformer-en", + hotwords=[], + enable_punctuation=True, + enable_normalization=True, + enabled=True, + ), + ASRModel( + id="whisper-1", + user_id=1, + name="Whisper", + vendor="OpenAI", + language="Multi-lingual", + base_url="https://api.openai.com/v1", + api_key="YOUR_API_KEY", + model_name="whisper-1", + hotwords=[], + enable_punctuation=True, + enable_normalization=True, + enabled=True, + ), + ASRModel( + id="sensevoice", + user_id=1, + name="SenseVoice", + vendor="SiliconFlow", + language="Multi-lingual", + base_url="https://api.siliconflow.cn/v1", + api_key="YOUR_API_KEY", + model_name="sensevoice", + hotwords=[], + enable_punctuation=True, + enable_normalization=True, + enabled=True, + ), + ] + seed_if_empty(db, ASRModel, asr_models, "✅ 默认ASR模型已初始化") if __name__ == "__main__": @@ -483,9 +449,7 @@ if __name__ == "__main__": if not args.rebuild_db and not args.rebuild_vector_store and not args.skip_seed: args.rebuild_db = True - # 确保 data 目录存在 - data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data") - os.makedirs(data_dir, exist_ok=True) + ensure_db_dir() if args.rebuild_db: init_db()