diff --git a/src/pipecat/services/azure.py b/src/pipecat/services/azure.py index 551fa40a8..56dde30c1 100644 --- a/src/pipecat/services/azure.py +++ b/src/pipecat/services/azure.py @@ -25,8 +25,14 @@ from pipecat.frames.frames import ( TTSStoppedFrame, URLImageRawFrame, ) +from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext from pipecat.services.ai_services import ImageGenService, STTService, TTSService -from pipecat.services.openai import BaseOpenAILLMService +from pipecat.services.openai import ( + BaseOpenAILLMService, + OpenAIAssistantContextAggregator, + OpenAIContextAggregatorPair, + OpenAIUserContextAggregator, +) from pipecat.transcriptions.language import Language from pipecat.utils.time import time_now_iso8601 @@ -38,6 +44,7 @@ try: SpeechConfig, SpeechRecognizer, SpeechSynthesizer, + SpeechSynthesisOutputFormat, ) from azure.cognitiveservices.speech.audio import ( AudioStreamFormat, @@ -70,6 +77,33 @@ class AzureLLMService(BaseOpenAILLMService): api_version=self._api_version, ) + @staticmethod + def create_context_aggregator( + context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True + ) -> OpenAIContextAggregatorPair: + user = OpenAIUserContextAggregator(context) + assistant = OpenAIAssistantContextAggregator( + user, expect_stripped_words=assistant_expect_stripped_words + ) + return OpenAIContextAggregatorPair(_user=user, _assistant=assistant) + + +def sample_rate_to_output_format(sample_rate: int) -> SpeechSynthesisOutputFormat: + match sample_rate: + case 8000: + return SpeechSynthesisOutputFormat.Raw8Khz16BitMonoPcm + case 16000: + return SpeechSynthesisOutputFormat.Raw16Khz16BitMonoPcm + case 22050: + return SpeechSynthesisOutputFormat.Raw22050Hz16BitMonoPcm + case 24000: + return SpeechSynthesisOutputFormat.Raw24Khz16BitMonoPcm + case 44100: + return SpeechSynthesisOutputFormat.Raw44100Hz16BitMonoPcm + case 48000: + return SpeechSynthesisOutputFormat.Raw48Khz16BitMonoPcm + return SpeechSynthesisOutputFormat.Raw16Khz16BitMonoPcm + class AzureTTSService(TTSService): class InputParams(BaseModel): @@ -95,6 +129,8 @@ class AzureTTSService(TTSService): super().__init__(sample_rate=sample_rate, **kwargs) speech_config = SpeechConfig(subscription=api_key, region=region) + speech_config.set_speech_synthesis_output_format(sample_rate_to_output_format(sample_rate)) + self._speech_synthesizer = SpeechSynthesizer(speech_config=speech_config, audio_config=None) self._settings = {