updated tts.py to match mark's version
This commit is contained in:
@@ -13,8 +13,6 @@ from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
@@ -66,11 +64,12 @@ def sample_rate_to_output_format(sample_rate: int) -> SpeechSynthesisOutputForma
|
||||
return sample_rate_map.get(sample_rate, SpeechSynthesisOutputFormat.Raw24Khz16BitMonoPcm)
|
||||
|
||||
|
||||
class AzureBaseTTSService(TTSService):
|
||||
"""Base class for Azure Cognitive Services text-to-speech implementations.
|
||||
class AzureBaseTTSService:
|
||||
"""Base mixin class for Azure Cognitive Services text-to-speech implementations.
|
||||
|
||||
Provides common functionality for Azure TTS services including SSML
|
||||
construction, voice configuration, and parameter management.
|
||||
This is a mixin class and should be used alongside TTSService or its subclasses.
|
||||
"""
|
||||
|
||||
# Define SSML escape mappings based on SSML reserved characters
|
||||
@@ -106,28 +105,24 @@ class AzureBaseTTSService(TTSService):
|
||||
style_degree: Optional[str] = None
|
||||
volume: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
def _init_azure_base(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
region: str,
|
||||
voice="en-US-SaraNeural",
|
||||
sample_rate: Optional[int] = None,
|
||||
voice: str = "en-US-SaraNeural",
|
||||
params: Optional[InputParams] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Azure TTS service with configuration parameters.
|
||||
"""Initialize Azure-specific configuration.
|
||||
|
||||
This method should be called by subclasses after initializing their TTSService parent.
|
||||
|
||||
Args:
|
||||
api_key: Azure Cognitive Services subscription key.
|
||||
region: Azure region identifier (e.g., "eastus", "westus2").
|
||||
voice: Voice name to use for synthesis. Defaults to "en-US-SaraNeural".
|
||||
sample_rate: Audio sample rate in Hz. If None, uses service default.
|
||||
params: Voice and synthesis parameters configuration.
|
||||
**kwargs: Additional arguments passed to parent TTSService.
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
params = params or AzureBaseTTSService.InputParams()
|
||||
|
||||
self._settings = {
|
||||
@@ -238,7 +233,7 @@ class AzureBaseTTSService(TTSService):
|
||||
return escaped_text
|
||||
|
||||
|
||||
class AzureTTSService(WordTTSService):
|
||||
class AzureTTSService(WordTTSService, AzureBaseTTSService):
|
||||
"""Azure Cognitive Services streaming TTS service with word timestamps.
|
||||
|
||||
Provides real-time text-to-speech synthesis using Azure's WebSocket-based
|
||||
@@ -246,47 +241,15 @@ class AzureTTSService(WordTTSService):
|
||||
available for lower latency playback and accurate word-level synchronization.
|
||||
"""
|
||||
|
||||
# Define SSML escape mappings based on SSML reserved characters
|
||||
# See - https://learn.microsoft.com/en-us/azure/ai-services/speech-service/speech-synthesis-markup-structure
|
||||
SSML_ESCAPE_CHARS = {
|
||||
"&": "&",
|
||||
"<": "<",
|
||||
">": ">",
|
||||
'"': """,
|
||||
"'": "'",
|
||||
}
|
||||
|
||||
class InputParams(BaseModel):
|
||||
"""Input parameters for Azure TTS voice configuration.
|
||||
|
||||
Parameters:
|
||||
emphasis: Emphasis level for speech ("strong", "moderate", "reduced").
|
||||
language: Language for synthesis. Defaults to English (US).
|
||||
pitch: Voice pitch adjustment (e.g., "+10%", "-5Hz", "high").
|
||||
rate: Speech rate multiplier. Defaults to "1.05".
|
||||
role: Voice role for expression (e.g., "YoungAdultFemale").
|
||||
style: Speaking style (e.g., "cheerful", "sad", "excited").
|
||||
style_degree: Intensity of the speaking style (0.01 to 2.0).
|
||||
volume: Volume level (e.g., "+20%", "loud", "x-soft").
|
||||
"""
|
||||
|
||||
emphasis: Optional[str] = None
|
||||
language: Optional[Language] = Language.EN_US
|
||||
pitch: Optional[str] = None
|
||||
rate: Optional[str] = "1.05"
|
||||
role: Optional[str] = None
|
||||
style: Optional[str] = None
|
||||
style_degree: Optional[str] = None
|
||||
volume: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
region: str,
|
||||
voice="en-US-SaraNeural",
|
||||
voice: str = "en-US-SaraNeural",
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
params: Optional[AzureBaseTTSService.InputParams] = None,
|
||||
aggregate_sentences: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Azure streaming TTS service.
|
||||
@@ -297,140 +260,28 @@ class AzureTTSService(WordTTSService):
|
||||
voice: Voice name to use for synthesis. Defaults to "en-US-SaraNeural".
|
||||
sample_rate: Audio sample rate in Hz. If None, uses service default.
|
||||
params: Voice and synthesis parameters configuration.
|
||||
aggregate_sentences: Whether to aggregate sentences before synthesis.
|
||||
**kwargs: Additional arguments passed to parent WordTTSService.
|
||||
"""
|
||||
# We want to push text frames ourselves with word-level timing
|
||||
# Initialize WordTTSService first to set up word timestamp tracking
|
||||
super().__init__(
|
||||
aggregate_sentences=True,
|
||||
push_text_frames=False,
|
||||
aggregate_sentences=aggregate_sentences,
|
||||
push_text_frames=False, # We'll push text frames based on word timestamps
|
||||
push_stop_frames=True,
|
||||
pause_frame_processing=True,
|
||||
sample_rate=sample_rate,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
params = params or AzureTTSService.InputParams()
|
||||
# Initialize Azure-specific functionality from mixin
|
||||
self._init_azure_base(api_key=api_key, region=region, voice=voice, params=params)
|
||||
|
||||
self._settings = {
|
||||
"emphasis": params.emphasis,
|
||||
"language": self.language_to_service_language(params.language)
|
||||
if params.language
|
||||
else "en-US",
|
||||
"pitch": params.pitch,
|
||||
"rate": params.rate,
|
||||
"role": params.role,
|
||||
"style": params.style,
|
||||
"style_degree": params.style_degree,
|
||||
"volume": params.volume,
|
||||
}
|
||||
|
||||
self._api_key = api_key
|
||||
self._region = region
|
||||
self._voice_id = voice
|
||||
self._speech_config = None
|
||||
self._speech_synthesizer = None
|
||||
self._audio_queue = asyncio.Queue()
|
||||
self._context_id = None
|
||||
self._word_timestamps_started = False
|
||||
self._synthesis_lock = asyncio.Lock()
|
||||
self._started = False
|
||||
self._cumulative_audio_offset: float = 0.0 # Cumulative audio duration in seconds
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
|
||||
Returns:
|
||||
True, as Azure TTS service supports metrics generation.
|
||||
"""
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert a Language enum to Azure language format.
|
||||
|
||||
Args:
|
||||
language: The language to convert.
|
||||
|
||||
Returns:
|
||||
The Azure-specific language code, or None if not supported.
|
||||
"""
|
||||
return language_to_azure_language(language)
|
||||
|
||||
def _construct_ssml(self, text: str) -> str:
|
||||
"""Construct SSML from text with current voice settings.
|
||||
|
||||
Args:
|
||||
text: The text to convert to SSML.
|
||||
|
||||
Returns:
|
||||
SSML string for Azure TTS synthesis.
|
||||
"""
|
||||
language = self._settings["language"]
|
||||
|
||||
# Escape special characters
|
||||
escaped_text = self._escape_text(text)
|
||||
|
||||
ssml = (
|
||||
f"<speak version='1.0' xml:lang='{language}' "
|
||||
"xmlns='http://www.w3.org/2001/10/synthesis' "
|
||||
"xmlns:mstts='http://www.w3.org/2001/mstts'>"
|
||||
f"<voice name='{self._voice_id}'>"
|
||||
"<mstts:silence type='Sentenceboundary' value='20ms' />"
|
||||
)
|
||||
|
||||
if self._settings["style"]:
|
||||
ssml += f"<mstts:express-as style='{self._settings['style']}'"
|
||||
if self._settings["style_degree"]:
|
||||
ssml += f" styledegree='{self._settings['style_degree']}'"
|
||||
if self._settings["role"]:
|
||||
ssml += f" role='{self._settings['role']}'"
|
||||
ssml += ">"
|
||||
|
||||
prosody_attrs = []
|
||||
if self._settings["rate"]:
|
||||
prosody_attrs.append(f"rate='{self._settings['rate']}'")
|
||||
if self._settings["pitch"]:
|
||||
prosody_attrs.append(f"pitch='{self._settings['pitch']}'")
|
||||
if self._settings["volume"]:
|
||||
prosody_attrs.append(f"volume='{self._settings['volume']}'")
|
||||
|
||||
ssml += f"<prosody {' '.join(prosody_attrs)}>"
|
||||
|
||||
if self._settings["emphasis"]:
|
||||
ssml += f"<emphasis level='{self._settings['emphasis']}'>"
|
||||
|
||||
ssml += escaped_text
|
||||
|
||||
if self._settings["emphasis"]:
|
||||
ssml += "</emphasis>"
|
||||
|
||||
ssml += "</prosody>"
|
||||
|
||||
if self._settings["style"]:
|
||||
ssml += "</mstts:express-as>"
|
||||
|
||||
ssml += "</voice></speak>"
|
||||
|
||||
return ssml
|
||||
|
||||
def _escape_text(self, text: str) -> str:
|
||||
"""Escapes XML/SSML reserved characters according to Microsoft documentation.
|
||||
|
||||
This method escapes the following characters:
|
||||
- & becomes &
|
||||
- < becomes <
|
||||
- > becomes >
|
||||
- " becomes "
|
||||
- ' becomes '
|
||||
|
||||
Args:
|
||||
text: The text to escape.
|
||||
|
||||
Returns:
|
||||
The escaped text.
|
||||
"""
|
||||
escaped_text = text
|
||||
for char, escape_code in AzureTTSService.SSML_ESCAPE_CHARS.items():
|
||||
escaped_text = escaped_text.replace(char, escape_code)
|
||||
return escaped_text
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the Azure TTS service and initialize speech synthesizer.
|
||||
|
||||
@@ -486,8 +337,8 @@ class AzureTTSService(WordTTSService):
|
||||
|
||||
# Queue the word timestamp for processing
|
||||
# Use put_nowait since this is a synchronous callback
|
||||
if self._context_id and word:
|
||||
logger.debug(f"{self}: Word boundary - '{word}' at {absolute_seconds:.2f}s")
|
||||
if word:
|
||||
logger.trace(f"{self}: Word boundary - '{word}' at {absolute_seconds:.2f}s")
|
||||
try:
|
||||
# Convert to nanoseconds and put directly in queue (sync operation)
|
||||
timestamp_ns = seconds_to_nanoseconds(absolute_seconds)
|
||||
@@ -515,14 +366,6 @@ class AzureTTSService(WordTTSService):
|
||||
self._cumulative_audio_offset += evt.result.audio_duration.total_seconds()
|
||||
|
||||
self._audio_queue.put_nowait(None) # Signal completion
|
||||
# Add completion markers to word timestamp queue
|
||||
# Use put_nowait since this is a synchronous callback
|
||||
if self._context_id:
|
||||
try:
|
||||
# Add TTSStoppedFrame marker but NOT Reset - we maintain cumulative PTS
|
||||
self._words_queue.put_nowait(("TTSStoppedFrame", 0))
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error finalizing word timestamps: {e}")
|
||||
|
||||
def _handle_canceled(self, evt):
|
||||
"""Handle synthesis cancellation.
|
||||
@@ -533,11 +376,22 @@ class AzureTTSService(WordTTSService):
|
||||
logger.error(f"Speech synthesis canceled: {evt.result.cancellation_details.reason}")
|
||||
self._audio_queue.put_nowait(None)
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
"""Push a frame and handle state changes.
|
||||
|
||||
Args:
|
||||
frame: The frame to push.
|
||||
direction: The direction to push the frame.
|
||||
"""
|
||||
await super().push_frame(frame, direction)
|
||||
if isinstance(frame, (TTSStoppedFrame, InterruptionFrame)):
|
||||
self._started = False
|
||||
if isinstance(frame, TTSStoppedFrame):
|
||||
await self.add_word_timestamps([("Reset", 0)])
|
||||
|
||||
async def flush_audio(self):
|
||||
"""Flush any pending audio data."""
|
||||
logger.trace(f"{self}: flushing audio")
|
||||
# Reset cumulative audio offset at end of LLM response
|
||||
self._cumulative_audio_offset = 0.0
|
||||
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
"""Handle interruption by stopping current synthesis.
|
||||
@@ -557,7 +411,6 @@ class AzureTTSService(WordTTSService):
|
||||
self._audio_queue.task_done()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
self._context_id = None
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
@@ -571,68 +424,58 @@ class AzureTTSService(WordTTSService):
|
||||
"""
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
# Ensure sequential sentence processing to prevent word boundary interleaving
|
||||
async with self._synthesis_lock:
|
||||
# Clear the audio queue in case there's still audio in it, causing the next audio response
|
||||
# to be cut off by the 'None' element returned at the end of the previous audio synthesis.
|
||||
# Empty the audio queue before processing the new text
|
||||
while not self._audio_queue.empty():
|
||||
self._audio_queue.get_nowait()
|
||||
self._audio_queue.task_done()
|
||||
# Clear the audio queue in case there's still audio in it, causing the next audio response
|
||||
# to be cut off by the 'None' element returned at the end of the previous audio synthesis.
|
||||
# Empty the audio queue before processing the new text
|
||||
while not self._audio_queue.empty():
|
||||
self._audio_queue.get_nowait()
|
||||
self._audio_queue.task_done()
|
||||
|
||||
try:
|
||||
if self._speech_synthesizer is None:
|
||||
error_msg = "Speech synthesizer not initialized."
|
||||
logger.error(error_msg)
|
||||
yield ErrorFrame(error=error_msg)
|
||||
return
|
||||
|
||||
try:
|
||||
if self._speech_synthesizer is None:
|
||||
error_msg = "Speech synthesizer not initialized."
|
||||
logger.error(error_msg)
|
||||
yield ErrorFrame(error=error_msg)
|
||||
return
|
||||
|
||||
try:
|
||||
if not self._started:
|
||||
await self.start_ttfb_metrics()
|
||||
await self.start_word_timestamps()
|
||||
yield TTSStartedFrame()
|
||||
self._started = True
|
||||
self._cumulative_audio_offset = 0.0
|
||||
|
||||
# Mark that we're starting a new synthesis
|
||||
self._context_id = str(id(text))
|
||||
self._word_timestamps_started = False
|
||||
ssml = self._construct_ssml(text)
|
||||
self._speech_synthesizer.speak_ssml_async(ssml)
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
ssml = self._construct_ssml(text)
|
||||
self._speech_synthesizer.speak_ssml_async(ssml)
|
||||
await self.start_tts_usage_metrics(text)
|
||||
# Stream audio chunks as they arrive
|
||||
while True:
|
||||
chunk = await self._audio_queue.get()
|
||||
if chunk is None: # End of stream
|
||||
break
|
||||
|
||||
# Stream audio chunks as they arrive
|
||||
while True:
|
||||
chunk = await self._audio_queue.get()
|
||||
if chunk is None: # End of stream
|
||||
break
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
# Start word timestamps only once when we receive first audio
|
||||
if not self._word_timestamps_started:
|
||||
await self.start_word_timestamps()
|
||||
self._word_timestamps_started = True
|
||||
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=chunk,
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
)
|
||||
yield frame
|
||||
|
||||
# Clear context ID when done
|
||||
self._context_id = None
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error during synthesis: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
# Could add reconnection logic here if needed
|
||||
return
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=chunk,
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
)
|
||||
yield frame
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
logger.error(f"{self} error during synthesis: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
self._started = False
|
||||
# Could add reconnection logic here if needed
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
|
||||
|
||||
class AzureHttpTTSService(AzureBaseTTSService):
|
||||
class AzureHttpTTSService(TTSService, AzureBaseTTSService):
|
||||
"""Azure Cognitive Services HTTP-based TTS service.
|
||||
|
||||
Provides text-to-speech synthesis using Azure's HTTP API for simpler,
|
||||
@@ -640,13 +483,31 @@ class AzureHttpTTSService(AzureBaseTTSService):
|
||||
required and simpler integration is preferred.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
region: str,
|
||||
voice: str = "en-US-SaraNeural",
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[AzureBaseTTSService.InputParams] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Azure HTTP TTS service.
|
||||
|
||||
Args:
|
||||
**kwargs: All arguments passed to AzureBaseTTSService parent class.
|
||||
api_key: Azure Cognitive Services subscription key.
|
||||
region: Azure region identifier (e.g., "eastus", "westus2").
|
||||
voice: Voice name to use for synthesis. Defaults to "en-US-SaraNeural".
|
||||
sample_rate: Audio sample rate in Hz. If None, uses service default.
|
||||
params: Voice and synthesis parameters configuration.
|
||||
**kwargs: Additional arguments passed to parent TTSService.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
# Initialize Azure-specific functionality from mixin
|
||||
self._init_azure_base(api_key=api_key, region=region, voice=voice, params=params)
|
||||
|
||||
self._speech_config = None
|
||||
self._speech_synthesizer = None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user