Integrate eou and vad
This commit is contained in:
168
processors/tracks.py
Normal file
168
processors/tracks.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""Audio track processing for WebRTC."""
|
||||
|
||||
import asyncio
|
||||
import fractions
|
||||
from typing import Optional
|
||||
from loguru import logger
|
||||
|
||||
# Try to import aiortc (optional for WebRTC functionality)
|
||||
try:
|
||||
from aiortc import AudioStreamTrack
|
||||
AIORTC_AVAILABLE = True
|
||||
except ImportError:
|
||||
AIORTC_AVAILABLE = False
|
||||
AudioStreamTrack = object # Dummy class for type hints
|
||||
|
||||
# Try to import PyAV (optional for audio resampling)
|
||||
try:
|
||||
from av import AudioFrame, AudioResampler
|
||||
AV_AVAILABLE = True
|
||||
except ImportError:
|
||||
AV_AVAILABLE = False
|
||||
# Create dummy classes for type hints
|
||||
class AudioFrame:
|
||||
pass
|
||||
class AudioResampler:
|
||||
pass
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Resampled16kTrack(AudioStreamTrack if AIORTC_AVAILABLE else object):
|
||||
"""
|
||||
Audio track that resamples input to 16kHz mono PCM.
|
||||
|
||||
Wraps an existing MediaStreamTrack and converts its output
|
||||
to 16kHz mono 16-bit PCM format for the pipeline.
|
||||
"""
|
||||
|
||||
def __init__(self, track, target_sample_rate: int = 16000):
|
||||
"""
|
||||
Initialize resampled track.
|
||||
|
||||
Args:
|
||||
track: Source MediaStreamTrack
|
||||
target_sample_rate: Target sample rate (default: 16000)
|
||||
"""
|
||||
if not AIORTC_AVAILABLE:
|
||||
raise RuntimeError("aiortc not available - Resampled16kTrack cannot be used")
|
||||
|
||||
super().__init__()
|
||||
self.track = track
|
||||
self.target_sample_rate = target_sample_rate
|
||||
|
||||
if AV_AVAILABLE:
|
||||
self.resampler = AudioResampler(
|
||||
format="s16",
|
||||
layout="mono",
|
||||
rate=target_sample_rate
|
||||
)
|
||||
else:
|
||||
logger.warning("PyAV not available, audio resampling disabled")
|
||||
self.resampler = None
|
||||
|
||||
self._closed = False
|
||||
|
||||
async def recv(self):
|
||||
"""
|
||||
Receive and resample next audio frame.
|
||||
|
||||
Returns:
|
||||
Resampled AudioFrame at 16kHz mono
|
||||
"""
|
||||
if self._closed:
|
||||
raise RuntimeError("Track is closed")
|
||||
|
||||
# Get frame from source track
|
||||
frame = await self.track.recv()
|
||||
|
||||
# Resample the frame if AV is available
|
||||
if AV_AVAILABLE and self.resampler:
|
||||
resampled_frame = self.resampler.resample(frame)
|
||||
# Ensure the frame has the correct format
|
||||
resampled_frame.sample_rate = self.target_sample_rate
|
||||
return resampled_frame
|
||||
else:
|
||||
# Return frame as-is if AV is not available
|
||||
return frame
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the track and cleanup resources."""
|
||||
self._closed = True
|
||||
if hasattr(self, 'resampler') and self.resampler:
|
||||
del self.resampler
|
||||
logger.debug("Resampled track stopped")
|
||||
|
||||
|
||||
class SineWaveTrack(AudioStreamTrack if AIORTC_AVAILABLE else object):
|
||||
"""
|
||||
Synthetic audio track that generates a sine wave.
|
||||
|
||||
Useful for testing without requiring real audio input.
|
||||
"""
|
||||
|
||||
def __init__(self, sample_rate: int = 16000, frequency: int = 440):
|
||||
"""
|
||||
Initialize sine wave track.
|
||||
|
||||
Args:
|
||||
sample_rate: Audio sample rate (default: 16000)
|
||||
frequency: Sine wave frequency in Hz (default: 440)
|
||||
"""
|
||||
if not AIORTC_AVAILABLE:
|
||||
raise RuntimeError("aiortc not available - SineWaveTrack cannot be used")
|
||||
|
||||
super().__init__()
|
||||
self.sample_rate = sample_rate
|
||||
self.frequency = frequency
|
||||
self.counter = 0
|
||||
self._stopped = False
|
||||
|
||||
async def recv(self):
|
||||
"""
|
||||
Generate next audio frame with sine wave.
|
||||
|
||||
Returns:
|
||||
AudioFrame with sine wave data
|
||||
"""
|
||||
if self._stopped:
|
||||
raise RuntimeError("Track is stopped")
|
||||
|
||||
# Generate 20ms of audio
|
||||
samples = int(self.sample_rate * 0.02)
|
||||
pts = self.counter
|
||||
time_base = fractions.Fraction(1, self.sample_rate)
|
||||
|
||||
# Generate sine wave
|
||||
t = np.linspace(
|
||||
self.counter / self.sample_rate,
|
||||
(self.counter + samples) / self.sample_rate,
|
||||
samples,
|
||||
endpoint=False
|
||||
)
|
||||
|
||||
# Generate sine wave (Int16 PCM)
|
||||
data = (0.5 * np.sin(2 * np.pi * self.frequency * t) * 32767).astype(np.int16)
|
||||
|
||||
# Update counter
|
||||
self.counter += samples
|
||||
|
||||
# Create AudioFrame if AV is available
|
||||
if AV_AVAILABLE:
|
||||
frame = AudioFrame.from_ndarray(data.reshape(1, -1), format='s16', layout='mono')
|
||||
frame.pts = pts
|
||||
frame.time_base = time_base
|
||||
frame.sample_rate = self.sample_rate
|
||||
return frame
|
||||
else:
|
||||
# Return simple data structure if AV is not available
|
||||
return {
|
||||
'data': data,
|
||||
'sample_rate': self.sample_rate,
|
||||
'pts': pts,
|
||||
'time_base': time_base
|
||||
}
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the track."""
|
||||
self._stopped = True
|
||||
Reference in New Issue
Block a user