diff --git a/tests/test_context_summarization.py b/tests/test_context_summarization.py
index 3bb1246e9..446bfb8bd 100644
--- a/tests/test_context_summarization.py
+++ b/tests/test_context_summarization.py
@@ -6,10 +6,11 @@
"""Tests for context summarization feature."""
+import asyncio
import unittest
-from unittest.mock import AsyncMock, MagicMock, patch
+from unittest.mock import AsyncMock
-from pipecat.frames.frames import LLMContextSummaryRequestFrame
+from pipecat.frames.frames import LLMContextSummaryRequestFrame, LLMContextSummaryResultFrame
from pipecat.processors.aggregators.llm_context import LLMContext, LLMSpecificMessage
from pipecat.services.llm_service import LLMService
from pipecat.utils.context.llm_context_summarization import (
@@ -601,6 +602,301 @@ class TestSummaryGenerationExceptions(unittest.IsolatedAsyncioTestCase):
self.assertGreater(last_index, -1)
self.assertEqual(last_index, 1) # Should be the index of the last summarized message
+ async def test_generate_summary_task_timeout(self):
+ """Test that _generate_summary_task handles timeout correctly."""
+ llm_service = LLMService()
+
+ # Mock _generate_summary to hang
+ async def slow_summary(frame):
+ await asyncio.sleep(10)
+ return ("summary", 1)
+
+ llm_service._generate_summary = slow_summary
+
+ broadcast_calls = []
+
+ async def mock_broadcast(frame_class, **kwargs):
+ broadcast_calls.append((frame_class, kwargs))
+
+ llm_service.broadcast_frame = mock_broadcast
+ llm_service.push_error = AsyncMock()
+
+ context = LLMContext()
+ context.add_message({"role": "user", "content": "Message 1"})
+ context.add_message({"role": "assistant", "content": "Response 1"})
+ context.add_message({"role": "user", "content": "Message 2"})
+
+ frame = LLMContextSummaryRequestFrame(
+ request_id="timeout_test",
+ context=context,
+ min_messages_to_keep=1,
+ target_context_tokens=1000,
+ summarization_prompt="Summarize this",
+ summarization_timeout=0.1, # Very short timeout
+ )
+
+ await llm_service._generate_summary_task(frame)
+
+ # Should have broadcast an error result
+ self.assertEqual(len(broadcast_calls), 1)
+ _, kwargs = broadcast_calls[0]
+ self.assertEqual(kwargs["request_id"], "timeout_test")
+ self.assertEqual(kwargs["summary"], "")
+ self.assertEqual(kwargs["last_summarized_index"], -1)
+ # error is None for timeout path (push_error is called instead)
+ self.assertIsNone(kwargs["error"])
+
+ # push_error should have been called with timeout message
+ llm_service.push_error.assert_called_once()
+ call_args = llm_service.push_error.call_args
+ error_msg = call_args.kwargs.get("error_msg") or call_args.args[0]
+ self.assertIn("timed out", error_msg)
+
+
+class TestDedicatedLLMSummarization(unittest.IsolatedAsyncioTestCase):
+ """Tests for dedicated LLM summarization in LLMAssistantAggregator."""
+
+ def _create_context_and_frame(self):
+ """Create a context with enough messages and a matching request frame."""
+ context = LLMContext()
+ context.add_message({"role": "user", "content": "Message 1"})
+ context.add_message({"role": "assistant", "content": "Response 1"})
+ context.add_message({"role": "user", "content": "Message 2"})
+
+ frame = LLMContextSummaryRequestFrame(
+ request_id="dedicated_test",
+ context=context,
+ min_messages_to_keep=1,
+ target_context_tokens=1000,
+ summarization_prompt="Summarize this",
+ summarization_timeout=5.0,
+ )
+ return context, frame
+
+ async def test_dedicated_llm_success(self):
+ """Test that dedicated LLM generates summary and feeds result to summarizer."""
+ from pipecat.processors.aggregators.llm_context_summarizer import LLMContextSummarizer
+ from pipecat.processors.aggregators.llm_response_universal import (
+ LLMAssistantAggregator,
+ LLMAssistantAggregatorParams,
+ )
+ from pipecat.utils.asyncio.task_manager import TaskManager, TaskManagerParams
+
+ context, frame = self._create_context_and_frame()
+
+ # Create a mock dedicated LLM
+ dedicated_llm = LLMService()
+ dedicated_llm._generate_summary = AsyncMock(return_value=("Dedicated summary", 1))
+
+ config = LLMContextSummarizationConfig(
+ max_context_tokens=50,
+ llm=dedicated_llm,
+ )
+ params = LLMAssistantAggregatorParams(
+ enable_context_summarization=True,
+ context_summarization_config=config,
+ )
+ aggregator = LLMAssistantAggregator(context, params=params)
+
+ # Mock summarizer.process_frame to capture the result
+ result_frames = []
+ original_process = aggregator._summarizer.process_frame
+
+ async def capture_process(frame):
+ result_frames.append(frame)
+ await original_process(frame)
+
+ aggregator._summarizer.process_frame = capture_process
+
+ # Call the method directly
+ await aggregator._generate_summary_with_dedicated_llm(dedicated_llm, frame)
+
+ # Verify the dedicated LLM was called
+ dedicated_llm._generate_summary.assert_called_once_with(frame)
+
+ # Verify result was fed to the summarizer
+ self.assertEqual(len(result_frames), 1)
+ result = result_frames[0]
+ self.assertIsInstance(result, LLMContextSummaryResultFrame)
+ self.assertEqual(result.request_id, "dedicated_test")
+ self.assertEqual(result.summary, "Dedicated summary")
+ self.assertEqual(result.last_summarized_index, 1)
+ self.assertIsNone(result.error)
+
+ async def test_dedicated_llm_timeout(self):
+ """Test that dedicated LLM timeout produces error result."""
+ from pipecat.processors.aggregators.llm_response_universal import (
+ LLMAssistantAggregator,
+ LLMAssistantAggregatorParams,
+ )
+
+ context, _ = self._create_context_and_frame()
+
+ # Create a mock dedicated LLM that hangs
+ dedicated_llm = LLMService()
+
+ async def slow_summary(frame):
+ await asyncio.sleep(10)
+ return ("summary", 1)
+
+ dedicated_llm._generate_summary = slow_summary
+
+ config = LLMContextSummarizationConfig(
+ max_context_tokens=50,
+ llm=dedicated_llm,
+ )
+ params = LLMAssistantAggregatorParams(
+ enable_context_summarization=True,
+ context_summarization_config=config,
+ )
+ aggregator = LLMAssistantAggregator(context, params=params)
+
+ # Mock summarizer.process_frame to capture the result
+ result_frames = []
+
+ async def capture_process(frame):
+ result_frames.append(frame)
+
+ aggregator._summarizer.process_frame = capture_process
+
+ # Create frame with very short timeout
+ frame = LLMContextSummaryRequestFrame(
+ request_id="timeout_test",
+ context=context,
+ min_messages_to_keep=1,
+ target_context_tokens=1000,
+ summarization_prompt="Summarize this",
+ summarization_timeout=0.1,
+ )
+
+ await aggregator._generate_summary_with_dedicated_llm(dedicated_llm, frame)
+
+ # Verify error result was fed to summarizer
+ self.assertEqual(len(result_frames), 1)
+ result = result_frames[0]
+ self.assertIsInstance(result, LLMContextSummaryResultFrame)
+ self.assertEqual(result.request_id, "timeout_test")
+ self.assertEqual(result.summary, "")
+ self.assertEqual(result.last_summarized_index, -1)
+ self.assertIn("timed out", result.error)
+
+ async def test_dedicated_llm_exception(self):
+ """Test that dedicated LLM exceptions produce error result."""
+ from pipecat.processors.aggregators.llm_response_universal import (
+ LLMAssistantAggregator,
+ LLMAssistantAggregatorParams,
+ )
+
+ context, frame = self._create_context_and_frame()
+
+ # Create a mock dedicated LLM that raises
+ dedicated_llm = LLMService()
+ dedicated_llm._generate_summary = AsyncMock(
+ side_effect=RuntimeError("LLM connection failed")
+ )
+
+ config = LLMContextSummarizationConfig(
+ max_context_tokens=50,
+ llm=dedicated_llm,
+ )
+ params = LLMAssistantAggregatorParams(
+ enable_context_summarization=True,
+ context_summarization_config=config,
+ )
+ aggregator = LLMAssistantAggregator(context, params=params)
+ aggregator.push_error = AsyncMock()
+
+ # Mock summarizer.process_frame to capture the result
+ result_frames = []
+
+ async def capture_process(frame):
+ result_frames.append(frame)
+
+ aggregator._summarizer.process_frame = capture_process
+
+ await aggregator._generate_summary_with_dedicated_llm(dedicated_llm, frame)
+
+ # Verify error result was fed to summarizer
+ self.assertEqual(len(result_frames), 1)
+ result = result_frames[0]
+ self.assertIsInstance(result, LLMContextSummaryResultFrame)
+ self.assertEqual(result.request_id, "dedicated_test")
+ self.assertEqual(result.summary, "")
+ self.assertEqual(result.last_summarized_index, -1)
+ self.assertIn("LLM connection failed", result.error)
+
+ # push_error should have been called
+ aggregator.push_error.assert_called_once()
+
+ async def test_on_request_summarization_routes_to_dedicated_llm(self):
+ """Test that _on_request_summarization routes to dedicated LLM when configured."""
+ from pipecat.processors.aggregators.llm_response_universal import (
+ LLMAssistantAggregator,
+ LLMAssistantAggregatorParams,
+ )
+
+ context, frame = self._create_context_and_frame()
+
+ dedicated_llm = LLMService()
+ dedicated_llm._generate_summary = AsyncMock(return_value=("Summary", 1))
+
+ config = LLMContextSummarizationConfig(
+ max_context_tokens=50,
+ llm=dedicated_llm,
+ )
+ params = LLMAssistantAggregatorParams(
+ enable_context_summarization=True,
+ context_summarization_config=config,
+ )
+ aggregator = LLMAssistantAggregator(context, params=params)
+ aggregator.push_frame = AsyncMock()
+
+ # Track what coroutine is passed to create_task
+ created_coros = []
+ original_create_task = aggregator.create_task
+
+ def mock_create_task(coro, *args, **kwargs):
+ created_coros.append(coro)
+ # Actually run the coroutine to avoid "never awaited" warning
+ task = asyncio.ensure_future(coro)
+ return task
+
+ aggregator.create_task = mock_create_task
+
+ await aggregator._on_request_summarization(aggregator._summarizer, frame)
+
+ # Should NOT push frame upstream
+ aggregator.push_frame.assert_not_called()
+
+ # Should have created a task for the dedicated LLM
+ self.assertEqual(len(created_coros), 1)
+
+ # Wait for the task to complete
+ await asyncio.sleep(0.05)
+
+ async def test_on_request_summarization_pushes_upstream_without_dedicated_llm(self):
+ """Test that _on_request_summarization pushes upstream when no dedicated LLM."""
+ from pipecat.processors.aggregators.llm_response_universal import (
+ LLMAssistantAggregator,
+ LLMAssistantAggregatorParams,
+ )
+ from pipecat.processors.frame_processor import FrameDirection
+
+ context, frame = self._create_context_and_frame()
+
+ config = LLMContextSummarizationConfig(max_context_tokens=50)
+ params = LLMAssistantAggregatorParams(
+ enable_context_summarization=True,
+ context_summarization_config=config,
+ )
+ aggregator = LLMAssistantAggregator(context, params=params)
+ aggregator.push_frame = AsyncMock()
+
+ await aggregator._on_request_summarization(aggregator._summarizer, frame)
+
+ # Should push frame upstream
+ aggregator.push_frame.assert_called_once_with(frame, FrameDirection.UPSTREAM)
+
class TestLLMSpecificMessageHandling(unittest.TestCase):
"""Tests that LLMSpecificMessage objects are correctly skipped in summarization."""
diff --git a/tests/test_llm_context_summarizer.py b/tests/test_llm_context_summarizer.py
index 7555a8762..0439d403d 100644
--- a/tests/test_llm_context_summarizer.py
+++ b/tests/test_llm_context_summarizer.py
@@ -14,7 +14,10 @@ from pipecat.frames.frames import (
LLMFullResponseStartFrame,
)
from pipecat.processors.aggregators.llm_context import LLMContext
-from pipecat.processors.aggregators.llm_context_summarizer import LLMContextSummarizer
+from pipecat.processors.aggregators.llm_context_summarizer import (
+ LLMContextSummarizer,
+ SummaryAppliedEvent,
+)
from pipecat.utils.asyncio.task_manager import TaskManager, TaskManagerParams
from pipecat.utils.context.llm_context_summarization import LLMContextSummarizationConfig
@@ -291,6 +294,252 @@ class TestLLMContextSummarizer(unittest.IsolatedAsyncioTestCase):
await summarizer.cleanup()
+ async def test_summary_message_role_is_user(self):
+ """Test that the summary message uses the user role."""
+ config = LLMContextSummarizationConfig(max_context_tokens=50, min_messages_after_summary=2)
+
+ summarizer = LLMContextSummarizer(context=self.context, config=config)
+ await summarizer.setup(self.task_manager)
+
+ # Add messages and trigger summarization
+ for i in range(10):
+ self.context.add_message({"role": "user", "content": "Test message."})
+
+ request_frame = None
+
+ @summarizer.event_handler("on_request_summarization")
+ async def on_request_summarization(summarizer, frame):
+ nonlocal request_frame
+ request_frame = frame
+
+ await summarizer.process_frame(LLMFullResponseStartFrame())
+ self.assertIsNotNone(request_frame)
+
+ # Simulate receiving a summary result
+ summary_result = LLMContextSummaryResultFrame(
+ request_id=request_frame.request_id,
+ summary="This is a test summary.",
+ last_summarized_index=5,
+ )
+ await summarizer.process_frame(summary_result)
+
+ # Find the summary message and verify its role is "user"
+ summary_msg = next(
+ (msg for msg in self.context.messages if "summary" in msg.get("content", "").lower()),
+ None,
+ )
+ self.assertIsNotNone(summary_msg)
+ self.assertEqual(summary_msg["role"], "user")
+
+ await summarizer.cleanup()
+
+ async def test_summary_message_default_template(self):
+ """Test that the default summary_message_template is used."""
+ config = LLMContextSummarizationConfig(max_context_tokens=50, min_messages_after_summary=2)
+
+ summarizer = LLMContextSummarizer(context=self.context, config=config)
+ await summarizer.setup(self.task_manager)
+
+ for i in range(10):
+ self.context.add_message({"role": "user", "content": "Test message."})
+
+ request_frame = None
+
+ @summarizer.event_handler("on_request_summarization")
+ async def on_request_summarization(summarizer, frame):
+ nonlocal request_frame
+ request_frame = frame
+
+ await summarizer.process_frame(LLMFullResponseStartFrame())
+
+ summary_result = LLMContextSummaryResultFrame(
+ request_id=request_frame.request_id,
+ summary="Key facts from conversation.",
+ last_summarized_index=5,
+ )
+ await summarizer.process_frame(summary_result)
+
+ # Default template wraps with "Conversation summary: {summary}"
+ summary_msg = next(
+ (
+ msg
+ for msg in self.context.messages
+ if "Conversation summary:" in msg.get("content", "")
+ ),
+ None,
+ )
+ self.assertIsNotNone(summary_msg)
+ self.assertEqual(
+ summary_msg["content"], "Conversation summary: Key facts from conversation."
+ )
+
+ await summarizer.cleanup()
+
+ async def test_summary_message_custom_template(self):
+ """Test that a custom summary_message_template is applied."""
+ config = LLMContextSummarizationConfig(
+ max_context_tokens=50,
+ min_messages_after_summary=2,
+ summary_message_template="\n{summary}\n",
+ )
+
+ summarizer = LLMContextSummarizer(context=self.context, config=config)
+ await summarizer.setup(self.task_manager)
+
+ for i in range(10):
+ self.context.add_message({"role": "user", "content": "Test message."})
+
+ request_frame = None
+
+ @summarizer.event_handler("on_request_summarization")
+ async def on_request_summarization(summarizer, frame):
+ nonlocal request_frame
+ request_frame = frame
+
+ await summarizer.process_frame(LLMFullResponseStartFrame())
+
+ summary_result = LLMContextSummaryResultFrame(
+ request_id=request_frame.request_id,
+ summary="Key facts from conversation.",
+ last_summarized_index=5,
+ )
+ await summarizer.process_frame(summary_result)
+
+ # Custom template wraps with XML tags
+ summary_msg = next(
+ (msg for msg in self.context.messages if "" in msg.get("content", "")),
+ None,
+ )
+ self.assertIsNotNone(summary_msg)
+ self.assertEqual(
+ summary_msg["content"],
+ "\nKey facts from conversation.\n",
+ )
+
+ await summarizer.cleanup()
+
+ async def test_on_summary_applied_event(self):
+ """Test that on_summary_applied event fires with correct data."""
+ config = LLMContextSummarizationConfig(max_context_tokens=50, min_messages_after_summary=2)
+
+ summarizer = LLMContextSummarizer(context=self.context, config=config)
+ await summarizer.setup(self.task_manager)
+
+ # Add messages (1 system + 10 user = 11 total)
+ for i in range(10):
+ self.context.add_message({"role": "user", "content": "Test message."})
+
+ request_frame = None
+ applied_event = None
+
+ @summarizer.event_handler("on_request_summarization")
+ async def on_request_summarization(summarizer, frame):
+ nonlocal request_frame
+ request_frame = frame
+
+ @summarizer.event_handler("on_summary_applied")
+ async def on_summary_applied(summarizer, event):
+ nonlocal applied_event
+ applied_event = event
+
+ original_count = len(self.context.messages) # 11
+ await summarizer.process_frame(LLMFullResponseStartFrame())
+
+ # Summarize up to index 7 (system=0, user1..user7), keep last 3 (user8, user9, user10)
+ summary_result = LLMContextSummaryResultFrame(
+ request_id=request_frame.request_id,
+ summary="Test summary.",
+ last_summarized_index=7,
+ )
+ await summarizer.process_frame(summary_result)
+
+ # Allow async event handler to complete
+ await asyncio.sleep(0.05)
+
+ # Verify event was fired
+ self.assertIsNotNone(applied_event)
+ self.assertIsInstance(applied_event, SummaryAppliedEvent)
+ self.assertEqual(applied_event.original_message_count, original_count)
+
+ # After summarization: system + summary + 3 recent = 5
+ self.assertEqual(applied_event.new_message_count, 5)
+
+ # Summarized messages: indices 1-7 = 7 messages (excluding system at index 0)
+ self.assertEqual(applied_event.summarized_message_count, 7)
+
+ # Preserved: system (1) + recent messages after index 7 (3) = 4
+ self.assertEqual(applied_event.preserved_message_count, 4)
+
+ await summarizer.cleanup()
+
+ async def test_on_summary_applied_not_fired_on_error(self):
+ """Test that on_summary_applied event is NOT fired when summarization fails."""
+ config = LLMContextSummarizationConfig(max_context_tokens=50, min_messages_after_summary=2)
+
+ summarizer = LLMContextSummarizer(context=self.context, config=config)
+ await summarizer.setup(self.task_manager)
+
+ for i in range(10):
+ self.context.add_message({"role": "user", "content": "Test message."})
+
+ request_frame = None
+ applied_event = None
+
+ @summarizer.event_handler("on_request_summarization")
+ async def on_request_summarization(summarizer, frame):
+ nonlocal request_frame
+ request_frame = frame
+
+ @summarizer.event_handler("on_summary_applied")
+ async def on_summary_applied(summarizer, event):
+ nonlocal applied_event
+ applied_event = event
+
+ await summarizer.process_frame(LLMFullResponseStartFrame())
+
+ # Send a result with an error
+ error_result = LLMContextSummaryResultFrame(
+ request_id=request_frame.request_id,
+ summary="",
+ last_summarized_index=-1,
+ error="Summarization timed out",
+ )
+ await summarizer.process_frame(error_result)
+
+ await asyncio.sleep(0.05)
+
+ # Event should NOT have fired
+ self.assertIsNone(applied_event)
+
+ await summarizer.cleanup()
+
+ async def test_request_frame_includes_timeout(self):
+ """Test that the request frame includes the configured summarization_timeout."""
+ config = LLMContextSummarizationConfig(
+ max_context_tokens=50,
+ summarization_timeout=60.0,
+ )
+
+ summarizer = LLMContextSummarizer(context=self.context, config=config)
+ await summarizer.setup(self.task_manager)
+
+ request_frame = None
+
+ @summarizer.event_handler("on_request_summarization")
+ async def on_request_summarization(summarizer, frame):
+ nonlocal request_frame
+ request_frame = frame
+
+ for i in range(10):
+ self.context.add_message({"role": "user", "content": "Test message to add tokens."})
+
+ await summarizer.process_frame(LLMFullResponseStartFrame())
+
+ self.assertIsNotNone(request_frame)
+ self.assertEqual(request_frame.summarization_timeout, 60.0)
+
+ await summarizer.cleanup()
+
if __name__ == "__main__":
unittest.main()