Add AssistantContextSyncProcessor to manage LLM context synchronization during assistant interruptions. Update ProductTextStreamProcessor to track last streamed text and modify pipeline to integrate new context sync functionality. Enhance tests to verify context sync behavior for interrupted assistant turns.

This commit is contained in:
Xin Wang
2026-05-26 10:44:12 +08:00
parent e7a8cb1faa
commit 3dfff0c937
3 changed files with 134 additions and 20 deletions

40
engine/context_sync.py Normal file
View File

@@ -0,0 +1,40 @@
from __future__ import annotations
from typing import Any
from pipecat.frames.frames import Frame, InterruptionFrame, LLMMessagesAppendFrame
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from .text_stream import ProductTextStreamProcessor, maybe_sync_assistant_context
class AssistantContextSyncProcessor(FrameProcessor):
"""Sync LLM context to urgent-streamed assistant text before text-input turns.
``input.text`` with ``interrupt: true`` queues ``InterruptionFrame`` before
``LLMMessagesAppendFrame``. This processor runs context repair after the
interrupt has propagated (including TTS-phase interrupts) and before the new
user message is appended.
"""
def __init__(
self,
*,
text_stream: ProductTextStreamProcessor,
assistant_aggregator: Any,
) -> None:
super().__init__()
self._text_stream = text_stream
self._assistant_aggregator = assistant_aggregator
self._sync_on_next_append = False
async def process_frame(self, frame: Frame, direction: FrameDirection) -> None:
await super().process_frame(frame, direction)
if isinstance(frame, InterruptionFrame):
self._sync_on_next_append = True
elif isinstance(frame, LLMMessagesAppendFrame) and self._sync_on_next_append:
self._sync_on_next_append = False
maybe_sync_assistant_context(self._assistant_aggregator, self._text_stream)
await self.push_frame(frame, direction)

View File

@@ -33,10 +33,11 @@ from pipecat.turns.user_stop.speech_timeout_user_turn_stop_strategy import (
from pipecat.turns.user_turn_strategies import UserTurnStrategies
from .config import EngineConfig
from .context_sync import AssistantContextSyncProcessor
from .product_protocol import ProductWebsocketSerializer
from .services import create_llm_service, create_stt_service, create_tts_service
from .text_input import ProductTextInputProcessor
from .text_stream import ProductTextStreamProcessor, sync_streamed_assistant_context
from .text_stream import ProductTextStreamProcessor, maybe_sync_assistant_context
from .transcript_stream import ProductTranscriptStreamProcessor
from .turn_start import InterruptionGateUserTurnStartStrategy
@@ -143,6 +144,10 @@ async def run_pipeline_with_serializer(
)
text_stream = ProductTextStreamProcessor()
context_sync = AssistantContextSyncProcessor(
text_stream=text_stream,
assistant_aggregator=assistant_aggregator,
)
pipeline = Pipeline(
[
@@ -150,6 +155,7 @@ async def run_pipeline_with_serializer(
ProductTextInputProcessor(),
stt,
ProductTranscriptStreamProcessor(),
context_sync,
user_aggregator,
llm,
text_stream,
@@ -212,14 +218,12 @@ async def run_pipeline_with_serializer(
@assistant_aggregator.event_handler("on_assistant_turn_stopped")
async def on_assistant_turn_stopped(_aggregator, message: AssistantTurnStoppedMessage):
logger.info(f"Assistant: {message.content}")
if message.interrupted:
streamed = text_stream.take_interrupted_stream_text()
if streamed:
sync_streamed_assistant_context(
_aggregator,
streamed_text=streamed,
committed_text=message.content or "",
)
maybe_sync_assistant_context(
_aggregator,
text_stream,
committed_text=message.content or "",
)
text_stream.take_interrupted_stream_text()
runner = PipelineRunner(handle_sigint=False)
await runner.run(task)

View File

@@ -20,16 +20,31 @@ class _AssistantContextSync(Protocol):
def context(self) -> Any: ...
def _committed_assistant_content(context: Any) -> str:
"""Return trailing assistant text only when the last context message is assistant."""
messages = context.get_messages()
if not messages:
return ""
last = messages[-1]
if not isinstance(last, dict) or last.get("role") != "assistant":
return ""
content = last.get("content")
if isinstance(content, str):
return content.strip()
return ""
def sync_streamed_assistant_context(
aggregator: _AssistantContextSync,
*,
streamed_text: str,
committed_text: str,
) -> None:
"""Align LLM context with UI text after an interrupted assistant turn.
"""Align LLM context with urgent-streamed UI text.
The assistant aggregator only commits TTS-spoken text on interrupt. Replace
or append the streamed LLM text so the next turn sees what the user saw.
The assistant aggregator commits TTS-spoken text; ``ProductTextStreamProcessor``
mirrors the LLM stream to the client. Replace or insert the streamed text so
the next turn sees what the user read on screen.
"""
streamed = streamed_text.strip()
if not streamed or streamed == committed_text.strip():
@@ -39,19 +54,58 @@ def sync_streamed_assistant_context(
def _apply(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
updated = list(messages)
if committed and updated:
last = updated[-1]
if isinstance(last, dict) and last.get("role") == "assistant":
content = last.get("content")
if isinstance(content, str) and content.strip() == committed:
updated[-1] = {"role": "assistant", "content": streamed}
return updated
if not updated:
updated.append({"role": "assistant", "content": streamed})
return updated
last = updated[-1]
if isinstance(last, dict) and last.get("role") == "assistant":
content = last.get("content")
if isinstance(content, str) and content.strip() != streamed:
updated[-1] = {"role": "assistant", "content": streamed}
return updated
if (
len(updated) >= 2
and isinstance(last, dict)
and last.get("role") == "user"
):
prev = updated[-2]
if isinstance(prev, dict) and prev.get("role") == "user":
updated.insert(len(updated) - 1, {"role": "assistant", "content": streamed})
return updated
if isinstance(last, dict) and last.get("role") == "user":
updated.append({"role": "assistant", "content": streamed})
return updated
updated.append({"role": "assistant", "content": streamed})
return updated
aggregator.context.transform_messages(_apply)
def maybe_sync_assistant_context(
aggregator: _AssistantContextSync,
text_stream: "ProductTextStreamProcessor",
*,
committed_text: str | None = None,
) -> None:
committed = (
committed_text.strip()
if committed_text is not None
else _committed_assistant_content(aggregator.context)
)
streamed = text_stream.last_assistant_stream_text()
if not streamed:
return
sync_streamed_assistant_context(
aggregator,
streamed_text=streamed,
committed_text=committed,
)
class ProductTextStreamProcessor(FrameProcessor):
"""Mirrors LLM text frames as streaming protocol events.
@@ -72,8 +126,12 @@ class ProductTextStreamProcessor(FrameProcessor):
super().__init__()
self._aggregation: list[str] = []
self._turn_active = False
self._last_assistant_stream_text = ""
self._interrupted_stream_text: str | None = None
def last_assistant_stream_text(self) -> str:
return self._last_assistant_stream_text
def take_interrupted_stream_text(self) -> str | None:
text = self._interrupted_stream_text
self._interrupted_stream_text = None
@@ -94,7 +152,7 @@ class ProductTextStreamProcessor(FrameProcessor):
await self._end_turn(interrupted=False)
elif isinstance(frame, (InterruptionFrame, CancelFrame)):
await self.push_frame(frame, direction)
await self._end_turn(interrupted=True)
await self._handle_interrupt()
elif isinstance(frame, TTSSpeakFrame):
# Fixed-text / direct-speech path: there's no LLM cycle, so
# synthesize one started/delta/final sequence for the spoken text.
@@ -122,12 +180,24 @@ class ProductTextStreamProcessor(FrameProcessor):
self._aggregation.append(text)
await self._emit("response.text.delta", text=text)
async def _handle_interrupt(self) -> None:
if self._turn_active:
await self._end_turn(interrupted=True)
return
if self._last_assistant_stream_text:
self._interrupted_stream_text = self._last_assistant_stream_text
async def _end_turn(self, *, interrupted: bool) -> None:
if not self._turn_active:
return
full_text = "".join(self._aggregation)
if full_text:
self._last_assistant_stream_text = full_text
if interrupted and full_text:
self._interrupted_stream_text = full_text
self._turn_active = False
self._aggregation = []
await self._emit(