Handle starting muted, add tests
This commit is contained in:
@@ -23,6 +23,7 @@ from pipecat.frames.frames import (
|
||||
Frame,
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
StopInterruptionFrame,
|
||||
STTMuteFrame,
|
||||
@@ -108,7 +109,7 @@ class STTMuteFilter(FrameProcessor):
|
||||
self._first_speech_handled = False
|
||||
self._bot_is_speaking = False
|
||||
self._function_call_in_progress = False
|
||||
self._is_muted = STTMuteStrategy.MUTE_UNTIL_FIRST_BOT_COMPLETE in self._config.strategies
|
||||
self._is_muted = False # Initialize as unmuted, will set state on StartFrame if needed
|
||||
|
||||
@property
|
||||
def is_muted(self) -> bool:
|
||||
@@ -155,6 +156,13 @@ class STTMuteFilter(FrameProcessor):
|
||||
"""Processes incoming frames and manages muting state."""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# Handle initial state on StartFrame
|
||||
if isinstance(frame, StartFrame):
|
||||
# Check if we should start muted
|
||||
should_mute = await self._should_mute()
|
||||
if should_mute:
|
||||
await self._handle_mute_state(True)
|
||||
|
||||
# First determine if we need to change mute state
|
||||
should_mute = None
|
||||
|
||||
|
||||
217
tests/test_stt_mute_filter.py
Normal file
217
tests/test_stt_mute_filter.py
Normal file
@@ -0,0 +1,217 @@
|
||||
#
|
||||
# Copyright (c) 2024-2025 Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import unittest
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
STTMuteFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.filters.stt_mute_filter import STTMuteConfig, STTMuteFilter, STTMuteStrategy
|
||||
from pipecat.tests.utils import run_test
|
||||
|
||||
|
||||
class TestSTTMuteFilter(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_first_speech_strategy(self):
|
||||
filter = STTMuteFilter(config=STTMuteConfig(strategies={STTMuteStrategy.FIRST_SPEECH}))
|
||||
|
||||
frames_to_send = [
|
||||
BotStartedSpeakingFrame(), # First bot speech starts
|
||||
UserStartedSpeakingFrame(), # Should be suppressed
|
||||
UserStoppedSpeakingFrame(), # Should be suppressed
|
||||
BotStoppedSpeakingFrame(), # First bot speech ends
|
||||
BotStartedSpeakingFrame(), # Second bot speech
|
||||
UserStartedSpeakingFrame(), # Should pass through
|
||||
UserStoppedSpeakingFrame(), # Should pass through
|
||||
BotStoppedSpeakingFrame(),
|
||||
]
|
||||
|
||||
expected_returned_frames = [
|
||||
BotStartedSpeakingFrame,
|
||||
STTMuteFrame, # mute=True
|
||||
BotStoppedSpeakingFrame,
|
||||
STTMuteFrame, # mute=False
|
||||
BotStartedSpeakingFrame,
|
||||
UserStartedSpeakingFrame, # Now passes through
|
||||
UserStoppedSpeakingFrame, # Now passes through
|
||||
BotStoppedSpeakingFrame,
|
||||
]
|
||||
|
||||
await run_test(
|
||||
filter,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_returned_frames,
|
||||
)
|
||||
|
||||
async def test_always_strategy(self):
|
||||
filter = STTMuteFilter(config=STTMuteConfig(strategies={STTMuteStrategy.ALWAYS}))
|
||||
|
||||
frames_to_send = [
|
||||
BotStartedSpeakingFrame(), # First speech starts
|
||||
UserStartedSpeakingFrame(), # Should be suppressed
|
||||
UserStoppedSpeakingFrame(), # Should be suppressed
|
||||
BotStoppedSpeakingFrame(), # First speech ends
|
||||
UserStartedSpeakingFrame(), # Should pass through
|
||||
UserStoppedSpeakingFrame(), # Should pass through
|
||||
BotStartedSpeakingFrame(), # Second speech starts
|
||||
UserStartedSpeakingFrame(), # Should be suppressed again
|
||||
UserStoppedSpeakingFrame(), # Should be suppressed again
|
||||
BotStoppedSpeakingFrame(), # Second speech ends
|
||||
]
|
||||
|
||||
expected_returned_frames = [
|
||||
BotStartedSpeakingFrame,
|
||||
STTMuteFrame, # mute=True
|
||||
BotStoppedSpeakingFrame,
|
||||
STTMuteFrame, # mute=False
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
STTMuteFrame, # mute=True
|
||||
BotStoppedSpeakingFrame,
|
||||
STTMuteFrame, # mute=False
|
||||
]
|
||||
|
||||
await run_test(
|
||||
filter,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_returned_frames,
|
||||
)
|
||||
|
||||
# TODO: Revisit once we figure out how to test SystemFrames and DataFrames
|
||||
# async def test_function_call_strategy(self):
|
||||
# filter = STTMuteFilter(config=STTMuteConfig(strategies={STTMuteStrategy.FUNCTION_CALL}))
|
||||
|
||||
# frames_to_send = [
|
||||
# UserStartedSpeakingFrame(), # Should pass through initially
|
||||
# UserStoppedSpeakingFrame(),
|
||||
# FunctionCallInProgressFrame(
|
||||
# function_name="get_weather",
|
||||
# tool_call_id="call_123",
|
||||
# arguments='{"location": "San Francisco"}',
|
||||
# ), # Start function call
|
||||
# UserStartedSpeakingFrame(), # Should be suppressed
|
||||
# UserStoppedSpeakingFrame(), # Should be suppressed
|
||||
# FunctionCallResultFrame(
|
||||
# function_name="get_weather",
|
||||
# tool_call_id="call_123",
|
||||
# arguments='{"location": "San Francisco"}',
|
||||
# result={"temperature": 22},
|
||||
# ), # End function call
|
||||
# UserStartedSpeakingFrame(), # Should pass through again
|
||||
# UserStoppedSpeakingFrame(),
|
||||
# ]
|
||||
|
||||
# expected_returned_frames = [
|
||||
# UserStartedSpeakingFrame,
|
||||
# UserStoppedSpeakingFrame,
|
||||
# FunctionCallInProgressFrame,
|
||||
# STTMuteFrame, # mute=True
|
||||
# FunctionCallResultFrame,
|
||||
# STTMuteFrame, # mute=False
|
||||
# UserStartedSpeakingFrame,
|
||||
# UserStoppedSpeakingFrame,
|
||||
# ]
|
||||
|
||||
# await run_test(
|
||||
# filter,
|
||||
# frames_to_send=frames_to_send,
|
||||
# expected_down_frames=expected_returned_frames,
|
||||
# )
|
||||
|
||||
async def test_mute_until_first_bot_complete_strategy(self):
|
||||
filter = STTMuteFilter(
|
||||
config=STTMuteConfig(strategies={STTMuteStrategy.MUTE_UNTIL_FIRST_BOT_COMPLETE})
|
||||
)
|
||||
|
||||
frames_to_send = [
|
||||
UserStartedSpeakingFrame(), # Should be suppressed (starts muted)
|
||||
UserStoppedSpeakingFrame(), # Should be suppressed
|
||||
BotStartedSpeakingFrame(), # First bot speech
|
||||
UserStartedSpeakingFrame(), # Should be suppressed
|
||||
UserStoppedSpeakingFrame(), # Should be suppressed
|
||||
BotStoppedSpeakingFrame(), # First speech ends, unmutes
|
||||
UserStartedSpeakingFrame(), # Should pass through
|
||||
UserStoppedSpeakingFrame(), # Should pass through
|
||||
BotStartedSpeakingFrame(), # Second speech
|
||||
UserStartedSpeakingFrame(), # Should pass through
|
||||
UserStoppedSpeakingFrame(), # Should pass through
|
||||
BotStoppedSpeakingFrame(),
|
||||
]
|
||||
|
||||
expected_returned_frames = [
|
||||
STTMuteFrame, # mute=True after first speech
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
STTMuteFrame, # mute=False after first speech
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
]
|
||||
|
||||
await run_test(
|
||||
filter,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_returned_frames,
|
||||
)
|
||||
|
||||
async def test_incompatible_strategies(self):
|
||||
with self.assertRaises(ValueError):
|
||||
STTMuteFilter(
|
||||
config=STTMuteConfig(
|
||||
strategies={
|
||||
STTMuteStrategy.FIRST_SPEECH,
|
||||
STTMuteStrategy.MUTE_UNTIL_FIRST_BOT_COMPLETE,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
async def test_custom_strategy(self):
|
||||
async def custom_mute_logic(processor: STTMuteFilter) -> bool:
|
||||
return processor._bot_is_speaking
|
||||
|
||||
filter = STTMuteFilter(
|
||||
config=STTMuteConfig(
|
||||
strategies={STTMuteStrategy.CUSTOM},
|
||||
should_mute_callback=custom_mute_logic,
|
||||
)
|
||||
)
|
||||
|
||||
frames_to_send = [
|
||||
UserStartedSpeakingFrame(), # Should pass through
|
||||
UserStoppedSpeakingFrame(), # Should pass through
|
||||
BotStartedSpeakingFrame(), # Bot starts speaking
|
||||
UserStartedSpeakingFrame(), # Should be suppressed
|
||||
UserStoppedSpeakingFrame(), # Should be suppressed
|
||||
BotStoppedSpeakingFrame(), # Bot stops speaking
|
||||
UserStartedSpeakingFrame(), # Should pass through
|
||||
UserStoppedSpeakingFrame(), # Should pass through
|
||||
]
|
||||
|
||||
expected_returned_frames = [
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
STTMuteFrame, # mute=True
|
||||
BotStoppedSpeakingFrame,
|
||||
STTMuteFrame, # mute=False
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
]
|
||||
|
||||
await run_test(
|
||||
filter,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_returned_frames,
|
||||
)
|
||||
Reference in New Issue
Block a user