Files
pipecat/tests/test_rnnoise_cancellation.py
2026-01-30 10:07:34 -08:00

178 lines
6.8 KiB
Python

#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import unittest
import numpy as np
try:
import pyrnnoise
except ImportError:
pyrnnoise = None
from pipecat.audio.filters.rnnoise_filter import RNNoiseFilter
from pipecat.frames.frames import FilterEnableFrame
class TestRNNoiseCancellation(unittest.IsolatedAsyncioTestCase):
async def test_rnnoise_cancellation_functionality(self):
print("\nStarting Noise Cancellation Test")
# 1. Check for pyrnnoise
if pyrnnoise is None:
self.skipTest("pyrnnoise not installed. Cannot verify actual noise cancellation.")
# 2. Generate clean speech-like audio (Harmonic series)
sample_rate = 48000
duration = 2.0
t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False)
# Fundamental 200Hz + harmonics
clean_signal = np.sin(2 * np.pi * 200 * t) * 0.5
clean_signal += np.sin(2 * np.pi * 400 * t) * 0.3
clean_signal += np.sin(2 * np.pi * 600 * t) * 0.2
# Apply envelope to simulate speech (bursts)
# sin(2*pi*2*t) has period 0.5s.
envelope = np.sin(2 * np.pi * 2 * t)
envelope = np.clip(envelope, 0, 1)
clean_signal *= envelope
# 3. Add Noise (White Noise)
noise_level = 0.1 # Reduced noise level slightly to make speech clearer for alignment
noise = np.random.normal(0, noise_level, len(t))
noisy_signal = clean_signal + noise
# Normalize to int16 range
noisy_signal = np.clip(noisy_signal, -1, 1)
noisy_int16 = (noisy_signal * 32767).astype(np.int16)
noisy_bytes = noisy_int16.tobytes()
clean_int16 = (clean_signal * 32767).astype(np.int16)
print(f"Generated 2s of noisy audio at {sample_rate}Hz")
# 4. Initialize RNNoiseFilter
rnnoise_filter = RNNoiseFilter()
await rnnoise_filter.start(sample_rate)
await rnnoise_filter.process_frame(FilterEnableFrame(enable=True))
# 5. Process
# Feed in chunks
chunk_size = 960 # 20ms
processed_audio = b""
for i in range(0, len(noisy_bytes), chunk_size):
chunk = noisy_bytes[i : i + chunk_size]
result = await rnnoise_filter.filter(chunk)
processed_audio += result
await rnnoise_filter.stop()
print(f"Output audio size: {len(processed_audio)}")
# 6. Verify Noise Reduction
output_int16 = np.frombuffer(processed_audio, dtype=np.int16)
# Truncate to min length
min_len = min(len(clean_int16), len(output_int16))
clean_trunc = clean_int16[:min_len]
output_trunc = output_int16[:min_len]
noisy_trunc = noisy_int16[:min_len]
# 7. Compensate for Delay
# Use cross-correlation on a segment to find delay
# We expect output to be delayed relative to clean (lag is positive)
# search window +/- 2000 samples (~40ms)
search_range = 2400 # 50ms
# Use the middle of the signal to avoid edge effects and have strong signal
mid_point = min_len // 2
window_len = 4800 # 100ms
ref_sig = clean_trunc[mid_point : mid_point + window_len].astype(float)
target_sig = output_trunc[
mid_point - search_range : mid_point + window_len + search_range
].astype(float)
correlation = np.correlate(target_sig, ref_sig, mode="valid")
best_idx = np.argmax(correlation)
# The 'valid' mode correlation result corresponds to shifts.
# index 0 matches alignment where ref starts at target start.
# target start is (mid_point - search_range).
# ref start is mid_point.
# So index 0 means target is shifted left by search_range (or delay = -search_range).
# delay = best_idx - search_range
delay = best_idx - search_range
print(f"Detected delay: {delay} samples ({delay / sample_rate * 1000:.2f} ms)")
# Shift output to align
if delay > 0:
# Output is delayed, so we need to look at output[delay:] to match clean[0:]
aligned_output = output_trunc[delay:]
aligned_clean = clean_trunc[: len(aligned_output)]
aligned_noisy = noisy_trunc[: len(aligned_output)]
elif delay < 0:
# Output is ahead (unlikely for causal filter), but handling it
aligned_output = output_trunc[:delay]
aligned_clean = clean_trunc[-delay:]
aligned_noisy = noisy_trunc[-delay:]
else:
aligned_output = output_trunc
aligned_clean = clean_trunc
aligned_noisy = noisy_trunc
# Recalculate MSE on aligned signals
mse_input = np.mean((aligned_noisy.astype(float) - aligned_clean.astype(float)) ** 2)
mse_output = np.mean((aligned_output.astype(float) - aligned_clean.astype(float)) ** 2)
print(f"MSE (Input vs Clean): {mse_input:.2f}")
print(f"MSE (Output vs Clean): {mse_output:.2f}")
# Also check noise reduction in silent regions
# Clean signal envelope is 0 at t=0, 0.25, 0.5...
# Let's find indices where aligned_clean is very small
threshold = 100 # amplitude threshold (out of 32767)
silent_mask = np.abs(aligned_clean) < threshold
if np.sum(silent_mask) > 1000:
noise_power_input = np.mean(aligned_noisy[silent_mask].astype(float) ** 2)
noise_power_output = np.mean(aligned_output[silent_mask].astype(float) ** 2)
print(f"Noise Power in Silence (Input): {noise_power_input:.2f}")
print(f"Noise Power in Silence (Output): {noise_power_output:.2f}")
self.assertLess(
noise_power_output, noise_power_input, "Noise power in silence not reduced"
)
else:
print("Warning: Not enough silent samples found for noise floor check.")
# Main assertion: MSE should improve
# Relax assertion slightly because RNNoise introduces distortion even on clean speech
# But for noisy speech, it should generally be better or at least remove noise.
# If MSE doesn't improve (due to speech distortion), at least Noise Power in Silence should drop.
if mse_output >= mse_input:
print(
"Warning: Overall MSE did not improve (speech distortion?). Relying on Noise Power check."
)
# If we passed the noise power check above, we are good.
self.assertTrue(
np.sum(silent_mask) > 1000
and np.mean(aligned_output[silent_mask].astype(float) ** 2)
< np.mean(aligned_noisy[silent_mask].astype(float) ** 2)
)
else:
self.assertLess(mse_output, mse_input, "MSE did not improve")
print("Test Passed: Noise cancellation verified.")
if __name__ == "__main__":
unittest.main()