Merge pull request #3334 from obata-kotobasamurai/fix/azure-tts-word-timestamp

Add word-level timestamp support to Azure TTS with race condition fix
This commit is contained in:
Mark Backman
2026-01-07 14:48:02 -05:00
committed by GitHub
2 changed files with 210 additions and 33 deletions

2
changelog/3334.added.md Normal file
View File

@@ -0,0 +1,2 @@
- Added word-level timestamp support to `AzureTTSService` for accurate text-to-audio synchronization.

View File

@@ -13,15 +13,19 @@ from loguru import logger
from pydantic import BaseModel
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
InterruptionFrame,
StartFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.azure.common import language_to_azure_language
from pipecat.services.tts_service import TTSService
from pipecat.services.tts_service import TTSService, WordTTSService
from pipecat.transcriptions.language import Language
from pipecat.utils.tracing.service_decorators import traced_tts
@@ -61,11 +65,12 @@ def sample_rate_to_output_format(sample_rate: int) -> SpeechSynthesisOutputForma
return sample_rate_map.get(sample_rate, SpeechSynthesisOutputFormat.Raw24Khz16BitMonoPcm)
class AzureBaseTTSService(TTSService):
"""Base class for Azure Cognitive Services text-to-speech implementations.
class AzureBaseTTSService:
"""Base mixin class for Azure Cognitive Services text-to-speech implementations.
Provides common functionality for Azure TTS services including SSML
construction, voice configuration, and parameter management.
This is a mixin class and should be used alongside TTSService or its subclasses.
"""
# Define SSML escape mappings based on SSML reserved characters
@@ -101,28 +106,24 @@ class AzureBaseTTSService(TTSService):
style_degree: Optional[str] = None
volume: Optional[str] = None
def __init__(
def _init_azure_base(
self,
*,
api_key: str,
region: str,
voice="en-US-SaraNeural",
sample_rate: Optional[int] = None,
voice: str = "en-US-SaraNeural",
params: Optional[InputParams] = None,
**kwargs,
):
"""Initialize the Azure TTS service with configuration parameters.
"""Initialize Azure-specific configuration.
This method should be called by subclasses after initializing their TTSService parent.
Args:
api_key: Azure Cognitive Services subscription key.
region: Azure region identifier (e.g., "eastus", "westus2").
voice: Voice name to use for synthesis. Defaults to "en-US-SaraNeural".
sample_rate: Audio sample rate in Hz. If None, uses service default.
params: Voice and synthesis parameters configuration.
**kwargs: Additional arguments passed to parent TTSService.
"""
super().__init__(sample_rate=sample_rate, **kwargs)
params = params or AzureBaseTTSService.InputParams()
self._settings = {
@@ -233,24 +234,56 @@ class AzureBaseTTSService(TTSService):
return escaped_text
class AzureTTSService(AzureBaseTTSService):
"""Azure Cognitive Services streaming TTS service.
class AzureTTSService(WordTTSService, AzureBaseTTSService):
"""Azure Cognitive Services streaming TTS service with word timestamps.
Provides real-time text-to-speech synthesis using Azure's WebSocket-based
streaming API. Audio chunks are streamed as they become available for
lower latency playback.
streaming API. Audio chunks and word boundaries are streamed as they become
available for lower latency playback and accurate word-level synchronization.
"""
def __init__(self, **kwargs):
def __init__(
self,
*,
api_key: str,
region: str,
voice: str = "en-US-SaraNeural",
sample_rate: Optional[int] = None,
params: Optional[AzureBaseTTSService.InputParams] = None,
aggregate_sentences: bool = True,
**kwargs,
):
"""Initialize the Azure streaming TTS service.
Args:
**kwargs: All arguments passed to AzureBaseTTSService parent class.
api_key: Azure Cognitive Services subscription key.
region: Azure region identifier (e.g., "eastus", "westus2").
voice: Voice name to use for synthesis. Defaults to "en-US-SaraNeural".
sample_rate: Audio sample rate in Hz. If None, uses service default.
params: Voice and synthesis parameters configuration.
aggregate_sentences: Whether to aggregate sentences before synthesis.
**kwargs: Additional arguments passed to parent WordTTSService.
"""
super().__init__(**kwargs)
# Initialize WordTTSService first to set up word timestamp tracking
super().__init__(
aggregate_sentences=aggregate_sentences,
push_text_frames=False, # We'll push text frames based on word timestamps
push_stop_frames=True,
pause_frame_processing=True,
sample_rate=sample_rate,
**kwargs,
)
# Initialize Azure-specific functionality from mixin
self._init_azure_base(api_key=api_key, region=region, voice=voice, params=params)
self._speech_config = None
self._speech_synthesizer = None
self._audio_queue = asyncio.Queue()
self._word_boundary_queue = asyncio.Queue()
self._word_processor_task = None
self._started = False
self._cumulative_audio_offset: float = 0.0 # Cumulative audio duration in seconds
async def start(self, frame: StartFrame):
"""Start the Azure TTS service and initialize speech synthesizer.
@@ -286,25 +319,144 @@ class AzureTTSService(AzureBaseTTSService):
self._speech_synthesizer.synthesizing.connect(self._handle_synthesizing)
self._speech_synthesizer.synthesis_completed.connect(self._handle_completed)
self._speech_synthesizer.synthesis_canceled.connect(self._handle_canceled)
self._speech_synthesizer.synthesis_word_boundary.connect(self._handle_word_boundary)
# Start word processor task
if not self._word_processor_task:
self._word_processor_task = self.create_task(self._word_processor_task_handler())
async def stop(self, frame: EndFrame):
"""Stop the Azure TTS service.
Args:
frame: End frame signaling service stop.
"""
await super().stop(frame)
await self.cancel_task(self._word_processor_task)
self._word_processor_task = None
async def cancel(self, frame: CancelFrame):
"""Cancel the Azure TTS service.
Args:
frame: Cancel frame signaling service cancellation.
"""
await super().cancel(frame)
await self.cancel_task(self._word_processor_task)
self._word_processor_task = None
def _handle_word_boundary(self, evt):
"""Handle word boundary events from Azure SDK.
Args:
evt: SpeechSynthesisWordBoundaryEventArgs from Azure Speech SDK
containing word text and audio offset timing.
"""
# evt.text contains the word
# evt.audio_offset contains timing in ticks (100-nanosecond units)
# Convert ticks to seconds: divide by 10,000,000
word = evt.text
sentence_relative_seconds = evt.audio_offset / 10_000_000.0
# Add cumulative offset to get absolute timestamp across sentences
absolute_seconds = self._cumulative_audio_offset + sentence_relative_seconds
# Queue word timestamp for async processing
# Use thread-safe queue since this is called from Azure SDK thread
if word:
logger.trace(f"{self}: Word boundary - '{word}' at {absolute_seconds:.2f}s")
try:
# Put in temporary queue - will be processed by async task
# Store as (word, timestamp_in_seconds) tuple
self._word_boundary_queue.put_nowait((word, absolute_seconds))
except Exception as e:
logger.error(f"{self} error queuing word timestamp: {e}")
async def _word_processor_task_handler(self):
"""Process word timestamps from the queue and call add_word_timestamps."""
while True:
try:
word, timestamp_seconds = await self._word_boundary_queue.get()
await self.add_word_timestamps([(word, timestamp_seconds)])
self._word_boundary_queue.task_done()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"{self} error processing word timestamp: {e}")
def _handle_synthesizing(self, evt):
"""Handle audio chunks as they arriv."""
"""Handle audio chunks as they arrive.
Args:
evt: Synthesis event containing audio data.
"""
if evt.result and evt.result.audio_data:
self._audio_queue.put_nowait(evt.result.audio_data)
def _handle_completed(self, evt):
"""Handle synthesis completion."""
"""Handle synthesis completion.
Args:
evt: Completion event from Azure Speech SDK.
"""
# Update cumulative audio offset for next sentence
if evt.result and evt.result.audio_duration:
self._cumulative_audio_offset += evt.result.audio_duration.total_seconds()
self._audio_queue.put_nowait(None) # Signal completion
def _handle_canceled(self, evt):
"""Handle synthesis cancellation."""
"""Handle synthesis cancellation.
Args:
evt: Cancellation event.
"""
logger.error(f"Speech synthesis canceled: {evt.result.cancellation_details.reason}")
self._audio_queue.put_nowait(None)
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
"""Push a frame and handle state changes.
Args:
frame: The frame to push.
direction: The direction to push the frame.
"""
await super().push_frame(frame, direction)
if isinstance(frame, (TTSStoppedFrame, InterruptionFrame)):
self._started = False
if isinstance(frame, TTSStoppedFrame):
await self.add_word_timestamps([("Reset", 0)])
async def flush_audio(self):
"""Flush any pending audio data."""
logger.trace(f"{self}: flushing audio")
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
"""Handle interruption by stopping current synthesis.
Args:
frame: The interruption frame.
direction: Frame processing direction.
"""
await super()._handle_interruption(frame, direction)
await self.stop_all_metrics()
# Reset cumulative audio offset on interruption
self._cumulative_audio_offset = 0.0
# Clear the audio queue
while not self._audio_queue.empty():
try:
self._audio_queue.get_nowait()
self._audio_queue.task_done()
except asyncio.QueueEmpty:
break
# Clear the word boundary queue
while not self._word_boundary_queue.empty():
try:
self._word_boundary_queue.get_nowait()
self._word_boundary_queue.task_done()
except asyncio.QueueEmpty:
break
@traced_tts
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
"""Generate speech from text using Azure's streaming synthesis.
@@ -327,12 +479,15 @@ class AzureTTSService(AzureBaseTTSService):
try:
if self._speech_synthesizer is None:
error_msg = "Speech synthesizer not initialized."
logger.error(error_msg)
yield ErrorFrame(error=error_msg)
return
try:
await self.start_ttfb_metrics()
yield TTSStartedFrame()
if not self._started:
await self.start_ttfb_metrics()
yield TTSStartedFrame()
self._started = True
ssml = self._construct_ssml(text)
self._speech_synthesizer.speak_ssml_async(ssml)
@@ -345,25 +500,27 @@ class AzureTTSService(AzureBaseTTSService):
break
await self.stop_ttfb_metrics()
yield TTSAudioRawFrame(
await self.start_word_timestamps()
frame = TTSAudioRawFrame(
audio=chunk,
sample_rate=self.sample_rate,
num_channels=1,
)
yield TTSStoppedFrame()
yield frame
except Exception as e:
yield ErrorFrame(error=f"Unknown error occurred: {e}")
logger.error(f"{self} error during synthesis: {e}")
yield TTSStoppedFrame()
self._started = False
# Could add reconnection logic here if needed
return
except Exception as e:
yield ErrorFrame(error=f"Unknown error occurred: {e}")
logger.error(f"{self} exception: {e}")
class AzureHttpTTSService(AzureBaseTTSService):
class AzureHttpTTSService(TTSService, AzureBaseTTSService):
"""Azure Cognitive Services HTTP-based TTS service.
Provides text-to-speech synthesis using Azure's HTTP API for simpler,
@@ -371,13 +528,31 @@ class AzureHttpTTSService(AzureBaseTTSService):
required and simpler integration is preferred.
"""
def __init__(self, **kwargs):
def __init__(
self,
*,
api_key: str,
region: str,
voice: str = "en-US-SaraNeural",
sample_rate: Optional[int] = None,
params: Optional[AzureBaseTTSService.InputParams] = None,
**kwargs,
):
"""Initialize the Azure HTTP TTS service.
Args:
**kwargs: All arguments passed to AzureBaseTTSService parent class.
api_key: Azure Cognitive Services subscription key.
region: Azure region identifier (e.g., "eastus", "westus2").
voice: Voice name to use for synthesis. Defaults to "en-US-SaraNeural".
sample_rate: Audio sample rate in Hz. If None, uses service default.
params: Voice and synthesis parameters configuration.
**kwargs: Additional arguments passed to parent TTSService.
"""
super().__init__(**kwargs)
super().__init__(sample_rate=sample_rate, **kwargs)
# Initialize Azure-specific functionality from mixin
self._init_azure_base(api_key=api_key, region=region, voice=voice, params=params)
self._speech_config = None
self._speech_synthesizer = None