Merge pull request #3308 from pipecat-ai/aleix/external-turn-start-strategies
turns: add external user and bot turn start strategies
This commit is contained in:
@@ -10,10 +10,12 @@
|
||||
- VADUserTurnStartStrategy
|
||||
- TranscriptionUserTurnStartStrategy
|
||||
- MinWordsUserTurnStartStrategy
|
||||
- ExternalUserTurnStartStrategy
|
||||
|
||||
Available bot turn start strategies:
|
||||
- TranscriptionBotTurnStartStrategy
|
||||
- TurnAnalyzerBotTurnStartStrategy
|
||||
- ExternalBotTurnStartStrategy
|
||||
|
||||
The default strategies are:
|
||||
|
||||
|
||||
1
changelog/3314.changed.md
Normal file
1
changelog/3314.changed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Updated `DeepgramSTTService` to push user started/stopped speaking and interruption frames when `vad_enabled` is set to true. This centralizes the frames into the service, removing the need to have your application code handle Deepgram's events and push these frames.
|
||||
@@ -29,6 +29,7 @@ from pipecat.transcriptions.language import Language
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
from pipecat.turns.turn_start_strategies import ExternalTurnStartStrategies
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
@@ -132,7 +133,9 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
context = LLMContext(messages)
|
||||
context_aggregator = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(enable_user_speaking_frames=False),
|
||||
user_params=LLMUserAggregatorParams(
|
||||
turn_start_strategies=ExternalTurnStartStrategies()
|
||||
),
|
||||
)
|
||||
|
||||
pipeline = Pipeline(
|
||||
|
||||
@@ -27,6 +27,7 @@ from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
from pipecat.turns.turn_start_strategies import ExternalTurnStartStrategies
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
@@ -71,7 +72,8 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
|
||||
context = LLMContext(messages)
|
||||
context_aggregator = LLMContextAggregatorPair(
|
||||
context, user_params=LLMUserAggregatorParams(enable_user_speaking_frames=False)
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(turn_start_strategies=ExternalTurnStartStrategies()),
|
||||
)
|
||||
|
||||
pipeline = Pipeline(
|
||||
|
||||
@@ -11,12 +11,7 @@ from deepgram import LiveOptions
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
InterruptionFrame,
|
||||
LLMRunFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
@@ -33,6 +28,7 @@ from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
from pipecat.turns.turn_start_strategies import ExternalTurnStartStrategies
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
@@ -78,7 +74,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
context = LLMContext(messages)
|
||||
context_aggregator = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(enable_user_speaking_frames=False),
|
||||
user_params=LLMUserAggregatorParams(turn_start_strategies=ExternalTurnStartStrategies()),
|
||||
)
|
||||
|
||||
pipeline = Pipeline(
|
||||
@@ -102,14 +98,6 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@stt.event_handler("on_speech_started")
|
||||
async def on_speech_started(stt, *args, **kwargs):
|
||||
await task.queue_frames([UserStartedSpeakingFrame(), InterruptionFrame()])
|
||||
|
||||
@stt.event_handler("on_utterance_end")
|
||||
async def on_utterance_end(stt, *args, **kwargs):
|
||||
await task.queue_frames([UserStoppedSpeakingFrame()])
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
|
||||
@@ -37,9 +37,13 @@ from pipecat.frames.frames import (
|
||||
)
|
||||
from pipecat.pipeline.parallel_pipeline import ParallelPipeline
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor, FrameProcessorSetup
|
||||
from pipecat.services.llm_service import LLMService
|
||||
from pipecat.turns.turn_start_strategies import ExternalTurnStartStrategies
|
||||
from pipecat.utils.sync.base_notifier import BaseNotifier
|
||||
from pipecat.utils.sync.event_notifier import EventNotifier
|
||||
|
||||
@@ -318,11 +322,13 @@ class ClassificationProcessor(FrameProcessor):
|
||||
# User started speaking - set the voicemail event
|
||||
if self._voicemail_detected:
|
||||
self._voicemail_event.set()
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
elif isinstance(frame, UserStoppedSpeakingFrame):
|
||||
# User stopped speaking - clear the voicemail event
|
||||
if self._voicemail_detected:
|
||||
self._voicemail_event.clear()
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
else:
|
||||
# Pass all non-LLM frames through
|
||||
@@ -621,7 +627,12 @@ VOICEMAIL SYSTEM (respond "VOICEMAIL"):
|
||||
|
||||
# Create the LLM context and aggregators for conversation management
|
||||
self._context = LLMContext(self._messages)
|
||||
self._context_aggregator = LLMContextAggregatorPair(self._context)
|
||||
self._context_aggregator = LLMContextAggregatorPair(
|
||||
self._context,
|
||||
user_params=LLMUserAggregatorParams(
|
||||
turn_start_strategies=ExternalTurnStartStrategies()
|
||||
),
|
||||
)
|
||||
|
||||
# Create notification system for coordinating between components
|
||||
self._gate_notifier = EventNotifier() # Signals classification completion
|
||||
|
||||
@@ -81,16 +81,11 @@ class LLMUserAggregatorParams:
|
||||
enable_emulated_vad_interruptions: When True, allows emulated VAD events
|
||||
to interrupt the bot when it's speaking. When False, emulated speech
|
||||
is ignored while the bot is speaking.
|
||||
enable_user_speaking_frames: [DO NOT USE] added for temporary backwards
|
||||
compatibility.
|
||||
|
||||
"""
|
||||
|
||||
aggregation_timeout: float = 0.5
|
||||
turn_emulated_vad_timeout: float = 0.8
|
||||
enable_emulated_vad_interruptions: bool = False
|
||||
# Added for backwards compatibility.
|
||||
enable_user_speaking_frames: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -63,10 +63,10 @@ from pipecat.processors.aggregators.llm_context import (
|
||||
NotGiven,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.turns.bot import BaseBotTurnStartStrategy
|
||||
from pipecat.turns.bot import BaseBotTurnStartStrategy, BotTurnStartedParams
|
||||
from pipecat.turns.mute import BaseUserMuteStrategy
|
||||
from pipecat.turns.turn_start_strategies import TurnStartStrategies
|
||||
from pipecat.turns.user import BaseUserTurnStartStrategy
|
||||
from pipecat.turns.turn_start_strategies import ExternalTurnStartStrategies, TurnStartStrategies
|
||||
from pipecat.turns.user import BaseUserTurnStartStrategy, UserTurnStartedParams
|
||||
from pipecat.utils.string import TextPartForConcatenation, concatenate_aggregated_text
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
@@ -76,18 +76,12 @@ class LLMUserAggregatorParams:
|
||||
"""Parameters for configuring LLM user aggregation behavior.
|
||||
|
||||
Parameters:
|
||||
enable_user_speaking_frames: If True, the aggregator will emit frames
|
||||
indicating when the user starts and stops speaking, as well as
|
||||
interruption frames. This is enabled by default, but you may want
|
||||
to disable it if another component (e.g., an STT service) is already
|
||||
generating these frames.
|
||||
turn_start_strategies: User and bot turn start strategies.
|
||||
user_mute_strategies: List of user mute strategies.
|
||||
user_turn_end_timeout: Time in seconds to wait before considering the
|
||||
user's turn finished and starting the bot turn.
|
||||
"""
|
||||
|
||||
enable_user_speaking_frames: bool = True
|
||||
turn_start_strategies: Optional[TurnStartStrategies] = None
|
||||
user_mute_strategies: List[BaseUserMuteStrategy] = field(default_factory=list)
|
||||
user_turn_end_timeout: float = 5.0
|
||||
@@ -365,9 +359,14 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
self._user_turn_end_timeout_task_handler()
|
||||
)
|
||||
|
||||
await self._setup_turn_start_strategies()
|
||||
await self._setup_user_mute_strategies()
|
||||
|
||||
async def _setup_user_mute_strategies(self):
|
||||
for s in self._params.user_mute_strategies:
|
||||
await s.setup(self.task_manager)
|
||||
|
||||
async def _setup_turn_start_strategies(self):
|
||||
if self._turn_start_strategies.user:
|
||||
for s in self._turn_start_strategies.user:
|
||||
await s.setup(self.task_manager)
|
||||
@@ -393,9 +392,10 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
await self.cancel_task(self._user_turn_end_timeout_task)
|
||||
self._user_turn_end_timeout_task = None
|
||||
|
||||
for s in self._params.user_mute_strategies:
|
||||
await s.cleanup()
|
||||
await self._cleanup_turn_start_strategies()
|
||||
await self._cleanup_user_mute_strategies()
|
||||
|
||||
async def _cleanup_turn_start_strategies(self):
|
||||
if self._turn_start_strategies.user:
|
||||
for s in self._turn_start_strategies.user:
|
||||
await s.cleanup()
|
||||
@@ -404,6 +404,10 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
for s in self._turn_start_strategies.bot:
|
||||
await s.cleanup()
|
||||
|
||||
async def _cleanup_user_mute_strategies(self):
|
||||
for s in self._params.user_mute_strategies:
|
||||
await s.cleanup()
|
||||
|
||||
async def _maybe_mute_frame(self, frame: Frame):
|
||||
should_mute_frame = self._user_is_muted and isinstance(
|
||||
frame,
|
||||
@@ -478,6 +482,10 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
" )"
|
||||
)
|
||||
|
||||
await self._cleanup_turn_start_strategies()
|
||||
self._turn_start_strategies = ExternalTurnStartStrategies()
|
||||
await self._setup_turn_start_strategies()
|
||||
|
||||
async def _handle_vad_user_started_speaking(self, frame: VADUserStartedSpeakingFrame):
|
||||
self._vad_user_speaking = True
|
||||
|
||||
@@ -507,11 +515,17 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
)
|
||||
)
|
||||
|
||||
async def _on_user_turn_started(self, strategy: BaseUserTurnStartStrategy):
|
||||
await self._trigger_user_turn_start(strategy)
|
||||
async def _on_user_turn_started(
|
||||
self,
|
||||
strategy: BaseUserTurnStartStrategy,
|
||||
params: UserTurnStartedParams,
|
||||
):
|
||||
await self._trigger_user_turn_start(strategy, params)
|
||||
|
||||
async def _on_bot_turn_started(self, strategy: BaseBotTurnStartStrategy):
|
||||
await self._trigger_bot_turn_start(strategy)
|
||||
async def _on_bot_turn_started(
|
||||
self, strategy: BaseBotTurnStartStrategy, params: BotTurnStartedParams
|
||||
):
|
||||
await self._trigger_bot_turn_start(strategy, params)
|
||||
|
||||
async def _on_push_frame(
|
||||
self,
|
||||
@@ -529,11 +543,15 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
):
|
||||
await self.broadcast_frame(frame_cls, **kwargs)
|
||||
|
||||
async def _trigger_user_turn_start(self, strategy: Optional[BaseUserTurnStartStrategy]):
|
||||
async def _trigger_user_turn_start(
|
||||
self, strategy: Optional[BaseUserTurnStartStrategy], params: UserTurnStartedParams
|
||||
):
|
||||
# Prevent two consecutive user turn starts.
|
||||
if self._user_turn:
|
||||
return
|
||||
|
||||
logger.debug(f"User started speaking (user turn start strategy: {strategy})")
|
||||
|
||||
self._user_turn = True
|
||||
self._user_turn_end_timeout_event.set()
|
||||
|
||||
@@ -542,19 +560,22 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
for s in self._turn_start_strategies.user:
|
||||
await s.reset()
|
||||
|
||||
if self._params.enable_user_speaking_frames:
|
||||
logger.debug(f"User started speaking (user turn start strategy: {strategy})")
|
||||
if params.enable_user_speaking_frames:
|
||||
# TODO(aleix): These frames should really come from the top of the pipeline.
|
||||
await self.broadcast_frame(UserStartedSpeakingFrame)
|
||||
await self.broadcast_frame(InterruptionFrame)
|
||||
|
||||
await self._call_event_handler("on_user_turn_started", strategy)
|
||||
|
||||
async def _trigger_bot_turn_start(self, strategy: Optional[BaseBotTurnStartStrategy]):
|
||||
async def _trigger_bot_turn_start(
|
||||
self, strategy: Optional[BaseBotTurnStartStrategy], params: BotTurnStartedParams
|
||||
):
|
||||
# Prevent two consecutive bot turn starts.
|
||||
if not self._user_turn:
|
||||
return
|
||||
|
||||
logger.debug(f"User stopped speaking (bot turn start strategy: {strategy})")
|
||||
|
||||
self._user_turn = False
|
||||
self._user_turn_end_timeout_event.set()
|
||||
|
||||
@@ -563,8 +584,7 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
for s in self._turn_start_strategies.bot:
|
||||
await s.reset()
|
||||
|
||||
if self._params.enable_user_speaking_frames:
|
||||
logger.debug(f"User stopped speaking (bot turn start strategy: {strategy})")
|
||||
if params.enable_user_speaking_frames:
|
||||
# TODO(aleix): This frame should really come from the top of the pipeline.
|
||||
await self.broadcast_frame(UserStoppedSpeakingFrame)
|
||||
|
||||
@@ -584,7 +604,9 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
except asyncio.TimeoutError:
|
||||
if self._user_turn and not self._vad_user_speaking:
|
||||
await self._call_event_handler("on_user_turn_end_timeout")
|
||||
await self._trigger_bot_turn_start(None)
|
||||
await self._trigger_bot_turn_start(
|
||||
None, BotTurnStartedParams(enable_user_speaking_frames=True)
|
||||
)
|
||||
|
||||
|
||||
class LLMAssistantAggregator(LLMContextAggregator):
|
||||
|
||||
@@ -17,6 +17,8 @@ from pipecat.frames.frames import (
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
@@ -271,9 +273,12 @@ class DeepgramSTTService(STTService):
|
||||
async def _on_speech_started(self, *args, **kwargs):
|
||||
await self.start_metrics()
|
||||
await self._call_event_handler("on_speech_started", *args, **kwargs)
|
||||
await self.broadcast_frame(UserStartedSpeakingFrame)
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
|
||||
async def _on_utterance_end(self, *args, **kwargs):
|
||||
await self._call_event_handler("on_utterance_end", *args, **kwargs)
|
||||
await self.broadcast_frame(UserStoppedSpeakingFrame)
|
||||
|
||||
@traced_stt
|
||||
async def _handle_transcription(
|
||||
|
||||
@@ -477,7 +477,7 @@ class BaseInputTransport(FrameProcessor):
|
||||
)
|
||||
|
||||
# Make sure we notify about interruptions quickly out-of-band.
|
||||
if should_push_immediate_interruption and self.interruptions_allowed:
|
||||
if should_push_immediate_interruption and self._allow_interruptions:
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
elif self.interruption_strategies and self._bot_speaking:
|
||||
logger.debug(
|
||||
|
||||
@@ -498,7 +498,7 @@ class BaseOutputTransport(FrameProcessor):
|
||||
Args:
|
||||
_: The start interruption frame (unused).
|
||||
"""
|
||||
if not self._transport.interruptions_allowed:
|
||||
if not self._transport._allow_interruptions:
|
||||
return
|
||||
|
||||
# Cancel tasks.
|
||||
|
||||
@@ -4,7 +4,11 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from pipecat.turns.bot.base_bot_turn_start_strategy import BaseBotTurnStartStrategy
|
||||
from pipecat.turns.bot.base_bot_turn_start_strategy import (
|
||||
BaseBotTurnStartStrategy,
|
||||
BotTurnStartedParams,
|
||||
)
|
||||
from pipecat.turns.bot.external_bot_turn_start_strategy import ExternalBotTurnStartStrategy
|
||||
from pipecat.turns.bot.transcription_bot_turn_start_strategy import (
|
||||
TranscriptionBotTurnStartStrategy,
|
||||
)
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
"""Base turn start strategy for determining when the bot should start speaking."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Type
|
||||
|
||||
from pipecat.frames.frames import Frame
|
||||
@@ -14,6 +15,26 @@ from pipecat.utils.asyncio.task_manager import BaseTaskManager
|
||||
from pipecat.utils.base_object import BaseObject
|
||||
|
||||
|
||||
@dataclass
|
||||
class BotTurnStartedParams:
|
||||
"""Parameters emitted when a bot turn starts.
|
||||
|
||||
These parameters are passed to the `on_bot_turn_started` event and provide
|
||||
contextual information about how the bot turn should be handled by the user
|
||||
aggregator.
|
||||
|
||||
Attributes:
|
||||
enable_user_speaking_frames: Whether the user aggregator should emit
|
||||
frames indicating user speaking state (e.g., user stopped speaking)
|
||||
during the bot's turn. This is typically enabled by default, but may
|
||||
be disabled when another component (such as an STT service) is already
|
||||
responsible for generating user speaking frames.
|
||||
|
||||
"""
|
||||
|
||||
enable_user_speaking_frames: bool
|
||||
|
||||
|
||||
class BaseBotTurnStartStrategy(BaseObject):
|
||||
"""Base class for strategies that determine when the bot should start speaking.
|
||||
|
||||
@@ -28,9 +49,18 @@ class BaseBotTurnStartStrategy(BaseObject):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize the base bot turn start strategy."""
|
||||
def __init__(self, *, enable_user_speaking_frames: bool = True, **kwargs):
|
||||
"""Initialize the base bot turn start strategy.
|
||||
|
||||
Args:
|
||||
enable_user_speaking_frames: If True, the aggregator will emit frames
|
||||
indicating when the user stops speaking. This is enabled by default,
|
||||
but you may want to disable it if another component (e.g., an STT
|
||||
service) is already generating these frames.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._enable_user_speaking_frames = enable_user_speaking_frames
|
||||
self._task_manager: Optional[BaseTaskManager] = None
|
||||
self._register_event_handler("on_push_frame", sync=True)
|
||||
self._register_event_handler("on_broadcast_frame", sync=True)
|
||||
@@ -90,4 +120,7 @@ class BaseBotTurnStartStrategy(BaseObject):
|
||||
|
||||
async def trigger_bot_turn_started(self):
|
||||
"""Trigger the `on_bot_turn_started` event."""
|
||||
await self._call_event_handler("on_bot_turn_started")
|
||||
await self._call_event_handler(
|
||||
"on_bot_turn_started",
|
||||
BotTurnStartedParams(enable_user_speaking_frames=self._enable_user_speaking_frames),
|
||||
)
|
||||
|
||||
127
src/pipecat/turns/bot/external_bot_turn_start_strategy.py
Normal file
127
src/pipecat/turns/bot/external_bot_turn_start_strategy.py
Normal file
@@ -0,0 +1,127 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Bot turn start strategy triggered by externally emitted frames."""
|
||||
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.turns.bot.base_bot_turn_start_strategy import BaseBotTurnStartStrategy
|
||||
from pipecat.utils.asyncio.task_manager import BaseTaskManager
|
||||
|
||||
|
||||
class ExternalBotTurnStartStrategy(BaseBotTurnStartStrategy):
|
||||
"""Bot turn start strategy controlled by an external processor.
|
||||
|
||||
This strategy does not determine when a user turn ends on its own, it relies
|
||||
on a different processor in the pipeline which is responsible for emitting
|
||||
`UserStoppedSpeakingFrame` frames.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, *, timeout: float = 0.5):
|
||||
"""Initialize the external bot turn start strategy.
|
||||
|
||||
Args:
|
||||
timeout: A short delay used internally to handle consecutive or
|
||||
slightly delayed transcriptions.
|
||||
"""
|
||||
super().__init__(enable_user_speaking_frames=False)
|
||||
self._timeout = timeout
|
||||
self._text = ""
|
||||
self._user_speaking = False
|
||||
self._seen_interim_results = False
|
||||
self._event = asyncio.Event()
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
|
||||
async def reset(self):
|
||||
"""Reset the strategy to its initial state."""
|
||||
await super().reset()
|
||||
self._text = ""
|
||||
self._user_speaking = False
|
||||
self._seen_interim_results = False
|
||||
self._event.clear()
|
||||
|
||||
async def setup(self, task_manager: BaseTaskManager):
|
||||
"""Initialize the strategy with the given task manager.
|
||||
|
||||
Args:
|
||||
task_manager: The task manager to be associated with this instance.
|
||||
"""
|
||||
await super().setup(task_manager)
|
||||
self._task = task_manager.create_task(self._task_handler(), f"{self}::_task_handler")
|
||||
|
||||
async def cleanup(self):
|
||||
"""Cleanup the strategy."""
|
||||
await super().cleanup()
|
||||
if self._task:
|
||||
await self.task_manager.cancel_task(self._task)
|
||||
self._task = None
|
||||
|
||||
async def process_frame(self, frame: Frame):
|
||||
"""Process an incoming frame to update strategy state.
|
||||
|
||||
Updates internal transcription text and VAD state. The bot turn will be
|
||||
triggered when appropriate based on the collected frames.
|
||||
|
||||
Args:
|
||||
frame: The frame to be analyzed.
|
||||
|
||||
"""
|
||||
if isinstance(frame, UserStartedSpeakingFrame):
|
||||
await self._handle_user_started_speaking(frame)
|
||||
elif isinstance(frame, UserStoppedSpeakingFrame):
|
||||
await self._handle_user_stopped_speaking(frame)
|
||||
elif isinstance(frame, InterimTranscriptionFrame):
|
||||
await self._handle_interim_transcription(frame)
|
||||
elif isinstance(frame, TranscriptionFrame):
|
||||
await self._handle_transcription(frame)
|
||||
|
||||
async def _handle_user_started_speaking(self, _: UserStartedSpeakingFrame):
|
||||
"""Handle when the external service indicates the user is speaking."""
|
||||
self._user_speaking = True
|
||||
|
||||
async def _handle_user_stopped_speaking(self, _: UserStoppedSpeakingFrame):
|
||||
"""Handle when the external service indicates the user has stopped speaking."""
|
||||
self._user_speaking = False
|
||||
await self._maybe_trigger_bot_turn_started()
|
||||
|
||||
async def _handle_interim_transcription(self, frame: InterimTranscriptionFrame):
|
||||
self._seen_interim_results = True
|
||||
|
||||
async def _handle_transcription(self, frame: TranscriptionFrame):
|
||||
"""Handle user transcription."""
|
||||
self._text += frame.text
|
||||
# We just got a final result, so let's reset interim results.
|
||||
self._seen_interim_results = False
|
||||
# Reset aggregation timer.
|
||||
self._event.set()
|
||||
|
||||
async def _task_handler(self):
|
||||
"""Asynchronously monitor transcriptions and trigger bot turn when ready.
|
||||
|
||||
If transcription text exists and the user is not currently speaking,
|
||||
triggers the bot turn. Handles multiple or delayed transcriptions
|
||||
gracefully.
|
||||
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.wait_for(self._event.wait(), timeout=self._timeout)
|
||||
self._event.clear()
|
||||
except asyncio.TimeoutError:
|
||||
await self._maybe_trigger_bot_turn_started()
|
||||
|
||||
async def _maybe_trigger_bot_turn_started(self):
|
||||
if not self._user_speaking and not self._seen_interim_results and self._text:
|
||||
await self.trigger_bot_turn_started()
|
||||
@@ -11,10 +11,12 @@ from typing import List, Optional
|
||||
|
||||
from pipecat.turns.bot import (
|
||||
BaseBotTurnStartStrategy,
|
||||
ExternalBotTurnStartStrategy,
|
||||
TranscriptionBotTurnStartStrategy,
|
||||
)
|
||||
from pipecat.turns.user import (
|
||||
BaseUserTurnStartStrategy,
|
||||
ExternalUserTurnStartStrategy,
|
||||
TranscriptionUserTurnStartStrategy,
|
||||
VADUserTurnStartStrategy,
|
||||
)
|
||||
@@ -49,3 +51,25 @@ class TurnStartStrategies:
|
||||
self.user = [VADUserTurnStartStrategy(), TranscriptionUserTurnStartStrategy()]
|
||||
if not self.bot:
|
||||
self.bot = [TranscriptionBotTurnStartStrategy()]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExternalTurnStartStrategies(TurnStartStrategies):
|
||||
"""Default container for external user and bot turn start strategies.
|
||||
|
||||
This class provides a convenience default for configuring external turn
|
||||
control. It preconfigures `TurnStartStrategies` with
|
||||
`ExternalUserTurnStartStrategy` and `ExternalBotTurnStartStrategy`, allowing
|
||||
external processors (such as services) to control when user and bot turns
|
||||
start.
|
||||
|
||||
When using this container, the user aggregator does not push
|
||||
`UserStartedSpeakingFrame` or `UserStoppedSpeakingFrame` frames, and does
|
||||
not generate interruptions. These signals are expected to be provided by an
|
||||
external processor.
|
||||
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
self.user = [ExternalUserTurnStartStrategy()]
|
||||
self.bot = [ExternalBotTurnStartStrategy()]
|
||||
|
||||
@@ -4,7 +4,11 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from pipecat.turns.user.base_user_turn_start_strategy import BaseUserTurnStartStrategy
|
||||
from pipecat.turns.user.base_user_turn_start_strategy import (
|
||||
BaseUserTurnStartStrategy,
|
||||
UserTurnStartedParams,
|
||||
)
|
||||
from pipecat.turns.user.external_user_turn_start_strategy import ExternalUserTurnStartStrategy
|
||||
from pipecat.turns.user.min_words_user_turn_start_strategy import MinWordsUserTurnStartStrategy
|
||||
from pipecat.turns.user.transcription_user_turn_start_strategy import (
|
||||
TranscriptionUserTurnStartStrategy,
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
"""Base turn start strategy for determining when the user starts speaking."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Type
|
||||
|
||||
from pipecat.frames.frames import Frame
|
||||
@@ -14,6 +15,26 @@ from pipecat.utils.asyncio.task_manager import BaseTaskManager
|
||||
from pipecat.utils.base_object import BaseObject
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserTurnStartedParams:
|
||||
"""Parameters emitted when a user turn starts.
|
||||
|
||||
These parameters are passed to the `on_user_turn_started` event and provide
|
||||
contextual information about how the user turn should be handled by the user
|
||||
aggregator.
|
||||
|
||||
Attributes:
|
||||
enable_user_speaking_frames: Whether the user aggregator should emit
|
||||
frames indicating user speaking state (e.g., user started speaking)
|
||||
during the bot's turn. This is typically enabled by default, but may
|
||||
be disabled when another component (such as an STT service) is already
|
||||
responsible for generating user speaking frames.
|
||||
|
||||
"""
|
||||
|
||||
enable_user_speaking_frames: bool
|
||||
|
||||
|
||||
class BaseUserTurnStartStrategy(BaseObject):
|
||||
"""Base class for strategies that determine when a user starts speaking.
|
||||
|
||||
@@ -28,9 +49,19 @@ class BaseUserTurnStartStrategy(BaseObject):
|
||||
- `on_user_turn_started`: Signals that a user turn has started.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize the base user turn start strategy."""
|
||||
def __init__(self, *, enable_user_speaking_frames: bool = True, **kwargs):
|
||||
"""Initialize the base user turn start strategy.
|
||||
|
||||
Args:
|
||||
enable_user_speaking_frames: If True, the aggregator will emit frames
|
||||
indicating when the user starts speaking, as well as interruption
|
||||
frames. This is enabled by default, but you may want to disable it
|
||||
if another component (e.g., an STT service) is already generating
|
||||
these frames.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._enable_user_speaking_frames = enable_user_speaking_frames
|
||||
self._task_manager: Optional[BaseTaskManager] = None
|
||||
self._register_event_handler("on_push_frame", sync=True)
|
||||
self._register_event_handler("on_broadcast_frame", sync=True)
|
||||
@@ -90,4 +121,7 @@ class BaseUserTurnStartStrategy(BaseObject):
|
||||
|
||||
async def trigger_user_turn_started(self):
|
||||
"""Trigger the `on_user_turn_started` event."""
|
||||
await self._call_event_handler("on_user_turn_started")
|
||||
await self._call_event_handler(
|
||||
"on_user_turn_started",
|
||||
UserTurnStartedParams(enable_user_speaking_frames=self._enable_user_speaking_frames),
|
||||
)
|
||||
|
||||
35
src/pipecat/turns/user/external_user_turn_start_strategy.py
Normal file
35
src/pipecat/turns/user/external_user_turn_start_strategy.py
Normal file
@@ -0,0 +1,35 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""User turn start strategy triggered by externally emitted frames."""
|
||||
|
||||
from pipecat.frames.frames import Frame, UserStartedSpeakingFrame
|
||||
from pipecat.turns.user.base_user_turn_start_strategy import BaseUserTurnStartStrategy
|
||||
|
||||
|
||||
class ExternalUserTurnStartStrategy(BaseUserTurnStartStrategy):
|
||||
"""User turn start strategy controlled by an external processor.
|
||||
|
||||
This strategy does not determine when a user turn starts on its own, instead
|
||||
it relies on a different processor in the pipeline which is responsible for
|
||||
emitting `UserStartedSpeakingFrame` frames.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the external user turn start strategy."""
|
||||
super().__init__(enable_user_speaking_frames=False)
|
||||
|
||||
async def process_frame(self, frame: Frame):
|
||||
"""Process an incoming frame to detect user turn start.
|
||||
|
||||
Args:
|
||||
frame: The frame to be analyzed.
|
||||
"""
|
||||
await super().process_frame(frame)
|
||||
|
||||
if isinstance(frame, UserStartedSpeakingFrame):
|
||||
await self.trigger_user_turn_started()
|
||||
@@ -10,10 +10,13 @@ import unittest
|
||||
from pipecat.frames.frames import (
|
||||
InterimTranscriptionFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.turns.bot import TranscriptionBotTurnStartStrategy
|
||||
from pipecat.turns.bot.external_bot_turn_start_strategy import ExternalBotTurnStartStrategy
|
||||
from pipecat.utils.asyncio.task_manager import TaskManager, TaskManagerParams
|
||||
|
||||
AGGREGATION_TIMEOUT = 0.1
|
||||
@@ -30,7 +33,7 @@ class TestTranscriptionBotTurnStartStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
should_start = None
|
||||
|
||||
@strategy.event_handler("on_bot_turn_started")
|
||||
async def on_bot_turn_started(strategy):
|
||||
async def on_bot_turn_started(strategy, enable_user_speaking_frames):
|
||||
nonlocal should_start
|
||||
should_start = True
|
||||
|
||||
@@ -55,7 +58,7 @@ class TestTranscriptionBotTurnStartStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
should_start = None
|
||||
|
||||
@strategy.event_handler("on_bot_turn_started")
|
||||
async def on_bot_turn_started(strategy):
|
||||
async def on_bot_turn_started(strategy, enable_user_speaking_frames):
|
||||
nonlocal should_start
|
||||
should_start = True
|
||||
|
||||
@@ -86,7 +89,7 @@ class TestTranscriptionBotTurnStartStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
should_start = None
|
||||
|
||||
@strategy.event_handler("on_bot_turn_started")
|
||||
async def on_bot_turn_started(strategy):
|
||||
async def on_bot_turn_started(strategy, enable_user_speaking_frames):
|
||||
nonlocal should_start
|
||||
should_start = True
|
||||
|
||||
@@ -133,7 +136,7 @@ class TestTranscriptionBotTurnStartStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
should_start = None
|
||||
|
||||
@strategy.event_handler("on_bot_turn_started")
|
||||
async def on_bot_turn_started(strategy):
|
||||
async def on_bot_turn_started(strategy, enable_user_speaking_frames):
|
||||
nonlocal should_start
|
||||
should_start = True
|
||||
|
||||
@@ -167,7 +170,7 @@ class TestTranscriptionBotTurnStartStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
should_start = None
|
||||
|
||||
@strategy.event_handler("on_bot_turn_started")
|
||||
async def on_bot_turn_started(strategy):
|
||||
async def on_bot_turn_started(strategy, enable_user_speaking_frames):
|
||||
nonlocal should_start
|
||||
should_start = True
|
||||
|
||||
@@ -209,7 +212,7 @@ class TestTranscriptionBotTurnStartStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
should_start = None
|
||||
|
||||
@strategy.event_handler("on_bot_turn_started")
|
||||
async def on_bot_turn_started(strategy):
|
||||
async def on_bot_turn_started(strategy, enable_user_speaking_frames):
|
||||
nonlocal should_start
|
||||
should_start = True
|
||||
|
||||
@@ -239,7 +242,7 @@ class TestTranscriptionBotTurnStartStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
should_start = None
|
||||
|
||||
@strategy.event_handler("on_bot_turn_started")
|
||||
async def on_bot_turn_started(strategy):
|
||||
async def on_bot_turn_started(strategy, enable_user_speaking_frames):
|
||||
nonlocal should_start
|
||||
should_start = True
|
||||
|
||||
@@ -275,7 +278,7 @@ class TestTranscriptionBotTurnStartStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
should_start = None
|
||||
|
||||
@strategy.event_handler("on_bot_turn_started")
|
||||
async def on_bot_turn_started(strategy):
|
||||
async def on_bot_turn_started(strategy, enable_user_speaking_frames):
|
||||
nonlocal should_start
|
||||
should_start = True
|
||||
|
||||
@@ -313,7 +316,7 @@ class TestTranscriptionBotTurnStartStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
should_start = None
|
||||
|
||||
@strategy.event_handler("on_bot_turn_started")
|
||||
async def on_bot_turn_started(strategy):
|
||||
async def on_bot_turn_started(strategy, enable_user_speaking_frames):
|
||||
nonlocal should_start
|
||||
should_start = True
|
||||
|
||||
@@ -347,7 +350,7 @@ class TestTranscriptionBotTurnStartStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
should_start = None
|
||||
|
||||
@strategy.event_handler("on_bot_turn_started")
|
||||
async def on_bot_turn_started(strategy):
|
||||
async def on_bot_turn_started(strategy, enable_user_speaking_frames):
|
||||
nonlocal should_start
|
||||
should_start = True
|
||||
|
||||
@@ -392,7 +395,7 @@ class TestTranscriptionBotTurnStartStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
should_start = None
|
||||
|
||||
@strategy.event_handler("on_bot_turn_started")
|
||||
async def on_bot_turn_started(strategy):
|
||||
async def on_bot_turn_started(strategy, enable_user_speaking_frames):
|
||||
nonlocal should_start
|
||||
should_start = True
|
||||
|
||||
@@ -412,7 +415,7 @@ class TestTranscriptionBotTurnStartStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
should_start = None
|
||||
|
||||
@strategy.event_handler("on_bot_turn_started")
|
||||
async def on_bot_turn_started(strategy):
|
||||
async def on_bot_turn_started(strategy, enable_user_speaking_frames):
|
||||
nonlocal should_start
|
||||
should_start = True
|
||||
|
||||
@@ -437,7 +440,7 @@ class TestTranscriptionBotTurnStartStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
should_start = None
|
||||
|
||||
@strategy.event_handler("on_bot_turn_started")
|
||||
async def on_bot_turn_started(strategy):
|
||||
async def on_bot_turn_started(strategy, enable_user_speaking_frames):
|
||||
nonlocal should_start
|
||||
should_start = True
|
||||
|
||||
@@ -472,3 +475,35 @@ class TestTranscriptionBotTurnStartStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
# at least the aggregation timeout.
|
||||
await asyncio.sleep(AGGREGATION_TIMEOUT + 0.1)
|
||||
self.assertTrue(should_start)
|
||||
|
||||
|
||||
class TestExternalBotTurnStartStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_external_strategy(self):
|
||||
strategy = ExternalBotTurnStartStrategy()
|
||||
|
||||
should_start = None
|
||||
|
||||
@strategy.event_handler("on_bot_turn_started")
|
||||
async def on_bot_turn_started(strategy, enable_user_speaking_frames):
|
||||
nonlocal should_start
|
||||
should_start = True
|
||||
|
||||
await strategy.process_frame(VADUserStartedSpeakingFrame())
|
||||
self.assertFalse(should_start)
|
||||
|
||||
await strategy.process_frame(UserStartedSpeakingFrame())
|
||||
self.assertFalse(should_start)
|
||||
|
||||
await strategy.process_frame(UserStoppedSpeakingFrame())
|
||||
self.assertFalse(should_start)
|
||||
|
||||
await strategy.process_frame(UserStartedSpeakingFrame())
|
||||
self.assertFalse(should_start)
|
||||
|
||||
await strategy.process_frame(
|
||||
TranscriptionFrame(text="How are you?", user_id="cat", timestamp="")
|
||||
)
|
||||
self.assertFalse(should_start)
|
||||
|
||||
await strategy.process_frame(UserStoppedSpeakingFrame())
|
||||
self.assertTrue(should_start)
|
||||
|
||||
@@ -10,10 +10,12 @@ from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
InterimTranscriptionFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.turns.user import (
|
||||
ExternalUserTurnStartStrategy,
|
||||
MinWordsUserTurnStartStrategy,
|
||||
TranscriptionUserTurnStartStrategy,
|
||||
VADUserTurnStartStrategy,
|
||||
@@ -27,7 +29,7 @@ class TestMinWordsInterruptionStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
should_start = None
|
||||
|
||||
@strategy.event_handler("on_user_turn_started")
|
||||
async def on_user_turn_started(strategy):
|
||||
async def on_user_turn_started(strategy, enable_user_speaking_frames):
|
||||
nonlocal should_start
|
||||
should_start = True
|
||||
|
||||
@@ -59,7 +61,7 @@ class TestMinWordsInterruptionStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
should_start = None
|
||||
|
||||
@strategy.event_handler("on_user_turn_started")
|
||||
async def on_user_turn_started(strategy):
|
||||
async def on_user_turn_started(strategy, enable_user_speaking_frames):
|
||||
nonlocal should_start
|
||||
should_start = True
|
||||
|
||||
@@ -81,7 +83,7 @@ class TestMinWordsInterruptionStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
should_start = None
|
||||
|
||||
@strategy.event_handler("on_user_turn_started")
|
||||
async def on_user_turn_started(strategy):
|
||||
async def on_user_turn_started(strategy, enable_user_speaking_frames):
|
||||
nonlocal should_start
|
||||
should_start = True
|
||||
|
||||
@@ -102,7 +104,7 @@ class TestMinWordsInterruptionStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
should_start = None
|
||||
|
||||
@strategy.event_handler("on_user_turn_started")
|
||||
async def on_user_turn_started(strategy):
|
||||
async def on_user_turn_started(strategy, enable_user_speaking_frames):
|
||||
nonlocal should_start
|
||||
should_start = True
|
||||
|
||||
@@ -115,7 +117,7 @@ class TestMinWordsInterruptionStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
should_start = None
|
||||
|
||||
@strategy.event_handler("on_user_turn_started")
|
||||
async def on_user_turn_started(strategy):
|
||||
async def on_user_turn_started(strategy, enable_user_speaking_frames):
|
||||
nonlocal should_start
|
||||
should_start = True
|
||||
|
||||
@@ -132,7 +134,7 @@ class TestVADUserTurnStartStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
should_start = None
|
||||
|
||||
@strategy.event_handler("on_user_turn_started")
|
||||
async def on_user_turn_started(strategy):
|
||||
async def on_user_turn_started(strategy, enable_user_speaking_frames):
|
||||
nonlocal should_start
|
||||
should_start = True
|
||||
|
||||
@@ -150,7 +152,7 @@ class TestTranscriptionUserTurnStartStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
should_start = None
|
||||
|
||||
@strategy.event_handler("on_user_turn_started")
|
||||
async def on_user_turn_started(strategy):
|
||||
async def on_user_turn_started(strategy, enable_user_speaking_frames):
|
||||
nonlocal should_start
|
||||
should_start = True
|
||||
|
||||
@@ -162,3 +164,21 @@ class TestTranscriptionUserTurnStartStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
await strategy.process_frame(TranscriptionFrame(text="Hello!", user_id="", timestamp="now"))
|
||||
self.assertTrue(should_start)
|
||||
|
||||
|
||||
class TestExternalUserTurnStartStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_external_strategy(self):
|
||||
strategy = ExternalUserTurnStartStrategy()
|
||||
|
||||
should_start = None
|
||||
|
||||
@strategy.event_handler("on_user_turn_started")
|
||||
async def on_user_turn_started(strategy, enable_user_speaking_frames):
|
||||
nonlocal should_start
|
||||
should_start = True
|
||||
|
||||
await strategy.process_frame(VADUserStartedSpeakingFrame())
|
||||
self.assertFalse(should_start)
|
||||
|
||||
await strategy.process_frame(UserStartedSpeakingFrame())
|
||||
self.assertTrue(should_start)
|
||||
|
||||
Reference in New Issue
Block a user