diff --git a/CHANGELOG.md b/CHANGELOG.md index 466b3f28b..a45172e2e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added reconnection logic and audio buffer management to `GladiaSTTService`. +- The `TurnTrackingObserver` now ends a turn upon observing an `EndFrame` or + `CancelFrame`. + - Added Polish support to `AWSTranscribeSTTService`. - Added new frames `FrameProcessorPauseFrame` and `FrameProcessorResumeFrame` diff --git a/src/pipecat/observers/turn_tracking_observer.py b/src/pipecat/observers/turn_tracking_observer.py index 956e46b55..04b5ad92b 100644 --- a/src/pipecat/observers/turn_tracking_observer.py +++ b/src/pipecat/observers/turn_tracking_observer.py @@ -12,6 +12,8 @@ from loguru import logger from pipecat.frames.frames import ( BotStartedSpeakingFrame, BotStoppedSpeakingFrame, + CancelFrame, + EndFrame, StartFrame, UserStartedSpeakingFrame, ) @@ -73,6 +75,8 @@ class TurnTrackingObserver(BaseObserver): # We only want to end the turn if the bot was previously speaking elif isinstance(data.frame, BotStoppedSpeakingFrame) and self._is_bot_speaking: await self._handle_bot_stopped_speaking(data) + elif isinstance(data.frame, (EndFrame, CancelFrame)): + await self._handle_pipeline_end(data) def _schedule_turn_end(self, data: FramePushed): """Schedule turn end with a timeout.""" @@ -134,6 +138,14 @@ class TurnTrackingObserver(BaseObserver): # This can happen with HTTP TTS services or function calls self._schedule_turn_end(data) + async def _handle_pipeline_end(self, data: FramePushed): + """Handle pipeline end or cancellation by flushing any active turn.""" + if self._is_turn_active: + # Cancel any pending turn end timer + self._cancel_turn_end_timer() + # End the current turn + await self._end_turn(data, was_interrupted=True) + async def _start_turn(self, data: FramePushed): """Start a new turn.""" self._is_turn_active = True diff --git a/tests/test_turn_tracking_observer.py b/tests/test_turn_tracking_observer.py index 14cfd472f..dd1f39e71 100644 --- a/tests/test_turn_tracking_observer.py +++ b/tests/test_turn_tracking_observer.py @@ -9,6 +9,7 @@ import unittest from pipecat.frames.frames import ( BotStartedSpeakingFrame, BotStoppedSpeakingFrame, + CancelFrame, UserStartedSpeakingFrame, UserStoppedSpeakingFrame, ) @@ -150,7 +151,10 @@ class TestTurnTrackingObserver(unittest.IsolatedAsyncioTestCase): self.assertEqual(turn_observer._turn_count, 2) async def test_user_interrupts_bot(self): - """Test when user interrupts bot speaking, should end current turn and start new one.""" + """Test when user interrupts bot speaking, should end current turn and start new one. + + Note: This test also verifies that the EndFrame ends the turn correctly. + """ # Create observer with a short timeout turn_observer = TurnTrackingObserver(turn_end_timeout_secs=0.2) @@ -197,6 +201,7 @@ class TestTurnTrackingObserver(unittest.IsolatedAsyncioTestCase): "Turn 1 started", "Turn 1 ended (interrupted: True)", # First turn was interrupted "Turn 2 started", # New turn started after interruption + "Turn 2 ended (interrupted: True)", # Second turn ends due to EndFrame ] self.assertEqual(turn_events, expected_events) self.assertEqual(turn_observer._turn_count, 2) @@ -256,6 +261,109 @@ class TestTurnTrackingObserver(unittest.IsolatedAsyncioTestCase): self.assertEqual(turn_events, expected_events) self.assertEqual(turn_observer._turn_count, 1) + async def test_cancel_frame_flushes_active_turn(self): + """Test that CancelFrame properly flushes an active turn.""" + # Create observer with a long timeout to ensure CancelFrame is what ends the turn + turn_observer = TurnTrackingObserver(turn_end_timeout_secs=5.0) + + # Create identity filter (passes all frames through) + processor = IdentityFilter() + + # Record start/end events with turn numbers + turn_events = [] + + @turn_observer.event_handler("on_turn_started") + async def on_turn_started(observer, turn_number): + turn_events.append(f"Turn {turn_number} started") + + @turn_observer.event_handler("on_turn_ended") + async def on_turn_ended(observer, turn_number, duration, was_interrupted): + turn_events.append(f"Turn {turn_number} ended (interrupted: {was_interrupted})") + + frames_to_send = [ + # Start a turn but don't complete it naturally + UserStartedSpeakingFrame(), + UserStoppedSpeakingFrame(), + BotStartedSpeakingFrame(), + # Send CancelFrame while bot is still speaking + CancelFrame(), + ] + + expected_down_frames = [ + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + BotStartedSpeakingFrame, + CancelFrame, + ] + + await run_test( + processor, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + observers=[turn_observer], + send_end_frame=False, # Don't send EndFrame since we're testing CancelFrame + ) + + # Verify that the turn was ended due to CancelFrame (marked as interrupted) + expected_events = [ + "Turn 1 started", + "Turn 1 ended (interrupted: True)", # Should be interrupted due to CancelFrame + ] + self.assertEqual(turn_events, expected_events) + self.assertEqual(turn_observer._turn_count, 1) + + async def test_end_frame_with_no_active_turn(self): + """Test that EndFrame doesn't cause issues when no turn is active.""" + # Create observer + turn_observer = TurnTrackingObserver(turn_end_timeout_secs=0.2) + + # Create identity filter (passes all frames through) + processor = IdentityFilter() + + # Record start/end events with turn numbers + turn_events = [] + + @turn_observer.event_handler("on_turn_started") + async def on_turn_started(observer, turn_number): + turn_events.append(f"Turn {turn_number} started") + + @turn_observer.event_handler("on_turn_ended") + async def on_turn_ended(observer, turn_number, duration, was_interrupted): + turn_events.append(f"Turn {turn_number} ended (interrupted: {was_interrupted})") + + frames_to_send = [ + # Complete a turn normally + UserStartedSpeakingFrame(), + UserStoppedSpeakingFrame(), + BotStartedSpeakingFrame(), + BotStoppedSpeakingFrame(), + SleepFrame(sleep=0.4), # Let turn end naturally due to timeout + # EndFrame will be sent by run_test when no turn is active + ] + + expected_down_frames = [ + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + BotStartedSpeakingFrame, + BotStoppedSpeakingFrame, + ] + + await run_test( + processor, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + observers=[turn_observer], + send_end_frame=True, + ) + + # Should only see one turn that ends naturally, EndFrame shouldn't create additional events + expected_events = [ + "Turn 1 started", + "Turn 1 ended (interrupted: False)", # Ends due to timeout, not EndFrame + ] + self.assertEqual(turn_events, expected_events) + self.assertEqual(turn_observer._turn_count, 1) + if __name__ == "__main__": unittest.main()