diff --git a/CHANGELOG.md b/CHANGELOG.md index 59efa9e3a..9f0b763d2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- `TavusTransport` and `TavusVideoService` now send audio to Tavus using WebRTC + audio tracks instead of `app-messages` over WebSocket. This should improve the + overall audio quality. + - Upgraded `daily-python` to 0.19.3. ### Fixed diff --git a/scripts/daily/test_tavus_transport.py b/scripts/daily/test_tavus_transport.py new file mode 100644 index 000000000..fa8afe835 --- /dev/null +++ b/scripts/daily/test_tavus_transport.py @@ -0,0 +1,177 @@ +import asyncio +import os +import signal + +from daily import * +from dotenv import load_dotenv +from loguru import logger + +load_dotenv(override=True) + + +def completion_callback(future): + def _callback(*args): + def set_result(future, *args): + try: + if len(args) > 1: + future.set_result(args) + else: + future.set_result(*args) + except asyncio.InvalidStateError: + pass + + future.get_loop().call_soon_threadsafe(set_result, future, *args) + + return _callback + + +class DailyProxyApp(EventHandler): + # This is necessary to override EventHandler's __new__ method. + def __new__(cls, *args, **kwargs): + return super().__new__(cls) + + def __init__(self, sample_rate: int): + super().__init__() + self._sample_rate = sample_rate + self._loop = None + self._audio_queue: asyncio.Queue | None = None + self._audio_task: asyncio.Task | None = None + + self._client: CallClient = CallClient(event_handler=self) + self._client.update_subscription_profiles( + {"base": {"camera": "unsubscribed", "microphone": "subscribed"}} + ) + + self._audio_source = CustomAudioSource(self._sample_rate, 1) + self._audio_track = CustomAudioTrack(self._audio_source) + + def on_joined(self, data, error): + logger.debug("Local participant Joined!") + if error: + print(f"Unable to join meeting: {error}") + self._loop.call_soon_threadsafe(self._loop.stop) + + def run(self, meeting_url: str): + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + self._create_audio_task() + + def handle_exit(): + logger.info("Ctrl+C pressed. Leaving the meeting...") + self._loop.call_soon_threadsafe(self._loop.stop) + + for sig in (signal.SIGINT, signal.SIGTERM): + self._loop.add_signal_handler(sig, handle_exit) + + self._client.set_user_name("TestTavusTransport") + self._client.join( + meeting_url, + completion=self.on_joined, + client_settings={ + "inputs": { + "microphone": { + "isEnabled": True, + "settings": {"customTrack": {"id": self._audio_track.id}}, + }, + } + }, + ) + + try: + self._loop.run_forever() + finally: + self.leave() + + def leave(self): + if self._audio_task: + self._loop.run_until_complete(self._cancel_audio_task()) + + self._client.leave() + self._client.release() + + async def update_subscriptions(self, participant_settings=None, profile_settings=None): + logger.info(f"Updating subscriptions participant_settings: {participant_settings}") + future = asyncio.get_running_loop().create_future() + self._client.update_subscriptions( + participant_settings=participant_settings, + profile_settings=profile_settings, + completion=completion_callback(future), + ) + await future + + def _create_audio_task(self): + if not self._audio_task: + self._audio_queue = asyncio.Queue() + self._audio_task = self._loop.create_task(self._audio_task_handler()) + + async def _cancel_audio_task(self): + if self._audio_task: + self._audio_task.cancel() + try: + # Waits for it to finish + await self._audio_task + except asyncio.CancelledError: + pass + self._audio_task = None + self._audio_queue = None + + async def capture_participant_audio(self, participant_id: str): + logger.info(f"Capturing participant audio: {participant_id}") + # Receiving from this custom track + # audio_source: str = "microphone" + audio_source: str = "stream" + media = {"media": {"customAudio": {audio_source: "subscribed"}}} + await self.update_subscriptions(participant_settings={participant_id: media}) + + self._client.set_audio_renderer( + participant_id, + self._audio_data_received, + audio_source=audio_source, + sample_rate=self._sample_rate, + callback_interval_ms=20, + ) + + async def send_audio(self, audio: AudioData): + future = asyncio.get_running_loop().create_future() + self._audio_source.write_frames(audio.audio_frames, completion=completion_callback(future)) + await future + + async def queue_audio(self, audio: AudioData): + await self._audio_queue.put(audio) + + def _audio_data_received(self, participant_id: str, audio_data: AudioData, audio_source: str): + # logger.info(f"Received audio data for {participant_id}, audio_source: {audio_source}") + asyncio.run_coroutine_threadsafe(self.queue_audio(audio_data), self._loop) + + async def _audio_task_handler(self): + while True: + audio = await self._audio_queue.get() + await self.send_audio(audio) + + # + # Daily (EventHandler) + # + + def on_participant_joined(self, participant): + participant_name = participant["info"]["userName"] + logger.info(f"Participant {participant_name} joined") + if participant_name != "Pipecat": + # We are only subscribing for audios from Pipecat. + return + asyncio.run_coroutine_threadsafe( + self.capture_participant_audio(participant_id=participant["id"]), self._loop + ) + + def on_participant_left(self, participant, reason): + logger.info(f"Participant {participant['id']} left {reason}") + + +def main(): + Daily.init() + room_url = os.getenv("TAVUS_SAMPLE_ROOM_URL") + app = DailyProxyApp(sample_rate=24000) + app.run(room_url) + + +if __name__ == "__main__": + main() diff --git a/scripts/fix-ruff.sh b/scripts/fix-ruff.sh index 892f6d405..6bd24300d 100755 --- a/scripts/fix-ruff.sh +++ b/scripts/fix-ruff.sh @@ -1,4 +1,5 @@ ruff format src ruff format examples ruff format tests +ruff format scripts ruff check --select I --fix \ No newline at end of file diff --git a/src/pipecat/services/tavus/video.py b/src/pipecat/services/tavus/video.py index 0b59514e7..e6c78813d 100644 --- a/src/pipecat/services/tavus/video.py +++ b/src/pipecat/services/tavus/video.py @@ -7,7 +7,6 @@ """This module implements Tavus as a sink transport layer""" import asyncio -import time from typing import Optional import aiohttp @@ -29,9 +28,6 @@ from pipecat.processors.frame_processor import FrameDirection, FrameProcessorSet from pipecat.services.ai_service import AIService from pipecat.transports.services.tavus import TavusCallbacks, TavusParams, TavusTransportClient -# Using the same values that we do in the BaseOutputTransport -BOT_VAD_STOP_SECS = 0.35 - class TavusVideoService(AIService): """ @@ -48,7 +44,7 @@ class TavusVideoService(AIService): Args: api_key (str): Tavus API key used for authentication. replica_id (str): ID of the Tavus voice replica to use for speech synthesis. - persona_id (str): ID of the Tavus persona. Defaults to "pipecat0" to use the Pipecat TTS voice. + persona_id (str): ID of the Tavus persona. Defaults to "pipecat-stream" to use the Pipecat TTS voice. session (aiohttp.ClientSession): Async HTTP session used for communication with Tavus. **kwargs: Additional arguments passed to the parent `AIService` class. """ @@ -58,7 +54,7 @@ class TavusVideoService(AIService): *, api_key: str, replica_id: str, - persona_id: str = "pipecat0", # Use `pipecat0` so that your TTS voice is used in place of the Tavus persona + persona_id: str = "pipecat-stream", session: aiohttp.ClientSession, **kwargs, ) -> None: @@ -77,6 +73,8 @@ class TavusVideoService(AIService): self._audio_buffer = bytearray() self._queue = asyncio.Queue() self._send_task: Optional[asyncio.Task] = None + # This is the custom track destination expected by Tavus + self._transport_destination: Optional[str] = "stream" async def setup(self, setup: FrameProcessorSetup): await super().setup(setup) @@ -94,6 +92,8 @@ class TavusVideoService(AIService): params=TavusParams( audio_in_enabled=True, video_in_enabled=True, + audio_out_enabled=True, + microphone_out_enabled=False, ), ) await self._client.setup(setup) @@ -152,6 +152,8 @@ class TavusVideoService(AIService): async def start(self, frame: StartFrame): await super().start(frame) await self._client.start(frame) + if self._transport_destination: + await self._client.register_audio_destination(self._transport_destination) await self._create_send_task() async def stop(self, frame: EndFrame): @@ -171,7 +173,7 @@ class TavusVideoService(AIService): await self._handle_interruptions() await self.push_frame(frame, direction) elif isinstance(frame, TTSAudioRawFrame): - await self._queue.put(frame) + await self._handle_audio_frame(frame) else: await self.push_frame(frame, direction) @@ -194,60 +196,26 @@ class TavusVideoService(AIService): await self.cancel_task(self._send_task) self._send_task = None - async def _send_task_handler(self): - # Daily app-messages have a 4kb limit and also a rate limit of 20 - # messages per second. Below, we only consider the rate limit because 1 - # second of a 24000 sample rate would be 48000 bytes (16-bit samples and - # 1 channel). So, that is 48000 / 20 = 2400, which is below the 4kb - # limit (even including base64 encoding). For a sample rate of 16000, - # that would be 32000 / 20 = 1600. + async def _handle_audio_frame(self, frame: OutputAudioRawFrame): sample_rate = self._client.out_sample_rate - # 50 ms of audio - MAX_CHUNK_SIZE = int((sample_rate * 2) / 20) - - audio_buffer = bytearray() - current_idx_str = None - silence = b"\x00" * MAX_CHUNK_SIZE - samples_sent = 0 - start_time = None + # 40 ms of audio + chunk_size = int((sample_rate * 2) / 25) + # We might need to resample if incoming audio doesn't match the + # transport sample rate. + resampled = await self._resampler.resample(frame.audio, frame.sample_rate, sample_rate) + self._audio_buffer.extend(resampled) + while len(self._audio_buffer) >= chunk_size: + chunk = OutputAudioRawFrame( + bytes(self._audio_buffer[:chunk_size]), + sample_rate=sample_rate, + num_channels=frame.num_channels, + ) + chunk.transport_destination = self._transport_destination + await self._queue.put(chunk) + self._audio_buffer = self._audio_buffer[chunk_size:] + async def _send_task_handler(self): while True: - try: - frame = await asyncio.wait_for(self._queue.get(), timeout=BOT_VAD_STOP_SECS) - if isinstance(frame, TTSAudioRawFrame): - # starting the new inference - if current_idx_str is None: - current_idx_str = str(frame.id) - samples_sent = 0 - start_time = time.time() - - audio = await self._resampler.resample( - frame.audio, frame.sample_rate, sample_rate - ) - audio_buffer.extend(audio) - while len(audio_buffer) >= MAX_CHUNK_SIZE: - chunk = audio_buffer[:MAX_CHUNK_SIZE] - audio_buffer = audio_buffer[MAX_CHUNK_SIZE:] - - # Compute wait time for synchronization - wait = start_time + (samples_sent / sample_rate) - time.time() - if wait > 0: - logger.trace(f"TavusVideoService _send_task_handler wait: {wait}") - await asyncio.sleep(wait) - - await self._client.encode_audio_and_send( - bytes(chunk), False, current_idx_str - ) - - # Update timestamp based on number of samples sent - samples_sent += len(chunk) // 2 # 2 bytes per sample (16-bit) - except asyncio.TimeoutError: - # Bot has stopped speaking - # Send any remaining audio. - if len(audio_buffer) > 0: - await self._client.encode_audio_and_send( - bytes(audio_buffer), False, current_idx_str - ) - await self._client.encode_audio_and_send(silence, True, current_idx_str) - audio_buffer.clear() - current_idx_str = None + frame = await self._queue.get() + if isinstance(frame, OutputAudioRawFrame): + await self._client.write_audio_frame(frame) diff --git a/src/pipecat/transports/services/daily.py b/src/pipecat/transports/services/daily.py index 78bf62062..776da1693 100644 --- a/src/pipecat/transports/services/daily.py +++ b/src/pipecat/transports/services/daily.py @@ -767,6 +767,7 @@ class DailyTransportClient(EventHandler): self._client.add_custom_audio_track( track_name=track_name, audio_track=audio_track, + ignore_audio_level=True, completion=completion_callback(future), ) diff --git a/src/pipecat/transports/services/tavus.py b/src/pipecat/transports/services/tavus.py index 2c76b9b33..ff70416d2 100644 --- a/src/pipecat/transports/services/tavus.py +++ b/src/pipecat/transports/services/tavus.py @@ -1,6 +1,4 @@ -import asyncio -import base64 -import time +import os from functools import partial from typing import Any, Awaitable, Callable, Mapping, Optional @@ -11,8 +9,6 @@ from pydantic import BaseModel from pipecat.audio.utils import create_default_resampler from pipecat.frames.frames import ( - BotStartedSpeakingFrame, - BotStoppedSpeakingFrame, CancelFrame, EndFrame, Frame, @@ -40,6 +36,8 @@ class TavusApi: """ BASE_URL = "https://tavusapi.com/v2" + MOCK_CONVERSATION_ID = "dev-conversation" + MOCK_PERSONA_NAME = "TestTavusTransport" def __init__(self, api_key: str, session: aiohttp.ClientSession): """ @@ -52,8 +50,16 @@ class TavusApi: self._api_key = api_key self._session = session self._headers = {"Content-Type": "application/json", "x-api-key": self._api_key} + # Only for development + self._dev_room_url = os.getenv("TAVUS_SAMPLE_ROOM_URL") async def create_conversation(self, replica_id: str, persona_id: str) -> dict: + if self._dev_room_url: + return { + "conversation_id": self.MOCK_CONVERSATION_ID, + "conversation_url": self._dev_room_url, + } + logger.debug(f"Creating Tavus conversation: replica={replica_id}, persona={persona_id}") url = f"{self.BASE_URL}/conversations" payload = { @@ -67,7 +73,7 @@ class TavusApi: return response async def end_conversation(self, conversation_id: str): - if conversation_id is None: + if conversation_id is None or conversation_id == self.MOCK_CONVERSATION_ID: return url = f"{self.BASE_URL}/conversations/{conversation_id}/end" @@ -76,6 +82,9 @@ class TavusApi: logger.debug(f"Ended Tavus conversation {conversation_id}") async def get_persona_name(self, persona_id: str) -> str: + if self._dev_room_url is not None: + return self.MOCK_PERSONA_NAME + url = f"{self.BASE_URL}/personas/{persona_id}" async with self._session.get(url, headers=self._headers) as r: r.raise_for_status() @@ -119,7 +128,7 @@ class TavusTransportClient: callbacks (TavusCallbacks): Callback handlers for Tavus-related events. api_key (str): API key for authenticating with Tavus API. replica_id (str): ID of the replica to use in the Tavus conversation. - persona_id (str): ID of the Tavus persona. Defaults to "pipecat0", which signals Tavus to use + persona_id (str): ID of the Tavus persona. Defaults to "pipecat-stream", which signals Tavus to use the TTS voice of the Pipecat bot instead of a Tavus persona voice. session (aiohttp.ClientSession): The aiohttp session for making async HTTP requests. sample_rate: Audio sample rate to be used by the client. @@ -133,7 +142,7 @@ class TavusTransportClient: callbacks: TavusCallbacks, api_key: str, replica_id: str, - persona_id: str = "pipecat0", # Use `pipecat0` so that your TTS voice is used in place of the Tavus persona + persona_id: str = "pipecat-stream", session: aiohttp.ClientSession, ) -> None: self._bot_name = bot_name @@ -141,7 +150,6 @@ class TavusTransportClient: self._replica_id = replica_id self._persona_id = persona_id self._conversation_id: Optional[str] = None - self._other_participant_has_joined = False self._client: Optional[DailyTransportClient] = None self._callbacks = callbacks self._params = params @@ -153,6 +161,7 @@ class TavusTransportClient: async def setup(self, setup: FrameProcessorSetup): if self._conversation_id is not None: + logger.debug(f"Conversation ID already defined: {self._conversation_id}") return try: room_url = await self._initialize() @@ -194,12 +203,13 @@ class TavusTransportClient: except Exception as e: logger.error(f"Failed to setup TavusTransportClient: {e}") await self._api.end_conversation(self._conversation_id) + self._conversation_id = None async def cleanup(self): - if self._client is None: - return - await self._client.cleanup() - self._client = None + try: + await self._client.cleanup() + except Exception as e: + logger.exception(f"Exception during cleanup: {e}") async def _on_joined(self, data): logger.debug("TavusTransportClient joined!") @@ -221,6 +231,7 @@ class TavusTransportClient: async def stop(self): await self._client.leave() await self._api.end_conversation(self._conversation_id) + self._conversation_id = None async def capture_participant_video( self, @@ -257,11 +268,6 @@ class TavusTransportClient: def in_sample_rate(self) -> int: return self._client.in_sample_rate - async def encode_audio_and_send(self, audio: bytes, done: bool, inference_id: str): - """Encodes audio to base64 and sends it to Tavus""" - audio_base64 = base64.b64encode(audio).decode("utf-8") - await self._send_audio_message(audio_base64, done=done, inference_id=inference_id) - async def send_interrupt_message(self) -> None: transport_frame = TransportMessageUrgentFrame( message={ @@ -272,23 +278,6 @@ class TavusTransportClient: ) await self.send_message(transport_frame) - async def _send_audio_message(self, audio_base64: str, done: bool, inference_id: str): - transport_frame = TransportMessageUrgentFrame( - message={ - "message_type": "conversation", - "event_type": "conversation.echo", - "conversation_id": self._conversation_id, - "properties": { - "modality": "audio", - "inference_id": inference_id, - "audio": audio_base64, - "done": done, - "sample_rate": self.out_sample_rate, - }, - } - ) - await self.send_message(transport_frame) - async def update_subscriptions(self, participant_settings=None, profile_settings=None): if not self._client: return @@ -300,9 +289,14 @@ class TavusTransportClient: async def write_audio_frame(self, frame: OutputAudioRawFrame): if not self._client: return - await self._client.write_audio_frame(frame) + async def register_audio_destination(self, destination: str): + if not self._client: + return + + await self._client.register_audio_destination(destination) + class TavusInputTransport(BaseInputTransport): def __init__( @@ -379,12 +373,11 @@ class TavusOutputTransport(BaseOutputTransport): super().__init__(params, **kwargs) self._client = client self._params = params - self._samples_sent = 0 - self._start_time = None - self._current_idx_str: Optional[str] = None # Whether we have seen a StartFrame already. self._initialized = False + # This is the custom track destination expected by Tavus + self._transport_destination: Optional[str] = "stream" async def setup(self, setup: FrameProcessorSetup): await super().setup(setup) @@ -403,6 +396,10 @@ class TavusOutputTransport(BaseOutputTransport): self._initialized = True await self._client.start(frame) + + if self._transport_destination: + await self._client.register_audio_destination(self._transport_destination) + await self.set_transport_ready(frame) async def stop(self, frame: EndFrame): @@ -417,23 +414,6 @@ class TavusOutputTransport(BaseOutputTransport): logger.info(f"TavusOutputTransport sending message {frame}") await self._client.send_message(frame) - async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM): - # The BotStartedSpeakingFrame and BotStoppedSpeakingFrame are created inside BaseOutputTransport - # so TavusOutputTransport never receives these frames. - # This is a workaround, so we can more reliably be aware when the bot has started or stopped speaking - if direction == FrameDirection.DOWNSTREAM: - if isinstance(frame, BotStartedSpeakingFrame): - if self._current_idx_str is not None: - logger.warning("TavusOutputTransport self._current_idx_str is already defined!") - self._current_idx_str = str(frame.id) - self._start_time = time.time() - self._samples_sent = 0 - elif isinstance(frame, BotStoppedSpeakingFrame): - silence = b"\x00" * self.audio_chunk_size - await self._client.encode_audio_and_send(silence, True, self._current_idx_str) - self._current_idx_str = None - await super().push_frame(frame, direction) - async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) if isinstance(frame, StartInterruptionFrame): @@ -443,20 +423,12 @@ class TavusOutputTransport(BaseOutputTransport): await self._client.send_interrupt_message() async def write_audio_frame(self, frame: OutputAudioRawFrame): - # Compute wait time for synchronization - wait = self._start_time + (self._samples_sent / self.sample_rate) - time.time() - if wait > 0: - logger.trace(f"TavusOutputTransport write_audio_frame wait: {wait}") - await asyncio.sleep(wait) + # This is the custom track destination expected by Tavus + frame.transport_destination = self._transport_destination + await self._client.write_audio_frame(frame) - if self._current_idx_str is None: - logger.warning("TavusOutputTransport self._current_idx_str not defined yet!") - return - - await self._client.encode_audio_and_send(frame.audio, False, self._current_idx_str) - - # Update timestamp based on number of samples sent - self._samples_sent += len(frame.audio) // 2 # 2 bytes per sample (16-bit) + async def register_audio_destination(self, destination: str): + await self._client.register_audio_destination(destination) class TavusTransport(BaseTransport): @@ -472,7 +444,7 @@ class TavusTransport(BaseTransport): session (aiohttp.ClientSession): aiohttp session used for async HTTP requests. api_key (str): Tavus API key for authentication. replica_id (str): ID of the replica model used for voice generation. - persona_id (str): ID of the Tavus persona. Defaults to "pipecat0" to use the Pipecat TTS voice. + persona_id (str): ID of the Tavus persona. Defaults to "pipecat-stream" to use the Pipecat TTS voice. params (TavusParams): Optional Tavus-specific configuration parameters. input_name (Optional[str]): Optional name for the input transport. output_name (Optional[str]): Optional name for the output transport. @@ -484,7 +456,7 @@ class TavusTransport(BaseTransport): session: aiohttp.ClientSession, api_key: str, replica_id: str, - persona_id: str = "pipecat0", # Use `pipecat0` so that your TTS voice is used in place of the Tavus persona + persona_id: str = "pipecat-stream", params: TavusParams = TavusParams(), input_name: Optional[str] = None, output_name: Optional[str] = None, @@ -492,11 +464,6 @@ class TavusTransport(BaseTransport): super().__init__(input_name=input_name, output_name=output_name) self._params = params - # TODO: Filipi - We can remove this if we stop sending the audio through app messages - # Limiting this so we don't go over 20 messages per second - # each message is going to have 50ms of audio - self._params.audio_out_10ms_chunks = 5 - callbacks = TavusCallbacks( on_participant_joined=self._on_participant_joined, on_participant_left=self._on_participant_left, @@ -527,6 +494,7 @@ class TavusTransport(BaseTransport): async def _on_participant_joined(self, participant): # get persona, look up persona_name, set this as the bot name to ignore persona_name = await self._client.get_persona_name() + # Ignore the Tavus replica's microphone if participant.get("info", {}).get("userName", "") == persona_name: self._tavus_participant_id = participant["id"]