updated tts.py to match mark's version

This commit is contained in:
yukiobata1
2026-01-06 21:16:13 +09:00
parent 4f93d331b7
commit 137bbb3d2c

View File

@@ -13,8 +13,6 @@ from loguru import logger
from pydantic import BaseModel
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
InterruptionFrame,
@@ -66,11 +64,12 @@ def sample_rate_to_output_format(sample_rate: int) -> SpeechSynthesisOutputForma
return sample_rate_map.get(sample_rate, SpeechSynthesisOutputFormat.Raw24Khz16BitMonoPcm)
class AzureBaseTTSService(TTSService):
"""Base class for Azure Cognitive Services text-to-speech implementations.
class AzureBaseTTSService:
"""Base mixin class for Azure Cognitive Services text-to-speech implementations.
Provides common functionality for Azure TTS services including SSML
construction, voice configuration, and parameter management.
This is a mixin class and should be used alongside TTSService or its subclasses.
"""
# Define SSML escape mappings based on SSML reserved characters
@@ -106,28 +105,24 @@ class AzureBaseTTSService(TTSService):
style_degree: Optional[str] = None
volume: Optional[str] = None
def __init__(
def _init_azure_base(
self,
*,
api_key: str,
region: str,
voice="en-US-SaraNeural",
sample_rate: Optional[int] = None,
voice: str = "en-US-SaraNeural",
params: Optional[InputParams] = None,
**kwargs,
):
"""Initialize the Azure TTS service with configuration parameters.
"""Initialize Azure-specific configuration.
This method should be called by subclasses after initializing their TTSService parent.
Args:
api_key: Azure Cognitive Services subscription key.
region: Azure region identifier (e.g., "eastus", "westus2").
voice: Voice name to use for synthesis. Defaults to "en-US-SaraNeural".
sample_rate: Audio sample rate in Hz. If None, uses service default.
params: Voice and synthesis parameters configuration.
**kwargs: Additional arguments passed to parent TTSService.
"""
super().__init__(sample_rate=sample_rate, **kwargs)
params = params or AzureBaseTTSService.InputParams()
self._settings = {
@@ -238,7 +233,7 @@ class AzureBaseTTSService(TTSService):
return escaped_text
class AzureTTSService(WordTTSService):
class AzureTTSService(WordTTSService, AzureBaseTTSService):
"""Azure Cognitive Services streaming TTS service with word timestamps.
Provides real-time text-to-speech synthesis using Azure's WebSocket-based
@@ -246,47 +241,15 @@ class AzureTTSService(WordTTSService):
available for lower latency playback and accurate word-level synchronization.
"""
# Define SSML escape mappings based on SSML reserved characters
# See - https://learn.microsoft.com/en-us/azure/ai-services/speech-service/speech-synthesis-markup-structure
SSML_ESCAPE_CHARS = {
"&": "&",
"<": "&lt;",
">": "&gt;",
'"': "&quot;",
"'": "&apos;",
}
class InputParams(BaseModel):
"""Input parameters for Azure TTS voice configuration.
Parameters:
emphasis: Emphasis level for speech ("strong", "moderate", "reduced").
language: Language for synthesis. Defaults to English (US).
pitch: Voice pitch adjustment (e.g., "+10%", "-5Hz", "high").
rate: Speech rate multiplier. Defaults to "1.05".
role: Voice role for expression (e.g., "YoungAdultFemale").
style: Speaking style (e.g., "cheerful", "sad", "excited").
style_degree: Intensity of the speaking style (0.01 to 2.0).
volume: Volume level (e.g., "+20%", "loud", "x-soft").
"""
emphasis: Optional[str] = None
language: Optional[Language] = Language.EN_US
pitch: Optional[str] = None
rate: Optional[str] = "1.05"
role: Optional[str] = None
style: Optional[str] = None
style_degree: Optional[str] = None
volume: Optional[str] = None
def __init__(
self,
*,
api_key: str,
region: str,
voice="en-US-SaraNeural",
voice: str = "en-US-SaraNeural",
sample_rate: Optional[int] = None,
params: Optional[InputParams] = None,
params: Optional[AzureBaseTTSService.InputParams] = None,
aggregate_sentences: bool = True,
**kwargs,
):
"""Initialize the Azure streaming TTS service.
@@ -297,140 +260,28 @@ class AzureTTSService(WordTTSService):
voice: Voice name to use for synthesis. Defaults to "en-US-SaraNeural".
sample_rate: Audio sample rate in Hz. If None, uses service default.
params: Voice and synthesis parameters configuration.
aggregate_sentences: Whether to aggregate sentences before synthesis.
**kwargs: Additional arguments passed to parent WordTTSService.
"""
# We want to push text frames ourselves with word-level timing
# Initialize WordTTSService first to set up word timestamp tracking
super().__init__(
aggregate_sentences=True,
push_text_frames=False,
aggregate_sentences=aggregate_sentences,
push_text_frames=False, # We'll push text frames based on word timestamps
push_stop_frames=True,
pause_frame_processing=True,
sample_rate=sample_rate,
**kwargs,
)
params = params or AzureTTSService.InputParams()
# Initialize Azure-specific functionality from mixin
self._init_azure_base(api_key=api_key, region=region, voice=voice, params=params)
self._settings = {
"emphasis": params.emphasis,
"language": self.language_to_service_language(params.language)
if params.language
else "en-US",
"pitch": params.pitch,
"rate": params.rate,
"role": params.role,
"style": params.style,
"style_degree": params.style_degree,
"volume": params.volume,
}
self._api_key = api_key
self._region = region
self._voice_id = voice
self._speech_config = None
self._speech_synthesizer = None
self._audio_queue = asyncio.Queue()
self._context_id = None
self._word_timestamps_started = False
self._synthesis_lock = asyncio.Lock()
self._started = False
self._cumulative_audio_offset: float = 0.0 # Cumulative audio duration in seconds
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.
Returns:
True, as Azure TTS service supports metrics generation.
"""
return True
def language_to_service_language(self, language: Language) -> Optional[str]:
"""Convert a Language enum to Azure language format.
Args:
language: The language to convert.
Returns:
The Azure-specific language code, or None if not supported.
"""
return language_to_azure_language(language)
def _construct_ssml(self, text: str) -> str:
"""Construct SSML from text with current voice settings.
Args:
text: The text to convert to SSML.
Returns:
SSML string for Azure TTS synthesis.
"""
language = self._settings["language"]
# Escape special characters
escaped_text = self._escape_text(text)
ssml = (
f"<speak version='1.0' xml:lang='{language}' "
"xmlns='http://www.w3.org/2001/10/synthesis' "
"xmlns:mstts='http://www.w3.org/2001/mstts'>"
f"<voice name='{self._voice_id}'>"
"<mstts:silence type='Sentenceboundary' value='20ms' />"
)
if self._settings["style"]:
ssml += f"<mstts:express-as style='{self._settings['style']}'"
if self._settings["style_degree"]:
ssml += f" styledegree='{self._settings['style_degree']}'"
if self._settings["role"]:
ssml += f" role='{self._settings['role']}'"
ssml += ">"
prosody_attrs = []
if self._settings["rate"]:
prosody_attrs.append(f"rate='{self._settings['rate']}'")
if self._settings["pitch"]:
prosody_attrs.append(f"pitch='{self._settings['pitch']}'")
if self._settings["volume"]:
prosody_attrs.append(f"volume='{self._settings['volume']}'")
ssml += f"<prosody {' '.join(prosody_attrs)}>"
if self._settings["emphasis"]:
ssml += f"<emphasis level='{self._settings['emphasis']}'>"
ssml += escaped_text
if self._settings["emphasis"]:
ssml += "</emphasis>"
ssml += "</prosody>"
if self._settings["style"]:
ssml += "</mstts:express-as>"
ssml += "</voice></speak>"
return ssml
def _escape_text(self, text: str) -> str:
"""Escapes XML/SSML reserved characters according to Microsoft documentation.
This method escapes the following characters:
- & becomes &amp;
- < becomes &lt;
- > becomes &gt;
- " becomes &quot;
- ' becomes &apos;
Args:
text: The text to escape.
Returns:
The escaped text.
"""
escaped_text = text
for char, escape_code in AzureTTSService.SSML_ESCAPE_CHARS.items():
escaped_text = escaped_text.replace(char, escape_code)
return escaped_text
async def start(self, frame: StartFrame):
"""Start the Azure TTS service and initialize speech synthesizer.
@@ -486,8 +337,8 @@ class AzureTTSService(WordTTSService):
# Queue the word timestamp for processing
# Use put_nowait since this is a synchronous callback
if self._context_id and word:
logger.debug(f"{self}: Word boundary - '{word}' at {absolute_seconds:.2f}s")
if word:
logger.trace(f"{self}: Word boundary - '{word}' at {absolute_seconds:.2f}s")
try:
# Convert to nanoseconds and put directly in queue (sync operation)
timestamp_ns = seconds_to_nanoseconds(absolute_seconds)
@@ -515,14 +366,6 @@ class AzureTTSService(WordTTSService):
self._cumulative_audio_offset += evt.result.audio_duration.total_seconds()
self._audio_queue.put_nowait(None) # Signal completion
# Add completion markers to word timestamp queue
# Use put_nowait since this is a synchronous callback
if self._context_id:
try:
# Add TTSStoppedFrame marker but NOT Reset - we maintain cumulative PTS
self._words_queue.put_nowait(("TTSStoppedFrame", 0))
except Exception as e:
logger.error(f"{self} error finalizing word timestamps: {e}")
def _handle_canceled(self, evt):
"""Handle synthesis cancellation.
@@ -533,11 +376,22 @@ class AzureTTSService(WordTTSService):
logger.error(f"Speech synthesis canceled: {evt.result.cancellation_details.reason}")
self._audio_queue.put_nowait(None)
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
"""Push a frame and handle state changes.
Args:
frame: The frame to push.
direction: The direction to push the frame.
"""
await super().push_frame(frame, direction)
if isinstance(frame, (TTSStoppedFrame, InterruptionFrame)):
self._started = False
if isinstance(frame, TTSStoppedFrame):
await self.add_word_timestamps([("Reset", 0)])
async def flush_audio(self):
"""Flush any pending audio data."""
logger.trace(f"{self}: flushing audio")
# Reset cumulative audio offset at end of LLM response
self._cumulative_audio_offset = 0.0
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
"""Handle interruption by stopping current synthesis.
@@ -557,7 +411,6 @@ class AzureTTSService(WordTTSService):
self._audio_queue.task_done()
except asyncio.QueueEmpty:
break
self._context_id = None
@traced_tts
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
@@ -571,68 +424,58 @@ class AzureTTSService(WordTTSService):
"""
logger.debug(f"{self}: Generating TTS [{text}]")
# Ensure sequential sentence processing to prevent word boundary interleaving
async with self._synthesis_lock:
# Clear the audio queue in case there's still audio in it, causing the next audio response
# to be cut off by the 'None' element returned at the end of the previous audio synthesis.
# Empty the audio queue before processing the new text
while not self._audio_queue.empty():
self._audio_queue.get_nowait()
self._audio_queue.task_done()
# Clear the audio queue in case there's still audio in it, causing the next audio response
# to be cut off by the 'None' element returned at the end of the previous audio synthesis.
# Empty the audio queue before processing the new text
while not self._audio_queue.empty():
self._audio_queue.get_nowait()
self._audio_queue.task_done()
try:
if self._speech_synthesizer is None:
error_msg = "Speech synthesizer not initialized."
logger.error(error_msg)
yield ErrorFrame(error=error_msg)
return
try:
if self._speech_synthesizer is None:
error_msg = "Speech synthesizer not initialized."
logger.error(error_msg)
yield ErrorFrame(error=error_msg)
return
try:
if not self._started:
await self.start_ttfb_metrics()
await self.start_word_timestamps()
yield TTSStartedFrame()
self._started = True
self._cumulative_audio_offset = 0.0
# Mark that we're starting a new synthesis
self._context_id = str(id(text))
self._word_timestamps_started = False
ssml = self._construct_ssml(text)
self._speech_synthesizer.speak_ssml_async(ssml)
await self.start_tts_usage_metrics(text)
ssml = self._construct_ssml(text)
self._speech_synthesizer.speak_ssml_async(ssml)
await self.start_tts_usage_metrics(text)
# Stream audio chunks as they arrive
while True:
chunk = await self._audio_queue.get()
if chunk is None: # End of stream
break
# Stream audio chunks as they arrive
while True:
chunk = await self._audio_queue.get()
if chunk is None: # End of stream
break
await self.stop_ttfb_metrics()
# Start word timestamps only once when we receive first audio
if not self._word_timestamps_started:
await self.start_word_timestamps()
self._word_timestamps_started = True
frame = TTSAudioRawFrame(
audio=chunk,
sample_rate=self.sample_rate,
num_channels=1,
)
yield frame
# Clear context ID when done
self._context_id = None
yield TTSStoppedFrame()
except Exception as e:
logger.error(f"{self} error during synthesis: {e}")
yield TTSStoppedFrame()
# Could add reconnection logic here if needed
return
await self.stop_ttfb_metrics()
frame = TTSAudioRawFrame(
audio=chunk,
sample_rate=self.sample_rate,
num_channels=1,
)
yield frame
except Exception as e:
logger.error(f"{self} exception: {e}")
logger.error(f"{self} error during synthesis: {e}")
yield TTSStoppedFrame()
self._started = False
# Could add reconnection logic here if needed
return
except Exception as e:
logger.error(f"{self} exception: {e}")
class AzureHttpTTSService(AzureBaseTTSService):
class AzureHttpTTSService(TTSService, AzureBaseTTSService):
"""Azure Cognitive Services HTTP-based TTS service.
Provides text-to-speech synthesis using Azure's HTTP API for simpler,
@@ -640,13 +483,31 @@ class AzureHttpTTSService(AzureBaseTTSService):
required and simpler integration is preferred.
"""
def __init__(self, **kwargs):
def __init__(
self,
*,
api_key: str,
region: str,
voice: str = "en-US-SaraNeural",
sample_rate: Optional[int] = None,
params: Optional[AzureBaseTTSService.InputParams] = None,
**kwargs,
):
"""Initialize the Azure HTTP TTS service.
Args:
**kwargs: All arguments passed to AzureBaseTTSService parent class.
api_key: Azure Cognitive Services subscription key.
region: Azure region identifier (e.g., "eastus", "westus2").
voice: Voice name to use for synthesis. Defaults to "en-US-SaraNeural".
sample_rate: Audio sample rate in Hz. If None, uses service default.
params: Voice and synthesis parameters configuration.
**kwargs: Additional arguments passed to parent TTSService.
"""
super().__init__(**kwargs)
super().__init__(sample_rate=sample_rate, **kwargs)
# Initialize Azure-specific functionality from mixin
self._init_azure_base(api_key=api_key, region=region, voice=voice, params=params)
self._speech_config = None
self._speech_synthesizer = None