Merge pull request #2022 from pipecat-ai/mb/turn-tracking-end-cancel-frame
TurnTrackingObserver ends turn upon seeing EndFrame, CancelFrame
This commit is contained in:
@@ -13,6 +13,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
- Added reconnection logic and audio buffer management to `GladiaSTTService`.
|
||||
|
||||
- The `TurnTrackingObserver` now ends a turn upon observing an `EndFrame` or
|
||||
`CancelFrame`.
|
||||
|
||||
- Added Polish support to `AWSTranscribeSTTService`.
|
||||
|
||||
- Added new frames `FrameProcessorPauseFrame` and `FrameProcessorResumeFrame`
|
||||
|
||||
@@ -12,6 +12,8 @@ from loguru import logger
|
||||
from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
StartFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
)
|
||||
@@ -73,6 +75,8 @@ class TurnTrackingObserver(BaseObserver):
|
||||
# We only want to end the turn if the bot was previously speaking
|
||||
elif isinstance(data.frame, BotStoppedSpeakingFrame) and self._is_bot_speaking:
|
||||
await self._handle_bot_stopped_speaking(data)
|
||||
elif isinstance(data.frame, (EndFrame, CancelFrame)):
|
||||
await self._handle_pipeline_end(data)
|
||||
|
||||
def _schedule_turn_end(self, data: FramePushed):
|
||||
"""Schedule turn end with a timeout."""
|
||||
@@ -134,6 +138,14 @@ class TurnTrackingObserver(BaseObserver):
|
||||
# This can happen with HTTP TTS services or function calls
|
||||
self._schedule_turn_end(data)
|
||||
|
||||
async def _handle_pipeline_end(self, data: FramePushed):
|
||||
"""Handle pipeline end or cancellation by flushing any active turn."""
|
||||
if self._is_turn_active:
|
||||
# Cancel any pending turn end timer
|
||||
self._cancel_turn_end_timer()
|
||||
# End the current turn
|
||||
await self._end_turn(data, was_interrupted=True)
|
||||
|
||||
async def _start_turn(self, data: FramePushed):
|
||||
"""Start a new turn."""
|
||||
self._is_turn_active = True
|
||||
|
||||
@@ -9,6 +9,7 @@ import unittest
|
||||
from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
@@ -150,7 +151,10 @@ class TestTurnTrackingObserver(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(turn_observer._turn_count, 2)
|
||||
|
||||
async def test_user_interrupts_bot(self):
|
||||
"""Test when user interrupts bot speaking, should end current turn and start new one."""
|
||||
"""Test when user interrupts bot speaking, should end current turn and start new one.
|
||||
|
||||
Note: This test also verifies that the EndFrame ends the turn correctly.
|
||||
"""
|
||||
# Create observer with a short timeout
|
||||
turn_observer = TurnTrackingObserver(turn_end_timeout_secs=0.2)
|
||||
|
||||
@@ -197,6 +201,7 @@ class TestTurnTrackingObserver(unittest.IsolatedAsyncioTestCase):
|
||||
"Turn 1 started",
|
||||
"Turn 1 ended (interrupted: True)", # First turn was interrupted
|
||||
"Turn 2 started", # New turn started after interruption
|
||||
"Turn 2 ended (interrupted: True)", # Second turn ends due to EndFrame
|
||||
]
|
||||
self.assertEqual(turn_events, expected_events)
|
||||
self.assertEqual(turn_observer._turn_count, 2)
|
||||
@@ -256,6 +261,109 @@ class TestTurnTrackingObserver(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(turn_events, expected_events)
|
||||
self.assertEqual(turn_observer._turn_count, 1)
|
||||
|
||||
async def test_cancel_frame_flushes_active_turn(self):
|
||||
"""Test that CancelFrame properly flushes an active turn."""
|
||||
# Create observer with a long timeout to ensure CancelFrame is what ends the turn
|
||||
turn_observer = TurnTrackingObserver(turn_end_timeout_secs=5.0)
|
||||
|
||||
# Create identity filter (passes all frames through)
|
||||
processor = IdentityFilter()
|
||||
|
||||
# Record start/end events with turn numbers
|
||||
turn_events = []
|
||||
|
||||
@turn_observer.event_handler("on_turn_started")
|
||||
async def on_turn_started(observer, turn_number):
|
||||
turn_events.append(f"Turn {turn_number} started")
|
||||
|
||||
@turn_observer.event_handler("on_turn_ended")
|
||||
async def on_turn_ended(observer, turn_number, duration, was_interrupted):
|
||||
turn_events.append(f"Turn {turn_number} ended (interrupted: {was_interrupted})")
|
||||
|
||||
frames_to_send = [
|
||||
# Start a turn but don't complete it naturally
|
||||
UserStartedSpeakingFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
BotStartedSpeakingFrame(),
|
||||
# Send CancelFrame while bot is still speaking
|
||||
CancelFrame(),
|
||||
]
|
||||
|
||||
expected_down_frames = [
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
CancelFrame,
|
||||
]
|
||||
|
||||
await run_test(
|
||||
processor,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
observers=[turn_observer],
|
||||
send_end_frame=False, # Don't send EndFrame since we're testing CancelFrame
|
||||
)
|
||||
|
||||
# Verify that the turn was ended due to CancelFrame (marked as interrupted)
|
||||
expected_events = [
|
||||
"Turn 1 started",
|
||||
"Turn 1 ended (interrupted: True)", # Should be interrupted due to CancelFrame
|
||||
]
|
||||
self.assertEqual(turn_events, expected_events)
|
||||
self.assertEqual(turn_observer._turn_count, 1)
|
||||
|
||||
async def test_end_frame_with_no_active_turn(self):
|
||||
"""Test that EndFrame doesn't cause issues when no turn is active."""
|
||||
# Create observer
|
||||
turn_observer = TurnTrackingObserver(turn_end_timeout_secs=0.2)
|
||||
|
||||
# Create identity filter (passes all frames through)
|
||||
processor = IdentityFilter()
|
||||
|
||||
# Record start/end events with turn numbers
|
||||
turn_events = []
|
||||
|
||||
@turn_observer.event_handler("on_turn_started")
|
||||
async def on_turn_started(observer, turn_number):
|
||||
turn_events.append(f"Turn {turn_number} started")
|
||||
|
||||
@turn_observer.event_handler("on_turn_ended")
|
||||
async def on_turn_ended(observer, turn_number, duration, was_interrupted):
|
||||
turn_events.append(f"Turn {turn_number} ended (interrupted: {was_interrupted})")
|
||||
|
||||
frames_to_send = [
|
||||
# Complete a turn normally
|
||||
UserStartedSpeakingFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
BotStartedSpeakingFrame(),
|
||||
BotStoppedSpeakingFrame(),
|
||||
SleepFrame(sleep=0.4), # Let turn end naturally due to timeout
|
||||
# EndFrame will be sent by run_test when no turn is active
|
||||
]
|
||||
|
||||
expected_down_frames = [
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
]
|
||||
|
||||
await run_test(
|
||||
processor,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
observers=[turn_observer],
|
||||
send_end_frame=True,
|
||||
)
|
||||
|
||||
# Should only see one turn that ends naturally, EndFrame shouldn't create additional events
|
||||
expected_events = [
|
||||
"Turn 1 started",
|
||||
"Turn 1 ended (interrupted: False)", # Ends due to timeout, not EndFrame
|
||||
]
|
||||
self.assertEqual(turn_events, expected_events)
|
||||
self.assertEqual(turn_observer._turn_count, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user