Add backend api and engine
This commit is contained in:
311
api/app/vector_store.py
Normal file
311
api/app/vector_store.py
Normal file
@@ -0,0 +1,311 @@
|
||||
"""
|
||||
向量数据库服务 (ChromaDB)
|
||||
"""
|
||||
import os
|
||||
from typing import List, Dict, Optional
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
|
||||
# 配置
|
||||
VECTOR_STORE_PATH = os.getenv("VECTOR_STORE_PATH", "./data/vector_store")
|
||||
COLLECTION_NAME_PREFIX = "kb_"
|
||||
|
||||
|
||||
class VectorStore:
|
||||
"""向量存储服务"""
|
||||
|
||||
def __init__(self):
|
||||
os.makedirs(VECTOR_STORE_PATH, exist_ok=True)
|
||||
self.client = chromadb.PersistentClient(
|
||||
path=VECTOR_STORE_PATH,
|
||||
settings=Settings(anonymized_telemetry=False)
|
||||
)
|
||||
|
||||
def get_collection(self, kb_id: str):
|
||||
"""获取知识库集合"""
|
||||
collection_name = f"{COLLECTION_NAME_PREFIX}{kb_id}"
|
||||
try:
|
||||
return self.client.get_collection(name=collection_name)
|
||||
except (ValueError, chromadb.errors.NotFoundError):
|
||||
return None
|
||||
|
||||
def create_collection(self, kb_id: str, embedding_model: str = "text-embedding-3-small"):
|
||||
"""创建知识库向量集合"""
|
||||
collection_name = f"{COLLECTION_NAME_PREFIX}{kb_id}"
|
||||
try:
|
||||
self.client.get_collection(name=collection_name)
|
||||
return collection_name
|
||||
except (ValueError, chromadb.errors.NotFoundError):
|
||||
self.client.create_collection(
|
||||
name=collection_name,
|
||||
metadata={
|
||||
"kb_id": kb_id,
|
||||
"embedding_model": embedding_model
|
||||
}
|
||||
)
|
||||
return collection_name
|
||||
|
||||
def delete_collection(self, kb_id: str):
|
||||
"""删除知识库向量集合"""
|
||||
collection_name = f"{COLLECTION_NAME_PREFIX}{kb_id}"
|
||||
try:
|
||||
self.client.delete_collection(name=collection_name)
|
||||
return True
|
||||
except (ValueError, chromadb.errors.NotFoundError):
|
||||
return False
|
||||
|
||||
def add_documents(
|
||||
self,
|
||||
kb_id: str,
|
||||
documents: List[str],
|
||||
embeddings: Optional[List[List[float]]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
metadatas: Optional[List[Dict]] = None
|
||||
):
|
||||
"""添加文档片段到向量库"""
|
||||
collection = self.get_collection(kb_id)
|
||||
|
||||
if ids is None:
|
||||
ids = [f"chunk-{i}" for i in range(len(documents))]
|
||||
|
||||
if embeddings is not None:
|
||||
collection.add(
|
||||
documents=documents,
|
||||
embeddings=embeddings,
|
||||
ids=ids,
|
||||
metadatas=metadatas
|
||||
)
|
||||
else:
|
||||
collection.add(
|
||||
documents=documents,
|
||||
ids=ids,
|
||||
metadatas=metadatas
|
||||
)
|
||||
|
||||
return len(documents)
|
||||
|
||||
def search(
|
||||
self,
|
||||
kb_id: str,
|
||||
query: str,
|
||||
n_results: int = 5,
|
||||
where: Optional[Dict] = None
|
||||
) -> Dict:
|
||||
"""检索相似文档"""
|
||||
collection = self.get_collection(kb_id)
|
||||
|
||||
# 生成查询向量
|
||||
query_embedding = embedding_service.embed_query(query)
|
||||
|
||||
results = collection.query(
|
||||
query_embeddings=[query_embedding],
|
||||
n_results=n_results,
|
||||
where=where
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def get_stats(self, kb_id: str) -> Dict:
|
||||
"""获取向量库统计"""
|
||||
collection = self.get_collection(kb_id)
|
||||
return {
|
||||
"count": collection.count(),
|
||||
"kb_id": kb_id
|
||||
}
|
||||
|
||||
def delete_documents(self, kb_id: str, ids: List[str]):
|
||||
"""删除指定文档片段"""
|
||||
collection = self.get_collection(kb_id)
|
||||
collection.delete(ids=ids)
|
||||
|
||||
def delete_by_metadata(self, kb_id: str, document_id: str):
|
||||
"""根据文档 ID 删除所有片段"""
|
||||
collection = self.get_collection(kb_id)
|
||||
results = collection.get(where={"document_id": document_id})
|
||||
if results["ids"]:
|
||||
collection.delete(ids=results["ids"])
|
||||
|
||||
|
||||
class EmbeddingService:
|
||||
""" embedding 服务(支持多种模型)"""
|
||||
|
||||
def __init__(self, model: str = "text-embedding-3-small"):
|
||||
self.model = model
|
||||
self._client = None
|
||||
|
||||
def _get_client(self):
|
||||
"""获取 OpenAI 客户端"""
|
||||
if self._client is None:
|
||||
try:
|
||||
from openai import OpenAI
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if api_key:
|
||||
self._client = OpenAI(api_key=api_key)
|
||||
except ImportError:
|
||||
pass
|
||||
return self._client
|
||||
|
||||
def embed(self, texts: List[str]) -> List[List[float]]:
|
||||
"""生成 embedding 向量"""
|
||||
client = self._get_client()
|
||||
|
||||
if client is None:
|
||||
# 返回随机向量(仅用于测试)
|
||||
import random
|
||||
import math
|
||||
dim = 1536 if "3-small" in self.model else 1024
|
||||
return [[random.uniform(-1, 1) for _ in range(dim)] for _ in texts]
|
||||
|
||||
response = client.embeddings.create(
|
||||
model=self.model,
|
||||
input=texts
|
||||
)
|
||||
return [data.embedding for data in response.data]
|
||||
|
||||
def embed_query(self, query: str) -> List[float]:
|
||||
"""生成查询向量"""
|
||||
return self.embed([query])[0]
|
||||
|
||||
|
||||
class DocumentProcessor:
|
||||
"""文档处理服务"""
|
||||
|
||||
def __init__(self, chunk_size: int = 500, chunk_overlap: int = 50):
|
||||
self.chunk_size = chunk_size
|
||||
self.chunk_overlap = chunk_overlap
|
||||
|
||||
def chunk_text(self, text: str, document_id: str = "") -> List[Dict]:
|
||||
"""将文本分块"""
|
||||
# 简单分块(按句子/段落)
|
||||
import re
|
||||
|
||||
# 按句子分割
|
||||
sentences = re.split(r'[。!?\n]', text)
|
||||
|
||||
chunks = []
|
||||
current_chunk = ""
|
||||
current_size = 0
|
||||
|
||||
for i, sentence in enumerate(sentences):
|
||||
sentence = sentence.strip()
|
||||
if not sentence:
|
||||
continue
|
||||
|
||||
sentence_len = len(sentence)
|
||||
|
||||
if current_size + sentence_len > self.chunk_size and current_chunk:
|
||||
# 保存当前块
|
||||
chunks.append({
|
||||
"content": current_chunk.strip(),
|
||||
"document_id": document_id,
|
||||
"chunk_index": len(chunks),
|
||||
"metadata": {
|
||||
"source": "text"
|
||||
}
|
||||
})
|
||||
|
||||
# 处理重叠
|
||||
if self.chunk_overlap > 0:
|
||||
# 保留末尾部分
|
||||
overlap_chars = current_chunk[-self.chunk_overlap:]
|
||||
current_chunk = overlap_chars + " " + sentence
|
||||
current_size = len(overlap_chars) + sentence_len + 1
|
||||
else:
|
||||
current_chunk = sentence
|
||||
current_size = sentence_len
|
||||
else:
|
||||
if current_chunk:
|
||||
current_chunk += " "
|
||||
current_chunk += sentence
|
||||
current_size += sentence_len + 1
|
||||
|
||||
# 保存最后一个块
|
||||
if current_chunk.strip():
|
||||
chunks.append({
|
||||
"content": current_chunk.strip(),
|
||||
"document_id": document_id,
|
||||
"chunk_index": len(chunks),
|
||||
"metadata": {
|
||||
"source": "text"
|
||||
}
|
||||
})
|
||||
|
||||
return chunks
|
||||
|
||||
def process_document(self, text: str, document_id: str = "") -> List[Dict]:
|
||||
"""完整处理文档"""
|
||||
return self.chunk_text(text, document_id)
|
||||
|
||||
|
||||
# 全局实例
|
||||
vector_store = VectorStore()
|
||||
embedding_service = EmbeddingService()
|
||||
|
||||
|
||||
def search_knowledge(kb_id: str, query: str, n_results: int = 5) -> Dict:
|
||||
"""知识库检索"""
|
||||
# 生成查询向量
|
||||
query_vector = embedding_service.embed_query(query)
|
||||
|
||||
# 检索
|
||||
results = vector_store.search(
|
||||
kb_id=kb_id,
|
||||
query=query,
|
||||
n_results=n_results
|
||||
)
|
||||
|
||||
return {
|
||||
"query": query,
|
||||
"results": [
|
||||
{
|
||||
"content": doc,
|
||||
"metadata": meta,
|
||||
"distance": dist
|
||||
}
|
||||
for doc, meta, dist in zip(
|
||||
results.get("documents", [[]])[0] if results.get("documents") else [],
|
||||
results.get("metadatas", [[]])[0] if results.get("metadatas") else [],
|
||||
results.get("distances", [[]])[0] if results.get("distances") else []
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def index_document(kb_id: str, document_id: str, text: str) -> int:
|
||||
"""索引文档到向量库"""
|
||||
# 分块
|
||||
processor = DocumentProcessor()
|
||||
chunks = processor.process_document(text, document_id)
|
||||
|
||||
if not chunks:
|
||||
return 0
|
||||
|
||||
# 生成向量
|
||||
contents = [c["content"] for c in chunks]
|
||||
embeddings = embedding_service.embed(contents)
|
||||
|
||||
# 添加到向量库
|
||||
ids = [f"{document_id}-{c['chunk_index']}" for c in chunks]
|
||||
metadatas = [
|
||||
{
|
||||
"document_id": c["document_id"],
|
||||
"chunk_index": c["chunk_index"],
|
||||
"kb_id": kb_id
|
||||
}
|
||||
for c in chunks
|
||||
]
|
||||
|
||||
vector_store.add_documents(
|
||||
kb_id=kb_id,
|
||||
documents=contents,
|
||||
embeddings=embeddings,
|
||||
ids=ids,
|
||||
metadatas=metadatas
|
||||
)
|
||||
|
||||
return len(chunks)
|
||||
|
||||
|
||||
def delete_document_from_vector(kb_id: str, document_id: str):
|
||||
"""从向量库删除文档"""
|
||||
vector_store.delete_by_metadata(kb_id, document_id)
|
||||
Reference in New Issue
Block a user