diff --git a/changelog/3855.added.2.md b/changelog/3855.added.2.md new file mode 100644 index 000000000..01cd23efe --- /dev/null +++ b/changelog/3855.added.2.md @@ -0,0 +1 @@ +- Added optional `llm` field to `LLMContextSummarizationConfig` for routing summarization to a dedicated LLM service (e.g., a cheaper/faster model) instead of the pipeline's primary model. diff --git a/changelog/3855.added.3.md b/changelog/3855.added.3.md new file mode 100644 index 000000000..b93fdec60 --- /dev/null +++ b/changelog/3855.added.3.md @@ -0,0 +1 @@ +- Added `summarization_timeout` to `LLMContextSummarizationConfig` (default 120s) to prevent hung LLM calls from permanently blocking future summarizations. diff --git a/changelog/3855.added.4.md b/changelog/3855.added.4.md new file mode 100644 index 000000000..b712b4ac9 --- /dev/null +++ b/changelog/3855.added.4.md @@ -0,0 +1 @@ +- Added `on_summary_applied` event to `LLMContextSummarizer` for observability, providing message counts before and after context summarization. diff --git a/changelog/3855.added.md b/changelog/3855.added.md new file mode 100644 index 000000000..79d37eeba --- /dev/null +++ b/changelog/3855.added.md @@ -0,0 +1 @@ +- Added `summary_message_template` to `LLMContextSummarizationConfig` for customizing how summaries are formatted when injected into context (e.g., wrapping in XML tags). diff --git a/changelog/3855.changed.md b/changelog/3855.changed.md new file mode 100644 index 000000000..2eac6785a --- /dev/null +++ b/changelog/3855.changed.md @@ -0,0 +1 @@ +- Updated context summarization to use `user` role instead of `assistant` for summary messages. diff --git a/examples/foundational/54-context-summarization-openai.py b/examples/foundational/54-context-summarization-openai.py index 652a3af13..45f27854f 100644 --- a/examples/foundational/54-context-summarization-openai.py +++ b/examples/foundational/54-context-summarization-openai.py @@ -20,14 +20,13 @@ from loguru import logger from pipecat.adapters.schemas.function_schema import FunctionSchema from pipecat.adapters.schemas.tools_schema import ToolsSchema -from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3 from pipecat.audio.vad.silero import SileroVADAnalyzer -from pipecat.audio.vad.vad_analyzer import VADParams from pipecat.frames.frames import LLMRunFrame from pipecat.pipeline.pipeline import Pipeline from pipecat.pipeline.runner import PipelineRunner from pipecat.pipeline.task import PipelineParams, PipelineTask from pipecat.processors.aggregators.llm_context import LLMContext +from pipecat.processors.aggregators.llm_context_summarizer import SummaryAppliedEvent from pipecat.processors.aggregators.llm_response_universal import ( LLMAssistantAggregatorParams, LLMContextAggregatorPair, @@ -42,8 +41,6 @@ from pipecat.services.openai.llm import OpenAILLMService from pipecat.transports.base_transport import BaseTransport, TransportParams from pipecat.transports.daily.transport import DailyParams from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams -from pipecat.turns.user_stop import TurnAnalyzerUserTurnStopStrategy -from pipecat.turns.user_turn_strategies import UserTurnStrategies from pipecat.utils.context.llm_context_summarization import LLMContextSummarizationConfig load_dotenv(override=True) @@ -120,10 +117,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments): user_aggregator, assistant_aggregator = LLMContextAggregatorPair( context, user_params=LLMUserAggregatorParams( - user_turn_strategies=UserTurnStrategies( - stop=[TurnAnalyzerUserTurnStopStrategy(turn_analyzer=LocalSmartTurnAnalyzerV3())] - ), - vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)), + vad_analyzer=SileroVADAnalyzer(), ), assistant_params=LLMAssistantAggregatorParams( enable_context_summarization=True, @@ -138,6 +132,19 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments): ), ) + # Listen for summarization events + summarizer = assistant_aggregator._summarizer + if summarizer: + + @summarizer.event_handler("on_summary_applied") + async def on_summary_applied(summarizer, event: SummaryAppliedEvent): + logger.info( + f"Context summarized: {event.original_message_count} messages -> " + f"{event.new_message_count} messages " + f"({event.summarized_message_count} summarized, " + f"{event.preserved_message_count} preserved)" + ) + pipeline = Pipeline( [ transport.input(), # Transport user input diff --git a/examples/foundational/54a-context-summarization-google.py b/examples/foundational/54a-context-summarization-google.py index a7fe4ba5e..2ce29e959 100644 --- a/examples/foundational/54a-context-summarization-google.py +++ b/examples/foundational/54a-context-summarization-google.py @@ -20,14 +20,13 @@ from loguru import logger from pipecat.adapters.schemas.function_schema import FunctionSchema from pipecat.adapters.schemas.tools_schema import ToolsSchema -from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3 from pipecat.audio.vad.silero import SileroVADAnalyzer -from pipecat.audio.vad.vad_analyzer import VADParams from pipecat.frames.frames import LLMRunFrame from pipecat.pipeline.pipeline import Pipeline from pipecat.pipeline.runner import PipelineRunner from pipecat.pipeline.task import PipelineParams, PipelineTask from pipecat.processors.aggregators.llm_context import LLMContext +from pipecat.processors.aggregators.llm_context_summarizer import SummaryAppliedEvent from pipecat.processors.aggregators.llm_response_universal import ( LLMAssistantAggregatorParams, LLMContextAggregatorPair, @@ -42,8 +41,6 @@ from pipecat.services.llm_service import FunctionCallParams from pipecat.transports.base_transport import BaseTransport, TransportParams from pipecat.transports.daily.transport import DailyParams from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams -from pipecat.turns.user_stop import TurnAnalyzerUserTurnStopStrategy -from pipecat.turns.user_turn_strategies import UserTurnStrategies from pipecat.utils.context.llm_context_summarization import LLMContextSummarizationConfig load_dotenv(override=True) @@ -120,10 +117,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments): user_aggregator, assistant_aggregator = LLMContextAggregatorPair( context, user_params=LLMUserAggregatorParams( - user_turn_strategies=UserTurnStrategies( - stop=[TurnAnalyzerUserTurnStopStrategy(turn_analyzer=LocalSmartTurnAnalyzerV3())] - ), - vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)), + vad_analyzer=SileroVADAnalyzer(), ), assistant_params=LLMAssistantAggregatorParams( enable_context_summarization=True, @@ -138,6 +132,19 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments): ), ) + # Listen for summarization events + summarizer = assistant_aggregator._summarizer + if summarizer: + + @summarizer.event_handler("on_summary_applied") + async def on_summary_applied(summarizer, event: SummaryAppliedEvent): + logger.info( + f"Context summarized: {event.original_message_count} messages -> " + f"{event.new_message_count} messages " + f"({event.summarized_message_count} summarized, " + f"{event.preserved_message_count} preserved)" + ) + pipeline = Pipeline( [ transport.input(), # Transport user input diff --git a/examples/foundational/54c-context-summarization-dedicated-llm.py b/examples/foundational/54c-context-summarization-dedicated-llm.py new file mode 100644 index 000000000..3b2195e80 --- /dev/null +++ b/examples/foundational/54c-context-summarization-dedicated-llm.py @@ -0,0 +1,231 @@ +# +# Copyright (c) 2024-2026, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +"""Example demonstrating advanced context summarization configuration. + +This example shows how to customize context summarization with: +- A dedicated cheap/fast LLM for generating summaries (Gemini Flash) +- A custom summary message template (XML tags) +- A custom summarization prompt +- A summarization timeout +- The on_summary_applied event for observability +""" + +import asyncio +import os + +from dotenv import load_dotenv +from loguru import logger + +from pipecat.adapters.schemas.function_schema import FunctionSchema +from pipecat.adapters.schemas.tools_schema import ToolsSchema +from pipecat.audio.vad.silero import SileroVADAnalyzer +from pipecat.frames.frames import LLMRunFrame +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.llm_context import LLMContext +from pipecat.processors.aggregators.llm_context_summarizer import SummaryAppliedEvent +from pipecat.processors.aggregators.llm_response_universal import ( + LLMAssistantAggregatorParams, + LLMContextAggregatorPair, + LLMUserAggregatorParams, +) +from pipecat.runner.types import RunnerArguments +from pipecat.runner.utils import create_transport +from pipecat.services.cartesia.tts import CartesiaTTSService +from pipecat.services.deepgram.stt import DeepgramSTTService +from pipecat.services.google import GoogleLLMService +from pipecat.services.llm_service import FunctionCallParams +from pipecat.services.openai.llm import OpenAILLMService +from pipecat.transports.base_transport import BaseTransport, TransportParams +from pipecat.transports.daily.transport import DailyParams +from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams +from pipecat.utils.context.llm_context_summarization import LLMContextSummarizationConfig + +load_dotenv(override=True) + +# We use lambdas to defer transport parameter creation until the transport +# type is selected at runtime. +transport_params = { + "daily": lambda: DailyParams( + audio_in_enabled=True, + audio_out_enabled=True, + ), + "twilio": lambda: FastAPIWebsocketParams( + audio_in_enabled=True, + audio_out_enabled=True, + ), + "webrtc": lambda: TransportParams( + audio_in_enabled=True, + audio_out_enabled=True, + ), +} + +# Custom summarization prompt tailored to the application +CUSTOM_SUMMARIZATION_PROMPT = """Summarize this conversation, preserving: +- Key decisions and agreements +- Important facts and user preferences +- Any pending action items or unresolved questions + +Be concise. Use clear, factual statements grouped by topic. +Omit greetings, small talk, and resolved tangents.""" + + +# Tool functions for the LLM +async def get_current_weather(params: FunctionCallParams): + """Get the current weather.""" + logger.info("Tool called: get_current_weather") + await asyncio.sleep(1) # Simulate some processing + await params.result_callback({"conditions": "nice", "temperature": "75"}) + + +async def run_bot(transport: BaseTransport, runner_args: RunnerArguments): + logger.info("Starting bot") + + stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY")) + + tts = CartesiaTTSService( + api_key=os.getenv("CARTESIA_API_KEY"), + voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady + ) + + # Primary LLM for conversation (could be any provider) + llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY")) + + # Dedicated cheap/fast LLM for summarization only + summarization_llm = GoogleLLMService( + api_key=os.getenv("GOOGLE_API_KEY"), + model="gemini-2.5-flash", + ) + + # Register tool functions + llm.register_function("get_current_weather", get_current_weather) + + weather_function = FunctionSchema( + name="get_current_weather", + description="Get the current weather", + properties={ + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the user's location.", + }, + }, + required=["location", "format"], + ) + tools = ToolsSchema(standard_tools=[weather_function]) + + messages = [ + { + "role": "system", + "content": ( + "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate " + "your capabilities in a succinct way. Your output will be spoken aloud, " + "so avoid special characters that can't easily be spoken. Respond to what " + "the user said in a creative and helpful way. You have access to tools to " + "get the current weather - use them when relevant.\n\n" + "When you see a block, it contains a compressed summary " + "of earlier conversation. Use it as reference but don't mention it to the user." + ), + }, + ] + + context = LLMContext(messages, tools=tools) + + # Create aggregators with custom summarization + user_aggregator, assistant_aggregator = LLMContextAggregatorPair( + context, + user_params=LLMUserAggregatorParams( + vad_analyzer=SileroVADAnalyzer(), + ), + assistant_params=LLMAssistantAggregatorParams( + enable_context_summarization=True, + context_summarization_config=LLMContextSummarizationConfig( + # Trigger thresholds (low values to demonstrate quickly) + max_context_tokens=1000, + max_unsummarized_messages=10, + # Summary generation + target_context_tokens=800, + min_messages_after_summary=2, + summarization_prompt=CUSTOM_SUMMARIZATION_PROMPT, + # Custom summary format - wrap in XML tags so the system + # prompt can identify summaries vs. live conversation + summary_message_template="\n{summary}\n", + # Use a dedicated cheap LLM for summarization instead of + # the primary conversation model + llm=summarization_llm, + # Cancel summarization if it takes longer than 60 seconds + summarization_timeout=60.0, + ), + ), + ) + + # Listen for summarization events + summarizer = assistant_aggregator._summarizer + if summarizer: + + @summarizer.event_handler("on_summary_applied") + async def on_summary_applied(summarizer, event: SummaryAppliedEvent): + logger.info( + f"Context summarized: {event.original_message_count} messages -> " + f"{event.new_message_count} messages " + f"({event.summarized_message_count} summarized, " + f"{event.preserved_message_count} preserved)" + ) + + pipeline = Pipeline( + [ + transport.input(), # Transport user input + stt, + user_aggregator, # User responses + llm, # LLM + tts, # TTS + transport.output(), # Transport bot output + assistant_aggregator, # Assistant spoken responses + ] + ) + + task = PipelineTask( + pipeline, + params=PipelineParams( + enable_metrics=True, + enable_usage_metrics=True, + ), + idle_timeout_secs=runner_args.pipeline_idle_timeout_secs, + ) + + @transport.event_handler("on_client_connected") + async def on_client_connected(transport, client): + logger.info("Client connected") + # Kick off the conversation. + messages.append({"role": "system", "content": "Please introduce yourself to the user."}) + await task.queue_frames([LLMRunFrame()]) + + @transport.event_handler("on_client_disconnected") + async def on_client_disconnected(transport, client): + logger.info("Client disconnected") + await task.cancel() + + runner = PipelineRunner(handle_sigint=runner_args.handle_sigint) + + await runner.run(task) + + +async def bot(runner_args: RunnerArguments): + """Main bot entry point compatible with Pipecat Cloud.""" + transport = await create_transport(runner_args, transport_params) + await run_bot(transport, runner_args) + + +if __name__ == "__main__": + from pipecat.runner.run import main + + main() diff --git a/src/pipecat/frames/frames.py b/src/pipecat/frames/frames.py index bbc065969..e1d2c37ff 100644 --- a/src/pipecat/frames/frames.py +++ b/src/pipecat/frames/frames.py @@ -2019,6 +2019,8 @@ class LLMContextSummaryRequestFrame(ControlFrame): the summary text. summarization_prompt: System prompt instructing the LLM how to generate the summary. + summarization_timeout: Maximum time in seconds for the LLM to generate a + summary. When None, a default timeout of 120s is applied. """ request_id: str @@ -2026,6 +2028,7 @@ class LLMContextSummaryRequestFrame(ControlFrame): min_messages_to_keep: int target_context_tokens: int summarization_prompt: str + summarization_timeout: Optional[float] = None @dataclass diff --git a/src/pipecat/processors/aggregators/llm_context_summarizer.py b/src/pipecat/processors/aggregators/llm_context_summarizer.py index 7886fcf12..bfdbbceb0 100644 --- a/src/pipecat/processors/aggregators/llm_context_summarizer.py +++ b/src/pipecat/processors/aggregators/llm_context_summarizer.py @@ -6,8 +6,10 @@ """This module defines a summarizer for managing LLM context summarization.""" +import asyncio import uuid -from typing import Optional +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional from loguru import logger @@ -22,10 +24,33 @@ from pipecat.processors.aggregators.llm_context import LLMContext, LLMSpecificMe from pipecat.utils.asyncio.task_manager import BaseTaskManager from pipecat.utils.base_object import BaseObject from pipecat.utils.context.llm_context_summarization import ( + DEFAULT_SUMMARIZATION_TIMEOUT, LLMContextSummarizationConfig, LLMContextSummarizationUtil, ) +if TYPE_CHECKING: + from pipecat.services.llm_service import LLMService + + +@dataclass +class SummaryAppliedEvent: + """Event data emitted when context summarization completes successfully. + + Parameters: + original_message_count: Number of messages before summarization. + new_message_count: Number of messages after summarization. + summarized_message_count: Number of messages that were compressed + into the summary. + preserved_message_count: Number of recent messages preserved + uncompressed. + """ + + original_message_count: int + new_message_count: int + summarized_message_count: int + preserved_message_count: int + class LLMContextSummarizer(BaseObject): """Summarizer for managing LLM context summarization. @@ -39,6 +64,10 @@ class LLMContextSummarizer(BaseObject): - on_request_summarization: Emitted when summarization should be triggered. The aggregator should broadcast this frame to the LLM service. + - on_summary_applied: Emitted after a summary has been successfully applied + to the context. Receives a SummaryAppliedEvent with metrics about the + compression. + Example:: @summarizer.event_handler("on_request_summarization") @@ -49,6 +78,10 @@ class LLMContextSummarizer(BaseObject): context=frame.context, ... ) + + @summarizer.event_handler("on_summary_applied") + async def on_summary_applied(summarizer, event: SummaryAppliedEvent): + logger.info(f"Compressed {event.original_message_count} -> {event.new_message_count} messages") """ def __init__( @@ -74,6 +107,7 @@ class LLMContextSummarizer(BaseObject): self._pending_summary_request_id: Optional[str] = None self._register_event_handler("on_request_summarization", sync=True) + self._register_event_handler("on_summary_applied") @property def task_manager(self) -> BaseTaskManager: @@ -198,8 +232,10 @@ class LLMContextSummarizer(BaseObject): async def _request_summarization(self): """Request context summarization from LLM service. - Creates a summarization request frame and emits it via event handler. - Tracks the request ID to match async responses and prevent race conditions. + Creates a summarization request frame and either handles it directly + using a dedicated LLM (if configured) or emits it via event handler + for the pipeline's primary LLM. Tracks the request ID to match async + responses and prevent race conditions. """ # Generate unique request ID request_id = str(uuid.uuid4()) @@ -218,10 +254,63 @@ class LLMContextSummarizer(BaseObject): min_messages_to_keep=min_keep, target_context_tokens=self._config.target_context_tokens, summarization_prompt=self._config.summary_prompt, + summarization_timeout=self._config.summarization_timeout, ) - # Emit event for aggregator to broadcast - await self._call_event_handler("on_request_summarization", request_frame) + if self._config.llm: + # Use dedicated LLM directly — no need to involve the pipeline + self.task_manager.create_task( + self._generate_summary_with_dedicated_llm(self._config.llm, request_frame), + f"{self}-dedicated-llm-summary", + ) + else: + # Emit event for aggregator to broadcast to the pipeline LLM + await self._call_event_handler("on_request_summarization", request_frame) + + async def _generate_summary_with_dedicated_llm( + self, llm: "LLMService", frame: LLMContextSummaryRequestFrame + ): + """Generate summary using a dedicated LLM service. + + Calls the dedicated LLM's _generate_summary directly and feeds the + result back through _handle_summary_result, bypassing the pipeline. + + Args: + llm: The dedicated LLM service to use for summarization. + frame: The summarization request frame. + """ + timeout = frame.summarization_timeout or DEFAULT_SUMMARIZATION_TIMEOUT + + try: + summary, last_index = await asyncio.wait_for( + llm._generate_summary(frame), + timeout=timeout, + ) + result_frame = LLMContextSummaryResultFrame( + request_id=frame.request_id, + summary=summary, + last_summarized_index=last_index, + ) + except asyncio.TimeoutError: + error = f"Context summarization timed out after {timeout}s" + logger.error(f"{self}: {error}") + result_frame = LLMContextSummaryResultFrame( + request_id=frame.request_id, + summary="", + last_summarized_index=-1, + error=error, + ) + except Exception as e: + error = f"Error generating context summary: {e}" + logger.error(f"{self}: {error}") + result_frame = LLMContextSummaryResultFrame( + request_id=frame.request_id, + summary="", + last_summarized_index=-1, + error=error, + ) + + await self._handle_summary_result(result_frame) async def _handle_summary_result(self, frame: LLMContextSummaryResultFrame): """Handle context summarization result from LLM service. @@ -306,8 +395,10 @@ class LLMContextSummarizer(BaseObject): # Get recent messages to keep recent_messages = messages[last_summarized_index + 1 :] - # Create summary message as an assistant message - summary_message = {"role": "assistant", "content": f"Conversation summary: {summary}"} + # Create summary message as a user message (the summary is context + # provided *to* the assistant, not something the assistant said) + summary_content = self._config.summary_message_template.format(summary=summary) + summary_message = {"role": "user", "content": summary_content} # Reconstruct context new_messages = [] @@ -317,9 +408,23 @@ class LLMContextSummarizer(BaseObject): new_messages.extend(recent_messages) # Update context + original_message_count = len(messages) + num_system_preserved = 1 if first_system_msg else 0 self._context.set_messages(new_messages) + # Messages actually summarized = index range minus the preserved system message + summarized_count = last_summarized_index + 1 - num_system_preserved + logger.info( - f"{self}: Applied context summary, compressed {last_summarized_index + 1} messages " - f"into summary. Context now has {len(new_messages)} messages (was {len(messages)})" + f"{self}: Applied context summary, compressed {summarized_count} messages " + f"into summary. Context now has {len(new_messages)} messages (was {original_message_count})" ) + + # Emit event for observability + event = SummaryAppliedEvent( + original_message_count=original_message_count, + new_message_count=len(new_messages), + summarized_message_count=summarized_count, + preserved_message_count=len(recent_messages) + num_system_preserved, + ) + await self._call_event_handler("on_summary_applied", event) diff --git a/src/pipecat/services/llm_service.py b/src/pipecat/services/llm_service.py index a06423754..da0d57d66 100644 --- a/src/pipecat/services/llm_service.py +++ b/src/pipecat/services/llm_service.py @@ -62,6 +62,7 @@ from pipecat.services.ai_service import AIService from pipecat.services.settings import LLMSettings from pipecat.turns.user_turn_completion_mixin import UserTurnCompletionLLMServiceMixin from pipecat.utils.context.llm_context_summarization import ( + DEFAULT_SUMMARIZATION_TIMEOUT, LLMContextSummarizationUtil, ) @@ -436,8 +437,15 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService): last_index = -1 error = None + timeout = frame.summarization_timeout or DEFAULT_SUMMARIZATION_TIMEOUT + try: - summary, last_index = await self._generate_summary(frame) + summary, last_index = await asyncio.wait_for( + self._generate_summary(frame), + timeout=timeout, + ) + except asyncio.TimeoutError: + await self.push_error(error_msg=f"Context summarization timed out after {timeout}s") except Exception as e: error = f"Error generating context summary: {e}" await self.push_error(error, exception=e) diff --git a/src/pipecat/utils/context/llm_context_summarization.py b/src/pipecat/utils/context/llm_context_summarization.py index 537cc91ab..0bdebb3a2 100644 --- a/src/pipecat/utils/context/llm_context_summarization.py +++ b/src/pipecat/utils/context/llm_context_summarization.py @@ -11,12 +11,18 @@ context when token limits are reached, enabling efficient long-running conversat """ from dataclasses import dataclass -from typing import List, Optional +from typing import TYPE_CHECKING, List, Optional + +if TYPE_CHECKING: + from pipecat.services.llm_service import LLMService from loguru import logger from pipecat.processors.aggregators.llm_context import LLMContext, LLMSpecificMessage +# Fallback timeout (seconds) used when summarization_timeout is None. +DEFAULT_SUMMARIZATION_TIMEOUT = 120.0 + # Token estimation constants CHARS_PER_TOKEN = 4 # Industry-standard heuristic: 1 token ≈ 4 characters TOKEN_OVERHEAD_PER_MESSAGE = 10 # Estimated structural overhead per message @@ -73,6 +79,19 @@ class LLMContextSummarizationConfig: immediate conversational context. summarization_prompt: Custom prompt for the LLM to use when generating summaries. If None, uses DEFAULT_SUMMARIZATION_PROMPT. + summary_message_template: Template for formatting the summary when + injected into context. Must contain ``{summary}`` as a placeholder + for the generated summary text. Allows applications to wrap the + summary in custom delimiters (e.g., XML tags) so that system + prompts can distinguish summaries from live conversation. + llm: Optional separate LLM service for generating summaries. When set, + summarization requests are sent to this service instead of the + pipeline's primary LLM. Useful for routing summarization to a + cheaper/faster model (e.g., Gemini Flash) while keeping an + expensive model for conversation. If None, uses the pipeline LLM. + summarization_timeout: Maximum time in seconds to wait for the LLM to + generate a summary. If the call exceeds this timeout, summarization + is aborted with an error and future summarizations are unblocked. """ max_context_tokens: int = 8000 @@ -80,6 +99,9 @@ class LLMContextSummarizationConfig: max_unsummarized_messages: int = 20 min_messages_after_summary: int = 4 summarization_prompt: Optional[str] = None + summary_message_template: str = "Conversation summary: {summary}" + llm: Optional["LLMService"] = None + summarization_timeout: float = DEFAULT_SUMMARIZATION_TIMEOUT def __post_init__(self): """Validate configuration parameters.""" diff --git a/tests/test_context_summarization.py b/tests/test_context_summarization.py index 3bb1246e9..ca56e7a32 100644 --- a/tests/test_context_summarization.py +++ b/tests/test_context_summarization.py @@ -6,10 +6,11 @@ """Tests for context summarization feature.""" +import asyncio import unittest -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock -from pipecat.frames.frames import LLMContextSummaryRequestFrame +from pipecat.frames.frames import LLMContextSummaryRequestFrame, LLMContextSummaryResultFrame from pipecat.processors.aggregators.llm_context import LLMContext, LLMSpecificMessage from pipecat.services.llm_service import LLMService from pipecat.utils.context.llm_context_summarization import ( @@ -601,6 +602,250 @@ class TestSummaryGenerationExceptions(unittest.IsolatedAsyncioTestCase): self.assertGreater(last_index, -1) self.assertEqual(last_index, 1) # Should be the index of the last summarized message + async def test_generate_summary_task_timeout(self): + """Test that _generate_summary_task handles timeout correctly.""" + llm_service = LLMService() + + # Mock _generate_summary to hang + async def slow_summary(frame): + await asyncio.sleep(10) + return ("summary", 1) + + llm_service._generate_summary = slow_summary + + broadcast_calls = [] + + async def mock_broadcast(frame_class, **kwargs): + broadcast_calls.append((frame_class, kwargs)) + + llm_service.broadcast_frame = mock_broadcast + llm_service.push_error = AsyncMock() + + context = LLMContext() + context.add_message({"role": "user", "content": "Message 1"}) + context.add_message({"role": "assistant", "content": "Response 1"}) + context.add_message({"role": "user", "content": "Message 2"}) + + frame = LLMContextSummaryRequestFrame( + request_id="timeout_test", + context=context, + min_messages_to_keep=1, + target_context_tokens=1000, + summarization_prompt="Summarize this", + summarization_timeout=0.1, # Very short timeout + ) + + await llm_service._generate_summary_task(frame) + + # Should have broadcast an error result + self.assertEqual(len(broadcast_calls), 1) + _, kwargs = broadcast_calls[0] + self.assertEqual(kwargs["request_id"], "timeout_test") + self.assertEqual(kwargs["summary"], "") + self.assertEqual(kwargs["last_summarized_index"], -1) + # error is None for timeout path (push_error is called instead) + self.assertIsNone(kwargs["error"]) + + # push_error should have been called with timeout message + llm_service.push_error.assert_called_once() + call_args = llm_service.push_error.call_args + error_msg = call_args.kwargs.get("error_msg") or call_args.args[0] + self.assertIn("timed out", error_msg) + + +class TestDedicatedLLMSummarization(unittest.IsolatedAsyncioTestCase): + """Tests for dedicated LLM summarization in LLMContextSummarizer.""" + + async def asyncSetUp(self): + from pipecat.utils.asyncio.task_manager import TaskManager, TaskManagerParams + + self.task_manager = TaskManager() + self.task_manager.setup(TaskManagerParams(loop=asyncio.get_running_loop())) + + def _create_context_and_config(self, dedicated_llm): + """Create a context with enough messages and a config with a dedicated LLM.""" + context = LLMContext() + for i in range(10): + context.add_message( + {"role": "user", "content": f"Test message {i} that adds tokens to context."} + ) + + config = LLMContextSummarizationConfig( + max_context_tokens=50, # Very low to trigger easily + llm=dedicated_llm, + summarization_timeout=5.0, + ) + return context, config + + async def test_dedicated_llm_success(self): + """Test that dedicated LLM generates summary and applies result.""" + from pipecat.processors.aggregators.llm_context_summarizer import LLMContextSummarizer + + dedicated_llm = LLMService() + dedicated_llm._generate_summary = AsyncMock(return_value=("Dedicated summary", 5)) + + context, config = self._create_context_and_config(dedicated_llm) + original_message_count = len(context.messages) + summarizer = LLMContextSummarizer(context=context, config=config) + await summarizer.setup(self.task_manager) + + # Track whether on_request_summarization event fires (it should NOT) + event_fired = False + + @summarizer.event_handler("on_request_summarization") + async def on_request_summarization(summarizer, frame): + nonlocal event_fired + event_fired = True + + # Trigger summarization via LLM response start + from pipecat.frames.frames import LLMFullResponseStartFrame + + await summarizer.process_frame(LLMFullResponseStartFrame()) + + # Wait for the background task to complete + await asyncio.sleep(0.1) + + # The event should NOT have fired (dedicated LLM handles it internally) + self.assertFalse(event_fired) + + # Verify the dedicated LLM was called + dedicated_llm._generate_summary.assert_called_once() + + # Verify summary was applied to context (message count should decrease) + self.assertLess(len(context.messages), original_message_count) + + # Verify summary message is present + summary_messages = [ + msg for msg in context.messages if "Conversation summary:" in msg.get("content", "") + ] + self.assertEqual(len(summary_messages), 1) + self.assertIn("Dedicated summary", summary_messages[0]["content"]) + + await summarizer.cleanup() + + async def test_dedicated_llm_timeout(self): + """Test that dedicated LLM timeout produces error and clears state.""" + from pipecat.processors.aggregators.llm_context_summarizer import LLMContextSummarizer + + dedicated_llm = LLMService() + + async def slow_summary(frame): + await asyncio.sleep(10) + return ("summary", 1) + + dedicated_llm._generate_summary = slow_summary + + context, config = self._create_context_and_config(dedicated_llm) + config.summarization_timeout = 0.1 # Very short timeout + summarizer = LLMContextSummarizer(context=context, config=config) + await summarizer.setup(self.task_manager) + + original_message_count = len(context.messages) + + # Trigger summarization + from pipecat.frames.frames import LLMFullResponseStartFrame + + await summarizer.process_frame(LLMFullResponseStartFrame()) + + # Wait for the background task to complete (timeout + some buffer) + await asyncio.sleep(0.3) + + # Context should be unchanged (timeout = error = no summary applied) + self.assertEqual(len(context.messages), original_message_count) + + # Summarization state should be cleared so new requests can be made + self.assertFalse(summarizer._summarization_in_progress) + + await summarizer.cleanup() + + async def test_dedicated_llm_exception(self): + """Test that dedicated LLM exceptions produce error and clear state.""" + from pipecat.processors.aggregators.llm_context_summarizer import LLMContextSummarizer + + dedicated_llm = LLMService() + dedicated_llm._generate_summary = AsyncMock( + side_effect=RuntimeError("LLM connection failed") + ) + + context, config = self._create_context_and_config(dedicated_llm) + summarizer = LLMContextSummarizer(context=context, config=config) + await summarizer.setup(self.task_manager) + + original_message_count = len(context.messages) + + # Trigger summarization + from pipecat.frames.frames import LLMFullResponseStartFrame + + await summarizer.process_frame(LLMFullResponseStartFrame()) + + # Wait for the background task to complete + await asyncio.sleep(0.1) + + # Context should be unchanged (exception = error = no summary applied) + self.assertEqual(len(context.messages), original_message_count) + + # Summarization state should be cleared + self.assertFalse(summarizer._summarization_in_progress) + + await summarizer.cleanup() + + async def test_dedicated_llm_does_not_emit_event(self): + """Test that summarizer does NOT emit on_request_summarization when dedicated LLM is set.""" + from pipecat.processors.aggregators.llm_context_summarizer import LLMContextSummarizer + + dedicated_llm = LLMService() + dedicated_llm._generate_summary = AsyncMock(return_value=("Summary", 1)) + + context, config = self._create_context_and_config(dedicated_llm) + summarizer = LLMContextSummarizer(context=context, config=config) + await summarizer.setup(self.task_manager) + + event_fired = False + + @summarizer.event_handler("on_request_summarization") + async def on_request_summarization(summarizer, frame): + nonlocal event_fired + event_fired = True + + from pipecat.frames.frames import LLMFullResponseStartFrame + + await summarizer.process_frame(LLMFullResponseStartFrame()) + await asyncio.sleep(0.1) + + self.assertFalse(event_fired) + + await summarizer.cleanup() + + async def test_no_dedicated_llm_emits_event(self): + """Test that summarizer emits on_request_summarization when no dedicated LLM.""" + from pipecat.processors.aggregators.llm_context_summarizer import LLMContextSummarizer + + context = LLMContext() + for i in range(10): + context.add_message( + {"role": "user", "content": f"Test message {i} that adds tokens to context."} + ) + + config = LLMContextSummarizationConfig(max_context_tokens=50) + summarizer = LLMContextSummarizer(context=context, config=config) + await summarizer.setup(self.task_manager) + + request_frame = None + + @summarizer.event_handler("on_request_summarization") + async def on_request_summarization(summarizer, frame): + nonlocal request_frame + request_frame = frame + + from pipecat.frames.frames import LLMFullResponseStartFrame + + await summarizer.process_frame(LLMFullResponseStartFrame()) + + self.assertIsNotNone(request_frame) + self.assertIsInstance(request_frame, LLMContextSummaryRequestFrame) + + await summarizer.cleanup() + class TestLLMSpecificMessageHandling(unittest.TestCase): """Tests that LLMSpecificMessage objects are correctly skipped in summarization.""" diff --git a/tests/test_llm_context_summarizer.py b/tests/test_llm_context_summarizer.py index 7555a8762..0439d403d 100644 --- a/tests/test_llm_context_summarizer.py +++ b/tests/test_llm_context_summarizer.py @@ -14,7 +14,10 @@ from pipecat.frames.frames import ( LLMFullResponseStartFrame, ) from pipecat.processors.aggregators.llm_context import LLMContext -from pipecat.processors.aggregators.llm_context_summarizer import LLMContextSummarizer +from pipecat.processors.aggregators.llm_context_summarizer import ( + LLMContextSummarizer, + SummaryAppliedEvent, +) from pipecat.utils.asyncio.task_manager import TaskManager, TaskManagerParams from pipecat.utils.context.llm_context_summarization import LLMContextSummarizationConfig @@ -291,6 +294,252 @@ class TestLLMContextSummarizer(unittest.IsolatedAsyncioTestCase): await summarizer.cleanup() + async def test_summary_message_role_is_user(self): + """Test that the summary message uses the user role.""" + config = LLMContextSummarizationConfig(max_context_tokens=50, min_messages_after_summary=2) + + summarizer = LLMContextSummarizer(context=self.context, config=config) + await summarizer.setup(self.task_manager) + + # Add messages and trigger summarization + for i in range(10): + self.context.add_message({"role": "user", "content": "Test message."}) + + request_frame = None + + @summarizer.event_handler("on_request_summarization") + async def on_request_summarization(summarizer, frame): + nonlocal request_frame + request_frame = frame + + await summarizer.process_frame(LLMFullResponseStartFrame()) + self.assertIsNotNone(request_frame) + + # Simulate receiving a summary result + summary_result = LLMContextSummaryResultFrame( + request_id=request_frame.request_id, + summary="This is a test summary.", + last_summarized_index=5, + ) + await summarizer.process_frame(summary_result) + + # Find the summary message and verify its role is "user" + summary_msg = next( + (msg for msg in self.context.messages if "summary" in msg.get("content", "").lower()), + None, + ) + self.assertIsNotNone(summary_msg) + self.assertEqual(summary_msg["role"], "user") + + await summarizer.cleanup() + + async def test_summary_message_default_template(self): + """Test that the default summary_message_template is used.""" + config = LLMContextSummarizationConfig(max_context_tokens=50, min_messages_after_summary=2) + + summarizer = LLMContextSummarizer(context=self.context, config=config) + await summarizer.setup(self.task_manager) + + for i in range(10): + self.context.add_message({"role": "user", "content": "Test message."}) + + request_frame = None + + @summarizer.event_handler("on_request_summarization") + async def on_request_summarization(summarizer, frame): + nonlocal request_frame + request_frame = frame + + await summarizer.process_frame(LLMFullResponseStartFrame()) + + summary_result = LLMContextSummaryResultFrame( + request_id=request_frame.request_id, + summary="Key facts from conversation.", + last_summarized_index=5, + ) + await summarizer.process_frame(summary_result) + + # Default template wraps with "Conversation summary: {summary}" + summary_msg = next( + ( + msg + for msg in self.context.messages + if "Conversation summary:" in msg.get("content", "") + ), + None, + ) + self.assertIsNotNone(summary_msg) + self.assertEqual( + summary_msg["content"], "Conversation summary: Key facts from conversation." + ) + + await summarizer.cleanup() + + async def test_summary_message_custom_template(self): + """Test that a custom summary_message_template is applied.""" + config = LLMContextSummarizationConfig( + max_context_tokens=50, + min_messages_after_summary=2, + summary_message_template="\n{summary}\n", + ) + + summarizer = LLMContextSummarizer(context=self.context, config=config) + await summarizer.setup(self.task_manager) + + for i in range(10): + self.context.add_message({"role": "user", "content": "Test message."}) + + request_frame = None + + @summarizer.event_handler("on_request_summarization") + async def on_request_summarization(summarizer, frame): + nonlocal request_frame + request_frame = frame + + await summarizer.process_frame(LLMFullResponseStartFrame()) + + summary_result = LLMContextSummaryResultFrame( + request_id=request_frame.request_id, + summary="Key facts from conversation.", + last_summarized_index=5, + ) + await summarizer.process_frame(summary_result) + + # Custom template wraps with XML tags + summary_msg = next( + (msg for msg in self.context.messages if "" in msg.get("content", "")), + None, + ) + self.assertIsNotNone(summary_msg) + self.assertEqual( + summary_msg["content"], + "\nKey facts from conversation.\n", + ) + + await summarizer.cleanup() + + async def test_on_summary_applied_event(self): + """Test that on_summary_applied event fires with correct data.""" + config = LLMContextSummarizationConfig(max_context_tokens=50, min_messages_after_summary=2) + + summarizer = LLMContextSummarizer(context=self.context, config=config) + await summarizer.setup(self.task_manager) + + # Add messages (1 system + 10 user = 11 total) + for i in range(10): + self.context.add_message({"role": "user", "content": "Test message."}) + + request_frame = None + applied_event = None + + @summarizer.event_handler("on_request_summarization") + async def on_request_summarization(summarizer, frame): + nonlocal request_frame + request_frame = frame + + @summarizer.event_handler("on_summary_applied") + async def on_summary_applied(summarizer, event): + nonlocal applied_event + applied_event = event + + original_count = len(self.context.messages) # 11 + await summarizer.process_frame(LLMFullResponseStartFrame()) + + # Summarize up to index 7 (system=0, user1..user7), keep last 3 (user8, user9, user10) + summary_result = LLMContextSummaryResultFrame( + request_id=request_frame.request_id, + summary="Test summary.", + last_summarized_index=7, + ) + await summarizer.process_frame(summary_result) + + # Allow async event handler to complete + await asyncio.sleep(0.05) + + # Verify event was fired + self.assertIsNotNone(applied_event) + self.assertIsInstance(applied_event, SummaryAppliedEvent) + self.assertEqual(applied_event.original_message_count, original_count) + + # After summarization: system + summary + 3 recent = 5 + self.assertEqual(applied_event.new_message_count, 5) + + # Summarized messages: indices 1-7 = 7 messages (excluding system at index 0) + self.assertEqual(applied_event.summarized_message_count, 7) + + # Preserved: system (1) + recent messages after index 7 (3) = 4 + self.assertEqual(applied_event.preserved_message_count, 4) + + await summarizer.cleanup() + + async def test_on_summary_applied_not_fired_on_error(self): + """Test that on_summary_applied event is NOT fired when summarization fails.""" + config = LLMContextSummarizationConfig(max_context_tokens=50, min_messages_after_summary=2) + + summarizer = LLMContextSummarizer(context=self.context, config=config) + await summarizer.setup(self.task_manager) + + for i in range(10): + self.context.add_message({"role": "user", "content": "Test message."}) + + request_frame = None + applied_event = None + + @summarizer.event_handler("on_request_summarization") + async def on_request_summarization(summarizer, frame): + nonlocal request_frame + request_frame = frame + + @summarizer.event_handler("on_summary_applied") + async def on_summary_applied(summarizer, event): + nonlocal applied_event + applied_event = event + + await summarizer.process_frame(LLMFullResponseStartFrame()) + + # Send a result with an error + error_result = LLMContextSummaryResultFrame( + request_id=request_frame.request_id, + summary="", + last_summarized_index=-1, + error="Summarization timed out", + ) + await summarizer.process_frame(error_result) + + await asyncio.sleep(0.05) + + # Event should NOT have fired + self.assertIsNone(applied_event) + + await summarizer.cleanup() + + async def test_request_frame_includes_timeout(self): + """Test that the request frame includes the configured summarization_timeout.""" + config = LLMContextSummarizationConfig( + max_context_tokens=50, + summarization_timeout=60.0, + ) + + summarizer = LLMContextSummarizer(context=self.context, config=config) + await summarizer.setup(self.task_manager) + + request_frame = None + + @summarizer.event_handler("on_request_summarization") + async def on_request_summarization(summarizer, frame): + nonlocal request_frame + request_frame = frame + + for i in range(10): + self.context.add_message({"role": "user", "content": "Test message to add tokens."}) + + await summarizer.process_frame(LLMFullResponseStartFrame()) + + self.assertIsNotNone(request_frame) + self.assertEqual(request_frame.summarization_timeout, 60.0) + + await summarizer.cleanup() + if __name__ == "__main__": unittest.main()