Organize tool scheme

This commit is contained in:
Xin Wang
2026-02-11 11:22:56 +08:00
parent 9304927fe9
commit 80e1d24443
4 changed files with 115 additions and 132 deletions

View File

@@ -15,17 +15,6 @@ router = APIRouter(prefix="/tools", tags=["Tools & Autotest"])
# ============ Available Tools ============ # ============ Available Tools ============
TOOL_REGISTRY = { TOOL_REGISTRY = {
"search": {
"name": "网络搜索",
"description": "搜索互联网获取最新信息",
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string", "description": "搜索关键词"}
},
"required": ["query"]
}
},
"calculator": { "calculator": {
"name": "计算器", "name": "计算器",
"description": "执行数学计算", "description": "执行数学计算",
@@ -37,50 +26,6 @@ TOOL_REGISTRY = {
"required": ["expression"] "required": ["expression"]
} }
}, },
"weather": {
"name": "天气查询",
"description": "查询指定城市的天气",
"parameters": {
"type": "object",
"properties": {
"city": {"type": "string", "description": "城市名称"}
},
"required": ["city"]
}
},
"translate": {
"name": "翻译",
"description": "翻译文本到指定语言",
"parameters": {
"type": "object",
"properties": {
"text": {"type": "string", "description": "要翻译的文本"},
"target_lang": {"type": "string", "description": "目标语言,如: en, ja, ko"}
},
"required": ["text", "target_lang"]
}
},
"knowledge": {
"name": "知识库查询",
"description": "从知识库中检索相关信息",
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string", "description": "查询内容"},
"kb_id": {"type": "string", "description": "知识库ID"}
},
"required": ["query"]
}
},
"current_time": {
"name": "当前时间",
"description": "获取当前本地时间",
"parameters": {
"type": "object",
"properties": {},
"required": []
}
},
"code_interpreter": { "code_interpreter": {
"name": "代码执行", "name": "代码执行",
"description": "安全地执行Python代码", "description": "安全地执行Python代码",
@@ -92,9 +37,27 @@ TOOL_REGISTRY = {
"required": ["code"] "required": ["code"]
} }
}, },
"take_phone": { "current_time": {
"name": "接听电话", "name": "当前时间",
"description": "执行接听电话命令", "description": "获取当前本地时间",
"parameters": {
"type": "object",
"properties": {},
"required": []
}
},
"turn_on_camera": {
"name": "打开摄像头",
"description": "执行打开摄像头命令",
"parameters": {
"type": "object",
"properties": {},
"required": []
}
},
"turn_off_camera": {
"name": "关闭摄像头",
"description": "执行关闭摄像头命令",
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": {}, "properties": {},
@@ -126,70 +89,50 @@ TOOL_REGISTRY = {
} }
TOOL_CATEGORY_MAP = { TOOL_CATEGORY_MAP = {
"search": "query",
"weather": "query",
"translate": "query",
"knowledge": "query",
"calculator": "query", "calculator": "query",
"current_time": "query", "current_time": "query",
"code_interpreter": "query", "code_interpreter": "query",
"take_phone": "system", "turn_on_camera": "system",
"turn_off_camera": "system",
"increase_volume": "system", "increase_volume": "system",
"decrease_volume": "system", "decrease_volume": "system",
} }
TOOL_ICON_MAP = { TOOL_ICON_MAP = {
"search": "Globe",
"weather": "CloudSun",
"translate": "Globe",
"knowledge": "Box",
"current_time": "Calendar",
"calculator": "Terminal", "calculator": "Terminal",
"current_time": "Calendar",
"code_interpreter": "Terminal", "code_interpreter": "Terminal",
"take_phone": "Phone", "turn_on_camera": "Camera",
"turn_off_camera": "CameraOff",
"increase_volume": "Volume2", "increase_volume": "Volume2",
"decrease_volume": "Volume2", "decrease_volume": "Volume2",
} }
def _sync_default_tools(db: Session) -> None: def _seed_default_tools_if_empty(db: Session) -> None:
"""Ensure built-in tools exist and keep system tool metadata aligned.""" """Seed built-in tools only when tool_resources is empty."""
changed = False if db.query(ToolResource).count() > 0:
return
for tool_id, payload in TOOL_REGISTRY.items(): for tool_id, payload in TOOL_REGISTRY.items():
row = db.query(ToolResource).filter(ToolResource.id == tool_id).first()
category = TOOL_CATEGORY_MAP.get(tool_id, "system")
icon = TOOL_ICON_MAP.get(tool_id, "Wrench")
if not row:
db.add(ToolResource( db.add(ToolResource(
id=tool_id, id=tool_id,
user_id=1, user_id=1,
name=payload.get("name", tool_id), name=payload.get("name", tool_id),
description=payload.get("description", ""), description=payload.get("description", ""),
category=category, category=TOOL_CATEGORY_MAP.get(tool_id, "system"),
icon=icon, icon=TOOL_ICON_MAP.get(tool_id, "Wrench"),
enabled=True, enabled=True,
is_system=True, is_system=True,
)) ))
changed = True
continue
if row.is_system:
new_name = payload.get("name", row.name)
new_description = payload.get("description", row.description)
if row.name != new_name:
row.name = new_name
changed = True
if row.description != new_description:
row.description = new_description
changed = True
if row.category != category:
row.category = category
changed = True
if row.icon != icon:
row.icon = icon
changed = True
if changed:
db.commit() db.commit()
def recreate_tool_resources(db: Session) -> None:
"""Recreate tool resources table content with current built-in defaults."""
db.query(ToolResource).delete()
db.commit()
_seed_default_tools_if_empty(db)
@router.get("/list") @router.get("/list")
def list_available_tools(): def list_available_tools():
"""获取可用的工具列表""" """获取可用的工具列表"""
@@ -215,7 +158,7 @@ def list_tool_resources(
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
"""获取工具资源列表。system/query 仅表示工具执行类型,不代表权限。""" """获取工具资源列表。system/query 仅表示工具执行类型,不代表权限。"""
_sync_default_tools(db) _seed_default_tools_if_empty(db)
query = db.query(ToolResource) query = db.query(ToolResource)
if not include_system: if not include_system:
query = query.filter(ToolResource.is_system == False) query = query.filter(ToolResource.is_system == False)
@@ -231,7 +174,7 @@ def list_tool_resources(
@router.get("/resources/{id}", response_model=ToolResourceOut) @router.get("/resources/{id}", response_model=ToolResourceOut)
def get_tool_resource(id: str, db: Session = Depends(get_db)): def get_tool_resource(id: str, db: Session = Depends(get_db)):
"""获取单个工具资源详情。""" """获取单个工具资源详情。"""
_sync_default_tools(db) _seed_default_tools_if_empty(db)
item = db.query(ToolResource).filter(ToolResource.id == id).first() item = db.query(ToolResource).filter(ToolResource.id == id).first()
if not item: if not item:
raise HTTPException(status_code=404, detail="Tool resource not found") raise HTTPException(status_code=404, detail="Tool resource not found")
@@ -241,7 +184,7 @@ def get_tool_resource(id: str, db: Session = Depends(get_db)):
@router.post("/resources", response_model=ToolResourceOut) @router.post("/resources", response_model=ToolResourceOut)
def create_tool_resource(data: ToolResourceCreate, db: Session = Depends(get_db)): def create_tool_resource(data: ToolResourceCreate, db: Session = Depends(get_db)):
"""创建自定义工具资源。""" """创建自定义工具资源。"""
_sync_default_tools(db) _seed_default_tools_if_empty(db)
candidate_id = (data.id or "").strip() candidate_id = (data.id or "").strip()
if candidate_id and db.query(ToolResource).filter(ToolResource.id == candidate_id).first(): if candidate_id and db.query(ToolResource).filter(ToolResource.id == candidate_id).first():
raise HTTPException(status_code=400, detail="Tool ID already exists") raise HTTPException(status_code=400, detail="Tool ID already exists")
@@ -265,7 +208,7 @@ def create_tool_resource(data: ToolResourceCreate, db: Session = Depends(get_db)
@router.put("/resources/{id}", response_model=ToolResourceOut) @router.put("/resources/{id}", response_model=ToolResourceOut)
def update_tool_resource(id: str, data: ToolResourceUpdate, db: Session = Depends(get_db)): def update_tool_resource(id: str, data: ToolResourceUpdate, db: Session = Depends(get_db)):
"""更新工具资源。""" """更新工具资源。"""
_sync_default_tools(db) _seed_default_tools_if_empty(db)
item = db.query(ToolResource).filter(ToolResource.id == id).first() item = db.query(ToolResource).filter(ToolResource.id == id).first()
if not item: if not item:
raise HTTPException(status_code=404, detail="Tool resource not found") raise HTTPException(status_code=404, detail="Tool resource not found")
@@ -283,7 +226,7 @@ def update_tool_resource(id: str, data: ToolResourceUpdate, db: Session = Depend
@router.delete("/resources/{id}") @router.delete("/resources/{id}")
def delete_tool_resource(id: str, db: Session = Depends(get_db)): def delete_tool_resource(id: str, db: Session = Depends(get_db)):
"""删除工具资源。""" """删除工具资源。"""
_sync_default_tools(db) _seed_default_tools_if_empty(db)
item = db.query(ToolResource).filter(ToolResource.id == id).first() item = db.query(ToolResource).filter(ToolResource.id == id).first()
if not item: if not item:
raise HTTPException(status_code=404, detail="Tool resource not found") raise HTTPException(status_code=404, detail="Tool resource not found")

View File

@@ -100,6 +100,23 @@ def init_default_data():
db.close() db.close()
def init_default_tools(recreate: bool = False):
"""初始化默认工具,或按需重建工具表数据。"""
from app.db import SessionLocal
from app.routers.tools import _seed_default_tools_if_empty, recreate_tool_resources
db = SessionLocal()
try:
if recreate:
recreate_tool_resources(db)
print("✅ 工具库已重建")
else:
_seed_default_tools_if_empty(db)
print("✅ 默认工具已初始化")
finally:
db.close()
def init_default_assistants(): def init_default_assistants():
"""初始化默认助手""" """初始化默认助手"""
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -121,7 +138,7 @@ def init_default_assistants():
voice="anna", voice="anna",
speed=1.0, speed=1.0,
hotwords=[], hotwords=[],
tools=["search", "calculator"], tools=["calculator", "current_time"],
interruption_sensitivity=500, interruption_sensitivity=500,
config_mode="platform", config_mode="platform",
llm_model_id="deepseek-chat", llm_model_id="deepseek-chat",
@@ -139,7 +156,7 @@ def init_default_assistants():
voice="bella", voice="bella",
speed=1.0, speed=1.0,
hotwords=["客服", "投诉", "咨询"], hotwords=["客服", "投诉", "咨询"],
tools=["search"], tools=["current_time"],
interruption_sensitivity=600, interruption_sensitivity=600,
config_mode="platform", config_mode="platform",
), ),
@@ -155,7 +172,7 @@ def init_default_assistants():
voice="alex", voice="alex",
speed=1.0, speed=1.0,
hotwords=["grammar", "vocabulary", "practice"], hotwords=["grammar", "vocabulary", "practice"],
tools=[], tools=["calculator"],
interruption_sensitivity=400, interruption_sensitivity=400,
config_mode="platform", config_mode="platform",
), ),
@@ -421,6 +438,11 @@ if __name__ == "__main__":
action="store_true", action="store_true",
help="跳过默认数据初始化", help="跳过默认数据初始化",
) )
parser.add_argument(
"--recreate-tool-db",
action="store_true",
help="重建工具库数据(清空 tool_resources 后按内置默认工具重建)",
)
args = parser.parse_args() args = parser.parse_args()
# 无参数时保持旧行为:重建 DB + 初始化默认数据 # 无参数时保持旧行为:重建 DB + 初始化默认数据
@@ -434,8 +456,13 @@ if __name__ == "__main__":
if args.rebuild_db: if args.rebuild_db:
init_db() init_db()
if args.recreate_tool_db:
init_default_tools(recreate=True)
if not args.skip_seed: if not args.skip_seed:
init_default_data() init_default_data()
if not args.recreate_tool_db:
init_default_tools(recreate=False)
init_default_assistants() init_default_assistants()
init_default_workflows() init_default_workflows()
init_default_knowledge_bases() init_default_knowledge_bases()

View File

@@ -14,16 +14,21 @@ class TestToolsAPI:
assert "tools" in data assert "tools" in data
# Check for expected tools # Check for expected tools
tools = data["tools"] tools = data["tools"]
assert "search" in tools
assert "calculator" in tools assert "calculator" in tools
assert "weather" in tools assert "code_interpreter" in tools
assert "current_time" in tools
assert "turn_on_camera" in tools
assert "turn_off_camera" in tools
assert "increase_volume" in tools
assert "decrease_volume" in tools
assert "calculator" in tools
def test_get_tool_detail(self, client): def test_get_tool_detail(self, client):
"""Test getting a specific tool's details""" """Test getting a specific tool's details"""
response = client.get("/api/tools/list/search") response = client.get("/api/tools/list/calculator")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["name"] == "网络搜索" assert data["name"] == "计算器"
assert "parameters" in data assert "parameters" in data
def test_get_tool_detail_not_found(self, client): def test_get_tool_detail_not_found(self, client):
@@ -256,15 +261,14 @@ class TestAutotestAPI:
assert "required" in data["parameters"] assert "required" in data["parameters"]
assert "expression" in data["parameters"]["required"] assert "expression" in data["parameters"]["required"]
def test_translate_tool_parameters(self, client): def test_code_interpreter_tool_parameters(self, client):
"""Test translate tool has correct parameters""" """Test code_interpreter tool has correct parameters"""
response = client.get("/api/tools/list/translate") response = client.get("/api/tools/list/code_interpreter")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["name"] == "翻译" assert data["name"] == "代码执行"
assert "text" in data["parameters"]["properties"] assert "code" in data["parameters"]["properties"]
assert "target_lang" in data["parameters"]["properties"]
class TestToolResourceCRUD: class TestToolResourceCRUD:
@@ -276,7 +280,7 @@ class TestToolResourceCRUD:
payload = response.json() payload = response.json()
assert payload["total"] >= 1 assert payload["total"] >= 1
ids = [item["id"] for item in payload["list"]] ids = [item["id"] for item in payload["list"]]
assert "search" in ids assert "calculator" in ids
def test_create_update_delete_tool_resource(self, client): def test_create_update_delete_tool_resource(self, client):
create_resp = client.post("/api/tools/resources", json={ create_resp = client.post("/api/tools/resources", json={
@@ -314,14 +318,14 @@ class TestToolResourceCRUD:
def test_system_tool_can_be_updated_and_deleted(self, client): def test_system_tool_can_be_updated_and_deleted(self, client):
list_resp = client.get("/api/tools/resources") list_resp = client.get("/api/tools/resources")
assert list_resp.status_code == 200 assert list_resp.status_code == 200
assert any(item["id"] == "search" for item in list_resp.json()["list"]) assert any(item["id"] == "turn_on_camera" for item in list_resp.json()["list"])
update_resp = client.put("/api/tools/resources/search", json={"name": "更新后的搜索工具", "category": "query"}) update_resp = client.put("/api/tools/resources/turn_on_camera", json={"name": "更新后的打开摄像头", "category": "system"})
assert update_resp.status_code == 200 assert update_resp.status_code == 200
assert update_resp.json()["name"] == "更新后的搜索工具" assert update_resp.json()["name"] == "更新后的打开摄像头"
delete_resp = client.delete("/api/tools/resources/search") delete_resp = client.delete("/api/tools/resources/turn_on_camera")
assert delete_resp.status_code == 200 assert delete_resp.status_code == 200
get_resp = client.get("/api/tools/resources/search") get_resp = client.get("/api/tools/resources/turn_on_camera")
assert get_resp.status_code == 404 assert get_resp.status_code == 404

View File

@@ -129,9 +129,18 @@ class DuplexPipeline:
"required": ["code"], "required": ["code"],
}, },
}, },
"take_phone": { "turn_on_camera": {
"name": "take_phone", "name": "turn_on_camera",
"description": "Take or answer a phone call", "description": "Turn on camera on client device",
"parameters": {
"type": "object",
"properties": {},
"required": [],
},
},
"turn_off_camera": {
"name": "turn_off_camera",
"description": "Turn off camera on client device",
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": {}, "properties": {},