Merge pull request #3855 from pipecat-ai/mb/context-summarization-improvements
Improve context summarization with dedicated LLM, timeout, and observability
This commit is contained in:
1
changelog/3855.added.2.md
Normal file
1
changelog/3855.added.2.md
Normal file
@@ -0,0 +1 @@
|
||||
- Added optional `llm` field to `LLMContextSummarizationConfig` for routing summarization to a dedicated LLM service (e.g., a cheaper/faster model) instead of the pipeline's primary model.
|
||||
1
changelog/3855.added.3.md
Normal file
1
changelog/3855.added.3.md
Normal file
@@ -0,0 +1 @@
|
||||
- Added `summarization_timeout` to `LLMContextSummarizationConfig` (default 120s) to prevent hung LLM calls from permanently blocking future summarizations.
|
||||
1
changelog/3855.added.4.md
Normal file
1
changelog/3855.added.4.md
Normal file
@@ -0,0 +1 @@
|
||||
- Added `on_summary_applied` event to `LLMContextSummarizer` for observability, providing message counts before and after context summarization.
|
||||
1
changelog/3855.added.md
Normal file
1
changelog/3855.added.md
Normal file
@@ -0,0 +1 @@
|
||||
- Added `summary_message_template` to `LLMContextSummarizationConfig` for customizing how summaries are formatted when injected into context (e.g., wrapping in XML tags).
|
||||
1
changelog/3855.changed.md
Normal file
1
changelog/3855.changed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Updated context summarization to use `user` role instead of `assistant` for summary messages.
|
||||
@@ -20,14 +20,13 @@ from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_context_summarizer import SummaryAppliedEvent
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMContextAggregatorPair,
|
||||
@@ -42,8 +41,6 @@ from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
from pipecat.turns.user_stop import TurnAnalyzerUserTurnStopStrategy
|
||||
from pipecat.turns.user_turn_strategies import UserTurnStrategies
|
||||
from pipecat.utils.context.llm_context_summarization import LLMContextSummarizationConfig
|
||||
|
||||
load_dotenv(override=True)
|
||||
@@ -120,10 +117,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(
|
||||
user_turn_strategies=UserTurnStrategies(
|
||||
stop=[TurnAnalyzerUserTurnStopStrategy(turn_analyzer=LocalSmartTurnAnalyzerV3())]
|
||||
),
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
assistant_params=LLMAssistantAggregatorParams(
|
||||
enable_context_summarization=True,
|
||||
@@ -138,6 +132,19 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
),
|
||||
)
|
||||
|
||||
# Listen for summarization events
|
||||
summarizer = assistant_aggregator._summarizer
|
||||
if summarizer:
|
||||
|
||||
@summarizer.event_handler("on_summary_applied")
|
||||
async def on_summary_applied(summarizer, event: SummaryAppliedEvent):
|
||||
logger.info(
|
||||
f"Context summarized: {event.original_message_count} messages -> "
|
||||
f"{event.new_message_count} messages "
|
||||
f"({event.summarized_message_count} summarized, "
|
||||
f"{event.preserved_message_count} preserved)"
|
||||
)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
|
||||
@@ -20,14 +20,13 @@ from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_context_summarizer import SummaryAppliedEvent
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMContextAggregatorPair,
|
||||
@@ -42,8 +41,6 @@ from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
from pipecat.turns.user_stop import TurnAnalyzerUserTurnStopStrategy
|
||||
from pipecat.turns.user_turn_strategies import UserTurnStrategies
|
||||
from pipecat.utils.context.llm_context_summarization import LLMContextSummarizationConfig
|
||||
|
||||
load_dotenv(override=True)
|
||||
@@ -120,10 +117,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(
|
||||
user_turn_strategies=UserTurnStrategies(
|
||||
stop=[TurnAnalyzerUserTurnStopStrategy(turn_analyzer=LocalSmartTurnAnalyzerV3())]
|
||||
),
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
assistant_params=LLMAssistantAggregatorParams(
|
||||
enable_context_summarization=True,
|
||||
@@ -138,6 +132,19 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
),
|
||||
)
|
||||
|
||||
# Listen for summarization events
|
||||
summarizer = assistant_aggregator._summarizer
|
||||
if summarizer:
|
||||
|
||||
@summarizer.event_handler("on_summary_applied")
|
||||
async def on_summary_applied(summarizer, event: SummaryAppliedEvent):
|
||||
logger.info(
|
||||
f"Context summarized: {event.original_message_count} messages -> "
|
||||
f"{event.new_message_count} messages "
|
||||
f"({event.summarized_message_count} summarized, "
|
||||
f"{event.preserved_message_count} preserved)"
|
||||
)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
|
||||
231
examples/foundational/54c-context-summarization-dedicated-llm.py
Normal file
231
examples/foundational/54c-context-summarization-dedicated-llm.py
Normal file
@@ -0,0 +1,231 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Example demonstrating advanced context summarization configuration.
|
||||
|
||||
This example shows how to customize context summarization with:
|
||||
- A dedicated cheap/fast LLM for generating summaries (Gemini Flash)
|
||||
- A custom summary message template (XML tags)
|
||||
- A custom summarization prompt
|
||||
- A summarization timeout
|
||||
- The on_summary_applied event for observability
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_context_summarizer import SummaryAppliedEvent
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMContextAggregatorPair,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.google import GoogleLLMService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
from pipecat.utils.context.llm_context_summarization import LLMContextSummarizationConfig
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
# We use lambdas to defer transport parameter creation until the transport
|
||||
# type is selected at runtime.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
),
|
||||
}
|
||||
|
||||
# Custom summarization prompt tailored to the application
|
||||
CUSTOM_SUMMARIZATION_PROMPT = """Summarize this conversation, preserving:
|
||||
- Key decisions and agreements
|
||||
- Important facts and user preferences
|
||||
- Any pending action items or unresolved questions
|
||||
|
||||
Be concise. Use clear, factual statements grouped by topic.
|
||||
Omit greetings, small talk, and resolved tangents."""
|
||||
|
||||
|
||||
# Tool functions for the LLM
|
||||
async def get_current_weather(params: FunctionCallParams):
|
||||
"""Get the current weather."""
|
||||
logger.info("Tool called: get_current_weather")
|
||||
await asyncio.sleep(1) # Simulate some processing
|
||||
await params.result_callback({"conditions": "nice", "temperature": "75"})
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info("Starting bot")
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
)
|
||||
|
||||
# Primary LLM for conversation (could be any provider)
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
# Dedicated cheap/fast LLM for summarization only
|
||||
summarization_llm = GoogleLLMService(
|
||||
api_key=os.getenv("GOOGLE_API_KEY"),
|
||||
model="gemini-2.5-flash",
|
||||
)
|
||||
|
||||
# Register tool functions
|
||||
llm.register_function("get_current_weather", get_current_weather)
|
||||
|
||||
weather_function = FunctionSchema(
|
||||
name="get_current_weather",
|
||||
description="Get the current weather",
|
||||
properties={
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the user's location.",
|
||||
},
|
||||
},
|
||||
required=["location", "format"],
|
||||
)
|
||||
tools = ToolsSchema(standard_tools=[weather_function])
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are a helpful LLM in a WebRTC call. Your goal is to demonstrate "
|
||||
"your capabilities in a succinct way. Your output will be spoken aloud, "
|
||||
"so avoid special characters that can't easily be spoken. Respond to what "
|
||||
"the user said in a creative and helpful way. You have access to tools to "
|
||||
"get the current weather - use them when relevant.\n\n"
|
||||
"When you see a <context_summary> block, it contains a compressed summary "
|
||||
"of earlier conversation. Use it as reference but don't mention it to the user."
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
context = LLMContext(messages, tools=tools)
|
||||
|
||||
# Create aggregators with custom summarization
|
||||
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
assistant_params=LLMAssistantAggregatorParams(
|
||||
enable_context_summarization=True,
|
||||
context_summarization_config=LLMContextSummarizationConfig(
|
||||
# Trigger thresholds (low values to demonstrate quickly)
|
||||
max_context_tokens=1000,
|
||||
max_unsummarized_messages=10,
|
||||
# Summary generation
|
||||
target_context_tokens=800,
|
||||
min_messages_after_summary=2,
|
||||
summarization_prompt=CUSTOM_SUMMARIZATION_PROMPT,
|
||||
# Custom summary format - wrap in XML tags so the system
|
||||
# prompt can identify summaries vs. live conversation
|
||||
summary_message_template="<context_summary>\n{summary}\n</context_summary>",
|
||||
# Use a dedicated cheap LLM for summarization instead of
|
||||
# the primary conversation model
|
||||
llm=summarization_llm,
|
||||
# Cancel summarization if it takes longer than 60 seconds
|
||||
summarization_timeout=60.0,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
# Listen for summarization events
|
||||
summarizer = assistant_aggregator._summarizer
|
||||
if summarizer:
|
||||
|
||||
@summarizer.event_handler("on_summary_applied")
|
||||
async def on_summary_applied(summarizer, event: SummaryAppliedEvent):
|
||||
logger.info(
|
||||
f"Context summarized: {event.original_message_count} messages -> "
|
||||
f"{event.new_message_count} messages "
|
||||
f"({event.summarized_message_count} summarized, "
|
||||
f"{event.preserved_message_count} preserved)"
|
||||
)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt,
|
||||
user_aggregator, # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
assistant_aggregator, # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info("Client connected")
|
||||
# Kick off the conversation.
|
||||
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info("Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -2019,6 +2019,8 @@ class LLMContextSummaryRequestFrame(ControlFrame):
|
||||
the summary text.
|
||||
summarization_prompt: System prompt instructing the LLM how to generate
|
||||
the summary.
|
||||
summarization_timeout: Maximum time in seconds for the LLM to generate a
|
||||
summary. When None, a default timeout of 120s is applied.
|
||||
"""
|
||||
|
||||
request_id: str
|
||||
@@ -2026,6 +2028,7 @@ class LLMContextSummaryRequestFrame(ControlFrame):
|
||||
min_messages_to_keep: int
|
||||
target_context_tokens: int
|
||||
summarization_prompt: str
|
||||
summarization_timeout: Optional[float] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -6,8 +6,10 @@
|
||||
|
||||
"""This module defines a summarizer for managing LLM context summarization."""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -22,10 +24,33 @@ from pipecat.processors.aggregators.llm_context import LLMContext, LLMSpecificMe
|
||||
from pipecat.utils.asyncio.task_manager import BaseTaskManager
|
||||
from pipecat.utils.base_object import BaseObject
|
||||
from pipecat.utils.context.llm_context_summarization import (
|
||||
DEFAULT_SUMMARIZATION_TIMEOUT,
|
||||
LLMContextSummarizationConfig,
|
||||
LLMContextSummarizationUtil,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pipecat.services.llm_service import LLMService
|
||||
|
||||
|
||||
@dataclass
|
||||
class SummaryAppliedEvent:
|
||||
"""Event data emitted when context summarization completes successfully.
|
||||
|
||||
Parameters:
|
||||
original_message_count: Number of messages before summarization.
|
||||
new_message_count: Number of messages after summarization.
|
||||
summarized_message_count: Number of messages that were compressed
|
||||
into the summary.
|
||||
preserved_message_count: Number of recent messages preserved
|
||||
uncompressed.
|
||||
"""
|
||||
|
||||
original_message_count: int
|
||||
new_message_count: int
|
||||
summarized_message_count: int
|
||||
preserved_message_count: int
|
||||
|
||||
|
||||
class LLMContextSummarizer(BaseObject):
|
||||
"""Summarizer for managing LLM context summarization.
|
||||
@@ -39,6 +64,10 @@ class LLMContextSummarizer(BaseObject):
|
||||
- on_request_summarization: Emitted when summarization should be triggered.
|
||||
The aggregator should broadcast this frame to the LLM service.
|
||||
|
||||
- on_summary_applied: Emitted after a summary has been successfully applied
|
||||
to the context. Receives a SummaryAppliedEvent with metrics about the
|
||||
compression.
|
||||
|
||||
Example::
|
||||
|
||||
@summarizer.event_handler("on_request_summarization")
|
||||
@@ -49,6 +78,10 @@ class LLMContextSummarizer(BaseObject):
|
||||
context=frame.context,
|
||||
...
|
||||
)
|
||||
|
||||
@summarizer.event_handler("on_summary_applied")
|
||||
async def on_summary_applied(summarizer, event: SummaryAppliedEvent):
|
||||
logger.info(f"Compressed {event.original_message_count} -> {event.new_message_count} messages")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -74,6 +107,7 @@ class LLMContextSummarizer(BaseObject):
|
||||
self._pending_summary_request_id: Optional[str] = None
|
||||
|
||||
self._register_event_handler("on_request_summarization", sync=True)
|
||||
self._register_event_handler("on_summary_applied")
|
||||
|
||||
@property
|
||||
def task_manager(self) -> BaseTaskManager:
|
||||
@@ -198,8 +232,10 @@ class LLMContextSummarizer(BaseObject):
|
||||
async def _request_summarization(self):
|
||||
"""Request context summarization from LLM service.
|
||||
|
||||
Creates a summarization request frame and emits it via event handler.
|
||||
Tracks the request ID to match async responses and prevent race conditions.
|
||||
Creates a summarization request frame and either handles it directly
|
||||
using a dedicated LLM (if configured) or emits it via event handler
|
||||
for the pipeline's primary LLM. Tracks the request ID to match async
|
||||
responses and prevent race conditions.
|
||||
"""
|
||||
# Generate unique request ID
|
||||
request_id = str(uuid.uuid4())
|
||||
@@ -218,10 +254,63 @@ class LLMContextSummarizer(BaseObject):
|
||||
min_messages_to_keep=min_keep,
|
||||
target_context_tokens=self._config.target_context_tokens,
|
||||
summarization_prompt=self._config.summary_prompt,
|
||||
summarization_timeout=self._config.summarization_timeout,
|
||||
)
|
||||
|
||||
# Emit event for aggregator to broadcast
|
||||
await self._call_event_handler("on_request_summarization", request_frame)
|
||||
if self._config.llm:
|
||||
# Use dedicated LLM directly — no need to involve the pipeline
|
||||
self.task_manager.create_task(
|
||||
self._generate_summary_with_dedicated_llm(self._config.llm, request_frame),
|
||||
f"{self}-dedicated-llm-summary",
|
||||
)
|
||||
else:
|
||||
# Emit event for aggregator to broadcast to the pipeline LLM
|
||||
await self._call_event_handler("on_request_summarization", request_frame)
|
||||
|
||||
async def _generate_summary_with_dedicated_llm(
|
||||
self, llm: "LLMService", frame: LLMContextSummaryRequestFrame
|
||||
):
|
||||
"""Generate summary using a dedicated LLM service.
|
||||
|
||||
Calls the dedicated LLM's _generate_summary directly and feeds the
|
||||
result back through _handle_summary_result, bypassing the pipeline.
|
||||
|
||||
Args:
|
||||
llm: The dedicated LLM service to use for summarization.
|
||||
frame: The summarization request frame.
|
||||
"""
|
||||
timeout = frame.summarization_timeout or DEFAULT_SUMMARIZATION_TIMEOUT
|
||||
|
||||
try:
|
||||
summary, last_index = await asyncio.wait_for(
|
||||
llm._generate_summary(frame),
|
||||
timeout=timeout,
|
||||
)
|
||||
result_frame = LLMContextSummaryResultFrame(
|
||||
request_id=frame.request_id,
|
||||
summary=summary,
|
||||
last_summarized_index=last_index,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
error = f"Context summarization timed out after {timeout}s"
|
||||
logger.error(f"{self}: {error}")
|
||||
result_frame = LLMContextSummaryResultFrame(
|
||||
request_id=frame.request_id,
|
||||
summary="",
|
||||
last_summarized_index=-1,
|
||||
error=error,
|
||||
)
|
||||
except Exception as e:
|
||||
error = f"Error generating context summary: {e}"
|
||||
logger.error(f"{self}: {error}")
|
||||
result_frame = LLMContextSummaryResultFrame(
|
||||
request_id=frame.request_id,
|
||||
summary="",
|
||||
last_summarized_index=-1,
|
||||
error=error,
|
||||
)
|
||||
|
||||
await self._handle_summary_result(result_frame)
|
||||
|
||||
async def _handle_summary_result(self, frame: LLMContextSummaryResultFrame):
|
||||
"""Handle context summarization result from LLM service.
|
||||
@@ -306,8 +395,10 @@ class LLMContextSummarizer(BaseObject):
|
||||
# Get recent messages to keep
|
||||
recent_messages = messages[last_summarized_index + 1 :]
|
||||
|
||||
# Create summary message as an assistant message
|
||||
summary_message = {"role": "assistant", "content": f"Conversation summary: {summary}"}
|
||||
# Create summary message as a user message (the summary is context
|
||||
# provided *to* the assistant, not something the assistant said)
|
||||
summary_content = self._config.summary_message_template.format(summary=summary)
|
||||
summary_message = {"role": "user", "content": summary_content}
|
||||
|
||||
# Reconstruct context
|
||||
new_messages = []
|
||||
@@ -317,9 +408,23 @@ class LLMContextSummarizer(BaseObject):
|
||||
new_messages.extend(recent_messages)
|
||||
|
||||
# Update context
|
||||
original_message_count = len(messages)
|
||||
num_system_preserved = 1 if first_system_msg else 0
|
||||
self._context.set_messages(new_messages)
|
||||
|
||||
# Messages actually summarized = index range minus the preserved system message
|
||||
summarized_count = last_summarized_index + 1 - num_system_preserved
|
||||
|
||||
logger.info(
|
||||
f"{self}: Applied context summary, compressed {last_summarized_index + 1} messages "
|
||||
f"into summary. Context now has {len(new_messages)} messages (was {len(messages)})"
|
||||
f"{self}: Applied context summary, compressed {summarized_count} messages "
|
||||
f"into summary. Context now has {len(new_messages)} messages (was {original_message_count})"
|
||||
)
|
||||
|
||||
# Emit event for observability
|
||||
event = SummaryAppliedEvent(
|
||||
original_message_count=original_message_count,
|
||||
new_message_count=len(new_messages),
|
||||
summarized_message_count=summarized_count,
|
||||
preserved_message_count=len(recent_messages) + num_system_preserved,
|
||||
)
|
||||
await self._call_event_handler("on_summary_applied", event)
|
||||
|
||||
@@ -62,6 +62,7 @@ from pipecat.services.ai_service import AIService
|
||||
from pipecat.services.settings import LLMSettings
|
||||
from pipecat.turns.user_turn_completion_mixin import UserTurnCompletionLLMServiceMixin
|
||||
from pipecat.utils.context.llm_context_summarization import (
|
||||
DEFAULT_SUMMARIZATION_TIMEOUT,
|
||||
LLMContextSummarizationUtil,
|
||||
)
|
||||
|
||||
@@ -436,8 +437,15 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
|
||||
last_index = -1
|
||||
error = None
|
||||
|
||||
timeout = frame.summarization_timeout or DEFAULT_SUMMARIZATION_TIMEOUT
|
||||
|
||||
try:
|
||||
summary, last_index = await self._generate_summary(frame)
|
||||
summary, last_index = await asyncio.wait_for(
|
||||
self._generate_summary(frame),
|
||||
timeout=timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
await self.push_error(error_msg=f"Context summarization timed out after {timeout}s")
|
||||
except Exception as e:
|
||||
error = f"Error generating context summary: {e}"
|
||||
await self.push_error(error, exception=e)
|
||||
|
||||
@@ -11,12 +11,18 @@ context when token limits are reached, enabling efficient long-running conversat
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pipecat.services.llm_service import LLMService
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext, LLMSpecificMessage
|
||||
|
||||
# Fallback timeout (seconds) used when summarization_timeout is None.
|
||||
DEFAULT_SUMMARIZATION_TIMEOUT = 120.0
|
||||
|
||||
# Token estimation constants
|
||||
CHARS_PER_TOKEN = 4 # Industry-standard heuristic: 1 token ≈ 4 characters
|
||||
TOKEN_OVERHEAD_PER_MESSAGE = 10 # Estimated structural overhead per message
|
||||
@@ -73,6 +79,19 @@ class LLMContextSummarizationConfig:
|
||||
immediate conversational context.
|
||||
summarization_prompt: Custom prompt for the LLM to use when generating
|
||||
summaries. If None, uses DEFAULT_SUMMARIZATION_PROMPT.
|
||||
summary_message_template: Template for formatting the summary when
|
||||
injected into context. Must contain ``{summary}`` as a placeholder
|
||||
for the generated summary text. Allows applications to wrap the
|
||||
summary in custom delimiters (e.g., XML tags) so that system
|
||||
prompts can distinguish summaries from live conversation.
|
||||
llm: Optional separate LLM service for generating summaries. When set,
|
||||
summarization requests are sent to this service instead of the
|
||||
pipeline's primary LLM. Useful for routing summarization to a
|
||||
cheaper/faster model (e.g., Gemini Flash) while keeping an
|
||||
expensive model for conversation. If None, uses the pipeline LLM.
|
||||
summarization_timeout: Maximum time in seconds to wait for the LLM to
|
||||
generate a summary. If the call exceeds this timeout, summarization
|
||||
is aborted with an error and future summarizations are unblocked.
|
||||
"""
|
||||
|
||||
max_context_tokens: int = 8000
|
||||
@@ -80,6 +99,9 @@ class LLMContextSummarizationConfig:
|
||||
max_unsummarized_messages: int = 20
|
||||
min_messages_after_summary: int = 4
|
||||
summarization_prompt: Optional[str] = None
|
||||
summary_message_template: str = "Conversation summary: {summary}"
|
||||
llm: Optional["LLMService"] = None
|
||||
summarization_timeout: float = DEFAULT_SUMMARIZATION_TIMEOUT
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate configuration parameters."""
|
||||
|
||||
@@ -6,10 +6,11 @@
|
||||
|
||||
"""Tests for context summarization feature."""
|
||||
|
||||
import asyncio
|
||||
import unittest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from pipecat.frames.frames import LLMContextSummaryRequestFrame
|
||||
from pipecat.frames.frames import LLMContextSummaryRequestFrame, LLMContextSummaryResultFrame
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext, LLMSpecificMessage
|
||||
from pipecat.services.llm_service import LLMService
|
||||
from pipecat.utils.context.llm_context_summarization import (
|
||||
@@ -601,6 +602,250 @@ class TestSummaryGenerationExceptions(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertGreater(last_index, -1)
|
||||
self.assertEqual(last_index, 1) # Should be the index of the last summarized message
|
||||
|
||||
async def test_generate_summary_task_timeout(self):
|
||||
"""Test that _generate_summary_task handles timeout correctly."""
|
||||
llm_service = LLMService()
|
||||
|
||||
# Mock _generate_summary to hang
|
||||
async def slow_summary(frame):
|
||||
await asyncio.sleep(10)
|
||||
return ("summary", 1)
|
||||
|
||||
llm_service._generate_summary = slow_summary
|
||||
|
||||
broadcast_calls = []
|
||||
|
||||
async def mock_broadcast(frame_class, **kwargs):
|
||||
broadcast_calls.append((frame_class, kwargs))
|
||||
|
||||
llm_service.broadcast_frame = mock_broadcast
|
||||
llm_service.push_error = AsyncMock()
|
||||
|
||||
context = LLMContext()
|
||||
context.add_message({"role": "user", "content": "Message 1"})
|
||||
context.add_message({"role": "assistant", "content": "Response 1"})
|
||||
context.add_message({"role": "user", "content": "Message 2"})
|
||||
|
||||
frame = LLMContextSummaryRequestFrame(
|
||||
request_id="timeout_test",
|
||||
context=context,
|
||||
min_messages_to_keep=1,
|
||||
target_context_tokens=1000,
|
||||
summarization_prompt="Summarize this",
|
||||
summarization_timeout=0.1, # Very short timeout
|
||||
)
|
||||
|
||||
await llm_service._generate_summary_task(frame)
|
||||
|
||||
# Should have broadcast an error result
|
||||
self.assertEqual(len(broadcast_calls), 1)
|
||||
_, kwargs = broadcast_calls[0]
|
||||
self.assertEqual(kwargs["request_id"], "timeout_test")
|
||||
self.assertEqual(kwargs["summary"], "")
|
||||
self.assertEqual(kwargs["last_summarized_index"], -1)
|
||||
# error is None for timeout path (push_error is called instead)
|
||||
self.assertIsNone(kwargs["error"])
|
||||
|
||||
# push_error should have been called with timeout message
|
||||
llm_service.push_error.assert_called_once()
|
||||
call_args = llm_service.push_error.call_args
|
||||
error_msg = call_args.kwargs.get("error_msg") or call_args.args[0]
|
||||
self.assertIn("timed out", error_msg)
|
||||
|
||||
|
||||
class TestDedicatedLLMSummarization(unittest.IsolatedAsyncioTestCase):
|
||||
"""Tests for dedicated LLM summarization in LLMContextSummarizer."""
|
||||
|
||||
async def asyncSetUp(self):
|
||||
from pipecat.utils.asyncio.task_manager import TaskManager, TaskManagerParams
|
||||
|
||||
self.task_manager = TaskManager()
|
||||
self.task_manager.setup(TaskManagerParams(loop=asyncio.get_running_loop()))
|
||||
|
||||
def _create_context_and_config(self, dedicated_llm):
|
||||
"""Create a context with enough messages and a config with a dedicated LLM."""
|
||||
context = LLMContext()
|
||||
for i in range(10):
|
||||
context.add_message(
|
||||
{"role": "user", "content": f"Test message {i} that adds tokens to context."}
|
||||
)
|
||||
|
||||
config = LLMContextSummarizationConfig(
|
||||
max_context_tokens=50, # Very low to trigger easily
|
||||
llm=dedicated_llm,
|
||||
summarization_timeout=5.0,
|
||||
)
|
||||
return context, config
|
||||
|
||||
async def test_dedicated_llm_success(self):
|
||||
"""Test that dedicated LLM generates summary and applies result."""
|
||||
from pipecat.processors.aggregators.llm_context_summarizer import LLMContextSummarizer
|
||||
|
||||
dedicated_llm = LLMService()
|
||||
dedicated_llm._generate_summary = AsyncMock(return_value=("Dedicated summary", 5))
|
||||
|
||||
context, config = self._create_context_and_config(dedicated_llm)
|
||||
original_message_count = len(context.messages)
|
||||
summarizer = LLMContextSummarizer(context=context, config=config)
|
||||
await summarizer.setup(self.task_manager)
|
||||
|
||||
# Track whether on_request_summarization event fires (it should NOT)
|
||||
event_fired = False
|
||||
|
||||
@summarizer.event_handler("on_request_summarization")
|
||||
async def on_request_summarization(summarizer, frame):
|
||||
nonlocal event_fired
|
||||
event_fired = True
|
||||
|
||||
# Trigger summarization via LLM response start
|
||||
from pipecat.frames.frames import LLMFullResponseStartFrame
|
||||
|
||||
await summarizer.process_frame(LLMFullResponseStartFrame())
|
||||
|
||||
# Wait for the background task to complete
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# The event should NOT have fired (dedicated LLM handles it internally)
|
||||
self.assertFalse(event_fired)
|
||||
|
||||
# Verify the dedicated LLM was called
|
||||
dedicated_llm._generate_summary.assert_called_once()
|
||||
|
||||
# Verify summary was applied to context (message count should decrease)
|
||||
self.assertLess(len(context.messages), original_message_count)
|
||||
|
||||
# Verify summary message is present
|
||||
summary_messages = [
|
||||
msg for msg in context.messages if "Conversation summary:" in msg.get("content", "")
|
||||
]
|
||||
self.assertEqual(len(summary_messages), 1)
|
||||
self.assertIn("Dedicated summary", summary_messages[0]["content"])
|
||||
|
||||
await summarizer.cleanup()
|
||||
|
||||
async def test_dedicated_llm_timeout(self):
|
||||
"""Test that dedicated LLM timeout produces error and clears state."""
|
||||
from pipecat.processors.aggregators.llm_context_summarizer import LLMContextSummarizer
|
||||
|
||||
dedicated_llm = LLMService()
|
||||
|
||||
async def slow_summary(frame):
|
||||
await asyncio.sleep(10)
|
||||
return ("summary", 1)
|
||||
|
||||
dedicated_llm._generate_summary = slow_summary
|
||||
|
||||
context, config = self._create_context_and_config(dedicated_llm)
|
||||
config.summarization_timeout = 0.1 # Very short timeout
|
||||
summarizer = LLMContextSummarizer(context=context, config=config)
|
||||
await summarizer.setup(self.task_manager)
|
||||
|
||||
original_message_count = len(context.messages)
|
||||
|
||||
# Trigger summarization
|
||||
from pipecat.frames.frames import LLMFullResponseStartFrame
|
||||
|
||||
await summarizer.process_frame(LLMFullResponseStartFrame())
|
||||
|
||||
# Wait for the background task to complete (timeout + some buffer)
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
# Context should be unchanged (timeout = error = no summary applied)
|
||||
self.assertEqual(len(context.messages), original_message_count)
|
||||
|
||||
# Summarization state should be cleared so new requests can be made
|
||||
self.assertFalse(summarizer._summarization_in_progress)
|
||||
|
||||
await summarizer.cleanup()
|
||||
|
||||
async def test_dedicated_llm_exception(self):
|
||||
"""Test that dedicated LLM exceptions produce error and clear state."""
|
||||
from pipecat.processors.aggregators.llm_context_summarizer import LLMContextSummarizer
|
||||
|
||||
dedicated_llm = LLMService()
|
||||
dedicated_llm._generate_summary = AsyncMock(
|
||||
side_effect=RuntimeError("LLM connection failed")
|
||||
)
|
||||
|
||||
context, config = self._create_context_and_config(dedicated_llm)
|
||||
summarizer = LLMContextSummarizer(context=context, config=config)
|
||||
await summarizer.setup(self.task_manager)
|
||||
|
||||
original_message_count = len(context.messages)
|
||||
|
||||
# Trigger summarization
|
||||
from pipecat.frames.frames import LLMFullResponseStartFrame
|
||||
|
||||
await summarizer.process_frame(LLMFullResponseStartFrame())
|
||||
|
||||
# Wait for the background task to complete
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Context should be unchanged (exception = error = no summary applied)
|
||||
self.assertEqual(len(context.messages), original_message_count)
|
||||
|
||||
# Summarization state should be cleared
|
||||
self.assertFalse(summarizer._summarization_in_progress)
|
||||
|
||||
await summarizer.cleanup()
|
||||
|
||||
async def test_dedicated_llm_does_not_emit_event(self):
|
||||
"""Test that summarizer does NOT emit on_request_summarization when dedicated LLM is set."""
|
||||
from pipecat.processors.aggregators.llm_context_summarizer import LLMContextSummarizer
|
||||
|
||||
dedicated_llm = LLMService()
|
||||
dedicated_llm._generate_summary = AsyncMock(return_value=("Summary", 1))
|
||||
|
||||
context, config = self._create_context_and_config(dedicated_llm)
|
||||
summarizer = LLMContextSummarizer(context=context, config=config)
|
||||
await summarizer.setup(self.task_manager)
|
||||
|
||||
event_fired = False
|
||||
|
||||
@summarizer.event_handler("on_request_summarization")
|
||||
async def on_request_summarization(summarizer, frame):
|
||||
nonlocal event_fired
|
||||
event_fired = True
|
||||
|
||||
from pipecat.frames.frames import LLMFullResponseStartFrame
|
||||
|
||||
await summarizer.process_frame(LLMFullResponseStartFrame())
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
self.assertFalse(event_fired)
|
||||
|
||||
await summarizer.cleanup()
|
||||
|
||||
async def test_no_dedicated_llm_emits_event(self):
|
||||
"""Test that summarizer emits on_request_summarization when no dedicated LLM."""
|
||||
from pipecat.processors.aggregators.llm_context_summarizer import LLMContextSummarizer
|
||||
|
||||
context = LLMContext()
|
||||
for i in range(10):
|
||||
context.add_message(
|
||||
{"role": "user", "content": f"Test message {i} that adds tokens to context."}
|
||||
)
|
||||
|
||||
config = LLMContextSummarizationConfig(max_context_tokens=50)
|
||||
summarizer = LLMContextSummarizer(context=context, config=config)
|
||||
await summarizer.setup(self.task_manager)
|
||||
|
||||
request_frame = None
|
||||
|
||||
@summarizer.event_handler("on_request_summarization")
|
||||
async def on_request_summarization(summarizer, frame):
|
||||
nonlocal request_frame
|
||||
request_frame = frame
|
||||
|
||||
from pipecat.frames.frames import LLMFullResponseStartFrame
|
||||
|
||||
await summarizer.process_frame(LLMFullResponseStartFrame())
|
||||
|
||||
self.assertIsNotNone(request_frame)
|
||||
self.assertIsInstance(request_frame, LLMContextSummaryRequestFrame)
|
||||
|
||||
await summarizer.cleanup()
|
||||
|
||||
|
||||
class TestLLMSpecificMessageHandling(unittest.TestCase):
|
||||
"""Tests that LLMSpecificMessage objects are correctly skipped in summarization."""
|
||||
|
||||
@@ -14,7 +14,10 @@ from pipecat.frames.frames import (
|
||||
LLMFullResponseStartFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_context_summarizer import LLMContextSummarizer
|
||||
from pipecat.processors.aggregators.llm_context_summarizer import (
|
||||
LLMContextSummarizer,
|
||||
SummaryAppliedEvent,
|
||||
)
|
||||
from pipecat.utils.asyncio.task_manager import TaskManager, TaskManagerParams
|
||||
from pipecat.utils.context.llm_context_summarization import LLMContextSummarizationConfig
|
||||
|
||||
@@ -291,6 +294,252 @@ class TestLLMContextSummarizer(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
await summarizer.cleanup()
|
||||
|
||||
async def test_summary_message_role_is_user(self):
|
||||
"""Test that the summary message uses the user role."""
|
||||
config = LLMContextSummarizationConfig(max_context_tokens=50, min_messages_after_summary=2)
|
||||
|
||||
summarizer = LLMContextSummarizer(context=self.context, config=config)
|
||||
await summarizer.setup(self.task_manager)
|
||||
|
||||
# Add messages and trigger summarization
|
||||
for i in range(10):
|
||||
self.context.add_message({"role": "user", "content": "Test message."})
|
||||
|
||||
request_frame = None
|
||||
|
||||
@summarizer.event_handler("on_request_summarization")
|
||||
async def on_request_summarization(summarizer, frame):
|
||||
nonlocal request_frame
|
||||
request_frame = frame
|
||||
|
||||
await summarizer.process_frame(LLMFullResponseStartFrame())
|
||||
self.assertIsNotNone(request_frame)
|
||||
|
||||
# Simulate receiving a summary result
|
||||
summary_result = LLMContextSummaryResultFrame(
|
||||
request_id=request_frame.request_id,
|
||||
summary="This is a test summary.",
|
||||
last_summarized_index=5,
|
||||
)
|
||||
await summarizer.process_frame(summary_result)
|
||||
|
||||
# Find the summary message and verify its role is "user"
|
||||
summary_msg = next(
|
||||
(msg for msg in self.context.messages if "summary" in msg.get("content", "").lower()),
|
||||
None,
|
||||
)
|
||||
self.assertIsNotNone(summary_msg)
|
||||
self.assertEqual(summary_msg["role"], "user")
|
||||
|
||||
await summarizer.cleanup()
|
||||
|
||||
async def test_summary_message_default_template(self):
|
||||
"""Test that the default summary_message_template is used."""
|
||||
config = LLMContextSummarizationConfig(max_context_tokens=50, min_messages_after_summary=2)
|
||||
|
||||
summarizer = LLMContextSummarizer(context=self.context, config=config)
|
||||
await summarizer.setup(self.task_manager)
|
||||
|
||||
for i in range(10):
|
||||
self.context.add_message({"role": "user", "content": "Test message."})
|
||||
|
||||
request_frame = None
|
||||
|
||||
@summarizer.event_handler("on_request_summarization")
|
||||
async def on_request_summarization(summarizer, frame):
|
||||
nonlocal request_frame
|
||||
request_frame = frame
|
||||
|
||||
await summarizer.process_frame(LLMFullResponseStartFrame())
|
||||
|
||||
summary_result = LLMContextSummaryResultFrame(
|
||||
request_id=request_frame.request_id,
|
||||
summary="Key facts from conversation.",
|
||||
last_summarized_index=5,
|
||||
)
|
||||
await summarizer.process_frame(summary_result)
|
||||
|
||||
# Default template wraps with "Conversation summary: {summary}"
|
||||
summary_msg = next(
|
||||
(
|
||||
msg
|
||||
for msg in self.context.messages
|
||||
if "Conversation summary:" in msg.get("content", "")
|
||||
),
|
||||
None,
|
||||
)
|
||||
self.assertIsNotNone(summary_msg)
|
||||
self.assertEqual(
|
||||
summary_msg["content"], "Conversation summary: Key facts from conversation."
|
||||
)
|
||||
|
||||
await summarizer.cleanup()
|
||||
|
||||
async def test_summary_message_custom_template(self):
|
||||
"""Test that a custom summary_message_template is applied."""
|
||||
config = LLMContextSummarizationConfig(
|
||||
max_context_tokens=50,
|
||||
min_messages_after_summary=2,
|
||||
summary_message_template="<context_summary>\n{summary}\n</context_summary>",
|
||||
)
|
||||
|
||||
summarizer = LLMContextSummarizer(context=self.context, config=config)
|
||||
await summarizer.setup(self.task_manager)
|
||||
|
||||
for i in range(10):
|
||||
self.context.add_message({"role": "user", "content": "Test message."})
|
||||
|
||||
request_frame = None
|
||||
|
||||
@summarizer.event_handler("on_request_summarization")
|
||||
async def on_request_summarization(summarizer, frame):
|
||||
nonlocal request_frame
|
||||
request_frame = frame
|
||||
|
||||
await summarizer.process_frame(LLMFullResponseStartFrame())
|
||||
|
||||
summary_result = LLMContextSummaryResultFrame(
|
||||
request_id=request_frame.request_id,
|
||||
summary="Key facts from conversation.",
|
||||
last_summarized_index=5,
|
||||
)
|
||||
await summarizer.process_frame(summary_result)
|
||||
|
||||
# Custom template wraps with XML tags
|
||||
summary_msg = next(
|
||||
(msg for msg in self.context.messages if "<context_summary>" in msg.get("content", "")),
|
||||
None,
|
||||
)
|
||||
self.assertIsNotNone(summary_msg)
|
||||
self.assertEqual(
|
||||
summary_msg["content"],
|
||||
"<context_summary>\nKey facts from conversation.\n</context_summary>",
|
||||
)
|
||||
|
||||
await summarizer.cleanup()
|
||||
|
||||
async def test_on_summary_applied_event(self):
|
||||
"""Test that on_summary_applied event fires with correct data."""
|
||||
config = LLMContextSummarizationConfig(max_context_tokens=50, min_messages_after_summary=2)
|
||||
|
||||
summarizer = LLMContextSummarizer(context=self.context, config=config)
|
||||
await summarizer.setup(self.task_manager)
|
||||
|
||||
# Add messages (1 system + 10 user = 11 total)
|
||||
for i in range(10):
|
||||
self.context.add_message({"role": "user", "content": "Test message."})
|
||||
|
||||
request_frame = None
|
||||
applied_event = None
|
||||
|
||||
@summarizer.event_handler("on_request_summarization")
|
||||
async def on_request_summarization(summarizer, frame):
|
||||
nonlocal request_frame
|
||||
request_frame = frame
|
||||
|
||||
@summarizer.event_handler("on_summary_applied")
|
||||
async def on_summary_applied(summarizer, event):
|
||||
nonlocal applied_event
|
||||
applied_event = event
|
||||
|
||||
original_count = len(self.context.messages) # 11
|
||||
await summarizer.process_frame(LLMFullResponseStartFrame())
|
||||
|
||||
# Summarize up to index 7 (system=0, user1..user7), keep last 3 (user8, user9, user10)
|
||||
summary_result = LLMContextSummaryResultFrame(
|
||||
request_id=request_frame.request_id,
|
||||
summary="Test summary.",
|
||||
last_summarized_index=7,
|
||||
)
|
||||
await summarizer.process_frame(summary_result)
|
||||
|
||||
# Allow async event handler to complete
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
# Verify event was fired
|
||||
self.assertIsNotNone(applied_event)
|
||||
self.assertIsInstance(applied_event, SummaryAppliedEvent)
|
||||
self.assertEqual(applied_event.original_message_count, original_count)
|
||||
|
||||
# After summarization: system + summary + 3 recent = 5
|
||||
self.assertEqual(applied_event.new_message_count, 5)
|
||||
|
||||
# Summarized messages: indices 1-7 = 7 messages (excluding system at index 0)
|
||||
self.assertEqual(applied_event.summarized_message_count, 7)
|
||||
|
||||
# Preserved: system (1) + recent messages after index 7 (3) = 4
|
||||
self.assertEqual(applied_event.preserved_message_count, 4)
|
||||
|
||||
await summarizer.cleanup()
|
||||
|
||||
async def test_on_summary_applied_not_fired_on_error(self):
|
||||
"""Test that on_summary_applied event is NOT fired when summarization fails."""
|
||||
config = LLMContextSummarizationConfig(max_context_tokens=50, min_messages_after_summary=2)
|
||||
|
||||
summarizer = LLMContextSummarizer(context=self.context, config=config)
|
||||
await summarizer.setup(self.task_manager)
|
||||
|
||||
for i in range(10):
|
||||
self.context.add_message({"role": "user", "content": "Test message."})
|
||||
|
||||
request_frame = None
|
||||
applied_event = None
|
||||
|
||||
@summarizer.event_handler("on_request_summarization")
|
||||
async def on_request_summarization(summarizer, frame):
|
||||
nonlocal request_frame
|
||||
request_frame = frame
|
||||
|
||||
@summarizer.event_handler("on_summary_applied")
|
||||
async def on_summary_applied(summarizer, event):
|
||||
nonlocal applied_event
|
||||
applied_event = event
|
||||
|
||||
await summarizer.process_frame(LLMFullResponseStartFrame())
|
||||
|
||||
# Send a result with an error
|
||||
error_result = LLMContextSummaryResultFrame(
|
||||
request_id=request_frame.request_id,
|
||||
summary="",
|
||||
last_summarized_index=-1,
|
||||
error="Summarization timed out",
|
||||
)
|
||||
await summarizer.process_frame(error_result)
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
# Event should NOT have fired
|
||||
self.assertIsNone(applied_event)
|
||||
|
||||
await summarizer.cleanup()
|
||||
|
||||
async def test_request_frame_includes_timeout(self):
|
||||
"""Test that the request frame includes the configured summarization_timeout."""
|
||||
config = LLMContextSummarizationConfig(
|
||||
max_context_tokens=50,
|
||||
summarization_timeout=60.0,
|
||||
)
|
||||
|
||||
summarizer = LLMContextSummarizer(context=self.context, config=config)
|
||||
await summarizer.setup(self.task_manager)
|
||||
|
||||
request_frame = None
|
||||
|
||||
@summarizer.event_handler("on_request_summarization")
|
||||
async def on_request_summarization(summarizer, frame):
|
||||
nonlocal request_frame
|
||||
request_frame = frame
|
||||
|
||||
for i in range(10):
|
||||
self.context.add_message({"role": "user", "content": "Test message to add tokens."})
|
||||
|
||||
await summarizer.process_frame(LLMFullResponseStartFrame())
|
||||
|
||||
self.assertIsNotNone(request_frame)
|
||||
self.assertEqual(request_frame.summarization_timeout, 60.0)
|
||||
|
||||
await summarizer.cleanup()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user