diff --git a/src/pipecat/processors/filters/stt_mute_filter.py b/src/pipecat/processors/filters/stt_mute_filter.py index fcd492bb3..a5430a451 100644 --- a/src/pipecat/processors/filters/stt_mute_filter.py +++ b/src/pipecat/processors/filters/stt_mute_filter.py @@ -23,6 +23,7 @@ from pipecat.frames.frames import ( Frame, FunctionCallInProgressFrame, FunctionCallResultFrame, + StartFrame, StartInterruptionFrame, StopInterruptionFrame, STTMuteFrame, @@ -108,7 +109,7 @@ class STTMuteFilter(FrameProcessor): self._first_speech_handled = False self._bot_is_speaking = False self._function_call_in_progress = False - self._is_muted = STTMuteStrategy.MUTE_UNTIL_FIRST_BOT_COMPLETE in self._config.strategies + self._is_muted = False # Initialize as unmuted, will set state on StartFrame if needed @property def is_muted(self) -> bool: @@ -155,6 +156,13 @@ class STTMuteFilter(FrameProcessor): """Processes incoming frames and manages muting state.""" await super().process_frame(frame, direction) + # Handle initial state on StartFrame + if isinstance(frame, StartFrame): + # Check if we should start muted + should_mute = await self._should_mute() + if should_mute: + await self._handle_mute_state(True) + # First determine if we need to change mute state should_mute = None diff --git a/tests/test_stt_mute_filter.py b/tests/test_stt_mute_filter.py new file mode 100644 index 000000000..62bdd03c9 --- /dev/null +++ b/tests/test_stt_mute_filter.py @@ -0,0 +1,217 @@ +# +# Copyright (c) 2024-2025 Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import unittest + +from pipecat.frames.frames import ( + BotStartedSpeakingFrame, + BotStoppedSpeakingFrame, + FunctionCallInProgressFrame, + FunctionCallResultFrame, + STTMuteFrame, + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, +) +from pipecat.processors.filters.stt_mute_filter import STTMuteConfig, STTMuteFilter, STTMuteStrategy +from pipecat.tests.utils import run_test + + +class TestSTTMuteFilter(unittest.IsolatedAsyncioTestCase): + async def test_first_speech_strategy(self): + filter = STTMuteFilter(config=STTMuteConfig(strategies={STTMuteStrategy.FIRST_SPEECH})) + + frames_to_send = [ + BotStartedSpeakingFrame(), # First bot speech starts + UserStartedSpeakingFrame(), # Should be suppressed + UserStoppedSpeakingFrame(), # Should be suppressed + BotStoppedSpeakingFrame(), # First bot speech ends + BotStartedSpeakingFrame(), # Second bot speech + UserStartedSpeakingFrame(), # Should pass through + UserStoppedSpeakingFrame(), # Should pass through + BotStoppedSpeakingFrame(), + ] + + expected_returned_frames = [ + BotStartedSpeakingFrame, + STTMuteFrame, # mute=True + BotStoppedSpeakingFrame, + STTMuteFrame, # mute=False + BotStartedSpeakingFrame, + UserStartedSpeakingFrame, # Now passes through + UserStoppedSpeakingFrame, # Now passes through + BotStoppedSpeakingFrame, + ] + + await run_test( + filter, + frames_to_send=frames_to_send, + expected_down_frames=expected_returned_frames, + ) + + async def test_always_strategy(self): + filter = STTMuteFilter(config=STTMuteConfig(strategies={STTMuteStrategy.ALWAYS})) + + frames_to_send = [ + BotStartedSpeakingFrame(), # First speech starts + UserStartedSpeakingFrame(), # Should be suppressed + UserStoppedSpeakingFrame(), # Should be suppressed + BotStoppedSpeakingFrame(), # First speech ends + UserStartedSpeakingFrame(), # Should pass through + UserStoppedSpeakingFrame(), # Should pass through + BotStartedSpeakingFrame(), # Second speech starts + UserStartedSpeakingFrame(), # Should be suppressed again + UserStoppedSpeakingFrame(), # Should be suppressed again + BotStoppedSpeakingFrame(), # Second speech ends + ] + + expected_returned_frames = [ + BotStartedSpeakingFrame, + STTMuteFrame, # mute=True + BotStoppedSpeakingFrame, + STTMuteFrame, # mute=False + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + BotStartedSpeakingFrame, + STTMuteFrame, # mute=True + BotStoppedSpeakingFrame, + STTMuteFrame, # mute=False + ] + + await run_test( + filter, + frames_to_send=frames_to_send, + expected_down_frames=expected_returned_frames, + ) + + # TODO: Revisit once we figure out how to test SystemFrames and DataFrames + # async def test_function_call_strategy(self): + # filter = STTMuteFilter(config=STTMuteConfig(strategies={STTMuteStrategy.FUNCTION_CALL})) + + # frames_to_send = [ + # UserStartedSpeakingFrame(), # Should pass through initially + # UserStoppedSpeakingFrame(), + # FunctionCallInProgressFrame( + # function_name="get_weather", + # tool_call_id="call_123", + # arguments='{"location": "San Francisco"}', + # ), # Start function call + # UserStartedSpeakingFrame(), # Should be suppressed + # UserStoppedSpeakingFrame(), # Should be suppressed + # FunctionCallResultFrame( + # function_name="get_weather", + # tool_call_id="call_123", + # arguments='{"location": "San Francisco"}', + # result={"temperature": 22}, + # ), # End function call + # UserStartedSpeakingFrame(), # Should pass through again + # UserStoppedSpeakingFrame(), + # ] + + # expected_returned_frames = [ + # UserStartedSpeakingFrame, + # UserStoppedSpeakingFrame, + # FunctionCallInProgressFrame, + # STTMuteFrame, # mute=True + # FunctionCallResultFrame, + # STTMuteFrame, # mute=False + # UserStartedSpeakingFrame, + # UserStoppedSpeakingFrame, + # ] + + # await run_test( + # filter, + # frames_to_send=frames_to_send, + # expected_down_frames=expected_returned_frames, + # ) + + async def test_mute_until_first_bot_complete_strategy(self): + filter = STTMuteFilter( + config=STTMuteConfig(strategies={STTMuteStrategy.MUTE_UNTIL_FIRST_BOT_COMPLETE}) + ) + + frames_to_send = [ + UserStartedSpeakingFrame(), # Should be suppressed (starts muted) + UserStoppedSpeakingFrame(), # Should be suppressed + BotStartedSpeakingFrame(), # First bot speech + UserStartedSpeakingFrame(), # Should be suppressed + UserStoppedSpeakingFrame(), # Should be suppressed + BotStoppedSpeakingFrame(), # First speech ends, unmutes + UserStartedSpeakingFrame(), # Should pass through + UserStoppedSpeakingFrame(), # Should pass through + BotStartedSpeakingFrame(), # Second speech + UserStartedSpeakingFrame(), # Should pass through + UserStoppedSpeakingFrame(), # Should pass through + BotStoppedSpeakingFrame(), + ] + + expected_returned_frames = [ + STTMuteFrame, # mute=True after first speech + BotStartedSpeakingFrame, + BotStoppedSpeakingFrame, + STTMuteFrame, # mute=False after first speech + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + BotStartedSpeakingFrame, + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + BotStoppedSpeakingFrame, + ] + + await run_test( + filter, + frames_to_send=frames_to_send, + expected_down_frames=expected_returned_frames, + ) + + async def test_incompatible_strategies(self): + with self.assertRaises(ValueError): + STTMuteFilter( + config=STTMuteConfig( + strategies={ + STTMuteStrategy.FIRST_SPEECH, + STTMuteStrategy.MUTE_UNTIL_FIRST_BOT_COMPLETE, + } + ) + ) + + async def test_custom_strategy(self): + async def custom_mute_logic(processor: STTMuteFilter) -> bool: + return processor._bot_is_speaking + + filter = STTMuteFilter( + config=STTMuteConfig( + strategies={STTMuteStrategy.CUSTOM}, + should_mute_callback=custom_mute_logic, + ) + ) + + frames_to_send = [ + UserStartedSpeakingFrame(), # Should pass through + UserStoppedSpeakingFrame(), # Should pass through + BotStartedSpeakingFrame(), # Bot starts speaking + UserStartedSpeakingFrame(), # Should be suppressed + UserStoppedSpeakingFrame(), # Should be suppressed + BotStoppedSpeakingFrame(), # Bot stops speaking + UserStartedSpeakingFrame(), # Should pass through + UserStoppedSpeakingFrame(), # Should pass through + ] + + expected_returned_frames = [ + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + BotStartedSpeakingFrame, + STTMuteFrame, # mute=True + BotStoppedSpeakingFrame, + STTMuteFrame, # mute=False + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + ] + + await run_test( + filter, + frames_to_send=frames_to_send, + expected_down_frames=expected_returned_frames, + )