Improve KB upload
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, Request
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
import uuid
|
||||
import os
|
||||
import json
|
||||
from io import BytesIO
|
||||
from datetime import datetime
|
||||
|
||||
from ..db import get_db
|
||||
@@ -20,6 +22,57 @@ from ..vector_store import (
|
||||
router = APIRouter(prefix="/knowledge", tags=["knowledge"])
|
||||
|
||||
|
||||
def _refresh_kb_stats(db: Session, kb_id: str) -> None:
|
||||
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
|
||||
if not kb:
|
||||
return
|
||||
docs = db.query(KnowledgeDocument).filter(KnowledgeDocument.kb_id == kb_id).all()
|
||||
completed_docs = [d for d in docs if d.status == "completed"]
|
||||
kb.doc_count = len(completed_docs)
|
||||
kb.chunk_count = sum(max(0, d.chunk_count or 0) for d in completed_docs)
|
||||
|
||||
|
||||
def _decode_text_bytes(raw: bytes) -> str:
|
||||
for encoding in ("utf-8", "utf-8-sig", "gb18030", "gbk", "latin-1"):
|
||||
try:
|
||||
return raw.decode(encoding)
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
return raw.decode("utf-8", errors="ignore")
|
||||
|
||||
|
||||
def _extract_text_from_upload(filename: str, content_type: Optional[str], raw: bytes) -> str:
|
||||
ext = os.path.splitext((filename or "").lower())[1]
|
||||
if ext in {".txt", ".md", ".csv"}:
|
||||
return _decode_text_bytes(raw)
|
||||
if ext == ".json":
|
||||
try:
|
||||
parsed = json.loads(_decode_text_bytes(raw))
|
||||
return json.dumps(parsed, ensure_ascii=False, indent=2)
|
||||
except Exception:
|
||||
return _decode_text_bytes(raw)
|
||||
if ext == ".pdf":
|
||||
try:
|
||||
from pypdf import PdfReader # type: ignore
|
||||
except Exception as exc:
|
||||
raise ValueError("PDF parsing requires installing pypdf") from exc
|
||||
reader = PdfReader(BytesIO(raw))
|
||||
return "\n".join((page.extract_text() or "") for page in reader.pages).strip()
|
||||
if ext == ".docx":
|
||||
try:
|
||||
from docx import Document # type: ignore
|
||||
except Exception as exc:
|
||||
raise ValueError("DOCX parsing requires installing python-docx") from exc
|
||||
doc = Document(BytesIO(raw))
|
||||
return "\n".join(p.text for p in doc.paragraphs).strip()
|
||||
if ext == ".doc":
|
||||
raise ValueError("DOC format is not supported for auto indexing. Please convert to DOCX/TXT.")
|
||||
# fallback: attempt plain text decode
|
||||
if (content_type or "").startswith("text/"):
|
||||
return _decode_text_bytes(raw)
|
||||
raise ValueError(f"Unsupported file type for auto indexing: {ext or content_type or 'unknown'}")
|
||||
|
||||
|
||||
def kb_to_dict(kb: KnowledgeBase) -> dict:
|
||||
return {
|
||||
"id": kb.id,
|
||||
@@ -191,20 +244,93 @@ def delete_knowledge_base(kb_id: str, db: Session = Depends(get_db)):
|
||||
|
||||
# ============ Documents ============
|
||||
@router.post("/bases/{kb_id}/documents")
|
||||
def upload_document(
|
||||
async def upload_document(
|
||||
kb_id: str,
|
||||
file: Optional[UploadFile] = File(default=None),
|
||||
name: Optional[str] = Form(default=None),
|
||||
size: Optional[str] = Form(default=None),
|
||||
file_type: Optional[str] = Form(default=None),
|
||||
storage_url: Optional[str] = Form(default=None),
|
||||
data: Optional[KnowledgeDocumentCreate] = None,
|
||||
name: Optional[str] = Query(default=None),
|
||||
size: Optional[str] = Query(default=None),
|
||||
file_type: Optional[str] = Query(default=None),
|
||||
storage_url: Optional[str] = Query(default=None),
|
||||
request: Request = None,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
|
||||
if not kb:
|
||||
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
||||
|
||||
# New mode: multipart file upload with automatic indexing
|
||||
if file is not None:
|
||||
filename = file.filename or "uploaded.txt"
|
||||
file_type_value = file.content_type or file_type or "application/octet-stream"
|
||||
raw = file.file.read()
|
||||
if not raw:
|
||||
raise HTTPException(status_code=400, detail="Uploaded file is empty")
|
||||
|
||||
doc = KnowledgeDocument(
|
||||
id=str(uuid.uuid4())[:8],
|
||||
kb_id=kb_id,
|
||||
name=filename,
|
||||
size=f"{len(raw)} bytes",
|
||||
file_type=file_type_value,
|
||||
storage_url=storage_url,
|
||||
status="processing",
|
||||
upload_date=datetime.utcnow().isoformat()
|
||||
)
|
||||
db.add(doc)
|
||||
db.commit()
|
||||
db.refresh(doc)
|
||||
|
||||
try:
|
||||
if vector_store.get_collection(kb_id) is None:
|
||||
vector_store.create_collection(kb_id, kb.embedding_model)
|
||||
|
||||
text = _extract_text_from_upload(filename, file.content_type, raw)
|
||||
if not text.strip():
|
||||
raise ValueError("No textual content extracted from file")
|
||||
|
||||
chunk_count = index_document(kb_id, doc.id, text)
|
||||
doc.status = "completed"
|
||||
doc.chunk_count = chunk_count
|
||||
doc.processed_at = datetime.utcnow()
|
||||
doc.error_message = None
|
||||
_refresh_kb_stats(db, kb_id)
|
||||
db.commit()
|
||||
return {
|
||||
"id": doc.id,
|
||||
"name": doc.name,
|
||||
"size": doc.size,
|
||||
"fileType": doc.file_type,
|
||||
"storageUrl": doc.storage_url,
|
||||
"status": doc.status,
|
||||
"chunkCount": doc.chunk_count,
|
||||
"message": "Document uploaded and indexed",
|
||||
}
|
||||
except ValueError as exc:
|
||||
doc.status = "failed"
|
||||
doc.error_message = str(exc)
|
||||
_refresh_kb_stats(db, kb_id)
|
||||
db.commit()
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
except Exception as exc:
|
||||
doc.status = "failed"
|
||||
doc.error_message = str(exc)
|
||||
_refresh_kb_stats(db, kb_id)
|
||||
db.commit()
|
||||
raise HTTPException(status_code=500, detail=f"Failed to index uploaded file: {exc}") from exc
|
||||
|
||||
# Backward-compatible mode: metadata-only document creation
|
||||
if data is None:
|
||||
if not name and not size and request is not None:
|
||||
try:
|
||||
raw_payload = await request.json()
|
||||
if isinstance(raw_payload, dict):
|
||||
name = raw_payload.get("name")
|
||||
size = raw_payload.get("size")
|
||||
file_type = raw_payload.get("fileType") or raw_payload.get("file_type") or file_type
|
||||
storage_url = raw_payload.get("storageUrl") or raw_payload.get("storage_url") or storage_url
|
||||
except Exception:
|
||||
pass
|
||||
if not name or not size:
|
||||
raise HTTPException(status_code=422, detail="name and size are required")
|
||||
data = KnowledgeDocumentCreate(
|
||||
@@ -266,21 +392,21 @@ def index_document_content(kb_id: str, doc_id: str, request: DocumentIndexReques
|
||||
db.commit()
|
||||
|
||||
try:
|
||||
if vector_store.get_collection(kb_id) is None:
|
||||
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
|
||||
vector_store.create_collection(kb_id, kb.embedding_model if kb else "text-embedding-3-small")
|
||||
chunk_count = index_document(kb_id, doc_id, request.content)
|
||||
doc.status = "completed"
|
||||
doc.chunk_count = chunk_count
|
||||
doc.processed_at = datetime.utcnow()
|
||||
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
|
||||
kb.doc_count = db.query(KnowledgeDocument).filter(
|
||||
KnowledgeDocument.kb_id == kb_id,
|
||||
KnowledgeDocument.status == "completed"
|
||||
).count()
|
||||
kb.chunk_count += chunk_count
|
||||
doc.error_message = None
|
||||
_refresh_kb_stats(db, kb_id)
|
||||
db.commit()
|
||||
return {"message": "Document indexed", "chunkCount": chunk_count}
|
||||
except Exception as e:
|
||||
doc.status = "failed"
|
||||
doc.error_message = str(e)
|
||||
_refresh_kb_stats(db, kb_id)
|
||||
db.commit()
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -297,11 +423,8 @@ def delete_document(kb_id: str, doc_id: str, db: Session = Depends(get_db)):
|
||||
delete_document_from_vector(kb_id, doc_id)
|
||||
except Exception:
|
||||
pass
|
||||
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
|
||||
if kb:
|
||||
kb.chunk_count = max(0, kb.chunk_count - (doc.chunk_count or 0))
|
||||
kb.doc_count = max(0, kb.doc_count - 1)
|
||||
db.delete(doc)
|
||||
_refresh_kb_stats(db, kb_id)
|
||||
db.commit()
|
||||
return {"message": "Deleted successfully"}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user