From 9e2374f492db9ae66a32c315fff31d1dfd4d2abb Mon Sep 17 00:00:00 2001 From: Xin Wang Date: Thu, 28 May 2026 11:32:20 +0800 Subject: [PATCH] Add voice state tags, SuperTTS configs, and demo WS log groups. Parse leading tags from LLM replies and emit response.state over the product websocket while stripping tags from TTS/text streams. Add FastGPT+Xfyun voice configs (including state-enabled preset), SuperTTS support, and context sync for interrupted turns. Refresh the voice demo with a state indicator and collapsible audio delta websocket log groups. Co-authored-by: Cursor --- config/voice-fastgpt-state-xfyunSuperTTS.json | 101 +++++ config/voice-fastgpt-xfyunSuperTTS.json | 101 +++++ config/voice-fastgpt-xfyunTTS.json | 101 +++++ config/voice-fastgpt.example.json | 58 --- config/voice-xfyun.json | 8 +- config/voice.json | 8 +- src/voice/config.py | 51 ++- src/voice/context_sync.py | 40 ++ src/voice/fastgpt_llm.py | 91 +++- src/voice/pipeline.py | 61 ++- src/voice/protocol.py | 66 ++- src/voice/response_state.py | 136 ++++++ src/voice/services.py | 35 +- src/voice/text_stream.py | 92 ++++- src/voice/xfyun_super_tts.py | 391 ++++++++++++++++++ static/voice-demo/app.js | 351 +++++++++++----- static/voice-demo/index.html | 4 + static/voice-demo/styles.css | 96 ++++- 18 files changed, 1596 insertions(+), 195 deletions(-) create mode 100644 config/voice-fastgpt-state-xfyunSuperTTS.json create mode 100644 config/voice-fastgpt-xfyunSuperTTS.json create mode 100644 config/voice-fastgpt-xfyunTTS.json delete mode 100644 config/voice-fastgpt.example.json create mode 100644 src/voice/context_sync.py create mode 100644 src/voice/response_state.py create mode 100644 src/voice/xfyun_super_tts.py diff --git a/config/voice-fastgpt-state-xfyunSuperTTS.json b/config/voice-fastgpt-state-xfyunSuperTTS.json new file mode 100644 index 0000000..98878a3 --- /dev/null +++ b/config/voice-fastgpt-state-xfyunSuperTTS.json @@ -0,0 +1,101 @@ +{ + "server": { + "host": "0.0.0.0", + "port": 8000, + "cors_origins": ["*"] + }, + "audio": { + "sample_rate_hz": 16000, + "channels": 1, + "frame_ms": 20 + }, + "session": { + "inactivity_timeout_sec": 60 + }, + "turn": { + "vad": { + "confidence": 0.8, + "start_secs": 0.4, + "stop_secs": 0.2, + "min_volume": 0.8 + }, + "interruption_min_chars": 3, + "interruption_use_interim": true, + "interruption_short_replies": [ + "是", + "是的", + "对", + "对的", + "嗯", + "好", + "好的", + "行", + "可以", + "没问题", + "不是", + "不", + "不行", + "不用", + "不要", + "没有", + "否", + "你好", + "在吗" + ], + "user_speech_timeout_sec": 0.2 + }, + "agent": { + "system_prompt": "FastGPT app owns the system prompt when send_system_prompt is false.", + "greeting": "您好,这里是无锡交警,我将为您远程处理交通事故。请将人员撤离至路侧安全区域,开启危险报警双闪灯、放置三角警告牌、做好安全防护,谨防二次事故伤害。若您已经准备好了,请点击继续办理,如需人工服务,请说转人工。", + "greeting_mode": "fixed", + "response_state": { + "enabled": true, + "tag": "state", + "event_type": "response.state", + "max_prefix_chars": 256 + } + }, + "services": { + "stt": { + "provider": "xfyun", + "app_id": "416ce125", + "api_key": "c65342fe603126c3610031d8429bb36d", + "api_secret": "MzkyYmI5OWEyODQzN2FiN2VhN2UzYzU4", + "base_url": "wss://iat-api.xfyun.cn/v2/iat", + "language": "zh_cn", + "domain": "iat", + "accent": "mandarin", + "encoding": "raw", + "frame_size": 1280, + "timeout_sec": 10.0 + }, + "llm": { + "provider": "fastgpt", + "api_key": "fastgpt-zlLjYtWZWN0uhQHs3ZOFHG4KLGMIdr2CkbZLCSfqGm5vcdx5xIZbp", + "base_url": "http://localhost:3030", + "model": "my-voice-app", + "app_id": "691eddaa53e3f8d9f25f1370", + "chat_id": null, + "variables": {}, + "detail": false, + "timeout_sec": 60.0, + "send_system_prompt": false + }, + "tts": { + "provider": "xfyun_super", + "app_id": "416ce125", + "api_key": "c65342fe603126c3610031d8429bb36d", + "api_secret": "MzkyYmI5OWEyODQzN2FiN2VhN2UzYzU4", + "base_url": "wss://cbm01.cn-huabei-1.xf-yun.com/v1/private/mcd9m97e6", + "voice": "x5_lingxiaoxuan_flow", + "aue": "raw", + "speed": 50, + "volume": 50, + "pitch": 50, + "oral_level": "mid", + "source_sample_rate_hz": 24000, + "text_aggregation_mode": "token", + "timeout_sec": 30.0 + } + } +} diff --git a/config/voice-fastgpt-xfyunSuperTTS.json b/config/voice-fastgpt-xfyunSuperTTS.json new file mode 100644 index 0000000..cdca5f4 --- /dev/null +++ b/config/voice-fastgpt-xfyunSuperTTS.json @@ -0,0 +1,101 @@ +{ + "server": { + "host": "0.0.0.0", + "port": 8000, + "cors_origins": ["*"] + }, + "audio": { + "sample_rate_hz": 16000, + "channels": 1, + "frame_ms": 20 + }, + "session": { + "inactivity_timeout_sec": 60 + }, + "turn": { + "vad": { + "confidence": 0.8, + "start_secs": 0.4, + "stop_secs": 0.2, + "min_volume": 0.8 + }, + "interruption_min_chars": 3, + "interruption_use_interim": true, + "interruption_short_replies": [ + "是", + "是的", + "对", + "对的", + "嗯", + "好", + "好的", + "行", + "可以", + "没问题", + "不是", + "不", + "不行", + "不用", + "不要", + "没有", + "否", + "你好", + "在吗" + ], + "user_speech_timeout_sec": 0.2 + }, + "agent": { + "system_prompt": "FastGPT app owns the system prompt when send_system_prompt is false.", + "greeting": "您好,这里是无锡交警,我将为您远程处理交通事故。请将人员撤离至路侧安全区域,开启危险报警双闪灯、放置三角警告牌、做好安全防护,谨防二次事故伤害。若您已经准备好了,请点击继续办理,如需人工服务,请说转人工。", + "greeting_mode": "fixed", + "response_state": { + "enabled": true, + "tag": "state", + "event_type": "response.state", + "max_prefix_chars": 256 + } + }, + "services": { + "stt": { + "provider": "xfyun", + "app_id": "416ce125", + "api_key": "c65342fe603126c3610031d8429bb36d", + "api_secret": "MzkyYmI5OWEyODQzN2FiN2VhN2UzYzU4", + "base_url": "wss://iat-api.xfyun.cn/v2/iat", + "language": "zh_cn", + "domain": "iat", + "accent": "mandarin", + "encoding": "raw", + "frame_size": 1280, + "timeout_sec": 10.0 + }, + "llm": { + "provider": "fastgpt", + "api_key": "fastgpt-v1FljAxBz3tJeS0bH7HZU4yVGclsTcfiy9yK7V9Zr9126maDHQ97Xlo8n", + "base_url": "http://localhost:3030", + "model": "my-voice-app", + "app_id": "6a153aed53e3f8d9f2744905", + "chat_id": null, + "variables": {}, + "detail": false, + "timeout_sec": 60.0, + "send_system_prompt": false + }, + "tts": { + "provider": "xfyun_super", + "app_id": "416ce125", + "api_key": "c65342fe603126c3610031d8429bb36d", + "api_secret": "MzkyYmI5OWEyODQzN2FiN2VhN2UzYzU4", + "base_url": "wss://cbm01.cn-huabei-1.xf-yun.com/v1/private/mcd9m97e6", + "voice": "x5_lingxiaoxuan_flow", + "aue": "raw", + "speed": 50, + "volume": 50, + "pitch": 50, + "oral_level": "mid", + "source_sample_rate_hz": 24000, + "text_aggregation_mode": "token", + "timeout_sec": 30.0 + } + } +} diff --git a/config/voice-fastgpt-xfyunTTS.json b/config/voice-fastgpt-xfyunTTS.json new file mode 100644 index 0000000..4fcf843 --- /dev/null +++ b/config/voice-fastgpt-xfyunTTS.json @@ -0,0 +1,101 @@ +{ + "server": { + "host": "0.0.0.0", + "port": 8000, + "cors_origins": ["*"] + }, + "audio": { + "sample_rate_hz": 16000, + "channels": 1, + "frame_ms": 20 + }, + "session": { + "inactivity_timeout_sec": 60 + }, + "turn": { + "vad": { + "confidence": 0.7, + "start_secs": 0.35, + "stop_secs": 0.2, + "min_volume": 0.65 + }, + "interruption_min_chars": 3, + "interruption_use_interim": true, + "interruption_short_replies": [ + "是", + "是的", + "对", + "对的", + "嗯", + "好", + "好的", + "行", + "可以", + "没问题", + "不是", + "不", + "不行", + "不用", + "不要", + "没有", + "否", + "你好", + "在吗" + ], + "user_speech_timeout_sec": 0.2 + }, + "agent": { + "system_prompt": "FastGPT app owns the system prompt when send_system_prompt is false.", + "greeting": "您好,这里是无锡交警,我将为您远程处理交通事故。请将人员撤离至路侧安全区域,开启危险报警双闪灯、放置三角警告牌、做好安全防护,谨防二次事故伤害。若您已经准备好了,请点击继续办理,如需人工服务,请说转人工。", + "greeting_mode": "fixed", + "response_state": { + "enabled": true, + "tag": "state", + "event_type": "response.state", + "max_prefix_chars": 256 + } + }, + "services": { + "stt": { + "provider": "xfyun", + "app_id": "416ce125", + "api_key": "c65342fe603126c3610031d8429bb36d", + "api_secret": "MzkyYmI5OWEyODQzN2FiN2VhN2UzYzU4", + "base_url": "wss://iat-api.xfyun.cn/v2/iat", + "language": "zh_cn", + "domain": "iat", + "accent": "mandarin", + "encoding": "raw", + "frame_size": 1280, + "timeout_sec": 10.0 + }, + "llm": { + "provider": "fastgpt", + "api_key": "fastgpt-v1FljAxBz3tJeS0bH7HZU4yVGclsTcfiy9yK7V9Zr9126maDHQ97Xlo8n", + "base_url": "http://localhost:3030", + "model": "my-voice-app", + "app_id": "6a153aed53e3f8d9f2744905", + "chat_id": null, + "variables": {}, + "detail": false, + "timeout_sec": 60.0, + "send_system_prompt": false + }, + "tts": { + "provider": "xfyun_super", + "app_id": "416ce125", + "api_key": "c65342fe603126c3610031d8429bb36d", + "api_secret": "MzkyYmI5OWEyODQzN2FiN2VhN2UzYzU4", + "base_url": "wss://cbm01.cn-huabei-1.xf-yun.com/v1/private/mcd9m97e6", + "voice": "x5_lingxiaoxuan_flow", + "aue": "raw", + "speed": 50, + "volume": 50, + "pitch": 50, + "oral_level": "mid", + "source_sample_rate_hz": 24000, + "text_aggregation_mode": "token", + "timeout_sec": 30.0 + } + } +} diff --git a/config/voice-fastgpt.example.json b/config/voice-fastgpt.example.json deleted file mode 100644 index c7063eb..0000000 --- a/config/voice-fastgpt.example.json +++ /dev/null @@ -1,58 +0,0 @@ -{ - "server": { - "host": "0.0.0.0", - "port": 8000, - "cors_origins": ["*"] - }, - "audio": { - "sample_rate_hz": 16000, - "channels": 1, - "frame_ms": 20 - }, - "session": { - "inactivity_timeout_sec": 60 - }, - "turn": { - "vad": { - "confidence": 0.7, - "start_secs": 0.2, - "stop_secs": 0.6, - "min_volume": 0.6 - }, - "interruption_min_chars": 3, - "interruption_use_interim": true, - "user_speech_timeout_sec": 1.0 - }, - "agent": { - "system_prompt": "FastGPT app owns the system prompt when send_system_prompt is false.", - "greeting": "你好", - "greeting_mode": "generated" - }, - "services": { - "stt": { - "provider": "openai", - "api_key": "", - "base_url": null, - "model": "gpt-4o-mini-transcribe", - "language": "zh" - }, - "llm": { - "provider": "fastgpt", - "api_key": "", - "base_url": null, - "model": "my-voice-app", - "chat_id": null, - "variables": {}, - "detail": false, - "timeout_sec": 60.0, - "send_system_prompt": false - }, - "tts": { - "provider": "openai", - "api_key": "", - "base_url": null, - "model": "gpt-4o-mini-tts", - "voice": "alloy" - } - } -} diff --git a/config/voice-xfyun.json b/config/voice-xfyun.json index 5c302bd..6143b60 100644 --- a/config/voice-xfyun.json +++ b/config/voice-xfyun.json @@ -45,7 +45,13 @@ "agent": { "system_prompt": "# 角色 你是一个高度集成、安全第一的交警AI接警员。正在收集事故人员伤亡情况,时间,地点,事故原因,事故车辆数量,收集完成之后和用户说再见", "greeting": "您好,这里是无锡交警,我将为您远程处理交通事故。请将人员撤离至路侧安全区域,开启危险报警双闪灯、放置三角警告牌、做好安全防护,谨防二次事故伤害。若您已经准备好了,请点击继续办理,如需人工服务,请说转人工。", - "greeting_mode": "fixed" + "greeting_mode": "fixed", + "response_state": { + "enabled": true, + "tag": "state", + "event_type": "response.state", + "max_prefix_chars": 256 + } }, "services": { "stt": { diff --git a/config/voice.json b/config/voice.json index 64f64af..10891ad 100644 --- a/config/voice.json +++ b/config/voice.json @@ -47,7 +47,13 @@ "agent": { "system_prompt": "You are a helpful, friendly voice assistant. Keep responses concise and natural for spoken conversation.", "greeting": "Please introduce yourself briefly.", - "greeting_mode": "generated" + "greeting_mode": "generated", + "response_state": { + "enabled": false, + "tag": "state", + "event_type": "response.state", + "max_prefix_chars": 256 + } }, "services": { "stt": { diff --git a/src/voice/config.py b/src/voice/config.py index affbce0..3cf3c3e 100644 --- a/src/voice/config.py +++ b/src/voice/config.py @@ -26,6 +26,9 @@ def resolve_voice_config_path() -> Path: DEFAULT_VOICE_CONFIG = resolve_voice_config_path() +SUPPORTED_LLM_PROVIDERS = frozenset({"openai", "fastgpt"}) +_LLM_PROVIDER_ALIASES = {"llm": "openai", "openai": "openai", "fastgpt": "fastgpt"} + @dataclass(frozen=True) class ServerConfig: @@ -93,11 +96,20 @@ class TurnConfig: ) +@dataclass(frozen=True) +class ResponseStateConfig: + enabled: bool = False + tag: str = "state" + event_type: str = "response.state" + max_prefix_chars: int = 256 + + @dataclass(frozen=True) class AgentConfig: system_prompt: str = "You are a helpful, friendly voice assistant." greeting: str | None = None greeting_mode: str = "generated" + response_state: ResponseStateConfig = field(default_factory=ResponseStateConfig) @dataclass(frozen=True) @@ -106,6 +118,7 @@ class LLMConfig: api_key: str = "" base_url: str | None = None model: str = "gpt-4o-mini" + app_id: str | None = None temperature: float | None = 0.7 chat_id: str | None = None variables: dict[str, str] = field(default_factory=dict) @@ -113,6 +126,19 @@ class LLMConfig: timeout_sec: float = 60.0 send_system_prompt: bool = False + @property + def is_fastgpt(self) -> bool: + return self.provider == "fastgpt" + + @property + def is_openai(self) -> bool: + return self.provider == "openai" + + @property + def uses_local_context_history(self) -> bool: + """Whether the pipeline should seed and maintain local LLM context history.""" + return not self.is_fastgpt or self.send_system_prompt + @dataclass(frozen=True) class STTConfig: @@ -147,6 +173,8 @@ class TTSConfig: pitch: int = 50 timeout_sec: float = 30.0 source_sample_rate_hz: int | None = None + oral_level: str = "mid" + text_aggregation_mode: str | None = None @dataclass(frozen=True) @@ -183,14 +211,24 @@ def config_from_dict(data: dict) -> EngineConfig: agent["greeting"] = None if agent.get("greeting_mode") not in (None, "generated", "fixed", "off"): raise ValueError("agent.greeting_mode must be one of: generated, fixed, off") + response_state = ResponseStateConfig(**_dict(agent.pop("response_state"))) + if response_state.max_prefix_chars < 1: + raise ValueError("agent.response_state.max_prefix_chars must be greater than 0") + if not response_state.tag: + raise ValueError("agent.response_state.tag must not be empty") + if not response_state.event_type: + raise ValueError("agent.response_state.event_type must not be empty") stt = _dict(services.get("stt") or services.get("asr")) if stt.get("language") == "": stt["language"] = None llm = _dict(services.get("llm")) + llm["provider"] = _normalize_llm_provider(llm.get("provider", LLMConfig().provider)) if llm.get("chat_id") == "": llm["chat_id"] = None + if llm.get("app_id") == "": + llm["app_id"] = None if not isinstance(llm.get("variables"), dict): llm["variables"] = {} @@ -219,7 +257,7 @@ def config_from_dict(data: dict) -> EngineConfig: ) ), ), - agent=AgentConfig(**agent), + agent=AgentConfig(**agent, response_state=response_state), services=ServicesConfig( llm=LLMConfig(**llm), stt=STTConfig(**stt), @@ -230,3 +268,14 @@ def config_from_dict(data: dict) -> EngineConfig: def _dict(value: object) -> dict: return dict(value) if isinstance(value, dict) else {} + + +def _normalize_llm_provider(value: object) -> str: + provider = str(value or LLMConfig().provider).strip().lower() + normalized = _LLM_PROVIDER_ALIASES.get(provider) + if normalized is None: + supported = ", ".join(sorted(SUPPORTED_LLM_PROVIDERS | {"llm"})) + raise ValueError( + f"services.llm.provider must be one of: {supported}; got {value!r}" + ) + return normalized diff --git a/src/voice/context_sync.py b/src/voice/context_sync.py new file mode 100644 index 0000000..3dab3c3 --- /dev/null +++ b/src/voice/context_sync.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from typing import Any + +from pipecat.frames.frames import Frame, InterruptionFrame, LLMMessagesAppendFrame +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor + +from .text_stream import ProductTextStreamProcessor, maybe_sync_assistant_context + + +class AssistantContextSyncProcessor(FrameProcessor): + """Sync LLM context to urgent-streamed assistant text before text-input turns. + + ``input.text`` with ``interrupt: true`` queues ``InterruptionFrame`` before + ``LLMMessagesAppendFrame``. This processor runs context repair after the + interrupt has propagated (including TTS-phase interrupts) and before the new + user message is appended. + """ + + def __init__( + self, + *, + text_stream: ProductTextStreamProcessor, + assistant_aggregator: Any, + ) -> None: + super().__init__() + self._text_stream = text_stream + self._assistant_aggregator = assistant_aggregator + self._sync_on_next_append = False + + async def process_frame(self, frame: Frame, direction: FrameDirection) -> None: + await super().process_frame(frame, direction) + + if isinstance(frame, InterruptionFrame): + self._sync_on_next_append = True + elif isinstance(frame, LLMMessagesAppendFrame) and self._sync_on_next_append: + self._sync_on_next_append = False + maybe_sync_assistant_context(self._assistant_aggregator, self._text_stream) + + await self.push_frame(frame, direction) diff --git a/src/voice/fastgpt_llm.py b/src/voice/fastgpt_llm.py index 4055d0c..b05a0f2 100644 --- a/src/voice/fastgpt_llm.py +++ b/src/voice/fastgpt_llm.py @@ -7,11 +7,13 @@ from typing import Any import httpx from fastgpt_client import AsyncChatClient, FastGPTInteractiveEvent, aiter_stream_events from fastgpt_client.exceptions import FastGPTError +from loguru import logger from pipecat.frames.frames import ( CancelFrame, EndFrame, Frame, + InterruptionFrame, LLMContextFrame, LLMFullResponseEndFrame, LLMFullResponseStartFrame, @@ -133,6 +135,24 @@ class FastGPTLLMSettings(LLMSettings): detail: bool = False +def _default_fastgpt_settings(*, model: str = "fastgpt") -> FastGPTLLMSettings: + return FastGPTLLMSettings( + model=model, + system_instruction=None, + temperature=None, + max_tokens=None, + top_p=None, + top_k=None, + frequency_penalty=None, + presence_penalty=None, + seed=None, + filter_incomplete_user_turns=False, + user_turn_completion_config=None, + variables={}, + detail=False, + ) + + class FastGPTLLMService(LLMService): """FastGPT LLM service using chatId server-side memory and workflow variables.""" @@ -144,18 +164,20 @@ class FastGPTLLMService(LLMService): api_key: str, base_url: str, chat_id: str | None = None, + app_id: str | None = None, send_system_prompt: bool = False, greeting_prompt: str | None = None, timeout: float = 60.0, settings: FastGPTLLMSettings | None = None, **kwargs, ) -> None: - default_settings = self.Settings(model="fastgpt") + default_settings = _default_fastgpt_settings() if settings is not None: default_settings.apply_update(settings) super().__init__(settings=default_settings, **kwargs) self._chat_id = chat_id or f"voice_{uuid.uuid4().hex[:16]}" + self._app_id = (app_id or "").strip() self._send_system_prompt = send_system_prompt self._greeting_prompt = (greeting_prompt or "你好").strip() or "你好" self._client = AsyncChatClient( @@ -165,6 +187,10 @@ class FastGPTLLMService(LLMService): ) self._active_response = None + @property + def app_id(self) -> str: + return self._app_id + @property def chat_id(self) -> str: return self._chat_id @@ -183,6 +209,63 @@ class FastGPTLLMService(LLMService): await self._close_active_response() await super().cancel(frame) + async def _handle_interruptions(self, _: InterruptionFrame) -> None: + await self._close_active_response() + await super()._handle_interruptions(_) + + @staticmethod + def _welcome_text_from_init_payload(payload: Any) -> str: + if not isinstance(payload, dict): + return "" + + for container in (payload.get("app"), payload.get("data"), payload): + if not isinstance(container, dict): + continue + nested_app = container.get("app") + if isinstance(nested_app, dict): + text = FastGPTLLMService._welcome_text_from_app(nested_app) + if text: + return text + text = FastGPTLLMService._welcome_text_from_app(container) + if text: + return text + return "" + + @staticmethod + def _welcome_text_from_app(app_payload: dict[str, Any]) -> str: + chat_config = ( + app_payload.get("chatConfig") + if isinstance(app_payload.get("chatConfig"), dict) + else {} + ) + return _first_nonempty_text( + chat_config.get("welcomeText"), + app_payload.get("welcomeText"), + ) + + async def fetch_welcome_text(self) -> str | None: + """Return FastGPT app welcome text from chat init when ``app_id`` is configured.""" + if not self._app_id: + return None + + try: + response = await self._client.get_chat_init( + appId=self._app_id, + chatId=self._chat_id, + ) + response.raise_for_status() + text = self._welcome_text_from_init_payload(response.json()) + if text: + logger.info(f"FastGPT welcomeText loaded for appId={self._app_id}") + return text or None + except FastGPTError as exc: + logger.warning(f"FastGPT chat init failed: {exc}") + except httpx.HTTPError as exc: + logger.warning(f"FastGPT chat init HTTP error: {exc}") + except Exception as exc: + logger.warning(f"FastGPT chat init error: {exc}") + return None + async def _close_active_response(self) -> None: response = self._active_response self._active_response = None @@ -216,6 +299,12 @@ class FastGPTLLMService(LLMService): messages = self._build_fastgpt_messages(context) variables = self._settings.variables or None + logger.info( + "FastGPT chat completion " + f"chatId={self._chat_id} appId={self._app_id or '-'} " + f"variables={sorted((variables or {}).keys())} messages={messages!r}" + ) + await self.start_ttfb_metrics() try: diff --git a/src/voice/pipeline.py b/src/voice/pipeline.py index 0714016..a49e8ec 100644 --- a/src/voice/pipeline.py +++ b/src/voice/pipeline.py @@ -32,10 +32,13 @@ from pipecat.turns.user_stop.speech_timeout_user_turn_stop_strategy import ( from pipecat.turns.user_turn_strategies import UserTurnStrategies from .config import EngineConfig +from .context_sync import AssistantContextSyncProcessor +from .fastgpt_llm import FastGPTLLMService from .protocol import ProductWebsocketSerializer from .services import create_llm_service, create_stt_service, create_tts_service +from .response_state import StateTagResponseProcessor from .text_input import ProductTextInputProcessor -from .text_stream import ProductTextStreamProcessor, sync_streamed_assistant_context +from .text_stream import ProductTextStreamProcessor, maybe_sync_assistant_context from .transcript_stream import ProductTranscriptStreamProcessor from .turn_start import InterruptionGateUserTurnStartStrategy @@ -83,14 +86,15 @@ async def run_pipeline_with_serializer( session_variables={"session_id": chat_id, "channel": "voice"}, greeting_prompt=config.agent.greeting, ) - if llm_config.provider == "fastgpt": - logger.info(f"FastGPT chatId={chat_id}") + if llm_config.is_fastgpt: + logger.info(f"LLM backend=fastgpt chatId={chat_id} appId={llm_config.app_id or '-'}") + else: + logger.info(f"LLM backend=openai model={llm_config.model}") tts = create_tts_service(config.services.tts, config.audio) - use_fastgpt = llm_config.provider == "fastgpt" and not llm_config.send_system_prompt messages: list[dict[str, str]] = [] - if not use_fastgpt: + if llm_config.uses_local_context_history: messages = [{"role": "system", "content": config.agent.system_prompt}] if config.agent.greeting and config.agent.greeting_mode == "generated": messages.append({"role": "system", "content": config.agent.greeting}) @@ -126,21 +130,31 @@ async def run_pipeline_with_serializer( ) text_stream = ProductTextStreamProcessor() + context_sync = AssistantContextSyncProcessor( + text_stream=text_stream, + assistant_aggregator=assistant_aggregator, + ) - pipeline = Pipeline( + processors = [ + transport.input(), + ProductTextInputProcessor(), + stt, + ProductTranscriptStreamProcessor(), + context_sync, + user_aggregator, + llm, + ] + if config.agent.response_state.enabled: + processors.append(StateTagResponseProcessor(config.agent.response_state)) + processors.extend( [ - transport.input(), - ProductTextInputProcessor(), - stt, - ProductTranscriptStreamProcessor(), - user_aggregator, - llm, text_stream, tts, transport.output(), assistant_aggregator, ] ) + pipeline = Pipeline(processors) task = PipelineTask( pipeline, @@ -160,7 +174,14 @@ async def run_pipeline_with_serializer( if config.agent.greeting_mode == "fixed" and config.agent.greeting: await task.queue_frames([TTSSpeakFrame(config.agent.greeting)]) elif config.agent.greeting_mode == "generated": - await task.queue_frames([LLMRunFrame()]) + if isinstance(llm, FastGPTLLMService): + welcome = await llm.fetch_welcome_text() + if welcome: + await task.queue_frames([TTSSpeakFrame(welcome)]) + else: + await task.queue_frames([LLMRunFrame()]) + else: + await task.queue_frames([LLMRunFrame()]) @transport.event_handler("on_client_disconnected") async def on_client_disconnected(_transport, _client): @@ -192,14 +213,12 @@ async def run_pipeline_with_serializer( @assistant_aggregator.event_handler("on_assistant_turn_stopped") async def on_assistant_turn_stopped(_aggregator, message: AssistantTurnStoppedMessage): logger.info(f"Assistant: {message.content}") - if message.interrupted: - streamed = text_stream.take_interrupted_stream_text() - if streamed: - sync_streamed_assistant_context( - _aggregator, - streamed_text=streamed, - committed_text=message.content or "", - ) + maybe_sync_assistant_context( + _aggregator, + text_stream, + committed_text=message.content or "", + ) + text_stream.take_interrupted_stream_text() runner = PipelineRunner(handle_sigint=False) await runner.run(task) diff --git a/src/voice/protocol.py b/src/voice/protocol.py index 79d4473..6ee3633 100644 --- a/src/voice/protocol.py +++ b/src/voice/protocol.py @@ -1,6 +1,7 @@ from __future__ import annotations import base64 +import binascii import json from typing import Any @@ -19,10 +20,15 @@ from pipecat.frames.frames import ( OutputTransportMessageUrgentFrame, TextFrame, TranscriptionFrame, + UserImageRawFrame, ) from pipecat.serializers.base_serializer import FrameSerializer +MAX_INPUT_IMAGE_BYTES = 8 * 1024 * 1024 +SUPPORTED_INPUT_IMAGE_MIME_TYPES = {"image/jpeg", "image/png", "image/webp"} + + class ProductWebsocketSerializer(FrameSerializer): """Stable app-facing JSON/base64 protocol adapter for Pipecat websocket transport.""" @@ -118,7 +124,7 @@ class ProductWebsocketSerializer(FrameSerializer): return None try: pcm = base64.b64decode(audio) - except ValueError as exc: + except (binascii.Error, ValueError) as exc: logger.warning(f"Invalid input.audio base64: {exc}") return None return InputAudioRawFrame( @@ -127,6 +133,9 @@ class ProductWebsocketSerializer(FrameSerializer): num_channels=int(message.get("channels") or self._channels), ) + if message_type == "input.image": + return self._deserialize_input_image(message) + if message_type == "input.text": text = message.get("text") if not isinstance(text, str) or not text.strip(): @@ -147,6 +156,61 @@ class ProductWebsocketSerializer(FrameSerializer): logger.warning(f"Unsupported product websocket message type: {message_type!r}") return None + def _deserialize_input_image(self, message: dict[str, Any]) -> Frame | None: + encoded = message.get("image") or message.get("data") + if not isinstance(encoded, str): + logger.warning("input.image requires base64 'image' or 'data'") + return None + + mime_type = str(message.get("mime_type") or message.get("media_type") or "image/jpeg") + if mime_type not in SUPPORTED_INPUT_IMAGE_MIME_TYPES: + logger.warning( + "input.image unsupported mime_type " + f"{mime_type!r}; expected one of {sorted(SUPPORTED_INPUT_IMAGE_MIME_TYPES)}" + ) + return None + + try: + width = int(message.get("width") or 0) + height = int(message.get("height") or 0) + except (TypeError, ValueError): + logger.warning("input.image width and height must be integers") + return None + + if width <= 0 or height <= 0: + logger.warning("input.image requires positive integer width and height") + return None + + if "," in encoded and encoded.lstrip().startswith("data:"): + encoded = encoded.split(",", 1)[1] + + try: + image = base64.b64decode(encoded, validate=True) + except (binascii.Error, ValueError) as exc: + logger.warning(f"Invalid input.image base64: {exc}") + return None + + if len(image) > MAX_INPUT_IMAGE_BYTES: + logger.warning( + f"input.image too large: {len(image)} bytes; " + f"max is {MAX_INPUT_IMAGE_BYTES} bytes" + ) + return None + + text = message.get("text") + if text is not None and not isinstance(text, str): + logger.warning("input.image text must be a string when provided") + return None + + return UserImageRawFrame( + image=image, + size=(width, height), + format=mime_type, + user_id=str(message.get("user_id") or "product-user"), + text=text or "Answer using this camera image.", + append_to_context=bool(message.get("append_to_context", True)), + ) + def _event(self, event_type: str, **payload: Any) -> str: self._sequence += 1 return json.dumps( diff --git a/src/voice/response_state.py b/src/voice/response_state.py new file mode 100644 index 0000000..5983061 --- /dev/null +++ b/src/voice/response_state.py @@ -0,0 +1,136 @@ +from __future__ import annotations + +from pipecat.frames.frames import ( + CancelFrame, + Frame, + InterruptionFrame, + LLMFullResponseEndFrame, + LLMFullResponseStartFrame, + LLMTextFrame, + OutputTransportMessageUrgentFrame, +) +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor + +from .config import ResponseStateConfig + + +class StateTagResponseProcessor(FrameProcessor): + """Extract a leading state tag from LLM text before text streaming and TTS. + + Expected model output: + + some statespoken response + + The extracted state is emitted as a product protocol event, while only the + spoken response text is forwarded downstream. If the model does not produce + the tag, the original text is forwarded unchanged. + """ + + def __init__(self, config: ResponseStateConfig) -> None: + super().__init__() + self._tag = config.tag + self._event_type = config.event_type + self._max_prefix_chars = config.max_prefix_chars + self._opening_tag = f"<{self._tag}>" + self._closing_tag = f"" + self._start_frame: LLMFullResponseStartFrame | None = None + self._buffer = "" + self._decided = False + self._in_llm_response = False + + async def process_frame(self, frame: Frame, direction: FrameDirection) -> None: + await super().process_frame(frame, direction) + + if isinstance(frame, LLMFullResponseStartFrame): + self._start_frame = frame + self._buffer = "" + self._decided = False + self._in_llm_response = True + return + + if isinstance(frame, LLMTextFrame) and self._in_llm_response and not self._decided: + await self._process_initial_text(frame.text or "", direction) + return + + if isinstance(frame, LLMFullResponseEndFrame): + if self._in_llm_response: + await self._flush_buffer(direction) + await self.push_frame(frame, direction) + self._reset() + return + + if isinstance(frame, (InterruptionFrame, CancelFrame)): + if self._in_llm_response: + await self._flush_buffer(direction) + self._reset() + await self.push_frame(frame, direction) + return + + await self.push_frame(frame, direction) + + async def _process_initial_text(self, text: str, direction: FrameDirection) -> None: + if not text: + return + + self._buffer += text + decision = self._parse_buffer() + if decision is None: + return + + self._decided = True + state, response_text = decision + if state is not None: + await self._emit_state(state) + await self._push_start(direction) + if response_text: + await self.push_frame(LLMTextFrame(response_text), direction) + self._buffer = "" + + def _parse_buffer(self) -> tuple[str | None, str] | None: + stripped = self._buffer.lstrip() + if not stripped: + return None + + if stripped.startswith(self._opening_tag): + state_start = len(self._opening_tag) + state_end = stripped.find(self._closing_tag, state_start) + if state_end >= 0: + response_start = state_end + len(self._closing_tag) + return stripped[state_start:state_end].strip(), stripped[response_start:] + if len(self._buffer) < self._max_prefix_chars: + return None + return None, self._buffer + + if self._opening_tag.startswith(stripped) and len(self._buffer) < self._max_prefix_chars: + return None + + return None, self._buffer + + async def _flush_buffer(self, direction: FrameDirection) -> None: + await self._push_start(direction) + if self._buffer: + await self.push_frame(LLMTextFrame(self._buffer), direction) + self._buffer = "" + self._decided = True + + async def _push_start(self, direction: FrameDirection) -> None: + if self._start_frame: + await self.push_frame(self._start_frame, direction) + self._start_frame = None + + async def _emit_state(self, state: str) -> None: + await self.push_frame( + OutputTransportMessageUrgentFrame( + message={ + "type": self._event_type, + "state": state, + } + ), + FrameDirection.DOWNSTREAM, + ) + + def _reset(self) -> None: + self._start_frame = None + self._buffer = "" + self._decided = False + self._in_llm_response = False diff --git a/src/voice/services.py b/src/voice/services.py index 4272003..ae8652c 100644 --- a/src/voice/services.py +++ b/src/voice/services.py @@ -10,11 +10,13 @@ from pipecat.services.openai._constants import OPENAI_SAMPLE_RATE from pipecat.services.openai.llm import OpenAILLMService from pipecat.services.openai.stt import OpenAISTTService from pipecat.services.openai.tts import VALID_VOICES, OpenAITTSService +from pipecat.services.tts_service import TextAggregationMode from pipecat.transcriptions.language import Language from .config import AudioConfig, LLMConfig, STTConfig, TTSConfig from .fastgpt_llm import FastGPTLLMService, FastGPTLLMSettings from .xfyun_asr import DEFAULT_XFYUN_ASR_URL, XfyunASRService +from .xfyun_super_tts import DEFAULT_XFYUN_SUPER_TTS_URL, XfyunSuperTTSService from .xfyun_tts import DEFAULT_XFYUN_TTS_URL, XfyunTTSService @@ -54,12 +56,13 @@ def create_llm_service( session_variables: dict | None = None, greeting_prompt: str | None = None, ): - if config.provider == "fastgpt": + if config.is_fastgpt: variables = {**config.variables, **(session_variables or {})} return FastGPTLLMService( api_key=config.api_key, base_url=config.base_url or "http://localhost:3000", chat_id=chat_id or config.chat_id, + app_id=config.app_id, send_system_prompt=config.send_system_prompt, greeting_prompt=greeting_prompt, timeout=config.timeout_sec, @@ -70,7 +73,11 @@ def create_llm_service( ), ) - _require_provider(config.provider, "openai", "llm") + if not config.is_openai: + supported = ", ".join(sorted(("openai", "fastgpt", "llm"))) + raise ValueError( + f"Unsupported llm provider {config.provider!r}; expected one of: {supported}" + ) return OpenAILLMService( api_key=config.api_key or None, base_url=config.base_url, @@ -102,6 +109,30 @@ def create_tts_service(config: TTSConfig, audio: AudioConfig): timeout=config.timeout_sec, ) + if config.provider in ("xfyun_super", "xfyun_super_tts"): + source_sample_rate = config.source_sample_rate_hz or 24000 + if source_sample_rate not in (8000, 16000, 24000): + raise ValueError( + "Xfyun Super TTS source_sample_rate_hz must be 8000, 16000, or 24000" + ) + text_aggregation_mode = config.text_aggregation_mode or TextAggregationMode.TOKEN + return XfyunSuperTTSService( + app_id=config.app_id, + api_key=config.api_key or "", + api_secret=config.api_secret, + voice=config.voice, + url=config.base_url or DEFAULT_XFYUN_SUPER_TTS_URL, + sample_rate=audio.sample_rate_hz, + source_sample_rate=source_sample_rate, + encoding=config.aue, + speed=config.speed, + volume=config.volume, + pitch=config.pitch, + oral_level=config.oral_level, + text_aggregation_mode=text_aggregation_mode, + open_timeout=config.timeout_sec, + ) + _require_provider(config.provider, "openai", "tts") service_class = OpenAITTSService if config.voice in VALID_VOICES else OpenAICompatibleTTSService return service_class( diff --git a/src/voice/text_stream.py b/src/voice/text_stream.py index 2490959..2cd60ee 100644 --- a/src/voice/text_stream.py +++ b/src/voice/text_stream.py @@ -20,16 +20,31 @@ class _AssistantContextSync(Protocol): def context(self) -> Any: ... +def _committed_assistant_content(context: Any) -> str: + """Return trailing assistant text only when the last context message is assistant.""" + messages = context.get_messages() + if not messages: + return "" + last = messages[-1] + if not isinstance(last, dict) or last.get("role") != "assistant": + return "" + content = last.get("content") + if isinstance(content, str): + return content.strip() + return "" + + def sync_streamed_assistant_context( aggregator: _AssistantContextSync, *, streamed_text: str, committed_text: str, ) -> None: - """Align LLM context with UI text after an interrupted assistant turn. + """Align LLM context with urgent-streamed UI text. - The assistant aggregator only commits TTS-spoken text on interrupt. Replace - or append the streamed LLM text so the next turn sees what the user saw. + The assistant aggregator commits TTS-spoken text; ``ProductTextStreamProcessor`` + mirrors the LLM stream to the client. Replace or insert the streamed text so + the next turn sees what the user read on screen. """ streamed = streamed_text.strip() if not streamed or streamed == committed_text.strip(): @@ -39,19 +54,58 @@ def sync_streamed_assistant_context( def _apply(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: updated = list(messages) - if committed and updated: - last = updated[-1] - if isinstance(last, dict) and last.get("role") == "assistant": - content = last.get("content") - if isinstance(content, str) and content.strip() == committed: - updated[-1] = {"role": "assistant", "content": streamed} - return updated + if not updated: + updated.append({"role": "assistant", "content": streamed}) + return updated + + last = updated[-1] + if isinstance(last, dict) and last.get("role") == "assistant": + content = last.get("content") + if isinstance(content, str) and content.strip() != streamed: + updated[-1] = {"role": "assistant", "content": streamed} + return updated + + if ( + len(updated) >= 2 + and isinstance(last, dict) + and last.get("role") == "user" + ): + prev = updated[-2] + if isinstance(prev, dict) and prev.get("role") == "user": + updated.insert(len(updated) - 1, {"role": "assistant", "content": streamed}) + return updated + + if isinstance(last, dict) and last.get("role") == "user": + updated.append({"role": "assistant", "content": streamed}) + return updated + updated.append({"role": "assistant", "content": streamed}) return updated aggregator.context.transform_messages(_apply) +def maybe_sync_assistant_context( + aggregator: _AssistantContextSync, + text_stream: "ProductTextStreamProcessor", + *, + committed_text: str | None = None, +) -> None: + committed = ( + committed_text.strip() + if committed_text is not None + else _committed_assistant_content(aggregator.context) + ) + streamed = text_stream.last_assistant_stream_text() + if not streamed: + return + sync_streamed_assistant_context( + aggregator, + streamed_text=streamed, + committed_text=committed, + ) + + class ProductTextStreamProcessor(FrameProcessor): """Mirrors LLM text frames as streaming protocol events. @@ -72,8 +126,12 @@ class ProductTextStreamProcessor(FrameProcessor): super().__init__() self._aggregation: list[str] = [] self._turn_active = False + self._last_assistant_stream_text = "" self._interrupted_stream_text: str | None = None + def last_assistant_stream_text(self) -> str: + return self._last_assistant_stream_text + def take_interrupted_stream_text(self) -> str | None: text = self._interrupted_stream_text self._interrupted_stream_text = None @@ -94,7 +152,7 @@ class ProductTextStreamProcessor(FrameProcessor): await self._end_turn(interrupted=False) elif isinstance(frame, (InterruptionFrame, CancelFrame)): await self.push_frame(frame, direction) - await self._end_turn(interrupted=True) + await self._handle_interrupt() elif isinstance(frame, TTSSpeakFrame): text = frame.text or "" await self.push_frame(frame, direction) @@ -118,12 +176,24 @@ class ProductTextStreamProcessor(FrameProcessor): self._aggregation.append(text) await self._emit("response.text.delta", text=text) + async def _handle_interrupt(self) -> None: + if self._turn_active: + await self._end_turn(interrupted=True) + return + + if self._last_assistant_stream_text: + self._interrupted_stream_text = self._last_assistant_stream_text + async def _end_turn(self, *, interrupted: bool) -> None: if not self._turn_active: return + full_text = "".join(self._aggregation) + if full_text: + self._last_assistant_stream_text = full_text if interrupted and full_text: self._interrupted_stream_text = full_text + self._turn_active = False self._aggregation = [] await self._emit( diff --git a/src/voice/xfyun_super_tts.py b/src/voice/xfyun_super_tts.py new file mode 100644 index 0000000..59de414 --- /dev/null +++ b/src/voice/xfyun_super_tts.py @@ -0,0 +1,391 @@ +from __future__ import annotations + +import asyncio +import base64 +import hashlib +import hmac +import json +import os +from collections.abc import AsyncGenerator +from datetime import datetime, timezone +from email.utils import format_datetime +from typing import Any +from urllib.parse import urlencode, urlparse + +from loguru import logger + +from pipecat.frames.frames import ( + CancelFrame, + EndFrame, + ErrorFrame, + Frame, + StartFrame, + TTSAudioRawFrame, + TTSStoppedFrame, +) +from pipecat.services.settings import TTSSettings +from pipecat.services.tts_service import TextAggregationMode, WebsocketTTSService +from pipecat.utils.tracing.service_decorators import traced_tts + +try: + from websockets.asyncio.client import connect as websocket_connect + from websockets.protocol import State +except ModuleNotFoundError as exc: + logger.error(f"Exception: {exc}") + logger.error("In order to use Xfyun Super TTS, install the websockets package.") + raise Exception(f"Missing module: {exc}") from exc + +from .xfyun_tts import _sanitize_text_for_tts + + +DEFAULT_XFYUN_SUPER_TTS_URL = "wss://cbm01.cn-huabei-1.xf-yun.com/v1/private/mcd9m97e6" +VALID_SAMPLE_RATES = {8000, 16000, 24000} + + +class XfyunSuperTTSService(WebsocketTTSService): + """iFlytek/Xfyun Super Smart TTS using bidirectional WebSocket streaming. + + The service keeps one Xfyun synthesis session open for a Pipecat turn. Each + ``run_tts`` call sends a text segment with status 0/1, while ``flush_audio`` + sends the terminal status 2 frame. Audio arrives on the receive task and is + appended to the Pipecat audio context. + """ + + def __init__( + self, + *, + app_id: str, + api_key: str, + api_secret: str, + voice: str, + url: str | None = None, + sample_rate: int = 16000, + source_sample_rate: int = 24000, + encoding: str = "raw", + speed: int = 50, + volume: int = 50, + pitch: int = 50, + oral_level: str = "mid", + text_aggregation_mode: TextAggregationMode | str | None = TextAggregationMode.TOKEN, + open_timeout: float = 30.0, + **kwargs, + ) -> None: + if isinstance(text_aggregation_mode, str): + text_aggregation_mode = TextAggregationMode(text_aggregation_mode) + + super().__init__( + text_aggregation_mode=text_aggregation_mode, + push_text_frames=True, + push_stop_frames=False, + push_start_frame=True, + pause_frame_processing=False, + sample_rate=sample_rate, + settings=TTSSettings(model=None, voice=voice, language=None), + **kwargs, + ) + self._app_id = app_id or os.environ.get("XFYUN_APP_ID", "") + self._api_key = api_key or os.environ.get("XFYUN_API_KEY", "") + self._api_secret = api_secret or os.environ.get("XFYUN_API_SECRET", "") + self._voice = voice + self._url = url or DEFAULT_XFYUN_SUPER_TTS_URL + self._source_sample_rate = source_sample_rate + self._encoding = encoding + self._speed = speed + self._volume = volume + self._pitch = pitch + self._oral_level = oral_level + self._open_timeout = open_timeout + + self._receive_task: asyncio.Task | None = None + self._active_context_id: str | None = None + self._started_contexts: set[str] = set() + self._seq_by_context: dict[str, int] = {} + self._sent_text_bytes_by_context: dict[str, int] = {} + self._stream_completed = False + + def can_generate_metrics(self) -> bool: + return True + + async def start(self, frame: StartFrame) -> None: + await super().start(frame) + if not self._app_id or not self._api_key or not self._api_secret: + await self.push_error( + error_msg="Xfyun Super TTS requires app_id, api_key, and api_secret" + ) + return + if self._encoding != "raw": + await self.push_error(error_msg="Xfyun Super TTS must use raw PCM audio in Pipecat") + return + if self._source_sample_rate not in VALID_SAMPLE_RATES: + await self.push_error( + error_msg=( + "Xfyun Super TTS source_sample_rate must be one of " + f"{sorted(VALID_SAMPLE_RATES)}" + ) + ) + return + await self._connect() + + async def stop(self, frame: EndFrame) -> None: + await super().stop(frame) + await self._disconnect() + + async def cancel(self, frame: CancelFrame) -> None: + await super().cancel(frame) + await self._disconnect() + + async def flush_audio(self, context_id: str | None = None) -> None: + flush_id = context_id or self.get_active_audio_context_id() + if not flush_id or not self._websocket: + return + if flush_id not in self._started_contexts: + return + + logger.trace(f"{self}: flushing Xfyun Super TTS stream {flush_id}") + await self._send_request_frame(flush_id, "", status=2) + + async def on_audio_context_interrupted(self, context_id: str) -> None: + await self.stop_all_metrics() + await self._reset_context(context_id) + await self._disconnect() + await self._connect() + await super().on_audio_context_interrupted(context_id) + + async def _connect(self) -> None: + await super()._connect() + await self._connect_websocket() + if self._websocket and not self._receive_task: + self._receive_task = self.create_task(self._receive_task_handler(self._report_error)) + + async def _disconnect(self) -> None: + await super()._disconnect() + if self._receive_task: + await self.cancel_task(self._receive_task) + self._receive_task = None + await self._disconnect_websocket() + + async def _connect_websocket(self) -> None: + try: + if self._websocket and self._websocket.state is State.OPEN: + return + logger.debug("Connecting to Xfyun Super TTS") + auth_url = _build_auth_url(self._url, self._api_key, self._api_secret) + self._websocket = await websocket_connect( + auth_url, + max_size=None, + open_timeout=self._open_timeout, + ) + await self._call_event_handler("on_connected") + except Exception as exc: + self._websocket = None + await self.push_error( + error_msg=f"Unable to connect to Xfyun Super TTS: {exc}", + exception=exc, + ) + await self._call_event_handler("on_connection_error", f"{exc}") + + async def _disconnect_websocket(self) -> None: + try: + await self.stop_all_metrics() + if self._websocket: + logger.debug("Disconnecting from Xfyun Super TTS") + await self._websocket.close() + except Exception as exc: + await self.push_error( + error_msg=f"Error closing Xfyun Super TTS websocket: {exc}", + exception=exc, + ) + finally: + await self.remove_active_audio_context() + self._websocket = None + self._active_context_id = None + self._started_contexts.clear() + self._seq_by_context.clear() + self._sent_text_bytes_by_context.clear() + self._stream_completed = False + await self._call_event_handler("on_disconnected") + + def _get_websocket(self): + if self._websocket: + return self._websocket + raise Exception("Websocket not connected") + + async def _receive_messages(self) -> None: + async for raw_message in self._get_websocket(): + try: + message = json.loads(raw_message) + except json.JSONDecodeError: + logger.warning(f"{self}: received non-JSON Xfyun Super TTS message: {raw_message!r}") + continue + + header = message.get("header") or {} + code = header.get("code", -1) + sid = header.get("sid") + context_id = self._active_context_id + + if code != 0: + error_message = header.get("message", "unknown error") + await self.push_error( + error_msg=f"Xfyun Super TTS error code={code}, sid={sid}: {error_message}" + ) + if context_id and self.audio_context_available(context_id): + await self.append_to_audio_context( + context_id, TTSStoppedFrame(context_id=context_id) + ) + await self.remove_audio_context(context_id) + if context_id: + await self._reset_context(context_id) + continue + + audio_obj = (message.get("payload") or {}).get("audio") or {} + audio_b64 = audio_obj.get("audio") + if audio_b64 and context_id and self.audio_context_available(context_id): + await self.stop_ttfb_metrics() + audio = base64.b64decode(audio_b64) + if self._source_sample_rate != self.sample_rate: + audio = await self._resampler.resample( + audio, self._source_sample_rate, self.sample_rate + ) + frame = TTSAudioRawFrame(audio, self.sample_rate, 1, context_id=context_id) + await self.append_to_audio_context(context_id, frame) + + audio_status = audio_obj.get("status") + header_status = header.get("status") + if audio_status == 2 or header_status == 2: + if context_id and self.audio_context_available(context_id): + await self.append_to_audio_context( + context_id, TTSStoppedFrame(context_id=context_id) + ) + await self.remove_audio_context(context_id) + if context_id: + await self._reset_context(context_id) + self._stream_completed = True + + @traced_tts + async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame | None, None]: + sanitized = _sanitize_text_for_tts(text) + if not sanitized: + return + + if not self._is_streaming_tokens: + logger.debug(f"{self}: Generating Xfyun Super TTS [{sanitized}]") + else: + logger.trace(f"{self}: Generating Xfyun Super TTS [{sanitized}]") + + if self._stream_completed and self._websocket: + await self._disconnect() + await self._connect() + + if not self._websocket or self._websocket.state is State.CLOSED: + await self._connect() + + if self._active_context_id and self._active_context_id != context_id: + yield ErrorFrame( + error=( + "Xfyun Super TTS supports one active synthesis stream per WebSocket; " + f"active={self._active_context_id}, new={context_id}" + ) + ) + return + + try: + status = 0 if context_id not in self._started_contexts else 1 + await self._send_request_frame(context_id, sanitized, status=status) + await self.start_tts_usage_metrics(sanitized) + except Exception as exc: + yield ErrorFrame(error=f"Xfyun Super TTS request failed: {exc}") + yield TTSStoppedFrame(context_id=context_id) + await self._disconnect() + await self._connect() + return + + yield None + + async def _send_request_frame(self, context_id: str, text: str, *, status: int) -> None: + if status == 0: + self._active_context_id = context_id + self._started_contexts.add(context_id) + + seq = self._seq_by_context.get(context_id, 0) + text_bytes = text.encode("utf-8") + total_bytes = self._sent_text_bytes_by_context.get(context_id, 0) + len(text_bytes) + if total_bytes > 65536: + raise ValueError("Xfyun Super TTS text must not exceed 64K UTF-8 bytes per stream") + + frame = self._build_request_frame(text, status=status, seq=seq) + await self._get_websocket().send(json.dumps(frame, ensure_ascii=False)) + + self._seq_by_context[context_id] = seq + 1 + self._sent_text_bytes_by_context[context_id] = total_bytes + + def _build_request_frame(self, text: str, *, status: int, seq: int) -> dict[str, Any]: + return { + "header": { + "app_id": self._app_id, + "status": status, + }, + "parameter": { + "oral": { + "oral_level": self._oral_level, + }, + "tts": { + "vcn": self._voice, + "speed": self._speed, + "volume": self._volume, + "pitch": self._pitch, + "bgs": 0, + "reg": 0, + "rdn": 0, + "rhy": 0, + "audio": { + "encoding": self._encoding, + "sample_rate": self._source_sample_rate, + "channels": 1, + "bit_depth": 16, + "frame_size": 0, + }, + }, + }, + "payload": { + "text": { + "encoding": "utf8", + "compress": "raw", + "format": "plain", + "status": status, + "seq": seq, + "text": base64.b64encode(text.encode("utf-8")).decode("utf-8"), + }, + }, + } + + async def _reset_context(self, context_id: str) -> None: + self._started_contexts.discard(context_id) + self._seq_by_context.pop(context_id, None) + self._sent_text_bytes_by_context.pop(context_id, None) + if self._active_context_id == context_id: + self._active_context_id = None + + +def _build_auth_url(url: str, api_key: str, api_secret: str) -> str: + parsed = urlparse(url) + if parsed.scheme not in {"ws", "wss"} or not parsed.hostname: + raise ValueError(f"invalid Xfyun Super TTS WebSocket URL: {url}") + + host = parsed.hostname + path = parsed.path or "/" + date = format_datetime(datetime.now(timezone.utc), usegmt=True) + request_line = f"GET {path} HTTP/1.1" + signature_origin = f"host: {host}\ndate: {date}\n{request_line}" + signature_sha = hmac.new( + api_secret.encode("utf-8"), + signature_origin.encode("utf-8"), + digestmod=hashlib.sha256, + ).digest() + signature = base64.b64encode(signature_sha).decode("utf-8") + authorization_origin = ( + f'api_key="{api_key}", algorithm="hmac-sha256", ' + f'headers="host date request-line", signature="{signature}"' + ) + authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode("utf-8") + query = urlencode({"authorization": authorization, "date": date, "host": host}) + return f"{url}?{query}" diff --git a/static/voice-demo/app.js b/static/voice-demo/app.js index a41e216..e18caad 100644 --- a/static/voice-demo/app.js +++ b/static/voice-demo/app.js @@ -9,6 +9,7 @@ * as binary websocket messages. * - Play `response.audio.delta` frames gaplessly through Web Audio. * - Render a chat-style history of user transcripts and bot text deltas. + * - Collapse high-frequency audio frames into expandable websocket log groups. */ const SAMPLE_RATE = 16000; @@ -16,7 +17,11 @@ const CHANNELS = 1; const FRAME_MS = 20; const PROTOCOL = "va.ws.v1"; const MAX_WS_LOG_LINES = 120; -const AUDIO_DELTA_LOG_INTERVAL_MS = 1000; +const MAX_GROUP_CHILDREN_RENDER = 100; +const WS_LOG_GROUP_KEYS = { + AUDIO_DELTA: "recv:response.audio.delta", + AUDIO_SEND: "send:input.audio", +}; function defaultWsUrl() { const scheme = location.protocol === "https:" ? "wss:" : "ws:"; @@ -34,6 +39,8 @@ const els = { micLabel: document.querySelector(".mic-btn__label"), micIndicator: document.getElementById("mic-indicator"), botIndicator: document.getElementById("bot-indicator"), + stateIndicator: document.getElementById("state-indicator"), + stateLabel: document.getElementById("state-label"), clearBtn: document.getElementById("clear-btn"), clearWsLogBtn: document.getElementById("clear-ws-log-btn"), wsLog: document.getElementById("ws-log"), @@ -66,17 +73,13 @@ const state = { // Chat state. currentAssistantBubble: null, + assistantState: "", // VU meter smoothing. meterLevel: 0, - // Compact websocket logging. - audioDeltaLogCount: 0, - audioDeltaLogBytes: 0, - lastAudioDeltaLogAt: 0, - audioSendLogCount: 0, - audioSendLogBytes: 0, - lastAudioSendLogAt: 0, + // Collapsible websocket log groups for high-frequency audio frames. + wsLogGroup: null, }; /* ------------------------------------------------------------------ UI */ @@ -123,6 +126,15 @@ function setBotIndicator(active) { els.botIndicator.classList.toggle("is-active", active); } +function setAssistantState(value) { + const text = typeof value === "string" ? value.trim() : ""; + const label = text.length > 32 ? `${text.slice(0, 31)}…` : text; + state.assistantState = text; + els.stateIndicator.classList.toggle("is-active", Boolean(text)); + els.stateLabel.textContent = label ? `State ${label}` : "State -"; + els.stateIndicator.title = label ? `Assistant state: ${text}` : "Assistant state"; +} + function addBubble(role, text) { if (els.chatLog.querySelector(".chat__empty")) { els.chatLog.innerHTML = ""; @@ -157,6 +169,7 @@ function scrollChatToBottom() { function clearChat() { els.chatLog.innerHTML = ""; state.currentAssistantBubble = null; + setAssistantState(""); const empty = document.createElement("div"); empty.className = "chat__empty"; empty.innerHTML = "

Chat cleared.

"; @@ -169,6 +182,209 @@ function truncateLogValue(value, maxLength = 160) { return `${text.slice(0, maxLength - 1)}…`; } +function formatLogTime(date = new Date()) { + return date.toLocaleTimeString([], { + hour12: false, + hour: "2-digit", + minute: "2-digit", + second: "2-digit", + }); +} + +function formatLogBytes(byteCount) { + if (byteCount >= 1048576) { + return `${(byteCount / 1048576).toFixed(2)} MB`; + } + if (byteCount >= 1024) { + return `${(byteCount / 1024).toFixed(1)} KB`; + } + return `${byteCount} bytes`; +} + +function wsLogGroupLabel(groupKey) { + if (groupKey === WS_LOG_GROUP_KEYS.AUDIO_DELTA) { + return "response.audio.delta"; + } + if (groupKey === WS_LOG_GROUP_KEYS.AUDIO_SEND) { + return "input.audio binary"; + } + return "grouped events"; +} + +function ensureWsLogReady() { + if (els.wsLog.querySelector(".ws-log__empty")) { + els.wsLog.innerHTML = ""; + } +} + +function scrollWsLogToBottom() { + els.wsLog.scrollTop = els.wsLog.scrollHeight; +} + +function trimWsLog() { + while (els.wsLog.children.length > MAX_WS_LOG_LINES) { + const first = els.wsLog.firstElementChild; + if (state.wsLogGroup?.element === first) { + state.wsLogGroup = null; + } + first.remove(); + } +} + +function finalizeWsLogGroup() { + state.wsLogGroup = null; +} + +function createWsLogEntry(direction, detail, kind, timeText = formatLogTime()) { + const entry = document.createElement("div"); + entry.className = `ws-log__entry ws-log__entry--${kind}`; + + const time = document.createElement("span"); + time.className = "ws-log__time"; + time.textContent = timeText; + + const dir = document.createElement("span"); + dir.className = "ws-log__direction"; + dir.textContent = + direction === "send" + ? "SEND" + : direction === "recv" + ? "RECV" + : direction.toUpperCase(); + + const body = document.createElement("span"); + body.className = "ws-log__detail"; + body.textContent = detail; + + entry.append(time, dir, body); + return entry; +} + +function updateWsLogGroupSummary(group) { + group.summaryEl.textContent = `${wsLogGroupLabel(group.key)} ×${group.count} (${formatLogBytes(group.totalBytes)})`; +} + +function appendWsLogGroupChildDom(group, item) { + const entry = createWsLogEntry( + group.direction, + item.detail, + group.kind, + item.time, + ); + entry.classList.add("ws-log__entry--child"); + group.childrenEl.appendChild(entry); + + const childEntries = group.childrenEl.querySelectorAll(".ws-log__entry"); + if (childEntries.length > MAX_GROUP_CHILDREN_RENDER) { + const omit = group.childrenEl.querySelector(".ws-log__group-omit"); + if (!omit) { + const omitted = document.createElement("div"); + omitted.className = "ws-log__group-omit"; + omitted.textContent = "… earlier events omitted"; + group.childrenEl.insertBefore(omitted, group.childrenEl.firstElementChild); + } + childEntries[0].remove(); + } +} + +function renderWsLogGroupChildren(group) { + group.childrenEl.innerHTML = ""; + const items = group.items; + const start = Math.max(0, items.length - MAX_GROUP_CHILDREN_RENDER); + if (start > 0) { + const omitted = document.createElement("div"); + omitted.className = "ws-log__group-omit"; + omitted.textContent = `… ${start} earlier events omitted`; + group.childrenEl.appendChild(omitted); + } + for (let i = start; i < items.length; i += 1) { + appendWsLogGroupChildDom(group, items[i]); + } +} + +function toggleWsLogGroup(group) { + group.collapsed = !group.collapsed; + group.childrenEl.hidden = group.collapsed; + group.chevronEl.textContent = group.collapsed ? "▶" : "▼"; + group.headerEl.setAttribute("aria-expanded", group.collapsed ? "false" : "true"); + + if (!group.collapsed && group.childrenEl.childElementCount === 0) { + renderWsLogGroupChildren(group); + } +} + +function appendWsLogGroupItem(groupKey, direction, kind, itemDetail, byteCount = 0) { + ensureWsLogReady(); + + let group = state.wsLogGroup; + if (!group || group.key !== groupKey) { + finalizeWsLogGroup(); + + const groupEl = document.createElement("div"); + groupEl.className = `ws-log__group ws-log__group--${kind}`; + + const header = document.createElement("button"); + header.type = "button"; + header.className = "ws-log__group-header"; + header.setAttribute("aria-expanded", "false"); + + const time = document.createElement("span"); + time.className = "ws-log__time"; + time.textContent = formatLogTime(); + + const dir = document.createElement("span"); + dir.className = "ws-log__direction"; + dir.textContent = direction === "send" ? "SEND" : "RECV"; + + const chevron = document.createElement("span"); + chevron.className = "ws-log__group-chevron"; + chevron.textContent = "▶"; + chevron.setAttribute("aria-hidden", "true"); + + const summary = document.createElement("span"); + summary.className = "ws-log__group-summary"; + + header.append(time, dir, chevron, summary); + + const children = document.createElement("div"); + children.className = "ws-log__group-children"; + children.hidden = true; + + groupEl.append(header, children); + els.wsLog.appendChild(groupEl); + + group = { + key: groupKey, + direction, + kind, + element: groupEl, + headerEl: header, + chevronEl: chevron, + summaryEl: summary, + childrenEl: children, + collapsed: true, + count: 0, + totalBytes: 0, + items: [], + }; + state.wsLogGroup = group; + header.addEventListener("click", () => toggleWsLogGroup(group)); + } + + group.count += 1; + group.totalBytes += byteCount; + const item = { time: formatLogTime(), detail: itemDetail }; + group.items.push(item); + updateWsLogGroupSummary(group); + + if (!group.collapsed) { + appendWsLogGroupChildDom(group, item); + } + + trimWsLog(); + scrollWsLogToBottom(); +} + function compactWsPayload(payload) { if (!payload || typeof payload !== "object") return String(payload); const compact = { ...payload }; @@ -191,85 +407,27 @@ function compactWsPayload(payload) { } function addWsLog(direction, detail, kind = direction) { - if (els.wsLog.querySelector(".ws-log__empty")) { - els.wsLog.innerHTML = ""; - } - - const entry = document.createElement("div"); - entry.className = `ws-log__entry ws-log__entry--${kind}`; - - const time = document.createElement("span"); - time.className = "ws-log__time"; - time.textContent = new Date().toLocaleTimeString([], { - hour12: false, - hour: "2-digit", - minute: "2-digit", - second: "2-digit", - }); - - const dir = document.createElement("span"); - dir.className = "ws-log__direction"; - dir.textContent = - direction === "send" - ? "SEND" - : direction === "recv" - ? "RECV" - : direction.toUpperCase(); - - const body = document.createElement("span"); - body.className = "ws-log__detail"; - body.textContent = detail; - - entry.append(time, dir, body); - els.wsLog.appendChild(entry); - - while (els.wsLog.children.length > MAX_WS_LOG_LINES) { - els.wsLog.firstElementChild.remove(); - } - els.wsLog.scrollTop = els.wsLog.scrollHeight; -} - -function flushAudioDeltaLog() { - if (state.audioDeltaLogCount === 0) return; - addWsLog( - "recv", - `response.audio.delta x${state.audioDeltaLogCount} (${state.audioDeltaLogBytes} bytes)`, - ); - state.audioDeltaLogCount = 0; - state.audioDeltaLogBytes = 0; - state.lastAudioDeltaLogAt = performance.now(); -} - -function flushAudioSendLog() { - if (state.audioSendLogCount === 0) return; - addWsLog( - "send", - `input.audio binary x${state.audioSendLogCount} (${state.audioSendLogBytes} bytes)`, - ); - state.audioSendLogCount = 0; - state.audioSendLogBytes = 0; - state.lastAudioSendLogAt = performance.now(); -} - -function flushPendingWsLogs() { - flushAudioDeltaLog(); - flushAudioSendLog(); + finalizeWsLogGroup(); + ensureWsLogReady(); + els.wsLog.appendChild(createWsLogEntry(direction, detail, kind)); + trimWsLog(); + scrollWsLogToBottom(); } function logWsPayload(direction, payload) { - if (direction === "send") { - flushAudioSendLog(); - } else { - flushAudioDeltaLog(); - } - if (direction === "recv" && payload?.type === "response.audio.delta") { - state.audioDeltaLogCount += 1; - state.audioDeltaLogBytes += payload.bytes || payload.audio?.length || 0; - const now = performance.now(); - if (now - state.lastAudioDeltaLogAt >= AUDIO_DELTA_LOG_INTERVAL_MS) { - flushAudioDeltaLog(); - } + const bytes = payload.bytes || 0; + const detail = + payload.seq != null + ? `seq=${payload.seq} (${bytes} bytes)` + : `(${bytes} bytes)`; + appendWsLogGroupItem( + WS_LOG_GROUP_KEYS.AUDIO_DELTA, + "recv", + "recv", + detail, + bytes, + ); return; } @@ -277,12 +435,13 @@ function logWsPayload(direction, payload) { } function logBinarySend(byteLength) { - state.audioSendLogCount += 1; - state.audioSendLogBytes += byteLength; - const now = performance.now(); - if (now - state.lastAudioSendLogAt >= AUDIO_DELTA_LOG_INTERVAL_MS) { - flushAudioSendLog(); - } + appendWsLogGroupItem( + WS_LOG_GROUP_KEYS.AUDIO_SEND, + "send", + "send", + `(${byteLength} bytes)`, + byteLength, + ); } function wsSend(data) { @@ -292,8 +451,6 @@ function wsSend(data) { try { logWsPayload("send", JSON.parse(data)); } catch (_) { - flushAudioSendLog(); - flushAudioDeltaLog(); addWsLog("send", truncateLogValue(data)); } } else { @@ -313,10 +470,7 @@ function wsSend(data) { } function clearWsLog() { - state.audioDeltaLogCount = 0; - state.audioDeltaLogBytes = 0; - state.audioSendLogCount = 0; - state.audioSendLogBytes = 0; + state.wsLogGroup = null; els.wsLog.innerHTML = '
No websocket events yet.
'; } @@ -450,7 +604,6 @@ function stopMic() { state.micEnabled = false; updateMeter(0); if (wasEnabled) { - flushAudioSendLog(); addWsLog("system", "mic capture stopped"); } setMicButton(); @@ -629,6 +782,9 @@ function handleEvent(event) { case "response.text.final": handleAssistantFinal(event.text, event.interrupted); break; + case "response.state": + setAssistantState(event.state); + break; case "input.transcript.final": handleUserTranscript(event.text); break; @@ -745,6 +901,7 @@ async function connect() { state.ws = null; state.connected = false; state.connecting = false; + setAssistantState(""); if (state.micEnabled) stopMic(); stopPlaybackQueue(); setConnectButton(); @@ -752,7 +909,7 @@ async function connect() { setMicSelectEnabled(); setComposerEnabled(false); setBotIndicator(false); - flushPendingWsLogs(); + finalizeWsLogGroup(); addWsLog( "system", `websocket close code=${event.code}${ diff --git a/static/voice-demo/index.html b/static/voice-demo/index.html index deef69e..85b1469 100644 --- a/static/voice-demo/index.html +++ b/static/voice-demo/index.html @@ -118,6 +118,10 @@ Bot + + + State - +