Add AssistantContextSyncProcessor to manage LLM context synchronization during assistant interruptions. Update ProductTextStreamProcessor to track last streamed text and modify pipeline to integrate new context sync functionality. Enhance tests to verify context sync behavior for interrupted assistant turns.
This commit is contained in:
40
engine/context_sync.py
Normal file
40
engine/context_sync.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pipecat.frames.frames import Frame, InterruptionFrame, LLMMessagesAppendFrame
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
from .text_stream import ProductTextStreamProcessor, maybe_sync_assistant_context
|
||||
|
||||
|
||||
class AssistantContextSyncProcessor(FrameProcessor):
|
||||
"""Sync LLM context to urgent-streamed assistant text before text-input turns.
|
||||
|
||||
``input.text`` with ``interrupt: true`` queues ``InterruptionFrame`` before
|
||||
``LLMMessagesAppendFrame``. This processor runs context repair after the
|
||||
interrupt has propagated (including TTS-phase interrupts) and before the new
|
||||
user message is appended.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
text_stream: ProductTextStreamProcessor,
|
||||
assistant_aggregator: Any,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._text_stream = text_stream
|
||||
self._assistant_aggregator = assistant_aggregator
|
||||
self._sync_on_next_append = False
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection) -> None:
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, InterruptionFrame):
|
||||
self._sync_on_next_append = True
|
||||
elif isinstance(frame, LLMMessagesAppendFrame) and self._sync_on_next_append:
|
||||
self._sync_on_next_append = False
|
||||
maybe_sync_assistant_context(self._assistant_aggregator, self._text_stream)
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
@@ -33,10 +33,11 @@ from pipecat.turns.user_stop.speech_timeout_user_turn_stop_strategy import (
|
||||
from pipecat.turns.user_turn_strategies import UserTurnStrategies
|
||||
|
||||
from .config import EngineConfig
|
||||
from .context_sync import AssistantContextSyncProcessor
|
||||
from .product_protocol import ProductWebsocketSerializer
|
||||
from .services import create_llm_service, create_stt_service, create_tts_service
|
||||
from .text_input import ProductTextInputProcessor
|
||||
from .text_stream import ProductTextStreamProcessor, sync_streamed_assistant_context
|
||||
from .text_stream import ProductTextStreamProcessor, maybe_sync_assistant_context
|
||||
from .transcript_stream import ProductTranscriptStreamProcessor
|
||||
from .turn_start import InterruptionGateUserTurnStartStrategy
|
||||
|
||||
@@ -143,6 +144,10 @@ async def run_pipeline_with_serializer(
|
||||
)
|
||||
|
||||
text_stream = ProductTextStreamProcessor()
|
||||
context_sync = AssistantContextSyncProcessor(
|
||||
text_stream=text_stream,
|
||||
assistant_aggregator=assistant_aggregator,
|
||||
)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
@@ -150,6 +155,7 @@ async def run_pipeline_with_serializer(
|
||||
ProductTextInputProcessor(),
|
||||
stt,
|
||||
ProductTranscriptStreamProcessor(),
|
||||
context_sync,
|
||||
user_aggregator,
|
||||
llm,
|
||||
text_stream,
|
||||
@@ -212,14 +218,12 @@ async def run_pipeline_with_serializer(
|
||||
@assistant_aggregator.event_handler("on_assistant_turn_stopped")
|
||||
async def on_assistant_turn_stopped(_aggregator, message: AssistantTurnStoppedMessage):
|
||||
logger.info(f"Assistant: {message.content}")
|
||||
if message.interrupted:
|
||||
streamed = text_stream.take_interrupted_stream_text()
|
||||
if streamed:
|
||||
sync_streamed_assistant_context(
|
||||
_aggregator,
|
||||
streamed_text=streamed,
|
||||
committed_text=message.content or "",
|
||||
)
|
||||
maybe_sync_assistant_context(
|
||||
_aggregator,
|
||||
text_stream,
|
||||
committed_text=message.content or "",
|
||||
)
|
||||
text_stream.take_interrupted_stream_text()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=False)
|
||||
await runner.run(task)
|
||||
|
||||
@@ -20,16 +20,31 @@ class _AssistantContextSync(Protocol):
|
||||
def context(self) -> Any: ...
|
||||
|
||||
|
||||
def _committed_assistant_content(context: Any) -> str:
|
||||
"""Return trailing assistant text only when the last context message is assistant."""
|
||||
messages = context.get_messages()
|
||||
if not messages:
|
||||
return ""
|
||||
last = messages[-1]
|
||||
if not isinstance(last, dict) or last.get("role") != "assistant":
|
||||
return ""
|
||||
content = last.get("content")
|
||||
if isinstance(content, str):
|
||||
return content.strip()
|
||||
return ""
|
||||
|
||||
|
||||
def sync_streamed_assistant_context(
|
||||
aggregator: _AssistantContextSync,
|
||||
*,
|
||||
streamed_text: str,
|
||||
committed_text: str,
|
||||
) -> None:
|
||||
"""Align LLM context with UI text after an interrupted assistant turn.
|
||||
"""Align LLM context with urgent-streamed UI text.
|
||||
|
||||
The assistant aggregator only commits TTS-spoken text on interrupt. Replace
|
||||
or append the streamed LLM text so the next turn sees what the user saw.
|
||||
The assistant aggregator commits TTS-spoken text; ``ProductTextStreamProcessor``
|
||||
mirrors the LLM stream to the client. Replace or insert the streamed text so
|
||||
the next turn sees what the user read on screen.
|
||||
"""
|
||||
streamed = streamed_text.strip()
|
||||
if not streamed or streamed == committed_text.strip():
|
||||
@@ -39,19 +54,58 @@ def sync_streamed_assistant_context(
|
||||
|
||||
def _apply(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
updated = list(messages)
|
||||
if committed and updated:
|
||||
last = updated[-1]
|
||||
if isinstance(last, dict) and last.get("role") == "assistant":
|
||||
content = last.get("content")
|
||||
if isinstance(content, str) and content.strip() == committed:
|
||||
updated[-1] = {"role": "assistant", "content": streamed}
|
||||
return updated
|
||||
if not updated:
|
||||
updated.append({"role": "assistant", "content": streamed})
|
||||
return updated
|
||||
|
||||
last = updated[-1]
|
||||
if isinstance(last, dict) and last.get("role") == "assistant":
|
||||
content = last.get("content")
|
||||
if isinstance(content, str) and content.strip() != streamed:
|
||||
updated[-1] = {"role": "assistant", "content": streamed}
|
||||
return updated
|
||||
|
||||
if (
|
||||
len(updated) >= 2
|
||||
and isinstance(last, dict)
|
||||
and last.get("role") == "user"
|
||||
):
|
||||
prev = updated[-2]
|
||||
if isinstance(prev, dict) and prev.get("role") == "user":
|
||||
updated.insert(len(updated) - 1, {"role": "assistant", "content": streamed})
|
||||
return updated
|
||||
|
||||
if isinstance(last, dict) and last.get("role") == "user":
|
||||
updated.append({"role": "assistant", "content": streamed})
|
||||
return updated
|
||||
|
||||
updated.append({"role": "assistant", "content": streamed})
|
||||
return updated
|
||||
|
||||
aggregator.context.transform_messages(_apply)
|
||||
|
||||
|
||||
def maybe_sync_assistant_context(
|
||||
aggregator: _AssistantContextSync,
|
||||
text_stream: "ProductTextStreamProcessor",
|
||||
*,
|
||||
committed_text: str | None = None,
|
||||
) -> None:
|
||||
committed = (
|
||||
committed_text.strip()
|
||||
if committed_text is not None
|
||||
else _committed_assistant_content(aggregator.context)
|
||||
)
|
||||
streamed = text_stream.last_assistant_stream_text()
|
||||
if not streamed:
|
||||
return
|
||||
sync_streamed_assistant_context(
|
||||
aggregator,
|
||||
streamed_text=streamed,
|
||||
committed_text=committed,
|
||||
)
|
||||
|
||||
|
||||
class ProductTextStreamProcessor(FrameProcessor):
|
||||
"""Mirrors LLM text frames as streaming protocol events.
|
||||
|
||||
@@ -72,8 +126,12 @@ class ProductTextStreamProcessor(FrameProcessor):
|
||||
super().__init__()
|
||||
self._aggregation: list[str] = []
|
||||
self._turn_active = False
|
||||
self._last_assistant_stream_text = ""
|
||||
self._interrupted_stream_text: str | None = None
|
||||
|
||||
def last_assistant_stream_text(self) -> str:
|
||||
return self._last_assistant_stream_text
|
||||
|
||||
def take_interrupted_stream_text(self) -> str | None:
|
||||
text = self._interrupted_stream_text
|
||||
self._interrupted_stream_text = None
|
||||
@@ -94,7 +152,7 @@ class ProductTextStreamProcessor(FrameProcessor):
|
||||
await self._end_turn(interrupted=False)
|
||||
elif isinstance(frame, (InterruptionFrame, CancelFrame)):
|
||||
await self.push_frame(frame, direction)
|
||||
await self._end_turn(interrupted=True)
|
||||
await self._handle_interrupt()
|
||||
elif isinstance(frame, TTSSpeakFrame):
|
||||
# Fixed-text / direct-speech path: there's no LLM cycle, so
|
||||
# synthesize one started/delta/final sequence for the spoken text.
|
||||
@@ -122,12 +180,24 @@ class ProductTextStreamProcessor(FrameProcessor):
|
||||
self._aggregation.append(text)
|
||||
await self._emit("response.text.delta", text=text)
|
||||
|
||||
async def _handle_interrupt(self) -> None:
|
||||
if self._turn_active:
|
||||
await self._end_turn(interrupted=True)
|
||||
return
|
||||
|
||||
if self._last_assistant_stream_text:
|
||||
self._interrupted_stream_text = self._last_assistant_stream_text
|
||||
|
||||
async def _end_turn(self, *, interrupted: bool) -> None:
|
||||
if not self._turn_active:
|
||||
return
|
||||
|
||||
full_text = "".join(self._aggregation)
|
||||
if full_text:
|
||||
self._last_assistant_stream_text = full_text
|
||||
if interrupted and full_text:
|
||||
self._interrupted_stream_text = full_text
|
||||
|
||||
self._turn_active = False
|
||||
self._aggregation = []
|
||||
await self._emit(
|
||||
|
||||
Reference in New Issue
Block a user