Remove redundant code in init_db
This commit is contained in:
138
api/init_db.py
138
api/init_db.py
@@ -3,6 +3,7 @@
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
from sqlalchemy import inspect, text
|
||||
|
||||
# 添加路径
|
||||
@@ -12,11 +13,43 @@ from app.db import Base, engine, DATABASE_URL
|
||||
from app.models import Voice, Assistant, KnowledgeBase, Workflow, LLMModel, ASRModel, KnowledgeDocument
|
||||
|
||||
|
||||
def ensure_db_dir():
|
||||
"""确保 SQLite 数据目录存在。"""
|
||||
if not DATABASE_URL.startswith("sqlite:///"):
|
||||
return
|
||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||
data_dir = os.path.dirname(db_path)
|
||||
if data_dir:
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def db_session():
|
||||
"""统一管理 DB session 生命周期。"""
|
||||
from app.db import SessionLocal
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def seed_if_empty(db, model_cls, records, success_msg: str):
|
||||
"""当目标表为空时写入默认记录。"""
|
||||
if db.query(model_cls).count() != 0:
|
||||
return
|
||||
if isinstance(records, list):
|
||||
db.add_all(records)
|
||||
else:
|
||||
db.add(records)
|
||||
db.commit()
|
||||
print(success_msg)
|
||||
|
||||
|
||||
def init_db():
|
||||
"""创建所有表"""
|
||||
# 确保 data 目录存在
|
||||
data_dir = os.path.dirname(DATABASE_URL.replace("sqlite:///", ""))
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
ensure_db_dir()
|
||||
|
||||
print("📦 创建数据库表...")
|
||||
Base.metadata.drop_all(bind=engine) # 删除旧表
|
||||
@@ -59,11 +92,9 @@ def migrate_db_schema():
|
||||
|
||||
def rebuild_vector_store(reset_doc_status: bool = True):
|
||||
"""重建知识库向量集合(按 DB 中的 KB 列表重建 collection 壳)。"""
|
||||
from app.db import SessionLocal
|
||||
from app.vector_store import vector_store
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
with db_session() as db:
|
||||
print("🧹 重建向量库集合...")
|
||||
kb_list = db.query(KnowledgeBase).all()
|
||||
|
||||
@@ -99,19 +130,11 @@ def rebuild_vector_store(reset_doc_status: bool = True):
|
||||
|
||||
db.commit()
|
||||
print("✅ 向量库重建完成(仅重建集合壳,文档需重新索引)")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def init_default_data():
|
||||
from sqlalchemy.orm import Session
|
||||
from app.db import SessionLocal
|
||||
from app.models import Voice
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
with db_session() as db:
|
||||
# 检查是否已有数据
|
||||
if db.query(Voice).count() == 0:
|
||||
# SiliconFlow CosyVoice 2.0 预设声音 (8个)
|
||||
# 参考: https://docs.siliconflow.cn/cn/api-reference/audio/create-speech
|
||||
voices = [
|
||||
@@ -126,39 +149,25 @@ def init_default_data():
|
||||
Voice(id="claire", name="Claire", vendor="SiliconFlow", gender="Female", language="en",
|
||||
description="Gentle female voice.", is_system=True),
|
||||
]
|
||||
for v in voices:
|
||||
db.add(v)
|
||||
db.commit()
|
||||
print("✅ 默认声音数据已初始化 (SiliconFlow CosyVoice 2.0)")
|
||||
finally:
|
||||
db.close()
|
||||
seed_if_empty(db, Voice, voices, "✅ 默认声音数据已初始化 (SiliconFlow CosyVoice 2.0)")
|
||||
|
||||
|
||||
def init_default_tools(recreate: bool = False):
|
||||
"""初始化默认工具,或按需重建工具表数据。"""
|
||||
from app.db import SessionLocal
|
||||
from app.routers.tools import _seed_default_tools_if_empty, recreate_tool_resources
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
with db_session() as db:
|
||||
if recreate:
|
||||
recreate_tool_resources(db)
|
||||
print("✅ 工具库已重建")
|
||||
else:
|
||||
_seed_default_tools_if_empty(db)
|
||||
print("✅ 默认工具已初始化")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def init_default_assistants():
|
||||
"""初始化默认助手"""
|
||||
from sqlalchemy.orm import Session
|
||||
from app.db import SessionLocal
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
if db.query(Assistant).count() == 0:
|
||||
with db_session() as db:
|
||||
assistants = [
|
||||
Assistant(
|
||||
id="default",
|
||||
@@ -211,23 +220,14 @@ def init_default_assistants():
|
||||
config_mode="platform",
|
||||
),
|
||||
]
|
||||
for a in assistants:
|
||||
db.add(a)
|
||||
db.commit()
|
||||
print("✅ 默认助手数据已初始化")
|
||||
finally:
|
||||
db.close()
|
||||
seed_if_empty(db, Assistant, assistants, "✅ 默认助手数据已初始化")
|
||||
|
||||
|
||||
def init_default_workflows():
|
||||
"""初始化默认工作流"""
|
||||
from sqlalchemy.orm import Session
|
||||
from app.db import SessionLocal
|
||||
from datetime import datetime
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
if db.query(Workflow).count() == 0:
|
||||
with db_session() as db:
|
||||
now = datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")
|
||||
workflows = [
|
||||
Workflow(
|
||||
@@ -265,22 +265,12 @@ def init_default_workflows():
|
||||
],
|
||||
),
|
||||
]
|
||||
for w in workflows:
|
||||
db.add(w)
|
||||
db.commit()
|
||||
print("✅ 默认工作流数据已初始化")
|
||||
finally:
|
||||
db.close()
|
||||
seed_if_empty(db, Workflow, workflows, "✅ 默认工作流数据已初始化")
|
||||
|
||||
|
||||
def init_default_knowledge_bases():
|
||||
"""初始化默认知识库"""
|
||||
from sqlalchemy.orm import Session
|
||||
from app.db import SessionLocal
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
if db.query(KnowledgeBase).count() == 0:
|
||||
with db_session() as db:
|
||||
kb = KnowledgeBase(
|
||||
id="default_kb",
|
||||
user_id=1,
|
||||
@@ -293,21 +283,12 @@ def init_default_knowledge_bases():
|
||||
chunk_count=0,
|
||||
status="active",
|
||||
)
|
||||
db.add(kb)
|
||||
db.commit()
|
||||
print("✅ 默认知识库已初始化")
|
||||
finally:
|
||||
db.close()
|
||||
seed_if_empty(db, KnowledgeBase, kb, "✅ 默认知识库已初始化")
|
||||
|
||||
|
||||
def init_default_llm_models():
|
||||
"""初始化默认LLM模型"""
|
||||
from sqlalchemy.orm import Session
|
||||
from app.db import SessionLocal
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
if db.query(LLMModel).count() == 0:
|
||||
with db_session() as db:
|
||||
llm_models = [
|
||||
LLMModel(
|
||||
id="deepseek-chat",
|
||||
@@ -373,22 +354,12 @@ def init_default_llm_models():
|
||||
enabled=True,
|
||||
),
|
||||
]
|
||||
for m in llm_models:
|
||||
db.add(m)
|
||||
db.commit()
|
||||
print("✅ 默认LLM模型已初始化")
|
||||
finally:
|
||||
db.close()
|
||||
seed_if_empty(db, LLMModel, llm_models, "✅ 默认LLM模型已初始化")
|
||||
|
||||
|
||||
def init_default_asr_models():
|
||||
"""初始化默认ASR模型"""
|
||||
from sqlalchemy.orm import Session
|
||||
from app.db import SessionLocal
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
if db.query(ASRModel).count() == 0:
|
||||
with db_session() as db:
|
||||
asr_models = [
|
||||
ASRModel(
|
||||
id="paraformer-v2",
|
||||
@@ -447,12 +418,7 @@ def init_default_asr_models():
|
||||
enabled=True,
|
||||
),
|
||||
]
|
||||
for m in asr_models:
|
||||
db.add(m)
|
||||
db.commit()
|
||||
print("✅ 默认ASR模型已初始化")
|
||||
finally:
|
||||
db.close()
|
||||
seed_if_empty(db, ASRModel, asr_models, "✅ 默认ASR模型已初始化")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -483,9 +449,7 @@ if __name__ == "__main__":
|
||||
if not args.rebuild_db and not args.rebuild_vector_store and not args.skip_seed:
|
||||
args.rebuild_db = True
|
||||
|
||||
# 确保 data 目录存在
|
||||
data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data")
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
ensure_db_dir()
|
||||
|
||||
if args.rebuild_db:
|
||||
init_db()
|
||||
|
||||
Reference in New Issue
Block a user