Integrate eou and vad
This commit is contained in:
131
core/pipeline.py
Normal file
131
core/pipeline.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""Audio processing pipeline."""
|
||||
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
from loguru import logger
|
||||
|
||||
from core.transports import BaseTransport
|
||||
from core.events import EventBus, get_event_bus
|
||||
from processors.vad import VADProcessor, SileroVAD
|
||||
from app.config import settings
|
||||
|
||||
|
||||
class AudioPipeline:
|
||||
"""
|
||||
Audio processing pipeline.
|
||||
|
||||
Processes incoming audio through VAD and emits events.
|
||||
"""
|
||||
|
||||
def __init__(self, transport: BaseTransport, session_id: str):
|
||||
"""
|
||||
Initialize audio pipeline.
|
||||
|
||||
Args:
|
||||
transport: Transport instance for sending events/audio
|
||||
session_id: Session identifier for event tracking
|
||||
"""
|
||||
self.transport = transport
|
||||
self.session_id = session_id
|
||||
self.event_bus = get_event_bus()
|
||||
|
||||
# Initialize VAD
|
||||
self.vad_model = SileroVAD(
|
||||
model_path=settings.vad_model_path,
|
||||
sample_rate=settings.sample_rate
|
||||
)
|
||||
self.vad_processor = VADProcessor(
|
||||
vad_model=self.vad_model,
|
||||
threshold=settings.vad_threshold,
|
||||
silence_threshold_ms=settings.vad_eou_threshold_ms,
|
||||
min_speech_duration_ms=settings.vad_min_speech_duration_ms
|
||||
)
|
||||
|
||||
# State
|
||||
self.is_bot_speaking = False
|
||||
self.interrupt_signal = asyncio.Event()
|
||||
self._running = True
|
||||
|
||||
logger.info(f"Audio pipeline initialized for session {session_id}")
|
||||
|
||||
async def process_input(self, pcm_bytes: bytes) -> None:
|
||||
"""
|
||||
Process incoming audio chunk.
|
||||
|
||||
Args:
|
||||
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
|
||||
"""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
try:
|
||||
# Process through VAD
|
||||
result = self.vad_processor.process(pcm_bytes, settings.chunk_size_ms)
|
||||
|
||||
if result:
|
||||
event_type, probability = result
|
||||
|
||||
# Emit event through event bus
|
||||
await self.event_bus.publish(event_type, {
|
||||
"trackId": self.session_id,
|
||||
"probability": probability
|
||||
})
|
||||
|
||||
# Send event to client
|
||||
if event_type == "speaking":
|
||||
logger.info(f"User speaking started (session {self.session_id})")
|
||||
await self.transport.send_event({
|
||||
"event": "speaking",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms(),
|
||||
"startTime": self._get_timestamp_ms()
|
||||
})
|
||||
|
||||
elif event_type == "silence":
|
||||
logger.info(f"User speaking stopped (session {self.session_id})")
|
||||
await self.transport.send_event({
|
||||
"event": "silence",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms(),
|
||||
"startTime": self._get_timestamp_ms(),
|
||||
"duration": 0 # TODO: Calculate actual duration
|
||||
})
|
||||
|
||||
elif event_type == "eou":
|
||||
logger.info(f"EOU detected (session {self.session_id})")
|
||||
await self.transport.send_event({
|
||||
"event": "eou",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Pipeline processing error: {e}", exc_info=True)
|
||||
|
||||
async def process_text_input(self, text: str) -> None:
|
||||
"""
|
||||
Process text input (chat command).
|
||||
|
||||
Args:
|
||||
text: Text input
|
||||
"""
|
||||
logger.info(f"Processing text input: {text[:50]}...")
|
||||
# TODO: Implement text processing (LLM integration, etc.)
|
||||
# For now, just log it
|
||||
|
||||
async def interrupt(self) -> None:
|
||||
"""Interrupt current audio playback."""
|
||||
if self.is_bot_speaking:
|
||||
self.interrupt_signal.set()
|
||||
logger.info(f"Pipeline interrupted for session {self.session_id}")
|
||||
|
||||
async def cleanup(self) -> None:
|
||||
"""Cleanup pipeline resources."""
|
||||
logger.info(f"Cleaning up pipeline for session {self.session_id}")
|
||||
self._running = False
|
||||
self.interrupt_signal.set()
|
||||
|
||||
def _get_timestamp_ms(self) -> int:
|
||||
"""Get current timestamp in milliseconds."""
|
||||
import time
|
||||
return int(time.time() * 1000)
|
||||
Reference in New Issue
Block a user