312 lines
9.3 KiB
Python
312 lines
9.3 KiB
Python
"""
|
|
向量数据库服务 (ChromaDB)
|
|
"""
|
|
import os
|
|
from typing import List, Dict, Optional
|
|
import chromadb
|
|
from chromadb.config import Settings
|
|
|
|
# 配置
|
|
VECTOR_STORE_PATH = os.getenv("VECTOR_STORE_PATH", "./data/vector_store")
|
|
COLLECTION_NAME_PREFIX = "kb_"
|
|
|
|
|
|
class VectorStore:
|
|
"""向量存储服务"""
|
|
|
|
def __init__(self):
|
|
os.makedirs(VECTOR_STORE_PATH, exist_ok=True)
|
|
self.client = chromadb.PersistentClient(
|
|
path=VECTOR_STORE_PATH,
|
|
settings=Settings(anonymized_telemetry=False)
|
|
)
|
|
|
|
def get_collection(self, kb_id: str):
|
|
"""获取知识库集合"""
|
|
collection_name = f"{COLLECTION_NAME_PREFIX}{kb_id}"
|
|
try:
|
|
return self.client.get_collection(name=collection_name)
|
|
except (ValueError, chromadb.errors.NotFoundError):
|
|
return None
|
|
|
|
def create_collection(self, kb_id: str, embedding_model: str = "text-embedding-3-small"):
|
|
"""创建知识库向量集合"""
|
|
collection_name = f"{COLLECTION_NAME_PREFIX}{kb_id}"
|
|
try:
|
|
self.client.get_collection(name=collection_name)
|
|
return collection_name
|
|
except (ValueError, chromadb.errors.NotFoundError):
|
|
self.client.create_collection(
|
|
name=collection_name,
|
|
metadata={
|
|
"kb_id": kb_id,
|
|
"embedding_model": embedding_model
|
|
}
|
|
)
|
|
return collection_name
|
|
|
|
def delete_collection(self, kb_id: str):
|
|
"""删除知识库向量集合"""
|
|
collection_name = f"{COLLECTION_NAME_PREFIX}{kb_id}"
|
|
try:
|
|
self.client.delete_collection(name=collection_name)
|
|
return True
|
|
except (ValueError, chromadb.errors.NotFoundError):
|
|
return False
|
|
|
|
def add_documents(
|
|
self,
|
|
kb_id: str,
|
|
documents: List[str],
|
|
embeddings: Optional[List[List[float]]] = None,
|
|
ids: Optional[List[str]] = None,
|
|
metadatas: Optional[List[Dict]] = None
|
|
):
|
|
"""添加文档片段到向量库"""
|
|
collection = self.get_collection(kb_id)
|
|
|
|
if ids is None:
|
|
ids = [f"chunk-{i}" for i in range(len(documents))]
|
|
|
|
if embeddings is not None:
|
|
collection.add(
|
|
documents=documents,
|
|
embeddings=embeddings,
|
|
ids=ids,
|
|
metadatas=metadatas
|
|
)
|
|
else:
|
|
collection.add(
|
|
documents=documents,
|
|
ids=ids,
|
|
metadatas=metadatas
|
|
)
|
|
|
|
return len(documents)
|
|
|
|
def search(
|
|
self,
|
|
kb_id: str,
|
|
query: str,
|
|
n_results: int = 5,
|
|
where: Optional[Dict] = None
|
|
) -> Dict:
|
|
"""检索相似文档"""
|
|
collection = self.get_collection(kb_id)
|
|
|
|
# 生成查询向量
|
|
query_embedding = embedding_service.embed_query(query)
|
|
|
|
results = collection.query(
|
|
query_embeddings=[query_embedding],
|
|
n_results=n_results,
|
|
where=where
|
|
)
|
|
|
|
return results
|
|
|
|
def get_stats(self, kb_id: str) -> Dict:
|
|
"""获取向量库统计"""
|
|
collection = self.get_collection(kb_id)
|
|
return {
|
|
"count": collection.count(),
|
|
"kb_id": kb_id
|
|
}
|
|
|
|
def delete_documents(self, kb_id: str, ids: List[str]):
|
|
"""删除指定文档片段"""
|
|
collection = self.get_collection(kb_id)
|
|
collection.delete(ids=ids)
|
|
|
|
def delete_by_metadata(self, kb_id: str, document_id: str):
|
|
"""根据文档 ID 删除所有片段"""
|
|
collection = self.get_collection(kb_id)
|
|
results = collection.get(where={"document_id": document_id})
|
|
if results["ids"]:
|
|
collection.delete(ids=results["ids"])
|
|
|
|
|
|
class EmbeddingService:
|
|
""" embedding 服务(支持多种模型)"""
|
|
|
|
def __init__(self, model: str = "text-embedding-3-small"):
|
|
self.model = model
|
|
self._client = None
|
|
|
|
def _get_client(self):
|
|
"""获取 OpenAI 客户端"""
|
|
if self._client is None:
|
|
try:
|
|
from openai import OpenAI
|
|
api_key = os.getenv("OPENAI_API_KEY")
|
|
if api_key:
|
|
self._client = OpenAI(api_key=api_key)
|
|
except ImportError:
|
|
pass
|
|
return self._client
|
|
|
|
def embed(self, texts: List[str]) -> List[List[float]]:
|
|
"""生成 embedding 向量"""
|
|
client = self._get_client()
|
|
|
|
if client is None:
|
|
# 返回随机向量(仅用于测试)
|
|
import random
|
|
import math
|
|
dim = 1536 if "3-small" in self.model else 1024
|
|
return [[random.uniform(-1, 1) for _ in range(dim)] for _ in texts]
|
|
|
|
response = client.embeddings.create(
|
|
model=self.model,
|
|
input=texts
|
|
)
|
|
return [data.embedding for data in response.data]
|
|
|
|
def embed_query(self, query: str) -> List[float]:
|
|
"""生成查询向量"""
|
|
return self.embed([query])[0]
|
|
|
|
|
|
class DocumentProcessor:
|
|
"""文档处理服务"""
|
|
|
|
def __init__(self, chunk_size: int = 500, chunk_overlap: int = 50):
|
|
self.chunk_size = chunk_size
|
|
self.chunk_overlap = chunk_overlap
|
|
|
|
def chunk_text(self, text: str, document_id: str = "") -> List[Dict]:
|
|
"""将文本分块"""
|
|
# 简单分块(按句子/段落)
|
|
import re
|
|
|
|
# 按句子分割
|
|
sentences = re.split(r'[。!?\n]', text)
|
|
|
|
chunks = []
|
|
current_chunk = ""
|
|
current_size = 0
|
|
|
|
for i, sentence in enumerate(sentences):
|
|
sentence = sentence.strip()
|
|
if not sentence:
|
|
continue
|
|
|
|
sentence_len = len(sentence)
|
|
|
|
if current_size + sentence_len > self.chunk_size and current_chunk:
|
|
# 保存当前块
|
|
chunks.append({
|
|
"content": current_chunk.strip(),
|
|
"document_id": document_id,
|
|
"chunk_index": len(chunks),
|
|
"metadata": {
|
|
"source": "text"
|
|
}
|
|
})
|
|
|
|
# 处理重叠
|
|
if self.chunk_overlap > 0:
|
|
# 保留末尾部分
|
|
overlap_chars = current_chunk[-self.chunk_overlap:]
|
|
current_chunk = overlap_chars + " " + sentence
|
|
current_size = len(overlap_chars) + sentence_len + 1
|
|
else:
|
|
current_chunk = sentence
|
|
current_size = sentence_len
|
|
else:
|
|
if current_chunk:
|
|
current_chunk += " "
|
|
current_chunk += sentence
|
|
current_size += sentence_len + 1
|
|
|
|
# 保存最后一个块
|
|
if current_chunk.strip():
|
|
chunks.append({
|
|
"content": current_chunk.strip(),
|
|
"document_id": document_id,
|
|
"chunk_index": len(chunks),
|
|
"metadata": {
|
|
"source": "text"
|
|
}
|
|
})
|
|
|
|
return chunks
|
|
|
|
def process_document(self, text: str, document_id: str = "") -> List[Dict]:
|
|
"""完整处理文档"""
|
|
return self.chunk_text(text, document_id)
|
|
|
|
|
|
# 全局实例
|
|
vector_store = VectorStore()
|
|
embedding_service = EmbeddingService()
|
|
|
|
|
|
def search_knowledge(kb_id: str, query: str, n_results: int = 5) -> Dict:
|
|
"""知识库检索"""
|
|
# 生成查询向量
|
|
query_vector = embedding_service.embed_query(query)
|
|
|
|
# 检索
|
|
results = vector_store.search(
|
|
kb_id=kb_id,
|
|
query=query,
|
|
n_results=n_results
|
|
)
|
|
|
|
return {
|
|
"query": query,
|
|
"results": [
|
|
{
|
|
"content": doc,
|
|
"metadata": meta,
|
|
"distance": dist
|
|
}
|
|
for doc, meta, dist in zip(
|
|
results.get("documents", [[]])[0] if results.get("documents") else [],
|
|
results.get("metadatas", [[]])[0] if results.get("metadatas") else [],
|
|
results.get("distances", [[]])[0] if results.get("distances") else []
|
|
)
|
|
]
|
|
}
|
|
|
|
|
|
def index_document(kb_id: str, document_id: str, text: str) -> int:
|
|
"""索引文档到向量库"""
|
|
# 分块
|
|
processor = DocumentProcessor()
|
|
chunks = processor.process_document(text, document_id)
|
|
|
|
if not chunks:
|
|
return 0
|
|
|
|
# 生成向量
|
|
contents = [c["content"] for c in chunks]
|
|
embeddings = embedding_service.embed(contents)
|
|
|
|
# 添加到向量库
|
|
ids = [f"{document_id}-{c['chunk_index']}" for c in chunks]
|
|
metadatas = [
|
|
{
|
|
"document_id": c["document_id"],
|
|
"chunk_index": c["chunk_index"],
|
|
"kb_id": kb_id
|
|
}
|
|
for c in chunks
|
|
]
|
|
|
|
vector_store.add_documents(
|
|
kb_id=kb_id,
|
|
documents=contents,
|
|
embeddings=embeddings,
|
|
ids=ids,
|
|
metadatas=metadatas
|
|
)
|
|
|
|
return len(chunks)
|
|
|
|
|
|
def delete_document_from_vector(kb_id: str, document_id: str):
|
|
"""从向量库删除文档"""
|
|
vector_store.delete_by_metadata(kb_id, document_id)
|