Organize tool scheme
This commit is contained in:
@@ -15,17 +15,6 @@ router = APIRouter(prefix="/tools", tags=["Tools & Autotest"])
|
||||
|
||||
# ============ Available Tools ============
|
||||
TOOL_REGISTRY = {
|
||||
"search": {
|
||||
"name": "网络搜索",
|
||||
"description": "搜索互联网获取最新信息",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "搜索关键词"}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
},
|
||||
"calculator": {
|
||||
"name": "计算器",
|
||||
"description": "执行数学计算",
|
||||
@@ -37,50 +26,6 @@ TOOL_REGISTRY = {
|
||||
"required": ["expression"]
|
||||
}
|
||||
},
|
||||
"weather": {
|
||||
"name": "天气查询",
|
||||
"description": "查询指定城市的天气",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string", "description": "城市名称"}
|
||||
},
|
||||
"required": ["city"]
|
||||
}
|
||||
},
|
||||
"translate": {
|
||||
"name": "翻译",
|
||||
"description": "翻译文本到指定语言",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {"type": "string", "description": "要翻译的文本"},
|
||||
"target_lang": {"type": "string", "description": "目标语言,如: en, ja, ko"}
|
||||
},
|
||||
"required": ["text", "target_lang"]
|
||||
}
|
||||
},
|
||||
"knowledge": {
|
||||
"name": "知识库查询",
|
||||
"description": "从知识库中检索相关信息",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "查询内容"},
|
||||
"kb_id": {"type": "string", "description": "知识库ID"}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
},
|
||||
"current_time": {
|
||||
"name": "当前时间",
|
||||
"description": "获取当前本地时间",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
},
|
||||
"code_interpreter": {
|
||||
"name": "代码执行",
|
||||
"description": "安全地执行Python代码",
|
||||
@@ -92,9 +37,27 @@ TOOL_REGISTRY = {
|
||||
"required": ["code"]
|
||||
}
|
||||
},
|
||||
"take_phone": {
|
||||
"name": "接听电话",
|
||||
"description": "执行接听电话命令",
|
||||
"current_time": {
|
||||
"name": "当前时间",
|
||||
"description": "获取当前本地时间",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
},
|
||||
"turn_on_camera": {
|
||||
"name": "打开摄像头",
|
||||
"description": "执行打开摄像头命令",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
},
|
||||
"turn_off_camera": {
|
||||
"name": "关闭摄像头",
|
||||
"description": "执行关闭摄像头命令",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
@@ -126,68 +89,48 @@ TOOL_REGISTRY = {
|
||||
}
|
||||
|
||||
TOOL_CATEGORY_MAP = {
|
||||
"search": "query",
|
||||
"weather": "query",
|
||||
"translate": "query",
|
||||
"knowledge": "query",
|
||||
"calculator": "query",
|
||||
"current_time": "query",
|
||||
"code_interpreter": "query",
|
||||
"take_phone": "system",
|
||||
"turn_on_camera": "system",
|
||||
"turn_off_camera": "system",
|
||||
"increase_volume": "system",
|
||||
"decrease_volume": "system",
|
||||
}
|
||||
|
||||
TOOL_ICON_MAP = {
|
||||
"search": "Globe",
|
||||
"weather": "CloudSun",
|
||||
"translate": "Globe",
|
||||
"knowledge": "Box",
|
||||
"current_time": "Calendar",
|
||||
"calculator": "Terminal",
|
||||
"current_time": "Calendar",
|
||||
"code_interpreter": "Terminal",
|
||||
"take_phone": "Phone",
|
||||
"turn_on_camera": "Camera",
|
||||
"turn_off_camera": "CameraOff",
|
||||
"increase_volume": "Volume2",
|
||||
"decrease_volume": "Volume2",
|
||||
}
|
||||
|
||||
def _sync_default_tools(db: Session) -> None:
|
||||
"""Ensure built-in tools exist and keep system tool metadata aligned."""
|
||||
changed = False
|
||||
def _seed_default_tools_if_empty(db: Session) -> None:
|
||||
"""Seed built-in tools only when tool_resources is empty."""
|
||||
if db.query(ToolResource).count() > 0:
|
||||
return
|
||||
for tool_id, payload in TOOL_REGISTRY.items():
|
||||
row = db.query(ToolResource).filter(ToolResource.id == tool_id).first()
|
||||
category = TOOL_CATEGORY_MAP.get(tool_id, "system")
|
||||
icon = TOOL_ICON_MAP.get(tool_id, "Wrench")
|
||||
if not row:
|
||||
db.add(ToolResource(
|
||||
id=tool_id,
|
||||
user_id=1,
|
||||
name=payload.get("name", tool_id),
|
||||
description=payload.get("description", ""),
|
||||
category=category,
|
||||
icon=icon,
|
||||
enabled=True,
|
||||
is_system=True,
|
||||
))
|
||||
changed = True
|
||||
continue
|
||||
if row.is_system:
|
||||
new_name = payload.get("name", row.name)
|
||||
new_description = payload.get("description", row.description)
|
||||
if row.name != new_name:
|
||||
row.name = new_name
|
||||
changed = True
|
||||
if row.description != new_description:
|
||||
row.description = new_description
|
||||
changed = True
|
||||
if row.category != category:
|
||||
row.category = category
|
||||
changed = True
|
||||
if row.icon != icon:
|
||||
row.icon = icon
|
||||
changed = True
|
||||
if changed:
|
||||
db.commit()
|
||||
db.add(ToolResource(
|
||||
id=tool_id,
|
||||
user_id=1,
|
||||
name=payload.get("name", tool_id),
|
||||
description=payload.get("description", ""),
|
||||
category=TOOL_CATEGORY_MAP.get(tool_id, "system"),
|
||||
icon=TOOL_ICON_MAP.get(tool_id, "Wrench"),
|
||||
enabled=True,
|
||||
is_system=True,
|
||||
))
|
||||
db.commit()
|
||||
|
||||
|
||||
def recreate_tool_resources(db: Session) -> None:
|
||||
"""Recreate tool resources table content with current built-in defaults."""
|
||||
db.query(ToolResource).delete()
|
||||
db.commit()
|
||||
_seed_default_tools_if_empty(db)
|
||||
|
||||
|
||||
@router.get("/list")
|
||||
@@ -215,7 +158,7 @@ def list_tool_resources(
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取工具资源列表。system/query 仅表示工具执行类型,不代表权限。"""
|
||||
_sync_default_tools(db)
|
||||
_seed_default_tools_if_empty(db)
|
||||
query = db.query(ToolResource)
|
||||
if not include_system:
|
||||
query = query.filter(ToolResource.is_system == False)
|
||||
@@ -231,7 +174,7 @@ def list_tool_resources(
|
||||
@router.get("/resources/{id}", response_model=ToolResourceOut)
|
||||
def get_tool_resource(id: str, db: Session = Depends(get_db)):
|
||||
"""获取单个工具资源详情。"""
|
||||
_sync_default_tools(db)
|
||||
_seed_default_tools_if_empty(db)
|
||||
item = db.query(ToolResource).filter(ToolResource.id == id).first()
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Tool resource not found")
|
||||
@@ -241,7 +184,7 @@ def get_tool_resource(id: str, db: Session = Depends(get_db)):
|
||||
@router.post("/resources", response_model=ToolResourceOut)
|
||||
def create_tool_resource(data: ToolResourceCreate, db: Session = Depends(get_db)):
|
||||
"""创建自定义工具资源。"""
|
||||
_sync_default_tools(db)
|
||||
_seed_default_tools_if_empty(db)
|
||||
candidate_id = (data.id or "").strip()
|
||||
if candidate_id and db.query(ToolResource).filter(ToolResource.id == candidate_id).first():
|
||||
raise HTTPException(status_code=400, detail="Tool ID already exists")
|
||||
@@ -265,7 +208,7 @@ def create_tool_resource(data: ToolResourceCreate, db: Session = Depends(get_db)
|
||||
@router.put("/resources/{id}", response_model=ToolResourceOut)
|
||||
def update_tool_resource(id: str, data: ToolResourceUpdate, db: Session = Depends(get_db)):
|
||||
"""更新工具资源。"""
|
||||
_sync_default_tools(db)
|
||||
_seed_default_tools_if_empty(db)
|
||||
item = db.query(ToolResource).filter(ToolResource.id == id).first()
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Tool resource not found")
|
||||
@@ -283,7 +226,7 @@ def update_tool_resource(id: str, data: ToolResourceUpdate, db: Session = Depend
|
||||
@router.delete("/resources/{id}")
|
||||
def delete_tool_resource(id: str, db: Session = Depends(get_db)):
|
||||
"""删除工具资源。"""
|
||||
_sync_default_tools(db)
|
||||
_seed_default_tools_if_empty(db)
|
||||
item = db.query(ToolResource).filter(ToolResource.id == id).first()
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Tool resource not found")
|
||||
|
||||
@@ -100,6 +100,23 @@ def init_default_data():
|
||||
db.close()
|
||||
|
||||
|
||||
def init_default_tools(recreate: bool = False):
|
||||
"""初始化默认工具,或按需重建工具表数据。"""
|
||||
from app.db import SessionLocal
|
||||
from app.routers.tools import _seed_default_tools_if_empty, recreate_tool_resources
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
if recreate:
|
||||
recreate_tool_resources(db)
|
||||
print("✅ 工具库已重建")
|
||||
else:
|
||||
_seed_default_tools_if_empty(db)
|
||||
print("✅ 默认工具已初始化")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def init_default_assistants():
|
||||
"""初始化默认助手"""
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -121,7 +138,7 @@ def init_default_assistants():
|
||||
voice="anna",
|
||||
speed=1.0,
|
||||
hotwords=[],
|
||||
tools=["search", "calculator"],
|
||||
tools=["calculator", "current_time"],
|
||||
interruption_sensitivity=500,
|
||||
config_mode="platform",
|
||||
llm_model_id="deepseek-chat",
|
||||
@@ -139,7 +156,7 @@ def init_default_assistants():
|
||||
voice="bella",
|
||||
speed=1.0,
|
||||
hotwords=["客服", "投诉", "咨询"],
|
||||
tools=["search"],
|
||||
tools=["current_time"],
|
||||
interruption_sensitivity=600,
|
||||
config_mode="platform",
|
||||
),
|
||||
@@ -155,7 +172,7 @@ def init_default_assistants():
|
||||
voice="alex",
|
||||
speed=1.0,
|
||||
hotwords=["grammar", "vocabulary", "practice"],
|
||||
tools=[],
|
||||
tools=["calculator"],
|
||||
interruption_sensitivity=400,
|
||||
config_mode="platform",
|
||||
),
|
||||
@@ -421,6 +438,11 @@ if __name__ == "__main__":
|
||||
action="store_true",
|
||||
help="跳过默认数据初始化",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--recreate-tool-db",
|
||||
action="store_true",
|
||||
help="重建工具库数据(清空 tool_resources 后按内置默认工具重建)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# 无参数时保持旧行为:重建 DB + 初始化默认数据
|
||||
@@ -434,8 +456,13 @@ if __name__ == "__main__":
|
||||
if args.rebuild_db:
|
||||
init_db()
|
||||
|
||||
if args.recreate_tool_db:
|
||||
init_default_tools(recreate=True)
|
||||
|
||||
if not args.skip_seed:
|
||||
init_default_data()
|
||||
if not args.recreate_tool_db:
|
||||
init_default_tools(recreate=False)
|
||||
init_default_assistants()
|
||||
init_default_workflows()
|
||||
init_default_knowledge_bases()
|
||||
|
||||
@@ -14,16 +14,21 @@ class TestToolsAPI:
|
||||
assert "tools" in data
|
||||
# Check for expected tools
|
||||
tools = data["tools"]
|
||||
assert "search" in tools
|
||||
assert "calculator" in tools
|
||||
assert "weather" in tools
|
||||
assert "code_interpreter" in tools
|
||||
assert "current_time" in tools
|
||||
assert "turn_on_camera" in tools
|
||||
assert "turn_off_camera" in tools
|
||||
assert "increase_volume" in tools
|
||||
assert "decrease_volume" in tools
|
||||
assert "calculator" in tools
|
||||
|
||||
def test_get_tool_detail(self, client):
|
||||
"""Test getting a specific tool's details"""
|
||||
response = client.get("/api/tools/list/search")
|
||||
response = client.get("/api/tools/list/calculator")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == "网络搜索"
|
||||
assert data["name"] == "计算器"
|
||||
assert "parameters" in data
|
||||
|
||||
def test_get_tool_detail_not_found(self, client):
|
||||
@@ -256,15 +261,14 @@ class TestAutotestAPI:
|
||||
assert "required" in data["parameters"]
|
||||
assert "expression" in data["parameters"]["required"]
|
||||
|
||||
def test_translate_tool_parameters(self, client):
|
||||
"""Test translate tool has correct parameters"""
|
||||
response = client.get("/api/tools/list/translate")
|
||||
def test_code_interpreter_tool_parameters(self, client):
|
||||
"""Test code_interpreter tool has correct parameters"""
|
||||
response = client.get("/api/tools/list/code_interpreter")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["name"] == "翻译"
|
||||
assert "text" in data["parameters"]["properties"]
|
||||
assert "target_lang" in data["parameters"]["properties"]
|
||||
assert data["name"] == "代码执行"
|
||||
assert "code" in data["parameters"]["properties"]
|
||||
|
||||
|
||||
class TestToolResourceCRUD:
|
||||
@@ -276,7 +280,7 @@ class TestToolResourceCRUD:
|
||||
payload = response.json()
|
||||
assert payload["total"] >= 1
|
||||
ids = [item["id"] for item in payload["list"]]
|
||||
assert "search" in ids
|
||||
assert "calculator" in ids
|
||||
|
||||
def test_create_update_delete_tool_resource(self, client):
|
||||
create_resp = client.post("/api/tools/resources", json={
|
||||
@@ -314,14 +318,14 @@ class TestToolResourceCRUD:
|
||||
def test_system_tool_can_be_updated_and_deleted(self, client):
|
||||
list_resp = client.get("/api/tools/resources")
|
||||
assert list_resp.status_code == 200
|
||||
assert any(item["id"] == "search" for item in list_resp.json()["list"])
|
||||
assert any(item["id"] == "turn_on_camera" for item in list_resp.json()["list"])
|
||||
|
||||
update_resp = client.put("/api/tools/resources/search", json={"name": "更新后的搜索工具", "category": "query"})
|
||||
update_resp = client.put("/api/tools/resources/turn_on_camera", json={"name": "更新后的打开摄像头", "category": "system"})
|
||||
assert update_resp.status_code == 200
|
||||
assert update_resp.json()["name"] == "更新后的搜索工具"
|
||||
assert update_resp.json()["name"] == "更新后的打开摄像头"
|
||||
|
||||
delete_resp = client.delete("/api/tools/resources/search")
|
||||
delete_resp = client.delete("/api/tools/resources/turn_on_camera")
|
||||
assert delete_resp.status_code == 200
|
||||
|
||||
get_resp = client.get("/api/tools/resources/search")
|
||||
get_resp = client.get("/api/tools/resources/turn_on_camera")
|
||||
assert get_resp.status_code == 404
|
||||
|
||||
Reference in New Issue
Block a user