DtmfAggregator: cancel interruption task to avoid a dangling task

This commit is contained in:
Aleix Conchillo Flaqué
2025-07-03 08:18:48 -07:00
parent e9d358ed17
commit af8b4901d4
2 changed files with 11 additions and 8 deletions

View File

@@ -64,6 +64,7 @@ class DTMFAggregator(FrameProcessor):
self._digit_event = asyncio.Event()
self._aggregation_task: Optional[asyncio.Task] = None
self._interruption_task: Optional[asyncio.Task] = None
async def process_frame(self, frame: Frame, direction: FrameDirection) -> None:
"""Process incoming frames and handle DTMF aggregation.
@@ -81,6 +82,7 @@ class DTMFAggregator(FrameProcessor):
if self._aggregation:
await self._flush_aggregation()
await self._stop_aggregation_task()
await self._stop_interruption_task()
await self.push_frame(frame, direction)
elif isinstance(frame, InputDTMFFrame):
# Push the DTMF frame downstream first
@@ -100,7 +102,7 @@ class DTMFAggregator(FrameProcessor):
# For first digit, schedule interruption in separate task
if is_first_digit:
asyncio.create_task(self._send_interruption_task())
self._interruption_task = self.create_task(self._send_interruption_task())
# Check for immediate flush conditions
if frame.button == self._termination_digit:
@@ -111,12 +113,13 @@ class DTMFAggregator(FrameProcessor):
async def _send_interruption_task(self):
"""Send interruption frame safely in a separate task."""
try:
# Send the interruption frame
await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM)
except Exception as e:
# Log error but don't propagate
print(f"Error sending interruption: {e}")
await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM)
async def _stop_interruption_task(self) -> None:
"""Stops the interruption task."""
if self._interruption_task:
await self.cancel_task(self._interruption_task)
self._interruption_task = None
def _create_aggregation_task(self) -> None:
"""Creates the aggregation task if it hasn't been created yet."""

View File

@@ -214,7 +214,7 @@ class TestDTMFAggregator(unittest.IsolatedAsyncioTestCase):
]
# All the InputDTMFFrames plus one TranscriptionFrame
expected_down_frames = [InputDTMFFrame] * 12 + [TranscriptionFrame]
expected_down_frames = [InputDTMFFrame] * len(frames_to_send) + [TranscriptionFrame]
received_down_frames, _ = await run_test(
aggregator,