216 lines
7.3 KiB
Python
216 lines
7.3 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Any, Protocol
|
|
|
|
from pipecat.frames.frames import (
|
|
CancelFrame,
|
|
Frame,
|
|
InterruptionFrame,
|
|
LLMFullResponseEndFrame,
|
|
LLMFullResponseStartFrame,
|
|
LLMTextFrame,
|
|
OutputTransportMessageUrgentFrame,
|
|
TTSSpeakFrame,
|
|
)
|
|
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
|
|
|
|
|
class _AssistantContextSync(Protocol):
|
|
@property
|
|
def context(self) -> Any: ...
|
|
|
|
|
|
def _committed_assistant_content(context: Any) -> str:
|
|
"""Return trailing assistant text only when the last context message is assistant."""
|
|
messages = context.get_messages()
|
|
if not messages:
|
|
return ""
|
|
last = messages[-1]
|
|
if not isinstance(last, dict) or last.get("role") != "assistant":
|
|
return ""
|
|
content = last.get("content")
|
|
if isinstance(content, str):
|
|
return content.strip()
|
|
return ""
|
|
|
|
|
|
def sync_streamed_assistant_context(
|
|
aggregator: _AssistantContextSync,
|
|
*,
|
|
streamed_text: str,
|
|
committed_text: str,
|
|
) -> None:
|
|
"""Align LLM context with urgent-streamed UI text.
|
|
|
|
The assistant aggregator commits TTS-spoken text; ``ProductTextStreamProcessor``
|
|
mirrors the LLM stream to the client. Replace or insert the streamed text so
|
|
the next turn sees what the user read on screen.
|
|
"""
|
|
streamed = streamed_text.strip()
|
|
if not streamed or streamed == committed_text.strip():
|
|
return
|
|
|
|
committed = committed_text.strip()
|
|
|
|
def _apply(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
updated = list(messages)
|
|
if not updated:
|
|
updated.append({"role": "assistant", "content": streamed})
|
|
return updated
|
|
|
|
last = updated[-1]
|
|
if isinstance(last, dict) and last.get("role") == "assistant":
|
|
content = last.get("content")
|
|
if isinstance(content, str) and content.strip() != streamed:
|
|
updated[-1] = {"role": "assistant", "content": streamed}
|
|
return updated
|
|
|
|
if (
|
|
len(updated) >= 2
|
|
and isinstance(last, dict)
|
|
and last.get("role") == "user"
|
|
):
|
|
prev = updated[-2]
|
|
if isinstance(prev, dict) and prev.get("role") == "user":
|
|
updated.insert(len(updated) - 1, {"role": "assistant", "content": streamed})
|
|
return updated
|
|
|
|
if isinstance(last, dict) and last.get("role") == "user":
|
|
updated.append({"role": "assistant", "content": streamed})
|
|
return updated
|
|
|
|
updated.append({"role": "assistant", "content": streamed})
|
|
return updated
|
|
|
|
aggregator.context.transform_messages(_apply)
|
|
|
|
|
|
def maybe_sync_assistant_context(
|
|
aggregator: _AssistantContextSync,
|
|
text_stream: "ProductTextStreamProcessor",
|
|
*,
|
|
committed_text: str | None = None,
|
|
) -> None:
|
|
committed = (
|
|
committed_text.strip()
|
|
if committed_text is not None
|
|
else _committed_assistant_content(aggregator.context)
|
|
)
|
|
streamed = text_stream.last_assistant_stream_text()
|
|
if not streamed:
|
|
return
|
|
sync_streamed_assistant_context(
|
|
aggregator,
|
|
streamed_text=streamed,
|
|
committed_text=committed,
|
|
)
|
|
|
|
|
|
class ProductTextStreamProcessor(FrameProcessor):
|
|
"""Mirrors LLM text frames as streaming protocol events.
|
|
|
|
Placed between the LLM service and the TTS service, this processor
|
|
observes the LLM's text frames as they're emitted and forwards them
|
|
downstream as ``OutputTransportMessageUrgentFrame``s that the product
|
|
serializer turns into ``response.text.{started,delta,final}`` events.
|
|
|
|
Urgent frames bypass TTS serialization and transport audio queues so text
|
|
reaches the client at least as quickly as synthesized audio.
|
|
|
|
``TTSSpeakFrame`` (used by the fixed-greeting code path, which bypasses
|
|
the LLM entirely) is also handled: the processor synthesizes a single
|
|
started/delta/final sequence for its fixed text.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self._aggregation: list[str] = []
|
|
self._turn_active = False
|
|
self._last_assistant_stream_text = ""
|
|
self._interrupted_stream_text: str | None = None
|
|
|
|
def last_assistant_stream_text(self) -> str:
|
|
return self._last_assistant_stream_text
|
|
|
|
def take_interrupted_stream_text(self) -> str | None:
|
|
text = self._interrupted_stream_text
|
|
self._interrupted_stream_text = None
|
|
return text
|
|
|
|
async def process_frame(self, frame: Frame, direction: FrameDirection) -> None:
|
|
await super().process_frame(frame, direction)
|
|
|
|
if isinstance(frame, LLMFullResponseStartFrame):
|
|
await self.push_frame(frame, direction)
|
|
await self._start_turn()
|
|
elif isinstance(frame, LLMTextFrame):
|
|
await self.push_frame(frame, direction)
|
|
if frame.text:
|
|
await self._delta(frame.text)
|
|
elif isinstance(frame, LLMFullResponseEndFrame):
|
|
await self.push_frame(frame, direction)
|
|
await self._end_turn(interrupted=False)
|
|
elif isinstance(frame, (InterruptionFrame, CancelFrame)):
|
|
await self.push_frame(frame, direction)
|
|
await self._handle_interrupt()
|
|
elif isinstance(frame, TTSSpeakFrame):
|
|
# Fixed-text / direct-speech path: there's no LLM cycle, so
|
|
# synthesize one started/delta/final sequence for the spoken text.
|
|
text = frame.text or ""
|
|
await self.push_frame(frame, direction)
|
|
await self._start_turn()
|
|
if text:
|
|
await self._delta(text)
|
|
await self._end_turn(interrupted=False)
|
|
else:
|
|
await self.push_frame(frame, direction)
|
|
|
|
async def _start_turn(self) -> None:
|
|
if self._turn_active:
|
|
return
|
|
self._turn_active = True
|
|
self._aggregation = []
|
|
await self._emit("response.text.started")
|
|
|
|
async def _delta(self, text: str) -> None:
|
|
if not self._turn_active:
|
|
# A text frame outside a turn shouldn't happen, but if it does,
|
|
# synthesize a started boundary so the client renders sensibly.
|
|
await self._start_turn()
|
|
self._aggregation.append(text)
|
|
await self._emit("response.text.delta", text=text)
|
|
|
|
async def _handle_interrupt(self) -> None:
|
|
if self._turn_active:
|
|
await self._end_turn(interrupted=True)
|
|
return
|
|
|
|
if self._last_assistant_stream_text:
|
|
self._interrupted_stream_text = self._last_assistant_stream_text
|
|
|
|
async def _end_turn(self, *, interrupted: bool) -> None:
|
|
if not self._turn_active:
|
|
return
|
|
|
|
full_text = "".join(self._aggregation)
|
|
if full_text:
|
|
self._last_assistant_stream_text = full_text
|
|
if interrupted and full_text:
|
|
self._interrupted_stream_text = full_text
|
|
|
|
self._turn_active = False
|
|
self._aggregation = []
|
|
await self._emit(
|
|
"response.text.final",
|
|
text=full_text,
|
|
interrupted=interrupted,
|
|
)
|
|
|
|
async def _emit(self, event_type: str, **payload: object) -> None:
|
|
await self.push_frame(
|
|
OutputTransportMessageUrgentFrame(
|
|
message={"type": event_type, **payload},
|
|
),
|
|
FrameDirection.DOWNSTREAM,
|
|
)
|