diff --git a/tests/test_context_summarization.py b/tests/test_context_summarization.py index 3bb1246e9..446bfb8bd 100644 --- a/tests/test_context_summarization.py +++ b/tests/test_context_summarization.py @@ -6,10 +6,11 @@ """Tests for context summarization feature.""" +import asyncio import unittest -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock -from pipecat.frames.frames import LLMContextSummaryRequestFrame +from pipecat.frames.frames import LLMContextSummaryRequestFrame, LLMContextSummaryResultFrame from pipecat.processors.aggregators.llm_context import LLMContext, LLMSpecificMessage from pipecat.services.llm_service import LLMService from pipecat.utils.context.llm_context_summarization import ( @@ -601,6 +602,301 @@ class TestSummaryGenerationExceptions(unittest.IsolatedAsyncioTestCase): self.assertGreater(last_index, -1) self.assertEqual(last_index, 1) # Should be the index of the last summarized message + async def test_generate_summary_task_timeout(self): + """Test that _generate_summary_task handles timeout correctly.""" + llm_service = LLMService() + + # Mock _generate_summary to hang + async def slow_summary(frame): + await asyncio.sleep(10) + return ("summary", 1) + + llm_service._generate_summary = slow_summary + + broadcast_calls = [] + + async def mock_broadcast(frame_class, **kwargs): + broadcast_calls.append((frame_class, kwargs)) + + llm_service.broadcast_frame = mock_broadcast + llm_service.push_error = AsyncMock() + + context = LLMContext() + context.add_message({"role": "user", "content": "Message 1"}) + context.add_message({"role": "assistant", "content": "Response 1"}) + context.add_message({"role": "user", "content": "Message 2"}) + + frame = LLMContextSummaryRequestFrame( + request_id="timeout_test", + context=context, + min_messages_to_keep=1, + target_context_tokens=1000, + summarization_prompt="Summarize this", + summarization_timeout=0.1, # Very short timeout + ) + + await llm_service._generate_summary_task(frame) + + # Should have broadcast an error result + self.assertEqual(len(broadcast_calls), 1) + _, kwargs = broadcast_calls[0] + self.assertEqual(kwargs["request_id"], "timeout_test") + self.assertEqual(kwargs["summary"], "") + self.assertEqual(kwargs["last_summarized_index"], -1) + # error is None for timeout path (push_error is called instead) + self.assertIsNone(kwargs["error"]) + + # push_error should have been called with timeout message + llm_service.push_error.assert_called_once() + call_args = llm_service.push_error.call_args + error_msg = call_args.kwargs.get("error_msg") or call_args.args[0] + self.assertIn("timed out", error_msg) + + +class TestDedicatedLLMSummarization(unittest.IsolatedAsyncioTestCase): + """Tests for dedicated LLM summarization in LLMAssistantAggregator.""" + + def _create_context_and_frame(self): + """Create a context with enough messages and a matching request frame.""" + context = LLMContext() + context.add_message({"role": "user", "content": "Message 1"}) + context.add_message({"role": "assistant", "content": "Response 1"}) + context.add_message({"role": "user", "content": "Message 2"}) + + frame = LLMContextSummaryRequestFrame( + request_id="dedicated_test", + context=context, + min_messages_to_keep=1, + target_context_tokens=1000, + summarization_prompt="Summarize this", + summarization_timeout=5.0, + ) + return context, frame + + async def test_dedicated_llm_success(self): + """Test that dedicated LLM generates summary and feeds result to summarizer.""" + from pipecat.processors.aggregators.llm_context_summarizer import LLMContextSummarizer + from pipecat.processors.aggregators.llm_response_universal import ( + LLMAssistantAggregator, + LLMAssistantAggregatorParams, + ) + from pipecat.utils.asyncio.task_manager import TaskManager, TaskManagerParams + + context, frame = self._create_context_and_frame() + + # Create a mock dedicated LLM + dedicated_llm = LLMService() + dedicated_llm._generate_summary = AsyncMock(return_value=("Dedicated summary", 1)) + + config = LLMContextSummarizationConfig( + max_context_tokens=50, + llm=dedicated_llm, + ) + params = LLMAssistantAggregatorParams( + enable_context_summarization=True, + context_summarization_config=config, + ) + aggregator = LLMAssistantAggregator(context, params=params) + + # Mock summarizer.process_frame to capture the result + result_frames = [] + original_process = aggregator._summarizer.process_frame + + async def capture_process(frame): + result_frames.append(frame) + await original_process(frame) + + aggregator._summarizer.process_frame = capture_process + + # Call the method directly + await aggregator._generate_summary_with_dedicated_llm(dedicated_llm, frame) + + # Verify the dedicated LLM was called + dedicated_llm._generate_summary.assert_called_once_with(frame) + + # Verify result was fed to the summarizer + self.assertEqual(len(result_frames), 1) + result = result_frames[0] + self.assertIsInstance(result, LLMContextSummaryResultFrame) + self.assertEqual(result.request_id, "dedicated_test") + self.assertEqual(result.summary, "Dedicated summary") + self.assertEqual(result.last_summarized_index, 1) + self.assertIsNone(result.error) + + async def test_dedicated_llm_timeout(self): + """Test that dedicated LLM timeout produces error result.""" + from pipecat.processors.aggregators.llm_response_universal import ( + LLMAssistantAggregator, + LLMAssistantAggregatorParams, + ) + + context, _ = self._create_context_and_frame() + + # Create a mock dedicated LLM that hangs + dedicated_llm = LLMService() + + async def slow_summary(frame): + await asyncio.sleep(10) + return ("summary", 1) + + dedicated_llm._generate_summary = slow_summary + + config = LLMContextSummarizationConfig( + max_context_tokens=50, + llm=dedicated_llm, + ) + params = LLMAssistantAggregatorParams( + enable_context_summarization=True, + context_summarization_config=config, + ) + aggregator = LLMAssistantAggregator(context, params=params) + + # Mock summarizer.process_frame to capture the result + result_frames = [] + + async def capture_process(frame): + result_frames.append(frame) + + aggregator._summarizer.process_frame = capture_process + + # Create frame with very short timeout + frame = LLMContextSummaryRequestFrame( + request_id="timeout_test", + context=context, + min_messages_to_keep=1, + target_context_tokens=1000, + summarization_prompt="Summarize this", + summarization_timeout=0.1, + ) + + await aggregator._generate_summary_with_dedicated_llm(dedicated_llm, frame) + + # Verify error result was fed to summarizer + self.assertEqual(len(result_frames), 1) + result = result_frames[0] + self.assertIsInstance(result, LLMContextSummaryResultFrame) + self.assertEqual(result.request_id, "timeout_test") + self.assertEqual(result.summary, "") + self.assertEqual(result.last_summarized_index, -1) + self.assertIn("timed out", result.error) + + async def test_dedicated_llm_exception(self): + """Test that dedicated LLM exceptions produce error result.""" + from pipecat.processors.aggregators.llm_response_universal import ( + LLMAssistantAggregator, + LLMAssistantAggregatorParams, + ) + + context, frame = self._create_context_and_frame() + + # Create a mock dedicated LLM that raises + dedicated_llm = LLMService() + dedicated_llm._generate_summary = AsyncMock( + side_effect=RuntimeError("LLM connection failed") + ) + + config = LLMContextSummarizationConfig( + max_context_tokens=50, + llm=dedicated_llm, + ) + params = LLMAssistantAggregatorParams( + enable_context_summarization=True, + context_summarization_config=config, + ) + aggregator = LLMAssistantAggregator(context, params=params) + aggregator.push_error = AsyncMock() + + # Mock summarizer.process_frame to capture the result + result_frames = [] + + async def capture_process(frame): + result_frames.append(frame) + + aggregator._summarizer.process_frame = capture_process + + await aggregator._generate_summary_with_dedicated_llm(dedicated_llm, frame) + + # Verify error result was fed to summarizer + self.assertEqual(len(result_frames), 1) + result = result_frames[0] + self.assertIsInstance(result, LLMContextSummaryResultFrame) + self.assertEqual(result.request_id, "dedicated_test") + self.assertEqual(result.summary, "") + self.assertEqual(result.last_summarized_index, -1) + self.assertIn("LLM connection failed", result.error) + + # push_error should have been called + aggregator.push_error.assert_called_once() + + async def test_on_request_summarization_routes_to_dedicated_llm(self): + """Test that _on_request_summarization routes to dedicated LLM when configured.""" + from pipecat.processors.aggregators.llm_response_universal import ( + LLMAssistantAggregator, + LLMAssistantAggregatorParams, + ) + + context, frame = self._create_context_and_frame() + + dedicated_llm = LLMService() + dedicated_llm._generate_summary = AsyncMock(return_value=("Summary", 1)) + + config = LLMContextSummarizationConfig( + max_context_tokens=50, + llm=dedicated_llm, + ) + params = LLMAssistantAggregatorParams( + enable_context_summarization=True, + context_summarization_config=config, + ) + aggregator = LLMAssistantAggregator(context, params=params) + aggregator.push_frame = AsyncMock() + + # Track what coroutine is passed to create_task + created_coros = [] + original_create_task = aggregator.create_task + + def mock_create_task(coro, *args, **kwargs): + created_coros.append(coro) + # Actually run the coroutine to avoid "never awaited" warning + task = asyncio.ensure_future(coro) + return task + + aggregator.create_task = mock_create_task + + await aggregator._on_request_summarization(aggregator._summarizer, frame) + + # Should NOT push frame upstream + aggregator.push_frame.assert_not_called() + + # Should have created a task for the dedicated LLM + self.assertEqual(len(created_coros), 1) + + # Wait for the task to complete + await asyncio.sleep(0.05) + + async def test_on_request_summarization_pushes_upstream_without_dedicated_llm(self): + """Test that _on_request_summarization pushes upstream when no dedicated LLM.""" + from pipecat.processors.aggregators.llm_response_universal import ( + LLMAssistantAggregator, + LLMAssistantAggregatorParams, + ) + from pipecat.processors.frame_processor import FrameDirection + + context, frame = self._create_context_and_frame() + + config = LLMContextSummarizationConfig(max_context_tokens=50) + params = LLMAssistantAggregatorParams( + enable_context_summarization=True, + context_summarization_config=config, + ) + aggregator = LLMAssistantAggregator(context, params=params) + aggregator.push_frame = AsyncMock() + + await aggregator._on_request_summarization(aggregator._summarizer, frame) + + # Should push frame upstream + aggregator.push_frame.assert_called_once_with(frame, FrameDirection.UPSTREAM) + class TestLLMSpecificMessageHandling(unittest.TestCase): """Tests that LLMSpecificMessage objects are correctly skipped in summarization.""" diff --git a/tests/test_llm_context_summarizer.py b/tests/test_llm_context_summarizer.py index 7555a8762..0439d403d 100644 --- a/tests/test_llm_context_summarizer.py +++ b/tests/test_llm_context_summarizer.py @@ -14,7 +14,10 @@ from pipecat.frames.frames import ( LLMFullResponseStartFrame, ) from pipecat.processors.aggregators.llm_context import LLMContext -from pipecat.processors.aggregators.llm_context_summarizer import LLMContextSummarizer +from pipecat.processors.aggregators.llm_context_summarizer import ( + LLMContextSummarizer, + SummaryAppliedEvent, +) from pipecat.utils.asyncio.task_manager import TaskManager, TaskManagerParams from pipecat.utils.context.llm_context_summarization import LLMContextSummarizationConfig @@ -291,6 +294,252 @@ class TestLLMContextSummarizer(unittest.IsolatedAsyncioTestCase): await summarizer.cleanup() + async def test_summary_message_role_is_user(self): + """Test that the summary message uses the user role.""" + config = LLMContextSummarizationConfig(max_context_tokens=50, min_messages_after_summary=2) + + summarizer = LLMContextSummarizer(context=self.context, config=config) + await summarizer.setup(self.task_manager) + + # Add messages and trigger summarization + for i in range(10): + self.context.add_message({"role": "user", "content": "Test message."}) + + request_frame = None + + @summarizer.event_handler("on_request_summarization") + async def on_request_summarization(summarizer, frame): + nonlocal request_frame + request_frame = frame + + await summarizer.process_frame(LLMFullResponseStartFrame()) + self.assertIsNotNone(request_frame) + + # Simulate receiving a summary result + summary_result = LLMContextSummaryResultFrame( + request_id=request_frame.request_id, + summary="This is a test summary.", + last_summarized_index=5, + ) + await summarizer.process_frame(summary_result) + + # Find the summary message and verify its role is "user" + summary_msg = next( + (msg for msg in self.context.messages if "summary" in msg.get("content", "").lower()), + None, + ) + self.assertIsNotNone(summary_msg) + self.assertEqual(summary_msg["role"], "user") + + await summarizer.cleanup() + + async def test_summary_message_default_template(self): + """Test that the default summary_message_template is used.""" + config = LLMContextSummarizationConfig(max_context_tokens=50, min_messages_after_summary=2) + + summarizer = LLMContextSummarizer(context=self.context, config=config) + await summarizer.setup(self.task_manager) + + for i in range(10): + self.context.add_message({"role": "user", "content": "Test message."}) + + request_frame = None + + @summarizer.event_handler("on_request_summarization") + async def on_request_summarization(summarizer, frame): + nonlocal request_frame + request_frame = frame + + await summarizer.process_frame(LLMFullResponseStartFrame()) + + summary_result = LLMContextSummaryResultFrame( + request_id=request_frame.request_id, + summary="Key facts from conversation.", + last_summarized_index=5, + ) + await summarizer.process_frame(summary_result) + + # Default template wraps with "Conversation summary: {summary}" + summary_msg = next( + ( + msg + for msg in self.context.messages + if "Conversation summary:" in msg.get("content", "") + ), + None, + ) + self.assertIsNotNone(summary_msg) + self.assertEqual( + summary_msg["content"], "Conversation summary: Key facts from conversation." + ) + + await summarizer.cleanup() + + async def test_summary_message_custom_template(self): + """Test that a custom summary_message_template is applied.""" + config = LLMContextSummarizationConfig( + max_context_tokens=50, + min_messages_after_summary=2, + summary_message_template="\n{summary}\n", + ) + + summarizer = LLMContextSummarizer(context=self.context, config=config) + await summarizer.setup(self.task_manager) + + for i in range(10): + self.context.add_message({"role": "user", "content": "Test message."}) + + request_frame = None + + @summarizer.event_handler("on_request_summarization") + async def on_request_summarization(summarizer, frame): + nonlocal request_frame + request_frame = frame + + await summarizer.process_frame(LLMFullResponseStartFrame()) + + summary_result = LLMContextSummaryResultFrame( + request_id=request_frame.request_id, + summary="Key facts from conversation.", + last_summarized_index=5, + ) + await summarizer.process_frame(summary_result) + + # Custom template wraps with XML tags + summary_msg = next( + (msg for msg in self.context.messages if "" in msg.get("content", "")), + None, + ) + self.assertIsNotNone(summary_msg) + self.assertEqual( + summary_msg["content"], + "\nKey facts from conversation.\n", + ) + + await summarizer.cleanup() + + async def test_on_summary_applied_event(self): + """Test that on_summary_applied event fires with correct data.""" + config = LLMContextSummarizationConfig(max_context_tokens=50, min_messages_after_summary=2) + + summarizer = LLMContextSummarizer(context=self.context, config=config) + await summarizer.setup(self.task_manager) + + # Add messages (1 system + 10 user = 11 total) + for i in range(10): + self.context.add_message({"role": "user", "content": "Test message."}) + + request_frame = None + applied_event = None + + @summarizer.event_handler("on_request_summarization") + async def on_request_summarization(summarizer, frame): + nonlocal request_frame + request_frame = frame + + @summarizer.event_handler("on_summary_applied") + async def on_summary_applied(summarizer, event): + nonlocal applied_event + applied_event = event + + original_count = len(self.context.messages) # 11 + await summarizer.process_frame(LLMFullResponseStartFrame()) + + # Summarize up to index 7 (system=0, user1..user7), keep last 3 (user8, user9, user10) + summary_result = LLMContextSummaryResultFrame( + request_id=request_frame.request_id, + summary="Test summary.", + last_summarized_index=7, + ) + await summarizer.process_frame(summary_result) + + # Allow async event handler to complete + await asyncio.sleep(0.05) + + # Verify event was fired + self.assertIsNotNone(applied_event) + self.assertIsInstance(applied_event, SummaryAppliedEvent) + self.assertEqual(applied_event.original_message_count, original_count) + + # After summarization: system + summary + 3 recent = 5 + self.assertEqual(applied_event.new_message_count, 5) + + # Summarized messages: indices 1-7 = 7 messages (excluding system at index 0) + self.assertEqual(applied_event.summarized_message_count, 7) + + # Preserved: system (1) + recent messages after index 7 (3) = 4 + self.assertEqual(applied_event.preserved_message_count, 4) + + await summarizer.cleanup() + + async def test_on_summary_applied_not_fired_on_error(self): + """Test that on_summary_applied event is NOT fired when summarization fails.""" + config = LLMContextSummarizationConfig(max_context_tokens=50, min_messages_after_summary=2) + + summarizer = LLMContextSummarizer(context=self.context, config=config) + await summarizer.setup(self.task_manager) + + for i in range(10): + self.context.add_message({"role": "user", "content": "Test message."}) + + request_frame = None + applied_event = None + + @summarizer.event_handler("on_request_summarization") + async def on_request_summarization(summarizer, frame): + nonlocal request_frame + request_frame = frame + + @summarizer.event_handler("on_summary_applied") + async def on_summary_applied(summarizer, event): + nonlocal applied_event + applied_event = event + + await summarizer.process_frame(LLMFullResponseStartFrame()) + + # Send a result with an error + error_result = LLMContextSummaryResultFrame( + request_id=request_frame.request_id, + summary="", + last_summarized_index=-1, + error="Summarization timed out", + ) + await summarizer.process_frame(error_result) + + await asyncio.sleep(0.05) + + # Event should NOT have fired + self.assertIsNone(applied_event) + + await summarizer.cleanup() + + async def test_request_frame_includes_timeout(self): + """Test that the request frame includes the configured summarization_timeout.""" + config = LLMContextSummarizationConfig( + max_context_tokens=50, + summarization_timeout=60.0, + ) + + summarizer = LLMContextSummarizer(context=self.context, config=config) + await summarizer.setup(self.task_manager) + + request_frame = None + + @summarizer.event_handler("on_request_summarization") + async def on_request_summarization(summarizer, frame): + nonlocal request_frame + request_frame = frame + + for i in range(10): + self.context.add_message({"role": "user", "content": "Test message to add tokens."}) + + await summarizer.process_frame(LLMFullResponseStartFrame()) + + self.assertIsNotNone(request_frame) + self.assertEqual(request_frame.summarization_timeout, 60.0) + + await summarizer.cleanup() + if __name__ == "__main__": unittest.main()