Fix SmallestTTSService to use InterruptibleTTSService audio context system
- Route audio through audio contexts (append_to_audio_context) instead of pushing frames directly, enabling proper turn management and interruptions - Add push_stop_frames and push_start_frame so the base class handles TTSStartedFrame/TTSStoppedFrame lifecycle - Remove manual context_id tracking (self._context_id) in favor of get_active_audio_context_id() - Don't call remove_audio_context on "complete" — Smallest sends one per request, not per turn; let the base class timeout handle cleanup - Guard v2-only params (consistency, similarity, enhancement) so they aren't sent to lightning-v3.1 - Remove request_id from request payload (not a documented request field) - Add flush_audio override to send flush to WebSocket
This commit is contained in:
@@ -26,7 +26,6 @@ from pipecat.frames.frames import (
|
||||
Frame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven
|
||||
@@ -152,8 +151,8 @@ class SmallestTTSService(InterruptibleTTSService):
|
||||
default_settings.apply_update(settings)
|
||||
|
||||
super().__init__(
|
||||
aggregate_sentences=True,
|
||||
push_text_frames=True,
|
||||
push_stop_frames=True,
|
||||
push_start_frame=True,
|
||||
pause_frame_processing=True,
|
||||
sample_rate=sample_rate,
|
||||
settings=default_settings,
|
||||
@@ -164,7 +163,6 @@ class SmallestTTSService(InterruptibleTTSService):
|
||||
self._base_url = base_url.rstrip("/")
|
||||
self._receive_task = None
|
||||
self._keepalive_task = None
|
||||
self._context_id: Optional[str] = None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
@@ -203,15 +201,15 @@ class SmallestTTSService(InterruptibleTTSService):
|
||||
|
||||
if self._settings.speed is not None:
|
||||
msg["speed"] = self._settings.speed
|
||||
if self._settings.consistency is not None:
|
||||
msg["consistency"] = self._settings.consistency
|
||||
if self._settings.similarity is not None:
|
||||
msg["similarity"] = self._settings.similarity
|
||||
if self._settings.enhancement is not None:
|
||||
msg["enhancement"] = self._settings.enhancement
|
||||
|
||||
if self._context_id:
|
||||
msg["request_id"] = self._context_id
|
||||
# consistency, similarity, enhancement are only supported by lightning-v2
|
||||
if self._settings.model == SmallestTTSModel.LIGHTNING_V2.value:
|
||||
if self._settings.consistency is not None:
|
||||
msg["consistency"] = self._settings.consistency
|
||||
if self._settings.similarity is not None:
|
||||
msg["similarity"] = self._settings.similarity
|
||||
if self._settings.enhancement is not None:
|
||||
msg["enhancement"] = self._settings.enhancement
|
||||
|
||||
return msg
|
||||
|
||||
@@ -322,7 +320,6 @@ class SmallestTTSService(InterruptibleTTSService):
|
||||
error_msg=f"Smallest TTS error closing websocket: {e}", exception=e
|
||||
)
|
||||
finally:
|
||||
self._context_id = None
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
@@ -352,6 +349,12 @@ class SmallestTTSService(InterruptibleTTSService):
|
||||
msg = {"flush": True}
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
|
||||
async def flush_audio(self, context_id: Optional[str] = None):
|
||||
"""Flush any pending audio synthesis."""
|
||||
if not self._websocket or self._websocket.state is State.CLOSED:
|
||||
return
|
||||
await self._get_websocket().send(json.dumps({"flush": True}))
|
||||
|
||||
async def _receive_messages(self):
|
||||
"""Receive and process messages from the Smallest WebSocket API."""
|
||||
async for message in self._get_websocket():
|
||||
@@ -359,25 +362,22 @@ class SmallestTTSService(InterruptibleTTSService):
|
||||
status = msg.get("status")
|
||||
|
||||
if status == "complete":
|
||||
msg_request_id = msg.get("request_id")
|
||||
if self._context_id and msg_request_id and msg_request_id == self._context_id:
|
||||
await self.stop_all_metrics()
|
||||
await self.push_frame(TTSStoppedFrame(context_id=self._context_id))
|
||||
self._context_id = None
|
||||
await self.stop_all_metrics()
|
||||
elif status == "chunk":
|
||||
await self.stop_ttfb_metrics()
|
||||
context_id = self.get_active_audio_context_id()
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=base64.b64decode(msg["data"]["audio"]),
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
context_id=self._context_id,
|
||||
context_id=context_id,
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
await self.append_to_audio_context(context_id, frame)
|
||||
elif status == "error":
|
||||
await self.push_frame(TTSStoppedFrame(context_id=self._context_id))
|
||||
context_id = self.get_active_audio_context_id()
|
||||
await self.push_frame(TTSStoppedFrame(context_id=context_id))
|
||||
await self.stop_all_metrics()
|
||||
await self.push_error(error_msg=f"Smallest TTS error: {msg.get('error', msg)}")
|
||||
self._context_id = None
|
||||
else:
|
||||
logger.warning(f"{self} unknown message status: {msg}")
|
||||
|
||||
@@ -390,7 +390,7 @@ class SmallestTTSService(InterruptibleTTSService):
|
||||
context_id: Unique identifier for this TTS context.
|
||||
|
||||
Yields:
|
||||
Frame: TTSStartedFrame to signal start; audio arrives via WebSocket.
|
||||
Frame: Audio arrives via WebSocket receive task.
|
||||
"""
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
@@ -399,9 +399,6 @@ class SmallestTTSService(InterruptibleTTSService):
|
||||
await self._connect()
|
||||
|
||||
try:
|
||||
self._context_id = context_id
|
||||
yield TTSStartedFrame(context_id=context_id)
|
||||
|
||||
msg = self._build_msg(text=text)
|
||||
await self._get_websocket().send(json.dumps(msg))
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
Reference in New Issue
Block a user