diff --git a/src/pipecat/services/smallest/tts.py b/src/pipecat/services/smallest/tts.py index eca73a8e0..15d107ffd 100644 --- a/src/pipecat/services/smallest/tts.py +++ b/src/pipecat/services/smallest/tts.py @@ -26,7 +26,6 @@ from pipecat.frames.frames import ( Frame, StartFrame, TTSAudioRawFrame, - TTSStartedFrame, TTSStoppedFrame, ) from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven @@ -152,8 +151,8 @@ class SmallestTTSService(InterruptibleTTSService): default_settings.apply_update(settings) super().__init__( - aggregate_sentences=True, - push_text_frames=True, + push_stop_frames=True, + push_start_frame=True, pause_frame_processing=True, sample_rate=sample_rate, settings=default_settings, @@ -164,7 +163,6 @@ class SmallestTTSService(InterruptibleTTSService): self._base_url = base_url.rstrip("/") self._receive_task = None self._keepalive_task = None - self._context_id: Optional[str] = None def can_generate_metrics(self) -> bool: """Check if this service can generate processing metrics. @@ -203,15 +201,15 @@ class SmallestTTSService(InterruptibleTTSService): if self._settings.speed is not None: msg["speed"] = self._settings.speed - if self._settings.consistency is not None: - msg["consistency"] = self._settings.consistency - if self._settings.similarity is not None: - msg["similarity"] = self._settings.similarity - if self._settings.enhancement is not None: - msg["enhancement"] = self._settings.enhancement - if self._context_id: - msg["request_id"] = self._context_id + # consistency, similarity, enhancement are only supported by lightning-v2 + if self._settings.model == SmallestTTSModel.LIGHTNING_V2.value: + if self._settings.consistency is not None: + msg["consistency"] = self._settings.consistency + if self._settings.similarity is not None: + msg["similarity"] = self._settings.similarity + if self._settings.enhancement is not None: + msg["enhancement"] = self._settings.enhancement return msg @@ -322,7 +320,6 @@ class SmallestTTSService(InterruptibleTTSService): error_msg=f"Smallest TTS error closing websocket: {e}", exception=e ) finally: - self._context_id = None self._websocket = None await self._call_event_handler("on_disconnected") @@ -352,6 +349,12 @@ class SmallestTTSService(InterruptibleTTSService): msg = {"flush": True} await self._websocket.send(json.dumps(msg)) + async def flush_audio(self, context_id: Optional[str] = None): + """Flush any pending audio synthesis.""" + if not self._websocket or self._websocket.state is State.CLOSED: + return + await self._get_websocket().send(json.dumps({"flush": True})) + async def _receive_messages(self): """Receive and process messages from the Smallest WebSocket API.""" async for message in self._get_websocket(): @@ -359,25 +362,22 @@ class SmallestTTSService(InterruptibleTTSService): status = msg.get("status") if status == "complete": - msg_request_id = msg.get("request_id") - if self._context_id and msg_request_id and msg_request_id == self._context_id: - await self.stop_all_metrics() - await self.push_frame(TTSStoppedFrame(context_id=self._context_id)) - self._context_id = None + await self.stop_all_metrics() elif status == "chunk": await self.stop_ttfb_metrics() + context_id = self.get_active_audio_context_id() frame = TTSAudioRawFrame( audio=base64.b64decode(msg["data"]["audio"]), sample_rate=self.sample_rate, num_channels=1, - context_id=self._context_id, + context_id=context_id, ) - await self.push_frame(frame) + await self.append_to_audio_context(context_id, frame) elif status == "error": - await self.push_frame(TTSStoppedFrame(context_id=self._context_id)) + context_id = self.get_active_audio_context_id() + await self.push_frame(TTSStoppedFrame(context_id=context_id)) await self.stop_all_metrics() await self.push_error(error_msg=f"Smallest TTS error: {msg.get('error', msg)}") - self._context_id = None else: logger.warning(f"{self} unknown message status: {msg}") @@ -390,7 +390,7 @@ class SmallestTTSService(InterruptibleTTSService): context_id: Unique identifier for this TTS context. Yields: - Frame: TTSStartedFrame to signal start; audio arrives via WebSocket. + Frame: Audio arrives via WebSocket receive task. """ logger.debug(f"{self}: Generating TTS [{text}]") @@ -399,9 +399,6 @@ class SmallestTTSService(InterruptibleTTSService): await self._connect() try: - self._context_id = context_id - yield TTSStartedFrame(context_id=context_id) - msg = self._build_msg(text=text) await self._get_websocket().send(json.dumps(msg)) await self.start_tts_usage_metrics(text)