Update tool panel and db

This commit is contained in:
Xin Wang
2026-02-09 00:22:31 +08:00
parent d0b96a3f72
commit 59cda0987f
3 changed files with 51 additions and 89 deletions

View File

@@ -103,17 +103,23 @@ TOOL_ICON_MAP = {
"code_interpreter": "Terminal",
}
def _seed_default_tools_if_empty(db: Session) -> None:
"""Seed default tools into DB when tool_resources is empty."""
if db.query(ToolResource).count() > 0:
return
def _builtin_tool_to_resource(tool_id: str, payload: Dict[str, Any]) -> Dict[str, Any]:
return {
"id": tool_id,
"name": payload.get("name", tool_id),
"description": payload.get("description", ""),
"category": TOOL_CATEGORY_MAP.get(tool_id, "system"),
"icon": TOOL_ICON_MAP.get(tool_id, "Wrench"),
"enabled": True,
"is_system": True,
}
for tool_id, payload in TOOL_REGISTRY.items():
db.add(ToolResource(
id=tool_id,
user_id=1,
name=payload.get("name", tool_id),
description=payload.get("description", ""),
category=TOOL_CATEGORY_MAP.get(tool_id, "system"),
icon=TOOL_ICON_MAP.get(tool_id, "Wrench"),
enabled=True,
is_system=True,
))
db.commit()
@router.get("/list")
@@ -140,49 +146,24 @@ def list_tool_resources(
limit: int = 100,
db: Session = Depends(get_db),
):
"""获取工具资源列表(内置工具 + 自定义工具)"""
merged: List[Dict[str, Any]] = []
if include_system:
for tool_id, payload in TOOL_REGISTRY.items():
merged.append(_builtin_tool_to_resource(tool_id, payload))
"""获取工具资源列表。system/query 仅表示工具执行类型,不代表权限"""
_seed_default_tools_if_empty(db)
query = db.query(ToolResource)
if not include_system:
query = query.filter(ToolResource.is_system == False)
if category:
query = query.filter(ToolResource.category == category)
if enabled is not None:
query = query.filter(ToolResource.enabled == enabled)
custom_tools = query.order_by(ToolResource.created_at.desc()).all()
for item in custom_tools:
merged.append({
"id": item.id,
"name": item.name,
"description": item.description,
"category": item.category,
"icon": item.icon,
"enabled": item.enabled,
"is_system": item.is_system,
})
if category:
merged = [item for item in merged if item.get("category") == category]
if enabled is not None:
merged = [item for item in merged if item.get("enabled") == enabled]
total = len(merged)
start = max(page - 1, 0) * limit
end = start + limit
return {"total": total, "page": page, "limit": limit, "list": merged[start:end]}
total = query.count()
rows = query.order_by(ToolResource.created_at.desc()).offset(max(page - 1, 0) * limit).limit(limit).all()
return {"total": total, "page": page, "limit": limit, "list": rows}
@router.get("/resources/{id}", response_model=ToolResourceOut)
def get_tool_resource(id: str, db: Session = Depends(get_db)):
"""获取单个工具资源详情。"""
if id in TOOL_REGISTRY:
tool = _builtin_tool_to_resource(id, TOOL_REGISTRY[id])
return ToolResourceOut(**tool)
_seed_default_tools_if_empty(db)
item = db.query(ToolResource).filter(ToolResource.id == id).first()
if not item:
raise HTTPException(status_code=404, detail="Tool resource not found")
@@ -192,9 +173,10 @@ def get_tool_resource(id: str, db: Session = Depends(get_db)):
@router.post("/resources", response_model=ToolResourceOut)
def create_tool_resource(data: ToolResourceCreate, db: Session = Depends(get_db)):
"""创建自定义工具资源。"""
_seed_default_tools_if_empty(db)
candidate_id = (data.id or "").strip()
if candidate_id and candidate_id in TOOL_REGISTRY:
raise HTTPException(status_code=400, detail="Tool ID conflicts with system tool")
if candidate_id and db.query(ToolResource).filter(ToolResource.id == candidate_id).first():
raise HTTPException(status_code=400, detail="Tool ID already exists")
item = ToolResource(
id=candidate_id or f"tool_{str(uuid.uuid4())[:8]}",
@@ -214,10 +196,8 @@ def create_tool_resource(data: ToolResourceCreate, db: Session = Depends(get_db)
@router.put("/resources/{id}", response_model=ToolResourceOut)
def update_tool_resource(id: str, data: ToolResourceUpdate, db: Session = Depends(get_db)):
"""更新自定义工具资源。"""
if id in TOOL_REGISTRY:
raise HTTPException(status_code=400, detail="System tools are read-only")
"""更新工具资源。"""
_seed_default_tools_if_empty(db)
item = db.query(ToolResource).filter(ToolResource.id == id).first()
if not item:
raise HTTPException(status_code=404, detail="Tool resource not found")
@@ -234,10 +214,8 @@ def update_tool_resource(id: str, data: ToolResourceUpdate, db: Session = Depend
@router.delete("/resources/{id}")
def delete_tool_resource(id: str, db: Session = Depends(get_db)):
"""删除自定义工具资源。"""
if id in TOOL_REGISTRY:
raise HTTPException(status_code=400, detail="System tools cannot be deleted")
"""删除工具资源。"""
_seed_default_tools_if_empty(db)
item = db.query(ToolResource).filter(ToolResource.id == id).first()
if not item:
raise HTTPException(status_code=404, detail="Tool resource not found")

View File

@@ -311,9 +311,17 @@ class TestToolResourceCRUD:
missing_resp = client.get(f"/api/tools/resources/{tool_id}")
assert missing_resp.status_code == 404
def test_system_tool_is_read_only(self, client):
update_resp = client.put("/api/tools/resources/search", json={"name": "new"})
assert update_resp.status_code == 400
def test_system_tool_can_be_updated_and_deleted(self, client):
list_resp = client.get("/api/tools/resources")
assert list_resp.status_code == 200
assert any(item["id"] == "search" for item in list_resp.json()["list"])
update_resp = client.put("/api/tools/resources/search", json={"name": "更新后的搜索工具", "category": "query"})
assert update_resp.status_code == 200
assert update_resp.json()["name"] == "更新后的搜索工具"
delete_resp = client.delete("/api/tools/resources/search")
assert delete_resp.status_code == 400
assert delete_resp.status_code == 200
get_resp = client.get("/api/tools/resources/search")
assert get_resp.status_code == 404