diff --git a/src/pipecat/transports/livekit/transport.py b/src/pipecat/transports/livekit/transport.py index f55fce63b..3e7ccfdf7 100644 --- a/src/pipecat/transports/livekit/transport.py +++ b/src/pipecat/transports/livekit/transport.py @@ -12,6 +12,7 @@ event handling for conversational AI applications. """ import asyncio +import json from dataclasses import dataclass from typing import Any, Awaitable, Callable, List, Optional @@ -24,6 +25,7 @@ from pipecat.frames.frames import ( AudioRawFrame, CancelFrame, EndFrame, + ImageRawFrame, OutputAudioRawFrame, OutputDTMFFrame, OutputDTMFUrgentFrame, @@ -31,6 +33,7 @@ from pipecat.frames.frames import ( TransportMessageFrame, TransportMessageUrgentFrame, UserAudioRawFrame, + UserImageRawFrame, ) from pipecat.processors.frame_processor import FrameDirection, FrameProcessorSetup from pipecat.transports.base_input import BaseInputTransport @@ -40,6 +43,7 @@ from pipecat.utils.asyncio.task_manager import BaseTaskManager try: from livekit import rtc + from livekit.rtc._proto import video_frame_pb2 as proto_video_frame from tenacity import retry, stop_after_attempt, wait_exponential except ModuleNotFoundError as e: logger.error(f"Exception: {e}") @@ -114,6 +118,8 @@ class LiveKitCallbacks(BaseModel): on_participant_disconnected: Callable[[str], Awaitable[None]] on_audio_track_subscribed: Callable[[str], Awaitable[None]] on_audio_track_unsubscribed: Callable[[str], Awaitable[None]] + on_video_track_subscribed: Callable[[str], Awaitable[None]] + on_video_track_unsubscribed: Callable[[str], Awaitable[None]] on_data_received: Callable[[bytes, str], Awaitable[None]] on_first_participant_joined: Callable[[str], Awaitable[None]] @@ -158,8 +164,11 @@ class LiveKitTransportClient: self._audio_track: Optional[rtc.LocalAudioTrack] = None self._audio_tracks = {} self._audio_queue = asyncio.Queue() + self._video_tracks = {} + self._video_queue = asyncio.Queue() self._other_participant_has_joined = False self._task_manager: Optional[BaseTaskManager] = None + self._async_lock = asyncio.Lock() @property def participant_id(self) -> str: @@ -220,61 +229,63 @@ class LiveKitTransportClient: @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10)) async def connect(self): """Connect to the LiveKit room with retry logic.""" - if self._connected: - # Increment disconnect counter if already connected. - self._disconnect_counter += 1 - return + async with self._async_lock: + if self._connected: + # Increment disconnect counter if already connected. + self._disconnect_counter += 1 + return - logger.info(f"Connecting to {self._room_name}") + logger.info(f"Connecting to {self._room_name}") - try: - await self.room.connect( - self._url, - self._token, - options=rtc.RoomOptions(auto_subscribe=True), - ) - self._connected = True - # Increment disconnect counter if we successfully connected. - self._disconnect_counter += 1 + try: + await self.room.connect( + self._url, + self._token, + options=rtc.RoomOptions(auto_subscribe=True), + ) + self._connected = True + # Increment disconnect counter if we successfully connected. + self._disconnect_counter += 1 - self._participant_id = self.room.local_participant.sid - logger.info(f"Connected to {self._room_name}") + self._participant_id = self.room.local_participant.sid + logger.info(f"Connected to {self._room_name}") - # Set up audio source and track - self._audio_source = rtc.AudioSource( - self._out_sample_rate, self._params.audio_out_channels - ) - self._audio_track = rtc.LocalAudioTrack.create_audio_track( - "pipecat-audio", self._audio_source - ) - options = rtc.TrackPublishOptions() - options.source = rtc.TrackSource.SOURCE_MICROPHONE - await self.room.local_participant.publish_track(self._audio_track, options) + # Set up audio source and track + self._audio_source = rtc.AudioSource( + self._out_sample_rate, self._params.audio_out_channels + ) + self._audio_track = rtc.LocalAudioTrack.create_audio_track( + "pipecat-audio", self._audio_source + ) + options = rtc.TrackPublishOptions() + options.source = rtc.TrackSource.SOURCE_MICROPHONE + await self.room.local_participant.publish_track(self._audio_track, options) - await self._callbacks.on_connected() + await self._callbacks.on_connected() - # Check if there are already participants in the room - participants = self.get_participants() - if participants and not self._other_participant_has_joined: - self._other_participant_has_joined = True - await self._callbacks.on_first_participant_joined(participants[0]) - except Exception as e: - logger.error(f"Error connecting to {self._room_name}: {e}") - raise + # Check if there are already participants in the room + participants = self.get_participants() + if participants and not self._other_participant_has_joined: + self._other_participant_has_joined = True + await self._callbacks.on_first_participant_joined(participants[0]) + except Exception as e: + logger.error(f"Error connecting to {self._room_name}: {e}") + raise async def disconnect(self): """Disconnect from the LiveKit room.""" - # Decrement leave counter when leaving. - self._disconnect_counter -= 1 + async with self._async_lock: + # Decrement leave counter when leaving. + self._disconnect_counter -= 1 - if not self._connected or self._disconnect_counter > 0: - return + if not self._connected or self._disconnect_counter > 0: + return - logger.info(f"Disconnecting from {self._room_name}") - await self.room.disconnect() - self._connected = False - logger.info(f"Disconnected from {self._room_name}") - await self._callbacks.on_disconnected() + logger.info(f"Disconnecting from {self._room_name}") + await self.room.disconnect() + self._connected = False + logger.info(f"Disconnected from {self._room_name}") + await self._callbacks.on_disconnected() async def send_data(self, data: bytes, participant_id: Optional[str] = None): """Send data to participants in the room. @@ -477,6 +488,15 @@ class LiveKitTransportClient: f"{self}::_process_audio_stream", ) await self._callbacks.on_audio_track_subscribed(participant.sid) + elif track.kind == rtc.TrackKind.KIND_VIDEO: + logger.info(f"Video track subscribed: {track.sid} from participant {participant.sid}") + self._video_tracks[participant.sid] = track + video_stream = rtc.VideoStream(track) + self._task_manager.create_task( + self._process_video_stream(video_stream, participant.sid), + f"{self}::_process_video_stream", + ) + await self._callbacks.on_video_track_subscribed(participant.sid) async def _async_on_track_unsubscribed( self, @@ -488,6 +508,8 @@ class LiveKitTransportClient: logger.info(f"Track unsubscribed: {publication.sid} from {participant.identity}") if track.kind == rtc.TrackKind.KIND_AUDIO: await self._callbacks.on_audio_track_unsubscribed(participant.sid) + elif track.kind == rtc.TrackKind.KIND_VIDEO: + await self._callbacks.on_video_track_unsubscribed(participant.sid) async def _async_on_data_received(self, data: rtc.DataPacket): """Handle data received events.""" @@ -518,6 +540,21 @@ class LiveKitTransportClient: frame, participant_id = await self._audio_queue.get() yield frame, participant_id + async def _process_video_stream(self, video_stream: rtc.VideoStream, participant_id: str): + """Process incoming video stream from a participant.""" + logger.info(f"Started processing video stream for participant {participant_id}") + async for event in video_stream: + if isinstance(event, rtc.VideoFrameEvent): + await self._video_queue.put((event, participant_id)) + else: + logger.warning(f"Received unexpected event type: {type(event)}") + + async def get_next_video_frame(self): + """Get the next video frame from the queue.""" + while True: + frame, participant_id = await self._video_queue.get() + yield frame, participant_id + def __str__(self): """String representation of the LiveKit transport client.""" return f"{self._transport_name}::LiveKitTransportClient" @@ -550,6 +587,7 @@ class LiveKitInputTransport(BaseInputTransport): self._client = client self._audio_in_task = None + self._video_in_task = None self._vad_analyzer: Optional[VADAnalyzer] = params.vad_analyzer self._resampler = create_stream_resampler() @@ -582,6 +620,8 @@ class LiveKitInputTransport(BaseInputTransport): await self._client.connect() if not self._audio_in_task and self._params.audio_in_enabled: self._audio_in_task = self.create_task(self._audio_in_task_handler()) + if not self._video_in_task and self._params.video_in_enabled: + self._video_in_task = self.create_task(self._video_in_task_handler()) await self.set_transport_ready(frame) logger.info("LiveKitInputTransport started") @@ -595,6 +635,8 @@ class LiveKitInputTransport(BaseInputTransport): await self._client.disconnect() if self._audio_in_task: await self.cancel_task(self._audio_in_task) + if self._video_in_task: + await self.cancel_task(self._video_in_task) logger.info("LiveKitInputTransport stopped") async def cancel(self, frame: CancelFrame): @@ -607,6 +649,8 @@ class LiveKitInputTransport(BaseInputTransport): await self._client.disconnect() if self._audio_in_task and self._params.audio_in_enabled: await self.cancel_task(self._audio_in_task) + if self._video_in_task and self._params.video_in_enabled: + await self.cancel_task(self._video_in_task) async def setup(self, setup: FrameProcessorSetup): """Setup the input transport with shared client setup. @@ -655,6 +699,29 @@ class LiveKitInputTransport(BaseInputTransport): ) await self.push_audio_frame(input_audio_frame) + async def _video_in_task_handler(self): + """Handle incoming video frames from participants.""" + logger.info("Video input task started") + video_iterator = self._client.get_next_video_frame() + async for video_data in video_iterator: + if video_data: + video_frame_event, participant_id = video_data + pipecat_video_frame = await self._convert_livekit_video_to_pipecat( + video_frame_event=video_frame_event + ) + + # Skip frames with no video data + if len(pipecat_video_frame.image) == 0: + continue + + input_video_frame = UserImageRawFrame( + user_id=participant_id, + image=pipecat_video_frame.image, + size=pipecat_video_frame.size, + format=pipecat_video_frame.format, + ) + await self.push_video_frame(input_video_frame) + async def _convert_livekit_audio_to_pipecat( self, audio_frame_event: rtc.AudioFrameEvent ) -> AudioRawFrame: @@ -671,6 +738,19 @@ class LiveKitInputTransport(BaseInputTransport): num_channels=audio_frame.num_channels, ) + async def _convert_livekit_video_to_pipecat( + self, + video_frame_event: rtc.VideoFrameEvent, + ) -> ImageRawFrame: + """Convert LiveKit video frame to Pipecat video frame.""" + rgb_frame = video_frame_event.frame.convert(proto_video_frame.VideoBufferType.RGB24) + image_frame = ImageRawFrame( + image=rgb_frame.data, + size=(rgb_frame.width, rgb_frame.height), + format="RGB", + ) + return image_frame + class LiveKitOutputTransport(BaseOutputTransport): """Handles outgoing media streams and events to LiveKit rooms. @@ -758,10 +838,14 @@ class LiveKitOutputTransport(BaseOutputTransport): Args: frame: The transport message frame to send. """ + message = frame.message + if isinstance(message, dict): + # fix message encoding for dict-like messages, e.g. RTVI messages. + message = json.dumps(message, ensure_ascii=False) if isinstance(frame, (LiveKitTransportMessageFrame, LiveKitTransportMessageUrgentFrame)): - await self._client.send_data(frame.message.encode(), frame.participant_id) + await self._client.send_data(message.encode(), frame.participant_id) else: - await self._client.send_data(frame.message.encode()) + await self._client.send_data(message.encode()) async def write_audio_frame(self, frame: OutputAudioRawFrame): """Write an audio frame to the LiveKit room. @@ -838,6 +922,8 @@ class LiveKitTransport(BaseTransport): on_participant_disconnected=self._on_participant_disconnected, on_audio_track_subscribed=self._on_audio_track_subscribed, on_audio_track_unsubscribed=self._on_audio_track_unsubscribed, + on_video_track_subscribed=self._on_video_track_subscribed, + on_video_track_unsubscribed=self._on_video_track_unsubscribed, on_data_received=self._on_data_received, on_first_participant_joined=self._on_first_participant_joined, ) @@ -855,6 +941,8 @@ class LiveKitTransport(BaseTransport): self._register_event_handler("on_participant_disconnected") self._register_event_handler("on_audio_track_subscribed") self._register_event_handler("on_audio_track_unsubscribed") + self._register_event_handler("on_video_track_subscribed") + self._register_event_handler("on_video_track_unsubscribed") self._register_event_handler("on_data_received") self._register_event_handler("on_first_participant_joined") self._register_event_handler("on_participant_left") @@ -976,6 +1064,20 @@ class LiveKitTransport(BaseTransport): """Handle audio track unsubscribed events.""" await self._call_event_handler("on_audio_track_unsubscribed", participant_id) + async def _on_video_track_subscribed(self, participant_id: str): + """Handle video track subscribed events.""" + await self._call_event_handler("on_video_track_subscribed", participant_id) + participant = self._client.room.remote_participants.get(participant_id) + if participant: + for publication in participant.video_tracks.values(): + self._client._on_track_subscribed_wrapper( + publication.track, publication, participant + ) + + async def _on_video_track_unsubscribed(self, participant_id: str): + """Handle video track unsubscribed events.""" + await self._call_event_handler("on_video_track_unsubscribed", participant_id) + async def _on_data_received(self, data: bytes, participant_id: str): """Handle data received events.""" if self._input: