Merge pull request #2512 from pipecat-ai/aleix/textframe-skip-tts

TextFrame: add skip_tts field
This commit is contained in:
Aleix Conchillo Flaqué
2025-08-27 16:26:03 -07:00
committed by GitHub
5 changed files with 65 additions and 4 deletions

View File

@@ -9,6 +9,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Added support for switching between audio+text to text-only modes within the
same pipeline. This is done by pushing
`LLMConfigureOutputFrame(skip_tts=True)` to enter text-only mode, and
disabling it to return to audio+text. The LLM will still generate tokens and
add them to the context, but they will not be sent to TTS.
- Added `skip_tts` field to `TextFrame`. This lets a text frame bypass TTS while
still being included in the LLM context. Useful for cases like structured text
that isnt meant to be spoken but should still contribute to context.
- Added a `cancel_timeout_secs` argument to `PipelineTask` which defines how
long the pipeline has to complete cancellation. When `PipelineTask.cancel()`
is called, a `CancelFrame` is pushed through the pipeline and must reach the

View File

@@ -305,6 +305,11 @@ class TextFrame(DataFrame):
"""
text: str
skip_tts: bool = field(init=False)
def __post_init__(self):
super().__post_init__()
self.skip_tts = False
def __str__(self):
pts = format_pts(self.pts)
@@ -602,6 +607,21 @@ class LLMEnablePromptCachingFrame(DataFrame):
enable: bool
@dataclass
class LLMConfigureOutputFrame(DataFrame):
"""Frame to configure LLM output.
This frame is used to configure how the LLM produces output. For example, it
can tell the LLM to generate tokens that should be added to the context but
not spoken by the TTS service (if one is present in the pipeline).
Parameters:
skip_tts: Whether LLM tokens should skip the TTS service (if any).
"""
skip_tts: bool
@dataclass
class TTSSpeakFrame(DataFrame):
"""Frame containing text that should be spoken by TTS.
@@ -1331,14 +1351,22 @@ class LLMFullResponseStartFrame(ControlFrame):
more TextFrames and a final LLMFullResponseEndFrame.
"""
pass
skip_tts: bool = field(init=False)
def __post_init__(self):
super().__post_init__()
self.skip_tts = False
@dataclass
class LLMFullResponseEndFrame(ControlFrame):
"""Frame indicating the end of an LLM response."""
pass
skip_tts: bool = field(init=False)
def __post_init__(self):
super().__post_init__()
self.skip_tts = False
@dataclass

View File

@@ -103,7 +103,7 @@ class DTMFAggregator(FrameProcessor):
digit_value = frame.button.value
self._aggregation += digit_value
# For first digit, schedule interruption in separate task
# For first digit, schedule interruption.
if is_first_digit:
await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM)

View File

@@ -14,7 +14,6 @@ from typing import (
Awaitable,
Callable,
Dict,
List,
Mapping,
Optional,
Protocol,
@@ -37,6 +36,10 @@ from pipecat.frames.frames import (
FunctionCallResultFrame,
FunctionCallResultProperties,
FunctionCallsStartedFrame,
LLMConfigureOutputFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMTextFrame,
StartFrame,
StartInterruptionFrame,
UserImageRequestFrame,
@@ -179,6 +182,7 @@ class LLMService(AIService):
self._function_call_tasks: Dict[asyncio.Task, FunctionCallRunnerItem] = {}
self._sequential_runner_task: Optional[asyncio.Task] = None
self._tracing_enabled: bool = False
self._skip_tts: bool = False
self._register_event_handler("on_function_calls_started")
self._register_event_handler("on_completion_timeout")
@@ -272,6 +276,20 @@ class LLMService(AIService):
if isinstance(frame, StartInterruptionFrame):
await self._handle_interruptions(frame)
elif isinstance(frame, LLMConfigureOutputFrame):
self._skip_tts = frame.skip_tts
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
"""Pushes a frame.
Args:
frame: The frame to push.
direction: The direction of frame pushing.
"""
if isinstance(frame, (LLMTextFrame, LLMFullResponseStartFrame, LLMFullResponseEndFrame)):
frame.skip_tts = self._skip_tts
await super().push_frame(frame, direction)
async def _handle_interruptions(self, _: StartInterruptionFrame):
for function_name, entry in self._functions.items():

View File

@@ -297,6 +297,11 @@ class TTSService(AIService):
await super().process_frame(frame, direction)
if (
isinstance(frame, (TextFrame, LLMFullResponseStartFrame, LLMFullResponseEndFrame))
and frame.skip_tts
):
await self.push_frame(frame, direction)
elif (
isinstance(frame, TextFrame)
and not isinstance(frame, InterimTranscriptionFrame)
and not isinstance(frame, TranscriptionFrame)