Save Smart Turn input data if SMART_TURN_LOG_DATA is set
This commit is contained in:
11
.gitignore
vendored
11
.gitignore
vendored
@@ -4,7 +4,14 @@ __pycache__/
|
||||
*~
|
||||
venv
|
||||
.venv
|
||||
/.idea
|
||||
.idea
|
||||
.gradle
|
||||
.next
|
||||
next-env.d.ts
|
||||
local.properties
|
||||
*.log
|
||||
*.lock
|
||||
smart_turn_audio_log
|
||||
#*#
|
||||
|
||||
# Distribution / Packaging
|
||||
@@ -27,7 +34,7 @@ share/python-wheels/
|
||||
*.egg
|
||||
MANIFEST
|
||||
.DS_Store
|
||||
.env
|
||||
.env*
|
||||
fly.toml
|
||||
|
||||
# Examples
|
||||
|
||||
@@ -16,6 +16,7 @@ import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import BaseSmartTurn
|
||||
from pipecat.utils.env import env_truthy
|
||||
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
@@ -48,6 +49,8 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._log_data = env_truthy("SMART_TURN_LOG_DATA", default=False)
|
||||
|
||||
if not smart_turn_model_path:
|
||||
# Load bundled model
|
||||
model_name = "smart-turn-v3.2-cpu.onnx"
|
||||
@@ -81,6 +84,49 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
|
||||
|
||||
logger.debug("Loaded Local Smart Turn v3.x")
|
||||
|
||||
def _write_audio_to_wav(
|
||||
self, audio_array: np.ndarray, sample_rate: int = 16000, suffix: str = ""
|
||||
) -> None:
|
||||
"""Write audio data to a WAV file in a background thread.
|
||||
|
||||
Args:
|
||||
audio_array: The audio data as a numpy array (float32, normalized to [-1, 1]).
|
||||
sample_rate: The sample rate of the audio data.
|
||||
suffix: Optional suffix to append to the filename (e.g., "_raw", "_padded").
|
||||
"""
|
||||
import wave
|
||||
from datetime import datetime
|
||||
import os
|
||||
import threading
|
||||
|
||||
# Generate filename with current timestamp (millisecond precision)
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d__%H:%M:%S.%f")[:-3]
|
||||
log_dir = "./smart_turn_audio_log"
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
filename = os.path.join(log_dir, f"{timestamp}{suffix}.wav")
|
||||
|
||||
# Make a copy of the audio data to avoid issues with the array being modified
|
||||
audio_copy = audio_array.copy()
|
||||
|
||||
def write_wav():
|
||||
try:
|
||||
# Convert float32 audio to int16 for WAV file
|
||||
audio_int16 = (audio_copy * 32767).astype(np.int16)
|
||||
|
||||
with wave.open(filename, "wb") as wav_file:
|
||||
wav_file.setnchannels(1) # Mono
|
||||
wav_file.setsampwidth(2) # 2 bytes for int16
|
||||
wav_file.setframerate(sample_rate)
|
||||
wav_file.writeframes(audio_int16.tobytes())
|
||||
|
||||
logger.debug(f"Wrote audio to {filename}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write audio to {filename}: {e}")
|
||||
|
||||
# Start background thread to write the WAV file
|
||||
thread = threading.Thread(target=write_wav, daemon=True)
|
||||
thread.start()
|
||||
|
||||
def _predict_endpoint(self, audio_array: np.ndarray) -> Dict[str, Any]:
|
||||
"""Predict end-of-turn using local ONNX model."""
|
||||
|
||||
@@ -95,6 +141,8 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
|
||||
return np.pad(audio_array, (padding, 0), mode="constant", constant_values=0)
|
||||
return audio_array
|
||||
|
||||
audio_for_logging = audio_array
|
||||
|
||||
# Truncate to 8 seconds (keeping the end) or pad to 8 seconds
|
||||
audio_array = truncate_audio_to_last_n_seconds(audio_array, n_seconds=8)
|
||||
|
||||
@@ -122,7 +170,11 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
|
||||
# Make prediction (1 for Complete, 0 for Incomplete)
|
||||
prediction = 1 if probability > 0.5 else 0
|
||||
|
||||
if self._log_data:
|
||||
suffix = "_complete" if prediction == 1 else "_incomplete"
|
||||
self._write_audio_to_wav(audio_for_logging, sample_rate=16000, suffix=suffix)
|
||||
|
||||
return {
|
||||
"prediction": prediction,
|
||||
"probability": probability,
|
||||
}
|
||||
}
|
||||
54
src/pipecat/utils/env.py
Normal file
54
src/pipecat/utils/env.py
Normal file
@@ -0,0 +1,54 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Environment variable helpers.
|
||||
|
||||
This module provides small, centralized parsing helpers for environment variables.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
|
||||
class InvalidEnvVarValueError(ValueError):
|
||||
"""Raised when an environment variable value cannot be parsed."""
|
||||
|
||||
def __init__(self, name: str, value: str, expected: str):
|
||||
super().__init__(f"Invalid value for env var {name!r}: {value!r}. Expected {expected}.")
|
||||
self.name = name
|
||||
self.value = value
|
||||
self.expected = expected
|
||||
|
||||
|
||||
def env_truthy(name: str, default: bool = False) -> bool:
|
||||
"""Interpret an environment variable as a boolean.
|
||||
|
||||
- If the variable is **not set**, returns `default`.
|
||||
- If the variable is set to a recognized boolean string, returns the parsed value.
|
||||
- Otherwise, raises `InvalidEnvVarValueError`.
|
||||
|
||||
Recognized values (case-insensitive, whitespace ignored):
|
||||
- Truthy: "1", "true", "yes", "y", "on"
|
||||
- Falsy: "0", "false", "no", "n", "off", ""
|
||||
"""
|
||||
|
||||
raw = os.getenv(name)
|
||||
if raw is None:
|
||||
return default
|
||||
|
||||
val = raw.strip().lower()
|
||||
if val in {"1", "true", "yes", "y", "on"}:
|
||||
return True
|
||||
if val in {"0", "false", "no", "n", "off", ""}:
|
||||
return False
|
||||
|
||||
raise InvalidEnvVarValueError(
|
||||
name=name,
|
||||
value=raw,
|
||||
expected='true or false',
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user