Improve user turn stop timing by triggering timeout from VAD stop
Refactor TranscriptionUserTurnStopStrategy and TurnAnalyzerUserTurnStopStrategy to use VADUserStoppedSpeakingFrame as the ground truth for when speech ended, rather than triggering timeouts from transcription frames.
This commit is contained in:
1
changelog/3637.added.3.md
Normal file
1
changelog/3637.added.3.md
Normal file
@@ -0,0 +1 @@
|
||||
- Added `RequestMetadataFrame` and metadata handling for `ServiceSwitcher` to ensure STT services correctly emit `STTMetadataFrame` when switching between services. Only the active service's metadata is propagated downstream, switching services triggers the newly active service to re-emit its metadata, and proper frame ordering is maintained at startup.
|
||||
6
changelog/3637.added.md
Normal file
6
changelog/3637.added.md
Normal file
@@ -0,0 +1,6 @@
|
||||
- Added `STTMetadataFrame` to broadcast STT service latency information at pipeline start.
|
||||
- STT services broadcast P99 time-to-final-segment (`ttfs_p99_latency`) to downstream processors
|
||||
- Turn stop strategies automatically configure their STT timeout from this metadata
|
||||
- Developers can override `ttfs_p99_latency` via constructor argument for custom deployments
|
||||
- Added measured P99 values for STT providers.
|
||||
- See [stt-benchmark](https://github.com/pipecat-ai/stt-benchmark) to measure latency for your configuration
|
||||
5
changelog/3637.changed.2.md
Normal file
5
changelog/3637.changed.2.md
Normal file
@@ -0,0 +1,5 @@
|
||||
- Improved user turn stop timing in `TranscriptionUserTurnStopStrategy` and `TurnAnalyzerUserTurnStopStrategy`.
|
||||
- Timeout now starts on `VADUserStoppedSpeakingFrame` for tighter, more predictable timing
|
||||
- Added support for finalized transcripts (`TranscriptionFrame.finalized=True`) to trigger earlier
|
||||
- Added fallback timeout for edge cases where transcripts arrive without VAD events
|
||||
- Removed `InterimTranscriptionFrame` handling (no longer affects timing)
|
||||
1
changelog/3637.changed.3.md
Normal file
1
changelog/3637.changed.3.md
Normal file
@@ -0,0 +1 @@
|
||||
- Updated the `VADUserStartedSpeakingFrame` to include `start_secs` and `timestamp` and `VADUserStoppedSpeakingFrame` to include `stop_secs` and `timestamp`, removing the need to separately handle the `SpeechControlParamsFrame` for VADParams values.
|
||||
1
changelog/3637.changed.4.md
Normal file
1
changelog/3637.changed.4.md
Normal file
@@ -0,0 +1 @@
|
||||
- ⚠️ Renamed `TranscriptionUserTurnStopStrategy` to `SpeechTimeoutUserTurnStopStrategy`. The old name is deprecated and will be removed in a future release.
|
||||
1
changelog/3637.changed.md
Normal file
1
changelog/3637.changed.md
Normal file
@@ -0,0 +1 @@
|
||||
- ⚠️ Renamed `timeout` parameter to `user_speech_timeout` in `TranscriptionUserTurnStopStrategy`.
|
||||
1
changelog/3637.removed.md
Normal file
1
changelog/3637.removed.md
Normal file
@@ -0,0 +1 @@
|
||||
- ⚠️ Removed `timeout` parameter from `TurnAnalyzerUserTurnStopStrategy`. The timeout is now managed internally based on STT latency.
|
||||
@@ -12,6 +12,7 @@ and LLM processing.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
@@ -1270,16 +1271,32 @@ class EmulateUserStoppedSpeakingFrame(SystemFrame):
|
||||
|
||||
@dataclass
|
||||
class VADUserStartedSpeakingFrame(SystemFrame):
|
||||
"""Frame emitted when VAD definitively detects user started speaking."""
|
||||
"""Frame emitted when VAD definitively detects user started speaking.
|
||||
|
||||
pass
|
||||
Parameters:
|
||||
start_secs: The VAD start_secs duration that was used to confirm the user
|
||||
started speaking. This represents the speech duration that had to
|
||||
elapse before the VAD determined speech began.
|
||||
timestamp: Wall-clock time when the VAD made its determination.
|
||||
"""
|
||||
|
||||
start_secs: float = 0.0
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VADUserStoppedSpeakingFrame(SystemFrame):
|
||||
"""Frame emitted when VAD definitively detects user stopped speaking."""
|
||||
"""Frame emitted when VAD definitively detects user stopped speaking.
|
||||
|
||||
pass
|
||||
Parameters:
|
||||
stop_secs: The VAD stop_secs duration that was used to confirm the user
|
||||
stopped speaking. This represents the silence duration that had to
|
||||
elapse before the VAD determined speech ended.
|
||||
timestamp: Wall-clock time when the VAD made its determination.
|
||||
"""
|
||||
|
||||
stop_secs: float = 0.0
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -1651,6 +1668,49 @@ class SpeechControlParamsFrame(SystemFrame):
|
||||
turn_params: Optional[BaseTurnParams] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServiceMetadataFrame(SystemFrame):
|
||||
"""Base metadata frame for services.
|
||||
|
||||
Broadcast by services at pipeline start to share service-specific
|
||||
configuration and performance characteristics with downstream processors.
|
||||
|
||||
Parameters:
|
||||
service_name: The name of the service broadcasting this metadata.
|
||||
"""
|
||||
|
||||
service_name: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class STTMetadataFrame(ServiceMetadataFrame):
|
||||
"""Metadata from STT service.
|
||||
|
||||
Broadcast by STT services to inform downstream processors (like turn
|
||||
strategies) about STT latency characteristics.
|
||||
|
||||
Parameters:
|
||||
ttfs_p99_latency: Time to final segment P99 latency in seconds.
|
||||
This is the expected time from when speech ends to when the
|
||||
final transcript is received, at the 99th percentile.
|
||||
"""
|
||||
|
||||
ttfs_p99_latency: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestMetadataFrame(ControlFrame):
|
||||
"""Request services to re-emit their metadata frames.
|
||||
|
||||
Used by ServiceSwitcher when switching active services to ensure
|
||||
downstream processors receive updated metadata from the newly active service.
|
||||
Services that receive this frame should re-push their metadata frame
|
||||
(e.g., STTMetadataFrame for STT services).
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
#
|
||||
# Task frames
|
||||
#
|
||||
|
||||
@@ -79,7 +79,7 @@ class UserBotLatencyLogObserver(BaseObserver):
|
||||
if isinstance(data.frame, VADUserStartedSpeakingFrame):
|
||||
self._user_stopped_time = 0
|
||||
elif isinstance(data.frame, VADUserStoppedSpeakingFrame):
|
||||
self._user_stopped_time = time.time()
|
||||
self._user_stopped_time = data.frame.timestamp
|
||||
elif isinstance(data.frame, (EndFrame, CancelFrame)):
|
||||
self._log_summary()
|
||||
elif isinstance(data.frame, BotStartedSpeakingFrame) and self._user_stopped_time:
|
||||
|
||||
@@ -73,7 +73,7 @@ class UserBotLatencyObserver(BaseObserver):
|
||||
self._user_stopped_time = None
|
||||
elif isinstance(data.frame, VADUserStoppedSpeakingFrame):
|
||||
# Record timestamp when user stops speaking
|
||||
self._user_stopped_time = time.time()
|
||||
self._user_stopped_time = data.frame.timestamp
|
||||
elif isinstance(data.frame, BotStartedSpeakingFrame) and self._user_stopped_time:
|
||||
# Calculate and emit latency
|
||||
latency = time.time() - self._user_stopped_time
|
||||
|
||||
@@ -13,6 +13,8 @@ from pipecat.frames.frames import (
|
||||
ControlFrame,
|
||||
Frame,
|
||||
ManuallySwitchServiceFrame,
|
||||
RequestMetadataFrame,
|
||||
ServiceMetadataFrame,
|
||||
ServiceSwitcherFrame,
|
||||
)
|
||||
from pipecat.pipeline.parallel_pipeline import ParallelPipeline
|
||||
@@ -123,7 +125,18 @@ class ServiceSwitcher(ParallelPipeline, Generic[StrategyType]):
|
||||
self.strategy = strategy
|
||||
|
||||
class ServiceSwitcherFilter(FunctionFilter):
|
||||
"""An internal filter that allows frames to pass through to the wrapped service only if it's the active service."""
|
||||
"""An internal filter that gates frame flow based on active service.
|
||||
|
||||
Two filters "sandwich" each service, allowing frames through only
|
||||
when the wrapped service is active. The pipeline layout is::
|
||||
|
||||
DownstreamFilter → Service → UpstreamFilter
|
||||
|
||||
The filter names refer to which *direction* of frame flow they
|
||||
filter, not their physical position: the downstream filter sits
|
||||
*before* the service (filtering frames flowing into it) and the
|
||||
upstream filter sits *after* it (filtering frames flowing back out).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -136,7 +149,9 @@ class ServiceSwitcher(ParallelPipeline, Generic[StrategyType]):
|
||||
Args:
|
||||
wrapped_service: The service that this filter wraps.
|
||||
active_service: The currently active service.
|
||||
direction: The direction of frame flow to filter.
|
||||
direction: The direction of frame flow this filter gates
|
||||
(DOWNSTREAM for the filter before the service,
|
||||
UPSTREAM for the filter after it).
|
||||
"""
|
||||
self._wrapped_service = wrapped_service
|
||||
self._active_service = active_service
|
||||
@@ -149,19 +164,54 @@ class ServiceSwitcher(ParallelPipeline, Generic[StrategyType]):
|
||||
async def process_frame(self, frame, direction):
|
||||
"""Process a frame through the filter, handling special internal filter-updating frames."""
|
||||
if isinstance(frame, ServiceSwitcher.ServiceSwitcherFilterFrame):
|
||||
old_active = self._active_service
|
||||
self._active_service = frame.active_service
|
||||
# Two ServiceSwitcherFilters "sandwich" a service. Push the
|
||||
# frame only to update the other side of the sandwich, but
|
||||
# otherwise don't let it leave the sandwich.
|
||||
# Two ServiceSwitcherFilters "sandwich" a service. The
|
||||
# frame enters via the downstream filter first. Push it
|
||||
# through so the upstream filter also updates its state.
|
||||
if direction == self._direction:
|
||||
await self.push_frame(frame, direction)
|
||||
# This is the upstream filter (the second to update). At
|
||||
# this point both filters know the new active service, so
|
||||
# it's safe to request metadata — the resulting
|
||||
# ServiceMetadataFrame will pass both filters on its way
|
||||
# out. Only do this for the newly active service's sandwich.
|
||||
elif (
|
||||
self._direction == FrameDirection.UPSTREAM
|
||||
and self._wrapped_service == frame.active_service
|
||||
and old_active != self._wrapped_service
|
||||
):
|
||||
await self.push_frame(RequestMetadataFrame(), FrameDirection.UPSTREAM)
|
||||
return
|
||||
|
||||
# RequestMetadataFrame is pushed upstream by the upstream filter
|
||||
# (above) and consumed by the service. Guard against services
|
||||
# that don't consume it: only forward in the filter's own
|
||||
# direction (so it can reach the service) and only for the
|
||||
# active service. Block in all other cases to prevent it from
|
||||
# escaping the sandwich.
|
||||
if isinstance(frame, RequestMetadataFrame):
|
||||
if direction == self._direction and self._wrapped_service == self._active_service:
|
||||
await self.push_frame(frame, direction)
|
||||
return
|
||||
|
||||
# Block ServiceMetadataFrame from inactive services.
|
||||
if isinstance(frame, ServiceMetadataFrame):
|
||||
if self._wrapped_service != self._active_service:
|
||||
return
|
||||
await self.push_frame(frame, direction)
|
||||
return
|
||||
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
@dataclass
|
||||
class ServiceSwitcherFilterFrame(ControlFrame):
|
||||
"""An internal frame used by ServiceSwitcher to filter frames based on active service."""
|
||||
"""An internal frame used to update filter state on service switch.
|
||||
|
||||
Sent when a service switch occurs to update the active service in
|
||||
the sandwich filters and trigger metadata emission from the newly
|
||||
active service.
|
||||
"""
|
||||
|
||||
active_service: FrameProcessor
|
||||
|
||||
@@ -178,6 +228,7 @@ class ServiceSwitcher(ParallelPipeline, Generic[StrategyType]):
|
||||
def _make_pipeline_definition(
|
||||
service: FrameProcessor, strategy: ServiceSwitcherStrategy
|
||||
) -> Any:
|
||||
# Layout: DownstreamFilter → Service → UpstreamFilter
|
||||
return [
|
||||
ServiceSwitcher.ServiceSwitcherFilter(
|
||||
wrapped_service=service,
|
||||
|
||||
@@ -664,10 +664,16 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
await self._queued_broadcast_frame(frame_cls, **kwargs)
|
||||
|
||||
async def _on_vad_speech_started(self, controller):
|
||||
await self._queued_broadcast_frame(VADUserStartedSpeakingFrame)
|
||||
await self._queued_broadcast_frame(
|
||||
VADUserStartedSpeakingFrame,
|
||||
start_secs=controller._vad_analyzer.params.start_secs,
|
||||
)
|
||||
|
||||
async def _on_vad_speech_stopped(self, controller):
|
||||
await self._queued_broadcast_frame(VADUserStoppedSpeakingFrame)
|
||||
await self._queued_broadcast_frame(
|
||||
VADUserStoppedSpeakingFrame,
|
||||
stop_secs=controller._vad_analyzer.params.stop_secs,
|
||||
)
|
||||
|
||||
async def _on_vad_speech_activity(self, controller):
|
||||
await self._queued_broadcast_frame(UserSpeakingFrame)
|
||||
|
||||
@@ -64,12 +64,18 @@ class VADProcessor(FrameProcessor):
|
||||
@self._vad_controller.event_handler("on_speech_started")
|
||||
async def on_speech_started(_controller):
|
||||
logger.debug(f"{self}: User started speaking")
|
||||
await self.broadcast_frame(VADUserStartedSpeakingFrame)
|
||||
await self.broadcast_frame(
|
||||
VADUserStartedSpeakingFrame,
|
||||
start_secs=_controller._vad_analyzer.params.start_secs,
|
||||
)
|
||||
|
||||
@self._vad_controller.event_handler("on_speech_stopped")
|
||||
async def on_speech_stopped(_controller):
|
||||
logger.debug(f"{self}: User stopped speaking")
|
||||
await self.broadcast_frame(VADUserStoppedSpeakingFrame)
|
||||
await self.broadcast_frame(
|
||||
VADUserStoppedSpeakingFrame,
|
||||
stop_secs=_controller._vad_analyzer.params.stop_secs,
|
||||
)
|
||||
|
||||
@self._vad_controller.event_handler("on_speech_activity")
|
||||
async def on_speech_activity(_controller):
|
||||
|
||||
@@ -12,7 +12,7 @@ WebSocket API for streaming audio transcription.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, AsyncGenerator, Dict
|
||||
from typing import Any, AsyncGenerator, Dict, Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from loguru import logger
|
||||
@@ -29,6 +29,7 @@ from pipecat.frames.frames import (
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.stt_latency import ASSEMBLYAI_TTFS_P99
|
||||
from pipecat.services.stt_service import WebsocketSTTService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
@@ -67,6 +68,7 @@ class AssemblyAISTTService(WebsocketSTTService):
|
||||
api_endpoint_base_url: str = "wss://streaming.assemblyai.com/v3/ws",
|
||||
connection_params: AssemblyAIConnectionParams = AssemblyAIConnectionParams(),
|
||||
vad_force_turn_endpoint: bool = True,
|
||||
ttfs_p99_latency: Optional[float] = ASSEMBLYAI_TTFS_P99,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the AssemblyAI STT service.
|
||||
@@ -77,9 +79,13 @@ class AssemblyAISTTService(WebsocketSTTService):
|
||||
api_endpoint_base_url: WebSocket endpoint URL. Defaults to AssemblyAI's streaming endpoint.
|
||||
connection_params: Connection configuration parameters. Defaults to AssemblyAIConnectionParams().
|
||||
vad_force_turn_endpoint: Whether to force turn endpoint on VAD stop. Defaults to True.
|
||||
ttfs_p99_latency: P99 latency from speech end to final transcript in seconds.
|
||||
Override for your deployment. See https://github.com/pipecat-ai/stt-benchmark
|
||||
**kwargs: Additional arguments passed to parent STTService class.
|
||||
"""
|
||||
super().__init__(sample_rate=connection_params.sample_rate, **kwargs)
|
||||
super().__init__(
|
||||
sample_rate=connection_params.sample_rate, ttfs_p99_latency=ttfs_p99_latency, **kwargs
|
||||
)
|
||||
|
||||
self._api_key = api_key
|
||||
self._language = language
|
||||
|
||||
@@ -28,6 +28,7 @@ from pipecat.frames.frames import (
|
||||
TranscriptionFrame,
|
||||
)
|
||||
from pipecat.services.aws.utils import build_event_message, decode_event, get_presigned_url
|
||||
from pipecat.services.stt_latency import AWS_TRANSCRIBE_TTFS_P99
|
||||
from pipecat.services.stt_service import WebsocketSTTService
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
@@ -59,6 +60,7 @@ class AWSTranscribeSTTService(WebsocketSTTService):
|
||||
region: Optional[str] = None,
|
||||
sample_rate: int = 16000,
|
||||
language: Language = Language.EN,
|
||||
ttfs_p99_latency: Optional[float] = AWS_TRANSCRIBE_TTFS_P99,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the AWS Transcribe STT service.
|
||||
@@ -70,9 +72,11 @@ class AWSTranscribeSTTService(WebsocketSTTService):
|
||||
region: AWS region for the service.
|
||||
sample_rate: Audio sample rate in Hz. Must be 8000 or 16000. Defaults to 16000.
|
||||
language: Language for transcription. Defaults to English.
|
||||
ttfs_p99_latency: P99 latency from speech end to final transcript in seconds.
|
||||
Override for your deployment. See https://github.com/pipecat-ai/stt-benchmark
|
||||
**kwargs: Additional arguments passed to parent STTService class.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(ttfs_p99_latency=ttfs_p99_latency, **kwargs)
|
||||
|
||||
self._settings = {
|
||||
"sample_rate": sample_rate,
|
||||
|
||||
@@ -25,6 +25,7 @@ from pipecat.frames.frames import (
|
||||
TranscriptionFrame,
|
||||
)
|
||||
from pipecat.services.azure.common import language_to_azure_language
|
||||
from pipecat.services.stt_latency import AZURE_TTFS_P99
|
||||
from pipecat.services.stt_service import STTService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
@@ -63,6 +64,7 @@ class AzureSTTService(STTService):
|
||||
language: Language = Language.EN_US,
|
||||
sample_rate: Optional[int] = None,
|
||||
endpoint_id: Optional[str] = None,
|
||||
ttfs_p99_latency: Optional[float] = AZURE_TTFS_P99,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Azure STT service.
|
||||
@@ -73,9 +75,11 @@ class AzureSTTService(STTService):
|
||||
language: Language for speech recognition. Defaults to English (US).
|
||||
sample_rate: Audio sample rate in Hz. If None, uses service default.
|
||||
endpoint_id: Custom model endpoint id.
|
||||
ttfs_p99_latency: P99 latency from speech end to final transcript in seconds.
|
||||
Override for your deployment. See https://github.com/pipecat-ai/stt-benchmark
|
||||
**kwargs: Additional arguments passed to parent STTService.
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
super().__init__(sample_rate=sample_rate, ttfs_p99_latency=ttfs_p99_latency, **kwargs)
|
||||
|
||||
self._speech_config = SpeechConfig(
|
||||
subscription=api_key,
|
||||
|
||||
@@ -27,6 +27,7 @@ from pipecat.frames.frames import (
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.stt_latency import CARTESIA_TTFS_P99
|
||||
from pipecat.services.stt_service import WebsocketSTTService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
@@ -137,6 +138,7 @@ class CartesiaSTTService(WebsocketSTTService):
|
||||
base_url: str = "",
|
||||
sample_rate: int = 16000,
|
||||
live_options: Optional[CartesiaLiveOptions] = None,
|
||||
ttfs_p99_latency: Optional[float] = CARTESIA_TTFS_P99,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize CartesiaSTTService with API key and options.
|
||||
@@ -146,10 +148,12 @@ class CartesiaSTTService(WebsocketSTTService):
|
||||
base_url: Custom API endpoint URL. If empty, uses default.
|
||||
sample_rate: Audio sample rate in Hz. Defaults to 16000.
|
||||
live_options: Configuration options for transcription service.
|
||||
ttfs_p99_latency: P99 latency from speech end to final transcript in seconds.
|
||||
Override for your deployment. See https://github.com/pipecat-ai/stt-benchmark
|
||||
**kwargs: Additional arguments passed to parent STTService.
|
||||
"""
|
||||
sample_rate = sample_rate or (live_options.sample_rate if live_options else None)
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
super().__init__(sample_rate=sample_rate, ttfs_p99_latency=ttfs_p99_latency, **kwargs)
|
||||
|
||||
default_options = CartesiaLiveOptions(
|
||||
model="ink-whisper",
|
||||
|
||||
@@ -161,7 +161,11 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
# was never destroyed.
|
||||
# So we can keep it here as false, because inside the method send_with_retry, it will
|
||||
# already try to reconnect if needed.
|
||||
super().__init__(sample_rate=sample_rate, reconnect_on_error=False, **kwargs)
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
reconnect_on_error=False,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self._api_key = api_key
|
||||
self._url = url
|
||||
|
||||
@@ -23,6 +23,7 @@ from pipecat.frames.frames import (
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.stt_latency import DEEPGRAM_TTFS_P99
|
||||
from pipecat.services.stt_service import STTService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
@@ -61,6 +62,7 @@ class DeepgramSTTService(STTService):
|
||||
live_options: Optional[LiveOptions] = None,
|
||||
addons: Optional[Dict] = None,
|
||||
should_interrupt: bool = True,
|
||||
ttfs_p99_latency: Optional[float] = DEEPGRAM_TTFS_P99,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Deepgram STT service.
|
||||
@@ -81,13 +83,15 @@ class DeepgramSTTService(STTService):
|
||||
.. deprecated:: 0.0.99
|
||||
This parameter will be removed along with `vad_events` support.
|
||||
|
||||
ttfs_p99_latency: P99 latency from speech end to final transcript in seconds.
|
||||
Override for your deployment. See https://github.com/pipecat-ai/stt-benchmark
|
||||
**kwargs: Additional arguments passed to the parent STTService.
|
||||
|
||||
Note:
|
||||
The `vad_events` option in LiveOptions is deprecated as of version 0.0.99 and will be removed in a future version. Please use the Silero VAD instead.
|
||||
"""
|
||||
sample_rate = sample_rate or (live_options.sample_rate if live_options else None)
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
super().__init__(sample_rate=sample_rate, ttfs_p99_latency=ttfs_p99_latency, **kwargs)
|
||||
|
||||
if url:
|
||||
import warnings
|
||||
|
||||
@@ -31,6 +31,7 @@ from pipecat.frames.frames import (
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.aws.sagemaker.bidi_client import SageMakerBidiClient
|
||||
from pipecat.services.stt_latency import DEEPGRAM_SAGEMAKER_TTFS_P99
|
||||
from pipecat.services.stt_service import STTService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
@@ -81,6 +82,7 @@ class DeepgramSageMakerSTTService(STTService):
|
||||
region: str,
|
||||
sample_rate: Optional[int] = None,
|
||||
live_options: Optional[LiveOptions] = None,
|
||||
ttfs_p99_latency: Optional[float] = DEEPGRAM_SAGEMAKER_TTFS_P99,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Deepgram SageMaker STT service.
|
||||
@@ -93,10 +95,12 @@ class DeepgramSageMakerSTTService(STTService):
|
||||
live_options or defaults to the value from StartFrame.
|
||||
live_options: Deepgram LiveOptions for detailed configuration. If None,
|
||||
uses sensible defaults (nova-3 model, English, interim results enabled).
|
||||
ttfs_p99_latency: P99 latency from speech end to final transcript in seconds.
|
||||
Override for your deployment. See https://github.com/pipecat-ai/stt-benchmark
|
||||
**kwargs: Additional arguments passed to the parent STTService.
|
||||
"""
|
||||
sample_rate = sample_rate or (live_options.sample_rate if live_options else None)
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
super().__init__(sample_rate=sample_rate, ttfs_p99_latency=ttfs_p99_latency, **kwargs)
|
||||
|
||||
self._endpoint_name = endpoint_name
|
||||
self._region = region
|
||||
|
||||
@@ -33,6 +33,7 @@ from pipecat.frames.frames import (
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.stt_latency import ELEVENLABS_REALTIME_TTFS_P99, ELEVENLABS_TTFS_P99
|
||||
from pipecat.services.stt_service import SegmentedSTTService, WebsocketSTTService
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
@@ -194,6 +195,7 @@ class ElevenLabsSTTService(SegmentedSTTService):
|
||||
model: str = "scribe_v2",
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
ttfs_p99_latency: Optional[float] = ELEVENLABS_TTFS_P99,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the ElevenLabs STT service.
|
||||
@@ -205,10 +207,13 @@ class ElevenLabsSTTService(SegmentedSTTService):
|
||||
model: Model ID for transcription. Defaults to "scribe_v2".
|
||||
sample_rate: Audio sample rate in Hz. If not provided, uses the pipeline's rate.
|
||||
params: Configuration parameters for the STT service.
|
||||
ttfs_p99_latency: P99 latency from speech end to final transcript in seconds.
|
||||
Override for your deployment. See https://github.com/pipecat-ai/stt-benchmark
|
||||
**kwargs: Additional arguments passed to SegmentedSTTService.
|
||||
"""
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
ttfs_p99_latency=ttfs_p99_latency,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -436,6 +441,7 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
model: str = "scribe_v2_realtime",
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
ttfs_p99_latency: Optional[float] = ELEVENLABS_REALTIME_TTFS_P99,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the ElevenLabs Realtime STT service.
|
||||
@@ -446,10 +452,13 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
model: Model ID for transcription. Defaults to "scribe_v2_realtime".
|
||||
sample_rate: Audio sample rate in Hz. If not provided, uses the pipeline's rate.
|
||||
params: Configuration parameters for the STT service.
|
||||
ttfs_p99_latency: P99 latency from speech end to final transcript in seconds.
|
||||
Override for your deployment. See https://github.com/pipecat-ai/stt-benchmark
|
||||
**kwargs: Additional arguments passed to WebsocketSTTService.
|
||||
"""
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
ttfs_p99_latency=ttfs_p99_latency,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import ErrorFrame, Frame, TranscriptionFrame
|
||||
from pipecat.services.stt_latency import FAL_TTFS_P99
|
||||
from pipecat.services.stt_service import SegmentedSTTService
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
@@ -173,6 +174,7 @@ class FalSTTService(SegmentedSTTService):
|
||||
api_key: Optional[str] = None,
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
ttfs_p99_latency: Optional[float] = FAL_TTFS_P99,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the FalSTTService with API key and parameters.
|
||||
@@ -181,10 +183,13 @@ class FalSTTService(SegmentedSTTService):
|
||||
api_key: Fal API key. If not provided, will check FAL_KEY environment variable.
|
||||
sample_rate: Audio sample rate in Hz. If not provided, uses the pipeline's rate.
|
||||
params: Configuration parameters for the Wizper API.
|
||||
ttfs_p99_latency: P99 latency from speech end to final transcript in seconds.
|
||||
Override for your deployment. See https://github.com/pipecat-ai/stt-benchmark
|
||||
**kwargs: Additional arguments passed to SegmentedSTTService.
|
||||
"""
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
ttfs_p99_latency=ttfs_p99_latency,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ from pipecat.frames.frames import (
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.services.gladia.config import GladiaInputParams
|
||||
from pipecat.services.stt_latency import GLADIA_TTFS_P99
|
||||
from pipecat.services.stt_service import WebsocketSTTService
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
@@ -205,6 +206,7 @@ class GladiaSTTService(WebsocketSTTService):
|
||||
params: Optional[GladiaInputParams] = None,
|
||||
max_buffer_size: int = 1024 * 1024 * 20, # 20MB default buffer
|
||||
should_interrupt: bool = True,
|
||||
ttfs_p99_latency: Optional[float] = GLADIA_TTFS_P99,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Gladia STT service.
|
||||
@@ -225,9 +227,11 @@ class GladiaSTTService(WebsocketSTTService):
|
||||
max_buffer_size: Maximum size of audio buffer in bytes. Defaults to 20MB.
|
||||
should_interrupt: Determine whether the bot should be interrupted when
|
||||
Gladia VAD detects user speech. Defaults to True.
|
||||
ttfs_p99_latency: P99 latency from speech end to final transcript in seconds.
|
||||
Override for your deployment. See https://github.com/pipecat-ai/stt-benchmark
|
||||
**kwargs: Additional arguments passed to the STTService parent class.
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
super().__init__(sample_rate=sample_rate, ttfs_p99_latency=ttfs_p99_latency, **kwargs)
|
||||
|
||||
params = params or GladiaInputParams()
|
||||
|
||||
|
||||
@@ -34,6 +34,7 @@ from pipecat.frames.frames import (
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
)
|
||||
from pipecat.services.stt_latency import GOOGLE_TTFS_P99
|
||||
from pipecat.services.stt_service import STTService
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
@@ -438,6 +439,7 @@ class GoogleSTTService(STTService):
|
||||
location: str = "global",
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
ttfs_p99_latency: Optional[float] = GOOGLE_TTFS_P99,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Google STT service.
|
||||
@@ -448,9 +450,11 @@ class GoogleSTTService(STTService):
|
||||
location: Google Cloud location (e.g., "global", "us-central1").
|
||||
sample_rate: Audio sample rate in Hertz.
|
||||
params: Configuration parameters for the service.
|
||||
ttfs_p99_latency: P99 latency from speech end to final transcript in seconds.
|
||||
Override for your deployment. See https://github.com/pipecat-ai/stt-benchmark
|
||||
**kwargs: Additional arguments passed to STTService.
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
super().__init__(sample_rate=sample_rate, ttfs_p99_latency=ttfs_p99_latency, **kwargs)
|
||||
|
||||
params = params or GoogleSTTService.InputParams()
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ from pipecat.frames.frames import (
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.stt_latency import GRADIUM_TTFS_P99
|
||||
from pipecat.services.stt_service import WebsocketSTTService
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
@@ -94,6 +95,7 @@ class GradiumSTTService(WebsocketSTTService):
|
||||
api_endpoint_base_url: str = "wss://eu.api.gradium.ai/api/speech/asr",
|
||||
params: Optional[InputParams] = None,
|
||||
json_config: Optional[str] = None,
|
||||
ttfs_p99_latency: Optional[float] = GRADIUM_TTFS_P99,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Gradium STT service.
|
||||
@@ -107,9 +109,11 @@ class GradiumSTTService(WebsocketSTTService):
|
||||
.. deprecated:: 0.0.101
|
||||
Use `params` instead for type-safe configuration.
|
||||
|
||||
ttfs_p99_latency: P99 latency from speech end to final transcript in seconds.
|
||||
Override for your deployment. See https://github.com/pipecat-ai/stt-benchmark
|
||||
**kwargs: Additional arguments passed to parent STTService class.
|
||||
"""
|
||||
super().__init__(sample_rate=SAMPLE_RATE, **kwargs)
|
||||
super().__init__(sample_rate=SAMPLE_RATE, ttfs_p99_latency=ttfs_p99_latency, **kwargs)
|
||||
|
||||
if json_config is not None:
|
||||
import warnings
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pipecat.services.stt_latency import GROQ_TTFS_P99
|
||||
from pipecat.services.whisper.base_stt import BaseWhisperSTTService, Transcription
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
@@ -28,6 +29,7 @@ class GroqSTTService(BaseWhisperSTTService):
|
||||
language: Optional[Language] = Language.EN,
|
||||
prompt: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
ttfs_p99_latency: Optional[float] = GROQ_TTFS_P99,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize Groq STT service.
|
||||
@@ -39,6 +41,8 @@ class GroqSTTService(BaseWhisperSTTService):
|
||||
language: Language of the audio input. Defaults to English.
|
||||
prompt: Optional text to guide the model's style or continue a previous segment.
|
||||
temperature: Optional sampling temperature between 0 and 1. Defaults to 0.0.
|
||||
ttfs_p99_latency: P99 latency from speech end to final transcript in seconds.
|
||||
Override for your deployment. See https://github.com/pipecat-ai/stt-benchmark
|
||||
**kwargs: Additional arguments passed to BaseWhisperSTTService.
|
||||
"""
|
||||
super().__init__(
|
||||
@@ -48,6 +52,7 @@ class GroqSTTService(BaseWhisperSTTService):
|
||||
language=language,
|
||||
prompt=prompt,
|
||||
temperature=temperature,
|
||||
ttfs_p99_latency=ttfs_p99_latency,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ from pipecat.frames.frames import (
|
||||
Frame,
|
||||
TranscriptionFrame,
|
||||
)
|
||||
from pipecat.services.stt_latency import HATHORA_TTFS_P99
|
||||
from pipecat.services.stt_service import SegmentedSTTService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
@@ -53,6 +54,7 @@ class HathoraSTTService(SegmentedSTTService):
|
||||
api_key: Optional[str] = None,
|
||||
base_url: str = "https://api.models.hathora.dev/inference/v1/stt",
|
||||
params: Optional[InputParams] = None,
|
||||
ttfs_p99_latency: Optional[float] = HATHORA_TTFS_P99,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Hathora STT service.
|
||||
@@ -66,10 +68,13 @@ class HathoraSTTService(SegmentedSTTService):
|
||||
provision one [here](https://models.hathora.dev/tokens).
|
||||
base_url: Base API URL for the Hathora STT service.
|
||||
params: Configuration parameters.
|
||||
ttfs_p99_latency: P99 latency from speech end to final transcript in seconds.
|
||||
Override for your deployment. See https://github.com/pipecat-ai/stt-benchmark
|
||||
**kwargs: Additional arguments passed to the parent class.
|
||||
"""
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
ttfs_p99_latency=ttfs_p99_latency,
|
||||
**kwargs,
|
||||
)
|
||||
self._model = model
|
||||
|
||||
@@ -22,6 +22,7 @@ from pipecat.frames.frames import (
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
)
|
||||
from pipecat.services.stt_latency import NVIDIA_TTFS_P99
|
||||
from pipecat.services.stt_service import SegmentedSTTService, STTService
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
@@ -117,6 +118,7 @@ class NvidiaSTTService(STTService):
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
use_ssl: bool = True,
|
||||
ttfs_p99_latency: Optional[float] = NVIDIA_TTFS_P99,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the NVIDIA Riva STT service.
|
||||
@@ -128,9 +130,11 @@ class NvidiaSTTService(STTService):
|
||||
sample_rate: Audio sample rate in Hz. If None, uses pipeline default.
|
||||
params: Additional configuration parameters for NVIDIA Riva.
|
||||
use_ssl: Whether to use SSL for the NVIDIA Riva server. Defaults to True.
|
||||
ttfs_p99_latency: P99 latency from speech end to final transcript in seconds.
|
||||
Override for your deployment. See https://github.com/pipecat-ai/stt-benchmark
|
||||
**kwargs: Additional arguments passed to STTService.
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
super().__init__(sample_rate=sample_rate, ttfs_p99_latency=ttfs_p99_latency, **kwargs)
|
||||
|
||||
params = params or NvidiaSTTService.InputParams()
|
||||
|
||||
@@ -413,6 +417,7 @@ class NvidiaSegmentedSTTService(SegmentedSTTService):
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
use_ssl: bool = True,
|
||||
ttfs_p99_latency: Optional[float] = NVIDIA_TTFS_P99,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the NVIDIA Riva segmented STT service.
|
||||
@@ -424,9 +429,11 @@ class NvidiaSegmentedSTTService(SegmentedSTTService):
|
||||
sample_rate: Audio sample rate in Hz. If not provided, uses the pipeline's rate
|
||||
params: Additional configuration parameters for NVIDIA Riva
|
||||
use_ssl: Whether to use SSL for the NVIDIA Riva server. Defaults to True.
|
||||
ttfs_p99_latency: P99 latency from speech end to final transcript in seconds.
|
||||
Override for your deployment. See https://github.com/pipecat-ai/stt-benchmark
|
||||
**kwargs: Additional arguments passed to SegmentedSTTService
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
super().__init__(sample_rate=sample_rate, ttfs_p99_latency=ttfs_p99_latency, **kwargs)
|
||||
|
||||
params = params or NvidiaSegmentedSTTService.InputParams()
|
||||
|
||||
|
||||
@@ -34,6 +34,7 @@ from pipecat.frames.frames import (
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.stt_latency import OPENAI_REALTIME_TTFS_P99, OPENAI_TTFS_P99
|
||||
from pipecat.services.stt_service import WebsocketSTTService
|
||||
from pipecat.services.whisper.base_stt import BaseWhisperSTTService, Transcription
|
||||
from pipecat.transcriptions.language import Language
|
||||
@@ -64,6 +65,7 @@ class OpenAISTTService(BaseWhisperSTTService):
|
||||
language: Optional[Language] = Language.EN,
|
||||
prompt: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
ttfs_p99_latency: Optional[float] = OPENAI_TTFS_P99,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize OpenAI STT service.
|
||||
@@ -75,6 +77,8 @@ class OpenAISTTService(BaseWhisperSTTService):
|
||||
language: Language of the audio input. Defaults to English.
|
||||
prompt: Optional text to guide the model's style or continue a previous segment.
|
||||
temperature: Optional sampling temperature between 0 and 1. Defaults to 0.0.
|
||||
ttfs_p99_latency: P99 latency from speech end to final transcript in seconds.
|
||||
Override for your deployment. See https://github.com/pipecat-ai/stt-benchmark
|
||||
**kwargs: Additional arguments passed to BaseWhisperSTTService.
|
||||
"""
|
||||
super().__init__(
|
||||
@@ -84,6 +88,7 @@ class OpenAISTTService(BaseWhisperSTTService):
|
||||
language=language,
|
||||
prompt=prompt,
|
||||
temperature=temperature,
|
||||
ttfs_p99_latency=ttfs_p99_latency,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -162,6 +167,7 @@ class OpenAIRealtimeSTTService(WebsocketSTTService):
|
||||
turn_detection: Optional[Union[dict, Literal[False]]] = False,
|
||||
noise_reduction: Optional[Literal["near_field", "far_field"]] = None,
|
||||
should_interrupt: bool = True,
|
||||
ttfs_p99_latency: Optional[float] = OPENAI_REALTIME_TTFS_P99,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the OpenAI Realtime STT service.
|
||||
@@ -187,6 +193,8 @@ class OpenAIRealtimeSTTService(WebsocketSTTService):
|
||||
should_interrupt: Whether to interrupt bot output when
|
||||
speech is detected by server-side VAD. Only applies when
|
||||
turn detection is enabled. Defaults to True.
|
||||
ttfs_p99_latency: P99 latency from speech end to final transcript in seconds.
|
||||
Override for your deployment. See https://github.com/pipecat-ai/stt-benchmark
|
||||
**kwargs: Additional arguments passed to parent
|
||||
WebsocketSTTService.
|
||||
"""
|
||||
@@ -196,7 +204,10 @@ class OpenAIRealtimeSTTService(WebsocketSTTService):
|
||||
"Install it with: pip install pipecat-ai[openai]"
|
||||
)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(
|
||||
ttfs_p99_latency=ttfs_p99_latency,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self._api_key = api_key
|
||||
self._base_url = base_url
|
||||
|
||||
@@ -10,6 +10,7 @@ from typing import Any, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.services.stt_latency import SAMBANOVA_TTFS_P99
|
||||
from pipecat.services.whisper.base_stt import BaseWhisperSTTService, Transcription
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
@@ -30,6 +31,7 @@ class SambaNovaSTTService(BaseWhisperSTTService): # type: ignore
|
||||
language: Optional[Language] = Language.EN,
|
||||
prompt: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
ttfs_p99_latency: Optional[float] = SAMBANOVA_TTFS_P99,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize SambaNova STT service.
|
||||
@@ -41,6 +43,8 @@ class SambaNovaSTTService(BaseWhisperSTTService): # type: ignore
|
||||
language: Language of the audio input. Defaults to English.
|
||||
prompt: Optional text to guide the model's style or continue a previous segment.
|
||||
temperature: Optional sampling temperature between 0 and 1. Defaults to 0.0.
|
||||
ttfs_p99_latency: P99 latency from speech end to final transcript in seconds.
|
||||
Override for your deployment. See https://github.com/pipecat-ai/stt-benchmark
|
||||
**kwargs: Additional arguments passed to `pipecat.services.whisper.base_stt.BaseWhisperSTTService`.
|
||||
"""
|
||||
super().__init__(
|
||||
@@ -50,6 +54,7 @@ class SambaNovaSTTService(BaseWhisperSTTService): # type: ignore
|
||||
language=language,
|
||||
prompt=prompt,
|
||||
temperature=temperature,
|
||||
ttfs_p99_latency=ttfs_p99_latency,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ from pipecat.frames.frames import (
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.sarvam._sdk import sdk_headers
|
||||
from pipecat.services.stt_latency import SARVAM_TTFS_P99
|
||||
from pipecat.services.stt_service import STTService
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
@@ -159,6 +160,7 @@ class SarvamSTTService(STTService):
|
||||
sample_rate: Optional[int] = None,
|
||||
input_audio_codec: str = "wav",
|
||||
params: Optional[InputParams] = None,
|
||||
ttfs_p99_latency: Optional[float] = SARVAM_TTFS_P99,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Sarvam STT service.
|
||||
@@ -172,6 +174,8 @@ class SarvamSTTService(STTService):
|
||||
sample_rate: Audio sample rate. Defaults to 16000 if not specified.
|
||||
input_audio_codec: Audio codec/format of the input file. Defaults to "wav".
|
||||
params: Configuration parameters for Sarvam STT service.
|
||||
ttfs_p99_latency: P99 latency from speech end to final transcript in seconds.
|
||||
Override for your deployment. See https://github.com/pipecat-ai/stt-benchmark
|
||||
**kwargs: Additional arguments passed to the parent STTService.
|
||||
"""
|
||||
params = params or SarvamSTTService.InputParams()
|
||||
@@ -193,7 +197,7 @@ class SarvamSTTService(STTService):
|
||||
f"Model '{model}' does not support language parameter (auto-detects language)."
|
||||
)
|
||||
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
super().__init__(sample_rate=sample_rate, ttfs_p99_latency=ttfs_p99_latency, **kwargs)
|
||||
|
||||
self.set_model_name(model)
|
||||
self._api_key = api_key
|
||||
|
||||
@@ -24,6 +24,7 @@ from pipecat.frames.frames import (
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.stt_latency import SONIOX_TTFS_P99
|
||||
from pipecat.services.stt_service import WebsocketSTTService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
@@ -152,6 +153,7 @@ class SonioxSTTService(WebsocketSTTService):
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[SonioxInputParams] = None,
|
||||
vad_force_turn_endpoint: bool = False,
|
||||
ttfs_p99_latency: Optional[float] = SONIOX_TTFS_P99,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Soniox STT service.
|
||||
@@ -163,9 +165,11 @@ class SonioxSTTService(WebsocketSTTService):
|
||||
params: Additional configuration parameters, such as language hints, context and
|
||||
speaker diarization.
|
||||
vad_force_turn_endpoint: Listen to `VADUserStoppedSpeakingFrame` to send finalize message to Soniox. If disabled, Soniox will detect the end of the speech.
|
||||
ttfs_p99_latency: P99 latency from speech end to final transcript in seconds.
|
||||
Override for your deployment. See https://github.com/pipecat-ai/stt-benchmark
|
||||
**kwargs: Additional arguments passed to the STTService.
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
super().__init__(sample_rate=sample_rate, ttfs_p99_latency=ttfs_p99_latency, **kwargs)
|
||||
params = params or SonioxInputParams()
|
||||
|
||||
self._api_key = api_key
|
||||
|
||||
@@ -31,6 +31,7 @@ from pipecat.frames.frames import (
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.stt_latency import SPEECHMATICS_TTFS_P99
|
||||
from pipecat.services.stt_service import STTService
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
from pipecat.utils.tracing.service_decorators import traced_stt
|
||||
@@ -288,6 +289,7 @@ class SpeechmaticsSTTService(STTService):
|
||||
sample_rate: int | None = None,
|
||||
params: InputParams | None = None,
|
||||
should_interrupt: bool = True,
|
||||
ttfs_p99_latency: float | None = SPEECHMATICS_TTFS_P99,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Speechmatics STT service.
|
||||
@@ -300,9 +302,11 @@ class SpeechmaticsSTTService(STTService):
|
||||
sample_rate: Optional audio sample rate in Hz.
|
||||
params: Optional[InputParams]: Input parameters for the service.
|
||||
should_interrupt: Determine whether the bot should be interrupted when Speechmatics turn_detection_mode is configured to detect user speech.
|
||||
ttfs_p99_latency: P99 latency from speech end to final transcript in seconds.
|
||||
Override for your deployment. See https://github.com/pipecat-ai/stt-benchmark
|
||||
**kwargs: Additional arguments passed to STTService.
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
super().__init__(sample_rate=sample_rate, ttfs_p99_latency=ttfs_p99_latency, **kwargs)
|
||||
|
||||
# Service parameters
|
||||
self._api_key: str = api_key or os.getenv("SPEECHMATICS_API_KEY")
|
||||
|
||||
53
src/pipecat/services/stt_latency.py
Normal file
53
src/pipecat/services/stt_latency.py
Normal file
@@ -0,0 +1,53 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""STT service latency defaults.
|
||||
|
||||
This module contains P99 time-to-final-segment (TTFS) latency values for STT
|
||||
services. TTFS measures the time from when speech ends to when the final
|
||||
transcript is received.
|
||||
|
||||
These values are used by turn stop strategies to optimize timing. Each STT
|
||||
service publishes its latency via STTMetadataFrame at pipeline start.
|
||||
|
||||
To measure latency for your specific deployment (region, network conditions,
|
||||
self-hosted instances), use the STT benchmark tool:
|
||||
https://github.com/pipecat-ai/stt-benchmark
|
||||
|
||||
Run the TTFS benchmark for your service and configuration, then pass the
|
||||
measured value to your STT service constructor:
|
||||
|
||||
stt = DeepgramSTTService(api_key="...", ttfs_p99_latency=0.45)
|
||||
"""
|
||||
|
||||
# Conservative fallback for services without measured values
|
||||
DEFAULT_TTFS_P99: float = 1.0
|
||||
|
||||
# Measured P99 TTFS latency values (in seconds)
|
||||
ASSEMBLYAI_TTFS_P99: float = 0.42
|
||||
AWS_TRANSCRIBE_TTFS_P99: float = 1.90
|
||||
AZURE_TTFS_P99: float = 1.80
|
||||
CARTESIA_TTFS_P99: float = 0.81
|
||||
DEEPGRAM_TTFS_P99: float = 0.35
|
||||
DEEPGRAM_SAGEMAKER_TTFS_P99: float = 0.35
|
||||
ELEVENLABS_TTFS_P99: float = 2.01
|
||||
ELEVENLABS_REALTIME_TTFS_P99: float = 0.41
|
||||
FAL_TTFS_P99: float = 2.07
|
||||
GLADIA_TTFS_P99: float = 1.49
|
||||
GOOGLE_TTFS_P99: float = 1.57
|
||||
GRADIUM_TTFS_P99: float = 1.61
|
||||
GROQ_TTFS_P99: float = 1.54
|
||||
HATHORA_TTFS_P99: float = 0.87
|
||||
OPENAI_TTFS_P99: float = 2.01
|
||||
OPENAI_REALTIME_TTFS_P99: float = 1.66
|
||||
SAMBANOVA_TTFS_P99: float = 2.20
|
||||
SARVAM_TTFS_P99: float = 1.17
|
||||
SONIOX_TTFS_P99: float = 0.35
|
||||
SPEECHMATICS_TTFS_P99: float = 0.74
|
||||
|
||||
# These services run locally and should be replaced with measured values
|
||||
NVIDIA_TTFS_P99: float = DEFAULT_TTFS_P99
|
||||
WHISPER_TTFS_P99: float = DEFAULT_TTFS_P99
|
||||
@@ -21,8 +21,9 @@ from pipecat.frames.frames import (
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
MetricsFrame,
|
||||
SpeechControlParamsFrame,
|
||||
RequestMetadataFrame,
|
||||
StartFrame,
|
||||
STTMetadataFrame,
|
||||
STTMuteFrame,
|
||||
STTUpdateSettingsFrame,
|
||||
TranscriptionFrame,
|
||||
@@ -32,6 +33,7 @@ from pipecat.frames.frames import (
|
||||
from pipecat.metrics.metrics import TTFBMetricsData
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_service import AIService
|
||||
from pipecat.services.stt_latency import DEFAULT_TTFS_P99
|
||||
from pipecat.services.websocket_service import WebsocketService
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
@@ -65,11 +67,11 @@ class STTService(AIService):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
audio_passthrough=True,
|
||||
# STT input sample rate
|
||||
sample_rate: Optional[int] = None,
|
||||
# STT TTFB timeout - time to wait after VAD stop before reporting TTFB
|
||||
stt_ttfb_timeout: float = 2.0,
|
||||
ttfs_p99_latency: Optional[float] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the STT service.
|
||||
@@ -85,6 +87,10 @@ class STTService(AIService):
|
||||
request to first response byte). Since STT receives continuous audio, we measure
|
||||
from when the user stops speaking to when the final transcript arrives—capturing
|
||||
the latency that matters for voice AI applications.
|
||||
ttfs_p99_latency: P99 latency from speech end to final transcript in seconds.
|
||||
This is broadcast via STTMetadataFrame at pipeline start for downstream
|
||||
processors (e.g., turn strategies) to optimize timing. Subclasses provide
|
||||
measured defaults; pass a value here to override for your deployment.
|
||||
**kwargs: Additional arguments passed to the parent AIService.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
@@ -95,11 +101,11 @@ class STTService(AIService):
|
||||
self._tracing_enabled: bool = False
|
||||
self._muted: bool = False
|
||||
self._user_id: str = ""
|
||||
self._ttfs_p99_latency = ttfs_p99_latency
|
||||
|
||||
# STT TTFB tracking state
|
||||
self._stt_ttfb_timeout = stt_ttfb_timeout
|
||||
self._ttfb_timeout_task: Optional[asyncio.Task] = None
|
||||
self._vad_stop_secs: Optional[float] = None
|
||||
self._speech_end_time: Optional[float] = None
|
||||
self._user_speaking: bool = False
|
||||
self._last_transcription_time: Optional[float] = None
|
||||
@@ -254,16 +260,20 @@ class STTService(AIService):
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, AudioRawFrame):
|
||||
if isinstance(frame, StartFrame):
|
||||
# Push StartFrame first, then metadata so downstream receives them in order
|
||||
await self.push_frame(frame, direction)
|
||||
await self._push_stt_metadata()
|
||||
elif isinstance(frame, RequestMetadataFrame):
|
||||
# Don't push the RequestMetadataFrame, just push the metadata
|
||||
await self._push_stt_metadata()
|
||||
elif isinstance(frame, AudioRawFrame):
|
||||
# In this service we accumulate audio internally and at the end we
|
||||
# push a TextFrame. We also push audio downstream in case someone
|
||||
# else needs it.
|
||||
await self.process_audio_frame(frame, direction)
|
||||
if self._audio_passthrough:
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, SpeechControlParamsFrame):
|
||||
await self._handle_speech_control_params(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, VADUserStartedSpeakingFrame):
|
||||
await self._handle_vad_user_started_speaking(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
@@ -314,14 +324,13 @@ class STTService(AIService):
|
||||
|
||||
await super().push_frame(frame, direction)
|
||||
|
||||
async def _handle_speech_control_params(self, frame: SpeechControlParamsFrame):
|
||||
"""Handle speech control parameters frame to extract VAD stop_secs.
|
||||
|
||||
Args:
|
||||
frame: The speech control parameters frame.
|
||||
"""
|
||||
if frame.vad_params is not None:
|
||||
self._vad_stop_secs = frame.vad_params.stop_secs
|
||||
async def _push_stt_metadata(self):
|
||||
"""Push STT metadata frame for downstream processors (e.g., turn strategies)."""
|
||||
ttfs = self._ttfs_p99_latency
|
||||
if ttfs is None:
|
||||
ttfs = DEFAULT_TTFS_P99
|
||||
logger.warning(f"{self.name}: ttfs_p99_latency not set, using default {ttfs}s")
|
||||
await self.broadcast_frame(STTMetadataFrame, service_name=self.name, ttfs_p99_latency=ttfs)
|
||||
|
||||
async def _cancel_ttfb_timeout(self):
|
||||
"""Cancel any pending TTFB timeout task."""
|
||||
@@ -369,14 +378,14 @@ class STTService(AIService):
|
||||
"""
|
||||
self._user_speaking = False
|
||||
|
||||
# Skip TTFB measurement if we don't have VAD params
|
||||
if self._vad_stop_secs is None:
|
||||
# Skip TTFB measurement if stop_secs is not set
|
||||
if frame.stop_secs == 0.0:
|
||||
return
|
||||
|
||||
# Calculate the actual speech end time (current time minus VAD stop delay).
|
||||
# This approximates when the last user audio was sent to the STT service,
|
||||
# which we use to measure against the eventual transcription response.
|
||||
self._speech_end_time = time.time() - self._vad_stop_secs
|
||||
self._speech_end_time = frame.timestamp - frame.stop_secs
|
||||
|
||||
# Start timeout task (any previous timeout was cancelled by VADUserStartedSpeakingFrame
|
||||
# or InterruptionFrame)
|
||||
|
||||
@@ -17,6 +17,7 @@ from openai import AsyncOpenAI
|
||||
from openai.types.audio import Transcription
|
||||
|
||||
from pipecat.frames.frames import ErrorFrame, Frame, TranscriptionFrame
|
||||
from pipecat.services.stt_latency import WHISPER_TTFS_P99
|
||||
from pipecat.services.stt_service import SegmentedSTTService
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
@@ -115,6 +116,7 @@ class BaseWhisperSTTService(SegmentedSTTService):
|
||||
prompt: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
include_prob_metrics: bool = False,
|
||||
ttfs_p99_latency: Optional[float] = WHISPER_TTFS_P99,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Whisper STT service.
|
||||
@@ -129,9 +131,11 @@ class BaseWhisperSTTService(SegmentedSTTService):
|
||||
include_prob_metrics: If True, enables probability metrics in API response.
|
||||
Each service implements this differently (see child classes).
|
||||
Defaults to False.
|
||||
ttfs_p99_latency: P99 latency from speech end to final transcript in seconds.
|
||||
Override for your deployment. See https://github.com/pipecat-ai/stt-benchmark
|
||||
**kwargs: Additional arguments passed to SegmentedSTTService.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(ttfs_p99_latency=ttfs_p99_latency, **kwargs)
|
||||
self.set_model_name(model)
|
||||
self._client = self._create_client(api_key, base_url)
|
||||
self._language = self.language_to_service_language(language or Language.EN)
|
||||
|
||||
@@ -464,7 +464,12 @@ class BaseInputTransport(FrameProcessor):
|
||||
if self._params.turn_analyzer:
|
||||
await self._deprecated_handle_user_interruption(VADState.QUIET)
|
||||
else:
|
||||
await self.push_frame(VADUserStoppedSpeakingFrame())
|
||||
stop_secs = (
|
||||
self._params.vad_analyzer.params.stop_secs
|
||||
if self._params.vad_analyzer
|
||||
else 0.0
|
||||
)
|
||||
await self.push_frame(VADUserStoppedSpeakingFrame(stop_secs=stop_secs))
|
||||
###################################################################
|
||||
|
||||
#
|
||||
@@ -492,9 +497,17 @@ class BaseInputTransport(FrameProcessor):
|
||||
and new_vad_state != VADState.STOPPING
|
||||
):
|
||||
if new_vad_state == VADState.SPEAKING:
|
||||
await self.push_frame(VADUserStartedSpeakingFrame())
|
||||
start_secs = (
|
||||
self._params.vad_analyzer.params.start_secs
|
||||
if self._params.vad_analyzer
|
||||
else 0.0
|
||||
)
|
||||
await self.push_frame(VADUserStartedSpeakingFrame(start_secs=start_secs))
|
||||
elif new_vad_state == VADState.QUIET:
|
||||
await self.push_frame(VADUserStoppedSpeakingFrame())
|
||||
stop_secs = (
|
||||
self._params.vad_analyzer.params.stop_secs if self._params.vad_analyzer else 0.0
|
||||
)
|
||||
await self.push_frame(VADUserStoppedSpeakingFrame(stop_secs=stop_secs))
|
||||
|
||||
vad_state = new_vad_state
|
||||
return vad_state
|
||||
@@ -574,11 +587,19 @@ class BaseInputTransport(FrameProcessor):
|
||||
or not self._params.turn_analyzer.speech_triggered
|
||||
)
|
||||
if new_vad_state == VADState.SPEAKING:
|
||||
await self.push_frame(VADUserStartedSpeakingFrame())
|
||||
start_secs = (
|
||||
self._params.vad_analyzer.params.start_secs
|
||||
if self._params.vad_analyzer
|
||||
else 0.0
|
||||
)
|
||||
await self.push_frame(VADUserStartedSpeakingFrame(start_secs=start_secs))
|
||||
if can_create_user_frames:
|
||||
interruption_state = VADState.SPEAKING
|
||||
elif new_vad_state == VADState.QUIET:
|
||||
await self.push_frame(VADUserStoppedSpeakingFrame())
|
||||
stop_secs = (
|
||||
self._params.vad_analyzer.params.stop_secs if self._params.vad_analyzer else 0.0
|
||||
)
|
||||
await self.push_frame(VADUserStoppedSpeakingFrame(stop_secs=stop_secs))
|
||||
if can_create_user_frames:
|
||||
interruption_state = VADState.QUIET
|
||||
|
||||
|
||||
@@ -6,13 +6,13 @@
|
||||
|
||||
from .base_user_turn_stop_strategy import BaseUserTurnStopStrategy, UserTurnStoppedParams
|
||||
from .external_user_turn_stop_strategy import ExternalUserTurnStopStrategy
|
||||
from .transcription_user_turn_stop_strategy import TranscriptionUserTurnStopStrategy
|
||||
from .speech_timeout_user_turn_stop_strategy import SpeechTimeoutUserTurnStopStrategy
|
||||
from .turn_analyzer_user_turn_stop_strategy import TurnAnalyzerUserTurnStopStrategy
|
||||
|
||||
__all__ = [
|
||||
"BaseUserTurnStopStrategy",
|
||||
"ExternalUserTurnStopStrategy",
|
||||
"SpeechTimeoutUserTurnStopStrategy",
|
||||
"UserTurnStoppedParams",
|
||||
"TranscriptionUserTurnStopStrategy",
|
||||
"TurnAnalyzerUserTurnStopStrategy",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,202 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Speech timeout-based user turn stop strategy."""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
STTMetadataFrame,
|
||||
TranscriptionFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.turns.user_stop.base_user_turn_stop_strategy import BaseUserTurnStopStrategy
|
||||
from pipecat.utils.asyncio.task_manager import BaseTaskManager
|
||||
|
||||
|
||||
class SpeechTimeoutUserTurnStopStrategy(BaseUserTurnStopStrategy):
|
||||
"""User turn stop strategy that uses a configurable timeout to determine if the user is done speaking.
|
||||
|
||||
After the user stops speaking (detected by VAD), this strategy waits for a
|
||||
configurable timeout before triggering the end of the user's turn. The
|
||||
timeout accounts for two factors:
|
||||
|
||||
- user_speech_timeout: Time to wait for the user to potentially say more
|
||||
after they pause.
|
||||
- stt_timeout: The P99 time for the STT service to return a transcription
|
||||
after the user stops speaking, adjusted by the VAD stop_secs.
|
||||
|
||||
For services that support finalization (TranscriptionFrame.finalized=True),
|
||||
the turn can be triggered immediately once the finalized transcript is
|
||||
received and the user resume speaking timeout has elapsed.
|
||||
"""
|
||||
|
||||
def __init__(self, *, user_speech_timeout: float = 0.6, **kwargs):
|
||||
"""Initialize the speech timeout-based user turn stop strategy.
|
||||
|
||||
Args:
|
||||
user_speech_timeout: Time to wait for the user to potentially
|
||||
say more after they pause speaking. Defaults to 0.6 seconds.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._user_speech_timeout = user_speech_timeout
|
||||
self._stt_timeout: float = 0.0 # STT P99 latency from STTMetadataFrame
|
||||
self._stop_secs: float = 0.0 # VAD stop_secs from VADUserStoppedSpeakingFrame
|
||||
|
||||
self._text = ""
|
||||
self._vad_user_speaking = False
|
||||
self._transcript_finalized = False
|
||||
self._vad_stopped_time: Optional[float] = None
|
||||
self._timeout_task: Optional[asyncio.Task] = None
|
||||
|
||||
async def reset(self):
|
||||
"""Reset the strategy to its initial state."""
|
||||
await super().reset()
|
||||
self._text = ""
|
||||
self._vad_user_speaking = False
|
||||
self._transcript_finalized = False
|
||||
self._vad_stopped_time = None
|
||||
|
||||
async def setup(self, task_manager: BaseTaskManager):
|
||||
"""Initialize the strategy with the given task manager.
|
||||
|
||||
Args:
|
||||
task_manager: The task manager to be associated with this instance.
|
||||
"""
|
||||
await super().setup(task_manager)
|
||||
|
||||
async def cleanup(self):
|
||||
"""Cleanup the strategy."""
|
||||
await super().cleanup()
|
||||
if self._timeout_task:
|
||||
await self.task_manager.cancel_task(self._timeout_task)
|
||||
self._timeout_task = None
|
||||
|
||||
async def process_frame(self, frame: Frame):
|
||||
"""Process an incoming frame to update strategy state.
|
||||
|
||||
Updates internal transcription text and VAD state. The user end turn
|
||||
will be triggered when appropriate based on the collected frames.
|
||||
|
||||
Args:
|
||||
frame: The frame to be analyzed.
|
||||
|
||||
"""
|
||||
if isinstance(frame, STTMetadataFrame):
|
||||
self._stt_timeout = frame.ttfs_p99_latency
|
||||
elif isinstance(frame, VADUserStartedSpeakingFrame):
|
||||
await self._handle_vad_user_started_speaking(frame)
|
||||
elif isinstance(frame, VADUserStoppedSpeakingFrame):
|
||||
await self._handle_vad_user_stopped_speaking(frame)
|
||||
elif isinstance(frame, TranscriptionFrame):
|
||||
await self._handle_transcription(frame)
|
||||
|
||||
async def _handle_vad_user_started_speaking(self, _: VADUserStartedSpeakingFrame):
|
||||
"""Handle when the VAD indicates the user is speaking."""
|
||||
self._vad_user_speaking = True
|
||||
self._transcript_finalized = False
|
||||
self._vad_stopped_time = None
|
||||
# Cancel any pending timeout
|
||||
if self._timeout_task:
|
||||
await self.task_manager.cancel_task(self._timeout_task)
|
||||
self._timeout_task = None
|
||||
|
||||
async def _handle_vad_user_stopped_speaking(self, frame: VADUserStoppedSpeakingFrame):
|
||||
"""Handle when the VAD indicates the user has stopped speaking."""
|
||||
self._vad_user_speaking = False
|
||||
self._stop_secs = frame.stop_secs
|
||||
self._vad_stopped_time = frame.timestamp
|
||||
|
||||
# Start the timeout task
|
||||
timeout = self._calculate_timeout()
|
||||
self._timeout_task = self.task_manager.create_task(
|
||||
self._timeout_handler(timeout), f"{self}::_timeout_handler"
|
||||
)
|
||||
|
||||
async def _handle_transcription(self, frame: TranscriptionFrame):
|
||||
"""Handle user transcription."""
|
||||
self._text += frame.text
|
||||
if frame.finalized:
|
||||
self._transcript_finalized = True
|
||||
# For finalized transcripts, check if we can trigger early
|
||||
await self._maybe_trigger_user_turn_stopped()
|
||||
|
||||
# Fallback: handle transcripts when no VAD stop was received.
|
||||
# This handles edge cases where transcripts arrive without VAD firing.
|
||||
# _vad_stopped_time is None means VAD stopped hasn't been received yet.
|
||||
# In fallback mode, reset timeout on each transcript to wait for inactivity.
|
||||
if not self._vad_user_speaking and self._vad_stopped_time is None:
|
||||
# Cancel existing fallback timeout if any
|
||||
if self._timeout_task:
|
||||
await self.task_manager.cancel_task(self._timeout_task)
|
||||
timeout = self._calculate_timeout()
|
||||
self._timeout_task = self.task_manager.create_task(
|
||||
self._timeout_handler(timeout), f"{self}::_timeout_handler"
|
||||
)
|
||||
|
||||
def _calculate_timeout(self) -> float:
|
||||
"""Calculate the timeout value based on current state.
|
||||
|
||||
Returns:
|
||||
The timeout in seconds to wait after VAD stopped speaking.
|
||||
"""
|
||||
# Adjust STT timeout by VAD stop_secs since that time has already elapsed
|
||||
effective_stt_wait = max(0, self._stt_timeout - self._stop_secs)
|
||||
|
||||
# If transcript is already finalized, we don't need to wait for STT
|
||||
if self._transcript_finalized:
|
||||
return self._user_speech_timeout
|
||||
|
||||
return max(effective_stt_wait, self._user_speech_timeout)
|
||||
|
||||
async def _timeout_handler(self, timeout: float):
|
||||
"""Wait for the timeout then trigger user turn stopped if conditions met.
|
||||
|
||||
Args:
|
||||
timeout: The timeout in seconds to wait.
|
||||
"""
|
||||
try:
|
||||
await asyncio.sleep(timeout)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
finally:
|
||||
self._timeout_task = None
|
||||
|
||||
await self._maybe_trigger_user_turn_stopped()
|
||||
|
||||
async def _maybe_trigger_user_turn_stopped(self):
|
||||
"""Trigger user turn stopped if conditions are met.
|
||||
|
||||
Conditions:
|
||||
- User is not currently speaking
|
||||
- We have transcription text
|
||||
- Either the timeout has elapsed OR we have a finalized transcript
|
||||
and user_speech_timeout has elapsed
|
||||
"""
|
||||
if self._vad_user_speaking or not self._text:
|
||||
return
|
||||
|
||||
# For finalized transcripts, check if user_speech_timeout has elapsed.
|
||||
# If elapsed, trigger user turn stopped immediately. Else, wait for user resume
|
||||
# speaking timeout.
|
||||
if self._transcript_finalized and self._vad_stopped_time is not None:
|
||||
elapsed = time.time() - self._vad_stopped_time
|
||||
if elapsed >= self._user_speech_timeout:
|
||||
# Cancel any remaining timeout since we're triggering now
|
||||
if self._timeout_task:
|
||||
await self.task_manager.cancel_task(self._timeout_task)
|
||||
self._timeout_task = None
|
||||
await self.trigger_user_turn_stopped()
|
||||
return
|
||||
|
||||
# For non-finalized, only trigger if timeout task has completed
|
||||
if self._timeout_task is None:
|
||||
await self.trigger_user_turn_stopped()
|
||||
@@ -4,124 +4,28 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Transcription time-based user turn stop strategy."""
|
||||
"""Transcription-based user turn stop strategy (deprecated).
|
||||
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
.. deprecated:: 0.0.102
|
||||
This module is deprecated. Please use
|
||||
``pipecat.turns.user_stop.speech_timeout_user_turn_stop_strategy.SpeechTimeoutUserTurnStopStrategy``
|
||||
instead.
|
||||
"""
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
TranscriptionFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
import warnings
|
||||
|
||||
from pipecat.turns.user_stop.speech_timeout_user_turn_stop_strategy import (
|
||||
SpeechTimeoutUserTurnStopStrategy,
|
||||
)
|
||||
from pipecat.turns.user_stop.base_user_turn_stop_strategy import BaseUserTurnStopStrategy
|
||||
from pipecat.utils.asyncio.task_manager import BaseTaskManager
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"TranscriptionUserTurnStopStrategy is deprecated. "
|
||||
"Please use SpeechTimeoutUserTurnStopStrategy from "
|
||||
"pipecat.turns.user_stop.speech_timeout_user_turn_stop_strategy instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
class TranscriptionUserTurnStopStrategy(BaseUserTurnStopStrategy):
|
||||
"""User turn stop strategy based on transcriptions.
|
||||
|
||||
This strategy assumes the user stops speaking once a transcription has been
|
||||
received. It handles multiple or delayed transcription frames gracefully.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, *, timeout: float = 0.5, **kwargs):
|
||||
"""Initialize the transcription-based user turn stop strategy.
|
||||
|
||||
Args:
|
||||
timeout: A short delay used internally to handle consecutive or
|
||||
slightly delayed transcriptions.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._timeout = timeout
|
||||
self._text = ""
|
||||
self._vad_user_speaking = False
|
||||
self._seen_interim_results = False
|
||||
self._event = asyncio.Event()
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
|
||||
async def reset(self):
|
||||
"""Reset the strategy to its initial state."""
|
||||
await super().reset()
|
||||
self._text = ""
|
||||
self._vad_user_speaking = False
|
||||
self._seen_interim_results = False
|
||||
self._event.clear()
|
||||
|
||||
async def setup(self, task_manager: BaseTaskManager):
|
||||
"""Initialize the strategy with the given task manager.
|
||||
|
||||
Args:
|
||||
task_manager: The task manager to be associated with this instance.
|
||||
"""
|
||||
await super().setup(task_manager)
|
||||
self._task = task_manager.create_task(self._task_handler(), f"{self}::_task_handler")
|
||||
|
||||
async def cleanup(self):
|
||||
"""Cleanup the strategy."""
|
||||
await super().cleanup()
|
||||
if self._task:
|
||||
await self.task_manager.cancel_task(self._task)
|
||||
self._task = None
|
||||
|
||||
async def process_frame(self, frame: Frame):
|
||||
"""Process an incoming frame to update strategy state.
|
||||
|
||||
Updates internal transcription text and VAD state. The user end turn
|
||||
will be triggered when appropriate based on the collected frames.
|
||||
|
||||
Args:
|
||||
frame: The frame to be analyzed.
|
||||
|
||||
"""
|
||||
if isinstance(frame, VADUserStartedSpeakingFrame):
|
||||
await self._handle_vad_user_started_speaking(frame)
|
||||
elif isinstance(frame, VADUserStoppedSpeakingFrame):
|
||||
await self._handle_vad_user_stopped_speaking(frame)
|
||||
elif isinstance(frame, InterimTranscriptionFrame):
|
||||
await self._handle_interim_transcription(frame)
|
||||
elif isinstance(frame, TranscriptionFrame):
|
||||
await self._handle_transcription(frame)
|
||||
|
||||
async def _handle_vad_user_started_speaking(self, _: VADUserStartedSpeakingFrame):
|
||||
"""Handle when the VAD indicates the user is speaking."""
|
||||
self._vad_user_speaking = True
|
||||
|
||||
async def _handle_vad_user_stopped_speaking(self, _: VADUserStoppedSpeakingFrame):
|
||||
"""Handle when the VAD indicates the user has stopped speaking."""
|
||||
self._vad_user_speaking = False
|
||||
await self._maybe_trigger_user_turn_stopped()
|
||||
|
||||
async def _handle_interim_transcription(self, frame: InterimTranscriptionFrame):
|
||||
self._seen_interim_results = True
|
||||
|
||||
async def _handle_transcription(self, frame: TranscriptionFrame):
|
||||
"""Handle user transcription."""
|
||||
self._text += frame.text
|
||||
# We just got a final result, so let's reset interim results.
|
||||
self._seen_interim_results = False
|
||||
# Reset aggregation timer.
|
||||
self._event.set()
|
||||
|
||||
async def _task_handler(self):
|
||||
"""Asynchronously monitor transcriptions and trigger user end turn when ready.
|
||||
|
||||
If transcription text exists and the user is not currently speaking,
|
||||
triggers the user end turn. Handles multiple or delayed transcriptions
|
||||
gracefully.
|
||||
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.wait_for(self._event.wait(), timeout=self._timeout)
|
||||
self._event.clear()
|
||||
except asyncio.TimeoutError:
|
||||
await self._maybe_trigger_user_turn_stopped()
|
||||
|
||||
async def _maybe_trigger_user_turn_stopped(self):
|
||||
if not self._vad_user_speaking and not self._seen_interim_results and self._text:
|
||||
await self.trigger_user_turn_stopped()
|
||||
TranscriptionUserTurnStopStrategy = SpeechTimeoutUserTurnStopStrategy
|
||||
|
||||
@@ -13,10 +13,10 @@ from pipecat.audio.turn.base_turn_analyzer import BaseTurnAnalyzer, EndOfTurnSta
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
InterimTranscriptionFrame,
|
||||
MetricsFrame,
|
||||
SpeechControlParamsFrame,
|
||||
StartFrame,
|
||||
STTMetadataFrame,
|
||||
TranscriptionFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
@@ -27,30 +27,38 @@ from pipecat.utils.asyncio.task_manager import BaseTaskManager
|
||||
|
||||
|
||||
class TurnAnalyzerUserTurnStopStrategy(BaseUserTurnStopStrategy):
|
||||
"""User turn stop strategy using a turn detection model to detect end of user turn.
|
||||
"""User turn stop strategy that uses a turn detection model to determine if the user is done speaking.
|
||||
|
||||
This strategy uses the turn detection models to determine when the user has
|
||||
finished speaking, combining audio, VAD, and transcription frames. Once the
|
||||
turn is considered complete, the user end of turn is triggered.
|
||||
This strategy feeds audio, VAD, and transcription frames to a turn
|
||||
detection model (``BaseTurnAnalyzer``) that predicts when the user has
|
||||
finished their turn. Once the model indicates the turn is complete, the
|
||||
strategy waits for a final transcription before triggering the end of
|
||||
the user's turn.
|
||||
|
||||
For services that support finalization (TranscriptionFrame.finalized=True),
|
||||
the turn can be triggered immediately once the finalized transcript is
|
||||
received. Otherwise, an STT timeout (adjusted by VAD stop_secs) is used
|
||||
as a fallback.
|
||||
"""
|
||||
|
||||
def __init__(self, *, turn_analyzer: BaseTurnAnalyzer, timeout: float = 0.5, **kwargs):
|
||||
def __init__(self, *, turn_analyzer: BaseTurnAnalyzer, **kwargs):
|
||||
"""Initialize the user turn stop strategy.
|
||||
|
||||
Args:
|
||||
turn_analyzer: The turn detection analyzer instance to detect end of user turn.
|
||||
timeout: Short delay used internally to handle frame timing and event triggering.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._turn_analyzer = turn_analyzer
|
||||
self._timeout = timeout
|
||||
self._stt_timeout: float = 0.0 # STT P99 latency from STTMetadataFrame
|
||||
self._stop_secs: float = 0.0 # VAD stop_secs from VADUserStoppedSpeakingFrame
|
||||
|
||||
self._text = ""
|
||||
self._turn_complete = False
|
||||
self._vad_user_speaking = False
|
||||
self._event = asyncio.Event()
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._vad_stopped_time: Optional[float] = None # Track when VAD stopped was received
|
||||
self._transcript_finalized = False
|
||||
self._timeout_task: Optional[asyncio.Task] = None
|
||||
|
||||
async def reset(self):
|
||||
"""Reset the strategy to its initial state."""
|
||||
@@ -58,7 +66,8 @@ class TurnAnalyzerUserTurnStopStrategy(BaseUserTurnStopStrategy):
|
||||
self._text = ""
|
||||
self._turn_complete = False
|
||||
self._vad_user_speaking = False
|
||||
self._event.clear()
|
||||
self._vad_stopped_time = None
|
||||
self._transcript_finalized = False
|
||||
|
||||
async def setup(self, task_manager: BaseTaskManager):
|
||||
"""Initialize the strategy with the given task manager.
|
||||
@@ -67,15 +76,14 @@ class TurnAnalyzerUserTurnStopStrategy(BaseUserTurnStopStrategy):
|
||||
task_manager: The task manager to be associated with this instance.
|
||||
"""
|
||||
await super().setup(task_manager)
|
||||
self._task = task_manager.create_task(self._task_handler(), f"{self}::_task_handler")
|
||||
|
||||
async def cleanup(self):
|
||||
"""Cleanup the strategy."""
|
||||
await super().cleanup()
|
||||
await self._turn_analyzer.cleanup()
|
||||
if self._task:
|
||||
await self.task_manager.cancel_task(self._task)
|
||||
self._task = None
|
||||
if self._timeout_task:
|
||||
await self.task_manager.cancel_task(self._timeout_task)
|
||||
self._timeout_task = None
|
||||
|
||||
async def process_frame(self, frame: Frame):
|
||||
"""Process an incoming frame to update the turn analyzer and strategy state.
|
||||
@@ -87,8 +95,8 @@ class TurnAnalyzerUserTurnStopStrategy(BaseUserTurnStopStrategy):
|
||||
|
||||
if isinstance(frame, StartFrame):
|
||||
await self._start(frame)
|
||||
elif isinstance(frame, SpeechControlParamsFrame):
|
||||
await self._handle_speech_control_params(frame)
|
||||
elif isinstance(frame, STTMetadataFrame):
|
||||
self._stt_timeout = frame.ttfs_p99_latency
|
||||
elif isinstance(frame, VADUserStartedSpeakingFrame):
|
||||
await self._handle_vad_user_started_speaking(frame)
|
||||
elif isinstance(frame, VADUserStoppedSpeakingFrame):
|
||||
@@ -97,25 +105,12 @@ class TurnAnalyzerUserTurnStopStrategy(BaseUserTurnStopStrategy):
|
||||
await self._handle_input_audio(frame)
|
||||
elif isinstance(frame, TranscriptionFrame):
|
||||
await self._handle_transcription(frame)
|
||||
elif isinstance(frame, InterimTranscriptionFrame):
|
||||
await self._handle_interim_transcription(frame)
|
||||
|
||||
async def _start(self, frame: StartFrame):
|
||||
"""Process the start frame to configure the turn analyzer."""
|
||||
self._turn_analyzer.set_sample_rate(frame.audio_in_sample_rate)
|
||||
await self.broadcast_frame(SpeechControlParamsFrame, turn_params=self._turn_analyzer.params)
|
||||
|
||||
async def _handle_speech_control_params(self, frame: SpeechControlParamsFrame):
|
||||
"""Sync Smart Turn pre-speech buffering with VAD start delay.
|
||||
|
||||
`VADUserStartedSpeakingFrame` is emitted only once VAD has confirmed speech
|
||||
(after `vad_params.start_secs`). Smart Turn should still include the initial
|
||||
audio collected during that confirmation window, so we let the analyzer know
|
||||
when this value has changed.
|
||||
"""
|
||||
if frame.vad_params:
|
||||
self._turn_analyzer.update_vad_start_secs(frame.vad_params.start_secs)
|
||||
|
||||
async def _handle_input_audio(self, frame: InputAudioRawFrame):
|
||||
"""Handle input audio to check if the turn is completed."""
|
||||
state = self._turn_analyzer.append_audio(frame.audio, self._vad_user_speaking)
|
||||
@@ -127,14 +122,24 @@ class TurnAnalyzerUserTurnStopStrategy(BaseUserTurnStopStrategy):
|
||||
self._turn_complete = True
|
||||
await self._maybe_trigger_user_turn_stopped()
|
||||
|
||||
async def _handle_vad_user_started_speaking(self, _: VADUserStartedSpeakingFrame):
|
||||
async def _handle_vad_user_started_speaking(self, frame: VADUserStartedSpeakingFrame):
|
||||
"""Handle when the VAD indicates the user is speaking."""
|
||||
# Sync Smart Turn pre-speech buffering with VAD start delay
|
||||
self._turn_analyzer.update_vad_start_secs(frame.start_secs)
|
||||
self._turn_complete = False
|
||||
self._vad_user_speaking = True
|
||||
self._vad_stopped_time = None
|
||||
self._transcript_finalized = False
|
||||
# Cancel any pending timeout
|
||||
if self._timeout_task:
|
||||
await self.task_manager.cancel_task(self._timeout_task)
|
||||
self._timeout_task = None
|
||||
|
||||
async def _handle_vad_user_stopped_speaking(self, _: VADUserStoppedSpeakingFrame):
|
||||
async def _handle_vad_user_stopped_speaking(self, frame: VADUserStoppedSpeakingFrame):
|
||||
"""Handle when the VAD indicates the user has stopped speaking."""
|
||||
self._vad_user_speaking = False
|
||||
self._stop_secs = frame.stop_secs
|
||||
self._vad_stopped_time = frame.timestamp
|
||||
|
||||
state, prediction = await self._turn_analyzer.analyze_end_of_turn()
|
||||
await self._handle_prediction_result(prediction)
|
||||
@@ -143,41 +148,76 @@ class TurnAnalyzerUserTurnStopStrategy(BaseUserTurnStopStrategy):
|
||||
# wait for transcriptions.
|
||||
self._turn_complete = state == EndOfTurnState.COMPLETE
|
||||
|
||||
# Reset transcription timeout.
|
||||
self._event.set()
|
||||
# Start the STT timeout (adjusted by VAD stop_secs since that time already elapsed)
|
||||
timeout = max(0, self._stt_timeout - self._stop_secs)
|
||||
self._timeout_task = self.task_manager.create_task(
|
||||
self._timeout_handler(timeout), f"{self}::_timeout_handler"
|
||||
)
|
||||
|
||||
async def _handle_transcription(self, frame: TranscriptionFrame):
|
||||
"""Handle user transcription."""
|
||||
# We don't really care about the content.
|
||||
self._text = frame.text
|
||||
# Reset transcription timeout.
|
||||
self._event.set()
|
||||
if frame.finalized:
|
||||
self._transcript_finalized = True
|
||||
# For finalized transcripts, trigger immediately if turn is complete
|
||||
await self._maybe_trigger_user_turn_stopped()
|
||||
|
||||
async def _handle_interim_transcription(self, frame: InterimTranscriptionFrame):
|
||||
"""Handle user interim transcription."""
|
||||
# Reset transcription timeout.
|
||||
self._event.set()
|
||||
# Fallback: handle transcripts when no VAD stop was received.
|
||||
# This handles edge cases where transcripts arrive without VAD firing.
|
||||
# _vad_stopped_time is None means VAD stopped hasn't been received yet.
|
||||
# In fallback mode, reset timeout on each transcript to wait for inactivity.
|
||||
if not self._vad_user_speaking and self._vad_stopped_time is None:
|
||||
# Cancel existing fallback timeout if any
|
||||
if self._timeout_task:
|
||||
await self.task_manager.cancel_task(self._timeout_task)
|
||||
# Without VAD/turn analyzer data, assume turn is complete
|
||||
self._turn_complete = True
|
||||
timeout = max(0, self._stt_timeout - self._stop_secs)
|
||||
self._timeout_task = self.task_manager.create_task(
|
||||
self._timeout_handler(timeout), f"{self}::_timeout_handler"
|
||||
)
|
||||
|
||||
async def _handle_prediction_result(self, result: Optional[MetricsData]):
|
||||
"""Handle a prediction result event from the turn analyzer."""
|
||||
if result:
|
||||
await self.push_frame(MetricsFrame(data=[result]))
|
||||
|
||||
async def _task_handler(self):
|
||||
"""Asynchronously monitor events and trigger user end of turn when appropriate.
|
||||
|
||||
If we have not received a transcription in the specified amount of time
|
||||
(and we initially received one) and the turn analyzer said the turn is
|
||||
done, then the user is done speaking.
|
||||
async def _timeout_handler(self, timeout: float):
|
||||
"""Wait for the timeout then trigger user turn stopped if conditions met.
|
||||
|
||||
Args:
|
||||
timeout: The timeout in seconds to wait.
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.wait_for(self._event.wait(), timeout=self._timeout)
|
||||
self._event.clear()
|
||||
except asyncio.TimeoutError:
|
||||
await self._maybe_trigger_user_turn_stopped()
|
||||
try:
|
||||
await asyncio.sleep(timeout)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
finally:
|
||||
self._timeout_task = None
|
||||
|
||||
await self._maybe_trigger_user_turn_stopped()
|
||||
|
||||
async def _maybe_trigger_user_turn_stopped(self):
|
||||
if self._text and self._turn_complete:
|
||||
"""Trigger user turn stopped if conditions are met.
|
||||
|
||||
Conditions:
|
||||
- We have transcription text
|
||||
- Turn analyzer indicates turn is complete
|
||||
- Either the timeout has elapsed OR we have a finalized transcript
|
||||
"""
|
||||
if not self._text or not self._turn_complete:
|
||||
return
|
||||
|
||||
# For finalized transcripts, trigger immediately
|
||||
if self._transcript_finalized:
|
||||
# Cancel any remaining timeout since we're triggering now
|
||||
if self._timeout_task:
|
||||
await self.task_manager.cancel_task(self._timeout_task)
|
||||
self._timeout_task = None
|
||||
await self.trigger_user_turn_stopped()
|
||||
return
|
||||
|
||||
# For non-finalized, only trigger if timeout task has completed
|
||||
if self._timeout_task is None:
|
||||
await self.trigger_user_turn_stopped()
|
||||
|
||||
@@ -18,7 +18,7 @@ from pipecat.turns.user_start import (
|
||||
from pipecat.turns.user_stop import (
|
||||
BaseUserTurnStopStrategy,
|
||||
ExternalUserTurnStopStrategy,
|
||||
TranscriptionUserTurnStopStrategy,
|
||||
SpeechTimeoutUserTurnStopStrategy,
|
||||
)
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ class UserTurnStrategies:
|
||||
If no strategies are specified, the following defaults are used:
|
||||
|
||||
start: [VADUserTurnStartStrategy, TranscriptionUserTurnStartStrategy]
|
||||
stop: [TranscriptionUserTurnStopStrategy]
|
||||
stop: [SpeechTimeoutUserTurnStopStrategy]
|
||||
|
||||
Attributes:
|
||||
start: A list of user turn start strategies used to detect when
|
||||
@@ -46,7 +46,7 @@ class UserTurnStrategies:
|
||||
if not self.start:
|
||||
self.start = [VADUserTurnStartStrategy(), TranscriptionUserTurnStartStrategy()]
|
||||
if not self.stop:
|
||||
self.stop = [TranscriptionUserTurnStopStrategy()]
|
||||
self.stop = [SpeechTimeoutUserTurnStopStrategy()]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -41,7 +41,7 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
||||
)
|
||||
from pipecat.tests.utils import SleepFrame, run_test
|
||||
from pipecat.turns.user_mute import FirstSpeechUserMuteStrategy, FunctionCallUserMuteStrategy
|
||||
from pipecat.turns.user_stop import TranscriptionUserTurnStopStrategy
|
||||
from pipecat.turns.user_stop import SpeechTimeoutUserTurnStopStrategy
|
||||
from pipecat.turns.user_turn_strategies import UserTurnStrategies
|
||||
|
||||
USER_TURN_STOP_TIMEOUT = 0.2
|
||||
@@ -149,7 +149,16 @@ class TestLLMUserAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def test_default_user_turn_strategies(self):
|
||||
context = LLMContext()
|
||||
user_aggregator = LLMUserAggregator(context)
|
||||
user_aggregator = LLMUserAggregator(
|
||||
context,
|
||||
params=LLMUserAggregatorParams(
|
||||
user_turn_strategies=UserTurnStrategies(
|
||||
stop=[
|
||||
SpeechTimeoutUserTurnStopStrategy(user_speech_timeout=TRANSCRIPTION_TIMEOUT)
|
||||
],
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
should_start = None
|
||||
should_stop = None
|
||||
@@ -173,6 +182,8 @@ class TestLLMUserAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
TranscriptionFrame(text="Hello!", user_id="", timestamp="now"),
|
||||
SleepFrame(),
|
||||
VADUserStoppedSpeakingFrame(),
|
||||
# Wait for user_speech_timeout to elapse
|
||||
SleepFrame(sleep=TRANSCRIPTION_TIMEOUT + 0.1),
|
||||
]
|
||||
expected_down_frames = [
|
||||
VADUserStartedSpeakingFrame,
|
||||
@@ -241,7 +252,9 @@ class TestLLMUserAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
context,
|
||||
params=LLMUserAggregatorParams(
|
||||
user_turn_strategies=UserTurnStrategies(
|
||||
stop=[TranscriptionUserTurnStopStrategy(timeout=TRANSCRIPTION_TIMEOUT)],
|
||||
stop=[
|
||||
SpeechTimeoutUserTurnStopStrategy(user_speech_timeout=TRANSCRIPTION_TIMEOUT)
|
||||
],
|
||||
),
|
||||
user_turn_stop_timeout=USER_TURN_STOP_TIMEOUT,
|
||||
),
|
||||
@@ -270,13 +283,13 @@ class TestLLMUserAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
pipeline = Pipeline([user_aggregator])
|
||||
|
||||
# Transcript arrives before VAD stop, then we wait for user_speech_timeout
|
||||
frames_to_send = [
|
||||
VADUserStartedSpeakingFrame(),
|
||||
VADUserStoppedSpeakingFrame(),
|
||||
SleepFrame(sleep=USER_TURN_STOP_TIMEOUT - 0.1),
|
||||
TranscriptionFrame(text="Hello!", user_id="", timestamp="now"),
|
||||
SleepFrame(sleep=USER_TURN_STOP_TIMEOUT - 0.1),
|
||||
SleepFrame(sleep=TRANSCRIPTION_TIMEOUT),
|
||||
VADUserStoppedSpeakingFrame(),
|
||||
# Wait for user_speech_timeout (TRANSCRIPTION_TIMEOUT=0.1s) to elapse
|
||||
SleepFrame(sleep=TRANSCRIPTION_TIMEOUT + 0.05),
|
||||
]
|
||||
await run_test(
|
||||
pipeline,
|
||||
|
||||
@@ -12,6 +12,9 @@ from dataclasses import dataclass
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
ManuallySwitchServiceFrame,
|
||||
RequestMetadataFrame,
|
||||
ServiceMetadataFrame,
|
||||
StartFrame,
|
||||
SystemFrame,
|
||||
TextFrame,
|
||||
)
|
||||
@@ -54,6 +57,47 @@ class MockFrameProcessor(FrameProcessor):
|
||||
self.frame_count = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockMetadataFrame(ServiceMetadataFrame):
|
||||
"""A mock metadata frame for testing ServiceMetadataFrame handling."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class MockMetadataService(FrameProcessor):
|
||||
"""A mock service that emits ServiceMetadataFrame like STT services.
|
||||
|
||||
Pushes MockMetadataFrame on StartFrame and RequestMetadataFrame.
|
||||
"""
|
||||
|
||||
def __init__(self, test_name: str, **kwargs):
|
||||
super().__init__(name=test_name, **kwargs)
|
||||
self.test_name = test_name
|
||||
self.processed_frames = []
|
||||
self.metadata_push_count = 0
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
self.processed_frames.append(frame)
|
||||
|
||||
if isinstance(frame, StartFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
await self._push_metadata()
|
||||
elif isinstance(frame, RequestMetadataFrame):
|
||||
# Don't push RequestMetadataFrame downstream (it's internal)
|
||||
await self._push_metadata()
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _push_metadata(self):
|
||||
self.metadata_push_count += 1
|
||||
await self.push_frame(MockMetadataFrame(service_name=self.test_name))
|
||||
|
||||
def reset_counters(self):
|
||||
self.processed_frames = []
|
||||
self.metadata_push_count = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class DummySystemFrame(SystemFrame):
|
||||
"""A dummy system frame for testing purposes."""
|
||||
@@ -336,5 +380,84 @@ class TestServiceSwitcher(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(switcher2_service2_texts[0].text, "After switching second switcher")
|
||||
|
||||
|
||||
class TestServiceSwitcherMetadata(unittest.IsolatedAsyncioTestCase):
|
||||
"""Test cases for ServiceMetadataFrame handling in ServiceSwitcher."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures with mock metadata services."""
|
||||
self.service1 = MockMetadataService("service1")
|
||||
self.service2 = MockMetadataService("service2")
|
||||
self.services = [self.service1, self.service2]
|
||||
|
||||
async def test_only_active_service_metadata_at_startup(self):
|
||||
"""Test that only the active service's metadata leaves the ServiceSwitcher at startup."""
|
||||
switcher = ServiceSwitcher(self.services, ServiceSwitcherStrategyManual)
|
||||
|
||||
# Run the pipeline (StartFrame triggers metadata emission)
|
||||
output_frames = []
|
||||
|
||||
async def capture_frame(frame: Frame):
|
||||
output_frames.append(frame)
|
||||
|
||||
await run_test(
|
||||
switcher,
|
||||
frames_to_send=[TextFrame(text="test")],
|
||||
expected_down_frames=[MockMetadataFrame, TextFrame],
|
||||
expected_up_frames=[],
|
||||
)
|
||||
|
||||
# Both services push metadata internally on StartFrame, but only the
|
||||
# active service's metadata passes through the filter
|
||||
self.assertEqual(self.service1.metadata_push_count, 1) # StartFrame (passes filter)
|
||||
self.assertEqual(self.service2.metadata_push_count, 1) # StartFrame (blocked by filter)
|
||||
|
||||
async def test_metadata_emitted_on_service_switch(self):
|
||||
"""Test that switching services triggers metadata emission from the new active service."""
|
||||
switcher = ServiceSwitcher(self.services, ServiceSwitcherStrategyManual)
|
||||
|
||||
# Reset counters after startup
|
||||
self.service1.reset_counters()
|
||||
self.service2.reset_counters()
|
||||
|
||||
await run_test(
|
||||
switcher,
|
||||
frames_to_send=[
|
||||
TextFrame(text="before switch"),
|
||||
ManuallySwitchServiceFrame(service=self.service2),
|
||||
TextFrame(text="after switch"),
|
||||
],
|
||||
expected_down_frames=[
|
||||
MockMetadataFrame, # From startup (service1)
|
||||
TextFrame,
|
||||
ManuallySwitchServiceFrame,
|
||||
TextFrame,
|
||||
MockMetadataFrame, # From service2 after switch
|
||||
],
|
||||
expected_up_frames=[],
|
||||
)
|
||||
|
||||
# service2 should have received RequestMetadataFrame after becoming active
|
||||
request_frames = [
|
||||
f for f in self.service2.processed_frames if isinstance(f, RequestMetadataFrame)
|
||||
]
|
||||
self.assertEqual(len(request_frames), 1)
|
||||
|
||||
async def test_inactive_service_metadata_blocked(self):
|
||||
"""Test that metadata from inactive services is blocked."""
|
||||
switcher = ServiceSwitcher(self.services, ServiceSwitcherStrategyManual)
|
||||
|
||||
# Run and collect output frames
|
||||
await run_test(
|
||||
switcher,
|
||||
frames_to_send=[TextFrame(text="test")],
|
||||
expected_down_frames=[MockMetadataFrame, TextFrame],
|
||||
expected_up_frames=[],
|
||||
)
|
||||
|
||||
# service2 pushed metadata on StartFrame, but it should have been blocked
|
||||
self.assertGreaterEqual(self.service2.metadata_push_count, 1)
|
||||
# Only one MockMetadataFrame should have left (from service1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -18,11 +18,13 @@ from pipecat.frames.frames import (
|
||||
from pipecat.turns.user_start.min_words_user_turn_start_strategy import (
|
||||
MinWordsUserTurnStartStrategy,
|
||||
)
|
||||
from pipecat.turns.user_stop import SpeechTimeoutUserTurnStopStrategy
|
||||
from pipecat.turns.user_turn_controller import UserTurnController
|
||||
from pipecat.turns.user_turn_strategies import ExternalUserTurnStrategies, UserTurnStrategies
|
||||
from pipecat.utils.asyncio.task_manager import TaskManager, TaskManagerParams
|
||||
|
||||
USER_TURN_STOP_TIMEOUT = 0.2
|
||||
TRANSCRIPTION_TIMEOUT = 0.1
|
||||
|
||||
|
||||
class TestUserTurnController(unittest.IsolatedAsyncioTestCase):
|
||||
@@ -31,7 +33,11 @@ class TestUserTurnController(unittest.IsolatedAsyncioTestCase):
|
||||
self.task_manager.setup(TaskManagerParams(loop=asyncio.get_running_loop()))
|
||||
|
||||
async def test_default_user_turn_strategies(self):
|
||||
controller = UserTurnController(user_turn_strategies=UserTurnStrategies())
|
||||
controller = UserTurnController(
|
||||
user_turn_strategies=UserTurnStrategies(
|
||||
stop=[SpeechTimeoutUserTurnStopStrategy(user_speech_timeout=TRANSCRIPTION_TIMEOUT)],
|
||||
)
|
||||
)
|
||||
|
||||
await controller.setup(self.task_manager)
|
||||
|
||||
@@ -60,6 +66,8 @@ class TestUserTurnController(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
await controller.process_frame(VADUserStoppedSpeakingFrame())
|
||||
self.assertTrue(should_start)
|
||||
# Wait for user_speech_timeout to elapse
|
||||
await asyncio.sleep(TRANSCRIPTION_TIMEOUT + 0.1)
|
||||
self.assertTrue(should_stop)
|
||||
|
||||
async def test_user_turn_start_reset(self):
|
||||
|
||||
@@ -16,7 +16,7 @@ from pipecat.frames.frames import (
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.tests.utils import SleepFrame, run_test
|
||||
from pipecat.turns.user_stop import TranscriptionUserTurnStopStrategy
|
||||
from pipecat.turns.user_stop import SpeechTimeoutUserTurnStopStrategy
|
||||
from pipecat.turns.user_turn_processor import UserTurnProcessor
|
||||
from pipecat.turns.user_turn_strategies import UserTurnStrategies
|
||||
|
||||
@@ -26,7 +26,11 @@ TRANSCRIPTION_TIMEOUT = 0.1
|
||||
|
||||
class TestUserTurnProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_default_user_turn_strategies(self):
|
||||
user_turn_processor = UserTurnProcessor(user_turn_strategies=UserTurnStrategies())
|
||||
user_turn_processor = UserTurnProcessor(
|
||||
user_turn_strategies=UserTurnStrategies(
|
||||
stop=[SpeechTimeoutUserTurnStopStrategy(user_speech_timeout=TRANSCRIPTION_TIMEOUT)],
|
||||
)
|
||||
)
|
||||
|
||||
should_start = None
|
||||
should_stop = None
|
||||
@@ -48,6 +52,8 @@ class TestUserTurnProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
TranscriptionFrame(text="Hello!", user_id="", timestamp="now"),
|
||||
SleepFrame(),
|
||||
VADUserStoppedSpeakingFrame(),
|
||||
# Wait for user_speech_timeout to elapse
|
||||
SleepFrame(sleep=TRANSCRIPTION_TIMEOUT + 0.1),
|
||||
]
|
||||
expected_down_frames = [
|
||||
VADUserStartedSpeakingFrame,
|
||||
@@ -109,7 +115,7 @@ class TestUserTurnProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_user_turn_stop_timeout_transcription(self):
|
||||
user_turn_processor = UserTurnProcessor(
|
||||
user_turn_strategies=UserTurnStrategies(
|
||||
stop=[TranscriptionUserTurnStopStrategy(timeout=TRANSCRIPTION_TIMEOUT)],
|
||||
stop=[SpeechTimeoutUserTurnStopStrategy(user_speech_timeout=TRANSCRIPTION_TIMEOUT)],
|
||||
),
|
||||
user_turn_stop_timeout=USER_TURN_STOP_TIMEOUT,
|
||||
)
|
||||
@@ -135,13 +141,13 @@ class TestUserTurnProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
pipeline = Pipeline([user_turn_processor])
|
||||
|
||||
# Transcript arrives before VAD stop, then we wait for user_speech_timeout
|
||||
frames_to_send = [
|
||||
VADUserStartedSpeakingFrame(),
|
||||
VADUserStoppedSpeakingFrame(),
|
||||
SleepFrame(sleep=USER_TURN_STOP_TIMEOUT - 0.1),
|
||||
TranscriptionFrame(text="Hello!", user_id="", timestamp="now"),
|
||||
SleepFrame(sleep=USER_TURN_STOP_TIMEOUT - 0.1),
|
||||
SleepFrame(sleep=TRANSCRIPTION_TIMEOUT),
|
||||
VADUserStoppedSpeakingFrame(),
|
||||
# Wait for user_speech_timeout (TRANSCRIPTION_TIMEOUT=0.1s) to elapse
|
||||
SleepFrame(sleep=TRANSCRIPTION_TIMEOUT + 0.05),
|
||||
]
|
||||
await run_test(
|
||||
pipeline,
|
||||
|
||||
@@ -9,25 +9,38 @@ import unittest
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
InterimTranscriptionFrame,
|
||||
STTMetadataFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.turns.user_stop import ExternalUserTurnStopStrategy, TranscriptionUserTurnStopStrategy
|
||||
from pipecat.turns.user_stop import ExternalUserTurnStopStrategy, SpeechTimeoutUserTurnStopStrategy
|
||||
from pipecat.utils.asyncio.task_manager import TaskManager, TaskManagerParams
|
||||
|
||||
AGGREGATION_TIMEOUT = 0.1
|
||||
# Use 0 STT timeout for deterministic test timing
|
||||
STT_TIMEOUT = 0.0
|
||||
|
||||
|
||||
class TestTranscriptionUserTurnStopStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
class TestSpeechTimeoutUserTurnStopStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
async def asyncSetUp(self) -> None:
|
||||
self.task_manager = TaskManager()
|
||||
self.task_manager.setup(TaskManagerParams(loop=asyncio.get_running_loop()))
|
||||
|
||||
async def _create_strategy(self, user_speech_timeout=AGGREGATION_TIMEOUT):
|
||||
"""Create strategy and configure STT timeout via metadata frame."""
|
||||
strategy = SpeechTimeoutUserTurnStopStrategy(user_speech_timeout=user_speech_timeout)
|
||||
await strategy.setup(self.task_manager)
|
||||
# Set STT timeout via metadata frame (as would happen in real pipeline)
|
||||
await strategy.process_frame(
|
||||
STTMetadataFrame(service_name="test", ttfs_p99_latency=STT_TIMEOUT)
|
||||
)
|
||||
return strategy
|
||||
|
||||
async def test_ste(self):
|
||||
strategy = TranscriptionUserTurnStopStrategy()
|
||||
strategy = await self._create_strategy()
|
||||
|
||||
should_start = None
|
||||
|
||||
@@ -46,13 +59,15 @@ class TestTranscriptionUserTurnStopStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
# E
|
||||
await strategy.process_frame(VADUserStoppedSpeakingFrame())
|
||||
self.assertIsNone(should_start)
|
||||
|
||||
# Transcription comes in between user started/stopped and there are not
|
||||
# interim, we just trigger bot speech.
|
||||
# Transcription came in between user started/stopped. Now we wait for
|
||||
# timeout before triggering.
|
||||
await asyncio.sleep(AGGREGATION_TIMEOUT + 0.1)
|
||||
self.assertTrue(should_start)
|
||||
|
||||
async def test_site(self):
|
||||
strategy = TranscriptionUserTurnStopStrategy()
|
||||
strategy = await self._create_strategy()
|
||||
|
||||
should_start = None
|
||||
|
||||
@@ -77,13 +92,15 @@ class TestTranscriptionUserTurnStopStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
# E
|
||||
await strategy.process_frame(VADUserStoppedSpeakingFrame())
|
||||
self.assertIsNone(should_start)
|
||||
|
||||
# Transcription comes in between user started/stopped, so we trigger
|
||||
# speech right away.
|
||||
# Transcription came in between user started/stopped. Now we wait for
|
||||
# timeout before triggering.
|
||||
await asyncio.sleep(AGGREGATION_TIMEOUT + 0.1)
|
||||
self.assertTrue(should_start)
|
||||
|
||||
async def test_st1iest2e(self):
|
||||
strategy = TranscriptionUserTurnStopStrategy()
|
||||
strategy = await self._create_strategy()
|
||||
|
||||
should_start = None
|
||||
|
||||
@@ -122,15 +139,14 @@ class TestTranscriptionUserTurnStopStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
# E
|
||||
await strategy.process_frame(VADUserStoppedSpeakingFrame())
|
||||
self.assertIsNone(should_start)
|
||||
|
||||
# There was an interim before the first user stopped speaking, then we
|
||||
# got a transcription comes in between user started/stopped, so we
|
||||
# trigger speech right away.
|
||||
# Now we wait for timeout before triggering.
|
||||
await asyncio.sleep(AGGREGATION_TIMEOUT + 0.1)
|
||||
self.assertTrue(should_start)
|
||||
|
||||
async def test_siet(self):
|
||||
strategy = TranscriptionUserTurnStopStrategy(timeout=AGGREGATION_TIMEOUT)
|
||||
await strategy.setup(self.task_manager)
|
||||
strategy = await self._create_strategy()
|
||||
|
||||
should_start = None
|
||||
|
||||
@@ -163,8 +179,7 @@ class TestTranscriptionUserTurnStopStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertTrue(should_start)
|
||||
|
||||
async def test_sieit(self):
|
||||
strategy = TranscriptionUserTurnStopStrategy(timeout=AGGREGATION_TIMEOUT)
|
||||
await strategy.setup(self.task_manager)
|
||||
strategy = await self._create_strategy()
|
||||
|
||||
should_start = None
|
||||
|
||||
@@ -205,8 +220,7 @@ class TestTranscriptionUserTurnStopStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertTrue(should_start)
|
||||
|
||||
async def test_set(self):
|
||||
strategy = TranscriptionUserTurnStopStrategy(timeout=AGGREGATION_TIMEOUT)
|
||||
await strategy.setup(self.task_manager)
|
||||
strategy = await self._create_strategy()
|
||||
|
||||
should_start = None
|
||||
|
||||
@@ -235,8 +249,7 @@ class TestTranscriptionUserTurnStopStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertTrue(should_start)
|
||||
|
||||
async def test_seit(self):
|
||||
strategy = TranscriptionUserTurnStopStrategy(timeout=AGGREGATION_TIMEOUT)
|
||||
await strategy.setup(self.task_manager)
|
||||
strategy = await self._create_strategy()
|
||||
|
||||
should_start = None
|
||||
|
||||
@@ -271,8 +284,7 @@ class TestTranscriptionUserTurnStopStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertTrue(should_start)
|
||||
|
||||
async def test_st1et2(self):
|
||||
strategy = TranscriptionUserTurnStopStrategy(timeout=AGGREGATION_TIMEOUT)
|
||||
await strategy.setup(self.task_manager)
|
||||
strategy = await self._create_strategy()
|
||||
|
||||
should_start = None
|
||||
|
||||
@@ -291,26 +303,37 @@ class TestTranscriptionUserTurnStopStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
# E
|
||||
await strategy.process_frame(VADUserStoppedSpeakingFrame())
|
||||
self.assertIsNone(should_start)
|
||||
|
||||
# Transcription comes between user start/stopped speaking, we need to
|
||||
# trigger speech right away.
|
||||
# Transcription came between user start/stopped speaking, wait for timeout.
|
||||
await asyncio.sleep(AGGREGATION_TIMEOUT + 0.1)
|
||||
self.assertTrue(should_start)
|
||||
should_start = None
|
||||
|
||||
# Reset for next turn (in real usage, UserTurnController would do this)
|
||||
await strategy.reset()
|
||||
|
||||
# S - new turn starts
|
||||
await strategy.process_frame(VADUserStartedSpeakingFrame())
|
||||
self.assertIsNone(should_start)
|
||||
|
||||
# T2
|
||||
await strategy.process_frame(
|
||||
TranscriptionFrame(text="How are you?", user_id="cat", timestamp="")
|
||||
)
|
||||
self.assertIsNone(should_start)
|
||||
|
||||
# E
|
||||
await strategy.process_frame(VADUserStoppedSpeakingFrame())
|
||||
self.assertIsNone(should_start)
|
||||
|
||||
# Transcription comes after user stopped speaking, we need to wait for
|
||||
# at least the aggregation timeout.
|
||||
await asyncio.sleep(AGGREGATION_TIMEOUT + 0.1)
|
||||
self.assertTrue(should_start)
|
||||
|
||||
async def test_set1t2(self):
|
||||
strategy = TranscriptionUserTurnStopStrategy(timeout=AGGREGATION_TIMEOUT)
|
||||
await strategy.setup(self.task_manager)
|
||||
strategy = await self._create_strategy()
|
||||
|
||||
should_start = None
|
||||
|
||||
@@ -343,8 +366,7 @@ class TestTranscriptionUserTurnStopStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertTrue(should_start)
|
||||
|
||||
async def test_siet1it2(self):
|
||||
strategy = TranscriptionUserTurnStopStrategy(timeout=AGGREGATION_TIMEOUT)
|
||||
await strategy.setup(self.task_manager)
|
||||
strategy = await self._create_strategy()
|
||||
|
||||
should_start = None
|
||||
|
||||
@@ -388,8 +410,8 @@ class TestTranscriptionUserTurnStopStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertTrue(should_start)
|
||||
|
||||
async def test_t(self):
|
||||
strategy = TranscriptionUserTurnStopStrategy(timeout=AGGREGATION_TIMEOUT)
|
||||
await strategy.setup(self.task_manager)
|
||||
"""Transcription without VAD - uses fallback timeout."""
|
||||
strategy = await self._create_strategy()
|
||||
|
||||
should_start = None
|
||||
|
||||
@@ -402,14 +424,13 @@ class TestTranscriptionUserTurnStopStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
await strategy.process_frame(TranscriptionFrame(text="Hello!", user_id="cat", timestamp=""))
|
||||
self.assertIsNone(should_start)
|
||||
|
||||
# Transcription comes after user stopped speaking, we need to wait for
|
||||
# at least the aggregation timeout.
|
||||
# Transcription without VAD triggers fallback timeout.
|
||||
await asyncio.sleep(AGGREGATION_TIMEOUT + 0.1)
|
||||
self.assertTrue(should_start)
|
||||
|
||||
async def test_it(self):
|
||||
strategy = TranscriptionUserTurnStopStrategy(timeout=AGGREGATION_TIMEOUT)
|
||||
await strategy.setup(self.task_manager)
|
||||
"""Interim + Transcription without VAD - uses fallback timeout."""
|
||||
strategy = await self._create_strategy()
|
||||
|
||||
should_start = None
|
||||
|
||||
@@ -427,14 +448,12 @@ class TestTranscriptionUserTurnStopStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
await strategy.process_frame(TranscriptionFrame(text="Hello!", user_id="cat", timestamp=""))
|
||||
self.assertIsNone(should_start)
|
||||
|
||||
# Transcription comes after user stopped speaking, we need to wait for
|
||||
# at least the aggregation timeout.
|
||||
# Transcription without VAD triggers fallback timeout.
|
||||
await asyncio.sleep(AGGREGATION_TIMEOUT + 0.1)
|
||||
self.assertTrue(should_start)
|
||||
|
||||
async def test_sie_delay_it(self):
|
||||
strategy = TranscriptionUserTurnStopStrategy(timeout=AGGREGATION_TIMEOUT)
|
||||
await strategy.setup(self.task_manager)
|
||||
strategy = await self._create_strategy()
|
||||
|
||||
should_start = None
|
||||
|
||||
@@ -456,23 +475,22 @@ class TestTranscriptionUserTurnStopStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
await strategy.process_frame(VADUserStoppedSpeakingFrame())
|
||||
self.assertIsNone(should_start)
|
||||
|
||||
# Delay
|
||||
# Delay - timeout expires but no transcript yet
|
||||
await asyncio.sleep(AGGREGATION_TIMEOUT + 0.1)
|
||||
# Still no trigger because no transcript received
|
||||
self.assertIsNone(should_start)
|
||||
|
||||
# I
|
||||
await strategy.process_frame(
|
||||
InterimTranscriptionFrame(text="How", user_id="cat", timestamp="")
|
||||
)
|
||||
|
||||
# T
|
||||
# T (finalized) - triggers immediately since timeout already elapsed
|
||||
await strategy.process_frame(
|
||||
TranscriptionFrame(text="How are you?", user_id="cat", timestamp="")
|
||||
TranscriptionFrame(text="How are you?", user_id="cat", timestamp="", finalized=True)
|
||||
)
|
||||
self.assertIsNone(should_start)
|
||||
|
||||
# Transcription comes after user stopped speaking, we need to wait for
|
||||
# at least the aggregation timeout.
|
||||
await asyncio.sleep(AGGREGATION_TIMEOUT + 0.1)
|
||||
# Finalized transcript received after timeout, triggers immediately
|
||||
self.assertTrue(should_start)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user