From 5512de32210e26194d00254fdd48c721d65ddc1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Sun, 1 Jun 2025 13:54:57 -0700 Subject: [PATCH 1/3] allow custom interruption strategies --- CHANGELOG.md | 15 ++++--- src/pipecat/audio/interruptions/__init__.py | 0 .../base_interruption_strategy.py | 38 ++++++++++++++++++ .../min_words_interruption_strategy.py | 40 +++++++++++++++++++ src/pipecat/frames/frames.py | 25 +----------- src/pipecat/pipeline/task.py | 4 +- .../processors/aggregators/llm_response.py | 36 +++++++++-------- src/pipecat/processors/frame_processor.py | 8 ++-- src/pipecat/transports/base_input.py | 2 +- tests/test_interruption_strategies.py | 24 +++++++++++ 10 files changed, 140 insertions(+), 52 deletions(-) create mode 100644 src/pipecat/audio/interruptions/__init__.py create mode 100644 src/pipecat/audio/interruptions/base_interruption_strategy.py create mode 100644 src/pipecat/audio/interruptions/min_words_interruption_strategy.py create mode 100644 tests/test_interruption_strategies.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 08f85f328..6eefa24b1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,12 +26,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added OpenTelemetry tracing for `GeminiMultimodalLiveLLMService` and `OpenAIRealtimeBetaLLMService`. -- Added `interruption_strategies` to `PipelineParams` using - `MinWordsInterruptionStrategy` to specify minimum words required to interrupt - the bot when it's speaking. Use - `interruption_strategies=[MinWordsInterruptionStrategy(min_words=N)]` to - require users to speak at least N words before interrupting. If not - specified, the normal interruption behavior applies. +- Added initial support for interruption strategies, which determine if the user + should interrupt the bot while the bot is speaking. Interruption strategies + can be based on factors such as audio volume or the number of words spoken by + the user. These can be specified via the new `interruption_strategies` field + in `PipelineParams`. A new `MinWordsInterruptionStrategy` strategy has been + introduced which triggers an interruption if the user has spoken a minimum + number of words. If no interruption strategies are specified, the normal + interruption behavior applies. If multiple strategies are provided, the first + one that evaluates to true will trigger the interruption. - `BaseInputTransport` now handles `StopFrame`. When a `StopFrame` is received the transport will pause sending frames downstream until a new `StartFrame` is diff --git a/src/pipecat/audio/interruptions/__init__.py b/src/pipecat/audio/interruptions/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/pipecat/audio/interruptions/base_interruption_strategy.py b/src/pipecat/audio/interruptions/base_interruption_strategy.py new file mode 100644 index 000000000..7811e8418 --- /dev/null +++ b/src/pipecat/audio/interruptions/base_interruption_strategy.py @@ -0,0 +1,38 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from abc import ABC, abstractmethod + + +class BaseInterruptionStrategy(ABC): + """This is a base class for interruption strategies. Interruption strategies + decide when the user can interrupt the bot while the bot is speaking. For + example, there could be strategies based on audio volume or strategies based + on the number of words the user spoke. + + """ + + async def append_audio(self, audio: bytes, sample_rate: int): + """Appends audio to the strategy. Not all strategies handle audio.""" + pass + + async def append_text(self, text: str): + """Appends text to the strategy. Not all strategies handle text.""" + pass + + @abstractmethod + async def should_interrupt(self) -> bool: + """This is called when the user stops speaking and it's time to decide + whether the user should interrupt the bot. The decision will be based on + the aggregated audio and/or text. + + """ + pass + + @abstractmethod + async def reset(self): + """Reset the current accumulated text and/or audio.""" + pass diff --git a/src/pipecat/audio/interruptions/min_words_interruption_strategy.py b/src/pipecat/audio/interruptions/min_words_interruption_strategy.py new file mode 100644 index 000000000..f9f7595ab --- /dev/null +++ b/src/pipecat/audio/interruptions/min_words_interruption_strategy.py @@ -0,0 +1,40 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from loguru import logger + +from pipecat.audio.interruptions.base_interruption_strategy import BaseInterruptionStrategy + + +class MinWordsInterruptionStrategy(BaseInterruptionStrategy): + """This is an interruption strategy based on a minimum number of words said + by the user. That is, the strategy will be true if the user has said at + least that amount of words. + + """ + + def __init__(self, *, min_words: int): + super().__init__() + self._min_words = min_words + self._text = "" + + async def append_text(self, text: str): + """Appends text for later analysis. Not all strategies need to handle + text. + + """ + self._text += text + + async def should_interrupt(self) -> bool: + word_count = len(self._text.split()) + interrupt = word_count >= self._min_words + logger.debug( + f"should_interrupt={interrupt} num_spoken_words={word_count} min_words={self._min_words}" + ) + return interrupt + + async def reset(self): + self._text = "" diff --git a/src/pipecat/frames/frames.py b/src/pipecat/frames/frames.py index 6ad0f089f..63caf9a09 100644 --- a/src/pipecat/frames/frames.py +++ b/src/pipecat/frames/frames.py @@ -19,6 +19,7 @@ from typing import ( Tuple, ) +from pipecat.audio.interruptions.base_interruption_strategy import BaseInterruptionStrategy from pipecat.audio.vad.vad_analyzer import VADParams from pipecat.metrics.metrics import MetricsData from pipecat.transcriptions.language import Language @@ -439,28 +440,6 @@ class OutputDTMFFrame(DTMFFrame, DataFrame): # -@dataclass -class InterruptionStrategy: - """Base class for interruption strategies.""" - - pass - - -@dataclass -class MinWordsInterruptionStrategy(InterruptionStrategy): - """Strategy for interruption behavior based on a minimum number of words spoken by the user. - - Args: - min_words: If set, user must speak at least this many words to interrupt - """ - - min_words: int - - def __post_init__(self): - if self.min_words <= 0: - raise ValueError("min_words must be greater than 0") - - @dataclass class StartFrame(SystemFrame): """This is the first frame that should be pushed down a pipeline.""" @@ -471,7 +450,7 @@ class StartFrame(SystemFrame): enable_metrics: bool = False enable_usage_metrics: bool = False report_only_initial_ttfb: bool = False - interruption_strategies: Optional[Sequence[InterruptionStrategy]] = None + interruption_strategies: List[BaseInterruptionStrategy] = field(default_factory=list) @dataclass diff --git a/src/pipecat/pipeline/task.py b/src/pipecat/pipeline/task.py index 5e421edc4..520998988 100644 --- a/src/pipecat/pipeline/task.py +++ b/src/pipecat/pipeline/task.py @@ -11,6 +11,7 @@ from typing import Any, AsyncIterable, Dict, Iterable, List, Optional, Sequence, from loguru import logger from pydantic import BaseModel, ConfigDict, Field +from pipecat.audio.interruptions.base_interruption_strategy import BaseInterruptionStrategy from pipecat.clocks.base_clock import BaseClock from pipecat.clocks.system_clock import SystemClock from pipecat.frames.frames import ( @@ -22,7 +23,6 @@ from pipecat.frames.frames import ( ErrorFrame, Frame, HeartbeatFrame, - InterruptionStrategy, LLMFullResponseEndFrame, MetricsFrame, StartFrame, @@ -75,7 +75,7 @@ class PipelineParams(BaseModel): report_only_initial_ttfb: bool = False send_initial_empty_metrics: bool = True start_metadata: Dict[str, Any] = Field(default_factory=dict) - interruption_strategies: Optional[Sequence[InterruptionStrategy]] = None + interruption_strategies: List[BaseInterruptionStrategy] = Field(default_factory=list) class PipelineTaskSource(FrameProcessor): diff --git a/src/pipecat/processors/aggregators/llm_response.py b/src/pipecat/processors/aggregators/llm_response.py index ae2382cd6..b1641d7df 100644 --- a/src/pipecat/processors/aggregators/llm_response.py +++ b/src/pipecat/processors/aggregators/llm_response.py @@ -11,6 +11,7 @@ from typing import Dict, List, Literal, Optional, Set from loguru import logger +from pipecat.audio.interruptions.base_interruption_strategy import BaseInterruptionStrategy from pipecat.frames.frames import ( BotInterruptionFrame, BotStartedSpeakingFrame, @@ -24,6 +25,7 @@ from pipecat.frames.frames import ( FunctionCallInProgressFrame, FunctionCallResultFrame, FunctionCallsStartedFrame, + InputAudioRawFrame, InterimTranscriptionFrame, LLMFullResponseEndFrame, LLMFullResponseStartFrame, @@ -33,7 +35,6 @@ from pipecat.frames.frames import ( LLMSetToolChoiceFrame, LLMSetToolsFrame, LLMTextFrame, - MinWordsInterruptionStrategy, OpenAILLMContextAssistantTimestampFrame, StartFrame, StartInterruptionFrame, @@ -296,6 +297,9 @@ class LLMUserContextAggregator(LLMContextResponseAggregator): elif isinstance(frame, CancelFrame): await self._cancel(frame) await self.push_frame(frame, direction) + elif isinstance(frame, InputAudioRawFrame): + await self._handle_input_audio(frame) + await self.push_frame(frame, direction) elif isinstance(frame, UserStartedSpeakingFrame): await self._handle_user_started_speaking(frame) await self.push_frame(frame, direction) @@ -332,10 +336,10 @@ class LLMUserContextAggregator(LLMContextResponseAggregator): await self.push_frame(frame) async def push_aggregation(self): - """Pushes the current aggregation based on interruption configuration and conditions.""" + """Pushes the current aggregation based on interruption strategies and conditions.""" if len(self._aggregation) > 0: if self.interruption_strategies and self._bot_speaking: - should_interrupt = self._should_interrupt_based_on_strategies() + should_interrupt = await self._should_interrupt_based_on_strategies() if should_interrupt: logger.debug( @@ -351,23 +355,19 @@ class LLMUserContextAggregator(LLMContextResponseAggregator): # No interruption config - normal behavior (always push aggregation) await self._process_aggregation() - def _should_interrupt_based_on_strategies(self) -> bool: + async def _should_interrupt_based_on_strategies(self) -> bool: """Check if interruption should occur based on configured strategies.""" - if not self.interruption_strategies: - return False - # Check strategies one by one until first match - for strategy in self.interruption_strategies: - if isinstance(strategy, MinWordsInterruptionStrategy): - if self._should_interrupt_min_words(strategy): - return True + async def should_interrupt(strategy: BaseInterruptionStrategy): + await strategy.append_text(self._aggregation) + return await strategy.should_interrupt() - return False + result = any([await should_interrupt(s) for s in self._interruption_strategies]) - def _should_interrupt_min_words(self, strategy: MinWordsInterruptionStrategy) -> bool: - """Check if word count threshold is met.""" - word_count = len(self._aggregation.split()) - return word_count >= strategy.min_words + # Reset all strategies. + [await s.reset() for s in self._interruption_strategies] + + return result async def _start(self, frame: StartFrame): self._create_aggregation_task() @@ -378,6 +378,10 @@ class LLMUserContextAggregator(LLMContextResponseAggregator): async def _cancel(self, frame: CancelFrame): await self._cancel_aggregation_task() + async def _handle_input_audio(self, frame: InputAudioRawFrame): + for s in self.interruption_strategies: + await s.append_audio(frame.audio, frame.sample_rate) + async def _handle_user_started_speaking(self, frame: UserStartedSpeakingFrame): self._user_speaking = True self._waiting_for_aggregation = True diff --git a/src/pipecat/processors/frame_processor.py b/src/pipecat/processors/frame_processor.py index 36055c7c0..3b66973dd 100644 --- a/src/pipecat/processors/frame_processor.py +++ b/src/pipecat/processors/frame_processor.py @@ -7,16 +7,16 @@ import asyncio from dataclasses import dataclass from enum import Enum -from typing import Awaitable, Callable, Coroutine, Optional, Sequence +from typing import Awaitable, Callable, Coroutine, List, Optional, Sequence from loguru import logger +from pipecat.audio.interruptions.base_interruption_strategy import BaseInterruptionStrategy from pipecat.clocks.base_clock import BaseClock from pipecat.frames.frames import ( CancelFrame, ErrorFrame, Frame, - InterruptionStrategy, StartFrame, StartInterruptionFrame, StopInterruptionFrame, @@ -68,7 +68,7 @@ class FrameProcessor(BaseObject): self._enable_metrics = False self._enable_usage_metrics = False self._report_only_initial_ttfb = False - self._interruption_strategies: Optional[Sequence[InterruptionStrategy]] = None + self._interruption_strategies: List[BaseInterruptionStrategy] = [] # Indicates whether we have received the StartFrame. self.__started = False @@ -122,7 +122,7 @@ class FrameProcessor(BaseObject): return self._report_only_initial_ttfb @property - def interruption_strategies(self) -> Optional[Sequence[InterruptionStrategy]]: + def interruption_strategies(self) -> Sequence[BaseInterruptionStrategy]: return self._interruption_strategies def can_generate_metrics(self) -> bool: diff --git a/src/pipecat/transports/base_input.py b/src/pipecat/transports/base_input.py index dce3f547d..68f58694a 100644 --- a/src/pipecat/transports/base_input.py +++ b/src/pipecat/transports/base_input.py @@ -246,7 +246,7 @@ class BaseInputTransport(FrameProcessor): # 1. No interruption config is set, OR # 2. Interruption config is set but bot is not speaking should_push_immediate_interruption = ( - self.interruption_strategies is None or not self._bot_speaking + not self.interruption_strategies or not self._bot_speaking ) # Make sure we notify about interruptions quickly out-of-band. diff --git a/tests/test_interruption_strategies.py b/tests/test_interruption_strategies.py new file mode 100644 index 000000000..aa1bd7625 --- /dev/null +++ b/tests/test_interruption_strategies.py @@ -0,0 +1,24 @@ +# +# Copyright (c) 2024-2025 Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import unittest + +from pipecat.audio.interruptions.min_words_interruption_strategy import MinWordsInterruptionStrategy + + +class TestInterruptionStrategy(unittest.IsolatedAsyncioTestCase): + async def test_min_words(self): + strategy = MinWordsInterruptionStrategy(min_words=2) + await strategy.append_text("Hello") + self.assertEqual(await strategy.should_interrupt(), False) + await strategy.append_text(" there!") + self.assertEqual(await strategy.should_interrupt(), True) + # Reset and check again + await strategy.reset() + await strategy.append_text("Hello!") + self.assertEqual(await strategy.should_interrupt(), False) + await strategy.append_text(" How are you?") + self.assertEqual(await strategy.should_interrupt(), True) From 532767cfa1896f53d792c2f896585d795f63de54 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Mon, 2 Jun 2025 11:50:25 -0700 Subject: [PATCH 2/3] LLMUserContextAggregator: reset strategies when reseting the aggregator --- .../processors/aggregators/llm_response.py | 30 ++++++++----------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/src/pipecat/processors/aggregators/llm_response.py b/src/pipecat/processors/aggregators/llm_response.py index b1641d7df..479199550 100644 --- a/src/pipecat/processors/aggregators/llm_response.py +++ b/src/pipecat/processors/aggregators/llm_response.py @@ -163,7 +163,7 @@ class BaseLLMResponseAggregator(FrameProcessor): pass @abstractmethod - def reset(self): + async def reset(self): """Reset the internals of this aggregator. This should not modify the internal messages. """ @@ -230,7 +230,7 @@ class LLMContextResponseAggregator(BaseLLMResponseAggregator): def set_tool_choice(self, tool_choice: Literal["none", "auto", "required"] | dict): self._context.set_tool_choice(tool_choice) - def reset(self): + async def reset(self): self._aggregation = "" @@ -273,10 +273,11 @@ class LLMUserContextAggregator(LLMContextResponseAggregator): self._aggregation_event = asyncio.Event() self._aggregation_task = None - def reset(self): - super().reset() + async def reset(self): + await super().reset() self._seen_interim_results = False self._waiting_for_aggregation = False + [await s.reset() for s in self._interruption_strategies] async def handle_aggregation(self, aggregation: str): self._context.add_message({"role": self.role, "content": aggregation}) @@ -330,7 +331,7 @@ class LLMUserContextAggregator(LLMContextResponseAggregator): async def _process_aggregation(self): """Process the current aggregation and push it downstream.""" aggregation = self._aggregation - self.reset() + await self.reset() await self.handle_aggregation(aggregation) frame = OpenAILLMContextFrame(self._context) await self.push_frame(frame) @@ -350,7 +351,7 @@ class LLMUserContextAggregator(LLMContextResponseAggregator): else: logger.debug("Interruption conditions not met - not pushing aggregation") # Don't process aggregation, just reset it - self.reset() + await self.reset() else: # No interruption config - normal behavior (always push aggregation) await self._process_aggregation() @@ -362,12 +363,7 @@ class LLMUserContextAggregator(LLMContextResponseAggregator): await strategy.append_text(self._aggregation) return await strategy.should_interrupt() - result = any([await should_interrupt(s) for s in self._interruption_strategies]) - - # Reset all strategies. - [await s.reset() for s in self._interruption_strategies] - - return result + return any([await should_interrupt(s) for s in self._interruption_strategies]) async def _start(self, frame: StartFrame): self._create_aggregation_task() @@ -467,7 +463,7 @@ class LLMUserContextAggregator(LLMContextResponseAggregator): # If we reached this case and the bot is speaking, let's ignore # what the user said. logger.debug("Ignoring user speaking emulation, bot is speaking.") - self.reset() + await self.reset() else: # The bot is not speaking so, let's trigger user speaking # emulation. @@ -564,7 +560,7 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator): return aggregation = self._aggregation.strip() - self.reset() + await self.reset() if aggregation: await self.handle_aggregation(aggregation) @@ -579,7 +575,7 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator): async def _handle_interruptions(self, frame: StartInterruptionFrame): await self.push_aggregation() self._started = 0 - self.reset() + await self.reset() async def _handle_function_calls_started(self, frame: FunctionCallsStartedFrame): function_names = [f"{f.function_name}:{f.tool_call_id}" for f in frame.function_calls] @@ -704,7 +700,7 @@ class LLMUserResponseAggregator(LLMUserContextAggregator): # Reset the aggregation. Reset it before pushing it down, otherwise # if the tasks gets cancelled we won't be able to clear things up. - self.reset() + await self.reset() frame = LLMMessagesFrame(self._context.messages) await self.push_frame(frame) @@ -726,7 +722,7 @@ class LLMAssistantResponseAggregator(LLMAssistantContextAggregator): # Reset the aggregation. Reset it before pushing it down, otherwise # if the tasks gets cancelled we won't be able to clear things up. - self.reset() + await self.reset() frame = LLMMessagesFrame(self._context.messages) await self.push_frame(frame) From ab4b48c823d72d1b99463023f890619589c93206 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Sun, 1 Jun 2025 16:16:26 -0700 Subject: [PATCH 3/3] examples(04a): fix daily_runner import --- examples/foundational/04a-transports-daily.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/foundational/04a-transports-daily.py b/examples/foundational/04a-transports-daily.py index 98060dbab..a968c3abb 100644 --- a/examples/foundational/04a-transports-daily.py +++ b/examples/foundational/04a-transports-daily.py @@ -9,11 +9,11 @@ import os import sys import aiohttp -from daily_runner import configure from dotenv import load_dotenv from loguru import logger from pipecat.audio.vad.silero import SileroVADAnalyzer +from pipecat.examples.daily_runner import configure from pipecat.pipeline.pipeline import Pipeline from pipecat.pipeline.runner import PipelineRunner from pipecat.pipeline.task import PipelineParams, PipelineTask