Add init db selectively
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
"""初始化数据库"""
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
@@ -7,7 +8,7 @@ import sys
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from app.db import Base, engine, DATABASE_URL
|
||||
from app.models import Voice, Assistant, KnowledgeBase, Workflow, LLMModel, ASRModel
|
||||
from app.models import Voice, Assistant, KnowledgeBase, Workflow, LLMModel, ASRModel, KnowledgeDocument
|
||||
|
||||
|
||||
def init_db():
|
||||
@@ -22,6 +23,52 @@ def init_db():
|
||||
print("✅ 数据库表创建完成")
|
||||
|
||||
|
||||
def rebuild_vector_store(reset_doc_status: bool = True):
|
||||
"""重建知识库向量集合(按 DB 中的 KB 列表重建 collection 壳)。"""
|
||||
from app.db import SessionLocal
|
||||
from app.vector_store import vector_store
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
print("🧹 重建向量库集合...")
|
||||
kb_list = db.query(KnowledgeBase).all()
|
||||
|
||||
# 删除现有 KB 集合
|
||||
try:
|
||||
collections = vector_store.client.list_collections()
|
||||
except Exception as exc:
|
||||
raise RuntimeError(f"无法读取向量集合列表: {exc}") from exc
|
||||
|
||||
for col in collections:
|
||||
name = getattr(col, "name", None) or str(col)
|
||||
if name.startswith("kb_"):
|
||||
try:
|
||||
vector_store.client.delete_collection(name=name)
|
||||
print(f" - removed {name}")
|
||||
except Exception as exc:
|
||||
print(f" - skip remove {name}: {exc}")
|
||||
|
||||
# 按 DB 重建 KB 集合
|
||||
for kb in kb_list:
|
||||
vector_store.create_collection(kb.id, kb.embedding_model)
|
||||
print(f" + created kb_{kb.id} ({kb.embedding_model})")
|
||||
|
||||
if reset_doc_status:
|
||||
kb.chunk_count = 0
|
||||
docs = db.query(KnowledgeDocument).filter(KnowledgeDocument.kb_id == kb.id).all()
|
||||
kb.doc_count = 0
|
||||
for doc in docs:
|
||||
doc.chunk_count = 0
|
||||
doc.status = "pending"
|
||||
doc.error_message = None
|
||||
doc.processed_at = None
|
||||
|
||||
db.commit()
|
||||
print("✅ 向量库重建完成(仅重建集合壳,文档需重新索引)")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def init_default_data():
|
||||
from sqlalchemy.orm import Session
|
||||
from app.db import SessionLocal
|
||||
@@ -355,15 +402,45 @@ def init_default_asr_models():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="初始化/重建 AI VideoAssistant 数据与向量库")
|
||||
parser.add_argument(
|
||||
"--rebuild-db",
|
||||
action="store_true",
|
||||
help="重建数据库(drop + create tables)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rebuild-vector-store",
|
||||
action="store_true",
|
||||
help="重建向量库 KB 集合(清空后按 DB 的 knowledge_bases 重建 collection)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-seed",
|
||||
action="store_true",
|
||||
help="跳过默认数据初始化",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# 无参数时保持旧行为:重建 DB + 初始化默认数据
|
||||
if not args.rebuild_db and not args.rebuild_vector_store and not args.skip_seed:
|
||||
args.rebuild_db = True
|
||||
|
||||
# 确保 data 目录存在
|
||||
data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data")
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
|
||||
init_db()
|
||||
init_default_data()
|
||||
init_default_assistants()
|
||||
init_default_workflows()
|
||||
init_default_knowledge_bases()
|
||||
init_default_llm_models()
|
||||
init_default_asr_models()
|
||||
print("🎉 数据库初始化完成!")
|
||||
if args.rebuild_db:
|
||||
init_db()
|
||||
|
||||
if not args.skip_seed:
|
||||
init_default_data()
|
||||
init_default_assistants()
|
||||
init_default_workflows()
|
||||
init_default_knowledge_bases()
|
||||
init_default_llm_models()
|
||||
init_default_asr_models()
|
||||
print("✅ 默认数据初始化完成")
|
||||
|
||||
if args.rebuild_vector_store:
|
||||
rebuild_vector_store(reset_doc_status=True)
|
||||
|
||||
print("🎉 初始化脚本执行完成!")
|
||||
|
||||
Reference in New Issue
Block a user