services(azure): support sample rates
This commit is contained in:
@@ -25,8 +25,14 @@ from pipecat.frames.frames import (
|
||||
TTSStoppedFrame,
|
||||
URLImageRawFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.ai_services import ImageGenService, STTService, TTSService
|
||||
from pipecat.services.openai import BaseOpenAILLMService
|
||||
from pipecat.services.openai import (
|
||||
BaseOpenAILLMService,
|
||||
OpenAIAssistantContextAggregator,
|
||||
OpenAIContextAggregatorPair,
|
||||
OpenAIUserContextAggregator,
|
||||
)
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
@@ -38,6 +44,7 @@ try:
|
||||
SpeechConfig,
|
||||
SpeechRecognizer,
|
||||
SpeechSynthesizer,
|
||||
SpeechSynthesisOutputFormat,
|
||||
)
|
||||
from azure.cognitiveservices.speech.audio import (
|
||||
AudioStreamFormat,
|
||||
@@ -70,6 +77,33 @@ class AzureLLMService(BaseOpenAILLMService):
|
||||
api_version=self._api_version,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_context_aggregator(
|
||||
context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True
|
||||
) -> OpenAIContextAggregatorPair:
|
||||
user = OpenAIUserContextAggregator(context)
|
||||
assistant = OpenAIAssistantContextAggregator(
|
||||
user, expect_stripped_words=assistant_expect_stripped_words
|
||||
)
|
||||
return OpenAIContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
|
||||
def sample_rate_to_output_format(sample_rate: int) -> SpeechSynthesisOutputFormat:
|
||||
match sample_rate:
|
||||
case 8000:
|
||||
return SpeechSynthesisOutputFormat.Raw8Khz16BitMonoPcm
|
||||
case 16000:
|
||||
return SpeechSynthesisOutputFormat.Raw16Khz16BitMonoPcm
|
||||
case 22050:
|
||||
return SpeechSynthesisOutputFormat.Raw22050Hz16BitMonoPcm
|
||||
case 24000:
|
||||
return SpeechSynthesisOutputFormat.Raw24Khz16BitMonoPcm
|
||||
case 44100:
|
||||
return SpeechSynthesisOutputFormat.Raw44100Hz16BitMonoPcm
|
||||
case 48000:
|
||||
return SpeechSynthesisOutputFormat.Raw48Khz16BitMonoPcm
|
||||
return SpeechSynthesisOutputFormat.Raw16Khz16BitMonoPcm
|
||||
|
||||
|
||||
class AzureTTSService(TTSService):
|
||||
class InputParams(BaseModel):
|
||||
@@ -95,6 +129,8 @@ class AzureTTSService(TTSService):
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
speech_config = SpeechConfig(subscription=api_key, region=region)
|
||||
speech_config.set_speech_synthesis_output_format(sample_rate_to_output_format(sample_rate))
|
||||
|
||||
self._speech_synthesizer = SpeechSynthesizer(speech_config=speech_config, audio_config=None)
|
||||
|
||||
self._settings = {
|
||||
|
||||
Reference in New Issue
Block a user