Merge pull request #3855 from pipecat-ai/mb/context-summarization-improvements

Improve context summarization with dedicated LLM, timeout, and observability
This commit is contained in:
Mark Backman
2026-02-27 15:24:38 -05:00
committed by GitHub
14 changed files with 912 additions and 30 deletions

View File

@@ -0,0 +1 @@
- Added optional `llm` field to `LLMContextSummarizationConfig` for routing summarization to a dedicated LLM service (e.g., a cheaper/faster model) instead of the pipeline's primary model.

View File

@@ -0,0 +1 @@
- Added `summarization_timeout` to `LLMContextSummarizationConfig` (default 120s) to prevent hung LLM calls from permanently blocking future summarizations.

View File

@@ -0,0 +1 @@
- Added `on_summary_applied` event to `LLMContextSummarizer` for observability, providing message counts before and after context summarization.

1
changelog/3855.added.md Normal file
View File

@@ -0,0 +1 @@
- Added `summary_message_template` to `LLMContextSummarizationConfig` for customizing how summaries are formatted when injected into context (e.g., wrapping in XML tags).

View File

@@ -0,0 +1 @@
- Updated context summarization to use `user` role instead of `assistant` for summary messages.

View File

@@ -20,14 +20,13 @@ from loguru import logger
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.audio.vad.vad_analyzer import VADParams
from pipecat.frames.frames import LLMRunFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_context_summarizer import SummaryAppliedEvent
from pipecat.processors.aggregators.llm_response_universal import (
LLMAssistantAggregatorParams,
LLMContextAggregatorPair,
@@ -42,8 +41,6 @@ from pipecat.services.openai.llm import OpenAILLMService
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
from pipecat.turns.user_stop import TurnAnalyzerUserTurnStopStrategy
from pipecat.turns.user_turn_strategies import UserTurnStrategies
from pipecat.utils.context.llm_context_summarization import LLMContextSummarizationConfig
load_dotenv(override=True)
@@ -120,10 +117,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
context,
user_params=LLMUserAggregatorParams(
user_turn_strategies=UserTurnStrategies(
stop=[TurnAnalyzerUserTurnStopStrategy(turn_analyzer=LocalSmartTurnAnalyzerV3())]
),
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
vad_analyzer=SileroVADAnalyzer(),
),
assistant_params=LLMAssistantAggregatorParams(
enable_context_summarization=True,
@@ -138,6 +132,19 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
),
)
# Listen for summarization events
summarizer = assistant_aggregator._summarizer
if summarizer:
@summarizer.event_handler("on_summary_applied")
async def on_summary_applied(summarizer, event: SummaryAppliedEvent):
logger.info(
f"Context summarized: {event.original_message_count} messages -> "
f"{event.new_message_count} messages "
f"({event.summarized_message_count} summarized, "
f"{event.preserved_message_count} preserved)"
)
pipeline = Pipeline(
[
transport.input(), # Transport user input

View File

@@ -20,14 +20,13 @@ from loguru import logger
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.audio.vad.vad_analyzer import VADParams
from pipecat.frames.frames import LLMRunFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_context_summarizer import SummaryAppliedEvent
from pipecat.processors.aggregators.llm_response_universal import (
LLMAssistantAggregatorParams,
LLMContextAggregatorPair,
@@ -42,8 +41,6 @@ from pipecat.services.llm_service import FunctionCallParams
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
from pipecat.turns.user_stop import TurnAnalyzerUserTurnStopStrategy
from pipecat.turns.user_turn_strategies import UserTurnStrategies
from pipecat.utils.context.llm_context_summarization import LLMContextSummarizationConfig
load_dotenv(override=True)
@@ -120,10 +117,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
context,
user_params=LLMUserAggregatorParams(
user_turn_strategies=UserTurnStrategies(
stop=[TurnAnalyzerUserTurnStopStrategy(turn_analyzer=LocalSmartTurnAnalyzerV3())]
),
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
vad_analyzer=SileroVADAnalyzer(),
),
assistant_params=LLMAssistantAggregatorParams(
enable_context_summarization=True,
@@ -138,6 +132,19 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
),
)
# Listen for summarization events
summarizer = assistant_aggregator._summarizer
if summarizer:
@summarizer.event_handler("on_summary_applied")
async def on_summary_applied(summarizer, event: SummaryAppliedEvent):
logger.info(
f"Context summarized: {event.original_message_count} messages -> "
f"{event.new_message_count} messages "
f"({event.summarized_message_count} summarized, "
f"{event.preserved_message_count} preserved)"
)
pipeline = Pipeline(
[
transport.input(), # Transport user input

View File

@@ -0,0 +1,231 @@
#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Example demonstrating advanced context summarization configuration.
This example shows how to customize context summarization with:
- A dedicated cheap/fast LLM for generating summaries (Gemini Flash)
- A custom summary message template (XML tags)
- A custom summarization prompt
- A summarization timeout
- The on_summary_applied event for observability
"""
import asyncio
import os
from dotenv import load_dotenv
from loguru import logger
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import LLMRunFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_context_summarizer import SummaryAppliedEvent
from pipecat.processors.aggregators.llm_response_universal import (
LLMAssistantAggregatorParams,
LLMContextAggregatorPair,
LLMUserAggregatorParams,
)
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport
from pipecat.services.cartesia.tts import CartesiaTTSService
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.google import GoogleLLMService
from pipecat.services.llm_service import FunctionCallParams
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
from pipecat.utils.context.llm_context_summarization import LLMContextSummarizationConfig
load_dotenv(override=True)
# We use lambdas to defer transport parameter creation until the transport
# type is selected at runtime.
transport_params = {
"daily": lambda: DailyParams(
audio_in_enabled=True,
audio_out_enabled=True,
),
"twilio": lambda: FastAPIWebsocketParams(
audio_in_enabled=True,
audio_out_enabled=True,
),
"webrtc": lambda: TransportParams(
audio_in_enabled=True,
audio_out_enabled=True,
),
}
# Custom summarization prompt tailored to the application
CUSTOM_SUMMARIZATION_PROMPT = """Summarize this conversation, preserving:
- Key decisions and agreements
- Important facts and user preferences
- Any pending action items or unresolved questions
Be concise. Use clear, factual statements grouped by topic.
Omit greetings, small talk, and resolved tangents."""
# Tool functions for the LLM
async def get_current_weather(params: FunctionCallParams):
"""Get the current weather."""
logger.info("Tool called: get_current_weather")
await asyncio.sleep(1) # Simulate some processing
await params.result_callback({"conditions": "nice", "temperature": "75"})
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
logger.info("Starting bot")
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
tts = CartesiaTTSService(
api_key=os.getenv("CARTESIA_API_KEY"),
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
)
# Primary LLM for conversation (could be any provider)
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
# Dedicated cheap/fast LLM for summarization only
summarization_llm = GoogleLLMService(
api_key=os.getenv("GOOGLE_API_KEY"),
model="gemini-2.5-flash",
)
# Register tool functions
llm.register_function("get_current_weather", get_current_weather)
weather_function = FunctionSchema(
name="get_current_weather",
description="Get the current weather",
properties={
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this from the user's location.",
},
},
required=["location", "format"],
)
tools = ToolsSchema(standard_tools=[weather_function])
messages = [
{
"role": "system",
"content": (
"You are a helpful LLM in a WebRTC call. Your goal is to demonstrate "
"your capabilities in a succinct way. Your output will be spoken aloud, "
"so avoid special characters that can't easily be spoken. Respond to what "
"the user said in a creative and helpful way. You have access to tools to "
"get the current weather - use them when relevant.\n\n"
"When you see a <context_summary> block, it contains a compressed summary "
"of earlier conversation. Use it as reference but don't mention it to the user."
),
},
]
context = LLMContext(messages, tools=tools)
# Create aggregators with custom summarization
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
context,
user_params=LLMUserAggregatorParams(
vad_analyzer=SileroVADAnalyzer(),
),
assistant_params=LLMAssistantAggregatorParams(
enable_context_summarization=True,
context_summarization_config=LLMContextSummarizationConfig(
# Trigger thresholds (low values to demonstrate quickly)
max_context_tokens=1000,
max_unsummarized_messages=10,
# Summary generation
target_context_tokens=800,
min_messages_after_summary=2,
summarization_prompt=CUSTOM_SUMMARIZATION_PROMPT,
# Custom summary format - wrap in XML tags so the system
# prompt can identify summaries vs. live conversation
summary_message_template="<context_summary>\n{summary}\n</context_summary>",
# Use a dedicated cheap LLM for summarization instead of
# the primary conversation model
llm=summarization_llm,
# Cancel summarization if it takes longer than 60 seconds
summarization_timeout=60.0,
),
),
)
# Listen for summarization events
summarizer = assistant_aggregator._summarizer
if summarizer:
@summarizer.event_handler("on_summary_applied")
async def on_summary_applied(summarizer, event: SummaryAppliedEvent):
logger.info(
f"Context summarized: {event.original_message_count} messages -> "
f"{event.new_message_count} messages "
f"({event.summarized_message_count} summarized, "
f"{event.preserved_message_count} preserved)"
)
pipeline = Pipeline(
[
transport.input(), # Transport user input
stt,
user_aggregator, # User responses
llm, # LLM
tts, # TTS
transport.output(), # Transport bot output
assistant_aggregator, # Assistant spoken responses
]
)
task = PipelineTask(
pipeline,
params=PipelineParams(
enable_metrics=True,
enable_usage_metrics=True,
),
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
)
@transport.event_handler("on_client_connected")
async def on_client_connected(transport, client):
logger.info("Client connected")
# Kick off the conversation.
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
await task.queue_frames([LLMRunFrame()])
@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(transport, client):
logger.info("Client disconnected")
await task.cancel()
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
await runner.run(task)
async def bot(runner_args: RunnerArguments):
"""Main bot entry point compatible with Pipecat Cloud."""
transport = await create_transport(runner_args, transport_params)
await run_bot(transport, runner_args)
if __name__ == "__main__":
from pipecat.runner.run import main
main()

View File

@@ -2019,6 +2019,8 @@ class LLMContextSummaryRequestFrame(ControlFrame):
the summary text.
summarization_prompt: System prompt instructing the LLM how to generate
the summary.
summarization_timeout: Maximum time in seconds for the LLM to generate a
summary. When None, a default timeout of 120s is applied.
"""
request_id: str
@@ -2026,6 +2028,7 @@ class LLMContextSummaryRequestFrame(ControlFrame):
min_messages_to_keep: int
target_context_tokens: int
summarization_prompt: str
summarization_timeout: Optional[float] = None
@dataclass

View File

@@ -6,8 +6,10 @@
"""This module defines a summarizer for managing LLM context summarization."""
import asyncio
import uuid
from typing import Optional
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
from loguru import logger
@@ -22,10 +24,33 @@ from pipecat.processors.aggregators.llm_context import LLMContext, LLMSpecificMe
from pipecat.utils.asyncio.task_manager import BaseTaskManager
from pipecat.utils.base_object import BaseObject
from pipecat.utils.context.llm_context_summarization import (
DEFAULT_SUMMARIZATION_TIMEOUT,
LLMContextSummarizationConfig,
LLMContextSummarizationUtil,
)
if TYPE_CHECKING:
from pipecat.services.llm_service import LLMService
@dataclass
class SummaryAppliedEvent:
"""Event data emitted when context summarization completes successfully.
Parameters:
original_message_count: Number of messages before summarization.
new_message_count: Number of messages after summarization.
summarized_message_count: Number of messages that were compressed
into the summary.
preserved_message_count: Number of recent messages preserved
uncompressed.
"""
original_message_count: int
new_message_count: int
summarized_message_count: int
preserved_message_count: int
class LLMContextSummarizer(BaseObject):
"""Summarizer for managing LLM context summarization.
@@ -39,6 +64,10 @@ class LLMContextSummarizer(BaseObject):
- on_request_summarization: Emitted when summarization should be triggered.
The aggregator should broadcast this frame to the LLM service.
- on_summary_applied: Emitted after a summary has been successfully applied
to the context. Receives a SummaryAppliedEvent with metrics about the
compression.
Example::
@summarizer.event_handler("on_request_summarization")
@@ -49,6 +78,10 @@ class LLMContextSummarizer(BaseObject):
context=frame.context,
...
)
@summarizer.event_handler("on_summary_applied")
async def on_summary_applied(summarizer, event: SummaryAppliedEvent):
logger.info(f"Compressed {event.original_message_count} -> {event.new_message_count} messages")
"""
def __init__(
@@ -74,6 +107,7 @@ class LLMContextSummarizer(BaseObject):
self._pending_summary_request_id: Optional[str] = None
self._register_event_handler("on_request_summarization", sync=True)
self._register_event_handler("on_summary_applied")
@property
def task_manager(self) -> BaseTaskManager:
@@ -198,8 +232,10 @@ class LLMContextSummarizer(BaseObject):
async def _request_summarization(self):
"""Request context summarization from LLM service.
Creates a summarization request frame and emits it via event handler.
Tracks the request ID to match async responses and prevent race conditions.
Creates a summarization request frame and either handles it directly
using a dedicated LLM (if configured) or emits it via event handler
for the pipeline's primary LLM. Tracks the request ID to match async
responses and prevent race conditions.
"""
# Generate unique request ID
request_id = str(uuid.uuid4())
@@ -218,10 +254,63 @@ class LLMContextSummarizer(BaseObject):
min_messages_to_keep=min_keep,
target_context_tokens=self._config.target_context_tokens,
summarization_prompt=self._config.summary_prompt,
summarization_timeout=self._config.summarization_timeout,
)
# Emit event for aggregator to broadcast
await self._call_event_handler("on_request_summarization", request_frame)
if self._config.llm:
# Use dedicated LLM directly — no need to involve the pipeline
self.task_manager.create_task(
self._generate_summary_with_dedicated_llm(self._config.llm, request_frame),
f"{self}-dedicated-llm-summary",
)
else:
# Emit event for aggregator to broadcast to the pipeline LLM
await self._call_event_handler("on_request_summarization", request_frame)
async def _generate_summary_with_dedicated_llm(
self, llm: "LLMService", frame: LLMContextSummaryRequestFrame
):
"""Generate summary using a dedicated LLM service.
Calls the dedicated LLM's _generate_summary directly and feeds the
result back through _handle_summary_result, bypassing the pipeline.
Args:
llm: The dedicated LLM service to use for summarization.
frame: The summarization request frame.
"""
timeout = frame.summarization_timeout or DEFAULT_SUMMARIZATION_TIMEOUT
try:
summary, last_index = await asyncio.wait_for(
llm._generate_summary(frame),
timeout=timeout,
)
result_frame = LLMContextSummaryResultFrame(
request_id=frame.request_id,
summary=summary,
last_summarized_index=last_index,
)
except asyncio.TimeoutError:
error = f"Context summarization timed out after {timeout}s"
logger.error(f"{self}: {error}")
result_frame = LLMContextSummaryResultFrame(
request_id=frame.request_id,
summary="",
last_summarized_index=-1,
error=error,
)
except Exception as e:
error = f"Error generating context summary: {e}"
logger.error(f"{self}: {error}")
result_frame = LLMContextSummaryResultFrame(
request_id=frame.request_id,
summary="",
last_summarized_index=-1,
error=error,
)
await self._handle_summary_result(result_frame)
async def _handle_summary_result(self, frame: LLMContextSummaryResultFrame):
"""Handle context summarization result from LLM service.
@@ -306,8 +395,10 @@ class LLMContextSummarizer(BaseObject):
# Get recent messages to keep
recent_messages = messages[last_summarized_index + 1 :]
# Create summary message as an assistant message
summary_message = {"role": "assistant", "content": f"Conversation summary: {summary}"}
# Create summary message as a user message (the summary is context
# provided *to* the assistant, not something the assistant said)
summary_content = self._config.summary_message_template.format(summary=summary)
summary_message = {"role": "user", "content": summary_content}
# Reconstruct context
new_messages = []
@@ -317,9 +408,23 @@ class LLMContextSummarizer(BaseObject):
new_messages.extend(recent_messages)
# Update context
original_message_count = len(messages)
num_system_preserved = 1 if first_system_msg else 0
self._context.set_messages(new_messages)
# Messages actually summarized = index range minus the preserved system message
summarized_count = last_summarized_index + 1 - num_system_preserved
logger.info(
f"{self}: Applied context summary, compressed {last_summarized_index + 1} messages "
f"into summary. Context now has {len(new_messages)} messages (was {len(messages)})"
f"{self}: Applied context summary, compressed {summarized_count} messages "
f"into summary. Context now has {len(new_messages)} messages (was {original_message_count})"
)
# Emit event for observability
event = SummaryAppliedEvent(
original_message_count=original_message_count,
new_message_count=len(new_messages),
summarized_message_count=summarized_count,
preserved_message_count=len(recent_messages) + num_system_preserved,
)
await self._call_event_handler("on_summary_applied", event)

View File

@@ -62,6 +62,7 @@ from pipecat.services.ai_service import AIService
from pipecat.services.settings import LLMSettings
from pipecat.turns.user_turn_completion_mixin import UserTurnCompletionLLMServiceMixin
from pipecat.utils.context.llm_context_summarization import (
DEFAULT_SUMMARIZATION_TIMEOUT,
LLMContextSummarizationUtil,
)
@@ -436,8 +437,15 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
last_index = -1
error = None
timeout = frame.summarization_timeout or DEFAULT_SUMMARIZATION_TIMEOUT
try:
summary, last_index = await self._generate_summary(frame)
summary, last_index = await asyncio.wait_for(
self._generate_summary(frame),
timeout=timeout,
)
except asyncio.TimeoutError:
await self.push_error(error_msg=f"Context summarization timed out after {timeout}s")
except Exception as e:
error = f"Error generating context summary: {e}"
await self.push_error(error, exception=e)

View File

@@ -11,12 +11,18 @@ context when token limits are reached, enabling efficient long-running conversat
"""
from dataclasses import dataclass
from typing import List, Optional
from typing import TYPE_CHECKING, List, Optional
if TYPE_CHECKING:
from pipecat.services.llm_service import LLMService
from loguru import logger
from pipecat.processors.aggregators.llm_context import LLMContext, LLMSpecificMessage
# Fallback timeout (seconds) used when summarization_timeout is None.
DEFAULT_SUMMARIZATION_TIMEOUT = 120.0
# Token estimation constants
CHARS_PER_TOKEN = 4 # Industry-standard heuristic: 1 token ≈ 4 characters
TOKEN_OVERHEAD_PER_MESSAGE = 10 # Estimated structural overhead per message
@@ -73,6 +79,19 @@ class LLMContextSummarizationConfig:
immediate conversational context.
summarization_prompt: Custom prompt for the LLM to use when generating
summaries. If None, uses DEFAULT_SUMMARIZATION_PROMPT.
summary_message_template: Template for formatting the summary when
injected into context. Must contain ``{summary}`` as a placeholder
for the generated summary text. Allows applications to wrap the
summary in custom delimiters (e.g., XML tags) so that system
prompts can distinguish summaries from live conversation.
llm: Optional separate LLM service for generating summaries. When set,
summarization requests are sent to this service instead of the
pipeline's primary LLM. Useful for routing summarization to a
cheaper/faster model (e.g., Gemini Flash) while keeping an
expensive model for conversation. If None, uses the pipeline LLM.
summarization_timeout: Maximum time in seconds to wait for the LLM to
generate a summary. If the call exceeds this timeout, summarization
is aborted with an error and future summarizations are unblocked.
"""
max_context_tokens: int = 8000
@@ -80,6 +99,9 @@ class LLMContextSummarizationConfig:
max_unsummarized_messages: int = 20
min_messages_after_summary: int = 4
summarization_prompt: Optional[str] = None
summary_message_template: str = "Conversation summary: {summary}"
llm: Optional["LLMService"] = None
summarization_timeout: float = DEFAULT_SUMMARIZATION_TIMEOUT
def __post_init__(self):
"""Validate configuration parameters."""

View File

@@ -6,10 +6,11 @@
"""Tests for context summarization feature."""
import asyncio
import unittest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock
from pipecat.frames.frames import LLMContextSummaryRequestFrame
from pipecat.frames.frames import LLMContextSummaryRequestFrame, LLMContextSummaryResultFrame
from pipecat.processors.aggregators.llm_context import LLMContext, LLMSpecificMessage
from pipecat.services.llm_service import LLMService
from pipecat.utils.context.llm_context_summarization import (
@@ -601,6 +602,250 @@ class TestSummaryGenerationExceptions(unittest.IsolatedAsyncioTestCase):
self.assertGreater(last_index, -1)
self.assertEqual(last_index, 1) # Should be the index of the last summarized message
async def test_generate_summary_task_timeout(self):
"""Test that _generate_summary_task handles timeout correctly."""
llm_service = LLMService()
# Mock _generate_summary to hang
async def slow_summary(frame):
await asyncio.sleep(10)
return ("summary", 1)
llm_service._generate_summary = slow_summary
broadcast_calls = []
async def mock_broadcast(frame_class, **kwargs):
broadcast_calls.append((frame_class, kwargs))
llm_service.broadcast_frame = mock_broadcast
llm_service.push_error = AsyncMock()
context = LLMContext()
context.add_message({"role": "user", "content": "Message 1"})
context.add_message({"role": "assistant", "content": "Response 1"})
context.add_message({"role": "user", "content": "Message 2"})
frame = LLMContextSummaryRequestFrame(
request_id="timeout_test",
context=context,
min_messages_to_keep=1,
target_context_tokens=1000,
summarization_prompt="Summarize this",
summarization_timeout=0.1, # Very short timeout
)
await llm_service._generate_summary_task(frame)
# Should have broadcast an error result
self.assertEqual(len(broadcast_calls), 1)
_, kwargs = broadcast_calls[0]
self.assertEqual(kwargs["request_id"], "timeout_test")
self.assertEqual(kwargs["summary"], "")
self.assertEqual(kwargs["last_summarized_index"], -1)
# error is None for timeout path (push_error is called instead)
self.assertIsNone(kwargs["error"])
# push_error should have been called with timeout message
llm_service.push_error.assert_called_once()
call_args = llm_service.push_error.call_args
error_msg = call_args.kwargs.get("error_msg") or call_args.args[0]
self.assertIn("timed out", error_msg)
class TestDedicatedLLMSummarization(unittest.IsolatedAsyncioTestCase):
"""Tests for dedicated LLM summarization in LLMContextSummarizer."""
async def asyncSetUp(self):
from pipecat.utils.asyncio.task_manager import TaskManager, TaskManagerParams
self.task_manager = TaskManager()
self.task_manager.setup(TaskManagerParams(loop=asyncio.get_running_loop()))
def _create_context_and_config(self, dedicated_llm):
"""Create a context with enough messages and a config with a dedicated LLM."""
context = LLMContext()
for i in range(10):
context.add_message(
{"role": "user", "content": f"Test message {i} that adds tokens to context."}
)
config = LLMContextSummarizationConfig(
max_context_tokens=50, # Very low to trigger easily
llm=dedicated_llm,
summarization_timeout=5.0,
)
return context, config
async def test_dedicated_llm_success(self):
"""Test that dedicated LLM generates summary and applies result."""
from pipecat.processors.aggregators.llm_context_summarizer import LLMContextSummarizer
dedicated_llm = LLMService()
dedicated_llm._generate_summary = AsyncMock(return_value=("Dedicated summary", 5))
context, config = self._create_context_and_config(dedicated_llm)
original_message_count = len(context.messages)
summarizer = LLMContextSummarizer(context=context, config=config)
await summarizer.setup(self.task_manager)
# Track whether on_request_summarization event fires (it should NOT)
event_fired = False
@summarizer.event_handler("on_request_summarization")
async def on_request_summarization(summarizer, frame):
nonlocal event_fired
event_fired = True
# Trigger summarization via LLM response start
from pipecat.frames.frames import LLMFullResponseStartFrame
await summarizer.process_frame(LLMFullResponseStartFrame())
# Wait for the background task to complete
await asyncio.sleep(0.1)
# The event should NOT have fired (dedicated LLM handles it internally)
self.assertFalse(event_fired)
# Verify the dedicated LLM was called
dedicated_llm._generate_summary.assert_called_once()
# Verify summary was applied to context (message count should decrease)
self.assertLess(len(context.messages), original_message_count)
# Verify summary message is present
summary_messages = [
msg for msg in context.messages if "Conversation summary:" in msg.get("content", "")
]
self.assertEqual(len(summary_messages), 1)
self.assertIn("Dedicated summary", summary_messages[0]["content"])
await summarizer.cleanup()
async def test_dedicated_llm_timeout(self):
"""Test that dedicated LLM timeout produces error and clears state."""
from pipecat.processors.aggregators.llm_context_summarizer import LLMContextSummarizer
dedicated_llm = LLMService()
async def slow_summary(frame):
await asyncio.sleep(10)
return ("summary", 1)
dedicated_llm._generate_summary = slow_summary
context, config = self._create_context_and_config(dedicated_llm)
config.summarization_timeout = 0.1 # Very short timeout
summarizer = LLMContextSummarizer(context=context, config=config)
await summarizer.setup(self.task_manager)
original_message_count = len(context.messages)
# Trigger summarization
from pipecat.frames.frames import LLMFullResponseStartFrame
await summarizer.process_frame(LLMFullResponseStartFrame())
# Wait for the background task to complete (timeout + some buffer)
await asyncio.sleep(0.3)
# Context should be unchanged (timeout = error = no summary applied)
self.assertEqual(len(context.messages), original_message_count)
# Summarization state should be cleared so new requests can be made
self.assertFalse(summarizer._summarization_in_progress)
await summarizer.cleanup()
async def test_dedicated_llm_exception(self):
"""Test that dedicated LLM exceptions produce error and clear state."""
from pipecat.processors.aggregators.llm_context_summarizer import LLMContextSummarizer
dedicated_llm = LLMService()
dedicated_llm._generate_summary = AsyncMock(
side_effect=RuntimeError("LLM connection failed")
)
context, config = self._create_context_and_config(dedicated_llm)
summarizer = LLMContextSummarizer(context=context, config=config)
await summarizer.setup(self.task_manager)
original_message_count = len(context.messages)
# Trigger summarization
from pipecat.frames.frames import LLMFullResponseStartFrame
await summarizer.process_frame(LLMFullResponseStartFrame())
# Wait for the background task to complete
await asyncio.sleep(0.1)
# Context should be unchanged (exception = error = no summary applied)
self.assertEqual(len(context.messages), original_message_count)
# Summarization state should be cleared
self.assertFalse(summarizer._summarization_in_progress)
await summarizer.cleanup()
async def test_dedicated_llm_does_not_emit_event(self):
"""Test that summarizer does NOT emit on_request_summarization when dedicated LLM is set."""
from pipecat.processors.aggregators.llm_context_summarizer import LLMContextSummarizer
dedicated_llm = LLMService()
dedicated_llm._generate_summary = AsyncMock(return_value=("Summary", 1))
context, config = self._create_context_and_config(dedicated_llm)
summarizer = LLMContextSummarizer(context=context, config=config)
await summarizer.setup(self.task_manager)
event_fired = False
@summarizer.event_handler("on_request_summarization")
async def on_request_summarization(summarizer, frame):
nonlocal event_fired
event_fired = True
from pipecat.frames.frames import LLMFullResponseStartFrame
await summarizer.process_frame(LLMFullResponseStartFrame())
await asyncio.sleep(0.1)
self.assertFalse(event_fired)
await summarizer.cleanup()
async def test_no_dedicated_llm_emits_event(self):
"""Test that summarizer emits on_request_summarization when no dedicated LLM."""
from pipecat.processors.aggregators.llm_context_summarizer import LLMContextSummarizer
context = LLMContext()
for i in range(10):
context.add_message(
{"role": "user", "content": f"Test message {i} that adds tokens to context."}
)
config = LLMContextSummarizationConfig(max_context_tokens=50)
summarizer = LLMContextSummarizer(context=context, config=config)
await summarizer.setup(self.task_manager)
request_frame = None
@summarizer.event_handler("on_request_summarization")
async def on_request_summarization(summarizer, frame):
nonlocal request_frame
request_frame = frame
from pipecat.frames.frames import LLMFullResponseStartFrame
await summarizer.process_frame(LLMFullResponseStartFrame())
self.assertIsNotNone(request_frame)
self.assertIsInstance(request_frame, LLMContextSummaryRequestFrame)
await summarizer.cleanup()
class TestLLMSpecificMessageHandling(unittest.TestCase):
"""Tests that LLMSpecificMessage objects are correctly skipped in summarization."""

View File

@@ -14,7 +14,10 @@ from pipecat.frames.frames import (
LLMFullResponseStartFrame,
)
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_context_summarizer import LLMContextSummarizer
from pipecat.processors.aggregators.llm_context_summarizer import (
LLMContextSummarizer,
SummaryAppliedEvent,
)
from pipecat.utils.asyncio.task_manager import TaskManager, TaskManagerParams
from pipecat.utils.context.llm_context_summarization import LLMContextSummarizationConfig
@@ -291,6 +294,252 @@ class TestLLMContextSummarizer(unittest.IsolatedAsyncioTestCase):
await summarizer.cleanup()
async def test_summary_message_role_is_user(self):
"""Test that the summary message uses the user role."""
config = LLMContextSummarizationConfig(max_context_tokens=50, min_messages_after_summary=2)
summarizer = LLMContextSummarizer(context=self.context, config=config)
await summarizer.setup(self.task_manager)
# Add messages and trigger summarization
for i in range(10):
self.context.add_message({"role": "user", "content": "Test message."})
request_frame = None
@summarizer.event_handler("on_request_summarization")
async def on_request_summarization(summarizer, frame):
nonlocal request_frame
request_frame = frame
await summarizer.process_frame(LLMFullResponseStartFrame())
self.assertIsNotNone(request_frame)
# Simulate receiving a summary result
summary_result = LLMContextSummaryResultFrame(
request_id=request_frame.request_id,
summary="This is a test summary.",
last_summarized_index=5,
)
await summarizer.process_frame(summary_result)
# Find the summary message and verify its role is "user"
summary_msg = next(
(msg for msg in self.context.messages if "summary" in msg.get("content", "").lower()),
None,
)
self.assertIsNotNone(summary_msg)
self.assertEqual(summary_msg["role"], "user")
await summarizer.cleanup()
async def test_summary_message_default_template(self):
"""Test that the default summary_message_template is used."""
config = LLMContextSummarizationConfig(max_context_tokens=50, min_messages_after_summary=2)
summarizer = LLMContextSummarizer(context=self.context, config=config)
await summarizer.setup(self.task_manager)
for i in range(10):
self.context.add_message({"role": "user", "content": "Test message."})
request_frame = None
@summarizer.event_handler("on_request_summarization")
async def on_request_summarization(summarizer, frame):
nonlocal request_frame
request_frame = frame
await summarizer.process_frame(LLMFullResponseStartFrame())
summary_result = LLMContextSummaryResultFrame(
request_id=request_frame.request_id,
summary="Key facts from conversation.",
last_summarized_index=5,
)
await summarizer.process_frame(summary_result)
# Default template wraps with "Conversation summary: {summary}"
summary_msg = next(
(
msg
for msg in self.context.messages
if "Conversation summary:" in msg.get("content", "")
),
None,
)
self.assertIsNotNone(summary_msg)
self.assertEqual(
summary_msg["content"], "Conversation summary: Key facts from conversation."
)
await summarizer.cleanup()
async def test_summary_message_custom_template(self):
"""Test that a custom summary_message_template is applied."""
config = LLMContextSummarizationConfig(
max_context_tokens=50,
min_messages_after_summary=2,
summary_message_template="<context_summary>\n{summary}\n</context_summary>",
)
summarizer = LLMContextSummarizer(context=self.context, config=config)
await summarizer.setup(self.task_manager)
for i in range(10):
self.context.add_message({"role": "user", "content": "Test message."})
request_frame = None
@summarizer.event_handler("on_request_summarization")
async def on_request_summarization(summarizer, frame):
nonlocal request_frame
request_frame = frame
await summarizer.process_frame(LLMFullResponseStartFrame())
summary_result = LLMContextSummaryResultFrame(
request_id=request_frame.request_id,
summary="Key facts from conversation.",
last_summarized_index=5,
)
await summarizer.process_frame(summary_result)
# Custom template wraps with XML tags
summary_msg = next(
(msg for msg in self.context.messages if "<context_summary>" in msg.get("content", "")),
None,
)
self.assertIsNotNone(summary_msg)
self.assertEqual(
summary_msg["content"],
"<context_summary>\nKey facts from conversation.\n</context_summary>",
)
await summarizer.cleanup()
async def test_on_summary_applied_event(self):
"""Test that on_summary_applied event fires with correct data."""
config = LLMContextSummarizationConfig(max_context_tokens=50, min_messages_after_summary=2)
summarizer = LLMContextSummarizer(context=self.context, config=config)
await summarizer.setup(self.task_manager)
# Add messages (1 system + 10 user = 11 total)
for i in range(10):
self.context.add_message({"role": "user", "content": "Test message."})
request_frame = None
applied_event = None
@summarizer.event_handler("on_request_summarization")
async def on_request_summarization(summarizer, frame):
nonlocal request_frame
request_frame = frame
@summarizer.event_handler("on_summary_applied")
async def on_summary_applied(summarizer, event):
nonlocal applied_event
applied_event = event
original_count = len(self.context.messages) # 11
await summarizer.process_frame(LLMFullResponseStartFrame())
# Summarize up to index 7 (system=0, user1..user7), keep last 3 (user8, user9, user10)
summary_result = LLMContextSummaryResultFrame(
request_id=request_frame.request_id,
summary="Test summary.",
last_summarized_index=7,
)
await summarizer.process_frame(summary_result)
# Allow async event handler to complete
await asyncio.sleep(0.05)
# Verify event was fired
self.assertIsNotNone(applied_event)
self.assertIsInstance(applied_event, SummaryAppliedEvent)
self.assertEqual(applied_event.original_message_count, original_count)
# After summarization: system + summary + 3 recent = 5
self.assertEqual(applied_event.new_message_count, 5)
# Summarized messages: indices 1-7 = 7 messages (excluding system at index 0)
self.assertEqual(applied_event.summarized_message_count, 7)
# Preserved: system (1) + recent messages after index 7 (3) = 4
self.assertEqual(applied_event.preserved_message_count, 4)
await summarizer.cleanup()
async def test_on_summary_applied_not_fired_on_error(self):
"""Test that on_summary_applied event is NOT fired when summarization fails."""
config = LLMContextSummarizationConfig(max_context_tokens=50, min_messages_after_summary=2)
summarizer = LLMContextSummarizer(context=self.context, config=config)
await summarizer.setup(self.task_manager)
for i in range(10):
self.context.add_message({"role": "user", "content": "Test message."})
request_frame = None
applied_event = None
@summarizer.event_handler("on_request_summarization")
async def on_request_summarization(summarizer, frame):
nonlocal request_frame
request_frame = frame
@summarizer.event_handler("on_summary_applied")
async def on_summary_applied(summarizer, event):
nonlocal applied_event
applied_event = event
await summarizer.process_frame(LLMFullResponseStartFrame())
# Send a result with an error
error_result = LLMContextSummaryResultFrame(
request_id=request_frame.request_id,
summary="",
last_summarized_index=-1,
error="Summarization timed out",
)
await summarizer.process_frame(error_result)
await asyncio.sleep(0.05)
# Event should NOT have fired
self.assertIsNone(applied_event)
await summarizer.cleanup()
async def test_request_frame_includes_timeout(self):
"""Test that the request frame includes the configured summarization_timeout."""
config = LLMContextSummarizationConfig(
max_context_tokens=50,
summarization_timeout=60.0,
)
summarizer = LLMContextSummarizer(context=self.context, config=config)
await summarizer.setup(self.task_manager)
request_frame = None
@summarizer.event_handler("on_request_summarization")
async def on_request_summarization(summarizer, frame):
nonlocal request_frame
request_frame = frame
for i in range(10):
self.context.add_message({"role": "user", "content": "Test message to add tokens."})
await summarizer.process_frame(LLMFullResponseStartFrame())
self.assertIsNotNone(request_frame)
self.assertEqual(request_frame.summarization_timeout, 60.0)
await summarizer.cleanup()
if __name__ == "__main__":
unittest.main()