diff --git a/src/pipecat/services/groq/tts.py b/src/pipecat/services/groq/tts.py index 6f73b1629..33fd3ce34 100644 --- a/src/pipecat/services/groq/tts.py +++ b/src/pipecat/services/groq/tts.py @@ -4,6 +4,8 @@ # SPDX-License-Identifier: BSD 2-Clause License # +import io +import wave from typing import AsyncGenerator, Optional from loguru import logger @@ -78,22 +80,26 @@ class GroqTTSService(TTSService): await self.start_ttfb_metrics() yield TTSStartedFrame() - response = await self._client.audio.speech.create( - model=self._model_name, - voice=self._voice_id, - response_format=self._output_format, - input=text, - ) + try: + response = await self._client.audio.speech.create( + model=self._model_name, + voice=self._voice_id, + response_format=self._output_format, + input=text, + ) - async for data in response.iter_bytes(): - if measuring_ttfb: - await self.stop_ttfb_metrics() - measuring_ttfb = False - # remove wav header if present - if data.startswith(b"RIFF"): - data = data[44:] - if len(data) == 0: - continue - yield TTSAudioRawFrame(data, self.sample_rate, 1) + async for data in response.iter_bytes(): + if measuring_ttfb: + await self.stop_ttfb_metrics() + measuring_ttfb = False + + with wave.open(io.BytesIO(data)) as w: + channels = w.getnchannels() + frame_rate = w.getframerate() + num_frames = w.getnframes() + bytes = w.readframes(num_frames) + yield TTSAudioRawFrame(bytes, frame_rate, channels) + except Exception as e: + logger.error(f"{self} exception: {e}") yield TTSStoppedFrame()