Merge pull request #4294 from pipecat-ai/ac/fix-assistant-turn-stopped-event

Fix on_assistant_turn_stopped not firing for tool-call-only responses
This commit is contained in:
Aleix Conchillo Flaqué
2026-04-14 10:09:55 -07:00
committed by GitHub
3 changed files with 29 additions and 12 deletions

1
changelog/4294.fixed.md Normal file
View File

@@ -0,0 +1 @@
- Fixed `on_assistant_turn_stopped` not resetting internal state when the LLM returned no text tokens. Added `interrupted` field to `AssistantTurnStoppedMessage` to indicate whether the assistant turn was interrupted.

View File

@@ -209,12 +209,16 @@ class AssistantTurnStoppedMessage:
content. This is the aggregated transcript that is then used in the context.
Parameters:
content: The message content/text.
content: The message content/text. May be empty if the LLM
returned zero tokens (e.g. turn was interrupted before any tokens
were received or pushed)
interrupted: Whether the assistant turn was interrupted.
timestamp: When the assistant turn started.
"""
content: str
interrupted: bool
timestamp: str
@@ -1032,11 +1036,11 @@ class LLMAssistantAggregator(LLMContextAggregator):
await self.push_context_frame(FrameDirection.UPSTREAM)
async def _handle_interruptions(self, frame: InterruptionFrame):
await self._trigger_assistant_turn_stopped()
await self._trigger_assistant_turn_stopped(interrupted=True)
await self.reset()
async def _handle_end_or_cancel(self, frame: Frame):
await self._trigger_assistant_turn_stopped()
await self._trigger_assistant_turn_stopped(interrupted=isinstance(frame, CancelFrame))
if self._summarizer:
await self._summarizer.cleanup()
@@ -1394,17 +1398,23 @@ class LLMAssistantAggregator(LLMContextAggregator):
await self._call_event_handler("on_assistant_turn_started")
async def _trigger_assistant_turn_stopped(self):
async def _trigger_assistant_turn_stopped(self, *, interrupted: bool = False):
if not self._assistant_turn_start_timestamp:
return
aggregation = await self.push_aggregation()
if aggregation:
# Strip turn completion markers from the transcript
content = self._maybe_strip_turn_completion_markers(aggregation)
message = AssistantTurnStoppedMessage(
content=content, timestamp=self._assistant_turn_start_timestamp
)
await self._call_event_handler("on_assistant_turn_stopped", message)
aggregation = self._maybe_strip_turn_completion_markers(aggregation)
self._assistant_turn_start_timestamp = ""
message = AssistantTurnStoppedMessage(
content=aggregation,
interrupted=interrupted,
timestamp=self._assistant_turn_start_timestamp,
)
await self._call_event_handler("on_assistant_turn_stopped", message)
self._assistant_turn_start_timestamp = ""
def _maybe_strip_turn_completion_markers(self, text: str) -> str:
"""Strip turn completion markers from assistant transcript.

View File

@@ -580,8 +580,10 @@ class TestLLMAssistantAggregator(unittest.IsolatedAsyncioTestCase):
frames_to_send = [LLMFullResponseStartFrame(), LLMFullResponseEndFrame()]
await run_test(aggregator, frames_to_send=frames_to_send)
self.assertTrue(should_start)
self.assertIsNone(should_stop)
self.assertIsNone(stop_message)
self.assertTrue(should_stop)
self.assertIsNotNone(stop_message)
self.assertFalse(stop_message.interrupted)
self.assertEqual(stop_message.content, "")
async def test_simple(self):
context = LLMContext()
@@ -616,6 +618,7 @@ class TestLLMAssistantAggregator(unittest.IsolatedAsyncioTestCase):
)
self.assertTrue(should_start)
self.assertTrue(should_stop)
self.assertFalse(stop_message.interrupted)
self.assertEqual(stop_message.content, "Hello from Pipecat!")
async def test_multiple(self):
@@ -653,6 +656,7 @@ class TestLLMAssistantAggregator(unittest.IsolatedAsyncioTestCase):
)
self.assertTrue(should_start)
self.assertTrue(should_stop)
self.assertFalse(stop_message.interrupted)
self.assertEqual(stop_message.content, "Hello from Pipecat!")
async def test_multiple_text_with_spaces(self):
@@ -858,7 +862,9 @@ class TestLLMAssistantAggregator(unittest.IsolatedAsyncioTestCase):
)
self.assertEqual(should_start, 2)
self.assertEqual(should_stop, 2)
self.assertTrue(stop_messages[0].interrupted)
self.assertEqual(stop_messages[0].content, "Hello")
self.assertFalse(stop_messages[1].interrupted)
self.assertEqual(stop_messages[1].content, "Hello there!")
async def test_function_call(self):