Merge pull request #1922 from pipecat-ai/aleix/output-transport-cleanup

output transports cleanup
This commit is contained in:
Aleix Conchillo Flaqué
2025-05-29 14:06:17 -07:00
committed by GitHub
21 changed files with 90 additions and 82 deletions

View File

@@ -11,6 +11,7 @@ import tkinter as tk
from dotenv import load_dotenv
from loguru import logger
from pipecat.examples.run import maybe_capture_participant_camera
from pipecat.frames.frames import (
Frame,
InputAudioRawFrame,
@@ -107,6 +108,7 @@ async def run_example(transport: BaseTransport, _: argparse.Namespace, handle_si
@transport.event_handler("on_client_connected")
async def on_client_connected(transport, client):
logger.info(f"Client connected")
await maybe_capture_participant_camera(transport, client, framerate=30)
@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(transport, client):

View File

@@ -12,7 +12,7 @@ from dotenv import load_dotenv
from loguru import logger
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.examples.run import get_transport_client_id, maybe_capture_participant_video
from pipecat.examples.run import get_transport_client_id, maybe_capture_participant_camera
from pipecat.frames.frames import Frame, TextFrame, UserImageRequestFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
@@ -105,7 +105,7 @@ async def run_example(transport: BaseTransport, _: argparse.Namespace, handle_si
async def on_client_connected(transport, client):
logger.info(f"Client connected: {client}")
await maybe_capture_participant_video(transport, client)
await maybe_capture_participant_camera(transport, client)
# Set the participant ID in the image requester
client_id = get_transport_client_id(transport, client)

View File

@@ -12,7 +12,7 @@ from dotenv import load_dotenv
from loguru import logger
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.examples.run import get_transport_client_id, maybe_capture_participant_video
from pipecat.examples.run import get_transport_client_id, maybe_capture_participant_camera
from pipecat.frames.frames import Frame, TextFrame, UserImageRequestFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
@@ -108,7 +108,7 @@ async def run_example(transport: BaseTransport, _: argparse.Namespace, handle_si
async def on_client_connected(transport, client):
logger.info(f"Client connected: {client}")
await maybe_capture_participant_video(transport, client)
await maybe_capture_participant_camera(transport, client)
# Set the participant ID in the image requester
client_id = get_transport_client_id(transport, client)

View File

@@ -12,7 +12,7 @@ from dotenv import load_dotenv
from loguru import logger
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.examples.run import get_transport_client_id, maybe_capture_participant_video
from pipecat.examples.run import get_transport_client_id, maybe_capture_participant_camera
from pipecat.frames.frames import Frame, TextFrame, UserImageRequestFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
@@ -108,7 +108,7 @@ async def run_example(transport: BaseTransport, _: argparse.Namespace, handle_si
async def on_client_connected(transport, client):
logger.info(f"Client connected: {client}")
await maybe_capture_participant_video(transport, client)
await maybe_capture_participant_camera(transport, client)
# Set the participant ID in the image requester
client_id = get_transport_client_id(transport, client)

View File

@@ -12,7 +12,7 @@ from dotenv import load_dotenv
from loguru import logger
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.examples.run import get_transport_client_id, maybe_capture_participant_video
from pipecat.examples.run import get_transport_client_id, maybe_capture_participant_camera
from pipecat.frames.frames import Frame, TextFrame, UserImageRequestFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
@@ -108,7 +108,7 @@ async def run_example(transport: BaseTransport, _: argparse.Namespace, handle_si
async def on_client_connected(transport, client):
logger.info(f"Client connected: {client}")
await maybe_capture_participant_video(transport, client)
await maybe_capture_participant_camera(transport, client)
# Set the participant ID in the image requester
client_id = get_transport_client_id(transport, client)

View File

@@ -14,7 +14,7 @@ from loguru import logger
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.examples.run import get_transport_client_id, maybe_capture_participant_video
from pipecat.examples.run import get_transport_client_id, maybe_capture_participant_camera
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
@@ -181,7 +181,7 @@ If you need to use a tool, simply use the tool. Do not tell the user the tool yo
async def on_client_connected(transport, client):
logger.info(f"Client connected: {client}")
await maybe_capture_participant_video(transport, client)
await maybe_capture_participant_camera(transport, client)
global client_id
client_id = get_transport_client_id(transport, client)

View File

@@ -14,7 +14,7 @@ from loguru import logger
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.examples.run import get_transport_client_id, maybe_capture_participant_video
from pipecat.examples.run import get_transport_client_id, maybe_capture_participant_camera
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineTask
@@ -164,7 +164,7 @@ indicate you should use the get_image tool are:
async def on_client_connected(transport, client):
logger.info(f"Client connected")
await maybe_capture_participant_video(transport, client)
await maybe_capture_participant_camera(transport, client)
global client_id
client_id = get_transport_client_id(transport, client)

View File

@@ -14,7 +14,7 @@ from loguru import logger
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.examples.run import get_transport_client_id, maybe_capture_participant_video
from pipecat.examples.run import get_transport_client_id, maybe_capture_participant_camera
from pipecat.frames.frames import TTSSpeakFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
@@ -174,7 +174,7 @@ indicate you should use the get_image tool are:
async def on_client_connected(transport, client):
logger.info(f"Client connected: {client}")
await maybe_capture_participant_video(transport, client)
await maybe_capture_participant_camera(transport, client)
global client_id
client_id = get_transport_client_id(transport, client)

View File

@@ -15,7 +15,7 @@ from loguru import logger
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.audio.vad.vad_analyzer import VADParams
from pipecat.examples.run import get_transport_client_id, maybe_capture_participant_video
from pipecat.examples.run import get_transport_client_id, maybe_capture_participant_camera
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
@@ -286,7 +286,7 @@ async def run_example(transport: BaseTransport, _: argparse.Namespace, handle_si
async def on_client_connected(transport, client):
logger.info(f"Client connected")
await maybe_capture_participant_video(transport, client)
await maybe_capture_participant_camera(transport, client)
global client_id
client_id = get_transport_client_id(transport, client)

View File

@@ -10,10 +10,10 @@ import os
from dotenv import load_dotenv
from loguru import logger
from run import maybe_capture_participant_video
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.audio.vad.vad_analyzer import VADParams
from pipecat.examples.run import maybe_capture_participant_camera, maybe_capture_participant_screen
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
@@ -92,7 +92,8 @@ async def run_example(transport: BaseTransport, _: argparse.Namespace, handle_si
async def on_client_connected(transport, client):
logger.info(f"Client connected: {client}")
await maybe_capture_participant_video(transport, client)
await maybe_capture_participant_camera(transport, client, framerate=1)
await maybe_capture_participant_screen(transport, client, framerate=1)
await task.queue_frames([context_aggregator.user().get_context_frame()])
await asyncio.sleep(3)

View File

@@ -43,11 +43,21 @@ def get_transport_client_id(transport: BaseTransport, client: Any) -> str:
return ""
async def maybe_capture_participant_video(transport: BaseTransport, client: Any):
async def maybe_capture_participant_camera(
transport: BaseTransport, client: Any, framerate: int = 0
):
if isinstance(transport, DailyTransport):
await transport.capture_participant_video(client["id"], framerate=0, video_source="camera")
await transport.capture_participant_video(
client["id"], framerate=0, video_source="screenVideo"
client["id"], framerate=framerate, video_source="camera"
)
async def maybe_capture_participant_screen(
transport: BaseTransport, client: Any, framerate: int = 0
):
if isinstance(transport, DailyTransport):
await transport.capture_participant_video(
client["id"], framerate=framerate, video_source="screenVideo"
)

View File

@@ -134,12 +134,10 @@ class BaseOutputTransport(FrameProcessor):
async def register_audio_destination(self, destination: str):
pass
async def write_raw_video_frame(
self, frame: OutputImageRawFrame, destination: Optional[str] = None
):
async def write_video_frame(self, frame: OutputImageRawFrame):
pass
async def write_raw_audio_frames(self, frames: bytes, destination: Optional[str] = None):
async def write_audio_frame(self, frame: OutputAudioRawFrame):
pass
async def write_dtmf(self, frame: OutputDTMFFrame | OutputDTMFUrgentFrame):
@@ -507,7 +505,7 @@ class BaseOutputTransport(FrameProcessor):
# Send audio.
if isinstance(frame, OutputAudioRawFrame):
await self._transport.write_raw_audio_frames(frame.audio, self._destination)
await self._transport.write_audio_frame(frame)
#
# Video handling
@@ -590,8 +588,7 @@ class BaseOutputTransport(FrameProcessor):
frame = await self._transport.get_event_loop().run_in_executor(
self._executor, resize_frame, frame
)
await self._transport.write_raw_video_frame(frame, self._destination)
await self._transport.write_video_frame(frame)
#
# Clock handling

View File

@@ -10,7 +10,7 @@ from typing import Optional
from loguru import logger
from pipecat.frames.frames import InputAudioRawFrame, StartFrame
from pipecat.frames.frames import InputAudioRawFrame, OutputAudioRawFrame, StartFrame
from pipecat.processors.frame_processor import FrameProcessor
from pipecat.transports.base_input import BaseInputTransport
from pipecat.transports.base_output import BaseOutputTransport
@@ -122,10 +122,10 @@ class LocalAudioOutputTransport(BaseOutputTransport):
self._out_stream.close()
self._out_stream = None
async def write_raw_audio_frames(self, frames: bytes, destination: Optional[str] = None):
async def write_audio_frame(self, frame: OutputAudioRawFrame):
if self._out_stream:
await self.get_event_loop().run_in_executor(
self._executor, self._out_stream.write, frames
self._executor, self._out_stream.write, frame.audio
)

View File

@@ -12,7 +12,12 @@ from typing import Optional
import numpy as np
from loguru import logger
from pipecat.frames.frames import InputAudioRawFrame, OutputImageRawFrame, StartFrame
from pipecat.frames.frames import (
InputAudioRawFrame,
OutputAudioRawFrame,
OutputImageRawFrame,
StartFrame,
)
from pipecat.transports.base_input import BaseInputTransport
from pipecat.transports.base_output import BaseOutputTransport
from pipecat.transports.base_transport import BaseTransport, TransportParams
@@ -135,15 +140,13 @@ class TkOutputTransport(BaseOutputTransport):
self._out_stream.close()
self._out_stream = None
async def write_raw_audio_frames(self, frames: bytes, destination: Optional[str] = None):
async def write_audio_frame(self, frame: OutputAudioRawFrame):
if self._out_stream:
await self.get_event_loop().run_in_executor(
self._executor, self._out_stream.write, frames
self._executor, self._out_stream.write, frame.audio
)
async def write_raw_video_frame(
self, frame: OutputImageRawFrame, destination: Optional[str] = None
):
async def write_video_frame(self, frame: OutputImageRawFrame):
self.get_event_loop().call_soon(self._write_frame_to_tk, frame)
def _write_frame_to_tk(self, frame: OutputImageRawFrame):

View File

@@ -196,7 +196,7 @@ class FastAPIWebsocketOutputTransport(BaseOutputTransport):
self._client = client
self._params = params
# write_raw_audio_frames() is called quickly, as soon as we get audio
# write_audio_frame() is called quickly, as soon as we get audio
# (e.g. from the TTS), and since this is just a network connection we
# would be sending it to quickly. Instead, we want to block to emulate
# an audio device, this is what the send interval is. It will be
@@ -236,7 +236,7 @@ class FastAPIWebsocketOutputTransport(BaseOutputTransport):
async def send_message(self, frame: TransportMessageFrame | TransportMessageUrgentFrame):
await self._write_frame(frame)
async def write_raw_audio_frames(self, frames: bytes, destination: Optional[str] = None):
async def write_audio_frame(self, frame: OutputAudioRawFrame):
if self._client.is_closing:
return
@@ -246,7 +246,7 @@ class FastAPIWebsocketOutputTransport(BaseOutputTransport):
return
frame = OutputAudioRawFrame(
audio=frames,
audio=frame.audio,
sample_rate=self.sample_rate,
num_channels=self._params.audio_out_channels,
)

View File

@@ -283,13 +283,11 @@ class SmallWebRTCClient:
)
yield audio_frame
async def write_raw_audio_frames(self, data: bytes, destination: Optional[str] = None):
async def write_audio_frame(self, frame: OutputAudioRawFrame):
if self._can_send() and self._audio_output_track:
await self._audio_output_track.add_audio_bytes(data)
await self._audio_output_track.add_audio_bytes(frame.audio)
async def write_raw_video_frame(
self, frame: OutputImageRawFrame, destination: Optional[str] = None
):
async def write_video_frame(self, frame: OutputImageRawFrame):
if self._can_send() and self._video_output_track:
self._video_output_track.add_video_frame(frame)
@@ -499,13 +497,11 @@ class SmallWebRTCOutputTransport(BaseOutputTransport):
async def send_message(self, frame: TransportMessageFrame | TransportMessageUrgentFrame):
await self._client.send_message(frame)
async def write_raw_audio_frames(self, frames: bytes, destination: Optional[str] = None):
await self._client.write_raw_audio_frames(frames)
async def write_audio_frame(self, frame: OutputAudioRawFrame):
await self._client.write_audio_frame(frame)
async def write_raw_video_frame(
self, frame: OutputImageRawFrame, destination: Optional[str] = None
):
await self._client.write_raw_video_frame(frame)
async def write_video_frame(self, frame: OutputImageRawFrame):
await self._client.write_video_frame(frame)
class SmallWebRTCTransport(BaseTransport):

View File

@@ -180,7 +180,7 @@ class WebsocketClientOutputTransport(BaseOutputTransport):
self._session = session
self._params = params
# write_raw_audio_frames() is called quickly, as soon as we get audio
# write_audio_frame() is called quickly, as soon as we get audio
# (e.g. from the TTS), and since this is just a network connection we
# would be sending it to quickly. Instead, we want to block to emulate
# an audio device, this is what the send interval is. It will be
@@ -215,9 +215,9 @@ class WebsocketClientOutputTransport(BaseOutputTransport):
async def send_message(self, frame: TransportMessageFrame | TransportMessageUrgentFrame):
await self._write_frame(frame)
async def write_raw_audio_frames(self, frames: bytes, destination: Optional[str] = None):
async def write_audio_frame(self, frame: OutputAudioRawFrame):
frame = OutputAudioRawFrame(
audio=frames,
audio=frame.audio,
sample_rate=self.sample_rate,
num_channels=self._params.audio_out_channels,
)

View File

@@ -182,7 +182,7 @@ class WebsocketServerOutputTransport(BaseOutputTransport):
self._websocket: Optional[websockets.WebSocketServerProtocol] = None
# write_raw_audio_frames() is called quickly, as soon as we get audio
# write_audio_frame() is called quickly, as soon as we get audio
# (e.g. from the TTS), and since this is just a network connection we
# would be sending it to quickly. Instead, we want to block to emulate
# an audio device, this is what the send interval is. It will be
@@ -225,14 +225,14 @@ class WebsocketServerOutputTransport(BaseOutputTransport):
async def send_message(self, frame: TransportMessageFrame | TransportMessageUrgentFrame):
await self._write_frame(frame)
async def write_raw_audio_frames(self, frames: bytes, destination: Optional[str] = None):
async def write_audio_frame(self, frame: OutputAudioRawFrame):
if not self._websocket:
# Simulate audio playback with a sleep.
await self._write_audio_sleep()
return
frame = OutputAudioRawFrame(
audio=frames,
audio=frame.audio,
sample_rate=self.sample_rate,
num_channels=self._params.audio_out_channels,
)

View File

@@ -372,9 +372,10 @@ class DailyTransportClient(EventHandler):
self._custom_audio_tracks[destination] = await self.add_custom_audio_track(destination)
self._client.update_publishing({"customAudio": {destination: True}})
async def write_raw_audio_frames(self, frames: bytes, destination: Optional[str] = None):
async def write_audio_frame(self, frame: OutputAudioRawFrame):
future = self._get_event_loop().create_future()
destination = frame.transport_destination
audio_source: Optional[CustomAudioSource] = None
if not destination and self._microphone_track:
audio_source = self._microphone_track.source
@@ -383,17 +384,15 @@ class DailyTransportClient(EventHandler):
audio_source = track.source
if audio_source:
audio_source.write_frames(frames, completion=completion_callback(future))
audio_source.write_frames(frame.audio, completion=completion_callback(future))
else:
logger.warning(f"{self} unable to write audio frames to destination [{destination}]")
future.set_result(None)
await future
async def write_raw_video_frame(
self, frame: OutputImageRawFrame, destination: Optional[str] = None
):
if not destination and self._camera:
async def write_video_frame(self, frame: OutputImageRawFrame):
if not frame.transport_destination and self._camera:
self._camera.write_frame(frame.image)
async def setup(self, setup: FrameProcessorSetup):
@@ -1230,13 +1229,11 @@ class DailyOutputTransport(BaseOutputTransport):
}
)
async def write_raw_audio_frames(self, frames: bytes, destination: Optional[str] = None):
await self._client.write_raw_audio_frames(frames, destination)
async def write_audio_frame(self, frame: OutputAudioRawFrame):
await self._client.write_audio_frame(frame)
async def write_raw_video_frame(
self, frame: OutputImageRawFrame, destination: Optional[str] = None
):
await self._client.write_raw_video_frame(frame, destination)
async def write_video_frame(self, frame: OutputImageRawFrame):
await self._client.write_video_frame(frame)
class DailyTransport(BaseTransport):

View File

@@ -477,8 +477,8 @@ class LiveKitOutputTransport(BaseOutputTransport):
else:
await self._client.send_data(frame.message.encode())
async def write_raw_audio_frames(self, frames: bytes, destination: Optional[str] = None):
livekit_audio = self._convert_pipecat_audio_to_livekit(frames)
async def write_audio_frame(self, frame: OutputAudioRawFrame):
livekit_audio = self._convert_pipecat_audio_to_livekit(frame.audio)
await self._client.publish_audio(livekit_audio)
def _convert_pipecat_audio_to_livekit(self, pipecat_audio: bytes) -> rtc.AudioFrame:

View File

@@ -17,6 +17,7 @@ from pipecat.frames.frames import (
EndFrame,
Frame,
InputAudioRawFrame,
OutputAudioRawFrame,
OutputImageRawFrame,
StartFrame,
StartInterruptionFrame,
@@ -290,12 +291,18 @@ class TavusTransportClient:
await self.send_message(transport_frame)
async def update_subscriptions(self, participant_settings=None, profile_settings=None):
if not self._client:
return
await self._client.update_subscriptions(
participant_settings=participant_settings, profile_settings=profile_settings
)
async def write_raw_audio_frames(self, frames: bytes, destination: Optional[str] = None):
await self._client.write_raw_audio_frames(frames, destination)
async def write_audio_frame(self, frame: OutputAudioRawFrame):
if not self._client:
return
await self._client.write_audio_frame(frame)
class TavusInputTransport(BaseInputTransport):
@@ -418,26 +425,21 @@ class TavusOutputTransport(BaseOutputTransport):
async def _handle_interruptions(self):
await self._client.send_interrupt_message()
async def write_raw_audio_frames(self, frames: bytes, destination: Optional[str] = None):
async def write_audio_frame(self, frame: OutputAudioRawFrame):
# Compute wait time for synchronization
wait = self._start_time + (self._samples_sent / self.sample_rate) - time.time()
if wait > 0:
logger.trace(f"TavusOutputTransport write_raw_audio_frames wait: {wait}")
logger.trace(f"TavusOutputTransport write_audio_frame wait: {wait}")
await asyncio.sleep(wait)
if self._current_idx_str is None:
logger.warning("TavusOutputTransport self._current_idx_str not defined yet!")
return
await self._client.encode_audio_and_send(frames, False, self._current_idx_str)
await self._client.encode_audio_and_send(frame.audio, False, self._current_idx_str)
# Update timestamp based on number of samples sent
self._samples_sent += len(frames) // 2 # 2 bytes per sample (16-bit)
async def write_raw_video_frame(
self, frame: OutputImageRawFrame, destination: Optional[str] = None
):
pass
self._samples_sent += len(frame.audio) // 2 # 2 bytes per sample (16-bit)
class TavusTransport(BaseTransport):