Merge pull request #1214 from pipecat-ai/mb/stt-mute-tests
Improve STTMuteFilter, add tests
This commit is contained in:
@@ -618,6 +618,13 @@ class FunctionCallInProgressFrame(SystemFrame):
|
||||
arguments: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class STTMuteFrame(SystemFrame):
|
||||
"""System frame to mute/unmute the STT service."""
|
||||
|
||||
mute: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransportMessageUrgentFrame(SystemFrame):
|
||||
message: Any
|
||||
@@ -752,13 +759,6 @@ class TTSUpdateSettingsFrame(ServiceUpdateSettingsFrame):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class STTMuteFrame(ControlFrame):
|
||||
"""Control frame to mute/unmute the STT service."""
|
||||
|
||||
mute: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class STTUpdateSettingsFrame(ServiceUpdateSettingsFrame):
|
||||
pass
|
||||
|
||||
@@ -23,6 +23,7 @@ from pipecat.frames.frames import (
|
||||
Frame,
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
StopInterruptionFrame,
|
||||
STTMuteFrame,
|
||||
@@ -108,7 +109,7 @@ class STTMuteFilter(FrameProcessor):
|
||||
self._first_speech_handled = False
|
||||
self._bot_is_speaking = False
|
||||
self._function_call_in_progress = False
|
||||
self._is_muted = STTMuteStrategy.MUTE_UNTIL_FIRST_BOT_COMPLETE in self._config.strategies
|
||||
self._is_muted = False # Initialize as unmuted, will set state on StartFrame if needed
|
||||
|
||||
@property
|
||||
def is_muted(self) -> bool:
|
||||
@@ -155,24 +156,28 @@ class STTMuteFilter(FrameProcessor):
|
||||
"""Processes incoming frames and manages muting state."""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# Handle function call state changes
|
||||
if isinstance(frame, FunctionCallInProgressFrame):
|
||||
# Determine if we need to change mute state based on frame type
|
||||
should_mute = None
|
||||
|
||||
# Process frames to determine mute state
|
||||
if isinstance(frame, StartFrame):
|
||||
should_mute = await self._should_mute()
|
||||
elif isinstance(frame, FunctionCallInProgressFrame):
|
||||
self._function_call_in_progress = True
|
||||
await self._handle_mute_state(await self._should_mute())
|
||||
should_mute = await self._should_mute()
|
||||
elif isinstance(frame, FunctionCallResultFrame):
|
||||
self._function_call_in_progress = False
|
||||
await self._handle_mute_state(await self._should_mute())
|
||||
# Handle bot speaking state changes
|
||||
should_mute = await self._should_mute()
|
||||
elif isinstance(frame, BotStartedSpeakingFrame):
|
||||
self._bot_is_speaking = True
|
||||
await self._handle_mute_state(await self._should_mute())
|
||||
should_mute = await self._should_mute()
|
||||
elif isinstance(frame, BotStoppedSpeakingFrame):
|
||||
self._bot_is_speaking = False
|
||||
if not self._first_speech_handled:
|
||||
self._first_speech_handled = True
|
||||
await self._handle_mute_state(await self._should_mute())
|
||||
should_mute = await self._should_mute()
|
||||
|
||||
# Handle frame propagation
|
||||
# Then push the original frame
|
||||
if isinstance(
|
||||
frame,
|
||||
(
|
||||
@@ -190,3 +195,7 @@ class STTMuteFilter(FrameProcessor):
|
||||
else:
|
||||
# Pass all other frames through
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
# Finally handle mute state change if needed
|
||||
if should_mute is not None and should_mute != self.is_muted:
|
||||
await self._handle_mute_state(should_mute)
|
||||
|
||||
@@ -525,9 +525,13 @@ class STTService(AIService):
|
||||
else:
|
||||
logger.warning(f"Unknown setting for STT service: {key}")
|
||||
|
||||
async def process_audio_frame(self, frame: AudioRawFrame):
|
||||
if not self._muted:
|
||||
await self.process_generator(self.run_stt(frame.audio))
|
||||
async def process_audio_frame(self, frame: AudioRawFrame, direction: FrameDirection):
|
||||
if self._muted:
|
||||
return
|
||||
|
||||
await self.process_generator(self.run_stt(frame.audio))
|
||||
if self._audio_passthrough:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Processes a frame of audio data, either buffering or transcribing it."""
|
||||
@@ -537,9 +541,7 @@ class STTService(AIService):
|
||||
# In this service we accumulate audio internally and at the end we
|
||||
# push a TextFrame. We also push audio downstream in case someone
|
||||
# else needs it.
|
||||
await self.process_audio_frame(frame)
|
||||
if self._audio_passthrough:
|
||||
await self.push_frame(frame, direction)
|
||||
await self.process_audio_frame(frame, direction)
|
||||
elif isinstance(frame, STTUpdateSettingsFrame):
|
||||
await self._update_settings(frame.settings)
|
||||
elif isinstance(frame, STTMuteFrame):
|
||||
|
||||
217
tests/test_stt_mute_filter.py
Normal file
217
tests/test_stt_mute_filter.py
Normal file
@@ -0,0 +1,217 @@
|
||||
#
|
||||
# Copyright (c) 2024-2025 Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import unittest
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
STTMuteFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.filters.stt_mute_filter import STTMuteConfig, STTMuteFilter, STTMuteStrategy
|
||||
from pipecat.tests.utils import run_test
|
||||
|
||||
|
||||
class TestSTTMuteFilter(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_first_speech_strategy(self):
|
||||
filter = STTMuteFilter(config=STTMuteConfig(strategies={STTMuteStrategy.FIRST_SPEECH}))
|
||||
|
||||
frames_to_send = [
|
||||
BotStartedSpeakingFrame(), # First bot speech starts
|
||||
UserStartedSpeakingFrame(), # Should be suppressed
|
||||
UserStoppedSpeakingFrame(), # Should be suppressed
|
||||
BotStoppedSpeakingFrame(), # First bot speech ends
|
||||
BotStartedSpeakingFrame(), # Second bot speech
|
||||
UserStartedSpeakingFrame(), # Should pass through
|
||||
UserStoppedSpeakingFrame(), # Should pass through
|
||||
BotStoppedSpeakingFrame(),
|
||||
]
|
||||
|
||||
expected_returned_frames = [
|
||||
BotStartedSpeakingFrame,
|
||||
STTMuteFrame, # mute=True
|
||||
BotStoppedSpeakingFrame,
|
||||
STTMuteFrame, # mute=False
|
||||
BotStartedSpeakingFrame,
|
||||
UserStartedSpeakingFrame, # Now passes through
|
||||
UserStoppedSpeakingFrame, # Now passes through
|
||||
BotStoppedSpeakingFrame,
|
||||
]
|
||||
|
||||
await run_test(
|
||||
filter,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_returned_frames,
|
||||
)
|
||||
|
||||
async def test_always_strategy(self):
|
||||
filter = STTMuteFilter(config=STTMuteConfig(strategies={STTMuteStrategy.ALWAYS}))
|
||||
|
||||
frames_to_send = [
|
||||
BotStartedSpeakingFrame(), # First speech starts
|
||||
UserStartedSpeakingFrame(), # Should be suppressed
|
||||
UserStoppedSpeakingFrame(), # Should be suppressed
|
||||
BotStoppedSpeakingFrame(), # First speech ends
|
||||
UserStartedSpeakingFrame(), # Should pass through
|
||||
UserStoppedSpeakingFrame(), # Should pass through
|
||||
BotStartedSpeakingFrame(), # Second speech starts
|
||||
UserStartedSpeakingFrame(), # Should be suppressed again
|
||||
UserStoppedSpeakingFrame(), # Should be suppressed again
|
||||
BotStoppedSpeakingFrame(), # Second speech ends
|
||||
]
|
||||
|
||||
expected_returned_frames = [
|
||||
BotStartedSpeakingFrame,
|
||||
STTMuteFrame, # mute=True
|
||||
BotStoppedSpeakingFrame,
|
||||
STTMuteFrame, # mute=False
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
STTMuteFrame, # mute=True
|
||||
BotStoppedSpeakingFrame,
|
||||
STTMuteFrame, # mute=False
|
||||
]
|
||||
|
||||
await run_test(
|
||||
filter,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_returned_frames,
|
||||
)
|
||||
|
||||
# TODO: Revisit once we figure out how to test SystemFrames and DataFrames
|
||||
# async def test_function_call_strategy(self):
|
||||
# filter = STTMuteFilter(config=STTMuteConfig(strategies={STTMuteStrategy.FUNCTION_CALL}))
|
||||
|
||||
# frames_to_send = [
|
||||
# UserStartedSpeakingFrame(), # Should pass through initially
|
||||
# UserStoppedSpeakingFrame(),
|
||||
# FunctionCallInProgressFrame(
|
||||
# function_name="get_weather",
|
||||
# tool_call_id="call_123",
|
||||
# arguments='{"location": "San Francisco"}',
|
||||
# ), # Start function call
|
||||
# UserStartedSpeakingFrame(), # Should be suppressed
|
||||
# UserStoppedSpeakingFrame(), # Should be suppressed
|
||||
# FunctionCallResultFrame(
|
||||
# function_name="get_weather",
|
||||
# tool_call_id="call_123",
|
||||
# arguments='{"location": "San Francisco"}',
|
||||
# result={"temperature": 22},
|
||||
# ), # End function call
|
||||
# UserStartedSpeakingFrame(), # Should pass through again
|
||||
# UserStoppedSpeakingFrame(),
|
||||
# ]
|
||||
|
||||
# expected_returned_frames = [
|
||||
# UserStartedSpeakingFrame,
|
||||
# UserStoppedSpeakingFrame,
|
||||
# FunctionCallInProgressFrame,
|
||||
# STTMuteFrame, # mute=True
|
||||
# FunctionCallResultFrame,
|
||||
# STTMuteFrame, # mute=False
|
||||
# UserStartedSpeakingFrame,
|
||||
# UserStoppedSpeakingFrame,
|
||||
# ]
|
||||
|
||||
# await run_test(
|
||||
# filter,
|
||||
# frames_to_send=frames_to_send,
|
||||
# expected_down_frames=expected_returned_frames,
|
||||
# )
|
||||
|
||||
async def test_mute_until_first_bot_complete_strategy(self):
|
||||
filter = STTMuteFilter(
|
||||
config=STTMuteConfig(strategies={STTMuteStrategy.MUTE_UNTIL_FIRST_BOT_COMPLETE})
|
||||
)
|
||||
|
||||
frames_to_send = [
|
||||
UserStartedSpeakingFrame(), # Should be suppressed (starts muted)
|
||||
UserStoppedSpeakingFrame(), # Should be suppressed
|
||||
BotStartedSpeakingFrame(), # First bot speech
|
||||
UserStartedSpeakingFrame(), # Should be suppressed
|
||||
UserStoppedSpeakingFrame(), # Should be suppressed
|
||||
BotStoppedSpeakingFrame(), # First speech ends, unmutes
|
||||
UserStartedSpeakingFrame(), # Should pass through
|
||||
UserStoppedSpeakingFrame(), # Should pass through
|
||||
BotStartedSpeakingFrame(), # Second speech
|
||||
UserStartedSpeakingFrame(), # Should pass through
|
||||
UserStoppedSpeakingFrame(), # Should pass through
|
||||
BotStoppedSpeakingFrame(),
|
||||
]
|
||||
|
||||
expected_returned_frames = [
|
||||
STTMuteFrame, # mute=True after first speech
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
STTMuteFrame, # mute=False after first speech
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
]
|
||||
|
||||
await run_test(
|
||||
filter,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_returned_frames,
|
||||
)
|
||||
|
||||
async def test_incompatible_strategies(self):
|
||||
with self.assertRaises(ValueError):
|
||||
STTMuteFilter(
|
||||
config=STTMuteConfig(
|
||||
strategies={
|
||||
STTMuteStrategy.FIRST_SPEECH,
|
||||
STTMuteStrategy.MUTE_UNTIL_FIRST_BOT_COMPLETE,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
async def test_custom_strategy(self):
|
||||
async def custom_mute_logic(processor: STTMuteFilter) -> bool:
|
||||
return processor._bot_is_speaking
|
||||
|
||||
filter = STTMuteFilter(
|
||||
config=STTMuteConfig(
|
||||
strategies={STTMuteStrategy.CUSTOM},
|
||||
should_mute_callback=custom_mute_logic,
|
||||
)
|
||||
)
|
||||
|
||||
frames_to_send = [
|
||||
UserStartedSpeakingFrame(), # Should pass through
|
||||
UserStoppedSpeakingFrame(), # Should pass through
|
||||
BotStartedSpeakingFrame(), # Bot starts speaking
|
||||
UserStartedSpeakingFrame(), # Should be suppressed
|
||||
UserStoppedSpeakingFrame(), # Should be suppressed
|
||||
BotStoppedSpeakingFrame(), # Bot stops speaking
|
||||
UserStartedSpeakingFrame(), # Should pass through
|
||||
UserStoppedSpeakingFrame(), # Should pass through
|
||||
]
|
||||
|
||||
expected_returned_frames = [
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
STTMuteFrame, # mute=True
|
||||
BotStoppedSpeakingFrame,
|
||||
STTMuteFrame, # mute=False
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
]
|
||||
|
||||
await run_test(
|
||||
filter,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_returned_frames,
|
||||
)
|
||||
Reference in New Issue
Block a user