diff --git a/changelog/4294.fixed.md b/changelog/4294.fixed.md new file mode 100644 index 000000000..73bf66e00 --- /dev/null +++ b/changelog/4294.fixed.md @@ -0,0 +1 @@ +- Fixed `on_assistant_turn_stopped` not resetting internal state when the LLM returned no text tokens. Added `interrupted` field to `AssistantTurnStoppedMessage` to indicate whether the assistant turn was interrupted. diff --git a/src/pipecat/processors/aggregators/llm_response_universal.py b/src/pipecat/processors/aggregators/llm_response_universal.py index 94e56a972..bc4129b43 100644 --- a/src/pipecat/processors/aggregators/llm_response_universal.py +++ b/src/pipecat/processors/aggregators/llm_response_universal.py @@ -209,12 +209,16 @@ class AssistantTurnStoppedMessage: content. This is the aggregated transcript that is then used in the context. Parameters: - content: The message content/text. + content: The message content/text. May be empty if the LLM + returned zero tokens (e.g. turn was interrupted before any tokens + were received or pushed) + interrupted: Whether the assistant turn was interrupted. timestamp: When the assistant turn started. """ content: str + interrupted: bool timestamp: str @@ -1032,11 +1036,11 @@ class LLMAssistantAggregator(LLMContextAggregator): await self.push_context_frame(FrameDirection.UPSTREAM) async def _handle_interruptions(self, frame: InterruptionFrame): - await self._trigger_assistant_turn_stopped() + await self._trigger_assistant_turn_stopped(interrupted=True) await self.reset() async def _handle_end_or_cancel(self, frame: Frame): - await self._trigger_assistant_turn_stopped() + await self._trigger_assistant_turn_stopped(interrupted=isinstance(frame, CancelFrame)) if self._summarizer: await self._summarizer.cleanup() @@ -1394,17 +1398,23 @@ class LLMAssistantAggregator(LLMContextAggregator): await self._call_event_handler("on_assistant_turn_started") - async def _trigger_assistant_turn_stopped(self): + async def _trigger_assistant_turn_stopped(self, *, interrupted: bool = False): + if not self._assistant_turn_start_timestamp: + return + aggregation = await self.push_aggregation() if aggregation: # Strip turn completion markers from the transcript - content = self._maybe_strip_turn_completion_markers(aggregation) - message = AssistantTurnStoppedMessage( - content=content, timestamp=self._assistant_turn_start_timestamp - ) - await self._call_event_handler("on_assistant_turn_stopped", message) + aggregation = self._maybe_strip_turn_completion_markers(aggregation) - self._assistant_turn_start_timestamp = "" + message = AssistantTurnStoppedMessage( + content=aggregation, + interrupted=interrupted, + timestamp=self._assistant_turn_start_timestamp, + ) + await self._call_event_handler("on_assistant_turn_stopped", message) + + self._assistant_turn_start_timestamp = "" def _maybe_strip_turn_completion_markers(self, text: str) -> str: """Strip turn completion markers from assistant transcript. diff --git a/tests/test_context_aggregators_universal.py b/tests/test_context_aggregators_universal.py index 8b9a7b743..4cd195bee 100644 --- a/tests/test_context_aggregators_universal.py +++ b/tests/test_context_aggregators_universal.py @@ -580,8 +580,10 @@ class TestLLMAssistantAggregator(unittest.IsolatedAsyncioTestCase): frames_to_send = [LLMFullResponseStartFrame(), LLMFullResponseEndFrame()] await run_test(aggregator, frames_to_send=frames_to_send) self.assertTrue(should_start) - self.assertIsNone(should_stop) - self.assertIsNone(stop_message) + self.assertTrue(should_stop) + self.assertIsNotNone(stop_message) + self.assertFalse(stop_message.interrupted) + self.assertEqual(stop_message.content, "") async def test_simple(self): context = LLMContext() @@ -616,6 +618,7 @@ class TestLLMAssistantAggregator(unittest.IsolatedAsyncioTestCase): ) self.assertTrue(should_start) self.assertTrue(should_stop) + self.assertFalse(stop_message.interrupted) self.assertEqual(stop_message.content, "Hello from Pipecat!") async def test_multiple(self): @@ -653,6 +656,7 @@ class TestLLMAssistantAggregator(unittest.IsolatedAsyncioTestCase): ) self.assertTrue(should_start) self.assertTrue(should_stop) + self.assertFalse(stop_message.interrupted) self.assertEqual(stop_message.content, "Hello from Pipecat!") async def test_multiple_text_with_spaces(self): @@ -858,7 +862,9 @@ class TestLLMAssistantAggregator(unittest.IsolatedAsyncioTestCase): ) self.assertEqual(should_start, 2) self.assertEqual(should_stop, 2) + self.assertTrue(stop_messages[0].interrupted) self.assertEqual(stop_messages[0].content, "Hello") + self.assertFalse(stop_messages[1].interrupted) self.assertEqual(stop_messages[1].content, "Hello there!") async def test_function_call(self):