services(azure): support sample rates

This commit is contained in:
Aleix Conchillo Flaqué
2024-10-24 16:31:28 -07:00
parent 6d317c6e8e
commit cfb48200c2

View File

@@ -25,8 +25,14 @@ from pipecat.frames.frames import (
TTSStoppedFrame,
URLImageRawFrame,
)
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.ai_services import ImageGenService, STTService, TTSService
from pipecat.services.openai import BaseOpenAILLMService
from pipecat.services.openai import (
BaseOpenAILLMService,
OpenAIAssistantContextAggregator,
OpenAIContextAggregatorPair,
OpenAIUserContextAggregator,
)
from pipecat.transcriptions.language import Language
from pipecat.utils.time import time_now_iso8601
@@ -38,6 +44,7 @@ try:
SpeechConfig,
SpeechRecognizer,
SpeechSynthesizer,
SpeechSynthesisOutputFormat,
)
from azure.cognitiveservices.speech.audio import (
AudioStreamFormat,
@@ -70,6 +77,33 @@ class AzureLLMService(BaseOpenAILLMService):
api_version=self._api_version,
)
@staticmethod
def create_context_aggregator(
context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True
) -> OpenAIContextAggregatorPair:
user = OpenAIUserContextAggregator(context)
assistant = OpenAIAssistantContextAggregator(
user, expect_stripped_words=assistant_expect_stripped_words
)
return OpenAIContextAggregatorPair(_user=user, _assistant=assistant)
def sample_rate_to_output_format(sample_rate: int) -> SpeechSynthesisOutputFormat:
match sample_rate:
case 8000:
return SpeechSynthesisOutputFormat.Raw8Khz16BitMonoPcm
case 16000:
return SpeechSynthesisOutputFormat.Raw16Khz16BitMonoPcm
case 22050:
return SpeechSynthesisOutputFormat.Raw22050Hz16BitMonoPcm
case 24000:
return SpeechSynthesisOutputFormat.Raw24Khz16BitMonoPcm
case 44100:
return SpeechSynthesisOutputFormat.Raw44100Hz16BitMonoPcm
case 48000:
return SpeechSynthesisOutputFormat.Raw48Khz16BitMonoPcm
return SpeechSynthesisOutputFormat.Raw16Khz16BitMonoPcm
class AzureTTSService(TTSService):
class InputParams(BaseModel):
@@ -95,6 +129,8 @@ class AzureTTSService(TTSService):
super().__init__(sample_rate=sample_rate, **kwargs)
speech_config = SpeechConfig(subscription=api_key, region=region)
speech_config.set_speech_synthesis_output_format(sample_rate_to_output_format(sample_rate))
self._speech_synthesizer = SpeechSynthesizer(speech_config=speech_config, audio_config=None)
self._settings = {