Files
pipecat/tests/test_livekit_transport.py

125 lines
4.7 KiB
Python

#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Tests for LiveKit transport video stream handling.
Regression tests for issue #3116: Memory leak when video_in_enabled=False
but video tracks are subscribed. The fix ensures video stream processing
only starts when there is a consumer for the frames.
"""
import unittest
from unittest.mock import AsyncMock, MagicMock, patch
try:
from livekit import rtc
from pipecat.transports.livekit.transport import (
LiveKitCallbacks,
LiveKitParams,
LiveKitTransportClient,
)
LIVEKIT_AVAILABLE = True
except ImportError:
LIVEKIT_AVAILABLE = False
@unittest.skipUnless(LIVEKIT_AVAILABLE, "livekit package not installed")
class TestLiveKitVideoStreamMemoryLeak(unittest.IsolatedAsyncioTestCase):
"""Regression tests for video queue memory leak (#3116).
The bug: When video_in_enabled=False, subscribing to a video track would
start a producer that fills _video_queue, but no consumer would drain it,
causing unbounded memory growth (~3GB/min).
The fix: Only start video stream processing when video_in_enabled=True.
"""
def _create_client(self, video_in_enabled: bool) -> LiveKitTransportClient:
"""Create a client with the specified video input setting."""
params = LiveKitParams(video_in_enabled=video_in_enabled)
callbacks = LiveKitCallbacks(
on_connected=AsyncMock(),
on_disconnected=AsyncMock(),
on_before_disconnect=AsyncMock(),
on_participant_connected=AsyncMock(),
on_participant_disconnected=AsyncMock(),
on_audio_track_subscribed=AsyncMock(),
on_audio_track_unsubscribed=AsyncMock(),
on_video_track_subscribed=AsyncMock(),
on_video_track_unsubscribed=AsyncMock(),
on_data_received=AsyncMock(),
on_first_participant_joined=AsyncMock(),
)
client = LiveKitTransportClient(
url="wss://test.livekit.cloud",
token="test-token",
room_name="test-room",
params=params,
callbacks=callbacks,
transport_name="test-transport",
)
client._task_manager = MagicMock()
return client
def _create_mock_video_track(self):
"""Create a mock video track subscription event."""
track = MagicMock()
track.kind = rtc.TrackKind.KIND_VIDEO
track.sid = "video-track-123"
publication = MagicMock()
participant = MagicMock()
participant.sid = "participant-456"
return track, publication, participant
async def test_disabled_video_input_does_not_start_queue_producer(self):
"""When video input is disabled, no producer should fill the queue.
This prevents the memory leak where frames accumulate with no consumer.
"""
client = self._create_client(video_in_enabled=False)
track, publication, participant = self._create_mock_video_track()
await client._async_on_track_subscribed(track, publication, participant)
# Verify no video processing task was started
task_names = [call[0][1] for call in client._task_manager.create_task.call_args_list]
video_tasks = [name for name in task_names if "video" in name.lower()]
self.assertEqual(video_tasks, [], "No video processing task should be started")
# Queue should remain empty
self.assertEqual(client._video_queue.qsize(), 0)
# Track metadata should still be recorded
self.assertIn(participant.sid, client._video_tracks)
# Callback should still fire for user code
client._callbacks.on_video_track_subscribed.assert_called_once()
async def test_enabled_video_input_starts_queue_producer(self):
"""When video input is enabled, the producer should start."""
client = self._create_client(video_in_enabled=True)
track, publication, participant = self._create_mock_video_track()
with patch.object(rtc, "VideoStream"):
await client._async_on_track_subscribed(track, publication, participant)
# Verify video processing task was started
task_names = [call[0][1] for call in client._task_manager.create_task.call_args_list]
video_tasks = [name for name in task_names if "video" in name.lower()]
self.assertEqual(len(video_tasks), 1, "Video processing task should be started")
# Track metadata should be recorded
self.assertIn(participant.sid, client._video_tracks)
# Callback should fire
client._callbacks.on_video_track_subscribed.assert_called_once()
if __name__ == "__main__":
unittest.main()