fix: route TTS audio through audio context queue in Fish, LMNT, Neuphonic, Rime NonJson
These services were pushing audio frames directly via push_frame() in their WebSocket receive loops, bypassing the base TTSService audio context serialization queue. This causes incorrect frame ordering and broken interruption handling. Changes per service: - Fish Audio: use append_to_audio_context(), replace _handle_interruption with on_audio_context_interrupted() - LMNT: use append_to_audio_context(), remove redundant push_frame override - Neuphonic: use append_to_audio_context(), remove redundant push_frame and process_frame overrides (base class handles pause/resume) - Rime NonJson: use append_to_audio_context(), remove redundant push_frame override
This commit is contained in:
@@ -21,12 +21,10 @@ from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven
|
||||
from pipecat.services.tts_service import InterruptibleTTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
@@ -362,8 +360,8 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
return self._websocket
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
await super()._handle_interruption(frame, direction)
|
||||
async def on_audio_context_interrupted(self, context_id: str):
|
||||
"""Stop all metrics when audio context is interrupted."""
|
||||
await self.stop_all_metrics()
|
||||
|
||||
async def _receive_messages(self):
|
||||
@@ -377,8 +375,14 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
audio_data = msg.get("audio")
|
||||
# Only process larger chunks to remove msgpack overhead
|
||||
if audio_data and len(audio_data) > 1024:
|
||||
frame = TTSAudioRawFrame(audio_data, self.sample_rate, 1)
|
||||
await self.push_frame(frame)
|
||||
context_id = self.get_active_audio_context_id()
|
||||
frame = TTSAudioRawFrame(
|
||||
audio_data,
|
||||
self.sample_rate,
|
||||
1,
|
||||
context_id=context_id,
|
||||
)
|
||||
await self.append_to_audio_context(context_id, frame)
|
||||
await self.stop_ttfb_metrics()
|
||||
elif event == "finish":
|
||||
reason = msg.get("reason", "unknown")
|
||||
|
||||
@@ -21,7 +21,6 @@ from pipecat.frames.frames import (
|
||||
TTSAudioRawFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.settings import TTSSettings
|
||||
from pipecat.services.tts_service import InterruptibleTTSService
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
@@ -212,15 +211,6 @@ class LmntTTSService(InterruptibleTTSService):
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
"""Push a frame downstream with special handling for stop conditions.
|
||||
|
||||
Args:
|
||||
frame: The frame to push.
|
||||
direction: The direction to push the frame.
|
||||
"""
|
||||
await super().push_frame(frame, direction)
|
||||
|
||||
async def _connect(self):
|
||||
"""Connect to LMNT WebSocket and start receive task."""
|
||||
await super()._connect()
|
||||
@@ -322,18 +312,22 @@ class LmntTTSService(InterruptibleTTSService):
|
||||
if isinstance(message, bytes):
|
||||
# Raw audio data
|
||||
await self.stop_ttfb_metrics()
|
||||
context_id = self.get_active_audio_context_id()
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=message,
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
context_id=self.get_active_audio_context_id(),
|
||||
context_id=context_id,
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
await self.append_to_audio_context(context_id, frame)
|
||||
else:
|
||||
try:
|
||||
msg = json.loads(message)
|
||||
if "error" in msg:
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
context_id = self.get_active_audio_context_id()
|
||||
await self.append_to_audio_context(
|
||||
context_id, TTSStoppedFrame(context_id=context_id)
|
||||
)
|
||||
await self.stop_all_metrics()
|
||||
await self.push_error(error_msg=f"Error: {msg['error']}")
|
||||
return
|
||||
|
||||
@@ -21,18 +21,14 @@ from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
LLMFullResponseEndFrame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSSpeakFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven
|
||||
from pipecat.services.tts_service import InterruptibleTTSService, TextAggregationMode, TTSService
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
@@ -180,6 +176,7 @@ class NeuphonicTTSService(InterruptibleTTSService):
|
||||
text_aggregation_mode=text_aggregation_mode,
|
||||
push_stop_frames=True,
|
||||
push_start_frame=True,
|
||||
pause_frame_processing=True,
|
||||
stop_frame_timeout_s=2.0,
|
||||
sample_rate=sample_rate,
|
||||
settings=default_settings,
|
||||
@@ -254,34 +251,6 @@ class NeuphonicTTSService(InterruptibleTTSService):
|
||||
msg = {"text": "<STOP>"}
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
"""Push a frame downstream with special handling for stop conditions.
|
||||
|
||||
Args:
|
||||
frame: The frame to push.
|
||||
direction: The direction to push the frame.
|
||||
"""
|
||||
await super().push_frame(frame, direction)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process frames with special handling for speech control.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame processing.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# If we received a TTSSpeakFrame and the LLM response included text (it
|
||||
# might be that it's only a function calling response) we pause
|
||||
# processing more frames until we receive a BotStoppedSpeakingFrame.
|
||||
if isinstance(frame, TTSSpeakFrame):
|
||||
await self.pause_processing_frames()
|
||||
elif isinstance(frame, LLMFullResponseEndFrame):
|
||||
await self.pause_processing_frames()
|
||||
elif isinstance(frame, BotStoppedSpeakingFrame):
|
||||
await self.resume_processing_frames()
|
||||
|
||||
async def _connect(self):
|
||||
"""Connect to Neuphonic WebSocket and start background tasks."""
|
||||
await super()._connect()
|
||||
@@ -366,10 +335,14 @@ class NeuphonicTTSService(InterruptibleTTSService):
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
audio = base64.b64decode(msg["data"]["audio"])
|
||||
context_id = self.get_active_audio_context_id()
|
||||
frame = TTSAudioRawFrame(
|
||||
audio, self.sample_rate, 1, context_id=self.get_active_audio_context_id()
|
||||
audio,
|
||||
self.sample_rate,
|
||||
1,
|
||||
context_id=context_id,
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
await self.append_to_audio_context(context_id, frame)
|
||||
|
||||
async def _keepalive_task_handler(self):
|
||||
"""Handle keepalive messages to maintain WebSocket connection."""
|
||||
|
||||
@@ -1054,15 +1054,6 @@ class RimeNonJsonTTSService(InterruptibleTTSService):
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
"""Push a frame downstream with special handling for stop conditions.
|
||||
|
||||
Args:
|
||||
frame: The frame to push.
|
||||
direction: The direction to push the frame.
|
||||
"""
|
||||
await super().push_frame(frame, direction)
|
||||
|
||||
async def _connect(self):
|
||||
"""Establish WebSocket connection and start receive task."""
|
||||
await super()._connect()
|
||||
@@ -1153,13 +1144,14 @@ class RimeNonJsonTTSService(InterruptibleTTSService):
|
||||
if isinstance(message, bytes):
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
context_id = self.get_active_audio_context_id()
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=message,
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
context_id=self.get_active_audio_context_id(),
|
||||
context_id=context_id,
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
await self.append_to_audio_context(context_id, frame)
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Error: {e}", exception=e)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user