Add support for prefix state code
This commit is contained in:
@@ -148,6 +148,7 @@ Returned transcripts and assistant text:
|
|||||||
```json
|
```json
|
||||||
{"type": "input.transcript.interim", "text": "What's the"}
|
{"type": "input.transcript.interim", "text": "What's the"}
|
||||||
{"type": "input.transcript.final", "text": "What's the weather?", "user_id": "...", "timestamp": "..."}
|
{"type": "input.transcript.final", "text": "What's the weather?", "user_id": "...", "timestamp": "..."}
|
||||||
|
{"type": "response.state", "state": "speaking"}
|
||||||
{"type": "response.text.started"}
|
{"type": "response.text.started"}
|
||||||
{"type": "response.text.delta", "text": "It's "}
|
{"type": "response.text.delta", "text": "It's "}
|
||||||
{"type": "response.text.delta", "text": "sunny in "}
|
{"type": "response.text.delta", "text": "sunny in "}
|
||||||
@@ -163,6 +164,11 @@ the TTS in the pipeline. `response.text.final` fires when the turn ends,
|
|||||||
carrying the full concatenated assistant text and an `interrupted` flag
|
carrying the full concatenated assistant text and an `interrupted` flag
|
||||||
(true when an `input.text` or barge-in cut the turn short).
|
(true when an `input.text` or barge-in cut the turn short).
|
||||||
|
|
||||||
|
When `agent.response_state.enabled` is true, an LLM response that starts with
|
||||||
|
`<state>...</state>` emits the tag body as `response.state` before the
|
||||||
|
remaining assistant text is streamed and spoken. If the tag is missing or
|
||||||
|
malformed, the original response text is streamed unchanged.
|
||||||
|
|
||||||
### Turn detection
|
### Turn detection
|
||||||
|
|
||||||
User-turn segmentation (VAD thresholds + how long to wait after silence
|
User-turn segmentation (VAD thresholds + how long to wait after silence
|
||||||
|
|||||||
@@ -57,7 +57,13 @@
|
|||||||
"agent": {
|
"agent": {
|
||||||
"system_prompt": "You are a helpful, friendly voice assistant. Keep responses concise and natural for spoken conversation.",
|
"system_prompt": "You are a helpful, friendly voice assistant. Keep responses concise and natural for spoken conversation.",
|
||||||
"greeting": "Please introduce yourself briefly.",
|
"greeting": "Please introduce yourself briefly.",
|
||||||
"greeting_mode": "generated"
|
"greeting_mode": "generated",
|
||||||
|
"response_state": {
|
||||||
|
"enabled": false,
|
||||||
|
"tag": "state",
|
||||||
|
"event_type": "response.state",
|
||||||
|
"max_prefix_chars": 256
|
||||||
|
}
|
||||||
},
|
},
|
||||||
"services": {
|
"services": {
|
||||||
"stt": {
|
"stt": {
|
||||||
|
|||||||
@@ -119,6 +119,7 @@ ws://<host>:<port>/ws-product
|
|||||||
|----------|------|
|
|----------|------|
|
||||||
| `input.transcript.interim` | 用户语音中间转写(流式 ASR 时) |
|
| `input.transcript.interim` | 用户语音中间转写(流式 ASR 时) |
|
||||||
| `input.transcript.final` | 用户语音最终转写 |
|
| `input.transcript.final` | 用户语音最终转写 |
|
||||||
|
| `response.state` | 可选;当 `agent.response_state.enabled` 开启且 LLM 回复以 `<state>...</state>` 开头时发送 |
|
||||||
| `response.text.started` | 助手开始回复 |
|
| `response.text.started` | 助手开始回复 |
|
||||||
| `response.text.delta` | 助手文本流式片段(通常早于 TTS 音频) |
|
| `response.text.delta` | 助手文本流式片段(通常早于 TTS 音频) |
|
||||||
| `response.text.final` | 助手本轮回复结束,见下方 `interrupted` 说明 |
|
| `response.text.final` | 助手本轮回复结束,见下方 `interrupted` 说明 |
|
||||||
|
|||||||
@@ -107,11 +107,20 @@ class TurnConfig:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ResponseStateConfig:
|
||||||
|
enabled: bool = False
|
||||||
|
tag: str = "state"
|
||||||
|
event_type: str = "response.state"
|
||||||
|
max_prefix_chars: int = 256
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class AgentConfig:
|
class AgentConfig:
|
||||||
system_prompt: str = "You are a helpful, friendly voice assistant."
|
system_prompt: str = "You are a helpful, friendly voice assistant."
|
||||||
greeting: str | None = None
|
greeting: str | None = None
|
||||||
greeting_mode: str = "generated"
|
greeting_mode: str = "generated"
|
||||||
|
response_state: ResponseStateConfig = field(default_factory=ResponseStateConfig)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@@ -220,6 +229,13 @@ def config_from_dict(data: dict) -> EngineConfig:
|
|||||||
agent["greeting"] = None
|
agent["greeting"] = None
|
||||||
if agent.get("greeting_mode") not in (None, "generated", "fixed", "off"):
|
if agent.get("greeting_mode") not in (None, "generated", "fixed", "off"):
|
||||||
raise ValueError("agent.greeting_mode must be one of: generated, fixed, off")
|
raise ValueError("agent.greeting_mode must be one of: generated, fixed, off")
|
||||||
|
response_state = ResponseStateConfig(**_dict(agent.pop("response_state")))
|
||||||
|
if response_state.max_prefix_chars < 1:
|
||||||
|
raise ValueError("agent.response_state.max_prefix_chars must be greater than 0")
|
||||||
|
if not response_state.tag:
|
||||||
|
raise ValueError("agent.response_state.tag must not be empty")
|
||||||
|
if not response_state.event_type:
|
||||||
|
raise ValueError("agent.response_state.event_type must not be empty")
|
||||||
|
|
||||||
stt = _dict(services.get("stt") or services.get("asr"))
|
stt = _dict(services.get("stt") or services.get("asr"))
|
||||||
if stt.get("language") == "":
|
if stt.get("language") == "":
|
||||||
@@ -260,7 +276,7 @@ def config_from_dict(data: dict) -> EngineConfig:
|
|||||||
)
|
)
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
agent=AgentConfig(**agent),
|
agent=AgentConfig(**agent, response_state=response_state),
|
||||||
services=ServicesConfig(
|
services=ServicesConfig(
|
||||||
llm=LLMConfig(**llm),
|
llm=LLMConfig(**llm),
|
||||||
stt=STTConfig(**stt),
|
stt=STTConfig(**stt),
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ from .config import EngineConfig
|
|||||||
from .context_sync import AssistantContextSyncProcessor
|
from .context_sync import AssistantContextSyncProcessor
|
||||||
from .fastgpt_llm import FastGPTLLMService
|
from .fastgpt_llm import FastGPTLLMService
|
||||||
from .product_protocol import ProductWebsocketSerializer
|
from .product_protocol import ProductWebsocketSerializer
|
||||||
|
from .response_state import StateTagResponseProcessor
|
||||||
from .services import create_llm_service, create_stt_service, create_tts_service
|
from .services import create_llm_service, create_stt_service, create_tts_service
|
||||||
from .text_input import ProductTextInputProcessor
|
from .text_input import ProductTextInputProcessor
|
||||||
from .text_stream import ProductTextStreamProcessor, maybe_sync_assistant_context
|
from .text_stream import ProductTextStreamProcessor, maybe_sync_assistant_context
|
||||||
@@ -153,21 +154,26 @@ async def run_pipeline_with_serializer(
|
|||||||
assistant_aggregator=assistant_aggregator,
|
assistant_aggregator=assistant_aggregator,
|
||||||
)
|
)
|
||||||
|
|
||||||
pipeline = Pipeline(
|
processors = [
|
||||||
|
transport.input(),
|
||||||
|
ProductTextInputProcessor(),
|
||||||
|
stt,
|
||||||
|
ProductTranscriptStreamProcessor(),
|
||||||
|
context_sync,
|
||||||
|
user_aggregator,
|
||||||
|
llm,
|
||||||
|
]
|
||||||
|
if config.agent.response_state.enabled:
|
||||||
|
processors.append(StateTagResponseProcessor(config.agent.response_state))
|
||||||
|
processors.extend(
|
||||||
[
|
[
|
||||||
transport.input(),
|
|
||||||
ProductTextInputProcessor(),
|
|
||||||
stt,
|
|
||||||
ProductTranscriptStreamProcessor(),
|
|
||||||
context_sync,
|
|
||||||
user_aggregator,
|
|
||||||
llm,
|
|
||||||
text_stream,
|
text_stream,
|
||||||
tts,
|
tts,
|
||||||
transport.output(),
|
transport.output(),
|
||||||
assistant_aggregator,
|
assistant_aggregator,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
pipeline = Pipeline(processors)
|
||||||
|
|
||||||
task = PipelineTask(
|
task = PipelineTask(
|
||||||
pipeline,
|
pipeline,
|
||||||
|
|||||||
136
engine/response_state.py
Normal file
136
engine/response_state.py
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pipecat.frames.frames import (
|
||||||
|
CancelFrame,
|
||||||
|
Frame,
|
||||||
|
InterruptionFrame,
|
||||||
|
LLMFullResponseEndFrame,
|
||||||
|
LLMFullResponseStartFrame,
|
||||||
|
LLMTextFrame,
|
||||||
|
OutputTransportMessageUrgentFrame,
|
||||||
|
)
|
||||||
|
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||||
|
|
||||||
|
from .config import ResponseStateConfig
|
||||||
|
|
||||||
|
|
||||||
|
class StateTagResponseProcessor(FrameProcessor):
|
||||||
|
"""Extract a leading state tag from LLM text before text streaming and TTS.
|
||||||
|
|
||||||
|
Expected model output:
|
||||||
|
|
||||||
|
<state>some state</state>spoken response
|
||||||
|
|
||||||
|
The extracted state is emitted as a product protocol event, while only the
|
||||||
|
spoken response text is forwarded downstream. If the model does not produce
|
||||||
|
the tag, the original text is forwarded unchanged.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: ResponseStateConfig) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._tag = config.tag
|
||||||
|
self._event_type = config.event_type
|
||||||
|
self._max_prefix_chars = config.max_prefix_chars
|
||||||
|
self._opening_tag = f"<{self._tag}>"
|
||||||
|
self._closing_tag = f"</{self._tag}>"
|
||||||
|
self._start_frame: LLMFullResponseStartFrame | None = None
|
||||||
|
self._buffer = ""
|
||||||
|
self._decided = False
|
||||||
|
self._in_llm_response = False
|
||||||
|
|
||||||
|
async def process_frame(self, frame: Frame, direction: FrameDirection) -> None:
|
||||||
|
await super().process_frame(frame, direction)
|
||||||
|
|
||||||
|
if isinstance(frame, LLMFullResponseStartFrame):
|
||||||
|
self._start_frame = frame
|
||||||
|
self._buffer = ""
|
||||||
|
self._decided = False
|
||||||
|
self._in_llm_response = True
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(frame, LLMTextFrame) and self._in_llm_response and not self._decided:
|
||||||
|
await self._process_initial_text(frame.text or "", direction)
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(frame, LLMFullResponseEndFrame):
|
||||||
|
if self._in_llm_response:
|
||||||
|
await self._flush_buffer(direction)
|
||||||
|
await self.push_frame(frame, direction)
|
||||||
|
self._reset()
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(frame, (InterruptionFrame, CancelFrame)):
|
||||||
|
if self._in_llm_response:
|
||||||
|
await self._flush_buffer(direction)
|
||||||
|
self._reset()
|
||||||
|
await self.push_frame(frame, direction)
|
||||||
|
return
|
||||||
|
|
||||||
|
await self.push_frame(frame, direction)
|
||||||
|
|
||||||
|
async def _process_initial_text(self, text: str, direction: FrameDirection) -> None:
|
||||||
|
if not text:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._buffer += text
|
||||||
|
decision = self._parse_buffer()
|
||||||
|
if decision is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._decided = True
|
||||||
|
state, response_text = decision
|
||||||
|
if state is not None:
|
||||||
|
await self._emit_state(state)
|
||||||
|
await self._push_start(direction)
|
||||||
|
if response_text:
|
||||||
|
await self.push_frame(LLMTextFrame(response_text), direction)
|
||||||
|
self._buffer = ""
|
||||||
|
|
||||||
|
def _parse_buffer(self) -> tuple[str | None, str] | None:
|
||||||
|
stripped = self._buffer.lstrip()
|
||||||
|
if not stripped:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if stripped.startswith(self._opening_tag):
|
||||||
|
state_start = len(self._opening_tag)
|
||||||
|
state_end = stripped.find(self._closing_tag, state_start)
|
||||||
|
if state_end >= 0:
|
||||||
|
response_start = state_end + len(self._closing_tag)
|
||||||
|
return stripped[state_start:state_end].strip(), stripped[response_start:]
|
||||||
|
if len(self._buffer) < self._max_prefix_chars:
|
||||||
|
return None
|
||||||
|
return None, self._buffer
|
||||||
|
|
||||||
|
if self._opening_tag.startswith(stripped) and len(self._buffer) < self._max_prefix_chars:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return None, self._buffer
|
||||||
|
|
||||||
|
async def _flush_buffer(self, direction: FrameDirection) -> None:
|
||||||
|
await self._push_start(direction)
|
||||||
|
if self._buffer:
|
||||||
|
await self.push_frame(LLMTextFrame(self._buffer), direction)
|
||||||
|
self._buffer = ""
|
||||||
|
self._decided = True
|
||||||
|
|
||||||
|
async def _push_start(self, direction: FrameDirection) -> None:
|
||||||
|
if self._start_frame:
|
||||||
|
await self.push_frame(self._start_frame, direction)
|
||||||
|
self._start_frame = None
|
||||||
|
|
||||||
|
async def _emit_state(self, state: str) -> None:
|
||||||
|
await self.push_frame(
|
||||||
|
OutputTransportMessageUrgentFrame(
|
||||||
|
message={
|
||||||
|
"type": self._event_type,
|
||||||
|
"state": state,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
FrameDirection.DOWNSTREAM,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _reset(self) -> None:
|
||||||
|
self._start_frame = None
|
||||||
|
self._buffer = ""
|
||||||
|
self._decided = False
|
||||||
|
self._in_llm_response = False
|
||||||
Reference in New Issue
Block a user