Files
engine-v5-pipecat-core/engine/text_stream.py

216 lines
7.3 KiB
Python

from __future__ import annotations
from typing import Any, Protocol
from pipecat.frames.frames import (
CancelFrame,
Frame,
InterruptionFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMTextFrame,
OutputTransportMessageUrgentFrame,
TTSSpeakFrame,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
class _AssistantContextSync(Protocol):
@property
def context(self) -> Any: ...
def _committed_assistant_content(context: Any) -> str:
"""Return trailing assistant text only when the last context message is assistant."""
messages = context.get_messages()
if not messages:
return ""
last = messages[-1]
if not isinstance(last, dict) or last.get("role") != "assistant":
return ""
content = last.get("content")
if isinstance(content, str):
return content.strip()
return ""
def sync_streamed_assistant_context(
aggregator: _AssistantContextSync,
*,
streamed_text: str,
committed_text: str,
) -> None:
"""Align LLM context with urgent-streamed UI text.
The assistant aggregator commits TTS-spoken text; ``ProductTextStreamProcessor``
mirrors the LLM stream to the client. Replace or insert the streamed text so
the next turn sees what the user read on screen.
"""
streamed = streamed_text.strip()
if not streamed or streamed == committed_text.strip():
return
committed = committed_text.strip()
def _apply(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
updated = list(messages)
if not updated:
updated.append({"role": "assistant", "content": streamed})
return updated
last = updated[-1]
if isinstance(last, dict) and last.get("role") == "assistant":
content = last.get("content")
if isinstance(content, str) and content.strip() != streamed:
updated[-1] = {"role": "assistant", "content": streamed}
return updated
if (
len(updated) >= 2
and isinstance(last, dict)
and last.get("role") == "user"
):
prev = updated[-2]
if isinstance(prev, dict) and prev.get("role") == "user":
updated.insert(len(updated) - 1, {"role": "assistant", "content": streamed})
return updated
if isinstance(last, dict) and last.get("role") == "user":
updated.append({"role": "assistant", "content": streamed})
return updated
updated.append({"role": "assistant", "content": streamed})
return updated
aggregator.context.transform_messages(_apply)
def maybe_sync_assistant_context(
aggregator: _AssistantContextSync,
text_stream: "ProductTextStreamProcessor",
*,
committed_text: str | None = None,
) -> None:
committed = (
committed_text.strip()
if committed_text is not None
else _committed_assistant_content(aggregator.context)
)
streamed = text_stream.last_assistant_stream_text()
if not streamed:
return
sync_streamed_assistant_context(
aggregator,
streamed_text=streamed,
committed_text=committed,
)
class ProductTextStreamProcessor(FrameProcessor):
"""Mirrors LLM text frames as streaming protocol events.
Placed between the LLM service and the TTS service, this processor
observes the LLM's text frames as they're emitted and forwards them
downstream as ``OutputTransportMessageUrgentFrame``s that the product
serializer turns into ``response.text.{started,delta,final}`` events.
Urgent frames bypass TTS serialization and transport audio queues so text
reaches the client at least as quickly as synthesized audio.
``TTSSpeakFrame`` (used by the fixed-greeting code path, which bypasses
the LLM entirely) is also handled: the processor synthesizes a single
started/delta/final sequence for its fixed text.
"""
def __init__(self) -> None:
super().__init__()
self._aggregation: list[str] = []
self._turn_active = False
self._last_assistant_stream_text = ""
self._interrupted_stream_text: str | None = None
def last_assistant_stream_text(self) -> str:
return self._last_assistant_stream_text
def take_interrupted_stream_text(self) -> str | None:
text = self._interrupted_stream_text
self._interrupted_stream_text = None
return text
async def process_frame(self, frame: Frame, direction: FrameDirection) -> None:
await super().process_frame(frame, direction)
if isinstance(frame, LLMFullResponseStartFrame):
await self.push_frame(frame, direction)
await self._start_turn()
elif isinstance(frame, LLMTextFrame):
await self.push_frame(frame, direction)
if frame.text:
await self._delta(frame.text)
elif isinstance(frame, LLMFullResponseEndFrame):
await self.push_frame(frame, direction)
await self._end_turn(interrupted=False)
elif isinstance(frame, (InterruptionFrame, CancelFrame)):
await self.push_frame(frame, direction)
await self._handle_interrupt()
elif isinstance(frame, TTSSpeakFrame):
# Fixed-text / direct-speech path: there's no LLM cycle, so
# synthesize one started/delta/final sequence for the spoken text.
text = frame.text or ""
await self.push_frame(frame, direction)
await self._start_turn()
if text:
await self._delta(text)
await self._end_turn(interrupted=False)
else:
await self.push_frame(frame, direction)
async def _start_turn(self) -> None:
if self._turn_active:
return
self._turn_active = True
self._aggregation = []
await self._emit("response.text.started")
async def _delta(self, text: str) -> None:
if not self._turn_active:
# A text frame outside a turn shouldn't happen, but if it does,
# synthesize a started boundary so the client renders sensibly.
await self._start_turn()
self._aggregation.append(text)
await self._emit("response.text.delta", text=text)
async def _handle_interrupt(self) -> None:
if self._turn_active:
await self._end_turn(interrupted=True)
return
if self._last_assistant_stream_text:
self._interrupted_stream_text = self._last_assistant_stream_text
async def _end_turn(self, *, interrupted: bool) -> None:
if not self._turn_active:
return
full_text = "".join(self._aggregation)
if full_text:
self._last_assistant_stream_text = full_text
if interrupted and full_text:
self._interrupted_stream_text = full_text
self._turn_active = False
self._aggregation = []
await self._emit(
"response.text.final",
text=full_text,
interrupted=interrupted,
)
async def _emit(self, event_type: str, **payload: object) -> None:
await self.push_frame(
OutputTransportMessageUrgentFrame(
message={"type": event_type, **payload},
),
FrameDirection.DOWNSTREAM,
)