diff --git a/CHANGELOG.md b/CHANGELOG.md index a643b7a3e..8ae588300 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added support for switching between audio+text to text-only modes within the + same pipeline. This is done by pushing + `LLMConfigureOutputFrame(skip_tts=True)` to enter text-only mode, and + disabling it to return to audio+text. The LLM will still generate tokens and + add them to the context, but they will not be sent to TTS. + +- Added `skip_tts` field to `TextFrame`. This lets a text frame bypass TTS while + still being included in the LLM context. Useful for cases like structured text + that isn’t meant to be spoken but should still contribute to context. + - Added a `cancel_timeout_secs` argument to `PipelineTask` which defines how long the pipeline has to complete cancellation. When `PipelineTask.cancel()` is called, a `CancelFrame` is pushed through the pipeline and must reach the diff --git a/src/pipecat/frames/frames.py b/src/pipecat/frames/frames.py index 139857297..691a09f5b 100644 --- a/src/pipecat/frames/frames.py +++ b/src/pipecat/frames/frames.py @@ -305,6 +305,11 @@ class TextFrame(DataFrame): """ text: str + skip_tts: bool = field(init=False) + + def __post_init__(self): + super().__post_init__() + self.skip_tts = False def __str__(self): pts = format_pts(self.pts) @@ -602,6 +607,21 @@ class LLMEnablePromptCachingFrame(DataFrame): enable: bool +@dataclass +class LLMConfigureOutputFrame(DataFrame): + """Frame to configure LLM output. + + This frame is used to configure how the LLM produces output. For example, it + can tell the LLM to generate tokens that should be added to the context but + not spoken by the TTS service (if one is present in the pipeline). + + Parameters: + skip_tts: Whether LLM tokens should skip the TTS service (if any). + """ + + skip_tts: bool + + @dataclass class TTSSpeakFrame(DataFrame): """Frame containing text that should be spoken by TTS. @@ -1331,14 +1351,22 @@ class LLMFullResponseStartFrame(ControlFrame): more TextFrames and a final LLMFullResponseEndFrame. """ - pass + skip_tts: bool = field(init=False) + + def __post_init__(self): + super().__post_init__() + self.skip_tts = False @dataclass class LLMFullResponseEndFrame(ControlFrame): """Frame indicating the end of an LLM response.""" - pass + skip_tts: bool = field(init=False) + + def __post_init__(self): + super().__post_init__() + self.skip_tts = False @dataclass diff --git a/src/pipecat/processors/aggregators/dtmf_aggregator.py b/src/pipecat/processors/aggregators/dtmf_aggregator.py index 24ef2a1e1..38e1296f6 100644 --- a/src/pipecat/processors/aggregators/dtmf_aggregator.py +++ b/src/pipecat/processors/aggregators/dtmf_aggregator.py @@ -103,7 +103,7 @@ class DTMFAggregator(FrameProcessor): digit_value = frame.button.value self._aggregation += digit_value - # For first digit, schedule interruption in separate task + # For first digit, schedule interruption. if is_first_digit: await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM) diff --git a/src/pipecat/services/llm_service.py b/src/pipecat/services/llm_service.py index 3152a0083..a44f7ab26 100644 --- a/src/pipecat/services/llm_service.py +++ b/src/pipecat/services/llm_service.py @@ -14,7 +14,6 @@ from typing import ( Awaitable, Callable, Dict, - List, Mapping, Optional, Protocol, @@ -37,6 +36,10 @@ from pipecat.frames.frames import ( FunctionCallResultFrame, FunctionCallResultProperties, FunctionCallsStartedFrame, + LLMConfigureOutputFrame, + LLMFullResponseEndFrame, + LLMFullResponseStartFrame, + LLMTextFrame, StartFrame, StartInterruptionFrame, UserImageRequestFrame, @@ -179,6 +182,7 @@ class LLMService(AIService): self._function_call_tasks: Dict[asyncio.Task, FunctionCallRunnerItem] = {} self._sequential_runner_task: Optional[asyncio.Task] = None self._tracing_enabled: bool = False + self._skip_tts: bool = False self._register_event_handler("on_function_calls_started") self._register_event_handler("on_completion_timeout") @@ -272,6 +276,20 @@ class LLMService(AIService): if isinstance(frame, StartInterruptionFrame): await self._handle_interruptions(frame) + elif isinstance(frame, LLMConfigureOutputFrame): + self._skip_tts = frame.skip_tts + + async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM): + """Pushes a frame. + + Args: + frame: The frame to push. + direction: The direction of frame pushing. + """ + if isinstance(frame, (LLMTextFrame, LLMFullResponseStartFrame, LLMFullResponseEndFrame)): + frame.skip_tts = self._skip_tts + + await super().push_frame(frame, direction) async def _handle_interruptions(self, _: StartInterruptionFrame): for function_name, entry in self._functions.items(): diff --git a/src/pipecat/services/tts_service.py b/src/pipecat/services/tts_service.py index 2a210f093..9e34adf25 100644 --- a/src/pipecat/services/tts_service.py +++ b/src/pipecat/services/tts_service.py @@ -297,6 +297,11 @@ class TTSService(AIService): await super().process_frame(frame, direction) if ( + isinstance(frame, (TextFrame, LLMFullResponseStartFrame, LLMFullResponseEndFrame)) + and frame.skip_tts + ): + await self.push_frame(frame, direction) + elif ( isinstance(frame, TextFrame) and not isinstance(frame, InterimTranscriptionFrame) and not isinstance(frame, TranscriptionFrame)