Merge pull request #4082 from pipecat-ai/filipi/sarvam_tts_improvements

Improvements to SarvamTTSService.
This commit is contained in:
Filipi da Silva Fuchter
2026-03-20 10:28:02 -04:00
committed by GitHub
2 changed files with 7 additions and 23 deletions

1
changelog/4082.fixed.md Normal file
View File

@@ -0,0 +1 @@
- Fixed `SarvamTTSService` audio and error frames now route through `append_to_audio_context()` instead of `push_frame()`, ensuring correct behavior with audio contexts and interruptions.

View File

@@ -1031,23 +1031,6 @@ class SarvamTTSService(InterruptibleTTSService):
except Exception as e:
await self.push_error(error_msg=f"Error sending flush to Sarvam: {e}", exception=e)
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
"""Push a frame downstream with special handling for stop conditions.
Args:
frame: The frame to push.
direction: The direction to push the frame.
"""
await super().push_frame(frame, direction)
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process a frame and flush audio if it's the end of a full response."""
await super().process_frame(frame, direction)
# When the LLM finishes responding, flush any remaining text in Sarvam's buffer
if isinstance(frame, (LLMFullResponseEndFrame, EndFrame)):
await self.flush_audio()
async def _update_settings(self, delta: TTSSettings) -> dict[str, Any]:
"""Apply a settings delta and resend config if voice changed."""
changed = await super()._update_settings(delta)
@@ -1168,14 +1151,13 @@ class SarvamTTSService(InterruptibleTTSService):
async for message in self._get_websocket():
if isinstance(message, str):
msg = json.loads(message)
context_id = self.get_active_audio_context_id()
if msg.get("type") == "audio":
# Check for interruption before processing audio
await self.stop_ttfb_metrics()
audio = base64.b64decode(msg["data"]["audio"])
frame = TTSAudioRawFrame(
audio, self.sample_rate, 1, context_id=self.get_active_audio_context_id()
)
await self.push_frame(frame)
frame = TTSAudioRawFrame(audio, self.sample_rate, 1, context_id=context_id)
await self.append_to_audio_context(context_id, frame)
elif msg.get("type") == "error":
error_msg = msg["data"]["message"]
await self.push_error(error_msg=f"TTS Error: {error_msg}")
@@ -1183,8 +1165,9 @@ class SarvamTTSService(InterruptibleTTSService):
# If it's a timeout error, the connection might need to be reset
if "too long" in error_msg.lower() or "timeout" in error_msg.lower():
logger.warning("Connection timeout detected, service may need restart")
await self.push_frame(ErrorFrame(error=f"TTS Error: {error_msg}"))
await self.append_to_audio_context(
context_id, ErrorFrame(error=f"TTS Error: {error_msg}")
)
async def _keepalive_task_handler(self):
"""Handle keepalive messages to maintain WebSocket connection."""