diff --git a/examples/voice/voice-krisp-viva.py b/examples/voice/voice-krisp-viva.py index 5fdefd2e1..06a236e0d 100644 --- a/examples/voice/voice-krisp-viva.py +++ b/examples/voice/voice-krisp-viva.py @@ -4,20 +4,24 @@ # SPDX-License-Identifier: BSD 2-Clause License # -"""Interruptible bot with Krisp VIVA noise filtering and turn detection. +"""Interruptible bot with Krisp VIVA noise filtering, turn detection, and IP. This example demonstrates a conversational bot with: - Krisp VIVA noise reduction on incoming audio -- Krisp VIVA Turn detection for natural interruptions +- Krisp VIVA Turn detection for end-of-turn +- Krisp Interruption Prediction (IP) to filter backchannels from real interruptions - Voice activity detection (VAD) Required environment variables: - KRISP_VIVA_FILTER_MODEL_PATH: Path to the Krisp noise filter model file (.kef) - KRISP_VIVA_TURN_MODEL_PATH: Path to the Krisp turn detection model file (.kef) -- DEEPGRAM_API_KEY: Deepgram API key for STT/TTS +- KRISP_VIVA_IP_MODEL_PATH: Path to the Krisp IP model file (.kef) +- DEEPGRAM_API_KEY: Deepgram API key for STT +- CARTESIA_API_KEY: Cartesia API key for TTS - OPENAI_API_KEY: OpenAI API key for LLM Optional environment variables: +- KRISP_VIVA_API_KEY: Krisp SDK API key (or set in code) - KRISP_NOISE_SUPPRESSION_LEVEL: Noise suppression level 0-100 (default: 100) Higher values = more aggressive noise reduction """ @@ -49,31 +53,30 @@ from pipecat.services.openai.llm import OpenAILLMService from pipecat.transports.base_transport import BaseTransport, TransportParams from pipecat.transports.daily.transport import DailyParams from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams +from pipecat.turns.user_start import ( + KrispVivaIPUserTurnStartStrategy, + TranscriptionUserTurnStartStrategy, +) from pipecat.turns.user_stop import TurnAnalyzerUserTurnStopStrategy from pipecat.turns.user_turn_strategies import UserTurnStrategies load_dotenv(override=True) -# We use lambdas to defer transport parameter creation until the transport -# type is selected at runtime. - -krisp_viva_filter = KrispVivaFilter() - transport_params = { "daily": lambda: DailyParams( audio_in_enabled=True, audio_out_enabled=True, - audio_in_filter=krisp_viva_filter, + audio_in_filter=KrispVivaFilter(), ), "twilio": lambda: FastAPIWebsocketParams( audio_in_enabled=True, audio_out_enabled=True, - audio_in_filter=krisp_viva_filter, + audio_in_filter=KrispVivaFilter(), ), "webrtc": lambda: TransportParams( audio_in_enabled=True, audio_out_enabled=True, - audio_in_filter=krisp_viva_filter, + audio_in_filter=KrispVivaFilter(), ), } @@ -102,7 +105,11 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments): context, user_params=LLMUserAggregatorParams( user_turn_strategies=UserTurnStrategies( - stop=[TurnAnalyzerUserTurnStopStrategy(turn_analyzer=KrispVivaTurn())] + start=[ + KrispVivaIPUserTurnStartStrategy(threshold=0.5), + TranscriptionUserTurnStartStrategy(), + ], + stop=[TurnAnalyzerUserTurnStopStrategy(turn_analyzer=KrispVivaTurn())], ), vad_analyzer=SileroVADAnalyzer(), # or KrispVivaVadAnalyzer ), diff --git a/src/pipecat/audio/krisp_instance.py b/src/pipecat/audio/krisp_instance.py index 5ebfd24cc..94e3dfe3f 100644 --- a/src/pipecat/audio/krisp_instance.py +++ b/src/pipecat/audio/krisp_instance.py @@ -17,7 +17,7 @@ try: except ModuleNotFoundError as e: logger.error(f"Exception: {e}") logger.error("In order to use the Krisp instance, you need to install krisp_audio.") - raise Exception(f"Missing module: {e}") + raise ImportError(f"Missing module: {e}") from e # Mapping of sample rates (Hz) to Krisp SDK SamplingRate enums diff --git a/src/pipecat/audio/turn/krisp_viva_turn.py b/src/pipecat/audio/turn/krisp_viva_turn.py index 5235a94be..a9cf6a847 100644 --- a/src/pipecat/audio/turn/krisp_viva_turn.py +++ b/src/pipecat/audio/turn/krisp_viva_turn.py @@ -7,7 +7,9 @@ """Krisp turn analyzer for end-of-turn detection using Krisp VIVA SDK. This module provides a turn analyzer implementation using Krisp's turn detection -(Tt) API to determine when a user has finished speaking in a conversation. +v3 (Tt) API to determine when a user has finished speaking in a conversation. +The Tt API accepts an external VAD flag alongside audio frames, allowing the +model to leverage voice activity information for more accurate turn detection. Note: This analyzer uses a different model than KrispVivaFilter. The model path can be specified via the KRISP_VIVA_TURN_MODEL_PATH environment variable or @@ -33,7 +35,7 @@ try: except ModuleNotFoundError as e: logger.error(f"Exception: {e}") logger.error("In order to use KrispVivaTurn, you need to install krisp_audio.") - raise Exception(f"Missing module: {e}") + raise ImportError(f"Missing module: {e}") from e class KrispTurnParams(BaseTurnParams): @@ -53,8 +55,10 @@ class KrispTurnParams(BaseTurnParams): class KrispVivaTurn(BaseTurnAnalyzer): """Turn analyzer using Krisp VIVA SDK for end-of-turn detection. - Uses Krisp's turn detection (Tt) API to determine when a user has finished - speaking. This analyzer requires a valid Krisp model file to operate. + Uses Krisp's turn detection v3 (Tt) API to determine when a user has + finished speaking. The Tt API receives an external VAD flag with each + audio frame, which the ``is_speech`` parameter of ``append_audio`` + provides. This analyzer requires a valid Krisp model file to operate. """ def __init__( @@ -158,14 +162,14 @@ class KrispVivaTurn(BaseTurnAnalyzer): """Create a turn detection session with the specified sample rate. Args: - sample_rate: Sample rate for the session + sample_rate: Sample rate for the session. Returns: - krisp_audio.TtFloat instance + krisp_audio.TtFloat instance. Raises: - ValueError: If sample rate or frame duration is not supported - RuntimeError: If session creation fails + ValueError: If sample rate or frame duration is not supported. + RuntimeError: If session creation fails. """ try: model_info = krisp_audio.ModelInfo() @@ -306,12 +310,7 @@ class KrispVivaTurn(BaseTurnAnalyzer): # Instead, we wait for the model's probability check below to confirm # end-of-turn based on the threshold. - prob = self._tt_session.process(frame.tolist()) - - # Negative values indicate the model is not ready yet (working with 100ms data) - # Skip processing until we get positive probabilities - if prob < 0: - continue + prob = self._tt_session.process(frame.tolist(), is_speech, False) # Store the probability for external access self._last_probability = prob diff --git a/src/pipecat/turns/user_start/__init__.py b/src/pipecat/turns/user_start/__init__.py index 94d12708d..14de5d28b 100644 --- a/src/pipecat/turns/user_start/__init__.py +++ b/src/pipecat/turns/user_start/__init__.py @@ -11,9 +11,15 @@ from .transcription_user_turn_start_strategy import TranscriptionUserTurnStartSt from .vad_user_turn_start_strategy import VADUserTurnStartStrategy from .wake_phrase_user_turn_start_strategy import WakePhraseUserTurnStartStrategy +try: + from .krisp_viva_ip_user_turn_start_strategy import KrispVivaIPUserTurnStartStrategy +except ImportError: + KrispVivaIPUserTurnStartStrategy = None + __all__ = [ "BaseUserTurnStartStrategy", "ExternalUserTurnStartStrategy", + "KrispVivaIPUserTurnStartStrategy", "MinWordsUserTurnStartStrategy", "TranscriptionUserTurnStartStrategy", "UserTurnStartedParams", diff --git a/src/pipecat/turns/user_start/krisp_viva_ip_user_turn_start_strategy.py b/src/pipecat/turns/user_start/krisp_viva_ip_user_turn_start_strategy.py new file mode 100644 index 000000000..807bc8154 --- /dev/null +++ b/src/pipecat/turns/user_start/krisp_viva_ip_user_turn_start_strategy.py @@ -0,0 +1,282 @@ +# +# Copyright (c) 2024-2026, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +"""User turn start strategy using Krisp Interruption Prediction (IP). + +This strategy uses Krisp's IP model to distinguish genuine user interruptions +from backchannels (e.g. "uh-huh", "yeah"). Instead of triggering a user turn +on every VAD speech event, it collects audio after VAD detects speech and runs +the IP model to predict whether the speech is a real interruption. + +Only when the IP model's probability exceeds the configured threshold is +``trigger_user_turn_started()`` called. This prevents the bot from being +interrupted by brief acknowledgements or filler words. +""" + +import os + +import numpy as np +from loguru import logger + +from pipecat.audio.krisp_instance import ( + KrispVivaSDKManager, + int_to_krisp_frame_duration, + int_to_krisp_sample_rate, +) +from pipecat.frames.frames import ( + BotStoppedSpeakingFrame, + Frame, + InputAudioRawFrame, + VADUserStartedSpeakingFrame, + VADUserStoppedSpeakingFrame, +) +from pipecat.turns.types import ProcessFrameResult +from pipecat.turns.user_start.base_user_turn_start_strategy import BaseUserTurnStartStrategy + +try: + import krisp_audio +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error( + "In order to use KrispVivaIPUserTurnStartStrategy, you need to install krisp_audio." + ) + raise Exception(f"Missing module: {e}") + + +class KrispVivaIPUserTurnStartStrategy(BaseUserTurnStartStrategy): + """User turn start strategy using Krisp VIVA Interruption Prediction. + + When VAD detects user speech, this strategy feeds audio frames into + the Krisp VIVA IP model. The model outputs a probability indicating + whether the speech is a genuine interruption (as opposed to a + backchannel). A user turn is triggered only when this probability + exceeds the configured threshold. + + This strategy is designed to work alongside other start strategies + (e.g. ``TranscriptionUserTurnStartStrategy`` as a fallback) via the + strategy list in ``UserTurnStrategies``. + + Example:: + + from pipecat.turns.user_start import KrispVivaIPUserTurnStartStrategy + + strategies = UserTurnStrategies( + start=[ + KrispVivaIPUserTurnStartStrategy( + model_path="/path/to/ip_model.kef", + threshold=0.5, + ), + TranscriptionUserTurnStartStrategy(), + ], + ) + """ + + def __init__( + self, + *, + model_path: str | None = None, + threshold: float = 0.5, + frame_duration_ms: int = 20, + api_key: str = "", + **kwargs, + ): + """Initialize the Krisp VIVA IP user turn start strategy. + + Args: + model_path: Path to the Krisp VIVA IP model file (.kef). If None, + uses the KRISP_VIVA_IP_MODEL_PATH environment variable. + threshold: IP probability threshold (0.0 to 1.0). When the model's + output exceeds this value, the speech is classified as a genuine + interruption. + frame_duration_ms: Frame duration in milliseconds for IP processing. + Supported values: 10, 15, 20, 30, 32. + api_key: Krisp SDK API key. If empty, falls back to the + KRISP_VIVA_API_KEY environment variable. + **kwargs: Additional arguments passed to BaseUserTurnStartStrategy. + """ + super().__init__(**kwargs) + + self._threshold = threshold + self._frame_duration_ms = frame_duration_ms + self._api_key = api_key + + self._model_path = model_path or os.getenv("KRISP_VIVA_IP_MODEL_PATH") + if not self._model_path: + raise ValueError( + "IP model path must be provided via model_path or " + "KRISP_VIVA_IP_MODEL_PATH environment variable." + ) + if not self._model_path.endswith(".kef"): + raise ValueError("Model is expected with .kef extension") + if not os.path.isfile(self._model_path): + raise FileNotFoundError(f"IP model file not found: {self._model_path}") + + self._sdk_acquired = False + self._ip_session = None + self._samples_per_frame: int | None = None + self._sample_rate: int | None = None + + # State tracking + self._speech_active = False + self._audio_buffer = bytearray() + self._decision_made = False + + # Acquire SDK + try: + KrispVivaSDKManager.acquire(api_key=api_key) + self._sdk_acquired = True + except Exception as e: + raise RuntimeError(f"Failed to initialize Krisp SDK: {e}") + + async def cleanup(self): + """Release Krisp SDK resources.""" + if self._sdk_acquired: + try: + self._ip_session = None + KrispVivaSDKManager.release() + self._sdk_acquired = False + except Exception as e: + logger.error(f"Error cleaning up Krisp VIVA IP strategy: {e}", exc_info=True) + + def _ensure_session(self, sample_rate: int): + """Create or re-create the IP session when sample rate changes. + + Args: + sample_rate: Audio sample rate in Hz. + """ + if self._sample_rate == sample_rate and self._ip_session is not None: + return + + self._sample_rate = sample_rate + self._samples_per_frame = int((sample_rate * self._frame_duration_ms) / 1000) + + model_info = krisp_audio.ModelInfo() + model_info.path = self._model_path + + ip_cfg = krisp_audio.IpSessionConfig() + ip_cfg.inputSampleRate = int_to_krisp_sample_rate(sample_rate) + ip_cfg.inputFrameDuration = int_to_krisp_frame_duration(self._frame_duration_ms) + ip_cfg.modelInfo = model_info + + self._ip_session = krisp_audio.IpFloat.create(ip_cfg) + logger.debug(f"Krisp VIVA IP session created (sample_rate={sample_rate})") + + def _reset_state(self): + """Reset speech tracking state for the next candidate interruption.""" + self._speech_active = False + self._audio_buffer.clear() + self._decision_made = False + + async def reset(self): + """Reset the strategy to its initial state.""" + await super().reset() + self._reset_state() + + async def process_frame(self, frame: Frame) -> ProcessFrameResult: + """Process a frame to detect genuine user interruptions. + + On ``VADUserStartedSpeakingFrame``, begins collecting audio. + On ``InputAudioRawFrame``, feeds audio through the IP model and + triggers a user turn if the interruption probability exceeds the + threshold. + On ``VADUserStoppedSpeakingFrame`` or ``BotStoppedSpeakingFrame``, + resets the candidate state. + + Args: + frame: The incoming frame. + + Returns: + STOP if a genuine interruption was detected, CONTINUE otherwise. + """ + if isinstance(frame, VADUserStartedSpeakingFrame): + return await self._handle_vad_started(frame) + elif isinstance(frame, InputAudioRawFrame): + return await self._handle_audio(frame) + elif isinstance(frame, (VADUserStoppedSpeakingFrame, BotStoppedSpeakingFrame)): + return await self._handle_reset(frame) + + return ProcessFrameResult.CONTINUE + + async def _handle_vad_started(self, frame: VADUserStartedSpeakingFrame) -> ProcessFrameResult: + """Begin collecting audio for interruption classification. + + Args: + frame: The VAD speech-start frame. + + Returns: + Always CONTINUE; the decision is deferred until enough audio is processed. + """ + logger.trace("Krisp VIVA IP: VAD speech started, collecting audio for classification") + self._speech_active = True + self._audio_buffer.clear() + self._decision_made = False + return ProcessFrameResult.CONTINUE + + async def _handle_audio(self, frame: InputAudioRawFrame) -> ProcessFrameResult: + """Feed audio to the IP model and check for genuine interruption. + + Args: + frame: Raw audio input frame. + + Returns: + STOP if the model detects a genuine interruption, CONTINUE otherwise. + """ + if not self._speech_active or self._decision_made: + return ProcessFrameResult.CONTINUE + + self._ensure_session(frame.sample_rate) + + if self._ip_session is None or self._samples_per_frame is None: + logger.warning("IP session not ready, skipping frame") + return ProcessFrameResult.CONTINUE + + self._audio_buffer.extend(frame.audio) + + total_samples = len(self._audio_buffer) // 2 # 2 bytes per int16 sample + num_complete_frames = total_samples // self._samples_per_frame + + if num_complete_frames == 0: + return ProcessFrameResult.CONTINUE + + complete_samples_count = num_complete_frames * self._samples_per_frame + bytes_to_process = complete_samples_count * 2 + + audio_to_process = bytes(self._audio_buffer[:bytes_to_process]) + self._audio_buffer = self._audio_buffer[bytes_to_process:] + + audio_int16 = np.frombuffer(audio_to_process, dtype=np.int16) + audio_float32 = audio_int16.astype(np.float32) / 32768.0 + frames = audio_float32.reshape(-1, self._samples_per_frame) + + for ip_frame in frames: + ip_prob = self._ip_session.process(ip_frame.tolist(), self._speech_active) + + if ip_prob >= self._threshold: + logger.debug( + f"Krisp VIVA IP: genuine interruption detected (prob={ip_prob:.3f}, " + f"threshold={self._threshold})" + ) + self._decision_made = True + await self.trigger_user_turn_started() + return ProcessFrameResult.STOP + + return ProcessFrameResult.CONTINUE + + async def _handle_reset( + self, frame: VADUserStoppedSpeakingFrame | BotStoppedSpeakingFrame + ) -> ProcessFrameResult: + """Reset state when the candidate interruption window ends. + + Args: + frame: The frame signaling end of speech or bot output. + + Returns: + Always CONTINUE. + """ + if self._speech_active: + logger.trace("Krisp VIVA IP: speech segment ended, resetting state") + self._reset_state() + return ProcessFrameResult.CONTINUE diff --git a/tests/test_krisp_ip_user_turn_start_strategy.py b/tests/test_krisp_ip_user_turn_start_strategy.py new file mode 100644 index 000000000..bb34d879a --- /dev/null +++ b/tests/test_krisp_ip_user_turn_start_strategy.py @@ -0,0 +1,236 @@ +# +# Copyright (c) 2024-2026, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import os +import sys +import tempfile +import unittest +from unittest.mock import MagicMock, patch + +import numpy as np + +# Mock package version check before importing pipecat (development mode) +_version_patcher = patch("importlib.metadata.version", return_value="0.0.0-dev") +_version_patcher.start() + +# Mock krisp_audio before any pipecat import that loads krisp_instance / VIVA IP strategy +mock_krisp_audio = MagicMock() +mock_krisp_audio.SamplingRate.Sr8000Hz = 8000 +mock_krisp_audio.SamplingRate.Sr16000Hz = 16000 +mock_krisp_audio.SamplingRate.Sr24000Hz = 24000 +mock_krisp_audio.SamplingRate.Sr32000Hz = 32000 +mock_krisp_audio.SamplingRate.Sr44100Hz = 44100 +mock_krisp_audio.SamplingRate.Sr48000Hz = 48000 +mock_krisp_audio.FrameDuration.Fd10ms = "10ms" +mock_krisp_audio.FrameDuration.Fd15ms = "15ms" +mock_krisp_audio.FrameDuration.Fd20ms = "20ms" +mock_krisp_audio.FrameDuration.Fd30ms = "30ms" +mock_krisp_audio.FrameDuration.Fd32ms = "32ms" + +sys.modules["krisp_audio"] = mock_krisp_audio + +mock_pipecat_krisp = MagicMock() +sys.modules["pipecat_ai_krisp"] = mock_pipecat_krisp +sys.modules["pipecat_ai_krisp.audio"] = MagicMock() +sys.modules["pipecat_ai_krisp.audio.krisp_processor"] = MagicMock() + +from pipecat.frames.frames import ( + BotStartedSpeakingFrame, + BotStoppedSpeakingFrame, + InputAudioRawFrame, + TranscriptionFrame, + VADUserStartedSpeakingFrame, + VADUserStoppedSpeakingFrame, +) +from pipecat.turns.types import ProcessFrameResult +from pipecat.turns.user_start.krisp_viva_ip_user_turn_start_strategy import ( + KrispVivaIPUserTurnStartStrategy, +) + +STRATEGY_MODULE = "pipecat.turns.user_start.krisp_viva_ip_user_turn_start_strategy" + + +def _int16_silence(num_samples: int) -> bytes: + return np.zeros(num_samples, dtype=np.int16).tobytes() + + +class TestKrispVivaIPUserTurnStartStrategy(unittest.IsolatedAsyncioTestCase): + """Tests for KrispVivaIPUserTurnStartStrategy with mocked krisp_audio.""" + + def setUp(self): + self.temp_model_file = tempfile.NamedTemporaryFile(suffix=".kef", delete=False) + self.temp_model_file.write(b"dummy") + self.temp_model_file.close() + self.model_path = self.temp_model_file.name + + self.mock_krisp_audio = mock_krisp_audio + self.mock_krisp_audio.reset_mock() + self.mock_krisp_audio.ModelInfo.reset_mock() + self.mock_krisp_audio.IpSessionConfig.reset_mock() + self.mock_krisp_audio.IpFloat.reset_mock() + + self.mock_model_info = MagicMock() + self.mock_krisp_audio.ModelInfo.return_value = self.mock_model_info + + self.mock_ip_cfg = MagicMock() + self.mock_krisp_audio.IpSessionConfig.return_value = self.mock_ip_cfg + + self.mock_ip_session = MagicMock() + self.mock_krisp_audio.IpFloat.create.return_value = self.mock_ip_session + + self.krisp_patch = patch(f"{STRATEGY_MODULE}.krisp_audio", self.mock_krisp_audio) + self.krisp_patch.start() + + self.sdk_patcher = patch(f"{STRATEGY_MODULE}.KrispVivaSDKManager") + self.mock_sdk_manager = self.sdk_patcher.start() + self.mock_sdk_manager.acquire = MagicMock() + self.mock_sdk_manager.release = MagicMock() + + def tearDown(self): + self.krisp_patch.stop() + self.sdk_patcher.stop() + if os.path.exists(self.model_path): + os.unlink(self.model_path) + + def _make_strategy(self, *, threshold: float = 0.5, frame_duration_ms: int = 20): + return KrispVivaIPUserTurnStartStrategy( + model_path=self.model_path, + threshold=threshold, + frame_duration_ms=frame_duration_ms, + api_key="test-key", + ) + + def _audio_frame(self, sample_rate: int = 16000, frame_duration_ms: int = 20): + samples = int(sample_rate * frame_duration_ms / 1000) + return InputAudioRawFrame( + audio=_int16_silence(samples), + sample_rate=sample_rate, + num_channels=1, + ) + + async def test_interruption_detected_emits_turn_and_stop(self): + self.mock_ip_session.process = MagicMock(return_value=0.87) + + strategy = self._make_strategy(threshold=0.5) + try: + fired = False + + @strategy.event_handler("on_user_turn_started") + async def on_user_turn_started(strategy, params): + nonlocal fired + fired = True + + await strategy.process_frame(VADUserStartedSpeakingFrame()) + result = await strategy.process_frame(self._audio_frame()) + + self.assertTrue(fired) + self.assertEqual(result, ProcessFrameResult.STOP) + self.mock_ip_session.process.assert_called() + finally: + await strategy.cleanup() + + async def test_backchannel_suppressed_no_event_continue(self): + self.mock_ip_session.process = MagicMock(return_value=0.23) + + strategy = self._make_strategy(threshold=0.5) + try: + fired = False + + @strategy.event_handler("on_user_turn_started") + async def on_user_turn_started(strategy, params): + nonlocal fired + fired = True + + await strategy.process_frame(VADUserStartedSpeakingFrame()) + result = await strategy.process_frame(self._audio_frame()) + + self.assertFalse(fired) + self.assertEqual(result, ProcessFrameResult.CONTINUE) + finally: + await strategy.cleanup() + + async def test_reset_on_vad_stopped_clears_state(self): + self.mock_ip_session.process = MagicMock(return_value=0.1) + + strategy = self._make_strategy(threshold=0.5) + try: + await strategy.process_frame(VADUserStartedSpeakingFrame()) + await strategy.process_frame(self._audio_frame()) + self.mock_ip_session.process.reset_mock() + + await strategy.process_frame(VADUserStoppedSpeakingFrame()) + result = await strategy.process_frame(self._audio_frame()) + + self.assertEqual(result, ProcessFrameResult.CONTINUE) + self.mock_ip_session.process.assert_not_called() + finally: + await strategy.cleanup() + + async def test_reset_on_bot_stopped_clears_state(self): + self.mock_ip_session.process = MagicMock(return_value=0.1) + + strategy = self._make_strategy(threshold=0.5) + try: + await strategy.process_frame(VADUserStartedSpeakingFrame()) + await strategy.process_frame(self._audio_frame()) + self.mock_ip_session.process.reset_mock() + + await strategy.process_frame(BotStoppedSpeakingFrame()) + result = await strategy.process_frame(self._audio_frame()) + + self.assertEqual(result, ProcessFrameResult.CONTINUE) + self.mock_ip_session.process.assert_not_called() + finally: + await strategy.cleanup() + + async def test_no_op_before_vad_start(self): + self.mock_ip_session.process = MagicMock(return_value=0.99) + + strategy = self._make_strategy() + try: + result = await strategy.process_frame(self._audio_frame()) + self.assertEqual(result, ProcessFrameResult.CONTINUE) + self.mock_ip_session.process.assert_not_called() + finally: + await strategy.cleanup() + + async def test_decision_sticks_no_double_trigger(self): + self.mock_ip_session.process = MagicMock(return_value=0.9) + + strategy = self._make_strategy(threshold=0.5) + try: + count = 0 + + @strategy.event_handler("on_user_turn_started") + async def on_user_turn_started(strategy, params): + nonlocal count + count += 1 + + await strategy.process_frame(VADUserStartedSpeakingFrame()) + r1 = await strategy.process_frame(self._audio_frame()) + r2 = await strategy.process_frame(self._audio_frame()) + + self.assertEqual(r1, ProcessFrameResult.STOP) + self.assertEqual(r2, ProcessFrameResult.CONTINUE) + self.assertEqual(count, 1) + finally: + await strategy.cleanup() + + async def test_unrelated_frames_continue(self): + strategy = self._make_strategy() + try: + r1 = await strategy.process_frame(BotStartedSpeakingFrame()) + r2 = await strategy.process_frame( + TranscriptionFrame(text="hi", user_id="", timestamp="") + ) + self.assertEqual(r1, ProcessFrameResult.CONTINUE) + self.assertEqual(r2, ProcessFrameResult.CONTINUE) + finally: + await strategy.cleanup() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_krisp_sdk_manager.py b/tests/test_krisp_sdk_manager.py index 78a4d955f..2edbf4598 100644 --- a/tests/test_krisp_sdk_manager.py +++ b/tests/test_krisp_sdk_manager.py @@ -62,8 +62,15 @@ class TestKrispVivaSDKManager: def setup_method(self): """Reset mocks and SDK state before each test.""" mock_krisp_audio.reset_mock() + mock_krisp_audio.globalInit.side_effect = None mock_krisp_audio.getVersion.return_value = mock_version + # Ensure krisp_instance module uses THIS test's mock, not a stale + # reference cached from a different test file's sys.modules entry. + import pipecat.audio.krisp_instance as _ki + + _ki.krisp_audio = mock_krisp_audio + # Reset the SDK manager state for clean tests # We access internal state to ensure tests are isolated with KrispVivaSDKManager._lock: