VIVA SDK TT v3 support (#4252)
* VIVA SDK TT v3 support * Format fix. * Renamed the API naming, removed '3' from the name. * Implementation of User turn start strategy using Krisp VIVA Interruption Prediction in scope of TT v3 support. * Typo fix in voice-krisp-viva example to use KrispVivaFilter class * style fix. * test run error fixes. * some test related changes. * Fixed tests * Stule fixes.
This commit is contained in:
committed by
GitHub
parent
fc1c3b48dc
commit
4c19f5584c
@@ -4,20 +4,24 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Interruptible bot with Krisp VIVA noise filtering and turn detection.
|
||||
"""Interruptible bot with Krisp VIVA noise filtering, turn detection, and IP.
|
||||
|
||||
This example demonstrates a conversational bot with:
|
||||
- Krisp VIVA noise reduction on incoming audio
|
||||
- Krisp VIVA Turn detection for natural interruptions
|
||||
- Krisp VIVA Turn detection for end-of-turn
|
||||
- Krisp Interruption Prediction (IP) to filter backchannels from real interruptions
|
||||
- Voice activity detection (VAD)
|
||||
|
||||
Required environment variables:
|
||||
- KRISP_VIVA_FILTER_MODEL_PATH: Path to the Krisp noise filter model file (.kef)
|
||||
- KRISP_VIVA_TURN_MODEL_PATH: Path to the Krisp turn detection model file (.kef)
|
||||
- DEEPGRAM_API_KEY: Deepgram API key for STT/TTS
|
||||
- KRISP_VIVA_IP_MODEL_PATH: Path to the Krisp IP model file (.kef)
|
||||
- DEEPGRAM_API_KEY: Deepgram API key for STT
|
||||
- CARTESIA_API_KEY: Cartesia API key for TTS
|
||||
- OPENAI_API_KEY: OpenAI API key for LLM
|
||||
|
||||
Optional environment variables:
|
||||
- KRISP_VIVA_API_KEY: Krisp SDK API key (or set in code)
|
||||
- KRISP_NOISE_SUPPRESSION_LEVEL: Noise suppression level 0-100 (default: 100)
|
||||
Higher values = more aggressive noise reduction
|
||||
"""
|
||||
@@ -49,31 +53,30 @@ from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
from pipecat.turns.user_start import (
|
||||
KrispVivaIPUserTurnStartStrategy,
|
||||
TranscriptionUserTurnStartStrategy,
|
||||
)
|
||||
from pipecat.turns.user_stop import TurnAnalyzerUserTurnStopStrategy
|
||||
from pipecat.turns.user_turn_strategies import UserTurnStrategies
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
# We use lambdas to defer transport parameter creation until the transport
|
||||
# type is selected at runtime.
|
||||
|
||||
krisp_viva_filter = KrispVivaFilter()
|
||||
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
audio_in_filter=krisp_viva_filter,
|
||||
audio_in_filter=KrispVivaFilter(),
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
audio_in_filter=krisp_viva_filter,
|
||||
audio_in_filter=KrispVivaFilter(),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
audio_in_filter=krisp_viva_filter,
|
||||
audio_in_filter=KrispVivaFilter(),
|
||||
),
|
||||
}
|
||||
|
||||
@@ -102,7 +105,11 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(
|
||||
user_turn_strategies=UserTurnStrategies(
|
||||
stop=[TurnAnalyzerUserTurnStopStrategy(turn_analyzer=KrispVivaTurn())]
|
||||
start=[
|
||||
KrispVivaIPUserTurnStartStrategy(threshold=0.5),
|
||||
TranscriptionUserTurnStartStrategy(),
|
||||
],
|
||||
stop=[TurnAnalyzerUserTurnStopStrategy(turn_analyzer=KrispVivaTurn())],
|
||||
),
|
||||
vad_analyzer=SileroVADAnalyzer(), # or KrispVivaVadAnalyzer
|
||||
),
|
||||
|
||||
@@ -17,7 +17,7 @@ try:
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use the Krisp instance, you need to install krisp_audio.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
raise ImportError(f"Missing module: {e}") from e
|
||||
|
||||
|
||||
# Mapping of sample rates (Hz) to Krisp SDK SamplingRate enums
|
||||
|
||||
@@ -7,7 +7,9 @@
|
||||
"""Krisp turn analyzer for end-of-turn detection using Krisp VIVA SDK.
|
||||
|
||||
This module provides a turn analyzer implementation using Krisp's turn detection
|
||||
(Tt) API to determine when a user has finished speaking in a conversation.
|
||||
v3 (Tt) API to determine when a user has finished speaking in a conversation.
|
||||
The Tt API accepts an external VAD flag alongside audio frames, allowing the
|
||||
model to leverage voice activity information for more accurate turn detection.
|
||||
|
||||
Note: This analyzer uses a different model than KrispVivaFilter. The model path
|
||||
can be specified via the KRISP_VIVA_TURN_MODEL_PATH environment variable or
|
||||
@@ -33,7 +35,7 @@ try:
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use KrispVivaTurn, you need to install krisp_audio.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
raise ImportError(f"Missing module: {e}") from e
|
||||
|
||||
|
||||
class KrispTurnParams(BaseTurnParams):
|
||||
@@ -53,8 +55,10 @@ class KrispTurnParams(BaseTurnParams):
|
||||
class KrispVivaTurn(BaseTurnAnalyzer):
|
||||
"""Turn analyzer using Krisp VIVA SDK for end-of-turn detection.
|
||||
|
||||
Uses Krisp's turn detection (Tt) API to determine when a user has finished
|
||||
speaking. This analyzer requires a valid Krisp model file to operate.
|
||||
Uses Krisp's turn detection v3 (Tt) API to determine when a user has
|
||||
finished speaking. The Tt API receives an external VAD flag with each
|
||||
audio frame, which the ``is_speech`` parameter of ``append_audio``
|
||||
provides. This analyzer requires a valid Krisp model file to operate.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -158,14 +162,14 @@ class KrispVivaTurn(BaseTurnAnalyzer):
|
||||
"""Create a turn detection session with the specified sample rate.
|
||||
|
||||
Args:
|
||||
sample_rate: Sample rate for the session
|
||||
sample_rate: Sample rate for the session.
|
||||
|
||||
Returns:
|
||||
krisp_audio.TtFloat instance
|
||||
krisp_audio.TtFloat instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If sample rate or frame duration is not supported
|
||||
RuntimeError: If session creation fails
|
||||
ValueError: If sample rate or frame duration is not supported.
|
||||
RuntimeError: If session creation fails.
|
||||
"""
|
||||
try:
|
||||
model_info = krisp_audio.ModelInfo()
|
||||
@@ -306,12 +310,7 @@ class KrispVivaTurn(BaseTurnAnalyzer):
|
||||
# Instead, we wait for the model's probability check below to confirm
|
||||
# end-of-turn based on the threshold.
|
||||
|
||||
prob = self._tt_session.process(frame.tolist())
|
||||
|
||||
# Negative values indicate the model is not ready yet (working with 100ms data)
|
||||
# Skip processing until we get positive probabilities
|
||||
if prob < 0:
|
||||
continue
|
||||
prob = self._tt_session.process(frame.tolist(), is_speech, False)
|
||||
|
||||
# Store the probability for external access
|
||||
self._last_probability = prob
|
||||
|
||||
@@ -11,9 +11,15 @@ from .transcription_user_turn_start_strategy import TranscriptionUserTurnStartSt
|
||||
from .vad_user_turn_start_strategy import VADUserTurnStartStrategy
|
||||
from .wake_phrase_user_turn_start_strategy import WakePhraseUserTurnStartStrategy
|
||||
|
||||
try:
|
||||
from .krisp_viva_ip_user_turn_start_strategy import KrispVivaIPUserTurnStartStrategy
|
||||
except ImportError:
|
||||
KrispVivaIPUserTurnStartStrategy = None
|
||||
|
||||
__all__ = [
|
||||
"BaseUserTurnStartStrategy",
|
||||
"ExternalUserTurnStartStrategy",
|
||||
"KrispVivaIPUserTurnStartStrategy",
|
||||
"MinWordsUserTurnStartStrategy",
|
||||
"TranscriptionUserTurnStartStrategy",
|
||||
"UserTurnStartedParams",
|
||||
|
||||
@@ -0,0 +1,282 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""User turn start strategy using Krisp Interruption Prediction (IP).
|
||||
|
||||
This strategy uses Krisp's IP model to distinguish genuine user interruptions
|
||||
from backchannels (e.g. "uh-huh", "yeah"). Instead of triggering a user turn
|
||||
on every VAD speech event, it collects audio after VAD detects speech and runs
|
||||
the IP model to predict whether the speech is a real interruption.
|
||||
|
||||
Only when the IP model's probability exceeds the configured threshold is
|
||||
``trigger_user_turn_started()`` called. This prevents the bot from being
|
||||
interrupted by brief acknowledgements or filler words.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.krisp_instance import (
|
||||
KrispVivaSDKManager,
|
||||
int_to_krisp_frame_duration,
|
||||
int_to_krisp_sample_rate,
|
||||
)
|
||||
from pipecat.frames.frames import (
|
||||
BotStoppedSpeakingFrame,
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.turns.types import ProcessFrameResult
|
||||
from pipecat.turns.user_start.base_user_turn_start_strategy import BaseUserTurnStartStrategy
|
||||
|
||||
try:
|
||||
import krisp_audio
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use KrispVivaIPUserTurnStartStrategy, you need to install krisp_audio."
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class KrispVivaIPUserTurnStartStrategy(BaseUserTurnStartStrategy):
|
||||
"""User turn start strategy using Krisp VIVA Interruption Prediction.
|
||||
|
||||
When VAD detects user speech, this strategy feeds audio frames into
|
||||
the Krisp VIVA IP model. The model outputs a probability indicating
|
||||
whether the speech is a genuine interruption (as opposed to a
|
||||
backchannel). A user turn is triggered only when this probability
|
||||
exceeds the configured threshold.
|
||||
|
||||
This strategy is designed to work alongside other start strategies
|
||||
(e.g. ``TranscriptionUserTurnStartStrategy`` as a fallback) via the
|
||||
strategy list in ``UserTurnStrategies``.
|
||||
|
||||
Example::
|
||||
|
||||
from pipecat.turns.user_start import KrispVivaIPUserTurnStartStrategy
|
||||
|
||||
strategies = UserTurnStrategies(
|
||||
start=[
|
||||
KrispVivaIPUserTurnStartStrategy(
|
||||
model_path="/path/to/ip_model.kef",
|
||||
threshold=0.5,
|
||||
),
|
||||
TranscriptionUserTurnStartStrategy(),
|
||||
],
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model_path: str | None = None,
|
||||
threshold: float = 0.5,
|
||||
frame_duration_ms: int = 20,
|
||||
api_key: str = "",
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Krisp VIVA IP user turn start strategy.
|
||||
|
||||
Args:
|
||||
model_path: Path to the Krisp VIVA IP model file (.kef). If None,
|
||||
uses the KRISP_VIVA_IP_MODEL_PATH environment variable.
|
||||
threshold: IP probability threshold (0.0 to 1.0). When the model's
|
||||
output exceeds this value, the speech is classified as a genuine
|
||||
interruption.
|
||||
frame_duration_ms: Frame duration in milliseconds for IP processing.
|
||||
Supported values: 10, 15, 20, 30, 32.
|
||||
api_key: Krisp SDK API key. If empty, falls back to the
|
||||
KRISP_VIVA_API_KEY environment variable.
|
||||
**kwargs: Additional arguments passed to BaseUserTurnStartStrategy.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._threshold = threshold
|
||||
self._frame_duration_ms = frame_duration_ms
|
||||
self._api_key = api_key
|
||||
|
||||
self._model_path = model_path or os.getenv("KRISP_VIVA_IP_MODEL_PATH")
|
||||
if not self._model_path:
|
||||
raise ValueError(
|
||||
"IP model path must be provided via model_path or "
|
||||
"KRISP_VIVA_IP_MODEL_PATH environment variable."
|
||||
)
|
||||
if not self._model_path.endswith(".kef"):
|
||||
raise ValueError("Model is expected with .kef extension")
|
||||
if not os.path.isfile(self._model_path):
|
||||
raise FileNotFoundError(f"IP model file not found: {self._model_path}")
|
||||
|
||||
self._sdk_acquired = False
|
||||
self._ip_session = None
|
||||
self._samples_per_frame: int | None = None
|
||||
self._sample_rate: int | None = None
|
||||
|
||||
# State tracking
|
||||
self._speech_active = False
|
||||
self._audio_buffer = bytearray()
|
||||
self._decision_made = False
|
||||
|
||||
# Acquire SDK
|
||||
try:
|
||||
KrispVivaSDKManager.acquire(api_key=api_key)
|
||||
self._sdk_acquired = True
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to initialize Krisp SDK: {e}")
|
||||
|
||||
async def cleanup(self):
|
||||
"""Release Krisp SDK resources."""
|
||||
if self._sdk_acquired:
|
||||
try:
|
||||
self._ip_session = None
|
||||
KrispVivaSDKManager.release()
|
||||
self._sdk_acquired = False
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up Krisp VIVA IP strategy: {e}", exc_info=True)
|
||||
|
||||
def _ensure_session(self, sample_rate: int):
|
||||
"""Create or re-create the IP session when sample rate changes.
|
||||
|
||||
Args:
|
||||
sample_rate: Audio sample rate in Hz.
|
||||
"""
|
||||
if self._sample_rate == sample_rate and self._ip_session is not None:
|
||||
return
|
||||
|
||||
self._sample_rate = sample_rate
|
||||
self._samples_per_frame = int((sample_rate * self._frame_duration_ms) / 1000)
|
||||
|
||||
model_info = krisp_audio.ModelInfo()
|
||||
model_info.path = self._model_path
|
||||
|
||||
ip_cfg = krisp_audio.IpSessionConfig()
|
||||
ip_cfg.inputSampleRate = int_to_krisp_sample_rate(sample_rate)
|
||||
ip_cfg.inputFrameDuration = int_to_krisp_frame_duration(self._frame_duration_ms)
|
||||
ip_cfg.modelInfo = model_info
|
||||
|
||||
self._ip_session = krisp_audio.IpFloat.create(ip_cfg)
|
||||
logger.debug(f"Krisp VIVA IP session created (sample_rate={sample_rate})")
|
||||
|
||||
def _reset_state(self):
|
||||
"""Reset speech tracking state for the next candidate interruption."""
|
||||
self._speech_active = False
|
||||
self._audio_buffer.clear()
|
||||
self._decision_made = False
|
||||
|
||||
async def reset(self):
|
||||
"""Reset the strategy to its initial state."""
|
||||
await super().reset()
|
||||
self._reset_state()
|
||||
|
||||
async def process_frame(self, frame: Frame) -> ProcessFrameResult:
|
||||
"""Process a frame to detect genuine user interruptions.
|
||||
|
||||
On ``VADUserStartedSpeakingFrame``, begins collecting audio.
|
||||
On ``InputAudioRawFrame``, feeds audio through the IP model and
|
||||
triggers a user turn if the interruption probability exceeds the
|
||||
threshold.
|
||||
On ``VADUserStoppedSpeakingFrame`` or ``BotStoppedSpeakingFrame``,
|
||||
resets the candidate state.
|
||||
|
||||
Args:
|
||||
frame: The incoming frame.
|
||||
|
||||
Returns:
|
||||
STOP if a genuine interruption was detected, CONTINUE otherwise.
|
||||
"""
|
||||
if isinstance(frame, VADUserStartedSpeakingFrame):
|
||||
return await self._handle_vad_started(frame)
|
||||
elif isinstance(frame, InputAudioRawFrame):
|
||||
return await self._handle_audio(frame)
|
||||
elif isinstance(frame, (VADUserStoppedSpeakingFrame, BotStoppedSpeakingFrame)):
|
||||
return await self._handle_reset(frame)
|
||||
|
||||
return ProcessFrameResult.CONTINUE
|
||||
|
||||
async def _handle_vad_started(self, frame: VADUserStartedSpeakingFrame) -> ProcessFrameResult:
|
||||
"""Begin collecting audio for interruption classification.
|
||||
|
||||
Args:
|
||||
frame: The VAD speech-start frame.
|
||||
|
||||
Returns:
|
||||
Always CONTINUE; the decision is deferred until enough audio is processed.
|
||||
"""
|
||||
logger.trace("Krisp VIVA IP: VAD speech started, collecting audio for classification")
|
||||
self._speech_active = True
|
||||
self._audio_buffer.clear()
|
||||
self._decision_made = False
|
||||
return ProcessFrameResult.CONTINUE
|
||||
|
||||
async def _handle_audio(self, frame: InputAudioRawFrame) -> ProcessFrameResult:
|
||||
"""Feed audio to the IP model and check for genuine interruption.
|
||||
|
||||
Args:
|
||||
frame: Raw audio input frame.
|
||||
|
||||
Returns:
|
||||
STOP if the model detects a genuine interruption, CONTINUE otherwise.
|
||||
"""
|
||||
if not self._speech_active or self._decision_made:
|
||||
return ProcessFrameResult.CONTINUE
|
||||
|
||||
self._ensure_session(frame.sample_rate)
|
||||
|
||||
if self._ip_session is None or self._samples_per_frame is None:
|
||||
logger.warning("IP session not ready, skipping frame")
|
||||
return ProcessFrameResult.CONTINUE
|
||||
|
||||
self._audio_buffer.extend(frame.audio)
|
||||
|
||||
total_samples = len(self._audio_buffer) // 2 # 2 bytes per int16 sample
|
||||
num_complete_frames = total_samples // self._samples_per_frame
|
||||
|
||||
if num_complete_frames == 0:
|
||||
return ProcessFrameResult.CONTINUE
|
||||
|
||||
complete_samples_count = num_complete_frames * self._samples_per_frame
|
||||
bytes_to_process = complete_samples_count * 2
|
||||
|
||||
audio_to_process = bytes(self._audio_buffer[:bytes_to_process])
|
||||
self._audio_buffer = self._audio_buffer[bytes_to_process:]
|
||||
|
||||
audio_int16 = np.frombuffer(audio_to_process, dtype=np.int16)
|
||||
audio_float32 = audio_int16.astype(np.float32) / 32768.0
|
||||
frames = audio_float32.reshape(-1, self._samples_per_frame)
|
||||
|
||||
for ip_frame in frames:
|
||||
ip_prob = self._ip_session.process(ip_frame.tolist(), self._speech_active)
|
||||
|
||||
if ip_prob >= self._threshold:
|
||||
logger.debug(
|
||||
f"Krisp VIVA IP: genuine interruption detected (prob={ip_prob:.3f}, "
|
||||
f"threshold={self._threshold})"
|
||||
)
|
||||
self._decision_made = True
|
||||
await self.trigger_user_turn_started()
|
||||
return ProcessFrameResult.STOP
|
||||
|
||||
return ProcessFrameResult.CONTINUE
|
||||
|
||||
async def _handle_reset(
|
||||
self, frame: VADUserStoppedSpeakingFrame | BotStoppedSpeakingFrame
|
||||
) -> ProcessFrameResult:
|
||||
"""Reset state when the candidate interruption window ends.
|
||||
|
||||
Args:
|
||||
frame: The frame signaling end of speech or bot output.
|
||||
|
||||
Returns:
|
||||
Always CONTINUE.
|
||||
"""
|
||||
if self._speech_active:
|
||||
logger.trace("Krisp VIVA IP: speech segment ended, resetting state")
|
||||
self._reset_state()
|
||||
return ProcessFrameResult.CONTINUE
|
||||
236
tests/test_krisp_ip_user_turn_start_strategy.py
Normal file
236
tests/test_krisp_ip_user_turn_start_strategy.py
Normal file
@@ -0,0 +1,236 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Mock package version check before importing pipecat (development mode)
|
||||
_version_patcher = patch("importlib.metadata.version", return_value="0.0.0-dev")
|
||||
_version_patcher.start()
|
||||
|
||||
# Mock krisp_audio before any pipecat import that loads krisp_instance / VIVA IP strategy
|
||||
mock_krisp_audio = MagicMock()
|
||||
mock_krisp_audio.SamplingRate.Sr8000Hz = 8000
|
||||
mock_krisp_audio.SamplingRate.Sr16000Hz = 16000
|
||||
mock_krisp_audio.SamplingRate.Sr24000Hz = 24000
|
||||
mock_krisp_audio.SamplingRate.Sr32000Hz = 32000
|
||||
mock_krisp_audio.SamplingRate.Sr44100Hz = 44100
|
||||
mock_krisp_audio.SamplingRate.Sr48000Hz = 48000
|
||||
mock_krisp_audio.FrameDuration.Fd10ms = "10ms"
|
||||
mock_krisp_audio.FrameDuration.Fd15ms = "15ms"
|
||||
mock_krisp_audio.FrameDuration.Fd20ms = "20ms"
|
||||
mock_krisp_audio.FrameDuration.Fd30ms = "30ms"
|
||||
mock_krisp_audio.FrameDuration.Fd32ms = "32ms"
|
||||
|
||||
sys.modules["krisp_audio"] = mock_krisp_audio
|
||||
|
||||
mock_pipecat_krisp = MagicMock()
|
||||
sys.modules["pipecat_ai_krisp"] = mock_pipecat_krisp
|
||||
sys.modules["pipecat_ai_krisp.audio"] = MagicMock()
|
||||
sys.modules["pipecat_ai_krisp.audio.krisp_processor"] = MagicMock()
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
InputAudioRawFrame,
|
||||
TranscriptionFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.turns.types import ProcessFrameResult
|
||||
from pipecat.turns.user_start.krisp_viva_ip_user_turn_start_strategy import (
|
||||
KrispVivaIPUserTurnStartStrategy,
|
||||
)
|
||||
|
||||
STRATEGY_MODULE = "pipecat.turns.user_start.krisp_viva_ip_user_turn_start_strategy"
|
||||
|
||||
|
||||
def _int16_silence(num_samples: int) -> bytes:
|
||||
return np.zeros(num_samples, dtype=np.int16).tobytes()
|
||||
|
||||
|
||||
class TestKrispVivaIPUserTurnStartStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
"""Tests for KrispVivaIPUserTurnStartStrategy with mocked krisp_audio."""
|
||||
|
||||
def setUp(self):
|
||||
self.temp_model_file = tempfile.NamedTemporaryFile(suffix=".kef", delete=False)
|
||||
self.temp_model_file.write(b"dummy")
|
||||
self.temp_model_file.close()
|
||||
self.model_path = self.temp_model_file.name
|
||||
|
||||
self.mock_krisp_audio = mock_krisp_audio
|
||||
self.mock_krisp_audio.reset_mock()
|
||||
self.mock_krisp_audio.ModelInfo.reset_mock()
|
||||
self.mock_krisp_audio.IpSessionConfig.reset_mock()
|
||||
self.mock_krisp_audio.IpFloat.reset_mock()
|
||||
|
||||
self.mock_model_info = MagicMock()
|
||||
self.mock_krisp_audio.ModelInfo.return_value = self.mock_model_info
|
||||
|
||||
self.mock_ip_cfg = MagicMock()
|
||||
self.mock_krisp_audio.IpSessionConfig.return_value = self.mock_ip_cfg
|
||||
|
||||
self.mock_ip_session = MagicMock()
|
||||
self.mock_krisp_audio.IpFloat.create.return_value = self.mock_ip_session
|
||||
|
||||
self.krisp_patch = patch(f"{STRATEGY_MODULE}.krisp_audio", self.mock_krisp_audio)
|
||||
self.krisp_patch.start()
|
||||
|
||||
self.sdk_patcher = patch(f"{STRATEGY_MODULE}.KrispVivaSDKManager")
|
||||
self.mock_sdk_manager = self.sdk_patcher.start()
|
||||
self.mock_sdk_manager.acquire = MagicMock()
|
||||
self.mock_sdk_manager.release = MagicMock()
|
||||
|
||||
def tearDown(self):
|
||||
self.krisp_patch.stop()
|
||||
self.sdk_patcher.stop()
|
||||
if os.path.exists(self.model_path):
|
||||
os.unlink(self.model_path)
|
||||
|
||||
def _make_strategy(self, *, threshold: float = 0.5, frame_duration_ms: int = 20):
|
||||
return KrispVivaIPUserTurnStartStrategy(
|
||||
model_path=self.model_path,
|
||||
threshold=threshold,
|
||||
frame_duration_ms=frame_duration_ms,
|
||||
api_key="test-key",
|
||||
)
|
||||
|
||||
def _audio_frame(self, sample_rate: int = 16000, frame_duration_ms: int = 20):
|
||||
samples = int(sample_rate * frame_duration_ms / 1000)
|
||||
return InputAudioRawFrame(
|
||||
audio=_int16_silence(samples),
|
||||
sample_rate=sample_rate,
|
||||
num_channels=1,
|
||||
)
|
||||
|
||||
async def test_interruption_detected_emits_turn_and_stop(self):
|
||||
self.mock_ip_session.process = MagicMock(return_value=0.87)
|
||||
|
||||
strategy = self._make_strategy(threshold=0.5)
|
||||
try:
|
||||
fired = False
|
||||
|
||||
@strategy.event_handler("on_user_turn_started")
|
||||
async def on_user_turn_started(strategy, params):
|
||||
nonlocal fired
|
||||
fired = True
|
||||
|
||||
await strategy.process_frame(VADUserStartedSpeakingFrame())
|
||||
result = await strategy.process_frame(self._audio_frame())
|
||||
|
||||
self.assertTrue(fired)
|
||||
self.assertEqual(result, ProcessFrameResult.STOP)
|
||||
self.mock_ip_session.process.assert_called()
|
||||
finally:
|
||||
await strategy.cleanup()
|
||||
|
||||
async def test_backchannel_suppressed_no_event_continue(self):
|
||||
self.mock_ip_session.process = MagicMock(return_value=0.23)
|
||||
|
||||
strategy = self._make_strategy(threshold=0.5)
|
||||
try:
|
||||
fired = False
|
||||
|
||||
@strategy.event_handler("on_user_turn_started")
|
||||
async def on_user_turn_started(strategy, params):
|
||||
nonlocal fired
|
||||
fired = True
|
||||
|
||||
await strategy.process_frame(VADUserStartedSpeakingFrame())
|
||||
result = await strategy.process_frame(self._audio_frame())
|
||||
|
||||
self.assertFalse(fired)
|
||||
self.assertEqual(result, ProcessFrameResult.CONTINUE)
|
||||
finally:
|
||||
await strategy.cleanup()
|
||||
|
||||
async def test_reset_on_vad_stopped_clears_state(self):
|
||||
self.mock_ip_session.process = MagicMock(return_value=0.1)
|
||||
|
||||
strategy = self._make_strategy(threshold=0.5)
|
||||
try:
|
||||
await strategy.process_frame(VADUserStartedSpeakingFrame())
|
||||
await strategy.process_frame(self._audio_frame())
|
||||
self.mock_ip_session.process.reset_mock()
|
||||
|
||||
await strategy.process_frame(VADUserStoppedSpeakingFrame())
|
||||
result = await strategy.process_frame(self._audio_frame())
|
||||
|
||||
self.assertEqual(result, ProcessFrameResult.CONTINUE)
|
||||
self.mock_ip_session.process.assert_not_called()
|
||||
finally:
|
||||
await strategy.cleanup()
|
||||
|
||||
async def test_reset_on_bot_stopped_clears_state(self):
|
||||
self.mock_ip_session.process = MagicMock(return_value=0.1)
|
||||
|
||||
strategy = self._make_strategy(threshold=0.5)
|
||||
try:
|
||||
await strategy.process_frame(VADUserStartedSpeakingFrame())
|
||||
await strategy.process_frame(self._audio_frame())
|
||||
self.mock_ip_session.process.reset_mock()
|
||||
|
||||
await strategy.process_frame(BotStoppedSpeakingFrame())
|
||||
result = await strategy.process_frame(self._audio_frame())
|
||||
|
||||
self.assertEqual(result, ProcessFrameResult.CONTINUE)
|
||||
self.mock_ip_session.process.assert_not_called()
|
||||
finally:
|
||||
await strategy.cleanup()
|
||||
|
||||
async def test_no_op_before_vad_start(self):
|
||||
self.mock_ip_session.process = MagicMock(return_value=0.99)
|
||||
|
||||
strategy = self._make_strategy()
|
||||
try:
|
||||
result = await strategy.process_frame(self._audio_frame())
|
||||
self.assertEqual(result, ProcessFrameResult.CONTINUE)
|
||||
self.mock_ip_session.process.assert_not_called()
|
||||
finally:
|
||||
await strategy.cleanup()
|
||||
|
||||
async def test_decision_sticks_no_double_trigger(self):
|
||||
self.mock_ip_session.process = MagicMock(return_value=0.9)
|
||||
|
||||
strategy = self._make_strategy(threshold=0.5)
|
||||
try:
|
||||
count = 0
|
||||
|
||||
@strategy.event_handler("on_user_turn_started")
|
||||
async def on_user_turn_started(strategy, params):
|
||||
nonlocal count
|
||||
count += 1
|
||||
|
||||
await strategy.process_frame(VADUserStartedSpeakingFrame())
|
||||
r1 = await strategy.process_frame(self._audio_frame())
|
||||
r2 = await strategy.process_frame(self._audio_frame())
|
||||
|
||||
self.assertEqual(r1, ProcessFrameResult.STOP)
|
||||
self.assertEqual(r2, ProcessFrameResult.CONTINUE)
|
||||
self.assertEqual(count, 1)
|
||||
finally:
|
||||
await strategy.cleanup()
|
||||
|
||||
async def test_unrelated_frames_continue(self):
|
||||
strategy = self._make_strategy()
|
||||
try:
|
||||
r1 = await strategy.process_frame(BotStartedSpeakingFrame())
|
||||
r2 = await strategy.process_frame(
|
||||
TranscriptionFrame(text="hi", user_id="", timestamp="")
|
||||
)
|
||||
self.assertEqual(r1, ProcessFrameResult.CONTINUE)
|
||||
self.assertEqual(r2, ProcessFrameResult.CONTINUE)
|
||||
finally:
|
||||
await strategy.cleanup()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -62,8 +62,15 @@ class TestKrispVivaSDKManager:
|
||||
def setup_method(self):
|
||||
"""Reset mocks and SDK state before each test."""
|
||||
mock_krisp_audio.reset_mock()
|
||||
mock_krisp_audio.globalInit.side_effect = None
|
||||
mock_krisp_audio.getVersion.return_value = mock_version
|
||||
|
||||
# Ensure krisp_instance module uses THIS test's mock, not a stale
|
||||
# reference cached from a different test file's sys.modules entry.
|
||||
import pipecat.audio.krisp_instance as _ki
|
||||
|
||||
_ki.krisp_audio = mock_krisp_audio
|
||||
|
||||
# Reset the SDK manager state for clean tests
|
||||
# We access internal state to ensure tests are isolated
|
||||
with KrispVivaSDKManager._lock:
|
||||
|
||||
Reference in New Issue
Block a user