Compare commits
148 Commits
kompfner-p
...
mb/gradium
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d5b34759d7 | ||
|
|
0d697d184a | ||
|
|
717e1ccc01 | ||
|
|
671cc8eb74 | ||
|
|
b4dce656f0 | ||
|
|
253a1d1114 | ||
|
|
ca613bcb79 | ||
|
|
0423acd8a0 | ||
|
|
7eabaaa0ef | ||
|
|
bbb8b53d03 | ||
|
|
f3b72e9263 | ||
|
|
b77a50de73 | ||
|
|
433c1b9b92 | ||
|
|
bd00587092 | ||
|
|
5a85e27cc5 | ||
|
|
11daa43b1b | ||
|
|
875614ff7a | ||
|
|
eb1bf1e446 | ||
|
|
7456a0a55f | ||
|
|
27277ed3d9 | ||
|
|
5543bc56f3 | ||
|
|
c8496dfb8e | ||
|
|
d3f4cbb620 | ||
|
|
c9f922c479 | ||
|
|
49bd3da26b | ||
|
|
f3ef488925 | ||
|
|
4f08098917 | ||
|
|
a7cd5b0322 | ||
|
|
55dadc9118 | ||
|
|
01bbf61e0d | ||
|
|
10fb77c0e2 | ||
|
|
2612fae527 | ||
|
|
c5be67f293 | ||
|
|
312caaba86 | ||
|
|
ff0eb6d286 | ||
|
|
ef6bbace98 | ||
|
|
06ec21387f | ||
|
|
bdae177125 | ||
|
|
468e159f9b | ||
|
|
a4acafd3be | ||
|
|
105824a372 | ||
|
|
55e0d4ecc4 | ||
|
|
9102e81cb8 | ||
|
|
d7d8e93a3d | ||
|
|
bf9b166464 | ||
|
|
e80e0eab29 | ||
|
|
61242e6575 | ||
|
|
8841387121 | ||
|
|
ee695ae9fe | ||
|
|
52012b0fb2 | ||
|
|
f7a1c6b719 | ||
|
|
6aa77ccc13 | ||
|
|
45b7ec4e2c | ||
|
|
1c434c6ad5 | ||
|
|
4591affba9 | ||
|
|
91346f5f37 | ||
|
|
6a66ebe332 | ||
|
|
c1d4180042 | ||
|
|
81a53c699c | ||
|
|
60168f7f69 | ||
|
|
23d7608e5f | ||
|
|
99242c0a93 | ||
|
|
3a71865cf4 | ||
|
|
ecf2e69f3f | ||
|
|
febd52274d | ||
|
|
1542d922e7 | ||
|
|
15d5d1159e | ||
|
|
884630a6bd | ||
|
|
1cf137c6a8 | ||
|
|
98fcfd7c91 | ||
|
|
2f23f2e39c | ||
|
|
9c6b11cecf | ||
|
|
fc1444c9d6 | ||
|
|
ea94939add | ||
|
|
0c69ae6371 | ||
|
|
8b88280bb1 | ||
|
|
960d0faea5 | ||
|
|
b9390ccb1b | ||
|
|
061a0dc43d | ||
|
|
328bbe069f | ||
|
|
dc32ecc872 | ||
|
|
ca2eb1904f | ||
|
|
4bce58f270 | ||
|
|
7572d63f8f | ||
|
|
3c463c9416 | ||
|
|
bd618d64e3 | ||
|
|
a824660df7 | ||
|
|
58b9019852 | ||
|
|
afcdef8c81 | ||
|
|
bd92104fb3 | ||
|
|
34e9f224a8 | ||
|
|
dca7f3b5b0 | ||
|
|
70a85cd192 | ||
|
|
91e86658b7 | ||
|
|
0a8588669c | ||
|
|
0e99400148 | ||
|
|
648f20db6d | ||
|
|
09b5b6b12d | ||
|
|
0e6a423955 | ||
|
|
dc8972cd94 | ||
|
|
e4e2231958 | ||
|
|
18b3ee743b | ||
|
|
65b8e0e89c | ||
|
|
b77f8b065f | ||
|
|
5fd43faec3 | ||
|
|
abebcf37bd | ||
|
|
ca4e3c79f9 | ||
|
|
e8d1bec03b | ||
|
|
f0cc54589e | ||
|
|
22b9aac2ff | ||
|
|
7f86f4ac27 | ||
|
|
dcab79753b | ||
|
|
bdded9b026 | ||
|
|
1e1e275fea | ||
|
|
effb6aa8f4 | ||
|
|
a4a9bae79e | ||
|
|
c943ef9261 | ||
|
|
f05809520b | ||
|
|
ec17dc6626 | ||
|
|
4e85e81d9b | ||
|
|
a1cc88a233 | ||
|
|
61a230ec53 | ||
|
|
a13380b574 | ||
|
|
2a927189d9 | ||
|
|
a90c15362c | ||
|
|
d3bdd2d246 | ||
|
|
465ae4f706 | ||
|
|
a0d801b658 | ||
|
|
35919a84e3 | ||
|
|
f94a60f381 | ||
|
|
a446bca72d | ||
|
|
8ae834366b | ||
|
|
a4acc12f91 | ||
|
|
e93112e76e | ||
|
|
680bcaac66 | ||
|
|
d2ac9006a2 | ||
|
|
bcb019e8ab | ||
|
|
4ea546785f | ||
|
|
f128cdd19a | ||
|
|
7921bce4af | ||
|
|
cadced3f79 | ||
|
|
3b3c7aa8cc | ||
|
|
fa5da3b0be | ||
|
|
7e82a0cf49 | ||
|
|
0b1a4792b8 | ||
|
|
14bd3b1b32 | ||
|
|
f733e77496 | ||
|
|
38506f51f7 |
9
.github/workflows/coverage.yaml
vendored
9
.github/workflows/coverage.yaml
vendored
@@ -33,7 +33,14 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv sync --group dev --extra anthropic --extra aws --extra google --extra langchain --extra livekit --extra websocket
|
||||
uv sync --group dev \
|
||||
--extra anthropic \
|
||||
--extra aws \
|
||||
--extra google \
|
||||
--extra langchain \
|
||||
--extra livekit \
|
||||
--extra piper \
|
||||
--extra websocket
|
||||
|
||||
- name: Run tests with coverage
|
||||
run: |
|
||||
|
||||
9
.github/workflows/tests.yaml
vendored
9
.github/workflows/tests.yaml
vendored
@@ -37,7 +37,14 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv sync --group dev --extra anthropic --extra aws --extra google --extra langchain --extra livekit --extra websocket
|
||||
uv sync --group dev \
|
||||
--extra anthropic \
|
||||
--extra aws \
|
||||
--extra google \
|
||||
--extra langchain \
|
||||
--extra livekit \
|
||||
--extra piper \
|
||||
--extra websocket
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
|
||||
1
changelog/3406.fixed.md
Normal file
1
changelog/3406.fixed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Fixed an issue where if you were using `OpenRouterLLMService` with a Gemini model, it wouldn't handle multiple `"system"` messages as expected (and as we do in `GoogleLLMService`), which is to convert subsequent ones into `"user"` messages. Instead, the latest `"system"` message would overwrite the previous ones.
|
||||
4
changelog/3408.added.md
Normal file
4
changelog/3408.added.md
Normal file
@@ -0,0 +1,4 @@
|
||||
- Additions for `AICFilter` and `AICVADAnalyzer`:
|
||||
- Added model downloading support to `AICFilter` with `model_id` and `model_download_dir` parameters.
|
||||
- Added `model_path` parameter to `AICFilter` for loading local `.aicmodel` files.
|
||||
- Added unit tests for `AICFilter` and `AICVADAnalyzer`.
|
||||
1
changelog/3408.changed.md
Normal file
1
changelog/3408.changed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Updated `AICFilter` and `AICVADAnalyzer` to use aic-sdk ~= 2.0.1.
|
||||
1
changelog/3408.removed.md
Normal file
1
changelog/3408.removed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Removed deprecated `AICFilter` parameters: `enhancement_level`, `voice_gain`, `noise_gate_enable`.
|
||||
1
changelog/3429.added.md
Normal file
1
changelog/3429.added.md
Normal file
@@ -0,0 +1 @@
|
||||
- Added handling for `server_content.interrupted` signal in the Gemini Live service for faster interruption response in the case where there isn't already turn tracking in the pipeline, e.g. local VAD + context aggregators. When there is already turn tracking in the pipeline, the additional interruption does no harm.
|
||||
1
changelog/3495.changed.2.md
Normal file
1
changelog/3495.changed.2.md
Normal file
@@ -0,0 +1 @@
|
||||
- `SarvamSTTService` now defaults `vad_signals` and `high_vad_sensitivity` to `None` (omitted from connection parameters), improving latency by ~300ms compared to the previous defaults.
|
||||
1
changelog/3495.changed.md
Normal file
1
changelog/3495.changed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Improved the STT TTFB (Time To First Byte) measurement, reporting the delay between when the user stops speaking and when the final transcription is received. Note: Unlike traditional TTFB which measures from a discrete request, STT services receive continuous audio input—so we measure from speech end to final transcript, which captures the latency that matters for voice AI applications. In support of this change, added `finalized` field to `TranscriptionFrame` to indicate when a transcript is the final result for an utterance.
|
||||
1
changelog/3500.added.md
Normal file
1
changelog/3500.added.md
Normal file
@@ -0,0 +1 @@
|
||||
- Added new `GenesysFrameSerializer` for the Genesys AudioHook WebSocket protocol, enabling bidirectional audio streaming between Pipecat pipelines and Genesys Cloud contact center.
|
||||
@@ -1 +1 @@
|
||||
- Added new `SMART_TURN_LOG_DATA` environment variable, which causes Smart Turn input data to be saved to disk
|
||||
- Added new `PIPECAT_SMART_TURN_LOG_DATA` environment variable, which causes Smart Turn input data to be saved to disk
|
||||
|
||||
1
changelog/3529.fixed.md
Normal file
1
changelog/3529.fixed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Fixed OpenAI LLM services to emit `ErrorFrame` on completion timeout, enabling proper error handling and LLMSwitcher failover.
|
||||
1
changelog/3536.fixed.md
Normal file
1
changelog/3536.fixed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Fixed a logging issue where non-ASCII characters (e.g., Japanese, Chinese, etc.) were being unnecessarily escaped to Unicode sequences when function call occurred.
|
||||
1
changelog/3541.fixed.md
Normal file
1
changelog/3541.fixed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Fixed how audio tracks are synchronized inside the `AudioBufferProcessor` to fix timing issues where silence and audio were misaligned between user and bot buffers.
|
||||
1
changelog/3560.changed.md
Normal file
1
changelog/3560.changed.md
Normal file
@@ -0,0 +1 @@
|
||||
- `FrameSerializer` now subclasses from `BaseObject` to enable event support.
|
||||
2
changelog/3562.changed.md
Normal file
2
changelog/3562.changed.md
Normal file
@@ -0,0 +1,2 @@
|
||||
- Added support for TTFS in `SpeechmaticsSTTService` and set the default mode to `EXTERNAL` to support Pipecat-controlled VAD.
|
||||
- Changed dependency to `speechmatics-voice[smart]>=0.2.8`
|
||||
1
changelog/3567.fixed.md
Normal file
1
changelog/3567.fixed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Fixed race condition in `OpenAIRealtimeBetaLLMService` that could cause an error when truncating the conversation.
|
||||
1
changelog/3571.added.2.md
Normal file
1
changelog/3571.added.2.md
Normal file
@@ -0,0 +1 @@
|
||||
- Added `function_call_timeout_secs` parameter to `LLMService` to configure timeout for deferred function calls (defaults to 10.0 seconds).
|
||||
1
changelog/3571.added.md
Normal file
1
changelog/3571.added.md
Normal file
@@ -0,0 +1 @@
|
||||
- Added `result_callback` parameter to `UserImageRequestFrame` to support deferred function call results.
|
||||
4
changelog/3571.changed.md
Normal file
4
changelog/3571.changed.md
Normal file
@@ -0,0 +1,4 @@
|
||||
- ⚠️ Changed function call handling to use timeout-based completion instead of immediate callback execution.
|
||||
- Function calls that defer their results (e.g., `UserImageRequestFrame`) now use a timeout mechanism
|
||||
- The `result_callback` is invoked automatically when the deferred operation completes or after timeout
|
||||
- This change affects examples using `UserImageRequestFrame` - the `result_callback` should now be passed to the frame instead of being called immediately
|
||||
1
changelog/3574.fixed.md
Normal file
1
changelog/3574.fixed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Fixed an infinite loop in `WebsocketService` that blocked the event loop when a remote server closed the connection gracefully.
|
||||
1
changelog/3575.fixed.md
Normal file
1
changelog/3575.fixed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Fixed `LLMUserAggregator` and `LLMAssistantAggregator` not emitting pending transcripts via `on_user_turn_stopped` and `on_assistant_turn_stopped` events when the conversation ends (`EndFrame`) or is cancelled (`CancelFrame`).
|
||||
1
changelog/3580.fixed.md
Normal file
1
changelog/3580.fixed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Added missing `LiveKitRunnerArguments` and `LiveKitTransport` support in runner utilities to enable LiveKit transport configuration.
|
||||
1
changelog/3581.fixed.md
Normal file
1
changelog/3581.fixed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Fixed race condition in `OpenAIRealtimeLLMService` that could cause an error when truncating the conversation.
|
||||
1
changelog/3582.change.md
Normal file
1
changelog/3582.change.md
Normal file
@@ -0,0 +1 @@
|
||||
- Pipecat runner now uses `DAILY_ROOM_URL` instead of `DAILY_SAMPLE_ROOM_URL`.
|
||||
1
changelog/3585.added.md
Normal file
1
changelog/3585.added.md
Normal file
@@ -0,0 +1 @@
|
||||
- Added local `PiperTTSService` for offline text-to-speech using Piper voice models. The existing HTTP-based service has been renamed to `PiperHttpTTSService`.
|
||||
1
changelog/3585.fixed.md
Normal file
1
changelog/3585.fixed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Fixed `PiperHttpTTSService` (olf `PiperTTSService`) to resample audio output based on the model's sample rate parsed from the WAV header.
|
||||
3
changelog/3587.changed.md
Normal file
3
changelog/3587.changed.md
Normal file
@@ -0,0 +1,3 @@
|
||||
- Updates to `GradiumSTTService`:
|
||||
- Now flushes pending transcriptions when VAD detects the user stopped speaking, improving response latency.
|
||||
- `GradiumSTTService` now supports `InputParams` for configuring `language` and `delay_in_frames` settings.
|
||||
1
changelog/3590.added.md
Normal file
1
changelog/3590.added.md
Normal file
@@ -0,0 +1 @@
|
||||
- `main()` in `pipecat.runner.run` now accepts an optional `argparse.ArgumentParser`, allowing bots to define custom CLI arguments accessible via `runner_args.cli_args`.
|
||||
1
changelog/3594.fixed.md
Normal file
1
changelog/3594.fixed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Fixed `UserTurnController` to reset user turn timeout when interim transcriptions are received.
|
||||
1
changelog/3596.fixed.md
Normal file
1
changelog/3596.fixed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Fixed an issue in `GradiumTTSService` where the websocket was being disconnected at the end of every bot turn.
|
||||
@@ -43,7 +43,7 @@ CEREBRAS_API_KEY=...
|
||||
|
||||
# Daily
|
||||
DAILY_API_KEY=...
|
||||
DAILY_SAMPLE_ROOM_URL=https://...
|
||||
DAILY_ROOM_URL=https://...
|
||||
|
||||
# Deepgram
|
||||
DEEPGRAM_API_KEY=...
|
||||
|
||||
@@ -16,7 +16,7 @@ from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.piper.tts import PiperTTSService
|
||||
from pipecat.services.piper.tts import PiperHttpTTSService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
@@ -39,7 +39,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
|
||||
# Create an HTTP session
|
||||
async with aiohttp.ClientSession() as session:
|
||||
tts = PiperTTSService(
|
||||
tts = PiperHttpTTSService(
|
||||
base_url=os.getenv("PIPER_BASE_URL"), aiohttp_session=session, sample_rate=24000
|
||||
)
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ def _create_aic_filter() -> AICFilter:
|
||||
|
||||
return AICFilter(
|
||||
license_key=license_key,
|
||||
enhancement_level=0.5,
|
||||
model_id="quail-vf-l-16khz",
|
||||
)
|
||||
|
||||
|
||||
@@ -62,7 +62,9 @@ transport_params = {
|
||||
lambda aic: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=aic.create_vad_analyzer(lookback_buffer_size=6.0, sensitivity=6.0),
|
||||
vad_analyzer=aic.create_vad_analyzer(
|
||||
speech_hold_duration=0.05, minimum_speech_duration=0.0, sensitivity=6.0
|
||||
),
|
||||
audio_in_filter=aic,
|
||||
)
|
||||
)(_create_aic_filter()),
|
||||
@@ -70,7 +72,9 @@ transport_params = {
|
||||
lambda aic: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=aic.create_vad_analyzer(lookback_buffer_size=6.0, sensitivity=6.0),
|
||||
vad_analyzer=aic.create_vad_analyzer(
|
||||
speech_hold_duration=0.05, minimum_speech_duration=0.0, sensitivity=6.0
|
||||
),
|
||||
audio_in_filter=aic,
|
||||
)
|
||||
)(_create_aic_filter()),
|
||||
@@ -78,7 +82,9 @@ transport_params = {
|
||||
lambda aic: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=aic.create_vad_analyzer(lookback_buffer_size=6.0, sensitivity=6.0),
|
||||
vad_analyzer=aic.create_vad_analyzer(
|
||||
speech_hold_duration=0.05, minimum_speech_duration=0.0, sensitivity=6.0
|
||||
),
|
||||
audio_in_filter=aic,
|
||||
)
|
||||
)(_create_aic_filter()),
|
||||
|
||||
@@ -26,6 +26,7 @@ from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.gradium.stt import GradiumSTTService
|
||||
from pipecat.services.gradium.tts import GradiumTTSService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
@@ -59,11 +60,18 @@ transport_params = {
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = GradiumSTTService(api_key=os.getenv("GRADIUM_API_KEY"))
|
||||
stt = GradiumSTTService(
|
||||
api_key=os.getenv("GRADIUM_API_KEY"),
|
||||
api_endpoint_base_url="wss://us.api.gradium.ai/api/speech/asr",
|
||||
params=GradiumSTTService.InputParams(
|
||||
language=Language.EN,
|
||||
),
|
||||
)
|
||||
|
||||
tts = GradiumTTSService(
|
||||
api_key=os.getenv("GRADIUM_API_KEY"),
|
||||
voice_id="YTpq7expH9539ERJ",
|
||||
url="wss://us.api.gradium.ai/api/speech/tts",
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
132
examples/foundational/07zi-interruptible-piper.py
Normal file
132
examples/foundational/07zi-interruptible-piper.py
Normal file
@@ -0,0 +1,132 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.services.piper.tts import PiperTTSService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
from pipecat.turns.user_stop import TurnAnalyzerUserTurnStopStrategy
|
||||
from pipecat.turns.user_turn_strategies import UserTurnStrategies
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
tts = PiperTTSService(voice_id="en_US-ryan-high")
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
context = LLMContext(messages)
|
||||
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(
|
||||
user_turn_strategies=UserTurnStrategies(
|
||||
stop=[TurnAnalyzerUserTurnStopStrategy(turn_analyzer=LocalSmartTurnAnalyzerV3())]
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt,
|
||||
user_aggregator, # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
assistant_aggregator, # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
86
examples/foundational/13l-gradium-transcription.py
Normal file
86
examples/foundational/13l-gradium-transcription.py
Normal file
@@ -0,0 +1,86 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import Frame, TranscriptionFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.gradium.stt import GradiumSTTService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
class TranscriptionLogger(FrameProcessor):
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, TranscriptionFrame):
|
||||
print(f"Transcription: {frame.text}")
|
||||
|
||||
# Push all frames through
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(audio_in_enabled=True),
|
||||
"twilio": lambda: FastAPIWebsocketParams(audio_in_enabled=True),
|
||||
"webrtc": lambda: TransportParams(audio_in_enabled=True),
|
||||
}
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = GradiumSTTService(
|
||||
api_key=os.getenv("GRADIUM_API_KEY"),
|
||||
api_endpoint_base_url="wss://us.api.gradium.ai/api/speech/asr",
|
||||
params=GradiumSTTService.InputParams(language=Language.EN, delay_in_frames=8),
|
||||
)
|
||||
|
||||
tl = TranscriptionLogger()
|
||||
|
||||
pipeline = Pipeline([transport.input(), stt, tl])
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -48,14 +48,16 @@ async def fetch_user_image(params: FunctionCallParams):
|
||||
When called, this function pushes a UserImageRequestFrame upstream to the
|
||||
transport. As a result, the transport will request the user image and push a
|
||||
UserImageRawFrame downstream which will be added to the context by the LLM
|
||||
assistant aggregator.
|
||||
assistant aggregator. The result_callback will be invoked once the image is
|
||||
retrieved and processed.
|
||||
"""
|
||||
user_id = params.arguments["user_id"]
|
||||
question = params.arguments["question"]
|
||||
logger.debug(f"Requesting image with user_id={user_id}, question={question}")
|
||||
|
||||
# Request a user image frame and indicate that it should be added to the
|
||||
# context. Also associate it to the function call.
|
||||
# context. Also associate it to the function call. Pass the result_callback
|
||||
# so it can be invoked when the image is actually retrieved.
|
||||
await params.llm.push_frame(
|
||||
UserImageRequestFrame(
|
||||
user_id=user_id,
|
||||
@@ -63,16 +65,11 @@ async def fetch_user_image(params: FunctionCallParams):
|
||||
append_to_context=True,
|
||||
function_name=params.function_name,
|
||||
tool_call_id=params.tool_call_id,
|
||||
result_callback=params.result_callback,
|
||||
),
|
||||
FrameDirection.UPSTREAM,
|
||||
)
|
||||
|
||||
await params.result_callback(None)
|
||||
|
||||
# Instead of None, it's possible to also provide a tool call answer to
|
||||
# tell the LLM that we are grabbing the image to analyze.
|
||||
# await params.result_callback({"result": "Image is being captured."})
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
|
||||
@@ -48,14 +48,16 @@ async def fetch_user_image(params: FunctionCallParams):
|
||||
When called, this function pushes a UserImageRequestFrame upstream to the
|
||||
transport. As a result, the transport will request the user image and push a
|
||||
UserImageRawFrame downstream which will be added to the context by the LLM
|
||||
assistant aggregator.
|
||||
assistant aggregator. The result_callback will be invoked once the image is
|
||||
retrieved and processed.
|
||||
"""
|
||||
user_id = params.arguments["user_id"]
|
||||
question = params.arguments["question"]
|
||||
logger.debug(f"Requesting image with user_id={user_id}, question={question}")
|
||||
|
||||
# Request a user image frame and indicate that it should be added to the
|
||||
# context. Also associate it to the function call.
|
||||
# context. Also associate it to the function call. Pass the result_callback
|
||||
# so it can be invoked when the image is actually retrieved.
|
||||
await params.llm.push_frame(
|
||||
UserImageRequestFrame(
|
||||
user_id=user_id,
|
||||
@@ -63,16 +65,11 @@ async def fetch_user_image(params: FunctionCallParams):
|
||||
append_to_context=True,
|
||||
function_name=params.function_name,
|
||||
tool_call_id=params.tool_call_id,
|
||||
result_callback=params.result_callback,
|
||||
),
|
||||
FrameDirection.UPSTREAM,
|
||||
)
|
||||
|
||||
await params.result_callback(None)
|
||||
|
||||
# Instead of None, it's possible to also provide a tool call answer to
|
||||
# tell the LLM that we are grabbing the image to analyze.
|
||||
# await params.result_callback({"result": "Image is being captured."})
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
|
||||
@@ -48,14 +48,16 @@ async def fetch_user_image(params: FunctionCallParams):
|
||||
When called, this function pushes a UserImageRequestFrame upstream to the
|
||||
transport. As a result, the transport will request the user image and push a
|
||||
UserImageRawFrame downstream which will be added to the context by the LLM
|
||||
assistant aggregator.
|
||||
assistant aggregator. The result_callback will be invoked once the image is
|
||||
retrieved and processed.
|
||||
"""
|
||||
user_id = params.arguments["user_id"]
|
||||
question = params.arguments["question"]
|
||||
logger.debug(f"Requesting image with user_id={user_id}, question={question}")
|
||||
|
||||
# Request a user image frame and indicate that it should be added to the
|
||||
# context. Also associate it to the function call.
|
||||
# context. Also associate it to the function call. Pass the result_callback
|
||||
# so it can be invoked when the image is actually retrieved.
|
||||
await params.llm.push_frame(
|
||||
UserImageRequestFrame(
|
||||
user_id=user_id,
|
||||
@@ -63,16 +65,11 @@ async def fetch_user_image(params: FunctionCallParams):
|
||||
append_to_context=True,
|
||||
function_name=params.function_name,
|
||||
tool_call_id=params.tool_call_id,
|
||||
result_callback=params.result_callback,
|
||||
),
|
||||
FrameDirection.UPSTREAM,
|
||||
)
|
||||
|
||||
await params.result_callback(None)
|
||||
|
||||
# Instead of None, it's possible to also provide a tool call answer to
|
||||
# tell the LLM that we are grabbing the image to analyze.
|
||||
# await params.result_callback({"result": "Image is being captured."})
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
|
||||
@@ -57,7 +57,8 @@ async def fetch_user_image(params: FunctionCallParams):
|
||||
|
||||
When called, this function pushes a UserImageRequestFrame upstream to the
|
||||
transport. As a result, the transport will request the user image and push a
|
||||
UserImageRawFrame downstream.
|
||||
UserImageRawFrame downstream. The result_callback will be invoked once the
|
||||
image is retrieved and processed.
|
||||
"""
|
||||
user_id = params.arguments["user_id"]
|
||||
question = params.arguments["question"]
|
||||
@@ -65,7 +66,8 @@ async def fetch_user_image(params: FunctionCallParams):
|
||||
|
||||
# Request a user image frame. In this case, we don't want the requested
|
||||
# image to be added to the context because we will process it with
|
||||
# Moondream. Also associate it to the function call.
|
||||
# Moondream. Also associate it to the function call. Pass the result_callback
|
||||
# so it can be invoked when the image is actually retrieved.
|
||||
await params.llm.push_frame(
|
||||
UserImageRequestFrame(
|
||||
user_id=user_id,
|
||||
@@ -73,16 +75,11 @@ async def fetch_user_image(params: FunctionCallParams):
|
||||
append_to_context=False,
|
||||
function_name=params.function_name,
|
||||
tool_call_id=params.tool_call_id,
|
||||
result_callback=params.result_callback,
|
||||
),
|
||||
FrameDirection.UPSTREAM,
|
||||
)
|
||||
|
||||
await params.result_callback(None)
|
||||
|
||||
# Instead of None, it's possible to also provide a tool call answer to
|
||||
# tell the LLM that we are grabbing the image to analyze.
|
||||
# await params.result_callback({"result": "Image is being captured."})
|
||||
|
||||
|
||||
class MoondreamTextFrameWrapper(FrameProcessor):
|
||||
"""Wraps Moondream-provided TextFrames with LLM response start/end frames.
|
||||
|
||||
@@ -49,14 +49,16 @@ async def fetch_user_image(params: FunctionCallParams):
|
||||
When called, this function pushes a UserImageRequestFrame upstream to the
|
||||
transport. As a result, the transport will request the user image and push a
|
||||
UserImageRawFrame downstream which will be added to the context by the LLM
|
||||
assistant aggregator.
|
||||
assistant aggregator. The result_callback will be invoked once the image is
|
||||
retrieved and processed.
|
||||
"""
|
||||
user_id = params.arguments["user_id"]
|
||||
question = params.arguments["question"]
|
||||
logger.debug(f"Requesting image with user_id={user_id}, question={question}")
|
||||
|
||||
# Request a user image frame and indicate that it should be added to the
|
||||
# context. Also associate it to the function call.
|
||||
# context. Also associate it to the function call. Pass the result_callback
|
||||
# so it can be invoked when the image is actually retrieved.
|
||||
await params.llm.push_frame(
|
||||
UserImageRequestFrame(
|
||||
user_id=user_id,
|
||||
@@ -64,16 +66,11 @@ async def fetch_user_image(params: FunctionCallParams):
|
||||
append_to_context=True,
|
||||
function_name=params.function_name,
|
||||
tool_call_id=params.tool_call_id,
|
||||
result_callback=params.result_callback,
|
||||
),
|
||||
FrameDirection.UPSTREAM,
|
||||
)
|
||||
|
||||
await params.result_callback(None)
|
||||
|
||||
# Instead of None, it's possible to also provide a tool call answer to
|
||||
# tell the LLM that we are grabbing the image to analyze.
|
||||
# await params.result_callback({"result": "Image is being captured."})
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
|
||||
@@ -58,14 +58,16 @@ async def get_image(params: FunctionCallParams):
|
||||
When called, this function pushes a UserImageRequestFrame upstream to the
|
||||
transport. As a result, the transport will request the user image and push a
|
||||
UserImageRawFrame downstream which will be added to the context by the LLM
|
||||
assistant aggregator.
|
||||
assistant aggregator. The result_callback will be invoked once the image is
|
||||
retrieved and processed.
|
||||
"""
|
||||
user_id = params.arguments["user_id"]
|
||||
question = params.arguments["question"]
|
||||
logger.debug(f"Requesting image with user_id={user_id}, question={question}")
|
||||
|
||||
# Request a user image frame and indicate that it should be added to the
|
||||
# context. Also associate it to the function call.
|
||||
# context. Also associate it to the function call. Pass the result_callback
|
||||
# so it can be invoked when the image is actually retrieved.
|
||||
await params.llm.push_frame(
|
||||
UserImageRequestFrame(
|
||||
user_id=user_id,
|
||||
@@ -73,16 +75,11 @@ async def get_image(params: FunctionCallParams):
|
||||
append_to_context=True,
|
||||
function_name=params.function_name,
|
||||
tool_call_id=params.tool_call_id,
|
||||
result_callback=params.result_callback,
|
||||
),
|
||||
FrameDirection.UPSTREAM,
|
||||
)
|
||||
|
||||
await params.result_callback(None)
|
||||
|
||||
# Instead of None, it's possible to also provide a tool call answer to
|
||||
# tell the LLM that we are grabbing the image to analyze.
|
||||
# await params.result_callback({"result": "Image is being captured."})
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
|
||||
@@ -20,10 +20,6 @@ from pipecat.transports.daily.transport import DailyParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
parser = argparse.ArgumentParser(description="Pipecat Video Streaming Bot")
|
||||
parser.add_argument("-i", "--input", type=str, required=True, help="Input video file")
|
||||
args = parser.parse_args()
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
@@ -46,10 +42,10 @@ transport_params = {
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot with video input: {args.input}")
|
||||
logger.info(f"Starting bot with video input: {runner_args.cli_args.input}")
|
||||
|
||||
gst = GStreamerPipelineSource(
|
||||
pipeline=f"filesrc location={args.input}",
|
||||
pipeline=f"filesrc location={runner_args.cli_args.input}",
|
||||
out_params=GStreamerPipelineSource.OutputParams(
|
||||
video_width=1280,
|
||||
video_height=720,
|
||||
@@ -68,6 +64,15 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
@@ -82,4 +87,7 @@ async def bot(runner_args: RunnerArguments):
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
parser = argparse.ArgumentParser(description="Pipecat Video Streaming Bot")
|
||||
parser.add_argument("-i", "--input", type=str, required=True, help="Input video file")
|
||||
|
||||
main(parser)
|
||||
|
||||
@@ -66,7 +66,8 @@ async def get_image(params: FunctionCallParams):
|
||||
logger.debug(f"Requesting image with user_id={user_id}, question={question}")
|
||||
|
||||
# Request a user image frame and indicate that it should be added to the
|
||||
# context. Also associate it to the function call.
|
||||
# context. Also associate it to the function call. Pass the result_callback
|
||||
# so it can be invoked when the image is actually retrieved.
|
||||
await params.llm.push_frame(
|
||||
UserImageRequestFrame(
|
||||
user_id=user_id,
|
||||
@@ -74,16 +75,11 @@ async def get_image(params: FunctionCallParams):
|
||||
append_to_context=True,
|
||||
function_name=params.function_name,
|
||||
tool_call_id=params.tool_call_id,
|
||||
result_callback=params.result_callback,
|
||||
),
|
||||
FrameDirection.UPSTREAM,
|
||||
)
|
||||
|
||||
await params.result_callback(None)
|
||||
|
||||
# Instead of None, it's possible to also provide a tool call answer to
|
||||
# tell the LLM that we are grabbing the image to analyze.
|
||||
# await params.result_callback({"result": "Image is being captured."})
|
||||
|
||||
|
||||
async def get_saved_conversation_filenames(params: FunctionCallParams):
|
||||
# Construct the full pattern including the BASE_FILENAME
|
||||
|
||||
@@ -31,7 +31,7 @@ Requirements:
|
||||
- [Optional] Anthropic API key (if using Claude with local config)
|
||||
|
||||
Environment variables (set in .env or in your terminal using `export`):
|
||||
DAILY_SAMPLE_ROOM_URL=daily_sample_room_url
|
||||
DAILY_ROOM_URL=daily_room_url
|
||||
DAILY_API_KEY=daily_api_key
|
||||
OPENAI_API_KEY=openai_api_key
|
||||
ELEVENLABS_API_KEY=elevenlabs_api_key
|
||||
|
||||
@@ -4,7 +4,7 @@ This directory contains examples showing how to build voice and multimodal agent
|
||||
|
||||
## Setup
|
||||
|
||||
1. Follow the [README](../../README.md#%EF%B8%8F-contributing-to-the-framework) steps to get your local environment configured.
|
||||
1. Follow the [README](https://github.com/pipecat-ai/pipecat/blob/main/README.md#%EF%B8%8F-contributing-to-the-framework) steps to get your local environment configured.
|
||||
|
||||
> **Run from root directory**: Make sure you are running the steps from the root directory.
|
||||
|
||||
@@ -37,7 +37,7 @@ Most examples support running with other transports, like Twilio or Daily.
|
||||
|
||||
### Daily
|
||||
|
||||
You need to create a Daily account at https://dashboard.daily.co/u/signup. Once signed up, you can create your own room from the dashboard and set the environment variables `DAILY_SAMPLE_ROOM_URL` and `DAILY_API_KEY`. Alternatively, you can let the example create a room for you (still needs `DAILY_API_KEY` environment variable). Then, start any example with `-t daily`:
|
||||
You need to create a Daily account at https://dashboard.daily.co/u/signup. Once signed up, you can create your own room from the dashboard and set the environment variables `DAILY_ROOM_URL` and `DAILY_API_KEY`. Alternatively, you can let the example create a room for you (still needs `DAILY_API_KEY` environment variable). Then, start any example with `-t daily`:
|
||||
|
||||
```bash
|
||||
uv run 07-interruptible.py -t daily
|
||||
@@ -140,4 +140,4 @@ uv run python <example-name> --host 0.0.0.0 --port 8080
|
||||
- **Connection errors**: Verify API keys in `.env` file
|
||||
- **Port conflicts**: Use `--port` to change the port
|
||||
|
||||
For more examples, visit our the [`pipecat-examples repository](https://github.com/pipecat-ai/pipecat-examples).
|
||||
For more examples, visit our the [pipecat-examples repository](https://github.com/pipecat-ai/pipecat-examples).
|
||||
|
||||
@@ -48,13 +48,13 @@ Issues = "https://github.com/pipecat-ai/pipecat/issues"
|
||||
Changelog = "https://github.com/pipecat-ai/pipecat/blob/main/CHANGELOG.md"
|
||||
|
||||
[project.optional-dependencies]
|
||||
aic = [ "aic-sdk~=1.2.0" ]
|
||||
aic = [ "aic-sdk~=2.0.1" ]
|
||||
anthropic = [ "anthropic~=0.49.0" ]
|
||||
assemblyai = [ "pipecat-ai[websockets-base]" ]
|
||||
asyncai = [ "pipecat-ai[websockets-base]" ]
|
||||
aws = [ "aioboto3~=15.5.0", "pipecat-ai[websockets-base]" ]
|
||||
aws-nova-sonic = [ "aws_sdk_bedrock_runtime~=0.2.0; python_version>='3.12'" ]
|
||||
azure = [ "azure-cognitiveservices-speech~=1.44.0"]
|
||||
azure = [ "azure-cognitiveservices-speech~=1.47.0"]
|
||||
cartesia = [ "cartesia~=2.0.3", "pipecat-ai[websockets-base]" ]
|
||||
camb = [ "camb-sdk>=1.5.4" ]
|
||||
cerebras = []
|
||||
@@ -95,6 +95,7 @@ rnnoise = [ "pyrnnoise~=0.4.1" ]
|
||||
openpipe = [ "openpipe>=4.50.0,<6" ]
|
||||
openrouter = []
|
||||
perplexity = []
|
||||
piper = [ "piper-tts>=1.3.0,<2" ]
|
||||
playht = [ "pipecat-ai[websockets-base]" ]
|
||||
qwen = []
|
||||
remote-smart-turn = []
|
||||
@@ -109,7 +110,7 @@ silero = [ "onnxruntime>=1.20.1,<2" ]
|
||||
simli = [ "simli-ai~=1.0.3"]
|
||||
soniox = [ "pipecat-ai[websockets-base]" ]
|
||||
soundfile = [ "soundfile~=0.13.1" ]
|
||||
speechmatics = [ "speechmatics-voice[smart]>=0.2.6" ]
|
||||
speechmatics = [ "speechmatics-voice[smart]>=0.2.8" ]
|
||||
strands = [ "strands-agents>=1.9.1,<2" ]
|
||||
tavus=[]
|
||||
together = []
|
||||
|
||||
@@ -195,7 +195,7 @@ class EvalRunner:
|
||||
|
||||
|
||||
async def run_example_pipeline(script_path: Path, eval_config: EvalConfig):
|
||||
room_url = os.getenv("DAILY_SAMPLE_ROOM_URL")
|
||||
room_url = os.getenv("DAILY_ROOM_URL")
|
||||
|
||||
module = load_module_from_path(script_path)
|
||||
|
||||
@@ -225,7 +225,7 @@ async def run_eval_pipeline(
|
||||
):
|
||||
logger.info(f"Starting eval bot")
|
||||
|
||||
room_url = os.getenv("DAILY_SAMPLE_ROOM_URL")
|
||||
room_url = os.getenv("DAILY_ROOM_URL")
|
||||
|
||||
transport = DailyTransport(
|
||||
room_url,
|
||||
|
||||
@@ -133,12 +133,12 @@ TESTS_07 = [
|
||||
("07zb-interruptible-inworld-http.py", EVAL_SIMPLE_MATH),
|
||||
("07zc-interruptible-asyncai.py", EVAL_SIMPLE_MATH),
|
||||
("07zc-interruptible-asyncai-http.py", EVAL_SIMPLE_MATH),
|
||||
# Need license key to run
|
||||
# ("07zd-interruptible-aicoustics.py", EVAL_SIMPLE_MATH),
|
||||
("07zd-interruptible-aicoustics.py", EVAL_SIMPLE_MATH),
|
||||
("07ze-interruptible-hume.py", EVAL_SIMPLE_MATH),
|
||||
("07zf-interruptible-gradium.py", EVAL_SIMPLE_MATH),
|
||||
("07zg-interruptible-camb.py", EVAL_SIMPLE_MATH),
|
||||
("07zh-interruptible-hathora.py", EVAL_SIMPLE_MATH),
|
||||
("07zi-interruptible-piper.py", EVAL_SIMPLE_MATH),
|
||||
# Needs a local XTTS docker instance running.
|
||||
# ("07i-interruptible-xtts.py", EVAL_SIMPLE_MATH),
|
||||
# Needs a Krisp license.
|
||||
|
||||
@@ -28,7 +28,7 @@ def check_env_variables() -> bool:
|
||||
"CARTESIA_API_KEY",
|
||||
"DEEPGRAM_API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
"DAILY_SAMPLE_ROOM_URL",
|
||||
"DAILY_ROOM_URL",
|
||||
]
|
||||
for env in required_envs:
|
||||
if not os.getenv(env):
|
||||
|
||||
@@ -9,129 +9,145 @@
|
||||
This module provides an audio filter implementation using ai-coustics' AIC SDK to
|
||||
enhance audio streams in real time. It mirrors the structure of other filters like
|
||||
the Koala filter and integrates with Pipecat's input transport pipeline.
|
||||
|
||||
Classes:
|
||||
AICFilter: For aic-sdk (uses 'aic_sdk' module)
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
from aic_sdk import (
|
||||
Model,
|
||||
ParameterFixedError,
|
||||
ProcessorAsync,
|
||||
ProcessorConfig,
|
||||
ProcessorParameter,
|
||||
set_sdk_id,
|
||||
)
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.filters.base_audio_filter import BaseAudioFilter
|
||||
from pipecat.audio.vad.aic_vad import AICVADAnalyzer
|
||||
from pipecat.frames.frames import FilterControlFrame, FilterEnableFrame
|
||||
|
||||
try:
|
||||
# AIC SDK (https://ai-coustics.github.io/aic-sdk-py/api/)
|
||||
from aic import AICModelType, AICParameter, Model
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use the AIC filter, you need to `pip install pipecat-ai[aic]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class AICFilter(BaseAudioFilter):
|
||||
"""Audio filter using ai-coustics' AIC SDK for real-time enhancement.
|
||||
|
||||
Buffers incoming audio to the model's preferred block size and processes
|
||||
planar frames in-place using float32 samples in the linear -1..+1 range.
|
||||
frames using float32 samples normalized to the range -1 to +1.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
license_key: str = "",
|
||||
model_type: AICModelType = AICModelType.QUAIL_STT,
|
||||
enhancement_level: Optional[float] = 1.0,
|
||||
voice_gain: Optional[float] = 1.0,
|
||||
noise_gate_enable: Optional[bool] = True,
|
||||
license_key: str,
|
||||
model_id: Optional[str] = None,
|
||||
model_path: Optional[Path] = None,
|
||||
model_download_dir: Optional[Path] = None,
|
||||
) -> None:
|
||||
"""Initialize the AIC filter.
|
||||
|
||||
Args:
|
||||
license_key: ai-coustics license key for authentication.
|
||||
model_type: Model variant to load.
|
||||
enhancement_level: Optional overall enhancement strength (0.0..1.0).
|
||||
voice_gain: Optional linear gain applied to detected speech (0.0..4.0).
|
||||
noise_gate_enable: Optional enable/disable noise gate (default: True).
|
||||
model_id: Model identifier to download from CDN. Required if model_path
|
||||
is not provided. See https://artifacts.ai-coustics.io/ for available models.
|
||||
model_path: Optional path to a local .aicmodel file. If provided,
|
||||
model_id is ignored and no download occurs.
|
||||
model_download_dir: Directory for downloading models as a Path object.
|
||||
Defaults to a cache directory in user's home folder.
|
||||
|
||||
.. deprecated:: 1.3.0
|
||||
The `noise_gate_enable` parameter is deprecated and no longer has any effect.
|
||||
It will be removed in a future version.
|
||||
Raises:
|
||||
ValueError: If neither model_id nor model_path is provided.
|
||||
"""
|
||||
# Set SDK ID for telemetry identification (6 = pipecat)
|
||||
set_sdk_id(6)
|
||||
|
||||
if model_id is None and model_path is None:
|
||||
raise ValueError(
|
||||
"Either 'model_id' or 'model_path' must be provided. "
|
||||
"See https://artifacts.ai-coustics.io/ for available models."
|
||||
)
|
||||
|
||||
self._license_key = license_key
|
||||
self._model_type = model_type
|
||||
self._model_id = model_id
|
||||
self._model_path = model_path
|
||||
self._model_download_dir = model_download_dir or (
|
||||
Path.home() / ".cache" / "pipecat" / "aic-models"
|
||||
)
|
||||
|
||||
self._enhancement_level = enhancement_level
|
||||
self._voice_gain = voice_gain
|
||||
if noise_gate_enable is not None:
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Parameter `noise_gate_enable` is deprecated and no longer has any effect. "
|
||||
"It will be removed in a future version. Use AIC VAD instead (create_vad_analyzer()).",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
self._noise_gate_enable = noise_gate_enable
|
||||
|
||||
self._enabled = True
|
||||
self._bypass = False
|
||||
self._sample_rate = 0
|
||||
self._aic_ready = False
|
||||
self._frames_per_block = 0
|
||||
self._audio_buffer = bytearray()
|
||||
# Model will be created in start() since the API now requires sample_rate
|
||||
self._aic = None
|
||||
|
||||
def get_vad_factory(self):
|
||||
"""Return a zero-arg factory that will create the VAD once the model exists.
|
||||
# Audio format constants
|
||||
self._bytes_per_sample = 2 # int16 = 2 bytes
|
||||
self._dtype = np.int16
|
||||
self._scale = (
|
||||
32768.0 # 2^15, for normalizing int16 (-32768 to 32767) to float32 (-1.0 to 1.0)
|
||||
)
|
||||
|
||||
# AIC SDK objects
|
||||
self._model = None
|
||||
self._processor = None
|
||||
self._processor_ctx = None
|
||||
self._vad_ctx = None
|
||||
|
||||
# Pre-allocated buffers (resized in start() once frames_per_block is known)
|
||||
self._in_f32 = None
|
||||
self._out_i16 = None
|
||||
|
||||
def get_vad_context(self):
|
||||
"""Return the VAD context once the processor exists.
|
||||
|
||||
Returns:
|
||||
A zero-argument callable that, when invoked, returns an initialized
|
||||
VoiceActivityDetector bound to the underlying AIC model. Raises a
|
||||
RuntimeError if the model has not been initialized (i.e. start()
|
||||
has not been called successfully).
|
||||
The VadContext instance bound to the underlying processor.
|
||||
Raises RuntimeError if the processor has not been initialized.
|
||||
"""
|
||||
|
||||
def _factory():
|
||||
if self._aic is None:
|
||||
raise RuntimeError("AIC model not initialized yet. Call start(sample_rate) first.")
|
||||
return self._aic.create_vad()
|
||||
|
||||
return _factory
|
||||
if self._vad_ctx is None:
|
||||
raise RuntimeError("AIC processor not initialized yet. Call start(sample_rate) first.")
|
||||
return self._vad_ctx
|
||||
|
||||
def create_vad_analyzer(
|
||||
self,
|
||||
*,
|
||||
lookback_buffer_size: Optional[float] = None,
|
||||
speech_hold_duration: Optional[float] = None,
|
||||
minimum_speech_duration: Optional[float] = None,
|
||||
sensitivity: Optional[float] = None,
|
||||
):
|
||||
"""Return an analyzer that will lazily instantiate the AIC VAD when ready.
|
||||
|
||||
AIC VAD parameters:
|
||||
- lookback_buffer_size:
|
||||
Number of window-length audio buffers used as a lookback buffer.
|
||||
Higher values increase prediction stability but add latency.
|
||||
Range: 1.0 .. 20.0, Default (SDK): 6.0
|
||||
- speech_hold_duration:
|
||||
How long VAD continues detecting after speech ends (in seconds).
|
||||
Range: 0.0 to 100x model window length, Default (SDK): 0.05s
|
||||
- minimum_speech_duration:
|
||||
Minimum duration of speech required before VAD reports speech detected
|
||||
(in seconds). Range: 0.0 to 1.0, Default (SDK): 0.0s
|
||||
- sensitivity:
|
||||
Energy threshold sensitivity. Energy threshold = 10 ** (-sensitivity).
|
||||
Range: 1.0 .. 15.0, Default (SDK): 6.0
|
||||
Range: 1.0 to 15.0, Default (SDK): 6.0
|
||||
|
||||
Args:
|
||||
lookback_buffer_size: Optional lookback buffer size to configure on the VAD.
|
||||
Range: 1.0 .. 20.0. If None, SDK default is used.
|
||||
speech_hold_duration: Optional speech hold duration to configure on the VAD.
|
||||
If None, SDK default (0.05s) is used.
|
||||
minimum_speech_duration: Optional minimum speech duration before VAD reports
|
||||
speech detected. If None, SDK default (0.0s) is used.
|
||||
sensitivity: Optional sensitivity (energy threshold) to configure on the VAD.
|
||||
Range: 1.0 .. 15.0. If None, SDK default is used.
|
||||
Range: 1.0 to 15.0. If None, SDK default (6.0) is used.
|
||||
|
||||
Returns:
|
||||
A lazily-initialized AICVADAnalyzer that will bind to the VAD backend
|
||||
once the filter's model has been created (after start(sample_rate)).
|
||||
A lazily-initialized AICVADAnalyzer that will bind to the VAD context
|
||||
once the filter's processor has been created (after start(sample_rate)).
|
||||
"""
|
||||
from pipecat.audio.vad.aic_vad import AICVADAnalyzer
|
||||
|
||||
return AICVADAnalyzer(
|
||||
vad_factory=self.get_vad_factory(),
|
||||
lookback_buffer_size=lookback_buffer_size,
|
||||
vad_context_factory=lambda: self.get_vad_context(),
|
||||
speech_hold_duration=speech_hold_duration,
|
||||
minimum_speech_duration=minimum_speech_duration,
|
||||
sensitivity=sensitivity,
|
||||
)
|
||||
|
||||
@@ -146,52 +162,83 @@ class AICFilter(BaseAudioFilter):
|
||||
"""
|
||||
self._sample_rate = sample_rate
|
||||
|
||||
# Load or download model
|
||||
if self._model_path:
|
||||
logger.debug(f"Loading AIC model from: {self._model_path}")
|
||||
self._model = Model.from_file(str(self._model_path))
|
||||
else:
|
||||
logger.debug(f"Downloading AIC model: {self._model_id}")
|
||||
self._model_download_dir.mkdir(parents=True, exist_ok=True)
|
||||
model_path = await Model.download_async(self._model_id, str(self._model_download_dir))
|
||||
logger.debug(f"Model downloaded to: {model_path}")
|
||||
self._model = Model.from_file(model_path)
|
||||
|
||||
# Get optimal frames for this sample rate
|
||||
self._frames_per_block = self._model.get_optimal_num_frames(self._sample_rate)
|
||||
|
||||
# Allocate processing buffers now that we know the block size
|
||||
self._in_f32 = np.zeros((1, self._frames_per_block), dtype=np.float32)
|
||||
self._out_i16 = np.zeros(self._frames_per_block, dtype=np.int16)
|
||||
|
||||
# Create configuration
|
||||
config = ProcessorConfig.optimal(
|
||||
self._model,
|
||||
sample_rate=self._sample_rate,
|
||||
)
|
||||
|
||||
# Create async processor
|
||||
try:
|
||||
# Create model with required runtime parameters
|
||||
self._aic = Model(
|
||||
model_type=self._model_type,
|
||||
license_key=self._license_key or None,
|
||||
sample_rate=self._sample_rate,
|
||||
channels=1,
|
||||
)
|
||||
self._frames_per_block = self._aic.optimal_num_frames()
|
||||
|
||||
# Optional parameter configuration
|
||||
if self._enhancement_level is not None:
|
||||
self._aic.set_parameter(
|
||||
AICParameter.ENHANCEMENT_LEVEL,
|
||||
float(self._enhancement_level if self._enabled else 0.0),
|
||||
)
|
||||
if self._voice_gain is not None:
|
||||
self._aic.set_parameter(AICParameter.VOICE_GAIN, float(self._voice_gain))
|
||||
|
||||
self._aic_ready = True
|
||||
|
||||
# Log processor information
|
||||
logger.debug(f"ai-coustics filter started:")
|
||||
logger.debug(f" Sample rate: {self._sample_rate} Hz")
|
||||
logger.debug(f" Frames per chunk: {self._frames_per_block}")
|
||||
logger.debug(f" Enhancement strength: {int(self._enhancement_level * 100)}%")
|
||||
logger.debug(f" Optimal input buffer size: {self._aic.optimal_num_frames()} samples")
|
||||
logger.debug(f" Optimal sample rate: {self._aic.optimal_sample_rate()} Hz")
|
||||
logger.debug(
|
||||
f" Current algorithmic latency: {self._aic.processing_latency() / self._sample_rate * 1000:.2f}ms"
|
||||
)
|
||||
self._processor = ProcessorAsync(self._model, self._license_key, config)
|
||||
except Exception as e: # noqa: BLE001 - surfacing SDK initialization errors
|
||||
logger.error(f"AIC model initialization failed: {e}")
|
||||
self._aic_ready = False
|
||||
self._processor = None
|
||||
|
||||
self._aic_ready = self._processor is not None
|
||||
|
||||
if not self._aic_ready:
|
||||
logger.debug(f"ai-coustics filter is not ready.")
|
||||
return
|
||||
|
||||
# Get contexts for parameter control and VAD
|
||||
self._processor_ctx = self._processor.get_processor_context()
|
||||
self._vad_ctx = self._processor.get_vad_context()
|
||||
|
||||
# Apply initial parameters
|
||||
try:
|
||||
self._processor_ctx.set_parameter(
|
||||
ProcessorParameter.Bypass, 1.0 if self._bypass else 0.0
|
||||
)
|
||||
except ParameterFixedError as e:
|
||||
logger.error(f"AIC parameter update failed: {e}")
|
||||
|
||||
# Log processor information
|
||||
logger.debug(f"ai-coustics filter started:")
|
||||
logger.debug(f" Model ID: {self._model.get_id()}")
|
||||
logger.debug(f" Sample rate: {self._sample_rate} Hz")
|
||||
logger.debug(f" Frames per chunk: {self._frames_per_block}")
|
||||
logger.debug(f" Optimal sample rate: {self._model.get_optimal_sample_rate()} Hz")
|
||||
logger.debug(
|
||||
f" Optimal number of frames for {self._sample_rate} Hz: {self._model.get_optimal_num_frames(self._sample_rate)}"
|
||||
)
|
||||
logger.debug(
|
||||
f" Output delay: {self._processor_ctx.get_output_delay()} samples "
|
||||
f"({self._processor_ctx.get_output_delay() / self._sample_rate * 1000:.2f}ms)"
|
||||
)
|
||||
|
||||
async def stop(self):
|
||||
"""Clean up the AIC model when stopping.
|
||||
"""Clean up the AIC processor when stopping.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
try:
|
||||
if self._aic is not None:
|
||||
self._aic.close()
|
||||
if self._processor_ctx is not None:
|
||||
self._processor_ctx.reset()
|
||||
finally:
|
||||
self._aic = None
|
||||
self._processor = None
|
||||
self._processor_ctx = None
|
||||
self._vad_ctx = None
|
||||
self._model = None
|
||||
self._aic_ready = False
|
||||
self._audio_buffer.clear()
|
||||
|
||||
@@ -205,11 +252,12 @@ class AICFilter(BaseAudioFilter):
|
||||
None
|
||||
"""
|
||||
if isinstance(frame, FilterEnableFrame):
|
||||
self._enabled = frame.enable
|
||||
if self._aic is not None:
|
||||
self._bypass = not frame.enable
|
||||
if self._processor_ctx is not None:
|
||||
try:
|
||||
level = float(self._enhancement_level if self._enabled else 0.0)
|
||||
self._aic.set_parameter(AICParameter.ENHANCEMENT_LEVEL, level)
|
||||
self._processor_ctx.set_parameter(
|
||||
ProcessorParameter.Bypass, 1.0 if self._bypass else 0.0
|
||||
)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error(f"AIC set_parameter failed: {e}")
|
||||
|
||||
@@ -220,43 +268,41 @@ class AICFilter(BaseAudioFilter):
|
||||
model's required block length. Returns enhanced audio data.
|
||||
|
||||
Args:
|
||||
audio: Raw audio data as bytes to be filtered (int16 PCM, planar).
|
||||
audio: Raw audio data as bytes (int16 PCM).
|
||||
|
||||
Returns:
|
||||
Enhanced audio data as bytes (int16 PCM, planar).
|
||||
Enhanced audio data as bytes (int16 PCM).
|
||||
"""
|
||||
if not self._aic_ready or self._aic is None:
|
||||
if not self._aic_ready or self._processor is None:
|
||||
return audio
|
||||
|
||||
self._audio_buffer.extend(audio)
|
||||
available_frames = len(self._audio_buffer) // self._bytes_per_sample
|
||||
num_blocks = available_frames // self._frames_per_block
|
||||
|
||||
if num_blocks == 0:
|
||||
return b""
|
||||
|
||||
filtered_chunks: List[bytes] = []
|
||||
mv = memoryview(self._audio_buffer)
|
||||
block_size = self._frames_per_block * self._bytes_per_sample
|
||||
|
||||
# Number of int16 samples currently buffered
|
||||
available_frames = len(self._audio_buffer) // 2
|
||||
for i in range(num_blocks):
|
||||
start = i * block_size
|
||||
block_i16 = np.frombuffer(mv[start : start + block_size], dtype=self._dtype)
|
||||
|
||||
while available_frames >= self._frames_per_block:
|
||||
# Consume exactly one block worth of frames
|
||||
samples_to_consume = self._frames_per_block * 1
|
||||
bytes_to_consume = samples_to_consume * 2
|
||||
block_bytes = bytes(self._audio_buffer[:bytes_to_consume])
|
||||
# Reuse input buffer, in-place divide
|
||||
np.copyto(self._in_f32[0], block_i16)
|
||||
self._in_f32 /= self._scale
|
||||
|
||||
# Convert to float32 in -1..+1 range and reshape to planar (channels, frames)
|
||||
block_i16 = np.frombuffer(block_bytes, dtype=np.int16)
|
||||
block_f32 = (block_i16.astype(np.float32) / 32768.0).reshape(
|
||||
(1, self._frames_per_block)
|
||||
)
|
||||
out_f32 = await self._processor.process_async(self._in_f32)
|
||||
|
||||
# Process planar in-place; returns ndarray (same shape)
|
||||
out_f32 = await self._aic.process_async(block_f32)
|
||||
# Convert float32 output back to int16
|
||||
np.multiply(out_f32, self._scale, out=self._in_f32) # reuse in_f32 as temp
|
||||
np.clip(self._in_f32, -self._scale, self._scale - 1, out=self._in_f32)
|
||||
np.copyto(self._out_i16, self._in_f32[0].astype(self._dtype))
|
||||
|
||||
# Convert back to int16 bytes, planar layout
|
||||
out_i16 = np.clip(out_f32 * 32768.0, -32768, 32767).astype(np.int16)
|
||||
filtered_chunks.append(out_i16.reshape(-1).tobytes())
|
||||
filtered_chunks.append(self._out_i16.tobytes())
|
||||
|
||||
# Slide buffer
|
||||
self._audio_buffer = self._audio_buffer[bytes_to_consume:]
|
||||
available_frames = len(self._audio_buffer) // 2
|
||||
|
||||
# Do not flush incomplete frames; keep them buffered for the next call
|
||||
self._audio_buffer = self._audio_buffer[num_blocks * block_size :]
|
||||
return b"".join(filtered_chunks)
|
||||
|
||||
@@ -49,7 +49,7 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._log_data = env_truthy("SMART_TURN_LOG_DATA", default=False)
|
||||
self._log_data = env_truthy("PIPECAT_SMART_TURN_LOG_DATA", default=False)
|
||||
|
||||
if not smart_turn_model_path:
|
||||
# Load bundled model
|
||||
|
||||
@@ -1,44 +1,44 @@
|
||||
"""AIC-integrated VAD analyzer that lazily binds to the AIC SDK backend.
|
||||
|
||||
This analyzer queries the backend's is_speech_detected() and maps it to a float
|
||||
confidence (1.0/0.0). It uses 10 ms windows based on the sample rate and applies
|
||||
optional AIC VAD parameters (lookback_buffer_size, sensitivity) when available.
|
||||
This module provides VAD analyzer implementations that query the AIC SDK's
|
||||
is_speech_detected() and map it to a float confidence (1.0/0.0).
|
||||
|
||||
Classes:
|
||||
AICVADAnalyzer: For aic-sdk (uses 'aic_sdk' module)
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from aic_sdk import VadParameter
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.vad.vad_analyzer import VADAnalyzer, VADParams
|
||||
|
||||
try:
|
||||
from aic import AICVadParameter
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use the AIC filter, you need to `pip install pipecat-ai[aic]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class AICVADAnalyzer(VADAnalyzer):
|
||||
"""VAD analyzer that lazily instantiates the AIC VoiceActivityDetector via a factory.
|
||||
"""VAD analyzer that lazily binds to the AIC VadContext via a factory.
|
||||
|
||||
The analyzer can be constructed before the AIC Model exists. Once the filter has
|
||||
started and the Model is available, the provided factory will succeed and the
|
||||
backend VAD will be created. We then switch to single-sample updates where
|
||||
num_frames_required() returns 1 and confidence is derived from the backend's
|
||||
boolean is_speech_detected() state.
|
||||
The analyzer can be constructed before the AIC Processor exists. Once the filter has
|
||||
started and the Processor is available, the provided factory will succeed and the
|
||||
VadContext will be obtained. The context's is_speech_detected() boolean state is
|
||||
then mapped to 1.0 (speech) or 0.0 (no speech) to satisfy the VADAnalyzer interface.
|
||||
|
||||
AIC VAD runtime parameters:
|
||||
- lookback_buffer_size:
|
||||
Controls the lookback buffer size used by the VAD, i.e. the number of
|
||||
window-length audio buffers used as a lookback buffer. Larger values improve
|
||||
stability but increase latency.
|
||||
Range: 1.0 .. 20.0
|
||||
Default (SDK): 6.0
|
||||
- speech_hold_duration:
|
||||
Controls for how long the VAD continues to detect speech after the audio signal
|
||||
no longer contains speech (in seconds).
|
||||
Range: 0.0 to 100x model window length
|
||||
Default (SDK): 0.05s (50ms)
|
||||
- minimum_speech_duration:
|
||||
Controls for how long speech needs to be present in the audio signal before the
|
||||
VAD considers it speech (in seconds).
|
||||
Range: 0.0 to 1.0
|
||||
Default (SDK): 0.0s
|
||||
- sensitivity:
|
||||
Controls the energy threshold sensitivity. Higher values make the detector
|
||||
less sensitive (require more energy to count as speech).
|
||||
Range: 1.0 .. 15.0
|
||||
Controls the sensitivity (energy threshold) of the VAD. This value is used by
|
||||
the VAD as the threshold a speech audio signal's energy has to exceed in order
|
||||
to be considered speech.
|
||||
Range: 1.0 to 15.0
|
||||
Formula: Energy threshold = 10 ** (-sensitivity)
|
||||
Default (SDK): 6.0
|
||||
"""
|
||||
@@ -46,69 +46,80 @@ class AICVADAnalyzer(VADAnalyzer):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vad_factory: Optional[Callable[[], Any]] = None,
|
||||
lookback_buffer_size: Optional[float] = None,
|
||||
vad_context_factory: Optional[Callable[[], Any]] = None,
|
||||
speech_hold_duration: Optional[float] = None,
|
||||
minimum_speech_duration: Optional[float] = None,
|
||||
sensitivity: Optional[float] = None,
|
||||
):
|
||||
"""Create an AIC VAD analyzer.
|
||||
|
||||
Args:
|
||||
vad_factory:
|
||||
Zero-arg callable that returns an initialized AIC VoiceActivityDetector.
|
||||
This may raise until the filter's Model has been created; the analyzer
|
||||
vad_context_factory:
|
||||
Zero-arg callable that returns the AIC VadContext.
|
||||
This may raise until the filter's Processor has been created; the analyzer
|
||||
will retry on set_sample_rate/first use.
|
||||
lookback_buffer_size:
|
||||
Optional override for AIC VAD lookback buffer size.
|
||||
Range: 1.0 .. 20.0. Larger values increase stability at the cost of latency.
|
||||
If None, the SDK default (6.0) is used.
|
||||
speech_hold_duration:
|
||||
Optional override for AIC VAD speech hold duration (in seconds).
|
||||
Range: 0.0 to 100x model window length.
|
||||
If None, the SDK default (0.05s) is used.
|
||||
minimum_speech_duration:
|
||||
Optional override for minimum speech duration before VAD reports
|
||||
speech detected (in seconds).
|
||||
Range: 0.0 to 1.0.
|
||||
If None, the SDK default (0.0s) is used.
|
||||
sensitivity:
|
||||
Optional override for AIC VAD sensitivity (energy threshold).
|
||||
Range: 1.0 .. 15.0. Energy threshold = 10 ** (-sensitivity).
|
||||
Range: 1.0 to 15.0. Energy threshold = 10 ** (-sensitivity).
|
||||
If None, the SDK default (6.0) is used.
|
||||
"""
|
||||
# Use fixed VAD parameters for AIC: no user override
|
||||
fixed_params = VADParams(confidence=0.5, start_secs=0.0, stop_secs=0.0, min_volume=0.0)
|
||||
super().__init__(sample_rate=None, params=fixed_params)
|
||||
self._vad_factory = vad_factory
|
||||
self._backend_vad: Optional[Any] = None
|
||||
self._pending_lookback: Optional[float] = lookback_buffer_size
|
||||
|
||||
self._vad_context_factory = vad_context_factory
|
||||
self._vad_ctx: Optional[Any] = None
|
||||
self._pending_speech_hold_duration: Optional[float] = speech_hold_duration
|
||||
self._pending_minimum_speech_duration: Optional[float] = minimum_speech_duration
|
||||
self._pending_sensitivity: Optional[float] = sensitivity
|
||||
|
||||
def bind_vad_factory(self, vad_factory: Callable[[], Any]):
|
||||
def bind_vad_context_factory(self, vad_context_factory: Callable[[], Any]):
|
||||
"""Attach or replace the factory post-construction."""
|
||||
self._vad_factory = vad_factory
|
||||
self._ensure_backend_initialized()
|
||||
self._vad_context_factory = vad_context_factory
|
||||
self._ensure_vad_context_initialized()
|
||||
|
||||
def _apply_backend_params(self):
|
||||
def _apply_vad_params(self):
|
||||
"""Apply optional AIC VAD parameters if available."""
|
||||
if self._backend_vad is None or AICVadParameter is None:
|
||||
if self._vad_ctx is None or VadParameter is None:
|
||||
return
|
||||
|
||||
try:
|
||||
if self._pending_lookback is not None:
|
||||
self._backend_vad.set_parameter(
|
||||
AICVadParameter.LOOKBACK_BUFFER_SIZE, float(self._pending_lookback)
|
||||
if self._pending_speech_hold_duration is not None:
|
||||
self._vad_ctx.set_parameter(
|
||||
VadParameter.SpeechHoldDuration, self._pending_speech_hold_duration
|
||||
)
|
||||
if self._pending_minimum_speech_duration is not None:
|
||||
self._vad_ctx.set_parameter(
|
||||
VadParameter.MinimumSpeechDuration, self._pending_minimum_speech_duration
|
||||
)
|
||||
if self._pending_sensitivity is not None:
|
||||
self._backend_vad.set_parameter(
|
||||
AICVadParameter.SENSITIVITY, float(self._pending_sensitivity)
|
||||
)
|
||||
self._vad_ctx.set_parameter(VadParameter.Sensitivity, self._pending_sensitivity)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.debug(f"AIC VAD parameter application deferred/failed: {e}")
|
||||
|
||||
def _ensure_backend_initialized(self):
|
||||
if self._backend_vad is not None:
|
||||
def _ensure_vad_context_initialized(self):
|
||||
if self._vad_ctx is not None:
|
||||
return
|
||||
if not self._vad_factory:
|
||||
if not self._vad_context_factory:
|
||||
return
|
||||
try:
|
||||
self._backend_vad = self._vad_factory()
|
||||
self._apply_backend_params()
|
||||
# With backend ready, recompute internal frame sizing
|
||||
self._vad_ctx = self._vad_context_factory()
|
||||
self._apply_vad_params()
|
||||
# With VAD context ready, recompute internal frame sizing
|
||||
super().set_params(self._params)
|
||||
logger.debug("AIC VAD backend initialized in analyzer.")
|
||||
logger.debug("AIC VAD context initialized in analyzer.")
|
||||
except Exception as e: # noqa: BLE001
|
||||
# Filter may not be started yet; try again later
|
||||
logger.debug(f"Deferring AIC VAD backend initialization: {e}")
|
||||
logger.debug(f"Deferring AIC VAD context initialization: {e}")
|
||||
|
||||
def set_sample_rate(self, sample_rate: int):
|
||||
"""Set the sample rate for audio processing.
|
||||
@@ -116,10 +127,10 @@ class AICVADAnalyzer(VADAnalyzer):
|
||||
Args:
|
||||
sample_rate: Audio sample rate in Hz.
|
||||
"""
|
||||
# Set rate and attempt backend initialization once we know SR
|
||||
# Set rate and attempt VAD context initialization once we know SR
|
||||
self._sample_rate = self._init_sample_rate or sample_rate
|
||||
self._ensure_backend_initialized()
|
||||
# Ensure params are initialized even if backend not ready yet
|
||||
self._ensure_vad_context_initialized()
|
||||
# Ensure params are initialized even if VAD context not ready yet
|
||||
try:
|
||||
super().set_params(self._params)
|
||||
except Exception:
|
||||
@@ -135,23 +146,29 @@ class AICVADAnalyzer(VADAnalyzer):
|
||||
return int(self.sample_rate * 0.01) if self.sample_rate > 0 else 160
|
||||
|
||||
def voice_confidence(self, buffer: bytes) -> float:
|
||||
"""Calculate voice activity confidence for the given audio buffer.
|
||||
"""Return voice activity detection result for the given audio buffer.
|
||||
|
||||
Note:
|
||||
The AIC SDK provides binary speech detection (not a probability score).
|
||||
This method returns 1.0 when speech is detected and 0.0 otherwise,
|
||||
rather than a true confidence value.
|
||||
|
||||
Args:
|
||||
buffer: Audio buffer to analyze.
|
||||
buffer: Audio buffer (unused - AIC VAD state is updated internally
|
||||
by the enhancement pipeline).
|
||||
|
||||
Returns:
|
||||
Voice confidence score is 0.0 or 1.0.
|
||||
1.0 if speech is detected, 0.0 otherwise.
|
||||
"""
|
||||
# Ensure backend exists (filter might have started since last call)
|
||||
self._ensure_backend_initialized()
|
||||
if self._backend_vad is None:
|
||||
# Ensure VAD context exists (filter might have started since last call)
|
||||
self._ensure_vad_context_initialized()
|
||||
if self._vad_ctx is None:
|
||||
return 0.0
|
||||
|
||||
# We do not need to analyze 'buffer' here since the model's VAD is updated
|
||||
# We do not need to analyze 'buffer' here since the processor's VAD is updated
|
||||
# as part of the enhancement pipeline. Simply query the boolean and map it.
|
||||
try:
|
||||
is_speech = self._backend_vad.is_speech_detected()
|
||||
is_speech = self._vad_ctx.is_speech_detected()
|
||||
return 1.0 if is_speech else 0.0
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error(f"AIC VAD inference error: {e}")
|
||||
|
||||
@@ -426,12 +426,15 @@ class TranscriptionFrame(TextFrame):
|
||||
timestamp: When the transcription occurred.
|
||||
language: Detected or specified language of the speech.
|
||||
result: Raw result from the STT service.
|
||||
finalized: Whether this is the final transcription for an utterance.
|
||||
Set by STT services that support commit/finalize signals.
|
||||
"""
|
||||
|
||||
user_id: str
|
||||
timestamp: str
|
||||
language: Optional[Language] = None
|
||||
result: Optional[Any] = None
|
||||
finalized: bool = False
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name}(user: {self.user_id}, text: [{self.text}], language: {self.language}, timestamp: {self.timestamp})"
|
||||
@@ -1463,6 +1466,7 @@ class UserImageRequestFrame(SystemFrame):
|
||||
video_source: Specific video source to capture from.
|
||||
function_name: Name of function that generated this request (if any).
|
||||
tool_call_id: Tool call ID if generated by function call (if any).
|
||||
result_callback: Optional callback to invoke when the image is retrieved.
|
||||
context: [DEPRECATED] Optional context for the image request.
|
||||
"""
|
||||
|
||||
@@ -1472,6 +1476,7 @@ class UserImageRequestFrame(SystemFrame):
|
||||
video_source: Optional[str] = None
|
||||
function_name: Optional[str] = None
|
||||
tool_call_id: Optional[str] = None
|
||||
result_callback: Optional[Any] = None
|
||||
context: Optional[Any] = None
|
||||
|
||||
def __post_init__(self):
|
||||
|
||||
@@ -1042,6 +1042,11 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
|
||||
|
||||
del self._function_calls_in_progress[frame.request.tool_call_id]
|
||||
|
||||
# Call the result_callback if provided. This signals that the image
|
||||
# has been retrieved and the function call can now complete.
|
||||
if frame.request and frame.request.result_callback:
|
||||
await frame.request.result_callback(None)
|
||||
|
||||
await self.handle_user_image_frame(frame)
|
||||
await self.push_aggregation()
|
||||
await self.push_context_frame(FrameDirection.UPSTREAM)
|
||||
|
||||
@@ -464,9 +464,11 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
await s.setup(self.task_manager)
|
||||
|
||||
async def _stop(self, frame: EndFrame):
|
||||
await self._maybe_emit_user_turn_stopped(on_session_end=True)
|
||||
await self._cleanup()
|
||||
|
||||
async def _cancel(self, frame: CancelFrame):
|
||||
await self._maybe_emit_user_turn_stopped(on_session_end=True)
|
||||
await self._cleanup()
|
||||
|
||||
async def _cleanup(self):
|
||||
@@ -602,14 +604,7 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
if params.enable_user_speaking_frames:
|
||||
await self.broadcast_frame(UserStoppedSpeakingFrame)
|
||||
|
||||
# Always push context frame.
|
||||
aggregation = await self.push_aggregation()
|
||||
|
||||
message = UserTurnStoppedMessage(
|
||||
content=aggregation, timestamp=self._user_turn_start_timestamp
|
||||
)
|
||||
await self._call_event_handler("on_user_turn_stopped", strategy, message)
|
||||
self._user_turn_start_timestamp = ""
|
||||
await self._maybe_emit_user_turn_stopped(strategy)
|
||||
|
||||
async def _on_user_turn_stop_timeout(self, controller):
|
||||
await self._call_event_handler("on_user_turn_stop_timeout")
|
||||
@@ -617,6 +612,26 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
async def _on_user_turn_idle(self, controller):
|
||||
await self._call_event_handler("on_user_turn_idle")
|
||||
|
||||
async def _maybe_emit_user_turn_stopped(
|
||||
self,
|
||||
strategy: Optional[BaseUserTurnStopStrategy] = None,
|
||||
on_session_end: bool = False,
|
||||
):
|
||||
"""Maybe emit user turn stopped event.
|
||||
|
||||
Args:
|
||||
strategy: The strategy that triggered the turn stop.
|
||||
on_session_end: If True, only emit if there's unemitted content
|
||||
(avoids duplicate events when session ends).
|
||||
"""
|
||||
aggregation = await self.push_aggregation()
|
||||
if not on_session_end or aggregation:
|
||||
message = UserTurnStoppedMessage(
|
||||
content=aggregation, timestamp=self._user_turn_start_timestamp
|
||||
)
|
||||
await self._call_event_handler("on_user_turn_stopped", strategy, message)
|
||||
self._user_turn_start_timestamp = ""
|
||||
|
||||
|
||||
class LLMAssistantAggregator(LLMContextAggregator):
|
||||
"""Assistant LLM aggregator that processes bot responses and function calls.
|
||||
@@ -739,6 +754,9 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
if isinstance(frame, InterruptionFrame):
|
||||
await self._handle_interruptions(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, (EndFrame, CancelFrame)):
|
||||
await self._handle_end_or_cancel(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, LLMFullResponseStartFrame):
|
||||
await self._handle_llm_start(frame)
|
||||
elif isinstance(frame, LLMFullResponseEndFrame):
|
||||
@@ -813,6 +831,10 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
self._started = 0
|
||||
await self.reset()
|
||||
|
||||
async def _handle_end_or_cancel(self, frame: Frame):
|
||||
await self._trigger_assistant_turn_stopped()
|
||||
self._started = 0
|
||||
|
||||
async def _handle_function_calls_started(self, frame: FunctionCallsStartedFrame):
|
||||
function_names = [f"{f.function_name}:{f.tool_call_id}" for f in frame.function_calls]
|
||||
logger.debug(f"{self} FunctionCallsStartedFrame: {function_names}")
|
||||
@@ -833,7 +855,7 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
"id": frame.tool_call_id,
|
||||
"function": {
|
||||
"name": frame.function_name,
|
||||
"arguments": json.dumps(frame.arguments),
|
||||
"arguments": json.dumps(frame.arguments, ensure_ascii=False),
|
||||
},
|
||||
"type": "function",
|
||||
}
|
||||
@@ -866,7 +888,7 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
|
||||
# Update context with the function call result
|
||||
if frame.result:
|
||||
result = json.dumps(frame.result)
|
||||
result = json.dumps(frame.result, ensure_ascii=False)
|
||||
self._update_function_call_result(frame.function_name, frame.tool_call_id, result)
|
||||
else:
|
||||
self._update_function_call_result(frame.function_name, frame.tool_call_id, "COMPLETED")
|
||||
@@ -919,16 +941,18 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
async def _handle_user_image_frame(self, frame: UserImageRawFrame):
|
||||
image_appended = False
|
||||
|
||||
# Check if this image is a result of a function call if so, let's cache.
|
||||
# TODO(aleix): The function call might have already been executed
|
||||
# because FunctionCallResultFrame was just faster, in that case we just
|
||||
# push the context frame now.
|
||||
# Check if this image is a result of a function call.
|
||||
if (
|
||||
frame.request
|
||||
and frame.request.tool_call_id
|
||||
and frame.request.tool_call_id in self._function_calls_in_progress
|
||||
):
|
||||
self._function_calls_image_results[frame.request.tool_call_id] = frame
|
||||
|
||||
# Call the result_callback if provided. This signals that the image
|
||||
# has been retrieved and the function call can now complete.
|
||||
if frame.request.result_callback:
|
||||
await frame.request.result_callback(None)
|
||||
else:
|
||||
image_appended = await self._maybe_append_image_to_context(frame)
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ of audio from both user input and bot output sources, with support for various a
|
||||
configurations and event-driven processing.
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from pipecat.audio.utils import create_stream_resampler, interleave_stereo_audio, mix_audio
|
||||
@@ -104,10 +103,6 @@ class AudioBufferProcessor(FrameProcessor):
|
||||
self._user_turn_audio_buffer = bytearray()
|
||||
self._bot_turn_audio_buffer = bytearray()
|
||||
|
||||
# Intermittent (non continous user stream variables)
|
||||
self._last_user_frame_at = 0
|
||||
self._last_bot_frame_at = 0
|
||||
|
||||
self._recording = False
|
||||
|
||||
self._input_resampler = create_stream_resampler()
|
||||
@@ -211,23 +206,31 @@ class AudioBufferProcessor(FrameProcessor):
|
||||
"""Process audio frames for recording."""
|
||||
resampled = None
|
||||
if isinstance(frame, InputAudioRawFrame):
|
||||
# Add silence if we need to.
|
||||
silence = self._compute_silence(self._last_user_frame_at)
|
||||
self._user_audio_buffer.extend(silence)
|
||||
# Add user audio.
|
||||
resampled = await self._resample_input_audio(frame)
|
||||
self._user_audio_buffer.extend(resampled)
|
||||
# Save time of frame so we can compute silence.
|
||||
self._last_user_frame_at = time.time()
|
||||
# Ignoring in case we don't have audio
|
||||
if len(resampled) > 0:
|
||||
# Sync bot buffer to current user position before adding user audio.
|
||||
# We sync BEFORE extending to align both buffers at the same starting timestamp.
|
||||
# For example, user buffer is at 100 bytes, and you receive 20 bytes of new audio
|
||||
# - Bot buffer sees User is at 100. Bot pads itself to 100.
|
||||
# - User buffer adds 20. User is now at 120.
|
||||
# - Outcome: At index 100-120, we have User Audio and (potentially) Bot Audio or silence. They are aligned
|
||||
# This gives the opportunity to the bot to send audio.
|
||||
#
|
||||
# If we synced AFTER, we'd pad the bot buffer with silence for the same
|
||||
# window we just gave to the user, effectively "overwriting" that time slot
|
||||
# with silence and causing the bot's audio to flicker or cut out.
|
||||
self._sync_buffer_to_position(self._bot_audio_buffer, len(self._user_audio_buffer))
|
||||
# Add user audio.
|
||||
self._user_audio_buffer.extend(resampled)
|
||||
elif self._recording and isinstance(frame, OutputAudioRawFrame):
|
||||
# Add silence if we need to.
|
||||
silence = self._compute_silence(self._last_bot_frame_at)
|
||||
self._bot_audio_buffer.extend(silence)
|
||||
# Add bot audio.
|
||||
resampled = await self._resample_output_audio(frame)
|
||||
self._bot_audio_buffer.extend(resampled)
|
||||
# Save time of frame so we can compute silence.
|
||||
self._last_bot_frame_at = time.time()
|
||||
# Ignoring in case we don't have audio
|
||||
if len(resampled) > 0:
|
||||
# Sync user buffer to current bot position before adding bot audio
|
||||
self._sync_buffer_to_position(self._user_audio_buffer, len(self._bot_audio_buffer))
|
||||
# Add bot audio.
|
||||
self._bot_audio_buffer.extend(resampled)
|
||||
|
||||
if self._buffer_size > 0 and (
|
||||
len(self._user_audio_buffer) >= self._buffer_size
|
||||
@@ -240,6 +243,21 @@ class AudioBufferProcessor(FrameProcessor):
|
||||
if self._enable_turn_audio:
|
||||
await self._process_turn_recording(frame, resampled)
|
||||
|
||||
def _sync_buffer_to_position(self, buffer: bytearray, target_position: int):
|
||||
"""Pad buffer with silence if it's behind the target position.
|
||||
|
||||
This ensures both buffers stay synchronized by padding the lagging
|
||||
buffer before new audio is added to the other buffer.
|
||||
|
||||
Args:
|
||||
buffer: The buffer to potentially pad.
|
||||
target_position: The position (in bytes) the buffer should reach.
|
||||
"""
|
||||
current_len = len(buffer)
|
||||
if current_len < target_position:
|
||||
silence_needed = target_position - current_len
|
||||
buffer.extend(b"\x00" * silence_needed)
|
||||
|
||||
async def _process_turn_recording(self, frame: Frame, resampled_audio: Optional[bytes] = None):
|
||||
"""Process frames for turn-based audio recording."""
|
||||
if isinstance(frame, UserStartedSpeakingFrame):
|
||||
@@ -281,8 +299,8 @@ class AudioBufferProcessor(FrameProcessor):
|
||||
if len(self._user_audio_buffer) == 0 and len(self._bot_audio_buffer) == 0:
|
||||
return
|
||||
|
||||
# Final alignment before we send the audio
|
||||
self._align_track_buffers()
|
||||
flush_time = time.time()
|
||||
|
||||
# Call original handler with merged audio
|
||||
merged_audio = self.merge_audio_buffers()
|
||||
@@ -299,9 +317,6 @@ class AudioBufferProcessor(FrameProcessor):
|
||||
self._num_channels,
|
||||
)
|
||||
|
||||
self._last_user_frame_at = flush_time
|
||||
self._last_bot_frame_at = flush_time
|
||||
|
||||
def _buffer_has_audio(self, buffer: bytearray) -> bool:
|
||||
"""Check if a buffer contains audio data."""
|
||||
return buffer is not None and len(buffer) > 0
|
||||
@@ -309,8 +324,6 @@ class AudioBufferProcessor(FrameProcessor):
|
||||
def _reset_recording(self):
|
||||
"""Reset recording state and buffers."""
|
||||
self._reset_all_audio_buffers()
|
||||
self._last_user_frame_at = time.time()
|
||||
self._last_bot_frame_at = time.time()
|
||||
|
||||
def _reset_all_audio_buffers(self):
|
||||
"""Reset all audio buffers to empty state."""
|
||||
@@ -336,11 +349,9 @@ class AudioBufferProcessor(FrameProcessor):
|
||||
|
||||
target_len = max(user_len, bot_len)
|
||||
if user_len < target_len:
|
||||
self._user_audio_buffer.extend(b"\x00" * (target_len - user_len))
|
||||
self._last_user_frame_at = max(self._last_user_frame_at, self._last_bot_frame_at)
|
||||
self._sync_buffer_to_position(self._user_audio_buffer, target_len)
|
||||
if bot_len < target_len:
|
||||
self._bot_audio_buffer.extend(b"\x00" * (target_len - bot_len))
|
||||
self._last_bot_frame_at = max(self._last_bot_frame_at, self._last_user_frame_at)
|
||||
self._sync_buffer_to_position(self._bot_audio_buffer, target_len)
|
||||
|
||||
async def _resample_input_audio(self, frame: InputAudioRawFrame) -> bytes:
|
||||
"""Resample audio frame to the target sample rate."""
|
||||
@@ -353,14 +364,3 @@ class AudioBufferProcessor(FrameProcessor):
|
||||
return await self._output_resampler.resample(
|
||||
frame.audio, frame.sample_rate, self._sample_rate
|
||||
)
|
||||
|
||||
def _compute_silence(self, from_time: float) -> bytes:
|
||||
"""Compute silence to insert based on time gap."""
|
||||
quiet_time = time.time() - from_time
|
||||
# We should get audio frames very frequently. We introduce silence only
|
||||
# if there's a big enough gap of 1s.
|
||||
if from_time == 0 or quiet_time < 1.0:
|
||||
return b""
|
||||
num_bytes = int(quiet_time * self._sample_rate) * 2
|
||||
silence = b"\x00" * num_bytes
|
||||
return silence
|
||||
|
||||
@@ -14,7 +14,6 @@ management, and frame flow control mechanisms.
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
@@ -775,20 +774,21 @@ class FrameProcessor(BaseObject):
|
||||
"""Broadcasts a frame of the specified class upstream and downstream.
|
||||
|
||||
This method creates two instances of the given frame class using the
|
||||
provided keyword arguments and pushes them upstream and downstream.
|
||||
provided keyword arguments (without deep-copying them) and pushes them
|
||||
upstream and downstream.
|
||||
|
||||
Args:
|
||||
frame_cls: The class of the frame to be broadcasted.
|
||||
**kwargs: Keyword arguments to be passed to the frame's constructor.
|
||||
"""
|
||||
await self.push_frame(frame_cls(**deepcopy(kwargs)))
|
||||
await self.push_frame(frame_cls(**deepcopy(kwargs)), FrameDirection.UPSTREAM)
|
||||
await self.push_frame(frame_cls(**kwargs))
|
||||
await self.push_frame(frame_cls(**kwargs), FrameDirection.UPSTREAM)
|
||||
|
||||
async def broadcast_frame_instance(self, frame: Frame):
|
||||
"""Broadcasts a frame instance upstream and downstream.
|
||||
|
||||
This method creates two new frame instances copying all fields from the
|
||||
original frame except `id` and `name`, which get fresh values.
|
||||
This method creates two new frame instances shallow-copying all fields
|
||||
from the original frame except `id` and `name`, which get fresh values.
|
||||
|
||||
Args:
|
||||
frame: The frame instance to broadcast.
|
||||
@@ -806,13 +806,13 @@ class FrameProcessor(BaseObject):
|
||||
if not f.init and f.name not in ("id", "name")
|
||||
}
|
||||
|
||||
new_frame = frame_cls(**deepcopy(init_fields))
|
||||
for k, v in deepcopy(extra_fields).items():
|
||||
new_frame = frame_cls(**init_fields)
|
||||
for k, v in extra_fields.items():
|
||||
setattr(new_frame, k, v)
|
||||
await self.push_frame(new_frame)
|
||||
|
||||
new_frame = frame_cls(**deepcopy(init_fields))
|
||||
for k, v in deepcopy(extra_fields).items():
|
||||
new_frame = frame_cls(**init_fields)
|
||||
for k, v in extra_fields.items():
|
||||
setattr(new_frame, k, v)
|
||||
await self.push_frame(new_frame, FrameDirection.UPSTREAM)
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ Functions:
|
||||
Environment variables:
|
||||
|
||||
- DAILY_API_KEY - Daily API key for room/token creation (required)
|
||||
- DAILY_SAMPLE_ROOM_URL (optional) - Existing room URL to use. If not provided,
|
||||
- DAILY_ROOM_URL (optional) - Existing room URL to use. If not provided,
|
||||
a temporary room will be created automatically.
|
||||
|
||||
Example::
|
||||
@@ -91,7 +91,7 @@ async def configure(
|
||||
"""Configure Daily room URL and token with optional SIP capabilities.
|
||||
|
||||
This function will either:
|
||||
1. Use an existing room URL from DAILY_SAMPLE_ROOM_URL environment variable (standard mode only)
|
||||
1. Use an existing room URL from DAILY_ROOM_URL environment variable (standard mode only)
|
||||
2. Create a new temporary room automatically if no URL is provided
|
||||
|
||||
Args:
|
||||
@@ -177,7 +177,7 @@ async def configure(
|
||||
)
|
||||
|
||||
# Check for existing room URL (only in standard mode)
|
||||
existing_room_url = os.getenv("DAILY_SAMPLE_ROOM_URL")
|
||||
existing_room_url = os.getenv("DAILY_ROOM_URL")
|
||||
if existing_room_url and not sip_enabled:
|
||||
# Use existing room (standard mode only)
|
||||
logger.info(f"Using existing Daily room: {existing_room_url}")
|
||||
|
||||
@@ -153,26 +153,18 @@ def _get_bot_module():
|
||||
)
|
||||
|
||||
|
||||
async def _run_telephony_bot(websocket: WebSocket):
|
||||
async def _run_telephony_bot(websocket: WebSocket, args: argparse.Namespace):
|
||||
"""Run a bot for telephony transports."""
|
||||
bot_module = _get_bot_module()
|
||||
|
||||
# Just pass the WebSocket - let the bot handle parsing
|
||||
runner_args = WebSocketRunnerArguments(websocket=websocket)
|
||||
runner_args.cli_args = args
|
||||
|
||||
await bot_module.bot(runner_args)
|
||||
|
||||
|
||||
def _create_server_app(
|
||||
*,
|
||||
transport_type: str,
|
||||
host: str = "localhost",
|
||||
proxy: str,
|
||||
esp32_mode: bool = False,
|
||||
whatsapp_enabled: bool = False,
|
||||
folder: Optional[str] = None,
|
||||
dialin_enabled: bool = False,
|
||||
):
|
||||
def _create_server_app(args: argparse.Namespace):
|
||||
"""Create FastAPI app with transport-specific routes."""
|
||||
app = FastAPI()
|
||||
|
||||
@@ -185,23 +177,21 @@ def _create_server_app(
|
||||
)
|
||||
|
||||
# Set up transport-specific routes
|
||||
if transport_type == "webrtc":
|
||||
_setup_webrtc_routes(app, esp32_mode=esp32_mode, host=host, folder=folder)
|
||||
if whatsapp_enabled:
|
||||
_setup_whatsapp_routes(app)
|
||||
elif transport_type == "daily":
|
||||
_setup_daily_routes(app, dialin_enabled=dialin_enabled)
|
||||
elif transport_type in TELEPHONY_TRANSPORTS:
|
||||
_setup_telephony_routes(app, transport_type=transport_type, proxy=proxy)
|
||||
if args.transport == "webrtc":
|
||||
_setup_webrtc_routes(app, args)
|
||||
if args.whatsapp:
|
||||
_setup_whatsapp_routes(app, args)
|
||||
elif args.transport == "daily":
|
||||
_setup_daily_routes(app, args)
|
||||
elif args.transport in TELEPHONY_TRANSPORTS:
|
||||
_setup_telephony_routes(app, args)
|
||||
else:
|
||||
logger.warning(f"Unknown transport type: {transport_type}")
|
||||
logger.warning(f"Unknown transport type: {args.transport}")
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def _setup_webrtc_routes(
|
||||
app: FastAPI, *, esp32_mode: bool = False, host: str = "localhost", folder: Optional[str] = None
|
||||
):
|
||||
def _setup_webrtc_routes(app: FastAPI, args: argparse.Namespace):
|
||||
"""Set up WebRTC-specific routes."""
|
||||
try:
|
||||
from pipecat_ai_small_webrtc_prebuilt.frontend import SmallWebRTCPrebuiltUI
|
||||
@@ -241,11 +231,11 @@ def _setup_webrtc_routes(
|
||||
@app.get("/files/{filename:path}")
|
||||
async def download_file(filename: str):
|
||||
"""Handle file downloads."""
|
||||
if not folder:
|
||||
if not args.folder:
|
||||
logger.warning(f"Attempting to dowload {filename}, but downloads folder not setup.")
|
||||
return
|
||||
|
||||
file_path = Path(folder) / filename
|
||||
file_path = Path(args.folder) / filename
|
||||
if not os.path.exists(file_path):
|
||||
raise HTTPException(404)
|
||||
|
||||
@@ -255,7 +245,7 @@ def _setup_webrtc_routes(
|
||||
|
||||
# Initialize the SmallWebRTC request handler
|
||||
small_webrtc_handler: SmallWebRTCRequestHandler = SmallWebRTCRequestHandler(
|
||||
esp32_mode=esp32_mode, host=host
|
||||
esp32_mode=args.esp32, host=args.host
|
||||
)
|
||||
|
||||
@app.post("/api/offer")
|
||||
@@ -269,6 +259,7 @@ def _setup_webrtc_routes(
|
||||
runner_args = SmallWebRTCRunnerArguments(
|
||||
webrtc_connection=connection, body=request.request_data
|
||||
)
|
||||
runner_args.cli_args = args
|
||||
background_tasks.add_task(bot_module.bot, runner_args)
|
||||
|
||||
# Delegate handling to SmallWebRTCRequestHandler
|
||||
@@ -298,7 +289,7 @@ def _setup_webrtc_routes(
|
||||
|
||||
# Store session info immediately in memory, replicate the behavior expected on Pipecat Cloud
|
||||
session_id = str(uuid.uuid4())
|
||||
active_sessions[session_id] = request_data
|
||||
active_sessions[session_id] = request_data.get("body", {})
|
||||
|
||||
result: StartBotResult = {"sessionId": session_id}
|
||||
if request_data.get("enableDefaultIceServers"):
|
||||
@@ -331,7 +322,8 @@ def _setup_webrtc_routes(
|
||||
pc_id=request_data.get("pc_id"),
|
||||
restart_pc=request_data.get("restart_pc"),
|
||||
request_data=request_data.get("request_data")
|
||||
or request_data.get("requestData"),
|
||||
or request_data.get("requestData")
|
||||
or active_session,
|
||||
)
|
||||
return await offer(webrtc_request, background_tasks)
|
||||
elif request.method == HTTPMethod.PATCH.value:
|
||||
@@ -380,8 +372,8 @@ def _add_lifespan_to_app(app: FastAPI, new_lifespan):
|
||||
app.router.lifespan_context = new_lifespan
|
||||
|
||||
|
||||
def _setup_whatsapp_routes(app: FastAPI):
|
||||
"""Set up WebRTC-specific routes."""
|
||||
def _setup_whatsapp_routes(app: FastAPI, args: argparse.Namespace):
|
||||
"""Set up WhatsApp-specific routes."""
|
||||
WHATSAPP_APP_SECRET = os.getenv("WHATSAPP_APP_SECRET")
|
||||
WHATSAPP_PHONE_NUMBER_ID = os.getenv("WHATSAPP_PHONE_NUMBER_ID")
|
||||
WHATSAPP_TOKEN = os.getenv("WHATSAPP_TOKEN")
|
||||
@@ -483,6 +475,7 @@ def _setup_whatsapp_routes(app: FastAPI):
|
||||
"""
|
||||
bot_module = _get_bot_module()
|
||||
runner_args = SmallWebRTCRunnerArguments(webrtc_connection=connection)
|
||||
runner_args.cli_args = args
|
||||
background_tasks.add_task(bot_module.bot, runner_args)
|
||||
|
||||
try:
|
||||
@@ -528,13 +521,8 @@ def _setup_whatsapp_routes(app: FastAPI):
|
||||
_add_lifespan_to_app(app, whatsapp_lifespan)
|
||||
|
||||
|
||||
def _setup_daily_routes(app: FastAPI, dialin_enabled: bool = False):
|
||||
"""Set up Daily-specific routes.
|
||||
|
||||
Args:
|
||||
app: FastAPI application instance
|
||||
dialin_enabled: If True, adds /daily-dialin-webhook endpoint for PSTN dial-in handling
|
||||
"""
|
||||
def _setup_daily_routes(app: FastAPI, args: argparse.Namespace):
|
||||
"""Set up Daily-specific routes."""
|
||||
|
||||
@app.get("/")
|
||||
async def create_room_and_start_agent():
|
||||
@@ -551,6 +539,7 @@ def _setup_daily_routes(app: FastAPI, dialin_enabled: bool = False):
|
||||
# Start the bot in the background with empty body for GET requests
|
||||
bot_module = _get_bot_module()
|
||||
runner_args = DailyRunnerArguments(room_url=room_url, token=token)
|
||||
runner_args.cli_args = args
|
||||
asyncio.create_task(bot_module.bot(runner_args))
|
||||
return RedirectResponse(room_url)
|
||||
|
||||
@@ -583,13 +572,13 @@ def _setup_daily_routes(app: FastAPI, dialin_enabled: bool = False):
|
||||
|
||||
bot_module = _get_bot_module()
|
||||
|
||||
existing_room_url = os.getenv("DAILY_SAMPLE_ROOM_URL")
|
||||
existing_room_url = os.getenv("DAILY_ROOM_URL")
|
||||
|
||||
result = None
|
||||
|
||||
# Configure room if:
|
||||
# 1. Explicitly requested via createDailyRoom in payload
|
||||
# 2. Using pre-configured room from DAILY_SAMPLE_ROOM_URL env var
|
||||
# 2. Using pre-configured room from DAILY_ROOM_URL env var
|
||||
if create_daily_room or existing_room_url:
|
||||
import aiohttp
|
||||
|
||||
@@ -634,12 +623,15 @@ def _setup_daily_routes(app: FastAPI, dialin_enabled: bool = False):
|
||||
else:
|
||||
runner_args = RunnerArguments(body=body)
|
||||
|
||||
# Update CLI args.
|
||||
runner_args.cli_args = args
|
||||
|
||||
# Start the bot in the background
|
||||
asyncio.create_task(bot_module.bot(runner_args))
|
||||
|
||||
return result
|
||||
|
||||
if dialin_enabled:
|
||||
if args.dialin:
|
||||
|
||||
@app.post("/daily-dialin-webhook")
|
||||
async def handle_dialin_webhook(request: Request):
|
||||
@@ -736,6 +728,7 @@ def _setup_daily_routes(app: FastAPI, dialin_enabled: bool = False):
|
||||
token=room_config.token,
|
||||
body=request_body.model_dump(),
|
||||
)
|
||||
runner_args.cli_args = args
|
||||
|
||||
asyncio.create_task(bot_module.bot(runner_args))
|
||||
|
||||
@@ -750,44 +743,44 @@ def _setup_daily_routes(app: FastAPI, dialin_enabled: bool = False):
|
||||
}
|
||||
|
||||
|
||||
def _setup_telephony_routes(app: FastAPI, *, transport_type: str, proxy: str):
|
||||
def _setup_telephony_routes(app: FastAPI, args: argparse.Namespace):
|
||||
"""Set up telephony-specific routes."""
|
||||
# XML response templates (Exotel doesn't use XML webhooks)
|
||||
XML_TEMPLATES = {
|
||||
"twilio": f"""<?xml version="1.0" encoding="UTF-8"?>
|
||||
<Response>
|
||||
<Connect>
|
||||
<Stream url="wss://{proxy}/ws"></Stream>
|
||||
<Stream url="wss://{args.proxy}/ws"></Stream>
|
||||
</Connect>
|
||||
<Pause length="40"/>
|
||||
</Response>""",
|
||||
"telnyx": f"""<?xml version="1.0" encoding="UTF-8"?>
|
||||
<Response>
|
||||
<Connect>
|
||||
<Stream url="wss://{proxy}/ws" bidirectionalMode="rtp"></Stream>
|
||||
<Stream url="wss://{args.proxy}/ws" bidirectionalMode="rtp"></Stream>
|
||||
</Connect>
|
||||
<Pause length="40"/>
|
||||
</Response>""",
|
||||
"plivo": f"""<?xml version="1.0" encoding="UTF-8"?>
|
||||
<Response>
|
||||
<Stream bidirectional="true" keepCallAlive="true" contentType="audio/x-mulaw;rate=8000">wss://{proxy}/ws</Stream>
|
||||
<Stream bidirectional="true" keepCallAlive="true" contentType="audio/x-mulaw;rate=8000">wss://{args.proxy}/ws</Stream>
|
||||
</Response>""",
|
||||
}
|
||||
|
||||
@app.post("/")
|
||||
async def start_call():
|
||||
"""Handle telephony webhook and return XML response."""
|
||||
if transport_type == "exotel":
|
||||
if args.transport == "exotel":
|
||||
# Exotel doesn't use POST webhooks - redirect to proper documentation
|
||||
logger.debug("POST Exotel endpoint - not used")
|
||||
return {
|
||||
"error": "Exotel doesn't use POST webhooks",
|
||||
"websocket_url": f"wss://{proxy}/ws",
|
||||
"websocket_url": f"wss://{args.proxy}/ws",
|
||||
"note": "Configure the WebSocket URL above in your Exotel App Bazaar Voicebot Applet",
|
||||
}
|
||||
else:
|
||||
logger.debug(f"POST {transport_type.upper()} XML")
|
||||
xml_content = XML_TEMPLATES.get(transport_type, "<Response></Response>")
|
||||
logger.debug(f"POST {args.transport.upper()} XML")
|
||||
xml_content = XML_TEMPLATES.get(args.transport, "<Response></Response>")
|
||||
return HTMLResponse(content=xml_content, media_type="application/xml")
|
||||
|
||||
@app.websocket("/ws")
|
||||
@@ -795,15 +788,15 @@ def _setup_telephony_routes(app: FastAPI, *, transport_type: str, proxy: str):
|
||||
"""Handle WebSocket connections for telephony."""
|
||||
await websocket.accept()
|
||||
logger.debug("WebSocket connection accepted")
|
||||
await _run_telephony_bot(websocket)
|
||||
await _run_telephony_bot(websocket, args)
|
||||
|
||||
@app.get("/")
|
||||
async def start_agent():
|
||||
"""Simple status endpoint for telephony transports."""
|
||||
return {"status": f"Bot started with {transport_type}"}
|
||||
return {"status": f"Bot started with {args.transport}"}
|
||||
|
||||
|
||||
async def _run_daily_direct():
|
||||
async def _run_daily_direct(args: argparse.Namespace):
|
||||
"""Run Daily bot with direct connection (no FastAPI server)."""
|
||||
try:
|
||||
from pipecat.runner.daily import configure
|
||||
@@ -819,6 +812,7 @@ async def _run_daily_direct():
|
||||
# Direct connections have no request body, so use empty dict
|
||||
runner_args = DailyRunnerArguments(room_url=room_url, token=token)
|
||||
runner_args.handle_sigint = True
|
||||
runner_args.cli_args = args
|
||||
|
||||
# Get the bot module and run it directly
|
||||
bot_module = _get_bot_module()
|
||||
@@ -866,29 +860,38 @@ def runner_port() -> int:
|
||||
return RUNNER_PORT
|
||||
|
||||
|
||||
def main():
|
||||
def main(parser: Optional[argparse.ArgumentParser] = None):
|
||||
"""Start the Pipecat development runner.
|
||||
|
||||
Parses command-line arguments and starts a FastAPI server configured
|
||||
for the specified transport type. The runner will discover and run
|
||||
any bot() function found in the current directory.
|
||||
for the specified transport type.
|
||||
|
||||
The runner discovers and runs any ``bot(runner_args)`` function found in the
|
||||
calling module.
|
||||
|
||||
Command-line arguments:
|
||||
- --host: Server host address (default: localhost) 879
|
||||
- --port: Server port (default: 7860)
|
||||
- -t/--transport: Transport type (daily, webrtc, twilio, telnyx, plivo, exotel)
|
||||
- -x/--proxy: Public proxy hostname for telephony webhooks
|
||||
- -d/--direct: Connect directly to Daily room (automatically sets transport to daily)
|
||||
- -f/--folder: Path to downloads folder
|
||||
- --dialin: Enable Daily PSTN dial-in webhook handling (requires Daily transport)
|
||||
- --esp32: Enable SDP munging for ESP32 compatibility (requires --host with IP address)
|
||||
- --whatsapp: Ensure requried WhatsApp environment variables are present
|
||||
- -v/--verbose: Increase logging verbosity
|
||||
|
||||
Args:
|
||||
--host: Server host address (default: localhost)
|
||||
--port: Server port (default: 7860)
|
||||
-t/--transport: Transport type (daily, webrtc, twilio, telnyx, plivo, exotel)
|
||||
-x/--proxy: Public proxy hostname for telephony webhooks
|
||||
--esp32: Enable SDP munging for ESP32 compatibility (requires --host with IP address)
|
||||
-d/--direct: Connect directly to Daily room (automatically sets transport to daily)
|
||||
-v/--verbose: Increase logging verbosity
|
||||
parser: Optional custom argument parser. If provided, default runner
|
||||
arguments are added to it so bots can define their own CLI
|
||||
arguments. Custom arguments should not conflict with the default
|
||||
ones. Custom args are accessible via `runner_args.cli_args`.
|
||||
|
||||
The bot file must contain a `bot(runner_args)` function as the entry point.
|
||||
"""
|
||||
global RUNNER_DOWNLOADS_FOLDER, RUNNER_HOST, RUNNER_PORT
|
||||
|
||||
parser = argparse.ArgumentParser(description="Pipecat Development Runner")
|
||||
if not parser:
|
||||
parser = argparse.ArgumentParser(description="Pipecat Development Runner")
|
||||
parser.add_argument("--host", type=str, default=RUNNER_HOST, help="Host address")
|
||||
parser.add_argument("--port", type=int, default=RUNNER_PORT, help="Port number")
|
||||
parser.add_argument(
|
||||
@@ -899,13 +902,7 @@ def main():
|
||||
default="webrtc",
|
||||
help="Transport type",
|
||||
)
|
||||
parser.add_argument("--proxy", "-x", help="Public proxy host name")
|
||||
parser.add_argument(
|
||||
"--esp32",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enable SDP munging for ESP32 compatibility (requires --host with IP address)",
|
||||
)
|
||||
parser.add_argument("-x", "--proxy", help="Public proxy host name")
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--direct",
|
||||
@@ -915,13 +912,7 @@ def main():
|
||||
)
|
||||
parser.add_argument("-f", "--folder", type=str, help="Path to downloads folder")
|
||||
parser.add_argument(
|
||||
"--verbose", "-v", action="count", default=0, help="Increase logging verbosity"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--whatsapp",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Ensure requried WhatsApp environment variables are present",
|
||||
"-v", "--verbose", action="count", default=0, help="Increase logging verbosity"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dialin",
|
||||
@@ -929,6 +920,18 @@ def main():
|
||||
default=False,
|
||||
help="Enable Daily PSTN dial-in webhook handling (requires Daily transport)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--esp32",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enable SDP munging for ESP32 compatibility (requires --host with IP address)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--whatsapp",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Ensure requried WhatsApp environment variables are present",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -964,7 +967,7 @@ def main():
|
||||
print()
|
||||
|
||||
# Run direct Daily connection
|
||||
asyncio.run(_run_daily_direct())
|
||||
asyncio.run(_run_daily_direct(args))
|
||||
return
|
||||
|
||||
# Print startup message for server-based transports
|
||||
@@ -995,15 +998,7 @@ def main():
|
||||
RUNNER_PORT = args.port
|
||||
|
||||
# Create the app with transport-specific setup
|
||||
app = _create_server_app(
|
||||
transport_type=args.transport,
|
||||
host=args.host,
|
||||
proxy=args.proxy,
|
||||
esp32_mode=args.esp32,
|
||||
whatsapp_enabled=args.whatsapp,
|
||||
folder=args.folder,
|
||||
dialin_enabled=args.dialin,
|
||||
)
|
||||
app = _create_server_app(args)
|
||||
|
||||
# Run the server
|
||||
uvicorn.run(app, host=args.host, port=args.port)
|
||||
|
||||
@@ -10,6 +10,7 @@ These types are used by the development runner to pass transport-specific
|
||||
information to bot functions.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
@@ -64,6 +65,7 @@ class RunnerArguments:
|
||||
handle_sigterm: bool = field(init=False, kw_only=True)
|
||||
pipeline_idle_timeout_secs: int = field(init=False, kw_only=True)
|
||||
body: Optional[Any] = field(default_factory=dict, kw_only=True)
|
||||
cli_args: Optional[argparse.Namespace] = field(default=None, init=False, kw_only=True)
|
||||
|
||||
def __post_init__(self):
|
||||
self.handle_sigint = False
|
||||
@@ -106,3 +108,18 @@ class SmallWebRTCRunnerArguments(RunnerArguments):
|
||||
"""
|
||||
|
||||
webrtc_connection: Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class LiveKitRunnerArguments(RunnerArguments):
|
||||
"""LiveKit transport session arguments for the runner.
|
||||
|
||||
Parameters:
|
||||
room_name: LiveKit room name to join
|
||||
token: Authentication token for the room
|
||||
body: Additional request data
|
||||
"""
|
||||
|
||||
room_name: str
|
||||
url: str
|
||||
token: Optional[str] = None
|
||||
|
||||
@@ -39,6 +39,7 @@ from loguru import logger
|
||||
|
||||
from pipecat.runner.types import (
|
||||
DailyRunnerArguments,
|
||||
LiveKitRunnerArguments,
|
||||
SmallWebRTCRunnerArguments,
|
||||
WebSocketRunnerArguments,
|
||||
)
|
||||
@@ -568,6 +569,17 @@ async def create_transport(
|
||||
return await _create_telephony_transport(
|
||||
runner_args.websocket, params, transport_type, call_data
|
||||
)
|
||||
elif isinstance(runner_args, LiveKitRunnerArguments):
|
||||
params = _get_transport_params("livekit", transport_params)
|
||||
|
||||
from pipecat.transports.livekit.transport import LiveKitTransport
|
||||
|
||||
return LiveKitTransport(
|
||||
runner_args.url,
|
||||
runner_args.token,
|
||||
runner_args.room_name,
|
||||
params=params,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported runner arguments type: {type(runner_args)}")
|
||||
|
||||
@@ -9,9 +9,10 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from pipecat.frames.frames import Frame, StartFrame
|
||||
from pipecat.utils.base_object import BaseObject
|
||||
|
||||
|
||||
class FrameSerializer(ABC):
|
||||
class FrameSerializer(BaseObject):
|
||||
"""Abstract base class for frame serialization implementations.
|
||||
|
||||
Defines the interface for converting frames to/from serialized formats
|
||||
|
||||
963
src/pipecat/serializers/genesys.py
Normal file
963
src/pipecat/serializers/genesys.py
Normal file
@@ -0,0 +1,963 @@
|
||||
"""Genesys AudioHook Serializer for Pipecat.
|
||||
|
||||
This module provides a serializer for integrating Pipecat pipelines with
|
||||
Genesys Cloud Contact Center via the AudioHook protocol.
|
||||
|
||||
Features:
|
||||
- Bidirectional audio streaming (PCMU μ-law at 8kHz)
|
||||
- Automatic protocol handshake handling (open/opened, close/closed, ping/pong)
|
||||
- Input/output variables for Architect flow integration
|
||||
- DTMF event support
|
||||
- Barge-in (interruption) events
|
||||
- Pause/resume support for hold scenarios (optional)
|
||||
|
||||
Protocol Reference:
|
||||
- https://developer.genesys.cloud/devapps/audiohook
|
||||
|
||||
Audio Format:
|
||||
- PCMU (μ-law) at 8kHz sample rate (preferred)
|
||||
- L16 (16-bit linear PCM) at 8kHz also supported
|
||||
- Mono (external channel) or Stereo (external on left, internal on right)
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from datetime import timedelta
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.audio.dtmf.types import KeypadEntry
|
||||
from pipecat.audio.resamplers.soxr_stream_resampler import SOXRStreamAudioResampler
|
||||
from pipecat.audio.utils import pcm_to_ulaw, ulaw_to_pcm
|
||||
from pipecat.frames.frames import (
|
||||
AudioRawFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
InputDTMFFrame,
|
||||
InterruptionFrame,
|
||||
OutputTransportMessageFrame,
|
||||
OutputTransportMessageUrgentFrame,
|
||||
StartFrame,
|
||||
)
|
||||
from pipecat.serializers.base_serializer import FrameSerializer
|
||||
|
||||
|
||||
class AudioHookMessageType(str, Enum):
|
||||
"""AudioHook protocol message types."""
|
||||
|
||||
OPEN = "open"
|
||||
OPENED = "opened"
|
||||
CLOSE = "close"
|
||||
CLOSED = "closed"
|
||||
PAUSE = "pause"
|
||||
RESUMED = "resumed"
|
||||
PING = "ping"
|
||||
PONG = "pong"
|
||||
UPDATE = "update"
|
||||
EVENT = "event"
|
||||
ERROR = "error"
|
||||
DISCONNECT = "disconnect"
|
||||
|
||||
|
||||
class AudioHookChannel(str, Enum):
|
||||
"""AudioHook audio channel configuration."""
|
||||
|
||||
EXTERNAL = "external" # Customer audio only (mono)
|
||||
INTERNAL = "internal" # Agent audio only (mono)
|
||||
BOTH = "both" # Stereo: external=left, internal=right
|
||||
|
||||
|
||||
class AudioHookMediaFormat(str, Enum):
|
||||
"""Supported audio formats."""
|
||||
|
||||
PCMU = "PCMU" # μ-law, 8kHz
|
||||
L16 = "L16" # 16-bit linear PCM, 8kHz
|
||||
|
||||
|
||||
class GenesysAudioHookSerializer(FrameSerializer):
|
||||
"""Serializer for Genesys AudioHook WebSocket protocol.
|
||||
|
||||
This serializer handles converting between Pipecat frames and Genesys
|
||||
AudioHook protocol messages. It supports:
|
||||
|
||||
- Bidirectional audio streaming (PCMU at 8kHz)
|
||||
- Automatic protocol handshake (open/opened, close/closed, ping/pong)
|
||||
- Session lifecycle management with pause/resume support
|
||||
- Custom input/output variables for Architect flow integration
|
||||
- DTMF event handling
|
||||
- Barge-in events for interruption support
|
||||
|
||||
The AudioHook protocol uses:
|
||||
- Text WebSocket frames for JSON control messages
|
||||
- Binary WebSocket frames for audio data
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
serializer = GenesysAudioHookSerializer(
|
||||
params=GenesysAudioHookSerializer.InputParams(
|
||||
channel=AudioHookChannel.EXTERNAL,
|
||||
supported_languages=["en-US", "es-ES"],
|
||||
selected_language="en-US",
|
||||
)
|
||||
)
|
||||
|
||||
# Use with FastAPI WebSocket transport
|
||||
transport = FastAPIWebsocketTransport(
|
||||
websocket=websocket,
|
||||
params=FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
serializer=serializer,
|
||||
audio_out_fixed_packet_size=1600, # Important: prevents 429 rate limiting from Genesys
|
||||
),
|
||||
)
|
||||
|
||||
# Access call information after connection
|
||||
participant = serializer.participant # ani, dnis, etc.
|
||||
input_vars = serializer.input_variables # Custom vars from Architect
|
||||
|
||||
# Set output variables to return to Architect
|
||||
serializer.set_output_variables({"intent": "billing", "resolved": True})
|
||||
```
|
||||
|
||||
Attributes:
|
||||
PROTOCOL_VERSION: The AudioHook protocol version (currently "2").
|
||||
"""
|
||||
|
||||
PROTOCOL_VERSION = "2"
|
||||
|
||||
class InputParams(BaseModel):
|
||||
"""Configuration parameters for GenesysAudioHookSerializer.
|
||||
|
||||
Attributes:
|
||||
genesys_sample_rate: Sample rate used by Genesys (default: 8000 Hz).
|
||||
sample_rate: Optional override for pipeline input sample rate.
|
||||
channel: Which audio channels to process (external, internal, both).
|
||||
media_format: Audio format (PCMU or L16).
|
||||
process_external: Whether to process external (customer) audio.
|
||||
process_internal: Whether to process internal (agent) audio.
|
||||
supported_languages: List of language codes the bot supports (e.g., ["en-US", "es-ES"]).
|
||||
selected_language: Default language code to use.
|
||||
start_paused: Whether to start the session in paused state.
|
||||
"""
|
||||
|
||||
genesys_sample_rate: int = 8000
|
||||
sample_rate: Optional[int] = None
|
||||
channel: AudioHookChannel = AudioHookChannel.EXTERNAL
|
||||
media_format: AudioHookMediaFormat = AudioHookMediaFormat.PCMU
|
||||
process_external: bool = True
|
||||
process_internal: bool = False
|
||||
supported_languages: Optional[List[str]] = None
|
||||
selected_language: Optional[str] = None
|
||||
start_paused: bool = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params: Optional[InputParams] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the GenesysAudioHookSerializer.
|
||||
|
||||
Args:
|
||||
params: Configuration parameters.
|
||||
**kwargs: Additional arguments passed to BaseObject (e.g., name).
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._params = params or GenesysAudioHookSerializer.InputParams()
|
||||
|
||||
self._genesys_sample_rate = self._params.genesys_sample_rate
|
||||
self._sample_rate = 0 # Pipeline input rate, set in setup()
|
||||
self._session_id = str(uuid.uuid4())
|
||||
|
||||
# Use Pipecat's official resampler if needed (SOXR)
|
||||
# Only used for TTS output (16kHz → 8kHz), input goes without resampling
|
||||
self._input_resampler = SOXRStreamAudioResampler()
|
||||
self._output_resampler = SOXRStreamAudioResampler()
|
||||
|
||||
# Protocol state
|
||||
self._client_seq = 0
|
||||
self._server_seq = 0
|
||||
self._is_open = False
|
||||
self._is_paused = False
|
||||
self._position = timedelta(0)
|
||||
|
||||
# Session metadata
|
||||
self._conversation_id: Optional[str] = None
|
||||
self._participant: Optional[Dict[str, Any]] = None
|
||||
self._custom_config: Optional[Dict[str, Any]] = None
|
||||
self._media_info: Optional[List[Dict[str, Any]]] = None
|
||||
self._input_variables: Optional[Dict[str, Any]] = None # Custom input from Genesys
|
||||
self._output_variables: Optional[Dict[str, Any]] = None # Custom output to Genesys
|
||||
|
||||
# Event handlers
|
||||
self._register_event_handler("on_open")
|
||||
self._register_event_handler("on_close")
|
||||
self._register_event_handler("on_ping")
|
||||
self._register_event_handler("on_pause")
|
||||
self._register_event_handler("on_update")
|
||||
self._register_event_handler("on_error")
|
||||
self._register_event_handler("on_dtmf")
|
||||
|
||||
@property
|
||||
def session_id(self) -> str:
|
||||
"""Get the Genesys AudioHook session ID generated by the serializer."""
|
||||
return self._session_id
|
||||
|
||||
@property
|
||||
def conversation_id(self) -> Optional[str]:
|
||||
"""Get the Genesys conversation ID."""
|
||||
return self._conversation_id
|
||||
|
||||
@property
|
||||
def is_open(self) -> bool:
|
||||
"""Check if the AudioHook session is open."""
|
||||
return self._is_open
|
||||
|
||||
@property
|
||||
def is_paused(self) -> bool:
|
||||
"""Check if audio streaming is paused."""
|
||||
return self._is_paused
|
||||
|
||||
@property
|
||||
def participant(self) -> Optional[Dict[str, Any]]:
|
||||
"""Get participant info (ani, dnis, etc.) from the open message."""
|
||||
return self._participant
|
||||
|
||||
@property
|
||||
def input_variables(self) -> Optional[Dict[str, Any]]:
|
||||
"""Get custom input variables from the open message."""
|
||||
return self._input_variables
|
||||
|
||||
@property
|
||||
def output_variables(self) -> Optional[Dict[str, Any]]:
|
||||
"""Get custom output variables to send back to Genesys."""
|
||||
return self._output_variables
|
||||
|
||||
def set_output_variables(self, variables: Dict[str, Any]) -> None:
|
||||
"""Set custom output variables to send back to Genesys on close.
|
||||
|
||||
These variables will be included in the 'closed' response when Genesys
|
||||
closes the connection, making them available in the Architect flow.
|
||||
|
||||
Args:
|
||||
variables: Dictionary of custom variables to send to Genesys.
|
||||
|
||||
Example:
|
||||
```python
|
||||
# During the conversation, collect data and set it
|
||||
serializer.set_output_variables({
|
||||
"intent": "billing_inquiry",
|
||||
"customer_verified": True,
|
||||
"summary": "Customer asked about their bill",
|
||||
"transfer_to": "billing_queue"
|
||||
})
|
||||
```
|
||||
"""
|
||||
self._output_variables = variables
|
||||
logger.debug(f"Output variables set: {variables}")
|
||||
|
||||
async def setup(self, frame: StartFrame):
|
||||
"""Sets up the serializer with pipeline configuration.
|
||||
|
||||
Args:
|
||||
frame: The StartFrame containing pipeline configuration.
|
||||
"""
|
||||
self._sample_rate = self._params.sample_rate or frame.audio_in_sample_rate
|
||||
logger.debug(f"GenesysAudioHookSerializer setup with sample_rate={self._sample_rate}")
|
||||
|
||||
def _format_position(self, position: timedelta) -> str:
|
||||
"""Format a timedelta as ISO 8601 duration string.
|
||||
|
||||
Args:
|
||||
position: The timedelta to format.
|
||||
|
||||
Returns:
|
||||
ISO 8601 duration string (e.g., "PT1.5S").
|
||||
"""
|
||||
total_seconds = position.total_seconds()
|
||||
return f"PT{total_seconds:.3f}S"
|
||||
|
||||
def _parse_position(self, position_str: str) -> timedelta:
|
||||
"""Parse an ISO 8601 duration string to timedelta.
|
||||
|
||||
Args:
|
||||
position_str: ISO 8601 duration string (e.g., "PT1.5S").
|
||||
|
||||
Returns:
|
||||
Corresponding timedelta.
|
||||
"""
|
||||
# Simple parser for PT#S or PT#.#S format
|
||||
if position_str.startswith("PT") and position_str.endswith("S"):
|
||||
try:
|
||||
seconds = float(position_str[2:-1])
|
||||
return timedelta(seconds=seconds)
|
||||
except ValueError:
|
||||
pass
|
||||
return timedelta(0)
|
||||
|
||||
def _next_server_seq(self) -> int:
|
||||
"""Get the next server sequence number."""
|
||||
self._server_seq += 1
|
||||
return self._server_seq
|
||||
|
||||
def _create_message(
|
||||
self,
|
||||
msg_type: AudioHookMessageType,
|
||||
parameters: Optional[Dict[str, Any]] = None,
|
||||
include_position: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""Create a protocol message with common fields.
|
||||
|
||||
Based on the Genesys AudioHook protocol, responses include:
|
||||
- seq: Server's sequence number (incremented per message)
|
||||
- clientseq: Echo of the client's last sequence number
|
||||
|
||||
Args:
|
||||
msg_type: The message type.
|
||||
parameters: Optional parameters object.
|
||||
include_position: Whether to include position field.
|
||||
|
||||
Returns:
|
||||
The message dictionary.
|
||||
"""
|
||||
seq = self._next_server_seq()
|
||||
msg = {
|
||||
"version": self.PROTOCOL_VERSION,
|
||||
"type": msg_type.value,
|
||||
"seq": seq,
|
||||
"clientseq": self._client_seq,
|
||||
"id": self._session_id,
|
||||
}
|
||||
|
||||
if include_position:
|
||||
msg["position"] = self._format_position(self._position)
|
||||
|
||||
if parameters:
|
||||
msg["parameters"] = parameters
|
||||
|
||||
return msg
|
||||
|
||||
def create_opened_response(
|
||||
self,
|
||||
start_paused: bool = False,
|
||||
supported_languages: Optional[List[str]] = None,
|
||||
selected_language: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Create an 'opened' response message for the client.
|
||||
|
||||
This should be sent in response to an 'open' message from Genesys.
|
||||
|
||||
Args:
|
||||
start_paused: Whether to start the session paused.
|
||||
supported_languages: List of supported language codes.
|
||||
selected_language: The selected language code.
|
||||
|
||||
Returns:
|
||||
Dictionary of the opened response message.
|
||||
"""
|
||||
# Build channels list based on configuration
|
||||
channels: list[str] = []
|
||||
|
||||
if self._params.channel == AudioHookChannel.EXTERNAL:
|
||||
channels = ["external"]
|
||||
elif self._params.channel == AudioHookChannel.INTERNAL:
|
||||
channels = ["internal"]
|
||||
elif self._params.channel == AudioHookChannel.BOTH:
|
||||
channels = ["external", "internal"]
|
||||
|
||||
parameters = {
|
||||
"startPaused": start_paused,
|
||||
"media": [
|
||||
{
|
||||
"type": "audio",
|
||||
"format": self._params.media_format.value,
|
||||
"channels": channels,
|
||||
"rate": self._genesys_sample_rate,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
if supported_languages:
|
||||
parameters["supportedLanguages"] = supported_languages
|
||||
if selected_language:
|
||||
parameters["selectedLanguage"] = selected_language
|
||||
|
||||
msg = self._create_message(
|
||||
AudioHookMessageType.OPENED,
|
||||
parameters=parameters,
|
||||
include_position=False, # opened doesn't need position
|
||||
)
|
||||
|
||||
self._is_open = True
|
||||
|
||||
logger.debug(f"AudioHook session opened: {self._session_id}")
|
||||
|
||||
return msg
|
||||
|
||||
def create_closed_response(
|
||||
self,
|
||||
output_variables: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Create a 'closed' response message.
|
||||
|
||||
This should be sent in response to a 'close' message from Genesys.
|
||||
|
||||
Args:
|
||||
output_variables: Optional custom variables to pass back to Genesys.
|
||||
These will be available in the Architect flow after the AudioHook
|
||||
action completes.
|
||||
|
||||
Returns:
|
||||
Dictionary of the closed response message.
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Pass custom data back to Genesys
|
||||
serializer.create_closed_response(
|
||||
output_variables={
|
||||
"intent": "billing_inquiry",
|
||||
"customer_verified": True,
|
||||
"summary": "Customer asked about their bill"
|
||||
}
|
||||
)
|
||||
```
|
||||
"""
|
||||
parameters: Optional[Dict[str, Any]] = None
|
||||
|
||||
if output_variables:
|
||||
parameters = {"outputVariables": output_variables}
|
||||
|
||||
msg = self._create_message(
|
||||
AudioHookMessageType.CLOSED,
|
||||
parameters=parameters,
|
||||
)
|
||||
|
||||
self._is_open = False
|
||||
logger.debug(f"AudioHook session closed: {self._session_id}")
|
||||
|
||||
return msg
|
||||
|
||||
def create_pong_response(self) -> Dict[str, Any]:
|
||||
"""Create a 'pong' response message.
|
||||
|
||||
This should be sent in response to a 'ping' message from Genesys.
|
||||
|
||||
Returns:
|
||||
Dictionary of the pong response message.
|
||||
"""
|
||||
msg = self._create_message(AudioHookMessageType.PONG)
|
||||
return msg
|
||||
|
||||
def create_resumed_response(self) -> Dict[str, Any]:
|
||||
"""Create a 'resumed' response message.
|
||||
|
||||
This should be sent in response to a 'pause' message when ready to resume.
|
||||
|
||||
Returns:
|
||||
Dictionary of the resumed response message.
|
||||
"""
|
||||
msg = self._create_message(AudioHookMessageType.RESUMED)
|
||||
|
||||
self._is_paused = False
|
||||
logger.debug(f"AudioHook session resumed: {self._session_id}")
|
||||
|
||||
return msg
|
||||
|
||||
def create_barge_in_event(self) -> Dict[str, Any]:
|
||||
"""Create a barge-in event message.
|
||||
|
||||
This notifies Genesys Cloud that the user has interrupted the bot's
|
||||
audio output. Genesys will stop any queued audio playback.
|
||||
|
||||
Returns:
|
||||
Dictionary of the barge-in event message.
|
||||
"""
|
||||
msg = self._create_message(
|
||||
AudioHookMessageType.EVENT,
|
||||
parameters={"entities": [{"type": "barge_in", "data": {}}]},
|
||||
)
|
||||
|
||||
logger.debug("🔇 Barge-in event sent to Genesys")
|
||||
|
||||
return msg
|
||||
|
||||
def create_disconnect_message(
|
||||
self,
|
||||
reason: str = "completed",
|
||||
action: str = "transfer",
|
||||
output_variables: Optional[Dict[str, Any]] = None,
|
||||
info: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Create a 'disconnect' message to initiate session termination.
|
||||
|
||||
Args:
|
||||
reason: Disconnect reason (e.g., "completed", "error").
|
||||
action: Action to take ("transfer" to agent, "finished" if completed).
|
||||
output_variables: Custom output variables to pass back to Genesys.
|
||||
info: Optional additional information.
|
||||
|
||||
Returns:
|
||||
Dictionary of the disconnect message.
|
||||
"""
|
||||
parameters: Dict[str, Any] = {"reason": reason}
|
||||
|
||||
# Build outputVariables
|
||||
out_vars = {"action": action}
|
||||
if output_variables:
|
||||
out_vars.update(output_variables)
|
||||
parameters["outputVariables"] = out_vars
|
||||
|
||||
if info:
|
||||
parameters["info"] = info
|
||||
|
||||
msg = self._create_message(
|
||||
AudioHookMessageType.DISCONNECT,
|
||||
parameters=parameters,
|
||||
)
|
||||
|
||||
logger.debug(f"AudioHook disconnect: reason={reason}, action={action}")
|
||||
return msg
|
||||
|
||||
def create_error_message(
|
||||
self,
|
||||
code: int,
|
||||
message: str,
|
||||
retryable: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Create an 'error' message.
|
||||
|
||||
Args:
|
||||
code: Error code.
|
||||
message: Error message.
|
||||
retryable: Whether the operation can be retried.
|
||||
|
||||
Returns:
|
||||
Dictionary of the error message.
|
||||
"""
|
||||
parameters = {
|
||||
"code": code,
|
||||
"message": message,
|
||||
"retryable": retryable,
|
||||
}
|
||||
|
||||
msg = self._create_message(
|
||||
AudioHookMessageType.ERROR,
|
||||
parameters=parameters,
|
||||
)
|
||||
|
||||
logger.error(f"AudioHook error: {code} - {message}")
|
||||
return msg
|
||||
|
||||
async def serialize(self, frame: Frame) -> str | bytes | None:
|
||||
"""Serializes a Pipecat frame to Genesys AudioHook format.
|
||||
|
||||
Handles conversion of various frame types to AudioHook messages:
|
||||
- AudioRawFrame -> Binary PCMU audio data (resampled to 8kHz)
|
||||
- EndFrame/CancelFrame -> Disconnect message (JSON)
|
||||
- InterruptionFrame -> Barge-in event (JSON)
|
||||
- OutputTransportMessageFrame -> Pass-through JSON
|
||||
|
||||
Args:
|
||||
frame: The Pipecat frame to serialize.
|
||||
|
||||
Returns:
|
||||
Serialized data as string (JSON) or bytes (audio), or None if
|
||||
the frame type is not handled or session is not open.
|
||||
"""
|
||||
if isinstance(frame, (EndFrame, CancelFrame)):
|
||||
return json.dumps(
|
||||
self.create_disconnect_message(
|
||||
output_variables=self.output_variables, reason="completed"
|
||||
)
|
||||
)
|
||||
|
||||
elif isinstance(frame, AudioRawFrame):
|
||||
if not self._is_open or self._is_paused:
|
||||
return None
|
||||
|
||||
data = frame.audio
|
||||
|
||||
# Convert PCM to μ-law at 8kHz for Genesys
|
||||
if self._params.media_format == AudioHookMediaFormat.PCMU:
|
||||
serialized_data = await pcm_to_ulaw(
|
||||
data,
|
||||
frame.sample_rate,
|
||||
self._genesys_sample_rate,
|
||||
self._output_resampler,
|
||||
)
|
||||
else:
|
||||
# L16 format - just resample if needed
|
||||
logger.warning("L16 format not yet fully implemented")
|
||||
return None
|
||||
|
||||
if serialized_data is None or len(serialized_data) == 0:
|
||||
return None
|
||||
|
||||
return bytes(serialized_data)
|
||||
|
||||
elif isinstance(frame, InterruptionFrame):
|
||||
return json.dumps(self.create_barge_in_event())
|
||||
|
||||
elif isinstance(frame, (OutputTransportMessageFrame, OutputTransportMessageUrgentFrame)):
|
||||
# Only pass through AudioHook protocol messages (those with "version" field)
|
||||
# Filter out RTVI and other non-AudioHook messages
|
||||
if isinstance(frame.message, dict) and "version" in frame.message:
|
||||
return json.dumps(frame.message)
|
||||
else:
|
||||
# Not an AudioHook message, ignore
|
||||
return None
|
||||
|
||||
# Ignore other frames - we don't need to process them here
|
||||
return None
|
||||
|
||||
async def deserialize(self, data: str | bytes) -> Frame | None:
|
||||
"""Deserializes Genesys AudioHook data to Pipecat frames.
|
||||
|
||||
Handles:
|
||||
- Binary data -> InputAudioRawFrame (converted from PCMU to PCM)
|
||||
- JSON 'open' -> OutputTransportMessageUrgentFrame with 'opened' response
|
||||
- JSON 'close' -> OutputTransportMessageUrgentFrame with 'closed' response
|
||||
- JSON 'ping' -> OutputTransportMessageUrgentFrame with 'pong' response
|
||||
- JSON 'pause' -> Sets is_paused=True, returns None
|
||||
- JSON 'dtmf' -> InputDTMFFrame
|
||||
- JSON 'update' -> Updates participant info, returns None
|
||||
- JSON 'error' -> Logs error, returns None
|
||||
|
||||
Protocol responses (opened, closed, pong) are returned as urgent frames
|
||||
to be sent immediately through the transport.
|
||||
|
||||
Args:
|
||||
data: The raw WebSocket data from Genesys (binary audio or JSON text).
|
||||
|
||||
Returns:
|
||||
A Pipecat frame to process, or None if handled internally.
|
||||
"""
|
||||
# Binary data = audio
|
||||
if isinstance(data, bytes):
|
||||
logger.debug(f"[AUDIO IN] Received {len(data)} bytes from Genesys")
|
||||
return await self._deserialize_audio(data)
|
||||
|
||||
# Text data = JSON control message
|
||||
try:
|
||||
message = json.loads(data)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse AudioHook message: {e}")
|
||||
return None
|
||||
|
||||
return await self._handle_control_message(message)
|
||||
|
||||
async def _deserialize_audio(self, data: bytes) -> Frame | None:
|
||||
"""Deserialize binary audio data to an InputAudioRawFrame.
|
||||
|
||||
Args:
|
||||
data: Raw audio bytes (PCMU or L16).
|
||||
|
||||
Returns:
|
||||
InputAudioRawFrame with PCM audio at pipeline sample rate.
|
||||
"""
|
||||
if not self._is_open or self._is_paused:
|
||||
return None
|
||||
|
||||
audio_data = data
|
||||
original_len = len(data)
|
||||
|
||||
# If Genesys sends stereo audio (BOTH channels), extract only the external channel (left)
|
||||
# Stereo audio comes interleaved: [L0, R0, L1, R1, ...]
|
||||
if self._params.channel == AudioHookChannel.BOTH and len(data) > 0:
|
||||
# For PCMU, each sample is 1 byte
|
||||
# Extract only bytes at even positions (left channel = external)
|
||||
audio_data = bytes(data[i] for i in range(0, len(data), 2))
|
||||
logger.debug(
|
||||
f"🔊 Stereo audio: {original_len} bytes → {len(audio_data)} bytes (external channel)"
|
||||
)
|
||||
|
||||
if self._params.media_format == AudioHookMediaFormat.PCMU:
|
||||
# Convert μ-law at 8kHz to PCM at pipeline rate
|
||||
deserialized_data = await ulaw_to_pcm(
|
||||
audio_data,
|
||||
self._genesys_sample_rate,
|
||||
self._sample_rate,
|
||||
self._input_resampler,
|
||||
)
|
||||
else:
|
||||
# L16 format
|
||||
logger.warning("L16 format not yet fully implemented")
|
||||
return None
|
||||
|
||||
if deserialized_data is None or len(deserialized_data) == 0:
|
||||
return None
|
||||
|
||||
# Always use mono for STT - ElevenLabs expects single channel
|
||||
num_channels = 1
|
||||
|
||||
audio_frame = InputAudioRawFrame(
|
||||
audio=deserialized_data,
|
||||
num_channels=num_channels,
|
||||
sample_rate=self._sample_rate,
|
||||
)
|
||||
|
||||
return audio_frame
|
||||
|
||||
async def _handle_control_message(self, message: Dict[str, Any]) -> Frame | None:
|
||||
"""Handle a JSON control message from Genesys.
|
||||
|
||||
Args:
|
||||
message: Parsed JSON message.
|
||||
|
||||
Returns:
|
||||
Frame if the message should be passed to the pipeline, None otherwise.
|
||||
"""
|
||||
msg_type = message.get("type", "")
|
||||
self._client_seq = message.get("seq", 0)
|
||||
|
||||
# Update position if provided
|
||||
if "position" in message:
|
||||
self._position = self._parse_position(message["position"])
|
||||
|
||||
if msg_type == AudioHookMessageType.OPEN.value:
|
||||
return await self._handle_open(message)
|
||||
|
||||
elif msg_type == AudioHookMessageType.CLOSE.value:
|
||||
return await self._handle_close(message)
|
||||
|
||||
elif msg_type == AudioHookMessageType.PING.value:
|
||||
return await self._handle_ping(message)
|
||||
|
||||
elif msg_type == AudioHookMessageType.PAUSE.value:
|
||||
return await self._handle_pause(message)
|
||||
|
||||
elif msg_type == AudioHookMessageType.UPDATE.value:
|
||||
return await self._handle_update(message)
|
||||
|
||||
elif msg_type == AudioHookMessageType.ERROR.value:
|
||||
return await self._handle_error(message)
|
||||
|
||||
elif msg_type == "dtmf":
|
||||
return await self._handle_dtmf(message)
|
||||
|
||||
elif msg_type == "playback_started":
|
||||
logger.debug("Playback started (from Genesys)")
|
||||
return None
|
||||
|
||||
elif msg_type == "playback_completed":
|
||||
logger.debug("Playback completed (from Genesys)")
|
||||
return None
|
||||
else:
|
||||
logger.warning(f"Unknown AudioHook message type: {msg_type}")
|
||||
return None
|
||||
|
||||
async def _handle_open(self, message: Dict[str, Any]) -> Frame | None:
|
||||
"""Handle an 'open' message from Genesys.
|
||||
|
||||
This initializes the session with metadata from Genesys Cloud and
|
||||
automatically responds with an 'opened' message using the configured
|
||||
InputParams (supported_languages, selected_language, start_paused).
|
||||
|
||||
Extracts and stores:
|
||||
- session_id: The AudioHook session identifier
|
||||
- conversation_id: The Genesys conversation ID
|
||||
- participant: Caller info (ani, dnis, etc.)
|
||||
- input_variables: Custom variables from Architect flow
|
||||
- media_info: Audio configuration from Genesys
|
||||
|
||||
Args:
|
||||
message: The open message from Genesys.
|
||||
|
||||
Returns:
|
||||
OutputTransportMessageUrgentFrame with the 'opened' response.
|
||||
"""
|
||||
self._session_id = message.get("id", str(uuid.uuid4()))
|
||||
|
||||
params = message.get("parameters", {})
|
||||
self._conversation_id = params.get("conversationId")
|
||||
self._participant = params.get("participant")
|
||||
self._custom_config = params.get("customConfig")
|
||||
self._media_info = params.get("media") # This is a list of media objects
|
||||
self._input_variables = params.get("inputVariables") # Custom vars from Genesys
|
||||
|
||||
# Extract media configuration if present
|
||||
# media is a list like: [{"type": "audio", "format": "PCMU", "channels": ["external"], "rate": 8000}]
|
||||
media_list = self._media_info
|
||||
if media_list and isinstance(media_list, list) and len(media_list) > 0:
|
||||
audio_media: Dict[str, Any] = media_list[0] # Get first media entry
|
||||
channels = audio_media.get("channels", [])
|
||||
logger.debug(
|
||||
f"📡 Genesys audio config: format={audio_media.get('format')}, channels={channels}, rate={audio_media.get('rate')}"
|
||||
)
|
||||
# channels is a list like ["external"] or ["external", "internal"]
|
||||
if isinstance(channels, list):
|
||||
if "external" in channels and "internal" in channels:
|
||||
self._params.channel = AudioHookChannel.BOTH
|
||||
logger.debug("📡 Stereo mode: extracting external channel")
|
||||
elif "external" in channels:
|
||||
self._params.channel = AudioHookChannel.EXTERNAL
|
||||
logger.debug("📡 Mono mode: external channel")
|
||||
elif "internal" in channels:
|
||||
self._params.channel = AudioHookChannel.INTERNAL
|
||||
logger.debug("📡 Mono mode: internal channel")
|
||||
|
||||
# Log participant info for debugging
|
||||
ani = self._participant.get("ani", "unknown") if self._participant else "unknown"
|
||||
logger.info(
|
||||
f"AudioHook open request: session={self._session_id}, "
|
||||
f"conversation={self._conversation_id}, ani={ani}"
|
||||
)
|
||||
|
||||
await self._call_event_handler("on_open", message)
|
||||
|
||||
return OutputTransportMessageUrgentFrame(
|
||||
message=self.create_opened_response(
|
||||
start_paused=self._params.start_paused,
|
||||
supported_languages=self._params.supported_languages,
|
||||
selected_language=self._params.selected_language,
|
||||
)
|
||||
)
|
||||
|
||||
async def _handle_close(self, message: Dict[str, Any]) -> Frame | None:
|
||||
"""Handle a 'close' message from Genesys.
|
||||
|
||||
Automatically responds with a 'closed' message. If output_variables
|
||||
were set via set_output_variables(), they will be included in the
|
||||
response and made available in the Architect flow.
|
||||
|
||||
Args:
|
||||
message: The close message from Genesys.
|
||||
|
||||
Returns:
|
||||
OutputTransportMessageUrgentFrame with the closed response
|
||||
(includes outputVariables if set).
|
||||
"""
|
||||
params = message.get("parameters", {})
|
||||
reason = params.get("reason", "unknown")
|
||||
|
||||
logger.info(f"🔴 Genesys closed the connection: {reason}")
|
||||
|
||||
self._is_open = False
|
||||
|
||||
logger.info(f"Sending closed response to Genesys...")
|
||||
|
||||
await self._call_event_handler("on_close", message)
|
||||
|
||||
# Return as urgent frame to be sent through pipeline immediately
|
||||
# Include any output variables that were set during the session
|
||||
return OutputTransportMessageUrgentFrame(
|
||||
message=self.create_closed_response(output_variables=self._output_variables)
|
||||
)
|
||||
|
||||
async def _handle_ping(self, message: Dict[str, Any]) -> Frame | None:
|
||||
"""Handle a 'ping' message from Genesys.
|
||||
|
||||
Automatically responds with a 'pong' message to maintain the connection.
|
||||
|
||||
Args:
|
||||
message: The ping message from Genesys.
|
||||
|
||||
Returns:
|
||||
OutputTransportMessageUrgentFrame with pong response.
|
||||
"""
|
||||
logger.info(f"Sending pong response to Genesys...")
|
||||
|
||||
await self._call_event_handler("on_ping", message)
|
||||
|
||||
# Return as urgent frame to be sent through pipeline immediately
|
||||
return OutputTransportMessageUrgentFrame(message=self.create_pong_response())
|
||||
|
||||
async def _handle_pause(self, message: Dict[str, Any]) -> Frame | None:
|
||||
"""Handle a 'pause' message from Genesys.
|
||||
|
||||
This is used when audio streaming is temporarily suspended
|
||||
(e.g., during hold).
|
||||
|
||||
Args:
|
||||
message: The pause message.
|
||||
|
||||
Returns:
|
||||
None (response should be sent via create_resumed_response()).
|
||||
"""
|
||||
params = message.get("parameters", {})
|
||||
reason = params.get("reason", "unknown")
|
||||
|
||||
logger.info(f"AudioHook pause request: reason={reason}")
|
||||
|
||||
self._is_paused = True
|
||||
|
||||
await self._call_event_handler("on_pause", message)
|
||||
|
||||
# Note: Application should call create_resumed_response() when ready
|
||||
return None
|
||||
|
||||
async def _handle_update(self, message: Dict[str, Any]) -> Frame | None:
|
||||
"""Handle an 'update' message from Genesys.
|
||||
|
||||
Updates may include changes to participants or configuration.
|
||||
|
||||
Args:
|
||||
message: The update message.
|
||||
|
||||
Returns:
|
||||
None.
|
||||
"""
|
||||
params = message.get("parameters", {})
|
||||
|
||||
if "participant" in params:
|
||||
self._participant = params["participant"]
|
||||
|
||||
logger.debug(f"AudioHook update received: {params}")
|
||||
|
||||
await self._call_event_handler("on_update", message)
|
||||
|
||||
return None
|
||||
|
||||
async def _handle_error(self, message: Dict[str, Any]) -> Frame | None:
|
||||
"""Handle an 'error' message from Genesys.
|
||||
|
||||
Args:
|
||||
message: The error message.
|
||||
|
||||
Returns:
|
||||
None.
|
||||
"""
|
||||
params = message.get("parameters", {})
|
||||
code = params.get("code", 0)
|
||||
error_msg = params.get("message", "Unknown error")
|
||||
|
||||
logger.error(f"AudioHook error from Genesys: {code} - {error_msg}")
|
||||
|
||||
await self._call_event_handler("on_error", message)
|
||||
|
||||
return None
|
||||
|
||||
async def _handle_dtmf(self, message: Dict[str, Any]) -> Frame | None:
|
||||
"""Handle a 'dtmf' message from Genesys.
|
||||
|
||||
DTMF (Dual-Tone Multi-Frequency) events are sent when the user
|
||||
presses keys on their phone keypad.
|
||||
|
||||
Args:
|
||||
message: The DTMF message.
|
||||
|
||||
Returns:
|
||||
InputDTMFFrame with the pressed digit.
|
||||
"""
|
||||
params = message.get("parameters", {})
|
||||
digit = params.get("digit", "")
|
||||
|
||||
if not digit:
|
||||
logger.warning("DTMF message received without digit")
|
||||
return None
|
||||
|
||||
logger.info(f"DTMF received: {digit}")
|
||||
|
||||
await self._call_event_handler("on_dtmf", message)
|
||||
|
||||
try:
|
||||
return InputDTMFFrame(KeypadEntry(digit))
|
||||
except ValueError:
|
||||
# Invalid digit
|
||||
logger.warning(f"Invalid DTMF digit: {digit}")
|
||||
return None
|
||||
@@ -161,7 +161,7 @@ class AssemblyAISTTService(WebsocketSTTService):
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
if isinstance(frame, VADUserStartedSpeakingFrame):
|
||||
await self.start_ttfb_metrics()
|
||||
pass
|
||||
elif isinstance(frame, VADUserStoppedSpeakingFrame):
|
||||
if (
|
||||
self._vad_force_turn_endpoint
|
||||
@@ -354,7 +354,6 @@ class AssemblyAISTTService(WebsocketSTTService):
|
||||
"""Handle transcription results."""
|
||||
if not message.transcript:
|
||||
return
|
||||
await self.stop_ttfb_metrics()
|
||||
if message.end_of_turn and (
|
||||
not self._connection_params.formatted_finals or message.turn_is_formatted
|
||||
):
|
||||
|
||||
@@ -158,7 +158,6 @@ class AWSTranscribeSTTService(WebsocketSTTService):
|
||||
await self._websocket.send(event_message)
|
||||
# Start metrics after first chunk sent
|
||||
await self.start_processing_metrics()
|
||||
await self.start_ttfb_metrics()
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Error sending audio: {e}")
|
||||
|
||||
@@ -470,7 +469,6 @@ class AWSTranscribeSTTService(WebsocketSTTService):
|
||||
is_final = not result.get("IsPartial", True)
|
||||
|
||||
if transcript:
|
||||
await self.stop_ttfb_metrics()
|
||||
if is_final:
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
|
||||
@@ -116,7 +116,6 @@ class AzureSTTService(STTService):
|
||||
"""
|
||||
try:
|
||||
await self.start_processing_metrics()
|
||||
await self.start_ttfb_metrics()
|
||||
if self._audio_stream:
|
||||
self._audio_stream.write(audio)
|
||||
yield None
|
||||
@@ -191,7 +190,6 @@ class AzureSTTService(STTService):
|
||||
self, transcript: str, is_final: bool, language: Optional[Language] = None
|
||||
):
|
||||
"""Handle a transcription result with tracing."""
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.stop_processing_metrics()
|
||||
|
||||
def _on_handle_recognized(self, event):
|
||||
|
||||
@@ -90,7 +90,7 @@ class AzureBaseTTSService:
|
||||
emphasis: Emphasis level for speech ("strong", "moderate", "reduced").
|
||||
language: Language for synthesis. Defaults to English (US).
|
||||
pitch: Voice pitch adjustment (e.g., "+10%", "-5Hz", "high").
|
||||
rate: Speech rate multiplier. Defaults to "1.05".
|
||||
rate: Speech rate adjustment (e.g., "1.0", "1.25", "slow", "fast").
|
||||
role: Voice role for expression (e.g., "YoungAdultFemale").
|
||||
style: Speaking style (e.g., "cheerful", "sad", "excited").
|
||||
style_degree: Intensity of the speaking style (0.01 to 2.0).
|
||||
@@ -100,7 +100,7 @@ class AzureBaseTTSService:
|
||||
emphasis: Optional[str] = None
|
||||
language: Optional[Language] = Language.EN_US
|
||||
pitch: Optional[str] = None
|
||||
rate: Optional[str] = "1.05"
|
||||
rate: Optional[str] = None
|
||||
role: Optional[str] = None
|
||||
style: Optional[str] = None
|
||||
style_degree: Optional[str] = None
|
||||
@@ -185,7 +185,9 @@ class AzureBaseTTSService:
|
||||
if self._settings["volume"]:
|
||||
prosody_attrs.append(f"volume='{self._settings['volume']}'")
|
||||
|
||||
ssml += f"<prosody {' '.join(prosody_attrs)}>"
|
||||
# Only wrap in prosody tag if there are prosody attributes
|
||||
if prosody_attrs:
|
||||
ssml += f"<prosody {' '.join(prosody_attrs)}>"
|
||||
|
||||
if self._settings["emphasis"]:
|
||||
ssml += f"<emphasis level='{self._settings['emphasis']}'>"
|
||||
@@ -195,7 +197,8 @@ class AzureBaseTTSService:
|
||||
if self._settings["emphasis"]:
|
||||
ssml += "</emphasis>"
|
||||
|
||||
ssml += "</prosody>"
|
||||
if prosody_attrs:
|
||||
ssml += "</prosody>"
|
||||
|
||||
if self._settings["style"]:
|
||||
ssml += "</mstts:express-as>"
|
||||
@@ -277,6 +280,11 @@ class AzureTTSService(WordTTSService, AzureBaseTTSService):
|
||||
self._started = False
|
||||
self._first_chunk = True
|
||||
self._cumulative_audio_offset: float = 0.0 # Cumulative audio duration in seconds
|
||||
self._current_sentence_base_offset: float = 0.0 # Base offset for current sentence
|
||||
self._current_sentence_duration: float = 0.0 # Duration from Azure callback
|
||||
self._current_sentence_max_word_offset: float = (
|
||||
0.0 # Max word boundary offset seen in current sentence (for 8kHz workaround)
|
||||
)
|
||||
self._last_word: Optional[str] = None # Track last word for punctuation merging
|
||||
self._last_timestamp: Optional[float] = None # Track last timestamp
|
||||
|
||||
@@ -386,8 +394,14 @@ class AzureTTSService(WordTTSService, AzureBaseTTSService):
|
||||
word = evt.text
|
||||
sentence_relative_seconds = evt.audio_offset / 10_000_000.0
|
||||
|
||||
# Add cumulative offset to get absolute timestamp across sentences
|
||||
absolute_seconds = self._cumulative_audio_offset + sentence_relative_seconds
|
||||
# Use base offset captured at start of run_tts to avoid race conditions
|
||||
# with callbacks from overlapping TTS requests
|
||||
absolute_seconds = self._current_sentence_base_offset + sentence_relative_seconds
|
||||
|
||||
# Track max word offset for accurate cumulative timing
|
||||
# (audio_duration from Azure doesn't always match word boundary offsets at 8kHz)
|
||||
if sentence_relative_seconds > self._current_sentence_max_word_offset:
|
||||
self._current_sentence_max_word_offset = sentence_relative_seconds
|
||||
|
||||
if not word:
|
||||
return
|
||||
@@ -492,9 +506,9 @@ class AzureTTSService(WordTTSService, AzureBaseTTSService):
|
||||
self._last_word = None
|
||||
self._last_timestamp = None
|
||||
|
||||
# Update cumulative audio offset for next sentence
|
||||
# Store duration for cumulative offset calculation
|
||||
if evt.result and evt.result.audio_duration:
|
||||
self._cumulative_audio_offset += evt.result.audio_duration.total_seconds()
|
||||
self._current_sentence_duration = evt.result.audio_duration.total_seconds()
|
||||
|
||||
self._audio_queue.put_nowait(None) # Signal completion
|
||||
|
||||
@@ -530,6 +544,9 @@ class AzureTTSService(WordTTSService, AzureBaseTTSService):
|
||||
self._started = False
|
||||
self._first_chunk = True
|
||||
self._cumulative_audio_offset = 0.0
|
||||
self._current_sentence_base_offset = 0.0
|
||||
self._current_sentence_duration = 0.0
|
||||
self._current_sentence_max_word_offset = 0.0
|
||||
self._last_word = None
|
||||
self._last_timestamp = None
|
||||
|
||||
@@ -604,6 +621,12 @@ class AzureTTSService(WordTTSService, AzureBaseTTSService):
|
||||
self._started = True
|
||||
self._first_chunk = True
|
||||
|
||||
# Capture base offset BEFORE starting synthesis to avoid race conditions
|
||||
# Word boundary callbacks will use this value
|
||||
self._current_sentence_base_offset = self._cumulative_audio_offset
|
||||
self._current_sentence_duration = 0.0
|
||||
self._current_sentence_max_word_offset = 0.0
|
||||
|
||||
ssml = self._construct_ssml(text)
|
||||
self._speech_synthesizer.speak_ssml_async(ssml)
|
||||
await self.start_tts_usage_metrics(text)
|
||||
@@ -627,6 +650,16 @@ class AzureTTSService(WordTTSService, AzureBaseTTSService):
|
||||
)
|
||||
yield frame
|
||||
|
||||
# Update cumulative offset for next sentence
|
||||
# At 8kHz, Azure's audio_duration doesn't match word boundary offsets,
|
||||
# so we use max_word_offset as a workaround. At other sample rates,
|
||||
# audio_duration is accurate.
|
||||
# TODO: Remove after Azure fixes word boundary timing at 8kHz
|
||||
if self.sample_rate == 8000:
|
||||
self._cumulative_audio_offset += self._current_sentence_max_word_offset
|
||||
else:
|
||||
self._cumulative_audio_offset += self._current_sentence_duration
|
||||
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -207,9 +207,8 @@ class CartesiaSTTService(WebsocketSTTService):
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def start_metrics(self):
|
||||
async def _start_metrics(self):
|
||||
"""Start performance metrics collection for transcription processing."""
|
||||
await self.start_ttfb_metrics()
|
||||
await self.start_processing_metrics()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
@@ -222,7 +221,7 @@ class CartesiaSTTService(WebsocketSTTService):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, VADUserStartedSpeakingFrame):
|
||||
await self.start_metrics()
|
||||
await self._start_metrics()
|
||||
elif isinstance(frame, VADUserStoppedSpeakingFrame):
|
||||
# Send finalize command to flush the transcription session
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
@@ -342,7 +341,6 @@ class CartesiaSTTService(WebsocketSTTService):
|
||||
pass
|
||||
|
||||
if len(transcript) > 0:
|
||||
await self.stop_ttfb_metrics()
|
||||
if is_final:
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
|
||||
@@ -659,6 +659,8 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
average_confidence = self._calculate_average_confidence(data)
|
||||
|
||||
if not self._params.min_confidence or average_confidence > self._params.min_confidence:
|
||||
# EndOfTurn means Flux has determined the turn is complete,
|
||||
# so this TranscriptionFrame is always finalized
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
transcript,
|
||||
@@ -666,6 +668,7 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
time_now_iso8601(),
|
||||
self._language,
|
||||
result=data,
|
||||
finalized=True,
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -276,9 +276,8 @@ class DeepgramSTTService(STTService):
|
||||
# GH issue: https://github.com/deepgram/deepgram-python-sdk/issues/570
|
||||
await self._connection.finish()
|
||||
|
||||
async def start_metrics(self):
|
||||
"""Start TTFB and processing metrics collection."""
|
||||
await self.start_ttfb_metrics()
|
||||
async def _start_metrics(self):
|
||||
"""Start processing metrics collection for this utterance."""
|
||||
await self.start_processing_metrics()
|
||||
|
||||
async def _on_error(self, *args, **kwargs):
|
||||
@@ -292,7 +291,7 @@ class DeepgramSTTService(STTService):
|
||||
await self._connect()
|
||||
|
||||
async def _on_speech_started(self, *args, **kwargs):
|
||||
await self.start_metrics()
|
||||
await self._start_metrics()
|
||||
await self._call_event_handler("on_speech_started", *args, **kwargs)
|
||||
await self.broadcast_frame(UserStartedSpeakingFrame)
|
||||
if self._should_interrupt:
|
||||
@@ -320,8 +319,12 @@ class DeepgramSTTService(STTService):
|
||||
language = result.channel.alternatives[0].languages[0]
|
||||
language = Language(language)
|
||||
if len(transcript) > 0:
|
||||
await self.stop_ttfb_metrics()
|
||||
if is_final:
|
||||
# Check if this response is from a finalize() call.
|
||||
# Only mark as finalized when both we requested it AND Deepgram confirms it.
|
||||
from_finalize = getattr(result, "from_finalize", False)
|
||||
if from_finalize:
|
||||
self.confirm_finalize()
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
transcript,
|
||||
@@ -356,8 +359,10 @@ class DeepgramSTTService(STTService):
|
||||
|
||||
if isinstance(frame, VADUserStartedSpeakingFrame) and not self.vad_enabled:
|
||||
# Start metrics if Deepgram VAD is disabled & pipeline VAD has detected speech
|
||||
await self.start_metrics()
|
||||
await self._start_metrics()
|
||||
elif isinstance(frame, VADUserStoppedSpeakingFrame):
|
||||
# https://developers.deepgram.com/docs/finalize
|
||||
# Mark that we're awaiting a from_finalize response
|
||||
self.request_finalize()
|
||||
await self._connection.finalize()
|
||||
logger.trace(f"Triggered finalize event on: {frame.name=}, {direction=}")
|
||||
|
||||
@@ -363,9 +363,6 @@ class DeepgramSageMakerSTTService(STTService):
|
||||
if not transcript.strip():
|
||||
return
|
||||
|
||||
# Stop TTFB metrics on first transcript
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
is_final = parsed.get("is_final", False)
|
||||
speech_final = parsed.get("speech_final", False)
|
||||
|
||||
@@ -417,9 +414,8 @@ class DeepgramSageMakerSTTService(STTService):
|
||||
"""
|
||||
pass
|
||||
|
||||
async def start_metrics(self):
|
||||
"""Start TTFB and processing metrics collection."""
|
||||
await self.start_ttfb_metrics()
|
||||
async def _start_metrics(self):
|
||||
"""Start processing metrics collection."""
|
||||
await self.start_processing_metrics()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
@@ -433,7 +429,7 @@ class DeepgramSageMakerSTTService(STTService):
|
||||
|
||||
# Start metrics when user starts speaking (if VAD is not provided by Deepgram)
|
||||
if isinstance(frame, VADUserStartedSpeakingFrame):
|
||||
await self.start_metrics()
|
||||
await self._start_metrics()
|
||||
elif isinstance(frame, VADUserStoppedSpeakingFrame):
|
||||
# Send finalize message to Deepgram when user stops speaking
|
||||
# This tells Deepgram to flush any remaining audio and return final results
|
||||
|
||||
@@ -310,7 +310,6 @@ class ElevenLabsSTTService(SegmentedSTTService):
|
||||
self, transcript: str, is_final: bool, language: Optional[str] = None
|
||||
):
|
||||
"""Handle a transcription result with tracing."""
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.stop_processing_metrics()
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
@@ -328,7 +327,6 @@ class ElevenLabsSTTService(SegmentedSTTService):
|
||||
"""
|
||||
try:
|
||||
await self.start_processing_metrics()
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
# Upload audio and get transcription result directly
|
||||
result = await self._transcribe_audio(audio)
|
||||
@@ -539,9 +537,8 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def start_metrics(self):
|
||||
async def _start_metrics(self):
|
||||
"""Start performance metrics collection for transcription processing."""
|
||||
await self.start_ttfb_metrics()
|
||||
await self.start_processing_metrics()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
@@ -555,7 +552,7 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
|
||||
if isinstance(frame, VADUserStartedSpeakingFrame):
|
||||
# Start metrics when user starts speaking
|
||||
await self.start_metrics()
|
||||
await self._start_metrics()
|
||||
elif isinstance(frame, VADUserStoppedSpeakingFrame):
|
||||
# Send commit when user stops speaking (manual commit mode)
|
||||
if self._params.commit_strategy == CommitStrategy.MANUAL:
|
||||
@@ -764,8 +761,6 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
if not text:
|
||||
return
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
# Get language if provided
|
||||
language = data.get("language_code")
|
||||
|
||||
@@ -803,7 +798,6 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
if not text:
|
||||
return
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.stop_processing_metrics()
|
||||
|
||||
# Get language if provided
|
||||
@@ -845,7 +839,6 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
if not text:
|
||||
return
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.stop_processing_metrics()
|
||||
|
||||
# Get language if provided
|
||||
|
||||
@@ -249,7 +249,6 @@ class FalSTTService(SegmentedSTTService):
|
||||
self, transcript: str, is_final: bool, language: Optional[str] = None
|
||||
):
|
||||
"""Handle a transcription result with tracing."""
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.stop_processing_metrics()
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
@@ -267,7 +266,6 @@ class FalSTTService(SegmentedSTTService):
|
||||
"""
|
||||
try:
|
||||
await self.start_processing_metrics()
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
# Send to Fal directly (audio is already in WAV format from base class)
|
||||
data_uri = fal_client.encode(audio, "audio/x-wav")
|
||||
|
||||
@@ -385,7 +385,6 @@ class GladiaSTTService(WebsocketSTTService):
|
||||
Yields:
|
||||
None (processing is handled asynchronously via WebSocket).
|
||||
"""
|
||||
await self.start_ttfb_metrics()
|
||||
await self.start_processing_metrics()
|
||||
|
||||
# Add audio to buffer
|
||||
@@ -513,7 +512,6 @@ class GladiaSTTService(WebsocketSTTService):
|
||||
async def _handle_transcription(
|
||||
self, transcript: str, is_final: bool, language: Optional[str] = None
|
||||
):
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.stop_processing_metrics()
|
||||
|
||||
async def _on_speech_started(self):
|
||||
|
||||
@@ -1198,7 +1198,20 @@ class GeminiLiveLLMService(LLMService):
|
||||
# Reset failure counter if connection has been stable
|
||||
self._check_and_reset_failure_counter()
|
||||
|
||||
if message.server_content and message.server_content.model_turn:
|
||||
if message.server_content and message.server_content.interrupted:
|
||||
# NOTE: while the service triggers interruptions in
|
||||
# the specific case of barge-ins, it does *not*
|
||||
# emit UserStarted/StoppedSpeakingFrames, as the
|
||||
# Gemini Live API does not give us broadly reliable
|
||||
# signals to base those off of. Pipelines that
|
||||
# require turn tracking (like those using context
|
||||
# aggregators) still need an independent way to
|
||||
# track turns, such as local Silero VAD in
|
||||
# combination with the context aggregator default
|
||||
# turn strategies.
|
||||
logger.debug("Gemini VAD: interrupted signal received")
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
elif message.server_content and message.server_content.model_turn:
|
||||
await self._handle_msg_model_turn(message)
|
||||
elif (
|
||||
message.server_content
|
||||
|
||||
@@ -40,7 +40,6 @@ from pipecat.frames.frames import (
|
||||
LLMThoughtStartFrame,
|
||||
LLMThoughtTextFrame,
|
||||
LLMUpdateSettingsFrame,
|
||||
UserImageRawFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
@@ -199,22 +198,6 @@ class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
if part.function_response and part.function_response.id == tool_call_id:
|
||||
part.function_response.response = {"value": json.dumps(result)}
|
||||
|
||||
async def handle_user_image_frame(self, frame: UserImageRawFrame):
|
||||
"""Handle user image frame.
|
||||
|
||||
Args:
|
||||
frame: Frame containing user image data and request context.
|
||||
"""
|
||||
await self._update_function_call_result(
|
||||
frame.request.function_name, frame.request.tool_call_id, "COMPLETED"
|
||||
)
|
||||
self._context.add_image_frame_message(
|
||||
format=frame.format,
|
||||
size=frame.size,
|
||||
image=frame.image,
|
||||
text=frame.request.context,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GoogleContextAggregatorPair:
|
||||
|
||||
@@ -823,7 +823,6 @@ class GoogleSTTService(STTService):
|
||||
"""
|
||||
if self._streaming_task:
|
||||
# Queue the audio data
|
||||
await self.start_ttfb_metrics()
|
||||
await self.start_processing_metrics()
|
||||
await self._request_queue.put(audio)
|
||||
yield None
|
||||
@@ -875,7 +874,6 @@ class GoogleSTTService(STTService):
|
||||
)
|
||||
else:
|
||||
self._last_transcript_was_final = False
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.push_frame(
|
||||
InterimTranscriptionFrame(
|
||||
transcript,
|
||||
|
||||
@@ -12,9 +12,10 @@ WebSocket API for streaming audio transcription.
|
||||
|
||||
import base64
|
||||
import json
|
||||
from typing import AsyncGenerator
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
@@ -22,9 +23,12 @@ from pipecat.frames.frames import (
|
||||
Frame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.stt_service import WebsocketSTTService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
from pipecat.utils.tracing.service_decorators import traced_stt
|
||||
|
||||
@@ -39,6 +43,26 @@ except ModuleNotFoundError as e:
|
||||
SAMPLE_RATE = 24000
|
||||
|
||||
|
||||
def language_to_gradium_language(language: Language) -> Optional[str]:
|
||||
"""Convert a Language enum to Gradium's language code format.
|
||||
|
||||
Args:
|
||||
language: The Language enum value to convert.
|
||||
|
||||
Returns:
|
||||
The Gradium language code string or None if not supported.
|
||||
"""
|
||||
LANGUAGE_MAP = {
|
||||
Language.DE: "de",
|
||||
Language.EN: "en",
|
||||
Language.ES: "es",
|
||||
Language.FR: "fr",
|
||||
Language.PT: "pt",
|
||||
}
|
||||
|
||||
return resolve_language(language, LANGUAGE_MAP, use_base_code=True)
|
||||
|
||||
|
||||
class GradiumSTTService(WebsocketSTTService):
|
||||
"""Gradium real-time speech-to-text service.
|
||||
|
||||
@@ -47,12 +71,29 @@ class GradiumSTTService(WebsocketSTTService):
|
||||
for audio processing and connection management.
|
||||
"""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
"""Configuration parameters for Gradium STT API.
|
||||
|
||||
Parameters:
|
||||
language: Expected language of the audio (e.g., "en", "es", "fr").
|
||||
This helps ground the model to a specific language and improve
|
||||
transcription quality.
|
||||
delay_in_frames: Delay in audio frames (80ms each) before text is
|
||||
generated. Higher delays allow more context but increase latency.
|
||||
Allowed values: 7, 8, 10, 12, 14, 16, 20, 24, 36, 48.
|
||||
Default is 10 (800ms). Lower values like 7-8 give faster response.
|
||||
"""
|
||||
|
||||
language: Optional[Language] = None
|
||||
delay_in_frames: Optional[int] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
api_endpoint_base_url: str = "wss://eu.api.gradium.ai/api/speech/asr",
|
||||
json_config: str | None = None,
|
||||
params: Optional[InputParams] = None,
|
||||
json_config: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Gradium STT service.
|
||||
@@ -60,14 +101,29 @@ class GradiumSTTService(WebsocketSTTService):
|
||||
Args:
|
||||
api_key: Gradium API key for authentication.
|
||||
api_endpoint_base_url: WebSocket endpoint URL. Defaults to Gradium's streaming endpoint.
|
||||
params: Configuration parameters for language and delay settings.
|
||||
json_config: Optional JSON configuration string for additional model settings.
|
||||
|
||||
.. deprecated:: 0.0.101
|
||||
Use `params` instead for type-safe configuration.
|
||||
|
||||
**kwargs: Additional arguments passed to parent STTService class.
|
||||
"""
|
||||
super().__init__(sample_rate=SAMPLE_RATE, **kwargs)
|
||||
|
||||
if json_config is not None:
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"Parameter 'json_config' is deprecated and will be removed in a future version, use 'params' instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
self._api_key = api_key
|
||||
self._api_endpoint_base_url = api_endpoint_base_url
|
||||
self._websocket = None
|
||||
self._params = params or GradiumSTTService.InputParams()
|
||||
self._json_config = json_config
|
||||
|
||||
self._receive_task = None
|
||||
@@ -76,6 +132,11 @@ class GradiumSTTService(WebsocketSTTService):
|
||||
self._chunk_size_ms = 80
|
||||
self._chunk_size_bytes = 0
|
||||
|
||||
# Set from the ready message when connecting to the service.
|
||||
# These values are used for flushing transcription.
|
||||
self._delay_in_frames = 0
|
||||
self._frame_size = 0
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if the service can generate metrics.
|
||||
|
||||
@@ -84,6 +145,17 @@ class GradiumSTTService(WebsocketSTTService):
|
||||
"""
|
||||
return True
|
||||
|
||||
async def set_language(self, language: Language):
|
||||
"""Set the recognition language and reconnect.
|
||||
|
||||
Args:
|
||||
language: The language to use for speech recognition.
|
||||
"""
|
||||
logger.info(f"Switching STT language to: [{language}]")
|
||||
self._params.language = language
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the speech-to-text service.
|
||||
|
||||
@@ -112,6 +184,57 @@ class GradiumSTTService(WebsocketSTTService):
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process frames with VAD-specific handling.
|
||||
|
||||
When VAD detects the user has stopped speaking, we flush the transcription
|
||||
by sending silence frames. This makes the system more reactive by getting
|
||||
the final transcription faster without closing the connection.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame processing.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, VADUserStartedSpeakingFrame):
|
||||
await self.start_processing_metrics()
|
||||
elif isinstance(frame, VADUserStoppedSpeakingFrame):
|
||||
await self._flush_transcription()
|
||||
|
||||
async def _flush_transcription(self):
|
||||
"""Flush the transcription by sending silence frames.
|
||||
|
||||
When VAD detects the user stopped speaking, we send delay_in_frames
|
||||
chunks of silence (zeros) to flush the remaining audio from the model's
|
||||
buffer. This allows for faster turn-around without closing the connection.
|
||||
|
||||
From Gradium docs: "feed in delay_in_frames chunks of silence (vectors
|
||||
of zeros). If those are fed in faster than realtime, the API also has
|
||||
a possibility to process them faster."
|
||||
"""
|
||||
if not self._websocket or self._websocket.state is not State.OPEN:
|
||||
return
|
||||
|
||||
if self._delay_in_frames <= 0:
|
||||
logger.debug("No delay_in_frames set, skipping flush")
|
||||
return
|
||||
|
||||
# Create a silence chunk (zeros) of frame_size samples
|
||||
# Each sample is 2 bytes (16-bit PCM)
|
||||
silence_bytes = bytes(self._frame_size * 2)
|
||||
silence_b64 = base64.b64encode(silence_bytes).decode("utf-8")
|
||||
|
||||
logger.debug(f"Flushing Gradium STT with {self._delay_in_frames} silence frames")
|
||||
|
||||
for _ in range(self._delay_in_frames):
|
||||
msg = {"type": "audio", "audio": silence_b64}
|
||||
try:
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send silence frame: {e}")
|
||||
break
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
"""Process audio data for speech-to-text conversion.
|
||||
|
||||
@@ -122,8 +245,6 @@ class GradiumSTTService(WebsocketSTTService):
|
||||
None (processing handled via WebSocket messages).
|
||||
"""
|
||||
self._audio_buffer.extend(audio)
|
||||
await self.start_ttfb_metrics()
|
||||
await self.start_processing_metrics()
|
||||
|
||||
while len(self._audio_buffer) >= self._chunk_size_bytes:
|
||||
chunk = bytes(self._audio_buffer[: self._chunk_size_bytes])
|
||||
@@ -152,6 +273,9 @@ class GradiumSTTService(WebsocketSTTService):
|
||||
try:
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
return
|
||||
|
||||
logger.debug("Connecting to Gradium STT")
|
||||
|
||||
ws_url = self._api_endpoint_base_url
|
||||
headers = {
|
||||
"x-api-key": self._api_key,
|
||||
@@ -166,8 +290,18 @@ class GradiumSTTService(WebsocketSTTService):
|
||||
"type": "setup",
|
||||
"input_format": "pcm",
|
||||
}
|
||||
if self._json_config is not None:
|
||||
setup_msg["json_config"] = self._json_config
|
||||
# Build json_config: start with deprecated json_config, then override with params
|
||||
json_config = {}
|
||||
if self._json_config:
|
||||
json_config = json.loads(self._json_config)
|
||||
if self._params.language:
|
||||
gradium_language = language_to_gradium_language(self._params.language)
|
||||
if gradium_language:
|
||||
json_config["language"] = gradium_language
|
||||
if self._params.delay_in_frames:
|
||||
json_config["delay_in_frames"] = self._params.delay_in_frames
|
||||
if json_config:
|
||||
setup_msg["json_config"] = json_config
|
||||
await self._websocket.send(json.dumps(setup_msg))
|
||||
ready_msg = await self._websocket.recv()
|
||||
ready_msg = json.loads(ready_msg)
|
||||
@@ -176,6 +310,14 @@ class GradiumSTTService(WebsocketSTTService):
|
||||
if ready_msg["type"] != "ready":
|
||||
raise Exception(f"unexpected first message type {ready_msg['type']}")
|
||||
|
||||
# Store delay_in_frames and frame_size for silence flushing
|
||||
self._delay_in_frames = ready_msg.get("delay_in_frames", 0)
|
||||
self._frame_size = ready_msg.get("frame_size", 1920)
|
||||
logger.debug(
|
||||
f"Connected to Gradium STT (delay_in_frames={self._delay_in_frames}, "
|
||||
f"frame_size={self._frame_size})"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
|
||||
raise
|
||||
@@ -241,3 +383,5 @@ class GradiumSTTService(WebsocketSTTService):
|
||||
time_now_iso8601(),
|
||||
)
|
||||
)
|
||||
await self._trace_transcription(text, is_final=True, language=None)
|
||||
await self.stop_processing_metrics()
|
||||
|
||||
@@ -232,11 +232,15 @@ class GradiumTTSService(InterruptibleWordTTSService):
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
async def flush_audio(self):
|
||||
"""Flush any pending audio synthesis."""
|
||||
"""Flush any pending audio synthesis.
|
||||
|
||||
Sends a <flush> tag to force the model to output audio for all text
|
||||
that has been input so far, without closing the connection.
|
||||
"""
|
||||
if not self._websocket:
|
||||
return
|
||||
try:
|
||||
msg = {"type": "end_of_stream"}
|
||||
msg = {"type": "text", "text": "<flush>"}
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
except ConnectionClosedOK:
|
||||
logger.debug(f"{self}: connection closed normally during flush")
|
||||
|
||||
@@ -111,7 +111,6 @@ class HathoraSTTService(SegmentedSTTService):
|
||||
"""
|
||||
try:
|
||||
await self.start_processing_metrics()
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
url = f"{self._base_url}"
|
||||
|
||||
@@ -153,7 +152,6 @@ class HathoraSTTService(SegmentedSTTService):
|
||||
result=response,
|
||||
)
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.stop_processing_metrics()
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -170,17 +170,22 @@ class LLMService(AIService):
|
||||
# However, subclasses should override this with a more specific adapter when necessary.
|
||||
adapter_class: Type[BaseLLMAdapter] = OpenAILLMAdapter
|
||||
|
||||
def __init__(self, run_in_parallel: bool = True, **kwargs):
|
||||
def __init__(
|
||||
self, run_in_parallel: bool = True, function_call_timeout_secs: float = 10.0, **kwargs
|
||||
):
|
||||
"""Initialize the LLM service.
|
||||
|
||||
Args:
|
||||
run_in_parallel: Whether to run function calls in parallel or sequentially.
|
||||
Defaults to True.
|
||||
function_call_timeout_secs: Timeout in seconds for deferred function calls.
|
||||
Defaults to 10.0 seconds.
|
||||
**kwargs: Additional arguments passed to the parent AIService.
|
||||
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._run_in_parallel = run_in_parallel
|
||||
self._function_call_timeout_secs = function_call_timeout_secs
|
||||
self._start_callbacks = {}
|
||||
self._adapter = self.adapter_class()
|
||||
self._functions: Dict[Optional[str], FunctionCallRegistryItem] = {}
|
||||
@@ -596,14 +601,26 @@ class LLMService(AIService):
|
||||
cancel_on_interruption=item.cancel_on_interruption,
|
||||
)
|
||||
|
||||
callback_executed = False
|
||||
# Start a timeout task for deferred function calls
|
||||
async def timeout_handler():
|
||||
await asyncio.sleep(self._function_call_timeout_secs)
|
||||
logger.warning(
|
||||
f"{self} Function call [{runner_item.function_name}:{runner_item.tool_call_id}] timed out after {self._function_call_timeout_secs} seconds"
|
||||
)
|
||||
await function_call_result_callback(None)
|
||||
|
||||
timeout_task = self.create_task(timeout_handler())
|
||||
|
||||
# Define a callback function that pushes a FunctionCallResultFrame upstream & downstream.
|
||||
async def function_call_result_callback(
|
||||
result: Any, *, properties: Optional[FunctionCallResultProperties] = None
|
||||
):
|
||||
nonlocal callback_executed
|
||||
callback_executed = True
|
||||
nonlocal timeout_task
|
||||
|
||||
# Cancel timeout task if it exists
|
||||
if timeout_task and not timeout_task.done():
|
||||
await self.cancel_task(timeout_task)
|
||||
|
||||
await self.broadcast_frame(
|
||||
FunctionCallResultFrame,
|
||||
function_name=runner_item.function_name,
|
||||
@@ -653,9 +670,6 @@ class LLMService(AIService):
|
||||
error_message = f"Error executing function call [{runner_item.function_name}]: {e}"
|
||||
logger.error(f"{self} {error_message}")
|
||||
await self.push_error(error_msg=error_message, exception=e, fatal=False)
|
||||
finally:
|
||||
if not callback_executed:
|
||||
await function_call_result_callback(None)
|
||||
|
||||
async def _cancel_function_call(self, function_name: Optional[str]):
|
||||
cancelled_tasks = set()
|
||||
|
||||
@@ -307,7 +307,6 @@ class NvidiaSTTService(STTService):
|
||||
|
||||
transcript = result.alternatives[0].transcript
|
||||
if transcript and len(transcript) > 0:
|
||||
await self.stop_ttfb_metrics()
|
||||
if result.is_final:
|
||||
await self.stop_processing_metrics()
|
||||
await self.push_frame(
|
||||
@@ -344,7 +343,6 @@ class NvidiaSTTService(STTService):
|
||||
Yields:
|
||||
None - transcription results are pushed to the pipeline via frames.
|
||||
"""
|
||||
await self.start_ttfb_metrics()
|
||||
await self.start_processing_metrics()
|
||||
await self._queue.put(audio)
|
||||
yield None
|
||||
@@ -598,12 +596,10 @@ class NvidiaSegmentedSTTService(SegmentedSTTService):
|
||||
assert self._config is not None, "Recognition config not created"
|
||||
|
||||
await self.start_processing_metrics()
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
# Process audio with NVIDIA Riva ASR - explicitly request non-future response
|
||||
raw_response = self._asr_service.offline_recognize(audio, self._config, future=False)
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.stop_processing_metrics()
|
||||
|
||||
# Process the response - handle different possible return types
|
||||
|
||||
@@ -492,8 +492,11 @@ class BaseOpenAILLMService(LLMService):
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
await self.start_processing_metrics()
|
||||
await self._process_context(context)
|
||||
except httpx.TimeoutException:
|
||||
except httpx.TimeoutException as e:
|
||||
await self._call_event_handler("on_completion_timeout")
|
||||
await self.push_error(error_msg="LLM completion timeout", exception=e)
|
||||
except Exception as e:
|
||||
await self.push_error(error_msg=f"Error during completion: {e}", exception=e)
|
||||
finally:
|
||||
await self.stop_processing_metrics()
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
|
||||
@@ -599,6 +599,14 @@ class OpenAIRealtimeLLMService(LLMService):
|
||||
# note: ttfb is faster by 1/2 RTT than ttfb as measured for other services, since we're getting
|
||||
# this event from the server
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
if self._current_audio_response and self._current_audio_response.item_id != evt.item_id:
|
||||
logger.warning(
|
||||
f"Received a new audio delta for an already completed audio response before receiving the BotStoppedSpeakingFrame."
|
||||
)
|
||||
logger.debug("Forcing previous audio response to None")
|
||||
self._current_audio_response = None
|
||||
|
||||
if not self._current_audio_response:
|
||||
self._current_audio_response = CurrentAudioResponse(
|
||||
item_id=evt.item_id,
|
||||
|
||||
@@ -525,6 +525,14 @@ class OpenAIRealtimeBetaLLMService(LLMService):
|
||||
# note: ttfb is faster by 1/2 RTT than ttfb as measured for other services, since we're getting
|
||||
# this event from the server
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
if self._current_audio_response and self._current_audio_response.item_id != evt.item_id:
|
||||
logger.warning(
|
||||
f"Received a new audio delta for an already completed audio response before receiving the BotStoppedSpeakingFrame."
|
||||
)
|
||||
logger.debug("Forcing previous audio response to None")
|
||||
self._current_audio_response = None
|
||||
|
||||
if not self._current_audio_response:
|
||||
self._current_audio_response = CurrentAudioResponse(
|
||||
item_id=evt.item_id,
|
||||
|
||||
@@ -10,7 +10,7 @@ This module provides an OpenAI-compatible interface for interacting with OpenRou
|
||||
extending the base OpenAI LLM service functionality.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -61,3 +61,35 @@ class OpenRouterLLMService(OpenAILLMService):
|
||||
"""
|
||||
logger.debug(f"Creating OpenRouter client with api {base_url}")
|
||||
return super().create_client(api_key, base_url, **kwargs)
|
||||
|
||||
def build_chat_completion_params(self, params_from_context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Builds chat parameters, handling model-specific constraints.
|
||||
|
||||
Args:
|
||||
params_from_context: Parameters from the LLM context.
|
||||
|
||||
Returns:
|
||||
Transformed parameters ready for the API call.
|
||||
"""
|
||||
params = super().build_chat_completion_params(params_from_context)
|
||||
model = getattr(self, "model_name", getattr(self, "model", "")).lower()
|
||||
if "gemini" in model:
|
||||
messages = params.get("messages", [])
|
||||
if not messages:
|
||||
return params
|
||||
transformed_messages = []
|
||||
system_message_seen = False
|
||||
for msg in messages:
|
||||
if msg.get("role") == "system":
|
||||
if not system_message_seen:
|
||||
transformed_messages.append(msg)
|
||||
system_message_seen = True
|
||||
else:
|
||||
new_msg = msg.copy()
|
||||
new_msg["role"] = "user"
|
||||
transformed_messages.append(new_msg)
|
||||
else:
|
||||
transformed_messages.append(msg)
|
||||
params["messages"] = transformed_messages
|
||||
|
||||
return params
|
||||
|
||||
@@ -6,7 +6,9 @@
|
||||
|
||||
"""Piper TTS service implementation."""
|
||||
|
||||
from typing import AsyncGenerator, Optional
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator, AsyncIterator, Optional
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
@@ -20,11 +22,128 @@ from pipecat.frames.frames import (
|
||||
from pipecat.services.tts_service import TTSService
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
try:
|
||||
from piper import PiperVoice
|
||||
from piper.download_voices import download_voice
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use Piper, you need to `pip install pipecat-ai[piper]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
# This assumes a running TTS service running: https://github.com/OHF-Voice/piper1-gpl/blob/main/docs/API_HTTP.md
|
||||
class PiperTTSService(TTSService):
|
||||
"""Piper TTS service implementation.
|
||||
|
||||
Provides local text-to-speech synthesis using Piper voice models. Automatically
|
||||
downloads voice models if not already present and resamples audio output to
|
||||
match the configured sample rate.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
voice_id: str,
|
||||
download_dir: Optional[Path] = None,
|
||||
force_redownload: bool = False,
|
||||
use_cuda: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Piper TTS service.
|
||||
|
||||
Args:
|
||||
voice_id: Piper voice model identifier (e.g. `en_US-ryan-high`).
|
||||
download_dir: Directory for storing voice model files. Defaults to
|
||||
the current working directory.
|
||||
force_redownload: Re-download the voice model even if it already exists.
|
||||
use_cuda: Use CUDA for GPU-accelerated inference.
|
||||
**kwargs: Additional arguments passed to the parent `TTSService`.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._voice_id = voice_id
|
||||
|
||||
download_dir = download_dir or Path.cwd()
|
||||
|
||||
model_file = f"{voice_id}.onnx"
|
||||
model_path = Path(download_dir) / model_file
|
||||
|
||||
if not model_path.exists():
|
||||
logger.debug(f"Downloading Piper '{voice_id}' model")
|
||||
download_voice(voice_id, download_dir, force_redownload=force_redownload)
|
||||
|
||||
logger.debug(f"Loading Piper '{voice_id}' model from {model_path}")
|
||||
|
||||
self._voice = PiperVoice.load(model_path, use_cuda=use_cuda)
|
||||
|
||||
logger.debug(f"Loaded Piper '{voice_id}' model")
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
|
||||
Returns:
|
||||
True, as Piper service supports metrics generation.
|
||||
"""
|
||||
return True
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate speech from text using Piper.
|
||||
|
||||
Args:
|
||||
text: The text to convert to speech.
|
||||
|
||||
Yields:
|
||||
Frame: Audio frames containing the synthesized speech and status frames.
|
||||
"""
|
||||
|
||||
def async_next(it):
|
||||
try:
|
||||
return next(it)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
async def async_iterator(iterator) -> AsyncIterator[bytes]:
|
||||
while True:
|
||||
item = await asyncio.to_thread(async_next, iterator)
|
||||
if item is None:
|
||||
return
|
||||
yield item.audio_int16_bytes
|
||||
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
yield TTSStartedFrame()
|
||||
|
||||
async for frame in self._stream_audio_frames_from_iterator(
|
||||
async_iterator(self._voice.synthesize(text)),
|
||||
in_sample_rate=self._voice.config.sample_rate,
|
||||
):
|
||||
await self.stop_ttfb_metrics()
|
||||
yield frame
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
finally:
|
||||
logger.debug(f"{self}: Finished TTS [{text}]")
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
|
||||
# This assumes a running TTS service running:
|
||||
# https://github.com/OHF-Voice/piper1-gpl/blob/main/docs/API_HTTP.md
|
||||
#
|
||||
# Usage:
|
||||
#
|
||||
# $ uv pip install "piper-tts[http]"
|
||||
# $ uv run python -m piper.http_server -m en_US-ryan-high
|
||||
#
|
||||
class PiperHttpTTSService(TTSService):
|
||||
"""Piper HTTP TTS service implementation.
|
||||
|
||||
Provides integration with Piper's HTTP TTS server for text-to-speech
|
||||
synthesis. Supports streaming audio generation with configurable sample
|
||||
rates and automatic WAV header removal.
|
||||
@@ -35,9 +154,7 @@ class PiperTTSService(TTSService):
|
||||
*,
|
||||
base_url: str,
|
||||
aiohttp_session: aiohttp.ClientSession,
|
||||
# When using Piper, the sample rate of the generated audio depends on the
|
||||
# voice model being used.
|
||||
sample_rate: Optional[int] = None,
|
||||
voice_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Piper TTS service.
|
||||
@@ -45,10 +162,10 @@ class PiperTTSService(TTSService):
|
||||
Args:
|
||||
base_url: Base URL for the Piper TTS HTTP server.
|
||||
aiohttp_session: aiohttp ClientSession for making HTTP requests.
|
||||
sample_rate: Output sample rate. If None, uses the voice model's native rate.
|
||||
voice_id: Piper voice model identifier (e.g. `en_US-ryan-high`).
|
||||
**kwargs: Additional arguments passed to the parent TTSService.
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if base_url.endswith("/"):
|
||||
logger.warning("Base URL ends with a slash, this is not allowed.")
|
||||
@@ -56,7 +173,7 @@ class PiperTTSService(TTSService):
|
||||
|
||||
self._base_url = base_url
|
||||
self._session = aiohttp_session
|
||||
self._settings = {"base_url": base_url}
|
||||
self._model_id = voice_id
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
@@ -83,9 +200,12 @@ class PiperTTSService(TTSService):
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
async with self._session.post(
|
||||
self._base_url, json={"text": text}, headers=headers
|
||||
) as response:
|
||||
data = {
|
||||
"text": text,
|
||||
"voice": self._model_id,
|
||||
}
|
||||
|
||||
async with self._session.post(self._base_url, json=data, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
error = await response.text()
|
||||
yield ErrorFrame(
|
||||
|
||||
@@ -15,9 +15,15 @@ from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.sarvam._sdk import sdk_headers
|
||||
from pipecat.services.stt_service import STTService
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
@@ -75,14 +81,14 @@ class SarvamSTTService(STTService):
|
||||
language: Target language for transcription. Defaults to None (required for saarika models).
|
||||
prompt: Optional prompt to guide translation style/context for STT-Translate models.
|
||||
Only applicable to saaras (STT-Translate) models. Defaults to None.
|
||||
vad_signals: Enable VAD signals in response. Defaults to True.
|
||||
high_vad_sensitivity: Enable high VAD (Voice Activity Detection) sensitivity. Defaults to False.
|
||||
vad_signals: Enable VAD signals in response. Defaults to None.
|
||||
high_vad_sensitivity: Enable high VAD (Voice Activity Detection) sensitivity. Defaults to None.
|
||||
"""
|
||||
|
||||
language: Optional[Language] = None
|
||||
prompt: Optional[str] = None
|
||||
vad_signals: bool = True
|
||||
high_vad_sensitivity: bool = False
|
||||
vad_signals: bool = None
|
||||
high_vad_sensitivity: bool = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -155,6 +161,7 @@ class SarvamSTTService(STTService):
|
||||
self._websocket_context = None
|
||||
self._socket_client = None
|
||||
self._receive_task = None
|
||||
logger.info(f"Sarvam STT initialized with SDK headers: {self._sdk_headers}")
|
||||
|
||||
def language_to_service_language(self, language: Language) -> str:
|
||||
"""Convert pipecat Language enum to Sarvam's language code.
|
||||
@@ -175,6 +182,22 @@ class SarvamSTTService(STTService):
|
||||
"""
|
||||
return True
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process incoming frames.
|
||||
|
||||
Handles VAD frames for TTFB tracking when using Pipecat's VAD
|
||||
instead of Sarvam's built-in VAD.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# Only handle VAD frames when not using Sarvam's VAD signals
|
||||
if not self._vad_signals:
|
||||
if isinstance(frame, VADUserStartedSpeakingFrame):
|
||||
await self._start_metrics()
|
||||
elif isinstance(frame, VADUserStoppedSpeakingFrame):
|
||||
if self._socket_client:
|
||||
await self._socket_client.flush()
|
||||
|
||||
async def set_language(self, language: Language):
|
||||
"""Set the recognition language and reconnect.
|
||||
|
||||
@@ -411,16 +434,18 @@ class SarvamSTTService(STTService):
|
||||
logger.debug(f"VAD Signal: {signal}, Occurred at: {timestamp}")
|
||||
|
||||
if signal == "START_SPEECH":
|
||||
await self.start_metrics()
|
||||
await self._start_metrics()
|
||||
logger.debug("User started speaking")
|
||||
await self._call_event_handler("on_speech_started")
|
||||
await self.broadcast_frame(UserStartedSpeakingFrame)
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
|
||||
elif signal == "END_SPEECH":
|
||||
logger.debug("User stopped speaking")
|
||||
await self._call_event_handler("on_speech_stopped")
|
||||
await self.broadcast_frame(UserStoppedSpeakingFrame)
|
||||
|
||||
elif message.type == "data":
|
||||
await self.stop_ttfb_metrics()
|
||||
transcript = message.data.transcript
|
||||
language_code = message.data.language_code
|
||||
# Prefer language from message (auto-detected for translate models). Fallback to configured.
|
||||
@@ -482,7 +507,6 @@ class SarvamSTTService(STTService):
|
||||
}
|
||||
return mapping.get(language_code, Language.HI_IN)
|
||||
|
||||
async def start_metrics(self):
|
||||
"""Start TTFB and processing metrics collection."""
|
||||
await self.start_ttfb_metrics()
|
||||
async def _start_metrics(self):
|
||||
"""Start processing metrics collection."""
|
||||
await self.start_processing_metrics()
|
||||
|
||||
@@ -21,7 +21,7 @@ from pipecat.frames.frames import (
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.stt_service import WebsocketSTTService
|
||||
@@ -162,7 +162,7 @@ class SonioxSTTService(WebsocketSTTService):
|
||||
sample_rate: Audio sample rate.
|
||||
params: Additional configuration parameters, such as language hints, context and
|
||||
speaker diarization.
|
||||
vad_force_turn_endpoint: Listen to `UserStoppedSpeakingFrame` to send finalize message to Soniox. If disabled, Soniox will detect the end of the speech.
|
||||
vad_force_turn_endpoint: Listen to `VADUserStoppedSpeakingFrame` to send finalize message to Soniox. If disabled, Soniox will detect the end of the speech.
|
||||
**kwargs: Additional arguments passed to the STTService.
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
@@ -247,7 +247,7 @@ class SonioxSTTService(WebsocketSTTService):
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, UserStoppedSpeakingFrame) and self._vad_force_turn_endpoint:
|
||||
if isinstance(frame, VADUserStoppedSpeakingFrame) and self._vad_force_turn_endpoint:
|
||||
# Send finalize message to Soniox so we get the final tokens asap.
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
await self._websocket.send(FINALIZE_MESSAGE)
|
||||
@@ -374,12 +374,15 @@ class SonioxSTTService(WebsocketSTTService):
|
||||
async def send_endpoint_transcript():
|
||||
if self._final_transcription_buffer:
|
||||
text = "".join(map(lambda token: token["text"], self._final_transcription_buffer))
|
||||
# Soniox only pushes TranscriptionFrame when an end token is received,
|
||||
# so every TranscriptionFrame is inherently finalized
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
text=text,
|
||||
user_id=self._user_id,
|
||||
timestamp=time_now_iso8601(),
|
||||
result=self._final_transcription_buffer,
|
||||
finalized=True,
|
||||
)
|
||||
)
|
||||
await self._handle_transcription(text, is_final=True)
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
@@ -67,7 +66,7 @@ class TurnDetectionMode(str, Enum):
|
||||
"""Endpoint and turn detection handling mode.
|
||||
|
||||
How the STT engine handles the endpointing of speech. If using Pipecat's built-in endpointing,
|
||||
then use `TurnDetectionMode.FIXED` (default).
|
||||
then use `TurnDetectionMode.EXTERNAL` (default).
|
||||
|
||||
To use the STT engine's built-in endpointing, then use `TurnDetectionMode.ADAPTIVE` for simple
|
||||
voice activity detection or `TurnDetectionMode.SMART_TURN` for more advanced ML-based
|
||||
@@ -107,7 +106,7 @@ class SpeechmaticsSTTService(STTService):
|
||||
|
||||
turn_detection_mode: Endpoint handling, one of `TurnDetectionMode.FIXED`,
|
||||
`TurnDetectionMode.EXTERNAL`, `TurnDetectionMode.ADAPTIVE` and
|
||||
`TurnDetectionMode.SMART_TURN`. Defaults to `TurnDetectionMode.FIXED`.
|
||||
`TurnDetectionMode.SMART_TURN`. Defaults to `TurnDetectionMode.EXTERNAL`.
|
||||
|
||||
speaker_active_format: Formatter for active speaker ID. This formatter is used to format
|
||||
the text output for individual speakers and ensures that the context is clear for
|
||||
@@ -201,6 +200,7 @@ class SpeechmaticsSTTService(STTService):
|
||||
extra_params: Extra parameters to pass to the STT engine. This is a dictionary of
|
||||
additional parameters that can be used to configure the STT engine.
|
||||
Default to None.
|
||||
|
||||
"""
|
||||
|
||||
# Service configuration
|
||||
@@ -208,7 +208,7 @@ class SpeechmaticsSTTService(STTService):
|
||||
language: Language | str = Language.EN
|
||||
|
||||
# Endpointing mode
|
||||
turn_detection_mode: TurnDetectionMode = TurnDetectionMode.FIXED
|
||||
turn_detection_mode: TurnDetectionMode = TurnDetectionMode.EXTERNAL
|
||||
|
||||
# Output formatting
|
||||
speaker_active_format: str | None = None
|
||||
@@ -346,7 +346,7 @@ class SpeechmaticsSTTService(STTService):
|
||||
params.speaker_passive_format or params.speaker_active_format
|
||||
)
|
||||
|
||||
# Metrics
|
||||
# Model + metrics
|
||||
self.set_model_name(self._config.operating_point.value)
|
||||
|
||||
# Message queue
|
||||
@@ -598,9 +598,6 @@ class SpeechmaticsSTTService(STTService):
|
||||
if segments:
|
||||
await self._send_frames(segments)
|
||||
|
||||
# Update metrics
|
||||
await self._emit_metrics(message.get("metadata", {}).get("processing_time", 0.0))
|
||||
|
||||
async def _handle_segment(self, message: dict[str, Any]) -> None:
|
||||
"""Handle AddSegment events.
|
||||
|
||||
@@ -695,6 +692,7 @@ class SpeechmaticsSTTService(STTService):
|
||||
f"{self} VADUserStoppedSpeakingFrame received but internal VAD is being used"
|
||||
)
|
||||
elif not self._enable_vad and self._client is not None:
|
||||
self.request_finalize()
|
||||
self._client.finalize()
|
||||
|
||||
async def _send_frames(self, segments: list[dict[str, Any]], finalized: bool = False) -> None:
|
||||
@@ -738,16 +736,33 @@ class SpeechmaticsSTTService(STTService):
|
||||
|
||||
# If final, then re-parse into TranscriptionFrame
|
||||
if finalized:
|
||||
# Do any segments have `is_eou` set to True?
|
||||
if (
|
||||
any(segment.get("is_eou", False) for segment in segments)
|
||||
and self._finalize_requested
|
||||
):
|
||||
self.confirm_finalize()
|
||||
|
||||
# Add the finalized frames
|
||||
frames += [TranscriptionFrame(**attr_from_segment(segment)) for segment in segments]
|
||||
|
||||
# Handle the text (for metrics reporting)
|
||||
finalized_text = "|".join([s["text"] for s in segments])
|
||||
await self._handle_transcription(finalized_text, True, segments[0]["language"])
|
||||
await self._handle_transcription(
|
||||
finalized_text, is_final=True, language=segments[0]["language"]
|
||||
)
|
||||
|
||||
# Log the frames
|
||||
logger.debug(f"{self} finalized transcript: {[f.text for f in frames]}")
|
||||
|
||||
# Return as interim results (unformatted)
|
||||
else:
|
||||
# Add the interim frames
|
||||
frames += [
|
||||
InterimTranscriptionFrame(**attr_from_segment(segment)) for segment in segments
|
||||
]
|
||||
|
||||
# Log the frames
|
||||
logger.debug(f"{self} interim transcript: {[f.text for f in frames]}")
|
||||
|
||||
# Send the frames
|
||||
@@ -804,28 +819,6 @@ class SpeechmaticsSTTService(STTService):
|
||||
yield ErrorFrame(f"Speechmatics error: {e}")
|
||||
await self._disconnect()
|
||||
|
||||
async def _emit_metrics(self, processing_time: float) -> None:
|
||||
"""Create TTFB metrics.
|
||||
|
||||
The TTFB is the seconds between the person speaking and the STT
|
||||
engine emitting the first partial. This is only calculated at the
|
||||
start of an utterance.
|
||||
"""
|
||||
# Skip if metrics not available
|
||||
if not self._metrics or processing_time == 0.0:
|
||||
return
|
||||
|
||||
# Calculate time as time.time() - ttfb (which is seconds)
|
||||
start_time = time.time() - processing_time
|
||||
|
||||
# Update internal metrics
|
||||
self._metrics._start_ttfb_time = start_time
|
||||
self._metrics._start_processing_time = start_time
|
||||
|
||||
# Stop TTFB metrics
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.stop_processing_metrics()
|
||||
|
||||
# ============================================================================
|
||||
# HELPERS
|
||||
# ============================================================================
|
||||
|
||||
@@ -6,7 +6,9 @@
|
||||
|
||||
"""Base classes for Speech-to-Text services with continuous and segmented processing."""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import time
|
||||
import wave
|
||||
from abc import abstractmethod
|
||||
from typing import Any, AsyncGenerator, Dict, Mapping, Optional
|
||||
@@ -17,12 +19,17 @@ from pipecat.frames.frames import (
|
||||
AudioRawFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
MetricsFrame,
|
||||
SpeechControlParamsFrame,
|
||||
StartFrame,
|
||||
STTMuteFrame,
|
||||
STTUpdateSettingsFrame,
|
||||
TranscriptionFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import TTFBMetricsData
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_service import AIService
|
||||
from pipecat.services.websocket_service import WebsocketService
|
||||
@@ -61,6 +68,8 @@ class STTService(AIService):
|
||||
audio_passthrough=True,
|
||||
# STT input sample rate
|
||||
sample_rate: Optional[int] = None,
|
||||
# STT TTFB timeout - time to wait after VAD stop before reporting TTFB
|
||||
stt_ttfb_timeout: float = 2.0,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the STT service.
|
||||
@@ -70,6 +79,12 @@ class STTService(AIService):
|
||||
Defaults to True.
|
||||
sample_rate: The sample rate for audio input. If None, will be determined
|
||||
from the start frame.
|
||||
stt_ttfb_timeout: Time in seconds to wait after VAD stop before reporting
|
||||
TTFB. This delay allows the final transcription to arrive. Defaults to 2.0.
|
||||
Note: STT "TTFB" differs from traditional TTFB (which measures from a discrete
|
||||
request to first response byte). Since STT receives continuous audio, we measure
|
||||
from when the user stops speaking to when the final transcript arrives—capturing
|
||||
the latency that matters for voice AI applications.
|
||||
**kwargs: Additional arguments passed to the parent AIService.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
@@ -81,6 +96,16 @@ class STTService(AIService):
|
||||
self._muted: bool = False
|
||||
self._user_id: str = ""
|
||||
|
||||
# STT TTFB tracking state
|
||||
self._stt_ttfb_timeout = stt_ttfb_timeout
|
||||
self._ttfb_timeout_task: Optional[asyncio.Task] = None
|
||||
self._vad_stop_secs: Optional[float] = None
|
||||
self._speech_end_time: Optional[float] = None
|
||||
self._user_speaking: bool = False
|
||||
self._last_transcription_time: Optional[float] = None
|
||||
self._finalize_pending: bool = False
|
||||
self._finalize_requested: bool = False
|
||||
|
||||
self._register_event_handler("on_connected")
|
||||
self._register_event_handler("on_disconnected")
|
||||
self._register_event_handler("on_connection_error")
|
||||
@@ -94,6 +119,31 @@ class STTService(AIService):
|
||||
"""
|
||||
return self._muted
|
||||
|
||||
def request_finalize(self):
|
||||
"""Mark that a finalize request has been sent, awaiting server confirmation.
|
||||
|
||||
For providers that have explicit server confirmation of finalization
|
||||
(e.g., Deepgram's from_finalize field), call this when sending the finalize
|
||||
request. Then call confirm_finalize() when the server confirms.
|
||||
|
||||
For providers without server confirmation, don't call this method - just
|
||||
send the finalize/flush/commit command and rely on the TTFB timeout.
|
||||
"""
|
||||
self._finalize_requested = True
|
||||
|
||||
def confirm_finalize(self):
|
||||
"""Confirm that the server has acknowledged the finalize request.
|
||||
|
||||
Call this when the server response confirms finalization (e.g., Deepgram's
|
||||
from_finalize=True). The next TranscriptionFrame pushed will be marked
|
||||
as finalized.
|
||||
|
||||
Only has effect if request_finalize() was previously called.
|
||||
"""
|
||||
if self._finalize_requested:
|
||||
self._finalize_pending = True
|
||||
self._finalize_requested = False
|
||||
|
||||
@property
|
||||
def sample_rate(self) -> int:
|
||||
"""Get the current sample rate for audio processing.
|
||||
@@ -144,6 +194,11 @@ class STTService(AIService):
|
||||
self._sample_rate = self._init_sample_rate or frame.audio_in_sample_rate
|
||||
self._tracing_enabled = frame.enable_tracing
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up STT service resources."""
|
||||
await super().cleanup()
|
||||
await self._cancel_ttfb_timeout()
|
||||
|
||||
async def _update_settings(self, settings: Mapping[str, Any]):
|
||||
logger.info(f"Updating STT settings: {self._settings}")
|
||||
for key, value in settings.items():
|
||||
@@ -206,14 +261,168 @@ class STTService(AIService):
|
||||
await self.process_audio_frame(frame, direction)
|
||||
if self._audio_passthrough:
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, SpeechControlParamsFrame):
|
||||
await self._handle_speech_control_params(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, VADUserStartedSpeakingFrame):
|
||||
await self._handle_vad_user_started_speaking(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, VADUserStoppedSpeakingFrame):
|
||||
await self._handle_vad_user_stopped_speaking(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, STTUpdateSettingsFrame):
|
||||
await self._update_settings(frame.settings)
|
||||
elif isinstance(frame, STTMuteFrame):
|
||||
self._muted = frame.mute
|
||||
logger.debug(f"STT service {'muted' if frame.mute else 'unmuted'}")
|
||||
elif isinstance(frame, InterruptionFrame):
|
||||
await self._reset_stt_ttfb_state()
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
"""Push a frame downstream, tracking TranscriptionFrame timestamps for TTFB.
|
||||
|
||||
Stores the timestamp of each TranscriptionFrame for TTFB calculation.
|
||||
If the frame is marked as finalized (via request_finalize/confirm_finalize),
|
||||
reports TTFB immediately and cancels any pending timeout. Otherwise, TTFB is
|
||||
reported after a timeout.
|
||||
|
||||
Args:
|
||||
frame: The frame to push.
|
||||
direction: The direction to push the frame.
|
||||
"""
|
||||
if isinstance(frame, TranscriptionFrame):
|
||||
# Store the transcription time for TTFB calculation
|
||||
self._last_transcription_time = time.time()
|
||||
|
||||
# Set finalized from pending state and auto-reset
|
||||
if self._finalize_pending:
|
||||
frame.finalized = True
|
||||
self._finalize_pending = False
|
||||
|
||||
# If this is a finalized transcription, report TTFB immediately
|
||||
if frame.finalized and self._speech_end_time is not None:
|
||||
ttfb = self._last_transcription_time - self._speech_end_time
|
||||
await self._emit_stt_ttfb_metric(ttfb)
|
||||
# Cancel the timeout since we've already reported
|
||||
await self._cancel_ttfb_timeout()
|
||||
# Clear state
|
||||
self._speech_end_time = None
|
||||
self._last_transcription_time = None
|
||||
|
||||
await super().push_frame(frame, direction)
|
||||
|
||||
async def _handle_speech_control_params(self, frame: SpeechControlParamsFrame):
|
||||
"""Handle speech control parameters frame to extract VAD stop_secs.
|
||||
|
||||
Args:
|
||||
frame: The speech control parameters frame.
|
||||
"""
|
||||
if frame.vad_params is not None:
|
||||
self._vad_stop_secs = frame.vad_params.stop_secs
|
||||
|
||||
async def _cancel_ttfb_timeout(self):
|
||||
"""Cancel any pending TTFB timeout task."""
|
||||
if self._ttfb_timeout_task:
|
||||
await self.cancel_task(self._ttfb_timeout_task)
|
||||
self._ttfb_timeout_task = None
|
||||
|
||||
async def _reset_stt_ttfb_state(self):
|
||||
"""Reset STT TTFB measurement state.
|
||||
|
||||
Called when starting a new utterance or on interruption to ensure
|
||||
we don't use stale state for TTFB calculations. This specifically guards
|
||||
against the case where a TranscriptionFrame is received without corresponding
|
||||
VADUserStartedSpeakingFrame and VADUserStoppedSpeakingFrame frames.
|
||||
|
||||
Note: Does not reset _user_speaking since InterruptionFrame can arrive
|
||||
while user is still speaking.
|
||||
"""
|
||||
await self._cancel_ttfb_timeout()
|
||||
self._speech_end_time = None
|
||||
self._last_transcription_time = None
|
||||
|
||||
async def _handle_vad_user_started_speaking(self, frame: VADUserStartedSpeakingFrame):
|
||||
"""Handle VAD user started speaking frame to start tracking transcriptions.
|
||||
|
||||
Cancels any pending TTFB timeout, resets TTFB tracking state, and marks user as speaking.
|
||||
Also resets finalization state to prevent stale finalization from a previous utterance.
|
||||
|
||||
Args:
|
||||
frame: The VAD user started speaking frame.
|
||||
"""
|
||||
await self._reset_stt_ttfb_state()
|
||||
self._user_speaking = True
|
||||
self._finalize_requested = False
|
||||
self._finalize_pending = False
|
||||
|
||||
async def _handle_vad_user_stopped_speaking(self, frame: VADUserStoppedSpeakingFrame):
|
||||
"""Handle VAD user stopped speaking frame.
|
||||
|
||||
Calculates the actual speech end time and starts a timeout task to wait
|
||||
for the final transcription before reporting TTFB.
|
||||
|
||||
Args:
|
||||
frame: The VAD user stopped speaking frame.
|
||||
"""
|
||||
self._user_speaking = False
|
||||
|
||||
# Skip TTFB measurement if we don't have VAD params
|
||||
if self._vad_stop_secs is None:
|
||||
return
|
||||
|
||||
# Calculate the actual speech end time (current time minus VAD stop delay).
|
||||
# This approximates when the last user audio was sent to the STT service,
|
||||
# which we use to measure against the eventual transcription response.
|
||||
self._speech_end_time = time.time() - self._vad_stop_secs
|
||||
|
||||
# Start timeout task (any previous timeout was cancelled by VADUserStartedSpeakingFrame
|
||||
# or InterruptionFrame)
|
||||
self._ttfb_timeout_task = self.create_task(
|
||||
self._ttfb_timeout_handler(), name="stt_ttfb_timeout"
|
||||
)
|
||||
|
||||
async def _ttfb_timeout_handler(self):
|
||||
"""Wait for timeout then report TTFB using the last transcription timestamp.
|
||||
|
||||
This timeout allows the final transcription to arrive before we calculate
|
||||
and report TTFB. If no transcription arrived, no TTFB is reported.
|
||||
"""
|
||||
try:
|
||||
await asyncio.sleep(self._stt_ttfb_timeout)
|
||||
|
||||
# Report TTFB if we have both speech end time and transcription time
|
||||
if self._speech_end_time is not None and self._last_transcription_time is not None:
|
||||
ttfb = self._last_transcription_time - self._speech_end_time
|
||||
await self._emit_stt_ttfb_metric(ttfb)
|
||||
|
||||
# Clear state after reporting
|
||||
self._speech_end_time = None
|
||||
self._last_transcription_time = None
|
||||
except asyncio.CancelledError:
|
||||
# Task was cancelled (new utterance or interruption), which is expected behavior
|
||||
pass
|
||||
finally:
|
||||
self._ttfb_timeout_task = None
|
||||
|
||||
async def _emit_stt_ttfb_metric(self, ttfb: float):
|
||||
"""Emit STT TTFB metric if value is non-negative.
|
||||
|
||||
Args:
|
||||
ttfb: The TTFB value in seconds.
|
||||
"""
|
||||
if ttfb >= 0:
|
||||
logger.debug(f"{self} TTFB: {ttfb:.3f}s")
|
||||
if self.metrics_enabled:
|
||||
ttfb_data = TTFBMetricsData(
|
||||
processor=self.name,
|
||||
model=self.model_name,
|
||||
value=ttfb,
|
||||
)
|
||||
await super().push_frame(MetricsFrame(data=[ttfb_data]))
|
||||
|
||||
|
||||
class SegmentedSTTService(STTService):
|
||||
"""STT service that processes speech in segments using VAD events.
|
||||
@@ -250,6 +459,20 @@ class SegmentedSTTService(STTService):
|
||||
await super().start(frame)
|
||||
self._audio_buffer_size_1s = self.sample_rate * 2
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
"""Push a frame, marking TranscriptionFrames as finalized.
|
||||
|
||||
Segmented STT services process complete speech segments and return a single
|
||||
TranscriptionFrame per segment, so every transcription is inherently finalized.
|
||||
|
||||
Args:
|
||||
frame: The frame to push.
|
||||
direction: The direction of frame flow in the pipeline.
|
||||
"""
|
||||
if isinstance(frame, TranscriptionFrame):
|
||||
frame.finalized = True
|
||||
await super().push_frame(frame, direction)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process frames, handling VAD events and audio segmentation."""
|
||||
await super().process_frame(frame, direction)
|
||||
@@ -274,11 +497,11 @@ class SegmentedSTTService(STTService):
|
||||
wav.close()
|
||||
content.seek(0)
|
||||
|
||||
await self.process_generator(self.run_stt(content.read()))
|
||||
|
||||
# Start clean.
|
||||
self._audio_buffer.clear()
|
||||
|
||||
await self.process_generator(self.run_stt(content.read()))
|
||||
|
||||
async def process_audio_frame(self, frame: AudioRawFrame, direction: FrameDirection):
|
||||
"""Process audio frames by buffering them for segmented transcription.
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ from typing import (
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.utils import create_stream_resampler
|
||||
from pipecat.frames.frames import (
|
||||
AggregatedTextFrame,
|
||||
AggregationType,
|
||||
@@ -202,6 +203,8 @@ class TTSService(AIService):
|
||||
)
|
||||
self._text_filters = [text_filter]
|
||||
|
||||
self._resampler = create_stream_resampler()
|
||||
|
||||
self._stop_frame_task: Optional[asyncio.Task] = None
|
||||
self._stop_frame_queue: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
@@ -505,12 +508,40 @@ class TTSService(AIService):
|
||||
await self._stop_frame_queue.put(frame)
|
||||
|
||||
async def _stream_audio_frames_from_iterator(
|
||||
self, iterator: AsyncIterator[bytes], *, strip_wav_header: bool
|
||||
self,
|
||||
iterator: AsyncIterator[bytes],
|
||||
*,
|
||||
strip_wav_header: bool = False,
|
||||
in_sample_rate: Optional[int] = None,
|
||||
) -> AsyncGenerator[Frame, None]:
|
||||
"""Stream audio frames from an async byte iterator with optional resampling.
|
||||
|
||||
For WAV data, use `strip_wav_header=True` to strip the header and
|
||||
auto-detect the source sample rate. For raw PCM data, pass
|
||||
`in_sample_rate` directly. Audio is resampled to `self.sample_rate` when
|
||||
the source rate differs.
|
||||
|
||||
Args:
|
||||
iterator: Async iterator yielding audio bytes.
|
||||
strip_wav_header: Strip WAV header and parse source sample rate from it.
|
||||
in_sample_rate: Source sample rate for raw PCM data. Overrides
|
||||
WAV-detected rate if both are provided.
|
||||
|
||||
"""
|
||||
buffer = bytearray()
|
||||
source_sample_rate = in_sample_rate
|
||||
need_to_strip_wav_header = strip_wav_header
|
||||
|
||||
async def maybe_resample(audio: bytes) -> bytes:
|
||||
if source_sample_rate and source_sample_rate != self.sample_rate:
|
||||
return await self._resampler.resample(audio, source_sample_rate, self.sample_rate)
|
||||
return audio
|
||||
|
||||
async for chunk in iterator:
|
||||
if need_to_strip_wav_header and chunk.startswith(b"RIFF"):
|
||||
# Parse sample rate from WAV header (bytes 24-28, little-endian uint32).
|
||||
if len(chunk) >= 44 and source_sample_rate is None:
|
||||
source_sample_rate = int.from_bytes(chunk[24:28], "little")
|
||||
chunk = chunk[44:]
|
||||
need_to_strip_wav_header = False
|
||||
|
||||
@@ -520,19 +551,18 @@ class TTSService(AIService):
|
||||
# Round to nearest even number.
|
||||
aligned_length = len(buffer) & ~1 # 111111111...11110
|
||||
if aligned_length > 0:
|
||||
aligned_chunk = buffer[:aligned_length]
|
||||
aligned_chunk = await maybe_resample(bytes(buffer[:aligned_length]))
|
||||
buffer = buffer[aligned_length:] # keep any leftover byte
|
||||
|
||||
if len(aligned_chunk) > 0:
|
||||
frame = TTSAudioRawFrame(bytes(aligned_chunk), self.sample_rate, 1)
|
||||
yield frame
|
||||
yield TTSAudioRawFrame(aligned_chunk, self.sample_rate, 1)
|
||||
|
||||
if len(buffer) > 0:
|
||||
# Make sure we don't need an extra padding byte.
|
||||
if len(buffer) % 2 == 1:
|
||||
buffer.extend(b"\x00")
|
||||
frame = TTSAudioRawFrame(bytes(buffer), self.sample_rate, 1)
|
||||
yield frame
|
||||
audio = await maybe_resample(bytes(buffer))
|
||||
yield TTSAudioRawFrame(audio, self.sample_rate, 1)
|
||||
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
self._processing_text = False
|
||||
|
||||
@@ -123,26 +123,29 @@ class WebsocketService(ABC):
|
||||
|
||||
async def _maybe_try_reconnect(
|
||||
self,
|
||||
error: Exception,
|
||||
error_message: str,
|
||||
report_error: Callable[[ErrorFrame], Awaitable[None]],
|
||||
error: Optional[Exception] = None,
|
||||
) -> bool:
|
||||
"""Check if reconnection should be attempted and try if appropriate.
|
||||
|
||||
Args:
|
||||
error: The exception that occurred.
|
||||
error_message: Human-readable error message for logging.
|
||||
report_error: Callback function to report connection errors.
|
||||
error: The exception that occurred (optional, may be None for graceful closes).
|
||||
|
||||
Returns:
|
||||
True if should continue the receive loop, False if should break.
|
||||
"""
|
||||
# Don't reconnect if we're intentionally disconnecting
|
||||
if self._disconnecting:
|
||||
logger.warning(f"{self} error during disconnect: {error}")
|
||||
if error:
|
||||
logger.warning(f"{self} error during disconnect: {error}")
|
||||
else:
|
||||
logger.debug(f"{self} receive loop ended during disconnect")
|
||||
return False
|
||||
|
||||
# Log the error
|
||||
# Log the message
|
||||
logger.warning(error_message)
|
||||
|
||||
# Try to reconnect if enabled
|
||||
@@ -167,6 +170,14 @@ class WebsocketService(ABC):
|
||||
while True:
|
||||
try:
|
||||
await self._receive_messages()
|
||||
# _receive_messages() returned normally. This happens when the websocket
|
||||
# closes gracefully (server sent close frame). The async for loop over
|
||||
# the websocket exits without raising an exception in this case.
|
||||
# We must handle this to avoid an infinite loop.
|
||||
message = f"{self} connection closed by server"
|
||||
should_continue = await self._maybe_try_reconnect(message, report_error)
|
||||
if not should_continue:
|
||||
break
|
||||
except ConnectionClosedOK as e:
|
||||
# Normal closure, don't retry
|
||||
logger.debug(f"{self} connection closed normally: {e}")
|
||||
@@ -175,13 +186,13 @@ class WebsocketService(ABC):
|
||||
# Connection closed with error (e.g., no close frame received/sent)
|
||||
# This often indicates network issues, server problems, or abrupt disconnection
|
||||
message = f"{self} connection closed, but with an error: {e}"
|
||||
should_continue = await self._maybe_try_reconnect(e, message, report_error)
|
||||
should_continue = await self._maybe_try_reconnect(message, report_error, e)
|
||||
if not should_continue:
|
||||
break
|
||||
except Exception as e:
|
||||
# General error during message receiving
|
||||
message = f"{self} error receiving messages: {e}"
|
||||
should_continue = await self._maybe_try_reconnect(e, message, report_error)
|
||||
should_continue = await self._maybe_try_reconnect(message, report_error, e)
|
||||
if not should_continue:
|
||||
break
|
||||
|
||||
|
||||
@@ -204,11 +204,9 @@ class BaseWhisperSTTService(SegmentedSTTService):
|
||||
"""
|
||||
try:
|
||||
await self.start_processing_metrics()
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
response = await self._transcribe(audio)
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.stop_processing_metrics()
|
||||
|
||||
text = response.text.strip()
|
||||
|
||||
@@ -289,7 +289,6 @@ class WhisperSTTService(SegmentedSTTService):
|
||||
return
|
||||
|
||||
await self.start_processing_metrics()
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
# Divide by 32768 because we have signed 16-bit data.
|
||||
audio_float = np.frombuffer(audio, dtype=np.int16).astype(np.float32) / 32768.0
|
||||
@@ -303,7 +302,6 @@ class WhisperSTTService(SegmentedSTTService):
|
||||
if segment.no_speech_prob < self._no_speech_prob:
|
||||
text += f"{segment.text} "
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.stop_processing_metrics()
|
||||
|
||||
if text:
|
||||
@@ -388,7 +386,6 @@ class WhisperSTTServiceMLX(WhisperSTTService):
|
||||
import mlx_whisper
|
||||
|
||||
await self.start_processing_metrics()
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
# Divide by 32768 because we have signed 16-bit data.
|
||||
audio_float = np.frombuffer(audio, dtype=np.int16).astype(np.float32) / 32768.0
|
||||
@@ -413,7 +410,6 @@ class WhisperSTTServiceMLX(WhisperSTTService):
|
||||
if len(text.strip()) == 0:
|
||||
text = None
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.stop_processing_metrics()
|
||||
|
||||
if text:
|
||||
|
||||
@@ -1733,7 +1733,7 @@ class DailyInputTransport(BaseInputTransport):
|
||||
message: The message data to send.
|
||||
sender: ID of the message sender.
|
||||
"""
|
||||
await self.broadcast_frame_class(
|
||||
await self.broadcast_frame(
|
||||
DailyInputTransportMessageFrame, message=message, participant_id=sender
|
||||
)
|
||||
|
||||
|
||||
@@ -698,7 +698,7 @@ class SmallWebRTCInputTransport(BaseInputTransport):
|
||||
message: The application message to process.
|
||||
"""
|
||||
logger.debug(f"Received app message inside SmallWebRTCInputTransport {message}")
|
||||
await self.broadcast_frame_class(InputTransportMessageFrame, message=message)
|
||||
await self.broadcast_frame(InputTransportMessageFrame, message=message)
|
||||
|
||||
# Add this method similar to DailyInputTransport.request_participant_image
|
||||
async def request_participant_image(self, frame: UserImageRequestFrame):
|
||||
|
||||
@@ -11,6 +11,7 @@ from typing import Optional, Type
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
@@ -156,7 +157,7 @@ class UserTurnController(BaseObject):
|
||||
await self._handle_vad_user_started_speaking(frame)
|
||||
elif isinstance(frame, VADUserStoppedSpeakingFrame):
|
||||
await self._handle_vad_user_stopped_speaking(frame)
|
||||
elif isinstance(frame, TranscriptionFrame):
|
||||
elif isinstance(frame, (TranscriptionFrame, InterimTranscriptionFrame)):
|
||||
await self._handle_transcription(frame)
|
||||
|
||||
for strategy in self._user_turn_strategies.start or []:
|
||||
@@ -209,8 +210,8 @@ class UserTurnController(BaseObject):
|
||||
# The user stopped talking, let's reset the user turn timeout.
|
||||
self._user_turn_stop_timeout_event.set()
|
||||
|
||||
async def _handle_transcription(self, frame: TranscriptionFrame):
|
||||
# We have creceived a transcription, let's reset the user turn timeout.
|
||||
async def _handle_transcription(self, frame: TranscriptionFrame | InterimTranscriptionFrame):
|
||||
# We have received a transcription, let's reset the user turn timeout.
|
||||
self._user_turn_stop_timeout_event.set()
|
||||
|
||||
async def _on_push_frame(
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user