VIVA SDK TT v3 support (#4252)

* VIVA SDK TT v3 support

* Format fix.

* Renamed the API naming, removed '3' from the name.

* Implementation of User turn start strategy using Krisp VIVA Interruption Prediction in scope of TT v3 support.

* Typo fix in voice-krisp-viva example to use KrispVivaFilter class

* style fix.

* test run error fixes.

* some test related changes.

* Fixed tests

* Stule fixes.
This commit is contained in:
Garegin Harutyunyan
2026-04-17 15:53:41 +04:00
committed by GitHub
parent fc1c3b48dc
commit 4c19f5584c
7 changed files with 564 additions and 27 deletions

View File

@@ -4,20 +4,24 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Interruptible bot with Krisp VIVA noise filtering and turn detection.
"""Interruptible bot with Krisp VIVA noise filtering, turn detection, and IP.
This example demonstrates a conversational bot with:
- Krisp VIVA noise reduction on incoming audio
- Krisp VIVA Turn detection for natural interruptions
- Krisp VIVA Turn detection for end-of-turn
- Krisp Interruption Prediction (IP) to filter backchannels from real interruptions
- Voice activity detection (VAD)
Required environment variables:
- KRISP_VIVA_FILTER_MODEL_PATH: Path to the Krisp noise filter model file (.kef)
- KRISP_VIVA_TURN_MODEL_PATH: Path to the Krisp turn detection model file (.kef)
- DEEPGRAM_API_KEY: Deepgram API key for STT/TTS
- KRISP_VIVA_IP_MODEL_PATH: Path to the Krisp IP model file (.kef)
- DEEPGRAM_API_KEY: Deepgram API key for STT
- CARTESIA_API_KEY: Cartesia API key for TTS
- OPENAI_API_KEY: OpenAI API key for LLM
Optional environment variables:
- KRISP_VIVA_API_KEY: Krisp SDK API key (or set in code)
- KRISP_NOISE_SUPPRESSION_LEVEL: Noise suppression level 0-100 (default: 100)
Higher values = more aggressive noise reduction
"""
@@ -49,31 +53,30 @@ from pipecat.services.openai.llm import OpenAILLMService
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
from pipecat.turns.user_start import (
KrispVivaIPUserTurnStartStrategy,
TranscriptionUserTurnStartStrategy,
)
from pipecat.turns.user_stop import TurnAnalyzerUserTurnStopStrategy
from pipecat.turns.user_turn_strategies import UserTurnStrategies
load_dotenv(override=True)
# We use lambdas to defer transport parameter creation until the transport
# type is selected at runtime.
krisp_viva_filter = KrispVivaFilter()
transport_params = {
"daily": lambda: DailyParams(
audio_in_enabled=True,
audio_out_enabled=True,
audio_in_filter=krisp_viva_filter,
audio_in_filter=KrispVivaFilter(),
),
"twilio": lambda: FastAPIWebsocketParams(
audio_in_enabled=True,
audio_out_enabled=True,
audio_in_filter=krisp_viva_filter,
audio_in_filter=KrispVivaFilter(),
),
"webrtc": lambda: TransportParams(
audio_in_enabled=True,
audio_out_enabled=True,
audio_in_filter=krisp_viva_filter,
audio_in_filter=KrispVivaFilter(),
),
}
@@ -102,7 +105,11 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
context,
user_params=LLMUserAggregatorParams(
user_turn_strategies=UserTurnStrategies(
stop=[TurnAnalyzerUserTurnStopStrategy(turn_analyzer=KrispVivaTurn())]
start=[
KrispVivaIPUserTurnStartStrategy(threshold=0.5),
TranscriptionUserTurnStartStrategy(),
],
stop=[TurnAnalyzerUserTurnStopStrategy(turn_analyzer=KrispVivaTurn())],
),
vad_analyzer=SileroVADAnalyzer(), # or KrispVivaVadAnalyzer
),

View File

@@ -17,7 +17,7 @@ try:
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use the Krisp instance, you need to install krisp_audio.")
raise Exception(f"Missing module: {e}")
raise ImportError(f"Missing module: {e}") from e
# Mapping of sample rates (Hz) to Krisp SDK SamplingRate enums

View File

@@ -7,7 +7,9 @@
"""Krisp turn analyzer for end-of-turn detection using Krisp VIVA SDK.
This module provides a turn analyzer implementation using Krisp's turn detection
(Tt) API to determine when a user has finished speaking in a conversation.
v3 (Tt) API to determine when a user has finished speaking in a conversation.
The Tt API accepts an external VAD flag alongside audio frames, allowing the
model to leverage voice activity information for more accurate turn detection.
Note: This analyzer uses a different model than KrispVivaFilter. The model path
can be specified via the KRISP_VIVA_TURN_MODEL_PATH environment variable or
@@ -33,7 +35,7 @@ try:
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use KrispVivaTurn, you need to install krisp_audio.")
raise Exception(f"Missing module: {e}")
raise ImportError(f"Missing module: {e}") from e
class KrispTurnParams(BaseTurnParams):
@@ -53,8 +55,10 @@ class KrispTurnParams(BaseTurnParams):
class KrispVivaTurn(BaseTurnAnalyzer):
"""Turn analyzer using Krisp VIVA SDK for end-of-turn detection.
Uses Krisp's turn detection (Tt) API to determine when a user has finished
speaking. This analyzer requires a valid Krisp model file to operate.
Uses Krisp's turn detection v3 (Tt) API to determine when a user has
finished speaking. The Tt API receives an external VAD flag with each
audio frame, which the ``is_speech`` parameter of ``append_audio``
provides. This analyzer requires a valid Krisp model file to operate.
"""
def __init__(
@@ -158,14 +162,14 @@ class KrispVivaTurn(BaseTurnAnalyzer):
"""Create a turn detection session with the specified sample rate.
Args:
sample_rate: Sample rate for the session
sample_rate: Sample rate for the session.
Returns:
krisp_audio.TtFloat instance
krisp_audio.TtFloat instance.
Raises:
ValueError: If sample rate or frame duration is not supported
RuntimeError: If session creation fails
ValueError: If sample rate or frame duration is not supported.
RuntimeError: If session creation fails.
"""
try:
model_info = krisp_audio.ModelInfo()
@@ -306,12 +310,7 @@ class KrispVivaTurn(BaseTurnAnalyzer):
# Instead, we wait for the model's probability check below to confirm
# end-of-turn based on the threshold.
prob = self._tt_session.process(frame.tolist())
# Negative values indicate the model is not ready yet (working with 100ms data)
# Skip processing until we get positive probabilities
if prob < 0:
continue
prob = self._tt_session.process(frame.tolist(), is_speech, False)
# Store the probability for external access
self._last_probability = prob

View File

@@ -11,9 +11,15 @@ from .transcription_user_turn_start_strategy import TranscriptionUserTurnStartSt
from .vad_user_turn_start_strategy import VADUserTurnStartStrategy
from .wake_phrase_user_turn_start_strategy import WakePhraseUserTurnStartStrategy
try:
from .krisp_viva_ip_user_turn_start_strategy import KrispVivaIPUserTurnStartStrategy
except ImportError:
KrispVivaIPUserTurnStartStrategy = None
__all__ = [
"BaseUserTurnStartStrategy",
"ExternalUserTurnStartStrategy",
"KrispVivaIPUserTurnStartStrategy",
"MinWordsUserTurnStartStrategy",
"TranscriptionUserTurnStartStrategy",
"UserTurnStartedParams",

View File

@@ -0,0 +1,282 @@
#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""User turn start strategy using Krisp Interruption Prediction (IP).
This strategy uses Krisp's IP model to distinguish genuine user interruptions
from backchannels (e.g. "uh-huh", "yeah"). Instead of triggering a user turn
on every VAD speech event, it collects audio after VAD detects speech and runs
the IP model to predict whether the speech is a real interruption.
Only when the IP model's probability exceeds the configured threshold is
``trigger_user_turn_started()`` called. This prevents the bot from being
interrupted by brief acknowledgements or filler words.
"""
import os
import numpy as np
from loguru import logger
from pipecat.audio.krisp_instance import (
KrispVivaSDKManager,
int_to_krisp_frame_duration,
int_to_krisp_sample_rate,
)
from pipecat.frames.frames import (
BotStoppedSpeakingFrame,
Frame,
InputAudioRawFrame,
VADUserStartedSpeakingFrame,
VADUserStoppedSpeakingFrame,
)
from pipecat.turns.types import ProcessFrameResult
from pipecat.turns.user_start.base_user_turn_start_strategy import BaseUserTurnStartStrategy
try:
import krisp_audio
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error(
"In order to use KrispVivaIPUserTurnStartStrategy, you need to install krisp_audio."
)
raise Exception(f"Missing module: {e}")
class KrispVivaIPUserTurnStartStrategy(BaseUserTurnStartStrategy):
"""User turn start strategy using Krisp VIVA Interruption Prediction.
When VAD detects user speech, this strategy feeds audio frames into
the Krisp VIVA IP model. The model outputs a probability indicating
whether the speech is a genuine interruption (as opposed to a
backchannel). A user turn is triggered only when this probability
exceeds the configured threshold.
This strategy is designed to work alongside other start strategies
(e.g. ``TranscriptionUserTurnStartStrategy`` as a fallback) via the
strategy list in ``UserTurnStrategies``.
Example::
from pipecat.turns.user_start import KrispVivaIPUserTurnStartStrategy
strategies = UserTurnStrategies(
start=[
KrispVivaIPUserTurnStartStrategy(
model_path="/path/to/ip_model.kef",
threshold=0.5,
),
TranscriptionUserTurnStartStrategy(),
],
)
"""
def __init__(
self,
*,
model_path: str | None = None,
threshold: float = 0.5,
frame_duration_ms: int = 20,
api_key: str = "",
**kwargs,
):
"""Initialize the Krisp VIVA IP user turn start strategy.
Args:
model_path: Path to the Krisp VIVA IP model file (.kef). If None,
uses the KRISP_VIVA_IP_MODEL_PATH environment variable.
threshold: IP probability threshold (0.0 to 1.0). When the model's
output exceeds this value, the speech is classified as a genuine
interruption.
frame_duration_ms: Frame duration in milliseconds for IP processing.
Supported values: 10, 15, 20, 30, 32.
api_key: Krisp SDK API key. If empty, falls back to the
KRISP_VIVA_API_KEY environment variable.
**kwargs: Additional arguments passed to BaseUserTurnStartStrategy.
"""
super().__init__(**kwargs)
self._threshold = threshold
self._frame_duration_ms = frame_duration_ms
self._api_key = api_key
self._model_path = model_path or os.getenv("KRISP_VIVA_IP_MODEL_PATH")
if not self._model_path:
raise ValueError(
"IP model path must be provided via model_path or "
"KRISP_VIVA_IP_MODEL_PATH environment variable."
)
if not self._model_path.endswith(".kef"):
raise ValueError("Model is expected with .kef extension")
if not os.path.isfile(self._model_path):
raise FileNotFoundError(f"IP model file not found: {self._model_path}")
self._sdk_acquired = False
self._ip_session = None
self._samples_per_frame: int | None = None
self._sample_rate: int | None = None
# State tracking
self._speech_active = False
self._audio_buffer = bytearray()
self._decision_made = False
# Acquire SDK
try:
KrispVivaSDKManager.acquire(api_key=api_key)
self._sdk_acquired = True
except Exception as e:
raise RuntimeError(f"Failed to initialize Krisp SDK: {e}")
async def cleanup(self):
"""Release Krisp SDK resources."""
if self._sdk_acquired:
try:
self._ip_session = None
KrispVivaSDKManager.release()
self._sdk_acquired = False
except Exception as e:
logger.error(f"Error cleaning up Krisp VIVA IP strategy: {e}", exc_info=True)
def _ensure_session(self, sample_rate: int):
"""Create or re-create the IP session when sample rate changes.
Args:
sample_rate: Audio sample rate in Hz.
"""
if self._sample_rate == sample_rate and self._ip_session is not None:
return
self._sample_rate = sample_rate
self._samples_per_frame = int((sample_rate * self._frame_duration_ms) / 1000)
model_info = krisp_audio.ModelInfo()
model_info.path = self._model_path
ip_cfg = krisp_audio.IpSessionConfig()
ip_cfg.inputSampleRate = int_to_krisp_sample_rate(sample_rate)
ip_cfg.inputFrameDuration = int_to_krisp_frame_duration(self._frame_duration_ms)
ip_cfg.modelInfo = model_info
self._ip_session = krisp_audio.IpFloat.create(ip_cfg)
logger.debug(f"Krisp VIVA IP session created (sample_rate={sample_rate})")
def _reset_state(self):
"""Reset speech tracking state for the next candidate interruption."""
self._speech_active = False
self._audio_buffer.clear()
self._decision_made = False
async def reset(self):
"""Reset the strategy to its initial state."""
await super().reset()
self._reset_state()
async def process_frame(self, frame: Frame) -> ProcessFrameResult:
"""Process a frame to detect genuine user interruptions.
On ``VADUserStartedSpeakingFrame``, begins collecting audio.
On ``InputAudioRawFrame``, feeds audio through the IP model and
triggers a user turn if the interruption probability exceeds the
threshold.
On ``VADUserStoppedSpeakingFrame`` or ``BotStoppedSpeakingFrame``,
resets the candidate state.
Args:
frame: The incoming frame.
Returns:
STOP if a genuine interruption was detected, CONTINUE otherwise.
"""
if isinstance(frame, VADUserStartedSpeakingFrame):
return await self._handle_vad_started(frame)
elif isinstance(frame, InputAudioRawFrame):
return await self._handle_audio(frame)
elif isinstance(frame, (VADUserStoppedSpeakingFrame, BotStoppedSpeakingFrame)):
return await self._handle_reset(frame)
return ProcessFrameResult.CONTINUE
async def _handle_vad_started(self, frame: VADUserStartedSpeakingFrame) -> ProcessFrameResult:
"""Begin collecting audio for interruption classification.
Args:
frame: The VAD speech-start frame.
Returns:
Always CONTINUE; the decision is deferred until enough audio is processed.
"""
logger.trace("Krisp VIVA IP: VAD speech started, collecting audio for classification")
self._speech_active = True
self._audio_buffer.clear()
self._decision_made = False
return ProcessFrameResult.CONTINUE
async def _handle_audio(self, frame: InputAudioRawFrame) -> ProcessFrameResult:
"""Feed audio to the IP model and check for genuine interruption.
Args:
frame: Raw audio input frame.
Returns:
STOP if the model detects a genuine interruption, CONTINUE otherwise.
"""
if not self._speech_active or self._decision_made:
return ProcessFrameResult.CONTINUE
self._ensure_session(frame.sample_rate)
if self._ip_session is None or self._samples_per_frame is None:
logger.warning("IP session not ready, skipping frame")
return ProcessFrameResult.CONTINUE
self._audio_buffer.extend(frame.audio)
total_samples = len(self._audio_buffer) // 2 # 2 bytes per int16 sample
num_complete_frames = total_samples // self._samples_per_frame
if num_complete_frames == 0:
return ProcessFrameResult.CONTINUE
complete_samples_count = num_complete_frames * self._samples_per_frame
bytes_to_process = complete_samples_count * 2
audio_to_process = bytes(self._audio_buffer[:bytes_to_process])
self._audio_buffer = self._audio_buffer[bytes_to_process:]
audio_int16 = np.frombuffer(audio_to_process, dtype=np.int16)
audio_float32 = audio_int16.astype(np.float32) / 32768.0
frames = audio_float32.reshape(-1, self._samples_per_frame)
for ip_frame in frames:
ip_prob = self._ip_session.process(ip_frame.tolist(), self._speech_active)
if ip_prob >= self._threshold:
logger.debug(
f"Krisp VIVA IP: genuine interruption detected (prob={ip_prob:.3f}, "
f"threshold={self._threshold})"
)
self._decision_made = True
await self.trigger_user_turn_started()
return ProcessFrameResult.STOP
return ProcessFrameResult.CONTINUE
async def _handle_reset(
self, frame: VADUserStoppedSpeakingFrame | BotStoppedSpeakingFrame
) -> ProcessFrameResult:
"""Reset state when the candidate interruption window ends.
Args:
frame: The frame signaling end of speech or bot output.
Returns:
Always CONTINUE.
"""
if self._speech_active:
logger.trace("Krisp VIVA IP: speech segment ended, resetting state")
self._reset_state()
return ProcessFrameResult.CONTINUE

View File

@@ -0,0 +1,236 @@
#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import os
import sys
import tempfile
import unittest
from unittest.mock import MagicMock, patch
import numpy as np
# Mock package version check before importing pipecat (development mode)
_version_patcher = patch("importlib.metadata.version", return_value="0.0.0-dev")
_version_patcher.start()
# Mock krisp_audio before any pipecat import that loads krisp_instance / VIVA IP strategy
mock_krisp_audio = MagicMock()
mock_krisp_audio.SamplingRate.Sr8000Hz = 8000
mock_krisp_audio.SamplingRate.Sr16000Hz = 16000
mock_krisp_audio.SamplingRate.Sr24000Hz = 24000
mock_krisp_audio.SamplingRate.Sr32000Hz = 32000
mock_krisp_audio.SamplingRate.Sr44100Hz = 44100
mock_krisp_audio.SamplingRate.Sr48000Hz = 48000
mock_krisp_audio.FrameDuration.Fd10ms = "10ms"
mock_krisp_audio.FrameDuration.Fd15ms = "15ms"
mock_krisp_audio.FrameDuration.Fd20ms = "20ms"
mock_krisp_audio.FrameDuration.Fd30ms = "30ms"
mock_krisp_audio.FrameDuration.Fd32ms = "32ms"
sys.modules["krisp_audio"] = mock_krisp_audio
mock_pipecat_krisp = MagicMock()
sys.modules["pipecat_ai_krisp"] = mock_pipecat_krisp
sys.modules["pipecat_ai_krisp.audio"] = MagicMock()
sys.modules["pipecat_ai_krisp.audio.krisp_processor"] = MagicMock()
from pipecat.frames.frames import (
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
InputAudioRawFrame,
TranscriptionFrame,
VADUserStartedSpeakingFrame,
VADUserStoppedSpeakingFrame,
)
from pipecat.turns.types import ProcessFrameResult
from pipecat.turns.user_start.krisp_viva_ip_user_turn_start_strategy import (
KrispVivaIPUserTurnStartStrategy,
)
STRATEGY_MODULE = "pipecat.turns.user_start.krisp_viva_ip_user_turn_start_strategy"
def _int16_silence(num_samples: int) -> bytes:
return np.zeros(num_samples, dtype=np.int16).tobytes()
class TestKrispVivaIPUserTurnStartStrategy(unittest.IsolatedAsyncioTestCase):
"""Tests for KrispVivaIPUserTurnStartStrategy with mocked krisp_audio."""
def setUp(self):
self.temp_model_file = tempfile.NamedTemporaryFile(suffix=".kef", delete=False)
self.temp_model_file.write(b"dummy")
self.temp_model_file.close()
self.model_path = self.temp_model_file.name
self.mock_krisp_audio = mock_krisp_audio
self.mock_krisp_audio.reset_mock()
self.mock_krisp_audio.ModelInfo.reset_mock()
self.mock_krisp_audio.IpSessionConfig.reset_mock()
self.mock_krisp_audio.IpFloat.reset_mock()
self.mock_model_info = MagicMock()
self.mock_krisp_audio.ModelInfo.return_value = self.mock_model_info
self.mock_ip_cfg = MagicMock()
self.mock_krisp_audio.IpSessionConfig.return_value = self.mock_ip_cfg
self.mock_ip_session = MagicMock()
self.mock_krisp_audio.IpFloat.create.return_value = self.mock_ip_session
self.krisp_patch = patch(f"{STRATEGY_MODULE}.krisp_audio", self.mock_krisp_audio)
self.krisp_patch.start()
self.sdk_patcher = patch(f"{STRATEGY_MODULE}.KrispVivaSDKManager")
self.mock_sdk_manager = self.sdk_patcher.start()
self.mock_sdk_manager.acquire = MagicMock()
self.mock_sdk_manager.release = MagicMock()
def tearDown(self):
self.krisp_patch.stop()
self.sdk_patcher.stop()
if os.path.exists(self.model_path):
os.unlink(self.model_path)
def _make_strategy(self, *, threshold: float = 0.5, frame_duration_ms: int = 20):
return KrispVivaIPUserTurnStartStrategy(
model_path=self.model_path,
threshold=threshold,
frame_duration_ms=frame_duration_ms,
api_key="test-key",
)
def _audio_frame(self, sample_rate: int = 16000, frame_duration_ms: int = 20):
samples = int(sample_rate * frame_duration_ms / 1000)
return InputAudioRawFrame(
audio=_int16_silence(samples),
sample_rate=sample_rate,
num_channels=1,
)
async def test_interruption_detected_emits_turn_and_stop(self):
self.mock_ip_session.process = MagicMock(return_value=0.87)
strategy = self._make_strategy(threshold=0.5)
try:
fired = False
@strategy.event_handler("on_user_turn_started")
async def on_user_turn_started(strategy, params):
nonlocal fired
fired = True
await strategy.process_frame(VADUserStartedSpeakingFrame())
result = await strategy.process_frame(self._audio_frame())
self.assertTrue(fired)
self.assertEqual(result, ProcessFrameResult.STOP)
self.mock_ip_session.process.assert_called()
finally:
await strategy.cleanup()
async def test_backchannel_suppressed_no_event_continue(self):
self.mock_ip_session.process = MagicMock(return_value=0.23)
strategy = self._make_strategy(threshold=0.5)
try:
fired = False
@strategy.event_handler("on_user_turn_started")
async def on_user_turn_started(strategy, params):
nonlocal fired
fired = True
await strategy.process_frame(VADUserStartedSpeakingFrame())
result = await strategy.process_frame(self._audio_frame())
self.assertFalse(fired)
self.assertEqual(result, ProcessFrameResult.CONTINUE)
finally:
await strategy.cleanup()
async def test_reset_on_vad_stopped_clears_state(self):
self.mock_ip_session.process = MagicMock(return_value=0.1)
strategy = self._make_strategy(threshold=0.5)
try:
await strategy.process_frame(VADUserStartedSpeakingFrame())
await strategy.process_frame(self._audio_frame())
self.mock_ip_session.process.reset_mock()
await strategy.process_frame(VADUserStoppedSpeakingFrame())
result = await strategy.process_frame(self._audio_frame())
self.assertEqual(result, ProcessFrameResult.CONTINUE)
self.mock_ip_session.process.assert_not_called()
finally:
await strategy.cleanup()
async def test_reset_on_bot_stopped_clears_state(self):
self.mock_ip_session.process = MagicMock(return_value=0.1)
strategy = self._make_strategy(threshold=0.5)
try:
await strategy.process_frame(VADUserStartedSpeakingFrame())
await strategy.process_frame(self._audio_frame())
self.mock_ip_session.process.reset_mock()
await strategy.process_frame(BotStoppedSpeakingFrame())
result = await strategy.process_frame(self._audio_frame())
self.assertEqual(result, ProcessFrameResult.CONTINUE)
self.mock_ip_session.process.assert_not_called()
finally:
await strategy.cleanup()
async def test_no_op_before_vad_start(self):
self.mock_ip_session.process = MagicMock(return_value=0.99)
strategy = self._make_strategy()
try:
result = await strategy.process_frame(self._audio_frame())
self.assertEqual(result, ProcessFrameResult.CONTINUE)
self.mock_ip_session.process.assert_not_called()
finally:
await strategy.cleanup()
async def test_decision_sticks_no_double_trigger(self):
self.mock_ip_session.process = MagicMock(return_value=0.9)
strategy = self._make_strategy(threshold=0.5)
try:
count = 0
@strategy.event_handler("on_user_turn_started")
async def on_user_turn_started(strategy, params):
nonlocal count
count += 1
await strategy.process_frame(VADUserStartedSpeakingFrame())
r1 = await strategy.process_frame(self._audio_frame())
r2 = await strategy.process_frame(self._audio_frame())
self.assertEqual(r1, ProcessFrameResult.STOP)
self.assertEqual(r2, ProcessFrameResult.CONTINUE)
self.assertEqual(count, 1)
finally:
await strategy.cleanup()
async def test_unrelated_frames_continue(self):
strategy = self._make_strategy()
try:
r1 = await strategy.process_frame(BotStartedSpeakingFrame())
r2 = await strategy.process_frame(
TranscriptionFrame(text="hi", user_id="", timestamp="")
)
self.assertEqual(r1, ProcessFrameResult.CONTINUE)
self.assertEqual(r2, ProcessFrameResult.CONTINUE)
finally:
await strategy.cleanup()
if __name__ == "__main__":
unittest.main()

View File

@@ -62,8 +62,15 @@ class TestKrispVivaSDKManager:
def setup_method(self):
"""Reset mocks and SDK state before each test."""
mock_krisp_audio.reset_mock()
mock_krisp_audio.globalInit.side_effect = None
mock_krisp_audio.getVersion.return_value = mock_version
# Ensure krisp_instance module uses THIS test's mock, not a stale
# reference cached from a different test file's sys.modules entry.
import pipecat.audio.krisp_instance as _ki
_ki.krisp_audio = mock_krisp_audio
# Reset the SDK manager state for clean tests
# We access internal state to ensure tests are isolated
with KrispVivaSDKManager._lock: