Merge pull request #3372 from pipecat-ai/aleix/add-user-turn-controller-processor
add new UserTurnController and UserTurnProcessor
This commit is contained in:
1
changelog/3372.added.2.md
Normal file
1
changelog/3372.added.2.md
Normal file
@@ -0,0 +1 @@
|
||||
- Added `UserTurnProcessor`, a frame processor built on `UserTurnController` that pushes `UserStartedSpeakingFrame` and `UserStoppedSpeakingFrame` frames and interruptions based on the controller's user turn strategies.
|
||||
1
changelog/3372.added.md
Normal file
1
changelog/3372.added.md
Normal file
@@ -0,0 +1 @@
|
||||
- Added `UserTurnController` to manage user turns. It emits `on_user_turn_started`, `on_user_turn_stopped`, and `on_user_turn_stop_timeout` events, and can be integrated into processors to detect and handle user turns. `LLMUserAggregator` and `UserTurnProcessor` are implemented using this controller.
|
||||
1
changelog/3372.other.md
Normal file
1
changelog/3372.other.md
Normal file
@@ -0,0 +1 @@
|
||||
- Added a new foundational example `53-concurrent-llm-evaluation.py` that shows how to use `UserTurnProcessor`.
|
||||
180
examples/foundational/53-concurrent-llm-evaluation.py
Normal file
180
examples/foundational/53-concurrent-llm-evaluation.py
Normal file
@@ -0,0 +1,180 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.parallel_pipeline import ParallelPipeline
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.groq.llm import GroqLLMService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
from pipecat.turns.user_stop import TurnAnalyzerUserTurnStopStrategy
|
||||
from pipecat.turns.user_turn_processor import UserTurnProcessor
|
||||
from pipecat.turns.user_turn_strategies import ExternalUserTurnStrategies, UserTurnStrategies
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="d4db5fb9-f44b-4bd1-85fa-192e0f0d75f9", # Spanish-speaking Lady
|
||||
)
|
||||
|
||||
openai_llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
openai_messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
groq_llm = GroqLLMService(
|
||||
api_key=os.getenv("GROQ_API_KEY"), model="meta-llama/llama-4-maverick-17b-128e-instruct"
|
||||
)
|
||||
|
||||
groq_messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a very helpful assistant. Your goal is to demonstrate your capabilities in detail in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
openai_context = LLMContext(openai_messages)
|
||||
groq_context = LLMContext(groq_messages)
|
||||
|
||||
# We use this external user turn processor. This processor will push
|
||||
# UserStartedSpeakingFrame and UserStoppedSpeakingFrame as well as
|
||||
# interruptions. This can be used in advanced cases when there are multiple
|
||||
# aggregators in the pipeline.
|
||||
user_turn_processor = UserTurnProcessor(
|
||||
user_turn_strategies=UserTurnStrategies(
|
||||
stop=[TurnAnalyzerUserTurnStopStrategy(turn_analyzer=LocalSmartTurnAnalyzerV3())]
|
||||
),
|
||||
)
|
||||
|
||||
# We use external user turn strategies for both aggregators since the turn
|
||||
# management is done by the common UserTurnProcessor.
|
||||
openai_context_aggregator = LLMContextAggregatorPair(
|
||||
openai_context,
|
||||
user_params=LLMUserAggregatorParams(user_turn_strategies=ExternalUserTurnStrategies()),
|
||||
)
|
||||
groq_context_aggregator = LLMContextAggregatorPair(
|
||||
groq_context,
|
||||
user_params=LLMUserAggregatorParams(user_turn_strategies=ExternalUserTurnStrategies()),
|
||||
)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt, # STT
|
||||
user_turn_processor,
|
||||
ParallelPipeline(
|
||||
[
|
||||
openai_context_aggregator.user(), # User responses
|
||||
openai_llm, # LLM
|
||||
tts, # TTS (bot will speak the chosen language)
|
||||
transport.output(), # Transport bot output
|
||||
openai_context_aggregator.assistant(), # Assistant spoken responses
|
||||
],
|
||||
[
|
||||
groq_context_aggregator.user(), # User responses
|
||||
groq_llm, # LLM
|
||||
groq_context_aggregator.assistant(), # Assistant responses
|
||||
],
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
openai_messages.append(
|
||||
{"role": "system", "content": "Please introduce yourself to the user."}
|
||||
)
|
||||
groq_messages.append(
|
||||
{"role": "system", "content": "Please introduce yourself to the user."}
|
||||
)
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -239,6 +239,10 @@ TESTS_51 = [
|
||||
("51-grok-realtime.py", EVAL_WEATHER),
|
||||
]
|
||||
|
||||
TESTS_53 = [
|
||||
("53-concurrent-llm-evaluation.py", EVAL_SIMPLE_MATH),
|
||||
]
|
||||
|
||||
TESTS = [
|
||||
*TESTS_07,
|
||||
*TESTS_12,
|
||||
@@ -254,6 +258,7 @@ TESTS = [
|
||||
*TESTS_49,
|
||||
*TESTS_50,
|
||||
*TESTS_51,
|
||||
*TESTS_53,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -66,6 +66,7 @@ from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.turns.mute import BaseUserMuteStrategy
|
||||
from pipecat.turns.user_start import BaseUserTurnStartStrategy, UserTurnStartedParams
|
||||
from pipecat.turns.user_stop import BaseUserTurnStopStrategy, UserTurnStoppedParams
|
||||
from pipecat.turns.user_turn_controller import UserTurnController
|
||||
from pipecat.turns.user_turn_strategies import ExternalUserTurnStrategies, UserTurnStrategies
|
||||
from pipecat.utils.string import TextPartForConcatenation, concatenate_aggregated_text
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
@@ -220,10 +221,10 @@ class LLMContextAggregator(FrameProcessor):
|
||||
class LLMUserAggregator(LLMContextAggregator):
|
||||
"""User LLM aggregator that aggregates user input during active user turns.
|
||||
|
||||
This aggregator operates within turn boundaries defined by the configured
|
||||
user and bot turn start strategies. User turn start strategies indicate when
|
||||
a user turn begins, while bot turn start strategies signal when the user
|
||||
turn has ended and control transitions to the bot turn.
|
||||
This aggregator uses a turn controller and operates within turn boundaries
|
||||
defined by the controller's configured user turn strategies. User turn start
|
||||
strategies indicate when a user turn begins, while user turn stop strategies
|
||||
signal when the user turn has ended.
|
||||
|
||||
The aggregator collects and aggregates speech-to-text transcriptions that
|
||||
occur while a user turn is active and pushes the final aggregation when the
|
||||
@@ -238,11 +239,11 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
Example::
|
||||
|
||||
@aggregator.event_handler("on_user_turn_started")
|
||||
async def on_user_turn_started(aggregator, Optional[strategy]):
|
||||
async def on_user_turn_started(aggregator, strategy: BaseUserTurnStartStrategy]):
|
||||
...
|
||||
|
||||
@aggregator.event_handler("on_user_turn_stopped")
|
||||
async def on_user_turn_stopped(aggregator, Optional[strategy]):
|
||||
async def on_user_turn_stopped(aggregator, strategy: BaseUserTurnStopStrategy):
|
||||
...
|
||||
|
||||
@aggregator.event_handler("on_user_turn_stop_timeout")
|
||||
@@ -268,20 +269,30 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
super().__init__(context=context, role="user", **kwargs)
|
||||
self._params = params or LLMUserAggregatorParams()
|
||||
|
||||
# Initialize default user turn strategies.
|
||||
self._user_turn_strategies = self._params.user_turn_strategies or UserTurnStrategies()
|
||||
|
||||
self._vad_user_speaking = False
|
||||
|
||||
self._user_turn = False
|
||||
self._user_is_muted = False
|
||||
self._user_turn_stop_timeout_event = asyncio.Event()
|
||||
self._user_turn_stop_timeout_task: Optional[asyncio.Task] = None
|
||||
|
||||
self._register_event_handler("on_user_turn_started")
|
||||
self._register_event_handler("on_user_turn_stopped")
|
||||
self._register_event_handler("on_user_turn_stop_timeout")
|
||||
|
||||
user_turn_strategies = self._params.user_turn_strategies or UserTurnStrategies()
|
||||
|
||||
self._user_is_muted = False
|
||||
|
||||
self._user_turn_controller = UserTurnController(
|
||||
user_turn_strategies=user_turn_strategies,
|
||||
user_turn_stop_timeout=self._params.user_turn_stop_timeout,
|
||||
)
|
||||
self._user_turn_controller.add_event_handler("on_push_frame", self._on_push_frame)
|
||||
self._user_turn_controller.add_event_handler("on_broadcast_frame", self._on_broadcast_frame)
|
||||
self._user_turn_controller.add_event_handler(
|
||||
"on_user_turn_started", self._on_user_turn_started
|
||||
)
|
||||
self._user_turn_controller.add_event_handler(
|
||||
"on_user_turn_stopped", self._on_user_turn_stopped
|
||||
)
|
||||
self._user_turn_controller.add_event_handler(
|
||||
"on_user_turn_stop_timeout", self._on_user_turn_stop_timeout
|
||||
)
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up processor resources."""
|
||||
await super().cleanup()
|
||||
@@ -312,12 +323,6 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
elif isinstance(frame, CancelFrame):
|
||||
await self._cancel(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, VADUserStartedSpeakingFrame):
|
||||
await self._handle_vad_user_started_speaking(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, VADUserStoppedSpeakingFrame):
|
||||
await self._handle_vad_user_stopped_speaking(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, TranscriptionFrame):
|
||||
await self._handle_transcription(frame)
|
||||
elif isinstance(frame, LLMRunFrame):
|
||||
@@ -341,7 +346,7 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
await self._user_turn_strategies_process_frame(frame)
|
||||
await self._user_turn_controller.process_frame(frame)
|
||||
|
||||
async def push_aggregation(self):
|
||||
"""Push the current aggregation."""
|
||||
@@ -354,33 +359,11 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
await self.push_context_frame()
|
||||
|
||||
async def _start(self, frame: StartFrame):
|
||||
if not self._user_turn_stop_timeout_task:
|
||||
self._user_turn_stop_timeout_task = self.create_task(
|
||||
self._user_turn_stop_timeout_task_handler()
|
||||
)
|
||||
await self._user_turn_controller.setup(self.task_manager)
|
||||
|
||||
await self._setup_user_turn_strategies()
|
||||
await self._setup_user_mute_strategies()
|
||||
|
||||
async def _setup_user_mute_strategies(self):
|
||||
for s in self._params.user_mute_strategies:
|
||||
await s.setup(self.task_manager)
|
||||
|
||||
async def _setup_user_turn_strategies(self):
|
||||
if self._user_turn_strategies.start:
|
||||
for s in self._user_turn_strategies.start:
|
||||
await s.setup(self.task_manager)
|
||||
s.add_event_handler("on_push_frame", self._on_push_frame)
|
||||
s.add_event_handler("on_broadcast_frame", self._on_broadcast_frame)
|
||||
s.add_event_handler("on_user_turn_started", self._on_user_turn_started)
|
||||
|
||||
if self._user_turn_strategies.stop:
|
||||
for s in self._user_turn_strategies.stop:
|
||||
await s.setup(self.task_manager)
|
||||
s.add_event_handler("on_push_frame", self._on_push_frame)
|
||||
s.add_event_handler("on_broadcast_frame", self._on_broadcast_frame)
|
||||
s.add_event_handler("on_user_turn_stopped", self._on_user_turn_stopped)
|
||||
|
||||
async def _stop(self, frame: EndFrame):
|
||||
await self._cleanup()
|
||||
|
||||
@@ -388,23 +371,8 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
await self._cleanup()
|
||||
|
||||
async def _cleanup(self):
|
||||
if self._user_turn_stop_timeout_task:
|
||||
await self.cancel_task(self._user_turn_stop_timeout_task)
|
||||
self._user_turn_stop_timeout_task = None
|
||||
await self._user_turn_controller.cleanup()
|
||||
|
||||
await self._cleanup_user_turn_strategies()
|
||||
await self._cleanup_user_mute_strategies()
|
||||
|
||||
async def _cleanup_user_turn_strategies(self):
|
||||
if self._user_turn_strategies.start:
|
||||
for s in self._user_turn_strategies.start:
|
||||
await s.cleanup()
|
||||
|
||||
if self._user_turn_strategies.stop:
|
||||
for s in self._user_turn_strategies.stop:
|
||||
await s.cleanup()
|
||||
|
||||
async def _cleanup_user_mute_strategies(self):
|
||||
for s in self._params.user_mute_strategies:
|
||||
await s.cleanup()
|
||||
|
||||
@@ -436,15 +404,6 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
|
||||
return should_mute_frame
|
||||
|
||||
async def _user_turn_strategies_process_frame(self, frame: Frame):
|
||||
if self._user_turn_strategies.start:
|
||||
for strategy in self._user_turn_strategies.start:
|
||||
await strategy.process_frame(frame)
|
||||
|
||||
if self._user_turn_strategies.stop:
|
||||
for strategy in self._user_turn_strategies.stop:
|
||||
await strategy.process_frame(frame)
|
||||
|
||||
async def _handle_llm_run(self, frame: LLMRunFrame):
|
||||
await self.push_context_frame()
|
||||
|
||||
@@ -482,21 +441,7 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
" )"
|
||||
)
|
||||
|
||||
await self._cleanup_user_turn_strategies()
|
||||
self._user_turn_strategies = ExternalUserTurnStrategies()
|
||||
await self._setup_user_turn_strategies()
|
||||
|
||||
async def _handle_vad_user_started_speaking(self, frame: VADUserStartedSpeakingFrame):
|
||||
self._vad_user_speaking = True
|
||||
|
||||
# The user started talking, let's reset the user turn timeout.
|
||||
self._user_turn_stop_timeout_event.set()
|
||||
|
||||
async def _handle_vad_user_stopped_speaking(self, frame: VADUserStoppedSpeakingFrame):
|
||||
self._vad_user_speaking = False
|
||||
|
||||
# The user stopped talking, let's reset the user turn timeout.
|
||||
self._user_turn_stop_timeout_event.set()
|
||||
await self._user_turn_controller.update_strategies(ExternalUserTurnStrategies())
|
||||
|
||||
async def _handle_transcription(self, frame: TranscriptionFrame):
|
||||
text = frame.text
|
||||
@@ -505,9 +450,6 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
if not text.strip():
|
||||
return
|
||||
|
||||
# We have creceived a transcription, let's reset the user turn timeout.
|
||||
self._user_turn_stop_timeout_event.set()
|
||||
|
||||
# Transcriptions never include inter-part spaces (so far).
|
||||
self._aggregation.append(
|
||||
TextPartForConcatenation(
|
||||
@@ -515,101 +457,48 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
)
|
||||
)
|
||||
|
||||
async def _on_user_turn_started(
|
||||
self,
|
||||
strategy: BaseUserTurnStartStrategy,
|
||||
params: UserTurnStartedParams,
|
||||
):
|
||||
await self._trigger_user_turn_start(strategy, params)
|
||||
|
||||
async def _on_user_turn_stopped(
|
||||
self, strategy: BaseUserTurnStopStrategy, params: UserTurnStoppedParams
|
||||
):
|
||||
await self._trigger_user_turn_stop(strategy, params)
|
||||
|
||||
async def _on_push_frame(
|
||||
self,
|
||||
strategy: BaseUserTurnStartStrategy | BaseUserTurnStopStrategy,
|
||||
frame: Frame,
|
||||
direction: FrameDirection = FrameDirection.DOWNSTREAM,
|
||||
self, controller, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM
|
||||
):
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _on_broadcast_frame(
|
||||
self,
|
||||
strategy: BaseUserTurnStartStrategy | BaseUserTurnStopStrategy,
|
||||
frame_cls: Type[Frame],
|
||||
**kwargs,
|
||||
):
|
||||
async def _on_broadcast_frame(self, controller, frame_cls: Type[Frame], **kwargs):
|
||||
await self.broadcast_frame(frame_cls, **kwargs)
|
||||
|
||||
async def _trigger_user_turn_start(
|
||||
self, strategy: Optional[BaseUserTurnStartStrategy], params: UserTurnStartedParams
|
||||
async def _on_user_turn_started(
|
||||
self,
|
||||
controller: UserTurnController,
|
||||
strategy: BaseUserTurnStartStrategy,
|
||||
params: UserTurnStartedParams,
|
||||
):
|
||||
# Prevent two consecutive user turn starts.
|
||||
if self._user_turn:
|
||||
return
|
||||
|
||||
logger.debug(f"User started speaking (user turn start strategy: {strategy})")
|
||||
|
||||
self._user_turn = True
|
||||
self._user_turn_stop_timeout_event.set()
|
||||
|
||||
# Reset all user turn start strategies to start fresh.
|
||||
if self._user_turn_strategies.start:
|
||||
for s in self._user_turn_strategies.start:
|
||||
await s.reset()
|
||||
logger.debug(f"{self}: User started speaking (user turn start strategy: {strategy})")
|
||||
|
||||
if params.enable_user_speaking_frames:
|
||||
# TODO(aleix): This frame should really come from the top of the pipeline.
|
||||
await self.broadcast_frame(UserStartedSpeakingFrame)
|
||||
|
||||
if params.enable_interruptions and self._allow_interruptions:
|
||||
# TODO(aleix): This frame should really come from the top of the pipeline.
|
||||
await self.broadcast_frame(InterruptionFrame)
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
|
||||
await self._call_event_handler("on_user_turn_started", strategy)
|
||||
|
||||
async def _trigger_user_turn_stop(
|
||||
self, strategy: Optional[BaseUserTurnStopStrategy], params: UserTurnStoppedParams
|
||||
async def _on_user_turn_stopped(
|
||||
self,
|
||||
controller: UserTurnController,
|
||||
strategy: BaseUserTurnStopStrategy,
|
||||
params: UserTurnStoppedParams,
|
||||
):
|
||||
# Prevent two consecutive user turn stops.
|
||||
if not self._user_turn:
|
||||
return
|
||||
|
||||
logger.debug(f"User stopped speaking (user turn stop strategy: {strategy})")
|
||||
|
||||
self._user_turn = False
|
||||
self._user_turn_stop_timeout_event.set()
|
||||
|
||||
# Reset all user turn stop strategies to start fresh.
|
||||
if self._user_turn_strategies.stop:
|
||||
for s in self._user_turn_strategies.stop:
|
||||
await s.reset()
|
||||
logger.debug(f"{self}: User stopped speaking (user turn stop strategy: {strategy})")
|
||||
|
||||
if params.enable_user_speaking_frames:
|
||||
# TODO(aleix): This frame should really come from the top of the pipeline.
|
||||
await self.broadcast_frame(UserStoppedSpeakingFrame)
|
||||
|
||||
await self._call_event_handler("on_user_turn_stopped", strategy)
|
||||
|
||||
# Always push context frame.
|
||||
await self.push_aggregation()
|
||||
|
||||
async def _user_turn_stop_timeout_task_handler(self):
|
||||
while True:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._user_turn_stop_timeout_event.wait(),
|
||||
timeout=self._params.user_turn_stop_timeout,
|
||||
)
|
||||
self._user_turn_stop_timeout_event.clear()
|
||||
except asyncio.TimeoutError:
|
||||
if self._user_turn and not self._vad_user_speaking:
|
||||
await self._call_event_handler("on_user_turn_stop_timeout")
|
||||
await self._trigger_user_turn_stop(
|
||||
None, UserTurnStoppedParams(enable_user_speaking_frames=True)
|
||||
)
|
||||
await self._call_event_handler("on_user_turn_stopped", strategy)
|
||||
|
||||
async def _on_user_turn_stop_timeout(self, controller):
|
||||
await self._call_event_handler("on_user_turn_stop_timeout")
|
||||
|
||||
|
||||
class LLMAssistantAggregator(LLMContextAggregator):
|
||||
|
||||
267
src/pipecat/turns/user_turn_controller.py
Normal file
267
src/pipecat/turns/user_turn_controller.py
Normal file
@@ -0,0 +1,267 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""This module defines a controller for managing user turn lifecycle."""
|
||||
|
||||
import asyncio
|
||||
from typing import Optional, Type
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
TranscriptionFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.turns.user_start import BaseUserTurnStartStrategy, UserTurnStartedParams
|
||||
from pipecat.turns.user_stop import BaseUserTurnStopStrategy, UserTurnStoppedParams
|
||||
from pipecat.turns.user_turn_strategies import UserTurnStrategies
|
||||
from pipecat.utils.asyncio.task_manager import BaseTaskManager
|
||||
from pipecat.utils.base_object import BaseObject
|
||||
|
||||
|
||||
class UserTurnController(BaseObject):
|
||||
"""Controller for managing user turn lifecycle.
|
||||
|
||||
This class manages user turn state (active/inactive), handles start and stop
|
||||
strategies, and emits events when user turns begin, end, or timeout occurs.
|
||||
|
||||
Event handlers available:
|
||||
|
||||
- on_user_turn_started: Emitted when a user turn starts.
|
||||
- on_user_turn_stopped: Emitted when a user turn stops.
|
||||
- on_user_turn_stop_timeout: Emitted if no stop strategy triggers before timeout.
|
||||
- on_push_frame: Emitted when a strategy wants to push a frame.
|
||||
- on_broadcast_frame: Emitted when a strategy wants to broadcast a frame.
|
||||
|
||||
Example::
|
||||
|
||||
@controller.event_handler("on_user_turn_started")
|
||||
async def on_user_turn_started(controller, strategy: BaseUserTurnStartStrategy, params: UserTurnStartedParams):
|
||||
...
|
||||
|
||||
@controller.event_handler("on_user_turn_stopped")
|
||||
async def on_user_turn_stopped(controller, strategy: BaseUserTurnStopStrategy, params: UserTurnStoppedParams):
|
||||
...
|
||||
|
||||
@controller.event_handler("on_user_turn_stop_timeout")
|
||||
async def on_user_turn_stop_timeout(controller):
|
||||
...
|
||||
|
||||
@controller.event_handler("on_push_frame")
|
||||
async def on_push_frame(controller, frame: Frame, direction: FrameDirection):
|
||||
...
|
||||
|
||||
@controller.event_handler("on_broadcast_frame")
|
||||
async def on_broadcast_frame(controller, frame_cls: Type[Frame], **kwargs):
|
||||
...
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
user_turn_strategies: UserTurnStrategies,
|
||||
user_turn_stop_timeout: float = 5.0,
|
||||
):
|
||||
"""Initialize the user turn controller.
|
||||
|
||||
Args:
|
||||
user_turn_strategies: Configured strategies for starting and stopping user turns.
|
||||
user_turn_stop_timeout: Timeout in seconds to automatically stop a user turn
|
||||
if no activity is detected.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self._user_turn_strategies = user_turn_strategies
|
||||
self._user_turn_stop_timeout = user_turn_stop_timeout
|
||||
|
||||
self._task_manager: Optional[BaseTaskManager] = None
|
||||
|
||||
self._vad_user_speaking = False
|
||||
|
||||
self._user_turn = False
|
||||
self._user_turn_stop_timeout_event = asyncio.Event()
|
||||
self._user_turn_stop_timeout_task: Optional[asyncio.Task] = None
|
||||
|
||||
self._register_event_handler("on_push_frame", sync=True)
|
||||
self._register_event_handler("on_broadcast_frame", sync=True)
|
||||
self._register_event_handler("on_user_turn_started", sync=True)
|
||||
self._register_event_handler("on_user_turn_stopped", sync=True)
|
||||
self._register_event_handler("on_user_turn_stop_timeout", sync=True)
|
||||
|
||||
@property
|
||||
def task_manager(self) -> BaseTaskManager:
|
||||
"""Returns the configured task manager."""
|
||||
if not self._task_manager:
|
||||
raise RuntimeError(f"{self} user turn controller was not properly setup")
|
||||
return self._task_manager
|
||||
|
||||
async def setup(self, task_manager: BaseTaskManager):
|
||||
"""Initialize the controller with the given task manager.
|
||||
|
||||
Args:
|
||||
task_manager: The task manager to be associated with this instance.
|
||||
"""
|
||||
self._task_manager = task_manager
|
||||
|
||||
if not self._user_turn_stop_timeout_task:
|
||||
self._user_turn_stop_timeout_task = self.task_manager.create_task(
|
||||
self._user_turn_stop_timeout_task_handler(),
|
||||
f"{self}::_user_turn_stop_timeout_task_handler",
|
||||
)
|
||||
|
||||
await self._setup_strategies()
|
||||
|
||||
async def cleanup(self):
|
||||
"""Cleanup the controller."""
|
||||
await super().cleanup()
|
||||
|
||||
if self._user_turn_stop_timeout_task:
|
||||
await self.task_manager.cancel_task(self._user_turn_stop_timeout_task)
|
||||
self._user_turn_stop_timeout_task = None
|
||||
|
||||
await self._cleanup_strategies()
|
||||
|
||||
async def update_strategies(self, strategies: UserTurnStrategies):
|
||||
"""Replace the current strategies with the given ones.
|
||||
|
||||
Args:
|
||||
strategies: The new user turn strategies the controller should use.
|
||||
"""
|
||||
await self._cleanup_strategies()
|
||||
self._user_turn_strategies = strategies
|
||||
await self._setup_strategies()
|
||||
|
||||
async def process_frame(self, frame: Frame):
|
||||
"""Process an incoming frame to detect user turn start or stop.
|
||||
|
||||
The frame is passed to the configured user turn strategies, which are
|
||||
responsible for deciding when a user turn starts or stops and emitting
|
||||
the corresponding events.
|
||||
|
||||
Args:
|
||||
frame: The frame to be processed.
|
||||
|
||||
"""
|
||||
if isinstance(frame, VADUserStartedSpeakingFrame):
|
||||
await self._handle_vad_user_started_speaking(frame)
|
||||
elif isinstance(frame, VADUserStoppedSpeakingFrame):
|
||||
await self._handle_vad_user_stopped_speaking(frame)
|
||||
elif isinstance(frame, TranscriptionFrame):
|
||||
await self._handle_transcription(frame)
|
||||
|
||||
for strategy in self._user_turn_strategies.start or []:
|
||||
await strategy.process_frame(frame)
|
||||
|
||||
for strategy in self._user_turn_strategies.stop or []:
|
||||
await strategy.process_frame(frame)
|
||||
|
||||
async def _setup_strategies(self):
|
||||
for s in self._user_turn_strategies.start or []:
|
||||
await s.setup(self.task_manager)
|
||||
s.add_event_handler("on_push_frame", self._on_push_frame)
|
||||
s.add_event_handler("on_broadcast_frame", self._on_broadcast_frame)
|
||||
s.add_event_handler("on_user_turn_started", self._on_user_turn_started)
|
||||
|
||||
for s in self._user_turn_strategies.stop or []:
|
||||
await s.setup(self.task_manager)
|
||||
s.add_event_handler("on_push_frame", self._on_push_frame)
|
||||
s.add_event_handler("on_broadcast_frame", self._on_broadcast_frame)
|
||||
s.add_event_handler("on_user_turn_stopped", self._on_user_turn_stopped)
|
||||
|
||||
async def _cleanup_strategies(self):
|
||||
for s in self._user_turn_strategies.start or []:
|
||||
await s.cleanup()
|
||||
|
||||
for s in self._user_turn_strategies.stop or []:
|
||||
await s.cleanup()
|
||||
|
||||
async def _handle_vad_user_started_speaking(self, frame: VADUserStartedSpeakingFrame):
|
||||
self._vad_user_speaking = True
|
||||
|
||||
# The user started talking, let's reset the user turn timeout.
|
||||
self._user_turn_stop_timeout_event.set()
|
||||
|
||||
async def _handle_vad_user_stopped_speaking(self, frame: VADUserStoppedSpeakingFrame):
|
||||
self._vad_user_speaking = False
|
||||
|
||||
# The user stopped talking, let's reset the user turn timeout.
|
||||
self._user_turn_stop_timeout_event.set()
|
||||
|
||||
async def _handle_transcription(self, frame: TranscriptionFrame):
|
||||
# We have creceived a transcription, let's reset the user turn timeout.
|
||||
self._user_turn_stop_timeout_event.set()
|
||||
|
||||
async def _on_push_frame(
|
||||
self,
|
||||
strategy: BaseUserTurnStartStrategy | BaseUserTurnStopStrategy,
|
||||
frame: Frame,
|
||||
direction: FrameDirection = FrameDirection.DOWNSTREAM,
|
||||
):
|
||||
await self._call_event_handler("on_push_frame", frame, direction)
|
||||
|
||||
async def _on_broadcast_frame(
|
||||
self,
|
||||
strategy: BaseUserTurnStartStrategy | BaseUserTurnStopStrategy,
|
||||
frame_cls: Type[Frame],
|
||||
**kwargs,
|
||||
):
|
||||
await self._call_event_handler("on_broadcast_frame", frame_cls, **kwargs)
|
||||
|
||||
async def _on_user_turn_started(
|
||||
self,
|
||||
strategy: BaseUserTurnStartStrategy,
|
||||
params: UserTurnStartedParams,
|
||||
):
|
||||
await self._trigger_user_turn_start(strategy, params)
|
||||
|
||||
async def _on_user_turn_stopped(
|
||||
self, strategy: BaseUserTurnStopStrategy, params: UserTurnStoppedParams
|
||||
):
|
||||
await self._trigger_user_turn_stop(strategy, params)
|
||||
|
||||
async def _trigger_user_turn_start(
|
||||
self, strategy: Optional[BaseUserTurnStartStrategy], params: UserTurnStartedParams
|
||||
):
|
||||
# Prevent two consecutive user turn starts.
|
||||
if self._user_turn:
|
||||
return
|
||||
|
||||
self._user_turn = True
|
||||
self._user_turn_stop_timeout_event.set()
|
||||
|
||||
await self._call_event_handler("on_user_turn_started", strategy, params)
|
||||
|
||||
async def _trigger_user_turn_stop(
|
||||
self, strategy: Optional[BaseUserTurnStopStrategy], params: UserTurnStoppedParams
|
||||
):
|
||||
# Prevent two consecutive user turn stops.
|
||||
if not self._user_turn:
|
||||
return
|
||||
|
||||
self._user_turn = False
|
||||
self._user_turn_stop_timeout_event.set()
|
||||
|
||||
# Reset all user turn stop strategies to start fresh.
|
||||
for s in self._user_turn_strategies.stop or []:
|
||||
await s.reset()
|
||||
|
||||
await self._call_event_handler("on_user_turn_stopped", strategy, params)
|
||||
|
||||
async def _user_turn_stop_timeout_task_handler(self):
|
||||
while True:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._user_turn_stop_timeout_event.wait(),
|
||||
timeout=self._user_turn_stop_timeout,
|
||||
)
|
||||
self._user_turn_stop_timeout_event.clear()
|
||||
except asyncio.TimeoutError:
|
||||
if self._user_turn and not self._vad_user_speaking:
|
||||
await self._call_event_handler("on_user_turn_stop_timeout")
|
||||
await self._trigger_user_turn_stop(
|
||||
None, UserTurnStoppedParams(enable_user_speaking_frames=True)
|
||||
)
|
||||
182
src/pipecat/turns/user_turn_processor.py
Normal file
182
src/pipecat/turns/user_turn_processor.py
Normal file
@@ -0,0 +1,182 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Frame processor for managing the user turn lifecycle."""
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
StartFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.turns.user_start import BaseUserTurnStartStrategy, UserTurnStartedParams
|
||||
from pipecat.turns.user_stop import BaseUserTurnStopStrategy, UserTurnStoppedParams
|
||||
from pipecat.turns.user_turn_controller import UserTurnController
|
||||
from pipecat.turns.user_turn_strategies import UserTurnStrategies
|
||||
|
||||
|
||||
class UserTurnProcessor(FrameProcessor):
|
||||
"""Frame processor for managing the user turn lifecycle.
|
||||
|
||||
This processor uses a turn controller to determine when a user turn starts
|
||||
or stops. The actual frames emitted (e.g., UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame) or interruptions depend on the configured
|
||||
strategies.
|
||||
|
||||
Event handlers available:
|
||||
|
||||
- on_user_turn_started: Emitted when a user turn starts.
|
||||
- on_user_turn_stopped: Emitted when a user turn stops.
|
||||
- on_user_turn_stop_timeout: Emitted if no stop strategy triggers before timeout.
|
||||
|
||||
Example::
|
||||
|
||||
@processor.event_handler("on_user_turn_started")
|
||||
async def on_user_turn_started(processor, strategy: BaseUserTurnStartStrategy):
|
||||
...
|
||||
|
||||
@processor.event_handler("on_user_turn_stopped")
|
||||
async def on_user_turn_stopped(processor, strategy: BaseUserTurnStopStrategy):
|
||||
...
|
||||
|
||||
@processor.event_handler("on_user_turn_stop_timeout")
|
||||
async def on_user_turn_stop_timeout(processor):
|
||||
...
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
user_turn_strategies: Optional[UserTurnStrategies] = None,
|
||||
user_turn_stop_timeout: float = 5.0,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the user turn processor.
|
||||
|
||||
Args:
|
||||
user_turn_strategies: Configured strategies for starting and stopping user turns.
|
||||
user_turn_stop_timeout: Timeout in seconds to automatically stop a user turn
|
||||
if no activity is detected.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._register_event_handler("on_user_turn_started")
|
||||
self._register_event_handler("on_user_turn_stopped")
|
||||
self._register_event_handler("on_user_turn_stop_timeout")
|
||||
|
||||
self._user_turn_controller = UserTurnController(
|
||||
user_turn_strategies=user_turn_strategies or UserTurnStrategies(),
|
||||
user_turn_stop_timeout=user_turn_stop_timeout,
|
||||
)
|
||||
self._user_turn_controller.add_event_handler("on_push_frame", self._on_push_frame)
|
||||
self._user_turn_controller.add_event_handler("on_broadcast_frame", self._on_broadcast_frame)
|
||||
self._user_turn_controller.add_event_handler(
|
||||
"on_user_turn_started", self._on_user_turn_started
|
||||
)
|
||||
self._user_turn_controller.add_event_handler(
|
||||
"on_user_turn_stopped", self._on_user_turn_stopped
|
||||
)
|
||||
self._user_turn_controller.add_event_handler(
|
||||
"on_user_turn_stop_timeout", self._on_user_turn_stop_timeout
|
||||
)
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up processor resources."""
|
||||
await super().cleanup()
|
||||
await self._cleanup()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process an incoming frame to detect user turn start or stop.
|
||||
|
||||
The frame is passed to the user turn controlled which is responsible for
|
||||
deciding when a user turn starts or stops and emitting the corresponding
|
||||
events.
|
||||
|
||||
Args:
|
||||
frame: The frame to be processed.
|
||||
direction: The direction of the incoming frame.
|
||||
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, StartFrame):
|
||||
# Push StartFrame before start(), because we want StartFrame to be
|
||||
# processed by every processor before any other frame is processed.
|
||||
await self.push_frame(frame, direction)
|
||||
await self._start(frame)
|
||||
elif isinstance(frame, EndFrame):
|
||||
# Push EndFrame before stop(), because stop() waits on the task to
|
||||
# finish and the task finishes when EndFrame is processed.
|
||||
await self.push_frame(frame, direction)
|
||||
await self._stop(frame)
|
||||
elif isinstance(frame, CancelFrame):
|
||||
await self._cancel(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
await self._user_turn_controller.process_frame(frame)
|
||||
|
||||
async def _start(self, frame: StartFrame):
|
||||
await self._user_turn_controller.setup(self.task_manager)
|
||||
|
||||
async def _stop(self, frame: EndFrame):
|
||||
await self._cleanup()
|
||||
|
||||
async def _cancel(self, frame: CancelFrame):
|
||||
await self._cleanup()
|
||||
|
||||
async def _cleanup(self):
|
||||
await self._user_turn_controller.cleanup()
|
||||
|
||||
async def _on_push_frame(
|
||||
self, controller, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM
|
||||
):
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _on_broadcast_frame(self, controller, frame_cls: Type[Frame], **kwargs):
|
||||
await self.broadcast_frame(frame_cls, **kwargs)
|
||||
|
||||
async def _on_user_turn_started(
|
||||
self,
|
||||
controller: UserTurnController,
|
||||
strategy: BaseUserTurnStartStrategy,
|
||||
params: UserTurnStartedParams,
|
||||
):
|
||||
logger.debug(f"{self}: User started speaking (user turn start strategy: {strategy})")
|
||||
|
||||
if params.enable_user_speaking_frames:
|
||||
await self.broadcast_frame(UserStartedSpeakingFrame)
|
||||
|
||||
if params.enable_interruptions and self._allow_interruptions:
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
|
||||
await self._call_event_handler("on_user_turn_started", strategy)
|
||||
|
||||
async def _on_user_turn_stopped(
|
||||
self,
|
||||
controller: UserTurnController,
|
||||
strategy: BaseUserTurnStopStrategy,
|
||||
params: UserTurnStoppedParams,
|
||||
):
|
||||
logger.debug(f"{self}: User stopped speaking (user turn stop strategy: {strategy})")
|
||||
|
||||
if params.enable_user_speaking_frames:
|
||||
await self.broadcast_frame(UserStoppedSpeakingFrame)
|
||||
|
||||
await self._call_event_handler("on_user_turn_stopped", strategy)
|
||||
|
||||
async def _on_user_turn_stop_timeout(self, controller):
|
||||
await self._call_event_handler("on_user_turn_stop_timeout")
|
||||
@@ -38,7 +38,7 @@ USER_TURN_STOP_TIMEOUT = 0.2
|
||||
TRANSCRIPTION_TIMEOUT = 0.1
|
||||
|
||||
|
||||
class TestUserAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
class TestLLMUserAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_llm_run(self):
|
||||
context = LLMContext()
|
||||
|
||||
@@ -141,6 +141,19 @@ class TestUserAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
context = LLMContext()
|
||||
user_aggregator = LLMUserAggregator(context)
|
||||
|
||||
should_start = None
|
||||
should_stop = None
|
||||
|
||||
@user_aggregator.event_handler("on_user_turn_started")
|
||||
async def on_user_turn_started(aggregator, strategy):
|
||||
nonlocal should_start
|
||||
should_start = True
|
||||
|
||||
@user_aggregator.event_handler("on_user_turn_stopped")
|
||||
async def on_user_turn_stopped(aggregator, strategy):
|
||||
nonlocal should_stop
|
||||
should_stop = True
|
||||
|
||||
pipeline = Pipeline([user_aggregator])
|
||||
|
||||
frames_to_send = [
|
||||
@@ -162,6 +175,8 @@ class TestUserAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
self.assertTrue(should_start)
|
||||
self.assertTrue(should_stop)
|
||||
|
||||
async def test_user_turn_stop_timeout_no_transcription(self):
|
||||
context = LLMContext()
|
||||
@@ -171,7 +186,19 @@ class TestUserAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
params=LLMUserAggregatorParams(user_turn_stop_timeout=USER_TURN_STOP_TIMEOUT),
|
||||
)
|
||||
|
||||
timeout = False
|
||||
should_start = None
|
||||
should_stop = None
|
||||
timeout = None
|
||||
|
||||
@user_aggregator.event_handler("on_user_turn_started")
|
||||
async def on_user_turn_started(aggregator, strategy):
|
||||
nonlocal should_start
|
||||
should_start = True
|
||||
|
||||
@user_aggregator.event_handler("on_user_turn_stopped")
|
||||
async def on_user_turn_stopped(aggregator, strategy):
|
||||
nonlocal should_stop
|
||||
should_stop = True
|
||||
|
||||
@user_aggregator.event_handler("on_user_turn_stop_timeout")
|
||||
async def on_user_turn_stop_timeout(aggregator):
|
||||
@@ -190,6 +217,8 @@ class TestUserAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
frames_to_send=frames_to_send,
|
||||
)
|
||||
|
||||
self.assertTrue(should_start)
|
||||
self.assertTrue(should_stop)
|
||||
self.assertTrue(timeout)
|
||||
|
||||
async def test_user_turn_stop_timeout_transcription(self):
|
||||
@@ -205,13 +234,19 @@ class TestUserAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
),
|
||||
)
|
||||
|
||||
timeout = False
|
||||
bot_turn = False
|
||||
should_start = None
|
||||
should_stop = None
|
||||
timeout = None
|
||||
|
||||
@user_aggregator.event_handler("on_user_turn_started")
|
||||
async def on_user_turn_started(aggregator, strategy):
|
||||
nonlocal should_start
|
||||
should_start = True
|
||||
|
||||
@user_aggregator.event_handler("on_user_turn_stopped")
|
||||
async def on_user_turn_stopped(aggregator, strategy):
|
||||
nonlocal bot_turn
|
||||
bot_turn = True
|
||||
nonlocal should_stop
|
||||
should_stop = True
|
||||
|
||||
@user_aggregator.event_handler("on_user_turn_stop_timeout")
|
||||
async def on_user_turn_stop_timeout(aggregator):
|
||||
@@ -234,7 +269,8 @@ class TestUserAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
)
|
||||
|
||||
# The transcription strategy should kick-in before the user turn end timeout.
|
||||
self.assertTrue(bot_turn)
|
||||
self.assertTrue(should_start)
|
||||
self.assertTrue(should_stop)
|
||||
self.assertFalse(timeout)
|
||||
|
||||
async def test_user_mute_strategies(self):
|
||||
|
||||
98
tests/test_user_turn_controller.py
Normal file
98
tests/test_user_turn_controller.py
Normal file
@@ -0,0 +1,98 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026 Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import unittest
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
TranscriptionFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.turns.user_turn_controller import UserTurnController
|
||||
from pipecat.turns.user_turn_strategies import UserTurnStrategies
|
||||
from pipecat.utils.asyncio.task_manager import TaskManager, TaskManagerParams
|
||||
|
||||
USER_TURN_STOP_TIMEOUT = 0.2
|
||||
|
||||
|
||||
class TestUserTurnController(unittest.IsolatedAsyncioTestCase):
|
||||
async def asyncSetUp(self):
|
||||
self.task_manager = TaskManager()
|
||||
self.task_manager.setup(TaskManagerParams(loop=asyncio.get_running_loop()))
|
||||
|
||||
async def test_default_user_turn_strategies(self):
|
||||
controller = UserTurnController(user_turn_strategies=UserTurnStrategies())
|
||||
|
||||
await controller.setup(self.task_manager)
|
||||
|
||||
should_start = None
|
||||
should_stop = None
|
||||
|
||||
@controller.event_handler("on_user_turn_started")
|
||||
async def on_user_turn_started(controller, strategy, params):
|
||||
nonlocal should_start
|
||||
should_start = True
|
||||
|
||||
@controller.event_handler("on_user_turn_stopped")
|
||||
async def on_user_turn_stopped(controller, strategy, params):
|
||||
nonlocal should_stop
|
||||
should_stop = True
|
||||
|
||||
await controller.process_frame(VADUserStartedSpeakingFrame())
|
||||
self.assertTrue(should_start)
|
||||
self.assertFalse(should_stop)
|
||||
|
||||
await controller.process_frame(
|
||||
TranscriptionFrame(text="Hello!", user_id="", timestamp="now")
|
||||
)
|
||||
self.assertTrue(should_start)
|
||||
self.assertFalse(should_stop)
|
||||
|
||||
await controller.process_frame(VADUserStoppedSpeakingFrame())
|
||||
self.assertTrue(should_start)
|
||||
self.assertTrue(should_stop)
|
||||
|
||||
async def test_user_turn_stop_timeout_no_transcription(self):
|
||||
controller = UserTurnController(
|
||||
user_turn_strategies=UserTurnStrategies(),
|
||||
user_turn_stop_timeout=USER_TURN_STOP_TIMEOUT,
|
||||
)
|
||||
|
||||
await controller.setup(self.task_manager)
|
||||
|
||||
should_start = None
|
||||
should_stop = None
|
||||
timeout = None
|
||||
|
||||
@controller.event_handler("on_user_turn_started")
|
||||
async def on_user_turn_started(controller, strategy, params):
|
||||
nonlocal should_start
|
||||
should_start = True
|
||||
|
||||
@controller.event_handler("on_user_turn_stopped")
|
||||
async def on_user_turn_stopped(controller, strategy, params):
|
||||
nonlocal should_stop
|
||||
should_stop = True
|
||||
|
||||
@controller.event_handler("on_user_turn_stop_timeout")
|
||||
async def on_user_turn_stop_timeout(controller):
|
||||
nonlocal timeout
|
||||
timeout = True
|
||||
|
||||
await controller.process_frame(VADUserStartedSpeakingFrame())
|
||||
self.assertTrue(should_start)
|
||||
self.assertFalse(should_stop)
|
||||
self.assertFalse(timeout)
|
||||
|
||||
await controller.process_frame(VADUserStoppedSpeakingFrame())
|
||||
self.assertTrue(should_start)
|
||||
self.assertFalse(should_stop)
|
||||
|
||||
await asyncio.sleep(USER_TURN_STOP_TIMEOUT + 0.1)
|
||||
self.assertTrue(should_start)
|
||||
self.assertTrue(should_stop)
|
||||
self.assertTrue(timeout)
|
||||
154
tests/test_user_turn_processor.py
Normal file
154
tests/test_user_turn_processor.py
Normal file
@@ -0,0 +1,154 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import unittest
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
InterruptionFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.tests.utils import SleepFrame, run_test
|
||||
from pipecat.turns.user_stop import TranscriptionUserTurnStopStrategy
|
||||
from pipecat.turns.user_turn_processor import UserTurnProcessor
|
||||
from pipecat.turns.user_turn_strategies import UserTurnStrategies
|
||||
|
||||
USER_TURN_STOP_TIMEOUT = 0.2
|
||||
TRANSCRIPTION_TIMEOUT = 0.1
|
||||
|
||||
|
||||
class TestUserTurnProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_default_user_turn_strategies(self):
|
||||
user_turn_processor = UserTurnProcessor(user_turn_strategies=UserTurnStrategies())
|
||||
|
||||
should_start = None
|
||||
should_stop = None
|
||||
|
||||
@user_turn_processor.event_handler("on_user_turn_started")
|
||||
async def on_user_turn_started(processor, strategy):
|
||||
nonlocal should_start
|
||||
should_start = True
|
||||
|
||||
@user_turn_processor.event_handler("on_user_turn_stopped")
|
||||
async def on_user_turn_stopped(processor, strategy):
|
||||
nonlocal should_stop
|
||||
should_stop = True
|
||||
|
||||
pipeline = Pipeline([user_turn_processor])
|
||||
|
||||
frames_to_send = [
|
||||
VADUserStartedSpeakingFrame(),
|
||||
TranscriptionFrame(text="Hello!", user_id="", timestamp="now"),
|
||||
SleepFrame(),
|
||||
VADUserStoppedSpeakingFrame(),
|
||||
]
|
||||
expected_down_frames = [
|
||||
VADUserStartedSpeakingFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
InterruptionFrame,
|
||||
TranscriptionFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
]
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
self.assertTrue(should_start)
|
||||
self.assertTrue(should_stop)
|
||||
|
||||
async def test_user_turn_stop_timeout_no_transcription(self):
|
||||
user_turn_processor = UserTurnProcessor(
|
||||
user_turn_strategies=UserTurnStrategies(),
|
||||
user_turn_stop_timeout=USER_TURN_STOP_TIMEOUT,
|
||||
)
|
||||
|
||||
should_start = None
|
||||
should_stop = None
|
||||
timeout = None
|
||||
|
||||
@user_turn_processor.event_handler("on_user_turn_started")
|
||||
async def on_user_turn_started(processor, strategy):
|
||||
nonlocal should_start
|
||||
should_start = True
|
||||
|
||||
@user_turn_processor.event_handler("on_user_turn_stopped")
|
||||
async def on_user_turn_stopped(processor, strategy):
|
||||
nonlocal should_stop
|
||||
should_stop = True
|
||||
|
||||
@user_turn_processor.event_handler("on_user_turn_stop_timeout")
|
||||
async def on_user_turn_stop_timeout(processor):
|
||||
nonlocal timeout
|
||||
timeout = True
|
||||
|
||||
pipeline = Pipeline([user_turn_processor])
|
||||
|
||||
frames_to_send = [
|
||||
VADUserStartedSpeakingFrame(),
|
||||
VADUserStoppedSpeakingFrame(),
|
||||
SleepFrame(sleep=USER_TURN_STOP_TIMEOUT + 0.1),
|
||||
]
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=frames_to_send,
|
||||
)
|
||||
|
||||
self.assertTrue(should_start)
|
||||
self.assertTrue(should_stop)
|
||||
self.assertTrue(timeout)
|
||||
|
||||
async def test_user_turn_stop_timeout_transcription(self):
|
||||
user_turn_processor = UserTurnProcessor(
|
||||
user_turn_strategies=UserTurnStrategies(
|
||||
stop=[TranscriptionUserTurnStopStrategy(timeout=TRANSCRIPTION_TIMEOUT)],
|
||||
),
|
||||
user_turn_stop_timeout=USER_TURN_STOP_TIMEOUT,
|
||||
)
|
||||
|
||||
should_start = None
|
||||
should_stop = None
|
||||
timeout = None
|
||||
|
||||
@user_turn_processor.event_handler("on_user_turn_started")
|
||||
async def on_user_turn_started(processor, strategy):
|
||||
nonlocal should_start
|
||||
should_start = True
|
||||
|
||||
@user_turn_processor.event_handler("on_user_turn_stopped")
|
||||
async def on_user_turn_stopped(processor, strategy):
|
||||
nonlocal should_stop
|
||||
should_stop = True
|
||||
|
||||
@user_turn_processor.event_handler("on_user_turn_stop_timeout")
|
||||
async def on_user_turn_stop_timeout(processor):
|
||||
nonlocal timeout
|
||||
timeout = True
|
||||
|
||||
pipeline = Pipeline([user_turn_processor])
|
||||
|
||||
frames_to_send = [
|
||||
VADUserStartedSpeakingFrame(),
|
||||
VADUserStoppedSpeakingFrame(),
|
||||
SleepFrame(sleep=USER_TURN_STOP_TIMEOUT - 0.1),
|
||||
TranscriptionFrame(text="Hello!", user_id="", timestamp="now"),
|
||||
SleepFrame(sleep=USER_TURN_STOP_TIMEOUT - 0.1),
|
||||
SleepFrame(sleep=TRANSCRIPTION_TIMEOUT),
|
||||
]
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=frames_to_send,
|
||||
)
|
||||
|
||||
# The transcription strategy should kick-in before the user turn end timeout.
|
||||
self.assertTrue(should_start)
|
||||
self.assertTrue(should_stop)
|
||||
self.assertFalse(timeout)
|
||||
Reference in New Issue
Block a user