diff --git a/src/pipecat/frames/frames.py b/src/pipecat/frames/frames.py index 001a4116f..09d0a93b0 100644 --- a/src/pipecat/frames/frames.py +++ b/src/pipecat/frames/frames.py @@ -618,6 +618,13 @@ class FunctionCallInProgressFrame(SystemFrame): arguments: str +@dataclass +class STTMuteFrame(SystemFrame): + """System frame to mute/unmute the STT service.""" + + mute: bool + + @dataclass class TransportMessageUrgentFrame(SystemFrame): message: Any @@ -752,13 +759,6 @@ class TTSUpdateSettingsFrame(ServiceUpdateSettingsFrame): pass -@dataclass -class STTMuteFrame(ControlFrame): - """Control frame to mute/unmute the STT service.""" - - mute: bool - - @dataclass class STTUpdateSettingsFrame(ServiceUpdateSettingsFrame): pass diff --git a/src/pipecat/processors/filters/stt_mute_filter.py b/src/pipecat/processors/filters/stt_mute_filter.py index 3ea0cae2b..fb1753263 100644 --- a/src/pipecat/processors/filters/stt_mute_filter.py +++ b/src/pipecat/processors/filters/stt_mute_filter.py @@ -23,6 +23,7 @@ from pipecat.frames.frames import ( Frame, FunctionCallInProgressFrame, FunctionCallResultFrame, + StartFrame, StartInterruptionFrame, StopInterruptionFrame, STTMuteFrame, @@ -108,7 +109,7 @@ class STTMuteFilter(FrameProcessor): self._first_speech_handled = False self._bot_is_speaking = False self._function_call_in_progress = False - self._is_muted = STTMuteStrategy.MUTE_UNTIL_FIRST_BOT_COMPLETE in self._config.strategies + self._is_muted = False # Initialize as unmuted, will set state on StartFrame if needed @property def is_muted(self) -> bool: @@ -155,24 +156,28 @@ class STTMuteFilter(FrameProcessor): """Processes incoming frames and manages muting state.""" await super().process_frame(frame, direction) - # Handle function call state changes - if isinstance(frame, FunctionCallInProgressFrame): + # Determine if we need to change mute state based on frame type + should_mute = None + + # Process frames to determine mute state + if isinstance(frame, StartFrame): + should_mute = await self._should_mute() + elif isinstance(frame, FunctionCallInProgressFrame): self._function_call_in_progress = True - await self._handle_mute_state(await self._should_mute()) + should_mute = await self._should_mute() elif isinstance(frame, FunctionCallResultFrame): self._function_call_in_progress = False - await self._handle_mute_state(await self._should_mute()) - # Handle bot speaking state changes + should_mute = await self._should_mute() elif isinstance(frame, BotStartedSpeakingFrame): self._bot_is_speaking = True - await self._handle_mute_state(await self._should_mute()) + should_mute = await self._should_mute() elif isinstance(frame, BotStoppedSpeakingFrame): self._bot_is_speaking = False if not self._first_speech_handled: self._first_speech_handled = True - await self._handle_mute_state(await self._should_mute()) + should_mute = await self._should_mute() - # Handle frame propagation + # Then push the original frame if isinstance( frame, ( @@ -190,3 +195,7 @@ class STTMuteFilter(FrameProcessor): else: # Pass all other frames through await self.push_frame(frame, direction) + + # Finally handle mute state change if needed + if should_mute is not None and should_mute != self.is_muted: + await self._handle_mute_state(should_mute) diff --git a/src/pipecat/services/ai_services.py b/src/pipecat/services/ai_services.py index 50be621a4..b8d223176 100644 --- a/src/pipecat/services/ai_services.py +++ b/src/pipecat/services/ai_services.py @@ -525,9 +525,13 @@ class STTService(AIService): else: logger.warning(f"Unknown setting for STT service: {key}") - async def process_audio_frame(self, frame: AudioRawFrame): - if not self._muted: - await self.process_generator(self.run_stt(frame.audio)) + async def process_audio_frame(self, frame: AudioRawFrame, direction: FrameDirection): + if self._muted: + return + + await self.process_generator(self.run_stt(frame.audio)) + if self._audio_passthrough: + await self.push_frame(frame, direction) async def process_frame(self, frame: Frame, direction: FrameDirection): """Processes a frame of audio data, either buffering or transcribing it.""" @@ -537,9 +541,7 @@ class STTService(AIService): # In this service we accumulate audio internally and at the end we # push a TextFrame. We also push audio downstream in case someone # else needs it. - await self.process_audio_frame(frame) - if self._audio_passthrough: - await self.push_frame(frame, direction) + await self.process_audio_frame(frame, direction) elif isinstance(frame, STTUpdateSettingsFrame): await self._update_settings(frame.settings) elif isinstance(frame, STTMuteFrame): diff --git a/tests/test_stt_mute_filter.py b/tests/test_stt_mute_filter.py new file mode 100644 index 000000000..62bdd03c9 --- /dev/null +++ b/tests/test_stt_mute_filter.py @@ -0,0 +1,217 @@ +# +# Copyright (c) 2024-2025 Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import unittest + +from pipecat.frames.frames import ( + BotStartedSpeakingFrame, + BotStoppedSpeakingFrame, + FunctionCallInProgressFrame, + FunctionCallResultFrame, + STTMuteFrame, + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, +) +from pipecat.processors.filters.stt_mute_filter import STTMuteConfig, STTMuteFilter, STTMuteStrategy +from pipecat.tests.utils import run_test + + +class TestSTTMuteFilter(unittest.IsolatedAsyncioTestCase): + async def test_first_speech_strategy(self): + filter = STTMuteFilter(config=STTMuteConfig(strategies={STTMuteStrategy.FIRST_SPEECH})) + + frames_to_send = [ + BotStartedSpeakingFrame(), # First bot speech starts + UserStartedSpeakingFrame(), # Should be suppressed + UserStoppedSpeakingFrame(), # Should be suppressed + BotStoppedSpeakingFrame(), # First bot speech ends + BotStartedSpeakingFrame(), # Second bot speech + UserStartedSpeakingFrame(), # Should pass through + UserStoppedSpeakingFrame(), # Should pass through + BotStoppedSpeakingFrame(), + ] + + expected_returned_frames = [ + BotStartedSpeakingFrame, + STTMuteFrame, # mute=True + BotStoppedSpeakingFrame, + STTMuteFrame, # mute=False + BotStartedSpeakingFrame, + UserStartedSpeakingFrame, # Now passes through + UserStoppedSpeakingFrame, # Now passes through + BotStoppedSpeakingFrame, + ] + + await run_test( + filter, + frames_to_send=frames_to_send, + expected_down_frames=expected_returned_frames, + ) + + async def test_always_strategy(self): + filter = STTMuteFilter(config=STTMuteConfig(strategies={STTMuteStrategy.ALWAYS})) + + frames_to_send = [ + BotStartedSpeakingFrame(), # First speech starts + UserStartedSpeakingFrame(), # Should be suppressed + UserStoppedSpeakingFrame(), # Should be suppressed + BotStoppedSpeakingFrame(), # First speech ends + UserStartedSpeakingFrame(), # Should pass through + UserStoppedSpeakingFrame(), # Should pass through + BotStartedSpeakingFrame(), # Second speech starts + UserStartedSpeakingFrame(), # Should be suppressed again + UserStoppedSpeakingFrame(), # Should be suppressed again + BotStoppedSpeakingFrame(), # Second speech ends + ] + + expected_returned_frames = [ + BotStartedSpeakingFrame, + STTMuteFrame, # mute=True + BotStoppedSpeakingFrame, + STTMuteFrame, # mute=False + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + BotStartedSpeakingFrame, + STTMuteFrame, # mute=True + BotStoppedSpeakingFrame, + STTMuteFrame, # mute=False + ] + + await run_test( + filter, + frames_to_send=frames_to_send, + expected_down_frames=expected_returned_frames, + ) + + # TODO: Revisit once we figure out how to test SystemFrames and DataFrames + # async def test_function_call_strategy(self): + # filter = STTMuteFilter(config=STTMuteConfig(strategies={STTMuteStrategy.FUNCTION_CALL})) + + # frames_to_send = [ + # UserStartedSpeakingFrame(), # Should pass through initially + # UserStoppedSpeakingFrame(), + # FunctionCallInProgressFrame( + # function_name="get_weather", + # tool_call_id="call_123", + # arguments='{"location": "San Francisco"}', + # ), # Start function call + # UserStartedSpeakingFrame(), # Should be suppressed + # UserStoppedSpeakingFrame(), # Should be suppressed + # FunctionCallResultFrame( + # function_name="get_weather", + # tool_call_id="call_123", + # arguments='{"location": "San Francisco"}', + # result={"temperature": 22}, + # ), # End function call + # UserStartedSpeakingFrame(), # Should pass through again + # UserStoppedSpeakingFrame(), + # ] + + # expected_returned_frames = [ + # UserStartedSpeakingFrame, + # UserStoppedSpeakingFrame, + # FunctionCallInProgressFrame, + # STTMuteFrame, # mute=True + # FunctionCallResultFrame, + # STTMuteFrame, # mute=False + # UserStartedSpeakingFrame, + # UserStoppedSpeakingFrame, + # ] + + # await run_test( + # filter, + # frames_to_send=frames_to_send, + # expected_down_frames=expected_returned_frames, + # ) + + async def test_mute_until_first_bot_complete_strategy(self): + filter = STTMuteFilter( + config=STTMuteConfig(strategies={STTMuteStrategy.MUTE_UNTIL_FIRST_BOT_COMPLETE}) + ) + + frames_to_send = [ + UserStartedSpeakingFrame(), # Should be suppressed (starts muted) + UserStoppedSpeakingFrame(), # Should be suppressed + BotStartedSpeakingFrame(), # First bot speech + UserStartedSpeakingFrame(), # Should be suppressed + UserStoppedSpeakingFrame(), # Should be suppressed + BotStoppedSpeakingFrame(), # First speech ends, unmutes + UserStartedSpeakingFrame(), # Should pass through + UserStoppedSpeakingFrame(), # Should pass through + BotStartedSpeakingFrame(), # Second speech + UserStartedSpeakingFrame(), # Should pass through + UserStoppedSpeakingFrame(), # Should pass through + BotStoppedSpeakingFrame(), + ] + + expected_returned_frames = [ + STTMuteFrame, # mute=True after first speech + BotStartedSpeakingFrame, + BotStoppedSpeakingFrame, + STTMuteFrame, # mute=False after first speech + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + BotStartedSpeakingFrame, + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + BotStoppedSpeakingFrame, + ] + + await run_test( + filter, + frames_to_send=frames_to_send, + expected_down_frames=expected_returned_frames, + ) + + async def test_incompatible_strategies(self): + with self.assertRaises(ValueError): + STTMuteFilter( + config=STTMuteConfig( + strategies={ + STTMuteStrategy.FIRST_SPEECH, + STTMuteStrategy.MUTE_UNTIL_FIRST_BOT_COMPLETE, + } + ) + ) + + async def test_custom_strategy(self): + async def custom_mute_logic(processor: STTMuteFilter) -> bool: + return processor._bot_is_speaking + + filter = STTMuteFilter( + config=STTMuteConfig( + strategies={STTMuteStrategy.CUSTOM}, + should_mute_callback=custom_mute_logic, + ) + ) + + frames_to_send = [ + UserStartedSpeakingFrame(), # Should pass through + UserStoppedSpeakingFrame(), # Should pass through + BotStartedSpeakingFrame(), # Bot starts speaking + UserStartedSpeakingFrame(), # Should be suppressed + UserStoppedSpeakingFrame(), # Should be suppressed + BotStoppedSpeakingFrame(), # Bot stops speaking + UserStartedSpeakingFrame(), # Should pass through + UserStoppedSpeakingFrame(), # Should pass through + ] + + expected_returned_frames = [ + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + BotStartedSpeakingFrame, + STTMuteFrame, # mute=True + BotStoppedSpeakingFrame, + STTMuteFrame, # mute=False + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, + ] + + await run_test( + filter, + frames_to_send=frames_to_send, + expected_down_frames=expected_returned_frames, + )