Files
AI-VideoAssistant/api/init_db.py
2026-02-12 15:59:21 +08:00

477 lines
17 KiB
Python
Raw Blame History

This file contains invisible Unicode characters
This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""初始化数据库"""
import argparse
import os
import sys
from contextlib import contextmanager
from sqlalchemy import inspect, text
# 添加路径
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from app.db import Base, engine, DATABASE_URL
from app.models import Voice, Assistant, KnowledgeBase, Workflow, LLMModel, ASRModel, KnowledgeDocument
def ensure_db_dir():
"""确保 SQLite 数据目录存在。"""
if not DATABASE_URL.startswith("sqlite:///"):
return
db_path = DATABASE_URL.replace("sqlite:///", "")
data_dir = os.path.dirname(db_path)
if data_dir:
os.makedirs(data_dir, exist_ok=True)
@contextmanager
def db_session():
"""统一管理 DB session 生命周期。"""
from app.db import SessionLocal
db = SessionLocal()
try:
yield db
finally:
db.close()
def seed_if_empty(db, model_cls, records, success_msg: str):
"""当目标表为空时写入默认记录。"""
if db.query(model_cls).count() != 0:
return
if isinstance(records, list):
db.add_all(records)
else:
db.add(records)
db.commit()
print(success_msg)
def init_db():
"""创建所有表"""
ensure_db_dir()
print("📦 创建数据库表...")
Base.metadata.drop_all(bind=engine) # 删除旧表
Base.metadata.create_all(bind=engine)
print("✅ 数据库表创建完成")
def migrate_db_schema():
"""对现有数据库执行非破坏性 schema 迁移。"""
inspector = inspect(engine)
table_names = set(inspector.get_table_names())
if "assistants" not in table_names:
print(" assistants 表不存在,跳过增量迁移")
return
columns = {col["name"] for col in inspector.get_columns("assistants")}
alter_statements = []
if "generated_opener_enabled" not in columns:
alter_statements.append(
"ALTER TABLE assistants ADD COLUMN generated_opener_enabled BOOLEAN DEFAULT 0"
)
if "first_turn_mode" not in columns:
alter_statements.append(
"ALTER TABLE assistants ADD COLUMN first_turn_mode VARCHAR(32) DEFAULT 'bot_first'"
)
if "bot_cannot_be_interrupted" not in columns:
alter_statements.append(
"ALTER TABLE assistants ADD COLUMN bot_cannot_be_interrupted BOOLEAN DEFAULT 0"
)
if not alter_statements:
print("✅ Schema 迁移检查完成(无需变更)")
return
with engine.begin() as conn:
for stmt in alter_statements:
conn.execute(text(stmt))
print(f"✅ Schema 迁移完成(应用 {len(alter_statements)} 条 ALTER")
def rebuild_vector_store(reset_doc_status: bool = True):
"""重建知识库向量集合(按 DB 中的 KB 列表重建 collection 壳)。"""
from app.vector_store import vector_store
with db_session() as db:
print("🧹 重建向量库集合...")
kb_list = db.query(KnowledgeBase).all()
# 删除现有 KB 集合
try:
collections = vector_store.client.list_collections()
except Exception as exc:
raise RuntimeError(f"无法读取向量集合列表: {exc}") from exc
for col in collections:
name = getattr(col, "name", None) or str(col)
if name.startswith("kb_"):
try:
vector_store.client.delete_collection(name=name)
print(f" - removed {name}")
except Exception as exc:
print(f" - skip remove {name}: {exc}")
# 按 DB 重建 KB 集合
for kb in kb_list:
vector_store.create_collection(kb.id, kb.embedding_model)
print(f" + created kb_{kb.id} ({kb.embedding_model})")
if reset_doc_status:
kb.chunk_count = 0
docs = db.query(KnowledgeDocument).filter(KnowledgeDocument.kb_id == kb.id).all()
kb.doc_count = 0
for doc in docs:
doc.chunk_count = 0
doc.status = "pending"
doc.error_message = None
doc.processed_at = None
db.commit()
print("✅ 向量库重建完成(仅重建集合壳,文档需重新索引)")
def init_default_data():
with db_session() as db:
# 检查是否已有数据
# SiliconFlow CosyVoice 2.0 预设声音 (8个)
# 参考: https://docs.siliconflow.cn/cn/api-reference/audio/create-speech
voices = [
# 男声 (Male Voices)
Voice(id="alex", name="Alex", vendor="SiliconFlow", gender="Male", language="en",
description="Steady male voice.", is_system=True),
Voice(id="david", name="David", vendor="SiliconFlow", gender="Male", language="en",
description="Cheerful male voice.", is_system=True),
# 女声 (Female Voices)
Voice(id="bella", name="Bella", vendor="SiliconFlow", gender="Female", language="en",
description="Passionate female voice.", is_system=True),
Voice(id="claire", name="Claire", vendor="SiliconFlow", gender="Female", language="en",
description="Gentle female voice.", is_system=True),
]
seed_if_empty(db, Voice, voices, "✅ 默认声音数据已初始化 (SiliconFlow CosyVoice 2.0)")
def init_default_tools(recreate: bool = False):
"""初始化默认工具,或按需重建工具表数据。"""
from app.routers.tools import _seed_default_tools_if_empty, recreate_tool_resources
with db_session() as db:
if recreate:
recreate_tool_resources(db)
print("✅ 工具库已重建")
else:
_seed_default_tools_if_empty(db)
print("✅ 默认工具已初始化")
def init_default_assistants():
"""初始化默认助手"""
with db_session() as db:
assistants = [
Assistant(
id="default",
user_id=1,
name="AI 助手",
call_count=0,
opener="你好我是AI助手有什么可以帮你的吗",
prompt="你是一个友好的AI助手请用简洁清晰的语言回答用户的问题。",
language="zh",
voice_output_enabled=True,
voice="anna",
speed=1.0,
hotwords=[],
tools=["calculator", "current_time"],
interruption_sensitivity=500,
config_mode="platform",
llm_model_id="deepseek-chat",
asr_model_id="paraformer-v2",
),
Assistant(
id="customer_service",
user_id=1,
name="客服助手",
call_count=0,
opener="您好,欢迎致电客服中心,请问有什么可以帮您?",
prompt="你是一个专业的客服人员,耐心解答客户问题,提供优质的服务体验。",
language="zh",
voice_output_enabled=True,
voice="bella",
speed=1.0,
hotwords=["客服", "投诉", "咨询"],
tools=["current_time"],
interruption_sensitivity=600,
config_mode="platform",
),
Assistant(
id="english_tutor",
user_id=1,
name="英语导师",
call_count=0,
opener="Hello! I'm your English learning companion. How can I help you today?",
prompt="You are a friendly English tutor. Help users practice English conversation and explain grammar points clearly.",
language="en",
voice_output_enabled=True,
voice="alex",
speed=1.0,
hotwords=["grammar", "vocabulary", "practice"],
tools=["calculator"],
interruption_sensitivity=400,
config_mode="platform",
),
]
seed_if_empty(db, Assistant, assistants, "✅ 默认助手数据已初始化")
def init_default_workflows():
"""初始化默认工作流"""
from datetime import datetime
with db_session() as db:
now = datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")
workflows = [
Workflow(
id="simple_conversation",
user_id=1,
name="简单对话",
node_count=2,
created_at=now,
updated_at=now,
global_prompt="处理简单的对话流程,用户问什么答什么。",
nodes=[
{"id": "1", "type": "start", "position": {"x": 100, "y": 100}, "data": {"label": "开始"}},
{"id": "2", "type": "ai_reply", "position": {"x": 300, "y": 100}, "data": {"label": "AI回复"}},
],
edges=[{"source": "1", "target": "2", "id": "e1-2"}],
),
Workflow(
id="voice_input_flow",
user_id=1,
name="语音输入流程",
node_count=4,
created_at=now,
updated_at=now,
global_prompt="处理语音输入的完整流程。",
nodes=[
{"id": "1", "type": "start", "position": {"x": 100, "y": 100}, "data": {"label": "开始"}},
{"id": "2", "type": "asr", "position": {"x": 250, "y": 100}, "data": {"label": "语音识别"}},
{"id": "3", "type": "llm", "position": {"x": 400, "y": 100}, "data": {"label": "LLM处理"}},
{"id": "4", "type": "tts", "position": {"x": 550, "y": 100}, "data": {"label": "语音合成"}},
],
edges=[
{"source": "1", "target": "2", "id": "e1-2"},
{"source": "2", "target": "3", "id": "e2-3"},
{"source": "3", "target": "4", "id": "e3-4"},
],
),
]
seed_if_empty(db, Workflow, workflows, "✅ 默认工作流数据已初始化")
def init_default_knowledge_bases():
"""初始化默认知识库"""
with db_session() as db:
kb = KnowledgeBase(
id="default_kb",
user_id=1,
name="默认知识库",
description="系统默认知识库,用于存储常见问题解答。",
embedding_model="text-embedding-3-small",
chunk_size=500,
chunk_overlap=50,
doc_count=0,
chunk_count=0,
status="active",
)
seed_if_empty(db, KnowledgeBase, kb, "✅ 默认知识库已初始化")
def init_default_llm_models():
"""初始化默认LLM模型"""
with db_session() as db:
llm_models = [
LLMModel(
id="deepseek-chat",
user_id=1,
name="DeepSeek Chat",
vendor="SiliconFlow",
type="text",
base_url="https://api.deepseek.com",
api_key="YOUR_API_KEY", # 用户需替换
model_name="deepseek-chat",
temperature=0.7,
context_length=4096,
enabled=True,
),
LLMModel(
id="deepseek-reasoner",
user_id=1,
name="DeepSeek Reasoner",
vendor="SiliconFlow",
type="text",
base_url="https://api.deepseek.com",
api_key="YOUR_API_KEY",
model_name="deepseek-reasoner",
temperature=0.7,
context_length=4096,
enabled=True,
),
LLMModel(
id="gpt-4o",
user_id=1,
name="GPT-4o",
vendor="OpenAI",
type="text",
base_url="https://api.openai.com/v1",
api_key="YOUR_API_KEY",
model_name="gpt-4o",
temperature=0.7,
context_length=16384,
enabled=True,
),
LLMModel(
id="glm-4",
user_id=1,
name="GLM-4",
vendor="ZhipuAI",
type="text",
base_url="https://open.bigmodel.cn/api/paas/v4",
api_key="YOUR_API_KEY",
model_name="glm-4",
temperature=0.7,
context_length=8192,
enabled=True,
),
LLMModel(
id="text-embedding-3-small",
user_id=1,
name="Embedding 3 Small",
vendor="OpenAI",
type="embedding",
base_url="https://api.openai.com/v1",
api_key="YOUR_API_KEY",
model_name="text-embedding-3-small",
enabled=True,
),
]
seed_if_empty(db, LLMModel, llm_models, "✅ 默认LLM模型已初始化")
def init_default_asr_models():
"""初始化默认ASR模型"""
with db_session() as db:
asr_models = [
ASRModel(
id="paraformer-v2",
user_id=1,
name="Paraformer V2",
vendor="SiliconFlow",
language="zh",
base_url="https://api.siliconflow.cn/v1",
api_key="YOUR_API_KEY",
model_name="paraformer-v2",
hotwords=["人工智能", "机器学习"],
enable_punctuation=True,
enable_normalization=True,
enabled=True,
),
ASRModel(
id="paraformer-en",
user_id=1,
name="Paraformer English",
vendor="SiliconFlow",
language="en",
base_url="https://api.siliconflow.cn/v1",
api_key="YOUR_API_KEY",
model_name="paraformer-en",
hotwords=[],
enable_punctuation=True,
enable_normalization=True,
enabled=True,
),
ASRModel(
id="whisper-1",
user_id=1,
name="Whisper",
vendor="OpenAI",
language="Multi-lingual",
base_url="https://api.openai.com/v1",
api_key="YOUR_API_KEY",
model_name="whisper-1",
hotwords=[],
enable_punctuation=True,
enable_normalization=True,
enabled=True,
),
ASRModel(
id="sensevoice",
user_id=1,
name="SenseVoice",
vendor="SiliconFlow",
language="Multi-lingual",
base_url="https://api.siliconflow.cn/v1",
api_key="YOUR_API_KEY",
model_name="sensevoice",
hotwords=[],
enable_punctuation=True,
enable_normalization=True,
enabled=True,
),
]
seed_if_empty(db, ASRModel, asr_models, "✅ 默认ASR模型已初始化")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="初始化/重建 AI VideoAssistant 数据与向量库")
parser.add_argument(
"--rebuild-db",
action="store_true",
help="重建数据库drop + create tables",
)
parser.add_argument(
"--rebuild-vector-store",
action="store_true",
help="重建向量库 KB 集合(清空后按 DB 的 knowledge_bases 重建 collection",
)
parser.add_argument(
"--skip-seed",
action="store_true",
help="跳过默认数据初始化",
)
parser.add_argument(
"--recreate-tool-db",
action="store_true",
help="重建工具库数据(清空 tool_resources 后按内置默认工具重建)",
)
args = parser.parse_args()
# 无参数时保持旧行为:重建 DB + 初始化默认数据
if not args.rebuild_db and not args.rebuild_vector_store and not args.skip_seed:
args.rebuild_db = True
ensure_db_dir()
if args.rebuild_db:
init_db()
else:
migrate_db_schema()
if args.recreate_tool_db:
init_default_tools(recreate=True)
if not args.skip_seed:
init_default_data()
if not args.recreate_tool_db:
init_default_tools(recreate=False)
init_default_assistants()
init_default_workflows()
init_default_knowledge_bases()
init_default_llm_models()
init_default_asr_models()
print("✅ 默认数据初始化完成")
if args.rebuild_vector_store:
rebuild_vector_store(reset_doc_status=True)
print("🎉 初始化脚本执行完成!")