Add support for prefix state code
This commit is contained in:
@@ -148,6 +148,7 @@ Returned transcripts and assistant text:
|
||||
```json
|
||||
{"type": "input.transcript.interim", "text": "What's the"}
|
||||
{"type": "input.transcript.final", "text": "What's the weather?", "user_id": "...", "timestamp": "..."}
|
||||
{"type": "response.state", "state": "speaking"}
|
||||
{"type": "response.text.started"}
|
||||
{"type": "response.text.delta", "text": "It's "}
|
||||
{"type": "response.text.delta", "text": "sunny in "}
|
||||
@@ -163,6 +164,11 @@ the TTS in the pipeline. `response.text.final` fires when the turn ends,
|
||||
carrying the full concatenated assistant text and an `interrupted` flag
|
||||
(true when an `input.text` or barge-in cut the turn short).
|
||||
|
||||
When `agent.response_state.enabled` is true, an LLM response that starts with
|
||||
`<state>...</state>` emits the tag body as `response.state` before the
|
||||
remaining assistant text is streamed and spoken. If the tag is missing or
|
||||
malformed, the original response text is streamed unchanged.
|
||||
|
||||
### Turn detection
|
||||
|
||||
User-turn segmentation (VAD thresholds + how long to wait after silence
|
||||
|
||||
@@ -57,7 +57,13 @@
|
||||
"agent": {
|
||||
"system_prompt": "You are a helpful, friendly voice assistant. Keep responses concise and natural for spoken conversation.",
|
||||
"greeting": "Please introduce yourself briefly.",
|
||||
"greeting_mode": "generated"
|
||||
"greeting_mode": "generated",
|
||||
"response_state": {
|
||||
"enabled": false,
|
||||
"tag": "state",
|
||||
"event_type": "response.state",
|
||||
"max_prefix_chars": 256
|
||||
}
|
||||
},
|
||||
"services": {
|
||||
"stt": {
|
||||
|
||||
@@ -119,6 +119,7 @@ ws://<host>:<port>/ws-product
|
||||
|----------|------|
|
||||
| `input.transcript.interim` | 用户语音中间转写(流式 ASR 时) |
|
||||
| `input.transcript.final` | 用户语音最终转写 |
|
||||
| `response.state` | 可选;当 `agent.response_state.enabled` 开启且 LLM 回复以 `<state>...</state>` 开头时发送 |
|
||||
| `response.text.started` | 助手开始回复 |
|
||||
| `response.text.delta` | 助手文本流式片段(通常早于 TTS 音频) |
|
||||
| `response.text.final` | 助手本轮回复结束,见下方 `interrupted` 说明 |
|
||||
|
||||
@@ -107,11 +107,20 @@ class TurnConfig:
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ResponseStateConfig:
|
||||
enabled: bool = False
|
||||
tag: str = "state"
|
||||
event_type: str = "response.state"
|
||||
max_prefix_chars: int = 256
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AgentConfig:
|
||||
system_prompt: str = "You are a helpful, friendly voice assistant."
|
||||
greeting: str | None = None
|
||||
greeting_mode: str = "generated"
|
||||
response_state: ResponseStateConfig = field(default_factory=ResponseStateConfig)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -220,6 +229,13 @@ def config_from_dict(data: dict) -> EngineConfig:
|
||||
agent["greeting"] = None
|
||||
if agent.get("greeting_mode") not in (None, "generated", "fixed", "off"):
|
||||
raise ValueError("agent.greeting_mode must be one of: generated, fixed, off")
|
||||
response_state = ResponseStateConfig(**_dict(agent.pop("response_state")))
|
||||
if response_state.max_prefix_chars < 1:
|
||||
raise ValueError("agent.response_state.max_prefix_chars must be greater than 0")
|
||||
if not response_state.tag:
|
||||
raise ValueError("agent.response_state.tag must not be empty")
|
||||
if not response_state.event_type:
|
||||
raise ValueError("agent.response_state.event_type must not be empty")
|
||||
|
||||
stt = _dict(services.get("stt") or services.get("asr"))
|
||||
if stt.get("language") == "":
|
||||
@@ -260,7 +276,7 @@ def config_from_dict(data: dict) -> EngineConfig:
|
||||
)
|
||||
),
|
||||
),
|
||||
agent=AgentConfig(**agent),
|
||||
agent=AgentConfig(**agent, response_state=response_state),
|
||||
services=ServicesConfig(
|
||||
llm=LLMConfig(**llm),
|
||||
stt=STTConfig(**stt),
|
||||
|
||||
@@ -37,6 +37,7 @@ from .config import EngineConfig
|
||||
from .context_sync import AssistantContextSyncProcessor
|
||||
from .fastgpt_llm import FastGPTLLMService
|
||||
from .product_protocol import ProductWebsocketSerializer
|
||||
from .response_state import StateTagResponseProcessor
|
||||
from .services import create_llm_service, create_stt_service, create_tts_service
|
||||
from .text_input import ProductTextInputProcessor
|
||||
from .text_stream import ProductTextStreamProcessor, maybe_sync_assistant_context
|
||||
@@ -153,21 +154,26 @@ async def run_pipeline_with_serializer(
|
||||
assistant_aggregator=assistant_aggregator,
|
||||
)
|
||||
|
||||
pipeline = Pipeline(
|
||||
processors = [
|
||||
transport.input(),
|
||||
ProductTextInputProcessor(),
|
||||
stt,
|
||||
ProductTranscriptStreamProcessor(),
|
||||
context_sync,
|
||||
user_aggregator,
|
||||
llm,
|
||||
]
|
||||
if config.agent.response_state.enabled:
|
||||
processors.append(StateTagResponseProcessor(config.agent.response_state))
|
||||
processors.extend(
|
||||
[
|
||||
transport.input(),
|
||||
ProductTextInputProcessor(),
|
||||
stt,
|
||||
ProductTranscriptStreamProcessor(),
|
||||
context_sync,
|
||||
user_aggregator,
|
||||
llm,
|
||||
text_stream,
|
||||
tts,
|
||||
transport.output(),
|
||||
assistant_aggregator,
|
||||
]
|
||||
)
|
||||
pipeline = Pipeline(processors)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
|
||||
136
engine/response_state.py
Normal file
136
engine/response_state.py
Normal file
@@ -0,0 +1,136 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMTextFrame,
|
||||
OutputTransportMessageUrgentFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
from .config import ResponseStateConfig
|
||||
|
||||
|
||||
class StateTagResponseProcessor(FrameProcessor):
|
||||
"""Extract a leading state tag from LLM text before text streaming and TTS.
|
||||
|
||||
Expected model output:
|
||||
|
||||
<state>some state</state>spoken response
|
||||
|
||||
The extracted state is emitted as a product protocol event, while only the
|
||||
spoken response text is forwarded downstream. If the model does not produce
|
||||
the tag, the original text is forwarded unchanged.
|
||||
"""
|
||||
|
||||
def __init__(self, config: ResponseStateConfig) -> None:
|
||||
super().__init__()
|
||||
self._tag = config.tag
|
||||
self._event_type = config.event_type
|
||||
self._max_prefix_chars = config.max_prefix_chars
|
||||
self._opening_tag = f"<{self._tag}>"
|
||||
self._closing_tag = f"</{self._tag}>"
|
||||
self._start_frame: LLMFullResponseStartFrame | None = None
|
||||
self._buffer = ""
|
||||
self._decided = False
|
||||
self._in_llm_response = False
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection) -> None:
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, LLMFullResponseStartFrame):
|
||||
self._start_frame = frame
|
||||
self._buffer = ""
|
||||
self._decided = False
|
||||
self._in_llm_response = True
|
||||
return
|
||||
|
||||
if isinstance(frame, LLMTextFrame) and self._in_llm_response and not self._decided:
|
||||
await self._process_initial_text(frame.text or "", direction)
|
||||
return
|
||||
|
||||
if isinstance(frame, LLMFullResponseEndFrame):
|
||||
if self._in_llm_response:
|
||||
await self._flush_buffer(direction)
|
||||
await self.push_frame(frame, direction)
|
||||
self._reset()
|
||||
return
|
||||
|
||||
if isinstance(frame, (InterruptionFrame, CancelFrame)):
|
||||
if self._in_llm_response:
|
||||
await self._flush_buffer(direction)
|
||||
self._reset()
|
||||
await self.push_frame(frame, direction)
|
||||
return
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _process_initial_text(self, text: str, direction: FrameDirection) -> None:
|
||||
if not text:
|
||||
return
|
||||
|
||||
self._buffer += text
|
||||
decision = self._parse_buffer()
|
||||
if decision is None:
|
||||
return
|
||||
|
||||
self._decided = True
|
||||
state, response_text = decision
|
||||
if state is not None:
|
||||
await self._emit_state(state)
|
||||
await self._push_start(direction)
|
||||
if response_text:
|
||||
await self.push_frame(LLMTextFrame(response_text), direction)
|
||||
self._buffer = ""
|
||||
|
||||
def _parse_buffer(self) -> tuple[str | None, str] | None:
|
||||
stripped = self._buffer.lstrip()
|
||||
if not stripped:
|
||||
return None
|
||||
|
||||
if stripped.startswith(self._opening_tag):
|
||||
state_start = len(self._opening_tag)
|
||||
state_end = stripped.find(self._closing_tag, state_start)
|
||||
if state_end >= 0:
|
||||
response_start = state_end + len(self._closing_tag)
|
||||
return stripped[state_start:state_end].strip(), stripped[response_start:]
|
||||
if len(self._buffer) < self._max_prefix_chars:
|
||||
return None
|
||||
return None, self._buffer
|
||||
|
||||
if self._opening_tag.startswith(stripped) and len(self._buffer) < self._max_prefix_chars:
|
||||
return None
|
||||
|
||||
return None, self._buffer
|
||||
|
||||
async def _flush_buffer(self, direction: FrameDirection) -> None:
|
||||
await self._push_start(direction)
|
||||
if self._buffer:
|
||||
await self.push_frame(LLMTextFrame(self._buffer), direction)
|
||||
self._buffer = ""
|
||||
self._decided = True
|
||||
|
||||
async def _push_start(self, direction: FrameDirection) -> None:
|
||||
if self._start_frame:
|
||||
await self.push_frame(self._start_frame, direction)
|
||||
self._start_frame = None
|
||||
|
||||
async def _emit_state(self, state: str) -> None:
|
||||
await self.push_frame(
|
||||
OutputTransportMessageUrgentFrame(
|
||||
message={
|
||||
"type": self._event_type,
|
||||
"state": state,
|
||||
}
|
||||
),
|
||||
FrameDirection.DOWNSTREAM,
|
||||
)
|
||||
|
||||
def _reset(self) -> None:
|
||||
self._start_frame = None
|
||||
self._buffer = ""
|
||||
self._decided = False
|
||||
self._in_llm_response = False
|
||||
Reference in New Issue
Block a user