Merge pull request #2022 from pipecat-ai/mb/turn-tracking-end-cancel-frame

TurnTrackingObserver ends turn upon seeing EndFrame, CancelFrame
This commit is contained in:
Mark Backman
2025-06-23 11:24:27 -04:00
committed by GitHub
3 changed files with 124 additions and 1 deletions

View File

@@ -13,6 +13,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added reconnection logic and audio buffer management to `GladiaSTTService`.
- The `TurnTrackingObserver` now ends a turn upon observing an `EndFrame` or
`CancelFrame`.
- Added Polish support to `AWSTranscribeSTTService`.
- Added new frames `FrameProcessorPauseFrame` and `FrameProcessorResumeFrame`

View File

@@ -12,6 +12,8 @@ from loguru import logger
from pipecat.frames.frames import (
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
CancelFrame,
EndFrame,
StartFrame,
UserStartedSpeakingFrame,
)
@@ -73,6 +75,8 @@ class TurnTrackingObserver(BaseObserver):
# We only want to end the turn if the bot was previously speaking
elif isinstance(data.frame, BotStoppedSpeakingFrame) and self._is_bot_speaking:
await self._handle_bot_stopped_speaking(data)
elif isinstance(data.frame, (EndFrame, CancelFrame)):
await self._handle_pipeline_end(data)
def _schedule_turn_end(self, data: FramePushed):
"""Schedule turn end with a timeout."""
@@ -134,6 +138,14 @@ class TurnTrackingObserver(BaseObserver):
# This can happen with HTTP TTS services or function calls
self._schedule_turn_end(data)
async def _handle_pipeline_end(self, data: FramePushed):
"""Handle pipeline end or cancellation by flushing any active turn."""
if self._is_turn_active:
# Cancel any pending turn end timer
self._cancel_turn_end_timer()
# End the current turn
await self._end_turn(data, was_interrupted=True)
async def _start_turn(self, data: FramePushed):
"""Start a new turn."""
self._is_turn_active = True

View File

@@ -9,6 +9,7 @@ import unittest
from pipecat.frames.frames import (
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
CancelFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
)
@@ -150,7 +151,10 @@ class TestTurnTrackingObserver(unittest.IsolatedAsyncioTestCase):
self.assertEqual(turn_observer._turn_count, 2)
async def test_user_interrupts_bot(self):
"""Test when user interrupts bot speaking, should end current turn and start new one."""
"""Test when user interrupts bot speaking, should end current turn and start new one.
Note: This test also verifies that the EndFrame ends the turn correctly.
"""
# Create observer with a short timeout
turn_observer = TurnTrackingObserver(turn_end_timeout_secs=0.2)
@@ -197,6 +201,7 @@ class TestTurnTrackingObserver(unittest.IsolatedAsyncioTestCase):
"Turn 1 started",
"Turn 1 ended (interrupted: True)", # First turn was interrupted
"Turn 2 started", # New turn started after interruption
"Turn 2 ended (interrupted: True)", # Second turn ends due to EndFrame
]
self.assertEqual(turn_events, expected_events)
self.assertEqual(turn_observer._turn_count, 2)
@@ -256,6 +261,109 @@ class TestTurnTrackingObserver(unittest.IsolatedAsyncioTestCase):
self.assertEqual(turn_events, expected_events)
self.assertEqual(turn_observer._turn_count, 1)
async def test_cancel_frame_flushes_active_turn(self):
"""Test that CancelFrame properly flushes an active turn."""
# Create observer with a long timeout to ensure CancelFrame is what ends the turn
turn_observer = TurnTrackingObserver(turn_end_timeout_secs=5.0)
# Create identity filter (passes all frames through)
processor = IdentityFilter()
# Record start/end events with turn numbers
turn_events = []
@turn_observer.event_handler("on_turn_started")
async def on_turn_started(observer, turn_number):
turn_events.append(f"Turn {turn_number} started")
@turn_observer.event_handler("on_turn_ended")
async def on_turn_ended(observer, turn_number, duration, was_interrupted):
turn_events.append(f"Turn {turn_number} ended (interrupted: {was_interrupted})")
frames_to_send = [
# Start a turn but don't complete it naturally
UserStartedSpeakingFrame(),
UserStoppedSpeakingFrame(),
BotStartedSpeakingFrame(),
# Send CancelFrame while bot is still speaking
CancelFrame(),
]
expected_down_frames = [
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
BotStartedSpeakingFrame,
CancelFrame,
]
await run_test(
processor,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
observers=[turn_observer],
send_end_frame=False, # Don't send EndFrame since we're testing CancelFrame
)
# Verify that the turn was ended due to CancelFrame (marked as interrupted)
expected_events = [
"Turn 1 started",
"Turn 1 ended (interrupted: True)", # Should be interrupted due to CancelFrame
]
self.assertEqual(turn_events, expected_events)
self.assertEqual(turn_observer._turn_count, 1)
async def test_end_frame_with_no_active_turn(self):
"""Test that EndFrame doesn't cause issues when no turn is active."""
# Create observer
turn_observer = TurnTrackingObserver(turn_end_timeout_secs=0.2)
# Create identity filter (passes all frames through)
processor = IdentityFilter()
# Record start/end events with turn numbers
turn_events = []
@turn_observer.event_handler("on_turn_started")
async def on_turn_started(observer, turn_number):
turn_events.append(f"Turn {turn_number} started")
@turn_observer.event_handler("on_turn_ended")
async def on_turn_ended(observer, turn_number, duration, was_interrupted):
turn_events.append(f"Turn {turn_number} ended (interrupted: {was_interrupted})")
frames_to_send = [
# Complete a turn normally
UserStartedSpeakingFrame(),
UserStoppedSpeakingFrame(),
BotStartedSpeakingFrame(),
BotStoppedSpeakingFrame(),
SleepFrame(sleep=0.4), # Let turn end naturally due to timeout
# EndFrame will be sent by run_test when no turn is active
]
expected_down_frames = [
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
]
await run_test(
processor,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
observers=[turn_observer],
send_end_frame=True,
)
# Should only see one turn that ends naturally, EndFrame shouldn't create additional events
expected_events = [
"Turn 1 started",
"Turn 1 ended (interrupted: False)", # Ends due to timeout, not EndFrame
]
self.assertEqual(turn_events, expected_events)
self.assertEqual(turn_observer._turn_count, 1)
if __name__ == "__main__":
unittest.main()