260 lines
8.5 KiB
Python
260 lines
8.5 KiB
Python
from fastapi import APIRouter, Depends, HTTPException, Query
|
|
from sqlalchemy.orm import Session
|
|
from typing import Optional
|
|
import uuid
|
|
import os
|
|
from datetime import datetime
|
|
|
|
from ..db import get_db
|
|
from ..models import KnowledgeBase, KnowledgeDocument
|
|
from ..schemas import (
|
|
KnowledgeBaseCreate, KnowledgeBaseUpdate, KnowledgeBaseOut,
|
|
KnowledgeSearchQuery, KnowledgeSearchResult, KnowledgeStats,
|
|
DocumentIndexRequest,
|
|
KnowledgeDocumentCreate,
|
|
)
|
|
from ..vector_store import (
|
|
vector_store, search_knowledge, index_document, delete_document_from_vector
|
|
)
|
|
|
|
router = APIRouter(prefix="/knowledge", tags=["knowledge"])
|
|
|
|
|
|
def kb_to_dict(kb: KnowledgeBase) -> dict:
|
|
return {
|
|
"id": kb.id,
|
|
"user_id": kb.user_id,
|
|
"name": kb.name,
|
|
"description": kb.description,
|
|
"embeddingModel": kb.embedding_model,
|
|
"chunkSize": kb.chunk_size,
|
|
"chunkOverlap": kb.chunk_overlap,
|
|
"docCount": kb.doc_count,
|
|
"chunkCount": kb.chunk_count,
|
|
"status": kb.status,
|
|
"createdAt": kb.created_at.isoformat() if kb.created_at else None,
|
|
"updatedAt": kb.updated_at.isoformat() if kb.updated_at else None,
|
|
}
|
|
|
|
|
|
def doc_to_dict(d: KnowledgeDocument) -> dict:
|
|
return {
|
|
"id": d.id,
|
|
"kb_id": d.kb_id,
|
|
"name": d.name,
|
|
"size": d.size,
|
|
"fileType": d.file_type,
|
|
"storageUrl": d.storage_url,
|
|
"status": d.status,
|
|
"chunkCount": d.chunk_count,
|
|
"errorMessage": d.error_message,
|
|
"uploadDate": d.upload_date,
|
|
"createdAt": d.created_at.isoformat() if d.created_at else None,
|
|
"processedAt": d.processed_at.isoformat() if d.processed_at else None,
|
|
}
|
|
|
|
|
|
# ============ Knowledge Bases ============
|
|
@router.get("/bases")
|
|
def list_knowledge_bases(
|
|
user_id: int = 1,
|
|
page: int = 1,
|
|
limit: int = 50,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
query = db.query(KnowledgeBase).filter(KnowledgeBase.user_id == user_id)
|
|
total = query.count()
|
|
kbs = query.order_by(KnowledgeBase.created_at.desc()).offset((page - 1) * limit).limit(limit).all()
|
|
result = []
|
|
for kb in kbs:
|
|
docs = db.query(KnowledgeDocument).filter(KnowledgeDocument.kb_id == kb.id).all()
|
|
kb_data = kb_to_dict(kb)
|
|
kb_data["documents"] = [doc_to_dict(d) for d in docs]
|
|
result.append(kb_data)
|
|
return {"total": total, "page": page, "limit": limit, "list": result}
|
|
|
|
|
|
@router.get("/bases/{kb_id}")
|
|
def get_knowledge_base(kb_id: str, db: Session = Depends(get_db)):
|
|
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
|
|
if not kb:
|
|
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
|
docs = db.query(KnowledgeDocument).filter(KnowledgeDocument.kb_id == kb_id).all()
|
|
kb_data = kb_to_dict(kb)
|
|
kb_data["documents"] = [doc_to_dict(d) for d in docs]
|
|
return kb_data
|
|
|
|
|
|
@router.post("/bases")
|
|
def create_knowledge_base(data: KnowledgeBaseCreate, user_id: int = 1, db: Session = Depends(get_db)):
|
|
kb = KnowledgeBase(
|
|
id=str(uuid.uuid4())[:8],
|
|
user_id=user_id,
|
|
name=data.name,
|
|
description=data.description,
|
|
embedding_model=data.embeddingModel,
|
|
chunk_size=data.chunkSize,
|
|
chunk_overlap=data.chunkOverlap,
|
|
)
|
|
db.add(kb)
|
|
db.commit()
|
|
db.refresh(kb)
|
|
try:
|
|
vector_store.create_collection(kb.id, data.embeddingModel)
|
|
except Exception:
|
|
pass
|
|
return kb_to_dict(kb)
|
|
|
|
|
|
@router.put("/bases/{kb_id}")
|
|
def update_knowledge_base(kb_id: str, data: KnowledgeBaseUpdate, db: Session = Depends(get_db)):
|
|
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
|
|
if not kb:
|
|
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
|
update_data = data.model_dump(exclude_unset=True)
|
|
field_map = {
|
|
"embeddingModel": "embedding_model",
|
|
"chunkSize": "chunk_size",
|
|
"chunkOverlap": "chunk_overlap",
|
|
}
|
|
for field, value in update_data.items():
|
|
setattr(kb, field_map.get(field, field), value)
|
|
kb.updated_at = datetime.utcnow()
|
|
db.commit()
|
|
db.refresh(kb)
|
|
return kb_to_dict(kb)
|
|
|
|
|
|
@router.delete("/bases/{kb_id}")
|
|
def delete_knowledge_base(kb_id: str, db: Session = Depends(get_db)):
|
|
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
|
|
if not kb:
|
|
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
|
try:
|
|
vector_store.delete_collection(kb_id)
|
|
except Exception:
|
|
pass
|
|
docs = db.query(KnowledgeDocument).filter(KnowledgeDocument.kb_id == kb_id).all()
|
|
for doc in docs:
|
|
db.delete(doc)
|
|
db.delete(kb)
|
|
db.commit()
|
|
return {"message": "Deleted successfully"}
|
|
|
|
|
|
# ============ Documents ============
|
|
@router.post("/bases/{kb_id}/documents")
|
|
def upload_document(
|
|
kb_id: str,
|
|
data: KnowledgeDocumentCreate,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
|
|
if not kb:
|
|
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
|
doc = KnowledgeDocument(
|
|
id=str(uuid.uuid4())[:8],
|
|
kb_id=kb_id,
|
|
name=data.name,
|
|
size=data.size,
|
|
file_type=data.fileType,
|
|
storage_url=data.storageUrl,
|
|
status="pending",
|
|
upload_date=datetime.utcnow().isoformat()
|
|
)
|
|
db.add(doc)
|
|
db.commit()
|
|
db.refresh(doc)
|
|
return {
|
|
"id": doc.id,
|
|
"name": doc.name,
|
|
"size": doc.size,
|
|
"fileType": doc.file_type,
|
|
"storageUrl": doc.storage_url,
|
|
"status": doc.status,
|
|
"message": "Document created",
|
|
}
|
|
|
|
|
|
@router.post("/bases/{kb_id}/documents/{doc_id}/index")
|
|
def index_document_content(kb_id: str, doc_id: str, request: DocumentIndexRequest, db: Session = Depends(get_db)):
|
|
# 检查文档是否存在,不存在则创建
|
|
doc = db.query(KnowledgeDocument).filter(
|
|
KnowledgeDocument.id == doc_id,
|
|
KnowledgeDocument.kb_id == kb_id
|
|
).first()
|
|
|
|
if not doc:
|
|
doc = KnowledgeDocument(
|
|
id=doc_id,
|
|
kb_id=kb_id,
|
|
name=f"doc-{doc_id}.txt",
|
|
size=str(len(request.content)),
|
|
file_type="txt",
|
|
status="pending",
|
|
upload_date=datetime.utcnow().isoformat()
|
|
)
|
|
db.add(doc)
|
|
db.commit()
|
|
db.refresh(doc)
|
|
else:
|
|
# 更新已有文档
|
|
doc.size = str(len(request.content))
|
|
doc.status = "pending"
|
|
db.commit()
|
|
|
|
try:
|
|
chunk_count = index_document(kb_id, doc_id, request.content)
|
|
doc.status = "completed"
|
|
doc.chunk_count = chunk_count
|
|
doc.processed_at = datetime.utcnow()
|
|
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
|
|
kb.doc_count = db.query(KnowledgeDocument).filter(
|
|
KnowledgeDocument.kb_id == kb_id,
|
|
KnowledgeDocument.status == "completed"
|
|
).count()
|
|
kb.chunk_count += chunk_count
|
|
db.commit()
|
|
return {"message": "Document indexed", "chunkCount": chunk_count}
|
|
except Exception as e:
|
|
doc.status = "failed"
|
|
doc.error_message = str(e)
|
|
db.commit()
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.delete("/bases/{kb_id}/documents/{doc_id}")
|
|
def delete_document(kb_id: str, doc_id: str, db: Session = Depends(get_db)):
|
|
doc = db.query(KnowledgeDocument).filter(
|
|
KnowledgeDocument.id == doc_id,
|
|
KnowledgeDocument.kb_id == kb_id
|
|
).first()
|
|
if not doc:
|
|
raise HTTPException(status_code=404, detail="Document not found")
|
|
try:
|
|
delete_document_from_vector(kb_id, doc_id)
|
|
except Exception:
|
|
pass
|
|
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
|
|
if kb:
|
|
kb.chunk_count = max(0, kb.chunk_count - (doc.chunk_count or 0))
|
|
kb.doc_count = max(0, kb.doc_count - 1)
|
|
db.delete(doc)
|
|
db.commit()
|
|
return {"message": "Deleted successfully"}
|
|
|
|
|
|
# ============ Search ============
|
|
@router.post("/search")
|
|
def search_knowledge_base(query: KnowledgeSearchQuery):
|
|
return search_knowledge(kb_id=query.kb_id, query=query.query, n_results=query.nResults)
|
|
|
|
|
|
# ============ Stats ============
|
|
@router.get("/bases/{kb_id}/stats")
|
|
def get_knowledge_stats(kb_id: str, db: Session = Depends(get_db)):
|
|
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
|
|
if not kb:
|
|
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
|
return {"kb_id": kb_id, "docCount": kb.doc_count, "chunkCount": kb.chunk_count}
|