Merge pull request #2604 from alexyzhou/feature/livekit_video_and_bug_fix
Feature: Add support for livekit video stream and minor bug fixes
This commit is contained in:
@@ -12,6 +12,7 @@ event handling for conversational AI applications.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Callable, List, Optional
|
||||
|
||||
@@ -24,6 +25,7 @@ from pipecat.frames.frames import (
|
||||
AudioRawFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ImageRawFrame,
|
||||
OutputAudioRawFrame,
|
||||
OutputDTMFFrame,
|
||||
OutputDTMFUrgentFrame,
|
||||
@@ -31,6 +33,7 @@ from pipecat.frames.frames import (
|
||||
TransportMessageFrame,
|
||||
TransportMessageUrgentFrame,
|
||||
UserAudioRawFrame,
|
||||
UserImageRawFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessorSetup
|
||||
from pipecat.transports.base_input import BaseInputTransport
|
||||
@@ -40,6 +43,7 @@ from pipecat.utils.asyncio.task_manager import BaseTaskManager
|
||||
|
||||
try:
|
||||
from livekit import rtc
|
||||
from livekit.rtc._proto import video_frame_pb2 as proto_video_frame
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
@@ -114,6 +118,8 @@ class LiveKitCallbacks(BaseModel):
|
||||
on_participant_disconnected: Callable[[str], Awaitable[None]]
|
||||
on_audio_track_subscribed: Callable[[str], Awaitable[None]]
|
||||
on_audio_track_unsubscribed: Callable[[str], Awaitable[None]]
|
||||
on_video_track_subscribed: Callable[[str], Awaitable[None]]
|
||||
on_video_track_unsubscribed: Callable[[str], Awaitable[None]]
|
||||
on_data_received: Callable[[bytes, str], Awaitable[None]]
|
||||
on_first_participant_joined: Callable[[str], Awaitable[None]]
|
||||
|
||||
@@ -158,8 +164,11 @@ class LiveKitTransportClient:
|
||||
self._audio_track: Optional[rtc.LocalAudioTrack] = None
|
||||
self._audio_tracks = {}
|
||||
self._audio_queue = asyncio.Queue()
|
||||
self._video_tracks = {}
|
||||
self._video_queue = asyncio.Queue()
|
||||
self._other_participant_has_joined = False
|
||||
self._task_manager: Optional[BaseTaskManager] = None
|
||||
self._async_lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def participant_id(self) -> str:
|
||||
@@ -220,61 +229,63 @@ class LiveKitTransportClient:
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
|
||||
async def connect(self):
|
||||
"""Connect to the LiveKit room with retry logic."""
|
||||
if self._connected:
|
||||
# Increment disconnect counter if already connected.
|
||||
self._disconnect_counter += 1
|
||||
return
|
||||
async with self._async_lock:
|
||||
if self._connected:
|
||||
# Increment disconnect counter if already connected.
|
||||
self._disconnect_counter += 1
|
||||
return
|
||||
|
||||
logger.info(f"Connecting to {self._room_name}")
|
||||
logger.info(f"Connecting to {self._room_name}")
|
||||
|
||||
try:
|
||||
await self.room.connect(
|
||||
self._url,
|
||||
self._token,
|
||||
options=rtc.RoomOptions(auto_subscribe=True),
|
||||
)
|
||||
self._connected = True
|
||||
# Increment disconnect counter if we successfully connected.
|
||||
self._disconnect_counter += 1
|
||||
try:
|
||||
await self.room.connect(
|
||||
self._url,
|
||||
self._token,
|
||||
options=rtc.RoomOptions(auto_subscribe=True),
|
||||
)
|
||||
self._connected = True
|
||||
# Increment disconnect counter if we successfully connected.
|
||||
self._disconnect_counter += 1
|
||||
|
||||
self._participant_id = self.room.local_participant.sid
|
||||
logger.info(f"Connected to {self._room_name}")
|
||||
self._participant_id = self.room.local_participant.sid
|
||||
logger.info(f"Connected to {self._room_name}")
|
||||
|
||||
# Set up audio source and track
|
||||
self._audio_source = rtc.AudioSource(
|
||||
self._out_sample_rate, self._params.audio_out_channels
|
||||
)
|
||||
self._audio_track = rtc.LocalAudioTrack.create_audio_track(
|
||||
"pipecat-audio", self._audio_source
|
||||
)
|
||||
options = rtc.TrackPublishOptions()
|
||||
options.source = rtc.TrackSource.SOURCE_MICROPHONE
|
||||
await self.room.local_participant.publish_track(self._audio_track, options)
|
||||
# Set up audio source and track
|
||||
self._audio_source = rtc.AudioSource(
|
||||
self._out_sample_rate, self._params.audio_out_channels
|
||||
)
|
||||
self._audio_track = rtc.LocalAudioTrack.create_audio_track(
|
||||
"pipecat-audio", self._audio_source
|
||||
)
|
||||
options = rtc.TrackPublishOptions()
|
||||
options.source = rtc.TrackSource.SOURCE_MICROPHONE
|
||||
await self.room.local_participant.publish_track(self._audio_track, options)
|
||||
|
||||
await self._callbacks.on_connected()
|
||||
await self._callbacks.on_connected()
|
||||
|
||||
# Check if there are already participants in the room
|
||||
participants = self.get_participants()
|
||||
if participants and not self._other_participant_has_joined:
|
||||
self._other_participant_has_joined = True
|
||||
await self._callbacks.on_first_participant_joined(participants[0])
|
||||
except Exception as e:
|
||||
logger.error(f"Error connecting to {self._room_name}: {e}")
|
||||
raise
|
||||
# Check if there are already participants in the room
|
||||
participants = self.get_participants()
|
||||
if participants and not self._other_participant_has_joined:
|
||||
self._other_participant_has_joined = True
|
||||
await self._callbacks.on_first_participant_joined(participants[0])
|
||||
except Exception as e:
|
||||
logger.error(f"Error connecting to {self._room_name}: {e}")
|
||||
raise
|
||||
|
||||
async def disconnect(self):
|
||||
"""Disconnect from the LiveKit room."""
|
||||
# Decrement leave counter when leaving.
|
||||
self._disconnect_counter -= 1
|
||||
async with self._async_lock:
|
||||
# Decrement leave counter when leaving.
|
||||
self._disconnect_counter -= 1
|
||||
|
||||
if not self._connected or self._disconnect_counter > 0:
|
||||
return
|
||||
if not self._connected or self._disconnect_counter > 0:
|
||||
return
|
||||
|
||||
logger.info(f"Disconnecting from {self._room_name}")
|
||||
await self.room.disconnect()
|
||||
self._connected = False
|
||||
logger.info(f"Disconnected from {self._room_name}")
|
||||
await self._callbacks.on_disconnected()
|
||||
logger.info(f"Disconnecting from {self._room_name}")
|
||||
await self.room.disconnect()
|
||||
self._connected = False
|
||||
logger.info(f"Disconnected from {self._room_name}")
|
||||
await self._callbacks.on_disconnected()
|
||||
|
||||
async def send_data(self, data: bytes, participant_id: Optional[str] = None):
|
||||
"""Send data to participants in the room.
|
||||
@@ -477,6 +488,15 @@ class LiveKitTransportClient:
|
||||
f"{self}::_process_audio_stream",
|
||||
)
|
||||
await self._callbacks.on_audio_track_subscribed(participant.sid)
|
||||
elif track.kind == rtc.TrackKind.KIND_VIDEO:
|
||||
logger.info(f"Video track subscribed: {track.sid} from participant {participant.sid}")
|
||||
self._video_tracks[participant.sid] = track
|
||||
video_stream = rtc.VideoStream(track)
|
||||
self._task_manager.create_task(
|
||||
self._process_video_stream(video_stream, participant.sid),
|
||||
f"{self}::_process_video_stream",
|
||||
)
|
||||
await self._callbacks.on_video_track_subscribed(participant.sid)
|
||||
|
||||
async def _async_on_track_unsubscribed(
|
||||
self,
|
||||
@@ -488,6 +508,8 @@ class LiveKitTransportClient:
|
||||
logger.info(f"Track unsubscribed: {publication.sid} from {participant.identity}")
|
||||
if track.kind == rtc.TrackKind.KIND_AUDIO:
|
||||
await self._callbacks.on_audio_track_unsubscribed(participant.sid)
|
||||
elif track.kind == rtc.TrackKind.KIND_VIDEO:
|
||||
await self._callbacks.on_video_track_unsubscribed(participant.sid)
|
||||
|
||||
async def _async_on_data_received(self, data: rtc.DataPacket):
|
||||
"""Handle data received events."""
|
||||
@@ -518,6 +540,21 @@ class LiveKitTransportClient:
|
||||
frame, participant_id = await self._audio_queue.get()
|
||||
yield frame, participant_id
|
||||
|
||||
async def _process_video_stream(self, video_stream: rtc.VideoStream, participant_id: str):
|
||||
"""Process incoming video stream from a participant."""
|
||||
logger.info(f"Started processing video stream for participant {participant_id}")
|
||||
async for event in video_stream:
|
||||
if isinstance(event, rtc.VideoFrameEvent):
|
||||
await self._video_queue.put((event, participant_id))
|
||||
else:
|
||||
logger.warning(f"Received unexpected event type: {type(event)}")
|
||||
|
||||
async def get_next_video_frame(self):
|
||||
"""Get the next video frame from the queue."""
|
||||
while True:
|
||||
frame, participant_id = await self._video_queue.get()
|
||||
yield frame, participant_id
|
||||
|
||||
def __str__(self):
|
||||
"""String representation of the LiveKit transport client."""
|
||||
return f"{self._transport_name}::LiveKitTransportClient"
|
||||
@@ -550,6 +587,7 @@ class LiveKitInputTransport(BaseInputTransport):
|
||||
self._client = client
|
||||
|
||||
self._audio_in_task = None
|
||||
self._video_in_task = None
|
||||
self._vad_analyzer: Optional[VADAnalyzer] = params.vad_analyzer
|
||||
self._resampler = create_stream_resampler()
|
||||
|
||||
@@ -582,6 +620,8 @@ class LiveKitInputTransport(BaseInputTransport):
|
||||
await self._client.connect()
|
||||
if not self._audio_in_task and self._params.audio_in_enabled:
|
||||
self._audio_in_task = self.create_task(self._audio_in_task_handler())
|
||||
if not self._video_in_task and self._params.video_in_enabled:
|
||||
self._video_in_task = self.create_task(self._video_in_task_handler())
|
||||
await self.set_transport_ready(frame)
|
||||
logger.info("LiveKitInputTransport started")
|
||||
|
||||
@@ -595,6 +635,8 @@ class LiveKitInputTransport(BaseInputTransport):
|
||||
await self._client.disconnect()
|
||||
if self._audio_in_task:
|
||||
await self.cancel_task(self._audio_in_task)
|
||||
if self._video_in_task:
|
||||
await self.cancel_task(self._video_in_task)
|
||||
logger.info("LiveKitInputTransport stopped")
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
@@ -607,6 +649,8 @@ class LiveKitInputTransport(BaseInputTransport):
|
||||
await self._client.disconnect()
|
||||
if self._audio_in_task and self._params.audio_in_enabled:
|
||||
await self.cancel_task(self._audio_in_task)
|
||||
if self._video_in_task and self._params.video_in_enabled:
|
||||
await self.cancel_task(self._video_in_task)
|
||||
|
||||
async def setup(self, setup: FrameProcessorSetup):
|
||||
"""Setup the input transport with shared client setup.
|
||||
@@ -655,6 +699,29 @@ class LiveKitInputTransport(BaseInputTransport):
|
||||
)
|
||||
await self.push_audio_frame(input_audio_frame)
|
||||
|
||||
async def _video_in_task_handler(self):
|
||||
"""Handle incoming video frames from participants."""
|
||||
logger.info("Video input task started")
|
||||
video_iterator = self._client.get_next_video_frame()
|
||||
async for video_data in video_iterator:
|
||||
if video_data:
|
||||
video_frame_event, participant_id = video_data
|
||||
pipecat_video_frame = await self._convert_livekit_video_to_pipecat(
|
||||
video_frame_event=video_frame_event
|
||||
)
|
||||
|
||||
# Skip frames with no video data
|
||||
if len(pipecat_video_frame.image) == 0:
|
||||
continue
|
||||
|
||||
input_video_frame = UserImageRawFrame(
|
||||
user_id=participant_id,
|
||||
image=pipecat_video_frame.image,
|
||||
size=pipecat_video_frame.size,
|
||||
format=pipecat_video_frame.format,
|
||||
)
|
||||
await self.push_video_frame(input_video_frame)
|
||||
|
||||
async def _convert_livekit_audio_to_pipecat(
|
||||
self, audio_frame_event: rtc.AudioFrameEvent
|
||||
) -> AudioRawFrame:
|
||||
@@ -671,6 +738,19 @@ class LiveKitInputTransport(BaseInputTransport):
|
||||
num_channels=audio_frame.num_channels,
|
||||
)
|
||||
|
||||
async def _convert_livekit_video_to_pipecat(
|
||||
self,
|
||||
video_frame_event: rtc.VideoFrameEvent,
|
||||
) -> ImageRawFrame:
|
||||
"""Convert LiveKit video frame to Pipecat video frame."""
|
||||
rgb_frame = video_frame_event.frame.convert(proto_video_frame.VideoBufferType.RGB24)
|
||||
image_frame = ImageRawFrame(
|
||||
image=rgb_frame.data,
|
||||
size=(rgb_frame.width, rgb_frame.height),
|
||||
format="RGB",
|
||||
)
|
||||
return image_frame
|
||||
|
||||
|
||||
class LiveKitOutputTransport(BaseOutputTransport):
|
||||
"""Handles outgoing media streams and events to LiveKit rooms.
|
||||
@@ -758,10 +838,14 @@ class LiveKitOutputTransport(BaseOutputTransport):
|
||||
Args:
|
||||
frame: The transport message frame to send.
|
||||
"""
|
||||
message = frame.message
|
||||
if isinstance(message, dict):
|
||||
# fix message encoding for dict-like messages, e.g. RTVI messages.
|
||||
message = json.dumps(message, ensure_ascii=False)
|
||||
if isinstance(frame, (LiveKitTransportMessageFrame, LiveKitTransportMessageUrgentFrame)):
|
||||
await self._client.send_data(frame.message.encode(), frame.participant_id)
|
||||
await self._client.send_data(message.encode(), frame.participant_id)
|
||||
else:
|
||||
await self._client.send_data(frame.message.encode())
|
||||
await self._client.send_data(message.encode())
|
||||
|
||||
async def write_audio_frame(self, frame: OutputAudioRawFrame):
|
||||
"""Write an audio frame to the LiveKit room.
|
||||
@@ -838,6 +922,8 @@ class LiveKitTransport(BaseTransport):
|
||||
on_participant_disconnected=self._on_participant_disconnected,
|
||||
on_audio_track_subscribed=self._on_audio_track_subscribed,
|
||||
on_audio_track_unsubscribed=self._on_audio_track_unsubscribed,
|
||||
on_video_track_subscribed=self._on_video_track_subscribed,
|
||||
on_video_track_unsubscribed=self._on_video_track_unsubscribed,
|
||||
on_data_received=self._on_data_received,
|
||||
on_first_participant_joined=self._on_first_participant_joined,
|
||||
)
|
||||
@@ -855,6 +941,8 @@ class LiveKitTransport(BaseTransport):
|
||||
self._register_event_handler("on_participant_disconnected")
|
||||
self._register_event_handler("on_audio_track_subscribed")
|
||||
self._register_event_handler("on_audio_track_unsubscribed")
|
||||
self._register_event_handler("on_video_track_subscribed")
|
||||
self._register_event_handler("on_video_track_unsubscribed")
|
||||
self._register_event_handler("on_data_received")
|
||||
self._register_event_handler("on_first_participant_joined")
|
||||
self._register_event_handler("on_participant_left")
|
||||
@@ -976,6 +1064,20 @@ class LiveKitTransport(BaseTransport):
|
||||
"""Handle audio track unsubscribed events."""
|
||||
await self._call_event_handler("on_audio_track_unsubscribed", participant_id)
|
||||
|
||||
async def _on_video_track_subscribed(self, participant_id: str):
|
||||
"""Handle video track subscribed events."""
|
||||
await self._call_event_handler("on_video_track_subscribed", participant_id)
|
||||
participant = self._client.room.remote_participants.get(participant_id)
|
||||
if participant:
|
||||
for publication in participant.video_tracks.values():
|
||||
self._client._on_track_subscribed_wrapper(
|
||||
publication.track, publication, participant
|
||||
)
|
||||
|
||||
async def _on_video_track_unsubscribed(self, participant_id: str):
|
||||
"""Handle video track unsubscribed events."""
|
||||
await self._call_event_handler("on_video_track_unsubscribed", participant_id)
|
||||
|
||||
async def _on_data_received(self, data: bytes, participant_id: str):
|
||||
"""Handle data received events."""
|
||||
if self._input:
|
||||
|
||||
Reference in New Issue
Block a user