From 80e1d24443487ae3b79b1d7ea5a20ca5086bab74 Mon Sep 17 00:00:00 2001 From: Xin Wang Date: Wed, 11 Feb 2026 11:22:56 +0800 Subject: [PATCH] Organize tool scheme --- api/app/routers/tools.py | 163 +++++++++++---------------------- api/init_db.py | 33 ++++++- api/tests/test_tools.py | 36 ++++---- engine/core/duplex_pipeline.py | 15 ++- 4 files changed, 115 insertions(+), 132 deletions(-) diff --git a/api/app/routers/tools.py b/api/app/routers/tools.py index 6e5db6a..2d183c5 100644 --- a/api/app/routers/tools.py +++ b/api/app/routers/tools.py @@ -15,17 +15,6 @@ router = APIRouter(prefix="/tools", tags=["Tools & Autotest"]) # ============ Available Tools ============ TOOL_REGISTRY = { - "search": { - "name": "网络搜索", - "description": "搜索互联网获取最新信息", - "parameters": { - "type": "object", - "properties": { - "query": {"type": "string", "description": "搜索关键词"} - }, - "required": ["query"] - } - }, "calculator": { "name": "计算器", "description": "执行数学计算", @@ -37,50 +26,6 @@ TOOL_REGISTRY = { "required": ["expression"] } }, - "weather": { - "name": "天气查询", - "description": "查询指定城市的天气", - "parameters": { - "type": "object", - "properties": { - "city": {"type": "string", "description": "城市名称"} - }, - "required": ["city"] - } - }, - "translate": { - "name": "翻译", - "description": "翻译文本到指定语言", - "parameters": { - "type": "object", - "properties": { - "text": {"type": "string", "description": "要翻译的文本"}, - "target_lang": {"type": "string", "description": "目标语言,如: en, ja, ko"} - }, - "required": ["text", "target_lang"] - } - }, - "knowledge": { - "name": "知识库查询", - "description": "从知识库中检索相关信息", - "parameters": { - "type": "object", - "properties": { - "query": {"type": "string", "description": "查询内容"}, - "kb_id": {"type": "string", "description": "知识库ID"} - }, - "required": ["query"] - } - }, - "current_time": { - "name": "当前时间", - "description": "获取当前本地时间", - "parameters": { - "type": "object", - "properties": {}, - "required": [] - } - }, "code_interpreter": { "name": "代码执行", "description": "安全地执行Python代码", @@ -92,9 +37,27 @@ TOOL_REGISTRY = { "required": ["code"] } }, - "take_phone": { - "name": "接听电话", - "description": "执行接听电话命令", + "current_time": { + "name": "当前时间", + "description": "获取当前本地时间", + "parameters": { + "type": "object", + "properties": {}, + "required": [] + } + }, + "turn_on_camera": { + "name": "打开摄像头", + "description": "执行打开摄像头命令", + "parameters": { + "type": "object", + "properties": {}, + "required": [] + } + }, + "turn_off_camera": { + "name": "关闭摄像头", + "description": "执行关闭摄像头命令", "parameters": { "type": "object", "properties": {}, @@ -126,68 +89,48 @@ TOOL_REGISTRY = { } TOOL_CATEGORY_MAP = { - "search": "query", - "weather": "query", - "translate": "query", - "knowledge": "query", "calculator": "query", "current_time": "query", "code_interpreter": "query", - "take_phone": "system", + "turn_on_camera": "system", + "turn_off_camera": "system", "increase_volume": "system", "decrease_volume": "system", } TOOL_ICON_MAP = { - "search": "Globe", - "weather": "CloudSun", - "translate": "Globe", - "knowledge": "Box", - "current_time": "Calendar", "calculator": "Terminal", + "current_time": "Calendar", "code_interpreter": "Terminal", - "take_phone": "Phone", + "turn_on_camera": "Camera", + "turn_off_camera": "CameraOff", "increase_volume": "Volume2", "decrease_volume": "Volume2", } -def _sync_default_tools(db: Session) -> None: - """Ensure built-in tools exist and keep system tool metadata aligned.""" - changed = False +def _seed_default_tools_if_empty(db: Session) -> None: + """Seed built-in tools only when tool_resources is empty.""" + if db.query(ToolResource).count() > 0: + return for tool_id, payload in TOOL_REGISTRY.items(): - row = db.query(ToolResource).filter(ToolResource.id == tool_id).first() - category = TOOL_CATEGORY_MAP.get(tool_id, "system") - icon = TOOL_ICON_MAP.get(tool_id, "Wrench") - if not row: - db.add(ToolResource( - id=tool_id, - user_id=1, - name=payload.get("name", tool_id), - description=payload.get("description", ""), - category=category, - icon=icon, - enabled=True, - is_system=True, - )) - changed = True - continue - if row.is_system: - new_name = payload.get("name", row.name) - new_description = payload.get("description", row.description) - if row.name != new_name: - row.name = new_name - changed = True - if row.description != new_description: - row.description = new_description - changed = True - if row.category != category: - row.category = category - changed = True - if row.icon != icon: - row.icon = icon - changed = True - if changed: - db.commit() + db.add(ToolResource( + id=tool_id, + user_id=1, + name=payload.get("name", tool_id), + description=payload.get("description", ""), + category=TOOL_CATEGORY_MAP.get(tool_id, "system"), + icon=TOOL_ICON_MAP.get(tool_id, "Wrench"), + enabled=True, + is_system=True, + )) + db.commit() + + +def recreate_tool_resources(db: Session) -> None: + """Recreate tool resources table content with current built-in defaults.""" + db.query(ToolResource).delete() + db.commit() + _seed_default_tools_if_empty(db) @router.get("/list") @@ -215,7 +158,7 @@ def list_tool_resources( db: Session = Depends(get_db), ): """获取工具资源列表。system/query 仅表示工具执行类型,不代表权限。""" - _sync_default_tools(db) + _seed_default_tools_if_empty(db) query = db.query(ToolResource) if not include_system: query = query.filter(ToolResource.is_system == False) @@ -231,7 +174,7 @@ def list_tool_resources( @router.get("/resources/{id}", response_model=ToolResourceOut) def get_tool_resource(id: str, db: Session = Depends(get_db)): """获取单个工具资源详情。""" - _sync_default_tools(db) + _seed_default_tools_if_empty(db) item = db.query(ToolResource).filter(ToolResource.id == id).first() if not item: raise HTTPException(status_code=404, detail="Tool resource not found") @@ -241,7 +184,7 @@ def get_tool_resource(id: str, db: Session = Depends(get_db)): @router.post("/resources", response_model=ToolResourceOut) def create_tool_resource(data: ToolResourceCreate, db: Session = Depends(get_db)): """创建自定义工具资源。""" - _sync_default_tools(db) + _seed_default_tools_if_empty(db) candidate_id = (data.id or "").strip() if candidate_id and db.query(ToolResource).filter(ToolResource.id == candidate_id).first(): raise HTTPException(status_code=400, detail="Tool ID already exists") @@ -265,7 +208,7 @@ def create_tool_resource(data: ToolResourceCreate, db: Session = Depends(get_db) @router.put("/resources/{id}", response_model=ToolResourceOut) def update_tool_resource(id: str, data: ToolResourceUpdate, db: Session = Depends(get_db)): """更新工具资源。""" - _sync_default_tools(db) + _seed_default_tools_if_empty(db) item = db.query(ToolResource).filter(ToolResource.id == id).first() if not item: raise HTTPException(status_code=404, detail="Tool resource not found") @@ -283,7 +226,7 @@ def update_tool_resource(id: str, data: ToolResourceUpdate, db: Session = Depend @router.delete("/resources/{id}") def delete_tool_resource(id: str, db: Session = Depends(get_db)): """删除工具资源。""" - _sync_default_tools(db) + _seed_default_tools_if_empty(db) item = db.query(ToolResource).filter(ToolResource.id == id).first() if not item: raise HTTPException(status_code=404, detail="Tool resource not found") diff --git a/api/init_db.py b/api/init_db.py index 305f516..21b3dd8 100644 --- a/api/init_db.py +++ b/api/init_db.py @@ -100,6 +100,23 @@ def init_default_data(): db.close() +def init_default_tools(recreate: bool = False): + """初始化默认工具,或按需重建工具表数据。""" + from app.db import SessionLocal + from app.routers.tools import _seed_default_tools_if_empty, recreate_tool_resources + + db = SessionLocal() + try: + if recreate: + recreate_tool_resources(db) + print("✅ 工具库已重建") + else: + _seed_default_tools_if_empty(db) + print("✅ 默认工具已初始化") + finally: + db.close() + + def init_default_assistants(): """初始化默认助手""" from sqlalchemy.orm import Session @@ -121,7 +138,7 @@ def init_default_assistants(): voice="anna", speed=1.0, hotwords=[], - tools=["search", "calculator"], + tools=["calculator", "current_time"], interruption_sensitivity=500, config_mode="platform", llm_model_id="deepseek-chat", @@ -139,7 +156,7 @@ def init_default_assistants(): voice="bella", speed=1.0, hotwords=["客服", "投诉", "咨询"], - tools=["search"], + tools=["current_time"], interruption_sensitivity=600, config_mode="platform", ), @@ -155,7 +172,7 @@ def init_default_assistants(): voice="alex", speed=1.0, hotwords=["grammar", "vocabulary", "practice"], - tools=[], + tools=["calculator"], interruption_sensitivity=400, config_mode="platform", ), @@ -421,6 +438,11 @@ if __name__ == "__main__": action="store_true", help="跳过默认数据初始化", ) + parser.add_argument( + "--recreate-tool-db", + action="store_true", + help="重建工具库数据(清空 tool_resources 后按内置默认工具重建)", + ) args = parser.parse_args() # 无参数时保持旧行为:重建 DB + 初始化默认数据 @@ -434,8 +456,13 @@ if __name__ == "__main__": if args.rebuild_db: init_db() + if args.recreate_tool_db: + init_default_tools(recreate=True) + if not args.skip_seed: init_default_data() + if not args.recreate_tool_db: + init_default_tools(recreate=False) init_default_assistants() init_default_workflows() init_default_knowledge_bases() diff --git a/api/tests/test_tools.py b/api/tests/test_tools.py index 124b7c0..78666c3 100644 --- a/api/tests/test_tools.py +++ b/api/tests/test_tools.py @@ -14,16 +14,21 @@ class TestToolsAPI: assert "tools" in data # Check for expected tools tools = data["tools"] - assert "search" in tools assert "calculator" in tools - assert "weather" in tools + assert "code_interpreter" in tools + assert "current_time" in tools + assert "turn_on_camera" in tools + assert "turn_off_camera" in tools + assert "increase_volume" in tools + assert "decrease_volume" in tools + assert "calculator" in tools def test_get_tool_detail(self, client): """Test getting a specific tool's details""" - response = client.get("/api/tools/list/search") + response = client.get("/api/tools/list/calculator") assert response.status_code == 200 data = response.json() - assert data["name"] == "网络搜索" + assert data["name"] == "计算器" assert "parameters" in data def test_get_tool_detail_not_found(self, client): @@ -256,15 +261,14 @@ class TestAutotestAPI: assert "required" in data["parameters"] assert "expression" in data["parameters"]["required"] - def test_translate_tool_parameters(self, client): - """Test translate tool has correct parameters""" - response = client.get("/api/tools/list/translate") + def test_code_interpreter_tool_parameters(self, client): + """Test code_interpreter tool has correct parameters""" + response = client.get("/api/tools/list/code_interpreter") assert response.status_code == 200 data = response.json() - assert data["name"] == "翻译" - assert "text" in data["parameters"]["properties"] - assert "target_lang" in data["parameters"]["properties"] + assert data["name"] == "代码执行" + assert "code" in data["parameters"]["properties"] class TestToolResourceCRUD: @@ -276,7 +280,7 @@ class TestToolResourceCRUD: payload = response.json() assert payload["total"] >= 1 ids = [item["id"] for item in payload["list"]] - assert "search" in ids + assert "calculator" in ids def test_create_update_delete_tool_resource(self, client): create_resp = client.post("/api/tools/resources", json={ @@ -314,14 +318,14 @@ class TestToolResourceCRUD: def test_system_tool_can_be_updated_and_deleted(self, client): list_resp = client.get("/api/tools/resources") assert list_resp.status_code == 200 - assert any(item["id"] == "search" for item in list_resp.json()["list"]) + assert any(item["id"] == "turn_on_camera" for item in list_resp.json()["list"]) - update_resp = client.put("/api/tools/resources/search", json={"name": "更新后的搜索工具", "category": "query"}) + update_resp = client.put("/api/tools/resources/turn_on_camera", json={"name": "更新后的打开摄像头", "category": "system"}) assert update_resp.status_code == 200 - assert update_resp.json()["name"] == "更新后的搜索工具" + assert update_resp.json()["name"] == "更新后的打开摄像头" - delete_resp = client.delete("/api/tools/resources/search") + delete_resp = client.delete("/api/tools/resources/turn_on_camera") assert delete_resp.status_code == 200 - get_resp = client.get("/api/tools/resources/search") + get_resp = client.get("/api/tools/resources/turn_on_camera") assert get_resp.status_code == 404 diff --git a/engine/core/duplex_pipeline.py b/engine/core/duplex_pipeline.py index 74d45ec..5684401 100644 --- a/engine/core/duplex_pipeline.py +++ b/engine/core/duplex_pipeline.py @@ -129,9 +129,18 @@ class DuplexPipeline: "required": ["code"], }, }, - "take_phone": { - "name": "take_phone", - "description": "Take or answer a phone call", + "turn_on_camera": { + "name": "turn_on_camera", + "description": "Turn on camera on client device", + "parameters": { + "type": "object", + "properties": {}, + "required": [], + }, + }, + "turn_off_camera": { + "name": "turn_off_camera", + "description": "Turn off camera on client device", "parameters": { "type": "object", "properties": {},