452 lines
15 KiB
Python
452 lines
15 KiB
Python
#!/usr/bin/env python3
|
||
"""初始化数据库"""
|
||
import argparse
|
||
import os
|
||
import sys
|
||
from contextlib import contextmanager
|
||
|
||
# 添加路径
|
||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||
|
||
from app.db import Base, engine, DATABASE_URL
|
||
from app.id_generator import short_id
|
||
from app.models import Voice, Assistant, KnowledgeBase, Workflow, LLMModel, ASRModel, KnowledgeDocument
|
||
|
||
VOICE_MODEL = "FunAudioLLM/CosyVoice2-0.5B"
|
||
|
||
SEED_VOICE_IDS = {
|
||
"alex": short_id("tts"),
|
||
"david": short_id("tts"),
|
||
"bella": short_id("tts"),
|
||
"claire": short_id("tts"),
|
||
}
|
||
|
||
SEED_LLM_IDS = {
|
||
"deepseek_chat": short_id("llm"),
|
||
"glm_4": short_id("llm"),
|
||
"embedding_3_small": short_id("llm"),
|
||
}
|
||
|
||
SEED_ASR_IDS = {
|
||
"sensevoice_small": short_id("asr"),
|
||
"telespeech_asr": short_id("asr"),
|
||
}
|
||
|
||
SEED_ASSISTANT_IDS = {
|
||
"default": short_id("ast"),
|
||
"customer_service": short_id("ast"),
|
||
"english_tutor": short_id("ast"),
|
||
}
|
||
|
||
|
||
def ensure_db_dir():
|
||
"""确保 SQLite 数据目录存在。"""
|
||
if not DATABASE_URL.startswith("sqlite:///"):
|
||
return
|
||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||
data_dir = os.path.dirname(db_path)
|
||
if data_dir:
|
||
os.makedirs(data_dir, exist_ok=True)
|
||
|
||
|
||
@contextmanager
|
||
def db_session():
|
||
"""统一管理 DB session 生命周期。"""
|
||
from app.db import SessionLocal
|
||
|
||
db = SessionLocal()
|
||
try:
|
||
yield db
|
||
finally:
|
||
db.close()
|
||
|
||
|
||
def seed_if_empty(db, model_cls, records, success_msg: str):
|
||
"""当目标表为空时写入默认记录。"""
|
||
if db.query(model_cls).count() != 0:
|
||
return
|
||
if isinstance(records, list):
|
||
db.add_all(records)
|
||
else:
|
||
db.add(records)
|
||
db.commit()
|
||
print(success_msg)
|
||
|
||
|
||
def init_db():
|
||
"""创建所有表"""
|
||
ensure_db_dir()
|
||
|
||
print("📦 创建数据库表...")
|
||
Base.metadata.drop_all(bind=engine) # 删除旧表
|
||
Base.metadata.create_all(bind=engine)
|
||
print("✅ 数据库表创建完成")
|
||
|
||
|
||
def rebuild_vector_store(reset_doc_status: bool = True):
|
||
"""重建知识库向量集合(按 DB 中的 KB 列表重建 collection 壳)。"""
|
||
from app.vector_store import vector_store
|
||
|
||
with db_session() as db:
|
||
print("🧹 重建向量库集合...")
|
||
kb_list = db.query(KnowledgeBase).all()
|
||
|
||
# 删除现有 KB 集合
|
||
try:
|
||
collections = vector_store.client.list_collections()
|
||
except Exception as exc:
|
||
raise RuntimeError(f"无法读取向量集合列表: {exc}") from exc
|
||
|
||
for col in collections:
|
||
name = getattr(col, "name", None) or str(col)
|
||
if name.startswith("kb_"):
|
||
try:
|
||
vector_store.client.delete_collection(name=name)
|
||
print(f" - removed {name}")
|
||
except Exception as exc:
|
||
print(f" - skip remove {name}: {exc}")
|
||
|
||
# 按 DB 重建 KB 集合
|
||
for kb in kb_list:
|
||
vector_store.create_collection(kb.id, kb.embedding_model)
|
||
print(f" + created kb_{kb.id} ({kb.embedding_model})")
|
||
|
||
if reset_doc_status:
|
||
kb.chunk_count = 0
|
||
docs = db.query(KnowledgeDocument).filter(KnowledgeDocument.kb_id == kb.id).all()
|
||
kb.doc_count = 0
|
||
for doc in docs:
|
||
doc.chunk_count = 0
|
||
doc.status = "pending"
|
||
doc.error_message = None
|
||
doc.processed_at = None
|
||
|
||
db.commit()
|
||
print("✅ 向量库重建完成(仅重建集合壳,文档需重新索引)")
|
||
|
||
|
||
def init_default_data():
|
||
with db_session() as db:
|
||
# 检查是否已有数据
|
||
# OpenAI Compatible (SiliconFlow API) CosyVoice 2.0 预设声音 (8个)
|
||
# 参考: https://docs.siliconflow.cn/cn/api-reference/audio/create-speech
|
||
voices = [
|
||
# 男声 (Male Voices)
|
||
Voice(
|
||
id=SEED_VOICE_IDS["alex"],
|
||
name="Alex",
|
||
vendor="OpenAI Compatible",
|
||
gender="Male",
|
||
language="en",
|
||
description="Steady male voice.",
|
||
model=VOICE_MODEL,
|
||
voice_key=f"{VOICE_MODEL}:alex",
|
||
is_system=True,
|
||
),
|
||
Voice(
|
||
id=SEED_VOICE_IDS["david"],
|
||
name="David",
|
||
vendor="OpenAI Compatible",
|
||
gender="Male",
|
||
language="en",
|
||
description="Cheerful male voice.",
|
||
model=VOICE_MODEL,
|
||
voice_key=f"{VOICE_MODEL}:david",
|
||
is_system=True,
|
||
),
|
||
# 女声 (Female Voices)
|
||
Voice(
|
||
id=SEED_VOICE_IDS["bella"],
|
||
name="Bella",
|
||
vendor="OpenAI Compatible",
|
||
gender="Female",
|
||
language="en",
|
||
description="Passionate female voice.",
|
||
model=VOICE_MODEL,
|
||
voice_key=f"{VOICE_MODEL}:bella",
|
||
is_system=True,
|
||
),
|
||
Voice(
|
||
id=SEED_VOICE_IDS["claire"],
|
||
name="Claire",
|
||
vendor="OpenAI Compatible",
|
||
gender="Female",
|
||
language="en",
|
||
description="Gentle female voice.",
|
||
model=VOICE_MODEL,
|
||
voice_key=f"{VOICE_MODEL}:claire",
|
||
is_system=True,
|
||
),
|
||
]
|
||
seed_if_empty(db, Voice, voices, "✅ 默认声音数据已初始化 (OpenAI Compatible CosyVoice 2.0)")
|
||
|
||
|
||
def init_default_tools(recreate: bool = False):
|
||
"""初始化默认工具,或按需重建工具表数据。"""
|
||
from app.routers.tools import _seed_default_tools_if_empty, recreate_tool_resources
|
||
|
||
with db_session() as db:
|
||
if recreate:
|
||
recreate_tool_resources(db)
|
||
print("✅ 工具库已重建")
|
||
else:
|
||
_seed_default_tools_if_empty(db)
|
||
print("✅ 默认工具已初始化")
|
||
|
||
|
||
def init_default_assistants():
|
||
"""初始化默认助手"""
|
||
with db_session() as db:
|
||
assistants = [
|
||
Assistant(
|
||
id=SEED_ASSISTANT_IDS["default"],
|
||
user_id=1,
|
||
name="AI 助手",
|
||
call_count=0,
|
||
opener="你好!我是AI助手,有什么可以帮你的吗?",
|
||
prompt="你是一个友好的AI助手,请用简洁清晰的语言回答用户的问题。",
|
||
language="zh",
|
||
voice_output_enabled=True,
|
||
voice=SEED_VOICE_IDS["bella"],
|
||
speed=1.0,
|
||
hotwords=[],
|
||
tools=["current_time"],
|
||
interruption_sensitivity=500,
|
||
config_mode="platform",
|
||
llm_model_id=SEED_LLM_IDS["deepseek_chat"],
|
||
asr_model_id=SEED_ASR_IDS["sensevoice_small"],
|
||
),
|
||
Assistant(
|
||
id=SEED_ASSISTANT_IDS["customer_service"],
|
||
user_id=1,
|
||
name="客服助手",
|
||
call_count=0,
|
||
opener="您好,欢迎致电客服中心,请问有什么可以帮您?",
|
||
prompt="你是一个专业的客服人员,耐心解答客户问题,提供优质的服务体验。",
|
||
language="zh",
|
||
voice_output_enabled=True,
|
||
voice=SEED_VOICE_IDS["claire"],
|
||
speed=1.0,
|
||
hotwords=["客服", "投诉", "咨询"],
|
||
tools=["current_time"],
|
||
interruption_sensitivity=600,
|
||
config_mode="platform",
|
||
),
|
||
Assistant(
|
||
id=SEED_ASSISTANT_IDS["english_tutor"],
|
||
user_id=1,
|
||
name="英语导师",
|
||
call_count=0,
|
||
opener="Hello! I'm your English learning companion. How can I help you today?",
|
||
prompt="You are a friendly English tutor. Help users practice English conversation and explain grammar points clearly.",
|
||
language="en",
|
||
voice_output_enabled=True,
|
||
voice=SEED_VOICE_IDS["alex"],
|
||
speed=1.0,
|
||
hotwords=["grammar", "vocabulary", "practice"],
|
||
tools=["current_time"],
|
||
interruption_sensitivity=400,
|
||
config_mode="platform",
|
||
),
|
||
]
|
||
seed_if_empty(db, Assistant, assistants, "✅ 默认助手数据已初始化")
|
||
|
||
|
||
def init_default_workflows():
|
||
"""初始化默认工作流"""
|
||
from datetime import datetime
|
||
|
||
with db_session() as db:
|
||
now = datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")
|
||
workflows = [
|
||
Workflow(
|
||
id="simple_conversation",
|
||
user_id=1,
|
||
name="简单对话",
|
||
node_count=2,
|
||
created_at=now,
|
||
updated_at=now,
|
||
global_prompt="处理简单的对话流程,用户问什么答什么。",
|
||
nodes=[
|
||
{"id": "1", "type": "start", "position": {"x": 100, "y": 100}, "data": {"label": "开始"}},
|
||
{"id": "2", "type": "ai_reply", "position": {"x": 300, "y": 100}, "data": {"label": "AI回复"}},
|
||
],
|
||
edges=[{"source": "1", "target": "2", "id": "e1-2"}],
|
||
),
|
||
Workflow(
|
||
id="voice_input_flow",
|
||
user_id=1,
|
||
name="语音输入流程",
|
||
node_count=4,
|
||
created_at=now,
|
||
updated_at=now,
|
||
global_prompt="处理语音输入的完整流程。",
|
||
nodes=[
|
||
{"id": "1", "type": "start", "position": {"x": 100, "y": 100}, "data": {"label": "开始"}},
|
||
{"id": "2", "type": "asr", "position": {"x": 250, "y": 100}, "data": {"label": "语音识别"}},
|
||
{"id": "3", "type": "llm", "position": {"x": 400, "y": 100}, "data": {"label": "LLM处理"}},
|
||
{"id": "4", "type": "tts", "position": {"x": 550, "y": 100}, "data": {"label": "语音合成"}},
|
||
],
|
||
edges=[
|
||
{"source": "1", "target": "2", "id": "e1-2"},
|
||
{"source": "2", "target": "3", "id": "e2-3"},
|
||
{"source": "3", "target": "4", "id": "e3-4"},
|
||
],
|
||
),
|
||
]
|
||
seed_if_empty(db, Workflow, workflows, "✅ 默认工作流数据已初始化")
|
||
|
||
|
||
def init_default_knowledge_bases():
|
||
"""初始化默认知识库"""
|
||
with db_session() as db:
|
||
kb = KnowledgeBase(
|
||
id="default_kb",
|
||
user_id=1,
|
||
name="默认知识库",
|
||
description="系统默认知识库,用于存储常见问题解答。",
|
||
embedding_model="text-embedding-3-small",
|
||
chunk_size=500,
|
||
chunk_overlap=50,
|
||
doc_count=0,
|
||
chunk_count=0,
|
||
status="active",
|
||
)
|
||
seed_if_empty(db, KnowledgeBase, kb, "✅ 默认知识库已初始化")
|
||
|
||
|
||
def init_default_llm_models():
|
||
"""初始化默认LLM模型"""
|
||
with db_session() as db:
|
||
llm_models = [
|
||
LLMModel(
|
||
id=SEED_LLM_IDS["deepseek_chat"],
|
||
user_id=1,
|
||
name="DeepSeek Chat",
|
||
vendor="OpenAI Compatible",
|
||
type="text",
|
||
base_url="https://api.deepseek.com",
|
||
api_key="YOUR_API_KEY", # 用户需替换
|
||
model_name="deepseek-chat",
|
||
temperature=0.7,
|
||
context_length=4096,
|
||
enabled=True,
|
||
),
|
||
LLMModel(
|
||
id=SEED_LLM_IDS["glm_4"],
|
||
user_id=1,
|
||
name="GLM-4",
|
||
vendor="ZhipuAI",
|
||
type="text",
|
||
base_url="https://open.bigmodel.cn/api/paas/v4",
|
||
api_key="YOUR_API_KEY",
|
||
model_name="glm-4",
|
||
temperature=0.7,
|
||
context_length=8192,
|
||
enabled=True,
|
||
),
|
||
LLMModel(
|
||
id=SEED_LLM_IDS["embedding_3_small"],
|
||
user_id=1,
|
||
name="Embedding 3 Small",
|
||
vendor="OpenAI Compatible",
|
||
type="embedding",
|
||
base_url="https://api.openai.com/v1",
|
||
api_key="YOUR_API_KEY",
|
||
model_name="text-embedding-3-small",
|
||
enabled=True,
|
||
),
|
||
]
|
||
seed_if_empty(db, LLMModel, llm_models, "✅ 默认LLM模型已初始化")
|
||
|
||
|
||
def init_default_asr_models():
|
||
"""初始化默认ASR模型"""
|
||
with db_session() as db:
|
||
asr_models = [
|
||
ASRModel(
|
||
id=SEED_ASR_IDS["sensevoice_small"],
|
||
user_id=1,
|
||
name="FunAudioLLM/SenseVoiceSmall",
|
||
vendor="OpenAI Compatible",
|
||
language="Multi-lingual",
|
||
base_url="https://api.siliconflow.cn/v1",
|
||
api_key="YOUR_API_KEY",
|
||
model_name="FunAudioLLM/SenseVoiceSmall",
|
||
hotwords=[],
|
||
enable_punctuation=True,
|
||
enable_normalization=True,
|
||
enabled=True,
|
||
),
|
||
ASRModel(
|
||
id=SEED_ASR_IDS["telespeech_asr"],
|
||
user_id=1,
|
||
name="TeleAI/TeleSpeechASR",
|
||
vendor="OpenAI Compatible",
|
||
language="Multi-lingual",
|
||
base_url="https://api.siliconflow.cn/v1",
|
||
api_key="YOUR_API_KEY",
|
||
model_name="TeleAI/TeleSpeechASR",
|
||
hotwords=[],
|
||
enable_punctuation=True,
|
||
enable_normalization=True,
|
||
enabled=True,
|
||
),
|
||
]
|
||
seed_if_empty(db, ASRModel, asr_models, "✅ 默认ASR模型已初始化")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
parser = argparse.ArgumentParser(description="初始化/重建 AI VideoAssistant 数据与向量库")
|
||
parser.add_argument(
|
||
"--rebuild-db",
|
||
action="store_true",
|
||
help="重建数据库(drop + create tables)",
|
||
)
|
||
parser.add_argument(
|
||
"--rebuild-vector-store",
|
||
action="store_true",
|
||
help="重建向量库 KB 集合(清空后按 DB 的 knowledge_bases 重建 collection)",
|
||
)
|
||
parser.add_argument(
|
||
"--skip-seed",
|
||
action="store_true",
|
||
help="跳过默认数据初始化",
|
||
)
|
||
parser.add_argument(
|
||
"--recreate-tool-db",
|
||
action="store_true",
|
||
help="重建工具库数据(清空 tool_resources 后按内置默认工具重建)",
|
||
)
|
||
args = parser.parse_args()
|
||
|
||
# 无参数时保持旧行为:重建 DB + 初始化默认数据
|
||
if not args.rebuild_db and not args.rebuild_vector_store and not args.skip_seed:
|
||
args.rebuild_db = True
|
||
|
||
ensure_db_dir()
|
||
|
||
if args.rebuild_db:
|
||
init_db()
|
||
else:
|
||
print("ℹ️ 跳过数据库结构变更(未指定 --rebuild-db)")
|
||
|
||
if args.recreate_tool_db:
|
||
init_default_tools(recreate=True)
|
||
|
||
if not args.skip_seed:
|
||
init_default_data()
|
||
if not args.recreate_tool_db:
|
||
init_default_tools(recreate=False)
|
||
init_default_assistants()
|
||
init_default_workflows()
|
||
init_default_knowledge_bases()
|
||
init_default_llm_models()
|
||
init_default_asr_models()
|
||
print("✅ 默认数据初始化完成")
|
||
|
||
if args.rebuild_vector_store:
|
||
rebuild_vector_store(reset_doc_status=True)
|
||
|
||
print("🎉 初始化脚本执行完成!")
|