137 lines
4.6 KiB
Python
137 lines
4.6 KiB
Python
from __future__ import annotations
|
|
|
|
from pipecat.frames.frames import (
|
|
CancelFrame,
|
|
Frame,
|
|
InterruptionFrame,
|
|
LLMFullResponseEndFrame,
|
|
LLMFullResponseStartFrame,
|
|
LLMTextFrame,
|
|
OutputTransportMessageUrgentFrame,
|
|
)
|
|
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
|
|
|
from .config import ResponseStateConfig
|
|
|
|
|
|
class StateTagResponseProcessor(FrameProcessor):
|
|
"""Extract a leading state tag from LLM text before text streaming and TTS.
|
|
|
|
Expected model output:
|
|
|
|
<state>some state</state>spoken response
|
|
|
|
The extracted state is emitted as a product protocol event, while only the
|
|
spoken response text is forwarded downstream. If the model does not produce
|
|
the tag, the original text is forwarded unchanged.
|
|
"""
|
|
|
|
def __init__(self, config: ResponseStateConfig) -> None:
|
|
super().__init__()
|
|
self._tag = config.tag
|
|
self._event_type = config.event_type
|
|
self._max_prefix_chars = config.max_prefix_chars
|
|
self._opening_tag = f"<{self._tag}>"
|
|
self._closing_tag = f"</{self._tag}>"
|
|
self._start_frame: LLMFullResponseStartFrame | None = None
|
|
self._buffer = ""
|
|
self._decided = False
|
|
self._in_llm_response = False
|
|
|
|
async def process_frame(self, frame: Frame, direction: FrameDirection) -> None:
|
|
await super().process_frame(frame, direction)
|
|
|
|
if isinstance(frame, LLMFullResponseStartFrame):
|
|
self._start_frame = frame
|
|
self._buffer = ""
|
|
self._decided = False
|
|
self._in_llm_response = True
|
|
return
|
|
|
|
if isinstance(frame, LLMTextFrame) and self._in_llm_response and not self._decided:
|
|
await self._process_initial_text(frame.text or "", direction)
|
|
return
|
|
|
|
if isinstance(frame, LLMFullResponseEndFrame):
|
|
if self._in_llm_response:
|
|
await self._flush_buffer(direction)
|
|
await self.push_frame(frame, direction)
|
|
self._reset()
|
|
return
|
|
|
|
if isinstance(frame, (InterruptionFrame, CancelFrame)):
|
|
if self._in_llm_response:
|
|
await self._flush_buffer(direction)
|
|
self._reset()
|
|
await self.push_frame(frame, direction)
|
|
return
|
|
|
|
await self.push_frame(frame, direction)
|
|
|
|
async def _process_initial_text(self, text: str, direction: FrameDirection) -> None:
|
|
if not text:
|
|
return
|
|
|
|
self._buffer += text
|
|
decision = self._parse_buffer()
|
|
if decision is None:
|
|
return
|
|
|
|
self._decided = True
|
|
state, response_text = decision
|
|
if state is not None:
|
|
await self._emit_state(state)
|
|
await self._push_start(direction)
|
|
if response_text:
|
|
await self.push_frame(LLMTextFrame(response_text), direction)
|
|
self._buffer = ""
|
|
|
|
def _parse_buffer(self) -> tuple[str | None, str] | None:
|
|
stripped = self._buffer.lstrip()
|
|
if not stripped:
|
|
return None
|
|
|
|
if stripped.startswith(self._opening_tag):
|
|
state_start = len(self._opening_tag)
|
|
state_end = stripped.find(self._closing_tag, state_start)
|
|
if state_end >= 0:
|
|
response_start = state_end + len(self._closing_tag)
|
|
return stripped[state_start:state_end].strip(), stripped[response_start:]
|
|
if len(self._buffer) < self._max_prefix_chars:
|
|
return None
|
|
return None, self._buffer
|
|
|
|
if self._opening_tag.startswith(stripped) and len(self._buffer) < self._max_prefix_chars:
|
|
return None
|
|
|
|
return None, self._buffer
|
|
|
|
async def _flush_buffer(self, direction: FrameDirection) -> None:
|
|
await self._push_start(direction)
|
|
if self._buffer:
|
|
await self.push_frame(LLMTextFrame(self._buffer), direction)
|
|
self._buffer = ""
|
|
self._decided = True
|
|
|
|
async def _push_start(self, direction: FrameDirection) -> None:
|
|
if self._start_frame:
|
|
await self.push_frame(self._start_frame, direction)
|
|
self._start_frame = None
|
|
|
|
async def _emit_state(self, state: str) -> None:
|
|
await self.push_frame(
|
|
OutputTransportMessageUrgentFrame(
|
|
message={
|
|
"type": self._event_type,
|
|
"state": state,
|
|
}
|
|
),
|
|
FrameDirection.DOWNSTREAM,
|
|
)
|
|
|
|
def _reset(self) -> None:
|
|
self._start_frame = None
|
|
self._buffer = ""
|
|
self._decided = False
|
|
self._in_llm_response = False
|