Merge pull request #1938 from pipecat-ai/aleix/custom-interruption-strategies
allow custom interruption strategies
This commit is contained in:
15
CHANGELOG.md
15
CHANGELOG.md
@@ -26,12 +26,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- Added OpenTelemetry tracing for `GeminiMultimodalLiveLLMService` and
|
||||
`OpenAIRealtimeBetaLLMService`.
|
||||
|
||||
- Added `interruption_strategies` to `PipelineParams` using
|
||||
`MinWordsInterruptionStrategy` to specify minimum words required to interrupt
|
||||
the bot when it's speaking. Use
|
||||
`interruption_strategies=[MinWordsInterruptionStrategy(min_words=N)]` to
|
||||
require users to speak at least N words before interrupting. If not
|
||||
specified, the normal interruption behavior applies.
|
||||
- Added initial support for interruption strategies, which determine if the user
|
||||
should interrupt the bot while the bot is speaking. Interruption strategies
|
||||
can be based on factors such as audio volume or the number of words spoken by
|
||||
the user. These can be specified via the new `interruption_strategies` field
|
||||
in `PipelineParams`. A new `MinWordsInterruptionStrategy` strategy has been
|
||||
introduced which triggers an interruption if the user has spoken a minimum
|
||||
number of words. If no interruption strategies are specified, the normal
|
||||
interruption behavior applies. If multiple strategies are provided, the first
|
||||
one that evaluates to true will trigger the interruption.
|
||||
|
||||
- `BaseInputTransport` now handles `StopFrame`. When a `StopFrame` is received
|
||||
the transport will pause sending frames downstream until a new `StartFrame` is
|
||||
|
||||
@@ -9,11 +9,11 @@ import os
|
||||
import sys
|
||||
|
||||
import aiohttp
|
||||
from daily_runner import configure
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.examples.daily_runner import configure
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
|
||||
0
src/pipecat/audio/interruptions/__init__.py
Normal file
0
src/pipecat/audio/interruptions/__init__.py
Normal file
@@ -0,0 +1,38 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseInterruptionStrategy(ABC):
|
||||
"""This is a base class for interruption strategies. Interruption strategies
|
||||
decide when the user can interrupt the bot while the bot is speaking. For
|
||||
example, there could be strategies based on audio volume or strategies based
|
||||
on the number of words the user spoke.
|
||||
|
||||
"""
|
||||
|
||||
async def append_audio(self, audio: bytes, sample_rate: int):
|
||||
"""Appends audio to the strategy. Not all strategies handle audio."""
|
||||
pass
|
||||
|
||||
async def append_text(self, text: str):
|
||||
"""Appends text to the strategy. Not all strategies handle text."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def should_interrupt(self) -> bool:
|
||||
"""This is called when the user stops speaking and it's time to decide
|
||||
whether the user should interrupt the bot. The decision will be based on
|
||||
the aggregated audio and/or text.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def reset(self):
|
||||
"""Reset the current accumulated text and/or audio."""
|
||||
pass
|
||||
@@ -0,0 +1,40 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.interruptions.base_interruption_strategy import BaseInterruptionStrategy
|
||||
|
||||
|
||||
class MinWordsInterruptionStrategy(BaseInterruptionStrategy):
|
||||
"""This is an interruption strategy based on a minimum number of words said
|
||||
by the user. That is, the strategy will be true if the user has said at
|
||||
least that amount of words.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, *, min_words: int):
|
||||
super().__init__()
|
||||
self._min_words = min_words
|
||||
self._text = ""
|
||||
|
||||
async def append_text(self, text: str):
|
||||
"""Appends text for later analysis. Not all strategies need to handle
|
||||
text.
|
||||
|
||||
"""
|
||||
self._text += text
|
||||
|
||||
async def should_interrupt(self) -> bool:
|
||||
word_count = len(self._text.split())
|
||||
interrupt = word_count >= self._min_words
|
||||
logger.debug(
|
||||
f"should_interrupt={interrupt} num_spoken_words={word_count} min_words={self._min_words}"
|
||||
)
|
||||
return interrupt
|
||||
|
||||
async def reset(self):
|
||||
self._text = ""
|
||||
@@ -19,6 +19,7 @@ from typing import (
|
||||
Tuple,
|
||||
)
|
||||
|
||||
from pipecat.audio.interruptions.base_interruption_strategy import BaseInterruptionStrategy
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.metrics.metrics import MetricsData
|
||||
from pipecat.transcriptions.language import Language
|
||||
@@ -439,28 +440,6 @@ class OutputDTMFFrame(DTMFFrame, DataFrame):
|
||||
#
|
||||
|
||||
|
||||
@dataclass
|
||||
class InterruptionStrategy:
|
||||
"""Base class for interruption strategies."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class MinWordsInterruptionStrategy(InterruptionStrategy):
|
||||
"""Strategy for interruption behavior based on a minimum number of words spoken by the user.
|
||||
|
||||
Args:
|
||||
min_words: If set, user must speak at least this many words to interrupt
|
||||
"""
|
||||
|
||||
min_words: int
|
||||
|
||||
def __post_init__(self):
|
||||
if self.min_words <= 0:
|
||||
raise ValueError("min_words must be greater than 0")
|
||||
|
||||
|
||||
@dataclass
|
||||
class StartFrame(SystemFrame):
|
||||
"""This is the first frame that should be pushed down a pipeline."""
|
||||
@@ -471,7 +450,7 @@ class StartFrame(SystemFrame):
|
||||
enable_metrics: bool = False
|
||||
enable_usage_metrics: bool = False
|
||||
report_only_initial_ttfb: bool = False
|
||||
interruption_strategies: Optional[Sequence[InterruptionStrategy]] = None
|
||||
interruption_strategies: List[BaseInterruptionStrategy] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -11,6 +11,7 @@ from typing import Any, AsyncIterable, Dict, Iterable, List, Optional, Sequence,
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from pipecat.audio.interruptions.base_interruption_strategy import BaseInterruptionStrategy
|
||||
from pipecat.clocks.base_clock import BaseClock
|
||||
from pipecat.clocks.system_clock import SystemClock
|
||||
from pipecat.frames.frames import (
|
||||
@@ -22,7 +23,6 @@ from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
HeartbeatFrame,
|
||||
InterruptionStrategy,
|
||||
LLMFullResponseEndFrame,
|
||||
MetricsFrame,
|
||||
StartFrame,
|
||||
@@ -75,7 +75,7 @@ class PipelineParams(BaseModel):
|
||||
report_only_initial_ttfb: bool = False
|
||||
send_initial_empty_metrics: bool = True
|
||||
start_metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
interruption_strategies: Optional[Sequence[InterruptionStrategy]] = None
|
||||
interruption_strategies: List[BaseInterruptionStrategy] = Field(default_factory=list)
|
||||
|
||||
|
||||
class PipelineTaskSource(FrameProcessor):
|
||||
|
||||
@@ -11,6 +11,7 @@ from typing import Dict, List, Literal, Optional, Set
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.interruptions.base_interruption_strategy import BaseInterruptionStrategy
|
||||
from pipecat.frames.frames import (
|
||||
BotInterruptionFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
@@ -24,6 +25,7 @@ from pipecat.frames.frames import (
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
FunctionCallsStartedFrame,
|
||||
InputAudioRawFrame,
|
||||
InterimTranscriptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
@@ -33,7 +35,6 @@ from pipecat.frames.frames import (
|
||||
LLMSetToolChoiceFrame,
|
||||
LLMSetToolsFrame,
|
||||
LLMTextFrame,
|
||||
MinWordsInterruptionStrategy,
|
||||
OpenAILLMContextAssistantTimestampFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
@@ -162,7 +163,7 @@ class BaseLLMResponseAggregator(FrameProcessor):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset(self):
|
||||
async def reset(self):
|
||||
"""Reset the internals of this aggregator. This should not modify the
|
||||
internal messages.
|
||||
"""
|
||||
@@ -229,7 +230,7 @@ class LLMContextResponseAggregator(BaseLLMResponseAggregator):
|
||||
def set_tool_choice(self, tool_choice: Literal["none", "auto", "required"] | dict):
|
||||
self._context.set_tool_choice(tool_choice)
|
||||
|
||||
def reset(self):
|
||||
async def reset(self):
|
||||
self._aggregation = ""
|
||||
|
||||
|
||||
@@ -272,10 +273,11 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
|
||||
self._aggregation_event = asyncio.Event()
|
||||
self._aggregation_task = None
|
||||
|
||||
def reset(self):
|
||||
super().reset()
|
||||
async def reset(self):
|
||||
await super().reset()
|
||||
self._seen_interim_results = False
|
||||
self._waiting_for_aggregation = False
|
||||
[await s.reset() for s in self._interruption_strategies]
|
||||
|
||||
async def handle_aggregation(self, aggregation: str):
|
||||
self._context.add_message({"role": self.role, "content": aggregation})
|
||||
@@ -296,6 +298,9 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
|
||||
elif isinstance(frame, CancelFrame):
|
||||
await self._cancel(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, InputAudioRawFrame):
|
||||
await self._handle_input_audio(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, UserStartedSpeakingFrame):
|
||||
await self._handle_user_started_speaking(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
@@ -326,16 +331,16 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
|
||||
async def _process_aggregation(self):
|
||||
"""Process the current aggregation and push it downstream."""
|
||||
aggregation = self._aggregation
|
||||
self.reset()
|
||||
await self.reset()
|
||||
await self.handle_aggregation(aggregation)
|
||||
frame = OpenAILLMContextFrame(self._context)
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def push_aggregation(self):
|
||||
"""Pushes the current aggregation based on interruption configuration and conditions."""
|
||||
"""Pushes the current aggregation based on interruption strategies and conditions."""
|
||||
if len(self._aggregation) > 0:
|
||||
if self.interruption_strategies and self._bot_speaking:
|
||||
should_interrupt = self._should_interrupt_based_on_strategies()
|
||||
should_interrupt = await self._should_interrupt_based_on_strategies()
|
||||
|
||||
if should_interrupt:
|
||||
logger.debug(
|
||||
@@ -346,28 +351,19 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
|
||||
else:
|
||||
logger.debug("Interruption conditions not met - not pushing aggregation")
|
||||
# Don't process aggregation, just reset it
|
||||
self.reset()
|
||||
await self.reset()
|
||||
else:
|
||||
# No interruption config - normal behavior (always push aggregation)
|
||||
await self._process_aggregation()
|
||||
|
||||
def _should_interrupt_based_on_strategies(self) -> bool:
|
||||
async def _should_interrupt_based_on_strategies(self) -> bool:
|
||||
"""Check if interruption should occur based on configured strategies."""
|
||||
if not self.interruption_strategies:
|
||||
return False
|
||||
|
||||
# Check strategies one by one until first match
|
||||
for strategy in self.interruption_strategies:
|
||||
if isinstance(strategy, MinWordsInterruptionStrategy):
|
||||
if self._should_interrupt_min_words(strategy):
|
||||
return True
|
||||
async def should_interrupt(strategy: BaseInterruptionStrategy):
|
||||
await strategy.append_text(self._aggregation)
|
||||
return await strategy.should_interrupt()
|
||||
|
||||
return False
|
||||
|
||||
def _should_interrupt_min_words(self, strategy: MinWordsInterruptionStrategy) -> bool:
|
||||
"""Check if word count threshold is met."""
|
||||
word_count = len(self._aggregation.split())
|
||||
return word_count >= strategy.min_words
|
||||
return any([await should_interrupt(s) for s in self._interruption_strategies])
|
||||
|
||||
async def _start(self, frame: StartFrame):
|
||||
self._create_aggregation_task()
|
||||
@@ -378,6 +374,10 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
|
||||
async def _cancel(self, frame: CancelFrame):
|
||||
await self._cancel_aggregation_task()
|
||||
|
||||
async def _handle_input_audio(self, frame: InputAudioRawFrame):
|
||||
for s in self.interruption_strategies:
|
||||
await s.append_audio(frame.audio, frame.sample_rate)
|
||||
|
||||
async def _handle_user_started_speaking(self, frame: UserStartedSpeakingFrame):
|
||||
self._user_speaking = True
|
||||
self._waiting_for_aggregation = True
|
||||
@@ -463,7 +463,7 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
|
||||
# If we reached this case and the bot is speaking, let's ignore
|
||||
# what the user said.
|
||||
logger.debug("Ignoring user speaking emulation, bot is speaking.")
|
||||
self.reset()
|
||||
await self.reset()
|
||||
else:
|
||||
# The bot is not speaking so, let's trigger user speaking
|
||||
# emulation.
|
||||
@@ -560,7 +560,7 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
|
||||
return
|
||||
|
||||
aggregation = self._aggregation.strip()
|
||||
self.reset()
|
||||
await self.reset()
|
||||
|
||||
if aggregation:
|
||||
await self.handle_aggregation(aggregation)
|
||||
@@ -575,7 +575,7 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
|
||||
async def _handle_interruptions(self, frame: StartInterruptionFrame):
|
||||
await self.push_aggregation()
|
||||
self._started = 0
|
||||
self.reset()
|
||||
await self.reset()
|
||||
|
||||
async def _handle_function_calls_started(self, frame: FunctionCallsStartedFrame):
|
||||
function_names = [f"{f.function_name}:{f.tool_call_id}" for f in frame.function_calls]
|
||||
@@ -700,7 +700,7 @@ class LLMUserResponseAggregator(LLMUserContextAggregator):
|
||||
|
||||
# Reset the aggregation. Reset it before pushing it down, otherwise
|
||||
# if the tasks gets cancelled we won't be able to clear things up.
|
||||
self.reset()
|
||||
await self.reset()
|
||||
|
||||
frame = LLMMessagesFrame(self._context.messages)
|
||||
await self.push_frame(frame)
|
||||
@@ -722,7 +722,7 @@ class LLMAssistantResponseAggregator(LLMAssistantContextAggregator):
|
||||
|
||||
# Reset the aggregation. Reset it before pushing it down, otherwise
|
||||
# if the tasks gets cancelled we won't be able to clear things up.
|
||||
self.reset()
|
||||
await self.reset()
|
||||
|
||||
frame = LLMMessagesFrame(self._context.messages)
|
||||
await self.push_frame(frame)
|
||||
|
||||
@@ -7,16 +7,16 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Awaitable, Callable, Coroutine, Optional, Sequence
|
||||
from typing import Awaitable, Callable, Coroutine, List, Optional, Sequence
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.interruptions.base_interruption_strategy import BaseInterruptionStrategy
|
||||
from pipecat.clocks.base_clock import BaseClock
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterruptionStrategy,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
StopInterruptionFrame,
|
||||
@@ -68,7 +68,7 @@ class FrameProcessor(BaseObject):
|
||||
self._enable_metrics = False
|
||||
self._enable_usage_metrics = False
|
||||
self._report_only_initial_ttfb = False
|
||||
self._interruption_strategies: Optional[Sequence[InterruptionStrategy]] = None
|
||||
self._interruption_strategies: List[BaseInterruptionStrategy] = []
|
||||
|
||||
# Indicates whether we have received the StartFrame.
|
||||
self.__started = False
|
||||
@@ -122,7 +122,7 @@ class FrameProcessor(BaseObject):
|
||||
return self._report_only_initial_ttfb
|
||||
|
||||
@property
|
||||
def interruption_strategies(self) -> Optional[Sequence[InterruptionStrategy]]:
|
||||
def interruption_strategies(self) -> Sequence[BaseInterruptionStrategy]:
|
||||
return self._interruption_strategies
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
|
||||
@@ -246,7 +246,7 @@ class BaseInputTransport(FrameProcessor):
|
||||
# 1. No interruption config is set, OR
|
||||
# 2. Interruption config is set but bot is not speaking
|
||||
should_push_immediate_interruption = (
|
||||
self.interruption_strategies is None or not self._bot_speaking
|
||||
not self.interruption_strategies or not self._bot_speaking
|
||||
)
|
||||
|
||||
# Make sure we notify about interruptions quickly out-of-band.
|
||||
|
||||
24
tests/test_interruption_strategies.py
Normal file
24
tests/test_interruption_strategies.py
Normal file
@@ -0,0 +1,24 @@
|
||||
#
|
||||
# Copyright (c) 2024-2025 Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import unittest
|
||||
|
||||
from pipecat.audio.interruptions.min_words_interruption_strategy import MinWordsInterruptionStrategy
|
||||
|
||||
|
||||
class TestInterruptionStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_min_words(self):
|
||||
strategy = MinWordsInterruptionStrategy(min_words=2)
|
||||
await strategy.append_text("Hello")
|
||||
self.assertEqual(await strategy.should_interrupt(), False)
|
||||
await strategy.append_text(" there!")
|
||||
self.assertEqual(await strategy.should_interrupt(), True)
|
||||
# Reset and check again
|
||||
await strategy.reset()
|
||||
await strategy.append_text("Hello!")
|
||||
self.assertEqual(await strategy.should_interrupt(), False)
|
||||
await strategy.append_text(" How are you?")
|
||||
self.assertEqual(await strategy.should_interrupt(), True)
|
||||
Reference in New Issue
Block a user