Files
pipecat/examples/foundational/13e-whisper-mlx.py
2025-04-24 17:14:18 -07:00

98 lines
2.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import argparse
import time
from dotenv import load_dotenv
from loguru import logger
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.audio.vad.vad_analyzer import VADParams
from pipecat.frames.frames import Frame, TranscriptionFrame, UserStoppedSpeakingFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.services.whisper.stt import MLXModel, WhisperSTTServiceMLX
from pipecat.transports.base_transport import TransportParams
from pipecat.transports.network.small_webrtc import SmallWebRTCTransport
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
load_dotenv(override=True)
STOP_SECS = 2.0
class TranscriptionLogger(FrameProcessor):
"""Measures transcription latency.
Uses the (intentionally) long STOP_SECS parameter to give the transcription time to finish,
then outputs the timing between when the VAD first classified audio input as not-speech and
the delivery of the last transcription frame.
"""
def __init__(self):
super().__init__()
self._last_transcription_time = time.time()
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, UserStoppedSpeakingFrame):
logger.debug(
f"Transcription latency: {(STOP_SECS - (time.time() - self._last_transcription_time)):.2f}"
)
if isinstance(frame, TranscriptionFrame):
self._last_transcription_time = time.time()
async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespace):
logger.info(f"Starting bot")
transport = SmallWebRTCTransport(
webrtc_connection=webrtc_connection,
params=TransportParams(
audio_in_enabled=True,
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=STOP_SECS)),
),
)
stt = WhisperSTTServiceMLX(model=MLXModel.LARGE_V3_TURBO)
tl = TranscriptionLogger()
pipeline = Pipeline([transport.input(), stt, tl])
task = PipelineTask(
pipeline,
params=PipelineParams(
enable_metrics=True,
report_only_initial_ttfb=False,
),
)
@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(transport, client):
logger.info(f"Client disconnected")
@transport.event_handler("on_client_closed")
async def on_client_closed(transport, client):
logger.info(f"Client closed connection")
await task.cancel()
runner = PipelineRunner(handle_sigint=False)
await runner.run(task)
if __name__ == "__main__":
from run import main
main()