""" 向量数据库服务 (ChromaDB) """ import os from typing import List, Dict, Optional import chromadb from chromadb.config import Settings # 配置 VECTOR_STORE_PATH = os.getenv("VECTOR_STORE_PATH", "./data/vector_store") COLLECTION_NAME_PREFIX = "kb_" class VectorStore: """向量存储服务""" def __init__(self): os.makedirs(VECTOR_STORE_PATH, exist_ok=True) self.client = chromadb.PersistentClient( path=VECTOR_STORE_PATH, settings=Settings(anonymized_telemetry=False) ) def get_collection(self, kb_id: str): """获取知识库集合""" collection_name = f"{COLLECTION_NAME_PREFIX}{kb_id}" try: return self.client.get_collection(name=collection_name) except (ValueError, chromadb.errors.NotFoundError): return None def create_collection(self, kb_id: str, embedding_model: str = "text-embedding-3-small"): """创建知识库向量集合""" collection_name = f"{COLLECTION_NAME_PREFIX}{kb_id}" try: self.client.get_collection(name=collection_name) return collection_name except (ValueError, chromadb.errors.NotFoundError): self.client.create_collection( name=collection_name, metadata={ "kb_id": kb_id, "embedding_model": embedding_model } ) return collection_name def delete_collection(self, kb_id: str): """删除知识库向量集合""" collection_name = f"{COLLECTION_NAME_PREFIX}{kb_id}" try: self.client.delete_collection(name=collection_name) return True except (ValueError, chromadb.errors.NotFoundError): return False def add_documents( self, kb_id: str, documents: List[str], embeddings: Optional[List[List[float]]] = None, ids: Optional[List[str]] = None, metadatas: Optional[List[Dict]] = None ): """添加文档片段到向量库""" collection = self.get_collection(kb_id) if ids is None: ids = [f"chunk-{i}" for i in range(len(documents))] if embeddings is not None: collection.add( documents=documents, embeddings=embeddings, ids=ids, metadatas=metadatas ) else: collection.add( documents=documents, ids=ids, metadatas=metadatas ) return len(documents) def search( self, kb_id: str, query: str, n_results: int = 5, where: Optional[Dict] = None ) -> Dict: """检索相似文档""" collection = self.get_collection(kb_id) # 生成查询向量 query_embedding = embedding_service.embed_query(query) results = collection.query( query_embeddings=[query_embedding], n_results=n_results, where=where ) return results def get_stats(self, kb_id: str) -> Dict: """获取向量库统计""" collection = self.get_collection(kb_id) return { "count": collection.count(), "kb_id": kb_id } def delete_documents(self, kb_id: str, ids: List[str]): """删除指定文档片段""" collection = self.get_collection(kb_id) collection.delete(ids=ids) def delete_by_metadata(self, kb_id: str, document_id: str): """根据文档 ID 删除所有片段""" collection = self.get_collection(kb_id) results = collection.get(where={"document_id": document_id}) if results["ids"]: collection.delete(ids=results["ids"]) class EmbeddingService: """ embedding 服务(支持多种模型)""" def __init__(self, model: str = "text-embedding-3-small"): self.model = model self._client = None def _get_client(self): """获取 OpenAI 客户端""" if self._client is None: try: from openai import OpenAI api_key = os.getenv("OPENAI_API_KEY") if api_key: self._client = OpenAI(api_key=api_key) except ImportError: pass return self._client def embed(self, texts: List[str]) -> List[List[float]]: """生成 embedding 向量""" client = self._get_client() if client is None: # 返回随机向量(仅用于测试) import random import math dim = 1536 if "3-small" in self.model else 1024 return [[random.uniform(-1, 1) for _ in range(dim)] for _ in texts] response = client.embeddings.create( model=self.model, input=texts ) return [data.embedding for data in response.data] def embed_query(self, query: str) -> List[float]: """生成查询向量""" return self.embed([query])[0] class DocumentProcessor: """文档处理服务""" def __init__(self, chunk_size: int = 500, chunk_overlap: int = 50): self.chunk_size = chunk_size self.chunk_overlap = chunk_overlap def chunk_text(self, text: str, document_id: str = "") -> List[Dict]: """将文本分块""" # 简单分块(按句子/段落) import re # 按句子分割 sentences = re.split(r'[。!?\n]', text) chunks = [] current_chunk = "" current_size = 0 for i, sentence in enumerate(sentences): sentence = sentence.strip() if not sentence: continue sentence_len = len(sentence) if current_size + sentence_len > self.chunk_size and current_chunk: # 保存当前块 chunks.append({ "content": current_chunk.strip(), "document_id": document_id, "chunk_index": len(chunks), "metadata": { "source": "text" } }) # 处理重叠 if self.chunk_overlap > 0: # 保留末尾部分 overlap_chars = current_chunk[-self.chunk_overlap:] current_chunk = overlap_chars + " " + sentence current_size = len(overlap_chars) + sentence_len + 1 else: current_chunk = sentence current_size = sentence_len else: if current_chunk: current_chunk += " " current_chunk += sentence current_size += sentence_len + 1 # 保存最后一个块 if current_chunk.strip(): chunks.append({ "content": current_chunk.strip(), "document_id": document_id, "chunk_index": len(chunks), "metadata": { "source": "text" } }) return chunks def process_document(self, text: str, document_id: str = "") -> List[Dict]: """完整处理文档""" return self.chunk_text(text, document_id) # 全局实例 vector_store = VectorStore() embedding_service = EmbeddingService() def search_knowledge(kb_id: str, query: str, n_results: int = 5) -> Dict: """知识库检索""" # 生成查询向量 query_vector = embedding_service.embed_query(query) # 检索 results = vector_store.search( kb_id=kb_id, query=query, n_results=n_results ) return { "query": query, "results": [ { "content": doc, "metadata": meta, "distance": dist } for doc, meta, dist in zip( results.get("documents", [[]])[0] if results.get("documents") else [], results.get("metadatas", [[]])[0] if results.get("metadatas") else [], results.get("distances", [[]])[0] if results.get("distances") else [] ) ] } def index_document(kb_id: str, document_id: str, text: str) -> int: """索引文档到向量库""" # 分块 processor = DocumentProcessor() chunks = processor.process_document(text, document_id) if not chunks: return 0 # 生成向量 contents = [c["content"] for c in chunks] embeddings = embedding_service.embed(contents) # 添加到向量库 ids = [f"{document_id}-{c['chunk_index']}" for c in chunks] metadatas = [ { "document_id": c["document_id"], "chunk_index": c["chunk_index"], "kb_id": kb_id } for c in chunks ] vector_store.add_documents( kb_id=kb_id, documents=contents, embeddings=embeddings, ids=ids, metadatas=metadatas ) return len(chunks) def delete_document_from_vector(kb_id: str, document_id: str): """从向量库删除文档""" vector_store.delete_by_metadata(kb_id, document_id)