Merge pull request #4082 from pipecat-ai/filipi/sarvam_tts_improvements
Improvements to SarvamTTSService.
This commit is contained in:
1
changelog/4082.fixed.md
Normal file
1
changelog/4082.fixed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Fixed `SarvamTTSService` audio and error frames now route through `append_to_audio_context()` instead of `push_frame()`, ensuring correct behavior with audio contexts and interruptions.
|
||||
@@ -1031,23 +1031,6 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Error sending flush to Sarvam: {e}", exception=e)
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
"""Push a frame downstream with special handling for stop conditions.
|
||||
|
||||
Args:
|
||||
frame: The frame to push.
|
||||
direction: The direction to push the frame.
|
||||
"""
|
||||
await super().push_frame(frame, direction)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process a frame and flush audio if it's the end of a full response."""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# When the LLM finishes responding, flush any remaining text in Sarvam's buffer
|
||||
if isinstance(frame, (LLMFullResponseEndFrame, EndFrame)):
|
||||
await self.flush_audio()
|
||||
|
||||
async def _update_settings(self, delta: TTSSettings) -> dict[str, Any]:
|
||||
"""Apply a settings delta and resend config if voice changed."""
|
||||
changed = await super()._update_settings(delta)
|
||||
@@ -1168,14 +1151,13 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
async for message in self._get_websocket():
|
||||
if isinstance(message, str):
|
||||
msg = json.loads(message)
|
||||
context_id = self.get_active_audio_context_id()
|
||||
if msg.get("type") == "audio":
|
||||
# Check for interruption before processing audio
|
||||
await self.stop_ttfb_metrics()
|
||||
audio = base64.b64decode(msg["data"]["audio"])
|
||||
frame = TTSAudioRawFrame(
|
||||
audio, self.sample_rate, 1, context_id=self.get_active_audio_context_id()
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
frame = TTSAudioRawFrame(audio, self.sample_rate, 1, context_id=context_id)
|
||||
await self.append_to_audio_context(context_id, frame)
|
||||
elif msg.get("type") == "error":
|
||||
error_msg = msg["data"]["message"]
|
||||
await self.push_error(error_msg=f"TTS Error: {error_msg}")
|
||||
@@ -1183,8 +1165,9 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
# If it's a timeout error, the connection might need to be reset
|
||||
if "too long" in error_msg.lower() or "timeout" in error_msg.lower():
|
||||
logger.warning("Connection timeout detected, service may need restart")
|
||||
|
||||
await self.push_frame(ErrorFrame(error=f"TTS Error: {error_msg}"))
|
||||
await self.append_to_audio_context(
|
||||
context_id, ErrorFrame(error=f"TTS Error: {error_msg}")
|
||||
)
|
||||
|
||||
async def _keepalive_task_handler(self):
|
||||
"""Handle keepalive messages to maintain WebSocket connection."""
|
||||
|
||||
Reference in New Issue
Block a user