169 lines
4.8 KiB
Python
169 lines
4.8 KiB
Python
"""Audio track processing for WebRTC."""
|
|
|
|
import asyncio
|
|
import fractions
|
|
from typing import Optional
|
|
from loguru import logger
|
|
|
|
# Try to import aiortc (optional for WebRTC functionality)
|
|
try:
|
|
from aiortc import AudioStreamTrack
|
|
AIORTC_AVAILABLE = True
|
|
except ImportError:
|
|
AIORTC_AVAILABLE = False
|
|
AudioStreamTrack = object # Dummy class for type hints
|
|
|
|
# Try to import PyAV (optional for audio resampling)
|
|
try:
|
|
from av import AudioFrame, AudioResampler
|
|
AV_AVAILABLE = True
|
|
except ImportError:
|
|
AV_AVAILABLE = False
|
|
# Create dummy classes for type hints
|
|
class AudioFrame:
|
|
pass
|
|
class AudioResampler:
|
|
pass
|
|
|
|
import numpy as np
|
|
|
|
|
|
class Resampled16kTrack(AudioStreamTrack if AIORTC_AVAILABLE else object):
|
|
"""
|
|
Audio track that resamples input to 16kHz mono PCM.
|
|
|
|
Wraps an existing MediaStreamTrack and converts its output
|
|
to 16kHz mono 16-bit PCM format for the pipeline.
|
|
"""
|
|
|
|
def __init__(self, track, target_sample_rate: int = 16000):
|
|
"""
|
|
Initialize resampled track.
|
|
|
|
Args:
|
|
track: Source MediaStreamTrack
|
|
target_sample_rate: Target sample rate (default: 16000)
|
|
"""
|
|
if not AIORTC_AVAILABLE:
|
|
raise RuntimeError("aiortc not available - Resampled16kTrack cannot be used")
|
|
|
|
super().__init__()
|
|
self.track = track
|
|
self.target_sample_rate = target_sample_rate
|
|
|
|
if AV_AVAILABLE:
|
|
self.resampler = AudioResampler(
|
|
format="s16",
|
|
layout="mono",
|
|
rate=target_sample_rate
|
|
)
|
|
else:
|
|
logger.warning("PyAV not available, audio resampling disabled")
|
|
self.resampler = None
|
|
|
|
self._closed = False
|
|
|
|
async def recv(self):
|
|
"""
|
|
Receive and resample next audio frame.
|
|
|
|
Returns:
|
|
Resampled AudioFrame at 16kHz mono
|
|
"""
|
|
if self._closed:
|
|
raise RuntimeError("Track is closed")
|
|
|
|
# Get frame from source track
|
|
frame = await self.track.recv()
|
|
|
|
# Resample the frame if AV is available
|
|
if AV_AVAILABLE and self.resampler:
|
|
resampled_frame = self.resampler.resample(frame)
|
|
# Ensure the frame has the correct format
|
|
resampled_frame.sample_rate = self.target_sample_rate
|
|
return resampled_frame
|
|
else:
|
|
# Return frame as-is if AV is not available
|
|
return frame
|
|
|
|
async def stop(self) -> None:
|
|
"""Stop the track and cleanup resources."""
|
|
self._closed = True
|
|
if hasattr(self, 'resampler') and self.resampler:
|
|
del self.resampler
|
|
logger.debug("Resampled track stopped")
|
|
|
|
|
|
class SineWaveTrack(AudioStreamTrack if AIORTC_AVAILABLE else object):
|
|
"""
|
|
Synthetic audio track that generates a sine wave.
|
|
|
|
Useful for testing without requiring real audio input.
|
|
"""
|
|
|
|
def __init__(self, sample_rate: int = 16000, frequency: int = 440):
|
|
"""
|
|
Initialize sine wave track.
|
|
|
|
Args:
|
|
sample_rate: Audio sample rate (default: 16000)
|
|
frequency: Sine wave frequency in Hz (default: 440)
|
|
"""
|
|
if not AIORTC_AVAILABLE:
|
|
raise RuntimeError("aiortc not available - SineWaveTrack cannot be used")
|
|
|
|
super().__init__()
|
|
self.sample_rate = sample_rate
|
|
self.frequency = frequency
|
|
self.counter = 0
|
|
self._stopped = False
|
|
|
|
async def recv(self):
|
|
"""
|
|
Generate next audio frame with sine wave.
|
|
|
|
Returns:
|
|
AudioFrame with sine wave data
|
|
"""
|
|
if self._stopped:
|
|
raise RuntimeError("Track is stopped")
|
|
|
|
# Generate 20ms of audio
|
|
samples = int(self.sample_rate * 0.02)
|
|
pts = self.counter
|
|
time_base = fractions.Fraction(1, self.sample_rate)
|
|
|
|
# Generate sine wave
|
|
t = np.linspace(
|
|
self.counter / self.sample_rate,
|
|
(self.counter + samples) / self.sample_rate,
|
|
samples,
|
|
endpoint=False
|
|
)
|
|
|
|
# Generate sine wave (Int16 PCM)
|
|
data = (0.5 * np.sin(2 * np.pi * self.frequency * t) * 32767).astype(np.int16)
|
|
|
|
# Update counter
|
|
self.counter += samples
|
|
|
|
# Create AudioFrame if AV is available
|
|
if AV_AVAILABLE:
|
|
frame = AudioFrame.from_ndarray(data.reshape(1, -1), format='s16', layout='mono')
|
|
frame.pts = pts
|
|
frame.time_base = time_base
|
|
frame.sample_rate = self.sample_rate
|
|
return frame
|
|
else:
|
|
# Return simple data structure if AV is not available
|
|
return {
|
|
'data': data,
|
|
'sample_rate': self.sample_rate,
|
|
'pts': pts,
|
|
'time_base': time_base
|
|
}
|
|
|
|
def stop(self) -> None:
|
|
"""Stop the track."""
|
|
self._stopped = True
|