Merge pull request #1214 from pipecat-ai/mb/stt-mute-tests

Improve STTMuteFilter, add tests
This commit is contained in:
Mark Backman
2025-02-13 09:50:43 -05:00
committed by GitHub
4 changed files with 250 additions and 22 deletions

View File

@@ -618,6 +618,13 @@ class FunctionCallInProgressFrame(SystemFrame):
arguments: str
@dataclass
class STTMuteFrame(SystemFrame):
"""System frame to mute/unmute the STT service."""
mute: bool
@dataclass
class TransportMessageUrgentFrame(SystemFrame):
message: Any
@@ -752,13 +759,6 @@ class TTSUpdateSettingsFrame(ServiceUpdateSettingsFrame):
pass
@dataclass
class STTMuteFrame(ControlFrame):
"""Control frame to mute/unmute the STT service."""
mute: bool
@dataclass
class STTUpdateSettingsFrame(ServiceUpdateSettingsFrame):
pass

View File

@@ -23,6 +23,7 @@ from pipecat.frames.frames import (
Frame,
FunctionCallInProgressFrame,
FunctionCallResultFrame,
StartFrame,
StartInterruptionFrame,
StopInterruptionFrame,
STTMuteFrame,
@@ -108,7 +109,7 @@ class STTMuteFilter(FrameProcessor):
self._first_speech_handled = False
self._bot_is_speaking = False
self._function_call_in_progress = False
self._is_muted = STTMuteStrategy.MUTE_UNTIL_FIRST_BOT_COMPLETE in self._config.strategies
self._is_muted = False # Initialize as unmuted, will set state on StartFrame if needed
@property
def is_muted(self) -> bool:
@@ -155,24 +156,28 @@ class STTMuteFilter(FrameProcessor):
"""Processes incoming frames and manages muting state."""
await super().process_frame(frame, direction)
# Handle function call state changes
if isinstance(frame, FunctionCallInProgressFrame):
# Determine if we need to change mute state based on frame type
should_mute = None
# Process frames to determine mute state
if isinstance(frame, StartFrame):
should_mute = await self._should_mute()
elif isinstance(frame, FunctionCallInProgressFrame):
self._function_call_in_progress = True
await self._handle_mute_state(await self._should_mute())
should_mute = await self._should_mute()
elif isinstance(frame, FunctionCallResultFrame):
self._function_call_in_progress = False
await self._handle_mute_state(await self._should_mute())
# Handle bot speaking state changes
should_mute = await self._should_mute()
elif isinstance(frame, BotStartedSpeakingFrame):
self._bot_is_speaking = True
await self._handle_mute_state(await self._should_mute())
should_mute = await self._should_mute()
elif isinstance(frame, BotStoppedSpeakingFrame):
self._bot_is_speaking = False
if not self._first_speech_handled:
self._first_speech_handled = True
await self._handle_mute_state(await self._should_mute())
should_mute = await self._should_mute()
# Handle frame propagation
# Then push the original frame
if isinstance(
frame,
(
@@ -190,3 +195,7 @@ class STTMuteFilter(FrameProcessor):
else:
# Pass all other frames through
await self.push_frame(frame, direction)
# Finally handle mute state change if needed
if should_mute is not None and should_mute != self.is_muted:
await self._handle_mute_state(should_mute)

View File

@@ -525,9 +525,13 @@ class STTService(AIService):
else:
logger.warning(f"Unknown setting for STT service: {key}")
async def process_audio_frame(self, frame: AudioRawFrame):
if not self._muted:
await self.process_generator(self.run_stt(frame.audio))
async def process_audio_frame(self, frame: AudioRawFrame, direction: FrameDirection):
if self._muted:
return
await self.process_generator(self.run_stt(frame.audio))
if self._audio_passthrough:
await self.push_frame(frame, direction)
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Processes a frame of audio data, either buffering or transcribing it."""
@@ -537,9 +541,7 @@ class STTService(AIService):
# In this service we accumulate audio internally and at the end we
# push a TextFrame. We also push audio downstream in case someone
# else needs it.
await self.process_audio_frame(frame)
if self._audio_passthrough:
await self.push_frame(frame, direction)
await self.process_audio_frame(frame, direction)
elif isinstance(frame, STTUpdateSettingsFrame):
await self._update_settings(frame.settings)
elif isinstance(frame, STTMuteFrame):

View File

@@ -0,0 +1,217 @@
#
# Copyright (c) 2024-2025 Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import unittest
from pipecat.frames.frames import (
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
FunctionCallInProgressFrame,
FunctionCallResultFrame,
STTMuteFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
)
from pipecat.processors.filters.stt_mute_filter import STTMuteConfig, STTMuteFilter, STTMuteStrategy
from pipecat.tests.utils import run_test
class TestSTTMuteFilter(unittest.IsolatedAsyncioTestCase):
async def test_first_speech_strategy(self):
filter = STTMuteFilter(config=STTMuteConfig(strategies={STTMuteStrategy.FIRST_SPEECH}))
frames_to_send = [
BotStartedSpeakingFrame(), # First bot speech starts
UserStartedSpeakingFrame(), # Should be suppressed
UserStoppedSpeakingFrame(), # Should be suppressed
BotStoppedSpeakingFrame(), # First bot speech ends
BotStartedSpeakingFrame(), # Second bot speech
UserStartedSpeakingFrame(), # Should pass through
UserStoppedSpeakingFrame(), # Should pass through
BotStoppedSpeakingFrame(),
]
expected_returned_frames = [
BotStartedSpeakingFrame,
STTMuteFrame, # mute=True
BotStoppedSpeakingFrame,
STTMuteFrame, # mute=False
BotStartedSpeakingFrame,
UserStartedSpeakingFrame, # Now passes through
UserStoppedSpeakingFrame, # Now passes through
BotStoppedSpeakingFrame,
]
await run_test(
filter,
frames_to_send=frames_to_send,
expected_down_frames=expected_returned_frames,
)
async def test_always_strategy(self):
filter = STTMuteFilter(config=STTMuteConfig(strategies={STTMuteStrategy.ALWAYS}))
frames_to_send = [
BotStartedSpeakingFrame(), # First speech starts
UserStartedSpeakingFrame(), # Should be suppressed
UserStoppedSpeakingFrame(), # Should be suppressed
BotStoppedSpeakingFrame(), # First speech ends
UserStartedSpeakingFrame(), # Should pass through
UserStoppedSpeakingFrame(), # Should pass through
BotStartedSpeakingFrame(), # Second speech starts
UserStartedSpeakingFrame(), # Should be suppressed again
UserStoppedSpeakingFrame(), # Should be suppressed again
BotStoppedSpeakingFrame(), # Second speech ends
]
expected_returned_frames = [
BotStartedSpeakingFrame,
STTMuteFrame, # mute=True
BotStoppedSpeakingFrame,
STTMuteFrame, # mute=False
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
BotStartedSpeakingFrame,
STTMuteFrame, # mute=True
BotStoppedSpeakingFrame,
STTMuteFrame, # mute=False
]
await run_test(
filter,
frames_to_send=frames_to_send,
expected_down_frames=expected_returned_frames,
)
# TODO: Revisit once we figure out how to test SystemFrames and DataFrames
# async def test_function_call_strategy(self):
# filter = STTMuteFilter(config=STTMuteConfig(strategies={STTMuteStrategy.FUNCTION_CALL}))
# frames_to_send = [
# UserStartedSpeakingFrame(), # Should pass through initially
# UserStoppedSpeakingFrame(),
# FunctionCallInProgressFrame(
# function_name="get_weather",
# tool_call_id="call_123",
# arguments='{"location": "San Francisco"}',
# ), # Start function call
# UserStartedSpeakingFrame(), # Should be suppressed
# UserStoppedSpeakingFrame(), # Should be suppressed
# FunctionCallResultFrame(
# function_name="get_weather",
# tool_call_id="call_123",
# arguments='{"location": "San Francisco"}',
# result={"temperature": 22},
# ), # End function call
# UserStartedSpeakingFrame(), # Should pass through again
# UserStoppedSpeakingFrame(),
# ]
# expected_returned_frames = [
# UserStartedSpeakingFrame,
# UserStoppedSpeakingFrame,
# FunctionCallInProgressFrame,
# STTMuteFrame, # mute=True
# FunctionCallResultFrame,
# STTMuteFrame, # mute=False
# UserStartedSpeakingFrame,
# UserStoppedSpeakingFrame,
# ]
# await run_test(
# filter,
# frames_to_send=frames_to_send,
# expected_down_frames=expected_returned_frames,
# )
async def test_mute_until_first_bot_complete_strategy(self):
filter = STTMuteFilter(
config=STTMuteConfig(strategies={STTMuteStrategy.MUTE_UNTIL_FIRST_BOT_COMPLETE})
)
frames_to_send = [
UserStartedSpeakingFrame(), # Should be suppressed (starts muted)
UserStoppedSpeakingFrame(), # Should be suppressed
BotStartedSpeakingFrame(), # First bot speech
UserStartedSpeakingFrame(), # Should be suppressed
UserStoppedSpeakingFrame(), # Should be suppressed
BotStoppedSpeakingFrame(), # First speech ends, unmutes
UserStartedSpeakingFrame(), # Should pass through
UserStoppedSpeakingFrame(), # Should pass through
BotStartedSpeakingFrame(), # Second speech
UserStartedSpeakingFrame(), # Should pass through
UserStoppedSpeakingFrame(), # Should pass through
BotStoppedSpeakingFrame(),
]
expected_returned_frames = [
STTMuteFrame, # mute=True after first speech
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
STTMuteFrame, # mute=False after first speech
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
BotStartedSpeakingFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
BotStoppedSpeakingFrame,
]
await run_test(
filter,
frames_to_send=frames_to_send,
expected_down_frames=expected_returned_frames,
)
async def test_incompatible_strategies(self):
with self.assertRaises(ValueError):
STTMuteFilter(
config=STTMuteConfig(
strategies={
STTMuteStrategy.FIRST_SPEECH,
STTMuteStrategy.MUTE_UNTIL_FIRST_BOT_COMPLETE,
}
)
)
async def test_custom_strategy(self):
async def custom_mute_logic(processor: STTMuteFilter) -> bool:
return processor._bot_is_speaking
filter = STTMuteFilter(
config=STTMuteConfig(
strategies={STTMuteStrategy.CUSTOM},
should_mute_callback=custom_mute_logic,
)
)
frames_to_send = [
UserStartedSpeakingFrame(), # Should pass through
UserStoppedSpeakingFrame(), # Should pass through
BotStartedSpeakingFrame(), # Bot starts speaking
UserStartedSpeakingFrame(), # Should be suppressed
UserStoppedSpeakingFrame(), # Should be suppressed
BotStoppedSpeakingFrame(), # Bot stops speaking
UserStartedSpeakingFrame(), # Should pass through
UserStoppedSpeakingFrame(), # Should pass through
]
expected_returned_frames = [
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
BotStartedSpeakingFrame,
STTMuteFrame, # mute=True
BotStoppedSpeakingFrame,
STTMuteFrame, # mute=False
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
]
await run_test(
filter,
frames_to_send=frames_to_send,
expected_down_frames=expected_returned_frames,
)