From a2e76bcad877f68314a2e005f0b2df3ac96b1529 Mon Sep 17 00:00:00 2001 From: marcus-daily <111281783+marcus-daily@users.noreply.github.com> Date: Mon, 18 Aug 2025 14:54:52 +0100 Subject: [PATCH] Smart Turn V3 support --- examples/foundational/38b-smart-turn-local.py | 23 +--- pyproject.toml | 1 + .../turn/smart_turn/local_smart_turn_v3.py | 102 ++++++++++++++++++ 3 files changed, 108 insertions(+), 18 deletions(-) create mode 100644 src/pipecat/audio/turn/smart_turn/local_smart_turn_v3.py diff --git a/examples/foundational/38b-smart-turn-local.py b/examples/foundational/38b-smart-turn-local.py index 372811c39..f736d5440 100644 --- a/examples/foundational/38b-smart-turn-local.py +++ b/examples/foundational/38b-smart-turn-local.py @@ -11,7 +11,7 @@ from dotenv import load_dotenv from loguru import logger from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams -from pipecat.audio.turn.smart_turn.local_smart_turn_v2 import LocalSmartTurnAnalyzerV2 +from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3 from pipecat.audio.vad.silero import SileroVADAnalyzer from pipecat.audio.vad.vad_analyzer import VADParams from pipecat.frames.frames import LLMRunFrame @@ -31,20 +31,7 @@ from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams load_dotenv(override=True) # To use this locally, set the environment variable LOCAL_SMART_TURN_MODEL_PATH -# to the path where the smart-turn repo is cloned. -# -# Example setup: -# -# # Git LFS (Large File Storage) -# brew install git-lfs -# # Hugging Face uses LFS to store large model files, including .mlpackage -# git lfs install -# # Clone the repo with the smart_turn_classifier.mlpackage -# git clone https://huggingface.co/pipecat-ai/smart-turn-v2 -# -# Then set the env variable: -# export LOCAL_SMART_TURN_MODEL_PATH=./smart-turn -# or add it to your .env file +# to the Smart Turn v3 ONNX model file. smart_turn_model_path = os.getenv("LOCAL_SMART_TURN_MODEL_PATH") # We store functions so objects (e.g. SileroVADAnalyzer) don't get @@ -55,7 +42,7 @@ transport_params = { audio_in_enabled=True, audio_out_enabled=True, vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)), - turn_analyzer=LocalSmartTurnAnalyzerV2( + turn_analyzer=LocalSmartTurnAnalyzerV3( smart_turn_model_path=smart_turn_model_path, params=SmartTurnParams() ), ), @@ -63,7 +50,7 @@ transport_params = { audio_in_enabled=True, audio_out_enabled=True, vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)), - turn_analyzer=LocalSmartTurnAnalyzerV2( + turn_analyzer=LocalSmartTurnAnalyzerV3( smart_turn_model_path=smart_turn_model_path, params=SmartTurnParams() ), ), @@ -71,7 +58,7 @@ transport_params = { audio_in_enabled=True, audio_out_enabled=True, vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)), - turn_analyzer=LocalSmartTurnAnalyzerV2( + turn_analyzer=LocalSmartTurnAnalyzerV3( smart_turn_model_path=smart_turn_model_path, params=SmartTurnParams() ), ), diff --git a/pyproject.toml b/pyproject.toml index e9a32b2ac..4942c4916 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,6 +95,7 @@ sambanova = [] sarvam = [ "websockets>=13.1,<15.0" ] sentry = [ "sentry-sdk~=2.23.1" ] local-smart-turn = [ "coremltools>=8.0", "transformers", "torch>=2.5.0,<3", "torchaudio>=2.5.0,<3" ] +local-smart-turn-v3 = [ "transformers", "torch>=2.5.0,<3", "torchaudio>=2.5.0,<3", "onnxruntime" ] remote-smart-turn = [] silero = [ "onnxruntime~=1.20.1" ] simli = [ "simli-ai~=0.1.10"] diff --git a/src/pipecat/audio/turn/smart_turn/local_smart_turn_v3.py b/src/pipecat/audio/turn/smart_turn/local_smart_turn_v3.py new file mode 100644 index 000000000..eeab4eb87 --- /dev/null +++ b/src/pipecat/audio/turn/smart_turn/local_smart_turn_v3.py @@ -0,0 +1,102 @@ +# +# Copyright (c) 2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +"""Local PyTorch turn analyzer for on-device ML inference using the smart-turn-v3 model. + +This module provides a smart turn analyzer that uses an ONNX model for +local end-of-turn detection without requiring network connectivity. +""" + +from typing import Any, Dict + +import numpy as np +from loguru import logger + +from pipecat.audio.turn.smart_turn.base_smart_turn import BaseSmartTurn + +try: + from transformers import WhisperFeatureExtractor + import onnxruntime as ort +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error( + "In order to use LocalSmartTurnAnalyzerV3, you need to `pip install pipecat-ai[local-smart-turn-v3]`." + ) + raise Exception(f"Missing module: {e}") + + +class LocalSmartTurnAnalyzerV3(BaseSmartTurn): + """Local turn analyzer using the smart-turn-v2 PyTorch model. + + Provides end-of-turn detection using locally-stored PyTorch models, + enabling offline operation without network dependencies. Uses + Wav2Vec2 architecture for audio sequence classification. + """ + + def __init__(self, *, smart_turn_model_path: str, **kwargs): + """Initialize the local PyTorch smart-turn-v3 analyzer. + + Args: + smart_turn_model_path: Path to the ONNX model file. + **kwargs: Additional arguments passed to BaseSmartTurn. + """ + super().__init__(**kwargs) + + if not smart_turn_model_path: + raise ValueError("smart_turn_model_path must be provided") + + logger.debug("Loading Local Smart Turn v3 model...") + + self._feature_extractor = WhisperFeatureExtractor(chunk_length=8) + self._session = ort.InferenceSession(smart_turn_model_path) + + logger.debug("Loaded Local Smart Turn v3") + + async def _predict_endpoint(self, audio_array: np.ndarray) -> Dict[str, Any]: + """Predict end-of-turn using local ONNX model.""" + + def truncate_audio_to_last_n_seconds(audio_array, n_seconds=8, sample_rate=16000): + """Truncate audio to last n seconds or pad with zeros to meet n seconds.""" + max_samples = n_seconds * sample_rate + if len(audio_array) > max_samples: + return audio_array[-max_samples:] + elif len(audio_array) < max_samples: + # Pad with zeros at the beginning + padding = max_samples - len(audio_array) + return np.pad(audio_array, (padding, 0), mode='constant', constant_values=0) + return audio_array + + # Truncate to 8 seconds (keeping the end) or pad to 8 seconds + audio_array = truncate_audio_to_last_n_seconds(audio_array, n_seconds=8) + + # Process audio using Whisper's feature extractor + inputs = self._feature_extractor( + audio_array, + sampling_rate=16000, + return_tensors="pt", + padding="max_length", + max_length=8 * 16000, + truncation=True, + do_normalize=True, + ) + + # Convert to numpy and ensure correct shape for ONNX + input_features = inputs.input_features.squeeze(0).numpy().astype(np.float32) + input_features = np.expand_dims(input_features, axis=0) # Add batch dimension + + # Run ONNX inference + outputs = self._session.run(None, {"input_features": input_features}) + + # Extract probability (ONNX model returns sigmoid probabilities) + probability = outputs[0][0].item() + + # Make prediction (1 for Complete, 0 for Incomplete) + prediction = 1 if probability > 0.5 else 0 + + return { + "prediction": prediction, + "probability": probability, + }