Merge pull request #1950 from pipecat-ai/filipi/tavus_custom_tracks

Sending audio to Tavus using custom tracks
This commit is contained in:
Filipi da Silva Fuchter
2025-06-18 07:57:19 -03:00
committed by GitHub
6 changed files with 256 additions and 137 deletions

View File

@@ -27,6 +27,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed
- `TavusTransport` and `TavusVideoService` now send audio to Tavus using WebRTC
audio tracks instead of `app-messages` over WebSocket. This should improve the
overall audio quality.
- Upgraded `daily-python` to 0.19.3.
### Fixed

View File

@@ -0,0 +1,177 @@
import asyncio
import os
import signal
from daily import *
from dotenv import load_dotenv
from loguru import logger
load_dotenv(override=True)
def completion_callback(future):
def _callback(*args):
def set_result(future, *args):
try:
if len(args) > 1:
future.set_result(args)
else:
future.set_result(*args)
except asyncio.InvalidStateError:
pass
future.get_loop().call_soon_threadsafe(set_result, future, *args)
return _callback
class DailyProxyApp(EventHandler):
# This is necessary to override EventHandler's __new__ method.
def __new__(cls, *args, **kwargs):
return super().__new__(cls)
def __init__(self, sample_rate: int):
super().__init__()
self._sample_rate = sample_rate
self._loop = None
self._audio_queue: asyncio.Queue | None = None
self._audio_task: asyncio.Task | None = None
self._client: CallClient = CallClient(event_handler=self)
self._client.update_subscription_profiles(
{"base": {"camera": "unsubscribed", "microphone": "subscribed"}}
)
self._audio_source = CustomAudioSource(self._sample_rate, 1)
self._audio_track = CustomAudioTrack(self._audio_source)
def on_joined(self, data, error):
logger.debug("Local participant Joined!")
if error:
print(f"Unable to join meeting: {error}")
self._loop.call_soon_threadsafe(self._loop.stop)
def run(self, meeting_url: str):
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
self._create_audio_task()
def handle_exit():
logger.info("Ctrl+C pressed. Leaving the meeting...")
self._loop.call_soon_threadsafe(self._loop.stop)
for sig in (signal.SIGINT, signal.SIGTERM):
self._loop.add_signal_handler(sig, handle_exit)
self._client.set_user_name("TestTavusTransport")
self._client.join(
meeting_url,
completion=self.on_joined,
client_settings={
"inputs": {
"microphone": {
"isEnabled": True,
"settings": {"customTrack": {"id": self._audio_track.id}},
},
}
},
)
try:
self._loop.run_forever()
finally:
self.leave()
def leave(self):
if self._audio_task:
self._loop.run_until_complete(self._cancel_audio_task())
self._client.leave()
self._client.release()
async def update_subscriptions(self, participant_settings=None, profile_settings=None):
logger.info(f"Updating subscriptions participant_settings: {participant_settings}")
future = asyncio.get_running_loop().create_future()
self._client.update_subscriptions(
participant_settings=participant_settings,
profile_settings=profile_settings,
completion=completion_callback(future),
)
await future
def _create_audio_task(self):
if not self._audio_task:
self._audio_queue = asyncio.Queue()
self._audio_task = self._loop.create_task(self._audio_task_handler())
async def _cancel_audio_task(self):
if self._audio_task:
self._audio_task.cancel()
try:
# Waits for it to finish
await self._audio_task
except asyncio.CancelledError:
pass
self._audio_task = None
self._audio_queue = None
async def capture_participant_audio(self, participant_id: str):
logger.info(f"Capturing participant audio: {participant_id}")
# Receiving from this custom track
# audio_source: str = "microphone"
audio_source: str = "stream"
media = {"media": {"customAudio": {audio_source: "subscribed"}}}
await self.update_subscriptions(participant_settings={participant_id: media})
self._client.set_audio_renderer(
participant_id,
self._audio_data_received,
audio_source=audio_source,
sample_rate=self._sample_rate,
callback_interval_ms=20,
)
async def send_audio(self, audio: AudioData):
future = asyncio.get_running_loop().create_future()
self._audio_source.write_frames(audio.audio_frames, completion=completion_callback(future))
await future
async def queue_audio(self, audio: AudioData):
await self._audio_queue.put(audio)
def _audio_data_received(self, participant_id: str, audio_data: AudioData, audio_source: str):
# logger.info(f"Received audio data for {participant_id}, audio_source: {audio_source}")
asyncio.run_coroutine_threadsafe(self.queue_audio(audio_data), self._loop)
async def _audio_task_handler(self):
while True:
audio = await self._audio_queue.get()
await self.send_audio(audio)
#
# Daily (EventHandler)
#
def on_participant_joined(self, participant):
participant_name = participant["info"]["userName"]
logger.info(f"Participant {participant_name} joined")
if participant_name != "Pipecat":
# We are only subscribing for audios from Pipecat.
return
asyncio.run_coroutine_threadsafe(
self.capture_participant_audio(participant_id=participant["id"]), self._loop
)
def on_participant_left(self, participant, reason):
logger.info(f"Participant {participant['id']} left {reason}")
def main():
Daily.init()
room_url = os.getenv("TAVUS_SAMPLE_ROOM_URL")
app = DailyProxyApp(sample_rate=24000)
app.run(room_url)
if __name__ == "__main__":
main()

View File

@@ -1,4 +1,5 @@
ruff format src
ruff format examples
ruff format tests
ruff format scripts
ruff check --select I --fix

View File

@@ -7,7 +7,6 @@
"""This module implements Tavus as a sink transport layer"""
import asyncio
import time
from typing import Optional
import aiohttp
@@ -29,9 +28,6 @@ from pipecat.processors.frame_processor import FrameDirection, FrameProcessorSet
from pipecat.services.ai_service import AIService
from pipecat.transports.services.tavus import TavusCallbacks, TavusParams, TavusTransportClient
# Using the same values that we do in the BaseOutputTransport
BOT_VAD_STOP_SECS = 0.35
class TavusVideoService(AIService):
"""
@@ -48,7 +44,7 @@ class TavusVideoService(AIService):
Args:
api_key (str): Tavus API key used for authentication.
replica_id (str): ID of the Tavus voice replica to use for speech synthesis.
persona_id (str): ID of the Tavus persona. Defaults to "pipecat0" to use the Pipecat TTS voice.
persona_id (str): ID of the Tavus persona. Defaults to "pipecat-stream" to use the Pipecat TTS voice.
session (aiohttp.ClientSession): Async HTTP session used for communication with Tavus.
**kwargs: Additional arguments passed to the parent `AIService` class.
"""
@@ -58,7 +54,7 @@ class TavusVideoService(AIService):
*,
api_key: str,
replica_id: str,
persona_id: str = "pipecat0", # Use `pipecat0` so that your TTS voice is used in place of the Tavus persona
persona_id: str = "pipecat-stream",
session: aiohttp.ClientSession,
**kwargs,
) -> None:
@@ -77,6 +73,8 @@ class TavusVideoService(AIService):
self._audio_buffer = bytearray()
self._queue = asyncio.Queue()
self._send_task: Optional[asyncio.Task] = None
# This is the custom track destination expected by Tavus
self._transport_destination: Optional[str] = "stream"
async def setup(self, setup: FrameProcessorSetup):
await super().setup(setup)
@@ -94,6 +92,8 @@ class TavusVideoService(AIService):
params=TavusParams(
audio_in_enabled=True,
video_in_enabled=True,
audio_out_enabled=True,
microphone_out_enabled=False,
),
)
await self._client.setup(setup)
@@ -152,6 +152,8 @@ class TavusVideoService(AIService):
async def start(self, frame: StartFrame):
await super().start(frame)
await self._client.start(frame)
if self._transport_destination:
await self._client.register_audio_destination(self._transport_destination)
await self._create_send_task()
async def stop(self, frame: EndFrame):
@@ -171,7 +173,7 @@ class TavusVideoService(AIService):
await self._handle_interruptions()
await self.push_frame(frame, direction)
elif isinstance(frame, TTSAudioRawFrame):
await self._queue.put(frame)
await self._handle_audio_frame(frame)
else:
await self.push_frame(frame, direction)
@@ -194,60 +196,26 @@ class TavusVideoService(AIService):
await self.cancel_task(self._send_task)
self._send_task = None
async def _send_task_handler(self):
# Daily app-messages have a 4kb limit and also a rate limit of 20
# messages per second. Below, we only consider the rate limit because 1
# second of a 24000 sample rate would be 48000 bytes (16-bit samples and
# 1 channel). So, that is 48000 / 20 = 2400, which is below the 4kb
# limit (even including base64 encoding). For a sample rate of 16000,
# that would be 32000 / 20 = 1600.
async def _handle_audio_frame(self, frame: OutputAudioRawFrame):
sample_rate = self._client.out_sample_rate
# 50 ms of audio
MAX_CHUNK_SIZE = int((sample_rate * 2) / 20)
audio_buffer = bytearray()
current_idx_str = None
silence = b"\x00" * MAX_CHUNK_SIZE
samples_sent = 0
start_time = None
# 40 ms of audio
chunk_size = int((sample_rate * 2) / 25)
# We might need to resample if incoming audio doesn't match the
# transport sample rate.
resampled = await self._resampler.resample(frame.audio, frame.sample_rate, sample_rate)
self._audio_buffer.extend(resampled)
while len(self._audio_buffer) >= chunk_size:
chunk = OutputAudioRawFrame(
bytes(self._audio_buffer[:chunk_size]),
sample_rate=sample_rate,
num_channels=frame.num_channels,
)
chunk.transport_destination = self._transport_destination
await self._queue.put(chunk)
self._audio_buffer = self._audio_buffer[chunk_size:]
async def _send_task_handler(self):
while True:
try:
frame = await asyncio.wait_for(self._queue.get(), timeout=BOT_VAD_STOP_SECS)
if isinstance(frame, TTSAudioRawFrame):
# starting the new inference
if current_idx_str is None:
current_idx_str = str(frame.id)
samples_sent = 0
start_time = time.time()
audio = await self._resampler.resample(
frame.audio, frame.sample_rate, sample_rate
)
audio_buffer.extend(audio)
while len(audio_buffer) >= MAX_CHUNK_SIZE:
chunk = audio_buffer[:MAX_CHUNK_SIZE]
audio_buffer = audio_buffer[MAX_CHUNK_SIZE:]
# Compute wait time for synchronization
wait = start_time + (samples_sent / sample_rate) - time.time()
if wait > 0:
logger.trace(f"TavusVideoService _send_task_handler wait: {wait}")
await asyncio.sleep(wait)
await self._client.encode_audio_and_send(
bytes(chunk), False, current_idx_str
)
# Update timestamp based on number of samples sent
samples_sent += len(chunk) // 2 # 2 bytes per sample (16-bit)
except asyncio.TimeoutError:
# Bot has stopped speaking
# Send any remaining audio.
if len(audio_buffer) > 0:
await self._client.encode_audio_and_send(
bytes(audio_buffer), False, current_idx_str
)
await self._client.encode_audio_and_send(silence, True, current_idx_str)
audio_buffer.clear()
current_idx_str = None
frame = await self._queue.get()
if isinstance(frame, OutputAudioRawFrame):
await self._client.write_audio_frame(frame)

View File

@@ -767,6 +767,7 @@ class DailyTransportClient(EventHandler):
self._client.add_custom_audio_track(
track_name=track_name,
audio_track=audio_track,
ignore_audio_level=True,
completion=completion_callback(future),
)

View File

@@ -1,6 +1,4 @@
import asyncio
import base64
import time
import os
from functools import partial
from typing import Any, Awaitable, Callable, Mapping, Optional
@@ -11,8 +9,6 @@ from pydantic import BaseModel
from pipecat.audio.utils import create_default_resampler
from pipecat.frames.frames import (
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
CancelFrame,
EndFrame,
Frame,
@@ -40,6 +36,8 @@ class TavusApi:
"""
BASE_URL = "https://tavusapi.com/v2"
MOCK_CONVERSATION_ID = "dev-conversation"
MOCK_PERSONA_NAME = "TestTavusTransport"
def __init__(self, api_key: str, session: aiohttp.ClientSession):
"""
@@ -52,8 +50,16 @@ class TavusApi:
self._api_key = api_key
self._session = session
self._headers = {"Content-Type": "application/json", "x-api-key": self._api_key}
# Only for development
self._dev_room_url = os.getenv("TAVUS_SAMPLE_ROOM_URL")
async def create_conversation(self, replica_id: str, persona_id: str) -> dict:
if self._dev_room_url:
return {
"conversation_id": self.MOCK_CONVERSATION_ID,
"conversation_url": self._dev_room_url,
}
logger.debug(f"Creating Tavus conversation: replica={replica_id}, persona={persona_id}")
url = f"{self.BASE_URL}/conversations"
payload = {
@@ -67,7 +73,7 @@ class TavusApi:
return response
async def end_conversation(self, conversation_id: str):
if conversation_id is None:
if conversation_id is None or conversation_id == self.MOCK_CONVERSATION_ID:
return
url = f"{self.BASE_URL}/conversations/{conversation_id}/end"
@@ -76,6 +82,9 @@ class TavusApi:
logger.debug(f"Ended Tavus conversation {conversation_id}")
async def get_persona_name(self, persona_id: str) -> str:
if self._dev_room_url is not None:
return self.MOCK_PERSONA_NAME
url = f"{self.BASE_URL}/personas/{persona_id}"
async with self._session.get(url, headers=self._headers) as r:
r.raise_for_status()
@@ -119,7 +128,7 @@ class TavusTransportClient:
callbacks (TavusCallbacks): Callback handlers for Tavus-related events.
api_key (str): API key for authenticating with Tavus API.
replica_id (str): ID of the replica to use in the Tavus conversation.
persona_id (str): ID of the Tavus persona. Defaults to "pipecat0", which signals Tavus to use
persona_id (str): ID of the Tavus persona. Defaults to "pipecat-stream", which signals Tavus to use
the TTS voice of the Pipecat bot instead of a Tavus persona voice.
session (aiohttp.ClientSession): The aiohttp session for making async HTTP requests.
sample_rate: Audio sample rate to be used by the client.
@@ -133,7 +142,7 @@ class TavusTransportClient:
callbacks: TavusCallbacks,
api_key: str,
replica_id: str,
persona_id: str = "pipecat0", # Use `pipecat0` so that your TTS voice is used in place of the Tavus persona
persona_id: str = "pipecat-stream",
session: aiohttp.ClientSession,
) -> None:
self._bot_name = bot_name
@@ -141,7 +150,6 @@ class TavusTransportClient:
self._replica_id = replica_id
self._persona_id = persona_id
self._conversation_id: Optional[str] = None
self._other_participant_has_joined = False
self._client: Optional[DailyTransportClient] = None
self._callbacks = callbacks
self._params = params
@@ -153,6 +161,7 @@ class TavusTransportClient:
async def setup(self, setup: FrameProcessorSetup):
if self._conversation_id is not None:
logger.debug(f"Conversation ID already defined: {self._conversation_id}")
return
try:
room_url = await self._initialize()
@@ -194,12 +203,13 @@ class TavusTransportClient:
except Exception as e:
logger.error(f"Failed to setup TavusTransportClient: {e}")
await self._api.end_conversation(self._conversation_id)
self._conversation_id = None
async def cleanup(self):
if self._client is None:
return
await self._client.cleanup()
self._client = None
try:
await self._client.cleanup()
except Exception as e:
logger.exception(f"Exception during cleanup: {e}")
async def _on_joined(self, data):
logger.debug("TavusTransportClient joined!")
@@ -221,6 +231,7 @@ class TavusTransportClient:
async def stop(self):
await self._client.leave()
await self._api.end_conversation(self._conversation_id)
self._conversation_id = None
async def capture_participant_video(
self,
@@ -257,11 +268,6 @@ class TavusTransportClient:
def in_sample_rate(self) -> int:
return self._client.in_sample_rate
async def encode_audio_and_send(self, audio: bytes, done: bool, inference_id: str):
"""Encodes audio to base64 and sends it to Tavus"""
audio_base64 = base64.b64encode(audio).decode("utf-8")
await self._send_audio_message(audio_base64, done=done, inference_id=inference_id)
async def send_interrupt_message(self) -> None:
transport_frame = TransportMessageUrgentFrame(
message={
@@ -272,23 +278,6 @@ class TavusTransportClient:
)
await self.send_message(transport_frame)
async def _send_audio_message(self, audio_base64: str, done: bool, inference_id: str):
transport_frame = TransportMessageUrgentFrame(
message={
"message_type": "conversation",
"event_type": "conversation.echo",
"conversation_id": self._conversation_id,
"properties": {
"modality": "audio",
"inference_id": inference_id,
"audio": audio_base64,
"done": done,
"sample_rate": self.out_sample_rate,
},
}
)
await self.send_message(transport_frame)
async def update_subscriptions(self, participant_settings=None, profile_settings=None):
if not self._client:
return
@@ -300,9 +289,14 @@ class TavusTransportClient:
async def write_audio_frame(self, frame: OutputAudioRawFrame):
if not self._client:
return
await self._client.write_audio_frame(frame)
async def register_audio_destination(self, destination: str):
if not self._client:
return
await self._client.register_audio_destination(destination)
class TavusInputTransport(BaseInputTransport):
def __init__(
@@ -379,12 +373,11 @@ class TavusOutputTransport(BaseOutputTransport):
super().__init__(params, **kwargs)
self._client = client
self._params = params
self._samples_sent = 0
self._start_time = None
self._current_idx_str: Optional[str] = None
# Whether we have seen a StartFrame already.
self._initialized = False
# This is the custom track destination expected by Tavus
self._transport_destination: Optional[str] = "stream"
async def setup(self, setup: FrameProcessorSetup):
await super().setup(setup)
@@ -403,6 +396,10 @@ class TavusOutputTransport(BaseOutputTransport):
self._initialized = True
await self._client.start(frame)
if self._transport_destination:
await self._client.register_audio_destination(self._transport_destination)
await self.set_transport_ready(frame)
async def stop(self, frame: EndFrame):
@@ -417,23 +414,6 @@ class TavusOutputTransport(BaseOutputTransport):
logger.info(f"TavusOutputTransport sending message {frame}")
await self._client.send_message(frame)
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
# The BotStartedSpeakingFrame and BotStoppedSpeakingFrame are created inside BaseOutputTransport
# so TavusOutputTransport never receives these frames.
# This is a workaround, so we can more reliably be aware when the bot has started or stopped speaking
if direction == FrameDirection.DOWNSTREAM:
if isinstance(frame, BotStartedSpeakingFrame):
if self._current_idx_str is not None:
logger.warning("TavusOutputTransport self._current_idx_str is already defined!")
self._current_idx_str = str(frame.id)
self._start_time = time.time()
self._samples_sent = 0
elif isinstance(frame, BotStoppedSpeakingFrame):
silence = b"\x00" * self.audio_chunk_size
await self._client.encode_audio_and_send(silence, True, self._current_idx_str)
self._current_idx_str = None
await super().push_frame(frame, direction)
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, StartInterruptionFrame):
@@ -443,20 +423,12 @@ class TavusOutputTransport(BaseOutputTransport):
await self._client.send_interrupt_message()
async def write_audio_frame(self, frame: OutputAudioRawFrame):
# Compute wait time for synchronization
wait = self._start_time + (self._samples_sent / self.sample_rate) - time.time()
if wait > 0:
logger.trace(f"TavusOutputTransport write_audio_frame wait: {wait}")
await asyncio.sleep(wait)
# This is the custom track destination expected by Tavus
frame.transport_destination = self._transport_destination
await self._client.write_audio_frame(frame)
if self._current_idx_str is None:
logger.warning("TavusOutputTransport self._current_idx_str not defined yet!")
return
await self._client.encode_audio_and_send(frame.audio, False, self._current_idx_str)
# Update timestamp based on number of samples sent
self._samples_sent += len(frame.audio) // 2 # 2 bytes per sample (16-bit)
async def register_audio_destination(self, destination: str):
await self._client.register_audio_destination(destination)
class TavusTransport(BaseTransport):
@@ -472,7 +444,7 @@ class TavusTransport(BaseTransport):
session (aiohttp.ClientSession): aiohttp session used for async HTTP requests.
api_key (str): Tavus API key for authentication.
replica_id (str): ID of the replica model used for voice generation.
persona_id (str): ID of the Tavus persona. Defaults to "pipecat0" to use the Pipecat TTS voice.
persona_id (str): ID of the Tavus persona. Defaults to "pipecat-stream" to use the Pipecat TTS voice.
params (TavusParams): Optional Tavus-specific configuration parameters.
input_name (Optional[str]): Optional name for the input transport.
output_name (Optional[str]): Optional name for the output transport.
@@ -484,7 +456,7 @@ class TavusTransport(BaseTransport):
session: aiohttp.ClientSession,
api_key: str,
replica_id: str,
persona_id: str = "pipecat0", # Use `pipecat0` so that your TTS voice is used in place of the Tavus persona
persona_id: str = "pipecat-stream",
params: TavusParams = TavusParams(),
input_name: Optional[str] = None,
output_name: Optional[str] = None,
@@ -492,11 +464,6 @@ class TavusTransport(BaseTransport):
super().__init__(input_name=input_name, output_name=output_name)
self._params = params
# TODO: Filipi - We can remove this if we stop sending the audio through app messages
# Limiting this so we don't go over 20 messages per second
# each message is going to have 50ms of audio
self._params.audio_out_10ms_chunks = 5
callbacks = TavusCallbacks(
on_participant_joined=self._on_participant_joined,
on_participant_left=self._on_participant_left,
@@ -527,6 +494,7 @@ class TavusTransport(BaseTransport):
async def _on_participant_joined(self, participant):
# get persona, look up persona_name, set this as the bot name to ignore
persona_name = await self._client.get_persona_name()
# Ignore the Tavus replica's microphone
if participant.get("info", {}).get("userName", "") == persona_name:
self._tavus_participant_id = participant["id"]