Merge pull request #3334 from obata-kotobasamurai/fix/azure-tts-word-timestamp
Add word-level timestamp support to Azure TTS with race condition fix
This commit is contained in:
2
changelog/3334.added.md
Normal file
2
changelog/3334.added.md
Normal file
@@ -0,0 +1,2 @@
|
||||
- Added word-level timestamp support to `AzureTTSService` for accurate text-to-audio synchronization.
|
||||
|
||||
@@ -13,15 +13,19 @@ from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.azure.common import language_to_azure_language
|
||||
from pipecat.services.tts_service import TTSService
|
||||
from pipecat.services.tts_service import TTSService, WordTTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
@@ -61,11 +65,12 @@ def sample_rate_to_output_format(sample_rate: int) -> SpeechSynthesisOutputForma
|
||||
return sample_rate_map.get(sample_rate, SpeechSynthesisOutputFormat.Raw24Khz16BitMonoPcm)
|
||||
|
||||
|
||||
class AzureBaseTTSService(TTSService):
|
||||
"""Base class for Azure Cognitive Services text-to-speech implementations.
|
||||
class AzureBaseTTSService:
|
||||
"""Base mixin class for Azure Cognitive Services text-to-speech implementations.
|
||||
|
||||
Provides common functionality for Azure TTS services including SSML
|
||||
construction, voice configuration, and parameter management.
|
||||
This is a mixin class and should be used alongside TTSService or its subclasses.
|
||||
"""
|
||||
|
||||
# Define SSML escape mappings based on SSML reserved characters
|
||||
@@ -101,28 +106,24 @@ class AzureBaseTTSService(TTSService):
|
||||
style_degree: Optional[str] = None
|
||||
volume: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
def _init_azure_base(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
region: str,
|
||||
voice="en-US-SaraNeural",
|
||||
sample_rate: Optional[int] = None,
|
||||
voice: str = "en-US-SaraNeural",
|
||||
params: Optional[InputParams] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Azure TTS service with configuration parameters.
|
||||
"""Initialize Azure-specific configuration.
|
||||
|
||||
This method should be called by subclasses after initializing their TTSService parent.
|
||||
|
||||
Args:
|
||||
api_key: Azure Cognitive Services subscription key.
|
||||
region: Azure region identifier (e.g., "eastus", "westus2").
|
||||
voice: Voice name to use for synthesis. Defaults to "en-US-SaraNeural".
|
||||
sample_rate: Audio sample rate in Hz. If None, uses service default.
|
||||
params: Voice and synthesis parameters configuration.
|
||||
**kwargs: Additional arguments passed to parent TTSService.
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
params = params or AzureBaseTTSService.InputParams()
|
||||
|
||||
self._settings = {
|
||||
@@ -233,24 +234,56 @@ class AzureBaseTTSService(TTSService):
|
||||
return escaped_text
|
||||
|
||||
|
||||
class AzureTTSService(AzureBaseTTSService):
|
||||
"""Azure Cognitive Services streaming TTS service.
|
||||
class AzureTTSService(WordTTSService, AzureBaseTTSService):
|
||||
"""Azure Cognitive Services streaming TTS service with word timestamps.
|
||||
|
||||
Provides real-time text-to-speech synthesis using Azure's WebSocket-based
|
||||
streaming API. Audio chunks are streamed as they become available for
|
||||
lower latency playback.
|
||||
streaming API. Audio chunks and word boundaries are streamed as they become
|
||||
available for lower latency playback and accurate word-level synchronization.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
region: str,
|
||||
voice: str = "en-US-SaraNeural",
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[AzureBaseTTSService.InputParams] = None,
|
||||
aggregate_sentences: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Azure streaming TTS service.
|
||||
|
||||
Args:
|
||||
**kwargs: All arguments passed to AzureBaseTTSService parent class.
|
||||
api_key: Azure Cognitive Services subscription key.
|
||||
region: Azure region identifier (e.g., "eastus", "westus2").
|
||||
voice: Voice name to use for synthesis. Defaults to "en-US-SaraNeural".
|
||||
sample_rate: Audio sample rate in Hz. If None, uses service default.
|
||||
params: Voice and synthesis parameters configuration.
|
||||
aggregate_sentences: Whether to aggregate sentences before synthesis.
|
||||
**kwargs: Additional arguments passed to parent WordTTSService.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
# Initialize WordTTSService first to set up word timestamp tracking
|
||||
super().__init__(
|
||||
aggregate_sentences=aggregate_sentences,
|
||||
push_text_frames=False, # We'll push text frames based on word timestamps
|
||||
push_stop_frames=True,
|
||||
pause_frame_processing=True,
|
||||
sample_rate=sample_rate,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Initialize Azure-specific functionality from mixin
|
||||
self._init_azure_base(api_key=api_key, region=region, voice=voice, params=params)
|
||||
|
||||
self._speech_config = None
|
||||
self._speech_synthesizer = None
|
||||
self._audio_queue = asyncio.Queue()
|
||||
self._word_boundary_queue = asyncio.Queue()
|
||||
self._word_processor_task = None
|
||||
self._started = False
|
||||
self._cumulative_audio_offset: float = 0.0 # Cumulative audio duration in seconds
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the Azure TTS service and initialize speech synthesizer.
|
||||
@@ -286,25 +319,144 @@ class AzureTTSService(AzureBaseTTSService):
|
||||
self._speech_synthesizer.synthesizing.connect(self._handle_synthesizing)
|
||||
self._speech_synthesizer.synthesis_completed.connect(self._handle_completed)
|
||||
self._speech_synthesizer.synthesis_canceled.connect(self._handle_canceled)
|
||||
self._speech_synthesizer.synthesis_word_boundary.connect(self._handle_word_boundary)
|
||||
|
||||
# Start word processor task
|
||||
if not self._word_processor_task:
|
||||
self._word_processor_task = self.create_task(self._word_processor_task_handler())
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the Azure TTS service.
|
||||
|
||||
Args:
|
||||
frame: End frame signaling service stop.
|
||||
"""
|
||||
await super().stop(frame)
|
||||
await self.cancel_task(self._word_processor_task)
|
||||
self._word_processor_task = None
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the Azure TTS service.
|
||||
|
||||
Args:
|
||||
frame: Cancel frame signaling service cancellation.
|
||||
"""
|
||||
await super().cancel(frame)
|
||||
await self.cancel_task(self._word_processor_task)
|
||||
self._word_processor_task = None
|
||||
|
||||
def _handle_word_boundary(self, evt):
|
||||
"""Handle word boundary events from Azure SDK.
|
||||
|
||||
Args:
|
||||
evt: SpeechSynthesisWordBoundaryEventArgs from Azure Speech SDK
|
||||
containing word text and audio offset timing.
|
||||
"""
|
||||
# evt.text contains the word
|
||||
# evt.audio_offset contains timing in ticks (100-nanosecond units)
|
||||
# Convert ticks to seconds: divide by 10,000,000
|
||||
word = evt.text
|
||||
sentence_relative_seconds = evt.audio_offset / 10_000_000.0
|
||||
|
||||
# Add cumulative offset to get absolute timestamp across sentences
|
||||
absolute_seconds = self._cumulative_audio_offset + sentence_relative_seconds
|
||||
|
||||
# Queue word timestamp for async processing
|
||||
# Use thread-safe queue since this is called from Azure SDK thread
|
||||
if word:
|
||||
logger.trace(f"{self}: Word boundary - '{word}' at {absolute_seconds:.2f}s")
|
||||
try:
|
||||
# Put in temporary queue - will be processed by async task
|
||||
# Store as (word, timestamp_in_seconds) tuple
|
||||
self._word_boundary_queue.put_nowait((word, absolute_seconds))
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error queuing word timestamp: {e}")
|
||||
|
||||
async def _word_processor_task_handler(self):
|
||||
"""Process word timestamps from the queue and call add_word_timestamps."""
|
||||
while True:
|
||||
try:
|
||||
word, timestamp_seconds = await self._word_boundary_queue.get()
|
||||
await self.add_word_timestamps([(word, timestamp_seconds)])
|
||||
self._word_boundary_queue.task_done()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error processing word timestamp: {e}")
|
||||
|
||||
def _handle_synthesizing(self, evt):
|
||||
"""Handle audio chunks as they arriv."""
|
||||
"""Handle audio chunks as they arrive.
|
||||
|
||||
Args:
|
||||
evt: Synthesis event containing audio data.
|
||||
"""
|
||||
if evt.result and evt.result.audio_data:
|
||||
self._audio_queue.put_nowait(evt.result.audio_data)
|
||||
|
||||
def _handle_completed(self, evt):
|
||||
"""Handle synthesis completion."""
|
||||
"""Handle synthesis completion.
|
||||
|
||||
Args:
|
||||
evt: Completion event from Azure Speech SDK.
|
||||
"""
|
||||
# Update cumulative audio offset for next sentence
|
||||
if evt.result and evt.result.audio_duration:
|
||||
self._cumulative_audio_offset += evt.result.audio_duration.total_seconds()
|
||||
|
||||
self._audio_queue.put_nowait(None) # Signal completion
|
||||
|
||||
def _handle_canceled(self, evt):
|
||||
"""Handle synthesis cancellation."""
|
||||
"""Handle synthesis cancellation.
|
||||
|
||||
Args:
|
||||
evt: Cancellation event.
|
||||
"""
|
||||
logger.error(f"Speech synthesis canceled: {evt.result.cancellation_details.reason}")
|
||||
self._audio_queue.put_nowait(None)
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
"""Push a frame and handle state changes.
|
||||
|
||||
Args:
|
||||
frame: The frame to push.
|
||||
direction: The direction to push the frame.
|
||||
"""
|
||||
await super().push_frame(frame, direction)
|
||||
if isinstance(frame, (TTSStoppedFrame, InterruptionFrame)):
|
||||
self._started = False
|
||||
if isinstance(frame, TTSStoppedFrame):
|
||||
await self.add_word_timestamps([("Reset", 0)])
|
||||
|
||||
async def flush_audio(self):
|
||||
"""Flush any pending audio data."""
|
||||
logger.trace(f"{self}: flushing audio")
|
||||
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
"""Handle interruption by stopping current synthesis.
|
||||
|
||||
Args:
|
||||
frame: The interruption frame.
|
||||
direction: Frame processing direction.
|
||||
"""
|
||||
await super()._handle_interruption(frame, direction)
|
||||
await self.stop_all_metrics()
|
||||
# Reset cumulative audio offset on interruption
|
||||
self._cumulative_audio_offset = 0.0
|
||||
# Clear the audio queue
|
||||
while not self._audio_queue.empty():
|
||||
try:
|
||||
self._audio_queue.get_nowait()
|
||||
self._audio_queue.task_done()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
# Clear the word boundary queue
|
||||
while not self._word_boundary_queue.empty():
|
||||
try:
|
||||
self._word_boundary_queue.get_nowait()
|
||||
self._word_boundary_queue.task_done()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate speech from text using Azure's streaming synthesis.
|
||||
@@ -327,12 +479,15 @@ class AzureTTSService(AzureBaseTTSService):
|
||||
try:
|
||||
if self._speech_synthesizer is None:
|
||||
error_msg = "Speech synthesizer not initialized."
|
||||
logger.error(error_msg)
|
||||
yield ErrorFrame(error=error_msg)
|
||||
return
|
||||
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame()
|
||||
if not self._started:
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame()
|
||||
self._started = True
|
||||
|
||||
ssml = self._construct_ssml(text)
|
||||
self._speech_synthesizer.speak_ssml_async(ssml)
|
||||
@@ -345,25 +500,27 @@ class AzureTTSService(AzureBaseTTSService):
|
||||
break
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSAudioRawFrame(
|
||||
await self.start_word_timestamps()
|
||||
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=chunk,
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
)
|
||||
|
||||
yield TTSStoppedFrame()
|
||||
yield frame
|
||||
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} error during synthesis: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
self._started = False
|
||||
# Could add reconnection logic here if needed
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
logger.error(f"{self} exception: {e}")
|
||||
|
||||
|
||||
class AzureHttpTTSService(AzureBaseTTSService):
|
||||
class AzureHttpTTSService(TTSService, AzureBaseTTSService):
|
||||
"""Azure Cognitive Services HTTP-based TTS service.
|
||||
|
||||
Provides text-to-speech synthesis using Azure's HTTP API for simpler,
|
||||
@@ -371,13 +528,31 @@ class AzureHttpTTSService(AzureBaseTTSService):
|
||||
required and simpler integration is preferred.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
region: str,
|
||||
voice: str = "en-US-SaraNeural",
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[AzureBaseTTSService.InputParams] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Azure HTTP TTS service.
|
||||
|
||||
Args:
|
||||
**kwargs: All arguments passed to AzureBaseTTSService parent class.
|
||||
api_key: Azure Cognitive Services subscription key.
|
||||
region: Azure region identifier (e.g., "eastus", "westus2").
|
||||
voice: Voice name to use for synthesis. Defaults to "en-US-SaraNeural".
|
||||
sample_rate: Audio sample rate in Hz. If None, uses service default.
|
||||
params: Voice and synthesis parameters configuration.
|
||||
**kwargs: Additional arguments passed to parent TTSService.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
# Initialize Azure-specific functionality from mixin
|
||||
self._init_azure_base(api_key=api_key, region=region, voice=voice, params=params)
|
||||
|
||||
self._speech_config = None
|
||||
self._speech_synthesizer = None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user