Compare commits
16 Commits
main
...
filipi/tav
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7c938102ad | ||
|
|
359c9394d0 | ||
|
|
144a1ece7b | ||
|
|
6ef7f6446a | ||
|
|
7c61c36825 | ||
|
|
1338da6831 | ||
|
|
0dc1337a3c | ||
|
|
6a238e0d62 | ||
|
|
e7bad7a007 | ||
|
|
b360fbf7fc | ||
|
|
f568b1d8df | ||
|
|
fd7af7ba9f | ||
|
|
40851696b7 | ||
|
|
b7d272a5be | ||
|
|
996aa461ac | ||
|
|
aef6226a1c |
@@ -196,6 +196,10 @@ SPEECHMATICS_API_KEY=...
|
||||
# Tavus
|
||||
TAVUS_API_KEY=...
|
||||
TAVUS_REPLICA_ID=...
|
||||
# Used by scripts/daily/test_tavus_transport.py, which mimics Tavus behavior
|
||||
# inside a Daily room for local testing. Set this to the Daily room URL where
|
||||
# the Pipecat pipeline is running.
|
||||
TAVUS_SAMPLE_ROOM_URL=https://...
|
||||
|
||||
# Telnyx
|
||||
TELNYX_API_KEY=...
|
||||
|
||||
@@ -5,8 +5,12 @@
|
||||
#
|
||||
|
||||
|
||||
import datetime
|
||||
import io
|
||||
import os
|
||||
import wave
|
||||
|
||||
import aiofiles
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
@@ -21,6 +25,7 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.audio.audio_buffer_processor import AudioBufferProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
@@ -32,6 +37,21 @@ from pipecat.transports.daily.transport import DailyParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
async def save_audio_file(audio: bytes, filename: str, sample_rate: int, num_channels: int):
|
||||
"""Save audio data to a WAV file."""
|
||||
if len(audio) > 0:
|
||||
with io.BytesIO() as buffer:
|
||||
with wave.open(buffer, "wb") as wf:
|
||||
wf.setsampwidth(2)
|
||||
wf.setnchannels(num_channels)
|
||||
wf.setframerate(sample_rate)
|
||||
wf.writeframes(audio)
|
||||
async with aiofiles.open(filename, "wb") as file:
|
||||
await file.write(buffer.getvalue())
|
||||
logger.info(f"Audio saved to {filename}")
|
||||
|
||||
|
||||
# We use lambdas to defer transport parameter creation until the transport
|
||||
# type is selected at runtime.
|
||||
transport_params = {
|
||||
@@ -59,7 +79,7 @@ transport_params = {
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
async with aiohttp.ClientSession() as session:
|
||||
stt = DeepgramSTTService(api_key=os.environ["DEEPGRAM_API_KEY"])
|
||||
stt = DeepgramSTTService(api_key=os.environ["DEEPGRAM_API_KEY"], audio_passthrough=True)
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.environ["CARTESIA_API_KEY"],
|
||||
@@ -87,6 +107,8 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
user_params=LLMUserAggregatorParams(vad_analyzer=SileroVADAnalyzer()),
|
||||
)
|
||||
|
||||
audiobuffer = AudioBufferProcessor()
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
@@ -96,6 +118,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
tts, # TTS
|
||||
tavus, # Tavus output layer
|
||||
transport.output(), # Transport bot output
|
||||
audiobuffer, # Audio recording
|
||||
assistant_aggregator, # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
@@ -114,6 +137,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
await audiobuffer.start_recording()
|
||||
# Kick off the conversation.
|
||||
context.add_message(
|
||||
{
|
||||
@@ -128,6 +152,20 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
@audiobuffer.event_handler("on_audio_data")
|
||||
async def on_audio_data(buffer, audio, sample_rate, num_channels):
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"recordings/merged_{timestamp}.wav"
|
||||
os.makedirs("recordings", exist_ok=True)
|
||||
await save_audio_file(audio, filename, sample_rate, num_channels)
|
||||
|
||||
@audiobuffer.event_handler("on_track_audio_data")
|
||||
async def on_track_audio_data(buffer, user_audio, bot_audio, sample_rate, num_channels):
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
os.makedirs("recordings", exist_ok=True)
|
||||
await save_audio_file(user_audio, f"recordings/user_{timestamp}.wav", sample_rate, 1)
|
||||
await save_audio_file(bot_audio, f"recordings/bot_{timestamp}.wav", sample_rate, 1)
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import array
|
||||
import asyncio
|
||||
import datetime
|
||||
import os
|
||||
import signal
|
||||
import wave
|
||||
|
||||
from daily import (
|
||||
AudioData,
|
||||
@@ -15,6 +18,16 @@ from loguru import logger
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
# Pipecat sends audio at this true content rate but declares it as
|
||||
# DECLARED_SAMPLE_RATE to write_frames(), which makes delivery faster than
|
||||
# real-time. We receive at the declared rate (no resampling) and play back at
|
||||
# the true rate so the avatar consumes audio at normal speed.
|
||||
TRUE_SAMPLE_RATE = 24000
|
||||
DECLARED_SAMPLE_RATE = 48000
|
||||
SPEEDUP = DECLARED_SAMPLE_RATE // TRUE_SAMPLE_RATE
|
||||
CHUNK_BYTES = int(TRUE_SAMPLE_RATE * 20 / 1000) * 2 # 20 ms, 16-bit mono
|
||||
MIN_AUDIO_BUFFER = CHUNK_BYTES * 5 # 100 ms pre-buffer
|
||||
|
||||
|
||||
def completion_callback(future):
|
||||
def _callback(*args):
|
||||
@@ -37,19 +50,21 @@ class DailyProxyApp(EventHandler):
|
||||
def __new__(cls, *args, **kwargs):
|
||||
return super().__new__(cls)
|
||||
|
||||
def __init__(self, sample_rate: int):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._sample_rate = sample_rate
|
||||
self._loop = asyncio.new_event_loop()
|
||||
self._audio_queue: asyncio.Queue = asyncio.Queue()
|
||||
# Raw PCM buffer — filled at DECLARED_SAMPLE_RATE speed, drained at TRUE_SAMPLE_RATE speed.
|
||||
self._buffer = bytearray()
|
||||
self._audio_task: asyncio.Task | None = None
|
||||
self._wav_file: wave.Wave_write | None = None
|
||||
|
||||
self._client: CallClient = CallClient(event_handler=self)
|
||||
self._client.update_subscription_profiles(
|
||||
{"base": {"camera": "unsubscribed", "microphone": "subscribed"}}
|
||||
)
|
||||
|
||||
self._audio_source = CustomAudioSource(self._sample_rate, 1)
|
||||
# Playback source declared at TRUE_SAMPLE_RATE — consumes audio at real-time speed.
|
||||
self._audio_source = CustomAudioSource(TRUE_SAMPLE_RATE, 1, False)
|
||||
self._audio_track = CustomAudioTrack(self._audio_source)
|
||||
|
||||
def on_joined(self, data, error):
|
||||
@@ -58,8 +73,27 @@ class DailyProxyApp(EventHandler):
|
||||
print(f"Unable to join meeting: {error}")
|
||||
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||
|
||||
def _open_wav(self):
|
||||
os.makedirs("recordings", exist_ok=True)
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
path = f"recordings/received_pos_speed_{timestamp}.wav"
|
||||
self._wav_file = wave.open(path, "wb")
|
||||
self._wav_file.setnchannels(1)
|
||||
self._wav_file.setsampwidth(2)
|
||||
# Declare TRUE_SAMPLE_RATE so timestamps match bot_*.wav for comparison.
|
||||
# Bytes arrive at DECLARED_SAMPLE_RATE speed (2x real-time) but each byte
|
||||
# is 24kHz content, so the WAV plays back at normal speed.
|
||||
self._wav_file.setframerate(TRUE_SAMPLE_RATE)
|
||||
logger.info(f"Recording received audio to {path}")
|
||||
|
||||
def _close_wav(self):
|
||||
if self._wav_file:
|
||||
self._wav_file.close()
|
||||
self._wav_file = None
|
||||
|
||||
def run(self, meeting_url: str):
|
||||
asyncio.set_event_loop(self._loop)
|
||||
self._open_wav()
|
||||
self._create_audio_task()
|
||||
|
||||
def handle_exit():
|
||||
@@ -92,6 +126,7 @@ class DailyProxyApp(EventHandler):
|
||||
if self._audio_task:
|
||||
self._loop.run_until_complete(self._cancel_audio_task())
|
||||
|
||||
self._close_wav()
|
||||
self._client.leave()
|
||||
self._client.release()
|
||||
|
||||
@@ -113,7 +148,6 @@ class DailyProxyApp(EventHandler):
|
||||
if self._audio_task:
|
||||
self._audio_task.cancel()
|
||||
try:
|
||||
# Waits for it to finish
|
||||
await self._audio_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
@@ -121,46 +155,120 @@ class DailyProxyApp(EventHandler):
|
||||
|
||||
async def capture_participant_audio(self, participant_id: str):
|
||||
logger.info(f"Capturing participant audio: {participant_id}")
|
||||
# Receiving from this custom track
|
||||
# audio_source: str = "microphone"
|
||||
audio_source: str = "stream"
|
||||
media = {"media": {"customAudio": {audio_source: "subscribed"}}}
|
||||
await self.update_subscriptions(participant_settings={participant_id: media})
|
||||
|
||||
# Must match the declared rate Pipecat used so WebRTC skips resampling —
|
||||
# every original byte arrives intact.
|
||||
self._client.set_audio_renderer(
|
||||
participant_id,
|
||||
self._audio_data_received,
|
||||
audio_source=audio_source,
|
||||
sample_rate=self._sample_rate,
|
||||
sample_rate=DECLARED_SAMPLE_RATE,
|
||||
callback_interval_ms=20,
|
||||
)
|
||||
logger.info(
|
||||
f"Receiving at declared_rate={DECLARED_SAMPLE_RATE} Hz "
|
||||
f"(true content: {TRUE_SAMPLE_RATE} Hz, ~{SPEEDUP}x faster than real-time)"
|
||||
)
|
||||
|
||||
async def send_audio(self, audio: AudioData):
|
||||
future = asyncio.get_running_loop().create_future()
|
||||
self._audio_source.write_frames(audio.audio_frames, completion=completion_callback(future))
|
||||
await future
|
||||
@staticmethod
|
||||
def _is_silence(data: bytes, threshold: int = 5) -> bool:
|
||||
# Interpret as 16-bit signed PCM samples and check peak amplitude.
|
||||
# WebRTC-injected silence is all zeros; real TTS audio has non-trivial
|
||||
# amplitude. This lets us skip buffering frames that Pipecat never wrote,
|
||||
# so the buffer only grows when actual speech arrives (via our trick).
|
||||
samples = array.array("h", data)
|
||||
return max(abs(s) for s in samples) < threshold
|
||||
|
||||
async def queue_audio(self, audio: AudioData):
|
||||
await self._audio_queue.put(audio)
|
||||
async def _buffer_audio(self, audio_data: AudioData):
|
||||
"""Append received bytes to the buffer, skipping WebRTC-injected silence.
|
||||
|
||||
Speech frames arrive at DECLARED_SAMPLE_RATE speed (~2x real-time) so the
|
||||
buffer grows ahead of the drain. WebRTC-injected silence (all-zero PCM) is
|
||||
handled differently based on buffer level: below MIN_AUDIO_BUFFER we keep it
|
||||
so the pre-buffer can fill; above that threshold we discard it so the buffer
|
||||
drains back down between utterances.
|
||||
"""
|
||||
new_bytes = audio_data.audio_frames
|
||||
if self._is_silence(new_bytes):
|
||||
if len(self._buffer) < MIN_AUDIO_BUFFER:
|
||||
# Below pre-buffer threshold: add silence so the buffer fills up.
|
||||
self._buffer.extend(new_bytes)
|
||||
# else: buffer is healthy, discard silence so it can drain.
|
||||
return
|
||||
|
||||
self._buffer.extend(new_bytes)
|
||||
|
||||
def _audio_data_received(self, participant_id: str, audio_data: AudioData, audio_source: str):
|
||||
# logger.info(f"Received audio data for {participant_id}, audio_source: {audio_source}")
|
||||
asyncio.run_coroutine_threadsafe(self.queue_audio(audio_data), self._loop)
|
||||
if self._wav_file:
|
||||
self._wav_file.writeframes(audio_data.audio_frames)
|
||||
asyncio.run_coroutine_threadsafe(self._buffer_audio(audio_data), self._loop)
|
||||
|
||||
async def _audio_task_handler(self):
|
||||
while True:
|
||||
audio = await self._audio_queue.get()
|
||||
await self.send_audio(audio)
|
||||
async def _handle_interrupt(self):
|
||||
"""Clear the audio buffer, mimicking the avatar stopping mid-speech."""
|
||||
dropped = len(self._buffer)
|
||||
self._buffer.clear()
|
||||
logger.info(
|
||||
f"Interrupt received — dropped {dropped}B ({dropped / (TRUE_SAMPLE_RATE * 2):.3f}s) from buffer"
|
||||
)
|
||||
|
||||
#
|
||||
# Daily (EventHandler)
|
||||
#
|
||||
|
||||
def on_app_message(self, message, sender):
|
||||
if not isinstance(message, dict):
|
||||
return
|
||||
if message.get("event_type") == "conversation.interrupt":
|
||||
asyncio.run_coroutine_threadsafe(self._handle_interrupt(), self._loop)
|
||||
|
||||
async def _audio_task_handler(self):
|
||||
"""Drain the buffer at TRUE_SAMPLE_RATE speed (real-time playback).
|
||||
|
||||
Waits until min_audio_buffer bytes are accumulated before starting
|
||||
playback, then drains freely in chunk_bytes steps. If the buffer runs
|
||||
dry it re-enters the waiting state so the next burst also gets the
|
||||
pre-buffer delay.
|
||||
"""
|
||||
buffering = True
|
||||
last_log_time = self._loop.time()
|
||||
|
||||
while True:
|
||||
if buffering:
|
||||
if len(self._buffer) >= MIN_AUDIO_BUFFER:
|
||||
buffering = False
|
||||
logger.debug(f"Pre-buffer reached ({MIN_AUDIO_BUFFER}B) — starting playback")
|
||||
else:
|
||||
await asyncio.sleep(0.001)
|
||||
continue
|
||||
|
||||
if len(self._buffer) >= CHUNK_BYTES:
|
||||
chunk = bytes(self._buffer[:CHUNK_BYTES])
|
||||
del self._buffer[:CHUNK_BYTES]
|
||||
|
||||
future = asyncio.get_running_loop().create_future()
|
||||
self._audio_source.write_frames(chunk, completion=completion_callback(future))
|
||||
await future
|
||||
else:
|
||||
buffering = True
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
now = self._loop.time()
|
||||
if now - last_log_time >= 1.0:
|
||||
buffer_seconds = len(self._buffer) / (TRUE_SAMPLE_RATE * 2)
|
||||
if buffer_seconds > 0:
|
||||
logger.info(
|
||||
f"Buffer status: {len(self._buffer)}B ({buffer_seconds:.3f}s buffered)"
|
||||
)
|
||||
last_log_time = now
|
||||
|
||||
def on_participant_joined(self, participant):
|
||||
participant_name = participant["info"]["userName"]
|
||||
logger.info(f"Participant {participant_name} joined")
|
||||
if participant_name != "Pipecat":
|
||||
# We are only subscribing for audios from Pipecat.
|
||||
# We are only subscribing for audio from Pipecat.
|
||||
return
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.capture_participant_audio(participant_id=participant["id"]), self._loop
|
||||
@@ -173,7 +281,7 @@ class DailyProxyApp(EventHandler):
|
||||
def main():
|
||||
Daily.init()
|
||||
room_url = os.environ["TAVUS_SAMPLE_ROOM_URL"]
|
||||
app = DailyProxyApp(sample_rate=24000)
|
||||
app = DailyProxyApp()
|
||||
app.run(room_url)
|
||||
|
||||
|
||||
|
||||
@@ -11,6 +11,9 @@ avatar functionality through Tavus's streaming API.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import os
|
||||
import wave
|
||||
from dataclasses import dataclass
|
||||
|
||||
import aiohttp
|
||||
@@ -101,8 +104,9 @@ class TavusVideoService(AIService):
|
||||
self._audio_buffer = bytearray()
|
||||
self._send_task: asyncio.Task | None = None
|
||||
# This is the custom track destination expected by Tavus
|
||||
self._transport_destination: str | None = "stream"
|
||||
self._transport_destination: str = "stream"
|
||||
self._transport_ready = False
|
||||
self._wav_file: wave.Wave_write | None = None
|
||||
|
||||
async def setup(self, setup: FrameProcessorSetup):
|
||||
"""Set up the Tavus video service.
|
||||
@@ -204,6 +208,21 @@ class TavusVideoService(AIService):
|
||||
"""
|
||||
return await self._client.get_persona_name()
|
||||
|
||||
def _open_wav(self, sample_rate: int):
|
||||
os.makedirs("recordings", exist_ok=True)
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
path = f"recordings/bot_pre_speed_{timestamp}.wav"
|
||||
self._wav_file = wave.open(path, "wb")
|
||||
self._wav_file.setnchannels(1)
|
||||
self._wav_file.setsampwidth(2)
|
||||
self._wav_file.setframerate(sample_rate)
|
||||
logger.info(f"Recording outgoing audio to {path}")
|
||||
|
||||
def _close_wav(self):
|
||||
if self._wav_file:
|
||||
self._wav_file.close()
|
||||
self._wav_file = None
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the Tavus video service.
|
||||
|
||||
@@ -212,10 +231,10 @@ class TavusVideoService(AIService):
|
||||
"""
|
||||
await super().start(frame)
|
||||
await self._client.start(frame)
|
||||
if self._transport_destination:
|
||||
await self._client.register_audio_destination(
|
||||
self._transport_destination, auto_silence=False
|
||||
)
|
||||
await self._client.register_audio_destination(
|
||||
self._transport_destination, auto_silence=False
|
||||
)
|
||||
self._open_wav(self._client.out_sample_rate)
|
||||
await self._create_send_task()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
@@ -227,6 +246,7 @@ class TavusVideoService(AIService):
|
||||
await super().stop(frame)
|
||||
await self._end_conversation()
|
||||
await self._cancel_send_task()
|
||||
self._close_wav()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the Tavus video service.
|
||||
@@ -237,6 +257,7 @@ class TavusVideoService(AIService):
|
||||
await super().cancel(frame)
|
||||
await self._end_conversation()
|
||||
await self._cancel_send_task()
|
||||
self._close_wav()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process frames through the service.
|
||||
@@ -308,9 +329,30 @@ class TavusVideoService(AIService):
|
||||
self._audio_buffer = self._audio_buffer[chunk_size:]
|
||||
|
||||
async def _send_task_handler(self):
|
||||
"""Handle sending audio frames to the Tavus client."""
|
||||
"""Handle sending audio frames to the Tavus client.
|
||||
|
||||
Accumulates 500 ms of audio before sending anything to WebRTC. This
|
||||
pre-buffer absorbs TTS jitter so the WebRTC jitter buffer sees a steady
|
||||
stream rather than bursts separated by silence, which prevents the drift
|
||||
and silence-injection observed without it. On interruption the task is
|
||||
replaced, so the next utterance gets a fresh 500 ms pre-buffer.
|
||||
"""
|
||||
min_prebuffer_bytes = int(self._client.out_sample_rate * 0.5) * 2
|
||||
prebuffer: list[OutputAudioRawFrame] | None = []
|
||||
|
||||
while True:
|
||||
frame = await self._queue.get()
|
||||
if isinstance(frame, OutputAudioRawFrame) and self._client:
|
||||
await self._client.write_audio_frame(frame)
|
||||
if prebuffer is not None:
|
||||
prebuffer.append(frame)
|
||||
if sum(len(f.audio) for f in prebuffer) >= min_prebuffer_bytes:
|
||||
for f in prebuffer:
|
||||
if self._wav_file:
|
||||
self._wav_file.writeframes(f.audio)
|
||||
await self._client.write_audio_frame(f)
|
||||
prebuffer = None
|
||||
else:
|
||||
if self._wav_file:
|
||||
self._wav_file.writeframes(frame.audio)
|
||||
await self._client.write_audio_frame(frame)
|
||||
self._queue.task_done()
|
||||
|
||||
@@ -823,6 +823,23 @@ class BaseOutputTransport(FrameProcessor):
|
||||
|
||||
async def _audio_task_handler(self):
|
||||
"""Main audio processing task handler."""
|
||||
# Pre-buffer: accumulate audio before sending anything to the transport.
|
||||
#
|
||||
# prebuffer is a list while we are still accumulating, and None once the
|
||||
# threshold has been reached and all held frames have been flushed. Using
|
||||
# None as the sentinel avoids a boolean flag and makes the steady-state
|
||||
# branch a simple identity check.
|
||||
#
|
||||
# The pre-buffer resets automatically on each interruption because the
|
||||
# audio task is cancelled and recreated, giving the next utterance a fresh
|
||||
# local variable.
|
||||
min_prebuffer_bytes = (
|
||||
int(self._sample_rate * self._params.audio_out_prebuffer_secs)
|
||||
* 2
|
||||
* self._params.audio_out_channels
|
||||
)
|
||||
prebuffer: list[OutputAudioRawFrame] | None = [] if min_prebuffer_bytes > 0 else None
|
||||
|
||||
async for frame in self._next_frame():
|
||||
# No need to push EndFrame, it's pushed from process_frame().
|
||||
if isinstance(frame, EndFrame):
|
||||
@@ -840,7 +857,20 @@ class BaseOutputTransport(FrameProcessor):
|
||||
# Try to send audio to the transport.
|
||||
try:
|
||||
if isinstance(frame, OutputAudioRawFrame):
|
||||
push_downstream = await self._transport.write_audio_frame(frame)
|
||||
if prebuffer is not None:
|
||||
# Accumulation phase: hold frames until we have enough audio.
|
||||
prebuffer.append(frame)
|
||||
if sum(len(f.audio) for f in prebuffer) >= min_prebuffer_bytes:
|
||||
# Threshold reached: flush all held frames at once, then
|
||||
# switch to direct-write mode for the rest of the utterance.
|
||||
for f in prebuffer:
|
||||
await self._transport.write_audio_frame(f)
|
||||
prebuffer = None
|
||||
# push_downstream stays True so frames flow through the
|
||||
# pipeline even while we are still accumulating.
|
||||
else:
|
||||
# Steady-state: write directly to the transport.
|
||||
push_downstream = await self._transport.write_audio_frame(frame)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} Error writing {frame} to transport: {e}")
|
||||
push_downstream = False
|
||||
|
||||
@@ -34,6 +34,8 @@ class TransportParams(BaseModel):
|
||||
audio_out_mixer: Audio mixer instance or destination mapping.
|
||||
audio_out_destinations: List of audio output destination identifiers.
|
||||
audio_out_end_silence_secs: How much silence to send after an EndFrame (0 for no silence).
|
||||
audio_out_prebuffer_secs: Seconds of audio to accumulate before sending anything to the
|
||||
transport. Resets automatically on each interruption. Defaults to 0.0 (disabled).
|
||||
audio_out_auto_silence: Insert silence frames when the audio output queue is empty.
|
||||
When False, the transport will wait for audio data instead of inserting silence.
|
||||
audio_in_enabled: Enable audio input streaming.
|
||||
@@ -70,6 +72,7 @@ class TransportParams(BaseModel):
|
||||
audio_out_mixer: BaseAudioMixer | Mapping[str | None, BaseAudioMixer] | None = None
|
||||
audio_out_destinations: list[str] = Field(default_factory=list)
|
||||
audio_out_end_silence_secs: int = 2
|
||||
audio_out_prebuffer_secs: float = 0.0
|
||||
audio_out_auto_silence: bool = True
|
||||
audio_in_enabled: bool = False
|
||||
audio_in_sample_rate: int | None = None
|
||||
|
||||
@@ -40,10 +40,16 @@ from pipecat.transports.base_output import BaseOutputTransport
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import (
|
||||
DailyCallbacks,
|
||||
DailyCustomAudioTrackParams,
|
||||
DailyParams,
|
||||
DailyTransportClient,
|
||||
)
|
||||
|
||||
# Opus codec maximum. When the Tavus server supports fast audio delivery it
|
||||
# returns this as stream_declared_sample_rate so that write_frames() blocks for
|
||||
# n/48000 s instead of n/true_rate s, delivering audio faster than real-time.
|
||||
_STREAM_DECLARED_SAMPLE_RATE = 48000
|
||||
|
||||
|
||||
class TavusApi:
|
||||
"""Helper class for interacting with the Tavus API (v2).
|
||||
@@ -69,20 +75,28 @@ class TavusApi:
|
||||
# Only for development
|
||||
self._dev_room_url = os.getenv("TAVUS_SAMPLE_ROOM_URL")
|
||||
|
||||
async def create_conversation(self, replica_id: str, persona_id: str) -> dict:
|
||||
async def create_conversation(self, replica_id: str, persona_id: str, sample_rate: int) -> dict:
|
||||
"""Create a new conversation with the specified replica and persona.
|
||||
|
||||
Args:
|
||||
replica_id: ID of the replica to use in the conversation.
|
||||
persona_id: ID of the persona to use in the conversation.
|
||||
sample_rate: True audio sample rate of the pipeline's output. Sent
|
||||
to Tavus so the server can negotiate fast audio delivery. When
|
||||
the server supports it, the response includes
|
||||
``stream_declared_sample_rate`` — the rate Pipecat should
|
||||
declare to the ``CustomAudioSource`` for faster-than-realtime
|
||||
delivery.
|
||||
|
||||
Returns:
|
||||
Dictionary containing conversation_id and conversation_url.
|
||||
Dictionary containing conversation_id, conversation_url, and
|
||||
optionally stream_declared_sample_rate.
|
||||
"""
|
||||
if self._dev_room_url:
|
||||
return {
|
||||
"conversation_id": self.MOCK_CONVERSATION_ID,
|
||||
"conversation_url": self._dev_room_url,
|
||||
"stream_declared_sample_rate": _STREAM_DECLARED_SAMPLE_RATE,
|
||||
}
|
||||
|
||||
logger.debug(f"Creating Tavus conversation: replica={replica_id}, persona={persona_id}")
|
||||
@@ -90,6 +104,8 @@ class TavusApi:
|
||||
payload = {
|
||||
"replica_id": replica_id,
|
||||
"persona_id": persona_id,
|
||||
# TODO: start to send it when Tavus start to support it.
|
||||
# "sample_rate": sample_rate,
|
||||
}
|
||||
async with self._session.post(url, headers=self._headers, json=payload) as r:
|
||||
r.raise_for_status()
|
||||
@@ -152,11 +168,15 @@ class TavusParams(DailyParams):
|
||||
audio_in_enabled: Whether to enable audio input from participants.
|
||||
audio_out_enabled: Whether to enable audio output to participants.
|
||||
microphone_out_enabled: Whether to enable microphone output track.
|
||||
audio_out_prebuffer_secs: Seconds of audio to accumulate before sending to WebRTC.
|
||||
Absorbs TTS jitter to prevent the WebRTC jitter buffer from injecting silence.
|
||||
Defaults to 0.5.
|
||||
"""
|
||||
|
||||
audio_in_enabled: bool = True
|
||||
audio_out_enabled: bool = True
|
||||
microphone_out_enabled: bool = False
|
||||
audio_out_prebuffer_secs: float = 0.5
|
||||
|
||||
|
||||
class TavusTransportClient:
|
||||
@@ -202,76 +222,60 @@ class TavusTransportClient:
|
||||
self._client: DailyTransportClient | None = None
|
||||
self._callbacks = callbacks
|
||||
self._params = params
|
||||
self._setup: FrameProcessorSetup | None = None
|
||||
self._initialized = False
|
||||
|
||||
async def _initialize(self) -> str:
|
||||
"""Initialize the conversation and return the room URL."""
|
||||
response = await self._api.create_conversation(self._replica_id, self._persona_id)
|
||||
self._conversation_id = response["conversation_id"]
|
||||
return response["conversation_url"]
|
||||
def _build_daily_callbacks(self) -> DailyCallbacks:
|
||||
"""Build the DailyCallbacks object."""
|
||||
return DailyCallbacks(
|
||||
on_active_speaker_changed=partial(
|
||||
self._on_handle_callback, "on_active_speaker_changed"
|
||||
),
|
||||
on_joined=self._on_joined,
|
||||
on_left=self._on_left,
|
||||
on_before_leave=partial(self._on_handle_callback, "on_before_leave"),
|
||||
on_error=partial(self._on_handle_callback, "on_error"),
|
||||
on_app_message=partial(self._on_handle_callback, "on_app_message"),
|
||||
on_call_state_updated=partial(self._on_handle_callback, "on_call_state_updated"),
|
||||
on_client_connected=partial(self._on_handle_callback, "on_client_connected"),
|
||||
on_client_disconnected=partial(self._on_handle_callback, "on_client_disconnected"),
|
||||
on_dialin_connected=partial(self._on_handle_callback, "on_dialin_connected"),
|
||||
on_dialin_ready=partial(self._on_handle_callback, "on_dialin_ready"),
|
||||
on_dialin_stopped=partial(self._on_handle_callback, "on_dialin_stopped"),
|
||||
on_dialin_error=partial(self._on_handle_callback, "on_dialin_error"),
|
||||
on_dialin_warning=partial(self._on_handle_callback, "on_dialin_warning"),
|
||||
on_dialout_answered=partial(self._on_handle_callback, "on_dialout_answered"),
|
||||
on_dialout_connected=partial(self._on_handle_callback, "on_dialout_connected"),
|
||||
on_dialout_stopped=partial(self._on_handle_callback, "on_dialout_stopped"),
|
||||
on_dialout_error=partial(self._on_handle_callback, "on_dialout_error"),
|
||||
on_dialout_warning=partial(self._on_handle_callback, "on_dialout_warning"),
|
||||
on_dtmf_event=partial(self._on_handle_callback, "on_dtmf_event"),
|
||||
on_participant_joined=self._callbacks.on_participant_joined,
|
||||
on_participant_left=self._callbacks.on_participant_left,
|
||||
on_participant_updated=partial(self._on_handle_callback, "on_participant_updated"),
|
||||
on_transcription_message=partial(self._on_handle_callback, "on_transcription_message"),
|
||||
on_recording_started=partial(self._on_handle_callback, "on_recording_started"),
|
||||
on_recording_stopped=partial(self._on_handle_callback, "on_recording_stopped"),
|
||||
on_recording_error=partial(self._on_handle_callback, "on_recording_error"),
|
||||
on_transcription_stopped=partial(self._on_handle_callback, "on_transcription_stopped"),
|
||||
on_transcription_error=partial(self._on_handle_callback, "on_transcription_error"),
|
||||
)
|
||||
|
||||
async def setup(self, setup: FrameProcessorSetup):
|
||||
"""Setup the client and initialize the conversation.
|
||||
"""Save setup context for later use in start().
|
||||
|
||||
Args:
|
||||
setup: The frame processor setup configuration.
|
||||
"""
|
||||
if self._conversation_id is not None:
|
||||
logger.debug(f"Conversation ID already defined: {self._conversation_id}")
|
||||
return
|
||||
try:
|
||||
room_url = await self._initialize()
|
||||
daily_callbacks = DailyCallbacks(
|
||||
on_active_speaker_changed=partial(
|
||||
self._on_handle_callback, "on_active_speaker_changed"
|
||||
),
|
||||
on_joined=self._on_joined,
|
||||
on_left=self._on_left,
|
||||
on_before_leave=partial(self._on_handle_callback, "on_before_leave"),
|
||||
on_error=partial(self._on_handle_callback, "on_error"),
|
||||
on_app_message=partial(self._on_handle_callback, "on_app_message"),
|
||||
on_call_state_updated=partial(self._on_handle_callback, "on_call_state_updated"),
|
||||
on_client_connected=partial(self._on_handle_callback, "on_client_connected"),
|
||||
on_client_disconnected=partial(self._on_handle_callback, "on_client_disconnected"),
|
||||
on_dialin_connected=partial(self._on_handle_callback, "on_dialin_connected"),
|
||||
on_dialin_ready=partial(self._on_handle_callback, "on_dialin_ready"),
|
||||
on_dialin_stopped=partial(self._on_handle_callback, "on_dialin_stopped"),
|
||||
on_dialin_error=partial(self._on_handle_callback, "on_dialin_error"),
|
||||
on_dialin_warning=partial(self._on_handle_callback, "on_dialin_warning"),
|
||||
on_dialout_answered=partial(self._on_handle_callback, "on_dialout_answered"),
|
||||
on_dialout_connected=partial(self._on_handle_callback, "on_dialout_connected"),
|
||||
on_dialout_stopped=partial(self._on_handle_callback, "on_dialout_stopped"),
|
||||
on_dialout_error=partial(self._on_handle_callback, "on_dialout_error"),
|
||||
on_dialout_warning=partial(self._on_handle_callback, "on_dialout_warning"),
|
||||
on_dtmf_event=partial(self._on_handle_callback, "on_dtmf_event"),
|
||||
on_participant_joined=self._callbacks.on_participant_joined,
|
||||
on_participant_left=self._callbacks.on_participant_left,
|
||||
on_participant_updated=partial(self._on_handle_callback, "on_participant_updated"),
|
||||
on_transcription_message=partial(
|
||||
self._on_handle_callback, "on_transcription_message"
|
||||
),
|
||||
on_recording_started=partial(self._on_handle_callback, "on_recording_started"),
|
||||
on_recording_stopped=partial(self._on_handle_callback, "on_recording_stopped"),
|
||||
on_recording_error=partial(self._on_handle_callback, "on_recording_error"),
|
||||
on_transcription_stopped=partial(
|
||||
self._on_handle_callback, "on_transcription_stopped"
|
||||
),
|
||||
on_transcription_error=partial(self._on_handle_callback, "on_transcription_error"),
|
||||
)
|
||||
self._client = DailyTransportClient(
|
||||
room_url, None, "Pipecat", self._params, daily_callbacks, self._bot_name
|
||||
)
|
||||
await self._client.setup(setup)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to setup TavusTransportClient: {e}")
|
||||
await self._api.end_conversation(self._conversation_id)
|
||||
self._conversation_id = None
|
||||
self._setup = setup
|
||||
|
||||
async def cleanup(self):
|
||||
"""Cleanup client resources."""
|
||||
try:
|
||||
await self._client.cleanup()
|
||||
except Exception as e:
|
||||
logger.error(f"Exception during cleanup: {e}")
|
||||
if self._client:
|
||||
try:
|
||||
await self._client.cleanup()
|
||||
except Exception as e:
|
||||
logger.error(f"Exception during cleanup: {e}")
|
||||
|
||||
async def _on_joined(self, data):
|
||||
"""Handle joined event."""
|
||||
@@ -295,12 +299,56 @@ class TavusTransportClient:
|
||||
return await self._api.get_persona_name(self._persona_id)
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the client and join the room.
|
||||
"""Create the conversation, build the Daily client, and join the room.
|
||||
|
||||
Args:
|
||||
frame: The start frame containing initialization parameters.
|
||||
"""
|
||||
if self._initialized:
|
||||
return
|
||||
self._initialized = True
|
||||
|
||||
logger.debug("TavusTransportClient start invoked!")
|
||||
try:
|
||||
sample_rate = self._params.audio_out_sample_rate or frame.audio_out_sample_rate
|
||||
response = await self._api.create_conversation(
|
||||
self._replica_id, self._persona_id, sample_rate
|
||||
)
|
||||
self._conversation_id = response["conversation_id"]
|
||||
room_url = response["conversation_url"]
|
||||
|
||||
params = self._params
|
||||
stream_declared_sample_rate = response.get("stream_declared_sample_rate")
|
||||
if stream_declared_sample_rate:
|
||||
# Tavus supports fast audio delivery: we write true-rate PCM bytes into a
|
||||
# CustomAudioSource declared at stream_declared_sample_rate (e.g. 48 kHz).
|
||||
# write_frames() blocks for n/declared_rate seconds instead of n/true_rate
|
||||
# seconds, so audio is delivered faster than real-time. The receiver must
|
||||
# also request the same declared rate so WebRTC skips resampling and every
|
||||
# original byte arrives intact.
|
||||
# We always override sample_rate here even if the user already provided
|
||||
# "stream" params, because the declared rate must match what the server
|
||||
# negotiated — other fields (channels, send_settings) are preserved.
|
||||
logger.debug(
|
||||
f"Tavus fast audio: true_rate={sample_rate} declared_rate={stream_declared_sample_rate}"
|
||||
)
|
||||
existing = dict(params.custom_audio_track_params or {})
|
||||
existing["stream"] = (
|
||||
existing.get("stream") or DailyCustomAudioTrackParams()
|
||||
).model_copy(update={"sample_rate": stream_declared_sample_rate})
|
||||
params = params.model_copy(update={"custom_audio_track_params": existing})
|
||||
|
||||
self._client = DailyTransportClient(
|
||||
room_url, None, "Pipecat", params, self._build_daily_callbacks(), self._bot_name
|
||||
)
|
||||
await self._client.setup(self._setup)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start TavusTransportClient: {e}")
|
||||
await self._api.end_conversation(self._conversation_id)
|
||||
self._conversation_id = None
|
||||
self._initialized = False
|
||||
return
|
||||
|
||||
await self._client.start(frame)
|
||||
await self._client.join()
|
||||
|
||||
@@ -598,7 +646,11 @@ class TavusOutputTransport(BaseOutputTransport):
|
||||
await self._client.start(frame)
|
||||
|
||||
if self._transport_destination:
|
||||
await self._client.register_audio_destination(self._transport_destination)
|
||||
# auto_silence=False so the CustomAudioSource only writes frames when
|
||||
# there is real TTS audio.
|
||||
await self._client.register_audio_destination(
|
||||
self._transport_destination, auto_silence=False
|
||||
)
|
||||
|
||||
await self.set_transport_ready(frame)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user