Creating TavusTransport and TavusTransportClient.

This commit is contained in:
Filipi Fuchter
2025-05-23 23:02:37 -03:00
parent 28b7a92a00
commit 86e6841569

View File

@@ -0,0 +1,532 @@
import asyncio
import base64
import time
from functools import partial
from typing import Any, Awaitable, Callable, Mapping, Optional
import aiohttp
from daily.daily import AudioData
from loguru import logger
from pydantic import BaseModel
from pipecat.audio.utils import create_default_resampler
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
Frame,
InputAudioRawFrame,
OutputImageRawFrame,
StartFrame,
StartInterruptionFrame,
TransportMessageFrame,
TransportMessageUrgentFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor, FrameProcessorSetup
from pipecat.transports.base_input import BaseInputTransport
from pipecat.transports.base_output import BaseOutputTransport
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.services.daily import (
DailyCallbacks,
DailyParams,
DailyTransportClient,
)
class TavusApi:
"""
A helper class for interacting with the Tavus API (v2).
"""
BASE_URL = "https://tavusapi.com/v2"
def __init__(self, api_key: str, session: aiohttp.ClientSession):
"""
Initialize the TavusApi client.
Args:
api_key (str): Tavus API key.
session (aiohttp.ClientSession): An aiohttp session for making HTTP requests.
"""
self._api_key = api_key
self._session = session
self._headers = {"Content-Type": "application/json", "x-api-key": self._api_key}
async def create_conversation(self, replica_id: str, persona_id: str) -> dict:
logger.debug(f"Creating Tavus conversation: replica={replica_id}, persona={persona_id}")
url = f"{self.BASE_URL}/conversations"
payload = {
"replica_id": replica_id,
"persona_id": persona_id,
}
async with self._session.post(url, headers=self._headers, json=payload) as r:
r.raise_for_status()
response = await r.json()
logger.debug(f"Created Tavus conversation: {response}")
return response
async def end_conversation(self, conversation_id: str):
if conversation_id is None:
return
url = f"{self.BASE_URL}/conversations/{conversation_id}/end"
async with self._session.post(url, headers=self._headers) as r:
r.raise_for_status()
logger.debug(f"Ended Tavus conversation {conversation_id}")
async def get_persona_name(self, persona_id: str) -> str:
url = f"{self.BASE_URL}/personas/{persona_id}"
async with self._session.get(url, headers=self._headers) as r:
r.raise_for_status()
response = await r.json()
logger.debug(f"Fetched Tavus persona: {response}")
return response["persona_name"]
class TavusCallbacks(BaseModel):
"""Callback handlers for the Tavus events.
Attributes:
on_participant_joined: Called when a participant joins.
on_participant_left: Called when a participant leaves.
"""
on_participant_joined: Callable[[Mapping[str, Any]], Awaitable[None]]
on_participant_left: Callable[[Mapping[str, Any], str], Awaitable[None]]
class TavusParams(DailyParams):
"""Configuration parameters for the Tavus transport."""
audio_in_enabled: bool = True
audio_out_enabled: bool = True
microphone_out_enabled: bool = False
class TavusTransportClient:
"""
A transport client that integrates a Pipecat Bot with the Tavus platform by managing
conversation sessions using the Tavus API.
This client uses `TavusApi` to interact with the Tavus backend services. When a conversation
is started via `TavusApi`, Tavus provides a `roomURL` that can be used to connect the Pipecat Bot
into the same virtual room where the TavusBot is operating.
Args:
bot_name (str): The name of the Pipecat bot instance.
params (TavusParams): Optional parameters for Tavus operation. Defaults to `TavusParams()`.
callbacks (TavusCallbacks): Callback handlers for Tavus-related events.
api_key (str): API key for authenticating with Tavus API.
replica_id (str): ID of the replica to use in the Tavus conversation.
persona_id (str): ID of the Tavus persona. Defaults to "pipecat0", which signals Tavus to use
the TTS voice of the Pipecat bot instead of a Tavus persona voice.
session (aiohttp.ClientSession): The aiohttp session for making async HTTP requests.
sample_rate: Audio sample rate to be used by the client.
"""
def __init__(
self,
*,
bot_name: str,
params: TavusParams = TavusParams(),
callbacks: TavusCallbacks,
api_key: str,
replica_id: str,
persona_id: str = "pipecat0", # Use `pipecat0` so that your TTS voice is used in place of the Tavus persona
session: aiohttp.ClientSession,
) -> None:
self._bot_name = bot_name
self._api = TavusApi(api_key, session)
self._replica_id = replica_id
self._persona_id = persona_id
self._conversation_id: Optional[str] = None
self._other_participant_has_joined = False
self._client: Optional[DailyTransportClient] = None
self._callbacks = callbacks
self._params = params
async def _initialize(self) -> str:
response = await self._api.create_conversation(self._replica_id, self._persona_id)
self._conversation_id = response["conversation_id"]
return response["conversation_url"]
async def setup(self, setup: FrameProcessorSetup):
if self._conversation_id is not None:
return
try:
room_url = await self._initialize()
daily_callbacks = DailyCallbacks(
on_active_speaker_changed=partial(
self._on_handle_callback, "on_active_speaker_changed"
),
on_joined=self._on_joined,
on_left=self._on_left,
on_error=partial(self._on_handle_callback, "on_error"),
on_app_message=partial(self._on_handle_callback, "on_app_message"),
on_call_state_updated=partial(self._on_handle_callback, "on_call_state_updated"),
on_client_connected=partial(self._on_handle_callback, "on_client_connected"),
on_client_disconnected=partial(self._on_handle_callback, "on_client_disconnected"),
on_dialin_connected=partial(self._on_handle_callback, "on_dialin_connected"),
on_dialin_ready=partial(self._on_handle_callback, "on_dialin_ready"),
on_dialin_stopped=partial(self._on_handle_callback, "on_dialin_stopped"),
on_dialin_error=partial(self._on_handle_callback, "on_dialin_error"),
on_dialin_warning=partial(self._on_handle_callback, "on_dialin_warning"),
on_dialout_answered=partial(self._on_handle_callback, "on_dialout_answered"),
on_dialout_connected=partial(self._on_handle_callback, "on_dialout_connected"),
on_dialout_stopped=partial(self._on_handle_callback, "on_dialout_stopped"),
on_dialout_error=partial(self._on_handle_callback, "on_dialout_error"),
on_dialout_warning=partial(self._on_handle_callback, "on_dialout_warning"),
on_participant_joined=self._callbacks.on_participant_joined,
on_participant_left=self._callbacks.on_participant_left,
on_participant_updated=partial(self._on_handle_callback, "on_participant_updated"),
on_transcription_message=partial(
self._on_handle_callback, "on_transcription_message"
),
on_recording_started=partial(self._on_handle_callback, "on_recording_started"),
on_recording_stopped=partial(self._on_handle_callback, "on_recording_stopped"),
on_recording_error=partial(self._on_handle_callback, "on_recording_error"),
)
self._client = DailyTransportClient(
room_url, None, "Pipecat", self._params, daily_callbacks, self._bot_name
)
await self._client.setup(setup)
except Exception as e:
logger.error(f"Failed to setup TavusTransportClient: {e}")
await self._api.end_conversation(self._conversation_id)
async def cleanup(self):
if self._client is None:
return
await self._client.cleanup()
self._client = None
async def _on_joined(self, data):
logger.debug("TavusTransportClient joined!")
async def _on_left(self):
logger.debug("TavusTransportClient left!")
async def _on_handle_callback(self, event_name, *args, **kwargs):
logger.trace(f"[Callback] {event_name} called with args={args}, kwargs={kwargs}")
async def get_persona_name(self) -> str:
return await self._api.get_persona_name(self._persona_id)
async def start(self, frame: StartFrame):
logger.debug("TavusTransportClient start invoked!")
await self._client.start(frame)
await self._client.join()
async def stop(self):
await self._client.leave()
await self._api.end_conversation(self._conversation_id)
async def capture_participant_video(
self,
participant_id: str,
callback: Callable,
framerate: int = 30,
video_source: str = "camera",
color_format: str = "RGB",
):
await self._client.capture_participant_video(
participant_id, callback, framerate, video_source, color_format
)
async def capture_participant_audio(
self,
participant_id: str,
callback: Callable,
audio_source: str = "microphone",
sample_rate: int = 16000,
callback_interval_ms: int = 20,
):
await self._client.capture_participant_audio(
participant_id, callback, audio_source, sample_rate, callback_interval_ms
)
async def send_message(self, frame: TransportMessageFrame | TransportMessageUrgentFrame):
await self._client.send_message(frame)
@property
def out_sample_rate(self) -> int:
return self._client.out_sample_rate
@property
def in_sample_rate(self) -> int:
return self._client.in_sample_rate
async def encode_audio_and_send(self, audio: bytes, done: bool, inference_id: str):
"""Encodes audio to base64 and sends it to Tavus"""
audio_base64 = base64.b64encode(audio).decode("utf-8")
await self._send_audio_message(audio_base64, done=done, inference_id=inference_id)
async def send_interrupt_message(self) -> None:
transport_frame = TransportMessageUrgentFrame(
message={
"message_type": "conversation",
"event_type": "conversation.interrupt",
"conversation_id": self._conversation_id,
}
)
await self.send_message(transport_frame)
async def _send_audio_message(self, audio_base64: str, done: bool, inference_id: str):
transport_frame = TransportMessageUrgentFrame(
message={
"message_type": "conversation",
"event_type": "conversation.echo",
"conversation_id": self._conversation_id,
"properties": {
"modality": "audio",
"inference_id": inference_id,
"audio": audio_base64,
"done": done,
"sample_rate": self.out_sample_rate,
},
}
)
await self.send_message(transport_frame)
async def update_subscriptions(self, participant_settings=None, profile_settings=None):
await self._client.update_subscriptions(
participant_settings=participant_settings, profile_settings=profile_settings
)
async def write_raw_audio_frames(self, frames: bytes, destination: Optional[str] = None):
await self._client.write_raw_audio_frames(frames, destination)
class TavusInputTransport(BaseInputTransport):
def __init__(
self,
client: TavusTransportClient,
params: TransportParams,
**kwargs,
):
super().__init__(params, **kwargs)
self._client = client
self._params = params
self._resampler = create_default_resampler()
async def setup(self, setup: FrameProcessorSetup):
await super().setup(setup)
await self._client.setup(setup)
async def cleanup(self):
await super().cleanup()
await self._client.cleanup()
async def start(self, frame: StartFrame):
await super().start(frame)
await self._client.start(frame)
await self.set_transport_ready(frame)
async def stop(self, frame: EndFrame):
await super().stop(frame)
await self._client.stop()
async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
await self._client.stop()
async def start_capturing_audio(self, participant):
if self._params.audio_in_enabled:
logger.info(
f"TavusTransportClient start capturing audio for participant {participant['id']}"
)
await self._client.capture_participant_audio(
participant_id=participant["id"],
callback=self._on_participant_audio_data,
sample_rate=self._client.in_sample_rate,
)
async def _on_participant_audio_data(
self, participant_id: str, audio: AudioData, audio_source: str
):
frame = InputAudioRawFrame(
audio=audio.audio_frames,
sample_rate=audio.audio_frames,
num_channels=audio.num_channels,
)
frame.transport_source = audio_source
await self.push_audio_frame(frame)
class TavusOutputTransport(BaseOutputTransport):
def __init__(
self,
client: TavusTransportClient,
params: TransportParams,
**kwargs,
):
super().__init__(params, **kwargs)
self._client = client
self._params = params
self._samples_sent = 0
self._start_time = time.time()
async def setup(self, setup: FrameProcessorSetup):
await super().setup(setup)
await self._client.setup(setup)
async def cleanup(self):
await super().cleanup()
await self._client.cleanup()
async def start(self, frame: StartFrame):
await super().start(frame)
self._samples_sent = 0
self._start_time = time.time()
await self._client.start(frame)
await self.set_transport_ready(frame)
async def stop(self, frame: EndFrame):
await super().stop(frame)
await self._client.stop()
async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
await self._client.stop()
async def send_message(self, frame: TransportMessageFrame | TransportMessageUrgentFrame):
logger.info(f"TavusOutputTransport sending message {frame}")
await self._client.send_message(frame)
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, StartInterruptionFrame):
await self._handle_interruptions()
elif isinstance(frame, TTSStartedFrame):
self._current_idx_str = str(frame.id)
elif isinstance(frame, TTSStoppedFrame):
logger.debug(f"TAVUS: {self}: stopped speaking")
await self._client.encode_audio_and_send(b"\x00\x00", True, self._current_idx_str)
async def _handle_interruptions(self):
await self._client.send_interrupt_message()
async def write_raw_audio_frames(self, frames: bytes, destination: Optional[str] = None):
# Compute wait time for synchronization
wait = self._start_time + (self._samples_sent / self._sample_rate) - time.time()
if wait > 0:
await asyncio.sleep(wait)
await self._client.encode_audio_and_send(frames, False, self._current_idx_str)
# Update timestamp based on number of samples sent
self._samples_sent += len(frames) // 2 # 2 bytes per sample (16-bit)
async def write_raw_video_frame(
self, frame: OutputImageRawFrame, destination: Optional[str] = None
):
pass
class TavusTransport(BaseTransport):
"""
Transport implementation for Tavus video calls.
When used, the Pipecat bot joins the same virtual room as the Tavus Avatar and the user.
This is achieved by using `TavusTransportClient`, which initiates the conversation via
`TavusApi` and obtains a room URL that all participants connect to.
Args:
bot_name (str): The name of the Pipecat bot.
session (aiohttp.ClientSession): aiohttp session used for async HTTP requests.
api_key (str): Tavus API key for authentication.
replica_id (str): ID of the replica model used for voice generation.
persona_id (str): ID of the Tavus persona. Defaults to "pipecat0" to use the Pipecat TTS voice.
params (TavusParams): Optional Tavus-specific configuration parameters.
input_name (Optional[str]): Optional name for the input transport.
output_name (Optional[str]): Optional name for the output transport.
"""
def __init__(
self,
bot_name: str,
session: aiohttp.ClientSession,
api_key: str,
replica_id: str,
persona_id: str = "pipecat0", # Use `pipecat0` so that your TTS voice is used in place of the Tavus persona
params: TavusParams = TavusParams(),
input_name: Optional[str] = None,
output_name: Optional[str] = None,
):
super().__init__(input_name=input_name, output_name=output_name)
self._params = params
# TODO: Filipi - We can remove this if we stop sending the audio through app messages
# Limiting this so we don't go over 20 messages per second
# each message is going to have 50ms of audio
self._params.audio_out_10ms_chunks = 5
callbacks = TavusCallbacks(
on_participant_joined=self._on_participant_joined,
on_participant_left=self._on_participant_left,
)
self._client = TavusTransportClient(
bot_name="Pipecat",
callbacks=callbacks,
api_key=api_key,
replica_id=replica_id,
persona_id=persona_id,
session=session,
params=params,
)
self._input: Optional[TavusInputTransport] = None
self._output: Optional[TavusOutputTransport] = None
self._tavus_participant_id = None
# Register supported handlers. The user will only be able to register
# these handlers.
self._register_event_handler("on_client_connected")
self._register_event_handler("on_client_disconnected")
async def _on_participant_left(self, participant, reason):
persona_name = await self._client.get_persona_name()
if participant.get("info", {}).get("userName", "") != persona_name:
await self._on_client_disconnected(participant)
async def _on_participant_joined(self, participant):
# get persona, look up persona_name, set this as the bot name to ignore
persona_name = await self._client.get_persona_name()
# Ignore the Tavus replica's microphone
if participant.get("info", {}).get("userName", "") == persona_name:
self._tavus_participant_id = participant["id"]
else:
await self._on_client_connected(participant)
if self._tavus_participant_id:
logger.debug(f"Ignoring {self._tavus_participant_id}'s microphone")
await self.update_subscriptions(
participant_settings={
self._tavus_participant_id: {
"media": {"microphone": "unsubscribed"},
}
}
)
if self._input:
await self._input.start_capturing_audio(participant)
async def update_subscriptions(self, participant_settings=None, profile_settings=None):
await self._client.update_subscriptions(
participant_settings=participant_settings,
profile_settings=profile_settings,
)
def input(self) -> FrameProcessor:
if not self._input:
self._input = TavusInputTransport(client=self._client, params=self._params)
return self._input
def output(self) -> FrameProcessor:
if not self._output:
self._output = TavusOutputTransport(client=self._client, params=self._params)
return self._output
async def _on_client_connected(self, participant: Any):
await self._call_event_handler("on_client_connected", participant)
async def _on_client_disconnected(self, participant: Any):
await self._call_event_handler("on_client_disconnected", participant)