Merge pull request #4294 from pipecat-ai/ac/fix-assistant-turn-stopped-event
Fix on_assistant_turn_stopped not firing for tool-call-only responses
This commit is contained in:
1
changelog/4294.fixed.md
Normal file
1
changelog/4294.fixed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Fixed `on_assistant_turn_stopped` not resetting internal state when the LLM returned no text tokens. Added `interrupted` field to `AssistantTurnStoppedMessage` to indicate whether the assistant turn was interrupted.
|
||||
@@ -209,12 +209,16 @@ class AssistantTurnStoppedMessage:
|
||||
content. This is the aggregated transcript that is then used in the context.
|
||||
|
||||
Parameters:
|
||||
content: The message content/text.
|
||||
content: The message content/text. May be empty if the LLM
|
||||
returned zero tokens (e.g. turn was interrupted before any tokens
|
||||
were received or pushed)
|
||||
interrupted: Whether the assistant turn was interrupted.
|
||||
timestamp: When the assistant turn started.
|
||||
|
||||
"""
|
||||
|
||||
content: str
|
||||
interrupted: bool
|
||||
timestamp: str
|
||||
|
||||
|
||||
@@ -1032,11 +1036,11 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
await self.push_context_frame(FrameDirection.UPSTREAM)
|
||||
|
||||
async def _handle_interruptions(self, frame: InterruptionFrame):
|
||||
await self._trigger_assistant_turn_stopped()
|
||||
await self._trigger_assistant_turn_stopped(interrupted=True)
|
||||
await self.reset()
|
||||
|
||||
async def _handle_end_or_cancel(self, frame: Frame):
|
||||
await self._trigger_assistant_turn_stopped()
|
||||
await self._trigger_assistant_turn_stopped(interrupted=isinstance(frame, CancelFrame))
|
||||
if self._summarizer:
|
||||
await self._summarizer.cleanup()
|
||||
|
||||
@@ -1394,17 +1398,23 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
|
||||
await self._call_event_handler("on_assistant_turn_started")
|
||||
|
||||
async def _trigger_assistant_turn_stopped(self):
|
||||
async def _trigger_assistant_turn_stopped(self, *, interrupted: bool = False):
|
||||
if not self._assistant_turn_start_timestamp:
|
||||
return
|
||||
|
||||
aggregation = await self.push_aggregation()
|
||||
if aggregation:
|
||||
# Strip turn completion markers from the transcript
|
||||
content = self._maybe_strip_turn_completion_markers(aggregation)
|
||||
message = AssistantTurnStoppedMessage(
|
||||
content=content, timestamp=self._assistant_turn_start_timestamp
|
||||
)
|
||||
await self._call_event_handler("on_assistant_turn_stopped", message)
|
||||
aggregation = self._maybe_strip_turn_completion_markers(aggregation)
|
||||
|
||||
self._assistant_turn_start_timestamp = ""
|
||||
message = AssistantTurnStoppedMessage(
|
||||
content=aggregation,
|
||||
interrupted=interrupted,
|
||||
timestamp=self._assistant_turn_start_timestamp,
|
||||
)
|
||||
await self._call_event_handler("on_assistant_turn_stopped", message)
|
||||
|
||||
self._assistant_turn_start_timestamp = ""
|
||||
|
||||
def _maybe_strip_turn_completion_markers(self, text: str) -> str:
|
||||
"""Strip turn completion markers from assistant transcript.
|
||||
|
||||
@@ -580,8 +580,10 @@ class TestLLMAssistantAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
frames_to_send = [LLMFullResponseStartFrame(), LLMFullResponseEndFrame()]
|
||||
await run_test(aggregator, frames_to_send=frames_to_send)
|
||||
self.assertTrue(should_start)
|
||||
self.assertIsNone(should_stop)
|
||||
self.assertIsNone(stop_message)
|
||||
self.assertTrue(should_stop)
|
||||
self.assertIsNotNone(stop_message)
|
||||
self.assertFalse(stop_message.interrupted)
|
||||
self.assertEqual(stop_message.content, "")
|
||||
|
||||
async def test_simple(self):
|
||||
context = LLMContext()
|
||||
@@ -616,6 +618,7 @@ class TestLLMAssistantAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
)
|
||||
self.assertTrue(should_start)
|
||||
self.assertTrue(should_stop)
|
||||
self.assertFalse(stop_message.interrupted)
|
||||
self.assertEqual(stop_message.content, "Hello from Pipecat!")
|
||||
|
||||
async def test_multiple(self):
|
||||
@@ -653,6 +656,7 @@ class TestLLMAssistantAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
)
|
||||
self.assertTrue(should_start)
|
||||
self.assertTrue(should_stop)
|
||||
self.assertFalse(stop_message.interrupted)
|
||||
self.assertEqual(stop_message.content, "Hello from Pipecat!")
|
||||
|
||||
async def test_multiple_text_with_spaces(self):
|
||||
@@ -858,7 +862,9 @@ class TestLLMAssistantAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
)
|
||||
self.assertEqual(should_start, 2)
|
||||
self.assertEqual(should_stop, 2)
|
||||
self.assertTrue(stop_messages[0].interrupted)
|
||||
self.assertEqual(stop_messages[0].content, "Hello")
|
||||
self.assertFalse(stop_messages[1].interrupted)
|
||||
self.assertEqual(stop_messages[1].content, "Hello there!")
|
||||
|
||||
async def test_function_call(self):
|
||||
|
||||
Reference in New Issue
Block a user