125 lines
4.7 KiB
Python
125 lines
4.7 KiB
Python
#
|
|
# Copyright (c) 2024-2026, Daily
|
|
#
|
|
# SPDX-License-Identifier: BSD 2-Clause License
|
|
#
|
|
|
|
"""Tests for LiveKit transport video stream handling.
|
|
|
|
Regression tests for issue #3116: Memory leak when video_in_enabled=False
|
|
but video tracks are subscribed. The fix ensures video stream processing
|
|
only starts when there is a consumer for the frames.
|
|
"""
|
|
|
|
import unittest
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
try:
|
|
from livekit import rtc
|
|
|
|
from pipecat.transports.livekit.transport import (
|
|
LiveKitCallbacks,
|
|
LiveKitParams,
|
|
LiveKitTransportClient,
|
|
)
|
|
|
|
LIVEKIT_AVAILABLE = True
|
|
except ImportError:
|
|
LIVEKIT_AVAILABLE = False
|
|
|
|
|
|
@unittest.skipUnless(LIVEKIT_AVAILABLE, "livekit package not installed")
|
|
class TestLiveKitVideoStreamMemoryLeak(unittest.IsolatedAsyncioTestCase):
|
|
"""Regression tests for video queue memory leak (#3116).
|
|
|
|
The bug: When video_in_enabled=False, subscribing to a video track would
|
|
start a producer that fills _video_queue, but no consumer would drain it,
|
|
causing unbounded memory growth (~3GB/min).
|
|
|
|
The fix: Only start video stream processing when video_in_enabled=True.
|
|
"""
|
|
|
|
def _create_client(self, video_in_enabled: bool) -> LiveKitTransportClient:
|
|
"""Create a client with the specified video input setting."""
|
|
params = LiveKitParams(video_in_enabled=video_in_enabled)
|
|
callbacks = LiveKitCallbacks(
|
|
on_connected=AsyncMock(),
|
|
on_disconnected=AsyncMock(),
|
|
on_before_disconnect=AsyncMock(),
|
|
on_participant_connected=AsyncMock(),
|
|
on_participant_disconnected=AsyncMock(),
|
|
on_audio_track_subscribed=AsyncMock(),
|
|
on_audio_track_unsubscribed=AsyncMock(),
|
|
on_video_track_subscribed=AsyncMock(),
|
|
on_video_track_unsubscribed=AsyncMock(),
|
|
on_data_received=AsyncMock(),
|
|
on_first_participant_joined=AsyncMock(),
|
|
)
|
|
client = LiveKitTransportClient(
|
|
url="wss://test.livekit.cloud",
|
|
token="test-token",
|
|
room_name="test-room",
|
|
params=params,
|
|
callbacks=callbacks,
|
|
transport_name="test-transport",
|
|
)
|
|
client._task_manager = MagicMock()
|
|
return client
|
|
|
|
def _create_mock_video_track(self):
|
|
"""Create a mock video track subscription event."""
|
|
track = MagicMock()
|
|
track.kind = rtc.TrackKind.KIND_VIDEO
|
|
track.sid = "video-track-123"
|
|
publication = MagicMock()
|
|
participant = MagicMock()
|
|
participant.sid = "participant-456"
|
|
return track, publication, participant
|
|
|
|
async def test_disabled_video_input_does_not_start_queue_producer(self):
|
|
"""When video input is disabled, no producer should fill the queue.
|
|
|
|
This prevents the memory leak where frames accumulate with no consumer.
|
|
"""
|
|
client = self._create_client(video_in_enabled=False)
|
|
track, publication, participant = self._create_mock_video_track()
|
|
|
|
await client._async_on_track_subscribed(track, publication, participant)
|
|
|
|
# Verify no video processing task was started
|
|
task_names = [call[0][1] for call in client._task_manager.create_task.call_args_list]
|
|
video_tasks = [name for name in task_names if "video" in name.lower()]
|
|
self.assertEqual(video_tasks, [], "No video processing task should be started")
|
|
|
|
# Queue should remain empty
|
|
self.assertEqual(client._video_queue.qsize(), 0)
|
|
|
|
# Track metadata should still be recorded
|
|
self.assertIn(participant.sid, client._video_tracks)
|
|
|
|
# Callback should still fire for user code
|
|
client._callbacks.on_video_track_subscribed.assert_called_once()
|
|
|
|
async def test_enabled_video_input_starts_queue_producer(self):
|
|
"""When video input is enabled, the producer should start."""
|
|
client = self._create_client(video_in_enabled=True)
|
|
track, publication, participant = self._create_mock_video_track()
|
|
|
|
with patch.object(rtc, "VideoStream"):
|
|
await client._async_on_track_subscribed(track, publication, participant)
|
|
|
|
# Verify video processing task was started
|
|
task_names = [call[0][1] for call in client._task_manager.create_task.call_args_list]
|
|
video_tasks = [name for name in task_names if "video" in name.lower()]
|
|
self.assertEqual(len(video_tasks), 1, "Video processing task should be started")
|
|
|
|
# Track metadata should be recorded
|
|
self.assertIn(participant.sid, client._video_tracks)
|
|
|
|
# Callback should fire
|
|
client._callbacks.on_video_track_subscribed.assert_called_once()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|