from fastapi import APIRouter, Depends, HTTPException, Query from sqlalchemy.orm import Session from typing import Optional import uuid import os from datetime import datetime from ..db import get_db from ..models import KnowledgeBase, KnowledgeDocument from ..schemas import ( KnowledgeBaseCreate, KnowledgeBaseUpdate, KnowledgeBaseOut, KnowledgeSearchQuery, KnowledgeSearchResult, KnowledgeStats, DocumentIndexRequest, ) from ..vector_store import ( vector_store, search_knowledge, index_document, delete_document_from_vector ) router = APIRouter(prefix="/knowledge", tags=["knowledge"]) def kb_to_dict(kb: KnowledgeBase) -> dict: return { "id": kb.id, "user_id": kb.user_id, "name": kb.name, "description": kb.description, "embedding_model": kb.embedding_model, "chunk_size": kb.chunk_size, "chunk_overlap": kb.chunk_overlap, "doc_count": kb.doc_count, "chunk_count": kb.chunk_count, "status": kb.status, "created_at": kb.created_at.isoformat() if kb.created_at else None, "updated_at": kb.updated_at.isoformat() if kb.updated_at else None, } def doc_to_dict(d: KnowledgeDocument) -> dict: return { "id": d.id, "kb_id": d.kb_id, "name": d.name, "size": d.size, "file_type": d.file_type, "storage_url": d.storage_url, "status": d.status, "chunk_count": d.chunk_count, "error_message": d.error_message, "upload_date": d.upload_date, "created_at": d.created_at.isoformat() if d.created_at else None, "processed_at": d.processed_at.isoformat() if d.processed_at else None, } # ============ Knowledge Bases ============ @router.get("/bases") def list_knowledge_bases(user_id: int = 1, db: Session = Depends(get_db)): kbs = db.query(KnowledgeBase).filter(KnowledgeBase.user_id == user_id).all() result = [] for kb in kbs: docs = db.query(KnowledgeDocument).filter(KnowledgeDocument.kb_id == kb.id).all() kb_data = kb_to_dict(kb) kb_data["documents"] = [doc_to_dict(d) for d in docs] result.append(kb_data) return {"total": len(result), "list": result} @router.get("/bases/{kb_id}") def get_knowledge_base(kb_id: str, db: Session = Depends(get_db)): kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first() if not kb: raise HTTPException(status_code=404, detail="Knowledge base not found") docs = db.query(KnowledgeDocument).filter(KnowledgeDocument.kb_id == kb_id).all() kb_data = kb_to_dict(kb) kb_data["documents"] = [doc_to_dict(d) for d in docs] return kb_data @router.post("/bases") def create_knowledge_base(data: KnowledgeBaseCreate, user_id: int = 1, db: Session = Depends(get_db)): kb = KnowledgeBase( id=str(uuid.uuid4())[:8], user_id=user_id, name=data.name, description=data.description, embedding_model=data.embeddingModel, chunk_size=data.chunkSize, chunk_overlap=data.chunkOverlap, ) db.add(kb) db.commit() db.refresh(kb) vector_store.create_collection(kb.id, data.embeddingModel) return kb_to_dict(kb) @router.put("/bases/{kb_id}") def update_knowledge_base(kb_id: str, data: KnowledgeBaseUpdate, db: Session = Depends(get_db)): kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first() if not kb: raise HTTPException(status_code=404, detail="Knowledge base not found") update_data = data.model_dump(exclude_unset=True) for field, value in update_data.items(): setattr(kb, field, value) kb.updated_at = datetime.utcnow() db.commit() db.refresh(kb) return kb_to_dict(kb) @router.delete("/bases/{kb_id}") def delete_knowledge_base(kb_id: str, db: Session = Depends(get_db)): kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first() if not kb: raise HTTPException(status_code=404, detail="Knowledge base not found") vector_store.delete_collection(kb_id) docs = db.query(KnowledgeDocument).filter(KnowledgeDocument.kb_id == kb_id).all() for doc in docs: db.delete(doc) db.delete(kb) db.commit() return {"message": "Deleted successfully"} # ============ Documents ============ @router.post("/bases/{kb_id}/documents") def upload_document( kb_id: str, name: str = Query(...), size: str = Query(...), file_type: str = Query("txt"), storage_url: Optional[str] = Query(None), db: Session = Depends(get_db) ): kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first() if not kb: raise HTTPException(status_code=404, detail="Knowledge base not found") doc = KnowledgeDocument( id=str(uuid.uuid4())[:8], kb_id=kb_id, name=name, size=size, file_type=file_type, storage_url=storage_url, status="pending", upload_date=datetime.utcnow().isoformat() ) db.add(doc) db.commit() db.refresh(doc) return {"id": doc.id, "name": doc.name, "status": doc.status, "message": "Document created"} @router.post("/bases/{kb_id}/documents/{doc_id}/index") def index_document_content(kb_id: str, doc_id: str, request: DocumentIndexRequest, db: Session = Depends(get_db)): # 检查文档是否存在,不存在则创建 doc = db.query(KnowledgeDocument).filter( KnowledgeDocument.id == doc_id, KnowledgeDocument.kb_id == kb_id ).first() if not doc: doc = KnowledgeDocument( id=doc_id, kb_id=kb_id, name=f"doc-{doc_id}.txt", size=str(len(request.content)), file_type="txt", status="pending", upload_date=datetime.utcnow().isoformat() ) db.add(doc) db.commit() db.refresh(doc) else: # 更新已有文档 doc.size = str(len(request.content)) doc.status = "pending" db.commit() try: chunk_count = index_document(kb_id, doc_id, request.content) doc.status = "completed" doc.chunk_count = chunk_count doc.processed_at = datetime.utcnow() kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first() kb.doc_count = db.query(KnowledgeDocument).filter( KnowledgeDocument.kb_id == kb_id, KnowledgeDocument.status == "completed" ).count() kb.chunk_count += chunk_count db.commit() return {"message": "Document indexed", "chunkCount": chunk_count} except Exception as e: doc.status = "failed" doc.error_message = str(e) db.commit() raise HTTPException(status_code=500, detail=str(e)) @router.delete("/bases/{kb_id}/documents/{doc_id}") def delete_document(kb_id: str, doc_id: str, db: Session = Depends(get_db)): doc = db.query(KnowledgeDocument).filter( KnowledgeDocument.id == doc_id, KnowledgeDocument.kb_id == kb_id ).first() if not doc: raise HTTPException(status_code=404, detail="Document not found") try: delete_document_from_vector(kb_id, doc_id) except Exception: pass kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first() kb.chunk_count -= doc.chunk_count kb.doc_count -= 1 db.delete(doc) db.commit() return {"message": "Deleted successfully"} # ============ Search ============ @router.post("/search") def search_knowledge_base(query: KnowledgeSearchQuery): return search_knowledge(kb_id=query.kb_id, query=query.query, n_results=query.nResults) # ============ Stats ============ @router.get("/bases/{kb_id}/stats") def get_knowledge_stats(kb_id: str, db: Session = Depends(get_db)): kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first() if not kb: raise HTTPException(status_code=404, detail="Knowledge base not found") return {"kb_id": kb_id, "docCount": kb.doc_count, "chunkCount": kb.chunk_count}