Handle starting muted, add tests

This commit is contained in:
Mark Backman
2025-02-12 18:57:17 -05:00
parent ee93e2a2b1
commit 1e8a86de63
2 changed files with 226 additions and 1 deletions

View File

@@ -23,6 +23,7 @@ from pipecat.frames.frames import (
Frame,
FunctionCallInProgressFrame,
FunctionCallResultFrame,
StartFrame,
StartInterruptionFrame,
StopInterruptionFrame,
STTMuteFrame,
@@ -108,7 +109,7 @@ class STTMuteFilter(FrameProcessor):
self._first_speech_handled = False
self._bot_is_speaking = False
self._function_call_in_progress = False
self._is_muted = STTMuteStrategy.MUTE_UNTIL_FIRST_BOT_COMPLETE in self._config.strategies
self._is_muted = False # Initialize as unmuted, will set state on StartFrame if needed
@property
def is_muted(self) -> bool:
@@ -155,6 +156,13 @@ class STTMuteFilter(FrameProcessor):
"""Processes incoming frames and manages muting state."""
await super().process_frame(frame, direction)
# Handle initial state on StartFrame
if isinstance(frame, StartFrame):
# Check if we should start muted
should_mute = await self._should_mute()
if should_mute:
await self._handle_mute_state(True)
# First determine if we need to change mute state
should_mute = None

View File

@@ -0,0 +1,217 @@
#
# Copyright (c) 2024-2025 Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import unittest
from pipecat.frames.frames import (
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
FunctionCallInProgressFrame,
FunctionCallResultFrame,
STTMuteFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
)
from pipecat.processors.filters.stt_mute_filter import STTMuteConfig, STTMuteFilter, STTMuteStrategy
from pipecat.tests.utils import run_test
class TestSTTMuteFilter(unittest.IsolatedAsyncioTestCase):
async def test_first_speech_strategy(self):
filter = STTMuteFilter(config=STTMuteConfig(strategies={STTMuteStrategy.FIRST_SPEECH}))
frames_to_send = [
BotStartedSpeakingFrame(), # First bot speech starts
UserStartedSpeakingFrame(), # Should be suppressed
UserStoppedSpeakingFrame(), # Should be suppressed
BotStoppedSpeakingFrame(), # First bot speech ends
BotStartedSpeakingFrame(), # Second bot speech
UserStartedSpeakingFrame(), # Should pass through
UserStoppedSpeakingFrame(), # Should pass through
BotStoppedSpeakingFrame(),
]
expected_returned_frames = [
BotStartedSpeakingFrame,
STTMuteFrame, # mute=True
BotStoppedSpeakingFrame,
STTMuteFrame, # mute=False
BotStartedSpeakingFrame,
UserStartedSpeakingFrame, # Now passes through
UserStoppedSpeakingFrame, # Now passes through
BotStoppedSpeakingFrame,
]
await run_test(
filter,
frames_to_send=frames_to_send,
expected_down_frames=expected_returned_frames,
)
async def test_always_strategy(self):
filter = STTMuteFilter(config=STTMuteConfig(strategies={STTMuteStrategy.ALWAYS}))
frames_to_send = [
BotStartedSpeakingFrame(), # First speech starts
UserStartedSpeakingFrame(), # Should be suppressed
UserStoppedSpeakingFrame(), # Should be suppressed
BotStoppedSpeakingFrame(), # First speech ends
UserStartedSpeakingFrame(), # Should pass through
UserStoppedSpeakingFrame(), # Should pass through
BotStartedSpeakingFrame(), # Second speech starts
UserStartedSpeakingFrame(), # Should be suppressed again
UserStoppedSpeakingFrame(), # Should be suppressed again
BotStoppedSpeakingFrame(), # Second speech ends
]
expected_returned_frames = [
BotStartedSpeakingFrame,
STTMuteFrame, # mute=True
BotStoppedSpeakingFrame,
STTMuteFrame, # mute=False
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
BotStartedSpeakingFrame,
STTMuteFrame, # mute=True
BotStoppedSpeakingFrame,
STTMuteFrame, # mute=False
]
await run_test(
filter,
frames_to_send=frames_to_send,
expected_down_frames=expected_returned_frames,
)
# TODO: Revisit once we figure out how to test SystemFrames and DataFrames
# async def test_function_call_strategy(self):
# filter = STTMuteFilter(config=STTMuteConfig(strategies={STTMuteStrategy.FUNCTION_CALL}))
# frames_to_send = [
# UserStartedSpeakingFrame(), # Should pass through initially
# UserStoppedSpeakingFrame(),
# FunctionCallInProgressFrame(
# function_name="get_weather",
# tool_call_id="call_123",
# arguments='{"location": "San Francisco"}',
# ), # Start function call
# UserStartedSpeakingFrame(), # Should be suppressed
# UserStoppedSpeakingFrame(), # Should be suppressed
# FunctionCallResultFrame(
# function_name="get_weather",
# tool_call_id="call_123",
# arguments='{"location": "San Francisco"}',
# result={"temperature": 22},
# ), # End function call
# UserStartedSpeakingFrame(), # Should pass through again
# UserStoppedSpeakingFrame(),
# ]
# expected_returned_frames = [
# UserStartedSpeakingFrame,
# UserStoppedSpeakingFrame,
# FunctionCallInProgressFrame,
# STTMuteFrame, # mute=True
# FunctionCallResultFrame,
# STTMuteFrame, # mute=False
# UserStartedSpeakingFrame,
# UserStoppedSpeakingFrame,
# ]
# await run_test(
# filter,
# frames_to_send=frames_to_send,
# expected_down_frames=expected_returned_frames,
# )
async def test_mute_until_first_bot_complete_strategy(self):
filter = STTMuteFilter(
config=STTMuteConfig(strategies={STTMuteStrategy.MUTE_UNTIL_FIRST_BOT_COMPLETE})
)
frames_to_send = [
UserStartedSpeakingFrame(), # Should be suppressed (starts muted)
UserStoppedSpeakingFrame(), # Should be suppressed
BotStartedSpeakingFrame(), # First bot speech
UserStartedSpeakingFrame(), # Should be suppressed
UserStoppedSpeakingFrame(), # Should be suppressed
BotStoppedSpeakingFrame(), # First speech ends, unmutes
UserStartedSpeakingFrame(), # Should pass through
UserStoppedSpeakingFrame(), # Should pass through
BotStartedSpeakingFrame(), # Second speech
UserStartedSpeakingFrame(), # Should pass through
UserStoppedSpeakingFrame(), # Should pass through
BotStoppedSpeakingFrame(),
]
expected_returned_frames = [
STTMuteFrame, # mute=True after first speech
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
STTMuteFrame, # mute=False after first speech
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
BotStartedSpeakingFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
BotStoppedSpeakingFrame,
]
await run_test(
filter,
frames_to_send=frames_to_send,
expected_down_frames=expected_returned_frames,
)
async def test_incompatible_strategies(self):
with self.assertRaises(ValueError):
STTMuteFilter(
config=STTMuteConfig(
strategies={
STTMuteStrategy.FIRST_SPEECH,
STTMuteStrategy.MUTE_UNTIL_FIRST_BOT_COMPLETE,
}
)
)
async def test_custom_strategy(self):
async def custom_mute_logic(processor: STTMuteFilter) -> bool:
return processor._bot_is_speaking
filter = STTMuteFilter(
config=STTMuteConfig(
strategies={STTMuteStrategy.CUSTOM},
should_mute_callback=custom_mute_logic,
)
)
frames_to_send = [
UserStartedSpeakingFrame(), # Should pass through
UserStoppedSpeakingFrame(), # Should pass through
BotStartedSpeakingFrame(), # Bot starts speaking
UserStartedSpeakingFrame(), # Should be suppressed
UserStoppedSpeakingFrame(), # Should be suppressed
BotStoppedSpeakingFrame(), # Bot stops speaking
UserStartedSpeakingFrame(), # Should pass through
UserStoppedSpeakingFrame(), # Should pass through
]
expected_returned_frames = [
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
BotStartedSpeakingFrame,
STTMuteFrame, # mute=True
BotStoppedSpeakingFrame,
STTMuteFrame, # mute=False
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
]
await run_test(
filter,
frames_to_send=frames_to_send,
expected_down_frames=expected_returned_frames,
)