Compare commits
42 Commits
hush/firew
...
mb/db-rime
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
713d20e4fc | ||
|
|
adc45bd282 | ||
|
|
fc544fa61c | ||
|
|
976fe95304 | ||
|
|
408270b647 | ||
|
|
1dfb75bc9d | ||
|
|
cefc2a1088 | ||
|
|
3b9b9200ea | ||
|
|
d6f29a0f4b | ||
|
|
5b762d11ef | ||
|
|
2f3e2da6b9 | ||
|
|
45058d4a94 | ||
|
|
5b637bd826 | ||
|
|
2d4fd7e903 | ||
|
|
b5662520aa | ||
|
|
af45c170b5 | ||
|
|
65f548b2ec | ||
|
|
b29ab8c608 | ||
|
|
d6dc37f0b6 | ||
|
|
12bce2e8c0 | ||
|
|
4acf7296e0 | ||
|
|
98706d429c | ||
|
|
41720b1a13 | ||
|
|
3ef4245166 | ||
|
|
3bb0797922 | ||
|
|
7c7b4c52af | ||
|
|
01f083b7fc | ||
|
|
91fcaebe25 | ||
|
|
9c5fe5c85e | ||
|
|
7e5e167a4b | ||
|
|
d04c4b36f3 | ||
|
|
a811e53626 | ||
|
|
df57202a05 | ||
|
|
69e6f3fdb7 | ||
|
|
6809254963 | ||
|
|
81093d3bed | ||
|
|
d9a67164f6 | ||
|
|
d0f67fc189 | ||
|
|
6e3f96aa83 | ||
|
|
293677588d | ||
|
|
a5cdd5f1b8 | ||
|
|
5f937b8479 |
41
CHANGELOG.md
41
CHANGELOG.md
@@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Added
|
||||
|
||||
- Added new log observers `LLMLogObserver` and `TranscriptionLogObserver` that
|
||||
can be useful for debugging your pipelines.
|
||||
|
||||
- Added `room_url` property to `DailyTransport`.
|
||||
|
||||
- Added `addons` argument to `DeepgramSTTService`.
|
||||
@@ -17,6 +20,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Changed
|
||||
|
||||
- `AnthropicLLMService` now uses `claude-3-7-sonnet-20250219` as the default
|
||||
model.
|
||||
|
||||
- `RimeHttpTTSService` needs an `aiohttp.ClientSession` to be passed to the
|
||||
constructor as all the other HTTP-based services.
|
||||
|
||||
- `RimeHttpTTSService` doesn't use a default voice anymore.
|
||||
|
||||
- `DeepgramSTTService` now uses the new `nova-3` model by default. If you want
|
||||
to use the previous model you can pass `LiveOptions(model="nova-2-general")`.
|
||||
(see https://deepgram.com/learn/introducing-nova-3-speech-to-text-api)
|
||||
@@ -25,8 +36,38 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
stt = DeepgramSTTService(..., live_options=LiveOptions(model="nova-2-general"))
|
||||
```
|
||||
|
||||
### Removed
|
||||
|
||||
- Remove `TransportParams.audio_out_is_live` since it was not being used at all.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed a `ElevenLabsTTSService`, `FishAudioTTSService`, `LMNTTTSService` and
|
||||
`PlayHTTTSService` issue that was resulting in audio requested before an
|
||||
interruption being played after an interruption.
|
||||
|
||||
- Fixed `match_endofsentence` support for ellipses.
|
||||
|
||||
- Fixed an issue that would cause undesired interruptions via
|
||||
`EmulateUserStartedSpeakingFrame` when only interim transcriptions (i.e. no
|
||||
final transcriptions) where received.
|
||||
|
||||
- Fixed an issue where `EndTaskFrame` was not triggering
|
||||
`on_client_disconnected` or closing the WebSocket in FastAPI.
|
||||
|
||||
- Fixed an issue in `DeepgramSTTService` where the `sample_rate` passed to the
|
||||
`LiveOptions` was not being used, causing the service to use the default
|
||||
sample rate of pipeline.
|
||||
|
||||
- Fixed a context aggregator issue that would not append the LLM text response
|
||||
to the context if a function call happened in the same LLM turn.
|
||||
|
||||
- Fixed an issue that was causing HTTP TTS services to push `TTSStoppedFrame`
|
||||
more than once.
|
||||
|
||||
- Fixed a `FishAudioTTSService` issue where `TTSStoppedFrame` was not being
|
||||
pushed.
|
||||
|
||||
- Fixed an issue that `start_callback` was not invoked for some LLM services.
|
||||
|
||||
- Fixed an issue that would cause `DeepgramSTTService` to stop working after an
|
||||
|
||||
@@ -18,6 +18,9 @@ AZURE_DALLE_API_KEY=...
|
||||
AZURE_DALLE_ENDPOINT=https://...
|
||||
AZURE_DALLE_MODEL=...
|
||||
|
||||
# Cartesia
|
||||
CARTESIA_API_KEY=...
|
||||
|
||||
# Daily
|
||||
DAILY_API_KEY=...
|
||||
DAILY_SAMPLE_ROOM_URL=https://...
|
||||
|
||||
@@ -38,7 +38,6 @@ async def main():
|
||||
"GStreamer",
|
||||
DailyParams(
|
||||
audio_out_enabled=True,
|
||||
audio_out_is_live=True,
|
||||
camera_out_enabled=True,
|
||||
camera_out_width=1280,
|
||||
camera_out_height=720,
|
||||
|
||||
@@ -18,12 +18,10 @@ from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
Frame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMTextFrame,
|
||||
StartInterruptionFrame,
|
||||
)
|
||||
from pipecat.observers.base_observer import BaseObserver
|
||||
from pipecat.observers.loggers.llm_log_observer import LLMLogObserver
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
@@ -73,38 +71,6 @@ class DebugObserver(BaseObserver):
|
||||
logger.info(f"🤖 BOT STOP SPEAKING: {src} {arrow} {dst} at {time_sec:.2f}s")
|
||||
|
||||
|
||||
class LLMLogObserver(BaseObserver):
|
||||
"""Observer to log LLM activity to the console.
|
||||
|
||||
Logs all frame instances of:
|
||||
- LLMFullResponseStartFrame (only from LLM service)
|
||||
- LLMTextFrame
|
||||
- LLMFullResponseEndFrame (only from LLM service)
|
||||
|
||||
This allows you to track when the LLM starts responding, what it generates, and when it finishes.
|
||||
Log format: [LLM EVENT]: [details] at [timestamp]s
|
||||
"""
|
||||
|
||||
async def on_push_frame(
|
||||
self,
|
||||
src: FrameProcessor,
|
||||
dst: FrameProcessor,
|
||||
frame: Frame,
|
||||
direction: FrameDirection,
|
||||
timestamp: int,
|
||||
):
|
||||
time_sec = timestamp / 1_000_000_000
|
||||
|
||||
# Only log start/end frames from OpenAILLMService
|
||||
if isinstance(frame, (LLMFullResponseStartFrame, LLMFullResponseEndFrame)):
|
||||
if isinstance(src, OpenAILLMService):
|
||||
event = "START" if isinstance(frame, LLMFullResponseStartFrame) else "END"
|
||||
logger.info(f"🧠 LLM {event} RESPONSE at {time_sec:.2f}s")
|
||||
# Log all LLMTextFrames
|
||||
elif isinstance(frame, LLMTextFrame):
|
||||
logger.info(f"🧠 LLM GENERATING: {frame.text!r} at {time_sec:.2f}s")
|
||||
|
||||
|
||||
async def main():
|
||||
async with aiohttp.ClientSession() as session:
|
||||
(room_url, token) = await configure(session)
|
||||
|
||||
@@ -33,7 +33,8 @@ dependencies = [
|
||||
"pydantic~=2.10.5",
|
||||
"pyloudnorm~=0.1.1",
|
||||
"resampy~=0.4.3",
|
||||
"soxr~=0.5.0"
|
||||
"soxr~=0.5.0",
|
||||
"openai~=1.59.6"
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
@@ -44,11 +45,11 @@ Website = "https://pipecat.ai"
|
||||
anthropic = [ "anthropic~=0.45.2" ]
|
||||
assemblyai = [ "assemblyai~=0.36.0" ]
|
||||
aws = [ "boto3~=1.35.99" ]
|
||||
azure = [ "azure-cognitiveservices-speech~=1.42.0", "openai~=1.59.6" ]
|
||||
azure = [ "azure-cognitiveservices-speech~=1.42.0"]
|
||||
canonical = [ "aiofiles~=24.1.0" ]
|
||||
cartesia = [ "cartesia~=1.3.1", "websockets~=13.1" ]
|
||||
cerebras = [ "openai~=1.59.6" ]
|
||||
deepseek = [ "openai~=1.59.6" ]
|
||||
cerebras = []
|
||||
deepseek = []
|
||||
daily = [ "daily-python~=0.14.2" ]
|
||||
deepgram = [ "deepgram-sdk~=3.8.0" ]
|
||||
elevenlabs = [ "websockets~=13.1" ]
|
||||
@@ -56,10 +57,10 @@ fal = [ "fal-client~=0.5.6" ]
|
||||
fish = [ "ormsgpack~=1.7.0", "websockets~=13.1" ]
|
||||
gladia = [ "websockets~=13.1" ]
|
||||
google = [ "google-cloud-speech~=2.31.0", "google-cloud-texttospeech~=2.25.0", "google-genai~=1.2.0", "google-generativeai~=0.8.4" ]
|
||||
grok = [ "openai~=1.59.6" ]
|
||||
groq = [ "openai~=1.59.6" ]
|
||||
grok = []
|
||||
groq = []
|
||||
gstreamer = [ "pygobject~=3.50.0" ]
|
||||
fireworks = [ "openai~=1.59.6" ]
|
||||
fireworks = []
|
||||
krisp = [ "pipecat-ai-krisp~=0.3.0" ]
|
||||
koala = [ "pvkoala~=2.0.3" ]
|
||||
langchain = [ "langchain~=0.3.14", "langchain-community~=0.3.14", "langchain-openai~=0.3.0" ]
|
||||
@@ -67,11 +68,11 @@ livekit = [ "livekit~=0.19.1", "livekit-api~=0.8.1", "tenacity~=9.0.0" ]
|
||||
lmnt = [ "websockets~=13.1" ]
|
||||
local = [ "pyaudio~=0.2.14" ]
|
||||
moondream = [ "einops~=0.8.0", "timm~=1.0.13", "transformers~=4.48.0" ]
|
||||
nim = [ "openai~=1.59.6" ]
|
||||
nim = []
|
||||
noisereduce = [ "noisereduce~=3.0.3" ]
|
||||
openai = [ "openai~=1.59.6", "websockets~=13.1", "python-deepcompare~=2.1.0" ]
|
||||
openai = [ "websockets~=13.1" ]
|
||||
openpipe = [ "openpipe~=4.45.0" ]
|
||||
perplexity = [ "openai~=1.59.6" ]
|
||||
perplexity = []
|
||||
playht = [ "pyht~=0.1.6", "websockets~=13.1" ]
|
||||
rime = [ "websockets~=13.1" ]
|
||||
riva = [ "nvidia-riva-client~=2.18.0" ]
|
||||
@@ -79,10 +80,10 @@ sentry = [ "sentry-sdk~=2.20.0" ]
|
||||
silero = [ "onnxruntime~=1.20.1" ]
|
||||
simli = [ "simli-ai~=0.1.10"]
|
||||
soundfile = [ "soundfile~=0.13.0" ]
|
||||
together = [ "openai~=1.59.6" ]
|
||||
together = []
|
||||
websocket = [ "websockets~=13.1", "fastapi~=0.115.6" ]
|
||||
whisper = [ "faster-whisper~=1.1.1" ]
|
||||
openrouter = [ "openai~=1.59.6" ]
|
||||
openrouter = []
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
# All the following settings are optional:
|
||||
|
||||
0
src/pipecat/observers/loggers/__init__.py
Normal file
0
src/pipecat/observers/loggers/__init__.py
Normal file
85
src/pipecat/observers/loggers/llm_log_observer.py
Normal file
85
src/pipecat/observers/loggers/llm_log_observer.py
Normal file
@@ -0,0 +1,85 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesFrame,
|
||||
LLMTextFrame,
|
||||
)
|
||||
from pipecat.observers.base_observer import BaseObserver
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.services.ai_services import LLMService
|
||||
|
||||
|
||||
class LLMLogObserver(BaseObserver):
|
||||
"""Observer to log LLM activity to the console.
|
||||
|
||||
Logs all frame instances (only from/to LLM service) of:
|
||||
|
||||
- LLMFullResponseStartFrame
|
||||
- LLMFullResponseEndFrame
|
||||
- LLMTextFrame
|
||||
- FunctionCallInProgressFrame
|
||||
- LLMMessagesFrame
|
||||
- OpenAILLMContextFrame
|
||||
|
||||
This allows you to track when the LLM starts responding, what it generates,
|
||||
and when it finishes.
|
||||
|
||||
"""
|
||||
|
||||
async def on_push_frame(
|
||||
self,
|
||||
src: FrameProcessor,
|
||||
dst: FrameProcessor,
|
||||
frame: Frame,
|
||||
direction: FrameDirection,
|
||||
timestamp: int,
|
||||
):
|
||||
if not isinstance(src, LLMService) and not isinstance(dst, LLMService):
|
||||
return
|
||||
|
||||
time_sec = timestamp / 1_000_000_000
|
||||
|
||||
arrow = "→"
|
||||
|
||||
# Log LLM start/end frames (output)
|
||||
if isinstance(frame, (LLMFullResponseStartFrame, LLMFullResponseEndFrame)):
|
||||
event = "START" if isinstance(frame, LLMFullResponseStartFrame) else "END"
|
||||
logger.debug(f"🧠 {src} {arrow} LLM {event} RESPONSE at {time_sec:.2f}s")
|
||||
# Log all LLMTextFrames (output)
|
||||
elif isinstance(frame, LLMTextFrame):
|
||||
logger.debug(f"🧠 {src} {arrow} LLM GENERATING: {frame.text!r} at {time_sec:.2f}s")
|
||||
# Log function calling (output)
|
||||
elif (
|
||||
isinstance(frame, FunctionCallInProgressFrame)
|
||||
and direction != FrameDirection.DOWNSTREAM
|
||||
):
|
||||
logger.debug(
|
||||
f"🧠 {src} {arrow} LLM FUNCTION CALL ({frame.tool_call_id}): {frame.function_name!r}({frame.arguments}) at {time_sec:.2f}s"
|
||||
)
|
||||
# Log LLMMessagesFrame (input)
|
||||
elif isinstance(frame, LLMMessagesFrame):
|
||||
logger.debug(
|
||||
f"🧠 {arrow} {dst} LLM MESSAGES FRAME: {frame.messages} at {time_sec:.2f}s"
|
||||
)
|
||||
# Log OpenAILLMContextFrame (input)
|
||||
elif isinstance(frame, OpenAILLMContextFrame):
|
||||
logger.debug(
|
||||
f"🧠 {arrow} {dst} LLM CONTEXT FRAME: {frame.context.messages} at {time_sec:.2f}s"
|
||||
)
|
||||
# Log function call result (input)
|
||||
elif isinstance(frame, FunctionCallResultFrame):
|
||||
logger.debug(
|
||||
f"🧠 {arrow} {src} LLM FUNCTION CALL RESULT ({frame.tool_call_id}): {frame.result} at {time_sec:.2f}s"
|
||||
)
|
||||
54
src/pipecat/observers/loggers/transcription_log_observer.py
Normal file
54
src/pipecat/observers/loggers/transcription_log_observer.py
Normal file
@@ -0,0 +1,54 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
TranscriptionFrame,
|
||||
)
|
||||
from pipecat.observers.base_observer import BaseObserver
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.services.ai_services import STTService
|
||||
|
||||
|
||||
class TranscriptionLogObserver(BaseObserver):
|
||||
"""Observer to log transcription activity to the console.
|
||||
|
||||
Logs all frame instances (only from STT service) of:
|
||||
|
||||
- TranscriptionFrame
|
||||
- InterimTranscriptionFrame
|
||||
|
||||
This allows you to track when the LLM starts responding, what it generates,
|
||||
and when it finishes.
|
||||
|
||||
"""
|
||||
|
||||
async def on_push_frame(
|
||||
self,
|
||||
src: FrameProcessor,
|
||||
dst: FrameProcessor,
|
||||
frame: Frame,
|
||||
direction: FrameDirection,
|
||||
timestamp: int,
|
||||
):
|
||||
if not isinstance(src, STTService):
|
||||
return
|
||||
|
||||
time_sec = timestamp / 1_000_000_000
|
||||
|
||||
arrow = "→"
|
||||
|
||||
if isinstance(frame, TranscriptionFrame):
|
||||
logger.debug(
|
||||
f"💬 {src} {arrow} TRANSCRIPTION: {frame.text!r} from {frame.user_id!r} at {time_sec:.2f}s"
|
||||
)
|
||||
elif isinstance(frame, InterimTranscriptionFrame):
|
||||
logger.debug(
|
||||
f"💬 {src} {arrow} INTERIM TRANSCRIPTION: {frame.text!r} from {frame.user_id!r} at {time_sec:.2f}s"
|
||||
)
|
||||
@@ -145,6 +145,9 @@ class LLMResponseAggregator(BaseLLMResponseAggregator):
|
||||
frame = LLMMessagesFrame(self._messages)
|
||||
await self.push_frame(frame)
|
||||
|
||||
# Reset our accumulator state.
|
||||
self.reset()
|
||||
|
||||
|
||||
class LLMContextResponseAggregator(BaseLLMResponseAggregator):
|
||||
"""This is a base LLM aggregator that uses an LLM context to store the
|
||||
@@ -290,7 +293,13 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
|
||||
await self.push_aggregation()
|
||||
|
||||
async def _handle_transcription(self, frame: TranscriptionFrame):
|
||||
self._aggregation += f" {frame.text}" if self._aggregation else frame.text
|
||||
text = frame.text
|
||||
|
||||
# Make sure we really have some text.
|
||||
if not text.strip():
|
||||
return
|
||||
|
||||
self._aggregation += f" {text}" if self._aggregation else text
|
||||
# We just got a final result, so let's reset interim results.
|
||||
self._seen_interim_results = False
|
||||
# Reset aggregation timer.
|
||||
@@ -298,8 +307,6 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
|
||||
|
||||
async def _handle_interim_transcription(self, _: InterimTranscriptionFrame):
|
||||
self._seen_interim_results = True
|
||||
# Reset aggregation timer.
|
||||
self._aggregation_event.set()
|
||||
|
||||
def _create_aggregation_task(self):
|
||||
self._aggregation_task = self.create_task(self._aggregation_task_handler())
|
||||
|
||||
@@ -12,6 +12,12 @@ from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Callable, List, Optional
|
||||
|
||||
from loguru import logger
|
||||
from openai._types import NOT_GIVEN, NotGiven
|
||||
from openai.types.chat import (
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionToolChoiceOptionParam,
|
||||
ChatCompletionToolParam,
|
||||
)
|
||||
from PIL import Image
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
@@ -22,20 +28,6 @@ from pipecat.frames.frames import (
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
try:
|
||||
from openai._types import NOT_GIVEN, NotGiven
|
||||
from openai.types.chat import (
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionToolChoiceOptionParam,
|
||||
ChatCompletionToolParam,
|
||||
)
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use OpenAI, you need to `pip install pipecat-ai[openai]`. Also, set `OPENAI_API_KEY` environment variable."
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
# JSON custom encoder to handle bytes arrays so that we can log contexts
|
||||
# with images to the console.
|
||||
|
||||
|
||||
@@ -22,9 +22,8 @@ from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
class AudioBufferProcessor(FrameProcessor):
|
||||
"""This processor buffers audio raw frames (input and output). The mixed
|
||||
audio can be obtained by calling `get_audio()` (if `buffer_size` is 0) or by
|
||||
registering an "on_audio_data" event handler. The event handler will be
|
||||
called every time `buffer_size` is reached.
|
||||
audio can be obtained by registering an "on_audio_data" event handler.
|
||||
The event handler will be called every time `buffer_size` is reached.
|
||||
|
||||
You can provide the desired output `sample_rate` and incoming audio frames
|
||||
will resampled to match it. Also, you can provide the number of channels, 1
|
||||
|
||||
@@ -15,6 +15,7 @@ from loguru import logger
|
||||
from pipecat.audio.utils import calculate_audio_volume, exp_smoothing
|
||||
from pipecat.frames.frames import (
|
||||
AudioRawFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
@@ -40,6 +41,7 @@ from pipecat.frames.frames import (
|
||||
from pipecat.metrics.metrics import MetricsData
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.services.websocket_service import WebsocketService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.string import match_endofsentence
|
||||
from pipecat.utils.text.base_text_filter import BaseTextFilter
|
||||
@@ -209,7 +211,7 @@ class TTSService(AIService):
|
||||
# if True, TTSService will push TTSStoppedFrames, otherwise subclass must do it
|
||||
push_stop_frames: bool = False,
|
||||
# if push_stop_frames is True, wait for this idle period before pushing TTSStoppedFrame
|
||||
stop_frame_timeout_s: float = 1.0,
|
||||
stop_frame_timeout_s: float = 2.0,
|
||||
# if True, TTSService will push silence audio frames after TTSStoppedFrame
|
||||
push_silence_after_stop: bool = False,
|
||||
# if push_silence_after_stop is True, send this amount of audio silence
|
||||
@@ -434,6 +436,12 @@ class TTSService(AIService):
|
||||
|
||||
|
||||
class WordTTSService(TTSService):
|
||||
"""This is a base class for TTS services that support word timestamps. Word
|
||||
timestamps are useful to synchronize audio with text of the spoken
|
||||
words. This way only the spoken words are added to the conversation context.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._initial_word_timestamp = -1
|
||||
@@ -503,11 +511,93 @@ class WordTTSService(TTSService):
|
||||
self._words_queue.task_done()
|
||||
|
||||
|
||||
class AudioContextWordTTSService(WordTTSService):
|
||||
"""This services allow us to send multiple TTS request to the services. Each
|
||||
request could be multiple sentences long which are grouped by context. For
|
||||
this to work, the TTS service needs to support handling multiple requests at
|
||||
once (i.e. multiple simultaneous contexts).
|
||||
class WebsocketTTSService(TTSService, WebsocketService):
|
||||
"""This is a base class for websocket-based TTS services."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
TTSService.__init__(self, **kwargs)
|
||||
WebsocketService.__init__(self)
|
||||
|
||||
|
||||
class InterruptibleTTSService(WebsocketTTSService):
|
||||
"""This is a base class for websocket-based TTS services that don't support
|
||||
word timestamps and that don't offer a way to correlate the generated audio
|
||||
to the requested text.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Indicates if the bot is speaking. If the bot is not speaking we don't
|
||||
# need to reconnect when the user speaks. If the bot is speaking and the
|
||||
# user interrupts we need to reconnect.
|
||||
self._bot_speaking = False
|
||||
|
||||
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
|
||||
await super()._handle_interruption(frame, direction)
|
||||
if self._bot_speaking:
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, BotStartedSpeakingFrame):
|
||||
self._bot_speaking = True
|
||||
elif isinstance(frame, BotStoppedSpeakingFrame):
|
||||
self._bot_speaking = False
|
||||
|
||||
|
||||
class WebsocketWordTTSService(WordTTSService, WebsocketService):
|
||||
"""This is a base class for websocket-based TTS services that support word
|
||||
timestamps.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
WordTTSService.__init__(self, **kwargs)
|
||||
WebsocketService.__init__(self)
|
||||
|
||||
|
||||
class InterruptibleWordTTSService(WebsocketWordTTSService):
|
||||
"""This is a base class for websocket-based TTS services that support word
|
||||
timestamps but don't offer a way to correlate the generated audio to the
|
||||
requested text.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Indicates if the bot is speaking. If the bot is not speaking we don't
|
||||
# need to reconnect when the user speaks. If the bot is speaking and the
|
||||
# user interrupts we need to reconnect.
|
||||
self._bot_speaking = False
|
||||
|
||||
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
|
||||
await super()._handle_interruption(frame, direction)
|
||||
if self._bot_speaking:
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, BotStartedSpeakingFrame):
|
||||
self._bot_speaking = True
|
||||
elif isinstance(frame, BotStoppedSpeakingFrame):
|
||||
self._bot_speaking = False
|
||||
|
||||
|
||||
class AudioContextWordTTSService(WebsocketWordTTSService):
|
||||
"""This is a base class for websocket-based TTS services that support word
|
||||
timestamps and also allow correlating the generated audio with the requested
|
||||
text.
|
||||
|
||||
Each request could be multiple sentences long which are grouped by
|
||||
context. For this to work, the TTS service needs to support handling
|
||||
multiple requests at once (i.e. multiple simultaneous contexts).
|
||||
|
||||
The audio received from the TTS will be played in context order. That is, if
|
||||
we requested audio for a context "A" and then audio for context "B", the
|
||||
|
||||
@@ -96,7 +96,7 @@ class AnthropicLLMService(LLMService):
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
model: str = "claude-3-5-sonnet-20241022",
|
||||
model: str = "claude-3-7-sonnet-20250219",
|
||||
params: InputParams = InputParams(),
|
||||
client=None,
|
||||
**kwargs,
|
||||
@@ -743,18 +743,19 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
run_llm = False
|
||||
properties: Optional[FunctionCallResultProperties] = None
|
||||
|
||||
aggregation = self._aggregation
|
||||
aggregation = self._aggregation.strip()
|
||||
self.reset()
|
||||
|
||||
try:
|
||||
if aggregation:
|
||||
self._context.add_message({"role": "assistant", "content": aggregation})
|
||||
|
||||
if self._function_call_result:
|
||||
frame = self._function_call_result
|
||||
properties = frame.properties
|
||||
self._function_call_result = None
|
||||
if frame.result:
|
||||
assistant_message = {"role": "assistant", "content": []}
|
||||
if aggregation:
|
||||
assistant_message["content"].append({"type": "text", "text": aggregation})
|
||||
assistant_message["content"].append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
@@ -782,8 +783,6 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
else:
|
||||
# Default behavior
|
||||
run_llm = True
|
||||
elif aggregation:
|
||||
self._context.add_message({"role": "assistant", "content": aggregation})
|
||||
|
||||
if self._pending_image_frame_message:
|
||||
frame = self._pending_image_frame_message
|
||||
|
||||
@@ -10,6 +10,7 @@ from typing import AsyncGenerator, Optional
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
from openai import AsyncAzureOpenAI
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -48,7 +49,6 @@ try:
|
||||
PushAudioInputStream,
|
||||
)
|
||||
from azure.cognitiveservices.speech.dialog import AudioConfig
|
||||
from openai import AsyncAzureOpenAI
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
|
||||
@@ -7,22 +7,14 @@
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
from loguru import logger
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.audio import Transcription
|
||||
|
||||
from pipecat.frames.frames import ErrorFrame, Frame, TranscriptionFrame
|
||||
from pipecat.services.ai_services import SegmentedSTTService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
try:
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.audio import Transcription
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use OpenAI, you need to `pip install pipecat-ai[openai]`. Also, set `OPENAI_API_KEY` environment variable."
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
def language_to_whisper_language(language: Language) -> Optional[str]:
|
||||
"""Language support for Whisper API.
|
||||
|
||||
@@ -13,22 +13,18 @@ from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
LLMFullResponseEndFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSSpeakFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import AudioContextWordTTSService, TTSService
|
||||
from pipecat.services.websocket_service import WebsocketService
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
# See .env.example for Cartesia configuration needed
|
||||
@@ -75,7 +71,7 @@ def language_to_cartesia_language(language: Language) -> Optional[str]:
|
||||
return result
|
||||
|
||||
|
||||
class CartesiaTTSService(AudioContextWordTTSService, WebsocketService):
|
||||
class CartesiaTTSService(AudioContextWordTTSService):
|
||||
class InputParams(BaseModel):
|
||||
language: Optional[Language] = Language.EN
|
||||
speed: Optional[Union[str, float]] = ""
|
||||
@@ -105,15 +101,13 @@ class CartesiaTTSService(AudioContextWordTTSService, WebsocketService):
|
||||
# if we're interrupted. Cartesia gives us word-by-word timestamps. We
|
||||
# can use those to generate text frames ourselves aligned with the
|
||||
# playout timing of the audio!
|
||||
AudioContextWordTTSService.__init__(
|
||||
self,
|
||||
super().__init__(
|
||||
aggregate_sentences=True,
|
||||
push_text_frames=False,
|
||||
pause_frame_processing=True,
|
||||
sample_rate=sample_rate,
|
||||
**kwargs,
|
||||
)
|
||||
WebsocketService.__init__(self)
|
||||
|
||||
self._api_key = api_key
|
||||
self._cartesia_version = cartesia_version
|
||||
@@ -364,9 +358,6 @@ class CartesiaHttpTTSService(TTSService):
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Generating TTS: [{text}]")
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame()
|
||||
|
||||
try:
|
||||
voice_controls = None
|
||||
if self._settings["speed"] or self._settings["emotion"]:
|
||||
@@ -376,6 +367,8 @@ class CartesiaHttpTTSService(TTSService):
|
||||
if self._settings["emotion"]:
|
||||
voice_controls["emotion"] = self._settings["emotion"]
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
output = await self._client.tts.sse(
|
||||
model_id=self._model_name,
|
||||
transcript=text,
|
||||
@@ -386,14 +379,17 @@ class CartesiaHttpTTSService(TTSService):
|
||||
_experimental_voice_controls=voice_controls,
|
||||
)
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
yield TTSStartedFrame()
|
||||
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=output["audio"], sample_rate=self.sample_rate, num_channels=1
|
||||
)
|
||||
|
||||
yield frame
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame()
|
||||
finally:
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -7,22 +7,12 @@
|
||||
from typing import List
|
||||
|
||||
from loguru import logger
|
||||
from openai import AsyncStream
|
||||
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
|
||||
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.openai import OpenAILLMService
|
||||
|
||||
try:
|
||||
from openai import (
|
||||
AsyncStream,
|
||||
)
|
||||
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use Cerebras, you need to `pip install pipecat-ai[cerebras]`. Also, set `CEREBRAS_API_KEY` environment variable."
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class CerebrasLLMService(OpenAILLMService):
|
||||
"""A service for interacting with Cerebras's API using the OpenAI-compatible interface.
|
||||
|
||||
@@ -124,6 +124,7 @@ class DeepgramSTTService(STTService):
|
||||
addons: Optional[Dict] = None,
|
||||
**kwargs,
|
||||
):
|
||||
sample_rate = sample_rate or (live_options.sample_rate if live_options else None)
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
default_options = LiveOptions(
|
||||
|
||||
@@ -8,22 +8,12 @@
|
||||
from typing import List
|
||||
|
||||
from loguru import logger
|
||||
from openai import AsyncStream
|
||||
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
|
||||
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.openai import OpenAILLMService
|
||||
|
||||
try:
|
||||
from openai import (
|
||||
AsyncStream,
|
||||
)
|
||||
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use DeepSeek, you need to `pip install pipecat-ai[deepseek]`. Also, set `DEEPSEEK_API_KEY` environment variable."
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class DeepSeekLLMService(OpenAILLMService):
|
||||
"""A service for interacting with DeepSeek's API using the OpenAI-compatible interface.
|
||||
|
||||
@@ -14,22 +14,18 @@ from loguru import logger
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
LLMFullResponseEndFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSSpeakFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import TTSService, WordTTSService
|
||||
from pipecat.services.websocket_service import WebsocketService
|
||||
from pipecat.services.ai_services import InterruptibleWordTTSService, TTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
# See .env.example for ElevenLabs configuration needed
|
||||
@@ -141,7 +137,7 @@ def calculate_word_times(
|
||||
return word_times
|
||||
|
||||
|
||||
class ElevenLabsTTSService(WordTTSService, WebsocketService):
|
||||
class ElevenLabsTTSService(InterruptibleWordTTSService):
|
||||
class InputParams(BaseModel):
|
||||
language: Optional[Language] = None
|
||||
optimize_streaming_latency: Optional[str] = None
|
||||
@@ -186,17 +182,14 @@ class ElevenLabsTTSService(WordTTSService, WebsocketService):
|
||||
# Finally, ElevenLabs doesn't provide information on when the bot stops
|
||||
# speaking for a while, so we want the parent class to send TTSStopFrame
|
||||
# after a short period not receiving any audio.
|
||||
WordTTSService.__init__(
|
||||
self,
|
||||
super().__init__(
|
||||
aggregate_sentences=True,
|
||||
push_text_frames=False,
|
||||
push_stop_frames=True,
|
||||
stop_frame_timeout_s=2.0,
|
||||
pause_frame_processing=True,
|
||||
sample_rate=sample_rate,
|
||||
**kwargs,
|
||||
)
|
||||
WebsocketService.__init__(self)
|
||||
|
||||
self._api_key = api_key
|
||||
self._url = url
|
||||
@@ -567,18 +560,16 @@ class ElevenLabsHttpTTSService(TTSService):
|
||||
return
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
yield TTSStartedFrame()
|
||||
|
||||
async for chunk in response.content:
|
||||
if chunk:
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSAudioRawFrame(chunk, self.sample_rate, 1)
|
||||
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in run_tts: {e}")
|
||||
yield ErrorFrame(error=str(e))
|
||||
|
||||
finally:
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -8,19 +8,11 @@
|
||||
from typing import List
|
||||
|
||||
from loguru import logger
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.openai import OpenAILLMService
|
||||
|
||||
try:
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use Fireworks, you need to `pip install pipecat-ai[fireworks]`. Also, set `FIREWORKS_API_KEY` environment variable."
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class FireworksLLMService(OpenAILLMService):
|
||||
"""A service for interacting with Fireworks AI using the OpenAI-compatible interface.
|
||||
|
||||
@@ -11,22 +11,18 @@ from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
LLMFullResponseEndFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSSpeakFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import TTSService
|
||||
from pipecat.services.websocket_service import WebsocketService
|
||||
from pipecat.services.ai_services import InterruptibleTTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
try:
|
||||
@@ -43,7 +39,7 @@ except ModuleNotFoundError as e:
|
||||
FishAudioOutputFormat = Literal["opus", "mp3", "pcm", "wav"]
|
||||
|
||||
|
||||
class FishAudioTTSService(TTSService, WebsocketService):
|
||||
class FishAudioTTSService(InterruptibleTTSService):
|
||||
class InputParams(BaseModel):
|
||||
language: Optional[Language] = Language.EN
|
||||
latency: Optional[str] = "normal" # "normal" or "balanced"
|
||||
@@ -60,7 +56,12 @@ class FishAudioTTSService(TTSService, WebsocketService):
|
||||
params: InputParams = InputParams(),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(pause_frame_processing=True, sample_rate=sample_rate, **kwargs)
|
||||
super().__init__(
|
||||
push_stop_frames=True,
|
||||
pause_frame_processing=True,
|
||||
sample_rate=sample_rate,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self._api_key = api_key
|
||||
self._base_url = "wss://api.fish.audio/v1/tts/live"
|
||||
@@ -108,11 +109,12 @@ class FishAudioTTSService(TTSService, WebsocketService):
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self.push_error))
|
||||
|
||||
async def _disconnect(self):
|
||||
await self._disconnect_websocket()
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
await self._disconnect_websocket()
|
||||
|
||||
async def _connect_websocket(self):
|
||||
try:
|
||||
logger.debug("Connecting to Fish Audio")
|
||||
@@ -147,6 +149,11 @@ class FishAudioTTSService(TTSService, WebsocketService):
|
||||
return self._websocket
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
|
||||
await super()._handle_interruption(frame, direction)
|
||||
await self.stop_all_metrics()
|
||||
self._request_id = None
|
||||
|
||||
async def _receive_messages(self):
|
||||
async for message in self._get_websocket():
|
||||
try:
|
||||
@@ -166,11 +173,6 @@ class FishAudioTTSService(TTSService, WebsocketService):
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message: {e}")
|
||||
|
||||
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
|
||||
await super()._handle_interruption(frame, direction)
|
||||
await self.stop_all_metrics()
|
||||
self._request_id = None
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Generating Fish TTS: [{text}]")
|
||||
try:
|
||||
|
||||
@@ -565,10 +565,15 @@ class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
run_llm = False
|
||||
properties: Optional[FunctionCallResultProperties] = None
|
||||
|
||||
aggregation = self._aggregation
|
||||
aggregation = self._aggregation.strip()
|
||||
self.reset()
|
||||
|
||||
try:
|
||||
if aggregation:
|
||||
self._context.add_message(
|
||||
glm.Content(role="model", parts=[glm.Part(text=aggregation)])
|
||||
)
|
||||
|
||||
if self._function_call_result:
|
||||
frame = self._function_call_result
|
||||
properties = frame.properties
|
||||
@@ -608,11 +613,6 @@ class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
else:
|
||||
# Default behavior is to run the LLM if there are no function calls in progress
|
||||
run_llm = not bool(self._function_calls_in_progress)
|
||||
else:
|
||||
if aggregation.strip():
|
||||
self._context.add_message(
|
||||
glm.Content(role="model", parts=[glm.Part(text=aggregation)])
|
||||
)
|
||||
|
||||
if self._pending_image_frame_message:
|
||||
frame = self._pending_image_frame_message
|
||||
|
||||
@@ -37,10 +37,13 @@ class GrokAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
run_llm = False
|
||||
properties: Optional[FunctionCallResultProperties] = None
|
||||
|
||||
aggregation = self._aggregation
|
||||
aggregation = self._aggregation.strip()
|
||||
self.reset()
|
||||
|
||||
try:
|
||||
if aggregation:
|
||||
self._context.add_message({"role": "assistant", "content": aggregation})
|
||||
|
||||
if self._function_call_result:
|
||||
frame = self._function_call_result
|
||||
properties = frame.properties
|
||||
@@ -77,9 +80,6 @@ class GrokAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
# Default behavior is to run the LLM if there are no function calls in progress
|
||||
run_llm = not bool(self._function_calls_in_progress)
|
||||
|
||||
else:
|
||||
self._context.add_message({"role": "assistant", "content": aggregation})
|
||||
|
||||
if self._pending_image_frame_message:
|
||||
frame = self._pending_image_frame_message
|
||||
self._pending_image_frame_message = None
|
||||
|
||||
@@ -21,8 +21,7 @@ from pipecat.frames.frames import (
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import TTSService
|
||||
from pipecat.services.websocket_service import WebsocketService
|
||||
from pipecat.services.ai_services import InterruptibleTTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
# See .env.example for LMNT configuration needed
|
||||
@@ -60,7 +59,7 @@ def language_to_lmnt_language(language: Language) -> Optional[str]:
|
||||
return result
|
||||
|
||||
|
||||
class LmntTTSService(TTSService, WebsocketService):
|
||||
class LmntTTSService(InterruptibleTTSService):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -70,14 +69,12 @@ class LmntTTSService(TTSService, WebsocketService):
|
||||
language: Language = Language.EN,
|
||||
**kwargs,
|
||||
):
|
||||
TTSService.__init__(
|
||||
self,
|
||||
super().__init__(
|
||||
push_stop_frames=True,
|
||||
pause_frame_processing=True,
|
||||
sample_rate=sample_rate,
|
||||
**kwargs,
|
||||
)
|
||||
WebsocketService.__init__(self)
|
||||
|
||||
self._api_key = api_key
|
||||
self._voice_id = voice_id
|
||||
@@ -116,12 +113,12 @@ class LmntTTSService(TTSService, WebsocketService):
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self.push_error))
|
||||
|
||||
async def _disconnect(self):
|
||||
await self._disconnect_websocket()
|
||||
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
await self._disconnect_websocket()
|
||||
|
||||
async def _connect_websocket(self):
|
||||
"""Connect to LMNT websocket."""
|
||||
try:
|
||||
@@ -153,8 +150,9 @@ class LmntTTSService(TTSService, WebsocketService):
|
||||
|
||||
if self._websocket:
|
||||
logger.debug("Disconnecting from LMNT")
|
||||
# Send EOF message before closing
|
||||
await self._websocket.send(json.dumps({"eof": True}))
|
||||
# NOTE(aleix): sending EOF message before closing is causing
|
||||
# errors on the websocket, so we just skip it for now.
|
||||
# await self._websocket.send(json.dumps({"eof": True}))
|
||||
await self._websocket.close()
|
||||
self._websocket = None
|
||||
|
||||
|
||||
@@ -13,6 +13,14 @@ from typing import Any, AsyncGenerator, Dict, List, Literal, Optional
|
||||
import aiohttp
|
||||
import httpx
|
||||
from loguru import logger
|
||||
from openai import (
|
||||
NOT_GIVEN,
|
||||
AsyncOpenAI,
|
||||
AsyncStream,
|
||||
BadRequestError,
|
||||
DefaultAsyncHttpxClient,
|
||||
)
|
||||
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -57,23 +65,6 @@ from pipecat.services.base_whisper import BaseWhisperSTTService, Transcription
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
try:
|
||||
from openai import (
|
||||
NOT_GIVEN,
|
||||
AsyncOpenAI,
|
||||
AsyncStream,
|
||||
BadRequestError,
|
||||
DefaultAsyncHttpxClient,
|
||||
)
|
||||
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use OpenAI, you need to `pip install pipecat-ai[openai]`. Also, set `OPENAI_API_KEY` environment variable."
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
ValidVoice = Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"]
|
||||
|
||||
VALID_VOICES: Dict[str, ValidVoice] = {
|
||||
@@ -631,10 +622,13 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
run_llm = False
|
||||
properties: Optional[FunctionCallResultProperties] = None
|
||||
|
||||
aggregation = self._aggregation
|
||||
aggregation = self._aggregation.strip()
|
||||
self.reset()
|
||||
|
||||
try:
|
||||
if aggregation:
|
||||
self._context.add_message({"role": "assistant", "content": aggregation})
|
||||
|
||||
if self._function_call_result:
|
||||
frame = self._function_call_result
|
||||
properties = frame.properties
|
||||
@@ -669,9 +663,6 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
# Default behavior is to run the LLM if there are no function calls in progress
|
||||
run_llm = not bool(self._function_calls_in_progress)
|
||||
|
||||
else:
|
||||
self._context.add_message({"role": "assistant", "content": aggregation})
|
||||
|
||||
if self._pending_image_frame_message:
|
||||
frame = self._pending_image_frame_message
|
||||
self._pending_image_frame_message = None
|
||||
|
||||
@@ -10,9 +10,17 @@ import json
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
import websockets
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
import websockets
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use OpenAI, you need to `pip install pipecat-ai[openai]`. Also, set `OPENAI_API_KEY` environment variable."
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
|
||||
@@ -7,12 +7,12 @@
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from loguru import logger
|
||||
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
|
||||
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.openai import OpenAILLMService
|
||||
|
||||
try:
|
||||
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
|
||||
from openpipe import AsyncOpenAI as OpenPipeAI
|
||||
from openpipe import AsyncStream
|
||||
except ModuleNotFoundError as e:
|
||||
|
||||
@@ -7,24 +7,13 @@
|
||||
from typing import List
|
||||
|
||||
from loguru import logger
|
||||
from openai import NOT_GIVEN, AsyncStream
|
||||
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
|
||||
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.openai import OpenAILLMService
|
||||
|
||||
try:
|
||||
from openai import (
|
||||
NOT_GIVEN,
|
||||
AsyncStream,
|
||||
)
|
||||
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use Perplexity, you need to `pip install pipecat-ai[perplexity]`. Also, set `PERPLEXITY_API_KEY` environment variable."
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class PerplexityLLMService(OpenAILLMService):
|
||||
"""A service for interacting with Perplexity's API.
|
||||
|
||||
@@ -16,22 +16,18 @@ from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
LLMFullResponseEndFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSSpeakFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import TTSService
|
||||
from pipecat.services.websocket_service import WebsocketService
|
||||
from pipecat.services.ai_services import InterruptibleTTSService, TTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
try:
|
||||
@@ -100,7 +96,7 @@ def language_to_playht_language(language: Language) -> Optional[str]:
|
||||
return result
|
||||
|
||||
|
||||
class PlayHTTTSService(TTSService, WebsocketService):
|
||||
class PlayHTTTSService(InterruptibleTTSService):
|
||||
class InputParams(BaseModel):
|
||||
language: Optional[Language] = Language.EN
|
||||
speed: Optional[float] = 1.0
|
||||
@@ -118,13 +114,11 @@ class PlayHTTTSService(TTSService, WebsocketService):
|
||||
params: InputParams = InputParams(),
|
||||
**kwargs,
|
||||
):
|
||||
TTSService.__init__(
|
||||
self,
|
||||
super().__init__(
|
||||
pause_frame_processing=True,
|
||||
sample_rate=sample_rate,
|
||||
**kwargs,
|
||||
)
|
||||
WebsocketService.__init__(self)
|
||||
|
||||
self._api_key = api_key
|
||||
self._user_id = user_id
|
||||
@@ -168,12 +162,12 @@ class PlayHTTTSService(TTSService, WebsocketService):
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self.push_error))
|
||||
|
||||
async def _disconnect(self):
|
||||
await self._disconnect_websocket()
|
||||
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
await self._disconnect_websocket()
|
||||
|
||||
async def _connect_websocket(self):
|
||||
try:
|
||||
logger.debug("Connecting to PlayHT")
|
||||
@@ -397,6 +391,7 @@ class PlayHTHttpTTSService(TTSService):
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
yield TTSStartedFrame()
|
||||
|
||||
async for chunk in playht_gen:
|
||||
# skip the RIFF header.
|
||||
if in_header:
|
||||
@@ -416,6 +411,8 @@ class PlayHTHttpTTSService(TTSService):
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(chunk, self.sample_rate, 1)
|
||||
yield frame
|
||||
yield TTSStoppedFrame()
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error generating TTS: {e}")
|
||||
finally:
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -14,22 +14,18 @@ from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
LLMFullResponseEndFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSSpeakFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import AudioContextWordTTSService, TTSService
|
||||
from pipecat.services.websocket_service import WebsocketService
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
try:
|
||||
@@ -58,7 +54,7 @@ def language_to_rime_language(language: Language) -> str:
|
||||
return LANGUAGE_MAP.get(language, "eng")
|
||||
|
||||
|
||||
class RimeTTSService(AudioContextWordTTSService, WebsocketService):
|
||||
class RimeTTSService(AudioContextWordTTSService):
|
||||
"""Text-to-Speech service using Rime's websocket API.
|
||||
|
||||
Uses Rime's websocket JSON API to convert text to speech with word-level timing
|
||||
@@ -95,17 +91,14 @@ class RimeTTSService(AudioContextWordTTSService, WebsocketService):
|
||||
params: Additional configuration parameters.
|
||||
"""
|
||||
# Initialize with parent class settings for proper frame handling
|
||||
AudioContextWordTTSService.__init__(
|
||||
self,
|
||||
super().__init__(
|
||||
aggregate_sentences=True,
|
||||
push_text_frames=False,
|
||||
push_stop_frames=True,
|
||||
stop_frame_timeout_s=2.0,
|
||||
pause_frame_processing=True,
|
||||
sample_rate=sample_rate,
|
||||
**kwargs,
|
||||
)
|
||||
WebsocketService.__init__(self)
|
||||
|
||||
# Store service configuration
|
||||
self._api_key = api_key
|
||||
@@ -176,11 +169,12 @@ class RimeTTSService(AudioContextWordTTSService, WebsocketService):
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Close websocket connection and clean up tasks."""
|
||||
await self._disconnect_websocket()
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
await self._disconnect_websocket()
|
||||
|
||||
async def _connect_websocket(self):
|
||||
"""Connect to Rime websocket API with configured settings."""
|
||||
try:
|
||||
@@ -250,7 +244,9 @@ class RimeTTSService(AudioContextWordTTSService, WebsocketService):
|
||||
async def flush_audio(self):
|
||||
if not self._context_id or not self._websocket:
|
||||
return
|
||||
|
||||
logger.trace(f"{self}: flushing audio")
|
||||
await self._get_websocket().send(json.dumps({"text": " "}))
|
||||
self._context_id = None
|
||||
|
||||
async def _receive_messages(self):
|
||||
@@ -349,7 +345,8 @@ class RimeHttpTTSService(TTSService):
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
voice_id: str = "eva",
|
||||
voice_id: str,
|
||||
aiohttp_session: aiohttp.ClientSession,
|
||||
model: str = "mistv2",
|
||||
sample_rate: Optional[int] = None,
|
||||
params: InputParams = InputParams(),
|
||||
@@ -358,6 +355,7 @@ class RimeHttpTTSService(TTSService):
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
self._api_key = api_key
|
||||
self._session = aiohttp_session
|
||||
self._base_url = "https://users.rime.ai/v1/rime-tts"
|
||||
self._settings = {
|
||||
"speedAlpha": params.speed_alpha,
|
||||
@@ -391,36 +389,31 @@ class RimeHttpTTSService(TTSService):
|
||||
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
yield TTSStartedFrame()
|
||||
async with self._session.post(
|
||||
self._base_url, json=payload, headers=headers
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_message = f"Rime TTS error: HTTP {response.status}"
|
||||
logger.error(error_message)
|
||||
yield ErrorFrame(error=error_message)
|
||||
return
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(self._base_url, json=payload, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
error_message = f"Rime TTS error: HTTP {response.status}"
|
||||
logger.error(error_message)
|
||||
yield ErrorFrame(error=error_message)
|
||||
return
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
# Process the streaming response
|
||||
chunk_size = 8192
|
||||
first_chunk = True
|
||||
yield TTSStartedFrame()
|
||||
|
||||
async for chunk in response.content.iter_chunked(chunk_size):
|
||||
if first_chunk:
|
||||
await self.stop_ttfb_metrics()
|
||||
first_chunk = False
|
||||
|
||||
if chunk:
|
||||
frame = TTSAudioRawFrame(chunk, self.sample_rate, 1)
|
||||
yield frame
|
||||
|
||||
yield TTSStoppedFrame()
|
||||
# Process the streaming response
|
||||
chunk_size = 8192
|
||||
|
||||
async for chunk in response.content.iter_chunked(chunk_size):
|
||||
if chunk:
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(chunk, self.sample_rate, 1)
|
||||
yield frame
|
||||
except Exception as e:
|
||||
logger.exception(f"Error generating TTS: {e}")
|
||||
yield ErrorFrame(error=f"Rime TTS error: {str(e)}")
|
||||
|
||||
finally:
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -90,14 +90,32 @@ class WebsocketService(ABC):
|
||||
logger.error(f"{self} reconnection failed: {reconnect_error}")
|
||||
continue
|
||||
|
||||
@abstractmethod
|
||||
async def _connect(self):
|
||||
"""Implement service-specific connection logic. This function will
|
||||
connect to the websocket via _connect_websocket() among other connection
|
||||
logic."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def _disconnect(self):
|
||||
"""Implement service-specific disconnection logic. This function will
|
||||
disconnect to the websocket via _connect_websocket() among other
|
||||
connection logic.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def _connect_websocket(self):
|
||||
"""Implement service-specific websocket connection logic."""
|
||||
"""Implement service-specific websocket connection logic. This function
|
||||
should only connect to the websocket."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def _disconnect_websocket(self):
|
||||
"""Implement service-specific websocket disconnection logic."""
|
||||
"""Implement service-specific websocket disconnection logic. This
|
||||
function should only disconnect from the websocket."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -174,11 +174,9 @@ class BaseInputTransport(FrameProcessor):
|
||||
async def _vad_analyze(self, audio_frame: InputAudioRawFrame) -> VADState:
|
||||
state = VADState.QUIET
|
||||
if self.vad_analyzer:
|
||||
logger.trace(f"{self}: analyzing VAD on {audio_frame}")
|
||||
state = await self.get_event_loop().run_in_executor(
|
||||
self._executor, self.vad_analyzer.analyze_audio, audio_frame.audio
|
||||
)
|
||||
logger.trace(f"{self}: done analyzing VAD on {audio_frame}")
|
||||
return state
|
||||
|
||||
async def _handle_vad(self, audio_frame: InputAudioRawFrame, vad_state: VADState):
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
@@ -30,7 +29,6 @@ class TransportParams(BaseModel):
|
||||
camera_out_framerate: int = 30
|
||||
camera_out_color_format: str = "RGB"
|
||||
audio_out_enabled: bool = False
|
||||
audio_out_is_live: bool = False
|
||||
audio_out_sample_rate: Optional[int] = None
|
||||
audio_out_channels: int = 1
|
||||
audio_out_bitrate: int = 96000
|
||||
|
||||
@@ -55,45 +55,89 @@ class FastAPIWebsocketCallbacks(BaseModel):
|
||||
on_session_timeout: Callable[[WebSocket], Awaitable[None]]
|
||||
|
||||
|
||||
class FastAPIWebsocketClient:
|
||||
def __init__(self, websocket: WebSocket, is_binary: bool, callbacks: FastAPIWebsocketCallbacks):
|
||||
self._websocket = websocket
|
||||
self._closing = False
|
||||
self._is_binary = is_binary
|
||||
self._callbacks = callbacks
|
||||
|
||||
def receive(self) -> typing.AsyncIterator[bytes | str]:
|
||||
return self._websocket.iter_bytes() if self._is_binary else self._websocket.iter_text()
|
||||
|
||||
async def send(self, data: str | bytes):
|
||||
if self._can_send():
|
||||
if self._is_binary:
|
||||
await self._websocket.send_bytes(data)
|
||||
else:
|
||||
await self._websocket.send_text(data)
|
||||
|
||||
async def disconnect(self):
|
||||
if self.is_connected and not self.is_closing:
|
||||
self._closing = True
|
||||
await self._websocket.close()
|
||||
await self.trigger_client_disconnected()
|
||||
|
||||
async def trigger_client_disconnected(self):
|
||||
await self._callbacks.on_client_disconnected(self._websocket)
|
||||
|
||||
async def trigger_client_connected(self):
|
||||
await self._callbacks.on_client_connected(self._websocket)
|
||||
|
||||
async def trigger_client_timout(self):
|
||||
await self._callbacks.on_session_timeout(self._websocket)
|
||||
|
||||
def _can_send(self):
|
||||
return self.is_connected and not self.is_closing
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self._websocket.client_state == WebSocketState.CONNECTED
|
||||
|
||||
@property
|
||||
def is_closing(self) -> bool:
|
||||
return self._closing
|
||||
|
||||
|
||||
class FastAPIWebsocketInputTransport(BaseInputTransport):
|
||||
def __init__(
|
||||
self,
|
||||
websocket: WebSocket,
|
||||
client: FastAPIWebsocketClient,
|
||||
params: FastAPIWebsocketParams,
|
||||
callbacks: FastAPIWebsocketCallbacks,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(params, **kwargs)
|
||||
|
||||
self._websocket = websocket
|
||||
self._client = client
|
||||
self._params = params
|
||||
self._callbacks = callbacks
|
||||
self._receive_task = None
|
||||
self._monitor_websocket_task = None
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
await self._params.serializer.setup(frame)
|
||||
if self._params.session_timeout:
|
||||
self._monitor_websocket_task = self.create_task(self._monitor_websocket())
|
||||
await self._callbacks.on_client_connected(self._websocket)
|
||||
await self._client.trigger_client_connected()
|
||||
self._receive_task = self.create_task(self._receive_messages())
|
||||
|
||||
async def _stop_tasks(self):
|
||||
if self._monitor_websocket_task:
|
||||
await self.cancel_task(self._monitor_websocket_task)
|
||||
await self.cancel_task(self._receive_task)
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
await self.cancel_task(self._receive_task)
|
||||
await self._stop_tasks()
|
||||
await self._client.disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
await self.cancel_task(self._receive_task)
|
||||
|
||||
def _iter_data(self) -> typing.AsyncIterator[bytes | str]:
|
||||
if self._params.serializer.type == FrameSerializerType.BINARY:
|
||||
return self._websocket.iter_bytes()
|
||||
else:
|
||||
return self._websocket.iter_text()
|
||||
await self._stop_tasks()
|
||||
await self._client.disconnect()
|
||||
|
||||
async def _receive_messages(self):
|
||||
try:
|
||||
async for message in self._iter_data():
|
||||
async for message in self._client.receive():
|
||||
frame = await self._params.serializer.deserialize(message)
|
||||
|
||||
if not frame:
|
||||
@@ -106,19 +150,23 @@ class FastAPIWebsocketInputTransport(BaseInputTransport):
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception receiving data: {e.__class__.__name__} ({e})")
|
||||
|
||||
await self._callbacks.on_client_disconnected(self._websocket)
|
||||
await self._client.trigger_client_disconnected()
|
||||
|
||||
async def _monitor_websocket(self):
|
||||
"""Wait for self._params.session_timeout seconds, if the websocket is still open, trigger timeout event."""
|
||||
await asyncio.sleep(self._params.session_timeout)
|
||||
await self._callbacks.on_session_timeout(self._websocket)
|
||||
await self._client.trigger_client_timout()
|
||||
|
||||
|
||||
class FastAPIWebsocketOutputTransport(BaseOutputTransport):
|
||||
def __init__(self, websocket: WebSocket, params: FastAPIWebsocketParams, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
client: FastAPIWebsocketClient,
|
||||
params: FastAPIWebsocketParams,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(params, **kwargs)
|
||||
|
||||
self._websocket = websocket
|
||||
self._client = client
|
||||
self._params = params
|
||||
|
||||
# write_raw_audio_frames() is called quickly, as soon as we get audio
|
||||
@@ -134,6 +182,14 @@ class FastAPIWebsocketOutputTransport(BaseOutputTransport):
|
||||
await self._params.serializer.setup(frame)
|
||||
self._send_interval = (self._audio_chunk_size / self.sample_rate) / 2
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
await self._client.disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
await self._client.disconnect()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
@@ -145,7 +201,10 @@ class FastAPIWebsocketOutputTransport(BaseOutputTransport):
|
||||
await self._write_frame(frame)
|
||||
|
||||
async def write_raw_audio_frames(self, frames: bytes):
|
||||
if self._websocket.client_state != WebSocketState.CONNECTED:
|
||||
if self._client.is_closing:
|
||||
return
|
||||
|
||||
if not self._client.is_connected:
|
||||
# Simulate audio playback with a sleep.
|
||||
await self._write_audio_sleep()
|
||||
return
|
||||
@@ -172,25 +231,17 @@ class FastAPIWebsocketOutputTransport(BaseOutputTransport):
|
||||
|
||||
await self._write_frame(frame)
|
||||
|
||||
self._websocket_audio_buffer = bytes()
|
||||
|
||||
# Simulate audio playback with a sleep.
|
||||
await self._write_audio_sleep()
|
||||
|
||||
async def _write_frame(self, frame: Frame):
|
||||
try:
|
||||
payload = await self._params.serializer.serialize(frame)
|
||||
if payload and self._websocket.client_state == WebSocketState.CONNECTED:
|
||||
await self._send_data(payload)
|
||||
if payload:
|
||||
await self._client.send(payload)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception sending data: {e.__class__.__name__} ({e})")
|
||||
|
||||
def _send_data(self, data: str | bytes):
|
||||
if self._params.serializer.type == FrameSerializerType.BINARY:
|
||||
return self._websocket.send_bytes(data)
|
||||
else:
|
||||
return self._websocket.send_text(data)
|
||||
|
||||
async def _write_audio_sleep(self):
|
||||
# Simulate a clock.
|
||||
current_time = time.monotonic()
|
||||
@@ -219,11 +270,14 @@ class FastAPIWebsocketTransport(BaseTransport):
|
||||
on_session_timeout=self._on_session_timeout,
|
||||
)
|
||||
|
||||
is_binary = self._params.serializer.type == FrameSerializerType.BINARY
|
||||
self._client = FastAPIWebsocketClient(websocket, is_binary, self._callbacks)
|
||||
|
||||
self._input = FastAPIWebsocketInputTransport(
|
||||
websocket, self._params, self._callbacks, name=self._input_name
|
||||
self._client, self._params, name=self._input_name
|
||||
)
|
||||
self._output = FastAPIWebsocketOutputTransport(
|
||||
websocket, self._params, name=self._output_name
|
||||
self._client, self._params, name=self._output_name
|
||||
)
|
||||
|
||||
# Register supported handlers. The user will only be able to register
|
||||
|
||||
@@ -13,8 +13,8 @@ ENDOFSENTENCE_PATTERN_STR = r"""
|
||||
(?<!Mr|Ms|Dr) # Negative lookbehind: not preceded by Mr, Ms, Dr (combined bc. length is the same)
|
||||
(?<!Mrs) # Negative lookbehind: not preceded by "Mrs"
|
||||
(?<!Prof) # Negative lookbehind: not preceded by "Prof"
|
||||
[\.\?\!;]| # Match a period, question mark, exclamation point, or semicolon
|
||||
[。?!;।] # the full-width version (mainly used in East Asian languages such as Chinese, Hindi)
|
||||
(\.\s*\.\s*\.|[\.\?\!;])| # Match a period, question mark, exclamation point, or semicolon
|
||||
(\。\s*\。\s*\。|[。?!;।]) # the full-width version (mainly used in East Asian languages such as Chinese, Hindi)
|
||||
$ # End of string
|
||||
"""
|
||||
ENDOFSENTENCE_PATTERN = re.compile(ENDOFSENTENCE_PATTERN_STR, re.VERBOSE)
|
||||
|
||||
@@ -2,6 +2,7 @@ aiohttp~=3.10.3
|
||||
anthropic~=0.30.0
|
||||
azure-cognitiveservices-speech~=1.40.0
|
||||
boto3~=1.35.27
|
||||
cartesia~=1.3.1
|
||||
daily-python~=0.11.0
|
||||
deepgram-sdk~=3.5.0
|
||||
fal-client~=0.4.1
|
||||
@@ -15,6 +16,7 @@ langchain~=0.2.14
|
||||
livekit~=0.13.1
|
||||
lmnt~=1.1.4
|
||||
loguru~=0.7.2
|
||||
Markdown~=3.7
|
||||
numpy~=1.26.4
|
||||
openai~=1.37.2
|
||||
openpipe~=4.24.0
|
||||
@@ -28,5 +30,4 @@ silero-vad~=5.1
|
||||
soxr~=0.5.0
|
||||
together~=1.2.7
|
||||
transformers~=4.48.0
|
||||
websockets~=13.1
|
||||
Markdown~=3.7
|
||||
websockets~=13.1
|
||||
@@ -11,10 +11,13 @@ from pipecat.utils.string import match_endofsentence
|
||||
|
||||
class TestUtilsString(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_endofsentence(self):
|
||||
assert match_endofsentence("This is a sentence.")
|
||||
assert match_endofsentence("This is a sentence! ")
|
||||
assert match_endofsentence("This is a sentence?")
|
||||
assert match_endofsentence("This is a sentence;")
|
||||
assert match_endofsentence("This is a sentence.") == 19
|
||||
assert match_endofsentence("This is a sentence!") == 19
|
||||
assert match_endofsentence("This is a sentence?") == 19
|
||||
assert match_endofsentence("This is a sentence;") == 19
|
||||
assert match_endofsentence("This is a sentence...") == 21
|
||||
assert match_endofsentence("This is a sentence . . .") == 24
|
||||
assert match_endofsentence("This is a sentence. ..") == 22
|
||||
assert not match_endofsentence("This is not a sentence")
|
||||
assert not match_endofsentence("This is not a sentence,")
|
||||
assert not match_endofsentence("This is not a sentence, ")
|
||||
|
||||
Reference in New Issue
Block a user