From f3612a710d79a1278b700e5f05e8a67a71bc854c Mon Sep 17 00:00:00 2001 From: Xin Wang Date: Wed, 11 Mar 2026 08:37:34 +0800 Subject: [PATCH] Add fastgpt as seperate assistant mode --- api/app/models.py | 1 + api/app/routers/assistants.py | 21 +- api/app/schemas.py | 2 + api/tests/test_assistants.py | 19 + engine/Dockerfile | 13 +- engine/adapters/control_plane/backend.py | 2 + engine/app/config.py | 2 +- engine/config/agents/dashscope.yaml | 2 +- engine/config/agents/dashscope_ontest.yaml | 47 ++ engine/config/agents/example.yaml | 3 +- engine/config/agents/volcengine_ontest.yaml | 67 +++ engine/examples/wav_client.py | 415 ++++++++------- engine/providers/factory/default.py | 15 +- engine/providers/llm/__init__.py | 13 + engine/providers/llm/fastgpt.py | 553 ++++++++++++++++++++ engine/providers/llm/fastgpt_types.py | 95 ++++ engine/requirements.txt | 3 + engine/runtime/pipeline/duplex.py | 112 +++- engine/runtime/ports/__init__.py | 9 +- engine/runtime/ports/llm.py | 15 + engine/tests/test_backend_adapters.py | 24 + engine/tests/test_fastgpt_provider.py | 411 +++++++++++++++ engine/tests/test_tool_call_flow.py | 167 ++++++ web/pages/Assistants.tsx | 528 ++++++++++++++++++- web/services/backendApi.ts | 3 + web/types.ts | 1 + 26 files changed, 2333 insertions(+), 210 deletions(-) create mode 100644 engine/config/agents/dashscope_ontest.yaml create mode 100644 engine/config/agents/volcengine_ontest.yaml create mode 100644 engine/providers/llm/fastgpt.py create mode 100644 engine/providers/llm/fastgpt_types.py create mode 100644 engine/tests/test_fastgpt_provider.py diff --git a/api/app/models.py b/api/app/models.py index 265f83d..aaad83c 100644 --- a/api/app/models.py +++ b/api/app/models.py @@ -133,6 +133,7 @@ class Assistant(Base): config_mode: Mapped[str] = mapped_column(String(32), default="platform") api_url: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) api_key: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + app_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) # 模型关联 llm_model_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True) asr_model_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True) diff --git a/api/app/routers/assistants.py b/api/app/routers/assistants.py index f517458..c398cc0 100644 --- a/api/app/routers/assistants.py +++ b/api/app/routers/assistants.py @@ -129,6 +129,9 @@ def _ensure_assistant_schema(db: Session) -> None: if "asr_interim_enabled" not in columns: db.execute(text("ALTER TABLE assistants ADD COLUMN asr_interim_enabled BOOLEAN DEFAULT 0")) altered = True + if "app_id" not in columns: + db.execute(text("ALTER TABLE assistants ADD COLUMN app_id VARCHAR(255)")) + altered = True if altered: db.commit() @@ -297,7 +300,7 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> tuple[Dict[s config_mode = str(assistant.config_mode or "platform").strip().lower() - if config_mode in {"dify", "fastgpt"}: + if config_mode == "dify": metadata["services"]["llm"] = { "provider": "openai", "model": "", @@ -308,6 +311,19 @@ def _resolve_runtime_metadata(db: Session, assistant: Assistant) -> tuple[Dict[s warnings.append(f"External LLM API URL is empty for mode: {assistant.config_mode}") if not (assistant.api_key or "").strip(): warnings.append(f"External LLM API key is empty for mode: {assistant.config_mode}") + elif config_mode == "fastgpt": + metadata["services"]["llm"] = { + "provider": "fastgpt", + "model": "fastgpt", + "apiKey": assistant.api_key, + "baseUrl": assistant.api_url, + } + if (assistant.app_id or "").strip(): + metadata["services"]["llm"]["appId"] = assistant.app_id + if not (assistant.api_url or "").strip(): + warnings.append(f"FastGPT API URL is empty for mode: {assistant.config_mode}") + if not (assistant.api_key or "").strip(): + warnings.append(f"FastGPT API key is empty for mode: {assistant.config_mode}") elif assistant.llm_model_id: llm = db.query(LLMModel).filter(LLMModel.id == assistant.llm_model_id).first() if llm: @@ -450,6 +466,7 @@ def assistant_to_dict(assistant: Assistant) -> dict: "configMode": assistant.config_mode, "apiUrl": assistant.api_url, "apiKey": assistant.api_key, + "appId": assistant.app_id, "llmModelId": assistant.llm_model_id, "asrModelId": assistant.asr_model_id, "embeddingModelId": assistant.embedding_model_id, @@ -472,6 +489,7 @@ def _apply_assistant_update(assistant: Assistant, update_data: dict) -> None: "generatedOpenerEnabled": "generated_opener_enabled", "apiUrl": "api_url", "apiKey": "api_key", + "appId": "app_id", "llmModelId": "llm_model_id", "asrModelId": "asr_model_id", "embeddingModelId": "embedding_model_id", @@ -666,6 +684,7 @@ def create_assistant(data: AssistantCreate, db: Session = Depends(get_db)): config_mode=data.configMode, api_url=data.apiUrl, api_key=data.apiKey, + app_id=data.appId, llm_model_id=data.llmModelId, asr_model_id=data.asrModelId, embedding_model_id=data.embeddingModelId, diff --git a/api/app/schemas.py b/api/app/schemas.py index cbce453..5778982 100644 --- a/api/app/schemas.py +++ b/api/app/schemas.py @@ -298,6 +298,7 @@ class AssistantBase(BaseModel): configMode: str = "platform" apiUrl: Optional[str] = None apiKey: Optional[str] = None + appId: Optional[str] = None # 模型关联 llmModelId: Optional[str] = None asrModelId: Optional[str] = None @@ -330,6 +331,7 @@ class AssistantUpdate(BaseModel): configMode: Optional[str] = None apiUrl: Optional[str] = None apiKey: Optional[str] = None + appId: Optional[str] = None llmModelId: Optional[str] = None asrModelId: Optional[str] = None embeddingModelId: Optional[str] = None diff --git a/api/tests/test_assistants.py b/api/tests/test_assistants.py index 7acbd30..eaab5b5 100644 --- a/api/tests/test_assistants.py +++ b/api/tests/test_assistants.py @@ -29,6 +29,7 @@ class TestAssistantAPI: assert data["generatedOpenerEnabled"] is False assert data["asrInterimEnabled"] is False assert data["botCannotBeInterrupted"] is False + assert data["appId"] is None assert "id" in data assert data["callCount"] == 0 @@ -419,3 +420,21 @@ class TestAssistantAPI: assert metadata["greeting"] == "" assert metadata["bargeIn"]["enabled"] is False assert metadata["bargeIn"]["minDurationMs"] == 900 + + def test_fastgpt_app_id_persists_and_flows_to_runtime(self, client, sample_assistant_data): + sample_assistant_data.update({ + "configMode": "fastgpt", + "apiUrl": "https://cloud.fastgpt.cn/api", + "apiKey": "fastgpt-key", + "appId": "app-fastgpt-123", + }) + assistant_resp = client.post("/api/assistants", json=sample_assistant_data) + assert assistant_resp.status_code == 200 + assistant_id = assistant_resp.json()["id"] + assert assistant_resp.json()["appId"] == "app-fastgpt-123" + + runtime_resp = client.get(f"/api/assistants/{assistant_id}/runtime-config") + assert runtime_resp.status_code == 200 + metadata = runtime_resp.json()["sessionStartMetadata"] + assert metadata["services"]["llm"]["provider"] == "fastgpt" + assert metadata["services"]["llm"]["appId"] == "app-fastgpt-123" diff --git a/engine/Dockerfile b/engine/Dockerfile index e6e5806..ab7b3d3 100644 --- a/engine/Dockerfile +++ b/engine/Dockerfile @@ -2,6 +2,11 @@ FROM python:3.12-slim WORKDIR /app +# Build this image from the project parent directory so both +# engine-v3/engine and fastgpt-python-sdk are available in the context. +# Example: +# docker build -f engine-v3/engine/Dockerfile -t engine-v3 . + # Install system dependencies for audio processing RUN apt-get update && apt-get install -y --no-install-recommends \ build-essential \ @@ -12,11 +17,13 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ && rm -rf /var/lib/apt/lists/* # Install Python dependencies -COPY requirements.txt . -RUN pip install --no-cache-dir -r requirements.txt +COPY engine-v3/engine/requirements.txt /tmp/requirements.txt +COPY fastgpt-python-sdk /deps/fastgpt-python-sdk +RUN pip install --no-cache-dir -r /tmp/requirements.txt \ + && pip install --no-cache-dir /deps/fastgpt-python-sdk # Copy application code -COPY . . +COPY engine-v3/engine /app # Create necessary directories RUN mkdir -p /app/logs /app/data/vad diff --git a/engine/adapters/control_plane/backend.py b/engine/adapters/control_plane/backend.py index 9f8914d..bc32cf1 100644 --- a/engine/adapters/control_plane/backend.py +++ b/engine/adapters/control_plane/backend.py @@ -214,6 +214,8 @@ class LocalYamlAssistantConfigAdapter(NullBackendAdapter): llm_runtime["apiKey"] = cls._as_str(llm.get("api_key")) if cls._as_str(llm.get("api_url")): llm_runtime["baseUrl"] = cls._as_str(llm.get("api_url")) + if cls._as_str(llm.get("app_id")): + llm_runtime["appId"] = cls._as_str(llm.get("app_id")) if llm_runtime: runtime["services"]["llm"] = llm_runtime diff --git a/engine/app/config.py b/engine/app/config.py index 8edf7ce..c5f8902 100644 --- a/engine/app/config.py +++ b/engine/app/config.py @@ -62,7 +62,7 @@ class Settings(BaseSettings): # LLM Configuration llm_provider: str = Field( default="openai", - description="LLM provider (openai, openai_compatible, siliconflow)" + description="LLM provider (openai, openai_compatible, siliconflow, fastgpt)" ) llm_api_url: Optional[str] = Field(default=None, description="LLM provider API base URL") llm_model: str = Field(default="gpt-4o-mini", description="LLM model name") diff --git a/engine/config/agents/dashscope.yaml b/engine/config/agents/dashscope.yaml index 3491d68..6cc77e9 100644 --- a/engine/config/agents/dashscope.yaml +++ b/engine/config/agents/dashscope.yaml @@ -40,7 +40,7 @@ agent: duplex: enabled: true - system_prompt: You are a helpful, friendly voice assistant. Keep your responses concise and conversational. + system_prompt: 你是一个人工智能助手,你用简答语句回答,避免使用标点符号和emoji。 barge_in: min_duration_ms: 200 diff --git a/engine/config/agents/dashscope_ontest.yaml b/engine/config/agents/dashscope_ontest.yaml new file mode 100644 index 0000000..55db902 --- /dev/null +++ b/engine/config/agents/dashscope_ontest.yaml @@ -0,0 +1,47 @@ +# Agent behavior configuration for DashScope realtime ASR/TTS. +# This file only controls agent-side behavior (VAD/LLM/TTS/ASR providers). +# Infra/server/network settings should stay in .env. + +agent: + vad: + type: silero + model_path: data/vad/silero_vad.onnx + threshold: 0.5 + min_speech_duration_ms: 100 + eou_threshold_ms: 800 + + llm: + # provider: openai | openai_compatible | siliconflow + provider: openai_compatible + model: deepseek-v3 + temperature: 0.7 + api_key: sk-fc4d59b360475f53401a864db8ce0985010acc4e696723d20a90d6569f38d80a + api_url: https://api.qnaigc.com/v1 + + tts: + provider: dashscope + api_key: sk-391f5126d18345d497c6e8717c8c9ad7 + api_url: wss://dashscope.aliyuncs.com/api-ws/v1/realtime + model: qwen3-tts-flash-realtime + voice: Cherry + dashscope_mode: commit + speed: 1.0 + + asr: + provider: dashscope + api_key: sk-391f5126d18345d497c6e8717c8c9ad7 + api_url: wss://dashscope.aliyuncs.com/api-ws/v1/realtime + model: qwen3-asr-flash-realtime + interim_interval_ms: 500 + min_audio_ms: 300 + start_min_speech_ms: 160 + pre_speech_ms: 240 + final_tail_ms: 120 + + duplex: + enabled: true + system_prompt: 你是一个人工智能助手,你用简答语句回答,避免使用标点符号和emoji。 + + barge_in: + min_duration_ms: 200 + silence_tolerance_ms: 60 diff --git a/engine/config/agents/example.yaml b/engine/config/agents/example.yaml index e68b6f3..2aa750b 100644 --- a/engine/config/agents/example.yaml +++ b/engine/config/agents/example.yaml @@ -11,7 +11,7 @@ agent: eou_threshold_ms: 800 llm: - # provider: openai | openai_compatible | siliconflow + # provider: openai | openai_compatible | siliconflow | fastgpt provider: openai_compatible model: deepseek-v3 temperature: 0.7 @@ -73,3 +73,4 @@ agent: barge_in: min_duration_ms: 200 silence_tolerance_ms: 60 + diff --git a/engine/config/agents/volcengine_ontest.yaml b/engine/config/agents/volcengine_ontest.yaml new file mode 100644 index 0000000..181fa79 --- /dev/null +++ b/engine/config/agents/volcengine_ontest.yaml @@ -0,0 +1,67 @@ +# Agent behavior configuration (safe to edit per profile) +# This file only controls agent-side behavior (VAD/LLM/TTS/ASR providers). +# Infra/server/network settings should stay in .env. + +agent: + vad: + type: silero + model_path: data/vad/silero_vad.onnx + threshold: 0.5 + min_speech_duration_ms: 100 + eou_threshold_ms: 800 + + llm: + # provider: openai | openai_compatible | siliconflow + provider: openai_compatible + model: deepseek-v3 + temperature: 0.7 + # Required: no fallback. You can still reference env explicitly. + api_key: sk-fc4d59b360475f53401a864db8ce0985010acc4e696723d20a90d6569f38d80a + # Optional for OpenAI-compatible endpoints: + api_url: https://api.qnaigc.com/v1 + + tts: + # provider: edge | openai_compatible | siliconflow | dashscope + # dashscope defaults (if omitted): + # api_url: wss://dashscope.aliyuncs.com/api-ws/v1/realtime + # model: qwen3-tts-flash-realtime + # dashscope_mode: commit (engine splits) | server_commit (dashscope splits) + # note: dashscope_mode/mode is ONLY used when provider=dashscope. + # volcengine defaults (if omitted): + provider: volcengine + api_url: https://openspeech.bytedance.com/api/v3/tts/unidirectional + resource_id: seed-tts-2.0 + app_id: 2931820332 + api_key: 4ustCTIpdCq8dE_msFrZvFn4nDpioIVo + speed: 1.1 + voice: zh_female_vv_uranus_bigtts + + asr: + provider: volcengine + api_url: wss://openspeech.bytedance.com/api/v3/sauc/bigmodel + app_id: 8607675070 + api_key: QiO0AptfmU0GLTSitwn7t5-zeo4gJ6K1 + resource_id: volc.bigasr.sauc.duration + uid: caller-1 + model: bigmodel + request_params: + end_window_size: 800 + force_to_speech_time: 1000 + enable_punc: true + enable_itn: false + enable_ddc: false + show_utterance: true + result_type: single + interim_interval_ms: 500 + min_audio_ms: 300 + start_min_speech_ms: 160 + pre_speech_ms: 240 + final_tail_ms: 120 + + duplex: + enabled: true + system_prompt: 你是一个人工智能助手,你用简答语句回答,避免使用标点符号和emoji。 + + barge_in: + min_duration_ms: 200 + silence_tolerance_ms: 60 diff --git a/engine/examples/wav_client.py b/engine/examples/wav_client.py index 7e4aef1..14b2587 100644 --- a/engine/examples/wav_client.py +++ b/engine/examples/wav_client.py @@ -3,13 +3,15 @@ WAV file client for testing duplex voice conversation. This client reads audio from a WAV file, sends it to the server, -and saves the AI's voice response to an output WAV file. +and saves a stereo WAV file with the input audio on the left channel +and the AI's voice response on the right channel. Usage: python examples/wav_client.py --input input.wav --output response.wav python examples/wav_client.py --input input.wav --output response.wav --url ws://localhost:8000/ws python examples/wav_client.py --input input.wav --output response.wav --wait-time 10 python wav_client.py --input ../data/audio_examples/two_utterances.wav -o response.wav + Requirements: pip install soundfile websockets numpy """ @@ -45,14 +47,14 @@ except ImportError: class WavFileClient: """ WAV file client for voice conversation testing. - + Features: - Read audio from WAV file - Send audio to WebSocket server - - Receive and save response audio + - Receive and save stereo conversation audio - Event logging """ - + def __init__( self, url: str, @@ -69,7 +71,7 @@ class WavFileClient: ): """ Initialize WAV file client. - + Args: url: WebSocket server URL input_file: Input WAV file path @@ -92,48 +94,51 @@ class WavFileClient: self.track_debug = track_debug self.tail_silence_ms = max(0, int(tail_silence_ms)) self.frame_bytes = 640 # 16k mono pcm_s16le, 20ms - + # WebSocket connection self.ws = None self.running = False - + # Audio buffers + self.input_audio = np.array([], dtype=np.int16) self.received_audio = bytearray() - + self.output_segments: list[dict[str, object]] = [] + self.current_output_segment: bytearray | None = None + # Statistics self.bytes_sent = 0 self.bytes_received = 0 - + # TTFB tracking (per response) self.send_start_time = None - self.response_start_time = None # set on each trackStart + self.response_start_time = None # set on each output.audio.start self.waiting_for_first_audio = False self.ttfb_ms = None # last TTFB for summary self.ttfb_list = [] # TTFB for each response - + # State tracking self.track_started = False self.track_ended = False self.send_completed = False self.session_ready = False - + # Events log self.events_log = [] - - def log_event(self, direction: str, message: str): + + def log_event(self, direction: str, message: str) -> None: """Log an event with timestamp.""" timestamp = time.time() - self.events_log.append({ - "timestamp": timestamp, - "direction": direction, - "message": message - }) - # Handle encoding errors on Windows + self.events_log.append( + { + "timestamp": timestamp, + "direction": direction, + "message": message, + } + ) try: print(f"{direction} {message}") except UnicodeEncodeError: - # Replace problematic characters for console output - safe_message = message.encode('ascii', errors='replace').decode('ascii') + safe_message = message.encode("ascii", errors="replace").decode("ascii") print(f"{direction} {safe_message}") @staticmethod @@ -152,119 +157,160 @@ class WavFileClient: query = dict(parse_qsl(parts.query, keep_blank_values=True)) query["assistant_id"] = self.assistant_id return urlunsplit((parts.scheme, parts.netloc, parts.path, urlencode(query), parts.fragment)) - + + def _current_timeline_sample(self) -> int: + """Return current sample position relative to input send start.""" + if self.send_start_time is None: + return 0 + elapsed_seconds = max(0.0, time.time() - self.send_start_time) + return int(round(elapsed_seconds * self.sample_rate)) + + def _start_output_segment(self) -> None: + """Create a new assistant-audio segment if one is not active.""" + if self.current_output_segment is not None: + return + self.current_output_segment = bytearray() + self.output_segments.append( + { + "start_sample": self._current_timeline_sample(), + "audio": self.current_output_segment, + } + ) + + def _close_output_segment(self) -> None: + """Close the active assistant-audio segment, if any.""" + self.current_output_segment = None + + def _build_input_track(self) -> np.ndarray: + """Build the saved left channel using the streamed input audio.""" + input_track = self.input_audio.astype(np.int16, copy=True) + tail_samples = int(round(self.sample_rate * self.tail_silence_ms / 1000.0)) + if tail_samples <= 0: + return input_track + if input_track.size == 0: + return np.zeros(tail_samples, dtype=np.int16) + return np.concatenate((input_track, np.zeros(tail_samples, dtype=np.int16))) + + def _build_output_track(self) -> np.ndarray: + """Build the saved right channel using received assistant audio.""" + if not self.output_segments: + return np.zeros(0, dtype=np.int16) + + total_samples = max( + int(segment["start_sample"]) + (len(segment["audio"]) // 2) + for segment in self.output_segments + ) + mixed_track = np.zeros(total_samples, dtype=np.int32) + + for segment in self.output_segments: + start_sample = int(segment["start_sample"]) + segment_audio = np.frombuffer(bytes(segment["audio"]), dtype=np.int16).astype(np.int32) + if segment_audio.size == 0: + continue + end_sample = start_sample + segment_audio.size + mixed_track[start_sample:end_sample] += segment_audio + + np.clip(mixed_track, -32768, 32767, out=mixed_track) + return mixed_track.astype(np.int16) + async def connect(self) -> None: """Connect to WebSocket server.""" session_url = self._session_url() - self.log_event("→", f"Connecting to {session_url}...") + self.log_event("->", f"Connecting to {session_url}...") self.ws = await websockets.connect(session_url) self.running = True - self.log_event("←", "Connected!") + self.log_event("->", "Connected!") + + await self.send_command( + { + "type": "session.start", + "audio": { + "encoding": "pcm_s16le", + "sample_rate_hz": self.sample_rate, + "channels": 1, + }, + "metadata": { + "channel": self.channel, + "source": "wav_client", + }, + } + ) - await self.send_command({ - "type": "session.start", - "audio": { - "encoding": "pcm_s16le", - "sample_rate_hz": self.sample_rate, - "channels": 1 - }, - "metadata": { - "channel": self.channel, - "source": "wav_client", - }, - }) - async def send_command(self, cmd: dict) -> None: """Send JSON command to server.""" if self.ws: await self.ws.send(json.dumps(cmd)) - self.log_event("→", f"Command: {cmd.get('type', 'unknown')}") - + self.log_event("->", f"Command: {cmd.get('type', 'unknown')}") + async def send_hangup(self, reason: str = "Session complete") -> None: """Send hangup command.""" - await self.send_command({ - "type": "session.stop", - "reason": reason - }) - + await self.send_command({"type": "session.stop", "reason": reason}) + def load_wav_file(self) -> tuple[np.ndarray, int]: """ Load and prepare WAV file for sending. - + Returns: Tuple of (audio_data as int16 numpy array, original sample rate) """ if not self.input_file.exists(): raise FileNotFoundError(f"Input file not found: {self.input_file}") - - # Load audio file + audio_data, file_sample_rate = sf.read(self.input_file) - self.log_event("→", f"Loaded: {self.input_file}") - self.log_event("→", f" Original sample rate: {file_sample_rate} Hz") - self.log_event("→", f" Duration: {len(audio_data) / file_sample_rate:.2f}s") - - # Convert stereo to mono if needed + self.log_event("->", f"Loaded: {self.input_file}") + self.log_event("->", f" Original sample rate: {file_sample_rate} Hz") + self.log_event("->", f" Duration: {len(audio_data) / file_sample_rate:.2f}s") + if len(audio_data.shape) > 1: audio_data = audio_data.mean(axis=1) - self.log_event("→", " Converted stereo to mono") - - # Resample if needed + self.log_event("->", " Converted stereo to mono") + if file_sample_rate != self.sample_rate: - # Simple resampling using numpy duration = len(audio_data) / file_sample_rate num_samples = int(duration * self.sample_rate) indices = np.linspace(0, len(audio_data) - 1, num_samples) audio_data = np.interp(indices, np.arange(len(audio_data)), audio_data) - self.log_event("→", f" Resampled to {self.sample_rate} Hz") - - # Convert to int16 + self.log_event("->", f" Resampled to {self.sample_rate} Hz") + if audio_data.dtype != np.int16: - # Normalize to [-1, 1] if needed max_val = np.max(np.abs(audio_data)) if max_val > 1.0: audio_data = audio_data / max_val audio_data = (audio_data * 32767).astype(np.int16) - - self.log_event("→", f" Prepared: {len(audio_data)} samples ({len(audio_data)/self.sample_rate:.2f}s)") - + + self.log_event("->", f" Prepared: {len(audio_data)} samples ({len(audio_data) / self.sample_rate:.2f}s)") + self.input_audio = audio_data.copy() return audio_data, file_sample_rate - + async def audio_sender(self, audio_data: np.ndarray) -> None: """Send audio data to server in chunks.""" total_samples = len(audio_data) chunk_size = self.chunk_samples sent_samples = 0 - + self.send_start_time = time.time() - self.log_event("→", f"Starting audio transmission ({total_samples} samples)...") - + self.log_event("->", f"Starting audio transmission ({total_samples} samples)...") + while sent_samples < total_samples and self.running: - # Get next chunk end_sample = min(sent_samples + chunk_size, total_samples) chunk = audio_data[sent_samples:end_sample] chunk_bytes = chunk.tobytes() if len(chunk_bytes) % self.frame_bytes != 0: - # v1 audio framing requires 640-byte (20ms) PCM units. pad = self.frame_bytes - (len(chunk_bytes) % self.frame_bytes) chunk_bytes += b"\x00" * pad - - # Send to server + if self.ws: await self.ws.send(chunk_bytes) self.bytes_sent += len(chunk_bytes) - + sent_samples = end_sample - - # Progress logging (every 500ms worth of audio) + if self.verbose and sent_samples % (self.sample_rate // 2) == 0: progress = (sent_samples / total_samples) * 100 print(f" Sending: {progress:.0f}%", end="\r") - - # Delay to simulate real-time streaming - # Server expects audio at real-time pace for VAD/ASR to work properly + await asyncio.sleep(self.chunk_duration_ms / 1000) - # Add a short silence tail to help VAD/EOU close the final utterance. if self.tail_silence_ms > 0 and self.ws: tail_frames = max(1, self.tail_silence_ms // 20) silence = b"\x00" * self.frame_bytes @@ -272,56 +318,53 @@ class WavFileClient: await self.ws.send(silence) self.bytes_sent += len(silence) await asyncio.sleep(0.02) - self.log_event("→", f"Sent trailing silence: {self.tail_silence_ms}ms") - + self.log_event("->", f"Sent trailing silence: {self.tail_silence_ms}ms") + self.send_completed = True elapsed = time.time() - self.send_start_time - self.log_event("→", f"Audio transmission complete ({elapsed:.2f}s, {self.bytes_sent/1024:.1f} KB)") - + self.log_event("->", f"Audio transmission complete ({elapsed:.2f}s, {self.bytes_sent / 1024:.1f} KB)") + async def receiver(self) -> None: """Receive messages from server.""" try: while self.running: try: message = await asyncio.wait_for(self.ws.recv(), timeout=0.1) - + if isinstance(message, bytes): - # Audio data received self.bytes_received += len(message) self.received_audio.extend(message) - - # Calculate TTFB on first audio of each response + self._start_output_segment() + self.current_output_segment.extend(message) + if self.waiting_for_first_audio and self.response_start_time is not None: ttfb_ms = (time.time() - self.response_start_time) * 1000 self.ttfb_ms = ttfb_ms self.ttfb_list.append(ttfb_ms) self.waiting_for_first_audio = False - self.log_event("←", f"[TTFB] First audio latency: {ttfb_ms:.0f}ms") - - # Log progress + self.log_event("<-", f"[TTFB] First audio latency: {ttfb_ms:.0f}ms") + duration_ms = len(message) / (self.sample_rate * 2) * 1000 total_ms = len(self.received_audio) / (self.sample_rate * 2) * 1000 if self.verbose: - print(f"← Audio: +{duration_ms:.0f}ms (total: {total_ms:.0f}ms)", end="\r") - + print(f"<- Audio: +{duration_ms:.0f}ms (total: {total_ms:.0f}ms)", end="\r") else: - # JSON event event = json.loads(message) await self._handle_event(event) - + except asyncio.TimeoutError: continue except websockets.ConnectionClosed: - self.log_event("←", "Connection closed") + self.log_event("<-", "Connection closed") self.running = False break - + except asyncio.CancelledError: pass - except Exception as e: - self.log_event("!", f"Receiver error: {e}") + except Exception as exc: + self.log_event("!", f"Receiver error: {exc}") self.running = False - + async def _handle_event(self, event: dict) -> None: """Handle incoming event.""" event_type = event.get("type", "unknown") @@ -331,14 +374,14 @@ class WavFileClient: if event_type == "session.started": self.session_ready = True - self.log_event("←", f"Session ready!{ids}") + self.log_event("<-", f"Session ready!{ids}") elif event_type == "config.resolved": config = event.get("config", {}) - self.log_event("←", f"Config resolved (output={config.get('output', {})}){ids}") + self.log_event("<-", f"Config resolved (output={config.get('output', {})}){ids}") elif event_type == "input.speech_started": - self.log_event("←", f"Speech detected{ids}") + self.log_event("<-", f"Speech detected{ids}") elif event_type == "input.speech_stopped": - self.log_event("←", f"Silence detected{ids}") + self.log_event("<-", f"Silence detected{ids}") elif event_type == "transcript.delta": text = event.get("text", "") display_text = text[:60] + "..." if len(text) > 60 else text @@ -346,125 +389,128 @@ class WavFileClient: elif event_type == "transcript.final": text = event.get("text", "") print(" " * 80, end="\r") - self.log_event("←", f"→ You: {text}{ids}") + self.log_event("<-", f"You: {text}{ids}") elif event_type == "metrics.ttfb": latency_ms = event.get("latencyMs", 0) - self.log_event("←", f"[TTFB] Server latency: {latency_ms}ms") + self.log_event("<-", f"[TTFB] Server latency: {latency_ms}ms") elif event_type == "assistant.response.delta": text = event.get("text", "") if self.verbose and text: - self.log_event("←", f"LLM: {text}{ids}") + self.log_event("<-", f"LLM: {text}{ids}") elif event_type == "assistant.response.final": text = event.get("text", "") if text: - self.log_event("←", f"LLM Response (final): {text[:100]}{'...' if len(text) > 100 else ''}{ids}") + summary = text[:100] + ("..." if len(text) > 100 else "") + self.log_event("<-", f"LLM Response (final): {summary}{ids}") elif event_type == "output.audio.start": self.track_started = True self.response_start_time = time.time() self.waiting_for_first_audio = True - self.log_event("←", f"Bot started speaking{ids}") + self._close_output_segment() + self.log_event("<-", f"Bot started speaking{ids}") elif event_type == "output.audio.end": self.track_ended = True - self.log_event("←", f"Bot finished speaking{ids}") + self._close_output_segment() + self.log_event("<-", f"Bot finished speaking{ids}") elif event_type == "response.interrupted": - self.log_event("←", f"Bot interrupted!{ids}") + self._close_output_segment() + self.log_event("<-", f"Bot interrupted!{ids}") elif event_type == "error": self.log_event("!", f"Error: {event.get('message')}{ids}") elif event_type == "session.stopped": - self.log_event("←", f"Session stopped: {event.get('reason')}{ids}") + self.log_event("<-", f"Session stopped: {event.get('reason')}{ids}") self.running = False else: - self.log_event("←", f"Event: {event_type}{ids}") - + self.log_event("<-", f"Event: {event_type}{ids}") + def save_output_wav(self) -> None: - """Save received audio to output WAV file.""" - if not self.received_audio: - self.log_event("!", "No audio received to save") + """Save the conversation to a stereo WAV file.""" + input_track = self._build_input_track() + output_track = self._build_output_track() + + if input_track.size == 0 and output_track.size == 0: + self.log_event("!", "No audio available to save") return - - # Convert bytes to numpy array - audio_data = np.frombuffer(bytes(self.received_audio), dtype=np.int16) - - # Ensure output directory exists + + if not self.received_audio: + self.log_event("!", "No assistant audio received; saving silent right channel") + + total_samples = max(input_track.size, output_track.size) + if input_track.size < total_samples: + input_track = np.pad(input_track, (0, total_samples - input_track.size)) + if output_track.size < total_samples: + output_track = np.pad(output_track, (0, total_samples - output_track.size)) + + stereo_audio = np.column_stack((input_track, output_track)).astype(np.int16, copy=False) + self.output_file.parent.mkdir(parents=True, exist_ok=True) - - # Save using wave module for compatibility - with wave.open(str(self.output_file), 'wb') as wav_file: - wav_file.setnchannels(1) + + with wave.open(str(self.output_file), "wb") as wav_file: + wav_file.setnchannels(2) wav_file.setsampwidth(2) # 16-bit wav_file.setframerate(self.sample_rate) - wav_file.writeframes(audio_data.tobytes()) - - duration = len(audio_data) / self.sample_rate - self.log_event("→", f"Saved output: {self.output_file}") - self.log_event("→", f" Duration: {duration:.2f}s ({len(audio_data)} samples)") - self.log_event("→", f" Size: {len(self.received_audio)/1024:.1f} KB") - + wav_file.writeframes(stereo_audio.tobytes()) + + duration = total_samples / self.sample_rate + self.log_event("->", f"Saved stereo output: {self.output_file}") + self.log_event("->", f" Duration: {duration:.2f}s ({total_samples} samples/channel)") + self.log_event("->", " Channels: left=input, right=assistant") + self.log_event("->", f" Size: {stereo_audio.nbytes / 1024:.1f} KB") + async def run(self) -> None: """Run the WAV file test.""" try: - # Load input WAV file audio_data, _ = self.load_wav_file() - - # Connect to server + await self.connect() - - # Start receiver task + receiver_task = asyncio.create_task(self.receiver()) - # Wait for session.started before streaming audio ready_start = time.time() while self.running and not self.session_ready: if time.time() - ready_start > 8.0: raise TimeoutError("Timeout waiting for session.started") await asyncio.sleep(0.05) - - # Send audio + await self.audio_sender(audio_data) - - # Wait for response - self.log_event("→", f"Waiting {self.wait_time}s for response...") - + + self.log_event("->", f"Waiting {self.wait_time}s for response...") + wait_start = time.time() while self.running and (time.time() - wait_start) < self.wait_time: - # Check if track has ended (response complete) if self.track_ended and self.send_completed: - # Give a little extra time for any remaining audio await asyncio.sleep(1.0) break await asyncio.sleep(0.1) - - # Cleanup + self.running = False receiver_task.cancel() - + try: await receiver_task except asyncio.CancelledError: pass - - # Save output + self.save_output_wav() - - # Print summary self._print_summary() - - except FileNotFoundError as e: - print(f"Error: {e}") + + except FileNotFoundError as exc: + print(f"Error: {exc}") sys.exit(1) except ConnectionRefusedError: print(f"Error: Could not connect to {self.url}") print("Make sure the server is running.") sys.exit(1) - except Exception as e: - print(f"Error: {e}") + except Exception as exc: + print(f"Error: {exc}") import traceback + traceback.print_exc() sys.exit(1) finally: await self.close() - - def _print_summary(self): + + def _print_summary(self) -> None: """Print session summary.""" print("\n" + "=" * 50) print("Session Summary") @@ -477,19 +523,20 @@ class WavFileClient: if len(self.ttfb_list) == 1: print(f" TTFB: {self.ttfb_list[0]:.0f} ms") else: - print(f" TTFB (per response): {', '.join(f'{t:.0f}ms' for t in self.ttfb_list)}") + values = ", ".join(f"{ttfb:.0f}ms" for ttfb in self.ttfb_list) + print(f" TTFB (per response): {values}") if self.received_audio: duration = len(self.received_audio) / (self.sample_rate * 2) print(f" Response duration: {duration:.2f}s") print("=" * 50) - + async def close(self) -> None: """Close the connection.""" self.running = False if self.ws: try: await self.ws.close() - except: + except Exception: pass @@ -498,67 +545,71 @@ async def main(): description="WAV file client for testing duplex voice conversation" ) parser.add_argument( - "--input", "-i", + "--input", + "-i", required=True, - help="Input WAV file path" + help="Input WAV file path", ) parser.add_argument( - "--output", "-o", + "--output", + "-o", required=True, - help="Output WAV file path for response" + help="Output WAV file path for stereo conversation audio", ) parser.add_argument( "--url", default="ws://localhost:8000/ws", - help="WebSocket server URL (default: ws://localhost:8000/ws)" + help="WebSocket server URL (default: ws://localhost:8000/ws)", ) parser.add_argument( "--sample-rate", type=int, default=16000, - help="Target sample rate for audio (default: 16000)" + help="Target sample rate for audio (default: 16000)", ) parser.add_argument( "--assistant-id", default="default", - help="Assistant identifier used in websocket query parameter" + help="Assistant identifier used in websocket query parameter", ) parser.add_argument( "--channel", default="wav_client", - help="Client channel name" + help="Client channel name", ) parser.add_argument( "--chunk-duration", type=int, default=20, - help="Chunk duration in ms for sending (default: 20)" + help="Chunk duration in ms for sending (default: 20)", ) parser.add_argument( - "--wait-time", "-w", + "--wait-time", + "-w", type=float, default=15.0, - help="Time to wait for response after sending (default: 15.0)" + help="Time to wait for response after sending (default: 15.0)", ) parser.add_argument( - "--verbose", "-v", + "--verbose", + "-v", action="store_true", - help="Enable verbose output" + help="Enable verbose output", ) parser.add_argument( "--track-debug", action="store_true", - help="Print event trackId for protocol debugging" + help="Print event trackId for protocol debugging", ) parser.add_argument( "--tail-silence-ms", type=int, default=800, - help="Trailing silence to send after WAV playback for EOU detection (default: 800)" + help="Trailing silence to send after WAV playback for EOU detection (default: 800)", ) - + args = parser.parse_args() - + client = WavFileClient( url=args.url, input_file=args.input, @@ -572,7 +623,7 @@ async def main(): track_debug=args.track_debug, tail_silence_ms=args.tail_silence_ms, ) - + await client.run() @@ -580,4 +631,4 @@ if __name__ == "__main__": try: asyncio.run(main()) except KeyboardInterrupt: - print("\nInterrupted by user") + print("\nInterrupted by user") \ No newline at end of file diff --git a/engine/providers/factory/default.py b/engine/providers/factory/default.py index de72af6..478d290 100644 --- a/engine/providers/factory/default.py +++ b/engine/providers/factory/default.py @@ -28,7 +28,7 @@ from providers.tts.volcengine import VolcengineTTSService _OPENAI_COMPATIBLE_PROVIDERS = {"openai_compatible", "openai-compatible", "siliconflow"} _DASHSCOPE_PROVIDERS = {"dashscope"} _VOLCENGINE_PROVIDERS = {"volcengine"} -_SUPPORTED_LLM_PROVIDERS = {"openai", *_OPENAI_COMPATIBLE_PROVIDERS} +_SUPPORTED_LLM_PROVIDERS = {"openai", "fastgpt", *_OPENAI_COMPATIBLE_PROVIDERS} class DefaultRealtimeServiceFactory(RealtimeServiceFactory): @@ -58,7 +58,18 @@ class DefaultRealtimeServiceFactory(RealtimeServiceFactory): def create_llm_service(self, spec: LLMServiceSpec) -> LLMPort: provider = self._normalize_provider(spec.provider) - if provider in _SUPPORTED_LLM_PROVIDERS and spec.api_key: + if provider == "fastgpt" and spec.api_key and spec.base_url: + from providers.llm.fastgpt import FastGPTLLMService + + return FastGPTLLMService( + api_key=spec.api_key, + base_url=spec.base_url, + app_id=spec.app_id, + model=spec.model, + system_prompt=spec.system_prompt, + ) + + if provider in _SUPPORTED_LLM_PROVIDERS and provider != "fastgpt" and spec.api_key: return OpenAILLMService( api_key=spec.api_key, base_url=spec.base_url, diff --git a/engine/providers/llm/__init__.py b/engine/providers/llm/__init__.py index 3258a10..528d1e1 100644 --- a/engine/providers/llm/__init__.py +++ b/engine/providers/llm/__init__.py @@ -1 +1,14 @@ """LLM providers.""" + +from providers.llm.openai import MockLLMService, OpenAILLMService + +try: # pragma: no cover - import depends on optional sibling SDK + from providers.llm.fastgpt import FastGPTLLMService +except Exception: # pragma: no cover - provider remains lazily available via factory + FastGPTLLMService = None # type: ignore[assignment] + +__all__ = [ + "FastGPTLLMService", + "MockLLMService", + "OpenAILLMService", +] diff --git a/engine/providers/llm/fastgpt.py b/engine/providers/llm/fastgpt.py new file mode 100644 index 0000000..a48814b --- /dev/null +++ b/engine/providers/llm/fastgpt.py @@ -0,0 +1,553 @@ +"""FastGPT-backed LLM provider.""" + +from __future__ import annotations + +import asyncio +import json +import uuid +from typing import Any, AsyncIterator, Dict, List, Optional + +from loguru import logger + +from providers.common.base import BaseLLMService, LLMMessage, LLMStreamEvent, ServiceState +from providers.llm.fastgpt_types import ( + FastGPTConversationState, + FastGPTField, + FastGPTInteractivePrompt, + FastGPTOption, + FastGPTPendingInteraction, +) + +try: + from fastgpt_client import AsyncChatClient, aiter_stream_events +except Exception as exc: # pragma: no cover - exercised indirectly via connect() + AsyncChatClient = None # type: ignore[assignment] + aiter_stream_events = None # type: ignore[assignment] + _FASTGPT_IMPORT_ERROR: Optional[Exception] = exc +else: # pragma: no cover - import success depends on local environment + _FASTGPT_IMPORT_ERROR = None + + +class FastGPTLLMService(BaseLLMService): + """LLM provider that delegates orchestration to FastGPT.""" + + INTERACTIVE_TOOL_NAME = "fastgpt.interactive" + INTERACTIVE_TIMEOUT_MS = 300000 + + def __init__( + self, + *, + api_key: str, + base_url: str, + app_id: Optional[str] = None, + model: str = "fastgpt", + system_prompt: Optional[str] = None, + ): + super().__init__(model=model or "fastgpt") + self.api_key = api_key + self.base_url = str(base_url or "").rstrip("/") + self.app_id = str(app_id or "").strip() + self.system_prompt = system_prompt or "" + self.client: Any = None + self._cancel_event = asyncio.Event() + self._state = FastGPTConversationState() + self._knowledge_config: Dict[str, Any] = {} + self._tool_schemas: List[Dict[str, Any]] = [] + + async def connect(self) -> None: + if AsyncChatClient is None or aiter_stream_events is None: + raise RuntimeError( + "fastgpt_client package is not available. " + "Install the sibling fastgpt-python-sdk package first." + ) from _FASTGPT_IMPORT_ERROR + if not self.api_key: + raise ValueError("FastGPT API key not provided") + if not self.base_url: + raise ValueError("FastGPT base URL not provided") + self.client = AsyncChatClient(api_key=self.api_key, base_url=self.base_url) + self.state = ServiceState.CONNECTED + logger.info("FastGPT LLM service connected: base_url={}", self.base_url) + + async def disconnect(self) -> None: + if self.client and hasattr(self.client, "close"): + await self.client.close() + self.client = None + self._state.pending_interaction = None + self.state = ServiceState.DISCONNECTED + logger.info("FastGPT LLM service disconnected") + + def cancel(self) -> None: + self._cancel_event.set() + self._state.pending_interaction = None + + def set_knowledge_config(self, config: Optional[Dict[str, Any]]) -> None: + # FastGPT owns KB orchestration in this provider mode. + self._knowledge_config = dict(config or {}) + + def set_tool_schemas(self, schemas: Optional[List[Dict[str, Any]]]) -> None: + # FastGPT owns workflow and tool orchestration in this provider mode. + self._tool_schemas = list(schemas or []) + + def handles_client_tool(self, tool_name: str) -> bool: + return str(tool_name or "").strip() == self.INTERACTIVE_TOOL_NAME + + async def get_initial_greeting(self) -> Optional[str]: + if not self.client or not self.app_id: + return None + + response = await self.client.get_chat_init( + appId=self.app_id, + chatId=self._ensure_chat_id(), + ) + raise_for_status = getattr(response, "raise_for_status", None) + if callable(raise_for_status): + raise_for_status() + elif int(getattr(response, "status_code", 200) or 200) >= 400: + raise RuntimeError(f"FastGPT chat init failed: HTTP {getattr(response, 'status_code', 'unknown')}") + + payload = response.json() if hasattr(response, "json") else {} + return self._extract_initial_greeting(payload) + + async def generate( + self, + messages: List[LLMMessage], + temperature: float = 0.7, + max_tokens: Optional[int] = None, + ) -> str: + parts: List[str] = [] + async for event in self.generate_stream(messages, temperature=temperature, max_tokens=max_tokens): + if event.type == "text_delta" and event.text: + parts.append(event.text) + if event.type == "tool_call": + break + return "".join(parts) + + async def generate_stream( + self, + messages: List[LLMMessage], + temperature: float = 0.7, + max_tokens: Optional[int] = None, + ) -> AsyncIterator[LLMStreamEvent]: + del temperature, max_tokens + if not self.client: + raise RuntimeError("LLM service not connected") + + self._cancel_event.clear() + request_messages = self._build_request_messages(messages) + response = await self.client.create_chat_completion( + messages=request_messages, + chatId=self._ensure_chat_id(), + detail=True, + stream=True, + ) + try: + async for event in aiter_stream_events(response): + if self._cancel_event.is_set(): + logger.info("FastGPT stream cancelled") + break + + stop_after_event = False + for mapped in self._map_stream_event(event): + if mapped.type == "tool_call": + stop_after_event = True + yield mapped + if stop_after_event: + break + finally: + await self._close_stream_response(response) + + async def resume_after_client_tool_result( + self, + tool_call_id: str, + result: Dict[str, Any], + ) -> AsyncIterator[LLMStreamEvent]: + if not self.client: + raise RuntimeError("LLM service not connected") + + pending = self._require_pending_interaction(tool_call_id) + follow_up_text = self._build_resume_text(pending, result) + self._state.pending_interaction = None + + if not follow_up_text: + yield LLMStreamEvent(type="done") + return + + self._cancel_event.clear() + response = await self.client.create_chat_completion( + messages=[{"role": "user", "content": follow_up_text}], + chatId=pending.chat_id, + detail=True, + stream=True, + ) + try: + async for event in aiter_stream_events(response): + if self._cancel_event.is_set(): + logger.info("FastGPT resume stream cancelled") + break + + stop_after_event = False + for mapped in self._map_stream_event(event): + if mapped.type == "tool_call": + stop_after_event = True + yield mapped + if stop_after_event: + break + finally: + await self._close_stream_response(response) + + async def _close_stream_response(self, response: Any) -> None: + if response is None: + return + + # httpx async streaming responses must use `aclose()`. + aclose = getattr(response, "aclose", None) + if callable(aclose): + await aclose() + return + + close = getattr(response, "close", None) + if callable(close): + maybe_awaitable = close() + if hasattr(maybe_awaitable, "__await__"): + await maybe_awaitable + + def _ensure_chat_id(self) -> str: + chat_id = str(self._state.chat_id or "").strip() + if not chat_id: + chat_id = f"fastgpt_{uuid.uuid4().hex}" + self._state.chat_id = chat_id + return chat_id + + def _build_request_messages(self, messages: List[LLMMessage]) -> List[Dict[str, Any]]: + non_empty = [msg for msg in messages if str(msg.content or "").strip()] + if not non_empty: + return [{"role": "user", "content": ""}] + + latest_user = next((msg for msg in reversed(non_empty) if msg.role == "user"), None) + trailing_system = non_empty[-1] if non_empty and non_empty[-1].role == "system" else None + + request: List[Dict[str, Any]] = [] + if trailing_system and trailing_system is not latest_user: + request.append({"role": "system", "content": trailing_system.content.strip()}) + if latest_user and str(latest_user.content or "").strip(): + request.append({"role": "user", "content": latest_user.content.strip()}) + return request + + last_message = non_empty[-1] + payload = last_message.to_dict() + payload["content"] = str(payload.get("content") or "").strip() + return [payload] + + def _extract_initial_greeting(self, payload: Any) -> Optional[str]: + if not isinstance(payload, dict): + return None + + candidates: List[Any] = [ + payload.get("app"), + payload.get("data"), + ] + for container in candidates: + if not isinstance(container, dict): + continue + nested_app = container.get("app") if isinstance(container.get("app"), dict) else None + if nested_app: + text = self._welcome_text_from_app(nested_app) + if text: + return text + text = self._welcome_text_from_app(container) + if text: + return text + + return None + + @staticmethod + def _welcome_text_from_app(app_payload: Dict[str, Any]) -> Optional[str]: + chat_config = app_payload.get("chatConfig") if isinstance(app_payload.get("chatConfig"), dict) else {} + text = str( + chat_config.get("welcomeText") + or app_payload.get("welcomeText") + or "" + ).strip() + return text or None + + def _map_stream_event(self, event: Any) -> List[LLMStreamEvent]: + kind = str(getattr(event, "kind", "") or "") + data = getattr(event, "data", {}) + if not isinstance(data, dict): + data = {} + + if kind in {"data", "answer", "fastAnswer"}: + chunks = self._extract_text_chunks(kind, data) + return [LLMStreamEvent(type="text_delta", text=chunk) for chunk in chunks if chunk] + + if kind == "interactive": + return [self._build_interactive_tool_event(data)] + + if kind == "error": + message = str(data.get("message") or data.get("error") or "FastGPT streaming error") + raise RuntimeError(message) + + if kind == "done": + return [LLMStreamEvent(type="done")] + + return [] + + @staticmethod + def _normalize_interactive_payload(payload: Dict[str, Any]) -> Dict[str, Any]: + normalized = payload + wrapped = normalized.get("interactive") + if isinstance(wrapped, dict): + normalized = wrapped + + interaction_type = str(normalized.get("type") or "").strip() + if interaction_type == "toolChildrenInteractive": + params = normalized.get("params") if isinstance(normalized.get("params"), dict) else {} + children_response = params.get("childrenResponse") + if isinstance(children_response, dict): + normalized = children_response + + return normalized + + def _extract_text_chunks(self, kind: str, data: Dict[str, Any]) -> List[str]: + if kind in {"answer", "fastAnswer"}: + text = str(data.get("text") or "") + if text: + return [text] + + choices = data.get("choices") if isinstance(data.get("choices"), list) else [] + if not choices: + text = str(data.get("text") or "") + return [text] if text else [] + + first = choices[0] if isinstance(choices[0], dict) else {} + delta = first.get("delta") if isinstance(first.get("delta"), dict) else {} + if isinstance(delta.get("content"), str) and delta.get("content"): + return [str(delta.get("content"))] + message = first.get("message") if isinstance(first.get("message"), dict) else {} + if isinstance(message.get("content"), str) and message.get("content"): + return [str(message.get("content"))] + return [] + + def _build_interactive_tool_event(self, payload: Dict[str, Any]) -> LLMStreamEvent: + normalized_payload = self._normalize_interactive_payload(payload) + prompt = self._parse_interactive_prompt(normalized_payload) + call_id = f"fgi_{uuid.uuid4().hex[:12]}" + pending = FastGPTPendingInteraction( + tool_call_id=call_id, + chat_id=self._ensure_chat_id(), + prompt=prompt, + timeout_ms=self.INTERACTIVE_TIMEOUT_MS, + fastgpt_event=dict(normalized_payload), + ) + self._state.pending_interaction = pending + arguments = prompt.to_ws_arguments(chat_id=pending.chat_id) + tool_call = { + "id": call_id, + "type": "function", + "executor": "client", + "wait_for_response": True, + "timeout_ms": pending.timeout_ms, + "display_name": prompt.title or prompt.description or prompt.prompt or "FastGPT Interactive", + "function": { + "name": self.INTERACTIVE_TOOL_NAME, + "arguments": json.dumps(arguments, ensure_ascii=False), + }, + } + return LLMStreamEvent(type="tool_call", tool_call=tool_call) + + def _parse_interactive_prompt(self, payload: Dict[str, Any]) -> FastGPTInteractivePrompt: + params = payload.get("params") if isinstance(payload.get("params"), dict) else {} + kind = str(payload.get("type") or "userSelect").strip() or "userSelect" + title = str( + payload.get("title") + or params.get("title") + or payload.get("nodeName") + or payload.get("label") + or "" + ).strip() + description = str( + payload.get("description") + or payload.get("desc") + or params.get("description") + or params.get("desc") + or "" + ).strip() + prompt_text = str( + payload.get("opener") + or params.get("opener") + or payload.get("intro") + or params.get("intro") + or payload.get("prompt") + or params.get("prompt") + or payload.get("text") + or params.get("text") + or title + or description + ).strip() + required = self._coerce_bool(payload.get("required"), default=True) + multiple = self._coerce_bool(params.get("multiple") or payload.get("multiple"), default=False) + submit_label = str(params.get("submitText") or payload.get("submitText") or "Continue").strip() or "Continue" + cancel_label = str(params.get("cancelText") or payload.get("cancelText") or "Cancel").strip() or "Cancel" + + options: List[FastGPTOption] = [] + raw_options = params.get("userSelectOptions") if isinstance(params.get("userSelectOptions"), list) else [] + for index, raw_option in enumerate(raw_options): + if isinstance(raw_option, str): + value = raw_option.strip() + if not value: + continue + options.append(FastGPTOption(id=f"option_{index}", label=value, value=value)) + continue + if not isinstance(raw_option, dict): + continue + label = str(raw_option.get("label") or raw_option.get("value") or raw_option.get("id") or "").strip() + value = str(raw_option.get("value") or raw_option.get("label") or raw_option.get("id") or "").strip() + option_id = str(raw_option.get("id") or value or f"option_{index}").strip() + if not label and not value: + continue + options.append( + FastGPTOption( + id=option_id or f"option_{index}", + label=label or value, + value=value or label, + description=str( + raw_option.get("description") + or raw_option.get("desc") + or raw_option.get("intro") + or raw_option.get("summary") + or "" + ).strip(), + ) + ) + + form: List[FastGPTField] = [] + raw_form = params.get("inputForm") if isinstance(params.get("inputForm"), list) else [] + for index, raw_field in enumerate(raw_form): + if not isinstance(raw_field, dict): + continue + field_options: List[FastGPTOption] = [] + nested_options = raw_field.get("options") if isinstance(raw_field.get("options"), list) else [] + for opt_index, option in enumerate(nested_options): + if isinstance(option, str): + value = option.strip() + if not value: + continue + field_options.append(FastGPTOption(id=f"field_{index}_opt_{opt_index}", label=value, value=value)) + continue + if not isinstance(option, dict): + continue + label = str(option.get("label") or option.get("value") or option.get("id") or "").strip() + value = str(option.get("value") or option.get("label") or option.get("id") or "").strip() + option_id = str(option.get("id") or value or f"field_{index}_opt_{opt_index}").strip() + if not label and not value: + continue + field_options.append( + FastGPTOption( + id=option_id or f"field_{index}_opt_{opt_index}", + label=label or value, + value=value or label, + description=str( + option.get("description") + or option.get("desc") + or option.get("intro") + or option.get("summary") + or "" + ).strip(), + ) + ) + name = str(raw_field.get("key") or raw_field.get("name") or raw_field.get("label") or f"field_{index}").strip() + label = str(raw_field.get("label") or raw_field.get("name") or name).strip() + form.append( + FastGPTField( + name=name or f"field_{index}", + label=label or name or f"field_{index}", + input_type=str(raw_field.get("type") or raw_field.get("inputType") or "text").strip() or "text", + required=self._coerce_bool(raw_field.get("required"), default=False), + placeholder=str( + raw_field.get("placeholder") + or raw_field.get("description") + or raw_field.get("desc") + or "" + ).strip(), + default=raw_field.get("defaultValue", raw_field.get("default")), + options=field_options, + ) + ) + + return FastGPTInteractivePrompt( + kind="userInput" if kind == "userInput" else "userSelect", + title=title, + description=description, + prompt=prompt_text, + required=required, + multiple=multiple, + submit_label=submit_label, + cancel_label=cancel_label, + options=options, + form=form, + raw=dict(payload), + ) + + def _require_pending_interaction(self, tool_call_id: str) -> FastGPTPendingInteraction: + pending = self._state.pending_interaction + if pending is None or pending.tool_call_id != tool_call_id: + raise ValueError(f"FastGPT interaction not pending for tool call: {tool_call_id}") + return pending + + def _build_resume_text(self, pending: FastGPTPendingInteraction, result: Dict[str, Any]) -> str: + status = result.get("status") if isinstance(result.get("status"), dict) else {} + status_code = self._safe_int(status.get("code"), default=0) + output = result.get("output") if isinstance(result.get("output"), dict) else {} + action = str(output.get("action") or "").strip().lower() + + if action == "cancel" or status_code == 499: + return "" + if status_code == 422: + raise ValueError("Invalid FastGPT interactive payload from client") + if status_code and not 200 <= status_code < 300: + raise ValueError(f"FastGPT interactive result rejected with status {status_code}") + if action and action != "submit": + raise ValueError(f"Unsupported FastGPT interactive action: {action}") + + payload = output.get("result") if isinstance(output.get("result"), dict) else output + if not isinstance(payload, dict): + raise ValueError("FastGPT interactive client result must be an object") + + if pending.prompt.kind == "userSelect": + selected = str(payload.get("selected") or "").strip() + if selected: + return selected + selected_values = payload.get("selected_values") if isinstance(payload.get("selected_values"), list) else [] + values = [str(item).strip() for item in selected_values if str(item).strip()] + if values: + return ", ".join(values) + text_value = str(payload.get("text") or "").strip() + return text_value + + text_value = str(payload.get("text") or "").strip() + if text_value: + return text_value + fields = payload.get("fields") if isinstance(payload.get("fields"), dict) else {} + compact_fields = {str(key): value for key, value in fields.items()} + if compact_fields: + return json.dumps(compact_fields, ensure_ascii=False) + return "" + + @staticmethod + def _coerce_bool(value: Any, *, default: bool) -> bool: + if isinstance(value, bool): + return value + if isinstance(value, str): + normalized = value.strip().lower() + if normalized in {"true", "1", "yes", "on"}: + return True + if normalized in {"false", "0", "no", "off"}: + return False + return default + + @staticmethod + def _safe_int(value: Any, *, default: int) -> int: + try: + return int(value) + except (TypeError, ValueError): + return default diff --git a/engine/providers/llm/fastgpt_types.py b/engine/providers/llm/fastgpt_types.py new file mode 100644 index 0000000..71766e3 --- /dev/null +++ b/engine/providers/llm/fastgpt_types.py @@ -0,0 +1,95 @@ +"""FastGPT-specific provider types.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Literal, Optional + +InteractiveKind = Literal["userSelect", "userInput"] + + +@dataclass(frozen=True) +class FastGPTOption: + id: str + label: str + value: str + description: str = "" + + +@dataclass(frozen=True) +class FastGPTField: + name: str + label: str + input_type: str = "text" + required: bool = False + placeholder: str = "" + default: Any = None + options: List[FastGPTOption] = field(default_factory=list) + + +@dataclass(frozen=True) +class FastGPTInteractivePrompt: + kind: InteractiveKind + title: str = "" + description: str = "" + prompt: str = "" + required: bool = True + multiple: bool = False + submit_label: str = "Continue" + cancel_label: str = "Cancel" + options: List[FastGPTOption] = field(default_factory=list) + form: List[FastGPTField] = field(default_factory=list) + raw: Dict[str, Any] = field(default_factory=dict) + + def to_ws_arguments( + self, + *, + turn_id: Optional[str] = None, + response_id: Optional[str] = None, + chat_id: Optional[str] = None, + ) -> Dict[str, Any]: + context: Dict[str, Any] = {} + if turn_id: + context["turn_id"] = turn_id + if response_id: + context["response_id"] = response_id + if chat_id: + context["chat_id"] = chat_id + return { + "provider": "fastgpt", + "version": "fastgpt_interactive_v1", + "interaction": { + "type": self.kind, + "title": self.title, + "description": self.description, + "prompt": self.prompt, + "required": self.required, + "multiple": self.multiple, + "submit_label": self.submit_label, + "cancel_label": self.cancel_label, + "options": [vars(item) for item in self.options], + "form": [ + { + **vars(item), + "options": [vars(option) for option in item.options], + } + for item in self.form + ], + }, + "context": context, + } + + +@dataclass +class FastGPTPendingInteraction: + tool_call_id: str + chat_id: str + prompt: FastGPTInteractivePrompt + timeout_ms: int + fastgpt_event: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class FastGPTConversationState: + chat_id: Optional[str] = None + pending_interaction: Optional[FastGPTPendingInteraction] = None diff --git a/engine/requirements.txt b/engine/requirements.txt index 2818030..0d8b90d 100644 --- a/engine/requirements.txt +++ b/engine/requirements.txt @@ -33,3 +33,6 @@ dashscope>=1.25.11 sounddevice>=0.4.6 soundfile>=0.12.1 pyaudio>=0.2.13 # More reliable audio on Windows + +# FastGPT runtime support is installed from the sibling fastgpt-python-sdk package. +# Local dev: pip install -e ..\\fastgpt-python-sdk diff --git a/engine/runtime/pipeline/duplex.py b/engine/runtime/pipeline/duplex.py index dcf198f..743c945 100644 --- a/engine/runtime/pipeline/duplex.py +++ b/engine/runtime/pipeline/duplex.py @@ -594,6 +594,7 @@ class DuplexPipeline: "provider": llm_provider, "model": str(self._runtime_llm.get("model") or settings.llm_model), "baseUrl": llm_base_url, + "appId": str(self._runtime_llm.get("appId") or ""), }, "asr": { "provider": asr_provider, @@ -937,6 +938,19 @@ class DuplexPipeline: return None return text.strip().strip('"').strip("'") + async def _resolve_provider_initial_greeting(self) -> Optional[str]: + if not self.llm_service or not hasattr(self.llm_service, "get_initial_greeting"): + return None + + try: + greeting = await self.llm_service.get_initial_greeting() + except Exception as exc: + logger.warning("Failed to load provider initial greeting: {}", exc) + return None + + text = str(greeting or "").strip() + return text or None + async def start(self) -> None: """Start the pipeline and connect services.""" try: @@ -956,6 +970,7 @@ class DuplexPipeline: model=str(llm_model), api_key=str(llm_api_key).strip() if llm_api_key else None, base_url=str(llm_base_url).strip() if llm_base_url else None, + app_id=str(self._runtime_llm.get("appId")).strip() if self._runtime_llm.get("appId") else None, system_prompt=self.conversation.system_prompt, temperature=settings.llm_temperature, knowledge_config=self._resolved_knowledge_config(), @@ -1096,7 +1111,11 @@ class DuplexPipeline: if not self._bot_starts_first(): return - if self._generated_opener_enabled() and self._resolved_tool_schemas(): + provider_greeting = await self._resolve_provider_initial_greeting() + if provider_greeting: + self.conversation.greeting = provider_greeting + + if not provider_greeting and self._generated_opener_enabled() and self._resolved_tool_schemas(): # Run generated opener as a normal tool-capable assistant turn. # Use an empty user input so the opener can be driven by system prompt policy. if self._current_turn_task and not self._current_turn_task.done(): @@ -1107,13 +1126,13 @@ class DuplexPipeline: return manual_opener_execution: Dict[str, List[Dict[str, Any]]] = {"toolCalls": [], "toolResults": []} - if not self._generated_opener_enabled() and self._resolved_manual_opener_tool_calls(): + if not provider_greeting and not self._generated_opener_enabled() and self._resolved_manual_opener_tool_calls(): self._start_turn() self._start_response() manual_opener_execution = await self._execute_manual_opener_tool_calls() greeting_to_speak = self.conversation.greeting - if self._generated_opener_enabled(): + if not provider_greeting and self._generated_opener_enabled(): generated_greeting = await self._generate_runtime_greeting() if generated_greeting: greeting_to_speak = generated_greeting @@ -1954,12 +1973,35 @@ class DuplexPipeline: return bool(self._runtime_tool_wait_for_response.get(normalized, False)) def _tool_executor(self, tool_call: Dict[str, Any]) -> str: + explicit_executor = str(tool_call.get("executor") or "").strip().lower() + if explicit_executor in {"client", "server"}: + return explicit_executor name = self._tool_name(tool_call) if name and name in self._runtime_tool_executor: return self._runtime_tool_executor[name] # Default to server execution unless explicitly marked as client. return "server" + def _tool_wait_for_response_for_call(self, tool_name: str, tool_call: Dict[str, Any]) -> bool: + explicit_wait = tool_call.get("wait_for_response") + if explicit_wait is None: + explicit_wait = tool_call.get("waitForResponse") + if isinstance(explicit_wait, bool): + return explicit_wait + return self._tool_wait_for_response(tool_name) + + def _tool_timeout_ms(self, tool_call: Dict[str, Any]) -> int: + raw_timeout = tool_call.get("timeout_ms") + if raw_timeout is None: + raw_timeout = tool_call.get("timeoutMs") + try: + timeout_ms = int(raw_timeout) + except (TypeError, ValueError): + timeout_ms = 0 + if timeout_ms > 0: + return timeout_ms + return int(self._TOOL_WAIT_TIMEOUT_SECONDS * 1000) + def _tool_arguments(self, tool_call: Dict[str, Any]) -> Dict[str, Any]: fn = tool_call.get("function") if not isinstance(fn, dict): @@ -2179,7 +2221,7 @@ class DuplexPipeline: self._early_tool_results[call_id] = item self._completed_tool_call_ids.add(call_id) - async def _wait_for_single_tool_result(self, call_id: str) -> Dict[str, Any]: + async def _wait_for_single_tool_result(self, call_id: str, timeout_seconds: Optional[float] = None) -> Dict[str, Any]: if call_id in self._completed_tool_call_ids and call_id not in self._early_tool_results: return { "tool_call_id": call_id, @@ -2193,8 +2235,9 @@ class DuplexPipeline: loop = asyncio.get_running_loop() future = loop.create_future() self._pending_tool_waiters[call_id] = future + timeout = timeout_seconds if isinstance(timeout_seconds, (int, float)) and timeout_seconds > 0 else self._TOOL_WAIT_TIMEOUT_SECONDS try: - return await asyncio.wait_for(future, timeout=self._TOOL_WAIT_TIMEOUT_SECONDS) + return await asyncio.wait_for(future, timeout=timeout) except asyncio.TimeoutError: self._completed_tool_call_ids.add(call_id) return { @@ -2256,6 +2299,7 @@ class DuplexPipeline: first_audio_sent = False self._pending_llm_delta = "" self._last_llm_delta_emit_ms = 0.0 + pending_provider_stream = None for _ in range(max_rounds): if self._interrupt_event.is_set(): break @@ -2267,7 +2311,10 @@ class DuplexPipeline: allow_text_output = True use_engine_sentence_split = self._use_engine_sentence_split_for_tts() - async for raw_event in self.llm_service.generate_stream(messages): + stream_iter = pending_provider_stream if pending_provider_stream is not None else self.llm_service.generate_stream(messages) + pending_provider_stream = None + + async for raw_event in stream_iter: if self._interrupt_event.is_set(): break @@ -2282,14 +2329,21 @@ class DuplexPipeline: if not tool_call: continue allow_text_output = False + tool_name = self._tool_name(tool_call) or "unknown_tool" executor = self._tool_executor(tool_call) enriched_tool_call = dict(tool_call) enriched_tool_call["executor"] = executor - tool_name = self._tool_name(enriched_tool_call) or "unknown_tool" tool_id = self._tool_id_for_name(tool_name) - tool_display_name = self._tool_display_name(tool_name) or tool_name - wait_for_response = self._tool_wait_for_response(tool_name) + tool_display_name = str( + enriched_tool_call.get("displayName") + or enriched_tool_call.get("display_name") + or self._tool_display_name(tool_name) + or tool_name + ).strip() + wait_for_response = self._tool_wait_for_response_for_call(tool_name, enriched_tool_call) enriched_tool_call["wait_for_response"] = wait_for_response + timeout_ms = self._tool_timeout_ms(enriched_tool_call) + enriched_tool_call["timeout_ms"] = timeout_ms call_id = str(enriched_tool_call.get("id") or "").strip() fn_payload = ( dict(enriched_tool_call.get("function")) @@ -2298,6 +2352,15 @@ class DuplexPipeline: ) raw_args = str(fn_payload.get("arguments") or "") if isinstance(fn_payload, dict) else "" tool_arguments = self._tool_arguments(enriched_tool_call) + if tool_name == "fastgpt.interactive": + context_payload = ( + dict(tool_arguments.get("context")) + if isinstance(tool_arguments.get("context"), dict) + else {} + ) + context_payload.setdefault("turn_id", turn_id) + context_payload.setdefault("response_id", response_id) + tool_arguments["context"] = context_payload merged_tool_arguments = self._apply_tool_default_args(tool_name, tool_arguments) try: merged_args_text = json.dumps(merged_tool_arguments, ensure_ascii=False) @@ -2324,9 +2387,9 @@ class DuplexPipeline: tool_id=tool_id, tool_display_name=tool_display_name, wait_for_response=wait_for_response, - arguments=tool_arguments, + arguments=merged_tool_arguments, executor=executor, - timeout_ms=int(self._TOOL_WAIT_TIMEOUT_SECONDS * 1000), + timeout_ms=timeout_ms, tool_call=enriched_tool_call, ) }, @@ -2457,6 +2520,8 @@ class DuplexPipeline: break tool_results: List[Dict[str, Any]] = [] + provider_managed_tool = False + provider_resumed = False for call in tool_calls: call_id = str(call.get("id") or "").strip() if not call_id: @@ -2466,9 +2531,27 @@ class DuplexPipeline: tool_id = self._tool_id_for_name(tool_name) logger.info(f"[Tool] execute start name={tool_name} call_id={call_id} executor={executor}") if executor == "client": - result = await self._wait_for_single_tool_result(call_id) + timeout_ms = self._tool_timeout_ms(call) + result = await self._wait_for_single_tool_result( + call_id, + timeout_seconds=(timeout_ms / 1000.0), + ) await self._emit_tool_result(result, source="client") tool_results.append(result) + if ( + hasattr(self.llm_service, "handles_client_tool") + and hasattr(self.llm_service, "resume_after_client_tool_result") + and self.llm_service.handles_client_tool(tool_name) + ): + provider_managed_tool = True + status = result.get("status") if isinstance(result.get("status"), dict) else {} + status_code = int(status.get("code") or 0) if status else 0 + output = result.get("output") if isinstance(result.get("output"), dict) else {} + action = str(output.get("action") or "").strip().lower() + if 200 <= status_code < 300 and action != "cancel": + pending_provider_stream = self.llm_service.resume_after_client_tool_result(call_id, result) + provider_resumed = True + break continue call_for_executor = dict(call) @@ -2495,6 +2578,11 @@ class DuplexPipeline: await self._emit_tool_result(result, source="server") tool_results.append(result) + if provider_resumed: + continue + if provider_managed_tool: + break + messages = [ *messages, LLMMessage( diff --git a/engine/runtime/ports/__init__.py b/engine/runtime/ports/__init__.py index 26319b2..0ef9fe6 100644 --- a/engine/runtime/ports/__init__.py +++ b/engine/runtime/ports/__init__.py @@ -14,7 +14,13 @@ from runtime.ports.control_plane import ( KnowledgeRetriever, ToolCatalog, ) -from runtime.ports.llm import LLMCancellable, LLMPort, LLMRuntimeConfigurable, LLMServiceSpec +from runtime.ports.llm import ( + LLMCancellable, + LLMClientToolResumable, + LLMPort, + LLMRuntimeConfigurable, + LLMServiceSpec, +) from runtime.ports.service_factory import RealtimeServiceFactory from runtime.ports.tts import TTSPort, TTSServiceSpec @@ -30,6 +36,7 @@ __all__ = [ "KnowledgeRetriever", "ToolCatalog", "LLMCancellable", + "LLMClientToolResumable", "LLMPort", "LLMRuntimeConfigurable", "LLMServiceSpec", diff --git a/engine/runtime/ports/llm.py b/engine/runtime/ports/llm.py index a591985..a5480f4 100644 --- a/engine/runtime/ports/llm.py +++ b/engine/runtime/ports/llm.py @@ -18,6 +18,7 @@ class LLMServiceSpec: model: str api_key: Optional[str] = None base_url: Optional[str] = None + app_id: Optional[str] = None system_prompt: Optional[str] = None temperature: float = 0.7 knowledge_config: Dict[str, Any] = field(default_factory=dict) @@ -65,3 +66,17 @@ class LLMRuntimeConfigurable(Protocol): def set_tool_schemas(self, schemas: Optional[List[Dict[str, Any]]]) -> None: """Apply runtime tool schemas used for tool calling.""" + + +class LLMClientToolResumable(Protocol): + """Optional extension for providers that pause on client-side tool results.""" + + def handles_client_tool(self, tool_name: str) -> bool: + """Return True when the provider owns the lifecycle of this client tool.""" + + def resume_after_client_tool_result( + self, + tool_call_id: str, + result: Dict[str, Any], + ) -> AsyncIterator[LLMStreamEvent]: + """Resume the provider stream after a correlated client-side tool result.""" diff --git a/engine/tests/test_backend_adapters.py b/engine/tests/test_backend_adapters.py index 9cce105..81f1134 100644 --- a/engine/tests/test_backend_adapters.py +++ b/engine/tests/test_backend_adapters.py @@ -283,6 +283,30 @@ def test_translate_agent_schema_maps_volcengine_fields(): } +def test_translate_agent_schema_maps_llm_app_id(): + payload = { + "agent": { + "llm": { + "provider": "fastgpt", + "model": "fastgpt", + "api_key": "llm-key", + "api_url": "https://cloud.fastgpt.cn/api", + "app_id": "app-fastgpt-123", + }, + } + } + + translated = LocalYamlAssistantConfigAdapter._translate_agent_schema("assistant_demo", payload) + assert translated is not None + assert translated["services"]["llm"] == { + "provider": "fastgpt", + "model": "fastgpt", + "apiKey": "llm-key", + "baseUrl": "https://cloud.fastgpt.cn/api", + "appId": "app-fastgpt-123", + } + + @pytest.mark.asyncio async def test_backend_mode_disabled_uses_local_assistant_config_even_with_url(monkeypatch, tmp_path): class _FailIfCalledClientSession: diff --git a/engine/tests/test_fastgpt_provider.py b/engine/tests/test_fastgpt_provider.py new file mode 100644 index 0000000..9b7d63a --- /dev/null +++ b/engine/tests/test_fastgpt_provider.py @@ -0,0 +1,411 @@ +import json +from types import SimpleNamespace +from typing import Any, Dict, List + +import pytest + +from providers.common.base import LLMMessage +from providers.llm.fastgpt import FastGPTLLMService + + +class _FakeResponse: + def __init__(self, events: List[Any]): + self.events = events + self.closed = False + + async def close(self) -> None: + self.closed = True + + +class _FakeJSONResponse: + def __init__(self, payload: Dict[str, Any], status_code: int = 200): + self._payload = payload + self.status_code = status_code + + def json(self) -> Dict[str, Any]: + return dict(self._payload) + + def raise_for_status(self) -> None: + if self.status_code >= 400: + raise RuntimeError(f"HTTP {self.status_code}") + + +class _FakeAsyncStreamResponse(_FakeResponse): + def __init__(self, events: List[Any]): + super().__init__(events) + self.aclosed = False + + def close(self) -> None: + raise AssertionError("sync close should not be used for async stream responses") + + async def aclose(self) -> None: + self.aclosed = True + + +class _FakeAsyncChatClient: + responses: List[_FakeResponse] = [] + init_payload: Dict[str, Any] | None = None + + def __init__(self, api_key: str, base_url: str): + self.api_key = api_key + self.base_url = base_url + self.requests: List[Dict[str, Any]] = [] + self.init_requests: List[Dict[str, Any]] = [] + + async def create_chat_completion(self, **kwargs): + self.requests.append(dict(kwargs)) + if not self.responses: + raise AssertionError("No fake FastGPT response queued") + return self.responses.pop(0) + + async def get_chat_init(self, **kwargs): + self.init_requests.append(dict(kwargs)) + return _FakeJSONResponse( + self.init_payload or {"data": {"app": {"chatConfig": {"welcomeText": ""}}}}, + ) + + async def close(self) -> None: + return None + + +async def _fake_aiter_stream_events(response: _FakeResponse): + for event in response.events: + yield event + + +@pytest.mark.asyncio +async def test_fastgpt_provider_streams_text_from_data_event(monkeypatch): + monkeypatch.setattr("providers.llm.fastgpt.AsyncChatClient", _FakeAsyncChatClient) + monkeypatch.setattr("providers.llm.fastgpt.aiter_stream_events", _fake_aiter_stream_events) + + _FakeAsyncChatClient.responses = [ + _FakeResponse( + [ + SimpleNamespace( + kind="data", + data={"choices": [{"delta": {"content": "Hello from FastGPT."}}]}, + ), + SimpleNamespace(kind="done", data={}), + ] + ) + ] + + service = FastGPTLLMService(api_key="key", base_url="https://fastgpt.example") + await service.connect() + + events = [event async for event in service.generate_stream([LLMMessage(role="user", content="Hi")])] + + assert [event.type for event in events] == ["text_delta", "done"] + assert events[0].text == "Hello from FastGPT." + assert service.client.requests[0]["messages"] == [{"role": "user", "content": "Hi"}] + assert service.client.requests[0]["chatId"] == service._state.chat_id + + +@pytest.mark.asyncio +async def test_fastgpt_provider_streams_text_from_answer_delta_event(monkeypatch): + monkeypatch.setattr("providers.llm.fastgpt.AsyncChatClient", _FakeAsyncChatClient) + monkeypatch.setattr("providers.llm.fastgpt.aiter_stream_events", _fake_aiter_stream_events) + + _FakeAsyncChatClient.responses = [ + _FakeResponse( + [ + SimpleNamespace( + kind="answer", + data={"choices": [{"delta": {"content": "Hello from answer delta."}}]}, + ), + SimpleNamespace(kind="done", data={}), + ] + ) + ] + + service = FastGPTLLMService(api_key="key", base_url="https://fastgpt.example") + await service.connect() + + events = [event async for event in service.generate_stream([LLMMessage(role="user", content="Hi")])] + + assert [event.type for event in events] == ["text_delta", "done"] + assert events[0].text == "Hello from answer delta." + + +@pytest.mark.asyncio +async def test_fastgpt_provider_uses_async_close_for_stream_responses(monkeypatch): + monkeypatch.setattr("providers.llm.fastgpt.AsyncChatClient", _FakeAsyncChatClient) + monkeypatch.setattr("providers.llm.fastgpt.aiter_stream_events", _fake_aiter_stream_events) + + response = _FakeAsyncStreamResponse( + [ + SimpleNamespace( + kind="data", + data={"choices": [{"delta": {"content": "Hello from FastGPT."}}]}, + ), + SimpleNamespace(kind="done", data={}), + ] + ) + _FakeAsyncChatClient.responses = [response] + + service = FastGPTLLMService(api_key="key", base_url="https://fastgpt.example") + await service.connect() + + events = [event async for event in service.generate_stream([LLMMessage(role="user", content="Hi")])] + + assert [event.type for event in events] == ["text_delta", "done"] + assert response.aclosed is True + + +@pytest.mark.asyncio +async def test_fastgpt_provider_loads_initial_greeting_from_chat_init(monkeypatch): + monkeypatch.setattr("providers.llm.fastgpt.AsyncChatClient", _FakeAsyncChatClient) + monkeypatch.setattr("providers.llm.fastgpt.aiter_stream_events", _fake_aiter_stream_events) + + _FakeAsyncChatClient.init_payload = { + "data": { + "app": { + "chatConfig": { + "welcomeText": "Hello from FastGPT init.", + } + } + } + } + + service = FastGPTLLMService( + api_key="key", + base_url="https://fastgpt.example", + app_id="app-123", + ) + await service.connect() + + greeting = await service.get_initial_greeting() + + assert greeting == "Hello from FastGPT init." + assert service.client.init_requests[0] == { + "appId": "app-123", + "chatId": service._state.chat_id, + } + + +@pytest.mark.asyncio +async def test_fastgpt_provider_maps_interactive_event_to_client_tool(monkeypatch): + monkeypatch.setattr("providers.llm.fastgpt.AsyncChatClient", _FakeAsyncChatClient) + monkeypatch.setattr("providers.llm.fastgpt.aiter_stream_events", _fake_aiter_stream_events) + + _FakeAsyncChatClient.responses = [ + _FakeResponse( + [ + SimpleNamespace( + kind="interactive", + data={ + "type": "userSelect", + "title": "Choose a plan", + "params": { + "description": "Pick the best plan for your team.", + "userSelectOptions": [ + {"id": "basic", "label": "Basic", "value": "basic", "desc": "Starter tier"}, + {"id": "pro", "label": "Pro", "value": "pro", "description": "Advanced tier"}, + ] + }, + }, + ) + ] + ) + ] + + service = FastGPTLLMService(api_key="key", base_url="https://fastgpt.example") + await service.connect() + + events = [event async for event in service.generate_stream([LLMMessage(role="user", content="Start")])] + + assert len(events) == 1 + assert events[0].type == "tool_call" + tool_call = events[0].tool_call + assert tool_call["executor"] == "client" + assert tool_call["wait_for_response"] is True + assert tool_call["timeout_ms"] == 300000 + assert tool_call["function"]["name"] == "fastgpt.interactive" + + arguments = json.loads(tool_call["function"]["arguments"]) + assert arguments["provider"] == "fastgpt" + assert arguments["version"] == "fastgpt_interactive_v1" + assert arguments["interaction"]["type"] == "userSelect" + assert arguments["interaction"]["description"] == "Pick the best plan for your team." + assert arguments["interaction"]["options"][0]["description"] == "Starter tier" + assert arguments["interaction"]["options"][1]["value"] == "pro" + assert arguments["interaction"]["options"][1]["description"] == "Advanced tier" + assert arguments["context"]["chat_id"] == service._state.chat_id + assert service._state.pending_interaction is not None + + +@pytest.mark.asyncio +async def test_fastgpt_provider_unwraps_nested_tool_children_interactive(monkeypatch): + monkeypatch.setattr("providers.llm.fastgpt.AsyncChatClient", _FakeAsyncChatClient) + monkeypatch.setattr("providers.llm.fastgpt.aiter_stream_events", _fake_aiter_stream_events) + + _FakeAsyncChatClient.responses = [ + _FakeResponse( + [ + SimpleNamespace( + kind="interactive", + data={ + "interactive": { + "type": "toolChildrenInteractive", + "params": { + "childrenResponse": { + "type": "userSelect", + "params": { + "description": "Please choose a workflow branch.", + "userSelectOptions": [ + {"value": "A", "description": "Branch A"}, + {"value": "B", "description": "Branch B"}, + ], + }, + } + }, + } + }, + ) + ] + ) + ] + + service = FastGPTLLMService(api_key="key", base_url="https://fastgpt.example") + await service.connect() + + events = [event async for event in service.generate_stream([LLMMessage(role="user", content="Start")])] + + assert len(events) == 1 + arguments = json.loads(events[0].tool_call["function"]["arguments"]) + assert arguments["interaction"]["type"] == "userSelect" + assert arguments["interaction"]["description"] == "Please choose a workflow branch." + assert arguments["interaction"]["options"][0]["description"] == "Branch A" + + +@pytest.mark.asyncio +async def test_fastgpt_provider_uses_opener_for_interactive_prompt_when_prompt_missing(monkeypatch): + monkeypatch.setattr("providers.llm.fastgpt.AsyncChatClient", _FakeAsyncChatClient) + monkeypatch.setattr("providers.llm.fastgpt.aiter_stream_events", _fake_aiter_stream_events) + + _FakeAsyncChatClient.responses = [ + _FakeResponse( + [ + SimpleNamespace( + kind="interactive", + data={ + "type": "userSelect", + "opener": "请确认您是否满意本次服务。", + "params": { + "userSelectOptions": [ + {"value": "是"}, + {"value": "否"}, + ] + }, + }, + ) + ] + ) + ] + + service = FastGPTLLMService(api_key="key", base_url="https://fastgpt.example") + await service.connect() + + events = [event async for event in service.generate_stream([LLMMessage(role="user", content="Start")])] + + assert len(events) == 1 + tool_call = events[0].tool_call + arguments = json.loads(tool_call["function"]["arguments"]) + assert tool_call["display_name"] == "请确认您是否满意本次服务。" + assert arguments["interaction"]["prompt"] == "请确认您是否满意本次服务。" + + +@pytest.mark.asyncio +async def test_fastgpt_provider_resumes_same_chat_after_client_result(monkeypatch): + monkeypatch.setattr("providers.llm.fastgpt.AsyncChatClient", _FakeAsyncChatClient) + monkeypatch.setattr("providers.llm.fastgpt.aiter_stream_events", _fake_aiter_stream_events) + + _FakeAsyncChatClient.responses = [ + _FakeResponse( + [ + SimpleNamespace( + kind="interactive", + data={ + "type": "userSelect", + "params": {"userSelectOptions": [{"label": "Pro", "value": "pro"}]}, + }, + ) + ] + ), + _FakeResponse( + [ + SimpleNamespace(kind="answer", data={"text": "Resumed answer."}), + SimpleNamespace(kind="done", data={}), + ] + ), + ] + + service = FastGPTLLMService(api_key="key", base_url="https://fastgpt.example") + await service.connect() + + initial_events = [event async for event in service.generate_stream([LLMMessage(role="user", content="Start")])] + call_id = initial_events[0].tool_call["id"] + + resumed_events = [ + event + async for event in service.resume_after_client_tool_result( + call_id, + { + "tool_call_id": call_id, + "name": "fastgpt.interactive", + "output": { + "action": "submit", + "result": {"type": "userSelect", "selected": "pro"}, + }, + "status": {"code": 200, "message": "ok"}, + }, + ) + ] + + assert [event.type for event in resumed_events] == ["text_delta", "done"] + assert resumed_events[0].text == "Resumed answer." + assert service.client.requests[1]["chatId"] == service.client.requests[0]["chatId"] + assert service.client.requests[1]["messages"] == [{"role": "user", "content": "pro"}] + assert service._state.pending_interaction is None + + +@pytest.mark.asyncio +async def test_fastgpt_provider_cancel_result_clears_pending_interaction(monkeypatch): + monkeypatch.setattr("providers.llm.fastgpt.AsyncChatClient", _FakeAsyncChatClient) + monkeypatch.setattr("providers.llm.fastgpt.aiter_stream_events", _fake_aiter_stream_events) + + _FakeAsyncChatClient.responses = [ + _FakeResponse( + [ + SimpleNamespace( + kind="interactive", + data={ + "type": "userInput", + "params": {"inputForm": [{"name": "name", "label": "Name"}]}, + }, + ) + ] + ) + ] + + service = FastGPTLLMService(api_key="key", base_url="https://fastgpt.example") + await service.connect() + + initial_events = [event async for event in service.generate_stream([LLMMessage(role="user", content="Start")])] + call_id = initial_events[0].tool_call["id"] + + resumed_events = [ + event + async for event in service.resume_after_client_tool_result( + call_id, + { + "tool_call_id": call_id, + "name": "fastgpt.interactive", + "output": {"action": "cancel", "result": {}}, + "status": {"code": 499, "message": "user_cancelled"}, + }, + ) + ] + + assert [event.type for event in resumed_events] == ["done"] + assert service._state.pending_interaction is None diff --git a/engine/tests/test_tool_call_flow.py b/engine/tests/test_tool_call_flow.py index 717f96a..3550d20 100644 --- a/engine/tests/test_tool_call_flow.py +++ b/engine/tests/test_tool_call_flow.py @@ -109,6 +109,22 @@ class _CaptureGenerateLLM: yield LLMStreamEvent(type="done") +class _InitGreetingLLM: + def __init__(self, greeting: str): + self.greeting = greeting + self.init_calls = 0 + + async def generate(self, _messages, temperature=0.7, max_tokens=None): + return "" + + async def generate_stream(self, _messages, temperature=0.7, max_tokens=None): + yield LLMStreamEvent(type="done") + + async def get_initial_greeting(self): + self.init_calls += 1 + return self.greeting + + def _build_pipeline(monkeypatch, llm_rounds: List[List[LLMStreamEvent]]) -> tuple[DuplexPipeline, List[Dict[str, Any]]]: monkeypatch.setattr("runtime.pipeline.duplex.SileroVAD", _DummySileroVAD) monkeypatch.setattr("runtime.pipeline.duplex.VADProcessor", _DummyVADProcessor) @@ -306,6 +322,21 @@ async def test_generated_opener_uses_tool_capable_turn_when_tools_available(monk assert called.get("user_text") == "" +@pytest.mark.asyncio +async def test_provider_initial_greeting_takes_precedence_over_local_opener(monkeypatch): + llm = _InitGreetingLLM("FastGPT init greeting") + pipeline, events = _build_pipeline_with_custom_llm(monkeypatch, llm) + pipeline.apply_runtime_overrides({"output": {"mode": "text"}}) + pipeline.conversation.greeting = "local fallback greeting" + + await pipeline.emit_initial_greeting() + + finals = [event for event in events if event.get("type") == "assistant.response.final"] + assert finals + assert finals[-1]["text"] == "FastGPT init greeting" + assert llm.init_calls == 1 + + @pytest.mark.asyncio async def test_manual_opener_tool_calls_emit_assistant_tool_call(monkeypatch): pipeline, events = _build_pipeline(monkeypatch, [[LLMStreamEvent(type="done")]]) @@ -736,3 +767,139 @@ async def test_eou_early_return_clears_stale_asr_capture(monkeypatch): assert pipeline._asr_capture_active is False assert pipeline._asr_capture_started_ms == 0.0 assert pipeline._pending_speech_audio == b"" + +class _FakeResumableLLM: + def __init__(self, *, timeout_ms: int = 300000): + self.timeout_ms = timeout_ms + self.generate_stream_calls = 0 + self.resumed_results: List[Dict[str, Any]] = [] + + async def generate(self, _messages, temperature=0.7, max_tokens=None): + return "" + + async def generate_stream(self, _messages, temperature=0.7, max_tokens=None): + self.generate_stream_calls += 1 + yield LLMStreamEvent( + type="tool_call", + tool_call={ + "id": "call_fastgpt_1", + "executor": "client", + "wait_for_response": True, + "timeout_ms": self.timeout_ms, + "display_name": "Choose a plan", + "type": "function", + "function": { + "name": "fastgpt.interactive", + "arguments": json.dumps( + { + "provider": "fastgpt", + "version": "fastgpt_interactive_v1", + "interaction": { + "type": "userSelect", + "title": "Choose a plan", + "options": [ + {"id": "basic", "label": "Basic", "value": "basic"}, + {"id": "pro", "label": "Pro", "value": "pro"}, + ], + "form": [], + }, + "context": {"chat_id": "fastgpt_chat_1"}, + }, + ensure_ascii=False, + ), + }, + }, + ) + yield LLMStreamEvent(type="done") + + def handles_client_tool(self, tool_name: str) -> bool: + return tool_name == "fastgpt.interactive" + + async def resume_after_client_tool_result(self, tool_call_id: str, result: Dict[str, Any]): + self.resumed_results.append({"tool_call_id": tool_call_id, "result": dict(result)}) + yield LLMStreamEvent(type="text_delta", text="provider resumed answer.") + yield LLMStreamEvent(type="done") + + +def _build_pipeline_with_custom_llm(monkeypatch, llm_service) -> tuple[DuplexPipeline, List[Dict[str, Any]]]: + monkeypatch.setattr("runtime.pipeline.duplex.SileroVAD", _DummySileroVAD) + monkeypatch.setattr("runtime.pipeline.duplex.VADProcessor", _DummyVADProcessor) + monkeypatch.setattr("runtime.pipeline.duplex.EouDetector", _DummyEouDetector) + + pipeline = DuplexPipeline( + transport=_FakeTransport(), + session_id="s_fastgpt", + llm_service=llm_service, + tts_service=_FakeTTS(), + asr_service=_FakeASR(), + ) + events: List[Dict[str, Any]] = [] + + async def _capture_event(event: Dict[str, Any], priority: int = 20): + events.append(event) + + async def _noop_speak(_text: str, *args, **kwargs): + return None + + monkeypatch.setattr(pipeline, "_send_event", _capture_event) + monkeypatch.setattr(pipeline, "_speak_sentence", _noop_speak) + return pipeline, events + + +@pytest.mark.asyncio +async def test_fastgpt_provider_managed_tool_resumes_provider_stream(monkeypatch): + llm = _FakeResumableLLM(timeout_ms=300000) + pipeline, events = _build_pipeline_with_custom_llm(monkeypatch, llm) + pipeline.apply_runtime_overrides({"output": {"mode": "text"}}) + + task = asyncio.create_task(pipeline._handle_turn("start fastgpt")) + for _ in range(200): + if any(event.get("type") == "assistant.tool_call" for event in events): + break + await asyncio.sleep(0.005) + + tool_event = next(event for event in events if event.get("type") == "assistant.tool_call") + assert tool_event.get("executor") == "client" + assert tool_event.get("tool_name") == "fastgpt.interactive" + assert tool_event.get("timeout_ms") == 300000 + assert tool_event.get("arguments", {}).get("context", {}).get("turn_id") + assert tool_event.get("arguments", {}).get("context", {}).get("response_id") + + await pipeline.handle_tool_call_results( + [ + { + "tool_call_id": "call_fastgpt_1", + "name": "fastgpt.interactive", + "output": { + "action": "submit", + "result": {"type": "userSelect", "selected": "pro"}, + }, + "status": {"code": 200, "message": "ok"}, + } + ] + ) + await task + + finals = [event for event in events if event.get("type") == "assistant.response.final"] + assert finals + assert "provider resumed answer" in finals[-1].get("text", "") + assert llm.generate_stream_calls == 1 + assert len(llm.resumed_results) == 1 + assert llm.resumed_results[0]["tool_call_id"] == "call_fastgpt_1" + + +@pytest.mark.asyncio +async def test_fastgpt_provider_managed_tool_timeout_stops_without_generic_tool_prompt(monkeypatch): + llm = _FakeResumableLLM(timeout_ms=10) + pipeline, events = _build_pipeline_with_custom_llm(monkeypatch, llm) + pipeline.apply_runtime_overrides({"output": {"mode": "text"}}) + + await pipeline._handle_turn("start fastgpt") + + tool_results = [event for event in events if event.get("type") == "assistant.tool_result"] + assert tool_results + assert tool_results[-1].get("result", {}).get("status", {}).get("code") == 504 + finals = [event for event in events if event.get("type") == "assistant.response.final"] + assert not finals + assert llm.generate_stream_calls == 1 + assert llm.resumed_results == [] diff --git a/web/pages/Assistants.tsx b/web/pages/Assistants.tsx index ec0ba28..11ae5b6 100644 --- a/web/pages/Assistants.tsx +++ b/web/pages/Assistants.tsx @@ -263,6 +263,7 @@ export const AssistantsPage: React.FC = () => { botCannotBeInterrupted: false, interruptionSensitivity: 180, configMode: 'platform', + appId: '', }; try { const created = await createAssistant(newAssistantPayload); @@ -874,6 +875,20 @@ export const AssistantsPage: React.FC = () => { /> + {selectedAssistant.configMode === 'fastgpt' && ( +
+ + updateAssistant('appId', e.target.value)} + placeholder="璇疯緭鍏?FastGPT App ID..." + className="bg-white/5 border-white/10 focus:border-primary/50 font-mono text-xs" + /> +
+ )} +
)} + {fastgptInteractiveDialog.open && ( +
+
+ {!fastgptInteractiveDialog.required && ( + + )} +
+
+ {fastgptInteractiveHeaderText} +
+ {fastgptInteractiveDialog.prompt + && fastgptInteractiveDialog.prompt !== fastgptInteractiveHeaderText && ( +

+ {fastgptInteractiveDialog.prompt} +

+ )} + {fastgptInteractiveDialog.description + && fastgptInteractiveDialog.description !== fastgptInteractiveHeaderText + && fastgptInteractiveDialog.description !== fastgptInteractiveDialog.prompt && ( +

+ {fastgptInteractiveDialog.description} +

+ )} + {fastgptInteractiveDialog.prompt + && fastgptInteractiveDialog.prompt !== fastgptInteractiveHeaderText + && fastgptInteractiveDialog.prompt !== fastgptInteractiveDialog.description && ( +

+ {fastgptInteractiveDialog.prompt} +

+ )} +
+ {fastgptInteractiveDialog.interactionType === 'userSelect' ? ( +
+ {fastgptInteractiveDialog.options.map((option) => { + const selected = fastgptInteractiveDialog.selectedValues.includes(option.value); + return ( + + ); + })} +
+ ) : ( +
+ {fastgptInteractiveDialog.form.map((field) => { + const value = fastgptInteractiveDialog.fieldValues[field.name] || ''; + const fieldType = field.inputType.toLowerCase(); + const useTextarea = ['textarea', 'multiline', 'longtext'].includes(fieldType); + const useSelect = field.options.length > 0 && ['select', 'dropdown', 'radio'].includes(fieldType); + return ( +