diff --git a/src/pipecat/processors/aggregators/dtmf_aggregator.py b/src/pipecat/processors/aggregators/dtmf_aggregator.py index f3485245c..43c59b661 100644 --- a/src/pipecat/processors/aggregators/dtmf_aggregator.py +++ b/src/pipecat/processors/aggregators/dtmf_aggregator.py @@ -64,6 +64,7 @@ class DTMFAggregator(FrameProcessor): self._digit_event = asyncio.Event() self._aggregation_task: Optional[asyncio.Task] = None + self._interruption_task: Optional[asyncio.Task] = None async def process_frame(self, frame: Frame, direction: FrameDirection) -> None: """Process incoming frames and handle DTMF aggregation. @@ -81,6 +82,7 @@ class DTMFAggregator(FrameProcessor): if self._aggregation: await self._flush_aggregation() await self._stop_aggregation_task() + await self._stop_interruption_task() await self.push_frame(frame, direction) elif isinstance(frame, InputDTMFFrame): # Push the DTMF frame downstream first @@ -100,7 +102,7 @@ class DTMFAggregator(FrameProcessor): # For first digit, schedule interruption in separate task if is_first_digit: - asyncio.create_task(self._send_interruption_task()) + self._interruption_task = self.create_task(self._send_interruption_task()) # Check for immediate flush conditions if frame.button == self._termination_digit: @@ -111,12 +113,13 @@ class DTMFAggregator(FrameProcessor): async def _send_interruption_task(self): """Send interruption frame safely in a separate task.""" - try: - # Send the interruption frame - await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM) - except Exception as e: - # Log error but don't propagate - print(f"Error sending interruption: {e}") + await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM) + + async def _stop_interruption_task(self) -> None: + """Stops the interruption task.""" + if self._interruption_task: + await self.cancel_task(self._interruption_task) + self._interruption_task = None def _create_aggregation_task(self) -> None: """Creates the aggregation task if it hasn't been created yet.""" diff --git a/tests/test_dtmf_aggregator.py b/tests/test_dtmf_aggregator.py index d2e1cc9aa..40d3ece13 100644 --- a/tests/test_dtmf_aggregator.py +++ b/tests/test_dtmf_aggregator.py @@ -214,7 +214,7 @@ class TestDTMFAggregator(unittest.IsolatedAsyncioTestCase): ] # All the InputDTMFFrames plus one TranscriptionFrame - expected_down_frames = [InputDTMFFrame] * 12 + [TranscriptionFrame] + expected_down_frames = [InputDTMFFrame] * len(frames_to_send) + [TranscriptionFrame] received_down_frames, _ = await run_test( aggregator,