Files
AI-VideoAssistant/api/init_db.py
2026-02-26 03:54:52 +08:00

476 lines
16 KiB
Python
Raw Blame History

This file contains invisible Unicode characters
This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""初始化数据库"""
import argparse
import os
import sys
from contextlib import contextmanager
# 添加路径
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from app.db import Base, engine, DATABASE_URL
from app.id_generator import short_id
from app.models import Voice, Assistant, KnowledgeBase, Workflow, LLMModel, ASRModel, KnowledgeDocument
VOICE_MODEL = "FunAudioLLM/CosyVoice2-0.5B"
DASHSCOPE_VOICE_MODEL = "qwen3-tts-flash-realtime"
DASHSCOPE_DEFAULT_VOICE_KEY = "Cherry"
DASHSCOPE_REALTIME_URL = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime"
SEED_VOICE_IDS = {
"alex": short_id("tts"),
"david": short_id("tts"),
"bella": short_id("tts"),
"claire": short_id("tts"),
"dashscope_cherry": short_id("tts"),
}
SEED_LLM_IDS = {
"deepseek_chat": short_id("llm"),
"glm_4": short_id("llm"),
"embedding_3_small": short_id("llm"),
}
SEED_ASR_IDS = {
"sensevoice_small": short_id("asr"),
"telespeech_asr": short_id("asr"),
}
SEED_ASSISTANT_IDS = {
"default": short_id("ast"),
"customer_service": short_id("ast"),
"english_tutor": short_id("ast"),
}
def ensure_db_dir():
"""确保 SQLite 数据目录存在。"""
if not DATABASE_URL.startswith("sqlite:///"):
return
db_path = DATABASE_URL.replace("sqlite:///", "")
data_dir = os.path.dirname(db_path)
if data_dir:
os.makedirs(data_dir, exist_ok=True)
@contextmanager
def db_session():
"""统一管理 DB session 生命周期。"""
from app.db import SessionLocal
db = SessionLocal()
try:
yield db
finally:
db.close()
def seed_if_empty(db, model_cls, records, success_msg: str):
"""当目标表为空时写入默认记录。"""
if db.query(model_cls).count() != 0:
return
if isinstance(records, list):
db.add_all(records)
else:
db.add(records)
db.commit()
print(success_msg)
def init_db():
"""创建所有表"""
ensure_db_dir()
print("📦 创建数据库表...")
Base.metadata.drop_all(bind=engine) # 删除旧表
Base.metadata.create_all(bind=engine)
print("✅ 数据库表创建完成")
def rebuild_vector_store(reset_doc_status: bool = True):
"""重建知识库向量集合(按 DB 中的 KB 列表重建 collection 壳)。"""
from app.vector_store import vector_store
with db_session() as db:
print("🧹 重建向量库集合...")
kb_list = db.query(KnowledgeBase).all()
# 删除现有 KB 集合
try:
collections = vector_store.client.list_collections()
except Exception as exc:
raise RuntimeError(f"无法读取向量集合列表: {exc}") from exc
for col in collections:
name = getattr(col, "name", None) or str(col)
if name.startswith("kb_"):
try:
vector_store.client.delete_collection(name=name)
print(f" - removed {name}")
except Exception as exc:
print(f" - skip remove {name}: {exc}")
# 按 DB 重建 KB 集合
for kb in kb_list:
vector_store.create_collection(kb.id, kb.embedding_model)
print(f" + created kb_{kb.id} ({kb.embedding_model})")
if reset_doc_status:
kb.chunk_count = 0
docs = db.query(KnowledgeDocument).filter(KnowledgeDocument.kb_id == kb.id).all()
kb.doc_count = 0
for doc in docs:
doc.chunk_count = 0
doc.status = "pending"
doc.error_message = None
doc.processed_at = None
db.commit()
print("✅ 向量库重建完成(仅重建集合壳,文档需重新索引)")
def init_default_data():
with db_session() as db:
# 检查是否已有数据
# OpenAI Compatible (SiliconFlow API) CosyVoice 2.0 预设声音 (8个)
# 参考: https://docs.siliconflow.cn/cn/api-reference/audio/create-speech
voices = [
# 男声 (Male Voices)
Voice(
id=SEED_VOICE_IDS["alex"],
name="Alex",
vendor="OpenAI Compatible",
gender="Male",
language="en",
description="Steady male voice.",
model=VOICE_MODEL,
voice_key=f"{VOICE_MODEL}:alex",
is_system=True,
),
Voice(
id=SEED_VOICE_IDS["david"],
name="David",
vendor="OpenAI Compatible",
gender="Male",
language="en",
description="Cheerful male voice.",
model=VOICE_MODEL,
voice_key=f"{VOICE_MODEL}:david",
is_system=True,
),
# 女声 (Female Voices)
Voice(
id=SEED_VOICE_IDS["bella"],
name="Bella",
vendor="OpenAI Compatible",
gender="Female",
language="en",
description="Passionate female voice.",
model=VOICE_MODEL,
voice_key=f"{VOICE_MODEL}:bella",
is_system=True,
),
Voice(
id=SEED_VOICE_IDS["claire"],
name="Claire",
vendor="OpenAI Compatible",
gender="Female",
language="en",
description="Gentle female voice.",
model=VOICE_MODEL,
voice_key=f"{VOICE_MODEL}:claire",
is_system=True,
),
Voice(
id=SEED_VOICE_IDS["dashscope_cherry"],
name="DashScope Cherry",
vendor="DashScope",
gender="Female",
language="zh",
description="DashScope realtime sample voice.",
model=DASHSCOPE_VOICE_MODEL,
voice_key=DASHSCOPE_DEFAULT_VOICE_KEY,
base_url=DASHSCOPE_REALTIME_URL,
is_system=True,
),
]
seed_if_empty(db, Voice, voices, "✅ 默认声音数据已初始化 (OpenAI Compatible + DashScope)")
def init_default_tools(recreate: bool = False):
"""初始化默认工具,或按需重建工具表数据。"""
from app.routers.tools import _seed_default_tools_if_empty, recreate_tool_resources
with db_session() as db:
if recreate:
recreate_tool_resources(db)
print("✅ 工具库已重建")
else:
_seed_default_tools_if_empty(db)
print("✅ 默认工具已初始化")
def init_default_assistants():
"""初始化默认助手"""
with db_session() as db:
assistants = [
Assistant(
id=SEED_ASSISTANT_IDS["default"],
user_id=1,
name="AI 助手",
call_count=0,
opener="你好我是AI助手有什么可以帮你的吗",
prompt="你是一个友好的AI助手请用简洁清晰的语言回答用户的问题。",
language="zh",
voice_output_enabled=True,
voice=SEED_VOICE_IDS["bella"],
speed=1.0,
hotwords=[],
tools=["current_time"],
interruption_sensitivity=500,
config_mode="platform",
llm_model_id=SEED_LLM_IDS["deepseek_chat"],
asr_model_id=SEED_ASR_IDS["sensevoice_small"],
),
Assistant(
id=SEED_ASSISTANT_IDS["customer_service"],
user_id=1,
name="客服助手",
call_count=0,
opener="您好,欢迎致电客服中心,请问有什么可以帮您?",
prompt="你是一个专业的客服人员,耐心解答客户问题,提供优质的服务体验。",
language="zh",
voice_output_enabled=True,
voice=SEED_VOICE_IDS["claire"],
speed=1.0,
hotwords=["客服", "投诉", "咨询"],
tools=["current_time"],
interruption_sensitivity=600,
config_mode="platform",
),
Assistant(
id=SEED_ASSISTANT_IDS["english_tutor"],
user_id=1,
name="英语导师",
call_count=0,
opener="Hello! I'm your English learning companion. How can I help you today?",
prompt="You are a friendly English tutor. Help users practice English conversation and explain grammar points clearly.",
language="en",
voice_output_enabled=True,
voice=SEED_VOICE_IDS["alex"],
speed=1.0,
hotwords=["grammar", "vocabulary", "practice"],
tools=["current_time"],
interruption_sensitivity=400,
config_mode="platform",
),
]
seed_if_empty(db, Assistant, assistants, "✅ 默认助手数据已初始化")
def init_default_workflows():
"""初始化默认工作流"""
from datetime import datetime
with db_session() as db:
now = datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")
workflows = [
Workflow(
id="simple_conversation",
user_id=1,
name="简单对话",
node_count=2,
created_at=now,
updated_at=now,
global_prompt="处理简单的对话流程,用户问什么答什么。",
nodes=[
{"id": "1", "type": "start", "position": {"x": 100, "y": 100}, "data": {"label": "开始"}},
{"id": "2", "type": "ai_reply", "position": {"x": 300, "y": 100}, "data": {"label": "AI回复"}},
],
edges=[{"source": "1", "target": "2", "id": "e1-2"}],
),
Workflow(
id="voice_input_flow",
user_id=1,
name="语音输入流程",
node_count=4,
created_at=now,
updated_at=now,
global_prompt="处理语音输入的完整流程。",
nodes=[
{"id": "1", "type": "start", "position": {"x": 100, "y": 100}, "data": {"label": "开始"}},
{"id": "2", "type": "asr", "position": {"x": 250, "y": 100}, "data": {"label": "语音识别"}},
{"id": "3", "type": "llm", "position": {"x": 400, "y": 100}, "data": {"label": "LLM处理"}},
{"id": "4", "type": "tts", "position": {"x": 550, "y": 100}, "data": {"label": "语音合成"}},
],
edges=[
{"source": "1", "target": "2", "id": "e1-2"},
{"source": "2", "target": "3", "id": "e2-3"},
{"source": "3", "target": "4", "id": "e3-4"},
],
),
]
seed_if_empty(db, Workflow, workflows, "✅ 默认工作流数据已初始化")
def init_default_knowledge_bases():
"""初始化默认知识库"""
with db_session() as db:
kb = KnowledgeBase(
id="default_kb",
user_id=1,
name="默认知识库",
description="系统默认知识库,用于存储常见问题解答。",
embedding_model="text-embedding-3-small",
chunk_size=500,
chunk_overlap=50,
doc_count=0,
chunk_count=0,
status="active",
)
seed_if_empty(db, KnowledgeBase, kb, "✅ 默认知识库已初始化")
def init_default_llm_models():
"""初始化默认LLM模型"""
with db_session() as db:
llm_models = [
LLMModel(
id=SEED_LLM_IDS["deepseek_chat"],
user_id=1,
name="DeepSeek Chat",
vendor="OpenAI Compatible",
type="text",
base_url="https://api.deepseek.com",
api_key="YOUR_API_KEY", # 用户需替换
model_name="deepseek-chat",
temperature=0.7,
context_length=4096,
enabled=True,
),
LLMModel(
id=SEED_LLM_IDS["glm_4"],
user_id=1,
name="GLM-4",
vendor="ZhipuAI",
type="text",
base_url="https://open.bigmodel.cn/api/paas/v4",
api_key="YOUR_API_KEY",
model_name="glm-4",
temperature=0.7,
context_length=8192,
enabled=True,
),
LLMModel(
id=SEED_LLM_IDS["embedding_3_small"],
user_id=1,
name="Embedding 3 Small",
vendor="OpenAI Compatible",
type="embedding",
base_url="https://api.openai.com/v1",
api_key="YOUR_API_KEY",
model_name="text-embedding-3-small",
enabled=True,
),
]
seed_if_empty(db, LLMModel, llm_models, "✅ 默认LLM模型已初始化")
def init_default_asr_models():
"""初始化默认ASR模型"""
with db_session() as db:
asr_models = [
ASRModel(
id=SEED_ASR_IDS["sensevoice_small"],
user_id=1,
name="FunAudioLLM/SenseVoiceSmall",
vendor="OpenAI Compatible",
language="Multi-lingual",
base_url="https://api.siliconflow.cn/v1",
api_key="YOUR_API_KEY",
model_name="FunAudioLLM/SenseVoiceSmall",
hotwords=[],
enable_punctuation=True,
enable_normalization=True,
enabled=True,
),
ASRModel(
id=SEED_ASR_IDS["telespeech_asr"],
user_id=1,
name="TeleAI/TeleSpeechASR",
vendor="OpenAI Compatible",
language="Multi-lingual",
base_url="https://api.siliconflow.cn/v1",
api_key="YOUR_API_KEY",
model_name="TeleAI/TeleSpeechASR",
hotwords=[],
enable_punctuation=True,
enable_normalization=True,
enabled=True,
),
]
seed_if_empty(db, ASRModel, asr_models, "✅ 默认ASR模型已初始化")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="初始化/重建 AI VideoAssistant 数据与向量库")
parser.add_argument(
"--rebuild-db",
action="store_true",
help="重建数据库drop + create tables",
)
parser.add_argument(
"--rebuild-vector-store",
action="store_true",
help="重建向量库 KB 集合(清空后按 DB 的 knowledge_bases 重建 collection",
)
parser.add_argument(
"--skip-seed",
action="store_true",
help="跳过默认数据初始化",
)
parser.add_argument(
"--recreate-tool-db",
action="store_true",
help="重建工具库数据(清空 tool_resources 后按内置默认工具重建)",
)
args = parser.parse_args()
# 无参数时保持旧行为:重建 DB + 初始化默认数据
# 仅当完全未指定任何选项时才自动触发 rebuild-db。
if (
not args.rebuild_db
and not args.rebuild_vector_store
and not args.skip_seed
and not args.recreate_tool_db
):
args.rebuild_db = True
ensure_db_dir()
if args.rebuild_db:
init_db()
else:
print(" 跳过数据库结构变更(未指定 --rebuild-db")
if not args.skip_seed or args.recreate_tool_db:
print(" 当前将执行非破坏性流程(仅工具/默认数据初始化)")
if args.recreate_tool_db:
init_default_tools(recreate=True)
if not args.skip_seed:
init_default_data()
if not args.recreate_tool_db:
init_default_tools(recreate=False)
init_default_assistants()
init_default_workflows()
init_default_knowledge_bases()
init_default_llm_models()
init_default_asr_models()
print("✅ 默认数据初始化完成")
if args.rebuild_vector_store:
rebuild_vector_store(reset_doc_status=True)
print("🎉 初始化脚本执行完成!")