Merge pull request #3291 from pipecat-ai/aleix/llm-user-aggregator-timeout
LLMUserAggregator bot turn start strategies timeout fallback
This commit is contained in:
4
changelog/3291.added.md
Normal file
4
changelog/3291.added.md
Normal file
@@ -0,0 +1,4 @@
|
||||
- `LLMUserAggregator` now exposes the following events:
|
||||
- `on_user_turn_started`: triggered when a user turn starts
|
||||
- `on_bot_turn_started`: triggered when a user turn ends and a bot turn starts
|
||||
- `on_user_turn_end_timeout`: triggered when a user turn does not stop and times out
|
||||
@@ -51,6 +51,8 @@ from pipecat.frames.frames import (
|
||||
UserImageRawFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_context import (
|
||||
LLMContext,
|
||||
@@ -75,9 +77,12 @@ class LLMUserAggregatorParams:
|
||||
interruption frames. This is enabled by default, but you may want
|
||||
to disable it if another component (e.g., an STT service) is already
|
||||
generating these frames.
|
||||
user_turn_end_timeout: Time in seconds to wait before considering the
|
||||
user's turn finished and starting the bot turn.
|
||||
"""
|
||||
|
||||
enable_user_speaking_frames: bool = True
|
||||
user_turn_end_timeout: float = 5.0
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -226,15 +231,20 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
|
||||
- on_user_turn_started: Called when the user turn starts
|
||||
- on_bot_turn_started: Called when the user turn ends and it is now the bot’s turn
|
||||
- on_user_turn_end_timeout: Called when no bot turn start strategy triggers
|
||||
|
||||
Example::
|
||||
|
||||
@aggregator.event_handler("on_user_turn_started")
|
||||
async def on_user_turn_started(aggregator, strategy):
|
||||
async def on_user_turn_started(aggregator, Optional[strategy]):
|
||||
...
|
||||
|
||||
@aggregator.event_handler("on_bot_turn_started")
|
||||
async def on_bot_turn_started(aggregator, strategy):
|
||||
async def on_bot_turn_started(aggregator, Optional[strategy]):
|
||||
...
|
||||
|
||||
@aggregator.event_handler("on_user_turn_end_timeout")
|
||||
async def on_user_turn_end_timeout(aggregator):
|
||||
...
|
||||
|
||||
"""
|
||||
@@ -255,9 +265,15 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
"""
|
||||
super().__init__(context=context, role="user", **kwargs)
|
||||
self._params = params or LLMUserAggregatorParams()
|
||||
self._user_speaking = False
|
||||
|
||||
self._vad_user_speaking = False
|
||||
|
||||
self._user_turn = False
|
||||
self._user_turn_end_timeout_event = asyncio.Event()
|
||||
self._user_turn_end_timeout_task: Optional[asyncio.Task] = None
|
||||
|
||||
self._register_event_handler("on_user_turn_started")
|
||||
self._register_event_handler("on_user_turn_end_timeout")
|
||||
self._register_event_handler("on_bot_turn_started")
|
||||
|
||||
async def cleanup(self):
|
||||
@@ -299,6 +315,12 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
elif isinstance(frame, CancelFrame):
|
||||
await self._cancel(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, VADUserStartedSpeakingFrame):
|
||||
await self._handle_vad_user_started_speaking(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, VADUserStoppedSpeakingFrame):
|
||||
await self._handle_vad_user_stopped_speaking(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, TranscriptionFrame):
|
||||
await self._handle_transcription(frame)
|
||||
elif isinstance(frame, LLMRunFrame):
|
||||
@@ -335,6 +357,11 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
await self.push_context_frame()
|
||||
|
||||
async def _start(self, frame: StartFrame):
|
||||
if not self._user_turn_end_timeout_task:
|
||||
self._user_turn_end_timeout_task = self.create_task(
|
||||
self._user_turn_end_timeout_task_handler()
|
||||
)
|
||||
|
||||
if self.turn_start_strategies and self.turn_start_strategies.user:
|
||||
for s in self.turn_start_strategies.user:
|
||||
await s.setup(self.task_manager)
|
||||
@@ -356,6 +383,10 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
await self._cleanup()
|
||||
|
||||
async def _cleanup(self):
|
||||
if self._user_turn_end_timeout_task:
|
||||
await self.cancel_task(self._user_turn_end_timeout_task)
|
||||
self._user_turn_end_timeout_task = None
|
||||
|
||||
if self.turn_start_strategies and self.turn_start_strategies.user:
|
||||
for s in self.turn_start_strategies.user:
|
||||
await s.cleanup()
|
||||
@@ -406,6 +437,18 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
" )"
|
||||
)
|
||||
|
||||
async def _handle_vad_user_started_speaking(self, frame: VADUserStartedSpeakingFrame):
|
||||
self._vad_user_speaking = True
|
||||
|
||||
# The user started talking, let's reset the user turn timeout.
|
||||
self._user_turn_end_timeout_event.set()
|
||||
|
||||
async def _handle_vad_user_stopped_speaking(self, frame: VADUserStoppedSpeakingFrame):
|
||||
self._vad_user_speaking = False
|
||||
|
||||
# The user stopped talking, let's reset the user turn timeout.
|
||||
self._user_turn_end_timeout_event.set()
|
||||
|
||||
async def _handle_transcription(self, frame: TranscriptionFrame):
|
||||
text = frame.text
|
||||
|
||||
@@ -413,6 +456,9 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
if not text.strip():
|
||||
return
|
||||
|
||||
# We have creceived a transcription, let's reset the user turn timeout.
|
||||
self._user_turn_end_timeout_event.set()
|
||||
|
||||
# Transcriptions never include inter-part spaces (so far).
|
||||
self._aggregation.append(
|
||||
TextPartForConcatenation(
|
||||
@@ -442,12 +488,13 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
):
|
||||
await self.broadcast_frame(frame_cls, **kwargs)
|
||||
|
||||
async def _trigger_user_turn_start(self, strategy: BaseUserTurnStartStrategy):
|
||||
async def _trigger_user_turn_start(self, strategy: Optional[BaseUserTurnStartStrategy]):
|
||||
# Prevent two consecutive user turn starts.
|
||||
if self._user_speaking:
|
||||
if self._user_turn:
|
||||
return
|
||||
|
||||
self._user_speaking = True
|
||||
self._user_turn = True
|
||||
self._user_turn_end_timeout_event.set()
|
||||
|
||||
# Reset all user turn start strategies to start fresh.
|
||||
if self.turn_start_strategies and self.turn_start_strategies.user:
|
||||
@@ -462,12 +509,13 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
|
||||
await self._call_event_handler("on_user_turn_started", strategy)
|
||||
|
||||
async def _trigger_bot_turn_start(self, strategy: BaseBotTurnStartStrategy):
|
||||
async def _trigger_bot_turn_start(self, strategy: Optional[BaseBotTurnStartStrategy]):
|
||||
# Prevent two consecutive bot turn starts.
|
||||
if not self._user_speaking:
|
||||
if not self._user_turn:
|
||||
return
|
||||
|
||||
self._user_speaking = False
|
||||
self._user_turn = False
|
||||
self._user_turn_end_timeout_event.set()
|
||||
|
||||
# Reset all bot turn start strategies to start fresh.
|
||||
if self.turn_start_strategies and self.turn_start_strategies.bot:
|
||||
@@ -484,6 +532,19 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
# Always push context frame.
|
||||
await self.push_aggregation()
|
||||
|
||||
async def _user_turn_end_timeout_task_handler(self):
|
||||
while True:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._user_turn_end_timeout_event.wait(),
|
||||
timeout=self._params.user_turn_end_timeout,
|
||||
)
|
||||
self._user_turn_end_timeout_event.clear()
|
||||
except asyncio.TimeoutError:
|
||||
if self._user_turn and not self._vad_user_speaking:
|
||||
await self._call_event_handler("on_user_turn_end_timeout")
|
||||
await self._trigger_bot_turn_start(None)
|
||||
|
||||
|
||||
class LLMAssistantAggregator(LLMContextAggregator):
|
||||
"""Assistant LLM aggregator that processes bot responses and function calls.
|
||||
|
||||
235
tests/test_context_aggregators_universal.py
Normal file
235
tests/test_context_aggregators_universal.py
Normal file
@@ -0,0 +1,235 @@
|
||||
#
|
||||
# Copyright (c) 2024-2025 Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import unittest
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
InterruptionFrame,
|
||||
LLMContextFrame,
|
||||
LLMMessagesAppendFrame,
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMRunFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.task import PipelineParams
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMUserAggregator,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.tests.utils import SleepFrame, run_test
|
||||
from pipecat.turns.bot.transcription_bot_turn_start_strategy import (
|
||||
TranscriptionBotTurnStartStrategy,
|
||||
)
|
||||
from pipecat.turns.turn_start_strategies import TurnStartStrategies
|
||||
|
||||
USER_TURN_END_TIMEOUT = 0.2
|
||||
TRANSCRIPTION_TIMEOUT = 0.1
|
||||
|
||||
|
||||
class TestUserAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_llm_run(self):
|
||||
context = LLMContext()
|
||||
|
||||
pipeline = Pipeline([LLMUserAggregator(context)])
|
||||
|
||||
frames_to_send = [LLMRunFrame()]
|
||||
expected_down_frames = [LLMContextFrame]
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
|
||||
async def test_llm_messages_append(self):
|
||||
context = LLMContext()
|
||||
|
||||
pipeline = Pipeline([LLMUserAggregator(context)])
|
||||
|
||||
frames_to_send = [
|
||||
LLMMessagesAppendFrame(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hi there!",
|
||||
}
|
||||
]
|
||||
)
|
||||
]
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=frames_to_send,
|
||||
)
|
||||
assert context.messages[0]["content"] == "Hi there!"
|
||||
|
||||
async def test_llm_messages_append_run(self):
|
||||
context = LLMContext()
|
||||
pipeline = Pipeline([LLMUserAggregator(context)])
|
||||
|
||||
frames_to_send = [
|
||||
LLMMessagesAppendFrame(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hi there!",
|
||||
}
|
||||
],
|
||||
run_llm=True,
|
||||
)
|
||||
]
|
||||
expected_down_frames = [LLMContextFrame]
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
assert context.messages[0]["content"] == "Hi there!"
|
||||
|
||||
async def test_llm_messages_update(self):
|
||||
context = LLMContext()
|
||||
pipeline = Pipeline([LLMUserAggregator(context)])
|
||||
|
||||
frames_to_send = [
|
||||
LLMMessagesUpdateFrame(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hi there!",
|
||||
}
|
||||
]
|
||||
)
|
||||
]
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=frames_to_send,
|
||||
)
|
||||
assert context.messages[0]["content"] == "Hi there!"
|
||||
|
||||
async def test_llm_messages_update_run(self):
|
||||
context = LLMContext()
|
||||
pipeline = Pipeline([LLMUserAggregator(context)])
|
||||
|
||||
frames_to_send = [
|
||||
LLMMessagesUpdateFrame(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hi there!",
|
||||
}
|
||||
],
|
||||
run_llm=True,
|
||||
)
|
||||
]
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=frames_to_send,
|
||||
)
|
||||
assert context.messages[0]["content"] == "Hi there!"
|
||||
|
||||
async def test_default_turn_start_strategies(self):
|
||||
context = LLMContext()
|
||||
user_aggregator = LLMUserAggregator(context)
|
||||
|
||||
pipeline = Pipeline([user_aggregator])
|
||||
|
||||
frames_to_send = [
|
||||
VADUserStartedSpeakingFrame(),
|
||||
TranscriptionFrame(text="Hello!", user_id="", timestamp="now"),
|
||||
SleepFrame(),
|
||||
VADUserStoppedSpeakingFrame(),
|
||||
]
|
||||
expected_down_frames = [
|
||||
VADUserStartedSpeakingFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
InterruptionFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
LLMContextFrame,
|
||||
]
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
|
||||
async def test_user_turn_end_timeout_no_transcription(self):
|
||||
context = LLMContext()
|
||||
|
||||
user_aggregator = LLMUserAggregator(
|
||||
context,
|
||||
params=LLMUserAggregatorParams(user_turn_end_timeout=USER_TURN_END_TIMEOUT),
|
||||
)
|
||||
|
||||
timeout = False
|
||||
|
||||
@user_aggregator.event_handler("on_user_turn_end_timeout")
|
||||
async def on_user_turn_end_timeout(aggregator):
|
||||
nonlocal timeout
|
||||
timeout = True
|
||||
|
||||
pipeline = Pipeline([user_aggregator])
|
||||
|
||||
frames_to_send = [
|
||||
VADUserStartedSpeakingFrame(),
|
||||
VADUserStoppedSpeakingFrame(),
|
||||
SleepFrame(sleep=USER_TURN_END_TIMEOUT + 0.1),
|
||||
]
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=frames_to_send,
|
||||
)
|
||||
|
||||
self.assertTrue(timeout)
|
||||
|
||||
async def test_user_turn_end_timeout_transcription(self):
|
||||
context = LLMContext()
|
||||
|
||||
user_aggregator = LLMUserAggregator(
|
||||
context,
|
||||
params=LLMUserAggregatorParams(user_turn_end_timeout=USER_TURN_END_TIMEOUT),
|
||||
)
|
||||
|
||||
timeout = False
|
||||
bot_turn = False
|
||||
|
||||
@user_aggregator.event_handler("on_bot_turn_started")
|
||||
async def on_bot_turn_started(aggregator, strategy):
|
||||
nonlocal bot_turn
|
||||
bot_turn = True
|
||||
|
||||
@user_aggregator.event_handler("on_user_turn_end_timeout")
|
||||
async def on_user_turn_end_timeout(aggregator):
|
||||
nonlocal timeout
|
||||
timeout = True
|
||||
|
||||
pipeline = Pipeline([user_aggregator])
|
||||
|
||||
frames_to_send = [
|
||||
VADUserStartedSpeakingFrame(),
|
||||
VADUserStoppedSpeakingFrame(),
|
||||
SleepFrame(sleep=USER_TURN_END_TIMEOUT - 0.1),
|
||||
TranscriptionFrame(text="Hello!", user_id="", timestamp="now"),
|
||||
SleepFrame(sleep=USER_TURN_END_TIMEOUT - 0.1),
|
||||
SleepFrame(sleep=TRANSCRIPTION_TIMEOUT),
|
||||
]
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=frames_to_send,
|
||||
pipeline_params=PipelineParams(
|
||||
turn_start_strategies=TurnStartStrategies(
|
||||
bot=[TranscriptionBotTurnStartStrategy(timeout=TRANSCRIPTION_TIMEOUT)],
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
# The transcription strategy should kick-in before the user turn end timeout.
|
||||
self.assertTrue(bot_turn)
|
||||
self.assertFalse(timeout)
|
||||
Reference in New Issue
Block a user