Add tests for context summarization improvements

Cover summary message role, template, on_summary_applied event,
summarization timeout, and dedicated LLM routing/error handling.
This commit is contained in:
Mark Backman
2026-02-26 22:56:10 -05:00
parent ec9ddb3199
commit 98e737b4e9
2 changed files with 548 additions and 3 deletions

View File

@@ -6,10 +6,11 @@
"""Tests for context summarization feature."""
import asyncio
import unittest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock
from pipecat.frames.frames import LLMContextSummaryRequestFrame
from pipecat.frames.frames import LLMContextSummaryRequestFrame, LLMContextSummaryResultFrame
from pipecat.processors.aggregators.llm_context import LLMContext, LLMSpecificMessage
from pipecat.services.llm_service import LLMService
from pipecat.utils.context.llm_context_summarization import (
@@ -601,6 +602,301 @@ class TestSummaryGenerationExceptions(unittest.IsolatedAsyncioTestCase):
self.assertGreater(last_index, -1)
self.assertEqual(last_index, 1) # Should be the index of the last summarized message
async def test_generate_summary_task_timeout(self):
"""Test that _generate_summary_task handles timeout correctly."""
llm_service = LLMService()
# Mock _generate_summary to hang
async def slow_summary(frame):
await asyncio.sleep(10)
return ("summary", 1)
llm_service._generate_summary = slow_summary
broadcast_calls = []
async def mock_broadcast(frame_class, **kwargs):
broadcast_calls.append((frame_class, kwargs))
llm_service.broadcast_frame = mock_broadcast
llm_service.push_error = AsyncMock()
context = LLMContext()
context.add_message({"role": "user", "content": "Message 1"})
context.add_message({"role": "assistant", "content": "Response 1"})
context.add_message({"role": "user", "content": "Message 2"})
frame = LLMContextSummaryRequestFrame(
request_id="timeout_test",
context=context,
min_messages_to_keep=1,
target_context_tokens=1000,
summarization_prompt="Summarize this",
summarization_timeout=0.1, # Very short timeout
)
await llm_service._generate_summary_task(frame)
# Should have broadcast an error result
self.assertEqual(len(broadcast_calls), 1)
_, kwargs = broadcast_calls[0]
self.assertEqual(kwargs["request_id"], "timeout_test")
self.assertEqual(kwargs["summary"], "")
self.assertEqual(kwargs["last_summarized_index"], -1)
# error is None for timeout path (push_error is called instead)
self.assertIsNone(kwargs["error"])
# push_error should have been called with timeout message
llm_service.push_error.assert_called_once()
call_args = llm_service.push_error.call_args
error_msg = call_args.kwargs.get("error_msg") or call_args.args[0]
self.assertIn("timed out", error_msg)
class TestDedicatedLLMSummarization(unittest.IsolatedAsyncioTestCase):
"""Tests for dedicated LLM summarization in LLMAssistantAggregator."""
def _create_context_and_frame(self):
"""Create a context with enough messages and a matching request frame."""
context = LLMContext()
context.add_message({"role": "user", "content": "Message 1"})
context.add_message({"role": "assistant", "content": "Response 1"})
context.add_message({"role": "user", "content": "Message 2"})
frame = LLMContextSummaryRequestFrame(
request_id="dedicated_test",
context=context,
min_messages_to_keep=1,
target_context_tokens=1000,
summarization_prompt="Summarize this",
summarization_timeout=5.0,
)
return context, frame
async def test_dedicated_llm_success(self):
"""Test that dedicated LLM generates summary and feeds result to summarizer."""
from pipecat.processors.aggregators.llm_context_summarizer import LLMContextSummarizer
from pipecat.processors.aggregators.llm_response_universal import (
LLMAssistantAggregator,
LLMAssistantAggregatorParams,
)
from pipecat.utils.asyncio.task_manager import TaskManager, TaskManagerParams
context, frame = self._create_context_and_frame()
# Create a mock dedicated LLM
dedicated_llm = LLMService()
dedicated_llm._generate_summary = AsyncMock(return_value=("Dedicated summary", 1))
config = LLMContextSummarizationConfig(
max_context_tokens=50,
llm=dedicated_llm,
)
params = LLMAssistantAggregatorParams(
enable_context_summarization=True,
context_summarization_config=config,
)
aggregator = LLMAssistantAggregator(context, params=params)
# Mock summarizer.process_frame to capture the result
result_frames = []
original_process = aggregator._summarizer.process_frame
async def capture_process(frame):
result_frames.append(frame)
await original_process(frame)
aggregator._summarizer.process_frame = capture_process
# Call the method directly
await aggregator._generate_summary_with_dedicated_llm(dedicated_llm, frame)
# Verify the dedicated LLM was called
dedicated_llm._generate_summary.assert_called_once_with(frame)
# Verify result was fed to the summarizer
self.assertEqual(len(result_frames), 1)
result = result_frames[0]
self.assertIsInstance(result, LLMContextSummaryResultFrame)
self.assertEqual(result.request_id, "dedicated_test")
self.assertEqual(result.summary, "Dedicated summary")
self.assertEqual(result.last_summarized_index, 1)
self.assertIsNone(result.error)
async def test_dedicated_llm_timeout(self):
"""Test that dedicated LLM timeout produces error result."""
from pipecat.processors.aggregators.llm_response_universal import (
LLMAssistantAggregator,
LLMAssistantAggregatorParams,
)
context, _ = self._create_context_and_frame()
# Create a mock dedicated LLM that hangs
dedicated_llm = LLMService()
async def slow_summary(frame):
await asyncio.sleep(10)
return ("summary", 1)
dedicated_llm._generate_summary = slow_summary
config = LLMContextSummarizationConfig(
max_context_tokens=50,
llm=dedicated_llm,
)
params = LLMAssistantAggregatorParams(
enable_context_summarization=True,
context_summarization_config=config,
)
aggregator = LLMAssistantAggregator(context, params=params)
# Mock summarizer.process_frame to capture the result
result_frames = []
async def capture_process(frame):
result_frames.append(frame)
aggregator._summarizer.process_frame = capture_process
# Create frame with very short timeout
frame = LLMContextSummaryRequestFrame(
request_id="timeout_test",
context=context,
min_messages_to_keep=1,
target_context_tokens=1000,
summarization_prompt="Summarize this",
summarization_timeout=0.1,
)
await aggregator._generate_summary_with_dedicated_llm(dedicated_llm, frame)
# Verify error result was fed to summarizer
self.assertEqual(len(result_frames), 1)
result = result_frames[0]
self.assertIsInstance(result, LLMContextSummaryResultFrame)
self.assertEqual(result.request_id, "timeout_test")
self.assertEqual(result.summary, "")
self.assertEqual(result.last_summarized_index, -1)
self.assertIn("timed out", result.error)
async def test_dedicated_llm_exception(self):
"""Test that dedicated LLM exceptions produce error result."""
from pipecat.processors.aggregators.llm_response_universal import (
LLMAssistantAggregator,
LLMAssistantAggregatorParams,
)
context, frame = self._create_context_and_frame()
# Create a mock dedicated LLM that raises
dedicated_llm = LLMService()
dedicated_llm._generate_summary = AsyncMock(
side_effect=RuntimeError("LLM connection failed")
)
config = LLMContextSummarizationConfig(
max_context_tokens=50,
llm=dedicated_llm,
)
params = LLMAssistantAggregatorParams(
enable_context_summarization=True,
context_summarization_config=config,
)
aggregator = LLMAssistantAggregator(context, params=params)
aggregator.push_error = AsyncMock()
# Mock summarizer.process_frame to capture the result
result_frames = []
async def capture_process(frame):
result_frames.append(frame)
aggregator._summarizer.process_frame = capture_process
await aggregator._generate_summary_with_dedicated_llm(dedicated_llm, frame)
# Verify error result was fed to summarizer
self.assertEqual(len(result_frames), 1)
result = result_frames[0]
self.assertIsInstance(result, LLMContextSummaryResultFrame)
self.assertEqual(result.request_id, "dedicated_test")
self.assertEqual(result.summary, "")
self.assertEqual(result.last_summarized_index, -1)
self.assertIn("LLM connection failed", result.error)
# push_error should have been called
aggregator.push_error.assert_called_once()
async def test_on_request_summarization_routes_to_dedicated_llm(self):
"""Test that _on_request_summarization routes to dedicated LLM when configured."""
from pipecat.processors.aggregators.llm_response_universal import (
LLMAssistantAggregator,
LLMAssistantAggregatorParams,
)
context, frame = self._create_context_and_frame()
dedicated_llm = LLMService()
dedicated_llm._generate_summary = AsyncMock(return_value=("Summary", 1))
config = LLMContextSummarizationConfig(
max_context_tokens=50,
llm=dedicated_llm,
)
params = LLMAssistantAggregatorParams(
enable_context_summarization=True,
context_summarization_config=config,
)
aggregator = LLMAssistantAggregator(context, params=params)
aggregator.push_frame = AsyncMock()
# Track what coroutine is passed to create_task
created_coros = []
original_create_task = aggregator.create_task
def mock_create_task(coro, *args, **kwargs):
created_coros.append(coro)
# Actually run the coroutine to avoid "never awaited" warning
task = asyncio.ensure_future(coro)
return task
aggregator.create_task = mock_create_task
await aggregator._on_request_summarization(aggregator._summarizer, frame)
# Should NOT push frame upstream
aggregator.push_frame.assert_not_called()
# Should have created a task for the dedicated LLM
self.assertEqual(len(created_coros), 1)
# Wait for the task to complete
await asyncio.sleep(0.05)
async def test_on_request_summarization_pushes_upstream_without_dedicated_llm(self):
"""Test that _on_request_summarization pushes upstream when no dedicated LLM."""
from pipecat.processors.aggregators.llm_response_universal import (
LLMAssistantAggregator,
LLMAssistantAggregatorParams,
)
from pipecat.processors.frame_processor import FrameDirection
context, frame = self._create_context_and_frame()
config = LLMContextSummarizationConfig(max_context_tokens=50)
params = LLMAssistantAggregatorParams(
enable_context_summarization=True,
context_summarization_config=config,
)
aggregator = LLMAssistantAggregator(context, params=params)
aggregator.push_frame = AsyncMock()
await aggregator._on_request_summarization(aggregator._summarizer, frame)
# Should push frame upstream
aggregator.push_frame.assert_called_once_with(frame, FrameDirection.UPSTREAM)
class TestLLMSpecificMessageHandling(unittest.TestCase):
"""Tests that LLMSpecificMessage objects are correctly skipped in summarization."""

View File

@@ -14,7 +14,10 @@ from pipecat.frames.frames import (
LLMFullResponseStartFrame,
)
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_context_summarizer import LLMContextSummarizer
from pipecat.processors.aggregators.llm_context_summarizer import (
LLMContextSummarizer,
SummaryAppliedEvent,
)
from pipecat.utils.asyncio.task_manager import TaskManager, TaskManagerParams
from pipecat.utils.context.llm_context_summarization import LLMContextSummarizationConfig
@@ -291,6 +294,252 @@ class TestLLMContextSummarizer(unittest.IsolatedAsyncioTestCase):
await summarizer.cleanup()
async def test_summary_message_role_is_user(self):
"""Test that the summary message uses the user role."""
config = LLMContextSummarizationConfig(max_context_tokens=50, min_messages_after_summary=2)
summarizer = LLMContextSummarizer(context=self.context, config=config)
await summarizer.setup(self.task_manager)
# Add messages and trigger summarization
for i in range(10):
self.context.add_message({"role": "user", "content": "Test message."})
request_frame = None
@summarizer.event_handler("on_request_summarization")
async def on_request_summarization(summarizer, frame):
nonlocal request_frame
request_frame = frame
await summarizer.process_frame(LLMFullResponseStartFrame())
self.assertIsNotNone(request_frame)
# Simulate receiving a summary result
summary_result = LLMContextSummaryResultFrame(
request_id=request_frame.request_id,
summary="This is a test summary.",
last_summarized_index=5,
)
await summarizer.process_frame(summary_result)
# Find the summary message and verify its role is "user"
summary_msg = next(
(msg for msg in self.context.messages if "summary" in msg.get("content", "").lower()),
None,
)
self.assertIsNotNone(summary_msg)
self.assertEqual(summary_msg["role"], "user")
await summarizer.cleanup()
async def test_summary_message_default_template(self):
"""Test that the default summary_message_template is used."""
config = LLMContextSummarizationConfig(max_context_tokens=50, min_messages_after_summary=2)
summarizer = LLMContextSummarizer(context=self.context, config=config)
await summarizer.setup(self.task_manager)
for i in range(10):
self.context.add_message({"role": "user", "content": "Test message."})
request_frame = None
@summarizer.event_handler("on_request_summarization")
async def on_request_summarization(summarizer, frame):
nonlocal request_frame
request_frame = frame
await summarizer.process_frame(LLMFullResponseStartFrame())
summary_result = LLMContextSummaryResultFrame(
request_id=request_frame.request_id,
summary="Key facts from conversation.",
last_summarized_index=5,
)
await summarizer.process_frame(summary_result)
# Default template wraps with "Conversation summary: {summary}"
summary_msg = next(
(
msg
for msg in self.context.messages
if "Conversation summary:" in msg.get("content", "")
),
None,
)
self.assertIsNotNone(summary_msg)
self.assertEqual(
summary_msg["content"], "Conversation summary: Key facts from conversation."
)
await summarizer.cleanup()
async def test_summary_message_custom_template(self):
"""Test that a custom summary_message_template is applied."""
config = LLMContextSummarizationConfig(
max_context_tokens=50,
min_messages_after_summary=2,
summary_message_template="<context_summary>\n{summary}\n</context_summary>",
)
summarizer = LLMContextSummarizer(context=self.context, config=config)
await summarizer.setup(self.task_manager)
for i in range(10):
self.context.add_message({"role": "user", "content": "Test message."})
request_frame = None
@summarizer.event_handler("on_request_summarization")
async def on_request_summarization(summarizer, frame):
nonlocal request_frame
request_frame = frame
await summarizer.process_frame(LLMFullResponseStartFrame())
summary_result = LLMContextSummaryResultFrame(
request_id=request_frame.request_id,
summary="Key facts from conversation.",
last_summarized_index=5,
)
await summarizer.process_frame(summary_result)
# Custom template wraps with XML tags
summary_msg = next(
(msg for msg in self.context.messages if "<context_summary>" in msg.get("content", "")),
None,
)
self.assertIsNotNone(summary_msg)
self.assertEqual(
summary_msg["content"],
"<context_summary>\nKey facts from conversation.\n</context_summary>",
)
await summarizer.cleanup()
async def test_on_summary_applied_event(self):
"""Test that on_summary_applied event fires with correct data."""
config = LLMContextSummarizationConfig(max_context_tokens=50, min_messages_after_summary=2)
summarizer = LLMContextSummarizer(context=self.context, config=config)
await summarizer.setup(self.task_manager)
# Add messages (1 system + 10 user = 11 total)
for i in range(10):
self.context.add_message({"role": "user", "content": "Test message."})
request_frame = None
applied_event = None
@summarizer.event_handler("on_request_summarization")
async def on_request_summarization(summarizer, frame):
nonlocal request_frame
request_frame = frame
@summarizer.event_handler("on_summary_applied")
async def on_summary_applied(summarizer, event):
nonlocal applied_event
applied_event = event
original_count = len(self.context.messages) # 11
await summarizer.process_frame(LLMFullResponseStartFrame())
# Summarize up to index 7 (system=0, user1..user7), keep last 3 (user8, user9, user10)
summary_result = LLMContextSummaryResultFrame(
request_id=request_frame.request_id,
summary="Test summary.",
last_summarized_index=7,
)
await summarizer.process_frame(summary_result)
# Allow async event handler to complete
await asyncio.sleep(0.05)
# Verify event was fired
self.assertIsNotNone(applied_event)
self.assertIsInstance(applied_event, SummaryAppliedEvent)
self.assertEqual(applied_event.original_message_count, original_count)
# After summarization: system + summary + 3 recent = 5
self.assertEqual(applied_event.new_message_count, 5)
# Summarized messages: indices 1-7 = 7 messages (excluding system at index 0)
self.assertEqual(applied_event.summarized_message_count, 7)
# Preserved: system (1) + recent messages after index 7 (3) = 4
self.assertEqual(applied_event.preserved_message_count, 4)
await summarizer.cleanup()
async def test_on_summary_applied_not_fired_on_error(self):
"""Test that on_summary_applied event is NOT fired when summarization fails."""
config = LLMContextSummarizationConfig(max_context_tokens=50, min_messages_after_summary=2)
summarizer = LLMContextSummarizer(context=self.context, config=config)
await summarizer.setup(self.task_manager)
for i in range(10):
self.context.add_message({"role": "user", "content": "Test message."})
request_frame = None
applied_event = None
@summarizer.event_handler("on_request_summarization")
async def on_request_summarization(summarizer, frame):
nonlocal request_frame
request_frame = frame
@summarizer.event_handler("on_summary_applied")
async def on_summary_applied(summarizer, event):
nonlocal applied_event
applied_event = event
await summarizer.process_frame(LLMFullResponseStartFrame())
# Send a result with an error
error_result = LLMContextSummaryResultFrame(
request_id=request_frame.request_id,
summary="",
last_summarized_index=-1,
error="Summarization timed out",
)
await summarizer.process_frame(error_result)
await asyncio.sleep(0.05)
# Event should NOT have fired
self.assertIsNone(applied_event)
await summarizer.cleanup()
async def test_request_frame_includes_timeout(self):
"""Test that the request frame includes the configured summarization_timeout."""
config = LLMContextSummarizationConfig(
max_context_tokens=50,
summarization_timeout=60.0,
)
summarizer = LLMContextSummarizer(context=self.context, config=config)
await summarizer.setup(self.task_manager)
request_frame = None
@summarizer.event_handler("on_request_summarization")
async def on_request_summarization(summarizer, frame):
nonlocal request_frame
request_frame = frame
for i in range(10):
self.context.add_message({"role": "user", "content": "Test message to add tokens."})
await summarizer.process_frame(LLMFullResponseStartFrame())
self.assertIsNotNone(request_frame)
self.assertEqual(request_frame.summarization_timeout, 60.0)
await summarizer.cleanup()
if __name__ == "__main__":
unittest.main()