448 lines
16 KiB
Python
448 lines
16 KiB
Python
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, Request
|
|
from sqlalchemy.orm import Session
|
|
from typing import Optional
|
|
import uuid
|
|
import os
|
|
import json
|
|
from io import BytesIO
|
|
from datetime import datetime
|
|
|
|
from ..db import get_db
|
|
from ..models import KnowledgeBase, KnowledgeDocument
|
|
from ..schemas import (
|
|
KnowledgeBaseCreate, KnowledgeBaseUpdate, KnowledgeBaseOut,
|
|
KnowledgeSearchQuery, KnowledgeSearchResult, KnowledgeStats,
|
|
DocumentIndexRequest,
|
|
KnowledgeDocumentCreate,
|
|
)
|
|
from ..vector_store import (
|
|
vector_store, search_knowledge, index_document, delete_document_from_vector
|
|
)
|
|
|
|
router = APIRouter(prefix="/knowledge", tags=["knowledge"])
|
|
|
|
|
|
def _refresh_kb_stats(db: Session, kb_id: str) -> None:
|
|
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
|
|
if not kb:
|
|
return
|
|
docs = db.query(KnowledgeDocument).filter(KnowledgeDocument.kb_id == kb_id).all()
|
|
completed_docs = [d for d in docs if d.status == "completed"]
|
|
kb.doc_count = len(completed_docs)
|
|
kb.chunk_count = sum(max(0, d.chunk_count or 0) for d in completed_docs)
|
|
|
|
|
|
def _decode_text_bytes(raw: bytes) -> str:
|
|
for encoding in ("utf-8", "utf-8-sig", "gb18030", "gbk", "latin-1"):
|
|
try:
|
|
return raw.decode(encoding)
|
|
except UnicodeDecodeError:
|
|
continue
|
|
return raw.decode("utf-8", errors="ignore")
|
|
|
|
|
|
def _extract_text_from_upload(filename: str, content_type: Optional[str], raw: bytes) -> str:
|
|
ext = os.path.splitext((filename or "").lower())[1]
|
|
if ext in {".txt", ".md", ".csv"}:
|
|
return _decode_text_bytes(raw)
|
|
if ext == ".json":
|
|
try:
|
|
parsed = json.loads(_decode_text_bytes(raw))
|
|
return json.dumps(parsed, ensure_ascii=False, indent=2)
|
|
except Exception:
|
|
return _decode_text_bytes(raw)
|
|
if ext == ".pdf":
|
|
try:
|
|
from pypdf import PdfReader # type: ignore
|
|
except Exception as exc:
|
|
raise ValueError("PDF parsing requires installing pypdf") from exc
|
|
reader = PdfReader(BytesIO(raw))
|
|
return "\n".join((page.extract_text() or "") for page in reader.pages).strip()
|
|
if ext == ".docx":
|
|
try:
|
|
from docx import Document # type: ignore
|
|
except Exception as exc:
|
|
raise ValueError("DOCX parsing requires installing python-docx") from exc
|
|
doc = Document(BytesIO(raw))
|
|
return "\n".join(p.text for p in doc.paragraphs).strip()
|
|
if ext == ".doc":
|
|
raise ValueError("DOC format is not supported for auto indexing. Please convert to DOCX/TXT.")
|
|
# fallback: attempt plain text decode
|
|
if (content_type or "").startswith("text/"):
|
|
return _decode_text_bytes(raw)
|
|
raise ValueError(f"Unsupported file type for auto indexing: {ext or content_type or 'unknown'}")
|
|
|
|
|
|
def kb_to_dict(kb: KnowledgeBase) -> dict:
|
|
return {
|
|
"id": kb.id,
|
|
"user_id": kb.user_id,
|
|
"name": kb.name,
|
|
"description": kb.description,
|
|
"embeddingModel": kb.embedding_model,
|
|
"chunkSize": kb.chunk_size,
|
|
"chunkOverlap": kb.chunk_overlap,
|
|
"docCount": kb.doc_count,
|
|
"chunkCount": kb.chunk_count,
|
|
"status": kb.status,
|
|
"createdAt": kb.created_at.isoformat() if kb.created_at else None,
|
|
"updatedAt": kb.updated_at.isoformat() if kb.updated_at else None,
|
|
}
|
|
|
|
|
|
def doc_to_dict(d: KnowledgeDocument) -> dict:
|
|
return {
|
|
"id": d.id,
|
|
"kb_id": d.kb_id,
|
|
"name": d.name,
|
|
"size": d.size,
|
|
"fileType": d.file_type,
|
|
"storageUrl": d.storage_url,
|
|
"status": d.status,
|
|
"chunkCount": d.chunk_count,
|
|
"errorMessage": d.error_message,
|
|
"uploadDate": d.upload_date,
|
|
"createdAt": d.created_at.isoformat() if d.created_at else None,
|
|
"processedAt": d.processed_at.isoformat() if d.processed_at else None,
|
|
}
|
|
|
|
|
|
# ============ Knowledge Bases ============
|
|
@router.get("/bases")
|
|
def list_knowledge_bases(
|
|
user_id: int = 1,
|
|
page: int = 1,
|
|
limit: int = 50,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
query = db.query(KnowledgeBase).filter(KnowledgeBase.user_id == user_id)
|
|
total = query.count()
|
|
kbs = query.order_by(KnowledgeBase.created_at.desc()).offset((page - 1) * limit).limit(limit).all()
|
|
result = []
|
|
for kb in kbs:
|
|
docs = db.query(KnowledgeDocument).filter(KnowledgeDocument.kb_id == kb.id).all()
|
|
kb_data = kb_to_dict(kb)
|
|
kb_data["documents"] = [doc_to_dict(d) for d in docs]
|
|
result.append(kb_data)
|
|
return {"total": total, "page": page, "limit": limit, "list": result}
|
|
|
|
|
|
@router.get("/bases/{kb_id}")
|
|
def get_knowledge_base(kb_id: str, db: Session = Depends(get_db)):
|
|
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
|
|
if not kb:
|
|
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
|
docs = db.query(KnowledgeDocument).filter(KnowledgeDocument.kb_id == kb_id).all()
|
|
kb_data = kb_to_dict(kb)
|
|
kb_data["documents"] = [doc_to_dict(d) for d in docs]
|
|
return kb_data
|
|
|
|
|
|
@router.post("/bases")
|
|
def create_knowledge_base(data: KnowledgeBaseCreate, user_id: int = 1, db: Session = Depends(get_db)):
|
|
name = (data.name or "").strip()
|
|
if not name:
|
|
raise HTTPException(status_code=400, detail="Knowledge base name cannot be empty")
|
|
|
|
exists = db.query(KnowledgeBase).filter(
|
|
KnowledgeBase.user_id == user_id,
|
|
KnowledgeBase.name == name
|
|
).first()
|
|
if exists:
|
|
raise HTTPException(status_code=400, detail=f"Knowledge base name already exists: {name}")
|
|
|
|
kb = KnowledgeBase(
|
|
id=str(uuid.uuid4())[:8],
|
|
user_id=user_id,
|
|
name=name,
|
|
description=data.description,
|
|
embedding_model=data.embeddingModel,
|
|
chunk_size=data.chunkSize,
|
|
chunk_overlap=data.chunkOverlap,
|
|
)
|
|
db.add(kb)
|
|
db.commit()
|
|
db.refresh(kb)
|
|
try:
|
|
vector_store.create_collection(kb.id, data.embeddingModel)
|
|
except Exception as exc:
|
|
# Keep DB and vector store consistent on create failure
|
|
db.delete(kb)
|
|
db.commit()
|
|
raise HTTPException(status_code=502, detail=f"Failed to create ChromaDB collection: {exc}") from exc
|
|
return kb_to_dict(kb)
|
|
|
|
|
|
@router.put("/bases/{kb_id}")
|
|
def update_knowledge_base(kb_id: str, data: KnowledgeBaseUpdate, db: Session = Depends(get_db)):
|
|
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
|
|
if not kb:
|
|
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
|
update_data = data.model_dump(exclude_unset=True)
|
|
field_map = {
|
|
"embeddingModel": "embedding_model",
|
|
"chunkSize": "chunk_size",
|
|
"chunkOverlap": "chunk_overlap",
|
|
}
|
|
if "name" in update_data:
|
|
update_data["name"] = (update_data["name"] or "").strip()
|
|
if not update_data["name"]:
|
|
raise HTTPException(status_code=400, detail="Knowledge base name cannot be empty")
|
|
name_exists = db.query(KnowledgeBase).filter(
|
|
KnowledgeBase.user_id == kb.user_id,
|
|
KnowledgeBase.name == update_data["name"],
|
|
KnowledgeBase.id != kb.id
|
|
).first()
|
|
if name_exists:
|
|
raise HTTPException(status_code=400, detail=f"Knowledge base name already exists: {update_data['name']}")
|
|
|
|
embedding_changed = "embeddingModel" in update_data and update_data["embeddingModel"] != kb.embedding_model
|
|
if embedding_changed and kb.chunk_count > 0:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Cannot change embedding model when knowledge base has indexed chunks. Remove documents first."
|
|
)
|
|
|
|
for field, value in update_data.items():
|
|
setattr(kb, field_map.get(field, field), value)
|
|
|
|
if embedding_changed:
|
|
try:
|
|
vector_store.delete_collection(kb_id)
|
|
except Exception:
|
|
pass
|
|
try:
|
|
vector_store.create_collection(kb_id, kb.embedding_model)
|
|
except Exception as exc:
|
|
raise HTTPException(status_code=502, detail=f"Failed to update ChromaDB collection: {exc}") from exc
|
|
|
|
kb.updated_at = datetime.utcnow()
|
|
db.commit()
|
|
db.refresh(kb)
|
|
return kb_to_dict(kb)
|
|
|
|
|
|
@router.delete("/bases/{kb_id}")
|
|
def delete_knowledge_base(kb_id: str, db: Session = Depends(get_db)):
|
|
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
|
|
if not kb:
|
|
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
|
vector_deleted = True
|
|
try:
|
|
vector_store.delete_collection(kb_id)
|
|
except Exception:
|
|
vector_deleted = False
|
|
docs = db.query(KnowledgeDocument).filter(KnowledgeDocument.kb_id == kb_id).all()
|
|
for doc in docs:
|
|
db.delete(doc)
|
|
db.delete(kb)
|
|
db.commit()
|
|
if not vector_deleted:
|
|
return {"message": "Deleted successfully", "warning": "Knowledge base deleted but failed to remove ChromaDB collection"}
|
|
return {"message": "Deleted successfully"}
|
|
|
|
|
|
# ============ Documents ============
|
|
@router.post("/bases/{kb_id}/documents")
|
|
async def upload_document(
|
|
kb_id: str,
|
|
file: Optional[UploadFile] = File(default=None),
|
|
name: Optional[str] = Form(default=None),
|
|
size: Optional[str] = Form(default=None),
|
|
file_type: Optional[str] = Form(default=None),
|
|
storage_url: Optional[str] = Form(default=None),
|
|
data: Optional[KnowledgeDocumentCreate] = None,
|
|
request: Request = None,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
|
|
if not kb:
|
|
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
|
|
|
# New mode: multipart file upload with automatic indexing
|
|
if file is not None:
|
|
filename = file.filename or "uploaded.txt"
|
|
file_type_value = file.content_type or file_type or "application/octet-stream"
|
|
raw = file.file.read()
|
|
if not raw:
|
|
raise HTTPException(status_code=400, detail="Uploaded file is empty")
|
|
|
|
doc = KnowledgeDocument(
|
|
id=str(uuid.uuid4())[:8],
|
|
kb_id=kb_id,
|
|
name=filename,
|
|
size=f"{len(raw)} bytes",
|
|
file_type=file_type_value,
|
|
storage_url=storage_url,
|
|
status="processing",
|
|
upload_date=datetime.utcnow().isoformat()
|
|
)
|
|
db.add(doc)
|
|
db.commit()
|
|
db.refresh(doc)
|
|
|
|
try:
|
|
if vector_store.get_collection(kb_id) is None:
|
|
vector_store.create_collection(kb_id, kb.embedding_model)
|
|
|
|
text = _extract_text_from_upload(filename, file.content_type, raw)
|
|
if not text.strip():
|
|
raise ValueError("No textual content extracted from file")
|
|
|
|
chunk_count = index_document(kb_id, doc.id, text)
|
|
doc.status = "completed"
|
|
doc.chunk_count = chunk_count
|
|
doc.processed_at = datetime.utcnow()
|
|
doc.error_message = None
|
|
_refresh_kb_stats(db, kb_id)
|
|
db.commit()
|
|
return {
|
|
"id": doc.id,
|
|
"name": doc.name,
|
|
"size": doc.size,
|
|
"fileType": doc.file_type,
|
|
"storageUrl": doc.storage_url,
|
|
"status": doc.status,
|
|
"chunkCount": doc.chunk_count,
|
|
"message": "Document uploaded and indexed",
|
|
}
|
|
except ValueError as exc:
|
|
doc.status = "failed"
|
|
doc.error_message = str(exc)
|
|
_refresh_kb_stats(db, kb_id)
|
|
db.commit()
|
|
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
|
except Exception as exc:
|
|
doc.status = "failed"
|
|
doc.error_message = str(exc)
|
|
_refresh_kb_stats(db, kb_id)
|
|
db.commit()
|
|
raise HTTPException(status_code=500, detail=f"Failed to index uploaded file: {exc}") from exc
|
|
|
|
# Backward-compatible mode: metadata-only document creation
|
|
if data is None:
|
|
if not name and not size and request is not None:
|
|
try:
|
|
raw_payload = await request.json()
|
|
if isinstance(raw_payload, dict):
|
|
name = raw_payload.get("name")
|
|
size = raw_payload.get("size")
|
|
file_type = raw_payload.get("fileType") or raw_payload.get("file_type") or file_type
|
|
storage_url = raw_payload.get("storageUrl") or raw_payload.get("storage_url") or storage_url
|
|
except Exception:
|
|
pass
|
|
if not name or not size:
|
|
raise HTTPException(status_code=422, detail="name and size are required")
|
|
data = KnowledgeDocumentCreate(
|
|
name=name,
|
|
size=size,
|
|
fileType=file_type or "txt",
|
|
storageUrl=storage_url,
|
|
)
|
|
|
|
doc = KnowledgeDocument(
|
|
id=str(uuid.uuid4())[:8],
|
|
kb_id=kb_id,
|
|
name=data.name,
|
|
size=data.size,
|
|
file_type=data.fileType,
|
|
storage_url=data.storageUrl,
|
|
status="pending",
|
|
upload_date=datetime.utcnow().isoformat()
|
|
)
|
|
db.add(doc)
|
|
db.commit()
|
|
db.refresh(doc)
|
|
return {
|
|
"id": doc.id,
|
|
"name": doc.name,
|
|
"size": doc.size,
|
|
"fileType": doc.file_type,
|
|
"storageUrl": doc.storage_url,
|
|
"status": doc.status,
|
|
"message": "Document created",
|
|
}
|
|
|
|
|
|
@router.post("/bases/{kb_id}/documents/{doc_id}/index")
|
|
def index_document_content(kb_id: str, doc_id: str, request: DocumentIndexRequest, db: Session = Depends(get_db)):
|
|
# 检查文档是否存在,不存在则创建
|
|
doc = db.query(KnowledgeDocument).filter(
|
|
KnowledgeDocument.id == doc_id,
|
|
KnowledgeDocument.kb_id == kb_id
|
|
).first()
|
|
|
|
if not doc:
|
|
doc = KnowledgeDocument(
|
|
id=doc_id,
|
|
kb_id=kb_id,
|
|
name=f"doc-{doc_id}.txt",
|
|
size=str(len(request.content)),
|
|
file_type="txt",
|
|
status="pending",
|
|
upload_date=datetime.utcnow().isoformat()
|
|
)
|
|
db.add(doc)
|
|
db.commit()
|
|
db.refresh(doc)
|
|
else:
|
|
# 更新已有文档
|
|
doc.size = str(len(request.content))
|
|
doc.status = "pending"
|
|
db.commit()
|
|
|
|
try:
|
|
if vector_store.get_collection(kb_id) is None:
|
|
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
|
|
vector_store.create_collection(kb_id, kb.embedding_model if kb else "text-embedding-3-small")
|
|
chunk_count = index_document(kb_id, doc_id, request.content)
|
|
doc.status = "completed"
|
|
doc.chunk_count = chunk_count
|
|
doc.processed_at = datetime.utcnow()
|
|
doc.error_message = None
|
|
_refresh_kb_stats(db, kb_id)
|
|
db.commit()
|
|
return {"message": "Document indexed", "chunkCount": chunk_count}
|
|
except Exception as e:
|
|
doc.status = "failed"
|
|
doc.error_message = str(e)
|
|
_refresh_kb_stats(db, kb_id)
|
|
db.commit()
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.delete("/bases/{kb_id}/documents/{doc_id}")
|
|
def delete_document(kb_id: str, doc_id: str, db: Session = Depends(get_db)):
|
|
doc = db.query(KnowledgeDocument).filter(
|
|
KnowledgeDocument.id == doc_id,
|
|
KnowledgeDocument.kb_id == kb_id
|
|
).first()
|
|
if not doc:
|
|
raise HTTPException(status_code=404, detail="Document not found")
|
|
try:
|
|
delete_document_from_vector(kb_id, doc_id)
|
|
except Exception:
|
|
pass
|
|
db.delete(doc)
|
|
_refresh_kb_stats(db, kb_id)
|
|
db.commit()
|
|
return {"message": "Deleted successfully"}
|
|
|
|
|
|
# ============ Search ============
|
|
@router.post("/search")
|
|
def search_knowledge_base(query: KnowledgeSearchQuery):
|
|
try:
|
|
return search_knowledge(kb_id=query.kb_id, query=query.query, n_results=query.nResults)
|
|
except ValueError as exc:
|
|
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
|
|
|
|
|
# ============ Stats ============
|
|
@router.get("/bases/{kb_id}/stats")
|
|
def get_knowledge_stats(kb_id: str, db: Session = Depends(get_db)):
|
|
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
|
|
if not kb:
|
|
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
|
return {"kb_id": kb_id, "docCount": kb.doc_count, "chunkCount": kb.chunk_count}
|