Creating TavusTransport and TavusTransportClient.
This commit is contained in:
532
src/pipecat/transports/services/tavus.py
Normal file
532
src/pipecat/transports/services/tavus.py
Normal file
@@ -0,0 +1,532 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import Any, Awaitable, Callable, Mapping, Optional
|
||||
|
||||
import aiohttp
|
||||
from daily.daily import AudioData
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.audio.utils import create_default_resampler
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
OutputImageRawFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TransportMessageFrame,
|
||||
TransportMessageUrgentFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor, FrameProcessorSetup
|
||||
from pipecat.transports.base_input import BaseInputTransport
|
||||
from pipecat.transports.base_output import BaseOutputTransport
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.services.daily import (
|
||||
DailyCallbacks,
|
||||
DailyParams,
|
||||
DailyTransportClient,
|
||||
)
|
||||
|
||||
|
||||
class TavusApi:
|
||||
"""
|
||||
A helper class for interacting with the Tavus API (v2).
|
||||
"""
|
||||
|
||||
BASE_URL = "https://tavusapi.com/v2"
|
||||
|
||||
def __init__(self, api_key: str, session: aiohttp.ClientSession):
|
||||
"""
|
||||
Initialize the TavusApi client.
|
||||
|
||||
Args:
|
||||
api_key (str): Tavus API key.
|
||||
session (aiohttp.ClientSession): An aiohttp session for making HTTP requests.
|
||||
"""
|
||||
self._api_key = api_key
|
||||
self._session = session
|
||||
self._headers = {"Content-Type": "application/json", "x-api-key": self._api_key}
|
||||
|
||||
async def create_conversation(self, replica_id: str, persona_id: str) -> dict:
|
||||
logger.debug(f"Creating Tavus conversation: replica={replica_id}, persona={persona_id}")
|
||||
url = f"{self.BASE_URL}/conversations"
|
||||
payload = {
|
||||
"replica_id": replica_id,
|
||||
"persona_id": persona_id,
|
||||
}
|
||||
async with self._session.post(url, headers=self._headers, json=payload) as r:
|
||||
r.raise_for_status()
|
||||
response = await r.json()
|
||||
logger.debug(f"Created Tavus conversation: {response}")
|
||||
return response
|
||||
|
||||
async def end_conversation(self, conversation_id: str):
|
||||
if conversation_id is None:
|
||||
return
|
||||
|
||||
url = f"{self.BASE_URL}/conversations/{conversation_id}/end"
|
||||
async with self._session.post(url, headers=self._headers) as r:
|
||||
r.raise_for_status()
|
||||
logger.debug(f"Ended Tavus conversation {conversation_id}")
|
||||
|
||||
async def get_persona_name(self, persona_id: str) -> str:
|
||||
url = f"{self.BASE_URL}/personas/{persona_id}"
|
||||
async with self._session.get(url, headers=self._headers) as r:
|
||||
r.raise_for_status()
|
||||
response = await r.json()
|
||||
logger.debug(f"Fetched Tavus persona: {response}")
|
||||
return response["persona_name"]
|
||||
|
||||
|
||||
class TavusCallbacks(BaseModel):
|
||||
"""Callback handlers for the Tavus events.
|
||||
|
||||
Attributes:
|
||||
on_participant_joined: Called when a participant joins.
|
||||
on_participant_left: Called when a participant leaves.
|
||||
"""
|
||||
|
||||
on_participant_joined: Callable[[Mapping[str, Any]], Awaitable[None]]
|
||||
on_participant_left: Callable[[Mapping[str, Any], str], Awaitable[None]]
|
||||
|
||||
|
||||
class TavusParams(DailyParams):
|
||||
"""Configuration parameters for the Tavus transport."""
|
||||
|
||||
audio_in_enabled: bool = True
|
||||
audio_out_enabled: bool = True
|
||||
microphone_out_enabled: bool = False
|
||||
|
||||
|
||||
class TavusTransportClient:
|
||||
"""
|
||||
A transport client that integrates a Pipecat Bot with the Tavus platform by managing
|
||||
conversation sessions using the Tavus API.
|
||||
|
||||
This client uses `TavusApi` to interact with the Tavus backend services. When a conversation
|
||||
is started via `TavusApi`, Tavus provides a `roomURL` that can be used to connect the Pipecat Bot
|
||||
into the same virtual room where the TavusBot is operating.
|
||||
|
||||
Args:
|
||||
bot_name (str): The name of the Pipecat bot instance.
|
||||
params (TavusParams): Optional parameters for Tavus operation. Defaults to `TavusParams()`.
|
||||
callbacks (TavusCallbacks): Callback handlers for Tavus-related events.
|
||||
api_key (str): API key for authenticating with Tavus API.
|
||||
replica_id (str): ID of the replica to use in the Tavus conversation.
|
||||
persona_id (str): ID of the Tavus persona. Defaults to "pipecat0", which signals Tavus to use
|
||||
the TTS voice of the Pipecat bot instead of a Tavus persona voice.
|
||||
session (aiohttp.ClientSession): The aiohttp session for making async HTTP requests.
|
||||
sample_rate: Audio sample rate to be used by the client.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
bot_name: str,
|
||||
params: TavusParams = TavusParams(),
|
||||
callbacks: TavusCallbacks,
|
||||
api_key: str,
|
||||
replica_id: str,
|
||||
persona_id: str = "pipecat0", # Use `pipecat0` so that your TTS voice is used in place of the Tavus persona
|
||||
session: aiohttp.ClientSession,
|
||||
) -> None:
|
||||
self._bot_name = bot_name
|
||||
self._api = TavusApi(api_key, session)
|
||||
self._replica_id = replica_id
|
||||
self._persona_id = persona_id
|
||||
self._conversation_id: Optional[str] = None
|
||||
self._other_participant_has_joined = False
|
||||
self._client: Optional[DailyTransportClient] = None
|
||||
self._callbacks = callbacks
|
||||
self._params = params
|
||||
|
||||
async def _initialize(self) -> str:
|
||||
response = await self._api.create_conversation(self._replica_id, self._persona_id)
|
||||
self._conversation_id = response["conversation_id"]
|
||||
return response["conversation_url"]
|
||||
|
||||
async def setup(self, setup: FrameProcessorSetup):
|
||||
if self._conversation_id is not None:
|
||||
return
|
||||
try:
|
||||
room_url = await self._initialize()
|
||||
daily_callbacks = DailyCallbacks(
|
||||
on_active_speaker_changed=partial(
|
||||
self._on_handle_callback, "on_active_speaker_changed"
|
||||
),
|
||||
on_joined=self._on_joined,
|
||||
on_left=self._on_left,
|
||||
on_error=partial(self._on_handle_callback, "on_error"),
|
||||
on_app_message=partial(self._on_handle_callback, "on_app_message"),
|
||||
on_call_state_updated=partial(self._on_handle_callback, "on_call_state_updated"),
|
||||
on_client_connected=partial(self._on_handle_callback, "on_client_connected"),
|
||||
on_client_disconnected=partial(self._on_handle_callback, "on_client_disconnected"),
|
||||
on_dialin_connected=partial(self._on_handle_callback, "on_dialin_connected"),
|
||||
on_dialin_ready=partial(self._on_handle_callback, "on_dialin_ready"),
|
||||
on_dialin_stopped=partial(self._on_handle_callback, "on_dialin_stopped"),
|
||||
on_dialin_error=partial(self._on_handle_callback, "on_dialin_error"),
|
||||
on_dialin_warning=partial(self._on_handle_callback, "on_dialin_warning"),
|
||||
on_dialout_answered=partial(self._on_handle_callback, "on_dialout_answered"),
|
||||
on_dialout_connected=partial(self._on_handle_callback, "on_dialout_connected"),
|
||||
on_dialout_stopped=partial(self._on_handle_callback, "on_dialout_stopped"),
|
||||
on_dialout_error=partial(self._on_handle_callback, "on_dialout_error"),
|
||||
on_dialout_warning=partial(self._on_handle_callback, "on_dialout_warning"),
|
||||
on_participant_joined=self._callbacks.on_participant_joined,
|
||||
on_participant_left=self._callbacks.on_participant_left,
|
||||
on_participant_updated=partial(self._on_handle_callback, "on_participant_updated"),
|
||||
on_transcription_message=partial(
|
||||
self._on_handle_callback, "on_transcription_message"
|
||||
),
|
||||
on_recording_started=partial(self._on_handle_callback, "on_recording_started"),
|
||||
on_recording_stopped=partial(self._on_handle_callback, "on_recording_stopped"),
|
||||
on_recording_error=partial(self._on_handle_callback, "on_recording_error"),
|
||||
)
|
||||
self._client = DailyTransportClient(
|
||||
room_url, None, "Pipecat", self._params, daily_callbacks, self._bot_name
|
||||
)
|
||||
await self._client.setup(setup)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to setup TavusTransportClient: {e}")
|
||||
await self._api.end_conversation(self._conversation_id)
|
||||
|
||||
async def cleanup(self):
|
||||
if self._client is None:
|
||||
return
|
||||
await self._client.cleanup()
|
||||
self._client = None
|
||||
|
||||
async def _on_joined(self, data):
|
||||
logger.debug("TavusTransportClient joined!")
|
||||
|
||||
async def _on_left(self):
|
||||
logger.debug("TavusTransportClient left!")
|
||||
|
||||
async def _on_handle_callback(self, event_name, *args, **kwargs):
|
||||
logger.trace(f"[Callback] {event_name} called with args={args}, kwargs={kwargs}")
|
||||
|
||||
async def get_persona_name(self) -> str:
|
||||
return await self._api.get_persona_name(self._persona_id)
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
logger.debug("TavusTransportClient start invoked!")
|
||||
await self._client.start(frame)
|
||||
await self._client.join()
|
||||
|
||||
async def stop(self):
|
||||
await self._client.leave()
|
||||
await self._api.end_conversation(self._conversation_id)
|
||||
|
||||
async def capture_participant_video(
|
||||
self,
|
||||
participant_id: str,
|
||||
callback: Callable,
|
||||
framerate: int = 30,
|
||||
video_source: str = "camera",
|
||||
color_format: str = "RGB",
|
||||
):
|
||||
await self._client.capture_participant_video(
|
||||
participant_id, callback, framerate, video_source, color_format
|
||||
)
|
||||
|
||||
async def capture_participant_audio(
|
||||
self,
|
||||
participant_id: str,
|
||||
callback: Callable,
|
||||
audio_source: str = "microphone",
|
||||
sample_rate: int = 16000,
|
||||
callback_interval_ms: int = 20,
|
||||
):
|
||||
await self._client.capture_participant_audio(
|
||||
participant_id, callback, audio_source, sample_rate, callback_interval_ms
|
||||
)
|
||||
|
||||
async def send_message(self, frame: TransportMessageFrame | TransportMessageUrgentFrame):
|
||||
await self._client.send_message(frame)
|
||||
|
||||
@property
|
||||
def out_sample_rate(self) -> int:
|
||||
return self._client.out_sample_rate
|
||||
|
||||
@property
|
||||
def in_sample_rate(self) -> int:
|
||||
return self._client.in_sample_rate
|
||||
|
||||
async def encode_audio_and_send(self, audio: bytes, done: bool, inference_id: str):
|
||||
"""Encodes audio to base64 and sends it to Tavus"""
|
||||
audio_base64 = base64.b64encode(audio).decode("utf-8")
|
||||
await self._send_audio_message(audio_base64, done=done, inference_id=inference_id)
|
||||
|
||||
async def send_interrupt_message(self) -> None:
|
||||
transport_frame = TransportMessageUrgentFrame(
|
||||
message={
|
||||
"message_type": "conversation",
|
||||
"event_type": "conversation.interrupt",
|
||||
"conversation_id": self._conversation_id,
|
||||
}
|
||||
)
|
||||
await self.send_message(transport_frame)
|
||||
|
||||
async def _send_audio_message(self, audio_base64: str, done: bool, inference_id: str):
|
||||
transport_frame = TransportMessageUrgentFrame(
|
||||
message={
|
||||
"message_type": "conversation",
|
||||
"event_type": "conversation.echo",
|
||||
"conversation_id": self._conversation_id,
|
||||
"properties": {
|
||||
"modality": "audio",
|
||||
"inference_id": inference_id,
|
||||
"audio": audio_base64,
|
||||
"done": done,
|
||||
"sample_rate": self.out_sample_rate,
|
||||
},
|
||||
}
|
||||
)
|
||||
await self.send_message(transport_frame)
|
||||
|
||||
async def update_subscriptions(self, participant_settings=None, profile_settings=None):
|
||||
await self._client.update_subscriptions(
|
||||
participant_settings=participant_settings, profile_settings=profile_settings
|
||||
)
|
||||
|
||||
async def write_raw_audio_frames(self, frames: bytes, destination: Optional[str] = None):
|
||||
await self._client.write_raw_audio_frames(frames, destination)
|
||||
|
||||
|
||||
class TavusInputTransport(BaseInputTransport):
|
||||
def __init__(
|
||||
self,
|
||||
client: TavusTransportClient,
|
||||
params: TransportParams,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(params, **kwargs)
|
||||
self._client = client
|
||||
self._params = params
|
||||
self._resampler = create_default_resampler()
|
||||
|
||||
async def setup(self, setup: FrameProcessorSetup):
|
||||
await super().setup(setup)
|
||||
await self._client.setup(setup)
|
||||
|
||||
async def cleanup(self):
|
||||
await super().cleanup()
|
||||
await self._client.cleanup()
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
await self._client.start(frame)
|
||||
await self.set_transport_ready(frame)
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
await self._client.stop()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
await self._client.stop()
|
||||
|
||||
async def start_capturing_audio(self, participant):
|
||||
if self._params.audio_in_enabled:
|
||||
logger.info(
|
||||
f"TavusTransportClient start capturing audio for participant {participant['id']}"
|
||||
)
|
||||
await self._client.capture_participant_audio(
|
||||
participant_id=participant["id"],
|
||||
callback=self._on_participant_audio_data,
|
||||
sample_rate=self._client.in_sample_rate,
|
||||
)
|
||||
|
||||
async def _on_participant_audio_data(
|
||||
self, participant_id: str, audio: AudioData, audio_source: str
|
||||
):
|
||||
frame = InputAudioRawFrame(
|
||||
audio=audio.audio_frames,
|
||||
sample_rate=audio.audio_frames,
|
||||
num_channels=audio.num_channels,
|
||||
)
|
||||
frame.transport_source = audio_source
|
||||
await self.push_audio_frame(frame)
|
||||
|
||||
|
||||
class TavusOutputTransport(BaseOutputTransport):
|
||||
def __init__(
|
||||
self,
|
||||
client: TavusTransportClient,
|
||||
params: TransportParams,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(params, **kwargs)
|
||||
self._client = client
|
||||
self._params = params
|
||||
self._samples_sent = 0
|
||||
self._start_time = time.time()
|
||||
|
||||
async def setup(self, setup: FrameProcessorSetup):
|
||||
await super().setup(setup)
|
||||
await self._client.setup(setup)
|
||||
|
||||
async def cleanup(self):
|
||||
await super().cleanup()
|
||||
await self._client.cleanup()
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
self._samples_sent = 0
|
||||
self._start_time = time.time()
|
||||
await self._client.start(frame)
|
||||
await self.set_transport_ready(frame)
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
await self._client.stop()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
await self._client.stop()
|
||||
|
||||
async def send_message(self, frame: TransportMessageFrame | TransportMessageUrgentFrame):
|
||||
logger.info(f"TavusOutputTransport sending message {frame}")
|
||||
await self._client.send_message(frame)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
await self._handle_interruptions()
|
||||
elif isinstance(frame, TTSStartedFrame):
|
||||
self._current_idx_str = str(frame.id)
|
||||
elif isinstance(frame, TTSStoppedFrame):
|
||||
logger.debug(f"TAVUS: {self}: stopped speaking")
|
||||
await self._client.encode_audio_and_send(b"\x00\x00", True, self._current_idx_str)
|
||||
|
||||
async def _handle_interruptions(self):
|
||||
await self._client.send_interrupt_message()
|
||||
|
||||
async def write_raw_audio_frames(self, frames: bytes, destination: Optional[str] = None):
|
||||
# Compute wait time for synchronization
|
||||
wait = self._start_time + (self._samples_sent / self._sample_rate) - time.time()
|
||||
if wait > 0:
|
||||
await asyncio.sleep(wait)
|
||||
|
||||
await self._client.encode_audio_and_send(frames, False, self._current_idx_str)
|
||||
|
||||
# Update timestamp based on number of samples sent
|
||||
self._samples_sent += len(frames) // 2 # 2 bytes per sample (16-bit)
|
||||
|
||||
async def write_raw_video_frame(
|
||||
self, frame: OutputImageRawFrame, destination: Optional[str] = None
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class TavusTransport(BaseTransport):
|
||||
"""
|
||||
Transport implementation for Tavus video calls.
|
||||
|
||||
When used, the Pipecat bot joins the same virtual room as the Tavus Avatar and the user.
|
||||
This is achieved by using `TavusTransportClient`, which initiates the conversation via
|
||||
`TavusApi` and obtains a room URL that all participants connect to.
|
||||
|
||||
Args:
|
||||
bot_name (str): The name of the Pipecat bot.
|
||||
session (aiohttp.ClientSession): aiohttp session used for async HTTP requests.
|
||||
api_key (str): Tavus API key for authentication.
|
||||
replica_id (str): ID of the replica model used for voice generation.
|
||||
persona_id (str): ID of the Tavus persona. Defaults to "pipecat0" to use the Pipecat TTS voice.
|
||||
params (TavusParams): Optional Tavus-specific configuration parameters.
|
||||
input_name (Optional[str]): Optional name for the input transport.
|
||||
output_name (Optional[str]): Optional name for the output transport.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bot_name: str,
|
||||
session: aiohttp.ClientSession,
|
||||
api_key: str,
|
||||
replica_id: str,
|
||||
persona_id: str = "pipecat0", # Use `pipecat0` so that your TTS voice is used in place of the Tavus persona
|
||||
params: TavusParams = TavusParams(),
|
||||
input_name: Optional[str] = None,
|
||||
output_name: Optional[str] = None,
|
||||
):
|
||||
super().__init__(input_name=input_name, output_name=output_name)
|
||||
self._params = params
|
||||
|
||||
# TODO: Filipi - We can remove this if we stop sending the audio through app messages
|
||||
# Limiting this so we don't go over 20 messages per second
|
||||
# each message is going to have 50ms of audio
|
||||
self._params.audio_out_10ms_chunks = 5
|
||||
|
||||
callbacks = TavusCallbacks(
|
||||
on_participant_joined=self._on_participant_joined,
|
||||
on_participant_left=self._on_participant_left,
|
||||
)
|
||||
self._client = TavusTransportClient(
|
||||
bot_name="Pipecat",
|
||||
callbacks=callbacks,
|
||||
api_key=api_key,
|
||||
replica_id=replica_id,
|
||||
persona_id=persona_id,
|
||||
session=session,
|
||||
params=params,
|
||||
)
|
||||
self._input: Optional[TavusInputTransport] = None
|
||||
self._output: Optional[TavusOutputTransport] = None
|
||||
self._tavus_participant_id = None
|
||||
|
||||
# Register supported handlers. The user will only be able to register
|
||||
# these handlers.
|
||||
self._register_event_handler("on_client_connected")
|
||||
self._register_event_handler("on_client_disconnected")
|
||||
|
||||
async def _on_participant_left(self, participant, reason):
|
||||
persona_name = await self._client.get_persona_name()
|
||||
if participant.get("info", {}).get("userName", "") != persona_name:
|
||||
await self._on_client_disconnected(participant)
|
||||
|
||||
async def _on_participant_joined(self, participant):
|
||||
# get persona, look up persona_name, set this as the bot name to ignore
|
||||
persona_name = await self._client.get_persona_name()
|
||||
# Ignore the Tavus replica's microphone
|
||||
if participant.get("info", {}).get("userName", "") == persona_name:
|
||||
self._tavus_participant_id = participant["id"]
|
||||
else:
|
||||
await self._on_client_connected(participant)
|
||||
if self._tavus_participant_id:
|
||||
logger.debug(f"Ignoring {self._tavus_participant_id}'s microphone")
|
||||
await self.update_subscriptions(
|
||||
participant_settings={
|
||||
self._tavus_participant_id: {
|
||||
"media": {"microphone": "unsubscribed"},
|
||||
}
|
||||
}
|
||||
)
|
||||
if self._input:
|
||||
await self._input.start_capturing_audio(participant)
|
||||
|
||||
async def update_subscriptions(self, participant_settings=None, profile_settings=None):
|
||||
await self._client.update_subscriptions(
|
||||
participant_settings=participant_settings,
|
||||
profile_settings=profile_settings,
|
||||
)
|
||||
|
||||
def input(self) -> FrameProcessor:
|
||||
if not self._input:
|
||||
self._input = TavusInputTransport(client=self._client, params=self._params)
|
||||
return self._input
|
||||
|
||||
def output(self) -> FrameProcessor:
|
||||
if not self._output:
|
||||
self._output = TavusOutputTransport(client=self._client, params=self._params)
|
||||
return self._output
|
||||
|
||||
async def _on_client_connected(self, participant: Any):
|
||||
await self._call_event_handler("on_client_connected", participant)
|
||||
|
||||
async def _on_client_disconnected(self, participant: Any):
|
||||
await self._call_event_handler("on_client_disconnected", participant)
|
||||
Reference in New Issue
Block a user