Save Smart Turn input data if SMART_TURN_LOG_DATA is set
This commit is contained in:
11
.gitignore
vendored
11
.gitignore
vendored
@@ -4,7 +4,14 @@ __pycache__/
|
|||||||
*~
|
*~
|
||||||
venv
|
venv
|
||||||
.venv
|
.venv
|
||||||
/.idea
|
.idea
|
||||||
|
.gradle
|
||||||
|
.next
|
||||||
|
next-env.d.ts
|
||||||
|
local.properties
|
||||||
|
*.log
|
||||||
|
*.lock
|
||||||
|
smart_turn_audio_log
|
||||||
#*#
|
#*#
|
||||||
|
|
||||||
# Distribution / Packaging
|
# Distribution / Packaging
|
||||||
@@ -27,7 +34,7 @@ share/python-wheels/
|
|||||||
*.egg
|
*.egg
|
||||||
MANIFEST
|
MANIFEST
|
||||||
.DS_Store
|
.DS_Store
|
||||||
.env
|
.env*
|
||||||
fly.toml
|
fly.toml
|
||||||
|
|
||||||
# Examples
|
# Examples
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import numpy as np
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from pipecat.audio.turn.smart_turn.base_smart_turn import BaseSmartTurn
|
from pipecat.audio.turn.smart_turn.base_smart_turn import BaseSmartTurn
|
||||||
|
from pipecat.utils.env import env_truthy
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
@@ -48,6 +49,8 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
|
|||||||
"""
|
"""
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
self._log_data = env_truthy("SMART_TURN_LOG_DATA", default=False)
|
||||||
|
|
||||||
if not smart_turn_model_path:
|
if not smart_turn_model_path:
|
||||||
# Load bundled model
|
# Load bundled model
|
||||||
model_name = "smart-turn-v3.2-cpu.onnx"
|
model_name = "smart-turn-v3.2-cpu.onnx"
|
||||||
@@ -81,6 +84,49 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
|
|||||||
|
|
||||||
logger.debug("Loaded Local Smart Turn v3.x")
|
logger.debug("Loaded Local Smart Turn v3.x")
|
||||||
|
|
||||||
|
def _write_audio_to_wav(
|
||||||
|
self, audio_array: np.ndarray, sample_rate: int = 16000, suffix: str = ""
|
||||||
|
) -> None:
|
||||||
|
"""Write audio data to a WAV file in a background thread.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_array: The audio data as a numpy array (float32, normalized to [-1, 1]).
|
||||||
|
sample_rate: The sample rate of the audio data.
|
||||||
|
suffix: Optional suffix to append to the filename (e.g., "_raw", "_padded").
|
||||||
|
"""
|
||||||
|
import wave
|
||||||
|
from datetime import datetime
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
|
||||||
|
# Generate filename with current timestamp (millisecond precision)
|
||||||
|
timestamp = datetime.now().strftime("%Y-%m-%d__%H:%M:%S.%f")[:-3]
|
||||||
|
log_dir = "./smart_turn_audio_log"
|
||||||
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
|
filename = os.path.join(log_dir, f"{timestamp}{suffix}.wav")
|
||||||
|
|
||||||
|
# Make a copy of the audio data to avoid issues with the array being modified
|
||||||
|
audio_copy = audio_array.copy()
|
||||||
|
|
||||||
|
def write_wav():
|
||||||
|
try:
|
||||||
|
# Convert float32 audio to int16 for WAV file
|
||||||
|
audio_int16 = (audio_copy * 32767).astype(np.int16)
|
||||||
|
|
||||||
|
with wave.open(filename, "wb") as wav_file:
|
||||||
|
wav_file.setnchannels(1) # Mono
|
||||||
|
wav_file.setsampwidth(2) # 2 bytes for int16
|
||||||
|
wav_file.setframerate(sample_rate)
|
||||||
|
wav_file.writeframes(audio_int16.tobytes())
|
||||||
|
|
||||||
|
logger.debug(f"Wrote audio to {filename}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to write audio to {filename}: {e}")
|
||||||
|
|
||||||
|
# Start background thread to write the WAV file
|
||||||
|
thread = threading.Thread(target=write_wav, daemon=True)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
def _predict_endpoint(self, audio_array: np.ndarray) -> Dict[str, Any]:
|
def _predict_endpoint(self, audio_array: np.ndarray) -> Dict[str, Any]:
|
||||||
"""Predict end-of-turn using local ONNX model."""
|
"""Predict end-of-turn using local ONNX model."""
|
||||||
|
|
||||||
@@ -95,6 +141,8 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
|
|||||||
return np.pad(audio_array, (padding, 0), mode="constant", constant_values=0)
|
return np.pad(audio_array, (padding, 0), mode="constant", constant_values=0)
|
||||||
return audio_array
|
return audio_array
|
||||||
|
|
||||||
|
audio_for_logging = audio_array
|
||||||
|
|
||||||
# Truncate to 8 seconds (keeping the end) or pad to 8 seconds
|
# Truncate to 8 seconds (keeping the end) or pad to 8 seconds
|
||||||
audio_array = truncate_audio_to_last_n_seconds(audio_array, n_seconds=8)
|
audio_array = truncate_audio_to_last_n_seconds(audio_array, n_seconds=8)
|
||||||
|
|
||||||
@@ -122,7 +170,11 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
|
|||||||
# Make prediction (1 for Complete, 0 for Incomplete)
|
# Make prediction (1 for Complete, 0 for Incomplete)
|
||||||
prediction = 1 if probability > 0.5 else 0
|
prediction = 1 if probability > 0.5 else 0
|
||||||
|
|
||||||
|
if self._log_data:
|
||||||
|
suffix = "_complete" if prediction == 1 else "_incomplete"
|
||||||
|
self._write_audio_to_wav(audio_for_logging, sample_rate=16000, suffix=suffix)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"prediction": prediction,
|
"prediction": prediction,
|
||||||
"probability": probability,
|
"probability": probability,
|
||||||
}
|
}
|
||||||
54
src/pipecat/utils/env.py
Normal file
54
src/pipecat/utils/env.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2024-2026, Daily
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: BSD 2-Clause License
|
||||||
|
#
|
||||||
|
|
||||||
|
"""Environment variable helpers.
|
||||||
|
|
||||||
|
This module provides small, centralized parsing helpers for environment variables.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidEnvVarValueError(ValueError):
|
||||||
|
"""Raised when an environment variable value cannot be parsed."""
|
||||||
|
|
||||||
|
def __init__(self, name: str, value: str, expected: str):
|
||||||
|
super().__init__(f"Invalid value for env var {name!r}: {value!r}. Expected {expected}.")
|
||||||
|
self.name = name
|
||||||
|
self.value = value
|
||||||
|
self.expected = expected
|
||||||
|
|
||||||
|
|
||||||
|
def env_truthy(name: str, default: bool = False) -> bool:
|
||||||
|
"""Interpret an environment variable as a boolean.
|
||||||
|
|
||||||
|
- If the variable is **not set**, returns `default`.
|
||||||
|
- If the variable is set to a recognized boolean string, returns the parsed value.
|
||||||
|
- Otherwise, raises `InvalidEnvVarValueError`.
|
||||||
|
|
||||||
|
Recognized values (case-insensitive, whitespace ignored):
|
||||||
|
- Truthy: "1", "true", "yes", "y", "on"
|
||||||
|
- Falsy: "0", "false", "no", "n", "off", ""
|
||||||
|
"""
|
||||||
|
|
||||||
|
raw = os.getenv(name)
|
||||||
|
if raw is None:
|
||||||
|
return default
|
||||||
|
|
||||||
|
val = raw.strip().lower()
|
||||||
|
if val in {"1", "true", "yes", "y", "on"}:
|
||||||
|
return True
|
||||||
|
if val in {"0", "false", "no", "n", "off", ""}:
|
||||||
|
return False
|
||||||
|
|
||||||
|
raise InvalidEnvVarValueError(
|
||||||
|
name=name,
|
||||||
|
value=raw,
|
||||||
|
expected='true or false',
|
||||||
|
)
|
||||||
|
|
||||||
Reference in New Issue
Block a user