Compare commits
78 Commits
mb/fix-pip
...
mb/review-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b57a9d2546 | ||
|
|
54ff49c4fc | ||
|
|
8549331b32 | ||
|
|
6b38510ace | ||
|
|
bbd5c3b487 | ||
|
|
c6fcf58be4 | ||
|
|
a475d09b90 | ||
|
|
4b11bcb4a4 | ||
|
|
c1a1d40f9a | ||
|
|
ede995c563 | ||
|
|
510f3df6b7 | ||
|
|
68292bd75f | ||
|
|
42423bff41 | ||
|
|
c3d2a25229 | ||
|
|
cf1a9c1548 | ||
|
|
51ba245e10 | ||
|
|
39b4e61837 | ||
|
|
ceaf53fdb0 | ||
|
|
f93276c64f | ||
|
|
62a0f0c0f5 | ||
|
|
793aca6b8b | ||
|
|
1fcaf3a4bf | ||
|
|
fdf3c8b4cf | ||
|
|
50bef86d33 | ||
|
|
79f43ece74 | ||
|
|
08b2365244 | ||
|
|
6484855139 | ||
|
|
771469b834 | ||
|
|
a60618b0ca | ||
|
|
3d21faaac2 | ||
|
|
f325eeb95b | ||
|
|
4c3fd42b1c | ||
|
|
c2309efd7e | ||
|
|
4ae1819645 | ||
|
|
a38f208135 | ||
|
|
d1eb837890 | ||
|
|
153201542b | ||
|
|
9137e50043 | ||
|
|
8dbe119a73 | ||
|
|
26f96d0be8 | ||
|
|
9944e6faf0 | ||
|
|
c1573c1f76 | ||
|
|
9f45ad4d2e | ||
|
|
fccc91e923 | ||
|
|
a510b276e6 | ||
|
|
6481094638 | ||
|
|
3132e12265 | ||
|
|
12af3f79d0 | ||
|
|
4835617b16 | ||
|
|
9283108240 | ||
|
|
515eaeeb1a | ||
|
|
5095fc6a64 | ||
|
|
7eedb33d50 | ||
|
|
47f78df497 | ||
|
|
74154b26a2 | ||
|
|
0c3c26b7b8 | ||
|
|
64417ef4ff | ||
|
|
f3b254e335 | ||
|
|
f27119a712 | ||
|
|
2a51d0f1e5 | ||
|
|
9156e21727 | ||
|
|
a5145be16e | ||
|
|
b104a59b10 | ||
|
|
04dbbabc03 | ||
|
|
19cc0177b8 | ||
|
|
77cd106795 | ||
|
|
71869a116d | ||
|
|
2f2bde9856 | ||
|
|
7de8838deb | ||
|
|
9bf88bbf14 | ||
|
|
35ff44b799 | ||
|
|
d1116d149e | ||
|
|
d01876ee60 | ||
|
|
74a0e8c88d | ||
|
|
fbbad27d37 | ||
|
|
2fab3e2286 | ||
|
|
a7b2052b38 | ||
|
|
3c76917c1e |
46
CHANGELOG.md
46
CHANGELOG.md
@@ -5,19 +5,52 @@ All notable changes to **Pipecat** will be documented in this file.
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
## [0.0.95] - 2025-11-18
|
||||
|
||||
### Added
|
||||
|
||||
- Added ai-coustics integrated VAD (`AICVADAnalyzer`) with `AICFilter` factory and
|
||||
example wiring; leverages the enhancement model for robust detection with no
|
||||
ONNX dependency or added processing complexity.
|
||||
|
||||
- Added a watchdog to `DeepgramFluxSTTService` to prevent dangling tasks in case the
|
||||
user was speaking and we stop receiving audio.
|
||||
|
||||
- Introduced a minimum confidence parameter in `DeepgramFluxSTTService` to avoid
|
||||
generating transcriptions below a defined threshold.
|
||||
|
||||
- Added `ElevenLabsRealtimeSTTService` which implements the Realtime STT
|
||||
service from ElevenLabs.
|
||||
|
||||
- Added a `TTSService.includes_inter_frame_spaces` property getter, so that TTS
|
||||
services that subclass `TTSService` can indicate whether the text in the
|
||||
`TTSTextFrame`s they push already contain any necessary inter-frame spaces.
|
||||
- Added word-level timestamps support to Hume TTS service
|
||||
|
||||
### Changed
|
||||
|
||||
- ⚠️ Breaking change: `LLMContext.create_image_message()`,
|
||||
`LLMContext.create_audio_message()`, `LLMContext.add_image_frame_message()`
|
||||
and `LLMContext.add_audio_frames_message()` are now async methods. This fixes
|
||||
an issue where the asyncio event loop would be blocked while encoding audio or
|
||||
images.
|
||||
|
||||
- `ConsumerProcessor` now queues frames from the producer internally instead of
|
||||
pushing them directly. This allows us to subclass consumer processors and
|
||||
manipulate frames before they are pushed.
|
||||
|
||||
- `BaseTextFilter` only require subclasses to implement the `filter()` method.
|
||||
|
||||
- Extracted the logic for retrying connections, and create a new `send_with_retry`
|
||||
method inside `WebSocketService`.
|
||||
|
||||
- Refactored `DeepgramFluxSTTService` to automatically reconnect if sending a
|
||||
message fails.
|
||||
|
||||
- Updated all STT and TTS services to use consistent error handling pattern with
|
||||
`push_error()` method for better pipeline error event integration.
|
||||
|
||||
- Added support for `maybe_capture_participant_camera()` and
|
||||
`maybe_capture_participant_screen()` for `SmallWebRTCTransport` in the runner
|
||||
utils.
|
||||
|
||||
- Added Hindi support for Rime TTS services.
|
||||
|
||||
- Updated `GeminiTTSService` to use Google Cloud Text-to-Speech streaming API
|
||||
@@ -37,6 +70,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed a `SimliVideoService` connection issue.
|
||||
|
||||
- Fixed an issue in the `Runner` where, when using `SmallWebRTCTransport`, the
|
||||
`request_data` was not being passed to the `SmallWebRTCRunnerArguments` body.
|
||||
|
||||
- Fixed subtle issue of assistant context messages ending up with double spaces
|
||||
between words or sentences.
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@ from loguru import logger
|
||||
from pipecat.audio.filters.aic_filter import AICFilter
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
@@ -48,7 +47,7 @@ def _create_aic_filter() -> AICFilter:
|
||||
|
||||
return AICFilter(
|
||||
license_key=license_key,
|
||||
enhancement_level=1.0,
|
||||
enhancement_level=0.5,
|
||||
)
|
||||
|
||||
|
||||
@@ -56,27 +55,33 @@ def _create_aic_filter() -> AICFilter:
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
audio_in_filter=_create_aic_filter(),
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
audio_in_filter=_create_aic_filter(),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
audio_in_filter=_create_aic_filter(),
|
||||
),
|
||||
"daily": lambda: (
|
||||
lambda aic: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=aic.create_vad_analyzer(lookback_buffer_size=6.0, sensitivity=6.0),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
audio_in_filter=aic,
|
||||
)
|
||||
)(_create_aic_filter()),
|
||||
"twilio": lambda: (
|
||||
lambda aic: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=aic.create_vad_analyzer(lookback_buffer_size=6.0, sensitivity=6.0),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
audio_in_filter=aic,
|
||||
)
|
||||
)(_create_aic_filter()),
|
||||
"webrtc": lambda: (
|
||||
lambda aic: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=aic.create_vad_analyzer(lookback_buffer_size=6.0, sensitivity=6.0),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
audio_in_filter=aic,
|
||||
)
|
||||
)(_create_aic_filter()),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -13,24 +13,29 @@ from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.frames.frames import LLMRunFrame, TTSTextFrame
|
||||
from pipecat.observers.loggers.debug_log_observer import DebugLogObserver, FrameEndpoint
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
)
|
||||
from pipecat.processors.frameworks.rtvi import RTVIConfig, RTVIObserver, RTVIProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.hume.tts import HUME_SAMPLE_RATE, HumeTTSService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.base_output import BaseOutputTransport
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
@@ -88,7 +93,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
stt,
|
||||
context_aggregator.user(), # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
tts, # TTS (HumeTTSService with word timestamps)
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
]
|
||||
@@ -102,7 +107,14 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
audio_out_sample_rate=HUME_SAMPLE_RATE,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
observers=[RTVIObserver(rtvi)],
|
||||
observers=[
|
||||
RTVIObserver(rtvi),
|
||||
DebugLogObserver(
|
||||
frame_types={
|
||||
TTSTextFrame: (BaseOutputTransport, FrameEndpoint.SOURCE),
|
||||
}
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
@rtvi.event_handler("on_client_ready")
|
||||
@@ -112,6 +124,9 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
logger.info(
|
||||
"💡 Word timestamps are enabled! Watch the console for TTSTextFrame logs showing each word with its PTS."
|
||||
)
|
||||
# Kick off the conversation.
|
||||
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@@ -52,7 +52,10 @@ transport_params = {
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = DeepgramFluxSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
stt = DeepgramFluxSTTService(
|
||||
api_key=os.getenv("DEEPGRAM_API_KEY"),
|
||||
params=DeepgramFluxSTTService.InputParams(min_confidence=0.3),
|
||||
)
|
||||
|
||||
tts = DeepgramTTSService(api_key=os.getenv("DEEPGRAM_API_KEY"), voice="aura-2-andromeda-en")
|
||||
|
||||
|
||||
@@ -110,7 +110,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
|
||||
# Kick off the conversation.
|
||||
image = Image.open(image_path)
|
||||
message = LLMContext.create_image_message(
|
||||
message = await LLMContext.create_image_message(
|
||||
image=image.tobytes(),
|
||||
format="RGB",
|
||||
size=image.size,
|
||||
|
||||
@@ -110,7 +110,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
|
||||
# Kick off the conversation.
|
||||
image = Image.open(image_path)
|
||||
message = LLMContext.create_image_message(
|
||||
message = await LLMContext.create_image_message(
|
||||
image=image.tobytes(),
|
||||
format="RGB",
|
||||
size=image.size,
|
||||
|
||||
@@ -117,7 +117,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
|
||||
# Kick off the conversation.
|
||||
image = Image.open(image_path)
|
||||
message = LLMContext.create_image_message(
|
||||
message = await LLMContext.create_image_message(
|
||||
image=image.tobytes(),
|
||||
format="RGB",
|
||||
size=image.size,
|
||||
|
||||
@@ -110,7 +110,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
|
||||
# Kick off the conversation.
|
||||
image = Image.open(image_path)
|
||||
message = LLMContext.create_image_message(
|
||||
message = await LLMContext.create_image_message(
|
||||
image=image.tobytes(),
|
||||
format="RGB",
|
||||
size=image.size,
|
||||
|
||||
@@ -15,14 +15,21 @@ from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import LLMRunFrame, UserImageRequestFrame
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMRunFrame,
|
||||
TextFrame,
|
||||
UserImageRequestFrame,
|
||||
)
|
||||
from pipecat.pipeline.parallel_pipeline import ParallelPipeline
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import (
|
||||
create_transport,
|
||||
@@ -66,6 +73,27 @@ async def fetch_user_image(params: FunctionCallParams):
|
||||
# await params.result_callback({"result": "Image is being captured."})
|
||||
|
||||
|
||||
class MoondreamTextFrameWrapper(FrameProcessor):
|
||||
"""Wraps Moondream-provided TextFrames with LLM response start/end frames.
|
||||
|
||||
This processor detects TextFrames and automatically wraps them with
|
||||
LLMFullResponseStartFrame and LLMFullResponseEndFrame to provide proper
|
||||
response boundaries for downstream processors.
|
||||
"""
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# If we receive a TextFrame, wrap it with response start/end frames
|
||||
if isinstance(frame, TextFrame):
|
||||
await self.push_frame(LLMFullResponseStartFrame(), direction)
|
||||
await self.push_frame(frame, direction)
|
||||
await self.push_frame(LLMFullResponseEndFrame(), direction)
|
||||
else:
|
||||
# For all other frames, just pass them through
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
@@ -130,6 +158,12 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
# If you run into weird description, try with use_cpu=True
|
||||
moondream = MoondreamService()
|
||||
|
||||
# Wrap TextFrames with LLM response start/end frames, which makes Moondream
|
||||
# output be treated like LLM responses for the purpose of context
|
||||
# aggregation. Without this, the assistant context aggregator would ignore
|
||||
# Moondream output (if the TTS service is disabled).
|
||||
moondream_text_wrapper = MoondreamTextFrameWrapper()
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
@@ -137,7 +171,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
context_aggregator.user(), # User responses
|
||||
ParallelPipeline(
|
||||
[llm], # LLM
|
||||
[moondream],
|
||||
[moondream, moondream_text_wrapper],
|
||||
),
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
|
||||
@@ -391,7 +391,7 @@ class AudioAccumulator(FrameProcessor):
|
||||
)
|
||||
self._user_speaking = False
|
||||
context = LLMContext()
|
||||
context.add_audio_frames_message(audio_frames=self._audio_frames)
|
||||
await context.add_audio_frames_message(audio_frames=self._audio_frames)
|
||||
await self.push_frame(LLMContextFrame(context=context))
|
||||
elif isinstance(frame, InputAudioRawFrame):
|
||||
# Append the audio frame to our buffer. Treat the buffer as a ring buffer, dropping the oldest
|
||||
|
||||
@@ -150,7 +150,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
LLMLogObserver(),
|
||||
DebugLogObserver(
|
||||
frame_types={
|
||||
TTSTextFrame: (BaseOutputTransport, FrameEndpoint.DESTINATION),
|
||||
TTSTextFrame: (BaseOutputTransport, FrameEndpoint.SOURCE),
|
||||
UserStartedSpeakingFrame: (BaseInputTransport, FrameEndpoint.SOURCE),
|
||||
EndFrame: None,
|
||||
}
|
||||
|
||||
@@ -155,7 +155,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
You are a helpful LLM in a WebRTC call.
|
||||
Your goal is to demonstrate your capabilities in a succinct way.
|
||||
You have access to tools to search the Rijksmuseum collection.
|
||||
Offer, for example, to show the earliest Rembrandt work from the museum. Use the `search_artwork` tool.
|
||||
Offer, for example, to show a floral still life, use the `search_artwork` tool.
|
||||
The tool may respond with a JSON object with an `artworks` array. Choose the art from that array.
|
||||
Once the tool has responded, tell the user the title and use the `open_image_in_browser` tool.
|
||||
Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points.
|
||||
|
||||
@@ -1,159 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from mcp.client.session_group import SseServerParameters
|
||||
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.anthropic.llm import AnthropicLLMService
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.mcp_service import MCPClient
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
)
|
||||
|
||||
llm = AnthropicLLMService(
|
||||
api_key=os.getenv("ANTHROPIC_API_KEY"), model="claude-3-7-sonnet-latest"
|
||||
)
|
||||
|
||||
try:
|
||||
# https://docs.mcp.run/integrating/tutorials/mcp-run-sse-openai-agents/
|
||||
mcp = MCPClient(server_params=SseServerParameters(url=os.getenv("MCP_RUN_SSE_URL")))
|
||||
except Exception as e:
|
||||
logger.error(f"error setting up mcp")
|
||||
logger.exception("error trace:")
|
||||
|
||||
tools = {}
|
||||
try:
|
||||
tools = await mcp.register_tools(llm)
|
||||
except Exception as e:
|
||||
logger.error(f"error registering tools")
|
||||
logger.exception("error trace:")
|
||||
|
||||
system = f"""
|
||||
You are a helpful LLM in a WebRTC call.
|
||||
Your goal is to demonstrate your capabilities in a succinct way.
|
||||
You have access to a number of tools provided by mcp.run. Use any and all tools to help users.
|
||||
Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points.
|
||||
Respond to what the user said in a creative and helpful way.
|
||||
When asked for today's date, use 'https://www.datetoday.net/'.
|
||||
Don't overexplain what you are doing.
|
||||
Just respond with short sentences when you are carrying out tool calls.
|
||||
"""
|
||||
|
||||
messages = [{"role": "system", "content": system}]
|
||||
|
||||
context = LLMContext(messages, tools)
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt,
|
||||
context_aggregator.user(), # User spoken responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses and tool context
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected: {client}")
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if not os.getenv("MCP_RUN_SSE_URL"):
|
||||
logger.error(
|
||||
f"Please set MCP_RUN_SSE_URL environment variable for this example. See https://mcp.run"
|
||||
)
|
||||
import sys
|
||||
|
||||
sys.exit(1)
|
||||
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -7,6 +7,7 @@
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
@@ -15,7 +16,7 @@ import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from mcp import StdioServerParameters
|
||||
from mcp.client.session_group import SseServerParameters
|
||||
from mcp.client.session_group import StreamableHttpParameters
|
||||
from PIL import Image
|
||||
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
@@ -66,10 +67,12 @@ class UrlToImageProcessor(FrameProcessor):
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
def extract_url(self, text: str):
|
||||
pattern = r"!\[[^\]]*\]\((https?://[^)]+\.(png|jpg|jpeg|PNG|JPG|JPEG|gif))\)"
|
||||
match = re.search(pattern, text)
|
||||
if match:
|
||||
return match.group(1)
|
||||
data = json.loads(text)
|
||||
if "artObject" in data:
|
||||
return data["artObject"]["webImage"]["url"]
|
||||
if "artworks" in data and len(data["artworks"]):
|
||||
return data["artworks"][0]["webImage"]["url"]
|
||||
|
||||
return None
|
||||
|
||||
async def run_image_process(self, image_url: str):
|
||||
@@ -132,10 +135,11 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
system = f"""
|
||||
You are a helpful LLM in a WebRTC call.
|
||||
Your goal is to demonstrate your capabilities in a succinct way.
|
||||
You have access to tools to search the Rijksmuseum collection.
|
||||
Offer, for example, to show the earliest Rembrandt work from the museum. Use the `search_artwork` tool.
|
||||
You have access to tools to search the Rijksmuseum collection and the user's GitHub repositories and account.
|
||||
Offer, for example, to show a floral still life, use the `search_artwork` tool.
|
||||
The tool may respond with a JSON object with an `artworks` array. Choose the art from that array.
|
||||
Once the tool has responded, tell the user the title and use the `open_image_in_browser` tool.
|
||||
You can also offer to answer users questions about their GitHub repositories and account.
|
||||
Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points.
|
||||
Respond to what the user said in a creative and helpful way.
|
||||
Don't overexplain what you are doing.
|
||||
@@ -145,11 +149,11 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
messages = [{"role": "system", "content": system}]
|
||||
|
||||
try:
|
||||
mcp = MCPClient(
|
||||
rijksmuseum_mcp = MCPClient(
|
||||
server_params=StdioServerParameters(
|
||||
command=shutil.which("npx"),
|
||||
# https://github.com/r-huijts/rijksmuseum-mcp
|
||||
args=["-y", "mcp-server-error setting up mcp"],
|
||||
args=["-y", "mcp-server-rijksmuseum"],
|
||||
env={"RIJKSMUSEUM_API_KEY": os.getenv("RIJKSMUSEUM_API_KEY")},
|
||||
)
|
||||
)
|
||||
@@ -157,24 +161,32 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.error(f"error setting up rijksmuseum mcp")
|
||||
logger.exception("error trace:")
|
||||
try:
|
||||
# https://docs.mcp.run/integrating/tutorials/mcp-run-sse-openai-agents/
|
||||
# ie. "https://www.mcp.run/api/mcp/sse?..."
|
||||
# ensure the profile has a tool or few installed
|
||||
mcp_run = MCPClient(server_params=SseServerParameters(url=os.getenv("MCP_RUN_SSE_URL")))
|
||||
# Github MCP docs: https://github.com/github/github-mcp-server
|
||||
# Enable Github Copilot on your GitHub account. Free tier is ok. (https://github.com/settings/copilot)
|
||||
# Generate a personal access token. It must be a Fine-grained token, classic tokens are not supported. (https://github.com/settings/personal-access-tokens)
|
||||
# Set permissions you want to use (eg. "all repositories", "profile: read/write", etc)
|
||||
github_mcp = MCPClient(
|
||||
server_params=StreamableHttpParameters(
|
||||
url="https://api.githubcopilot.com/mcp/",
|
||||
headers={
|
||||
"Authorization": f"Bearer {os.getenv('GITHUB_PERSONAL_ACCESS_TOKEN')}"
|
||||
},
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"error setting up mcp.run")
|
||||
logger.exception("error trace:")
|
||||
|
||||
tools = {}
|
||||
run_tools = {}
|
||||
rijksmuseum_tools = {}
|
||||
github_tools = {}
|
||||
try:
|
||||
tools = await mcp.register_tools(llm)
|
||||
run_tools = await mcp_run.register_tools(llm)
|
||||
rijksmuseum_tools = await rijksmuseum_mcp.register_tools(llm)
|
||||
github_tools = await github_mcp.register_tools(llm)
|
||||
except Exception as e:
|
||||
logger.error(f"error registering tools")
|
||||
logger.exception("error trace:")
|
||||
|
||||
all_standard_tools = run_tools.standard_tools + tools.standard_tools
|
||||
all_standard_tools = rijksmuseum_tools.standard_tools + github_tools.standard_tools
|
||||
all_tools = ToolsSchema(standard_tools=all_standard_tools)
|
||||
|
||||
context = LLMContext(messages, all_tools)
|
||||
@@ -226,9 +238,9 @@ async def bot(runner_args: RunnerArguments):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if not os.getenv("RIJKSMUSEUM_API_KEY") or not os.getenv("MCP_RUN_SSE_URL"):
|
||||
if not os.getenv("RIJKSMUSEUM_API_KEY") or not os.getenv("GITHUB_PERSONAL_ACCESS_TOKEN"):
|
||||
logger.error(
|
||||
f"Please set RIJKSMUSEUM_API_KEY and MCP_RUN_SSE_URL environment variables. See https://github.com/r-huijts/rijksmuseum-mcp and https://mcp.run"
|
||||
f"Please set `RIJKSMUSEUM_API_KEY` and `GITHUB_PERSONAL_ACCESS_TOKEN` environment variables. See https://github.com/r-huijts/rijksmuseum-mcp."
|
||||
)
|
||||
import sys
|
||||
|
||||
@@ -45,7 +45,7 @@ Source = "https://github.com/pipecat-ai/pipecat"
|
||||
Website = "https://pipecat.ai"
|
||||
|
||||
[project.optional-dependencies]
|
||||
aic = [ "aic-sdk~=1.0.1" ]
|
||||
aic = [ "aic-sdk~=1.1.0" ]
|
||||
anthropic = [ "anthropic~=0.49.0" ]
|
||||
assemblyai = [ "pipecat-ai[websockets-base]" ]
|
||||
asyncai = [ "pipecat-ai[websockets-base]" ]
|
||||
@@ -99,7 +99,7 @@ local-smart-turn = [ "coremltools>=8.0", "transformers", "torch>=2.5.0,<3", "tor
|
||||
local-smart-turn-v3 = [ "transformers", "onnxruntime>=1.20.1,<2" ]
|
||||
remote-smart-turn = []
|
||||
silero = [ "onnxruntime>=1.20.1,<2" ]
|
||||
simli = [ "simli-ai~=0.1.25"]
|
||||
simli = [ "simli-ai~=1.0.3"]
|
||||
soniox = [ "pipecat-ai[websockets-base]" ]
|
||||
soundfile = [ "soundfile~=0.13.1" ]
|
||||
speechmatics = [ "speechmatics-rt>=0.5.0" ]
|
||||
|
||||
@@ -30,8 +30,8 @@ EVAL_SIMPLE_MATH = EvalConfig(
|
||||
)
|
||||
|
||||
EVAL_WEATHER = EvalConfig(
|
||||
prompt="What's the weather in San Francisco?",
|
||||
eval="The user says something specific about the current weather in San Francisco, including the degrees.",
|
||||
prompt="What's the weather in San Francisco (in farhenheit or celsius)?",
|
||||
eval="The user says something specific about the current weather in San Francisco, including the degrees (in farhenheit or celsius).",
|
||||
)
|
||||
|
||||
EVAL_ONLINE_SEARCH = EvalConfig(
|
||||
@@ -70,7 +70,7 @@ EVAL_VOICEMAIL = EvalConfig(
|
||||
|
||||
EVAL_CONVERSATION = EvalConfig(
|
||||
prompt="Hello, this is Mark.",
|
||||
eval="The user replies with a greeting.",
|
||||
eval="The user acknowledges the greeting.",
|
||||
eval_speaks_first=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -68,6 +68,58 @@ class AICFilter(BaseAudioFilter):
|
||||
# Model will be created in start() since the API now requires sample_rate
|
||||
self._aic = None
|
||||
|
||||
def get_vad_factory(self):
|
||||
"""Return a zero-arg factory that will create the VAD once the model exists.
|
||||
|
||||
Returns:
|
||||
A zero-argument callable that, when invoked, returns an initialized
|
||||
VoiceActivityDetector bound to the underlying AIC model. Raises a
|
||||
RuntimeError if the model has not been initialized (i.e. start()
|
||||
has not been called successfully).
|
||||
"""
|
||||
|
||||
def _factory():
|
||||
if self._aic is None:
|
||||
raise RuntimeError("AIC model not initialized yet. Call start(sample_rate) first.")
|
||||
return self._aic.create_vad()
|
||||
|
||||
return _factory
|
||||
|
||||
def create_vad_analyzer(
|
||||
self,
|
||||
*,
|
||||
lookback_buffer_size: Optional[float] = None,
|
||||
sensitivity: Optional[float] = None,
|
||||
):
|
||||
"""Return an analyzer that will lazily instantiate the AIC VAD when ready.
|
||||
|
||||
AIC VAD parameters:
|
||||
- lookback_buffer_size:
|
||||
Number of window-length audio buffers used as a lookback buffer.
|
||||
Higher values increase prediction stability but add latency.
|
||||
Range: 1.0 .. 20.0, Default (SDK): 6.0
|
||||
- sensitivity:
|
||||
Energy threshold sensitivity. Energy threshold = 10 ** (-sensitivity).
|
||||
Range: 1.0 .. 15.0, Default (SDK): 6.0
|
||||
|
||||
Args:
|
||||
lookback_buffer_size: Optional lookback buffer size to configure on the VAD.
|
||||
Range: 1.0 .. 20.0. If None, SDK default is used.
|
||||
sensitivity: Optional sensitivity (energy threshold) to configure on the VAD.
|
||||
Range: 1.0 .. 15.0. If None, SDK default is used.
|
||||
|
||||
Returns:
|
||||
A lazily-initialized AICVADAnalyzer that will bind to the VAD backend
|
||||
once the filter's model has been created (after start(sample_rate)).
|
||||
"""
|
||||
from pipecat.audio.vad.aic_vad import AICVADAnalyzer
|
||||
|
||||
return AICVADAnalyzer(
|
||||
vad_factory=self.get_vad_factory(),
|
||||
lookback_buffer_size=lookback_buffer_size,
|
||||
sensitivity=sensitivity,
|
||||
)
|
||||
|
||||
async def start(self, sample_rate: int):
|
||||
"""Initialize the filter with the transport's sample rate.
|
||||
|
||||
@@ -185,7 +237,7 @@ class AICFilter(BaseAudioFilter):
|
||||
)
|
||||
|
||||
# Process planar in-place; returns ndarray (same shape)
|
||||
out_f32 = self._aic.process(block_f32)
|
||||
out_f32 = await self._aic.process_async(block_f32)
|
||||
|
||||
# Convert back to int16 bytes, planar layout
|
||||
out_i16 = np.clip(out_f32 * 32768.0, -32768, 32767).astype(np.int16)
|
||||
|
||||
158
src/pipecat/audio/vad/aic_vad.py
Normal file
158
src/pipecat/audio/vad/aic_vad.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""AIC-integrated VAD analyzer that lazily binds to the AIC SDK backend.
|
||||
|
||||
This analyzer queries the backend's is_speech_detected() and maps it to a float
|
||||
confidence (1.0/0.0). It uses 10 ms windows based on the sample rate and applies
|
||||
optional AIC VAD parameters (lookback_buffer_size, sensitivity) when available.
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.vad.vad_analyzer import VADAnalyzer, VADParams
|
||||
|
||||
try:
|
||||
from aic import AICVadParameter
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use the AIC filter, you need to `pip install pipecat-ai[aic]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class AICVADAnalyzer(VADAnalyzer):
|
||||
"""VAD analyzer that lazily instantiates the AIC VoiceActivityDetector via a factory.
|
||||
|
||||
The analyzer can be constructed before the AIC Model exists. Once the filter has
|
||||
started and the Model is available, the provided factory will succeed and the
|
||||
backend VAD will be created. We then switch to single-sample updates where
|
||||
num_frames_required() returns 1 and confidence is derived from the backend's
|
||||
boolean is_speech_detected() state.
|
||||
|
||||
AIC VAD runtime parameters:
|
||||
- lookback_buffer_size:
|
||||
Controls the lookback buffer size used by the VAD, i.e. the number of
|
||||
window-length audio buffers used as a lookback buffer. Larger values improve
|
||||
stability but increase latency.
|
||||
Range: 1.0 .. 20.0
|
||||
Default (SDK): 6.0
|
||||
- sensitivity:
|
||||
Controls the energy threshold sensitivity. Higher values make the detector
|
||||
less sensitive (require more energy to count as speech).
|
||||
Range: 1.0 .. 15.0
|
||||
Formula: Energy threshold = 10 ** (-sensitivity)
|
||||
Default (SDK): 6.0
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vad_factory: Optional[Callable[[], Any]] = None,
|
||||
lookback_buffer_size: Optional[float] = None,
|
||||
sensitivity: Optional[float] = None,
|
||||
):
|
||||
"""Create an AIC VAD analyzer.
|
||||
|
||||
Args:
|
||||
vad_factory:
|
||||
Zero-arg callable that returns an initialized AIC VoiceActivityDetector.
|
||||
This may raise until the filter's Model has been created; the analyzer
|
||||
will retry on set_sample_rate/first use.
|
||||
lookback_buffer_size:
|
||||
Optional override for AIC VAD lookback buffer size.
|
||||
Range: 1.0 .. 20.0. Larger values increase stability at the cost of latency.
|
||||
If None, the SDK default (6.0) is used.
|
||||
sensitivity:
|
||||
Optional override for AIC VAD sensitivity (energy threshold).
|
||||
Range: 1.0 .. 15.0. Energy threshold = 10 ** (-sensitivity).
|
||||
If None, the SDK default (6.0) is used.
|
||||
"""
|
||||
# Use fixed VAD parameters for AIC: no user override
|
||||
fixed_params = VADParams(confidence=0.5, start_secs=0.0, stop_secs=0.0, min_volume=0.0)
|
||||
super().__init__(sample_rate=None, params=fixed_params)
|
||||
self._vad_factory = vad_factory
|
||||
self._backend_vad: Optional[Any] = None
|
||||
self._pending_lookback: Optional[float] = lookback_buffer_size
|
||||
self._pending_sensitivity: Optional[float] = sensitivity
|
||||
|
||||
def bind_vad_factory(self, vad_factory: Callable[[], Any]):
|
||||
"""Attach or replace the factory post-construction."""
|
||||
self._vad_factory = vad_factory
|
||||
self._ensure_backend_initialized()
|
||||
|
||||
def _apply_backend_params(self):
|
||||
"""Apply optional AIC VAD parameters if available."""
|
||||
if self._backend_vad is None or AICVadParameter is None:
|
||||
return
|
||||
try:
|
||||
if self._pending_lookback is not None:
|
||||
self._backend_vad.set_parameter(
|
||||
AICVadParameter.LOOKBACK_BUFFER_SIZE, float(self._pending_lookback)
|
||||
)
|
||||
if self._pending_sensitivity is not None:
|
||||
self._backend_vad.set_parameter(
|
||||
AICVadParameter.SENSITIVITY, float(self._pending_sensitivity)
|
||||
)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.debug(f"AIC VAD parameter application deferred/failed: {e}")
|
||||
|
||||
def _ensure_backend_initialized(self):
|
||||
if self._backend_vad is not None:
|
||||
return
|
||||
if not self._vad_factory:
|
||||
return
|
||||
try:
|
||||
self._backend_vad = self._vad_factory()
|
||||
self._apply_backend_params()
|
||||
# With backend ready, recompute internal frame sizing
|
||||
super().set_params(self._params)
|
||||
logger.debug("AIC VAD backend initialized in analyzer.")
|
||||
except Exception as e: # noqa: BLE001
|
||||
# Filter may not be started yet; try again later
|
||||
logger.debug(f"Deferring AIC VAD backend initialization: {e}")
|
||||
|
||||
def set_sample_rate(self, sample_rate: int):
|
||||
"""Set the sample rate for audio processing.
|
||||
|
||||
Args:
|
||||
sample_rate: Audio sample rate in Hz.
|
||||
"""
|
||||
# Set rate and attempt backend initialization once we know SR
|
||||
self._sample_rate = self._init_sample_rate or sample_rate
|
||||
self._ensure_backend_initialized()
|
||||
# Ensure params are initialized even if backend not ready yet
|
||||
try:
|
||||
super().set_params(self._params)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def num_frames_required(self) -> int:
|
||||
"""Get the number of audio frames required for analysis.
|
||||
|
||||
Returns:
|
||||
Number of frames needed for VAD processing.
|
||||
"""
|
||||
# Use 10 ms windows based on sample rate
|
||||
return int(self.sample_rate * 0.01) if self.sample_rate > 0 else 160
|
||||
|
||||
def voice_confidence(self, buffer: bytes) -> float:
|
||||
"""Calculate voice activity confidence for the given audio buffer.
|
||||
|
||||
Args:
|
||||
buffer: Audio buffer to analyze.
|
||||
|
||||
Returns:
|
||||
Voice confidence score is 0.0 or 1.0.
|
||||
"""
|
||||
# Ensure backend exists (filter might have started since last call)
|
||||
self._ensure_backend_initialized()
|
||||
if self._backend_vad is None:
|
||||
return 0.0
|
||||
|
||||
# We do not need to analyze 'buffer' here since the model's VAD is updated
|
||||
# as part of the enhancement pipeline. Simply query the boolean and map it.
|
||||
try:
|
||||
is_speech = self._backend_vad.is_speech_detected()
|
||||
return 1.0 if is_speech else 0.0
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error(f"AIC VAD inference error: {e}")
|
||||
return 0.0
|
||||
@@ -352,7 +352,10 @@ class TextFrame(DataFrame):
|
||||
class LLMTextFrame(TextFrame):
|
||||
"""Text frame generated by LLM services."""
|
||||
|
||||
pass
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
# LLM services send text frames with all necessary spaces included
|
||||
self.includes_inter_frame_spaces = True
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -804,11 +807,13 @@ class ErrorFrame(SystemFrame):
|
||||
error: Description of the error that occurred.
|
||||
fatal: Whether the error is fatal and requires bot shutdown.
|
||||
processor: The frame processor that generated the error.
|
||||
exception: The exception that occurred.
|
||||
"""
|
||||
|
||||
error: str
|
||||
fatal: bool = False
|
||||
processor: Optional["FrameProcessor"] = None
|
||||
exception: Optional[Exception] = None
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name}(error: {self.error}, fatal: {self.fatal})"
|
||||
|
||||
@@ -14,6 +14,7 @@ translation from this universal context into whatever format it needs, using a
|
||||
service-specific adapter.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import wave
|
||||
@@ -137,7 +138,7 @@ class LLMContext:
|
||||
return {"role": role, "content": content}
|
||||
|
||||
@staticmethod
|
||||
def create_image_message(
|
||||
async def create_image_message(
|
||||
*,
|
||||
role: str = "user",
|
||||
format: str,
|
||||
@@ -154,15 +155,21 @@ class LLMContext:
|
||||
image: Raw image bytes.
|
||||
text: Optional text to include with the image.
|
||||
"""
|
||||
buffer = io.BytesIO()
|
||||
Image.frombytes(format, size, image).save(buffer, format="JPEG")
|
||||
encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
def encode_image():
|
||||
buffer = io.BytesIO()
|
||||
Image.frombytes(format, size, image).save(buffer, format="JPEG")
|
||||
encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
return encoded_image
|
||||
|
||||
encoded_image = await asyncio.to_thread(encode_image)
|
||||
|
||||
url = f"data:image/jpeg;base64,{encoded_image}"
|
||||
|
||||
return LLMContext.create_image_url_message(role=role, url=url, text=text)
|
||||
|
||||
@staticmethod
|
||||
def create_audio_message(
|
||||
async def create_audio_message(
|
||||
*, role: str = "user", audio_frames: list[AudioRawFrame], text: str = "Audio follows"
|
||||
) -> LLMContextMessage:
|
||||
"""Create a context message containing audio.
|
||||
@@ -172,21 +179,26 @@ class LLMContext:
|
||||
audio_frames: List of audio frame objects to include.
|
||||
text: Optional text to include with the audio.
|
||||
"""
|
||||
sample_rate = audio_frames[0].sample_rate
|
||||
num_channels = audio_frames[0].num_channels
|
||||
|
||||
content = []
|
||||
content.append({"type": "text", "text": text})
|
||||
data = b"".join(frame.audio for frame in audio_frames)
|
||||
async def encode_audio():
|
||||
sample_rate = audio_frames[0].sample_rate
|
||||
num_channels = audio_frames[0].num_channels
|
||||
|
||||
with io.BytesIO() as buffer:
|
||||
with wave.open(buffer, "wb") as wf:
|
||||
wf.setsampwidth(2)
|
||||
wf.setnchannels(num_channels)
|
||||
wf.setframerate(sample_rate)
|
||||
wf.writeframes(data)
|
||||
content = []
|
||||
content.append({"type": "text", "text": text})
|
||||
data = b"".join(frame.audio for frame in audio_frames)
|
||||
|
||||
encoded_audio = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
with io.BytesIO() as buffer:
|
||||
with wave.open(buffer, "wb") as wf:
|
||||
wf.setsampwidth(2)
|
||||
wf.setnchannels(num_channels)
|
||||
wf.setframerate(sample_rate)
|
||||
wf.writeframes(data)
|
||||
|
||||
encoded_audio = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
return encoded_audio
|
||||
|
||||
encoded_audio = await asyncio.to_thread(encode_audio)
|
||||
|
||||
content.append(
|
||||
{
|
||||
@@ -321,7 +333,7 @@ class LLMContext:
|
||||
"""
|
||||
self._tool_choice = tool_choice
|
||||
|
||||
def add_image_frame_message(
|
||||
async def add_image_frame_message(
|
||||
self, *, format: str, size: tuple[int, int], image: bytes, text: Optional[str] = None
|
||||
):
|
||||
"""Add a message containing an image frame.
|
||||
@@ -332,10 +344,12 @@ class LLMContext:
|
||||
image: Raw image bytes.
|
||||
text: Optional text to include with the image.
|
||||
"""
|
||||
message = LLMContext.create_image_message(format=format, size=size, image=image, text=text)
|
||||
message = await LLMContext.create_image_message(
|
||||
format=format, size=size, image=image, text=text
|
||||
)
|
||||
self.add_message(message)
|
||||
|
||||
def add_audio_frames_message(
|
||||
async def add_audio_frames_message(
|
||||
self, *, audio_frames: list[AudioRawFrame], text: str = "Audio follows"
|
||||
):
|
||||
"""Add a message containing audio frames.
|
||||
@@ -344,7 +358,7 @@ class LLMContext:
|
||||
audio_frames: List of audio frame objects to include.
|
||||
text: Optional text to include with the audio.
|
||||
"""
|
||||
message = LLMContext.create_audio_message(audio_frames=audio_frames, text=text)
|
||||
message = await LLMContext.create_audio_message(audio_frames=audio_frames, text=text)
|
||||
self.add_message(message)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -66,7 +66,7 @@ from pipecat.processors.aggregators.llm_response import (
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.utils.string import concatenate_aggregated_text
|
||||
from pipecat.utils.string import TextPartForConcatenation, concatenate_aggregated_text
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
|
||||
@@ -90,15 +90,7 @@ class LLMContextAggregator(FrameProcessor):
|
||||
self._context = context
|
||||
self._role = role
|
||||
|
||||
self._aggregation: List[str] = []
|
||||
|
||||
# Whether to add spaces between text parts.
|
||||
# (Currently only used by LLMAssistantAggregator, but could be expanded
|
||||
# to LLMUserAggregator in the future if needed; that would require
|
||||
# additional work since LLMUserAggregator currently trims spaces from
|
||||
# incoming frames before determining whether it "really" received any
|
||||
# text).
|
||||
self._add_spaces = True
|
||||
self._aggregation: List[TextPartForConcatenation] = []
|
||||
|
||||
@property
|
||||
def messages(self) -> List[LLMContextMessage]:
|
||||
@@ -191,7 +183,7 @@ class LLMContextAggregator(FrameProcessor):
|
||||
Returns:
|
||||
The concatenated aggregation string.
|
||||
"""
|
||||
return concatenate_aggregated_text(self._aggregation, self._add_spaces)
|
||||
return concatenate_aggregated_text(self._aggregation)
|
||||
|
||||
|
||||
class LLMUserAggregator(LLMContextAggregator):
|
||||
@@ -441,7 +433,12 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
if not text.strip():
|
||||
return
|
||||
|
||||
self._aggregation.append(text)
|
||||
# Transcriptions never include inter-part spaces (so far).
|
||||
self._aggregation.append(
|
||||
TextPartForConcatenation(
|
||||
text, includes_inter_part_spaces=frame.includes_inter_frame_spaces
|
||||
)
|
||||
)
|
||||
# We just got a final result, so let's reset interim results.
|
||||
self._seen_interim_results = False
|
||||
# Reset aggregation timer.
|
||||
@@ -796,7 +793,7 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
|
||||
logger.debug(f"{self} Appending UserImageRawFrame to LLM context (size: {frame.size})")
|
||||
|
||||
self._context.add_image_frame_message(
|
||||
await self._context.add_image_frame_message(
|
||||
format=frame.format,
|
||||
size=frame.size,
|
||||
image=frame.image,
|
||||
@@ -821,11 +818,11 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
if len(frame.text) == 0:
|
||||
return
|
||||
|
||||
# Track whether we need to add spaces between text parts
|
||||
# Assumption: we can just keep track of the latest frame's value
|
||||
self._add_spaces = not frame.includes_inter_frame_spaces
|
||||
|
||||
self._aggregation.append(frame.text)
|
||||
self._aggregation.append(
|
||||
TextPartForConcatenation(
|
||||
frame.text, includes_inter_part_spaces=frame.includes_inter_frame_spaces
|
||||
)
|
||||
)
|
||||
|
||||
def _context_updated_task_finished(self, task: asyncio.Task):
|
||||
self._context_updated_tasks.discard(task)
|
||||
|
||||
@@ -83,4 +83,4 @@ class ConsumerProcessor(FrameProcessor):
|
||||
while True:
|
||||
frame = await self._queue.get()
|
||||
new_frame = await self._transformer(frame)
|
||||
await self.push_frame(new_frame, self._direction)
|
||||
await self.queue_frame(new_frame, self._direction)
|
||||
|
||||
@@ -126,6 +126,4 @@ class WakeCheckFilter(FrameProcessor):
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
except Exception as e:
|
||||
error_msg = f"Error in wake word filter: {e}"
|
||||
logger.exception(error_msg)
|
||||
await self.push_error(ErrorFrame(error_msg))
|
||||
await self.push_error(error_msg=f"Error in wake word filter: {e}", exception=e)
|
||||
|
||||
@@ -142,6 +142,7 @@ class FrameProcessor(BaseObject):
|
||||
- on_after_process_frame: Called after a frame is processed
|
||||
- on_before_push_frame: Called before a frame is pushed
|
||||
- on_after_push_frame: Called after a frame is pushed
|
||||
- on_error: Called when an error is raised in the frame processing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -234,6 +235,7 @@ class FrameProcessor(BaseObject):
|
||||
self._register_event_handler("on_after_process_frame", sync=True)
|
||||
self._register_event_handler("on_before_push_frame", sync=True)
|
||||
self._register_event_handler("on_after_push_frame", sync=True)
|
||||
self._register_event_handler("on_error", sync=True)
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
@@ -630,7 +632,45 @@ class FrameProcessor(BaseObject):
|
||||
elif isinstance(frame, (FrameProcessorResumeFrame, FrameProcessorResumeUrgentFrame)):
|
||||
await self.__resume(frame)
|
||||
|
||||
async def push_error(self, error: ErrorFrame):
|
||||
async def push_error(
|
||||
self,
|
||||
error_msg: Optional[str] = None,
|
||||
exception: Optional[Exception] = None,
|
||||
fatal: bool = False,
|
||||
):
|
||||
"""Creates and pushes an ErrorFrame upstream.
|
||||
|
||||
Creates and pushes an ErrorFrame upstream to notify other processors in the
|
||||
pipeline about an error condition. The error frame will include context about
|
||||
which processor generated the error.
|
||||
|
||||
Args:
|
||||
error_msg: Optional descriptive message explaining the error condition.
|
||||
exception: Optional exception object that caused the error, if available.
|
||||
This provides additional context for debugging and error handling.
|
||||
fatal: Whether this error should be considered fatal to the pipeline.
|
||||
Fatal errors typically cause the entire pipeline to stop processing.
|
||||
Defaults to False for non-fatal errors.
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Non-fatal error
|
||||
await self.push_error("Failed to process audio chunk, skipping")
|
||||
|
||||
# Fatal error with exception context
|
||||
try:
|
||||
result = some_critical_operation()
|
||||
except Exception as e:
|
||||
await self.push_error("Critical operation failed", exception=e, fatal=True)
|
||||
```
|
||||
"""
|
||||
error_message = error_msg or f"{self} exception: {exception}"
|
||||
error_frame = ErrorFrame(
|
||||
error=error_message, fatal=fatal, exception=exception, processor=self
|
||||
)
|
||||
await self.push_error_frame(error=error_frame)
|
||||
|
||||
async def push_error_frame(self, error: ErrorFrame):
|
||||
"""Push an error frame upstream.
|
||||
|
||||
Args:
|
||||
@@ -638,6 +678,8 @@ class FrameProcessor(BaseObject):
|
||||
"""
|
||||
if not error.processor:
|
||||
error.processor = self
|
||||
await self._call_event_handler("on_error", error)
|
||||
logger.error(error.error)
|
||||
await self.push_frame(error, FrameDirection.UPSTREAM)
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
@@ -759,8 +801,10 @@ class FrameProcessor(BaseObject):
|
||||
await self.__cancel_process_task()
|
||||
self.__create_process_task()
|
||||
except Exception as e:
|
||||
logger.exception(f"Uncaught exception in {self} when handling _start_interruption: {e}")
|
||||
await self.push_error(ErrorFrame(str(e)))
|
||||
await self.push_error(
|
||||
error_msg=f"Uncaught exception in {self} when handling _start_interruption: {e}",
|
||||
exception=e,
|
||||
)
|
||||
|
||||
async def __internal_push_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Internal method to push frames to adjacent processors.
|
||||
@@ -797,8 +841,7 @@ class FrameProcessor(BaseObject):
|
||||
await self._observer.on_push_frame(data)
|
||||
await self._prev.queue_frame(frame, direction)
|
||||
except Exception as e:
|
||||
logger.exception(f"Uncaught exception in {self}: {e}")
|
||||
await self.push_error(ErrorFrame(str(e)))
|
||||
await self.push_error(exception=e)
|
||||
|
||||
def _check_started(self, frame: Frame):
|
||||
"""Check if the processor has been started.
|
||||
@@ -874,8 +917,7 @@ class FrameProcessor(BaseObject):
|
||||
|
||||
await self._call_event_handler("on_after_process_frame", frame)
|
||||
except Exception as e:
|
||||
logger.exception(f"{self}: error processing frame: {e}")
|
||||
await self.push_error(ErrorFrame(str(e)))
|
||||
await self.push_error(error_msg=f"{self}: error processing frame: {e}", exception=e)
|
||||
|
||||
async def __input_frame_task_handler(self):
|
||||
"""Handle frames from the input queue.
|
||||
|
||||
@@ -24,7 +24,7 @@ try:
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.runnables import Runnable
|
||||
except ModuleNotFoundError as e:
|
||||
logger.exception("In order to use Langchain, you need to `pip install pipecat-ai[langchain]`. ")
|
||||
logger.error("In order to use Langchain, you need to `pip install pipecat-ai[langchain]`. ")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
@@ -113,6 +113,6 @@ class LangchainProcessor(FrameProcessor):
|
||||
except GeneratorExit:
|
||||
logger.warning(f"{self} generator was closed prematurely")
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} an unknown error occurred: {e}")
|
||||
await self.push_error(exception=e)
|
||||
finally:
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
|
||||
@@ -23,7 +23,7 @@ try:
|
||||
from strands import Agent
|
||||
from strands.multiagent.graph import Graph
|
||||
except ModuleNotFoundError as e:
|
||||
logger.exception("In order to use Strands Agents, you need to `pip install strands-agents`.")
|
||||
logger.error("In order to use Strands Agents, you need to `pip install strands-agents`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
@@ -143,7 +143,7 @@ class StrandsAgentsProcessor(FrameProcessor):
|
||||
except GeneratorExit:
|
||||
logger.warning(f"{self} generator was closed prematurely")
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} an unknown error occurred: {e}")
|
||||
await self.push_error(exception=e)
|
||||
finally:
|
||||
if ttfb_tracking:
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
@@ -26,7 +26,7 @@ from pipecat.frames.frames import (
|
||||
TTSTextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.utils.string import concatenate_aggregated_text
|
||||
from pipecat.utils.string import TextPartForConcatenation, concatenate_aggregated_text
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
|
||||
@@ -98,15 +98,9 @@ class AssistantTranscriptProcessor(BaseTranscriptProcessor):
|
||||
**kwargs: Additional arguments passed to parent class.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._current_text_parts: List[str] = []
|
||||
self._current_text_parts: List[TextPartForConcatenation] = []
|
||||
self._aggregation_start_time: Optional[str] = None
|
||||
|
||||
# Whether to add spaces between text parts.
|
||||
# (The use of this could be expanded to the UserTranscriptProcessor in
|
||||
# the future if needed; currently the UserTranscriptProcessor assumes
|
||||
# that user transcription frames do not need aggregation).
|
||||
self._add_spaces = True
|
||||
|
||||
async def _emit_aggregated_text(self):
|
||||
"""Aggregates and emits text fragments as a transcript message.
|
||||
|
||||
@@ -147,7 +141,7 @@ class AssistantTranscriptProcessor(BaseTranscriptProcessor):
|
||||
Result: "Hello there how are you"
|
||||
"""
|
||||
if self._current_text_parts and self._aggregation_start_time:
|
||||
content = concatenate_aggregated_text(self._current_text_parts, self._add_spaces)
|
||||
content = concatenate_aggregated_text(self._current_text_parts)
|
||||
if content:
|
||||
logger.trace(f"Emitting aggregated assistant message: {content}")
|
||||
message = TranscriptionMessage(
|
||||
@@ -191,11 +185,11 @@ class AssistantTranscriptProcessor(BaseTranscriptProcessor):
|
||||
if not self._aggregation_start_time:
|
||||
self._aggregation_start_time = time_now_iso8601()
|
||||
|
||||
# Track whether we need to add spaces between text parts
|
||||
# Assumption: we can just keep track of the latest frame's value
|
||||
self._add_spaces = not frame.includes_inter_frame_spaces
|
||||
|
||||
self._current_text_parts.append(frame.text)
|
||||
self._current_text_parts.append(
|
||||
TextPartForConcatenation(
|
||||
frame.text, includes_inter_part_spaces=frame.includes_inter_frame_spaces
|
||||
)
|
||||
)
|
||||
|
||||
# Push frame.
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -264,7 +264,10 @@ def _setup_webrtc_routes(
|
||||
# Prepare runner arguments with the callback to run your bot
|
||||
async def webrtc_connection_callback(connection):
|
||||
bot_module = _get_bot_module()
|
||||
runner_args = SmallWebRTCRunnerArguments(webrtc_connection=connection)
|
||||
|
||||
runner_args = SmallWebRTCRunnerArguments(
|
||||
webrtc_connection=connection, body=request.request_data
|
||||
)
|
||||
background_tasks.add_task(bot_module.bot, runner_args)
|
||||
|
||||
# Delegate handling to SmallWebRTCRequestHandler
|
||||
@@ -326,7 +329,8 @@ def _setup_webrtc_routes(
|
||||
type=request_data["type"],
|
||||
pc_id=request_data.get("pc_id"),
|
||||
restart_pc=request_data.get("restart_pc"),
|
||||
request_data=request_data,
|
||||
request_data=request_data.get("request_data")
|
||||
or request_data.get("requestData"),
|
||||
)
|
||||
return await offer(webrtc_request, background_tasks)
|
||||
elif request.method == HTTPMethod.PATCH.value:
|
||||
|
||||
@@ -281,6 +281,14 @@ async def maybe_capture_participant_camera(
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from pipecat.transports.smallwebrtc.transport import SmallWebRTCTransport
|
||||
|
||||
if isinstance(transport, SmallWebRTCTransport):
|
||||
await transport.capture_participant_video(video_source="camera")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
async def maybe_capture_participant_screen(
|
||||
transport: BaseTransport, client: Any, framerate: int = 0
|
||||
@@ -303,6 +311,14 @@ async def maybe_capture_participant_screen(
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from pipecat.transports.smallwebrtc.transport import SmallWebRTCTransport
|
||||
|
||||
if isinstance(transport, SmallWebRTCTransport):
|
||||
await transport.capture_participant_video(video_source="screenVideo")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def _smallwebrtc_sdp_cleanup_ice_candidates(text: str, pattern: str) -> str:
|
||||
"""Clean up ICE candidates in SDP text for SmallWebRTC.
|
||||
|
||||
@@ -199,7 +199,7 @@ class PlivoFrameSerializer(FrameSerializer):
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to hang up Plivo call: {e}")
|
||||
logger.error(f"Failed to hang up Plivo call: {e}")
|
||||
|
||||
async def deserialize(self, data: str | bytes) -> Frame | None:
|
||||
"""Deserializes Plivo WebSocket data to Pipecat frames.
|
||||
|
||||
@@ -225,7 +225,7 @@ class TelnyxFrameSerializer(FrameSerializer):
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to hang up Telnyx call: {e}")
|
||||
logger.error(f"Failed to hang up Telnyx call: {e}")
|
||||
|
||||
async def deserialize(self, data: str | bytes) -> Frame | None:
|
||||
"""Deserializes Telnyx WebSocket data to Pipecat frames.
|
||||
|
||||
@@ -236,7 +236,7 @@ class TwilioFrameSerializer(FrameSerializer):
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to hang up Twilio call: {e}")
|
||||
logger.error(f"Failed to hang up Twilio call: {e}")
|
||||
|
||||
async def deserialize(self, data: str | bytes) -> Frame | None:
|
||||
"""Deserializes Twilio WebSocket data to Pipecat frames.
|
||||
|
||||
@@ -166,6 +166,6 @@ class AIService(FrameProcessor):
|
||||
async for f in generator:
|
||||
if f:
|
||||
if isinstance(f, ErrorFrame):
|
||||
await self.push_error(f)
|
||||
await self.push_error_frame(f)
|
||||
else:
|
||||
await self.push_frame(f)
|
||||
|
||||
@@ -373,9 +373,7 @@ class AnthropicLLMService(LLMService):
|
||||
|
||||
if event.type == "content_block_delta":
|
||||
if hasattr(event.delta, "text"):
|
||||
frame = LLMTextFrame(event.delta.text)
|
||||
frame.includes_inter_frame_spaces = True
|
||||
await self.push_frame(frame)
|
||||
await self.push_frame(LLMTextFrame(event.delta.text))
|
||||
completion_tokens_estimate += self._estimate_tokens(event.delta.text)
|
||||
elif hasattr(event.delta, "partial_json") and tool_use_block:
|
||||
json_accumulator += event.delta.partial_json
|
||||
@@ -460,8 +458,7 @@ class AnthropicLLMService(LLMService):
|
||||
except httpx.TimeoutException:
|
||||
await self._call_event_handler("on_completion_timeout")
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(f"{e}"))
|
||||
await self.push_error(exception=e)
|
||||
finally:
|
||||
await self.stop_processing_metrics()
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
|
||||
@@ -21,6 +21,7 @@ from pipecat import __version__ as pipecat_version
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
@@ -205,8 +206,8 @@ class AssemblyAISTTService(STTService):
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to AssemblyAI: {e}")
|
||||
self._connected = False
|
||||
await self.push_error(exception=e)
|
||||
raise
|
||||
|
||||
async def _disconnect(self):
|
||||
@@ -231,7 +232,7 @@ class AssemblyAISTTService(STTService):
|
||||
logger.warning("Timed out waiting for termination message from server")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error during termination handshake: {e}")
|
||||
await self.push_error(exception=e)
|
||||
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
@@ -239,7 +240,7 @@ class AssemblyAISTTService(STTService):
|
||||
await self._websocket.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during disconnect: {e}")
|
||||
await self.push_error(exception=e)
|
||||
|
||||
finally:
|
||||
self._websocket = None
|
||||
@@ -258,11 +259,11 @@ class AssemblyAISTTService(STTService):
|
||||
except websockets.exceptions.ConnectionClosedOK:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing WebSocket message: {e}")
|
||||
await self.push_error(exception=e)
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Fatal error in receive handler: {e}")
|
||||
await self.push_error(exception=e)
|
||||
|
||||
def _parse_message(self, message: Dict[str, Any]) -> BaseMessage:
|
||||
"""Parse a raw message into the appropriate message type."""
|
||||
@@ -291,7 +292,7 @@ class AssemblyAISTTService(STTService):
|
||||
elif isinstance(parsed_message, TerminationMessage):
|
||||
await self._handle_termination(parsed_message)
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling message: {e}")
|
||||
await self.push_error(exception=e)
|
||||
|
||||
async def _handle_termination(self, message: TerminationMessage):
|
||||
"""Handle termination message."""
|
||||
|
||||
@@ -146,15 +146,6 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that AsyncAI TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that AsyncAI's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert a Language enum to Async language format.
|
||||
|
||||
@@ -237,7 +228,7 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
await self.push_error(exception=e)
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
@@ -249,7 +240,7 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
logger.debug("Disconnecting from Async")
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error closing websocket: {e}")
|
||||
await self.push_error(exception=e)
|
||||
finally:
|
||||
self._websocket = None
|
||||
self._started = False
|
||||
@@ -294,12 +285,11 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
elif msg.get("error_code"):
|
||||
logger.error(f"{self} error: {msg}")
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
await self.stop_all_metrics()
|
||||
await self.push_error(ErrorFrame(f"{self} error: {msg['message']}"))
|
||||
await self.push_error(error_msg=f"{self} error: {msg['message']}")
|
||||
else:
|
||||
logger.error(f"{self} error, unknown message type: {msg}")
|
||||
await self.push_error(error_msg=f"{self} error, unknown message type: {msg}")
|
||||
|
||||
async def _keepalive_task_handler(self):
|
||||
"""Send periodic keepalive messages to maintain WebSocket connection."""
|
||||
@@ -342,14 +332,14 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
await self._get_websocket().send(msg)
|
||||
await self.start_tts_usage_metrics(text)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error sending message: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
return
|
||||
yield None
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
|
||||
class AsyncAIHttpTTSService(TTSService):
|
||||
@@ -429,15 +419,6 @@ class AsyncAIHttpTTSService(TTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that AsyncAI TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that AsyncAI's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert a Language enum to Async language format.
|
||||
|
||||
@@ -491,8 +472,7 @@ class AsyncAIHttpTTSService(TTSService):
|
||||
async with self._session.post(url, json=payload, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"Async API error: {error_text}")
|
||||
await self.push_error(ErrorFrame(f"Async API error: {error_text}"))
|
||||
yield ErrorFrame(error=f"Async API error: {error_text}")
|
||||
raise Exception(f"Async API returned status {response.status}: {error_text}")
|
||||
|
||||
audio_data = await response.read()
|
||||
@@ -508,8 +488,7 @@ class AsyncAIHttpTTSService(TTSService):
|
||||
yield frame
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(f"Error generating TTS: {e}"))
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
finally:
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -1078,9 +1078,7 @@ class AWSBedrockLLMService(LLMService):
|
||||
if "contentBlockDelta" in event:
|
||||
delta = event["contentBlockDelta"]["delta"]
|
||||
if "text" in delta:
|
||||
frame = LLMTextFrame(delta["text"])
|
||||
frame.includes_inter_frame_spaces = True
|
||||
await self.push_frame(frame)
|
||||
await self.push_frame(LLMTextFrame(delta["text"]))
|
||||
completion_tokens_estimate += self._estimate_tokens(delta["text"])
|
||||
elif "toolUse" in delta and "input" in delta["toolUse"]:
|
||||
# Handle partial JSON for tool use
|
||||
@@ -1138,7 +1136,7 @@ class AWSBedrockLLMService(LLMService):
|
||||
except (ReadTimeoutError, asyncio.TimeoutError):
|
||||
await self._call_event_handler("on_completion_timeout")
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
await self.push_error(exception=e)
|
||||
finally:
|
||||
await self.stop_processing_metrics()
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
|
||||
@@ -452,7 +452,7 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
self._ready_to_send_context = True
|
||||
await self._finish_connecting_if_context_available()
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
await self.push_error(exception=e)
|
||||
await self._disconnect()
|
||||
|
||||
async def _process_completed_function_calls(self, send_new_results: bool):
|
||||
@@ -576,7 +576,7 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
|
||||
logger.info("Finished disconnecting")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error disconnecting: {e}")
|
||||
await self.push_error(error_msg=f"Error disconnecting: {e}", exception=e)
|
||||
|
||||
def _create_client(self) -> BedrockRuntimeClient:
|
||||
config = Config(
|
||||
@@ -884,7 +884,7 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
# Errors are kind of expected while disconnecting, so just
|
||||
# ignore them and do nothing
|
||||
return
|
||||
logger.error(f"{self} error processing responses: {e}")
|
||||
await self.push_error(exception=e)
|
||||
if self._wants_connection:
|
||||
await self.reset_conversation()
|
||||
|
||||
|
||||
@@ -140,7 +140,7 @@ class AWSTranscribeSTTService(STTService):
|
||||
return
|
||||
logger.warning("WebSocket connection not established after connect")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect (attempt {retry_count + 1}/{max_retries}): {e}")
|
||||
await self.push_error(exception=e)
|
||||
retry_count += 1
|
||||
if retry_count < max_retries:
|
||||
await asyncio.sleep(1) # Wait before retrying
|
||||
@@ -181,8 +181,7 @@ class AWSTranscribeSTTService(STTService):
|
||||
try:
|
||||
await self._connect()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reconnect: {e}")
|
||||
yield ErrorFrame("Failed to reconnect to AWS Transcribe", fatal=False)
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
return
|
||||
|
||||
# Format the audio data according to AWS event stream format
|
||||
@@ -199,13 +198,11 @@ class AWSTranscribeSTTService(STTService):
|
||||
await self._disconnect()
|
||||
# Don't yield error here - we'll retry on next frame
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending audio: {e}")
|
||||
yield ErrorFrame(f"AWS Transcribe error: {str(e)}", fatal=False)
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
await self._disconnect()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in run_stt: {e}")
|
||||
yield ErrorFrame(f"AWS Transcribe error: {str(e)}", fatal=False)
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
await self._disconnect()
|
||||
|
||||
async def _connect(self):
|
||||
@@ -288,7 +285,7 @@ class AWSTranscribeSTTService(STTService):
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} Failed to connect to AWS Transcribe: {e}")
|
||||
await self.push_error(exception=e)
|
||||
await self._disconnect()
|
||||
raise
|
||||
|
||||
@@ -308,7 +305,7 @@ class AWSTranscribeSTTService(STTService):
|
||||
await self._ws_client.send(json.dumps(end_stream))
|
||||
await self._ws_client.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"{self} Error closing WebSocket connection: {e}")
|
||||
await self.push_error(exception=e)
|
||||
finally:
|
||||
self._ws_client = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
@@ -526,16 +523,15 @@ class AWSTranscribeSTTService(STTService):
|
||||
)
|
||||
elif headers.get(":message-type") == "exception":
|
||||
error_msg = payload.get("Message", "Unknown error")
|
||||
logger.error(f"{self} Exception from AWS: {error_msg}")
|
||||
await self.push_frame(
|
||||
ErrorFrame(f"AWS Transcribe error: {error_msg}", fatal=False)
|
||||
)
|
||||
await self.push_error(error_msg=f"AWS Transcribe error: {error_msg}")
|
||||
else:
|
||||
logger.debug(f"{self} Other message type received: {headers}")
|
||||
logger.debug(f"{self} Payload: {payload}")
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
logger.error(f"{self} WebSocket connection closed in receive loop: {e}")
|
||||
await self.push_error(
|
||||
error_msg=f"WebSocket connection closed in receive loop", exception=e
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"{self} Unexpected error in receive loop: {e}")
|
||||
await self.push_error(exception=e)
|
||||
break
|
||||
|
||||
@@ -209,15 +209,6 @@ class AWSPollyTTSService(TTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that AWS TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that AWS's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert a Language enum to AWS Polly language format.
|
||||
|
||||
@@ -321,7 +312,6 @@ class AWSPollyTTSService(TTSService):
|
||||
|
||||
yield TTSStoppedFrame()
|
||||
except (BotoCoreError, ClientError) as error:
|
||||
logger.exception(f"{self} error generating TTS: {error}")
|
||||
error_message = f"AWS Polly TTS error: {str(error)}"
|
||||
yield ErrorFrame(error=error_message)
|
||||
|
||||
|
||||
@@ -91,7 +91,6 @@ class AzureImageGenServiceREST(ImageGenService):
|
||||
while status != "succeeded":
|
||||
attempts_left -= 1
|
||||
if attempts_left == 0:
|
||||
logger.error(f"{self} error: image generation timed out")
|
||||
yield ErrorFrame("Image generation timed out")
|
||||
return
|
||||
|
||||
@@ -104,7 +103,6 @@ class AzureImageGenServiceREST(ImageGenService):
|
||||
|
||||
image_url = json_response["result"]["data"][0]["url"] if json_response else None
|
||||
if not image_url:
|
||||
logger.error(f"{self} error: image generation failed")
|
||||
yield ErrorFrame("Image generation failed")
|
||||
return
|
||||
|
||||
|
||||
@@ -61,5 +61,5 @@ class AzureRealtimeLLMService(OpenAIRealtimeLLMService):
|
||||
)
|
||||
self._receive_task = self.create_task(self._receive_task_handler())
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
await self.push_error(exception=e)
|
||||
self._websocket = None
|
||||
|
||||
@@ -18,6 +18,7 @@ from loguru import logger
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
@@ -111,13 +112,16 @@ class AzureSTTService(STTService):
|
||||
audio: Raw audio bytes to process.
|
||||
|
||||
Yields:
|
||||
None - actual transcription frames are pushed via callbacks.
|
||||
Frame: Either None for successful processing or ErrorFrame on failure.
|
||||
"""
|
||||
await self.start_processing_metrics()
|
||||
await self.start_ttfb_metrics()
|
||||
if self._audio_stream:
|
||||
self._audio_stream.write(audio)
|
||||
yield None
|
||||
try:
|
||||
await self.start_processing_metrics()
|
||||
await self.start_ttfb_metrics()
|
||||
if self._audio_stream:
|
||||
self._audio_stream.write(audio)
|
||||
yield None
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the speech recognition service.
|
||||
@@ -133,17 +137,22 @@ class AzureSTTService(STTService):
|
||||
if self._audio_stream:
|
||||
return
|
||||
|
||||
stream_format = AudioStreamFormat(samples_per_second=self.sample_rate, channels=1)
|
||||
self._audio_stream = PushAudioInputStream(stream_format)
|
||||
try:
|
||||
stream_format = AudioStreamFormat(samples_per_second=self.sample_rate, channels=1)
|
||||
self._audio_stream = PushAudioInputStream(stream_format)
|
||||
|
||||
audio_config = AudioConfig(stream=self._audio_stream)
|
||||
audio_config = AudioConfig(stream=self._audio_stream)
|
||||
|
||||
self._speech_recognizer = SpeechRecognizer(
|
||||
speech_config=self._speech_config, audio_config=audio_config
|
||||
)
|
||||
self._speech_recognizer.recognizing.connect(self._on_handle_recognizing)
|
||||
self._speech_recognizer.recognized.connect(self._on_handle_recognized)
|
||||
self._speech_recognizer.start_continuous_recognition_async()
|
||||
self._speech_recognizer = SpeechRecognizer(
|
||||
speech_config=self._speech_config, audio_config=audio_config
|
||||
)
|
||||
self._speech_recognizer.recognizing.connect(self._on_handle_recognizing)
|
||||
self._speech_recognizer.recognized.connect(self._on_handle_recognized)
|
||||
self._speech_recognizer.start_continuous_recognition_async()
|
||||
except Exception as e:
|
||||
await self.push_error(
|
||||
error_msg=f"{self} exception during initialization: {e}", exception=e
|
||||
)
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the speech recognition service.
|
||||
|
||||
@@ -151,15 +151,6 @@ class AzureBaseTTSService(TTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that Azure TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that Azure's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert a Language enum to Azure language format.
|
||||
|
||||
@@ -336,8 +327,7 @@ class AzureTTSService(AzureBaseTTSService):
|
||||
try:
|
||||
if self._speech_synthesizer is None:
|
||||
error_msg = "Speech synthesizer not initialized."
|
||||
logger.error(error_msg)
|
||||
yield ErrorFrame(error_msg)
|
||||
yield ErrorFrame(error=error_msg)
|
||||
return
|
||||
|
||||
try:
|
||||
@@ -364,13 +354,13 @@ class AzureTTSService(AzureBaseTTSService):
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error during synthesis: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
# Could add reconnection logic here if needed
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
|
||||
class AzureHttpTTSService(AzureBaseTTSService):
|
||||
@@ -447,4 +437,4 @@ class AzureHttpTTSService(AzureBaseTTSService):
|
||||
cancellation_details = result.cancellation_details
|
||||
logger.warning(f"Speech synthesis canceled: {cancellation_details.reason}")
|
||||
if cancellation_details.reason == CancellationReason.Error:
|
||||
logger.error(f"{self} error: {cancellation_details.error_details}")
|
||||
yield ErrorFrame(error=f"{self} error: {cancellation_details.error_details}")
|
||||
|
||||
@@ -20,6 +20,7 @@ from loguru import logger
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
@@ -275,7 +276,7 @@ class CartesiaSTTService(WebsocketSTTService):
|
||||
self._websocket = await websocket_connect(ws_url, additional_headers=headers)
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self}: unable to connect to Cartesia: {e}")
|
||||
await self.push_error(exception=e)
|
||||
|
||||
async def _disconnect_websocket(self):
|
||||
try:
|
||||
@@ -283,7 +284,7 @@ class CartesiaSTTService(WebsocketSTTService):
|
||||
logger.debug("Disconnecting from Cartesia STT")
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error closing websocket: {e}")
|
||||
await self.push_error(error_msg=f"{self} error closing websocket: {e}", exception=e)
|
||||
finally:
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
@@ -315,7 +316,8 @@ class CartesiaSTTService(WebsocketSTTService):
|
||||
await self._on_transcript(data)
|
||||
|
||||
elif data["type"] == "error":
|
||||
logger.error(f"Cartesia error: {data.get('message', 'Unknown error')}")
|
||||
error_msg = data.get("message", "Unknown error")
|
||||
await self.push_error(error_msg=error_msg)
|
||||
|
||||
@traced_stt
|
||||
async def _handle_transcription(
|
||||
|
||||
@@ -397,7 +397,7 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
||||
)
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
await self.push_error(exception=e)
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
@@ -409,7 +409,7 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
||||
logger.debug("Disconnecting from Cartesia")
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error closing websocket: {e}")
|
||||
await self.push_error(exception=e)
|
||||
finally:
|
||||
self._context_id = None
|
||||
self._websocket = None
|
||||
@@ -462,13 +462,12 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
||||
)
|
||||
await self.append_to_audio_context(msg["context_id"], frame)
|
||||
elif msg["type"] == "error":
|
||||
logger.error(f"{self} error: {msg}")
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
await self.stop_all_metrics()
|
||||
await self.push_error(ErrorFrame(f"{self} error: {msg['error']}"))
|
||||
await self.push_error(error_msg=f"{self} error: {msg}")
|
||||
self._context_id = None
|
||||
else:
|
||||
logger.error(f"{self} error, unknown message type: {msg}")
|
||||
await self.push_error(error_msg=f"{self} error, unknown message type: {msg}")
|
||||
|
||||
async def _receive_messages(self):
|
||||
while True:
|
||||
@@ -506,14 +505,14 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
||||
await self._get_websocket().send(msg)
|
||||
await self.start_tts_usage_metrics(text)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error sending message: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
return
|
||||
yield None
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
|
||||
class CartesiaHttpTTSService(TTSService):
|
||||
@@ -704,8 +703,7 @@ class CartesiaHttpTTSService(TTSService):
|
||||
async with session.post(url, json=payload, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"Cartesia API error: {error_text}")
|
||||
await self.push_error(ErrorFrame(f"Cartesia API error: {error_text}"))
|
||||
yield ErrorFrame(error=f"Cartesia API error: {error_text}")
|
||||
raise Exception(f"Cartesia API returned status {response.status}: {error_text}")
|
||||
|
||||
audio_data = await response.read()
|
||||
@@ -721,8 +719,7 @@ class CartesiaHttpTTSService(TTSService):
|
||||
yield frame
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(f"Error generating TTS: {e}"))
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
finally:
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -6,7 +6,9 @@
|
||||
|
||||
"""Deepgram Flux speech-to-text service implementation."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Any, AsyncGenerator, Dict, Optional
|
||||
from urllib.parse import urlencode
|
||||
@@ -94,6 +96,7 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
mip_opt_out: Optional. Opts out requests from the Deepgram Model Improvement Program
|
||||
(default False).
|
||||
tag: List of tags to label requests for identification during usage reporting.
|
||||
min_confidence: Optional. Minimum confidence required confidence to create a TranscriptionFrame
|
||||
"""
|
||||
|
||||
eager_eot_threshold: Optional[float] = None
|
||||
@@ -102,6 +105,7 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
keyterm: list = []
|
||||
mip_opt_out: Optional[bool] = None
|
||||
tag: list = []
|
||||
min_confidence: Optional[float] = None # New parameter
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -163,6 +167,13 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
self._register_event_handler("on_end_of_turn")
|
||||
self._register_event_handler("on_eager_end_of_turn")
|
||||
self._register_event_handler("on_update")
|
||||
self._connection_established_event = asyncio.Event()
|
||||
# Watchdog task to prevent dangling tasks
|
||||
# If we stop sending audio to Flux after we have received that the User has started speaking
|
||||
# we never receive the user stopped speaking event unless we resume sending audio to it.
|
||||
self._last_stt_time = None
|
||||
self._watchdog_task = None
|
||||
self._user_is_speaking = False
|
||||
|
||||
async def _connect(self):
|
||||
"""Connect to WebSocket and start background tasks.
|
||||
@@ -172,9 +183,6 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
"""
|
||||
await self._connect_websocket()
|
||||
|
||||
if self._websocket and not self._receive_task:
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Disconnect from WebSocket and clean up tasks.
|
||||
|
||||
@@ -182,20 +190,32 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
and cleans up resources to prevent memory leaks.
|
||||
"""
|
||||
try:
|
||||
# Cancel background tasks BEFORE closing websocket
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task, timeout=2.0)
|
||||
self._receive_task = None
|
||||
|
||||
# Now close the websocket
|
||||
await self._disconnect_websocket()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during disconnect: {e}")
|
||||
await self.push_error(exception=e)
|
||||
finally:
|
||||
# Reset state only after everything is cleaned up
|
||||
self._websocket = None
|
||||
|
||||
async def _send_silence(self, duration_secs: float = 0.5):
|
||||
"""Send a block of silence of the specified duration (default 500 ms)."""
|
||||
sample_width = 2 # bytes per sample for 16-bit PCM
|
||||
num_channels = 1 # mono
|
||||
num_samples = int(self.sample_rate * duration_secs)
|
||||
silence = b"\x00" * (num_samples * sample_width * num_channels)
|
||||
await self._websocket.send(silence)
|
||||
|
||||
async def _watchdog_task_handler(self):
|
||||
while self._websocket and self._websocket.state is State.OPEN:
|
||||
now = time.monotonic()
|
||||
# More than 500 ms without sending new audio to Flux
|
||||
if self._user_is_speaking and self._last_stt_time and now - self._last_stt_time > 0.5:
|
||||
logger.warning("Sending silence to Flux to prevent dangling task")
|
||||
await self._send_silence()
|
||||
self._last_stt_time = time.monotonic()
|
||||
# check every 100ms
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def _connect_websocket(self):
|
||||
"""Establish WebSocket connection to API.
|
||||
|
||||
@@ -207,14 +227,30 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
return
|
||||
|
||||
self._connection_established_event.clear()
|
||||
self._user_is_speaking = False
|
||||
self._websocket = await websocket_connect(
|
||||
self._websocket_url,
|
||||
additional_headers={"Authorization": f"Token {self._api_key}"},
|
||||
)
|
||||
|
||||
# Creating the receiver task
|
||||
if not self._receive_task:
|
||||
self._receive_task = self.create_task(
|
||||
self._receive_task_handler(self._report_error)
|
||||
)
|
||||
|
||||
# Creating the watchdog task
|
||||
if not self._watchdog_task:
|
||||
self._watchdog_task = self.create_task(self._watchdog_task_handler())
|
||||
|
||||
# Now wait for the connection established event
|
||||
logger.debug("WebSocket connected, waiting for server confirmation...")
|
||||
await self._connection_established_event.wait()
|
||||
logger.debug("Connected to Deepgram Flux Websocket")
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
await self.push_error(exception=e)
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
@@ -225,6 +261,16 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
metrics collection. Handles disconnection errors gracefully.
|
||||
"""
|
||||
try:
|
||||
# Cancel background tasks BEFORE closing websocket
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task, timeout=2.0)
|
||||
self._receive_task = None
|
||||
if self._watchdog_task:
|
||||
await self.cancel_task(self._watchdog_task, timeout=2.0)
|
||||
self._watchdog_task = None
|
||||
self._last_stt_time = None
|
||||
|
||||
self._connection_established_event.clear()
|
||||
await self.stop_all_metrics()
|
||||
|
||||
if self._websocket:
|
||||
@@ -232,7 +278,7 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
logger.debug("Disconnecting from Deepgram Flux Websocket")
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error closing websocket: {e}")
|
||||
await self.push_error(error_msg=f"{self} error closing websocket: {e}", exception=e)
|
||||
finally:
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
@@ -332,15 +378,14 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
are issues sending the audio data.
|
||||
"""
|
||||
if not self._websocket:
|
||||
logger.error("Not connected to Deepgram Flux.")
|
||||
yield ErrorFrame("Not connected to Deepgram Flux.", fatal=True)
|
||||
yield ErrorFrame("Not connected to Deepgram Flux.")
|
||||
return
|
||||
|
||||
try:
|
||||
await self._websocket.send(audio)
|
||||
self._last_stt_time = time.monotonic()
|
||||
await self.send_with_retry(audio, self._report_error)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send audio to Flux: {e}")
|
||||
yield ErrorFrame(f"Failed to send audio to Flux: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
return
|
||||
|
||||
yield None
|
||||
@@ -417,7 +462,7 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
# Skip malformed messages
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message: {e}")
|
||||
await self.push_error(exception=e)
|
||||
# Error will be handled inside WebsocketService->_receive_task_handler
|
||||
raise
|
||||
else:
|
||||
@@ -459,6 +504,8 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
transcription processing.
|
||||
"""
|
||||
logger.info("Connected to Flux - ready to stream audio")
|
||||
# Notify connection is established
|
||||
self._connection_established_event.set()
|
||||
|
||||
async def _handle_fatal_error(self, data: Dict[str, Any]):
|
||||
"""Handle fatal error messages from Deepgram Flux.
|
||||
@@ -526,6 +573,7 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
transcript: maybe the first few words of the turn.
|
||||
"""
|
||||
logger.debug("User started speaking")
|
||||
self._user_is_speaking = True
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.broadcast_frame(UserStartedSpeakingFrame)
|
||||
await self.start_metrics()
|
||||
@@ -546,6 +594,22 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
logger.trace(f"Received event TurnResumed: {event}")
|
||||
await self._call_event_handler("on_turn_resumed")
|
||||
|
||||
def _calculate_average_confidence(self, transcript_data) -> Optional[float]:
|
||||
"""Calculate the average confidence from transcript data.
|
||||
|
||||
Return None if the data is missing or invalid.
|
||||
"""
|
||||
# Example: Assume transcript_data has a list of words with confidence
|
||||
words = transcript_data.get("words")
|
||||
if not words or not isinstance(words, list):
|
||||
return None
|
||||
confidences = [
|
||||
w.get("confidence") for w in words if isinstance(w.get("confidence"), (float, int))
|
||||
]
|
||||
if not confidences:
|
||||
return None
|
||||
return sum(confidences) / len(confidences)
|
||||
|
||||
async def _handle_end_of_turn(self, transcript: str, data: Dict[str, Any]):
|
||||
"""Handle EndOfTurn events from Deepgram Flux.
|
||||
|
||||
@@ -565,16 +629,26 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
data: The TurnInfo message data containing event type, transcript and some extra metadata.
|
||||
"""
|
||||
logger.debug("User stopped speaking")
|
||||
self._user_is_speaking = False
|
||||
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
self._language,
|
||||
result=data,
|
||||
# Compute the average confidence
|
||||
average_confidence = self._calculate_average_confidence(data)
|
||||
|
||||
if not self._params.min_confidence or average_confidence > self._params.min_confidence:
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
self._language,
|
||||
result=data,
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Transcription confidence below min_confidence threshold: {average_confidence}"
|
||||
)
|
||||
|
||||
await self._handle_transcription(transcript, True, self._language)
|
||||
await self.stop_processing_metrics()
|
||||
await self.push_frame(UserStoppedSpeakingFrame(), FrameDirection.DOWNSTREAM)
|
||||
|
||||
@@ -233,7 +233,7 @@ class DeepgramSTTService(STTService):
|
||||
)
|
||||
|
||||
if not await self._connection.start(options=self._settings, addons=self._addons):
|
||||
logger.error(f"{self}: unable to connect to Deepgram")
|
||||
await self.push_error(error_msg=f"Unable to connect to Deepgram")
|
||||
|
||||
async def _disconnect(self):
|
||||
if await self._connection.is_connected():
|
||||
@@ -256,7 +256,7 @@ class DeepgramSTTService(STTService):
|
||||
async def _on_error(self, *args, **kwargs):
|
||||
error: ErrorResponse = kwargs["error"]
|
||||
logger.warning(f"{self} connection error, will retry: {error}")
|
||||
await self.push_error(ErrorFrame(f"{error}"))
|
||||
await self.push_error(error_msg=f"{error}")
|
||||
await self.stop_all_metrics()
|
||||
# NOTE(aleix): we don't disconnect (i.e. call finish on the connection)
|
||||
# because this triggers more errors internally in the Deepgram SDK. So,
|
||||
|
||||
@@ -79,15 +79,6 @@ class DeepgramTTSService(TTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that Deepgram TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that Deepgram's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate speech from text using Deepgram's TTS API.
|
||||
@@ -125,8 +116,7 @@ class DeepgramTTSService(TTSService):
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
yield ErrorFrame(f"Error getting audio: {str(e)}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
|
||||
class DeepgramHttpTTSService(TTSService):
|
||||
@@ -177,15 +167,6 @@ class DeepgramHttpTTSService(TTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that Deepgram TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that Deepgram's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate speech from text using Deepgram's TTS API.
|
||||
@@ -245,5 +226,4 @@ class DeepgramHttpTTSService(TTSService):
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
yield ErrorFrame(f"Error getting audio: {str(e)}")
|
||||
|
||||
@@ -351,8 +351,7 @@ class ElevenLabsSTTService(SegmentedSTTService):
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ElevenLabs STT error: {e}")
|
||||
yield ErrorFrame(f"ElevenLabs STT error: {str(e)}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
|
||||
def audio_format_from_sample_rate(sample_rate: int) -> str:
|
||||
@@ -586,7 +585,6 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
}
|
||||
await self._websocket.send(json.dumps(message))
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending audio: {e}")
|
||||
yield ErrorFrame(f"ElevenLabs Realtime STT error: {str(e)}")
|
||||
|
||||
yield None
|
||||
@@ -645,8 +643,9 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
await self._call_event_handler("on_connected")
|
||||
logger.debug("Connected to ElevenLabs Realtime STT")
|
||||
except Exception as e:
|
||||
logger.error(f"{self}: unable to connect to ElevenLabs Realtime STT: {e}")
|
||||
await self.push_error(ErrorFrame(f"Connection error: {str(e)}"))
|
||||
await self.push_error(
|
||||
error_msg=f"{self}: unable to connect to ElevenLabs Realtime STT: {e}", exception=e
|
||||
)
|
||||
|
||||
async def _disconnect_websocket(self):
|
||||
"""Disconnect from ElevenLabs Realtime STT WebSocket."""
|
||||
@@ -655,7 +654,7 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
logger.debug("Disconnecting from ElevenLabs Realtime STT")
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error closing websocket: {e}")
|
||||
await self.push_error(error_msg=f"Error closing websocket: {e}", exception=e)
|
||||
finally:
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
@@ -714,13 +713,11 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
|
||||
elif message_type == "input_error":
|
||||
error_msg = data.get("error", "Unknown input error")
|
||||
logger.error(f"ElevenLabs input error: {error_msg}")
|
||||
await self.push_error(ErrorFrame(f"Input error: {error_msg}"))
|
||||
await self.push_error(error_msg=f"ElevenLabs input error: {error_msg}")
|
||||
|
||||
elif message_type in ["auth_error", "quota_exceeded", "transcriber_error", "error"]:
|
||||
error_msg = data.get("error", data.get("message", "Unknown error"))
|
||||
logger.error(f"ElevenLabs error ({message_type}): {error_msg}")
|
||||
await self.push_error(ErrorFrame(f"{message_type}: {error_msg}"))
|
||||
await self.push_error(error_msg=f"ElevenLabs error ({message_type}): {error_msg}")
|
||||
|
||||
else:
|
||||
logger.debug(f"Unknown message type: {message_type}")
|
||||
|
||||
@@ -424,7 +424,7 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
json.dumps({"context_id": self._context_id, "close_context": True})
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing context for voice settings update: {e}")
|
||||
await self.push_error(exception=e)
|
||||
self._context_id = None
|
||||
self._started = False
|
||||
|
||||
@@ -535,8 +535,8 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._websocket = None
|
||||
await self.push_error(exception=e)
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
async def _disconnect_websocket(self):
|
||||
@@ -551,7 +551,7 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
await self._websocket.close()
|
||||
logger.debug("Disconnected from ElevenLabs")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error closing websocket: {e}")
|
||||
await self.push_error(exception=e)
|
||||
finally:
|
||||
self._started = False
|
||||
self._context_id = None
|
||||
@@ -581,7 +581,7 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
json.dumps({"context_id": self._context_id, "close_context": True})
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing context on interruption: {e}")
|
||||
await self.push_error(exception=e)
|
||||
self._context_id = None
|
||||
self._started = False
|
||||
self._partial_word = ""
|
||||
@@ -736,13 +736,13 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
else:
|
||||
await self._send_text(text)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error sending message: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
self._started = False
|
||||
return
|
||||
yield None
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
|
||||
class ElevenLabsHttpTTSService(WordTTSService):
|
||||
@@ -1037,7 +1037,6 @@ class ElevenLabsHttpTTSService(WordTTSService):
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"{self} error: {error_text}")
|
||||
yield ErrorFrame(error=f"ElevenLabs API error: {error_text}")
|
||||
return
|
||||
|
||||
@@ -1085,7 +1084,7 @@ class ElevenLabsHttpTTSService(WordTTSService):
|
||||
logger.warning(f"Failed to parse JSON from stream: {e}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing response: {e}", exc_info=True)
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
continue
|
||||
|
||||
# After processing all chunks, emit any remaining partial word
|
||||
@@ -1109,8 +1108,7 @@ class ElevenLabsHttpTTSService(WordTTSService):
|
||||
self._previous_text = text
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in run_tts: {e}")
|
||||
yield ErrorFrame(error=str(e))
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
finally:
|
||||
await self.stop_ttfb_metrics()
|
||||
# Let the parent class handle TTSStoppedFrame
|
||||
|
||||
@@ -110,7 +110,6 @@ class FalImageGenService(ImageGenService):
|
||||
image_url = response["images"][0]["url"] if response else None
|
||||
|
||||
if not image_url:
|
||||
logger.error(f"{self} error: image generation failed")
|
||||
yield ErrorFrame("Image generation failed")
|
||||
return
|
||||
|
||||
|
||||
@@ -290,5 +290,4 @@ class FalSTTService(SegmentedSTTService):
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Fal Wizper error: {e}")
|
||||
yield ErrorFrame(f"Fal Wizper error: {str(e)}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
@@ -159,15 +159,6 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that Fish Audio TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that Fish Audio's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def set_model(self, model: str):
|
||||
"""Set the TTS model and reconnect.
|
||||
|
||||
@@ -237,7 +228,7 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"Fish Audio initialization error: {e}")
|
||||
await self.push_error(exception=e)
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
@@ -251,7 +242,7 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
await self._websocket.send(ormsgpack.packb(stop_message))
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing websocket: {e}")
|
||||
await self.push_error(exception=e)
|
||||
finally:
|
||||
self._request_id = None
|
||||
self._started = False
|
||||
@@ -293,7 +284,7 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message: {e}")
|
||||
await self.push_error(exception=e)
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
@@ -329,7 +320,7 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
flush_message = {"event": "flush"}
|
||||
await self._get_websocket().send(ormsgpack.packb(flush_message))
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error sending message: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
@@ -337,5 +328,4 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
yield None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating TTS: {e}")
|
||||
yield ErrorFrame(f"Error: {str(e)}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
@@ -23,6 +23,7 @@ from pipecat import __version__ as pipecat_version
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
@@ -467,7 +468,7 @@ class GladiaSTTService(STTService):
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in connection handler: {e}")
|
||||
await self.push_error(exception=e)
|
||||
self._connection_active = False
|
||||
|
||||
if not self._should_reconnect:
|
||||
@@ -557,7 +558,7 @@ class GladiaSTTService(STTService):
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logger.debug("Connection closed during keepalive")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Gladia keepalive task: {e}")
|
||||
await self.push_error(exception=e)
|
||||
|
||||
async def _receive_task_handler(self):
|
||||
try:
|
||||
@@ -620,7 +621,7 @@ class GladiaSTTService(STTService):
|
||||
# Expected when closing the connection
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Gladia WebSocket handler: {e}")
|
||||
await self.push_error(exception=e)
|
||||
|
||||
async def _maybe_reconnect(self) -> bool:
|
||||
"""Handle exponential backoff reconnection logic."""
|
||||
@@ -628,7 +629,9 @@ class GladiaSTTService(STTService):
|
||||
return False
|
||||
self._reconnection_attempts += 1
|
||||
if self._reconnection_attempts > self._max_reconnection_attempts:
|
||||
logger.error(f"Max reconnection attempts ({self._max_reconnection_attempts}) reached")
|
||||
await self.push_error(
|
||||
error_msg=f"Max reconnection attempts ({self._max_reconnection_attempts}) reached"
|
||||
)
|
||||
self._should_reconnect = False
|
||||
return False
|
||||
delay = self._reconnection_delay * (2 ** (self._reconnection_attempts - 1))
|
||||
|
||||
@@ -1174,7 +1174,7 @@ class GeminiLiveLLMService(LLMService):
|
||||
self._connection_task = self.create_task(self._connection_task_handler(config=config))
|
||||
|
||||
except Exception as e:
|
||||
await self.push_error(ErrorFrame(error=f"{self} Initialization error: {e}", fatal=True))
|
||||
await self.push_error(exception=e)
|
||||
|
||||
async def _connection_task_handler(self, config: LiveConnectConfig):
|
||||
async with self._client.aio.live.connect(model=self._model_name, config=config) as session:
|
||||
@@ -1251,13 +1251,11 @@ class GeminiLiveLLMService(LLMService):
|
||||
)
|
||||
|
||||
if self._consecutive_failures >= MAX_CONSECUTIVE_FAILURES:
|
||||
logger.error(
|
||||
error_msg = (
|
||||
f"Max consecutive failures ({MAX_CONSECUTIVE_FAILURES}) reached, "
|
||||
"treating as fatal error"
|
||||
)
|
||||
await self.push_error(
|
||||
ErrorFrame(error=f"{self} Error in receive loop: {error}", fatal=True)
|
||||
)
|
||||
await self.push_error(error_msg=error_msg, exception=error)
|
||||
return False
|
||||
else:
|
||||
logger.info(
|
||||
@@ -1285,7 +1283,7 @@ class GeminiLiveLLMService(LLMService):
|
||||
self._completed_tool_calls = set()
|
||||
self._disconnecting = False
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error disconnecting: {e}")
|
||||
await self.push_error(error_msg=f"Error disconnecting: {e}", exception=e)
|
||||
|
||||
async def _send_user_audio(self, frame):
|
||||
"""Send user audio frame to Gemini Live API."""
|
||||
@@ -1454,8 +1452,6 @@ class GeminiLiveLLMService(LLMService):
|
||||
self._bot_text_buffer += text
|
||||
self._search_result_buffer += text # Also accumulate for grounding
|
||||
frame = LLMTextFrame(text=text)
|
||||
# Gemini Live text already includes any necessary inter-chunk spaces
|
||||
frame.includes_inter_frame_spaces = True
|
||||
await self.push_frame(frame)
|
||||
|
||||
# Check for grounding metadata in server content
|
||||
@@ -1746,7 +1742,7 @@ class GeminiLiveLLMService(LLMService):
|
||||
# state management, and that exponential backoff for retries can have
|
||||
# cost/stability implications for a service cluster, let's just treat a
|
||||
# send-side error as fatal.
|
||||
await self.push_error(ErrorFrame(error=f"{self} Send error: {error}", fatal=True))
|
||||
await self.push_error(ErrorFrame(error=f"{self} Send error: {error}"))
|
||||
|
||||
def create_context_aggregator(
|
||||
self,
|
||||
|
||||
@@ -110,7 +110,6 @@ class GoogleImageGenService(ImageGenService):
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
if not response or not response.generated_images:
|
||||
logger.error(f"{self} error: image generation failed")
|
||||
yield ErrorFrame("Image generation failed")
|
||||
return
|
||||
|
||||
@@ -128,5 +127,4 @@ class GoogleImageGenService(ImageGenService):
|
||||
yield frame
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error generating image: {e}")
|
||||
yield ErrorFrame(f"Image generation error: {str(e)}")
|
||||
|
||||
@@ -793,7 +793,7 @@ class GoogleLLMService(LLMService):
|
||||
return
|
||||
generation_params.setdefault("thinking_config", {})["thinking_budget"] = 0
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to unset thinking budget: {e}")
|
||||
logger.error(f"Failed to unset thinking budget: {e}")
|
||||
|
||||
async def _stream_content(
|
||||
self, params_from_context: GeminiLLMInvocationParams
|
||||
@@ -920,9 +920,7 @@ class GoogleLLMService(LLMService):
|
||||
for part in candidate.content.parts:
|
||||
if not part.thought and part.text:
|
||||
search_result += part.text
|
||||
frame = LLMTextFrame(part.text)
|
||||
frame.includes_inter_frame_spaces = True
|
||||
await self.push_frame(frame)
|
||||
await self.push_frame(LLMTextFrame(part.text))
|
||||
elif part.function_call:
|
||||
function_call = part.function_call
|
||||
id = function_call.id or str(uuid.uuid4())
|
||||
@@ -985,7 +983,7 @@ class GoogleLLMService(LLMService):
|
||||
except DeadlineExceeded:
|
||||
await self._call_event_handler("on_completion_timeout")
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
await self.push_error(exception=e)
|
||||
finally:
|
||||
if grounding_metadata and isinstance(grounding_metadata, dict):
|
||||
llm_search_frame = LLMSearchResponseFrame(
|
||||
|
||||
@@ -774,7 +774,7 @@ class GoogleSTTService(STTService):
|
||||
yield cloud_speech.StreamingRecognizeRequest(audio=audio_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in request generator: {e}")
|
||||
await self.push_error(exception=e)
|
||||
raise
|
||||
|
||||
async def _stream_audio(self):
|
||||
@@ -805,14 +805,13 @@ class GoogleSTTService(STTService):
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"{self} Reconnecting: {e}")
|
||||
await self.push_error(exception=e)
|
||||
|
||||
await asyncio.sleep(1) # Brief delay before reconnecting
|
||||
self._stream_start_time = int(time.time() * 1000)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in streaming task: {e}")
|
||||
await self.push_frame(ErrorFrame(str(e)))
|
||||
await self.push_error(exception=e)
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
"""Process an audio chunk for STT transcription.
|
||||
@@ -900,7 +899,7 @@ class GoogleSTTService(STTService):
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing Google STT responses: {e}")
|
||||
await self.push_error(exception=e)
|
||||
# Re-raise the exception to let it propagate (e.g. in the case of a
|
||||
# timeout, propagate to _stream_audio to reconnect)
|
||||
raise
|
||||
|
||||
@@ -596,15 +596,6 @@ class GoogleHttpTTSService(TTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that Google TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that Google's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert a Language enum to Google TTS language format.
|
||||
|
||||
@@ -746,7 +737,6 @@ class GoogleHttpTTSService(TTSService):
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} error generating TTS: {e}")
|
||||
error_message = f"TTS generation error: {str(e)}"
|
||||
yield ErrorFrame(error=error_message)
|
||||
|
||||
@@ -803,15 +793,6 @@ class GoogleBaseTTSService(TTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that Google and Gemini TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that Google's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert a Language enum to Google TTS language format.
|
||||
|
||||
@@ -1014,9 +995,7 @@ class GoogleTTSService(GoogleBaseTTSService):
|
||||
yield frame
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} error generating TTS: {e}")
|
||||
error_message = f"TTS generation error: {str(e)}"
|
||||
yield ErrorFrame(error=error_message)
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
|
||||
class GeminiTTSService(GoogleBaseTTSService):
|
||||
@@ -1266,6 +1245,5 @@ class GeminiTTSService(GoogleBaseTTSService):
|
||||
yield frame
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} error generating TTS: {e}")
|
||||
error_message = f"Gemini TTS generation error: {str(e)}"
|
||||
yield ErrorFrame(error=error_message)
|
||||
|
||||
@@ -13,7 +13,13 @@ from typing import AsyncGenerator, Optional
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import Frame, TTSAudioRawFrame, TTSStartedFrame, TTSStoppedFrame
|
||||
from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.tts_service import TTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
@@ -105,15 +111,6 @@ class GroqTTSService(TTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that Groq TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that Groq's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate speech from text using Groq's TTS API.
|
||||
@@ -149,6 +146,6 @@ class GroqTTSService(TTSService):
|
||||
bytes = w.readframes(num_frames)
|
||||
yield TTSAudioRawFrame(bytes, frame_rate, channels)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -179,7 +179,7 @@ class HeyGenClient:
|
||||
await self._task_manager.cancel_task(self._event_task)
|
||||
self._event_task = None
|
||||
except Exception as e:
|
||||
logger.exception(f"Exception during cleanup: {e}")
|
||||
logger.error(f"Exception during cleanup: {e}")
|
||||
|
||||
async def start(self, frame: StartFrame, audio_chunk_size: int) -> None:
|
||||
"""Start the client and establish all necessary connections.
|
||||
|
||||
@@ -14,12 +14,14 @@ from pydantic import BaseModel
|
||||
from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.tts_service import TTSService
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.tts_service import WordTTSService
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
try:
|
||||
@@ -29,6 +31,7 @@ try:
|
||||
PostedUtterance,
|
||||
PostedUtteranceVoiceWithId,
|
||||
)
|
||||
from hume.tts.types import TimestampMessage
|
||||
except ModuleNotFoundError as e: # pragma: no cover - import-time guidance
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use Hume, you need to `pip install pipecat-ai[hume]`.")
|
||||
@@ -38,7 +41,7 @@ except ModuleNotFoundError as e: # pragma: no cover - import-time guidance
|
||||
HUME_SAMPLE_RATE = 48_000 # Hume TTS streams at 48 kHz
|
||||
|
||||
|
||||
class HumeTTSService(TTSService):
|
||||
class HumeTTSService(WordTTSService):
|
||||
"""Hume Octave Text-to-Speech service.
|
||||
|
||||
Streams PCM audio via Hume's HTTP output streaming (JSON chunks) endpoint
|
||||
@@ -48,6 +51,7 @@ class HumeTTSService(TTSService):
|
||||
|
||||
- Generates speech from text using Hume TTS.
|
||||
- Streams PCM audio.
|
||||
- Supports word-level timestamps for precise audio-text synchronization.
|
||||
- Supports dynamic updates of voice and synthesis parameters at runtime.
|
||||
- Provides metrics for Time To First Byte (TTFB) and TTS usage.
|
||||
"""
|
||||
@@ -92,7 +96,13 @@ class HumeTTSService(TTSService):
|
||||
f"Hume TTS streams at {HUME_SAMPLE_RATE} Hz; configured sample_rate={sample_rate}"
|
||||
)
|
||||
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
# WordTTSService sets push_text_frames=False by default, which we want
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
push_text_frames=False,
|
||||
push_stop_frames=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self._client = AsyncHumeClient(api_key=api_key)
|
||||
self._params = params or HumeTTSService.InputParams()
|
||||
@@ -102,6 +112,10 @@ class HumeTTSService(TTSService):
|
||||
|
||||
self._audio_bytes = b""
|
||||
|
||||
# Track cumulative time for word timestamps across utterances
|
||||
self._cumulative_time = 0.0
|
||||
self._started = False
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Can generate metrics.
|
||||
|
||||
@@ -110,15 +124,6 @@ class HumeTTSService(TTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that Hume TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that Hume's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def start(self, frame: StartFrame) -> None:
|
||||
"""Start the service.
|
||||
|
||||
@@ -126,6 +131,27 @@ class HumeTTSService(TTSService):
|
||||
frame: The start frame.
|
||||
"""
|
||||
await super().start(frame)
|
||||
self._reset_state()
|
||||
|
||||
def _reset_state(self):
|
||||
"""Reset internal state variables."""
|
||||
self._cumulative_time = 0.0
|
||||
self._started = False
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
"""Push a frame and handle state changes.
|
||||
|
||||
Args:
|
||||
frame: The frame to push.
|
||||
direction: The direction to push the frame.
|
||||
"""
|
||||
await super().push_frame(frame, direction)
|
||||
if isinstance(frame, (InterruptionFrame, TTSStoppedFrame)):
|
||||
# Reset timing on interruption or stop
|
||||
self._reset_state()
|
||||
|
||||
if isinstance(frame, TTSStoppedFrame):
|
||||
await self.add_word_timestamps([("Reset", 0)])
|
||||
|
||||
async def update_setting(self, key: str, value: Any) -> None:
|
||||
"""Runtime updates via `TTSUpdateSettingsFrame`.
|
||||
@@ -142,7 +168,7 @@ class HumeTTSService(TTSService):
|
||||
|
||||
if key_l == "voice_id":
|
||||
self.set_voice(str(value))
|
||||
logger.info(f"HumeTTSService voice_id set to: {self.voice}")
|
||||
logger.debug(f"HumeTTSService voice_id set to: {self.voice}")
|
||||
elif key_l == "description":
|
||||
self._params.description = None if value is None else str(value)
|
||||
elif key_l == "speed":
|
||||
@@ -155,7 +181,7 @@ class HumeTTSService(TTSService):
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate speech from text using Hume TTS.
|
||||
"""Generate speech from text using Hume TTS with word timestamps.
|
||||
|
||||
Args:
|
||||
text: The text to be synthesized.
|
||||
@@ -186,7 +212,12 @@ class HumeTTSService(TTSService):
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
await self.start_tts_usage_metrics(text)
|
||||
yield TTSStartedFrame()
|
||||
|
||||
# Start TTS sequence if not already started
|
||||
if not self._started:
|
||||
self.start_word_timestamps()
|
||||
yield TTSStartedFrame()
|
||||
self._started = True
|
||||
|
||||
try:
|
||||
# Instant mode is always enabled here (not user-configurable)
|
||||
@@ -197,23 +228,50 @@ class HumeTTSService(TTSService):
|
||||
# Use version "2" by default if no description is provided
|
||||
# Version "1" is needed when description is used
|
||||
version = "1" if self._params.description is not None else "2"
|
||||
|
||||
# Track the duration of this utterance based on the last timestamp
|
||||
utterance_duration = 0.0
|
||||
|
||||
async for chunk in self._client.tts.synthesize_json_streaming(
|
||||
utterances=[utterance],
|
||||
format=pcm_fmt,
|
||||
instant_mode=True,
|
||||
version=version,
|
||||
include_timestamp_types=["word"], # Request word-level timestamps
|
||||
):
|
||||
# Process audio chunks
|
||||
audio_b64 = getattr(chunk, "audio", None)
|
||||
if not audio_b64:
|
||||
continue
|
||||
if audio_b64:
|
||||
await self.stop_ttfb_metrics()
|
||||
pcm_bytes = base64.b64decode(audio_b64)
|
||||
self._audio_bytes += pcm_bytes
|
||||
|
||||
pcm_bytes = base64.b64decode(audio_b64)
|
||||
self._audio_bytes += pcm_bytes
|
||||
# Buffer audio until we have enough to avoid glitches
|
||||
if len(self._audio_bytes) >= self.chunk_size:
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=self._audio_bytes,
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
)
|
||||
yield frame
|
||||
self._audio_bytes = b""
|
||||
|
||||
# Buffer audio until we have enough to avoid glitches
|
||||
if len(self._audio_bytes) < self.chunk_size:
|
||||
continue
|
||||
# Process timestamp messages
|
||||
if isinstance(chunk, TimestampMessage):
|
||||
timestamp = chunk.timestamp
|
||||
if timestamp.type == "word":
|
||||
# Convert milliseconds to seconds and add cumulative offset
|
||||
word_start_time = self._cumulative_time + (timestamp.time.begin / 1000.0)
|
||||
word_end_time = self._cumulative_time + (timestamp.time.end / 1000.0)
|
||||
|
||||
# Track the maximum end time for this utterance
|
||||
utterance_duration = max(utterance_duration, word_end_time)
|
||||
|
||||
# Add word timestamp
|
||||
await self.add_word_timestamps([(timestamp.text, word_start_time)])
|
||||
|
||||
# Flush any remaining audio bytes
|
||||
if self._audio_bytes:
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=self._audio_bytes,
|
||||
sample_rate=self.sample_rate,
|
||||
@@ -224,10 +282,13 @@ class HumeTTSService(TTSService):
|
||||
|
||||
self._audio_bytes = b""
|
||||
|
||||
# Update cumulative time for next utterance
|
||||
if utterance_duration > 0:
|
||||
self._cumulative_time = utterance_duration
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} error generating TTS: {e}")
|
||||
await self.push_error(ErrorFrame(f"Error generating TTS: {e}"))
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
finally:
|
||||
# Ensure TTFB timer is stopped even on early failures
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame()
|
||||
# Let the parent class handle TTSStoppedFrame via push_stop_frames
|
||||
|
||||
@@ -250,15 +250,6 @@ class InworldTTSService(TTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that Inworld TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that Inworld's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the Inworld TTS service.
|
||||
|
||||
@@ -374,7 +365,7 @@ class InworldTTSService(TTSService):
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"Inworld API error: {error_text}")
|
||||
await self.push_error(ErrorFrame(f"Inworld API error: {error_text}"))
|
||||
yield ErrorFrame(error=f"Inworld API error: {error_text}")
|
||||
return
|
||||
|
||||
# ================================================================================
|
||||
@@ -401,8 +392,7 @@ class InworldTTSService(TTSService):
|
||||
# STEP 7: ERROR HANDLING
|
||||
# ================================================================================
|
||||
# Log any unexpected errors and notify the pipeline
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(f"Error generating TTS: {e}"))
|
||||
await self.push_error(exception=e)
|
||||
finally:
|
||||
# ================================================================================
|
||||
# STEP 8: CLEANUP AND COMPLETION
|
||||
@@ -517,7 +507,7 @@ class InworldTTSService(TTSService):
|
||||
# Extract the base64-encoded audio content from response
|
||||
if "audioContent" not in response_data:
|
||||
logger.error("No audioContent in Inworld API response")
|
||||
await self.push_error(ErrorFrame("No audioContent in response"))
|
||||
await self.push_error(ErrorFrame(error="No audioContent in response"))
|
||||
return
|
||||
|
||||
# ================================================================================
|
||||
|
||||
@@ -124,15 +124,6 @@ class LmntTTSService(InterruptibleTTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that LMNT TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that LMNT's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert a Language enum to LMNT service language format.
|
||||
|
||||
@@ -223,7 +214,7 @@ class LmntTTSService(InterruptibleTTSService):
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
await self.push_error(exception=e)
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
@@ -239,7 +230,7 @@ class LmntTTSService(InterruptibleTTSService):
|
||||
# await self._websocket.send(json.dumps({"eof": True}))
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error closing websocket: {e}")
|
||||
await self.push_error(error_msg=f"Error disconnecting from LMNT: {e}", exception=e)
|
||||
finally:
|
||||
self._started = False
|
||||
self._websocket = None
|
||||
@@ -273,10 +264,9 @@ class LmntTTSService(InterruptibleTTSService):
|
||||
try:
|
||||
msg = json.loads(message)
|
||||
if "error" in msg:
|
||||
logger.error(f"{self} error: {msg['error']}")
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
await self.stop_all_metrics()
|
||||
await self.push_error(ErrorFrame(f"{self} error: {msg['error']}"))
|
||||
await self.push_error(error_msg=f"{self} error: {msg['error']}")
|
||||
return
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Invalid JSON message: {message}")
|
||||
@@ -309,11 +299,11 @@ class LmntTTSService(InterruptibleTTSService):
|
||||
await self._get_websocket().send(json.dumps({"flush": True}))
|
||||
await self.start_tts_usage_metrics(text)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error sending message: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
return
|
||||
yield None
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
@@ -176,7 +176,6 @@ class MCPClient(BaseObject):
|
||||
except Exception as e:
|
||||
error_msg = f"Error calling mcp tool {params.function_name}: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
logger.exception("Full exception details:")
|
||||
await params.result_callback(error_msg)
|
||||
|
||||
async def _stdio_list_tools(self) -> ToolsSchema:
|
||||
@@ -207,7 +206,6 @@ class MCPClient(BaseObject):
|
||||
except Exception as e:
|
||||
error_msg = f"Error calling mcp tool {params.function_name}: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
logger.exception("Full exception details:")
|
||||
await params.result_callback(error_msg)
|
||||
|
||||
async def _streamable_http_list_tools(self) -> ToolsSchema:
|
||||
@@ -246,7 +244,6 @@ class MCPClient(BaseObject):
|
||||
except Exception as e:
|
||||
error_msg = f"Error calling mcp tool {params.function_name}: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
logger.exception("Full exception details:")
|
||||
await params.result_callback(error_msg)
|
||||
|
||||
async def _call_tool(self, session, function_name, arguments, result_callback):
|
||||
@@ -302,7 +299,6 @@ class MCPClient(BaseObject):
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read tool '{tool_name}': {str(e)}")
|
||||
logger.exception("Full exception details:")
|
||||
continue
|
||||
|
||||
logger.debug(f"Completed reading {len(tool_schemas)} tools")
|
||||
|
||||
@@ -253,8 +253,7 @@ class Mem0MemoryService(FrameProcessor):
|
||||
# Otherwise, pass the enhanced context frame downstream
|
||||
await self.push_frame(frame)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing with Mem0: {str(e)}")
|
||||
await self.push_frame(ErrorFrame(f"Error processing with Mem0: {str(e)}"))
|
||||
await self.push_error(exception=e)
|
||||
await self.push_frame(frame) # Still pass the original frame through
|
||||
else:
|
||||
# For non-context frames, just pass them through
|
||||
|
||||
@@ -194,15 +194,6 @@ class MiniMaxHttpTTSService(TTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that MiniMax TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that MiniMax's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert a Language enum to MiniMax service language format.
|
||||
|
||||
@@ -273,7 +264,6 @@ class MiniMaxHttpTTSService(TTSService):
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_message = f"MiniMax TTS error: HTTP {response.status}"
|
||||
logger.error(error_message)
|
||||
yield ErrorFrame(error=error_message)
|
||||
return
|
||||
|
||||
@@ -347,8 +337,7 @@ class MiniMaxHttpTTSService(TTSService):
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error generating TTS: {e}")
|
||||
yield ErrorFrame(error=f"MiniMax TTS error: {str(e)}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
finally:
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -110,7 +110,6 @@ class MoondreamService(VisionService):
|
||||
if analysis fails.
|
||||
"""
|
||||
if not self._model:
|
||||
logger.error(f"{self} error: Moondream model not available ({self.model_name})")
|
||||
yield ErrorFrame("Moondream model not available")
|
||||
return
|
||||
|
||||
|
||||
@@ -151,15 +151,6 @@ class NeuphonicTTSService(InterruptibleTTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that Neuphonic TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that Neuphonic's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert a Language enum to Neuphonic service language format.
|
||||
|
||||
@@ -294,7 +285,7 @@ class NeuphonicTTSService(InterruptibleTTSService):
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
await self.push_error(exception=e)
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
@@ -307,7 +298,7 @@ class NeuphonicTTSService(InterruptibleTTSService):
|
||||
logger.debug("Disconnecting from Neuphonic")
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error closing websocket: {e}")
|
||||
await self.push_error(exception=e)
|
||||
finally:
|
||||
self._started = False
|
||||
self._websocket = None
|
||||
@@ -372,14 +363,14 @@ class NeuphonicTTSService(InterruptibleTTSService):
|
||||
await self._send_text(text)
|
||||
await self.start_tts_usage_metrics(text)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error sending message: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
return
|
||||
yield None
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
|
||||
class NeuphonicHttpTTSService(TTSService):
|
||||
@@ -445,15 +436,6 @@ class NeuphonicHttpTTSService(TTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that Neuphonic TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that Neuphonic's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert a Language enum to Neuphonic service language format.
|
||||
|
||||
@@ -553,7 +535,6 @@ class NeuphonicHttpTTSService(TTSService):
|
||||
error_text = await response.text()
|
||||
error_message = f"Neuphonic API error: HTTP {response.status} - {error_text}"
|
||||
logger.error(error_message)
|
||||
yield ErrorFrame(error=error_message)
|
||||
return
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
@@ -582,7 +563,7 @@ class NeuphonicHttpTTSService(TTSService):
|
||||
yield TTSAudioRawFrame(audio_bytes, self.sample_rate, 1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing SSE message: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
# Don't yield error frame for individual message failures
|
||||
continue
|
||||
|
||||
@@ -590,8 +571,7 @@ class NeuphonicHttpTTSService(TTSService):
|
||||
logger.debug("TTS generation cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in run_tts: {e}")
|
||||
yield ErrorFrame(error=f"Neuphonic TTS error: {str(e)}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
finally:
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -390,9 +390,7 @@ class BaseOpenAILLMService(LLMService):
|
||||
# Keep iterating through the response to collect all the argument fragments
|
||||
arguments += tool_call.function.arguments
|
||||
elif chunk.choices[0].delta.content:
|
||||
frame = LLMTextFrame(chunk.choices[0].delta.content)
|
||||
frame.includes_inter_frame_spaces = True
|
||||
await self.push_frame(frame)
|
||||
await self.push_frame(LLMTextFrame(chunk.choices[0].delta.content))
|
||||
|
||||
# When gpt-4o-audio / gpt-4o-mini-audio is used for llm or stt+llm
|
||||
# we need to get LLMTextFrame for the transcript
|
||||
|
||||
@@ -76,7 +76,6 @@ class OpenAIImageGenService(ImageGenService):
|
||||
image_url = image.data[0].url
|
||||
|
||||
if not image_url:
|
||||
logger.error(f"{self} No image provided in response: {image}")
|
||||
yield ErrorFrame("Image generation failed")
|
||||
return
|
||||
|
||||
|
||||
@@ -443,7 +443,7 @@ class OpenAIRealtimeLLMService(LLMService):
|
||||
)
|
||||
self._receive_task = self.create_task(self._receive_task_handler())
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
await self.push_error(error_msg=f"Error connecting: {e}", exception=e)
|
||||
self._websocket = None
|
||||
|
||||
async def _disconnect(self):
|
||||
@@ -460,7 +460,7 @@ class OpenAIRealtimeLLMService(LLMService):
|
||||
self._completed_tool_calls = set()
|
||||
self._disconnecting = False
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error disconnecting: {e}")
|
||||
await self.push_error(error_msg=f"Error disconnecting: {e}", exception=e)
|
||||
|
||||
async def _ws_send(self, realtime_message):
|
||||
try:
|
||||
@@ -473,12 +473,11 @@ class OpenAIRealtimeLLMService(LLMService):
|
||||
# somehow *started* the websocket send attempt while we still
|
||||
# had a connection)
|
||||
return
|
||||
logger.error(f"Error sending message to websocket: {e}")
|
||||
# In server-to-server contexts, a WebSocket error should be quite rare. Given how hard
|
||||
# it is to recover from a send-side error with proper state management, and that exponential
|
||||
# backoff for retries can have cost/stability implications for a service cluster, let's just
|
||||
# treat a send-side error as fatal.
|
||||
await self.push_error(ErrorFrame(error=f"Error sending client event: {e}", fatal=True))
|
||||
await self.push_error(error_msg=f"Error sending client event: {e}", exception=e)
|
||||
|
||||
async def _update_settings(self):
|
||||
settings = self._session_properties
|
||||
@@ -667,9 +666,7 @@ class OpenAIRealtimeLLMService(LLMService):
|
||||
self._current_assistant_response = None
|
||||
# error handling
|
||||
if evt.response.status == "failed":
|
||||
await self.push_error(
|
||||
ErrorFrame(error=evt.response.status_details["error"]["message"], fatal=True)
|
||||
)
|
||||
await self.push_error(ErrorFrame(error=evt.response.status_details["error"]["message"]))
|
||||
return
|
||||
# response content
|
||||
for item in evt.response.output:
|
||||
@@ -680,8 +677,6 @@ class OpenAIRealtimeLLMService(LLMService):
|
||||
# the output modality is "text"
|
||||
if evt.delta:
|
||||
frame = LLMTextFrame(evt.delta)
|
||||
# OpenAI Realtime text already includes any necessary inter-chunk spaces
|
||||
frame.includes_inter_frame_spaces = True
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def _handle_evt_audio_transcript_delta(self, evt):
|
||||
@@ -763,7 +758,7 @@ class OpenAIRealtimeLLMService(LLMService):
|
||||
|
||||
async def _handle_evt_error(self, evt):
|
||||
# Errors are fatal to this connection. Send an ErrorFrame.
|
||||
await self.push_error(ErrorFrame(error=f"Error: {evt}", fatal=True))
|
||||
await self.push_error(error_msg=f"Error: {evt}")
|
||||
|
||||
#
|
||||
# state and client events for the current conversation
|
||||
|
||||
@@ -131,15 +131,6 @@ class OpenAITTSService(TTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that OpenAI TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that OpenAI's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def set_model(self, model: str):
|
||||
"""Set the TTS model to use.
|
||||
|
||||
@@ -199,7 +190,7 @@ class OpenAITTSService(TTSService):
|
||||
f"{self} error getting audio (status: {r.status_code}, error: {error})"
|
||||
)
|
||||
yield ErrorFrame(
|
||||
f"Error getting audio (status: {r.status_code}, error: {error})"
|
||||
error=f"Error getting audio (status: {r.status_code}, error: {error})"
|
||||
)
|
||||
return
|
||||
|
||||
@@ -215,4 +206,4 @@ class OpenAITTSService(TTSService):
|
||||
yield frame
|
||||
yield TTSStoppedFrame()
|
||||
except BadRequestError as e:
|
||||
logger.exception(f"{self} error generating TTS: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
@@ -79,5 +79,5 @@ class AzureRealtimeBetaLLMService(OpenAIRealtimeBetaLLMService):
|
||||
)
|
||||
self._receive_task = self.create_task(self._receive_task_handler())
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
await self.push_error(error_msg=f"Error connecting: {e}", exception=e)
|
||||
self._websocket = None
|
||||
|
||||
@@ -424,7 +424,7 @@ class OpenAIRealtimeBetaLLMService(LLMService):
|
||||
)
|
||||
self._receive_task = self.create_task(self._receive_task_handler())
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
await self.push_error(error_msg=f"Error connecting: {e}", exception=e)
|
||||
self._websocket = None
|
||||
|
||||
async def _disconnect(self):
|
||||
@@ -440,7 +440,7 @@ class OpenAIRealtimeBetaLLMService(LLMService):
|
||||
self._receive_task = None
|
||||
self._disconnecting = False
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error disconnecting: {e}")
|
||||
await self.push_error(error_msg=f"Error disconnecting: {e}", exception=e)
|
||||
|
||||
async def _ws_send(self, realtime_message):
|
||||
try:
|
||||
@@ -449,12 +449,11 @@ class OpenAIRealtimeBetaLLMService(LLMService):
|
||||
except Exception as e:
|
||||
if self._disconnecting:
|
||||
return
|
||||
logger.error(f"Error sending message to websocket: {e}")
|
||||
# In server-to-server contexts, a WebSocket error should be quite rare. Given how hard
|
||||
# it is to recover from a send-side error with proper state management, and that exponential
|
||||
# backoff for retries can have cost/stability implications for a service cluster, let's just
|
||||
# treat a send-side error as fatal.
|
||||
await self.push_error(ErrorFrame(error=f"Error sending client event: {e}", fatal=True))
|
||||
await self.push_error(error_msg=f"Error sending client event: {e}", exception=e)
|
||||
|
||||
async def _update_settings(self):
|
||||
settings = self._session_properties
|
||||
@@ -627,9 +626,7 @@ class OpenAIRealtimeBetaLLMService(LLMService):
|
||||
self._current_assistant_response = None
|
||||
# error handling
|
||||
if evt.response.status == "failed":
|
||||
await self.push_error(
|
||||
ErrorFrame(error=evt.response.status_details["error"]["message"], fatal=True)
|
||||
)
|
||||
await self.push_error(ErrorFrame(error=evt.response.status_details["error"]["message"]))
|
||||
return
|
||||
# response content
|
||||
for item in evt.response.output:
|
||||
@@ -687,7 +684,7 @@ class OpenAIRealtimeBetaLLMService(LLMService):
|
||||
|
||||
async def _handle_evt_error(self, evt):
|
||||
# Errors are fatal to this connection. Send an ErrorFrame.
|
||||
await self.push_error(ErrorFrame(error=f"Error: {evt}", fatal=True))
|
||||
await self.push_error(error_msg=f"Error: {evt}")
|
||||
|
||||
async def _handle_assistant_output(self, output):
|
||||
# We haven't seen intermixed audio and function_call items in the same response. But let's
|
||||
|
||||
@@ -66,15 +66,6 @@ class PiperTTSService(TTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that Piper TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that Piper's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate speech from text using Piper's HTTP API.
|
||||
@@ -97,11 +88,8 @@ class PiperTTSService(TTSService):
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error = await response.text()
|
||||
logger.error(
|
||||
f"{self} error getting audio (status: {response.status}, error: {error})"
|
||||
)
|
||||
yield ErrorFrame(
|
||||
f"Error getting audio (status: {response.status}, error: {error})"
|
||||
error=f"Error getting audio (status: {response.status}, error: {error})"
|
||||
)
|
||||
return
|
||||
|
||||
@@ -117,8 +105,8 @@ class PiperTTSService(TTSService):
|
||||
await self.stop_ttfb_metrics()
|
||||
yield frame
|
||||
except Exception as e:
|
||||
logger.error(f"Error in run_tts: {e}")
|
||||
yield ErrorFrame(error=str(e))
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
finally:
|
||||
logger.debug(f"{self}: Finished TTS [{text}]")
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
@@ -266,7 +266,7 @@ class PlayHTTTSService(InterruptibleTTSService):
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
await self.push_error(error_msg=f"Error connecting: {e}", exception=e)
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
@@ -279,7 +279,7 @@ class PlayHTTTSService(InterruptibleTTSService):
|
||||
logger.debug("Disconnecting from PlayHT")
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error closing websocket: {e}")
|
||||
await self.push_error(error_msg=f"Error disconnecting: {e}", exception=e)
|
||||
finally:
|
||||
self._request_id = None
|
||||
self._websocket = None
|
||||
@@ -349,8 +349,7 @@ class PlayHTTTSService(InterruptibleTTSService):
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
self._request_id = None
|
||||
elif "error" in msg:
|
||||
logger.error(f"{self} error: {msg}")
|
||||
await self.push_error(ErrorFrame(f"{self} error: {msg['error']}"))
|
||||
await self.push_error(error_msg=f"{self} error: {msg['error']}")
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Invalid JSON message: {message}")
|
||||
|
||||
@@ -392,7 +391,7 @@ class PlayHTTTSService(InterruptibleTTSService):
|
||||
await self._get_websocket().send(json.dumps(tts_command))
|
||||
await self.start_tts_usage_metrics(text)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error sending message: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
@@ -402,8 +401,7 @@ class PlayHTTTSService(InterruptibleTTSService):
|
||||
yield None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error generating TTS: {e}")
|
||||
yield ErrorFrame(f"{self} error: {str(e)}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
|
||||
class PlayHTHttpTTSService(TTSService):
|
||||
@@ -623,7 +621,7 @@ class PlayHTHttpTTSService(TTSService):
|
||||
yield frame
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error generating TTS: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
finally:
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -259,7 +259,7 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
await self.push_error(error_msg=f"Error connecting: {e}", exception=e)
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
@@ -271,7 +271,7 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
await self._websocket.send(json.dumps(self._build_eos_msg()))
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error closing websocket: {e}")
|
||||
await self.push_error(error_msg=f"Error disconnecting: {e}", exception=e)
|
||||
finally:
|
||||
self._context_id = None
|
||||
self._websocket = None
|
||||
@@ -364,10 +364,9 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
logger.debug(f"Updated cumulative time to: {self._cumulative_time}")
|
||||
|
||||
elif msg["type"] == "error":
|
||||
logger.error(f"{self} error: {msg}")
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
await self.stop_all_metrics()
|
||||
await self.push_error(ErrorFrame(f"{self} error: {msg['message']}"))
|
||||
await self.push_error(error_msg=f"{self} error: {msg['message']}")
|
||||
self._context_id = None
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
@@ -409,14 +408,14 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
await self._get_websocket().send(json.dumps(msg))
|
||||
await self.start_tts_usage_metrics(text)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error sending message: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
return
|
||||
yield None
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
|
||||
class RimeHttpTTSService(TTSService):
|
||||
@@ -497,15 +496,6 @@ class RimeHttpTTSService(TTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that Rime TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that Rime's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> str | None:
|
||||
"""Convert pipecat language to Rime language code.
|
||||
|
||||
@@ -556,7 +546,6 @@ class RimeHttpTTSService(TTSService):
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_message = f"Rime TTS error: HTTP {response.status}"
|
||||
logger.error(error_message)
|
||||
yield ErrorFrame(error=error_message)
|
||||
return
|
||||
|
||||
@@ -574,8 +563,7 @@ class RimeHttpTTSService(TTSService):
|
||||
yield frame
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error generating TTS: {e}")
|
||||
yield ErrorFrame(error=f"Rime TTS error: {str(e)}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
finally:
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -655,12 +655,10 @@ class RivaSegmentedSTTService(SegmentedSTTService):
|
||||
logger.debug("No transcription results found in Riva response")
|
||||
|
||||
except AttributeError as ae:
|
||||
logger.error(f"Unexpected response structure from Riva: {ae}")
|
||||
yield ErrorFrame(f"Unexpected Riva response format: {str(ae)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Riva Canary ASR error: {e}")
|
||||
yield ErrorFrame(f"Riva Canary ASR error: {str(e)}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
|
||||
class ParakeetSTTService(RivaSTTService):
|
||||
|
||||
@@ -23,6 +23,7 @@ from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
@@ -112,15 +113,6 @@ class RivaTTSService(TTSService):
|
||||
riva.client.proto.riva_tts_pb2.RivaSynthesisConfigRequest()
|
||||
)
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that Riva TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that Riva's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def set_model(self, model: str):
|
||||
"""Attempt to set the TTS model.
|
||||
|
||||
@@ -188,7 +180,7 @@ class RivaTTSService(TTSService):
|
||||
yield frame
|
||||
resp = await asyncio.wait_for(queue.get(), timeout=RIVA_TTS_TIMEOUT_SECS)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"{self} timeout waiting for audio response")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -176,9 +176,7 @@ class SambaNovaLLMService(OpenAILLMService): # type: ignore
|
||||
# Keep iterating through the response to collect all the argument fragments
|
||||
arguments += tool_call.function.arguments
|
||||
elif chunk.choices[0].delta.content:
|
||||
frame = LLMTextFrame(chunk.choices[0].delta.content)
|
||||
frame.includes_inter_frame_spaces = True
|
||||
await self.push_frame(frame)
|
||||
await self.push_frame(LLMTextFrame(chunk.choices[0].delta.content))
|
||||
|
||||
# When gpt-4o-audio / gpt-4o-mini-audio is used for llm or stt+llm
|
||||
# we need to get LLMTextFrame for the transcript
|
||||
|
||||
@@ -275,8 +275,7 @@ class SarvamSTTService(STTService):
|
||||
await self._socket_client.translate(**method_kwargs)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending audio to Sarvam: {e}")
|
||||
await self.push_error(ErrorFrame(f"Failed to send audio: {e}"))
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
yield None
|
||||
|
||||
@@ -332,13 +331,11 @@ class SarvamSTTService(STTService):
|
||||
logger.info("Connected to Sarvam successfully")
|
||||
|
||||
except ApiError as e:
|
||||
logger.error(f"Sarvam API error: {e}")
|
||||
await self.push_error(ErrorFrame(f"Sarvam API error: {e}"))
|
||||
await self.push_error(error_msg=f"Sarvam API error: {e}", exception=e)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Sarvam: {e}")
|
||||
self._socket_client = None
|
||||
self._websocket_context = None
|
||||
await self.push_error(ErrorFrame(f"Failed to connect to Sarvam: {e}"))
|
||||
await self.push_error(error_msg=f"Failed to connect to Sarvam: {e}", exception=e)
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Disconnect from Sarvam WebSocket API using SDK."""
|
||||
@@ -351,7 +348,9 @@ class SarvamSTTService(STTService):
|
||||
# Exit the async context manager
|
||||
await self._websocket_context.__aexit__(None, None, None)
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing WebSocket connection: {e}")
|
||||
await self.push_error(
|
||||
error_msg=f"Error closing WebSocket connection: {e}", exception=e
|
||||
)
|
||||
finally:
|
||||
logger.debug("Disconnected from Sarvam WebSocket")
|
||||
self._socket_client = None
|
||||
@@ -371,8 +370,7 @@ class SarvamSTTService(STTService):
|
||||
# Messages will be handled via the _message_handler callback
|
||||
await self._socket_client.start_listening()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Sarvam receive task: {e}")
|
||||
await self.push_error(ErrorFrame(f"Sarvam receive task error: {e}"))
|
||||
await self.push_error(error_msg=f"Sarvam receive task error: {e}", exception=e)
|
||||
|
||||
async def _handle_message(self, message):
|
||||
"""Handle incoming WebSocket message from Sarvam SDK.
|
||||
@@ -427,8 +425,7 @@ class SarvamSTTService(STTService):
|
||||
await self.stop_processing_metrics()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling Sarvam message: {e}")
|
||||
await self.push_error(ErrorFrame(f"Failed to handle message: {e}"))
|
||||
await self.push_error(error_msg=f"Failed to handle message: {e}", exception=e)
|
||||
await self.stop_all_metrics()
|
||||
|
||||
@traced_stt
|
||||
|
||||
@@ -195,15 +195,6 @@ class SarvamHttpTTSService(TTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that Sarvam TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that Sarvam's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert a Language enum to Sarvam AI language format.
|
||||
|
||||
@@ -263,8 +254,7 @@ class SarvamHttpTTSService(TTSService):
|
||||
async with self._session.post(url, json=payload, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"Sarvam API error: {error_text}")
|
||||
await self.push_error(ErrorFrame(f"Sarvam API error: {error_text}"))
|
||||
yield ErrorFrame(error=f"Sarvam API error: {error_text}")
|
||||
return
|
||||
|
||||
response_data = await response.json()
|
||||
@@ -273,8 +263,7 @@ class SarvamHttpTTSService(TTSService):
|
||||
|
||||
# Decode base64 audio data
|
||||
if "audios" not in response_data or not response_data["audios"]:
|
||||
logger.error("No audio data received from Sarvam API")
|
||||
await self.push_error(ErrorFrame("No audio data received"))
|
||||
yield ErrorFrame(error="No audio data received")
|
||||
return
|
||||
|
||||
# Get the first audio (there should be only one for single text input)
|
||||
@@ -295,8 +284,7 @@ class SarvamHttpTTSService(TTSService):
|
||||
yield frame
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
await self.push_error(ErrorFrame(f"Error generating TTS: {e}"))
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
finally:
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame()
|
||||
@@ -467,15 +455,6 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that Sarvam TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that Sarvam's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert a Language enum to Sarvam AI language format.
|
||||
|
||||
@@ -578,7 +557,7 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
await self._disconnect_websocket()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during disconnect: {e}")
|
||||
await self.push_error(exception=e)
|
||||
finally:
|
||||
# Reset state only after everything is cleaned up
|
||||
self._started = False
|
||||
@@ -602,7 +581,9 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
await self.push_error(
|
||||
error_msg=f"Error connecting to Sarvam TTS Websocket: {e}", exception=e
|
||||
)
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_connection_error", f"{e}")
|
||||
|
||||
@@ -618,8 +599,7 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
await self._websocket.send(json.dumps(config_message))
|
||||
logger.debug("Configuration sent successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send config: {str(e)}")
|
||||
await self.push_frame(ErrorFrame(f"Failed to send config: {str(e)}"))
|
||||
await self.push_error(exception=e)
|
||||
raise
|
||||
|
||||
async def _disconnect_websocket(self):
|
||||
@@ -631,7 +611,7 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
logger.debug("Disconnecting from Sarvam")
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error closing websocket: {e}")
|
||||
await self.push_error(error_msg=f"Error closing websocket: {e}", exception=e)
|
||||
finally:
|
||||
self._started = False
|
||||
self._websocket = None
|
||||
@@ -655,13 +635,13 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
await self.push_frame(frame)
|
||||
elif msg.get("type") == "error":
|
||||
error_msg = msg["data"]["message"]
|
||||
logger.error(f"TTS Error: {error_msg}")
|
||||
await self.push_error(error_msg=f"TTS Error: {error_msg}")
|
||||
|
||||
# If it's a timeout error, the connection might need to be reset
|
||||
if "too long" in error_msg.lower() or "timeout" in error_msg.lower():
|
||||
logger.warning("Connection timeout detected, service may need restart")
|
||||
|
||||
await self.push_frame(ErrorFrame(f"TTS Error: {error_msg}"))
|
||||
await self.push_frame(ErrorFrame(error=f"TTS Error: {error_msg}"))
|
||||
|
||||
async def _keepalive_task_handler(self):
|
||||
"""Handle keepalive messages to maintain WebSocket connection."""
|
||||
@@ -717,11 +697,11 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
await self._send_text(text)
|
||||
await self.start_tts_usage_metrics(text)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error sending message: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
return
|
||||
yield None
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
@@ -84,6 +84,10 @@ class SimliVideoService(FrameProcessor):
|
||||
Please use 'api_key' and 'face_id' parameters instead.
|
||||
|
||||
use_turn_server: Whether to use TURN server for connection. Defaults to False.
|
||||
|
||||
.. deprecated:: 0.0.95
|
||||
The 'use_turn_server' parameter is deprecated and will be removed in a future version.
|
||||
|
||||
latency_interval: Latency interval setting for sending health checks to check
|
||||
the latency to Simli Servers. Defaults to 0.
|
||||
simli_url: URL of the simli servers. Can be changed for custom deployments
|
||||
@@ -135,14 +139,20 @@ class SimliVideoService(FrameProcessor):
|
||||
|
||||
config = SimliConfig(**config_kwargs)
|
||||
|
||||
if use_turn_server:
|
||||
warnings.warn(
|
||||
"The 'use_turn_server' parameter is deprecated and will be removed in a future version.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
self._initialized = False
|
||||
# Add buffer time to session limits
|
||||
config.maxIdleTime += 5
|
||||
config.maxSessionLength += 5
|
||||
self._simli_client = SimliClient(
|
||||
config,
|
||||
use_turn_server,
|
||||
latency_interval,
|
||||
config=config,
|
||||
latencyInterval=latency_interval,
|
||||
simliURL=simli_url,
|
||||
)
|
||||
|
||||
@@ -168,7 +178,7 @@ class SimliVideoService(FrameProcessor):
|
||||
self._audio_task = self.create_task(self._consume_and_process_audio())
|
||||
self._video_task = self.create_task(self._consume_and_process_video())
|
||||
except Exception as e:
|
||||
logger.error(f"{self}: unable to start connection: {e}")
|
||||
await self.push_error(error_msg=f"Unable to start connection: {e}", exception=e)
|
||||
|
||||
async def _consume_and_process_audio(self):
|
||||
"""Consume audio frames from Simli and push them downstream."""
|
||||
@@ -246,7 +256,7 @@ class SimliVideoService(FrameProcessor):
|
||||
await self._simli_client.send(audioBytes)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
await self.push_error(error_msg=f"Error sending audio: {e}", exception=e)
|
||||
elif isinstance(frame, TTSStoppedFrame):
|
||||
try:
|
||||
if self._previously_interrupted and len(self._audio_buffer) > 0:
|
||||
@@ -254,7 +264,7 @@ class SimliVideoService(FrameProcessor):
|
||||
self._previously_interrupted = False
|
||||
self._audio_buffer = bytearray()
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
await self.push_error(error_msg=f"Error stopping TTS: {e}", exception=e)
|
||||
return
|
||||
elif isinstance(frame, (EndFrame, CancelFrame)):
|
||||
await self._stop()
|
||||
|
||||
@@ -194,7 +194,7 @@ class SonioxSTTService(STTService):
|
||||
self._websocket = await websocket_connect(self._url)
|
||||
|
||||
if not self._websocket:
|
||||
logger.error(f"Unable to connect to Soniox API at {self._url}")
|
||||
await self.push_error(error_msg=f"Unable to connect to Soniox API at {self._url}")
|
||||
|
||||
# If vad_force_turn_endpoint is not enabled, we need to enable endpoint detection.
|
||||
# Either one or the other is required.
|
||||
@@ -327,8 +327,7 @@ class SonioxSTTService(STTService):
|
||||
# Expected when closing the connection
|
||||
logger.debug("WebSocket connection closed, keepalive task stopped.")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error (_keepalive_task_handler): {e}")
|
||||
await self.push_error(ErrorFrame(f"{self} error (_keepalive_task_handler): {e}"))
|
||||
await self.push_error(exception=e)
|
||||
|
||||
async def _receive_task_handler(self):
|
||||
if not self._websocket:
|
||||
@@ -404,13 +403,8 @@ class SonioxSTTService(STTService):
|
||||
if error_code or error_message:
|
||||
# In case of error, still send the final transcript (if any remaining in the buffer).
|
||||
await send_endpoint_transcript()
|
||||
logger.error(
|
||||
f"{self} error: {error_code} (_receive_task_handler) - {error_message}"
|
||||
)
|
||||
await self.push_error(
|
||||
ErrorFrame(
|
||||
f"{self} error: {error_code} (_receive_task_handler) - {error_message}"
|
||||
)
|
||||
error_msg=f"{self} error: {error_code} (_receive_task_handler) - {error_message}"
|
||||
)
|
||||
|
||||
finished = content.get("finished")
|
||||
@@ -425,5 +419,4 @@ class SonioxSTTService(STTService):
|
||||
# Expected when closing the connection.
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error: {e}")
|
||||
await self.push_error(ErrorFrame(f"{self} error: {e}"))
|
||||
await self.push_error(error_msg=f"Error receiving message: {e}", exception=e)
|
||||
|
||||
@@ -467,8 +467,7 @@ class SpeechmaticsSTTService(STTService):
|
||||
await self._client.send_audio(audio)
|
||||
yield None
|
||||
except Exception as e:
|
||||
logger.error(f"Speechmatics error: {e}")
|
||||
yield ErrorFrame(f"Speechmatics error: {e}", fatal=False)
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
await self._disconnect()
|
||||
|
||||
def update_params(
|
||||
@@ -514,6 +513,7 @@ class SpeechmaticsSTTService(STTService):
|
||||
self._client.send_message(payload), self.get_event_loop()
|
||||
)
|
||||
except Exception as e:
|
||||
await self.push_error(exception=e)
|
||||
raise RuntimeError(f"error sending message to STT: {e}")
|
||||
|
||||
async def _connect(self) -> None:
|
||||
@@ -579,7 +579,7 @@ class SpeechmaticsSTTService(STTService):
|
||||
logger.debug(f"{self} Connected to Speechmatics STT service")
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} Error connecting to Speechmatics: {e}")
|
||||
await self.push_error(error_msg=f"Error connecting to Speechmatics: {e}", exception=e)
|
||||
self._client = None
|
||||
|
||||
async def _disconnect(self) -> None:
|
||||
@@ -593,7 +593,9 @@ class SpeechmaticsSTTService(STTService):
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"{self} Timeout while closing Speechmatics client connection")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} Error closing Speechmatics client: {e}")
|
||||
await self.push_error(
|
||||
error_msg=f"Error disconnecting from Speechmatics: {e}", exception=e
|
||||
)
|
||||
finally:
|
||||
self._client = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
@@ -105,15 +105,6 @@ class SpeechmaticsTTSService(TTSService):
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates that Speechmatics TTSTextFrames include necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
True, indicating that Speechmatics's text frames include necessary inter-frame spaces.
|
||||
"""
|
||||
return True
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate speech from text using Speechmatics' HTTP API.
|
||||
@@ -183,16 +174,13 @@ class SpeechmaticsTTSService(TTSService):
|
||||
|
||||
except (ValueError, ArithmeticError):
|
||||
yield ErrorFrame(
|
||||
error=f"{self} Service unavailable [503] (attempts {attempt})",
|
||||
fatal=True,
|
||||
error=f"{self} Service unavailable [503] (attempts {attempt})"
|
||||
)
|
||||
return
|
||||
|
||||
# != 200 : Error
|
||||
if response.status != 200:
|
||||
yield ErrorFrame(
|
||||
error=f"{self} Service unavailable [{response.status}]", fatal=True
|
||||
)
|
||||
yield ErrorFrame(error=f"{self} Service unavailable [{response.status}]")
|
||||
return
|
||||
|
||||
# Update Pipecat metrics
|
||||
@@ -234,7 +222,7 @@ class SpeechmaticsTTSService(TTSService):
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"{self}: Error generating TTS: {e}", fatal=True)
|
||||
yield ErrorFrame(error=f"{self}: Error generating TTS: {e}")
|
||||
finally:
|
||||
# Emit the TTS stopped frame
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -329,4 +329,4 @@ class WebsocketSTTService(STTService, WebsocketService):
|
||||
|
||||
async def _report_error(self, error: ErrorFrame):
|
||||
await self._call_event_handler("on_connection_error", error.error)
|
||||
await self.push_error(error)
|
||||
await self.push_error_frame(error)
|
||||
|
||||
@@ -142,6 +142,7 @@ class TTSService(AIService):
|
||||
self._voice_id: str = ""
|
||||
self._settings: Dict[str, Any] = {}
|
||||
self._text_aggregator: BaseTextAggregator = text_aggregator or SimpleTextAggregator()
|
||||
self._aggregated_text_includes_inter_frame_spaces: bool = False
|
||||
self._text_filters: Sequence[BaseTextFilter] = text_filters or []
|
||||
self._transport_destination: Optional[str] = transport_destination
|
||||
self._tracing_enabled: bool = False
|
||||
@@ -192,23 +193,6 @@ class TTSService(AIService):
|
||||
CHUNK_SECONDS = 0.5
|
||||
return int(self.sample_rate * CHUNK_SECONDS * 2) # 2 bytes/sample
|
||||
|
||||
@property
|
||||
def includes_inter_frame_spaces(self) -> bool:
|
||||
"""Indicates whether TTSTextFrames include necesary inter-frame spaces.
|
||||
|
||||
When True, the TTSTextFrame objects pushed by this service already
|
||||
include all necessary spaces between subsequent frames. When False,
|
||||
downstream processors (like the assistant context aggregator) may need
|
||||
to add spacing.
|
||||
|
||||
Subclasses should override this property to return True if their text
|
||||
generation process already includes necessary inter-frame spaces.
|
||||
|
||||
Returns:
|
||||
False by default. Subclasses can override to return True.
|
||||
"""
|
||||
return False
|
||||
|
||||
async def set_model(self, model: str):
|
||||
"""Set the TTS model to use.
|
||||
|
||||
@@ -369,9 +353,16 @@ class TTSService(AIService):
|
||||
await self._maybe_pause_frame_processing()
|
||||
|
||||
sentence = self._text_aggregator.text
|
||||
includes_inter_frame_spaces = self._aggregated_text_includes_inter_frame_spaces
|
||||
|
||||
# Reset aggregator state
|
||||
await self._text_aggregator.reset()
|
||||
self._processing_text = False
|
||||
await self._push_tts_frames(sentence)
|
||||
self._aggregated_text_includes_inter_frame_spaces = False
|
||||
|
||||
await self._push_tts_frames(
|
||||
sentence, includes_inter_frame_spaces=includes_inter_frame_spaces
|
||||
)
|
||||
if isinstance(frame, LLMFullResponseEndFrame):
|
||||
if self._push_text_frames:
|
||||
await self.push_frame(frame, direction)
|
||||
@@ -380,7 +371,8 @@ class TTSService(AIService):
|
||||
elif isinstance(frame, TTSSpeakFrame):
|
||||
# Store if we were processing text or not so we can set it back.
|
||||
processing_text = self._processing_text
|
||||
await self._push_tts_frames(frame.text)
|
||||
# Assumption: text in TTSSpeakFrame does not include inter-frame spaces
|
||||
await self._push_tts_frames(frame.text, includes_inter_frame_spaces=False)
|
||||
# We pause processing incoming frames because we are sending data to
|
||||
# the TTS. We pause to avoid audio overlapping.
|
||||
await self._maybe_pause_frame_processing()
|
||||
@@ -474,11 +466,17 @@ class TTSService(AIService):
|
||||
text = frame.text
|
||||
else:
|
||||
text = await self._text_aggregator.aggregate(frame.text)
|
||||
# Assumption: whether inter-frame spaces are included shouldn't
|
||||
# change during aggregation, so we can just use the latest frame's
|
||||
# value
|
||||
self._aggregated_text_includes_inter_frame_spaces = frame.includes_inter_frame_spaces
|
||||
|
||||
if text:
|
||||
await self._push_tts_frames(text)
|
||||
await self._push_tts_frames(
|
||||
text, includes_inter_frame_spaces=frame.includes_inter_frame_spaces
|
||||
)
|
||||
|
||||
async def _push_tts_frames(self, text: str):
|
||||
async def _push_tts_frames(self, text: str, includes_inter_frame_spaces: bool):
|
||||
# Remove leading newlines only
|
||||
text = text.lstrip("\n")
|
||||
|
||||
@@ -508,7 +506,7 @@ class TTSService(AIService):
|
||||
# We send the original text after the audio. This way, if we are
|
||||
# interrupted, the text is not added to the assistant context.
|
||||
frame = TTSTextFrame(text)
|
||||
frame.includes_inter_frame_spaces = self.includes_inter_frame_spaces
|
||||
frame.includes_inter_frame_spaces = includes_inter_frame_spaces
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def _stop_frame_handler(self):
|
||||
@@ -635,6 +633,8 @@ class WordTTSService(TTSService):
|
||||
frame = TTSStoppedFrame()
|
||||
frame.pts = last_pts
|
||||
else:
|
||||
# Assumption: word-by-word text frames don't include spaces, so
|
||||
# we can rely on the default includes_inter_frame_spaces=False
|
||||
frame = TTSTextFrame(word)
|
||||
frame.pts = self._initial_word_timestamp + timestamp
|
||||
if frame:
|
||||
@@ -671,7 +671,7 @@ class WebsocketTTSService(TTSService, WebsocketService):
|
||||
|
||||
async def _report_error(self, error: ErrorFrame):
|
||||
await self._call_event_handler("on_connection_error", error.error)
|
||||
await self.push_error(error)
|
||||
await self.push_error_frame(error)
|
||||
|
||||
|
||||
class InterruptibleTTSService(WebsocketTTSService):
|
||||
@@ -733,7 +733,7 @@ class WebsocketWordTTSService(WordTTSService, WebsocketService):
|
||||
|
||||
async def _report_error(self, error: ErrorFrame):
|
||||
await self._call_event_handler("on_connection_error", error.error)
|
||||
await self.push_error(error)
|
||||
await self.push_error_frame(error)
|
||||
|
||||
|
||||
class InterruptibleWordTTSService(WebsocketWordTTSService):
|
||||
|
||||
@@ -246,7 +246,7 @@ class UltravoxSTTService(AIService):
|
||||
|
||||
logger.info("Model warm-up completed successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"Model warm-up failed: {e}")
|
||||
await self.push_error(exception=e)
|
||||
|
||||
def _generate_silent_audio(self, sample_rate=16000, duration_sec=1.0):
|
||||
"""Generate silent audio as a numpy array.
|
||||
@@ -376,7 +376,7 @@ class UltravoxSTTService(AIService):
|
||||
if arr.size > 0: # Check if array is not empty
|
||||
audio_arrays.append(arr)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing bytes audio frame: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
# Handle numpy array data
|
||||
elif isinstance(f.audio, np.ndarray):
|
||||
if f.audio.size > 0: # Check if array is not empty
|
||||
@@ -436,17 +436,11 @@ class UltravoxSTTService(AIService):
|
||||
yield LLMFullResponseEndFrame()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating text from model: {e}")
|
||||
yield ErrorFrame(f"Error generating text: {str(e)}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
else:
|
||||
logger.warning("No model available for text generation")
|
||||
yield ErrorFrame("No model available for text generation")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing audio buffer: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
yield ErrorFrame(f"Error processing audio: {str(e)}")
|
||||
finally:
|
||||
self._buffer.is_processing = False
|
||||
|
||||
@@ -36,6 +36,7 @@ class WebsocketService(ABC):
|
||||
"""
|
||||
self._websocket: Optional[websockets.WebSocketClientProtocol] = None
|
||||
self._reconnect_on_error = reconnect_on_error
|
||||
self._reconnect_in_progress: bool = False # Add this flag
|
||||
|
||||
async def _verify_connection(self) -> bool:
|
||||
"""Verify the websocket connection is active and responsive.
|
||||
@@ -66,6 +67,59 @@ class WebsocketService(ABC):
|
||||
await self._connect_websocket()
|
||||
return await self._verify_connection()
|
||||
|
||||
async def _try_reconnect(
|
||||
self,
|
||||
max_retries: int = 3,
|
||||
report_error: Optional[Callable[[ErrorFrame], Awaitable[None]]] = None,
|
||||
) -> bool:
|
||||
# Prevent concurrent reconnection attempts
|
||||
if self._reconnect_in_progress:
|
||||
logger.warning(f"{self} reconnect attempt aborted: already in progress")
|
||||
return False
|
||||
|
||||
self._reconnect_in_progress = True
|
||||
last_exception: Optional[Exception] = None
|
||||
try:
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
logger.warning(f"{self} reconnecting, attempt {attempt}")
|
||||
if await self._reconnect_websocket(attempt):
|
||||
logger.info(f"{self} reconnected successfully on attempt {attempt}")
|
||||
return True
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
logger.error(f"{self} reconnection attempt {attempt} failed: {e}")
|
||||
if report_error:
|
||||
await report_error(
|
||||
ErrorFrame(f"{self} reconnection attempt {attempt} failed: {e}")
|
||||
)
|
||||
wait_time = exponential_backoff_time(attempt)
|
||||
await asyncio.sleep(wait_time)
|
||||
fatal_msg = f"{self} failed to reconnect after {max_retries} attempts"
|
||||
if last_exception:
|
||||
fatal_msg += f": {last_exception}"
|
||||
logger.error(fatal_msg)
|
||||
if report_error:
|
||||
await report_error(ErrorFrame(fatal_msg, fatal=True))
|
||||
return False
|
||||
finally:
|
||||
self._reconnect_in_progress = False
|
||||
|
||||
async def send_with_retry(self, message, report_error: Callable[[ErrorFrame], Awaitable[None]]):
|
||||
"""Attempt to send a message, retrying after reconnect if necessary."""
|
||||
try:
|
||||
await self._websocket.send(message)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} send failed: {e}, will try to reconnect")
|
||||
# Try to reconnect before retrying
|
||||
success = await self._try_reconnect(report_error=report_error)
|
||||
if success:
|
||||
logger.info(f"{self} reconnected successfully, will retry send the message")
|
||||
# trying to send the message one more time
|
||||
await self._websocket.send(message)
|
||||
else:
|
||||
logger.error(f"{self} send failed; unable to reconnect")
|
||||
|
||||
async def _receive_task_handler(self, report_error: Callable[[ErrorFrame], Awaitable[None]]):
|
||||
"""Handle websocket message receiving with automatic retry logic.
|
||||
|
||||
@@ -76,13 +130,9 @@ class WebsocketService(ABC):
|
||||
Args:
|
||||
report_error: Callback function to report connection errors.
|
||||
"""
|
||||
retry_count = 0
|
||||
MAX_RETRIES = 3
|
||||
|
||||
while True:
|
||||
try:
|
||||
await self._receive_messages()
|
||||
retry_count = 0 # Reset counter on successful message receive
|
||||
except ConnectionClosedOK as e:
|
||||
# Normal closure, don't retry
|
||||
logger.debug(f"{self} connection closed normally: {e}")
|
||||
@@ -92,21 +142,9 @@ class WebsocketService(ABC):
|
||||
logger.error(message)
|
||||
|
||||
if self._reconnect_on_error:
|
||||
retry_count += 1
|
||||
if retry_count >= MAX_RETRIES:
|
||||
await report_error(ErrorFrame(message, fatal=True))
|
||||
success = await self._try_reconnect(report_error=report_error)
|
||||
if not success:
|
||||
break
|
||||
|
||||
logger.warning(f"{self} connection error, will retry: {e}")
|
||||
await report_error(ErrorFrame(message))
|
||||
|
||||
try:
|
||||
if await self._reconnect_websocket(retry_count):
|
||||
retry_count = 0 # Reset counter on successful reconnection
|
||||
wait_time = exponential_backoff_time(retry_count)
|
||||
await asyncio.sleep(wait_time)
|
||||
except Exception as reconnect_error:
|
||||
logger.error(f"{self} reconnection failed: {reconnect_error}")
|
||||
else:
|
||||
await report_error(ErrorFrame(message))
|
||||
break
|
||||
|
||||
@@ -226,8 +226,7 @@ class BaseWhisperSTTService(SegmentedSTTService):
|
||||
logger.warning("Received empty transcription from API")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Exception during transcription: {e}")
|
||||
yield ErrorFrame(f"Error during transcription: {str(e)}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
async def _transcribe(self, audio: bytes) -> Transcription:
|
||||
"""Transcribe audio data to text.
|
||||
|
||||
@@ -285,7 +285,6 @@ class WhisperSTTService(SegmentedSTTService):
|
||||
The service will normalize it to float32 in the range [-1, 1].
|
||||
"""
|
||||
if not self._model:
|
||||
logger.error(f"{self} error: Whisper model not available")
|
||||
yield ErrorFrame("Whisper model not available")
|
||||
return
|
||||
|
||||
@@ -428,5 +427,4 @@ class WhisperSTTServiceMLX(WhisperSTTService):
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"MLX Whisper transcription error: {e}")
|
||||
yield ErrorFrame(f"MLX Whisper transcription error: {str(e)}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
@@ -141,13 +141,8 @@ class XTTSService(TTSService):
|
||||
async with self._aiohttp_session.get(self._settings["base_url"] + "/studio_speakers") as r:
|
||||
if r.status != 200:
|
||||
text = await r.text()
|
||||
logger.error(
|
||||
f"{self} error getting studio speakers (status: {r.status}, error: {text})"
|
||||
)
|
||||
await self.push_error(
|
||||
ErrorFrame(
|
||||
f"Error error getting studio speakers (status: {r.status}, error: {text})"
|
||||
)
|
||||
error_msg=f"{self} error getting studio speakers (status: {r.status}, error: {text})"
|
||||
)
|
||||
return
|
||||
self._studio_speakers = await r.json()
|
||||
@@ -186,8 +181,7 @@ class XTTSService(TTSService):
|
||||
async with self._aiohttp_session.post(url, json=payload) as r:
|
||||
if r.status != 200:
|
||||
text = await r.text()
|
||||
logger.error(f"{self} error getting audio (status: {r.status}, error: {text})")
|
||||
yield ErrorFrame(f"Error getting audio (status: {r.status}, error: {text})")
|
||||
yield ErrorFrame(error=f"Error getting audio (status: {r.status}, error: {text})")
|
||||
return
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
@@ -2506,13 +2506,10 @@ class DailyTransport(BaseTransport):
|
||||
async def _on_error(self, error):
|
||||
"""Handle error events and push error frames."""
|
||||
await self._call_event_handler("on_error", error)
|
||||
# Push error frame to notify the pipeline
|
||||
error_frame = ErrorFrame(error)
|
||||
|
||||
if self._input:
|
||||
await self._input.push_error(error_frame)
|
||||
await self._input.push_error(error_msg=error)
|
||||
elif self._output:
|
||||
await self._output.push_error(error_frame)
|
||||
await self._output.push_error(error_msg=error)
|
||||
else:
|
||||
logger.error("Both input and output are None while trying to push error")
|
||||
raise Exception("No valid input or output channel to push error")
|
||||
@@ -2568,7 +2565,7 @@ class DailyTransport(BaseTransport):
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Timeout handling dialin-ready event ({url})")
|
||||
except Exception as e:
|
||||
logger.exception(f"Error handling dialin-ready event ({url}): {e}")
|
||||
logger.error(f"Error handling dialin-ready event ({url}): {e}")
|
||||
|
||||
async def _on_dialin_connected(self, data):
|
||||
"""Handle dial-in connected events."""
|
||||
|
||||
@@ -316,7 +316,7 @@ class SmallWebRTCConnection(BaseObject):
|
||||
logger.debug("Client not connected. Queuing app-message.")
|
||||
self._pending_app_messages.append(json_message)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error parsing JSON message {message}, {e}")
|
||||
logger.error(f"Error parsing JSON message {message}, {e}")
|
||||
|
||||
# Despite the fact that aiortc provides this listener, they don't have a status for "disconnected"
|
||||
# So, in case we loose connection, this event will not be triggered
|
||||
|
||||
@@ -265,7 +265,7 @@ class TavusTransportClient:
|
||||
try:
|
||||
await self._client.cleanup()
|
||||
except Exception as e:
|
||||
logger.exception(f"Exception during cleanup: {e}")
|
||||
logger.error(f"Exception during cleanup: {e}")
|
||||
|
||||
async def _on_joined(self, data):
|
||||
"""Handle joined event."""
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user