From f68b3222b323dc7fd2d94518a02ba2d4634ec8da Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Tue, 24 Mar 2026 11:46:28 -0400 Subject: [PATCH] Fix SmallestTTSService to use InterruptibleTTSService audio context system MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Route audio through audio contexts (append_to_audio_context) instead of pushing frames directly, enabling proper turn management and interruptions - Add push_stop_frames and push_start_frame so the base class handles TTSStartedFrame/TTSStoppedFrame lifecycle - Remove manual context_id tracking (self._context_id) in favor of get_active_audio_context_id() - Don't call remove_audio_context on "complete" — Smallest sends one per request, not per turn; let the base class timeout handle cleanup - Guard v2-only params (consistency, similarity, enhancement) so they aren't sent to lightning-v3.1 - Remove request_id from request payload (not a documented request field) - Add flush_audio override to send flush to WebSocket --- src/pipecat/services/smallest/tts.py | 49 +++++++++++++--------------- 1 file changed, 23 insertions(+), 26 deletions(-) diff --git a/src/pipecat/services/smallest/tts.py b/src/pipecat/services/smallest/tts.py index eca73a8e0..15d107ffd 100644 --- a/src/pipecat/services/smallest/tts.py +++ b/src/pipecat/services/smallest/tts.py @@ -26,7 +26,6 @@ from pipecat.frames.frames import ( Frame, StartFrame, TTSAudioRawFrame, - TTSStartedFrame, TTSStoppedFrame, ) from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven @@ -152,8 +151,8 @@ class SmallestTTSService(InterruptibleTTSService): default_settings.apply_update(settings) super().__init__( - aggregate_sentences=True, - push_text_frames=True, + push_stop_frames=True, + push_start_frame=True, pause_frame_processing=True, sample_rate=sample_rate, settings=default_settings, @@ -164,7 +163,6 @@ class SmallestTTSService(InterruptibleTTSService): self._base_url = base_url.rstrip("/") self._receive_task = None self._keepalive_task = None - self._context_id: Optional[str] = None def can_generate_metrics(self) -> bool: """Check if this service can generate processing metrics. @@ -203,15 +201,15 @@ class SmallestTTSService(InterruptibleTTSService): if self._settings.speed is not None: msg["speed"] = self._settings.speed - if self._settings.consistency is not None: - msg["consistency"] = self._settings.consistency - if self._settings.similarity is not None: - msg["similarity"] = self._settings.similarity - if self._settings.enhancement is not None: - msg["enhancement"] = self._settings.enhancement - if self._context_id: - msg["request_id"] = self._context_id + # consistency, similarity, enhancement are only supported by lightning-v2 + if self._settings.model == SmallestTTSModel.LIGHTNING_V2.value: + if self._settings.consistency is not None: + msg["consistency"] = self._settings.consistency + if self._settings.similarity is not None: + msg["similarity"] = self._settings.similarity + if self._settings.enhancement is not None: + msg["enhancement"] = self._settings.enhancement return msg @@ -322,7 +320,6 @@ class SmallestTTSService(InterruptibleTTSService): error_msg=f"Smallest TTS error closing websocket: {e}", exception=e ) finally: - self._context_id = None self._websocket = None await self._call_event_handler("on_disconnected") @@ -352,6 +349,12 @@ class SmallestTTSService(InterruptibleTTSService): msg = {"flush": True} await self._websocket.send(json.dumps(msg)) + async def flush_audio(self, context_id: Optional[str] = None): + """Flush any pending audio synthesis.""" + if not self._websocket or self._websocket.state is State.CLOSED: + return + await self._get_websocket().send(json.dumps({"flush": True})) + async def _receive_messages(self): """Receive and process messages from the Smallest WebSocket API.""" async for message in self._get_websocket(): @@ -359,25 +362,22 @@ class SmallestTTSService(InterruptibleTTSService): status = msg.get("status") if status == "complete": - msg_request_id = msg.get("request_id") - if self._context_id and msg_request_id and msg_request_id == self._context_id: - await self.stop_all_metrics() - await self.push_frame(TTSStoppedFrame(context_id=self._context_id)) - self._context_id = None + await self.stop_all_metrics() elif status == "chunk": await self.stop_ttfb_metrics() + context_id = self.get_active_audio_context_id() frame = TTSAudioRawFrame( audio=base64.b64decode(msg["data"]["audio"]), sample_rate=self.sample_rate, num_channels=1, - context_id=self._context_id, + context_id=context_id, ) - await self.push_frame(frame) + await self.append_to_audio_context(context_id, frame) elif status == "error": - await self.push_frame(TTSStoppedFrame(context_id=self._context_id)) + context_id = self.get_active_audio_context_id() + await self.push_frame(TTSStoppedFrame(context_id=context_id)) await self.stop_all_metrics() await self.push_error(error_msg=f"Smallest TTS error: {msg.get('error', msg)}") - self._context_id = None else: logger.warning(f"{self} unknown message status: {msg}") @@ -390,7 +390,7 @@ class SmallestTTSService(InterruptibleTTSService): context_id: Unique identifier for this TTS context. Yields: - Frame: TTSStartedFrame to signal start; audio arrives via WebSocket. + Frame: Audio arrives via WebSocket receive task. """ logger.debug(f"{self}: Generating TTS [{text}]") @@ -399,9 +399,6 @@ class SmallestTTSService(InterruptibleTTSService): await self._connect() try: - self._context_id = context_id - yield TTSStartedFrame(context_id=context_id) - msg = self._build_msg(text=text) await self._get_websocket().send(json.dumps(msg)) await self.start_tts_usage_metrics(text)