235 lines
7.9 KiB
Python
235 lines
7.9 KiB
Python
from fastapi import APIRouter, Depends, HTTPException, Query
|
|
from sqlalchemy.orm import Session
|
|
from typing import Optional
|
|
import uuid
|
|
import os
|
|
from datetime import datetime
|
|
|
|
from ..db import get_db
|
|
from ..models import KnowledgeBase, KnowledgeDocument
|
|
from ..schemas import (
|
|
KnowledgeBaseCreate, KnowledgeBaseUpdate, KnowledgeBaseOut,
|
|
KnowledgeSearchQuery, KnowledgeSearchResult, KnowledgeStats,
|
|
DocumentIndexRequest,
|
|
)
|
|
from ..vector_store import (
|
|
vector_store, search_knowledge, index_document, delete_document_from_vector
|
|
)
|
|
|
|
router = APIRouter(prefix="/knowledge", tags=["knowledge"])
|
|
|
|
|
|
def kb_to_dict(kb: KnowledgeBase) -> dict:
|
|
return {
|
|
"id": kb.id,
|
|
"user_id": kb.user_id,
|
|
"name": kb.name,
|
|
"description": kb.description,
|
|
"embedding_model": kb.embedding_model,
|
|
"chunk_size": kb.chunk_size,
|
|
"chunk_overlap": kb.chunk_overlap,
|
|
"doc_count": kb.doc_count,
|
|
"chunk_count": kb.chunk_count,
|
|
"status": kb.status,
|
|
"created_at": kb.created_at.isoformat() if kb.created_at else None,
|
|
"updated_at": kb.updated_at.isoformat() if kb.updated_at else None,
|
|
}
|
|
|
|
|
|
def doc_to_dict(d: KnowledgeDocument) -> dict:
|
|
return {
|
|
"id": d.id,
|
|
"kb_id": d.kb_id,
|
|
"name": d.name,
|
|
"size": d.size,
|
|
"file_type": d.file_type,
|
|
"storage_url": d.storage_url,
|
|
"status": d.status,
|
|
"chunk_count": d.chunk_count,
|
|
"error_message": d.error_message,
|
|
"upload_date": d.upload_date,
|
|
"created_at": d.created_at.isoformat() if d.created_at else None,
|
|
"processed_at": d.processed_at.isoformat() if d.processed_at else None,
|
|
}
|
|
|
|
|
|
# ============ Knowledge Bases ============
|
|
@router.get("/bases")
|
|
def list_knowledge_bases(user_id: int = 1, db: Session = Depends(get_db)):
|
|
kbs = db.query(KnowledgeBase).filter(KnowledgeBase.user_id == user_id).all()
|
|
result = []
|
|
for kb in kbs:
|
|
docs = db.query(KnowledgeDocument).filter(KnowledgeDocument.kb_id == kb.id).all()
|
|
kb_data = kb_to_dict(kb)
|
|
kb_data["documents"] = [doc_to_dict(d) for d in docs]
|
|
result.append(kb_data)
|
|
return {"total": len(result), "list": result}
|
|
|
|
|
|
@router.get("/bases/{kb_id}")
|
|
def get_knowledge_base(kb_id: str, db: Session = Depends(get_db)):
|
|
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
|
|
if not kb:
|
|
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
|
docs = db.query(KnowledgeDocument).filter(KnowledgeDocument.kb_id == kb_id).all()
|
|
kb_data = kb_to_dict(kb)
|
|
kb_data["documents"] = [doc_to_dict(d) for d in docs]
|
|
return kb_data
|
|
|
|
|
|
@router.post("/bases")
|
|
def create_knowledge_base(data: KnowledgeBaseCreate, user_id: int = 1, db: Session = Depends(get_db)):
|
|
kb = KnowledgeBase(
|
|
id=str(uuid.uuid4())[:8],
|
|
user_id=user_id,
|
|
name=data.name,
|
|
description=data.description,
|
|
embedding_model=data.embeddingModel,
|
|
chunk_size=data.chunkSize,
|
|
chunk_overlap=data.chunkOverlap,
|
|
)
|
|
db.add(kb)
|
|
db.commit()
|
|
db.refresh(kb)
|
|
vector_store.create_collection(kb.id, data.embeddingModel)
|
|
return kb_to_dict(kb)
|
|
|
|
|
|
@router.put("/bases/{kb_id}")
|
|
def update_knowledge_base(kb_id: str, data: KnowledgeBaseUpdate, db: Session = Depends(get_db)):
|
|
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
|
|
if not kb:
|
|
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
|
update_data = data.model_dump(exclude_unset=True)
|
|
for field, value in update_data.items():
|
|
setattr(kb, field, value)
|
|
kb.updated_at = datetime.utcnow()
|
|
db.commit()
|
|
db.refresh(kb)
|
|
return kb_to_dict(kb)
|
|
|
|
|
|
@router.delete("/bases/{kb_id}")
|
|
def delete_knowledge_base(kb_id: str, db: Session = Depends(get_db)):
|
|
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
|
|
if not kb:
|
|
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
|
vector_store.delete_collection(kb_id)
|
|
docs = db.query(KnowledgeDocument).filter(KnowledgeDocument.kb_id == kb_id).all()
|
|
for doc in docs:
|
|
db.delete(doc)
|
|
db.delete(kb)
|
|
db.commit()
|
|
return {"message": "Deleted successfully"}
|
|
|
|
|
|
# ============ Documents ============
|
|
@router.post("/bases/{kb_id}/documents")
|
|
def upload_document(
|
|
kb_id: str,
|
|
name: str = Query(...),
|
|
size: str = Query(...),
|
|
file_type: str = Query("txt"),
|
|
storage_url: Optional[str] = Query(None),
|
|
db: Session = Depends(get_db)
|
|
):
|
|
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
|
|
if not kb:
|
|
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
|
doc = KnowledgeDocument(
|
|
id=str(uuid.uuid4())[:8],
|
|
kb_id=kb_id,
|
|
name=name,
|
|
size=size,
|
|
file_type=file_type,
|
|
storage_url=storage_url,
|
|
status="pending",
|
|
upload_date=datetime.utcnow().isoformat()
|
|
)
|
|
db.add(doc)
|
|
db.commit()
|
|
db.refresh(doc)
|
|
return {"id": doc.id, "name": doc.name, "status": doc.status, "message": "Document created"}
|
|
|
|
|
|
@router.post("/bases/{kb_id}/documents/{doc_id}/index")
|
|
def index_document_content(kb_id: str, doc_id: str, request: DocumentIndexRequest, db: Session = Depends(get_db)):
|
|
# 检查文档是否存在,不存在则创建
|
|
doc = db.query(KnowledgeDocument).filter(
|
|
KnowledgeDocument.id == doc_id,
|
|
KnowledgeDocument.kb_id == kb_id
|
|
).first()
|
|
|
|
if not doc:
|
|
doc = KnowledgeDocument(
|
|
id=doc_id,
|
|
kb_id=kb_id,
|
|
name=f"doc-{doc_id}.txt",
|
|
size=str(len(request.content)),
|
|
file_type="txt",
|
|
status="pending",
|
|
upload_date=datetime.utcnow().isoformat()
|
|
)
|
|
db.add(doc)
|
|
db.commit()
|
|
db.refresh(doc)
|
|
else:
|
|
# 更新已有文档
|
|
doc.size = str(len(request.content))
|
|
doc.status = "pending"
|
|
db.commit()
|
|
|
|
try:
|
|
chunk_count = index_document(kb_id, doc_id, request.content)
|
|
doc.status = "completed"
|
|
doc.chunk_count = chunk_count
|
|
doc.processed_at = datetime.utcnow()
|
|
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
|
|
kb.doc_count = db.query(KnowledgeDocument).filter(
|
|
KnowledgeDocument.kb_id == kb_id,
|
|
KnowledgeDocument.status == "completed"
|
|
).count()
|
|
kb.chunk_count += chunk_count
|
|
db.commit()
|
|
return {"message": "Document indexed", "chunkCount": chunk_count}
|
|
except Exception as e:
|
|
doc.status = "failed"
|
|
doc.error_message = str(e)
|
|
db.commit()
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.delete("/bases/{kb_id}/documents/{doc_id}")
|
|
def delete_document(kb_id: str, doc_id: str, db: Session = Depends(get_db)):
|
|
doc = db.query(KnowledgeDocument).filter(
|
|
KnowledgeDocument.id == doc_id,
|
|
KnowledgeDocument.kb_id == kb_id
|
|
).first()
|
|
if not doc:
|
|
raise HTTPException(status_code=404, detail="Document not found")
|
|
try:
|
|
delete_document_from_vector(kb_id, doc_id)
|
|
except Exception:
|
|
pass
|
|
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
|
|
kb.chunk_count -= doc.chunk_count
|
|
kb.doc_count -= 1
|
|
db.delete(doc)
|
|
db.commit()
|
|
return {"message": "Deleted successfully"}
|
|
|
|
|
|
# ============ Search ============
|
|
@router.post("/search")
|
|
def search_knowledge_base(query: KnowledgeSearchQuery):
|
|
return search_knowledge(kb_id=query.kb_id, query=query.query, n_results=query.nResults)
|
|
|
|
|
|
# ============ Stats ============
|
|
@router.get("/bases/{kb_id}/stats")
|
|
def get_knowledge_stats(kb_id: str, db: Session = Depends(get_db)):
|
|
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
|
|
if not kb:
|
|
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
|
return {"kb_id": kb_id, "docCount": kb.doc_count, "chunkCount": kb.chunk_count}
|