Merge pull request #1938 from pipecat-ai/aleix/custom-interruption-strategies

allow custom interruption strategies
This commit is contained in:
Aleix Conchillo Flaqué
2025-06-02 12:05:50 -07:00
committed by GitHub
11 changed files with 149 additions and 65 deletions

View File

@@ -26,12 +26,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added OpenTelemetry tracing for `GeminiMultimodalLiveLLMService` and
`OpenAIRealtimeBetaLLMService`.
- Added `interruption_strategies` to `PipelineParams` using
`MinWordsInterruptionStrategy` to specify minimum words required to interrupt
the bot when it's speaking. Use
`interruption_strategies=[MinWordsInterruptionStrategy(min_words=N)]` to
require users to speak at least N words before interrupting. If not
specified, the normal interruption behavior applies.
- Added initial support for interruption strategies, which determine if the user
should interrupt the bot while the bot is speaking. Interruption strategies
can be based on factors such as audio volume or the number of words spoken by
the user. These can be specified via the new `interruption_strategies` field
in `PipelineParams`. A new `MinWordsInterruptionStrategy` strategy has been
introduced which triggers an interruption if the user has spoken a minimum
number of words. If no interruption strategies are specified, the normal
interruption behavior applies. If multiple strategies are provided, the first
one that evaluates to true will trigger the interruption.
- `BaseInputTransport` now handles `StopFrame`. When a `StopFrame` is received
the transport will pause sending frames downstream until a new `StartFrame` is

View File

@@ -9,11 +9,11 @@ import os
import sys
import aiohttp
from daily_runner import configure
from dotenv import load_dotenv
from loguru import logger
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.examples.daily_runner import configure
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask

View File

@@ -0,0 +1,38 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
from abc import ABC, abstractmethod
class BaseInterruptionStrategy(ABC):
"""This is a base class for interruption strategies. Interruption strategies
decide when the user can interrupt the bot while the bot is speaking. For
example, there could be strategies based on audio volume or strategies based
on the number of words the user spoke.
"""
async def append_audio(self, audio: bytes, sample_rate: int):
"""Appends audio to the strategy. Not all strategies handle audio."""
pass
async def append_text(self, text: str):
"""Appends text to the strategy. Not all strategies handle text."""
pass
@abstractmethod
async def should_interrupt(self) -> bool:
"""This is called when the user stops speaking and it's time to decide
whether the user should interrupt the bot. The decision will be based on
the aggregated audio and/or text.
"""
pass
@abstractmethod
async def reset(self):
"""Reset the current accumulated text and/or audio."""
pass

View File

@@ -0,0 +1,40 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
from loguru import logger
from pipecat.audio.interruptions.base_interruption_strategy import BaseInterruptionStrategy
class MinWordsInterruptionStrategy(BaseInterruptionStrategy):
"""This is an interruption strategy based on a minimum number of words said
by the user. That is, the strategy will be true if the user has said at
least that amount of words.
"""
def __init__(self, *, min_words: int):
super().__init__()
self._min_words = min_words
self._text = ""
async def append_text(self, text: str):
"""Appends text for later analysis. Not all strategies need to handle
text.
"""
self._text += text
async def should_interrupt(self) -> bool:
word_count = len(self._text.split())
interrupt = word_count >= self._min_words
logger.debug(
f"should_interrupt={interrupt} num_spoken_words={word_count} min_words={self._min_words}"
)
return interrupt
async def reset(self):
self._text = ""

View File

@@ -19,6 +19,7 @@ from typing import (
Tuple,
)
from pipecat.audio.interruptions.base_interruption_strategy import BaseInterruptionStrategy
from pipecat.audio.vad.vad_analyzer import VADParams
from pipecat.metrics.metrics import MetricsData
from pipecat.transcriptions.language import Language
@@ -439,28 +440,6 @@ class OutputDTMFFrame(DTMFFrame, DataFrame):
#
@dataclass
class InterruptionStrategy:
"""Base class for interruption strategies."""
pass
@dataclass
class MinWordsInterruptionStrategy(InterruptionStrategy):
"""Strategy for interruption behavior based on a minimum number of words spoken by the user.
Args:
min_words: If set, user must speak at least this many words to interrupt
"""
min_words: int
def __post_init__(self):
if self.min_words <= 0:
raise ValueError("min_words must be greater than 0")
@dataclass
class StartFrame(SystemFrame):
"""This is the first frame that should be pushed down a pipeline."""
@@ -471,7 +450,7 @@ class StartFrame(SystemFrame):
enable_metrics: bool = False
enable_usage_metrics: bool = False
report_only_initial_ttfb: bool = False
interruption_strategies: Optional[Sequence[InterruptionStrategy]] = None
interruption_strategies: List[BaseInterruptionStrategy] = field(default_factory=list)
@dataclass

View File

@@ -11,6 +11,7 @@ from typing import Any, AsyncIterable, Dict, Iterable, List, Optional, Sequence,
from loguru import logger
from pydantic import BaseModel, ConfigDict, Field
from pipecat.audio.interruptions.base_interruption_strategy import BaseInterruptionStrategy
from pipecat.clocks.base_clock import BaseClock
from pipecat.clocks.system_clock import SystemClock
from pipecat.frames.frames import (
@@ -22,7 +23,6 @@ from pipecat.frames.frames import (
ErrorFrame,
Frame,
HeartbeatFrame,
InterruptionStrategy,
LLMFullResponseEndFrame,
MetricsFrame,
StartFrame,
@@ -75,7 +75,7 @@ class PipelineParams(BaseModel):
report_only_initial_ttfb: bool = False
send_initial_empty_metrics: bool = True
start_metadata: Dict[str, Any] = Field(default_factory=dict)
interruption_strategies: Optional[Sequence[InterruptionStrategy]] = None
interruption_strategies: List[BaseInterruptionStrategy] = Field(default_factory=list)
class PipelineTaskSource(FrameProcessor):

View File

@@ -11,6 +11,7 @@ from typing import Dict, List, Literal, Optional, Set
from loguru import logger
from pipecat.audio.interruptions.base_interruption_strategy import BaseInterruptionStrategy
from pipecat.frames.frames import (
BotInterruptionFrame,
BotStartedSpeakingFrame,
@@ -24,6 +25,7 @@ from pipecat.frames.frames import (
FunctionCallInProgressFrame,
FunctionCallResultFrame,
FunctionCallsStartedFrame,
InputAudioRawFrame,
InterimTranscriptionFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
@@ -33,7 +35,6 @@ from pipecat.frames.frames import (
LLMSetToolChoiceFrame,
LLMSetToolsFrame,
LLMTextFrame,
MinWordsInterruptionStrategy,
OpenAILLMContextAssistantTimestampFrame,
StartFrame,
StartInterruptionFrame,
@@ -162,7 +163,7 @@ class BaseLLMResponseAggregator(FrameProcessor):
pass
@abstractmethod
def reset(self):
async def reset(self):
"""Reset the internals of this aggregator. This should not modify the
internal messages.
"""
@@ -229,7 +230,7 @@ class LLMContextResponseAggregator(BaseLLMResponseAggregator):
def set_tool_choice(self, tool_choice: Literal["none", "auto", "required"] | dict):
self._context.set_tool_choice(tool_choice)
def reset(self):
async def reset(self):
self._aggregation = ""
@@ -272,10 +273,11 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
self._aggregation_event = asyncio.Event()
self._aggregation_task = None
def reset(self):
super().reset()
async def reset(self):
await super().reset()
self._seen_interim_results = False
self._waiting_for_aggregation = False
[await s.reset() for s in self._interruption_strategies]
async def handle_aggregation(self, aggregation: str):
self._context.add_message({"role": self.role, "content": aggregation})
@@ -296,6 +298,9 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
elif isinstance(frame, CancelFrame):
await self._cancel(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, InputAudioRawFrame):
await self._handle_input_audio(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, UserStartedSpeakingFrame):
await self._handle_user_started_speaking(frame)
await self.push_frame(frame, direction)
@@ -326,16 +331,16 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
async def _process_aggregation(self):
"""Process the current aggregation and push it downstream."""
aggregation = self._aggregation
self.reset()
await self.reset()
await self.handle_aggregation(aggregation)
frame = OpenAILLMContextFrame(self._context)
await self.push_frame(frame)
async def push_aggregation(self):
"""Pushes the current aggregation based on interruption configuration and conditions."""
"""Pushes the current aggregation based on interruption strategies and conditions."""
if len(self._aggregation) > 0:
if self.interruption_strategies and self._bot_speaking:
should_interrupt = self._should_interrupt_based_on_strategies()
should_interrupt = await self._should_interrupt_based_on_strategies()
if should_interrupt:
logger.debug(
@@ -346,28 +351,19 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
else:
logger.debug("Interruption conditions not met - not pushing aggregation")
# Don't process aggregation, just reset it
self.reset()
await self.reset()
else:
# No interruption config - normal behavior (always push aggregation)
await self._process_aggregation()
def _should_interrupt_based_on_strategies(self) -> bool:
async def _should_interrupt_based_on_strategies(self) -> bool:
"""Check if interruption should occur based on configured strategies."""
if not self.interruption_strategies:
return False
# Check strategies one by one until first match
for strategy in self.interruption_strategies:
if isinstance(strategy, MinWordsInterruptionStrategy):
if self._should_interrupt_min_words(strategy):
return True
async def should_interrupt(strategy: BaseInterruptionStrategy):
await strategy.append_text(self._aggregation)
return await strategy.should_interrupt()
return False
def _should_interrupt_min_words(self, strategy: MinWordsInterruptionStrategy) -> bool:
"""Check if word count threshold is met."""
word_count = len(self._aggregation.split())
return word_count >= strategy.min_words
return any([await should_interrupt(s) for s in self._interruption_strategies])
async def _start(self, frame: StartFrame):
self._create_aggregation_task()
@@ -378,6 +374,10 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
async def _cancel(self, frame: CancelFrame):
await self._cancel_aggregation_task()
async def _handle_input_audio(self, frame: InputAudioRawFrame):
for s in self.interruption_strategies:
await s.append_audio(frame.audio, frame.sample_rate)
async def _handle_user_started_speaking(self, frame: UserStartedSpeakingFrame):
self._user_speaking = True
self._waiting_for_aggregation = True
@@ -463,7 +463,7 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
# If we reached this case and the bot is speaking, let's ignore
# what the user said.
logger.debug("Ignoring user speaking emulation, bot is speaking.")
self.reset()
await self.reset()
else:
# The bot is not speaking so, let's trigger user speaking
# emulation.
@@ -560,7 +560,7 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
return
aggregation = self._aggregation.strip()
self.reset()
await self.reset()
if aggregation:
await self.handle_aggregation(aggregation)
@@ -575,7 +575,7 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
async def _handle_interruptions(self, frame: StartInterruptionFrame):
await self.push_aggregation()
self._started = 0
self.reset()
await self.reset()
async def _handle_function_calls_started(self, frame: FunctionCallsStartedFrame):
function_names = [f"{f.function_name}:{f.tool_call_id}" for f in frame.function_calls]
@@ -700,7 +700,7 @@ class LLMUserResponseAggregator(LLMUserContextAggregator):
# Reset the aggregation. Reset it before pushing it down, otherwise
# if the tasks gets cancelled we won't be able to clear things up.
self.reset()
await self.reset()
frame = LLMMessagesFrame(self._context.messages)
await self.push_frame(frame)
@@ -722,7 +722,7 @@ class LLMAssistantResponseAggregator(LLMAssistantContextAggregator):
# Reset the aggregation. Reset it before pushing it down, otherwise
# if the tasks gets cancelled we won't be able to clear things up.
self.reset()
await self.reset()
frame = LLMMessagesFrame(self._context.messages)
await self.push_frame(frame)

View File

@@ -7,16 +7,16 @@
import asyncio
from dataclasses import dataclass
from enum import Enum
from typing import Awaitable, Callable, Coroutine, Optional, Sequence
from typing import Awaitable, Callable, Coroutine, List, Optional, Sequence
from loguru import logger
from pipecat.audio.interruptions.base_interruption_strategy import BaseInterruptionStrategy
from pipecat.clocks.base_clock import BaseClock
from pipecat.frames.frames import (
CancelFrame,
ErrorFrame,
Frame,
InterruptionStrategy,
StartFrame,
StartInterruptionFrame,
StopInterruptionFrame,
@@ -68,7 +68,7 @@ class FrameProcessor(BaseObject):
self._enable_metrics = False
self._enable_usage_metrics = False
self._report_only_initial_ttfb = False
self._interruption_strategies: Optional[Sequence[InterruptionStrategy]] = None
self._interruption_strategies: List[BaseInterruptionStrategy] = []
# Indicates whether we have received the StartFrame.
self.__started = False
@@ -122,7 +122,7 @@ class FrameProcessor(BaseObject):
return self._report_only_initial_ttfb
@property
def interruption_strategies(self) -> Optional[Sequence[InterruptionStrategy]]:
def interruption_strategies(self) -> Sequence[BaseInterruptionStrategy]:
return self._interruption_strategies
def can_generate_metrics(self) -> bool:

View File

@@ -246,7 +246,7 @@ class BaseInputTransport(FrameProcessor):
# 1. No interruption config is set, OR
# 2. Interruption config is set but bot is not speaking
should_push_immediate_interruption = (
self.interruption_strategies is None or not self._bot_speaking
not self.interruption_strategies or not self._bot_speaking
)
# Make sure we notify about interruptions quickly out-of-band.

View File

@@ -0,0 +1,24 @@
#
# Copyright (c) 2024-2025 Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import unittest
from pipecat.audio.interruptions.min_words_interruption_strategy import MinWordsInterruptionStrategy
class TestInterruptionStrategy(unittest.IsolatedAsyncioTestCase):
async def test_min_words(self):
strategy = MinWordsInterruptionStrategy(min_words=2)
await strategy.append_text("Hello")
self.assertEqual(await strategy.should_interrupt(), False)
await strategy.append_text(" there!")
self.assertEqual(await strategy.should_interrupt(), True)
# Reset and check again
await strategy.reset()
await strategy.append_text("Hello!")
self.assertEqual(await strategy.should_interrupt(), False)
await strategy.append_text(" How are you?")
self.assertEqual(await strategy.should_interrupt(), True)