diff --git a/api/app/models.py b/api/app/models.py index 2b553ee..29579f2 100644 --- a/api/app/models.py +++ b/api/app/models.py @@ -117,6 +117,7 @@ class Assistant(Base): call_count: Mapped[int] = mapped_column(Integer, default=0) first_turn_mode: Mapped[str] = mapped_column(String(32), default="bot_first") opener: Mapped[str] = mapped_column(Text, default="") + manual_opener_tool_calls: Mapped[list] = mapped_column(JSON, default=list) generated_opener_enabled: Mapped[bool] = mapped_column(default=False) prompt: Mapped[str] = mapped_column(Text, default="") knowledge_base_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True) diff --git a/api/app/routers/assistants.py b/api/app/routers/assistants.py index d3a28e2..bf43303 100644 --- a/api/app/routers/assistants.py +++ b/api/app/routers/assistants.py @@ -7,6 +7,7 @@ from pathlib import Path import httpx from fastapi import APIRouter, Depends, HTTPException from fastapi.responses import FileResponse +from sqlalchemy import inspect, text from sqlalchemy.orm import Session from typing import Any, Dict, List, Optional import uuid @@ -27,6 +28,7 @@ from .tools import ( TOOL_CATEGORY_MAP, TOOL_PARAMETER_DEFAULTS, TOOL_WAIT_FOR_RESPONSE_DEFAULTS, + normalize_tool_id, _ensure_tool_resource_schema, ) @@ -111,9 +113,91 @@ def _compose_runtime_system_prompt(base_prompt: Optional[str]) -> str: return f"{raw}\n\n{tool_policy}" if raw else tool_policy +def _ensure_assistant_schema(db: Session) -> None: + """Apply lightweight SQLite migrations for newly added assistants columns.""" + bind = db.get_bind() + inspector = inspect(bind) + try: + columns = {col["name"] for col in inspector.get_columns("assistants")} + except Exception: + return + + altered = False + if "manual_opener_tool_calls" not in columns: + db.execute(text("ALTER TABLE assistants ADD COLUMN manual_opener_tool_calls JSON")) + altered = True + + if altered: + db.commit() + + +def _normalize_manual_opener_tool_calls(raw: Any, warnings: Optional[List[str]] = None) -> List[Dict[str, Any]]: + normalized: List[Dict[str, Any]] = [] + if not isinstance(raw, list): + return normalized + + for idx, item in enumerate(raw): + if not isinstance(item, dict): + if warnings is not None: + warnings.append(f"Ignored invalid manual opener tool call at index {idx}: not an object") + continue + + tool_name = normalize_tool_id(str( + item.get("toolName") + or item.get("tool_name") + or item.get("name") + or "" + ).strip()) + if not tool_name: + if warnings is not None: + warnings.append(f"Ignored invalid manual opener tool call at index {idx}: missing toolName") + continue + + args_raw = item.get("arguments") + args: Dict[str, Any] = {} + if isinstance(args_raw, dict): + args = dict(args_raw) + elif isinstance(args_raw, str): + text_value = args_raw.strip() + if text_value: + try: + parsed = json.loads(text_value) + if isinstance(parsed, dict): + args = parsed + else: + if warnings is not None: + warnings.append( + f"Ignored non-object arguments for manual opener tool call '{tool_name}' at index {idx}" + ) + except Exception: + if warnings is not None: + warnings.append(f"Ignored invalid JSON arguments for manual opener tool call '{tool_name}' at index {idx}") + elif args_raw is not None and warnings is not None: + warnings.append(f"Ignored unsupported arguments type for manual opener tool call '{tool_name}' at index {idx}") + + normalized.append({"toolName": tool_name, "arguments": args}) + + # Keep opener sequence intentionally short to avoid long pre-dialog delays. + return normalized[:8] + + +def _normalize_assistant_tool_ids(raw: Any) -> List[str]: + if not isinstance(raw, list): + return [] + normalized: List[str] = [] + seen: set[str] = set() + for item in raw: + tool_id = normalize_tool_id(item) + if not tool_id or tool_id in seen: + continue + seen.add(tool_id) + normalized.append(tool_id) + return normalized + + def _resolve_runtime_tools(db: Session, selected_tool_ids: List[str], warnings: List[str]) -> List[Dict[str, Any]]: _ensure_tool_resource_schema(db) - ids = [str(tool_id).strip() for tool_id in selected_tool_ids if str(tool_id).strip()] + ids = _normalize_assistant_tool_ids(selected_tool_ids) if not ids: return [] @@ -183,12 +267,17 @@ def _resolve_runtime_tools(db: Session, selected_tool_ids: List[str], warnings: def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> tuple[Dict[str, Any], List[str]]: warnings: List[str] = [] generated_opener_enabled = bool(assistant.generated_opener_enabled) + manual_opener_tool_calls = _normalize_manual_opener_tool_calls( + assistant.manual_opener_tool_calls, + warnings=warnings, + ) metadata: Dict[str, Any] = { "systemPrompt": _compose_runtime_system_prompt(assistant.prompt), "firstTurnMode": assistant.first_turn_mode or "bot_first", # Generated opener should rely on systemPrompt instead of fixed opener text. "greeting": "" if generated_opener_enabled else (assistant.opener or ""), "generatedOpenerEnabled": generated_opener_enabled, + "manualOpenerToolCalls": manual_opener_tool_calls, "output": {"mode": "audio" if assistant.voice_output_enabled else "text"}, "bargeIn": { "enabled": not bool(assistant.bot_cannot_be_interrupted), @@ -329,6 +418,7 @@ def assistant_to_dict(assistant: Assistant) -> dict: "callCount": assistant.call_count, "firstTurnMode": assistant.first_turn_mode or "bot_first", "opener": assistant.opener or "", + "manualOpenerToolCalls": _normalize_manual_opener_tool_calls(assistant.manual_opener_tool_calls), "generatedOpenerEnabled": bool(assistant.generated_opener_enabled), "openerAudioEnabled": bool(opener_audio.enabled) if opener_audio else False, "openerAudioReady": opener_audio_ready, @@ -341,7 +431,7 @@ def assistant_to_dict(assistant: Assistant) -> dict: "voice": assistant.voice, "speed": assistant.speed, "hotwords": assistant.hotwords or [], - "tools": assistant.tools or [], + "tools": _normalize_assistant_tool_ids(assistant.tools), "botCannotBeInterrupted": bool(assistant.bot_cannot_be_interrupted), "interruptionSensitivity": assistant.interruption_sensitivity, "configMode": assistant.config_mode, @@ -360,6 +450,7 @@ def _apply_assistant_update(assistant: Assistant, update_data: dict) -> None: field_map = { "knowledgeBaseId": "knowledge_base_id", "firstTurnMode": "first_turn_mode", + "manualOpenerToolCalls": "manual_opener_tool_calls", "interruptionSensitivity": "interruption_sensitivity", "botCannotBeInterrupted": "bot_cannot_be_interrupted", "configMode": "config_mode", @@ -492,6 +583,7 @@ def list_assistants( db: Session = Depends(get_db) ): """获取助手列表""" + _ensure_assistant_schema(db) query = db.query(Assistant) total = query.count() assistants = query.order_by(Assistant.created_at.desc()) \ @@ -507,6 +599,7 @@ def list_assistants( @router.get("/{id}", response_model=AssistantOut) def get_assistant(id: str, db: Session = Depends(get_db)): """获取单个助手详情""" + _ensure_assistant_schema(db) assistant = db.query(Assistant).filter(Assistant.id == id).first() if not assistant: raise HTTPException(status_code=404, detail="Assistant not found") @@ -516,6 +609,7 @@ def get_assistant(id: str, db: Session = Depends(get_db)): @router.get("/{id}/config", response_model=AssistantEngineConfigResponse) def get_assistant_config(id: str, db: Session = Depends(get_db)): """Canonical engine config endpoint consumed by engine backend adapter.""" + _ensure_assistant_schema(db) assistant = db.query(Assistant).filter(Assistant.id == id).first() if not assistant: raise HTTPException(status_code=404, detail="Assistant not found") @@ -525,6 +619,7 @@ def get_assistant_config(id: str, db: Session = Depends(get_db)): @router.get("/{id}/runtime-config", response_model=AssistantEngineConfigResponse) def get_assistant_runtime_config(id: str, db: Session = Depends(get_db)): """Legacy alias for resolved engine runtime config.""" + _ensure_assistant_schema(db) assistant = db.query(Assistant).filter(Assistant.id == id).first() if not assistant: raise HTTPException(status_code=404, detail="Assistant not found") @@ -534,12 +629,14 @@ def get_assistant_runtime_config(id: str, db: Session = Depends(get_db)): @router.post("", response_model=AssistantOut) def create_assistant(data: AssistantCreate, db: Session = Depends(get_db)): """创建新助手""" + _ensure_assistant_schema(db) assistant = Assistant( id=str(uuid.uuid4())[:8], user_id=1, # 默认用户,后续添加认证 name=data.name, first_turn_mode=data.firstTurnMode, opener=data.opener, + manual_opener_tool_calls=_normalize_manual_opener_tool_calls(data.manualOpenerToolCalls), generated_opener_enabled=data.generatedOpenerEnabled, prompt=data.prompt, knowledge_base_id=data.knowledgeBaseId, @@ -548,7 +645,7 @@ def create_assistant(data: AssistantCreate, db: Session = Depends(get_db)): voice=data.voice, speed=data.speed, hotwords=data.hotwords, - tools=data.tools, + tools=_normalize_assistant_tool_ids(data.tools), bot_cannot_be_interrupted=data.botCannotBeInterrupted, interruption_sensitivity=data.interruptionSensitivity, config_mode=data.configMode, @@ -572,6 +669,7 @@ def create_assistant(data: AssistantCreate, db: Session = Depends(get_db)): @router.get("/{id}/opener-audio", response_model=AssistantOpenerAudioOut) def get_assistant_opener_audio(id: str, db: Session = Depends(get_db)): + _ensure_assistant_schema(db) assistant = db.query(Assistant).filter(Assistant.id == id).first() if not assistant: raise HTTPException(status_code=404, detail="Assistant not found") @@ -580,6 +678,7 @@ def get_assistant_opener_audio(id: str, db: Session = Depends(get_db)): @router.get("/{id}/opener-audio/pcm") def get_assistant_opener_audio_pcm(id: str, db: Session = Depends(get_db)): + _ensure_assistant_schema(db) assistant = db.query(Assistant).filter(Assistant.id == id).first() if not assistant: raise HTTPException(status_code=404, detail="Assistant not found") @@ -602,6 +701,7 @@ def generate_assistant_opener_audio( data: AssistantOpenerAudioGenerateRequest, db: Session = Depends(get_db), ): + _ensure_assistant_schema(db) assistant = db.query(Assistant).filter(Assistant.id == id).first() if not assistant: raise HTTPException(status_code=404, detail="Assistant not found") @@ -691,12 +791,17 @@ def generate_assistant_opener_audio( @router.put("/{id}") def update_assistant(id: str, data: AssistantUpdate, db: Session = Depends(get_db)): """更新助手""" + _ensure_assistant_schema(db) assistant = db.query(Assistant).filter(Assistant.id == id).first() if not assistant: raise HTTPException(status_code=404, detail="Assistant not found") update_data = data.model_dump(exclude_unset=True) opener_audio_enabled = update_data.pop("openerAudioEnabled", None) + if "manualOpenerToolCalls" in update_data: + update_data["manualOpenerToolCalls"] = _normalize_manual_opener_tool_calls(update_data.get("manualOpenerToolCalls")) + if "tools" in update_data: + update_data["tools"] = _normalize_assistant_tool_ids(update_data.get("tools")) _apply_assistant_update(assistant, update_data) if opener_audio_enabled is not None: record = _ensure_assistant_opener_audio(db, assistant) @@ -712,6 +817,7 @@ def update_assistant(id: str, data: AssistantUpdate, db: Session = Depends(get_d @router.delete("/{id}") def delete_assistant(id: str, db: Session = Depends(get_db)): """删除助手""" + _ensure_assistant_schema(db) assistant = db.query(Assistant).filter(Assistant.id == id).first() if not assistant: raise HTTPException(status_code=404, detail="Assistant not found") diff --git a/api/app/routers/tools.py b/api/app/routers/tools.py index 00c9bf3..2b0220a 100644 --- a/api/app/routers/tools.py +++ b/api/app/routers/tools.py @@ -14,6 +14,19 @@ from ..schemas import ToolResourceCreate, ToolResourceOut, ToolResourceUpdate router = APIRouter(prefix="/tools", tags=["Tools & Autotest"]) +TOOL_ID_ALIASES: Dict[str, str] = { + # legacy -> canonical + "voice_message_prompt": "voice_msg_prompt", +} + + +def normalize_tool_id(tool_id: Optional[str]) -> str: + raw = str(tool_id or "").strip() + if not raw: + return "" + return TOOL_ID_ALIASES.get(raw, raw) + + # ============ Available Tools ============ TOOL_REGISTRY = { "calculator": { @@ -87,7 +100,7 @@ TOOL_REGISTRY = { "required": [] } }, - "voice_message_prompt": { + "voice_msg_prompt": { "name": "语音消息提示", "description": "播报一条语音提示消息", "parameters": { @@ -180,7 +193,8 @@ TOOL_CATEGORY_MAP = { "turn_off_camera": "system", "increase_volume": "system", "decrease_volume": "system", - "voice_message_prompt": "system", + "voice_msg_prompt": "system", + "voice_message_prompt": "system", # backward compatibility "text_msg_prompt": "system", "voice_choice_prompt": "system", "text_choice_prompt": "system", @@ -194,7 +208,8 @@ TOOL_ICON_MAP = { "turn_off_camera": "CameraOff", "increase_volume": "Volume2", "decrease_volume": "Volume2", - "voice_message_prompt": "Volume2", + "voice_msg_prompt": "Volume2", + "voice_message_prompt": "Volume2", # backward compatibility "text_msg_prompt": "Terminal", "voice_choice_prompt": "Volume2", "text_choice_prompt": "Terminal", @@ -284,9 +299,49 @@ def _validate_query_http_config(*, category: str, tool_id: Optional[str], http_u raise HTTPException(status_code=400, detail="http_url is required for query tools (except calculator/code_interpreter)") +def _migrate_legacy_system_tool_ids(db: Session) -> None: + """Rename legacy built-in system tool IDs to their canonical IDs.""" + changed = False + for legacy_id, canonical_id in TOOL_ID_ALIASES.items(): + if legacy_id == canonical_id: + continue + legacy_item = ( + db.query(ToolResource) + .filter(ToolResource.id == legacy_id) + .first() + ) + if not legacy_item or not bool(legacy_item.is_system): + continue + + canonical_item = ( + db.query(ToolResource) + .filter(ToolResource.id == canonical_id) + .first() + ) + if canonical_item: + db.delete(legacy_item) + changed = True + continue + + legacy_item.id = canonical_id + legacy_item.updated_at = datetime.utcnow() + changed = True + + if changed: + db.commit() + + def _seed_default_tools_if_empty(db: Session) -> None: """Ensure built-in tools exist in tool_resources without overriding custom edits.""" _ensure_tool_resource_schema(db) + _migrate_legacy_system_tool_ids(db) + existing_system_count = ( + db.query(ToolResource.id) + .filter(ToolResource.is_system.is_(True)) + .count() + ) + if existing_system_count > 0: + return existing_ids = { str(item[0]) for item in db.query(ToolResource.id).all() @@ -335,9 +390,10 @@ def list_available_tools(): @router.get("/list/{tool_id}") def get_tool_detail(tool_id: str): """获取工具详情""" - if tool_id not in TOOL_REGISTRY: + canonical_tool_id = normalize_tool_id(tool_id) + if canonical_tool_id not in TOOL_REGISTRY: raise HTTPException(status_code=404, detail="Tool not found") - return TOOL_REGISTRY[tool_id] + return TOOL_REGISTRY[canonical_tool_id] # ============ Tool Resource CRUD ============ @@ -369,6 +425,10 @@ def get_tool_resource(id: str, db: Session = Depends(get_db)): """获取单个工具资源详情。""" _seed_default_tools_if_empty(db) item = db.query(ToolResource).filter(ToolResource.id == id).first() + if not item: + canonical_id = normalize_tool_id(id) + if canonical_id and canonical_id != id: + item = db.query(ToolResource).filter(ToolResource.id == canonical_id).first() if not item: raise HTTPException(status_code=404, detail="Tool resource not found") return item @@ -378,7 +438,7 @@ def get_tool_resource(id: str, db: Session = Depends(get_db)): def create_tool_resource(data: ToolResourceCreate, db: Session = Depends(get_db)): """创建自定义工具资源。""" _seed_default_tools_if_empty(db) - candidate_id = (data.id or "").strip() + candidate_id = normalize_tool_id((data.id or "").strip()) if candidate_id and db.query(ToolResource).filter(ToolResource.id == candidate_id).first(): raise HTTPException(status_code=400, detail="Tool ID already exists") @@ -413,7 +473,10 @@ def create_tool_resource(data: ToolResourceCreate, db: Session = Depends(get_db) def update_tool_resource(id: str, data: ToolResourceUpdate, db: Session = Depends(get_db)): """更新工具资源。""" _seed_default_tools_if_empty(db) + canonical_id = normalize_tool_id(id) item = db.query(ToolResource).filter(ToolResource.id == id).first() + if not item and canonical_id and canonical_id != id: + item = db.query(ToolResource).filter(ToolResource.id == canonical_id).first() if not item: raise HTTPException(status_code=404, detail="Tool resource not found") @@ -421,14 +484,14 @@ def update_tool_resource(id: str, data: ToolResourceUpdate, db: Session = Depend new_category = update_data.get("category", item.category) new_http_url = update_data.get("http_url", item.http_url) - _validate_query_http_config(category=new_category, tool_id=id, http_url=new_http_url) + _validate_query_http_config(category=new_category, tool_id=item.id, http_url=new_http_url) if "http_method" in update_data: update_data["http_method"] = _normalize_http_method(update_data.get("http_method")) if "http_timeout_ms" in update_data and update_data.get("http_timeout_ms") is not None: update_data["http_timeout_ms"] = max(1000, int(update_data["http_timeout_ms"])) if "parameter_schema" in update_data: - update_data["parameter_schema"] = _normalize_parameter_schema(update_data.get("parameter_schema"), tool_id=id) + update_data["parameter_schema"] = _normalize_parameter_schema(update_data.get("parameter_schema"), tool_id=item.id) if "parameter_defaults" in update_data: update_data["parameter_defaults"] = _normalize_parameter_defaults(update_data.get("parameter_defaults")) if new_category != "system": @@ -447,7 +510,10 @@ def update_tool_resource(id: str, data: ToolResourceUpdate, db: Session = Depend def delete_tool_resource(id: str, db: Session = Depends(get_db)): """删除工具资源。""" _seed_default_tools_if_empty(db) + canonical_id = normalize_tool_id(id) item = db.query(ToolResource).filter(ToolResource.id == id).first() + if not item and canonical_id and canonical_id != id: + item = db.query(ToolResource).filter(ToolResource.id == canonical_id).first() if not item: raise HTTPException(status_code=404, detail="Tool resource not found") db.delete(item) diff --git a/api/app/schemas.py b/api/app/schemas.py index 91b81a5..9bf2274 100644 --- a/api/app/schemas.py +++ b/api/app/schemas.py @@ -280,6 +280,7 @@ class AssistantBase(BaseModel): name: str firstTurnMode: str = "bot_first" opener: str = "" + manualOpenerToolCalls: List[Dict[str, Any]] = [] generatedOpenerEnabled: bool = False openerAudioEnabled: bool = False prompt: str = "" @@ -310,6 +311,7 @@ class AssistantUpdate(BaseModel): name: Optional[str] = None firstTurnMode: Optional[str] = None opener: Optional[str] = None + manualOpenerToolCalls: Optional[List[Dict[str, Any]]] = None generatedOpenerEnabled: Optional[bool] = None openerAudioEnabled: Optional[bool] = None prompt: Optional[str] = None @@ -350,6 +352,7 @@ class AssistantRuntimeMetadata(BaseModel): firstTurnMode: str = "bot_first" greeting: str = "" generatedOpenerEnabled: bool = False + manualOpenerToolCalls: List[Dict[str, Any]] = Field(default_factory=list) output: Dict[str, Any] = Field(default_factory=dict) bargeIn: Dict[str, Any] = Field(default_factory=dict) services: Dict[str, Dict[str, Any]] = Field(default_factory=dict) diff --git a/api/docs/tools.md b/api/docs/tools.md index 892b8a2..08cd5a2 100644 --- a/api/docs/tools.md +++ b/api/docs/tools.md @@ -24,7 +24,7 @@ | turn_off_camera | 关闭摄像头 | system | 执行关闭摄像头命令 | | increase_volume | 调高音量 | system | 提升设备音量 | | decrease_volume | 调低音量 | system | 降低设备音量 | -| voice_message_prompt | 语音消息提示 | system | 播报一条语音提示消息 | +| voice_msg_prompt | 语音消息提示 | system | 播报一条语音提示消息 | | text_msg_prompt | 文本消息提示 | system | 显示一条文本弹窗提示 | | voice_choice_prompt | 语音选项提示 | system | 播报问题并展示可选项,等待用户选择 | | text_choice_prompt | 文本选项提示 | system | 显示文本选项弹窗并等待用户选择 | diff --git a/api/tests/test_assistants.py b/api/tests/test_assistants.py index adb7d7c..0d880ef 100644 --- a/api/tests/test_assistants.py +++ b/api/tests/test_assistants.py @@ -21,6 +21,7 @@ class TestAssistantAPI: data = response.json() assert data["name"] == sample_assistant_data["name"] assert data["opener"] == sample_assistant_data["opener"] + assert data["manualOpenerToolCalls"] == [] assert data["prompt"] == sample_assistant_data["prompt"] assert data["language"] == sample_assistant_data["language"] assert data["voiceOutputEnabled"] is True @@ -67,6 +68,9 @@ class TestAssistantAPI: "prompt": "You are an updated assistant.", "speed": 1.5, "voiceOutputEnabled": False, + "manualOpenerToolCalls": [ + {"toolName": "text_msg_prompt", "arguments": {"msg": "请选择服务类型"}} + ], } response = client.put(f"/api/assistants/{assistant_id}", json=update_data) assert response.status_code == 200 @@ -75,6 +79,9 @@ class TestAssistantAPI: assert data["prompt"] == "You are an updated assistant." assert data["speed"] == 1.5 assert data["voiceOutputEnabled"] is False + assert data["manualOpenerToolCalls"] == [ + {"toolName": "text_msg_prompt", "arguments": {"msg": "请选择服务类型"}} + ] def test_delete_assistant(self, client, sample_assistant_data): """Test deleting an assistant""" @@ -205,6 +212,7 @@ class TestAssistantAPI: "voice": voice_id, "prompt": "runtime prompt", "opener": "runtime opener", + "manualOpenerToolCalls": [{"toolName": "text_msg_prompt", "arguments": {"msg": "欢迎"}}], "speed": 1.1, }) assistant_resp = client.post("/api/assistants", json=sample_assistant_data) @@ -217,8 +225,10 @@ class TestAssistantAPI: assert payload["assistantId"] == assistant_id metadata = payload["sessionStartMetadata"] - assert metadata["systemPrompt"] == "runtime prompt" + assert metadata["systemPrompt"].startswith("runtime prompt") + assert "Tool usage policy:" in metadata["systemPrompt"] assert metadata["greeting"] == "runtime opener" + assert metadata["manualOpenerToolCalls"] == [{"toolName": "text_msg_prompt", "arguments": {"msg": "欢迎"}}] assert metadata["services"]["llm"]["model"] == sample_llm_model_data["model_name"] assert metadata["services"]["asr"]["model"] == sample_asr_model_data["model_name"] assert metadata["services"]["asr"]["baseUrl"] == sample_asr_model_data["base_url"] @@ -239,8 +249,10 @@ class TestAssistantAPI: assert payload["assistantId"] == assistant_id assert payload["assistant"]["assistantId"] == assistant_id assert payload["assistant"]["configVersionId"].startswith(f"asst_{assistant_id}_") - assert payload["assistant"]["systemPrompt"] == sample_assistant_data["prompt"] - assert payload["sessionStartMetadata"]["systemPrompt"] == sample_assistant_data["prompt"] + assert payload["assistant"]["systemPrompt"].startswith(sample_assistant_data["prompt"]) + assert "Tool usage policy:" in payload["assistant"]["systemPrompt"] + assert payload["sessionStartMetadata"]["systemPrompt"].startswith(sample_assistant_data["prompt"]) + assert "Tool usage policy:" in payload["sessionStartMetadata"]["systemPrompt"] assert payload["sessionStartMetadata"]["history"]["assistantId"] == assistant_id def test_runtime_config_resolves_selected_tools_into_runtime_definitions(self, client, sample_assistant_data): @@ -263,6 +275,30 @@ class TestAssistantAPI: assert by_name["calculator"]["function"]["parameters"]["type"] == "object" assert "expression" in by_name["calculator"]["function"]["parameters"]["properties"] + def test_runtime_config_normalizes_legacy_voice_message_prompt_tool_id(self, client, sample_assistant_data): + sample_assistant_data["tools"] = ["voice_message_prompt"] + sample_assistant_data["manualOpenerToolCalls"] = [ + {"toolName": "voice_message_prompt", "arguments": {"msg": "您好"}} + ] + assistant_resp = client.post("/api/assistants", json=sample_assistant_data) + assert assistant_resp.status_code == 200 + assistant_payload = assistant_resp.json() + assistant_id = assistant_payload["id"] + assert assistant_payload["tools"] == ["voice_msg_prompt"] + assert assistant_payload["manualOpenerToolCalls"] == [ + {"toolName": "voice_msg_prompt", "arguments": {"msg": "您好"}} + ] + + runtime_resp = client.get(f"/api/assistants/{assistant_id}/runtime-config") + assert runtime_resp.status_code == 200 + metadata = runtime_resp.json()["sessionStartMetadata"] + tools = metadata["tools"] + by_name = {item["function"]["name"]: item for item in tools} + assert "voice_msg_prompt" in by_name + assert metadata["manualOpenerToolCalls"] == [ + {"toolName": "voice_msg_prompt", "arguments": {"msg": "您好"}} + ] + def test_runtime_config_text_mode_when_voice_output_disabled(self, client, sample_assistant_data): sample_assistant_data["voiceOutputEnabled"] = False assistant_resp = client.post("/api/assistants", json=sample_assistant_data) diff --git a/api/tests/test_tools.py b/api/tests/test_tools.py index daa1e85..9662e0d 100644 --- a/api/tests/test_tools.py +++ b/api/tests/test_tools.py @@ -21,6 +21,7 @@ class TestToolsAPI: assert "turn_off_camera" in tools assert "increase_volume" in tools assert "decrease_volume" in tools + assert "voice_msg_prompt" in tools assert "calculator" in tools def test_get_tool_detail(self, client): @@ -36,6 +37,14 @@ class TestToolsAPI: response = client.get("/api/tools/list/non-existent-tool") assert response.status_code == 404 + def test_get_tool_detail_legacy_alias(self, client): + """Legacy tool id should resolve to canonical tool detail.""" + response = client.get("/api/tools/list/voice_message_prompt") + assert response.status_code == 200 + data = response.json() + assert data["name"] == "语音消息提示" + assert "msg" in data["parameters"]["properties"] + def test_health_check(self, client): """Test health check endpoint""" response = client.get("/api/tools/health") @@ -281,6 +290,7 @@ class TestToolResourceCRUD: assert payload["total"] >= 1 ids = [item["id"] for item in payload["list"]] assert "calculator" in ids + assert "voice_msg_prompt" in ids calculator = next((item for item in payload["list"] if item["id"] == "calculator"), None) assert calculator is not None assert calculator["parameter_schema"]["type"] == "object" diff --git a/engine/core/duplex_pipeline.py b/engine/core/duplex_pipeline.py index f2d96e2..fa903af 100644 --- a/engine/core/duplex_pipeline.py +++ b/engine/core/duplex_pipeline.py @@ -146,8 +146,8 @@ class DuplexPipeline: "required": [], }, }, - "voice_message_prompt": { - "name": "voice_message_prompt", + "voice_msg_prompt": { + "name": "voice_msg_prompt", "description": "Speak a message prompt on client side", "parameters": { "type": "object", @@ -238,11 +238,21 @@ class DuplexPipeline: "turn_off_camera", "increase_volume", "decrease_volume", - "voice_message_prompt", + "voice_msg_prompt", "text_msg_prompt", "voice_choice_prompt", "text_choice_prompt", }) + _TOOL_NAME_ALIASES = { + "voice_message_prompt": "voice_msg_prompt", + } + + @classmethod + def _normalize_tool_name(cls, raw_name: Any) -> str: + name = str(raw_name or "").strip() + if not name: + return "" + return cls._TOOL_NAME_ALIASES.get(name, name) def __init__( self, @@ -369,6 +379,7 @@ class DuplexPipeline: self._runtime_first_turn_mode: str = "bot_first" self._runtime_greeting: Optional[str] = None self._runtime_generated_opener_enabled: Optional[bool] = None + self._runtime_manual_opener_tool_calls: List[Any] = [] self._runtime_opener_audio: Dict[str, Any] = {} self._runtime_barge_in_enabled: Optional[bool] = None self._runtime_barge_in_min_duration_ms: Optional[int] = None @@ -463,6 +474,9 @@ class DuplexPipeline: generated_opener_flag = self._coerce_bool(metadata.get("generatedOpenerEnabled")) if generated_opener_flag is not None: self._runtime_generated_opener_enabled = generated_opener_flag + if "manualOpenerToolCalls" in metadata: + manual_calls = metadata.get("manualOpenerToolCalls") + self._runtime_manual_opener_tool_calls = manual_calls if isinstance(manual_calls, list) else [] services = metadata.get("services") or {} if isinstance(services, dict): @@ -571,6 +585,10 @@ class DuplexPipeline: "tools": { "allowlist": self._resolved_tool_allowlist(), }, + "opener": { + "generated": self._generated_opener_enabled(), + "manualToolCallCount": len(self._resolved_manual_opener_tool_calls()), + }, "tracks": { "audio_in": self.track_audio_in, "audio_out": self.track_audio_out, @@ -965,6 +983,11 @@ class DuplexPipeline: logger.info("Initial generated opener started with tool-calling path") return + if not self._generated_opener_enabled() and self._resolved_manual_opener_tool_calls(): + self._start_turn() + self._start_response() + await self._execute_manual_opener_tool_calls() + greeting_to_speak = self.conversation.greeting if self._generated_opener_enabled(): generated_greeting = await self._generate_runtime_greeting() @@ -975,8 +998,10 @@ class DuplexPipeline: if not greeting_to_speak: return - self._start_turn() - self._start_response() + if not self._current_turn_id: + self._start_turn() + if not self._current_response_id: + self._start_response() await self._send_event( ev( "assistant.response.final", @@ -1551,7 +1576,7 @@ class DuplexPipeline: seen: set[str] = set() for item in self._runtime_tools: if isinstance(item, str): - tool_name = item.strip() + tool_name = self._normalize_tool_name(item) if not tool_name or tool_name in seen: continue seen.add(tool_name) @@ -1585,7 +1610,7 @@ class DuplexPipeline: fn = item.get("function") if isinstance(fn, dict) and fn.get("name"): - fn_name = str(fn.get("name")).strip() + fn_name = self._normalize_tool_name(fn.get("name")) if not fn_name or fn_name in seen: continue seen.add(fn_name) @@ -1602,7 +1627,7 @@ class DuplexPipeline: continue if item.get("name"): - item_name = str(item.get("name")).strip() + item_name = self._normalize_tool_name(item.get("name")) if not item_name or item_name in seen: continue seen.add(item_name) @@ -1622,7 +1647,7 @@ class DuplexPipeline: result: Dict[str, str] = {} for item in self._runtime_tools: if isinstance(item, str): - name = item.strip() + name = self._normalize_tool_name(item) if name in self._DEFAULT_CLIENT_EXECUTORS: result[name] = "client" continue @@ -1630,9 +1655,9 @@ class DuplexPipeline: continue fn = item.get("function") if isinstance(fn, dict) and fn.get("name"): - name = str(fn.get("name")) + name = self._normalize_tool_name(fn.get("name")) else: - name = str(item.get("name") or "").strip() + name = self._normalize_tool_name(item.get("name")) if not name: continue executor = str(item.get("executor") or item.get("run_on") or "").strip().lower() @@ -1647,9 +1672,9 @@ class DuplexPipeline: continue fn = item.get("function") if isinstance(fn, dict) and fn.get("name"): - name = str(fn.get("name")).strip() + name = self._normalize_tool_name(fn.get("name")) else: - name = str(item.get("name") or "").strip() + name = self._normalize_tool_name(item.get("name")) if not name: continue raw_defaults = item.get("defaultArgs") @@ -1666,9 +1691,9 @@ class DuplexPipeline: continue fn = item.get("function") if isinstance(fn, dict) and fn.get("name"): - name = str(fn.get("name")).strip() + name = self._normalize_tool_name(fn.get("name")) else: - name = str(item.get("name") or "").strip() + name = self._normalize_tool_name(item.get("name")) if not name: continue raw_wait = item.get("waitForResponse") @@ -1685,12 +1710,12 @@ class DuplexPipeline: continue fn = item.get("function") if isinstance(fn, dict) and fn.get("name"): - alias = str(fn.get("name")).strip() + alias = self._normalize_tool_name(fn.get("name")) else: - alias = str(item.get("name") or "").strip() + alias = self._normalize_tool_name(item.get("name")) if not alias: continue - tool_id = str(item.get("toolId") or item.get("tool_id") or alias).strip() + tool_id = self._normalize_tool_name(item.get("toolId") or item.get("tool_id") or alias) if tool_id: result[alias] = tool_id return result @@ -1702,9 +1727,9 @@ class DuplexPipeline: continue fn = item.get("function") if isinstance(fn, dict) and fn.get("name"): - name = str(fn.get("name")).strip() + name = self._normalize_tool_name(fn.get("name")) else: - name = str(item.get("name") or "").strip() + name = self._normalize_tool_name(item.get("name")) if not name: continue display_name = str( @@ -1714,7 +1739,7 @@ class DuplexPipeline: ).strip() if display_name: result[name] = display_name - tool_id = str(item.get("toolId") or item.get("tool_id") or "").strip() + tool_id = self._normalize_tool_name(item.get("toolId") or item.get("tool_id") or "") if tool_id: result[tool_id] = display_name return result @@ -1723,7 +1748,7 @@ class DuplexPipeline: names: set[str] = set() for item in self._runtime_tools: if isinstance(item, str): - name = item.strip() + name = self._normalize_tool_name(item) if name: names.add(name) continue @@ -1731,25 +1756,57 @@ class DuplexPipeline: continue fn = item.get("function") if isinstance(fn, dict) and fn.get("name"): - names.add(str(fn.get("name")).strip()) + names.add(self._normalize_tool_name(fn.get("name"))) elif item.get("name"): - names.add(str(item.get("name")).strip()) + names.add(self._normalize_tool_name(item.get("name"))) return sorted([name for name in names if name]) + def _resolved_manual_opener_tool_calls(self) -> List[Dict[str, Any]]: + result: List[Dict[str, Any]] = [] + for item in self._runtime_manual_opener_tool_calls: + if not isinstance(item, dict): + continue + tool_name = self._normalize_tool_name(str( + item.get("toolName") + or item.get("tool_name") + or item.get("name") + or "" + ).strip()) + if not tool_name: + continue + args_raw = item.get("arguments") + args: Dict[str, Any] = {} + if isinstance(args_raw, dict): + args = dict(args_raw) + elif isinstance(args_raw, str): + text_value = args_raw.strip() + if text_value: + try: + parsed = json.loads(text_value) + if isinstance(parsed, dict): + args = parsed + except Exception: + logger.warning(f"[OpenerTool] ignore invalid JSON args for tool={tool_name}") + result.append({"toolName": tool_name, "arguments": args}) + return result[:8] + def _tool_name(self, tool_call: Dict[str, Any]) -> str: fn = tool_call.get("function") if isinstance(fn, dict): - return str(fn.get("name") or "").strip() + return self._normalize_tool_name(fn.get("name")) return "" def _tool_id_for_name(self, tool_name: str) -> str: - return str(self._runtime_tool_id_map.get(tool_name) or tool_name).strip() + normalized = self._normalize_tool_name(tool_name) + return self._normalize_tool_name(self._runtime_tool_id_map.get(normalized) or normalized) def _tool_display_name(self, tool_name: str) -> str: - return str(self._runtime_tool_display_names.get(tool_name) or tool_name).strip() + normalized = self._normalize_tool_name(tool_name) + return str(self._runtime_tool_display_names.get(normalized) or normalized).strip() def _tool_wait_for_response(self, tool_name: str) -> bool: - return bool(self._runtime_tool_wait_for_response.get(tool_name, False)) + normalized = self._normalize_tool_name(tool_name) + return bool(self._runtime_tool_wait_for_response.get(normalized, False)) def _tool_executor(self, tool_call: Dict[str, Any]) -> str: name = self._tool_name(tool_call) @@ -1774,7 +1831,8 @@ class DuplexPipeline: return {} def _apply_tool_default_args(self, tool_name: str, args: Dict[str, Any]) -> Dict[str, Any]: - defaults = self._runtime_tool_default_args.get(tool_name) + normalized_tool_name = self._normalize_tool_name(tool_name) + defaults = self._runtime_tool_default_args.get(normalized_tool_name) if not isinstance(defaults, dict) or not defaults: return args merged = dict(defaults) @@ -1782,6 +1840,84 @@ class DuplexPipeline: merged.update(args) return merged + async def _execute_manual_opener_tool_calls(self) -> None: + calls = self._resolved_manual_opener_tool_calls() + if not calls: + return + + for call in calls: + tool_name = str(call.get("toolName") or "").strip() + if not tool_name: + continue + tool_id = self._tool_id_for_name(tool_name) + tool_display_name = self._tool_display_name(tool_name) or tool_name + tool_arguments = call.get("arguments") if isinstance(call.get("arguments"), dict) else {} + merged_tool_arguments = self._apply_tool_default_args(tool_name, tool_arguments) + call_id = f"call_opener_{uuid.uuid4().hex[:10]}" + wait_for_response = self._tool_wait_for_response(tool_name) + tool_call = { + "id": call_id, + "type": "function", + "function": { + "name": tool_name, + "arguments": json.dumps(merged_tool_arguments, ensure_ascii=False), + }, + } + executor = self._tool_executor(tool_call) + + await self._send_event( + { + **ev( + "assistant.tool_call", + trackId=self.track_audio_out, + tool_call_id=call_id, + tool_name=tool_name, + tool_id=tool_id, + tool_display_name=tool_display_name, + wait_for_response=wait_for_response, + arguments=merged_tool_arguments, + executor=executor, + timeout_ms=int(self._TOOL_WAIT_TIMEOUT_SECONDS * 1000), + tool_call={**tool_call, "executor": executor, "wait_for_response": wait_for_response}, + ) + }, + priority=22, + ) + logger.info( + f"[OpenerTool] execute name={tool_name} call_id={call_id} executor={executor} " + f"wait_for_response={wait_for_response}" + ) + + if executor == "client": + self._pending_client_tool_call_ids.add(call_id) + if wait_for_response: + result = await self._wait_for_single_tool_result(call_id) + await self._emit_tool_result(result, source="client") + continue + + call_for_executor = dict(tool_call) + fn_for_executor = ( + dict(call_for_executor.get("function")) + if isinstance(call_for_executor.get("function"), dict) + else None + ) + if isinstance(fn_for_executor, dict): + fn_for_executor["name"] = tool_id + call_for_executor["function"] = fn_for_executor + try: + result = await asyncio.wait_for( + self._server_tool_executor(call_for_executor), + timeout=self._SERVER_TOOL_TIMEOUT_SECONDS, + ) + except asyncio.TimeoutError: + result = { + "tool_call_id": call_id, + "name": tool_name, + "output": {"message": "server tool timeout"}, + "status": {"code": 504, "message": "server_tool_timeout"}, + } + await self._emit_tool_result(result, source="server") + def _normalize_tool_result(self, result: Dict[str, Any]) -> Dict[str, Any]: status = result.get("status") if isinstance(result.get("status"), dict) else {} status_code = int(status.get("code") or 0) if status else 0 diff --git a/engine/core/session.py b/engine/core/session.py index f7042fc..1603bda 100644 --- a/engine/core/session.py +++ b/engine/core/session.py @@ -66,6 +66,7 @@ class Session: "firstTurnMode", "greeting", "generatedOpenerEnabled", + "manualOpenerToolCalls", "systemPrompt", "output", "bargeIn", @@ -973,6 +974,7 @@ class Session: passthrough_keys = { "firstTurnMode", "generatedOpenerEnabled", + "manualOpenerToolCalls", "output", "bargeIn", "knowledgeBaseId", diff --git a/engine/tests/test_tool_call_flow.py b/engine/tests/test_tool_call_flow.py index f3285be..d42962a 100644 --- a/engine/tests/test_tool_call_flow.py +++ b/engine/tests/test_tool_call_flow.py @@ -282,6 +282,74 @@ async def test_generated_opener_uses_tool_capable_turn_when_tools_available(monk assert called.get("user_text") == "" +@pytest.mark.asyncio +async def test_manual_opener_tool_calls_emit_assistant_tool_call(monkeypatch): + pipeline, events = _build_pipeline(monkeypatch, [[LLMStreamEvent(type="done")]]) + pipeline.apply_runtime_overrides( + { + "generatedOpenerEnabled": False, + "greeting": "你好,欢迎来电", + "output": {"mode": "text"}, + "tools": [ + { + "type": "function", + "executor": "client", + "waitForResponse": False, + "function": { + "name": "text_msg_prompt", + "description": "Show prompt dialog", + "parameters": {"type": "object", "properties": {"msg": {"type": "string"}}}, + }, + } + ], + "manualOpenerToolCalls": [ + {"toolName": "text_msg_prompt", "arguments": {"msg": "请先选择业务类型"}} + ], + } + ) + + await pipeline.emit_initial_greeting() + + tool_events = [event for event in events if event.get("type") == "assistant.tool_call"] + assert len(tool_events) == 1 + assert tool_events[0].get("tool_name") == "text_msg_prompt" + assert tool_events[0].get("arguments") == {"msg": "请先选择业务类型"} + + +@pytest.mark.asyncio +async def test_manual_opener_legacy_voice_message_prompt_is_normalized(monkeypatch): + pipeline, events = _build_pipeline(monkeypatch, [[LLMStreamEvent(type="done")]]) + pipeline.apply_runtime_overrides( + { + "generatedOpenerEnabled": False, + "greeting": "", + "output": {"mode": "text"}, + "tools": [ + { + "type": "function", + "executor": "client", + "waitForResponse": False, + "function": { + "name": "voice_message_prompt", + "description": "Speak prompt", + "parameters": {"type": "object", "properties": {"msg": {"type": "string"}}}, + }, + } + ], + "manualOpenerToolCalls": [ + {"toolName": "voice_message_prompt", "arguments": {"msg": "您好"}} + ], + } + ) + + await pipeline.emit_initial_greeting() + + tool_events = [event for event in events if event.get("type") == "assistant.tool_call"] + assert len(tool_events) == 1 + assert tool_events[0].get("tool_name") == "voice_msg_prompt" + assert tool_events[0].get("arguments") == {"msg": "您好"} + + @pytest.mark.asyncio async def test_ws_message_parses_tool_call_results(): msg = parse_client_message( diff --git a/web/pages/Assistants.tsx b/web/pages/Assistants.tsx index c85ada5..266bbe6 100644 --- a/web/pages/Assistants.tsx +++ b/web/pages/Assistants.tsx @@ -3,7 +3,7 @@ import React, { useState, useEffect, useMemo, useRef } from 'react'; import { createPortal } from 'react-dom'; import { Plus, Search, Play, Square, Copy, Trash2, Mic, MessageSquare, Save, Video, PhoneOff, Camera, ArrowLeftRight, Send, Phone, Rocket, AlertTriangle, PhoneCall, CameraOff, Image, Images, CloudSun, Calendar, TrendingUp, Coins, Wrench, Globe, Terminal, X, ClipboardCheck, Sparkles, Volume2, Timer, ChevronDown, Database, Server, Zap, ExternalLink, Key, BrainCircuit, Ear, Book, Filter } from 'lucide-react'; import { Button, Input, Badge, Drawer, Dialog, Switch } from '../components/UI'; -import { ASRModel, Assistant, KnowledgeBase, LLMModel, TabValue, Tool, Voice } from '../types'; +import { ASRModel, Assistant, AssistantOpenerToolCall, KnowledgeBase, LLMModel, TabValue, Tool, Voice } from '../types'; import { createAssistant, deleteAssistant, fetchASRModels, fetchAssistantOpenerAudioPcmBuffer, fetchAssistants, fetchKnowledgeBases, fetchLLMModels, fetchTools, fetchVoices, generateAssistantOpenerAudio, previewVoice, updateAssistant as updateAssistantApi } from '../services/backendApi'; const isOpenAICompatibleVendor = (vendor?: string) => { @@ -85,6 +85,80 @@ const renderToolIcon = (icon: string) => { return map[icon] || ; }; +const TOOL_ID_ALIASES: Record = { + voice_message_prompt: 'voice_msg_prompt', +}; + +const normalizeToolId = (raw: unknown): string => { + const toolId = String(raw || '').trim(); + if (!toolId) return ''; + return TOOL_ID_ALIASES[toolId] || toolId; +}; + +const OPENER_TOOL_ARGUMENT_TEMPLATES: Record> = { + text_msg_prompt: { + msg: '您好,请先描述您要咨询的问题。', + }, + voice_msg_prompt: { + msg: '您好,请先描述您要咨询的问题。', + }, + text_choice_prompt: { + question: '请选择需要办理的业务', + options: [ + { id: 'billing', label: '账单咨询', value: 'billing' }, + { id: 'repair', label: '故障报修', value: 'repair' }, + { id: 'manual', label: '人工客服', value: 'manual' }, + ], + }, + voice_choice_prompt: { + question: '请选择需要办理的业务', + options: [ + { id: 'billing', label: '账单咨询', value: 'billing' }, + { id: 'repair', label: '故障报修', value: 'repair' }, + { id: 'manual', label: '人工客服', value: 'manual' }, + ], + voice_text: '请从以下选项中选择:账单咨询、故障报修或人工客服。', + }, +}; + +const normalizeManualOpenerToolCallsForRuntime = ( + calls: AssistantOpenerToolCall[] | undefined, + options?: { strictJson?: boolean } +): { calls: Array<{ toolName: string; arguments: Record }>; error?: string } => { + const strictJson = options?.strictJson === true; + const normalized: Array<{ toolName: string; arguments: Record }> = []; + if (!Array.isArray(calls)) return { calls: normalized }; + + for (let i = 0; i < calls.length; i += 1) { + const item = calls[i]; + if (!item || typeof item !== 'object') continue; + const toolName = normalizeToolId(item.toolName || ''); + if (!toolName) continue; + + const argsRaw = item.arguments; + let args: Record = {}; + if (argsRaw && typeof argsRaw === 'object' && !Array.isArray(argsRaw)) { + args = argsRaw as Record; + } else if (typeof argsRaw === 'string' && argsRaw.trim()) { + try { + const parsed = JSON.parse(argsRaw); + if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) { + args = parsed as Record; + } else if (strictJson) { + return { calls: normalized, error: `Opener tool call #${i + 1} arguments must be a JSON object.` }; + } + } catch { + if (strictJson) { + return { calls: normalized, error: `Opener tool call #${i + 1} has invalid JSON arguments.` }; + } + } + } + normalized.push({ toolName, arguments: args }); + } + + return { calls: normalized.slice(0, 8) }; +}; + export const AssistantsPage: React.FC = () => { const [assistants, setAssistants] = useState([]); const [voices, setVoices] = useState([]); @@ -173,6 +247,7 @@ export const AssistantsPage: React.FC = () => { name: 'New Assistant', firstTurnMode: 'bot_first', opener: '', + manualOpenerToolCalls: [], generatedOpenerEnabled: false, openerAudioEnabled: false, prompt: '', @@ -201,9 +276,17 @@ export const AssistantsPage: React.FC = () => { const handleSave = async () => { if (!selectedAssistant) return; + const normalizedManualCalls = normalizeManualOpenerToolCallsForRuntime(selectedAssistant.manualOpenerToolCalls, { strictJson: true }); + if (normalizedManualCalls.error) { + alert(normalizedManualCalls.error); + return; + } setSaveLoading(true); try { - const updated = await updateAssistantApi(selectedAssistant.id, selectedAssistant); + const updated = await updateAssistantApi(selectedAssistant.id, { + ...selectedAssistant, + manualOpenerToolCalls: normalizedManualCalls.calls, + }); setAssistants((prev) => prev.map((item) => (item.id === updated.id ? { ...item, ...updated } : item))); setPersistedAssistantSnapshotById((prev) => ({ ...prev, [updated.id]: serializeAssistant(updated) })); } catch (error) { @@ -436,17 +519,19 @@ export const AssistantsPage: React.FC = () => { const toggleTool = (toolId: string) => { if (!selectedAssistant) return; - const currentTools = selectedAssistant.tools || []; - const newTools = currentTools.includes(toolId) - ? currentTools.filter(id => id !== toolId) - : [...currentTools, toolId]; + const canonicalToolId = normalizeToolId(toolId); + const currentTools = (selectedAssistant.tools || []).map((id) => normalizeToolId(id)); + const newTools = currentTools.includes(canonicalToolId) + ? currentTools.filter(id => id !== canonicalToolId) + : [...currentTools, canonicalToolId]; updateAssistant('tools', newTools); }; const removeImportedTool = (e: React.MouseEvent, tool: Tool) => { e.stopPropagation(); if (!selectedAssistant) return; - updateAssistant('tools', (selectedAssistant.tools || []).filter((id) => id !== tool.id)); + const canonicalToolId = normalizeToolId(tool.id); + updateAssistant('tools', (selectedAssistant.tools || []).filter((id) => normalizeToolId(id) !== canonicalToolId)); }; const addHotword = () => { @@ -462,13 +547,76 @@ export const AssistantsPage: React.FC = () => { } }; + const addManualOpenerToolCall = () => { + if (!selectedAssistant) return; + const current = selectedAssistant.manualOpenerToolCalls || []; + if (current.length >= 8) return; + const fallbackTool = normalizeToolId( + (selectedAssistant.tools || []).find((id) => + tools.some((tool) => normalizeToolId(tool.id) === normalizeToolId(id) && tool.enabled !== false) + ) || '' + ); + updateAssistant('manualOpenerToolCalls', [ + ...current, + { + toolName: fallbackTool, + arguments: '{}', + }, + ]); + }; + + const updateManualOpenerToolCall = (index: number, patch: Partial) => { + if (!selectedAssistant) return; + const current = selectedAssistant.manualOpenerToolCalls || []; + if (index < 0 || index >= current.length) return; + const next = [...current]; + const normalizedPatch = { ...patch }; + if (Object.prototype.hasOwnProperty.call(normalizedPatch, 'toolName')) { + normalizedPatch.toolName = normalizeToolId(normalizedPatch.toolName || ''); + } + next[index] = { ...next[index], ...normalizedPatch }; + updateAssistant('manualOpenerToolCalls', next); + }; + + const removeManualOpenerToolCall = (index: number) => { + if (!selectedAssistant) return; + const current = selectedAssistant.manualOpenerToolCalls || []; + updateAssistant('manualOpenerToolCalls', current.filter((_, idx) => idx !== index)); + }; + + const applyManualOpenerToolTemplate = (index: number) => { + if (!selectedAssistant) return; + const current = selectedAssistant.manualOpenerToolCalls || []; + if (index < 0 || index >= current.length) return; + const toolName = normalizeToolId(current[index]?.toolName || ''); + const template = OPENER_TOOL_ARGUMENT_TEMPLATES[toolName]; + if (!template) return; + updateManualOpenerToolCall(index, { + arguments: JSON.stringify(template, null, 2), + }); + }; + const systemTools = tools.filter((t) => t.enabled !== false && t.category === 'system'); const queryTools = tools.filter((t) => t.enabled !== false && t.category === 'query'); - const selectedToolIds = selectedAssistant?.tools || []; - const activeSystemTools = systemTools.filter((tool) => selectedToolIds.includes(tool.id)); - const activeQueryTools = queryTools.filter((tool) => selectedToolIds.includes(tool.id)); - const availableSystemTools = systemTools.filter((tool) => !selectedToolIds.includes(tool.id)); - const availableQueryTools = queryTools.filter((tool) => !selectedToolIds.includes(tool.id)); + const selectedToolIds = (selectedAssistant?.tools || []).map((id) => normalizeToolId(id)); + const activeSystemTools = systemTools.filter((tool) => selectedToolIds.includes(normalizeToolId(tool.id))); + const activeQueryTools = queryTools.filter((tool) => selectedToolIds.includes(normalizeToolId(tool.id))); + const availableSystemTools = systemTools.filter((tool) => !selectedToolIds.includes(normalizeToolId(tool.id))); + const availableQueryTools = queryTools.filter((tool) => !selectedToolIds.includes(normalizeToolId(tool.id))); + const openerToolOptions = Array.from( + new Map( + tools + .filter( + (tool) => + tool.enabled !== false && + selectedToolIds.some((selectedId) => normalizeToolId(selectedId) === normalizeToolId(tool.id)) + ) + .map((tool) => { + const toolId = normalizeToolId(tool.id); + return [toolId, { id: toolId, label: `${tool.name} (${toolId})` }]; + }) + ).values() + ); const isExternalConfig = selectedAssistant?.configMode === 'dify' || selectedAssistant?.configMode === 'fastgpt'; const isNoneConfig = selectedAssistant?.configMode === 'none' || !selectedAssistant?.configMode; @@ -949,6 +1097,96 @@ export const AssistantsPage: React.FC = () => { )} )} + {selectedAssistant.generatedOpenerEnabled !== true && ( +
+
+ + +
+ {(selectedAssistant.manualOpenerToolCalls || []).length === 0 ? ( +
+ 未配置。可添加 text_msg_prompt / voice_msg_prompt 等工具作为开场动作。 +
+ ) : ( +
+ {(selectedAssistant.manualOpenerToolCalls || []).map((call, idx) => ( +
+
+ + +
+
+

参数 JSON

+ {OPENER_TOOL_ARGUMENT_TEMPLATES[normalizeToolId(call.toolName || '')] && ( + + )} +
+