Files
AI-VideoAssistant/api/app/vector_store.py
2026-02-06 14:01:34 +08:00

312 lines
9.3 KiB
Python

"""
向量数据库服务 (ChromaDB)
"""
import os
from typing import List, Dict, Optional
import chromadb
from chromadb.config import Settings
# 配置
VECTOR_STORE_PATH = os.getenv("VECTOR_STORE_PATH", "./data/vector_store")
COLLECTION_NAME_PREFIX = "kb_"
class VectorStore:
"""向量存储服务"""
def __init__(self):
os.makedirs(VECTOR_STORE_PATH, exist_ok=True)
self.client = chromadb.PersistentClient(
path=VECTOR_STORE_PATH,
settings=Settings(anonymized_telemetry=False)
)
def get_collection(self, kb_id: str):
"""获取知识库集合"""
collection_name = f"{COLLECTION_NAME_PREFIX}{kb_id}"
try:
return self.client.get_collection(name=collection_name)
except (ValueError, chromadb.errors.NotFoundError):
return None
def create_collection(self, kb_id: str, embedding_model: str = "text-embedding-3-small"):
"""创建知识库向量集合"""
collection_name = f"{COLLECTION_NAME_PREFIX}{kb_id}"
try:
self.client.get_collection(name=collection_name)
return collection_name
except (ValueError, chromadb.errors.NotFoundError):
self.client.create_collection(
name=collection_name,
metadata={
"kb_id": kb_id,
"embedding_model": embedding_model
}
)
return collection_name
def delete_collection(self, kb_id: str):
"""删除知识库向量集合"""
collection_name = f"{COLLECTION_NAME_PREFIX}{kb_id}"
try:
self.client.delete_collection(name=collection_name)
return True
except (ValueError, chromadb.errors.NotFoundError):
return False
def add_documents(
self,
kb_id: str,
documents: List[str],
embeddings: Optional[List[List[float]]] = None,
ids: Optional[List[str]] = None,
metadatas: Optional[List[Dict]] = None
):
"""添加文档片段到向量库"""
collection = self.get_collection(kb_id)
if ids is None:
ids = [f"chunk-{i}" for i in range(len(documents))]
if embeddings is not None:
collection.add(
documents=documents,
embeddings=embeddings,
ids=ids,
metadatas=metadatas
)
else:
collection.add(
documents=documents,
ids=ids,
metadatas=metadatas
)
return len(documents)
def search(
self,
kb_id: str,
query: str,
n_results: int = 5,
where: Optional[Dict] = None
) -> Dict:
"""检索相似文档"""
collection = self.get_collection(kb_id)
# 生成查询向量
query_embedding = embedding_service.embed_query(query)
results = collection.query(
query_embeddings=[query_embedding],
n_results=n_results,
where=where
)
return results
def get_stats(self, kb_id: str) -> Dict:
"""获取向量库统计"""
collection = self.get_collection(kb_id)
return {
"count": collection.count(),
"kb_id": kb_id
}
def delete_documents(self, kb_id: str, ids: List[str]):
"""删除指定文档片段"""
collection = self.get_collection(kb_id)
collection.delete(ids=ids)
def delete_by_metadata(self, kb_id: str, document_id: str):
"""根据文档 ID 删除所有片段"""
collection = self.get_collection(kb_id)
results = collection.get(where={"document_id": document_id})
if results["ids"]:
collection.delete(ids=results["ids"])
class EmbeddingService:
""" embedding 服务(支持多种模型)"""
def __init__(self, model: str = "text-embedding-3-small"):
self.model = model
self._client = None
def _get_client(self):
"""获取 OpenAI 客户端"""
if self._client is None:
try:
from openai import OpenAI
api_key = os.getenv("OPENAI_API_KEY")
if api_key:
self._client = OpenAI(api_key=api_key)
except ImportError:
pass
return self._client
def embed(self, texts: List[str]) -> List[List[float]]:
"""生成 embedding 向量"""
client = self._get_client()
if client is None:
# 返回随机向量(仅用于测试)
import random
import math
dim = 1536 if "3-small" in self.model else 1024
return [[random.uniform(-1, 1) for _ in range(dim)] for _ in texts]
response = client.embeddings.create(
model=self.model,
input=texts
)
return [data.embedding for data in response.data]
def embed_query(self, query: str) -> List[float]:
"""生成查询向量"""
return self.embed([query])[0]
class DocumentProcessor:
"""文档处理服务"""
def __init__(self, chunk_size: int = 500, chunk_overlap: int = 50):
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
def chunk_text(self, text: str, document_id: str = "") -> List[Dict]:
"""将文本分块"""
# 简单分块(按句子/段落)
import re
# 按句子分割
sentences = re.split(r'[。!?\n]', text)
chunks = []
current_chunk = ""
current_size = 0
for i, sentence in enumerate(sentences):
sentence = sentence.strip()
if not sentence:
continue
sentence_len = len(sentence)
if current_size + sentence_len > self.chunk_size and current_chunk:
# 保存当前块
chunks.append({
"content": current_chunk.strip(),
"document_id": document_id,
"chunk_index": len(chunks),
"metadata": {
"source": "text"
}
})
# 处理重叠
if self.chunk_overlap > 0:
# 保留末尾部分
overlap_chars = current_chunk[-self.chunk_overlap:]
current_chunk = overlap_chars + " " + sentence
current_size = len(overlap_chars) + sentence_len + 1
else:
current_chunk = sentence
current_size = sentence_len
else:
if current_chunk:
current_chunk += " "
current_chunk += sentence
current_size += sentence_len + 1
# 保存最后一个块
if current_chunk.strip():
chunks.append({
"content": current_chunk.strip(),
"document_id": document_id,
"chunk_index": len(chunks),
"metadata": {
"source": "text"
}
})
return chunks
def process_document(self, text: str, document_id: str = "") -> List[Dict]:
"""完整处理文档"""
return self.chunk_text(text, document_id)
# 全局实例
vector_store = VectorStore()
embedding_service = EmbeddingService()
def search_knowledge(kb_id: str, query: str, n_results: int = 5) -> Dict:
"""知识库检索"""
# 生成查询向量
query_vector = embedding_service.embed_query(query)
# 检索
results = vector_store.search(
kb_id=kb_id,
query=query,
n_results=n_results
)
return {
"query": query,
"results": [
{
"content": doc,
"metadata": meta,
"distance": dist
}
for doc, meta, dist in zip(
results.get("documents", [[]])[0] if results.get("documents") else [],
results.get("metadatas", [[]])[0] if results.get("metadatas") else [],
results.get("distances", [[]])[0] if results.get("distances") else []
)
]
}
def index_document(kb_id: str, document_id: str, text: str) -> int:
"""索引文档到向量库"""
# 分块
processor = DocumentProcessor()
chunks = processor.process_document(text, document_id)
if not chunks:
return 0
# 生成向量
contents = [c["content"] for c in chunks]
embeddings = embedding_service.embed(contents)
# 添加到向量库
ids = [f"{document_id}-{c['chunk_index']}" for c in chunks]
metadatas = [
{
"document_id": c["document_id"],
"chunk_index": c["chunk_index"],
"kb_id": kb_id
}
for c in chunks
]
vector_store.add_documents(
kb_id=kb_id,
documents=contents,
embeddings=embeddings,
ids=ids,
metadatas=metadatas
)
return len(chunks)
def delete_document_from_vector(kb_id: str, document_id: str):
"""从向量库删除文档"""
vector_store.delete_by_metadata(kb_id, document_id)