TurnAnalyzerBotTurnStartStrategy: broadcast SpeechControlParamsFrame
This commit is contained in:
@@ -30,7 +30,7 @@ from typing import (
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.dtmf.types import KeypadEntry as NewKeypadEntry
|
||||
from pipecat.audio.interruptions.base_interruption_strategy import BaseInterruptionStrategy
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.turn.base_turn_analyzer import BaseTurnParams
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.metrics.metrics import MetricsData
|
||||
from pipecat.transcriptions.language import Language
|
||||
@@ -1550,7 +1550,7 @@ class SpeechControlParamsFrame(SystemFrame):
|
||||
"""
|
||||
|
||||
vad_params: Optional[VADParams] = None
|
||||
turn_params: Optional[SmartTurnParams] = None
|
||||
turn_params: Optional[BaseTurnParams] = None
|
||||
|
||||
|
||||
#
|
||||
|
||||
@@ -16,7 +16,7 @@ import json
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Literal, Optional, Set
|
||||
from typing import Any, Dict, List, Literal, Optional, Set, Type
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -310,11 +310,13 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
for s in self.turn_start_strategies.user:
|
||||
await s.setup(self.task_manager)
|
||||
s.add_event_handler("on_push_frame", self._on_push_frame)
|
||||
s.add_event_handler("on_broadcast_frame", self._on_broadcast_frame)
|
||||
s.add_event_handler("on_user_turn_started", self._on_user_turn_started)
|
||||
|
||||
for s in self.turn_start_strategies.bot:
|
||||
await s.setup(self.task_manager)
|
||||
s.add_event_handler("on_push_frame", self._on_push_frame)
|
||||
s.add_event_handler("on_broadcast_frame", self._on_broadcast_frame)
|
||||
s.add_event_handler("on_bot_turn_started", self._on_bot_turn_started)
|
||||
|
||||
async def _stop(self, frame: EndFrame):
|
||||
@@ -375,10 +377,18 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
self,
|
||||
strategy: BaseUserTurnStartStrategy | BaseBotTurnStartStrategy,
|
||||
frame: Frame,
|
||||
direction: FrameDirection,
|
||||
direction: FrameDirection = FrameDirection.DOWNSTREAM,
|
||||
):
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _on_broadcast_frame(
|
||||
self,
|
||||
strategy: BaseUserTurnStartStrategy | BaseBotTurnStartStrategy,
|
||||
frame_cls: Type[Frame],
|
||||
**kwargs,
|
||||
):
|
||||
await self.broadcast_frame(frame_cls, **kwargs)
|
||||
|
||||
async def _trigger_user_turn_start(self, strategy: BaseUserTurnStartStrategy):
|
||||
if self._user_speaking:
|
||||
return
|
||||
|
||||
@@ -6,9 +6,10 @@
|
||||
|
||||
"""Base turn start strategy for determining when the bot should start speaking."""
|
||||
|
||||
from typing import Optional
|
||||
from typing import Optional, Type
|
||||
|
||||
from pipecat.frames.frames import Frame
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.utils.asyncio.task_manager import BaseTaskManager
|
||||
from pipecat.utils.base_object import BaseObject
|
||||
|
||||
@@ -32,6 +33,7 @@ class BaseBotTurnStartStrategy(BaseObject):
|
||||
super().__init__(**kwargs)
|
||||
self._task_manager: Optional[BaseTaskManager] = None
|
||||
self._register_event_handler("on_push_frame", sync=True)
|
||||
self._register_event_handler("on_broadcast_frame", sync=True)
|
||||
self._register_event_handler("on_bot_turn_started", sync=True)
|
||||
|
||||
@property
|
||||
@@ -69,6 +71,24 @@ class BaseBotTurnStartStrategy(BaseObject):
|
||||
"""
|
||||
pass
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
"""Emit on_push_frame to push a frame using the user aggreagtor.
|
||||
|
||||
Args:
|
||||
frame: The frame to be pushed.
|
||||
direction: What direction the frame should be pushed to.
|
||||
"""
|
||||
await self._call_event_handler("on_push_frame", frame, direction)
|
||||
|
||||
async def broadcast_frame(self, frame_cls: Type[Frame], **kwargs):
|
||||
"""Emit on_broadcast_frame to broadcast a frame using the user aggreagtor.
|
||||
|
||||
Args:
|
||||
frame_cls: The class of the frame to be broadcasted.
|
||||
**kwargs: Keyword arguments to be passed to the frame's constructor.
|
||||
"""
|
||||
await self._call_event_handler("on_broadcast_frame", frame_cls, **kwargs)
|
||||
|
||||
async def trigger_bot_turn_started(self):
|
||||
"""Trigger the `on_bot_turn_started` event."""
|
||||
await self._call_event_handler("on_bot_turn_started")
|
||||
|
||||
@@ -15,13 +15,13 @@ from pipecat.frames.frames import (
|
||||
InputAudioRawFrame,
|
||||
InterimTranscriptionFrame,
|
||||
MetricsFrame,
|
||||
SpeechControlParamsFrame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import MetricsData
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.turns.bot.base_bot_turn_start_strategy import BaseBotTurnStartStrategy
|
||||
from pipecat.utils.asyncio.task_manager import BaseTaskManager
|
||||
|
||||
@@ -95,6 +95,7 @@ class TurnAnalyzerBotTurnStartStrategy(BaseBotTurnStartStrategy):
|
||||
async def _start(self, frame: StartFrame):
|
||||
"""Process the start frame to configure the turn analyzer."""
|
||||
self._turn_analyzer.set_sample_rate(frame.audio_in_sample_rate)
|
||||
await self.broadcast_frame(SpeechControlParamsFrame, turn_params=self._turn_analyzer.params)
|
||||
|
||||
async def _handle_input_audio(self, frame: InputAudioRawFrame):
|
||||
"""Handle input audio to check if the turn is completed."""
|
||||
@@ -129,11 +130,7 @@ class TurnAnalyzerBotTurnStartStrategy(BaseBotTurnStartStrategy):
|
||||
async def _handle_prediction_result(self, result: Optional[MetricsData]):
|
||||
"""Handle a prediction result event from the turn analyzer."""
|
||||
if result:
|
||||
await self._call_event_handler(
|
||||
"on_push_frame",
|
||||
MetricsFrame(data=[result]),
|
||||
FrameDirection.DOWNSTREAM,
|
||||
)
|
||||
await self.push_frame(MetricsFrame(data=[result]))
|
||||
|
||||
async def _task_handler(self):
|
||||
"""Asynchronously monitor events and trigger bot turn when appropriate.
|
||||
|
||||
@@ -6,9 +6,10 @@
|
||||
|
||||
"""Base turn start strategy for determining when the user starts speaking."""
|
||||
|
||||
from typing import Optional
|
||||
from typing import Optional, Type
|
||||
|
||||
from pipecat.frames.frames import Frame
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.utils.asyncio.task_manager import BaseTaskManager
|
||||
from pipecat.utils.base_object import BaseObject
|
||||
|
||||
@@ -23,6 +24,7 @@ class BaseUserTurnStartStrategy(BaseObject):
|
||||
Events triggered by user turn start strategies:
|
||||
|
||||
- `on_push_frame`: Indicates the strategy wants to push a frame.
|
||||
- `on_broadcast_frame`: Indicates the strategy wants to broadcast a frame.
|
||||
- `on_user_turn_started`: Signals that a user turn has started.
|
||||
"""
|
||||
|
||||
@@ -31,6 +33,7 @@ class BaseUserTurnStartStrategy(BaseObject):
|
||||
super().__init__(**kwargs)
|
||||
self._task_manager: Optional[BaseTaskManager] = None
|
||||
self._register_event_handler("on_push_frame", sync=True)
|
||||
self._register_event_handler("on_broadcast_frame", sync=True)
|
||||
self._register_event_handler("on_user_turn_started", sync=True)
|
||||
|
||||
@property
|
||||
@@ -68,6 +71,24 @@ class BaseUserTurnStartStrategy(BaseObject):
|
||||
"""
|
||||
pass
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
"""Emit on_push_frame to push a frame using the user aggreagtor.
|
||||
|
||||
Args:
|
||||
frame: The frame to be pushed.
|
||||
direction: What direction the frame should be pushed to.
|
||||
"""
|
||||
await self._call_event_handler("on_push_frame", frame, direction)
|
||||
|
||||
async def broadcast_frame(self, frame_cls: Type[Frame], **kwargs):
|
||||
"""Emit on_broadcast_frame to broadcast a frame using the user aggreagtor.
|
||||
|
||||
Args:
|
||||
frame_cls: The class of the frame to be broadcasted.
|
||||
**kwargs: Keyword arguments to be passed to the frame's constructor.
|
||||
"""
|
||||
await self._call_event_handler("on_broadcast_frame", frame_cls, **kwargs)
|
||||
|
||||
async def trigger_user_turn_started(self):
|
||||
"""Trigger the `on_user_turn_started` event."""
|
||||
await self._call_event_handler("on_user_turn_started")
|
||||
|
||||
Reference in New Issue
Block a user