Files
AI-VideoAssistant/api/init_db.py
2026-02-09 14:07:28 +08:00

370 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""初始化数据库"""
import os
import sys
# 添加路径
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from app.db import Base, engine, DATABASE_URL
from app.models import Voice, Assistant, KnowledgeBase, Workflow, LLMModel, ASRModel
def init_db():
"""创建所有表"""
# 确保 data 目录存在
data_dir = os.path.dirname(DATABASE_URL.replace("sqlite:///", ""))
os.makedirs(data_dir, exist_ok=True)
print("📦 创建数据库表...")
Base.metadata.drop_all(bind=engine) # 删除旧表
Base.metadata.create_all(bind=engine)
print("✅ 数据库表创建完成")
def init_default_data():
from sqlalchemy.orm import Session
from app.db import SessionLocal
from app.models import Voice
db = SessionLocal()
try:
# 检查是否已有数据
if db.query(Voice).count() == 0:
# SiliconFlow CosyVoice 2.0 预设声音 (8个)
# 参考: https://docs.siliconflow.cn/cn/api-reference/audio/create-speech
voices = [
# 男声 (Male Voices)
Voice(id="alex", name="Alex", vendor="SiliconFlow", gender="Male", language="en",
description="Steady male voice.", is_system=True),
Voice(id="david", name="David", vendor="SiliconFlow", gender="Male", language="en",
description="Cheerful male voice.", is_system=True),
# 女声 (Female Voices)
Voice(id="bella", name="Bella", vendor="SiliconFlow", gender="Female", language="en",
description="Passionate female voice.", is_system=True),
Voice(id="claire", name="Claire", vendor="SiliconFlow", gender="Female", language="en",
description="Gentle female voice.", is_system=True),
]
for v in voices:
db.add(v)
db.commit()
print("✅ 默认声音数据已初始化 (SiliconFlow CosyVoice 2.0)")
finally:
db.close()
def init_default_assistants():
"""初始化默认助手"""
from sqlalchemy.orm import Session
from app.db import SessionLocal
db = SessionLocal()
try:
if db.query(Assistant).count() == 0:
assistants = [
Assistant(
id="default",
user_id=1,
name="AI 助手",
call_count=0,
opener="你好我是AI助手有什么可以帮你的吗",
prompt="你是一个友好的AI助手请用简洁清晰的语言回答用户的问题。",
language="zh",
voice="anna",
speed=1.0,
hotwords=[],
tools=["search", "calculator"],
interruption_sensitivity=500,
config_mode="platform",
llm_model_id="deepseek-chat",
asr_model_id="paraformer-v2",
),
Assistant(
id="customer_service",
user_id=1,
name="客服助手",
call_count=0,
opener="您好,欢迎致电客服中心,请问有什么可以帮您?",
prompt="你是一个专业的客服人员,耐心解答客户问题,提供优质的服务体验。",
language="zh",
voice="bella",
speed=1.0,
hotwords=["客服", "投诉", "咨询"],
tools=["search"],
interruption_sensitivity=600,
config_mode="platform",
),
Assistant(
id="english_tutor",
user_id=1,
name="英语导师",
call_count=0,
opener="Hello! I'm your English learning companion. How can I help you today?",
prompt="You are a friendly English tutor. Help users practice English conversation and explain grammar points clearly.",
language="en",
voice="alex",
speed=1.0,
hotwords=["grammar", "vocabulary", "practice"],
tools=[],
interruption_sensitivity=400,
config_mode="platform",
),
]
for a in assistants:
db.add(a)
db.commit()
print("✅ 默认助手数据已初始化")
finally:
db.close()
def init_default_workflows():
"""初始化默认工作流"""
from sqlalchemy.orm import Session
from app.db import SessionLocal
from datetime import datetime
db = SessionLocal()
try:
if db.query(Workflow).count() == 0:
now = datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")
workflows = [
Workflow(
id="simple_conversation",
user_id=1,
name="简单对话",
node_count=2,
created_at=now,
updated_at=now,
global_prompt="处理简单的对话流程,用户问什么答什么。",
nodes=[
{"id": "1", "type": "start", "position": {"x": 100, "y": 100}, "data": {"label": "开始"}},
{"id": "2", "type": "ai_reply", "position": {"x": 300, "y": 100}, "data": {"label": "AI回复"}},
],
edges=[{"source": "1", "target": "2", "id": "e1-2"}],
),
Workflow(
id="voice_input_flow",
user_id=1,
name="语音输入流程",
node_count=4,
created_at=now,
updated_at=now,
global_prompt="处理语音输入的完整流程。",
nodes=[
{"id": "1", "type": "start", "position": {"x": 100, "y": 100}, "data": {"label": "开始"}},
{"id": "2", "type": "asr", "position": {"x": 250, "y": 100}, "data": {"label": "语音识别"}},
{"id": "3", "type": "llm", "position": {"x": 400, "y": 100}, "data": {"label": "LLM处理"}},
{"id": "4", "type": "tts", "position": {"x": 550, "y": 100}, "data": {"label": "语音合成"}},
],
edges=[
{"source": "1", "target": "2", "id": "e1-2"},
{"source": "2", "target": "3", "id": "e2-3"},
{"source": "3", "target": "4", "id": "e3-4"},
],
),
]
for w in workflows:
db.add(w)
db.commit()
print("✅ 默认工作流数据已初始化")
finally:
db.close()
def init_default_knowledge_bases():
"""初始化默认知识库"""
from sqlalchemy.orm import Session
from app.db import SessionLocal
db = SessionLocal()
try:
if db.query(KnowledgeBase).count() == 0:
kb = KnowledgeBase(
id="default_kb",
user_id=1,
name="默认知识库",
description="系统默认知识库,用于存储常见问题解答。",
embedding_model="text-embedding-3-small",
chunk_size=500,
chunk_overlap=50,
doc_count=0,
chunk_count=0,
status="active",
)
db.add(kb)
db.commit()
print("✅ 默认知识库已初始化")
finally:
db.close()
def init_default_llm_models():
"""初始化默认LLM模型"""
from sqlalchemy.orm import Session
from app.db import SessionLocal
db = SessionLocal()
try:
if db.query(LLMModel).count() == 0:
llm_models = [
LLMModel(
id="deepseek-chat",
user_id=1,
name="DeepSeek Chat",
vendor="SiliconFlow",
type="text",
base_url="https://api.deepseek.com",
api_key="YOUR_API_KEY", # 用户需替换
model_name="deepseek-chat",
temperature=0.7,
context_length=4096,
enabled=True,
),
LLMModel(
id="deepseek-reasoner",
user_id=1,
name="DeepSeek Reasoner",
vendor="SiliconFlow",
type="text",
base_url="https://api.deepseek.com",
api_key="YOUR_API_KEY",
model_name="deepseek-reasoner",
temperature=0.7,
context_length=4096,
enabled=True,
),
LLMModel(
id="gpt-4o",
user_id=1,
name="GPT-4o",
vendor="OpenAI",
type="text",
base_url="https://api.openai.com/v1",
api_key="YOUR_API_KEY",
model_name="gpt-4o",
temperature=0.7,
context_length=16384,
enabled=True,
),
LLMModel(
id="glm-4",
user_id=1,
name="GLM-4",
vendor="ZhipuAI",
type="text",
base_url="https://open.bigmodel.cn/api/paas/v4",
api_key="YOUR_API_KEY",
model_name="glm-4",
temperature=0.7,
context_length=8192,
enabled=True,
),
LLMModel(
id="text-embedding-3-small",
user_id=1,
name="Embedding 3 Small",
vendor="OpenAI",
type="embedding",
base_url="https://api.openai.com/v1",
api_key="YOUR_API_KEY",
model_name="text-embedding-3-small",
enabled=True,
),
]
for m in llm_models:
db.add(m)
db.commit()
print("✅ 默认LLM模型已初始化")
finally:
db.close()
def init_default_asr_models():
"""初始化默认ASR模型"""
from sqlalchemy.orm import Session
from app.db import SessionLocal
db = SessionLocal()
try:
if db.query(ASRModel).count() == 0:
asr_models = [
ASRModel(
id="paraformer-v2",
user_id=1,
name="Paraformer V2",
vendor="SiliconFlow",
language="zh",
base_url="https://api.siliconflow.cn/v1",
api_key="YOUR_API_KEY",
model_name="paraformer-v2",
hotwords=["人工智能", "机器学习"],
enable_punctuation=True,
enable_normalization=True,
enabled=True,
),
ASRModel(
id="paraformer-en",
user_id=1,
name="Paraformer English",
vendor="SiliconFlow",
language="en",
base_url="https://api.siliconflow.cn/v1",
api_key="YOUR_API_KEY",
model_name="paraformer-en",
hotwords=[],
enable_punctuation=True,
enable_normalization=True,
enabled=True,
),
ASRModel(
id="whisper-1",
user_id=1,
name="Whisper",
vendor="OpenAI",
language="Multi-lingual",
base_url="https://api.openai.com/v1",
api_key="YOUR_API_KEY",
model_name="whisper-1",
hotwords=[],
enable_punctuation=True,
enable_normalization=True,
enabled=True,
),
ASRModel(
id="sensevoice",
user_id=1,
name="SenseVoice",
vendor="SiliconFlow",
language="Multi-lingual",
base_url="https://api.siliconflow.cn/v1",
api_key="YOUR_API_KEY",
model_name="sensevoice",
hotwords=[],
enable_punctuation=True,
enable_normalization=True,
enabled=True,
),
]
for m in asr_models:
db.add(m)
db.commit()
print("✅ 默认ASR模型已初始化")
finally:
db.close()
if __name__ == "__main__":
# 确保 data 目录存在
data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data")
os.makedirs(data_dir, exist_ok=True)
init_db()
init_default_data()
init_default_assistants()
init_default_workflows()
init_default_knowledge_bases()
init_default_llm_models()
init_default_asr_models()
print("🎉 数据库初始化完成!")