turns: add UserTurnStartedParams and BotTurnStartedParams
This commit is contained in:
@@ -63,10 +63,10 @@ from pipecat.processors.aggregators.llm_context import (
|
||||
NotGiven,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.turns.bot import BaseBotTurnStartStrategy
|
||||
from pipecat.turns.bot import BaseBotTurnStartStrategy, BotTurnStartedParams
|
||||
from pipecat.turns.mute import BaseUserMuteStrategy
|
||||
from pipecat.turns.turn_start_strategies import ExternalTurnStartStrategies, TurnStartStrategies
|
||||
from pipecat.turns.user import BaseUserTurnStartStrategy
|
||||
from pipecat.turns.user import BaseUserTurnStartStrategy, UserTurnStartedParams
|
||||
from pipecat.utils.string import TextPartForConcatenation, concatenate_aggregated_text
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
@@ -518,22 +518,14 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
async def _on_user_turn_started(
|
||||
self,
|
||||
strategy: BaseUserTurnStartStrategy,
|
||||
enable_user_speaking_frames: bool,
|
||||
params: UserTurnStartedParams,
|
||||
):
|
||||
await self._trigger_user_turn_start(
|
||||
strategy,
|
||||
enable_user_speaking_frames=enable_user_speaking_frames,
|
||||
)
|
||||
await self._trigger_user_turn_start(strategy, params)
|
||||
|
||||
async def _on_bot_turn_started(
|
||||
self,
|
||||
strategy: BaseBotTurnStartStrategy,
|
||||
enable_user_speaking_frames: bool,
|
||||
self, strategy: BaseBotTurnStartStrategy, params: BotTurnStartedParams
|
||||
):
|
||||
await self._trigger_bot_turn_start(
|
||||
strategy,
|
||||
enable_user_speaking_frames=enable_user_speaking_frames,
|
||||
)
|
||||
await self._trigger_bot_turn_start(strategy, params)
|
||||
|
||||
async def _on_push_frame(
|
||||
self,
|
||||
@@ -552,10 +544,7 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
await self.broadcast_frame(frame_cls, **kwargs)
|
||||
|
||||
async def _trigger_user_turn_start(
|
||||
self,
|
||||
strategy: Optional[BaseUserTurnStartStrategy],
|
||||
*,
|
||||
enable_user_speaking_frames: bool,
|
||||
self, strategy: Optional[BaseUserTurnStartStrategy], params: UserTurnStartedParams
|
||||
):
|
||||
# Prevent two consecutive user turn starts.
|
||||
if self._user_turn:
|
||||
@@ -571,7 +560,7 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
for s in self._turn_start_strategies.user:
|
||||
await s.reset()
|
||||
|
||||
if enable_user_speaking_frames:
|
||||
if params.enable_user_speaking_frames:
|
||||
# TODO(aleix): These frames should really come from the top of the pipeline.
|
||||
await self.broadcast_frame(UserStartedSpeakingFrame)
|
||||
await self.broadcast_frame(InterruptionFrame)
|
||||
@@ -579,10 +568,7 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
await self._call_event_handler("on_user_turn_started", strategy)
|
||||
|
||||
async def _trigger_bot_turn_start(
|
||||
self,
|
||||
strategy: Optional[BaseBotTurnStartStrategy],
|
||||
*,
|
||||
enable_user_speaking_frames: bool,
|
||||
self, strategy: Optional[BaseBotTurnStartStrategy], params: BotTurnStartedParams
|
||||
):
|
||||
# Prevent two consecutive bot turn starts.
|
||||
if not self._user_turn:
|
||||
@@ -598,7 +584,7 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
for s in self._turn_start_strategies.bot:
|
||||
await s.reset()
|
||||
|
||||
if enable_user_speaking_frames:
|
||||
if params.enable_user_speaking_frames:
|
||||
# TODO(aleix): This frame should really come from the top of the pipeline.
|
||||
await self.broadcast_frame(UserStoppedSpeakingFrame)
|
||||
|
||||
@@ -618,7 +604,9 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
except asyncio.TimeoutError:
|
||||
if self._user_turn and not self._vad_user_speaking:
|
||||
await self._call_event_handler("on_user_turn_end_timeout")
|
||||
await self._trigger_bot_turn_start(None, enable_user_speaking_frames=True)
|
||||
await self._trigger_bot_turn_start(
|
||||
None, BotTurnStartedParams(enable_user_speaking_frames=True)
|
||||
)
|
||||
|
||||
|
||||
class LLMAssistantAggregator(LLMContextAggregator):
|
||||
|
||||
@@ -4,7 +4,10 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from pipecat.turns.bot.base_bot_turn_start_strategy import BaseBotTurnStartStrategy
|
||||
from pipecat.turns.bot.base_bot_turn_start_strategy import (
|
||||
BaseBotTurnStartStrategy,
|
||||
BotTurnStartedParams,
|
||||
)
|
||||
from pipecat.turns.bot.external_bot_turn_start_strategy import ExternalBotTurnStartStrategy
|
||||
from pipecat.turns.bot.transcription_bot_turn_start_strategy import (
|
||||
TranscriptionBotTurnStartStrategy,
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
"""Base turn start strategy for determining when the bot should start speaking."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Type
|
||||
|
||||
from pipecat.frames.frames import Frame
|
||||
@@ -14,6 +15,26 @@ from pipecat.utils.asyncio.task_manager import BaseTaskManager
|
||||
from pipecat.utils.base_object import BaseObject
|
||||
|
||||
|
||||
@dataclass
|
||||
class BotTurnStartedParams:
|
||||
"""Parameters emitted when a bot turn starts.
|
||||
|
||||
These parameters are passed to the `on_bot_turn_started` event and provide
|
||||
contextual information about how the bot turn should be handled by the user
|
||||
aggregator.
|
||||
|
||||
Attributes:
|
||||
enable_user_speaking_frames: Whether the user aggregator should emit
|
||||
frames indicating user speaking state (e.g., user stopped speaking)
|
||||
during the bot's turn. This is typically enabled by default, but may
|
||||
be disabled when another component (such as an STT service) is already
|
||||
responsible for generating user speaking frames.
|
||||
|
||||
"""
|
||||
|
||||
enable_user_speaking_frames: bool
|
||||
|
||||
|
||||
class BaseBotTurnStartStrategy(BaseObject):
|
||||
"""Base class for strategies that determine when the bot should start speaking.
|
||||
|
||||
@@ -28,9 +49,18 @@ class BaseBotTurnStartStrategy(BaseObject):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize the base bot turn start strategy."""
|
||||
def __init__(self, *, enable_user_speaking_frames: bool = True, **kwargs):
|
||||
"""Initialize the base bot turn start strategy.
|
||||
|
||||
Args:
|
||||
enable_user_speaking_frames: If True, the aggregator will emit frames
|
||||
indicating when the user stops speaking. This is enabled by default,
|
||||
but you may want to disable it if another component (e.g., an STT
|
||||
service) is already generating these frames.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._enable_user_speaking_frames = enable_user_speaking_frames
|
||||
self._task_manager: Optional[BaseTaskManager] = None
|
||||
self._register_event_handler("on_push_frame", sync=True)
|
||||
self._register_event_handler("on_broadcast_frame", sync=True)
|
||||
@@ -88,13 +118,9 @@ class BaseBotTurnStartStrategy(BaseObject):
|
||||
"""
|
||||
await self._call_event_handler("on_broadcast_frame", frame_cls, **kwargs)
|
||||
|
||||
async def trigger_bot_turn_started(self, *, enable_user_speaking_frames: bool = True):
|
||||
"""Trigger the `on_bot_turn_started` event.
|
||||
|
||||
Args:
|
||||
enable_user_speaking_frames: If True, the aggregator will emit frames
|
||||
indicating when the user stops speaking. This is enabled by default,
|
||||
but you may want to disable it if another component (e.g., an STT
|
||||
service) is already generating these frames.
|
||||
"""
|
||||
await self._call_event_handler("on_bot_turn_started", enable_user_speaking_frames)
|
||||
async def trigger_bot_turn_started(self):
|
||||
"""Trigger the `on_bot_turn_started` event."""
|
||||
await self._call_event_handler(
|
||||
"on_bot_turn_started",
|
||||
BotTurnStartedParams(enable_user_speaking_frames=self._enable_user_speaking_frames),
|
||||
)
|
||||
|
||||
@@ -30,13 +30,13 @@ class ExternalBotTurnStartStrategy(BaseBotTurnStartStrategy):
|
||||
"""
|
||||
|
||||
def __init__(self, *, timeout: float = 0.5):
|
||||
"""Initialize the transcription-based bot turn start strategy.
|
||||
"""Initialize the external bot turn start strategy.
|
||||
|
||||
Args:
|
||||
timeout: A short delay used internally to handle consecutive or
|
||||
slightly delayed transcriptions.
|
||||
"""
|
||||
super().__init__()
|
||||
super().__init__(enable_user_speaking_frames=False)
|
||||
self._timeout = timeout
|
||||
self._text = ""
|
||||
self._user_speaking = False
|
||||
@@ -124,4 +124,4 @@ class ExternalBotTurnStartStrategy(BaseBotTurnStartStrategy):
|
||||
|
||||
async def _maybe_trigger_bot_turn_started(self):
|
||||
if not self._user_speaking and not self._seen_interim_results and self._text:
|
||||
await self.trigger_bot_turn_started(enable_user_speaking_frames=False)
|
||||
await self.trigger_bot_turn_started()
|
||||
|
||||
@@ -4,7 +4,10 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from pipecat.turns.user.base_user_turn_start_strategy import BaseUserTurnStartStrategy
|
||||
from pipecat.turns.user.base_user_turn_start_strategy import (
|
||||
BaseUserTurnStartStrategy,
|
||||
UserTurnStartedParams,
|
||||
)
|
||||
from pipecat.turns.user.external_user_turn_start_strategy import ExternalUserTurnStartStrategy
|
||||
from pipecat.turns.user.min_words_user_turn_start_strategy import MinWordsUserTurnStartStrategy
|
||||
from pipecat.turns.user.transcription_user_turn_start_strategy import (
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
"""Base turn start strategy for determining when the user starts speaking."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Type
|
||||
|
||||
from pipecat.frames.frames import Frame
|
||||
@@ -14,6 +15,26 @@ from pipecat.utils.asyncio.task_manager import BaseTaskManager
|
||||
from pipecat.utils.base_object import BaseObject
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserTurnStartedParams:
|
||||
"""Parameters emitted when a user turn starts.
|
||||
|
||||
These parameters are passed to the `on_user_turn_started` event and provide
|
||||
contextual information about how the user turn should be handled by the user
|
||||
aggregator.
|
||||
|
||||
Attributes:
|
||||
enable_user_speaking_frames: Whether the user aggregator should emit
|
||||
frames indicating user speaking state (e.g., user started speaking)
|
||||
during the bot's turn. This is typically enabled by default, but may
|
||||
be disabled when another component (such as an STT service) is already
|
||||
responsible for generating user speaking frames.
|
||||
|
||||
"""
|
||||
|
||||
enable_user_speaking_frames: bool
|
||||
|
||||
|
||||
class BaseUserTurnStartStrategy(BaseObject):
|
||||
"""Base class for strategies that determine when a user starts speaking.
|
||||
|
||||
@@ -28,9 +49,19 @@ class BaseUserTurnStartStrategy(BaseObject):
|
||||
- `on_user_turn_started`: Signals that a user turn has started.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize the base user turn start strategy."""
|
||||
def __init__(self, *, enable_user_speaking_frames: bool = True, **kwargs):
|
||||
"""Initialize the base user turn start strategy.
|
||||
|
||||
Args:
|
||||
enable_user_speaking_frames: If True, the aggregator will emit frames
|
||||
indicating when the user starts speaking, as well as interruption
|
||||
frames. This is enabled by default, but you may want to disable it
|
||||
if another component (e.g., an STT service) is already generating
|
||||
these frames.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._enable_user_speaking_frames = enable_user_speaking_frames
|
||||
self._task_manager: Optional[BaseTaskManager] = None
|
||||
self._register_event_handler("on_push_frame", sync=True)
|
||||
self._register_event_handler("on_broadcast_frame", sync=True)
|
||||
@@ -88,15 +119,9 @@ class BaseUserTurnStartStrategy(BaseObject):
|
||||
"""
|
||||
await self._call_event_handler("on_broadcast_frame", frame_cls, **kwargs)
|
||||
|
||||
async def trigger_user_turn_started(self, *, enable_user_speaking_frames: bool = True):
|
||||
"""Trigger the `on_user_turn_started` event.
|
||||
|
||||
Args:
|
||||
enable_user_speaking_frames: If True, the aggregator will emit frames
|
||||
indicating when the user starts speaking, as well as interruption
|
||||
frames. This is enabled by default, but you may want to disable it
|
||||
if another component (e.g., an STT service) is already generating
|
||||
these frames.
|
||||
**kwargs: Keyword arguments to be passed to the frame's constructor.
|
||||
"""
|
||||
await self._call_event_handler("on_user_turn_started", enable_user_speaking_frames)
|
||||
async def trigger_user_turn_started(self):
|
||||
"""Trigger the `on_user_turn_started` event."""
|
||||
await self._call_event_handler(
|
||||
"on_user_turn_started",
|
||||
UserTurnStartedParams(enable_user_speaking_frames=self._enable_user_speaking_frames),
|
||||
)
|
||||
|
||||
@@ -19,6 +19,10 @@ class ExternalUserTurnStartStrategy(BaseUserTurnStartStrategy):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the external user turn start strategy."""
|
||||
super().__init__(enable_user_speaking_frames=False)
|
||||
|
||||
async def process_frame(self, frame: Frame):
|
||||
"""Process an incoming frame to detect user turn start.
|
||||
|
||||
@@ -28,4 +32,4 @@ class ExternalUserTurnStartStrategy(BaseUserTurnStartStrategy):
|
||||
await super().process_frame(frame)
|
||||
|
||||
if isinstance(frame, UserStartedSpeakingFrame):
|
||||
await self.trigger_user_turn_started(enable_user_speaking_frames=False)
|
||||
await self.trigger_user_turn_started()
|
||||
|
||||
Reference in New Issue
Block a user