Merge pull request #1950 from pipecat-ai/filipi/tavus_custom_tracks
Sending audio to Tavus using custom tracks
This commit is contained in:
@@ -27,6 +27,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Changed
|
||||
|
||||
- `TavusTransport` and `TavusVideoService` now send audio to Tavus using WebRTC
|
||||
audio tracks instead of `app-messages` over WebSocket. This should improve the
|
||||
overall audio quality.
|
||||
|
||||
- Upgraded `daily-python` to 0.19.3.
|
||||
|
||||
### Fixed
|
||||
|
||||
177
scripts/daily/test_tavus_transport.py
Normal file
177
scripts/daily/test_tavus_transport.py
Normal file
@@ -0,0 +1,177 @@
|
||||
import asyncio
|
||||
import os
|
||||
import signal
|
||||
|
||||
from daily import *
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
def completion_callback(future):
|
||||
def _callback(*args):
|
||||
def set_result(future, *args):
|
||||
try:
|
||||
if len(args) > 1:
|
||||
future.set_result(args)
|
||||
else:
|
||||
future.set_result(*args)
|
||||
except asyncio.InvalidStateError:
|
||||
pass
|
||||
|
||||
future.get_loop().call_soon_threadsafe(set_result, future, *args)
|
||||
|
||||
return _callback
|
||||
|
||||
|
||||
class DailyProxyApp(EventHandler):
|
||||
# This is necessary to override EventHandler's __new__ method.
|
||||
def __new__(cls, *args, **kwargs):
|
||||
return super().__new__(cls)
|
||||
|
||||
def __init__(self, sample_rate: int):
|
||||
super().__init__()
|
||||
self._sample_rate = sample_rate
|
||||
self._loop = None
|
||||
self._audio_queue: asyncio.Queue | None = None
|
||||
self._audio_task: asyncio.Task | None = None
|
||||
|
||||
self._client: CallClient = CallClient(event_handler=self)
|
||||
self._client.update_subscription_profiles(
|
||||
{"base": {"camera": "unsubscribed", "microphone": "subscribed"}}
|
||||
)
|
||||
|
||||
self._audio_source = CustomAudioSource(self._sample_rate, 1)
|
||||
self._audio_track = CustomAudioTrack(self._audio_source)
|
||||
|
||||
def on_joined(self, data, error):
|
||||
logger.debug("Local participant Joined!")
|
||||
if error:
|
||||
print(f"Unable to join meeting: {error}")
|
||||
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||
|
||||
def run(self, meeting_url: str):
|
||||
self._loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._loop)
|
||||
self._create_audio_task()
|
||||
|
||||
def handle_exit():
|
||||
logger.info("Ctrl+C pressed. Leaving the meeting...")
|
||||
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||
|
||||
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||
self._loop.add_signal_handler(sig, handle_exit)
|
||||
|
||||
self._client.set_user_name("TestTavusTransport")
|
||||
self._client.join(
|
||||
meeting_url,
|
||||
completion=self.on_joined,
|
||||
client_settings={
|
||||
"inputs": {
|
||||
"microphone": {
|
||||
"isEnabled": True,
|
||||
"settings": {"customTrack": {"id": self._audio_track.id}},
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
self._loop.run_forever()
|
||||
finally:
|
||||
self.leave()
|
||||
|
||||
def leave(self):
|
||||
if self._audio_task:
|
||||
self._loop.run_until_complete(self._cancel_audio_task())
|
||||
|
||||
self._client.leave()
|
||||
self._client.release()
|
||||
|
||||
async def update_subscriptions(self, participant_settings=None, profile_settings=None):
|
||||
logger.info(f"Updating subscriptions participant_settings: {participant_settings}")
|
||||
future = asyncio.get_running_loop().create_future()
|
||||
self._client.update_subscriptions(
|
||||
participant_settings=participant_settings,
|
||||
profile_settings=profile_settings,
|
||||
completion=completion_callback(future),
|
||||
)
|
||||
await future
|
||||
|
||||
def _create_audio_task(self):
|
||||
if not self._audio_task:
|
||||
self._audio_queue = asyncio.Queue()
|
||||
self._audio_task = self._loop.create_task(self._audio_task_handler())
|
||||
|
||||
async def _cancel_audio_task(self):
|
||||
if self._audio_task:
|
||||
self._audio_task.cancel()
|
||||
try:
|
||||
# Waits for it to finish
|
||||
await self._audio_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._audio_task = None
|
||||
self._audio_queue = None
|
||||
|
||||
async def capture_participant_audio(self, participant_id: str):
|
||||
logger.info(f"Capturing participant audio: {participant_id}")
|
||||
# Receiving from this custom track
|
||||
# audio_source: str = "microphone"
|
||||
audio_source: str = "stream"
|
||||
media = {"media": {"customAudio": {audio_source: "subscribed"}}}
|
||||
await self.update_subscriptions(participant_settings={participant_id: media})
|
||||
|
||||
self._client.set_audio_renderer(
|
||||
participant_id,
|
||||
self._audio_data_received,
|
||||
audio_source=audio_source,
|
||||
sample_rate=self._sample_rate,
|
||||
callback_interval_ms=20,
|
||||
)
|
||||
|
||||
async def send_audio(self, audio: AudioData):
|
||||
future = asyncio.get_running_loop().create_future()
|
||||
self._audio_source.write_frames(audio.audio_frames, completion=completion_callback(future))
|
||||
await future
|
||||
|
||||
async def queue_audio(self, audio: AudioData):
|
||||
await self._audio_queue.put(audio)
|
||||
|
||||
def _audio_data_received(self, participant_id: str, audio_data: AudioData, audio_source: str):
|
||||
# logger.info(f"Received audio data for {participant_id}, audio_source: {audio_source}")
|
||||
asyncio.run_coroutine_threadsafe(self.queue_audio(audio_data), self._loop)
|
||||
|
||||
async def _audio_task_handler(self):
|
||||
while True:
|
||||
audio = await self._audio_queue.get()
|
||||
await self.send_audio(audio)
|
||||
|
||||
#
|
||||
# Daily (EventHandler)
|
||||
#
|
||||
|
||||
def on_participant_joined(self, participant):
|
||||
participant_name = participant["info"]["userName"]
|
||||
logger.info(f"Participant {participant_name} joined")
|
||||
if participant_name != "Pipecat":
|
||||
# We are only subscribing for audios from Pipecat.
|
||||
return
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.capture_participant_audio(participant_id=participant["id"]), self._loop
|
||||
)
|
||||
|
||||
def on_participant_left(self, participant, reason):
|
||||
logger.info(f"Participant {participant['id']} left {reason}")
|
||||
|
||||
|
||||
def main():
|
||||
Daily.init()
|
||||
room_url = os.getenv("TAVUS_SAMPLE_ROOM_URL")
|
||||
app = DailyProxyApp(sample_rate=24000)
|
||||
app.run(room_url)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,4 +1,5 @@
|
||||
ruff format src
|
||||
ruff format examples
|
||||
ruff format tests
|
||||
ruff format scripts
|
||||
ruff check --select I --fix
|
||||
@@ -7,7 +7,6 @@
|
||||
"""This module implements Tavus as a sink transport layer"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
@@ -29,9 +28,6 @@ from pipecat.processors.frame_processor import FrameDirection, FrameProcessorSet
|
||||
from pipecat.services.ai_service import AIService
|
||||
from pipecat.transports.services.tavus import TavusCallbacks, TavusParams, TavusTransportClient
|
||||
|
||||
# Using the same values that we do in the BaseOutputTransport
|
||||
BOT_VAD_STOP_SECS = 0.35
|
||||
|
||||
|
||||
class TavusVideoService(AIService):
|
||||
"""
|
||||
@@ -48,7 +44,7 @@ class TavusVideoService(AIService):
|
||||
Args:
|
||||
api_key (str): Tavus API key used for authentication.
|
||||
replica_id (str): ID of the Tavus voice replica to use for speech synthesis.
|
||||
persona_id (str): ID of the Tavus persona. Defaults to "pipecat0" to use the Pipecat TTS voice.
|
||||
persona_id (str): ID of the Tavus persona. Defaults to "pipecat-stream" to use the Pipecat TTS voice.
|
||||
session (aiohttp.ClientSession): Async HTTP session used for communication with Tavus.
|
||||
**kwargs: Additional arguments passed to the parent `AIService` class.
|
||||
"""
|
||||
@@ -58,7 +54,7 @@ class TavusVideoService(AIService):
|
||||
*,
|
||||
api_key: str,
|
||||
replica_id: str,
|
||||
persona_id: str = "pipecat0", # Use `pipecat0` so that your TTS voice is used in place of the Tavus persona
|
||||
persona_id: str = "pipecat-stream",
|
||||
session: aiohttp.ClientSession,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
@@ -77,6 +73,8 @@ class TavusVideoService(AIService):
|
||||
self._audio_buffer = bytearray()
|
||||
self._queue = asyncio.Queue()
|
||||
self._send_task: Optional[asyncio.Task] = None
|
||||
# This is the custom track destination expected by Tavus
|
||||
self._transport_destination: Optional[str] = "stream"
|
||||
|
||||
async def setup(self, setup: FrameProcessorSetup):
|
||||
await super().setup(setup)
|
||||
@@ -94,6 +92,8 @@ class TavusVideoService(AIService):
|
||||
params=TavusParams(
|
||||
audio_in_enabled=True,
|
||||
video_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
microphone_out_enabled=False,
|
||||
),
|
||||
)
|
||||
await self._client.setup(setup)
|
||||
@@ -152,6 +152,8 @@ class TavusVideoService(AIService):
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
await self._client.start(frame)
|
||||
if self._transport_destination:
|
||||
await self._client.register_audio_destination(self._transport_destination)
|
||||
await self._create_send_task()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
@@ -171,7 +173,7 @@ class TavusVideoService(AIService):
|
||||
await self._handle_interruptions()
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, TTSAudioRawFrame):
|
||||
await self._queue.put(frame)
|
||||
await self._handle_audio_frame(frame)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -194,60 +196,26 @@ class TavusVideoService(AIService):
|
||||
await self.cancel_task(self._send_task)
|
||||
self._send_task = None
|
||||
|
||||
async def _send_task_handler(self):
|
||||
# Daily app-messages have a 4kb limit and also a rate limit of 20
|
||||
# messages per second. Below, we only consider the rate limit because 1
|
||||
# second of a 24000 sample rate would be 48000 bytes (16-bit samples and
|
||||
# 1 channel). So, that is 48000 / 20 = 2400, which is below the 4kb
|
||||
# limit (even including base64 encoding). For a sample rate of 16000,
|
||||
# that would be 32000 / 20 = 1600.
|
||||
async def _handle_audio_frame(self, frame: OutputAudioRawFrame):
|
||||
sample_rate = self._client.out_sample_rate
|
||||
# 50 ms of audio
|
||||
MAX_CHUNK_SIZE = int((sample_rate * 2) / 20)
|
||||
|
||||
audio_buffer = bytearray()
|
||||
current_idx_str = None
|
||||
silence = b"\x00" * MAX_CHUNK_SIZE
|
||||
samples_sent = 0
|
||||
start_time = None
|
||||
# 40 ms of audio
|
||||
chunk_size = int((sample_rate * 2) / 25)
|
||||
# We might need to resample if incoming audio doesn't match the
|
||||
# transport sample rate.
|
||||
resampled = await self._resampler.resample(frame.audio, frame.sample_rate, sample_rate)
|
||||
self._audio_buffer.extend(resampled)
|
||||
while len(self._audio_buffer) >= chunk_size:
|
||||
chunk = OutputAudioRawFrame(
|
||||
bytes(self._audio_buffer[:chunk_size]),
|
||||
sample_rate=sample_rate,
|
||||
num_channels=frame.num_channels,
|
||||
)
|
||||
chunk.transport_destination = self._transport_destination
|
||||
await self._queue.put(chunk)
|
||||
self._audio_buffer = self._audio_buffer[chunk_size:]
|
||||
|
||||
async def _send_task_handler(self):
|
||||
while True:
|
||||
try:
|
||||
frame = await asyncio.wait_for(self._queue.get(), timeout=BOT_VAD_STOP_SECS)
|
||||
if isinstance(frame, TTSAudioRawFrame):
|
||||
# starting the new inference
|
||||
if current_idx_str is None:
|
||||
current_idx_str = str(frame.id)
|
||||
samples_sent = 0
|
||||
start_time = time.time()
|
||||
|
||||
audio = await self._resampler.resample(
|
||||
frame.audio, frame.sample_rate, sample_rate
|
||||
)
|
||||
audio_buffer.extend(audio)
|
||||
while len(audio_buffer) >= MAX_CHUNK_SIZE:
|
||||
chunk = audio_buffer[:MAX_CHUNK_SIZE]
|
||||
audio_buffer = audio_buffer[MAX_CHUNK_SIZE:]
|
||||
|
||||
# Compute wait time for synchronization
|
||||
wait = start_time + (samples_sent / sample_rate) - time.time()
|
||||
if wait > 0:
|
||||
logger.trace(f"TavusVideoService _send_task_handler wait: {wait}")
|
||||
await asyncio.sleep(wait)
|
||||
|
||||
await self._client.encode_audio_and_send(
|
||||
bytes(chunk), False, current_idx_str
|
||||
)
|
||||
|
||||
# Update timestamp based on number of samples sent
|
||||
samples_sent += len(chunk) // 2 # 2 bytes per sample (16-bit)
|
||||
except asyncio.TimeoutError:
|
||||
# Bot has stopped speaking
|
||||
# Send any remaining audio.
|
||||
if len(audio_buffer) > 0:
|
||||
await self._client.encode_audio_and_send(
|
||||
bytes(audio_buffer), False, current_idx_str
|
||||
)
|
||||
await self._client.encode_audio_and_send(silence, True, current_idx_str)
|
||||
audio_buffer.clear()
|
||||
current_idx_str = None
|
||||
frame = await self._queue.get()
|
||||
if isinstance(frame, OutputAudioRawFrame):
|
||||
await self._client.write_audio_frame(frame)
|
||||
|
||||
@@ -767,6 +767,7 @@ class DailyTransportClient(EventHandler):
|
||||
self._client.add_custom_audio_track(
|
||||
track_name=track_name,
|
||||
audio_track=audio_track,
|
||||
ignore_audio_level=True,
|
||||
completion=completion_callback(future),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import time
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Any, Awaitable, Callable, Mapping, Optional
|
||||
|
||||
@@ -11,8 +9,6 @@ from pydantic import BaseModel
|
||||
|
||||
from pipecat.audio.utils import create_default_resampler
|
||||
from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
@@ -40,6 +36,8 @@ class TavusApi:
|
||||
"""
|
||||
|
||||
BASE_URL = "https://tavusapi.com/v2"
|
||||
MOCK_CONVERSATION_ID = "dev-conversation"
|
||||
MOCK_PERSONA_NAME = "TestTavusTransport"
|
||||
|
||||
def __init__(self, api_key: str, session: aiohttp.ClientSession):
|
||||
"""
|
||||
@@ -52,8 +50,16 @@ class TavusApi:
|
||||
self._api_key = api_key
|
||||
self._session = session
|
||||
self._headers = {"Content-Type": "application/json", "x-api-key": self._api_key}
|
||||
# Only for development
|
||||
self._dev_room_url = os.getenv("TAVUS_SAMPLE_ROOM_URL")
|
||||
|
||||
async def create_conversation(self, replica_id: str, persona_id: str) -> dict:
|
||||
if self._dev_room_url:
|
||||
return {
|
||||
"conversation_id": self.MOCK_CONVERSATION_ID,
|
||||
"conversation_url": self._dev_room_url,
|
||||
}
|
||||
|
||||
logger.debug(f"Creating Tavus conversation: replica={replica_id}, persona={persona_id}")
|
||||
url = f"{self.BASE_URL}/conversations"
|
||||
payload = {
|
||||
@@ -67,7 +73,7 @@ class TavusApi:
|
||||
return response
|
||||
|
||||
async def end_conversation(self, conversation_id: str):
|
||||
if conversation_id is None:
|
||||
if conversation_id is None or conversation_id == self.MOCK_CONVERSATION_ID:
|
||||
return
|
||||
|
||||
url = f"{self.BASE_URL}/conversations/{conversation_id}/end"
|
||||
@@ -76,6 +82,9 @@ class TavusApi:
|
||||
logger.debug(f"Ended Tavus conversation {conversation_id}")
|
||||
|
||||
async def get_persona_name(self, persona_id: str) -> str:
|
||||
if self._dev_room_url is not None:
|
||||
return self.MOCK_PERSONA_NAME
|
||||
|
||||
url = f"{self.BASE_URL}/personas/{persona_id}"
|
||||
async with self._session.get(url, headers=self._headers) as r:
|
||||
r.raise_for_status()
|
||||
@@ -119,7 +128,7 @@ class TavusTransportClient:
|
||||
callbacks (TavusCallbacks): Callback handlers for Tavus-related events.
|
||||
api_key (str): API key for authenticating with Tavus API.
|
||||
replica_id (str): ID of the replica to use in the Tavus conversation.
|
||||
persona_id (str): ID of the Tavus persona. Defaults to "pipecat0", which signals Tavus to use
|
||||
persona_id (str): ID of the Tavus persona. Defaults to "pipecat-stream", which signals Tavus to use
|
||||
the TTS voice of the Pipecat bot instead of a Tavus persona voice.
|
||||
session (aiohttp.ClientSession): The aiohttp session for making async HTTP requests.
|
||||
sample_rate: Audio sample rate to be used by the client.
|
||||
@@ -133,7 +142,7 @@ class TavusTransportClient:
|
||||
callbacks: TavusCallbacks,
|
||||
api_key: str,
|
||||
replica_id: str,
|
||||
persona_id: str = "pipecat0", # Use `pipecat0` so that your TTS voice is used in place of the Tavus persona
|
||||
persona_id: str = "pipecat-stream",
|
||||
session: aiohttp.ClientSession,
|
||||
) -> None:
|
||||
self._bot_name = bot_name
|
||||
@@ -141,7 +150,6 @@ class TavusTransportClient:
|
||||
self._replica_id = replica_id
|
||||
self._persona_id = persona_id
|
||||
self._conversation_id: Optional[str] = None
|
||||
self._other_participant_has_joined = False
|
||||
self._client: Optional[DailyTransportClient] = None
|
||||
self._callbacks = callbacks
|
||||
self._params = params
|
||||
@@ -153,6 +161,7 @@ class TavusTransportClient:
|
||||
|
||||
async def setup(self, setup: FrameProcessorSetup):
|
||||
if self._conversation_id is not None:
|
||||
logger.debug(f"Conversation ID already defined: {self._conversation_id}")
|
||||
return
|
||||
try:
|
||||
room_url = await self._initialize()
|
||||
@@ -194,12 +203,13 @@ class TavusTransportClient:
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to setup TavusTransportClient: {e}")
|
||||
await self._api.end_conversation(self._conversation_id)
|
||||
self._conversation_id = None
|
||||
|
||||
async def cleanup(self):
|
||||
if self._client is None:
|
||||
return
|
||||
await self._client.cleanup()
|
||||
self._client = None
|
||||
try:
|
||||
await self._client.cleanup()
|
||||
except Exception as e:
|
||||
logger.exception(f"Exception during cleanup: {e}")
|
||||
|
||||
async def _on_joined(self, data):
|
||||
logger.debug("TavusTransportClient joined!")
|
||||
@@ -221,6 +231,7 @@ class TavusTransportClient:
|
||||
async def stop(self):
|
||||
await self._client.leave()
|
||||
await self._api.end_conversation(self._conversation_id)
|
||||
self._conversation_id = None
|
||||
|
||||
async def capture_participant_video(
|
||||
self,
|
||||
@@ -257,11 +268,6 @@ class TavusTransportClient:
|
||||
def in_sample_rate(self) -> int:
|
||||
return self._client.in_sample_rate
|
||||
|
||||
async def encode_audio_and_send(self, audio: bytes, done: bool, inference_id: str):
|
||||
"""Encodes audio to base64 and sends it to Tavus"""
|
||||
audio_base64 = base64.b64encode(audio).decode("utf-8")
|
||||
await self._send_audio_message(audio_base64, done=done, inference_id=inference_id)
|
||||
|
||||
async def send_interrupt_message(self) -> None:
|
||||
transport_frame = TransportMessageUrgentFrame(
|
||||
message={
|
||||
@@ -272,23 +278,6 @@ class TavusTransportClient:
|
||||
)
|
||||
await self.send_message(transport_frame)
|
||||
|
||||
async def _send_audio_message(self, audio_base64: str, done: bool, inference_id: str):
|
||||
transport_frame = TransportMessageUrgentFrame(
|
||||
message={
|
||||
"message_type": "conversation",
|
||||
"event_type": "conversation.echo",
|
||||
"conversation_id": self._conversation_id,
|
||||
"properties": {
|
||||
"modality": "audio",
|
||||
"inference_id": inference_id,
|
||||
"audio": audio_base64,
|
||||
"done": done,
|
||||
"sample_rate": self.out_sample_rate,
|
||||
},
|
||||
}
|
||||
)
|
||||
await self.send_message(transport_frame)
|
||||
|
||||
async def update_subscriptions(self, participant_settings=None, profile_settings=None):
|
||||
if not self._client:
|
||||
return
|
||||
@@ -300,9 +289,14 @@ class TavusTransportClient:
|
||||
async def write_audio_frame(self, frame: OutputAudioRawFrame):
|
||||
if not self._client:
|
||||
return
|
||||
|
||||
await self._client.write_audio_frame(frame)
|
||||
|
||||
async def register_audio_destination(self, destination: str):
|
||||
if not self._client:
|
||||
return
|
||||
|
||||
await self._client.register_audio_destination(destination)
|
||||
|
||||
|
||||
class TavusInputTransport(BaseInputTransport):
|
||||
def __init__(
|
||||
@@ -379,12 +373,11 @@ class TavusOutputTransport(BaseOutputTransport):
|
||||
super().__init__(params, **kwargs)
|
||||
self._client = client
|
||||
self._params = params
|
||||
self._samples_sent = 0
|
||||
self._start_time = None
|
||||
self._current_idx_str: Optional[str] = None
|
||||
|
||||
# Whether we have seen a StartFrame already.
|
||||
self._initialized = False
|
||||
# This is the custom track destination expected by Tavus
|
||||
self._transport_destination: Optional[str] = "stream"
|
||||
|
||||
async def setup(self, setup: FrameProcessorSetup):
|
||||
await super().setup(setup)
|
||||
@@ -403,6 +396,10 @@ class TavusOutputTransport(BaseOutputTransport):
|
||||
self._initialized = True
|
||||
|
||||
await self._client.start(frame)
|
||||
|
||||
if self._transport_destination:
|
||||
await self._client.register_audio_destination(self._transport_destination)
|
||||
|
||||
await self.set_transport_ready(frame)
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
@@ -417,23 +414,6 @@ class TavusOutputTransport(BaseOutputTransport):
|
||||
logger.info(f"TavusOutputTransport sending message {frame}")
|
||||
await self._client.send_message(frame)
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
# The BotStartedSpeakingFrame and BotStoppedSpeakingFrame are created inside BaseOutputTransport
|
||||
# so TavusOutputTransport never receives these frames.
|
||||
# This is a workaround, so we can more reliably be aware when the bot has started or stopped speaking
|
||||
if direction == FrameDirection.DOWNSTREAM:
|
||||
if isinstance(frame, BotStartedSpeakingFrame):
|
||||
if self._current_idx_str is not None:
|
||||
logger.warning("TavusOutputTransport self._current_idx_str is already defined!")
|
||||
self._current_idx_str = str(frame.id)
|
||||
self._start_time = time.time()
|
||||
self._samples_sent = 0
|
||||
elif isinstance(frame, BotStoppedSpeakingFrame):
|
||||
silence = b"\x00" * self.audio_chunk_size
|
||||
await self._client.encode_audio_and_send(silence, True, self._current_idx_str)
|
||||
self._current_idx_str = None
|
||||
await super().push_frame(frame, direction)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
@@ -443,20 +423,12 @@ class TavusOutputTransport(BaseOutputTransport):
|
||||
await self._client.send_interrupt_message()
|
||||
|
||||
async def write_audio_frame(self, frame: OutputAudioRawFrame):
|
||||
# Compute wait time for synchronization
|
||||
wait = self._start_time + (self._samples_sent / self.sample_rate) - time.time()
|
||||
if wait > 0:
|
||||
logger.trace(f"TavusOutputTransport write_audio_frame wait: {wait}")
|
||||
await asyncio.sleep(wait)
|
||||
# This is the custom track destination expected by Tavus
|
||||
frame.transport_destination = self._transport_destination
|
||||
await self._client.write_audio_frame(frame)
|
||||
|
||||
if self._current_idx_str is None:
|
||||
logger.warning("TavusOutputTransport self._current_idx_str not defined yet!")
|
||||
return
|
||||
|
||||
await self._client.encode_audio_and_send(frame.audio, False, self._current_idx_str)
|
||||
|
||||
# Update timestamp based on number of samples sent
|
||||
self._samples_sent += len(frame.audio) // 2 # 2 bytes per sample (16-bit)
|
||||
async def register_audio_destination(self, destination: str):
|
||||
await self._client.register_audio_destination(destination)
|
||||
|
||||
|
||||
class TavusTransport(BaseTransport):
|
||||
@@ -472,7 +444,7 @@ class TavusTransport(BaseTransport):
|
||||
session (aiohttp.ClientSession): aiohttp session used for async HTTP requests.
|
||||
api_key (str): Tavus API key for authentication.
|
||||
replica_id (str): ID of the replica model used for voice generation.
|
||||
persona_id (str): ID of the Tavus persona. Defaults to "pipecat0" to use the Pipecat TTS voice.
|
||||
persona_id (str): ID of the Tavus persona. Defaults to "pipecat-stream" to use the Pipecat TTS voice.
|
||||
params (TavusParams): Optional Tavus-specific configuration parameters.
|
||||
input_name (Optional[str]): Optional name for the input transport.
|
||||
output_name (Optional[str]): Optional name for the output transport.
|
||||
@@ -484,7 +456,7 @@ class TavusTransport(BaseTransport):
|
||||
session: aiohttp.ClientSession,
|
||||
api_key: str,
|
||||
replica_id: str,
|
||||
persona_id: str = "pipecat0", # Use `pipecat0` so that your TTS voice is used in place of the Tavus persona
|
||||
persona_id: str = "pipecat-stream",
|
||||
params: TavusParams = TavusParams(),
|
||||
input_name: Optional[str] = None,
|
||||
output_name: Optional[str] = None,
|
||||
@@ -492,11 +464,6 @@ class TavusTransport(BaseTransport):
|
||||
super().__init__(input_name=input_name, output_name=output_name)
|
||||
self._params = params
|
||||
|
||||
# TODO: Filipi - We can remove this if we stop sending the audio through app messages
|
||||
# Limiting this so we don't go over 20 messages per second
|
||||
# each message is going to have 50ms of audio
|
||||
self._params.audio_out_10ms_chunks = 5
|
||||
|
||||
callbacks = TavusCallbacks(
|
||||
on_participant_joined=self._on_participant_joined,
|
||||
on_participant_left=self._on_participant_left,
|
||||
@@ -527,6 +494,7 @@ class TavusTransport(BaseTransport):
|
||||
async def _on_participant_joined(self, participant):
|
||||
# get persona, look up persona_name, set this as the bot name to ignore
|
||||
persona_name = await self._client.get_persona_name()
|
||||
|
||||
# Ignore the Tavus replica's microphone
|
||||
if participant.get("info", {}).get("userName", "") == persona_name:
|
||||
self._tavus_participant_id = participant["id"]
|
||||
|
||||
Reference in New Issue
Block a user