From d94da4b4d1bcead20487105022c25e4e5af45682 Mon Sep 17 00:00:00 2001 From: Xin Wang Date: Thu, 28 May 2026 09:36:44 +0800 Subject: [PATCH] Add support for prefix state code --- README.md | 6 ++ config.json | 8 ++- docs/product-ws.md | 1 + engine/config.py | 18 +++++- engine/pipeline.py | 22 ++++--- engine/response_state.py | 136 +++++++++++++++++++++++++++++++++++++++ 6 files changed, 181 insertions(+), 10 deletions(-) create mode 100644 engine/response_state.py diff --git a/README.md b/README.md index cd72d31..e51943e 100644 --- a/README.md +++ b/README.md @@ -148,6 +148,7 @@ Returned transcripts and assistant text: ```json {"type": "input.transcript.interim", "text": "What's the"} {"type": "input.transcript.final", "text": "What's the weather?", "user_id": "...", "timestamp": "..."} +{"type": "response.state", "state": "speaking"} {"type": "response.text.started"} {"type": "response.text.delta", "text": "It's "} {"type": "response.text.delta", "text": "sunny in "} @@ -163,6 +164,11 @@ the TTS in the pipeline. `response.text.final` fires when the turn ends, carrying the full concatenated assistant text and an `interrupted` flag (true when an `input.text` or barge-in cut the turn short). +When `agent.response_state.enabled` is true, an LLM response that starts with +`...` emits the tag body as `response.state` before the +remaining assistant text is streamed and spoken. If the tag is missing or +malformed, the original response text is streamed unchanged. + ### Turn detection User-turn segmentation (VAD thresholds + how long to wait after silence diff --git a/config.json b/config.json index 2f2a8fb..aab655b 100644 --- a/config.json +++ b/config.json @@ -57,7 +57,13 @@ "agent": { "system_prompt": "You are a helpful, friendly voice assistant. Keep responses concise and natural for spoken conversation.", "greeting": "Please introduce yourself briefly.", - "greeting_mode": "generated" + "greeting_mode": "generated", + "response_state": { + "enabled": false, + "tag": "state", + "event_type": "response.state", + "max_prefix_chars": 256 + } }, "services": { "stt": { diff --git a/docs/product-ws.md b/docs/product-ws.md index 501c5db..7c25c54 100644 --- a/docs/product-ws.md +++ b/docs/product-ws.md @@ -119,6 +119,7 @@ ws://:/ws-product |----------|------| | `input.transcript.interim` | 用户语音中间转写(流式 ASR 时) | | `input.transcript.final` | 用户语音最终转写 | +| `response.state` | 可选;当 `agent.response_state.enabled` 开启且 LLM 回复以 `...` 开头时发送 | | `response.text.started` | 助手开始回复 | | `response.text.delta` | 助手文本流式片段(通常早于 TTS 音频) | | `response.text.final` | 助手本轮回复结束,见下方 `interrupted` 说明 | diff --git a/engine/config.py b/engine/config.py index 3008e16..1d32c4b 100644 --- a/engine/config.py +++ b/engine/config.py @@ -107,11 +107,20 @@ class TurnConfig: ) +@dataclass(frozen=True) +class ResponseStateConfig: + enabled: bool = False + tag: str = "state" + event_type: str = "response.state" + max_prefix_chars: int = 256 + + @dataclass(frozen=True) class AgentConfig: system_prompt: str = "You are a helpful, friendly voice assistant." greeting: str | None = None greeting_mode: str = "generated" + response_state: ResponseStateConfig = field(default_factory=ResponseStateConfig) @dataclass(frozen=True) @@ -220,6 +229,13 @@ def config_from_dict(data: dict) -> EngineConfig: agent["greeting"] = None if agent.get("greeting_mode") not in (None, "generated", "fixed", "off"): raise ValueError("agent.greeting_mode must be one of: generated, fixed, off") + response_state = ResponseStateConfig(**_dict(agent.pop("response_state"))) + if response_state.max_prefix_chars < 1: + raise ValueError("agent.response_state.max_prefix_chars must be greater than 0") + if not response_state.tag: + raise ValueError("agent.response_state.tag must not be empty") + if not response_state.event_type: + raise ValueError("agent.response_state.event_type must not be empty") stt = _dict(services.get("stt") or services.get("asr")) if stt.get("language") == "": @@ -260,7 +276,7 @@ def config_from_dict(data: dict) -> EngineConfig: ) ), ), - agent=AgentConfig(**agent), + agent=AgentConfig(**agent, response_state=response_state), services=ServicesConfig( llm=LLMConfig(**llm), stt=STTConfig(**stt), diff --git a/engine/pipeline.py b/engine/pipeline.py index 97dc646..ec2519f 100644 --- a/engine/pipeline.py +++ b/engine/pipeline.py @@ -37,6 +37,7 @@ from .config import EngineConfig from .context_sync import AssistantContextSyncProcessor from .fastgpt_llm import FastGPTLLMService from .product_protocol import ProductWebsocketSerializer +from .response_state import StateTagResponseProcessor from .services import create_llm_service, create_stt_service, create_tts_service from .text_input import ProductTextInputProcessor from .text_stream import ProductTextStreamProcessor, maybe_sync_assistant_context @@ -153,21 +154,26 @@ async def run_pipeline_with_serializer( assistant_aggregator=assistant_aggregator, ) - pipeline = Pipeline( + processors = [ + transport.input(), + ProductTextInputProcessor(), + stt, + ProductTranscriptStreamProcessor(), + context_sync, + user_aggregator, + llm, + ] + if config.agent.response_state.enabled: + processors.append(StateTagResponseProcessor(config.agent.response_state)) + processors.extend( [ - transport.input(), - ProductTextInputProcessor(), - stt, - ProductTranscriptStreamProcessor(), - context_sync, - user_aggregator, - llm, text_stream, tts, transport.output(), assistant_aggregator, ] ) + pipeline = Pipeline(processors) task = PipelineTask( pipeline, diff --git a/engine/response_state.py b/engine/response_state.py new file mode 100644 index 0000000..5983061 --- /dev/null +++ b/engine/response_state.py @@ -0,0 +1,136 @@ +from __future__ import annotations + +from pipecat.frames.frames import ( + CancelFrame, + Frame, + InterruptionFrame, + LLMFullResponseEndFrame, + LLMFullResponseStartFrame, + LLMTextFrame, + OutputTransportMessageUrgentFrame, +) +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor + +from .config import ResponseStateConfig + + +class StateTagResponseProcessor(FrameProcessor): + """Extract a leading state tag from LLM text before text streaming and TTS. + + Expected model output: + + some statespoken response + + The extracted state is emitted as a product protocol event, while only the + spoken response text is forwarded downstream. If the model does not produce + the tag, the original text is forwarded unchanged. + """ + + def __init__(self, config: ResponseStateConfig) -> None: + super().__init__() + self._tag = config.tag + self._event_type = config.event_type + self._max_prefix_chars = config.max_prefix_chars + self._opening_tag = f"<{self._tag}>" + self._closing_tag = f"" + self._start_frame: LLMFullResponseStartFrame | None = None + self._buffer = "" + self._decided = False + self._in_llm_response = False + + async def process_frame(self, frame: Frame, direction: FrameDirection) -> None: + await super().process_frame(frame, direction) + + if isinstance(frame, LLMFullResponseStartFrame): + self._start_frame = frame + self._buffer = "" + self._decided = False + self._in_llm_response = True + return + + if isinstance(frame, LLMTextFrame) and self._in_llm_response and not self._decided: + await self._process_initial_text(frame.text or "", direction) + return + + if isinstance(frame, LLMFullResponseEndFrame): + if self._in_llm_response: + await self._flush_buffer(direction) + await self.push_frame(frame, direction) + self._reset() + return + + if isinstance(frame, (InterruptionFrame, CancelFrame)): + if self._in_llm_response: + await self._flush_buffer(direction) + self._reset() + await self.push_frame(frame, direction) + return + + await self.push_frame(frame, direction) + + async def _process_initial_text(self, text: str, direction: FrameDirection) -> None: + if not text: + return + + self._buffer += text + decision = self._parse_buffer() + if decision is None: + return + + self._decided = True + state, response_text = decision + if state is not None: + await self._emit_state(state) + await self._push_start(direction) + if response_text: + await self.push_frame(LLMTextFrame(response_text), direction) + self._buffer = "" + + def _parse_buffer(self) -> tuple[str | None, str] | None: + stripped = self._buffer.lstrip() + if not stripped: + return None + + if stripped.startswith(self._opening_tag): + state_start = len(self._opening_tag) + state_end = stripped.find(self._closing_tag, state_start) + if state_end >= 0: + response_start = state_end + len(self._closing_tag) + return stripped[state_start:state_end].strip(), stripped[response_start:] + if len(self._buffer) < self._max_prefix_chars: + return None + return None, self._buffer + + if self._opening_tag.startswith(stripped) and len(self._buffer) < self._max_prefix_chars: + return None + + return None, self._buffer + + async def _flush_buffer(self, direction: FrameDirection) -> None: + await self._push_start(direction) + if self._buffer: + await self.push_frame(LLMTextFrame(self._buffer), direction) + self._buffer = "" + self._decided = True + + async def _push_start(self, direction: FrameDirection) -> None: + if self._start_frame: + await self.push_frame(self._start_frame, direction) + self._start_frame = None + + async def _emit_state(self, state: str) -> None: + await self.push_frame( + OutputTransportMessageUrgentFrame( + message={ + "type": self._event_type, + "state": state, + } + ), + FrameDirection.DOWNSTREAM, + ) + + def _reset(self) -> None: + self._start_frame = None + self._buffer = "" + self._decided = False + self._in_llm_response = False