Fix SmallestTTSService to use InterruptibleTTSService audio context system

- Route audio through audio contexts (append_to_audio_context) instead of
  pushing frames directly, enabling proper turn management and interruptions
- Add push_stop_frames and push_start_frame so the base class handles
  TTSStartedFrame/TTSStoppedFrame lifecycle
- Remove manual context_id tracking (self._context_id) in favor of
  get_active_audio_context_id()
- Don't call remove_audio_context on "complete" — Smallest sends one
  per request, not per turn; let the base class timeout handle cleanup
- Guard v2-only params (consistency, similarity, enhancement) so they
  aren't sent to lightning-v3.1
- Remove request_id from request payload (not a documented request field)
- Add flush_audio override to send flush to WebSocket
This commit is contained in:
Mark Backman
2026-03-24 11:46:28 -04:00
parent 51d28b4a9f
commit f68b3222b3

View File

@@ -26,7 +26,6 @@ from pipecat.frames.frames import (
Frame,
StartFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven
@@ -152,8 +151,8 @@ class SmallestTTSService(InterruptibleTTSService):
default_settings.apply_update(settings)
super().__init__(
aggregate_sentences=True,
push_text_frames=True,
push_stop_frames=True,
push_start_frame=True,
pause_frame_processing=True,
sample_rate=sample_rate,
settings=default_settings,
@@ -164,7 +163,6 @@ class SmallestTTSService(InterruptibleTTSService):
self._base_url = base_url.rstrip("/")
self._receive_task = None
self._keepalive_task = None
self._context_id: Optional[str] = None
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.
@@ -203,15 +201,15 @@ class SmallestTTSService(InterruptibleTTSService):
if self._settings.speed is not None:
msg["speed"] = self._settings.speed
if self._settings.consistency is not None:
msg["consistency"] = self._settings.consistency
if self._settings.similarity is not None:
msg["similarity"] = self._settings.similarity
if self._settings.enhancement is not None:
msg["enhancement"] = self._settings.enhancement
if self._context_id:
msg["request_id"] = self._context_id
# consistency, similarity, enhancement are only supported by lightning-v2
if self._settings.model == SmallestTTSModel.LIGHTNING_V2.value:
if self._settings.consistency is not None:
msg["consistency"] = self._settings.consistency
if self._settings.similarity is not None:
msg["similarity"] = self._settings.similarity
if self._settings.enhancement is not None:
msg["enhancement"] = self._settings.enhancement
return msg
@@ -322,7 +320,6 @@ class SmallestTTSService(InterruptibleTTSService):
error_msg=f"Smallest TTS error closing websocket: {e}", exception=e
)
finally:
self._context_id = None
self._websocket = None
await self._call_event_handler("on_disconnected")
@@ -352,6 +349,12 @@ class SmallestTTSService(InterruptibleTTSService):
msg = {"flush": True}
await self._websocket.send(json.dumps(msg))
async def flush_audio(self, context_id: Optional[str] = None):
"""Flush any pending audio synthesis."""
if not self._websocket or self._websocket.state is State.CLOSED:
return
await self._get_websocket().send(json.dumps({"flush": True}))
async def _receive_messages(self):
"""Receive and process messages from the Smallest WebSocket API."""
async for message in self._get_websocket():
@@ -359,25 +362,22 @@ class SmallestTTSService(InterruptibleTTSService):
status = msg.get("status")
if status == "complete":
msg_request_id = msg.get("request_id")
if self._context_id and msg_request_id and msg_request_id == self._context_id:
await self.stop_all_metrics()
await self.push_frame(TTSStoppedFrame(context_id=self._context_id))
self._context_id = None
await self.stop_all_metrics()
elif status == "chunk":
await self.stop_ttfb_metrics()
context_id = self.get_active_audio_context_id()
frame = TTSAudioRawFrame(
audio=base64.b64decode(msg["data"]["audio"]),
sample_rate=self.sample_rate,
num_channels=1,
context_id=self._context_id,
context_id=context_id,
)
await self.push_frame(frame)
await self.append_to_audio_context(context_id, frame)
elif status == "error":
await self.push_frame(TTSStoppedFrame(context_id=self._context_id))
context_id = self.get_active_audio_context_id()
await self.push_frame(TTSStoppedFrame(context_id=context_id))
await self.stop_all_metrics()
await self.push_error(error_msg=f"Smallest TTS error: {msg.get('error', msg)}")
self._context_id = None
else:
logger.warning(f"{self} unknown message status: {msg}")
@@ -390,7 +390,7 @@ class SmallestTTSService(InterruptibleTTSService):
context_id: Unique identifier for this TTS context.
Yields:
Frame: TTSStartedFrame to signal start; audio arrives via WebSocket.
Frame: Audio arrives via WebSocket receive task.
"""
logger.debug(f"{self}: Generating TTS [{text}]")
@@ -399,9 +399,6 @@ class SmallestTTSService(InterruptibleTTSService):
await self._connect()
try:
self._context_id = context_id
yield TTSStartedFrame(context_id=context_id)
msg = self._build_msg(text=text)
await self._get_websocket().send(json.dumps(msg))
await self.start_tts_usage_metrics(text)