use new InputAudioRawFrame and OutputAudioRawFrame

This commit is contained in:
joachimchauvet
2024-10-05 14:16:44 +03:00
parent b373bc82b5
commit 86143f79a1

View File

@@ -1,25 +1,17 @@
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
from dataclasses import dataclass
from typing import Any, Awaitable, Callable, List
from pydantic import BaseModel
import numpy as np
from scipy import signal
from loguru import logger
from pipecat.frames.frames import (
AudioRawFrame,
CancelFrame,
EndFrame,
Frame,
InputAudioRawFrame,
MetricsFrame,
OutputAudioRawFrame,
StartFrame,
TransportMessageFrame,
)
@@ -29,13 +21,13 @@ from pipecat.metrics.metrics import (
TTFBMetricsData,
TTSUsageMetricsData,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.processors.frame_processor import FrameDirection
from pipecat.transports.base_input import BaseInputTransport
from pipecat.transports.base_output import BaseOutputTransport
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.vad.vad_analyzer import VADAnalyzer
from loguru import logger
from pydantic import BaseModel
from scipy import signal
try:
from livekit import rtc
@@ -364,10 +356,15 @@ class LiveKitInputTransport(BaseInputTransport):
if audio_data:
audio_frame_event, participant_id = audio_data
pipecat_audio_frame = self._convert_livekit_audio_to_pipecat(audio_frame_event)
await self.push_audio_frame(pipecat_audio_frame)
input_audio_frame = InputAudioRawFrame(
audio=pipecat_audio_frame.audio,
sample_rate=pipecat_audio_frame.sample_rate,
num_channels=pipecat_audio_frame.num_channels,
)
await self.push_frame(
pipecat_audio_frame
) # TODO: ensure audio frames are pushed with the default BaseInputTransport.push_audio_frame()
await self.push_audio_frame(input_audio_frame)
except asyncio.CancelledError:
logger.info("Audio input task cancelled")
break
@@ -390,9 +387,11 @@ class LiveKitInputTransport(BaseInputTransport):
if sample_rate != self._current_sample_rate:
self._current_sample_rate = sample_rate
self._vad_analyzer = VADAnalyzer(
sample_rate=self._current_sample_rate, num_channels=self._params.audio_in_channels
)
if self._params.vad_enabled:
self._vad_analyzer = VADAnalyzer(
sample_rate=self._current_sample_rate,
num_channels=self._params.audio_in_channels,
)
return AudioRawFrame(
audio=audio_data.tobytes(),
@@ -445,11 +444,11 @@ class LiveKitOutputTransport(BaseOutputTransport):
if isinstance(d, TTFBMetricsData):
if "ttfb" not in metrics:
metrics["ttfb"] = []
metrics["ttfb"].append(d.model_dump())
metrics["ttfb"].append(d.model_dump(exclude_none=True))
elif isinstance(d, ProcessingMetricsData):
if "processing" not in metrics:
metrics["processing"] = []
metrics["processing"].append(d.model_dump())
metrics["processing"].append(d.model_dump(exclude_none=True))
elif isinstance(d, LLMUsageMetricsData):
if "tokens" not in metrics:
metrics["tokens"] = []
@@ -457,7 +456,7 @@ class LiveKitOutputTransport(BaseOutputTransport):
elif isinstance(d, TTSUsageMetricsData):
if "characters" not in metrics:
metrics["characters"] = []
metrics["characters"].append(d.model_dump())
metrics["characters"].append(d.model_dump(exclude_none=True))
message = LiveKitTransportMessageFrame(
message={"type": "pipecat-metrics", "metrics": metrics}
@@ -494,13 +493,20 @@ class LiveKitTransport(BaseTransport):
):
super().__init__(input_name=input_name, output_name=output_name, loop=loop)
self._url = url
self._token = token
self._room_name = room_name
callbacks = LiveKitCallbacks(
on_connected=self._on_connected,
on_disconnected=self._on_disconnected,
on_participant_connected=self._on_participant_connected,
on_participant_disconnected=self._on_participant_disconnected,
on_audio_track_subscribed=self._on_audio_track_subscribed,
on_audio_track_unsubscribed=self._on_audio_track_unsubscribed,
on_data_received=self._on_data_received,
on_first_participant_joined=self._on_first_participant_joined,
)
self._params = params
self._client = LiveKitTransportClient(
url, token, room_name, self._params, self._create_callbacks(), self._loop
url, token, room_name, self._params, callbacks, self._loop
)
self._input: LiveKitInputTransport | None = None
self._output: LiveKitOutputTransport | None = None
@@ -516,24 +522,12 @@ class LiveKitTransport(BaseTransport):
self._register_event_handler("on_participant_left")
self._register_event_handler("on_call_state_updated")
def _create_callbacks(self) -> LiveKitCallbacks:
return LiveKitCallbacks(
on_connected=self._on_connected,
on_disconnected=self._on_disconnected,
on_participant_connected=self._on_participant_connected,
on_participant_disconnected=self._on_participant_disconnected,
on_audio_track_subscribed=self._on_audio_track_subscribed,
on_audio_track_unsubscribed=self._on_audio_track_unsubscribed,
on_data_received=self._on_data_received,
on_first_participant_joined=self._on_first_participant_joined,
)
def input(self) -> FrameProcessor:
def input(self) -> LiveKitInputTransport:
if not self._input:
self._input = LiveKitInputTransport(self._client, self._params, name=self._input_name)
return self._input
def output(self) -> FrameProcessor:
def output(self) -> LiveKitOutputTransport:
if not self._output:
self._output = LiveKitOutputTransport(
self._client, self._params, name=self._output_name
@@ -544,7 +538,7 @@ class LiveKitTransport(BaseTransport):
def participant_id(self) -> str:
return self._client.participant_id
async def send_audio(self, frame: AudioRawFrame):
async def send_audio(self, frame: OutputAudioRawFrame):
if self._output:
await self._output.process_frame(frame, FrameDirection.DOWNSTREAM)