Remove redundant code in init_db

This commit is contained in:
Xin Wang
2026-02-12 15:59:21 +08:00
parent 399c9c97b1
commit bbfb5570cc

View File

@@ -3,6 +3,7 @@
import argparse import argparse
import os import os
import sys import sys
from contextlib import contextmanager
from sqlalchemy import inspect, text from sqlalchemy import inspect, text
# 添加路径 # 添加路径
@@ -12,11 +13,43 @@ from app.db import Base, engine, DATABASE_URL
from app.models import Voice, Assistant, KnowledgeBase, Workflow, LLMModel, ASRModel, KnowledgeDocument from app.models import Voice, Assistant, KnowledgeBase, Workflow, LLMModel, ASRModel, KnowledgeDocument
def ensure_db_dir():
"""确保 SQLite 数据目录存在。"""
if not DATABASE_URL.startswith("sqlite:///"):
return
db_path = DATABASE_URL.replace("sqlite:///", "")
data_dir = os.path.dirname(db_path)
if data_dir:
os.makedirs(data_dir, exist_ok=True)
@contextmanager
def db_session():
"""统一管理 DB session 生命周期。"""
from app.db import SessionLocal
db = SessionLocal()
try:
yield db
finally:
db.close()
def seed_if_empty(db, model_cls, records, success_msg: str):
"""当目标表为空时写入默认记录。"""
if db.query(model_cls).count() != 0:
return
if isinstance(records, list):
db.add_all(records)
else:
db.add(records)
db.commit()
print(success_msg)
def init_db(): def init_db():
"""创建所有表""" """创建所有表"""
# 确保 data 目录存在 ensure_db_dir()
data_dir = os.path.dirname(DATABASE_URL.replace("sqlite:///", ""))
os.makedirs(data_dir, exist_ok=True)
print("📦 创建数据库表...") print("📦 创建数据库表...")
Base.metadata.drop_all(bind=engine) # 删除旧表 Base.metadata.drop_all(bind=engine) # 删除旧表
@@ -59,11 +92,9 @@ def migrate_db_schema():
def rebuild_vector_store(reset_doc_status: bool = True): def rebuild_vector_store(reset_doc_status: bool = True):
"""重建知识库向量集合(按 DB 中的 KB 列表重建 collection 壳)。""" """重建知识库向量集合(按 DB 中的 KB 列表重建 collection 壳)。"""
from app.db import SessionLocal
from app.vector_store import vector_store from app.vector_store import vector_store
db = SessionLocal() with db_session() as db:
try:
print("🧹 重建向量库集合...") print("🧹 重建向量库集合...")
kb_list = db.query(KnowledgeBase).all() kb_list = db.query(KnowledgeBase).all()
@@ -99,360 +130,295 @@ def rebuild_vector_store(reset_doc_status: bool = True):
db.commit() db.commit()
print("✅ 向量库重建完成(仅重建集合壳,文档需重新索引)") print("✅ 向量库重建完成(仅重建集合壳,文档需重新索引)")
finally:
db.close()
def init_default_data(): def init_default_data():
from sqlalchemy.orm import Session with db_session() as db:
from app.db import SessionLocal
from app.models import Voice
db = SessionLocal()
try:
# 检查是否已有数据 # 检查是否已有数据
if db.query(Voice).count() == 0: # SiliconFlow CosyVoice 2.0 预设声音 (8个)
# SiliconFlow CosyVoice 2.0 预设声音 (8个) # 参考: https://docs.siliconflow.cn/cn/api-reference/audio/create-speech
# 参考: https://docs.siliconflow.cn/cn/api-reference/audio/create-speech voices = [
voices = [ # 男声 (Male Voices)
# 男声 (Male Voices) Voice(id="alex", name="Alex", vendor="SiliconFlow", gender="Male", language="en",
Voice(id="alex", name="Alex", vendor="SiliconFlow", gender="Male", language="en", description="Steady male voice.", is_system=True),
description="Steady male voice.", is_system=True), Voice(id="david", name="David", vendor="SiliconFlow", gender="Male", language="en",
Voice(id="david", name="David", vendor="SiliconFlow", gender="Male", language="en", description="Cheerful male voice.", is_system=True),
description="Cheerful male voice.", is_system=True), # 女声 (Female Voices)
# 女声 (Female Voices) Voice(id="bella", name="Bella", vendor="SiliconFlow", gender="Female", language="en",
Voice(id="bella", name="Bella", vendor="SiliconFlow", gender="Female", language="en", description="Passionate female voice.", is_system=True),
description="Passionate female voice.", is_system=True), Voice(id="claire", name="Claire", vendor="SiliconFlow", gender="Female", language="en",
Voice(id="claire", name="Claire", vendor="SiliconFlow", gender="Female", language="en", description="Gentle female voice.", is_system=True),
description="Gentle female voice.", is_system=True), ]
] seed_if_empty(db, Voice, voices, "✅ 默认声音数据已初始化 (SiliconFlow CosyVoice 2.0)")
for v in voices:
db.add(v)
db.commit()
print("✅ 默认声音数据已初始化 (SiliconFlow CosyVoice 2.0)")
finally:
db.close()
def init_default_tools(recreate: bool = False): def init_default_tools(recreate: bool = False):
"""初始化默认工具,或按需重建工具表数据。""" """初始化默认工具,或按需重建工具表数据。"""
from app.db import SessionLocal
from app.routers.tools import _seed_default_tools_if_empty, recreate_tool_resources from app.routers.tools import _seed_default_tools_if_empty, recreate_tool_resources
db = SessionLocal() with db_session() as db:
try:
if recreate: if recreate:
recreate_tool_resources(db) recreate_tool_resources(db)
print("✅ 工具库已重建") print("✅ 工具库已重建")
else: else:
_seed_default_tools_if_empty(db) _seed_default_tools_if_empty(db)
print("✅ 默认工具已初始化") print("✅ 默认工具已初始化")
finally:
db.close()
def init_default_assistants(): def init_default_assistants():
"""初始化默认助手""" """初始化默认助手"""
from sqlalchemy.orm import Session with db_session() as db:
from app.db import SessionLocal assistants = [
Assistant(
db = SessionLocal() id="default",
try: user_id=1,
if db.query(Assistant).count() == 0: name="AI 助手",
assistants = [ call_count=0,
Assistant( opener="你好我是AI助手有什么可以帮你的吗",
id="default", prompt="你是一个友好的AI助手请用简洁清晰的语言回答用户的问题。",
user_id=1, language="zh",
name="AI 助手", voice_output_enabled=True,
call_count=0, voice="anna",
opener="你好我是AI助手有什么可以帮你的吗", speed=1.0,
prompt="你是一个友好的AI助手请用简洁清晰的语言回答用户的问题。", hotwords=[],
language="zh", tools=["calculator", "current_time"],
voice_output_enabled=True, interruption_sensitivity=500,
voice="anna", config_mode="platform",
speed=1.0, llm_model_id="deepseek-chat",
hotwords=[], asr_model_id="paraformer-v2",
tools=["calculator", "current_time"], ),
interruption_sensitivity=500, Assistant(
config_mode="platform", id="customer_service",
llm_model_id="deepseek-chat", user_id=1,
asr_model_id="paraformer-v2", name="客服助手",
), call_count=0,
Assistant( opener="您好,欢迎致电客服中心,请问有什么可以帮您?",
id="customer_service", prompt="你是一个专业的客服人员,耐心解答客户问题,提供优质的服务体验。",
user_id=1, language="zh",
name="客服助手", voice_output_enabled=True,
call_count=0, voice="bella",
opener="您好,欢迎致电客服中心,请问有什么可以帮您?", speed=1.0,
prompt="你是一个专业的客服人员,耐心解答客户问题,提供优质的服务体验。", hotwords=["客服", "投诉", "咨询"],
language="zh", tools=["current_time"],
voice_output_enabled=True, interruption_sensitivity=600,
voice="bella", config_mode="platform",
speed=1.0, ),
hotwords=["客服", "投诉", "咨询"], Assistant(
tools=["current_time"], id="english_tutor",
interruption_sensitivity=600, user_id=1,
config_mode="platform", name="英语导师",
), call_count=0,
Assistant( opener="Hello! I'm your English learning companion. How can I help you today?",
id="english_tutor", prompt="You are a friendly English tutor. Help users practice English conversation and explain grammar points clearly.",
user_id=1, language="en",
name="英语导师", voice_output_enabled=True,
call_count=0, voice="alex",
opener="Hello! I'm your English learning companion. How can I help you today?", speed=1.0,
prompt="You are a friendly English tutor. Help users practice English conversation and explain grammar points clearly.", hotwords=["grammar", "vocabulary", "practice"],
language="en", tools=["calculator"],
voice_output_enabled=True, interruption_sensitivity=400,
voice="alex", config_mode="platform",
speed=1.0, ),
hotwords=["grammar", "vocabulary", "practice"], ]
tools=["calculator"], seed_if_empty(db, Assistant, assistants, "✅ 默认助手数据已初始化")
interruption_sensitivity=400,
config_mode="platform",
),
]
for a in assistants:
db.add(a)
db.commit()
print("✅ 默认助手数据已初始化")
finally:
db.close()
def init_default_workflows(): def init_default_workflows():
"""初始化默认工作流""" """初始化默认工作流"""
from sqlalchemy.orm import Session
from app.db import SessionLocal
from datetime import datetime from datetime import datetime
db = SessionLocal() with db_session() as db:
try: now = datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")
if db.query(Workflow).count() == 0: workflows = [
now = datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S") Workflow(
workflows = [ id="simple_conversation",
Workflow( user_id=1,
id="simple_conversation", name="简单对话",
user_id=1, node_count=2,
name="简单对话", created_at=now,
node_count=2, updated_at=now,
created_at=now, global_prompt="处理简单的对话流程,用户问什么答什么。",
updated_at=now, nodes=[
global_prompt="处理简单的对话流程,用户问什么答什么。", {"id": "1", "type": "start", "position": {"x": 100, "y": 100}, "data": {"label": "开始"}},
nodes=[ {"id": "2", "type": "ai_reply", "position": {"x": 300, "y": 100}, "data": {"label": "AI回复"}},
{"id": "1", "type": "start", "position": {"x": 100, "y": 100}, "data": {"label": "开始"}}, ],
{"id": "2", "type": "ai_reply", "position": {"x": 300, "y": 100}, "data": {"label": "AI回复"}}, edges=[{"source": "1", "target": "2", "id": "e1-2"}],
], ),
edges=[{"source": "1", "target": "2", "id": "e1-2"}], Workflow(
), id="voice_input_flow",
Workflow( user_id=1,
id="voice_input_flow", name="语音输入流程",
user_id=1, node_count=4,
name="语音输入流程", created_at=now,
node_count=4, updated_at=now,
created_at=now, global_prompt="处理语音输入的完整流程。",
updated_at=now, nodes=[
global_prompt="处理语音输入的完整流程。", {"id": "1", "type": "start", "position": {"x": 100, "y": 100}, "data": {"label": "开始"}},
nodes=[ {"id": "2", "type": "asr", "position": {"x": 250, "y": 100}, "data": {"label": "语音识别"}},
{"id": "1", "type": "start", "position": {"x": 100, "y": 100}, "data": {"label": "开始"}}, {"id": "3", "type": "llm", "position": {"x": 400, "y": 100}, "data": {"label": "LLM处理"}},
{"id": "2", "type": "asr", "position": {"x": 250, "y": 100}, "data": {"label": "语音识别"}}, {"id": "4", "type": "tts", "position": {"x": 550, "y": 100}, "data": {"label": "语音合成"}},
{"id": "3", "type": "llm", "position": {"x": 400, "y": 100}, "data": {"label": "LLM处理"}}, ],
{"id": "4", "type": "tts", "position": {"x": 550, "y": 100}, "data": {"label": "语音合成"}}, edges=[
], {"source": "1", "target": "2", "id": "e1-2"},
edges=[ {"source": "2", "target": "3", "id": "e2-3"},
{"source": "1", "target": "2", "id": "e1-2"}, {"source": "3", "target": "4", "id": "e3-4"},
{"source": "2", "target": "3", "id": "e2-3"}, ],
{"source": "3", "target": "4", "id": "e3-4"}, ),
], ]
), seed_if_empty(db, Workflow, workflows, "✅ 默认工作流数据已初始化")
]
for w in workflows:
db.add(w)
db.commit()
print("✅ 默认工作流数据已初始化")
finally:
db.close()
def init_default_knowledge_bases(): def init_default_knowledge_bases():
"""初始化默认知识库""" """初始化默认知识库"""
from sqlalchemy.orm import Session with db_session() as db:
from app.db import SessionLocal kb = KnowledgeBase(
id="default_kb",
db = SessionLocal() user_id=1,
try: name="默认知识库",
if db.query(KnowledgeBase).count() == 0: description="系统默认知识库,用于存储常见问题解答。",
kb = KnowledgeBase( embedding_model="text-embedding-3-small",
id="default_kb", chunk_size=500,
user_id=1, chunk_overlap=50,
name="默认知识库", doc_count=0,
description="系统默认知识库,用于存储常见问题解答。", chunk_count=0,
embedding_model="text-embedding-3-small", status="active",
chunk_size=500, )
chunk_overlap=50, seed_if_empty(db, KnowledgeBase, kb, "✅ 默认知识库已初始化")
doc_count=0,
chunk_count=0,
status="active",
)
db.add(kb)
db.commit()
print("✅ 默认知识库已初始化")
finally:
db.close()
def init_default_llm_models(): def init_default_llm_models():
"""初始化默认LLM模型""" """初始化默认LLM模型"""
from sqlalchemy.orm import Session with db_session() as db:
from app.db import SessionLocal llm_models = [
LLMModel(
db = SessionLocal() id="deepseek-chat",
try: user_id=1,
if db.query(LLMModel).count() == 0: name="DeepSeek Chat",
llm_models = [ vendor="SiliconFlow",
LLMModel( type="text",
id="deepseek-chat", base_url="https://api.deepseek.com",
user_id=1, api_key="YOUR_API_KEY", # 用户需替换
name="DeepSeek Chat", model_name="deepseek-chat",
vendor="SiliconFlow", temperature=0.7,
type="text", context_length=4096,
base_url="https://api.deepseek.com", enabled=True,
api_key="YOUR_API_KEY", # 用户需替换 ),
model_name="deepseek-chat", LLMModel(
temperature=0.7, id="deepseek-reasoner",
context_length=4096, user_id=1,
enabled=True, name="DeepSeek Reasoner",
), vendor="SiliconFlow",
LLMModel( type="text",
id="deepseek-reasoner", base_url="https://api.deepseek.com",
user_id=1, api_key="YOUR_API_KEY",
name="DeepSeek Reasoner", model_name="deepseek-reasoner",
vendor="SiliconFlow", temperature=0.7,
type="text", context_length=4096,
base_url="https://api.deepseek.com", enabled=True,
api_key="YOUR_API_KEY", ),
model_name="deepseek-reasoner", LLMModel(
temperature=0.7, id="gpt-4o",
context_length=4096, user_id=1,
enabled=True, name="GPT-4o",
), vendor="OpenAI",
LLMModel( type="text",
id="gpt-4o", base_url="https://api.openai.com/v1",
user_id=1, api_key="YOUR_API_KEY",
name="GPT-4o", model_name="gpt-4o",
vendor="OpenAI", temperature=0.7,
type="text", context_length=16384,
base_url="https://api.openai.com/v1", enabled=True,
api_key="YOUR_API_KEY", ),
model_name="gpt-4o", LLMModel(
temperature=0.7, id="glm-4",
context_length=16384, user_id=1,
enabled=True, name="GLM-4",
), vendor="ZhipuAI",
LLMModel( type="text",
id="glm-4", base_url="https://open.bigmodel.cn/api/paas/v4",
user_id=1, api_key="YOUR_API_KEY",
name="GLM-4", model_name="glm-4",
vendor="ZhipuAI", temperature=0.7,
type="text", context_length=8192,
base_url="https://open.bigmodel.cn/api/paas/v4", enabled=True,
api_key="YOUR_API_KEY", ),
model_name="glm-4", LLMModel(
temperature=0.7, id="text-embedding-3-small",
context_length=8192, user_id=1,
enabled=True, name="Embedding 3 Small",
), vendor="OpenAI",
LLMModel( type="embedding",
id="text-embedding-3-small", base_url="https://api.openai.com/v1",
user_id=1, api_key="YOUR_API_KEY",
name="Embedding 3 Small", model_name="text-embedding-3-small",
vendor="OpenAI", enabled=True,
type="embedding", ),
base_url="https://api.openai.com/v1", ]
api_key="YOUR_API_KEY", seed_if_empty(db, LLMModel, llm_models, "✅ 默认LLM模型已初始化")
model_name="text-embedding-3-small",
enabled=True,
),
]
for m in llm_models:
db.add(m)
db.commit()
print("✅ 默认LLM模型已初始化")
finally:
db.close()
def init_default_asr_models(): def init_default_asr_models():
"""初始化默认ASR模型""" """初始化默认ASR模型"""
from sqlalchemy.orm import Session with db_session() as db:
from app.db import SessionLocal asr_models = [
ASRModel(
db = SessionLocal() id="paraformer-v2",
try: user_id=1,
if db.query(ASRModel).count() == 0: name="Paraformer V2",
asr_models = [ vendor="SiliconFlow",
ASRModel( language="zh",
id="paraformer-v2", base_url="https://api.siliconflow.cn/v1",
user_id=1, api_key="YOUR_API_KEY",
name="Paraformer V2", model_name="paraformer-v2",
vendor="SiliconFlow", hotwords=["人工智能", "机器学习"],
language="zh", enable_punctuation=True,
base_url="https://api.siliconflow.cn/v1", enable_normalization=True,
api_key="YOUR_API_KEY", enabled=True,
model_name="paraformer-v2", ),
hotwords=["人工智能", "机器学习"], ASRModel(
enable_punctuation=True, id="paraformer-en",
enable_normalization=True, user_id=1,
enabled=True, name="Paraformer English",
), vendor="SiliconFlow",
ASRModel( language="en",
id="paraformer-en", base_url="https://api.siliconflow.cn/v1",
user_id=1, api_key="YOUR_API_KEY",
name="Paraformer English", model_name="paraformer-en",
vendor="SiliconFlow", hotwords=[],
language="en", enable_punctuation=True,
base_url="https://api.siliconflow.cn/v1", enable_normalization=True,
api_key="YOUR_API_KEY", enabled=True,
model_name="paraformer-en", ),
hotwords=[], ASRModel(
enable_punctuation=True, id="whisper-1",
enable_normalization=True, user_id=1,
enabled=True, name="Whisper",
), vendor="OpenAI",
ASRModel( language="Multi-lingual",
id="whisper-1", base_url="https://api.openai.com/v1",
user_id=1, api_key="YOUR_API_KEY",
name="Whisper", model_name="whisper-1",
vendor="OpenAI", hotwords=[],
language="Multi-lingual", enable_punctuation=True,
base_url="https://api.openai.com/v1", enable_normalization=True,
api_key="YOUR_API_KEY", enabled=True,
model_name="whisper-1", ),
hotwords=[], ASRModel(
enable_punctuation=True, id="sensevoice",
enable_normalization=True, user_id=1,
enabled=True, name="SenseVoice",
), vendor="SiliconFlow",
ASRModel( language="Multi-lingual",
id="sensevoice", base_url="https://api.siliconflow.cn/v1",
user_id=1, api_key="YOUR_API_KEY",
name="SenseVoice", model_name="sensevoice",
vendor="SiliconFlow", hotwords=[],
language="Multi-lingual", enable_punctuation=True,
base_url="https://api.siliconflow.cn/v1", enable_normalization=True,
api_key="YOUR_API_KEY", enabled=True,
model_name="sensevoice", ),
hotwords=[], ]
enable_punctuation=True, seed_if_empty(db, ASRModel, asr_models, "✅ 默认ASR模型已初始化")
enable_normalization=True,
enabled=True,
),
]
for m in asr_models:
db.add(m)
db.commit()
print("✅ 默认ASR模型已初始化")
finally:
db.close()
if __name__ == "__main__": if __name__ == "__main__":
@@ -483,9 +449,7 @@ if __name__ == "__main__":
if not args.rebuild_db and not args.rebuild_vector_store and not args.skip_seed: if not args.rebuild_db and not args.rebuild_vector_store and not args.skip_seed:
args.rebuild_db = True args.rebuild_db = True
# 确保 data 目录存在 ensure_db_dir()
data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data")
os.makedirs(data_dir, exist_ok=True)
if args.rebuild_db: if args.rebuild_db:
init_db() init_db()