From 8ad2ad0e593c955f17f7c607d9e238599f5b5afc Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Wed, 9 Apr 2025 23:01:06 -0400 Subject: [PATCH] Add image capture to SmallWebRTCTransport --- CHANGELOG.md | 6 +- .../transports/network/small_webrtc.py | 63 +++++++++++++++---- 2 files changed, 55 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 19da1d16b..416a0bddb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- Added a new iOS client option to the `SmallWebRTCTransport` **video-transform** example. +- Added support for image capture from a video stream to the + `SmallWebRTCTransport`. + +- Added a new iOS client option to the `SmallWebRTCTransport` + **video-transform** example. - Added new processors `ProducerProcessor` and `ConsumerProcessor`. The producer processor processes frames from the pipeline and decides whether the diff --git a/src/pipecat/transports/network/small_webrtc.py b/src/pipecat/transports/network/small_webrtc.py index 7feb0e77d..2eea370b4 100644 --- a/src/pipecat/transports/network/small_webrtc.py +++ b/src/pipecat/transports/network/small_webrtc.py @@ -17,13 +17,17 @@ from pydantic import BaseModel from pipecat.frames.frames import ( CancelFrame, EndFrame, + Frame, InputAudioRawFrame, InputImageRawFrame, OutputImageRawFrame, StartFrame, TransportMessageFrame, TransportMessageUrgentFrame, + UserImageRawFrame, + UserImageRequestFrame, ) +from pipecat.processors.frame_processor import FrameDirection from pipecat.transports.base_input import BaseInputTransport from pipecat.transports.base_output import BaseOutputTransport from pipecat.transports.base_transport import BaseTransport, TransportParams @@ -59,9 +63,7 @@ class RawAudioTrack(AudioStreamTrack): self._chunk_queue = deque() def add_audio_bytes(self, audio_bytes: bytes): - """ - Adds bytes to the audio buffer and returns a Future that completes when the data is processed. - """ + """Adds bytes to the audio buffer and returns a Future that completes when the data is processed.""" if len(audio_bytes) % self._bytes_per_10ms != 0: raise ValueError("Audio bytes must be a multiple of 10ms size.") future = asyncio.get_running_loop().create_future() @@ -76,9 +78,7 @@ class RawAudioTrack(AudioStreamTrack): return future async def recv(self): - """ - Returns the next audio frame, generating silence if needed. - """ + """Returns the next audio frame, generating silence if needed.""" # Compute required wait time for synchronization if self._timestamp > 0: wait = self._start + (self._timestamp / self._sample_rate) - time.time() @@ -179,8 +179,7 @@ class SmallWebRTCClient: await self._handle_app_message(message) def _convert_frame(self, frame_array: np.ndarray, format_name: str) -> np.ndarray: - """ - Convert a given frame to RGB format based on the input format. + """Convert a given frame to RGB format based on the input format. Args: frame_array (np.ndarray): The input frame. @@ -203,8 +202,7 @@ class SmallWebRTCClient: return cv2.cvtColor(frame_array, conversion_code) async def read_video_frame(self): - """ - Reads a video frame from the given MediaStreamTrack, converts it to RGB, + """Reads a video frame from the given MediaStreamTrack, converts it to RGB, and creates an InputImageRawFrame. """ while True: @@ -242,9 +240,7 @@ class SmallWebRTCClient: yield image_frame async def read_audio_frame(self): - """ - Reads 20ms of audio from the given MediaStreamTrack and creates an InputAudioRawFrame. - """ + """Reads 20ms of audio from the given MediaStreamTrack and creates an InputAudioRawFrame.""" while True: if self._audio_input_track is None: await asyncio.sleep(0.01) @@ -379,6 +375,13 @@ class SmallWebRTCInputTransport(BaseInputTransport): self._params = params self._receive_audio_task = None self._receive_video_task = None + self._image_requests = {} + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + if isinstance(frame, UserImageRequestFrame): + await self.request_participant_image(frame) async def start(self, frame: StartFrame): await super().start(frame) @@ -424,6 +427,22 @@ class SmallWebRTCInputTransport(BaseInputTransport): if video_frame: await self.push_frame(video_frame) + # Check if there are any pending image requests and create UserImageRawFrame + if self._image_requests: + for req_id, request_frame in list(self._image_requests.items()): + # Create UserImageRawFrame using the current video frame + image_frame = UserImageRawFrame( + user_id=request_frame.user_id, + request=request_frame, + image=video_frame.image, + size=video_frame.size, + format=video_frame.format, + ) + # Push the frame to the pipeline + await self.push_frame(image_frame) + # Remove from pending requests + del self._image_requests[req_id] + except Exception as e: logger.error(f"{self} exception receiving data: {e.__class__.__name__} ({e})") @@ -432,6 +451,24 @@ class SmallWebRTCInputTransport(BaseInputTransport): frame = TransportMessageUrgentFrame(message=message) await self.push_frame(frame) + # Add this method similar to DailyInputTransport.request_participant_image + async def request_participant_image(self, frame: UserImageRequestFrame): + """Requests an image frame from the participant's video stream. + + When a UserImageRequestFrame is received, this method will store the request + and the next video frame received will be converted to a UserImageRawFrame. + """ + logger.debug(f"Requesting image from participant: {frame.user_id}") + + # Store the request + request_id = f"{frame.function_name}:{frame.tool_call_id}" + self._image_requests[request_id] = frame + + # If we're not already receiving video, try to get a frame now + if not self._receive_video_task and self._params.camera_in_enabled: + # Start video reception if it's not already running + self._receive_video_task = self.create_task(self._receive_video()) + class SmallWebRTCOutputTransport(BaseOutputTransport): def __init__(