Merge pull request #2512 from pipecat-ai/aleix/textframe-skip-tts
TextFrame: add skip_tts field
This commit is contained in:
10
CHANGELOG.md
10
CHANGELOG.md
@@ -9,6 +9,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Added
|
||||
|
||||
- Added support for switching between audio+text to text-only modes within the
|
||||
same pipeline. This is done by pushing
|
||||
`LLMConfigureOutputFrame(skip_tts=True)` to enter text-only mode, and
|
||||
disabling it to return to audio+text. The LLM will still generate tokens and
|
||||
add them to the context, but they will not be sent to TTS.
|
||||
|
||||
- Added `skip_tts` field to `TextFrame`. This lets a text frame bypass TTS while
|
||||
still being included in the LLM context. Useful for cases like structured text
|
||||
that isn’t meant to be spoken but should still contribute to context.
|
||||
|
||||
- Added a `cancel_timeout_secs` argument to `PipelineTask` which defines how
|
||||
long the pipeline has to complete cancellation. When `PipelineTask.cancel()`
|
||||
is called, a `CancelFrame` is pushed through the pipeline and must reach the
|
||||
|
||||
@@ -305,6 +305,11 @@ class TextFrame(DataFrame):
|
||||
"""
|
||||
|
||||
text: str
|
||||
skip_tts: bool = field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.skip_tts = False
|
||||
|
||||
def __str__(self):
|
||||
pts = format_pts(self.pts)
|
||||
@@ -602,6 +607,21 @@ class LLMEnablePromptCachingFrame(DataFrame):
|
||||
enable: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMConfigureOutputFrame(DataFrame):
|
||||
"""Frame to configure LLM output.
|
||||
|
||||
This frame is used to configure how the LLM produces output. For example, it
|
||||
can tell the LLM to generate tokens that should be added to the context but
|
||||
not spoken by the TTS service (if one is present in the pipeline).
|
||||
|
||||
Parameters:
|
||||
skip_tts: Whether LLM tokens should skip the TTS service (if any).
|
||||
"""
|
||||
|
||||
skip_tts: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class TTSSpeakFrame(DataFrame):
|
||||
"""Frame containing text that should be spoken by TTS.
|
||||
@@ -1331,14 +1351,22 @@ class LLMFullResponseStartFrame(ControlFrame):
|
||||
more TextFrames and a final LLMFullResponseEndFrame.
|
||||
"""
|
||||
|
||||
pass
|
||||
skip_tts: bool = field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.skip_tts = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMFullResponseEndFrame(ControlFrame):
|
||||
"""Frame indicating the end of an LLM response."""
|
||||
|
||||
pass
|
||||
skip_tts: bool = field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.skip_tts = False
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -103,7 +103,7 @@ class DTMFAggregator(FrameProcessor):
|
||||
digit_value = frame.button.value
|
||||
self._aggregation += digit_value
|
||||
|
||||
# For first digit, schedule interruption in separate task
|
||||
# For first digit, schedule interruption.
|
||||
if is_first_digit:
|
||||
await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM)
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ from typing import (
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Protocol,
|
||||
@@ -37,6 +36,10 @@ from pipecat.frames.frames import (
|
||||
FunctionCallResultFrame,
|
||||
FunctionCallResultProperties,
|
||||
FunctionCallsStartedFrame,
|
||||
LLMConfigureOutputFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMTextFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
UserImageRequestFrame,
|
||||
@@ -179,6 +182,7 @@ class LLMService(AIService):
|
||||
self._function_call_tasks: Dict[asyncio.Task, FunctionCallRunnerItem] = {}
|
||||
self._sequential_runner_task: Optional[asyncio.Task] = None
|
||||
self._tracing_enabled: bool = False
|
||||
self._skip_tts: bool = False
|
||||
|
||||
self._register_event_handler("on_function_calls_started")
|
||||
self._register_event_handler("on_completion_timeout")
|
||||
@@ -272,6 +276,20 @@ class LLMService(AIService):
|
||||
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
await self._handle_interruptions(frame)
|
||||
elif isinstance(frame, LLMConfigureOutputFrame):
|
||||
self._skip_tts = frame.skip_tts
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
"""Pushes a frame.
|
||||
|
||||
Args:
|
||||
frame: The frame to push.
|
||||
direction: The direction of frame pushing.
|
||||
"""
|
||||
if isinstance(frame, (LLMTextFrame, LLMFullResponseStartFrame, LLMFullResponseEndFrame)):
|
||||
frame.skip_tts = self._skip_tts
|
||||
|
||||
await super().push_frame(frame, direction)
|
||||
|
||||
async def _handle_interruptions(self, _: StartInterruptionFrame):
|
||||
for function_name, entry in self._functions.items():
|
||||
|
||||
@@ -297,6 +297,11 @@ class TTSService(AIService):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if (
|
||||
isinstance(frame, (TextFrame, LLMFullResponseStartFrame, LLMFullResponseEndFrame))
|
||||
and frame.skip_tts
|
||||
):
|
||||
await self.push_frame(frame, direction)
|
||||
elif (
|
||||
isinstance(frame, TextFrame)
|
||||
and not isinstance(frame, InterimTranscriptionFrame)
|
||||
and not isinstance(frame, TranscriptionFrame)
|
||||
|
||||
Reference in New Issue
Block a user