diff --git a/changelog/4082.fixed.md b/changelog/4082.fixed.md new file mode 100644 index 000000000..e17e84fea --- /dev/null +++ b/changelog/4082.fixed.md @@ -0,0 +1 @@ +- Fixed `SarvamTTSService` audio and error frames now route through `append_to_audio_context()` instead of `push_frame()`, ensuring correct behavior with audio contexts and interruptions. diff --git a/src/pipecat/services/sarvam/tts.py b/src/pipecat/services/sarvam/tts.py index c926270de..8bfeea8c6 100644 --- a/src/pipecat/services/sarvam/tts.py +++ b/src/pipecat/services/sarvam/tts.py @@ -1031,23 +1031,6 @@ class SarvamTTSService(InterruptibleTTSService): except Exception as e: await self.push_error(error_msg=f"Error sending flush to Sarvam: {e}", exception=e) - async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM): - """Push a frame downstream with special handling for stop conditions. - - Args: - frame: The frame to push. - direction: The direction to push the frame. - """ - await super().push_frame(frame, direction) - - async def process_frame(self, frame: Frame, direction: FrameDirection): - """Process a frame and flush audio if it's the end of a full response.""" - await super().process_frame(frame, direction) - - # When the LLM finishes responding, flush any remaining text in Sarvam's buffer - if isinstance(frame, (LLMFullResponseEndFrame, EndFrame)): - await self.flush_audio() - async def _update_settings(self, delta: TTSSettings) -> dict[str, Any]: """Apply a settings delta and resend config if voice changed.""" changed = await super()._update_settings(delta) @@ -1168,14 +1151,13 @@ class SarvamTTSService(InterruptibleTTSService): async for message in self._get_websocket(): if isinstance(message, str): msg = json.loads(message) + context_id = self.get_active_audio_context_id() if msg.get("type") == "audio": # Check for interruption before processing audio await self.stop_ttfb_metrics() audio = base64.b64decode(msg["data"]["audio"]) - frame = TTSAudioRawFrame( - audio, self.sample_rate, 1, context_id=self.get_active_audio_context_id() - ) - await self.push_frame(frame) + frame = TTSAudioRawFrame(audio, self.sample_rate, 1, context_id=context_id) + await self.append_to_audio_context(context_id, frame) elif msg.get("type") == "error": error_msg = msg["data"]["message"] await self.push_error(error_msg=f"TTS Error: {error_msg}") @@ -1183,8 +1165,9 @@ class SarvamTTSService(InterruptibleTTSService): # If it's a timeout error, the connection might need to be reset if "too long" in error_msg.lower() or "timeout" in error_msg.lower(): logger.warning("Connection timeout detected, service may need restart") - - await self.push_frame(ErrorFrame(error=f"TTS Error: {error_msg}")) + await self.append_to_audio_context( + context_id, ErrorFrame(error=f"TTS Error: {error_msg}") + ) async def _keepalive_task_handler(self): """Handle keepalive messages to maintain WebSocket connection."""