Compare commits

..

23 Commits

Author SHA1 Message Date
mattie ruth backman
e8640d84ae test fix now that we send an aggregated text frame for non word-by-word tts services 2025-11-14 17:13:08 -05:00
mattie ruth backman
23e4e29999 CHANGELOG fixes 2025-11-14 13:57:49 -05:00
mattie ruth backman
713b488bb6 Final PR Feedback changes 2025-11-14 13:54:20 -05:00
mattie ruth backman
71b87fd420 add transformers to initialization args 2025-11-14 13:54:20 -05:00
mattie ruth backman
3f269f9834 Add backwards compatibility for add_pattern_pair 2025-11-14 13:54:20 -05:00
mattie ruth backman
4c698777f3 PR Feedback 2025-11-14 13:54:20 -05:00
mattie ruth backman
5ca04ad741 CHANGELOG updates 2025-11-14 13:54:20 -05:00
mattie ruth backman
9a3902a82c Introducing a new processor: LLMTextProcessor
This new processor wraps an aggregator that can be overridden for the purposes
of customizing how the llm output gets categorized and handled in the pipeline.

Along with this, we are deprecating the ability to override the default
aggregator in the TTS to encourage use of the LLMTextProcessor in cases where
custome aggregation is needed.

This PR also:
- Introduces TTSService.transform_aggregation_type():
  This function provides the ability to provide callbacks to the TTS to
  transform text based on its aggregated type prior to sending the text to the
  underlying TTS service. This makes it possible to do things like introduce
  TTS-specific tags for spelling or emotion or change the pronunciation of
  something on the fly.
- Introduces to the RTVIObserver:
  - new init field skip_aggregator_types: A way to provide a list of aggregation
    types that should not be included in bot-output (or tts-text) messages
  - transform_aggregation_type(): Same as with TTSService, this allows you
    to provide a callback to transform text being sent as bot-output before
    it gets sent.
2025-11-14 13:54:20 -05:00
mattie ruth backman
8ab0c92681 Rename AggregatedLLMTextFrame to AggregatedTextFrame and made built-in types an enum 2025-11-14 13:54:20 -05:00
mattie ruth backman
124f147a37 CHANGELOG improvements 2025-11-14 13:54:18 -05:00
mattie ruth backman
ed808a9246 Fix new test and str version of PatternMatch 2025-11-14 13:53:23 -05:00
mattie ruth backman
e9de9daf8c Update PatternPairAggregator patterns to replace pattern_id with type to simplify the API 2025-11-14 13:53:23 -05:00
mattie ruth backman
82b9c4f0b6 various PR Review fixes:
1. Added support for turning off bot-output messages with the bot_output_enabled flag
2. Cleaned up logic and comments around TTSService:_push_tts_frames to hopefully make
   it easier to understand
3. Other minor cleanup
2025-11-14 13:53:23 -05:00
mattie ruth backman
5dfe20be91 Update Changelog 2025-11-14 13:53:22 -05:00
mattie ruth backman
0d2c5286fa Support customization over the way the assistant aggregator aggregates LLMTextFrames when tts_skip is on 2025-11-14 13:51:45 -05:00
mattie ruth backman
29417ba44d Move aggregation logic when skip_tts is on to the assistant aggregator 2025-11-14 13:51:45 -05:00
mattie ruth backman
bc6a9cac26 Add append_to_context boolean field to TextFrames
This allows any given TextFrame to be marked in a way such that it does not get
added to the context.

Specifically, this fixes a problem with the new AggregatedTextFrames where we
need to send LLM text both in an aggregated form as well as word-by-word but
avoid duplicating the text in the context.
2025-11-14 13:51:45 -05:00
mattie ruth backman
8a90decbc0 codepilot review fixes 2025-11-14 13:51:45 -05:00
mattie ruth backman
ccca6e8d81 Make the PatternPair action an Enum 2025-11-14 13:51:45 -05:00
mattie ruth backman
e6dc1a510d Introduce AggregatedLLMTextFrame to allow a separation of TTSTextFrame, indicating a spoken frame vs other aggregated, non-spoken frames 2025-11-14 13:51:45 -05:00
mattie ruth backman
69945c5e0d Various fixes:
1. Fixed pattern_pair_aggregator to support various ways of handling
   pattern matches (remove, keep and just trigger a callback, or
   aggregate
2. Fixed ivr_navigator use of pattern_pair_aggregator
3. Test fixes -- Tests now pass
2025-11-14 13:51:45 -05:00
mattie ruth backman
5c8635570d test fixes 2025-11-14 13:51:45 -05:00
mattie ruth backman
fe9aa3383e Adding support for new bot-output RTVI Message:
1. TTSTextFrames now include metadata about whether the text was spoken
   or not along with a type string to describe what the text represents:
   ex. "sentence", "word", "custom aggregation"
2. Expanded how aggregators work so that the aggregate method returns
   aggregated text along with the type of aggregation used to create it
3. Deprecated the RTVI bot-transcription event in lieu of...
4. Introduced support for a new bot-output event. This event is meant
   to be the one stop shop for communicating what the bot actually "says".
   It is based off TTSTextFrames to communicate both sentence by sentence
   (or whatever aggregation is used) as well as word by word. In addition,
   it will include LLMTextFrames, aggregated by sentence when tts is
   turned off (i.e. skip_tts is true).

Resolves pipecat-ai/pipecat-client-web#158
2025-11-14 13:51:45 -05:00
82 changed files with 1509 additions and 1429 deletions

View File

@@ -5,54 +5,98 @@ All notable changes to **Pipecat** will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.0.95] - 2025-11-18
## [Unreleased]
### Added
- Added ai-coustics integrated VAD (`AICVADAnalyzer`) with `AICFilter` factory and
example wiring; leverages the enhancement model for robust detection with no
ONNX dependency or added processing complexity.
- Added a watchdog to `DeepgramFluxSTTService` to prevent dangling tasks in case the
user was speaking and we stop receiving audio.
- Introduced a minimum confidence parameter in `DeepgramFluxSTTService` to avoid
generating transcriptions below a defined threshold.
- Added `ElevenLabsRealtimeSTTService` which implements the Realtime STT
service from ElevenLabs.
- Added word-level timestamps support to Hume TTS service
- Added a `TTSService.includes_inter_frame_spaces` property getter, so that TTS
services that subclass `TTSService` can indicate whether the text in the
`TTSTextFrame`s they push already contain any necessary inter-frame spaces.
- Added optional speaking rate control to `InworldTTSService`.
- Introduced new `AggregatedTextFrame` type to support representing a best effort of
the perceived llm output whether or not it is processed by the TTS. This new frame
type includes the field `aggregated_by` to represent the conceptual format by which
the given text is aggregated. `TTSTextFrame`s now inherit from `AggregatedTextFrame`.
With this inheritance, an observer can watch for `AggregatedTextFrame`s to accumlate
the perceived output and determine whether or not the text was spoken based on if that
frame is also a `TTSTextFrame`. (See bullet below on new `bot-output` which takes
advantage of this)
- Introduced `LLMTextProcessor`: A new processor meant to allow customization for how
LLMTextFrames should be aggregated and considered. It's purpose is to turn
`LLMTextFrame`s into `AggregatedTextFrame`s. By default, a TTSService will still
aggregate `LLMTextFrame`s by sentence for the service to consume. However, if you
wish to override how the llm text is aggregated, you should no longer override the
TTS's internal aggregator, but instead, insert this processor between your LLM and
TTS in the pipeline.
- New `bot-output` RTVI message to represent what the bot actually "says".
- The `RTVIObserver` now emits `bot-output` messages based off the new `AggregatedTextFrame`s
(`bot-tts-text` and `bot-llm-text` are still supported and generated, but `bot-transcript` is
now deprecated in lieu of this new, more thorough, message).
- The new `RTVIBotOutputMessage` includes the fields:
- `spoken`: A boolean indicating whether the text was spoken by TTS
- `aggregated_by`: A string representing how the text was aggregated ("sentence", "word",
"my custom aggregation")
- Introduced new fields to `RTVIObserver` to support the new `bot-output` messaging:
- `bot_output_enabled`: Defaults to True. Set to false to disable bot-output messages.
- `skip_aggregator_types`: Defaults to `None`. Set to a list of strings that match
aggregation types that should not be included in bot-output messages. (Ex. `credit_card`)
- Introduced new methods, `add_text_transformer()` and `remove_text_transformer()`, to `RTVIObserver` to support providing (and subsequently removing)
callbacks for various types of aggregations (or all aggregations with `*`) that can modify the
text before being sent as a `bot-output` or `tts-text` message. (Think obscuring the credit card
or inserting extra detail the client might want that the context doesn't need.)
- Updated the base aggregator type:
- Introduced a new `Aggregation` dataclass to represent both the aggregated `text` and
a string identifying the `type` of aggregation (ex. "sentence", "word", "my custom
aggregation")
- **BREAKING**: `BaseTextAggregator.text` now returns an `Aggregation` (instead of `str`).
To update: `aggregated_text = myAggregator.text` -> `aggregated_text = myAggregator.text.text`
- **BREAKING**: `BaseTextAggregator.aggregate()` now returns `Optional[Aggregation]`
(instead of `Optional[str]`). To update:
```
aggregation = myAggregator.aggregate(text)
if (aggregation):
print(f"successfully aggregated text: {aggregation.text}") // instead of {aggregation}
```
- `SimpleTextAggregator`, `SkipTagsAggregator`, `PatternPairAggregator` updated to
produce/consume `Aggregation` objects.
- Augmented the `PatternPairAggregator`:
- Introduced a new, preferred version of `add_pattern` to support a new option for treating a
match as a separate aggregation returned from `aggregate()`. This replaces the now
deprecated `add_pattern_pair` method and you provide a `MatchAction` in lieu of the `remove_match` field.
- `MatchAction` enum: `REMOVE`, `KEEP`, `AGGREGATE`, allowing customization for how
a match should be handled.
- `REMOVE`: The text along with its delimiters will be removed from the streaming text.
Sentence aggregation will continue on as if this text did not exist.
- `KEEP`: The delimiters will be removed, but the content between them will be kept.
Sentence aggregation will continue on with the internal text included.
- `AGGREGATE`: The delimiters will be removed and the content between will be treated
as a separate aggregation. Any text before the start of the pattern will be
returned early, whether or not a complete sentence was found. Then the pattern
will be returned. Then the aggregation will continue on sentence matching after
the closing delimiter is found. The content between the delimiters is not
aggregated by sentence. It is aggregated as one single block of text.
- `PatternMatch` now extends `Aggregation` and provides richer info to handlers.
- **BREAKING**: The `PatternMatch` type returned to handlers registered via `on_pattern_match`
has been updated to subclass from the new `Aggregation` type, which means that `content`
has been replaced with `text` and `pattern_id` has been replaced with `type`:
```
async dev on_match_tag(match: PatternMatch):
pattern = match.type # instead of match.pattern_id
text = match.text # instead of match.content
```
### Changed
- ⚠️ Breaking change: `LLMContext.create_image_message()`,
`LLMContext.create_audio_message()`, `LLMContext.add_image_frame_message()`
and `LLMContext.add_audio_frames_message()` are now async methods. This fixes
an issue where the asyncio event loop would be blocked while encoding audio or
images.
- `ConsumerProcessor` now queues frames from the producer internally instead of
pushing them directly. This allows us to subclass consumer processors and
manipulate frames before they are pushed.
- `BaseTextFilter` only require subclasses to implement the `filter()` method.
- Extracted the logic for retrying connections, and create a new `send_with_retry`
method inside `WebSocketService`.
- Refactored `DeepgramFluxSTTService` to automatically reconnect if sending a
message fails.
- Updated all STT and TTS services to use consistent error handling pattern with
`push_error()` method for better pipeline error event integration.
- Added support for `maybe_capture_participant_camera()` and
`maybe_capture_participant_screen()` for `SmallWebRTCTransport` in the runner
utils.
- Added Hindi support for Rime TTS services.
- Updated `GeminiTTSService` to use Google Cloud Text-to-Speech streaming API
@@ -65,18 +109,44 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Updated language mappings for the Google and Gemini TTS services to match
official documentation.
- `TextFrame` new field `append_to_context` used to indicate if the encompassing
text should be added to the LLM context (by the LLM assistant aggregator). It
defaults to `True`.
- TTS flow respects aggregation metadata
- `TTSService` accepts a new `skip_aggregator_types` to avoid speaking certain aggregation types
(now determined/returned by the aggregator)
- TTS services push `AggregatedTextFrame` in addition to `TTSTextFrame`s when either an
aggregation occurs that should not be spoken or when the TTS service supports word-by-word
timestamping. In the latter case, the `TTSService` preliminarily generates an
`AggregatedTextFrame`, aggregated by sentence to generate the full sentence content as early
as possible.
- Introduced a new methods, `add_text_transformer()` and `remove_text_transformer()`:
These functions introduce the ability to provide (and subsequently remove) callbacks to the TTS to transform text based on
its aggregated type prior to sending the text to the underlying TTS service. This makes it
possible to do things like introduce TTS-specific tags for spelling or emotion or change the
pronunciation of something on the fly.
### Deprecated
- The `api_key` parameter in `GeminiTTSService` is deprecated. Use
`credentials` or `credentials_path` instead for Google Cloud authentication.
- The RTVI `bot-transcription` event is deprecated in favor of the new `bot-output`
message which is the canonical representation of bot output (spoken or not). The code
still emits a transcription message for backwards compatibility while transition occurs.
- The TTS constructor field, `text_aggregator` is deprecated in favor of the new
`LLMTextProcessor`. TTSServices still have an internal aggregator for support of default
behavior, but if you want to override the aggregation behavior, you should use the new
processor.
- Deprecated `add_pattern_pair` in the `PatternPairAggregator` which takes a `pattern_id`
and `remove_match` field in favor of the new `add_pattern` method which takes a `type` and an
`action`
### Fixed
- Fixed a `SimliVideoService` connection issue.
- Fixed an issue in the `Runner` where, when using `SmallWebRTCTransport`, the
`request_data` was not being passed to the `SmallWebRTCRunnerArguments` body.
- Fixed subtle issue of assistant context messages ending up with double spaces
between words or sentences.
@@ -91,8 +161,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Prevented `HeyGenVideoService` from automatically disconnecting after 5 minutes.
- Fixed `InworldTTSService` audio config payload to use camelCase keys expected
by the Inworld API.
### Added
- Added ai-coustics integrated VAD (`AICVADAnalyzer`) with `AICFilter` factory and
example wiring; leverages the enhancement model for robust detection with no
ONNX dependency or added processing complexity.
## [0.0.94] - 2025-11-10

View File

@@ -13,29 +13,24 @@ from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.audio.vad.vad_analyzer import VADParams
from pipecat.frames.frames import LLMRunFrame, TTSTextFrame
from pipecat.observers.loggers.debug_log_observer import DebugLogObserver, FrameEndpoint
from pipecat.frames.frames import LLMRunFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response_universal import (
LLMContextAggregatorPair,
)
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
from pipecat.processors.frameworks.rtvi import RTVIConfig, RTVIObserver, RTVIProcessor
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.hume.tts import HUME_SAMPLE_RATE, HumeTTSService
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.transports.base_output import BaseOutputTransport
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
load_dotenv(override=True)
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
# instantiated. The function will be called when the desired transport gets
# selected.
@@ -93,7 +88,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
stt,
context_aggregator.user(), # User responses
llm, # LLM
tts, # TTS (HumeTTSService with word timestamps)
tts, # TTS
transport.output(), # Transport bot output
context_aggregator.assistant(), # Assistant spoken responses
]
@@ -107,14 +102,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
audio_out_sample_rate=HUME_SAMPLE_RATE,
),
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
observers=[
RTVIObserver(rtvi),
DebugLogObserver(
frame_types={
TTSTextFrame: (BaseOutputTransport, FrameEndpoint.SOURCE),
}
),
],
observers=[RTVIObserver(rtvi)],
)
@rtvi.event_handler("on_client_ready")
@@ -124,9 +112,6 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
@transport.event_handler("on_client_connected")
async def on_client_connected(transport, client):
logger.info(f"Client connected")
logger.info(
"💡 Word timestamps are enabled! Watch the console for TTSTextFrame logs showing each word with its PTS."
)
# Kick off the conversation.
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
await task.queue_frames([LLMRunFrame()])

View File

@@ -52,10 +52,7 @@ transport_params = {
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
logger.info(f"Starting bot")
stt = DeepgramFluxSTTService(
api_key=os.getenv("DEEPGRAM_API_KEY"),
params=DeepgramFluxSTTService.InputParams(min_confidence=0.3),
)
stt = DeepgramFluxSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
tts = DeepgramTTSService(api_key=os.getenv("DEEPGRAM_API_KEY"), voice="aura-2-andromeda-en")

View File

@@ -110,7 +110,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
# Kick off the conversation.
image = Image.open(image_path)
message = await LLMContext.create_image_message(
message = LLMContext.create_image_message(
image=image.tobytes(),
format="RGB",
size=image.size,

View File

@@ -110,7 +110,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
# Kick off the conversation.
image = Image.open(image_path)
message = await LLMContext.create_image_message(
message = LLMContext.create_image_message(
image=image.tobytes(),
format="RGB",
size=image.size,

View File

@@ -117,7 +117,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
# Kick off the conversation.
image = Image.open(image_path)
message = await LLMContext.create_image_message(
message = LLMContext.create_image_message(
image=image.tobytes(),
format="RGB",
size=image.size,

View File

@@ -110,7 +110,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
# Kick off the conversation.
image = Image.open(image_path)
message = await LLMContext.create_image_message(
message = LLMContext.create_image_message(
image=image.tobytes(),
format="RGB",
size=image.size,

View File

@@ -15,21 +15,14 @@ from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.audio.vad.vad_analyzer import VADParams
from pipecat.frames.frames import (
Frame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMRunFrame,
TextFrame,
UserImageRequestFrame,
)
from pipecat.frames.frames import LLMRunFrame, UserImageRequestFrame
from pipecat.pipeline.parallel_pipeline import ParallelPipeline
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineTask
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.processors.frame_processor import FrameDirection
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import (
create_transport,
@@ -73,27 +66,6 @@ async def fetch_user_image(params: FunctionCallParams):
# await params.result_callback({"result": "Image is being captured."})
class MoondreamTextFrameWrapper(FrameProcessor):
"""Wraps Moondream-provided TextFrames with LLM response start/end frames.
This processor detects TextFrames and automatically wraps them with
LLMFullResponseStartFrame and LLMFullResponseEndFrame to provide proper
response boundaries for downstream processors.
"""
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
# If we receive a TextFrame, wrap it with response start/end frames
if isinstance(frame, TextFrame):
await self.push_frame(LLMFullResponseStartFrame(), direction)
await self.push_frame(frame, direction)
await self.push_frame(LLMFullResponseEndFrame(), direction)
else:
# For all other frames, just pass them through
await self.push_frame(frame, direction)
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
# instantiated. The function will be called when the desired transport gets
# selected.
@@ -158,12 +130,6 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
# If you run into weird description, try with use_cpu=True
moondream = MoondreamService()
# Wrap TextFrames with LLM response start/end frames, which makes Moondream
# output be treated like LLM responses for the purpose of context
# aggregation. Without this, the assistant context aggregator would ignore
# Moondream output (if the TTS service is disabled).
moondream_text_wrapper = MoondreamTextFrameWrapper()
pipeline = Pipeline(
[
transport.input(), # Transport user input
@@ -171,7 +137,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
context_aggregator.user(), # User responses
ParallelPipeline(
[llm], # LLM
[moondream, moondream_text_wrapper],
[moondream],
),
tts, # TTS
transport.output(), # Transport bot output

View File

@@ -391,7 +391,7 @@ class AudioAccumulator(FrameProcessor):
)
self._user_speaking = False
context = LLMContext()
await context.add_audio_frames_message(audio_frames=self._audio_frames)
context.add_audio_frames_message(audio_frames=self._audio_frames)
await self.push_frame(LLMContextFrame(context=context))
elif isinstance(frame, InputAudioRawFrame):
# Append the audio frame to our buffer. Treat the buffer as a ring buffer, dropping the oldest

View File

@@ -150,7 +150,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
LLMLogObserver(),
DebugLogObserver(
frame_types={
TTSTextFrame: (BaseOutputTransport, FrameEndpoint.SOURCE),
TTSTextFrame: (BaseOutputTransport, FrameEndpoint.DESTINATION),
UserStartedSpeakingFrame: (BaseInputTransport, FrameEndpoint.SOURCE),
EndFrame: None,
}

View File

@@ -62,7 +62,11 @@ from pipecat.services.openai.llm import OpenAILLMService
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
from pipecat.utils.text.pattern_pair_aggregator import PatternMatch, PatternPairAggregator
from pipecat.utils.text.pattern_pair_aggregator import (
MatchAction,
PatternMatch,
PatternPairAggregator,
)
load_dotenv(override=True)
@@ -106,16 +110,16 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
pattern_aggregator = PatternPairAggregator()
# Add pattern for voice switching
pattern_aggregator.add_pattern_pair(
pattern_id="voice_tag",
pattern_aggregator.add_pattern(
type="voice",
start_pattern="<voice>",
end_pattern="</voice>",
remove_match=True,
action=MatchAction.REMOVE, # Remove tags from final text
)
# Register handler for voice switching
async def on_voice_tag(match: PatternMatch):
voice_name = match.content.strip().lower()
voice_name = match.text.strip().lower()
if voice_name in VOICE_IDS:
# First flush any existing audio to finish the current context
await tts.flush_audio()
@@ -125,7 +129,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
else:
logger.warning(f"Unknown voice: {voice_name}")
pattern_aggregator.on_pattern_match("voice_tag", on_voice_tag)
pattern_aggregator.on_pattern_match("voice", on_voice_tag)
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))

View File

@@ -155,7 +155,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
You are a helpful LLM in a WebRTC call.
Your goal is to demonstrate your capabilities in a succinct way.
You have access to tools to search the Rijksmuseum collection.
Offer, for example, to show a floral still life, use the `search_artwork` tool.
Offer, for example, to show the earliest Rembrandt work from the museum. Use the `search_artwork` tool.
The tool may respond with a JSON object with an `artworks` array. Choose the art from that array.
Once the tool has responded, tell the user the title and use the `open_image_in_browser` tool.
Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points.

View File

@@ -4,11 +4,12 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
import os
from dotenv import load_dotenv
from loguru import logger
from turn_detector_observer import TurnDetectorObserver
from mcp.client.session_group import SseServerParameters
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
@@ -20,12 +21,12 @@ from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
from pipecat.processors.frameworks.rtvi import RTVIConfig, RTVIObserver, RTVIProcessor
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.services.openai.stt import OpenAISTTService
from pipecat.services.openai.tts import OpenAITTSService
from pipecat.services.anthropic.llm import AnthropicLLMService
from pipecat.services.cartesia.tts import CartesiaTTSService
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.mcp_service import MCPClient
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
@@ -60,55 +61,59 @@ transport_params = {
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
logger.info(f"Starting bot")
### STT ###
stt = OpenAISTTService(
api_key=os.getenv("OPENAI_API_KEY"),
model="gpt-4o-transcribe",
prompt="Expect normal helpful conversation.",
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
tts = CartesiaTTSService(
api_key=os.getenv("CARTESIA_API_KEY"),
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
)
### LLM ###
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
### TTS ###
tts = OpenAITTSService(
api_key=os.getenv("OPENAI_API_KEY"),
voice="ballad",
params=OpenAITTSService.InputParams(
instructions="Please speak clearly and at a moderate pace."
),
llm = AnthropicLLMService(
api_key=os.getenv("ANTHROPIC_API_KEY"), model="claude-3-7-sonnet-latest"
)
messages = [
{
"role": "system",
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
},
]
try:
# https://docs.mcp.run/integrating/tutorials/mcp-run-sse-openai-agents/
mcp = MCPClient(server_params=SseServerParameters(url=os.getenv("MCP_RUN_SSE_URL")))
except Exception as e:
logger.error(f"error setting up mcp")
logger.exception("error trace:")
context = LLMContext(messages)
tools = {}
try:
tools = await mcp.register_tools(llm)
except Exception as e:
logger.error(f"error registering tools")
logger.exception("error trace:")
system = f"""
You are a helpful LLM in a WebRTC call.
Your goal is to demonstrate your capabilities in a succinct way.
You have access to a number of tools provided by mcp.run. Use any and all tools to help users.
Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points.
Respond to what the user said in a creative and helpful way.
When asked for today's date, use 'https://www.datetoday.net/'.
Don't overexplain what you are doing.
Just respond with short sentences when you are carrying out tool calls.
"""
messages = [{"role": "system", "content": system}]
context = LLMContext(messages, tools)
context_aggregator = LLMContextAggregatorPair(context)
# RTVI events for detecting bot aggregation
rtvi = RTVIProcessor(config=RTVIConfig(config=[]))
### PIPELINE ###
pipeline = Pipeline(
[
transport.input(), # Transport user input
rtvi,
stt,
context_aggregator.user(), # User responses
context_aggregator.user(), # User spoken responses
llm, # LLM
tts, # TTS
transport.output(), # Transport bot output
context_aggregator.assistant(), # Assistant spoken responses
context_aggregator.assistant(), # Assistant spoken responses and tool context
]
)
### TASK ###
turn_detector = TurnDetectorObserver()
task = PipelineTask(
pipeline,
params=PipelineParams(
@@ -116,16 +121,12 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
enable_usage_metrics=True,
),
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
observers=[turn_detector, RTVIObserver(rtvi)],
)
turn_detector.set_turn_observer_event_handlers(task.turn_tracking_observer)
@transport.event_handler("on_client_connected")
async def on_client_connected(transport, client):
logger.info(f"Client connected")
logger.info(f"Client connected: {client}")
# Kick off the conversation.
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
await task.queue_frames([LLMRunFrame()])
@transport.event_handler("on_client_disconnected")
@@ -133,11 +134,6 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
logger.info(f"Client disconnected")
await task.cancel()
# not sure this is needed...
@rtvi.event_handler("on_client_ready")
async def on_client_ready(rtvi):
await rtvi.set_bot_ready()
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
await runner.run(task)
@@ -150,6 +146,14 @@ async def bot(runner_args: RunnerArguments):
if __name__ == "__main__":
if not os.getenv("MCP_RUN_SSE_URL"):
logger.error(
f"Please set MCP_RUN_SSE_URL environment variable for this example. See https://mcp.run"
)
import sys
sys.exit(1)
from pipecat.runner.run import main
main()

View File

@@ -7,7 +7,6 @@
import asyncio
import io
import json
import os
import re
import shutil
@@ -16,7 +15,7 @@ import aiohttp
from dotenv import load_dotenv
from loguru import logger
from mcp import StdioServerParameters
from mcp.client.session_group import StreamableHttpParameters
from mcp.client.session_group import SseServerParameters
from PIL import Image
from pipecat.adapters.schemas.tools_schema import ToolsSchema
@@ -67,12 +66,10 @@ class UrlToImageProcessor(FrameProcessor):
await self.push_frame(frame, direction)
def extract_url(self, text: str):
data = json.loads(text)
if "artObject" in data:
return data["artObject"]["webImage"]["url"]
if "artworks" in data and len(data["artworks"]):
return data["artworks"][0]["webImage"]["url"]
pattern = r"!\[[^\]]*\]\((https?://[^)]+\.(png|jpg|jpeg|PNG|JPG|JPEG|gif))\)"
match = re.search(pattern, text)
if match:
return match.group(1)
return None
async def run_image_process(self, image_url: str):
@@ -135,11 +132,10 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
system = f"""
You are a helpful LLM in a WebRTC call.
Your goal is to demonstrate your capabilities in a succinct way.
You have access to tools to search the Rijksmuseum collection and the user's GitHub repositories and account.
Offer, for example, to show a floral still life, use the `search_artwork` tool.
You have access to tools to search the Rijksmuseum collection.
Offer, for example, to show the earliest Rembrandt work from the museum. Use the `search_artwork` tool.
The tool may respond with a JSON object with an `artworks` array. Choose the art from that array.
Once the tool has responded, tell the user the title and use the `open_image_in_browser` tool.
You can also offer to answer users questions about their GitHub repositories and account.
Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points.
Respond to what the user said in a creative and helpful way.
Don't overexplain what you are doing.
@@ -149,11 +145,11 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
messages = [{"role": "system", "content": system}]
try:
rijksmuseum_mcp = MCPClient(
mcp = MCPClient(
server_params=StdioServerParameters(
command=shutil.which("npx"),
# https://github.com/r-huijts/rijksmuseum-mcp
args=["-y", "mcp-server-rijksmuseum"],
args=["-y", "mcp-server-error setting up mcp"],
env={"RIJKSMUSEUM_API_KEY": os.getenv("RIJKSMUSEUM_API_KEY")},
)
)
@@ -161,32 +157,24 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
logger.error(f"error setting up rijksmuseum mcp")
logger.exception("error trace:")
try:
# Github MCP docs: https://github.com/github/github-mcp-server
# Enable Github Copilot on your GitHub account. Free tier is ok. (https://github.com/settings/copilot)
# Generate a personal access token. It must be a Fine-grained token, classic tokens are not supported. (https://github.com/settings/personal-access-tokens)
# Set permissions you want to use (eg. "all repositories", "profile: read/write", etc)
github_mcp = MCPClient(
server_params=StreamableHttpParameters(
url="https://api.githubcopilot.com/mcp/",
headers={
"Authorization": f"Bearer {os.getenv('GITHUB_PERSONAL_ACCESS_TOKEN')}"
},
)
)
# https://docs.mcp.run/integrating/tutorials/mcp-run-sse-openai-agents/
# ie. "https://www.mcp.run/api/mcp/sse?..."
# ensure the profile has a tool or few installed
mcp_run = MCPClient(server_params=SseServerParameters(url=os.getenv("MCP_RUN_SSE_URL")))
except Exception as e:
logger.error(f"error setting up mcp.run")
logger.exception("error trace:")
rijksmuseum_tools = {}
github_tools = {}
tools = {}
run_tools = {}
try:
rijksmuseum_tools = await rijksmuseum_mcp.register_tools(llm)
github_tools = await github_mcp.register_tools(llm)
tools = await mcp.register_tools(llm)
run_tools = await mcp_run.register_tools(llm)
except Exception as e:
logger.error(f"error registering tools")
logger.exception("error trace:")
all_standard_tools = rijksmuseum_tools.standard_tools + github_tools.standard_tools
all_standard_tools = run_tools.standard_tools + tools.standard_tools
all_tools = ToolsSchema(standard_tools=all_standard_tools)
context = LLMContext(messages, all_tools)
@@ -238,9 +226,9 @@ async def bot(runner_args: RunnerArguments):
if __name__ == "__main__":
if not os.getenv("RIJKSMUSEUM_API_KEY") or not os.getenv("GITHUB_PERSONAL_ACCESS_TOKEN"):
if not os.getenv("RIJKSMUSEUM_API_KEY") or not os.getenv("MCP_RUN_SSE_URL"):
logger.error(
f"Please set `RIJKSMUSEUM_API_KEY` and `GITHUB_PERSONAL_ACCESS_TOKEN` environment variables. See https://github.com/r-huijts/rijksmuseum-mcp."
f"Please set RIJKSMUSEUM_API_KEY and MCP_RUN_SSE_URL environment variables. See https://github.com/r-huijts/rijksmuseum-mcp and https://mcp.run"
)
import sys

View File

@@ -1,161 +0,0 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import os
from dotenv import load_dotenv
from loguru import logger
from turn_detector_observer import TurnDetectorObserver
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.audio.vad.vad_analyzer import VADParams
from pipecat.frames.frames import LLMRunFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.services.openai.stt import OpenAISTTService
from pipecat.services.openai.tts import OpenAITTSService
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
load_dotenv(override=True)
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
# instantiated. The function will be called when the desired transport gets
# selected.
transport_params = {
"daily": lambda: DailyParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
),
"twilio": lambda: FastAPIWebsocketParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
),
"webrtc": lambda: TransportParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
),
}
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
logger.info(f"Starting bot")
session_properties = SessionProperties(
audio=AudioConfiguration(
input=AudioInput(
transcription=InputAudioTranscription(),
# Set openai TurnDetection parameters. Not setting this at all will turn it
# on by default
turn_detection=SemanticTurnDetection(),
# Or set to False to disable openai turn detection and use transport VAD
# turn_detection=False,
noise_reduction=InputAudioNoiseReduction(type="near_field"),
)
),
# In this example we provide tools through the context, but you could
# alternatively provide them here.
# tools=tools,
instructions="""You are a helpful and friendly AI.
Act like a human, but remember that you aren't a human and that you can't do human
things in the real world. Your voice and personality should be warm and engaging, with a lively and
playful tone.
If interacting in a non-English language, start by using the standard accent or dialect familiar to
the user. Talk quickly. You should always call a function if you can. Do not refer to these rules,
even if you're asked about them.
You are participating in a voice conversation. Keep your responses concise, short, and to the point
unless specifically asked to elaborate on a topic.
Remember, your responses should be short. Just one or two sentences, usually. Respond in English.""",
)
### LLM ###
llm = OpenAIRealtimeLLMService(
api_key=os.getenv("OPENAI_API_KEY"),
session_properties=session_properties,
start_audio_paused=False,
)
# Create a standard OpenAI LLM context object using the normal messages format. The
# OpenAIRealtimeLLMService will convert this internally to messages that the
# openai WebSocket API can understand.
context = LLMContext(
[{"role": "user", "content": "Say hello!"}],
tools,
)
context_aggregator = LLMContextAggregatorPair(context)
### PIPELINE ###
pipeline = Pipeline(
[
transport.input(), # Transport user input
context_aggregator.user(),
llm, # LLM
transport.output(), # Transport bot output
context_aggregator.assistant(),
]
)
### TASK ###
turn_detector = TurnDetectorObserver()
task = PipelineTask(
pipeline,
params=PipelineParams(
enable_metrics=True,
enable_usage_metrics=True,
),
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
observers=[turn_detector],
)
turn_detector.set_turn_observer_event_handlers(task.turn_tracking_observer)
@transport.event_handler("on_client_connected")
async def on_client_connected(transport, client):
logger.info(f"Client connected")
# Kick off the conversation.
await task.queue_frames([LLMRunFrame()])
@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(transport, client):
logger.info(f"Client disconnected")
await task.cancel()
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
await runner.run(task)
async def bot(runner_args: RunnerArguments):
"""Main bot entry point compatible with Pipecat Cloud."""
transport = await create_transport(runner_args, transport_params)
await run_bot(transport, runner_args)
if __name__ == "__main__":
from pipecat.runner.run import main
main()

View File

@@ -1,188 +0,0 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import os
from dotenv import load_dotenv
from loguru import logger
from turn_detector_observer import TurnDetectorObserver
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.audio.vad.vad_analyzer import VADParams
from pipecat.frames.frames import LLMRunFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.services.openai.stt import OpenAISTTService
from pipecat.services.openai.tts import OpenAITTSService
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
load_dotenv(override=True)
async def fetch_weather_from_api(params: FunctionCallParams):
await params.result_callback({"conditions": "nice", "temperature": "75"})
async def fetch_restaurant_recommendation(params: FunctionCallParams):
await params.result_callback({"name": "The Golden Dragon"})
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
# instantiated. The function will be called when the desired transport gets
# selected.
transport_params = {
"daily": lambda: DailyParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
),
"twilio": lambda: FastAPIWebsocketParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
),
"webrtc": lambda: TransportParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
),
}
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
logger.info(f"Starting bot")
### STT ###
stt = OpenAISTTService(
api_key=os.getenv("OPENAI_API_KEY"),
model="gpt-4o-transcribe",
prompt="Expect normal helpful conversation.",
)
### LLM ###
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
### TTS ###
tts = OpenAITTSService(
api_key=os.getenv("OPENAI_API_KEY"),
voice="ballad",
params=OpenAITTSService.InputParams(instructions="Please speak clearly and at a moderate pace."),
)
# You can also register a function_name of None to get all functions
# sent to the same callback with an additional function_name parameter.
llm.register_function("get_current_weather", fetch_weather_from_api)
llm.register_function("get_restaurant_recommendation", fetch_restaurant_recommendation)
@llm.event_handler("on_function_calls_started")
async def on_function_calls_started(service, function_calls):
await tts.queue_frame(TTSSpeakFrame("Let me check on that."))
weather_function = FunctionSchema(
name="get_current_weather",
description="Get the current weather",
properties={
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this from the user's location.",
},
},
required=["location", "format"],
)
restaurant_function = FunctionSchema(
name="get_restaurant_recommendation",
description="Get a restaurant recommendation",
properties={
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
},
required=["location"],
)
tools = ToolsSchema(standard_tools=[weather_function, restaurant_function])
messages = [
{
"role": "system",
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
},
]
context = LLMContext(messages, tools)
context_aggregator = LLMContextAggregatorPair(context)
### PIPELINE ###
pipeline = Pipeline(
[
transport.input(), # Transport user input
stt,
context_aggregator.user(), # User responses
llm, # LLM
tts, # TTS
transport.output(), # Transport bot output
context_aggregator.assistant(), # Assistant spoken responses
]
)
### TASK ###
turn_detector = TurnDetectorObserver()
task = PipelineTask(
pipeline,
params=PipelineParams(
enable_metrics=True,
enable_usage_metrics=True,
),
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
observers=[turn_detector],
)
turn_detector.set_turn_observer_event_handlers(task.turn_tracking_observer)
@transport.event_handler("on_client_connected")
async def on_client_connected(transport, client):
logger.info(f"Client connected")
# Kick off the conversation.
await task.queue_frames([LLMRunFrame()])
@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(transport, client):
logger.info(f"Client disconnected")
await task.cancel()
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
await runner.run(task)
async def bot(runner_args: RunnerArguments):
"""Main bot entry point compatible with Pipecat Cloud."""
transport = await create_transport(runner_args, transport_params)
await run_bot(transport, runner_args)
if __name__ == "__main__":
from pipecat.runner.run import main
main()

View File

@@ -1,11 +0,0 @@
```bash
uv sync
uv pip install -e '.[cartesia,daily,elevenlabs,local-smart-turn-v3,openai,runner,webrtc]'
```
```bash
python examples/foundational/trace/001-trace.py
```
- open [http://localhost:7860](http://localhost:7860)
- click `connect` button in top right

View File

@@ -1,5 +0,0 @@
OPENAI_API_KEY=...
ELEVENLABS_API_KEY=...
ELEVENLABS_VOICE_ID=...
CARTESIA_API_KEY=...

View File

@@ -1,181 +0,0 @@
import time
from loguru import logger
from pipecat.frames.frames import (
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
EndFrame,
FunctionCallResultFrame,
FunctionCallsStartedFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
StartFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
)
from pipecat.observers.base_observer import BaseObserver, FramePushed
from pipecat.pipeline.pipeline import Pipeline
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.openai.base_llm import LLMService
from pipecat.transports.base_output import BaseOutputTransport
class TurnDetectorObserver(BaseObserver):
"""Observer ... of turns."""
def __init__(self):
super().__init__()
self._turn_observer = None
self._arrow = ""
self._turn_number = 1
self._endframe_queued = False
def init(self):
"""
Set ...
"""
pass
def set_turn_observer_event_handlers(self, turn_observer):
self._turn_observer = turn_observer
self.set_turn_observer_event_handlers(self._turn_observer)
def get_turn_observer(self):
return self._turn_observer
def set_turn_observer_event_handlers(self, turn_observer):
"""Sets the Turn Observer event handlers `on_turn_started` and `on_turn_ended`.
Args:
turn_observer: The turn tracking observer of the pipeline task
"""
@turn_observer.event_handler("on_turn_started")
async def on_turn_started(observer, turn_number):
self._turn_number = turn_number
current_time = time.time()
logger.info(f"🔄 Turn {turn_number} started")
# 🫆🫆🫆🫆
# code to start conversation turn here
# 🫆🫆🫆🫆
# 🫆🫆🫆🫆
# 🫆🫆🫆🫆
@turn_observer.event_handler("on_turn_ended")
async def on_turn_ended(observer, turn_number, duration, was_interrupted):
current_time = time.time()
if was_interrupted:
logger.info(f"🔄 Turn {turn_number} interrupted after {duration:.2f}s")
else:
logger.info(f"🏁 Turn {turn_number} completed in {duration:.2f}s")
# 🫆🫆🫆🫆
# code to end conversation turn here
# 🫆🫆🫆🫆
# 🫆🫆🫆🫆
# 🫆🫆🫆🫆
########
# everything past here isn't needed, just nice to have logging
########
async def on_push_frame(self, data: FramePushed):
"""Runs when any frame is pushed through pipeline.
Determines based on what type of frame and where it came from
what metrics to update.
Args:
data: the pushed frame
"""
src = data.source
dst = data.destination
frame = data.frame
direction = data.direction
timestamp = data.timestamp
# Convert timestamp to milliseconds for readability
time_sec = timestamp / 1_000_000
# Convert timestamp to seconds for readability
# time_sec = timestamp / 1_000_000_000
# only log downstream frames
if direction == FrameDirection.UPSTREAM:
return
if isinstance(src, Pipeline) or isinstance(dst, Pipeline):
if isinstance(frame, StartFrame):
self._handle_StartFrame(src, dst, frame, time_sec)
elif isinstance(frame, EndFrame):
self._handle_EndFrame(src, dst, frame, time_sec)
if isinstance(src, BaseOutputTransport):
if isinstance(frame, BotStartedSpeakingFrame):
self._handle_BotStartedSpeakingFrame(src, dst, frame, time_sec)
elif isinstance(frame, BotStoppedSpeakingFrame):
self._handle_BotStoppedSpeakingFrame(src, dst, frame, time_sec)
elif isinstance(frame, UserStartedSpeakingFrame):
self._handle_UserStartedSpeakingFrame(src, dst, frame, time_sec)
elif isinstance(frame, UserStoppedSpeakingFrame):
self._handle_UserStoppedSpeakingFrame(src, dst, frame, time_sec)
if isinstance(src, LLMService):
if isinstance(frame, LLMFullResponseStartFrame):
self._handle_LLMFullResponseStartFrame(src, dst, frame, time_sec)
elif isinstance(frame, LLMFullResponseEndFrame):
self._handle_LLMFullResponseEndFrame(src, dst, frame, time_sec)
elif isinstance(frame, FunctionCallsStartedFrame):
self._handle_FunctionCallsStartedFrame(src, dst, frame, time_sec)
elif isinstance(frame, FunctionCallResultFrame):
self._handle_FunctionCallResultFrame(src, dst, frame, time_sec)
# ------------ FRAME HANDLERS ------------
def _handle_StartFrame(self, src, dst, frame, time_sec):
if isinstance(dst, Pipeline):
logger.info(f"🟢🟢🟢 StartFrame: {src} {self._arrow} {dst} at {time_sec:.2f}s")
def _handle_EndFrame(self, src, dst, frame, time_sec):
if isinstance(dst, Pipeline):
logger.info(f"Queueing 🔴🔴🔴 EndFrame: {src} {self._arrow} {dst} at {time_sec:.2f}s")
self._endframe_queued = True
if isinstance(src, Pipeline):
logger.info(f"🔴🔴🔴 EndFrame: {src} {self._arrow} {dst} at {time_sec:.2f}s")
current_time = time.time()
end_state_info = {
"turn_number": self._turn_number,
}
def _handle_BotStartedSpeakingFrame(self, src, dst, frame, time_sec):
logger.info(f"🤖🟢 BotStartedSpeakingFrame: {src} {self._arrow} {dst} at {time_sec:.2f}s")
def _handle_BotStoppedSpeakingFrame(self, src, dst, frame, time_sec):
logger.info(f"🤖🔴 BotStoppedSpeakingFrame: {src} {self._arrow} {dst} at {time_sec:.2f}s")
def _handle_LLMFullResponseStartFrame(self, src, dst, frame, time_sec):
logger.info(f"🧠🟢 LLMFullResponseStartFrame: {src} {self._arrow} {dst} at {time_sec:.2f}s")
def _handle_LLMFullResponseEndFrame(self, src, dst, frame, time_sec):
logger.info(f"🧠🔴 LLMFullResponseEndFrame: {src} {self._arrow} {dst} at {time_sec:.2f}s")
def _handle_UserStartedSpeakingFrame(self, src, dst, frame, time_sec):
logger.info(f"🙂🟢 UserStartedSpeakingFrame: {src} {self._arrow} {dst} at {time_sec:.2f}s")
def _handle_UserStoppedSpeakingFrame(self, src, dst, frame, time_sec):
logger.info(f"🙂🔴 UserStoppedSpeakingFrame: {src} {self._arrow} {dst} at {time_sec:.2f}s")
def _handle_FunctionCallsStartedFrame(self, src, dst, frame, time_sec):
logger.info(
f"📐🟢 {frame.function_calls[0].function_name} FunctionCallsStartedFrame: {src} {self._arrow} {dst} at {time_sec:.2f}s"
)
def _handle_FunctionCallResultFrame(self, src, dst, frame, time_sec):
logger.info(
f"📐🔴 {frame.function_name} FunctionCallResultFrame: {src} {self._arrow} {dst} at {time_sec:.2f}s"
)

View File

@@ -99,7 +99,7 @@ local-smart-turn = [ "coremltools>=8.0", "transformers", "torch>=2.5.0,<3", "tor
local-smart-turn-v3 = [ "transformers", "onnxruntime>=1.20.1,<2" ]
remote-smart-turn = []
silero = [ "onnxruntime>=1.20.1,<2" ]
simli = [ "simli-ai~=1.0.3"]
simli = [ "simli-ai~=0.1.25"]
soniox = [ "pipecat-ai[websockets-base]" ]
soundfile = [ "soundfile~=0.13.1" ]
speechmatics = [ "speechmatics-rt>=0.5.0" ]

View File

@@ -30,8 +30,8 @@ EVAL_SIMPLE_MATH = EvalConfig(
)
EVAL_WEATHER = EvalConfig(
prompt="What's the weather in San Francisco (in farhenheit or celsius)?",
eval="The user says something specific about the current weather in San Francisco, including the degrees (in farhenheit or celsius).",
prompt="What's the weather in San Francisco?",
eval="The user says something specific about the current weather in San Francisco, including the degrees.",
)
EVAL_ONLINE_SEARCH = EvalConfig(
@@ -70,7 +70,7 @@ EVAL_VOICEMAIL = EvalConfig(
EVAL_CONVERSATION = EvalConfig(
prompt="Hello, this is Mark.",
eval="The user acknowledges the greeting.",
eval="The user replies with a greeting.",
eval_speaks_first=True,
)

View File

@@ -31,7 +31,11 @@ from pipecat.pipeline.pipeline import Pipeline
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.services.llm_service import LLMService
from pipecat.utils.text.pattern_pair_aggregator import PatternMatch, PatternPairAggregator
from pipecat.utils.text.pattern_pair_aggregator import (
MatchAction,
PatternMatch,
PatternPairAggregator,
)
class IVRStatus(Enum):
@@ -114,15 +118,15 @@ class IVRProcessor(FrameProcessor):
def _setup_xml_patterns(self):
"""Set up XML pattern detection and handlers."""
# Register DTMF pattern
self._aggregator.add_pattern_pair("dtmf", "<dtmf>", "</dtmf>", remove_match=True)
self._aggregator.add_pattern("dtmf", "<dtmf>", "</dtmf>", action=MatchAction.REMOVE)
self._aggregator.on_pattern_match("dtmf", self._handle_dtmf_action)
# Register mode pattern
self._aggregator.add_pattern_pair("mode", "<mode>", "</mode>", remove_match=True)
self._aggregator.add_pattern("mode", "<mode>", "</mode>", action=MatchAction.REMOVE)
self._aggregator.on_pattern_match("mode", self._handle_mode_action)
# Register IVR pattern
self._aggregator.add_pattern_pair("ivr", "<ivr>", "</ivr>", remove_match=True)
self._aggregator.add_pattern("ivr", "<ivr>", "</ivr>", action=MatchAction.REMOVE)
self._aggregator.on_pattern_match("ivr", self._handle_ivr_action)
async def process_frame(self, frame: Frame, direction: FrameDirection):
@@ -159,7 +163,7 @@ class IVRProcessor(FrameProcessor):
Args:
match: The pattern match containing DTMF content.
"""
value = match.content
value = match.text
logger.debug(f"DTMF detected: {value}")
try:
@@ -180,7 +184,7 @@ class IVRProcessor(FrameProcessor):
Args:
match: The pattern match containing IVR status content.
"""
status = match.content
status = match.text
logger.trace(f"IVR status detected: {status}")
# Convert string to enum, with validation
@@ -211,7 +215,7 @@ class IVRProcessor(FrameProcessor):
Args:
match: The pattern match containing mode content.
"""
mode = match.content
mode = match.text
logger.debug(f"Mode detected: {mode}")
if mode == "conversation":
await self._handle_conversation()

View File

@@ -12,6 +12,7 @@ and LLM processing.
"""
from dataclasses import dataclass, field
from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
@@ -337,11 +338,14 @@ class TextFrame(DataFrame):
# mandatory fields of theirs to have defaults to preserve
# non-default-before-default argument order)
includes_inter_frame_spaces: bool = field(init=False)
# Whether this text frame should be appended to the LLM context.
append_to_context: bool = field(init=False)
def __post_init__(self):
super().__post_init__()
self.skip_tts = False
self.includes_inter_frame_spaces = False
self.append_to_context = True
def __str__(self):
pts = format_pts(self.pts)
@@ -352,14 +356,35 @@ class TextFrame(DataFrame):
class LLMTextFrame(TextFrame):
"""Text frame generated by LLM services."""
def __post_init__(self):
super().__post_init__()
# LLM services send text frames with all necessary spaces included
self.includes_inter_frame_spaces = True
pass
class AggregationType(str, Enum):
"""Built-in aggregation strings."""
SENTENCE = "sentence"
WORD = "word"
def __str__(self):
return self.value
@dataclass
class TTSTextFrame(TextFrame):
class AggregatedTextFrame(TextFrame):
"""Text frame representing an aggregation of TextFrames.
This frame contains multiple TextFrames aggregated together for processing
or output along with a field to indicate how they are aggregated.
Parameters:
aggregated_by: Method used to aggregate the text frames.
"""
aggregated_by: AggregationType | str
@dataclass
class TTSTextFrame(AggregatedTextFrame):
"""Text frame generated by Text-to-Speech services."""
pass

View File

@@ -14,7 +14,6 @@ translation from this universal context into whatever format it needs, using a
service-specific adapter.
"""
import asyncio
import base64
import io
import wave
@@ -138,7 +137,7 @@ class LLMContext:
return {"role": role, "content": content}
@staticmethod
async def create_image_message(
def create_image_message(
*,
role: str = "user",
format: str,
@@ -155,21 +154,15 @@ class LLMContext:
image: Raw image bytes.
text: Optional text to include with the image.
"""
def encode_image():
buffer = io.BytesIO()
Image.frombytes(format, size, image).save(buffer, format="JPEG")
encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
return encoded_image
encoded_image = await asyncio.to_thread(encode_image)
buffer = io.BytesIO()
Image.frombytes(format, size, image).save(buffer, format="JPEG")
encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
url = f"data:image/jpeg;base64,{encoded_image}"
return LLMContext.create_image_url_message(role=role, url=url, text=text)
@staticmethod
async def create_audio_message(
def create_audio_message(
*, role: str = "user", audio_frames: list[AudioRawFrame], text: str = "Audio follows"
) -> LLMContextMessage:
"""Create a context message containing audio.
@@ -179,26 +172,21 @@ class LLMContext:
audio_frames: List of audio frame objects to include.
text: Optional text to include with the audio.
"""
sample_rate = audio_frames[0].sample_rate
num_channels = audio_frames[0].num_channels
async def encode_audio():
sample_rate = audio_frames[0].sample_rate
num_channels = audio_frames[0].num_channels
content = []
content.append({"type": "text", "text": text})
data = b"".join(frame.audio for frame in audio_frames)
content = []
content.append({"type": "text", "text": text})
data = b"".join(frame.audio for frame in audio_frames)
with io.BytesIO() as buffer:
with wave.open(buffer, "wb") as wf:
wf.setsampwidth(2)
wf.setnchannels(num_channels)
wf.setframerate(sample_rate)
wf.writeframes(data)
with io.BytesIO() as buffer:
with wave.open(buffer, "wb") as wf:
wf.setsampwidth(2)
wf.setnchannels(num_channels)
wf.setframerate(sample_rate)
wf.writeframes(data)
encoded_audio = base64.b64encode(buffer.getvalue()).decode("utf-8")
return encoded_audio
encoded_audio = await asyncio.to_thread(encode_audio)
encoded_audio = base64.b64encode(buffer.getvalue()).decode("utf-8")
content.append(
{
@@ -333,7 +321,7 @@ class LLMContext:
"""
self._tool_choice = tool_choice
async def add_image_frame_message(
def add_image_frame_message(
self, *, format: str, size: tuple[int, int], image: bytes, text: Optional[str] = None
):
"""Add a message containing an image frame.
@@ -344,12 +332,10 @@ class LLMContext:
image: Raw image bytes.
text: Optional text to include with the image.
"""
message = await LLMContext.create_image_message(
format=format, size=size, image=image, text=text
)
message = LLMContext.create_image_message(format=format, size=size, image=image, text=text)
self.add_message(message)
async def add_audio_frames_message(
def add_audio_frames_message(
self, *, audio_frames: list[AudioRawFrame], text: str = "Audio follows"
):
"""Add a message containing audio frames.
@@ -358,7 +344,7 @@ class LLMContext:
audio_frames: List of audio frame objects to include.
text: Optional text to include with the audio.
"""
message = await LLMContext.create_audio_message(audio_frames=audio_frames, text=text)
message = LLMContext.create_audio_message(audio_frames=audio_frames, text=text)
self.add_message(message)
@staticmethod

View File

@@ -1001,7 +1001,7 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
await self.push_aggregation()
async def _handle_text(self, frame: TextFrame):
if not self._started:
if not self._started or not frame.append_to_context:
return
if self._params.expect_stripped_words:

View File

@@ -66,7 +66,7 @@ from pipecat.processors.aggregators.llm_response import (
LLMUserAggregatorParams,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.utils.string import TextPartForConcatenation, concatenate_aggregated_text
from pipecat.utils.string import concatenate_aggregated_text
from pipecat.utils.time import time_now_iso8601
@@ -90,7 +90,15 @@ class LLMContextAggregator(FrameProcessor):
self._context = context
self._role = role
self._aggregation: List[TextPartForConcatenation] = []
self._aggregation: List[str] = []
# Whether to add spaces between text parts.
# (Currently only used by LLMAssistantAggregator, but could be expanded
# to LLMUserAggregator in the future if needed; that would require
# additional work since LLMUserAggregator currently trims spaces from
# incoming frames before determining whether it "really" received any
# text).
self._add_spaces = True
@property
def messages(self) -> List[LLMContextMessage]:
@@ -183,7 +191,7 @@ class LLMContextAggregator(FrameProcessor):
Returns:
The concatenated aggregation string.
"""
return concatenate_aggregated_text(self._aggregation)
return concatenate_aggregated_text(self._aggregation, self._add_spaces)
class LLMUserAggregator(LLMContextAggregator):
@@ -433,12 +441,7 @@ class LLMUserAggregator(LLMContextAggregator):
if not text.strip():
return
# Transcriptions never include inter-part spaces (so far).
self._aggregation.append(
TextPartForConcatenation(
text, includes_inter_part_spaces=frame.includes_inter_frame_spaces
)
)
self._aggregation.append(text)
# We just got a final result, so let's reset interim results.
self._seen_interim_results = False
# Reset aggregation timer.
@@ -793,7 +796,7 @@ class LLMAssistantAggregator(LLMContextAggregator):
logger.debug(f"{self} Appending UserImageRawFrame to LLM context (size: {frame.size})")
await self._context.add_image_frame_message(
self._context.add_image_frame_message(
format=frame.format,
size=frame.size,
image=frame.image,
@@ -811,18 +814,18 @@ class LLMAssistantAggregator(LLMContextAggregator):
await self.push_aggregation()
async def _handle_text(self, frame: TextFrame):
if not self._started:
if not self._started or not frame.append_to_context:
return
# Make sure we really have text (spaces count, too!)
if len(frame.text) == 0:
return
self._aggregation.append(
TextPartForConcatenation(
frame.text, includes_inter_part_spaces=frame.includes_inter_frame_spaces
)
)
# Track whether we need to add spaces between text parts
# Assumption: we can just keep track of the latest frame's value
self._add_spaces = not frame.includes_inter_frame_spaces
self._aggregation.append(frame.text)
def _context_updated_task_finished(self, task: asyncio.Task):
self._context_updated_tasks.discard(task)

View File

@@ -0,0 +1,106 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""LLM text processor module for processing and aggregating raw LLM output text.
This processor will convert LLMTextFrames into AggregatedTextFrames based on the
configured text aggregator. Using the customizable aggregator, it provides
functionality to handle or manipulate LLM text frames before they are sent to other
components such as TTS services or context aggregators. It can be used to pre-aggregate
and categorize, modify, or filter direct output tokens from the LLM.
"""
from typing import Optional
from pipecat.frames.frames import (
AggregatedTextFrame,
EndFrame,
Frame,
InterruptionFrame,
LLMFullResponseEndFrame,
LLMTextFrame,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
from pipecat.utils.text.simple_text_aggregator import SimpleTextAggregator
class LLMTextProcessor(FrameProcessor):
"""A processor for handling or manipulating LLM text frames before they are processed further.
This processor will convert LLMTextFrames into AggregatedTextFrames based on the configured
text aggregator. Using the customizable aggregator, it provides functionality to handle or
manipulate LLM text frames before they are sent to other components such as TTS services or
context aggregators. It can be used to pre-aggregate and categorize, modify, or filter direct
output tokens from the LLM.
"""
def __init__(self, *, text_aggregator: Optional[BaseTextAggregator] = None, **kwargs):
"""Initialize the LLM text processor.
Args:
text_aggregator: An optional text aggregator to use for processing LLM text frames. By
default, a SimpleTextAggregator aggregating by sentence will be used.
**kwargs: Additional arguments passed to parent class.
TODO: Allow transformations per aggregation type or all (and deprecate the TTS filters).
"""
super().__init__(**kwargs)
self._text_aggregator: BaseTextAggregator = text_aggregator or SimpleTextAggregator()
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process an LLMTextFrames using the aggregator to generate AggregatedTextFrames.
Args:
frame: The frame to process.
direction: The direction of frame flow in the pipeline.
"""
await super().process_frame(frame, direction)
if isinstance(frame, InterruptionFrame):
await self._handle_interruption(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, LLMTextFrame):
await self._handle_llm_text(frame)
elif isinstance(frame, LLMFullResponseEndFrame):
await self._handle_llm_end(frame.skip_tts)
await self.push_frame(frame, direction)
elif isinstance(frame, EndFrame):
await self._handle_llm_end()
await self.push_frame(frame, direction)
else:
await self.push_frame(frame, direction)
async def _handle_interruption(self, _):
"""Handle interruptions by resetting the text aggregator."""
await self._text_aggregator.handle_interruption()
async def reset(self):
"""Reset the internal state of the text processor and its aggregator."""
await self._text_aggregator.reset()
async def _handle_llm_text(self, in_frame: LLMTextFrame):
aggregation = await self._text_aggregator.aggregate(in_frame.text)
if aggregation:
out_frame = AggregatedTextFrame(
text=aggregation.text,
aggregated_by=aggregation.type,
)
out_frame.skip_tts = in_frame.skip_tts
await self.push_frame(out_frame)
async def _handle_llm_end(self, skip_tts: bool = False):
# Flush any remaining aggregated text at the end of the LLM response
aggregation = self._text_aggregator.text
await self._text_aggregator.reset()
text = aggregation.text.strip()
if text:
out_frame = AggregatedTextFrame(
text=text,
aggregated_by=aggregation.type,
)
out_frame.skip_tts = skip_tts
await self.push_frame(out_frame)

View File

@@ -83,4 +83,4 @@ class ConsumerProcessor(FrameProcessor):
while True:
frame = await self._queue.get()
new_frame = await self._transformer(frame)
await self.queue_frame(new_frame, self._direction)
await self.push_frame(new_frame, self._direction)

View File

@@ -24,6 +24,7 @@ from typing import (
Literal,
Mapping,
Optional,
Tuple,
Union,
)
@@ -32,6 +33,8 @@ from pydantic import BaseModel, Field, PrivateAttr, ValidationError
from pipecat.audio.utils import calculate_audio_volume
from pipecat.frames.frames import (
AggregatedTextFrame,
AggregationType,
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
CancelFrame,
@@ -704,6 +707,29 @@ class RTVITextMessageData(BaseModel):
text: str
class RTVIBotOutputMessageData(RTVITextMessageData):
"""Data for bot output RTVI messages.
Extends RTVITextMessageData to include metadata about the output.
"""
spoken: bool = False # Indicates if the text has been spoken by TTS
aggregated_by: AggregationType | str
# Indicates what form the text is in (e.g., by word, sentence, etc.)
class RTVIBotOutputMessage(BaseModel):
"""Message containing bot output text.
An event meant to holistically represent what the bot is outputting,
along with metadata about the output and if it has been spoken.
"""
label: RTVIMessageLiteral = RTVI_MESSAGE_LABEL
type: Literal["bot-output"] = "bot-output"
data: RTVIBotOutputMessageData
class RTVIBotTranscriptionMessage(BaseModel):
"""Message containing bot transcription text.
@@ -896,6 +922,7 @@ class RTVIObserverParams:
Parameter `errors_enabled` is deprecated. Error messages are always enabled.
Parameters:
bot_output_enabled: Indicates if bot output messages should be sent.
bot_llm_enabled: Indicates if the bot's LLM messages should be sent.
bot_tts_enabled: Indicates if the bot's TTS messages should be sent.
bot_speaking_enabled: Indicates if the bot's started/stopped speaking messages should be sent.
@@ -907,9 +934,17 @@ class RTVIObserverParams:
metrics_enabled: Indicates if metrics messages should be sent.
system_logs_enabled: Indicates if system logs should be sent.
errors_enabled: [Deprecated] Indicates if errors messages should be sent.
skip_aggregator_types: List of aggregation types to skip sending as tts/output messages.
Note: if using this to avoid sending secure information, be sure to also disable
bot_llm_enabled to avoid leaking through LLM messages.
bot_output_transforms: A list of callables to transform text before just before sending it
to TTS. Each callable takes the aggregated text and its type, and returns the
transformed text. To register, provide a list of tuples of
(aggregation_type | '*', transform_function).
audio_level_period_secs: How often audio levels should be sent if enabled.
"""
bot_output_enabled: bool = True
bot_llm_enabled: bool = True
bot_tts_enabled: bool = True
bot_speaking_enabled: bool = True
@@ -921,6 +956,15 @@ class RTVIObserverParams:
metrics_enabled: bool = True
system_logs_enabled: bool = False
errors_enabled: Optional[bool] = None
skip_aggregator_types: Optional[List[AggregationType | str]] = None
bot_output_transforms: Optional[
List[
Tuple[
AggregationType | str,
Callable[[str, AggregationType | str], Awaitable[str]],
]
]
] = None
audio_level_period_secs: float = 0.15
@@ -973,8 +1017,45 @@ class RTVIObserver(BaseObserver):
DeprecationWarning,
)
self._aggregation_transforms: List[
Tuple[AggregationType | str, Callable[[str, AggregationType | str], Awaitable[str]]]
] = self._params.bot_output_transforms or []
def add_bot_output_transformer(
self,
transform_function: Callable[[str, AggregationType | str], Awaitable[str]],
aggregation_type: AggregationType | str = "*",
):
"""Transform text for a specific aggregation type before sending as Bot Output or TTS.
Args:
transform_function: The function to apply for transformation. This function should take
the text and aggregation type as input and return the transformed text.
Ex.: async def my_transform(text: str, aggregation_type: str) -> str:
aggregation_type: The type of aggregation to transform. This value defaults to "*" to
handle all text before sending to the client.
"""
self._aggregation_transforms.append((aggregation_type, transform_function))
def remove_bot_output_transformer(
self,
transform_function: Callable[[str, AggregationType | str], Awaitable[str]],
aggregation_type: AggregationType | str = "*",
):
"""Remove a text transformer for a specific aggregation type.
Args:
transform_function: The function to remove.
aggregation_type: The type of aggregation to remove the transformer for.
"""
self._aggregation_transforms = [
(agg_type, func)
for agg_type, func in self._aggregation_transforms
if not (agg_type == aggregation_type and func == transform_function)
]
async def _logger_sink(self, message):
"""Logger sink so we cna send system logs to RTVI clients."""
"""Logger sink so we can send system logs to RTVI clients."""
message = RTVISystemLogMessage(data=RTVITextMessageData(text=message))
await self.send_rtvi_message(message)
@@ -1048,12 +1129,15 @@ class RTVIObserver(BaseObserver):
await self.send_rtvi_message(RTVIBotTTSStartedMessage())
elif isinstance(frame, TTSStoppedFrame) and self._params.bot_tts_enabled:
await self.send_rtvi_message(RTVIBotTTSStoppedMessage())
elif isinstance(frame, TTSTextFrame) and self._params.bot_tts_enabled:
if isinstance(src, BaseOutputTransport):
message = RTVIBotTTSTextMessage(data=RTVITextMessageData(text=frame.text))
await self.send_rtvi_message(message)
else:
elif isinstance(frame, AggregatedTextFrame) and (
self._params.bot_output_enabled or self._params.bot_tts_enabled
):
if isinstance(frame, TTSTextFrame) and not isinstance(src, BaseOutputTransport):
# This check is to make sure we handle the frame when it has gone
# through the transport and has correct timing.
mark_as_seen = False
else:
await self._handle_aggregated_llm_text(frame)
elif isinstance(frame, MetricsFrame) and self._params.metrics_enabled:
await self._handle_metrics(frame)
elif isinstance(frame, RTVIServerMessageFrame):
@@ -1084,15 +1168,6 @@ class RTVIObserver(BaseObserver):
if mark_as_seen:
self._frames_seen.add(frame.id)
async def _push_bot_transcription(self):
"""Push accumulated bot transcription as a message."""
if len(self._bot_transcription) > 0:
message = RTVIBotTranscriptionMessage(
data=RTVITextMessageData(text=self._bot_transcription)
)
await self.send_rtvi_message(message)
self._bot_transcription = ""
async def _handle_interruptions(self, frame: Frame):
"""Handle user speaking interruption frames."""
message = None
@@ -1115,14 +1190,45 @@ class RTVIObserver(BaseObserver):
if message:
await self.send_rtvi_message(message)
async def _handle_aggregated_llm_text(self, frame: AggregatedTextFrame):
"""Handle aggregated LLM text output frames."""
# Skip certain aggregator types if configured to do so.
if (
self._params.skip_aggregator_types
and frame.aggregated_by in self._params.skip_aggregator_types
):
return
text = frame.text
type = frame.aggregated_by
for aggregation_type, transform in self._aggregation_transforms:
if aggregation_type == type or aggregation_type == "*":
text = await transform(text, type)
isTTS = isinstance(frame, TTSTextFrame)
if self._params.bot_output_enabled:
message = RTVIBotOutputMessage(
data=RTVIBotOutputMessageData(text=text, spoken=isTTS, aggregated_by=type)
)
await self.send_rtvi_message(message)
if isTTS and self._params.bot_tts_enabled:
tts_message = RTVIBotTTSTextMessage(data=RTVITextMessageData(text=text))
await self.send_rtvi_message(tts_message)
async def _handle_llm_text_frame(self, frame: LLMTextFrame):
"""Handle LLM text output frames."""
message = RTVIBotLLMTextMessage(data=RTVITextMessageData(text=frame.text))
await self.send_rtvi_message(message)
# TODO (mrkb): Remove all this logic when we fully deprecate bot-transcription messages.
self._bot_transcription += frame.text
if match_endofsentence(self._bot_transcription):
await self._push_bot_transcription()
if match_endofsentence(self._bot_transcription) and len(self._bot_transcription) > 0:
await self.send_rtvi_message(
RTVIBotTranscriptionMessage(data=RTVITextMessageData(text=self._bot_transcription))
)
self._bot_transcription = ""
async def _handle_user_transcriptions(self, frame: Frame):
"""Handle user transcription frames."""
@@ -1248,7 +1354,7 @@ class RTVIProcessor(FrameProcessor):
# Default to 0.3.0 which is the last version before actually having a
# "client-version".
self._client_version = [0, 3, 0]
self._skip_tts: bool = False # Keep in sync with llm_service.py
self._llm_skip_tts: bool = False # Keep in sync with llm_service.py's configuration.
self._registered_actions: Dict[str, RTVIAction] = {}
self._registered_services: Dict[str, RTVIService] = {}
@@ -1441,7 +1547,7 @@ class RTVIProcessor(FrameProcessor):
elif isinstance(frame, RTVIActionFrame):
await self._action_queue.put(frame)
elif isinstance(frame, LLMConfigureOutputFrame):
self._skip_tts = frame.skip_tts
self._llm_skip_tts = frame.skip_tts
await self.push_frame(frame, direction)
# Other frames
else:
@@ -1697,9 +1803,9 @@ class RTVIProcessor(FrameProcessor):
opts = data.options if data.options is not None else RTVISendTextOptions()
if opts.run_immediately:
await self.interrupt_bot()
cur_skip_tts = self._skip_tts
cur_llm_skip_tts = self._llm_skip_tts
should_skip_tts = not opts.audio_response
toggle_skip_tts = cur_skip_tts != should_skip_tts
toggle_skip_tts = cur_llm_skip_tts != should_skip_tts
if toggle_skip_tts:
output_frame = LLMConfigureOutputFrame(skip_tts=should_skip_tts)
await self.push_frame(output_frame)
@@ -1709,7 +1815,7 @@ class RTVIProcessor(FrameProcessor):
)
await self.push_frame(text_frame)
if toggle_skip_tts:
output_frame = LLMConfigureOutputFrame(skip_tts=cur_skip_tts)
output_frame = LLMConfigureOutputFrame(skip_tts=cur_llm_skip_tts)
await self.push_frame(output_frame)
async def _handle_update_context(self, data: RTVIAppendToContextData):

View File

@@ -26,7 +26,7 @@ from pipecat.frames.frames import (
TTSTextFrame,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.utils.string import TextPartForConcatenation, concatenate_aggregated_text
from pipecat.utils.string import concatenate_aggregated_text
from pipecat.utils.time import time_now_iso8601
@@ -98,9 +98,15 @@ class AssistantTranscriptProcessor(BaseTranscriptProcessor):
**kwargs: Additional arguments passed to parent class.
"""
super().__init__(**kwargs)
self._current_text_parts: List[TextPartForConcatenation] = []
self._current_text_parts: List[str] = []
self._aggregation_start_time: Optional[str] = None
# Whether to add spaces between text parts.
# (The use of this could be expanded to the UserTranscriptProcessor in
# the future if needed; currently the UserTranscriptProcessor assumes
# that user transcription frames do not need aggregation).
self._add_spaces = True
async def _emit_aggregated_text(self):
"""Aggregates and emits text fragments as a transcript message.
@@ -141,7 +147,7 @@ class AssistantTranscriptProcessor(BaseTranscriptProcessor):
Result: "Hello there how are you"
"""
if self._current_text_parts and self._aggregation_start_time:
content = concatenate_aggregated_text(self._current_text_parts)
content = concatenate_aggregated_text(self._current_text_parts, self._add_spaces)
if content:
logger.trace(f"Emitting aggregated assistant message: {content}")
message = TranscriptionMessage(
@@ -185,11 +191,11 @@ class AssistantTranscriptProcessor(BaseTranscriptProcessor):
if not self._aggregation_start_time:
self._aggregation_start_time = time_now_iso8601()
self._current_text_parts.append(
TextPartForConcatenation(
frame.text, includes_inter_part_spaces=frame.includes_inter_frame_spaces
)
)
# Track whether we need to add spaces between text parts
# Assumption: we can just keep track of the latest frame's value
self._add_spaces = not frame.includes_inter_frame_spaces
self._current_text_parts.append(frame.text)
# Push frame.
await self.push_frame(frame, direction)

View File

@@ -264,10 +264,7 @@ def _setup_webrtc_routes(
# Prepare runner arguments with the callback to run your bot
async def webrtc_connection_callback(connection):
bot_module = _get_bot_module()
runner_args = SmallWebRTCRunnerArguments(
webrtc_connection=connection, body=request.request_data
)
runner_args = SmallWebRTCRunnerArguments(webrtc_connection=connection)
background_tasks.add_task(bot_module.bot, runner_args)
# Delegate handling to SmallWebRTCRequestHandler
@@ -329,8 +326,7 @@ def _setup_webrtc_routes(
type=request_data["type"],
pc_id=request_data.get("pc_id"),
restart_pc=request_data.get("restart_pc"),
request_data=request_data.get("request_data")
or request_data.get("requestData"),
request_data=request_data,
)
return await offer(webrtc_request, background_tasks)
elif request.method == HTTPMethod.PATCH.value:

View File

@@ -281,14 +281,6 @@ async def maybe_capture_participant_camera(
except ImportError:
pass
try:
from pipecat.transports.smallwebrtc.transport import SmallWebRTCTransport
if isinstance(transport, SmallWebRTCTransport):
await transport.capture_participant_video(video_source="camera")
except ImportError:
pass
async def maybe_capture_participant_screen(
transport: BaseTransport, client: Any, framerate: int = 0
@@ -311,14 +303,6 @@ async def maybe_capture_participant_screen(
except ImportError:
pass
try:
from pipecat.transports.smallwebrtc.transport import SmallWebRTCTransport
if isinstance(transport, SmallWebRTCTransport):
await transport.capture_participant_video(video_source="screenVideo")
except ImportError:
pass
def _smallwebrtc_sdp_cleanup_ice_candidates(text: str, pattern: str) -> str:
"""Clean up ICE candidates in SDP text for SmallWebRTC.

View File

@@ -373,7 +373,9 @@ class AnthropicLLMService(LLMService):
if event.type == "content_block_delta":
if hasattr(event.delta, "text"):
await self.push_frame(LLMTextFrame(event.delta.text))
frame = LLMTextFrame(event.delta.text)
frame.includes_inter_frame_spaces = True
await self.push_frame(frame)
completion_tokens_estimate += self._estimate_tokens(event.delta.text)
elif hasattr(event.delta, "partial_json") and tool_use_block:
json_accumulator += event.delta.partial_json

View File

@@ -146,6 +146,15 @@ class AsyncAITTSService(InterruptibleTTSService):
"""
return True
@property
def includes_inter_frame_spaces(self) -> bool:
"""Indicates that AsyncAI TTSTextFrames include necessary inter-frame spaces.
Returns:
True, indicating that AsyncAI's text frames include necessary inter-frame spaces.
"""
return True
def language_to_service_language(self, language: Language) -> Optional[str]:
"""Convert a Language enum to Async language format.
@@ -424,6 +433,15 @@ class AsyncAIHttpTTSService(TTSService):
"""
return True
@property
def includes_inter_frame_spaces(self) -> bool:
"""Indicates that AsyncAI TTSTextFrames include necessary inter-frame spaces.
Returns:
True, indicating that AsyncAI's text frames include necessary inter-frame spaces.
"""
return True
def language_to_service_language(self, language: Language) -> Optional[str]:
"""Convert a Language enum to Async language format.

View File

@@ -1078,7 +1078,9 @@ class AWSBedrockLLMService(LLMService):
if "contentBlockDelta" in event:
delta = event["contentBlockDelta"]["delta"]
if "text" in delta:
await self.push_frame(LLMTextFrame(delta["text"]))
frame = LLMTextFrame(delta["text"])
frame.includes_inter_frame_spaces = True
await self.push_frame(frame)
completion_tokens_estimate += self._estimate_tokens(delta["text"])
elif "toolUse" in delta and "input" in delta["toolUse"]:
# Handle partial JSON for tool use

View File

@@ -27,6 +27,7 @@ from pydantic import BaseModel, Field
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.adapters.services.aws_nova_sonic_adapter import AWSNovaSonicLLMAdapter, Role
from pipecat.frames.frames import (
AggregationType,
BotStoppedSpeakingFrame,
CancelFrame,
EndFrame,
@@ -1027,7 +1028,7 @@ class AWSNovaSonicLLMService(LLMService):
logger.debug(f"Assistant response text added: {text}")
# Report the text of the assistant response.
frame = TTSTextFrame(text)
frame = TTSTextFrame(text, aggregated_by=AggregationType.SENTENCE)
frame.includes_inter_frame_spaces = True
await self.push_frame(frame)
@@ -1062,7 +1063,9 @@ class AWSNovaSonicLLMService(LLMService):
# TTSTextFrame would be ignored otherwise (the interruption frame
# would have cleared the assistant aggregator state).
await self.push_frame(LLMFullResponseStartFrame())
frame = TTSTextFrame(self._assistant_text_buffer)
frame = TTSTextFrame(
self._assistant_text_buffer, aggregated_by=AggregationType.SENTENCE
)
frame.includes_inter_frame_spaces = True
await self.push_frame(frame)
self._may_need_repush_assistant_text = False

View File

@@ -209,6 +209,15 @@ class AWSPollyTTSService(TTSService):
"""
return True
@property
def includes_inter_frame_spaces(self) -> bool:
"""Indicates that AWS TTSTextFrames include necessary inter-frame spaces.
Returns:
True, indicating that AWS's text frames include necessary inter-frame spaces.
"""
return True
def language_to_service_language(self, language: Language) -> Optional[str]:
"""Convert a Language enum to AWS Polly language format.

View File

@@ -151,6 +151,15 @@ class AzureBaseTTSService(TTSService):
"""
return True
@property
def includes_inter_frame_spaces(self) -> bool:
"""Indicates that Azure TTSTextFrames include necessary inter-frame spaces.
Returns:
True, indicating that Azure's text frames include necessary inter-frame spaces.
"""
return True
def language_to_service_language(self, language: Language) -> Optional[str]:
"""Convert a Language enum to Azure language format.

View File

@@ -10,7 +10,8 @@ import base64
import json
import uuid
import warnings
from typing import AsyncGenerator, List, Literal, Optional, Union
from enum import Enum
from typing import AsyncGenerator, List, Literal, Optional
from loguru import logger
from pydantic import BaseModel, Field
@@ -125,6 +126,72 @@ def language_to_cartesia_language(language: Language) -> Optional[str]:
return resolve_language(language, LANGUAGE_MAP, use_base_code=True)
class CartesiaEmotion(str, Enum):
"""Predefined Emotions supported by Cartesia."""
# Primary emotions supported by Cartesia
NEUTRAL = "neutral"
ANGRY = "angry"
EXCITED = "excited"
CONTENT = "content"
SAD = "sad"
SCARED = "scared"
# Additional emotions supported by Cartesia
HAPPY = "happy"
ENTHUSIASTIC = "enthusiastic"
ELATED = "elated"
EUPHORIC = "euphoric"
TRIUMPHANT = "triumphant"
AMAZED = "amazed"
SURPRISED = "surprised"
FLIRTATIOUS = "flirtatious"
JOKING_COMEDIC = "joking/comedic"
CURIOUS = "curious"
PEACEFUL = "peaceful"
SERENE = "serene"
CALM = "calm"
GRATEFUL = "grateful"
AFFECTIONATE = "affectionate"
TRUST = "trust"
SYMPATHETIC = "sympathetic"
ANTICIPATION = "anticipation"
MYSTERIOUS = "mysterious"
MAD = "mad"
OUTRAGED = "outraged"
FRUSTRATED = "frustrated"
AGITATED = "agitated"
THREATENED = "threatened"
DISGUSTED = "disgusted"
CONTEMPT = "contempt"
ENVIOUS = "envious"
SARCASTIC = "sarcastic"
IRONIC = "ironic"
DEJECTED = "dejected"
MELANCHOLIC = "melancholic"
DISAPPOINTED = "disappointed"
HURT = "hurt"
GUILTY = "guilty"
BORED = "bored"
TIRED = "tired"
REJECTED = "rejected"
NOSTALGIC = "nostalgic"
WISTFUL = "wistful"
APOLOGETIC = "apologetic"
HESITANT = "hesitant"
INSECURE = "insecure"
CONFUSED = "confused"
RESIGNED = "resigned"
ANXIOUS = "anxious"
PANICKED = "panicked"
ALARMED = "alarmed"
PROUD = "proud"
CONFIDENT = "confident"
DISTANT = "distant"
SKEPTICAL = "skeptical"
CONTEMPLATIVE = "contemplative"
DETERMINED = "determined"
class CartesiaTTSService(AudioContextWordTTSService):
"""Cartesia TTS service with WebSocket streaming and word timestamps.
@@ -182,6 +249,10 @@ class CartesiaTTSService(AudioContextWordTTSService):
container: Audio container format.
params: Additional input parameters for voice customization.
text_aggregator: Custom text aggregator for processing input text.
.. deprecated:: 0.0.95
Use an LLMTextProcessor before the TTSService for custom text aggregation.
aggregate_sentences: Whether to aggregate sentences within the TTSService.
**kwargs: Additional arguments passed to the parent service.
"""
@@ -200,10 +271,18 @@ class CartesiaTTSService(AudioContextWordTTSService):
push_text_frames=False,
pause_frame_processing=True,
sample_rate=sample_rate,
text_aggregator=text_aggregator or SkipTagsAggregator([("<spell>", "</spell>")]),
text_aggregator=text_aggregator,
**kwargs,
)
if not text_aggregator:
# Always skip tags added for spelled-out text
# Note: This is primarily to support backwards compatibility.
# The preferred way of taking advantage of Cartesia SSML Tags is
# to use an LLMTextProcessor and/or a text_transformer to identify
# and insert these tags for the purpose of the TTS service alone.
self._text_aggregator = SkipTagsAggregator([("<spell>", "</spell>")])
params = params or CartesiaTTSService.InputParams()
self._api_key = api_key
@@ -257,6 +336,27 @@ class CartesiaTTSService(AudioContextWordTTSService):
"""
return language_to_cartesia_language(language)
# A set of Cartesia-specific helpers for text transformations
def SPELL(text: str) -> str:
"""Wrap text in Cartesia spell tag."""
return f"<spell>{text}</spell>"
def EMOTION_TAG(emotion: CartesiaEmotion) -> str:
"""Convenience method to create an emotion tag."""
return f'<emotion value="{emotion}" />'
def PAUSE_TAG(seconds: float) -> str:
"""Convenience method to create a pause tag."""
return f'<break time="{seconds}s" />'
def VOLUME_TAG(volume: float) -> str:
"""Convenience method to create a volume tag."""
return f'<volume ratio="{volume}" />'
def SPEED_TAG(speed: float) -> str:
"""Convenience method to create a speed tag."""
return f'<speed ratio="{speed}" />'
def _is_cjk_language(self, language: str) -> bool:
"""Check if the given language is CJK (Chinese, Japanese, Korean).

View File

@@ -6,9 +6,7 @@
"""Deepgram Flux speech-to-text service implementation."""
import asyncio
import json
import time
from enum import Enum
from typing import Any, AsyncGenerator, Dict, Optional
from urllib.parse import urlencode
@@ -96,7 +94,6 @@ class DeepgramFluxSTTService(WebsocketSTTService):
mip_opt_out: Optional. Opts out requests from the Deepgram Model Improvement Program
(default False).
tag: List of tags to label requests for identification during usage reporting.
min_confidence: Optional. Minimum confidence required confidence to create a TranscriptionFrame
"""
eager_eot_threshold: Optional[float] = None
@@ -105,7 +102,6 @@ class DeepgramFluxSTTService(WebsocketSTTService):
keyterm: list = []
mip_opt_out: Optional[bool] = None
tag: list = []
min_confidence: Optional[float] = None # New parameter
def __init__(
self,
@@ -167,13 +163,6 @@ class DeepgramFluxSTTService(WebsocketSTTService):
self._register_event_handler("on_end_of_turn")
self._register_event_handler("on_eager_end_of_turn")
self._register_event_handler("on_update")
self._connection_established_event = asyncio.Event()
# Watchdog task to prevent dangling tasks
# If we stop sending audio to Flux after we have received that the User has started speaking
# we never receive the user stopped speaking event unless we resume sending audio to it.
self._last_stt_time = None
self._watchdog_task = None
self._user_is_speaking = False
async def _connect(self):
"""Connect to WebSocket and start background tasks.
@@ -183,6 +172,9 @@ class DeepgramFluxSTTService(WebsocketSTTService):
"""
await self._connect_websocket()
if self._websocket and not self._receive_task:
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
async def _disconnect(self):
"""Disconnect from WebSocket and clean up tasks.
@@ -190,7 +182,14 @@ class DeepgramFluxSTTService(WebsocketSTTService):
and cleans up resources to prevent memory leaks.
"""
try:
# Cancel background tasks BEFORE closing websocket
if self._receive_task:
await self.cancel_task(self._receive_task, timeout=2.0)
self._receive_task = None
# Now close the websocket
await self._disconnect_websocket()
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
@@ -198,25 +197,6 @@ class DeepgramFluxSTTService(WebsocketSTTService):
# Reset state only after everything is cleaned up
self._websocket = None
async def _send_silence(self, duration_secs: float = 0.5):
"""Send a block of silence of the specified duration (default 500 ms)."""
sample_width = 2 # bytes per sample for 16-bit PCM
num_channels = 1 # mono
num_samples = int(self.sample_rate * duration_secs)
silence = b"\x00" * (num_samples * sample_width * num_channels)
await self._websocket.send(silence)
async def _watchdog_task_handler(self):
while self._websocket and self._websocket.state is State.OPEN:
now = time.monotonic()
# More than 500 ms without sending new audio to Flux
if self._user_is_speaking and self._last_stt_time and now - self._last_stt_time > 0.5:
logger.warning("Sending silence to Flux to prevent dangling task")
await self._send_silence()
self._last_stt_time = time.monotonic()
# check every 100ms
await asyncio.sleep(0.1)
async def _connect_websocket(self):
"""Establish WebSocket connection to API.
@@ -228,26 +208,10 @@ class DeepgramFluxSTTService(WebsocketSTTService):
if self._websocket and self._websocket.state is State.OPEN:
return
self._connection_established_event.clear()
self._user_is_speaking = False
self._websocket = await websocket_connect(
self._websocket_url,
additional_headers={"Authorization": f"Token {self._api_key}"},
)
# Creating the receiver task
if not self._receive_task:
self._receive_task = self.create_task(
self._receive_task_handler(self._report_error)
)
# Creating the watchdog task
if not self._watchdog_task:
self._watchdog_task = self.create_task(self._watchdog_task_handler())
# Now wait for the connection established event
logger.debug("WebSocket connected, waiting for server confirmation...")
await self._connection_established_event.wait()
logger.debug("Connected to Deepgram Flux Websocket")
await self._call_event_handler("on_connected")
except Exception as e:
@@ -263,16 +227,6 @@ class DeepgramFluxSTTService(WebsocketSTTService):
metrics collection. Handles disconnection errors gracefully.
"""
try:
# Cancel background tasks BEFORE closing websocket
if self._receive_task:
await self.cancel_task(self._receive_task, timeout=2.0)
self._receive_task = None
if self._watchdog_task:
await self.cancel_task(self._watchdog_task, timeout=2.0)
self._watchdog_task = None
self._last_stt_time = None
self._connection_established_event.clear()
await self.stop_all_metrics()
if self._websocket:
@@ -386,8 +340,7 @@ class DeepgramFluxSTTService(WebsocketSTTService):
return
try:
self._last_stt_time = time.monotonic()
await self.send_with_retry(audio, self._report_error)
await self._websocket.send(audio)
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
@@ -510,8 +463,6 @@ class DeepgramFluxSTTService(WebsocketSTTService):
transcription processing.
"""
logger.info("Connected to Flux - ready to stream audio")
# Notify connection is established
self._connection_established_event.set()
async def _handle_fatal_error(self, data: Dict[str, Any]):
"""Handle fatal error messages from Deepgram Flux.
@@ -579,7 +530,6 @@ class DeepgramFluxSTTService(WebsocketSTTService):
transcript: maybe the first few words of the turn.
"""
logger.debug("User started speaking")
self._user_is_speaking = True
await self.push_interruption_task_frame_and_wait()
await self.broadcast_frame(UserStartedSpeakingFrame)
await self.start_metrics()
@@ -600,22 +550,6 @@ class DeepgramFluxSTTService(WebsocketSTTService):
logger.trace(f"Received event TurnResumed: {event}")
await self._call_event_handler("on_turn_resumed")
def _calculate_average_confidence(self, transcript_data) -> Optional[float]:
"""Calculate the average confidence from transcript data.
Return None if the data is missing or invalid.
"""
# Example: Assume transcript_data has a list of words with confidence
words = transcript_data.get("words")
if not words or not isinstance(words, list):
return None
confidences = [
w.get("confidence") for w in words if isinstance(w.get("confidence"), (float, int))
]
if not confidences:
return None
return sum(confidences) / len(confidences)
async def _handle_end_of_turn(self, transcript: str, data: Dict[str, Any]):
"""Handle EndOfTurn events from Deepgram Flux.
@@ -635,26 +569,16 @@ class DeepgramFluxSTTService(WebsocketSTTService):
data: The TurnInfo message data containing event type, transcript and some extra metadata.
"""
logger.debug("User stopped speaking")
self._user_is_speaking = False
# Compute the average confidence
average_confidence = self._calculate_average_confidence(data)
if not self._params.min_confidence or average_confidence > self._params.min_confidence:
await self.push_frame(
TranscriptionFrame(
transcript,
self._user_id,
time_now_iso8601(),
self._language,
result=data,
)
await self.push_frame(
TranscriptionFrame(
transcript,
self._user_id,
time_now_iso8601(),
self._language,
result=data,
)
else:
logger.warning(
f"Transcription confidence below min_confidence threshold: {average_confidence}"
)
)
await self._handle_transcription(transcript, True, self._language)
await self.stop_processing_metrics()
await self.push_frame(UserStoppedSpeakingFrame(), FrameDirection.DOWNSTREAM)

View File

@@ -79,6 +79,15 @@ class DeepgramTTSService(TTSService):
"""
return True
@property
def includes_inter_frame_spaces(self) -> bool:
"""Indicates that Deepgram TTSTextFrames include necessary inter-frame spaces.
Returns:
True, indicating that Deepgram's text frames include necessary inter-frame spaces.
"""
return True
@traced_tts
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
"""Generate speech from text using Deepgram's TTS API.
@@ -168,6 +177,15 @@ class DeepgramHttpTTSService(TTSService):
"""
return True
@property
def includes_inter_frame_spaces(self) -> bool:
"""Indicates that Deepgram TTSTextFrames include necessary inter-frame spaces.
Returns:
True, indicating that Deepgram's text frames include necessary inter-frame spaces.
"""
return True
@traced_tts
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
"""Generate speech from text using Deepgram's TTS API.

View File

@@ -159,6 +159,15 @@ class FishAudioTTSService(InterruptibleTTSService):
"""
return True
@property
def includes_inter_frame_spaces(self) -> bool:
"""Indicates that Fish Audio TTSTextFrames include necessary inter-frame spaces.
Returns:
True, indicating that Fish Audio's text frames include necessary inter-frame spaces.
"""
return True
async def set_model(self, model: str):
"""Set the TTS model and reconnect.

View File

@@ -27,6 +27,7 @@ from pydantic import BaseModel, Field
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.adapters.services.gemini_adapter import GeminiLLMAdapter
from pipecat.frames.frames import (
AggregationType,
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
CancelFrame,
@@ -1452,6 +1453,8 @@ class GeminiLiveLLMService(LLMService):
self._bot_text_buffer += text
self._search_result_buffer += text # Also accumulate for grounding
frame = LLMTextFrame(text=text)
# Gemini Live text already includes any necessary inter-chunk spaces
frame.includes_inter_frame_spaces = True
await self.push_frame(frame)
# Check for grounding metadata in server content
@@ -1644,7 +1647,7 @@ class GeminiLiveLLMService(LLMService):
await self.push_frame(TTSStartedFrame())
await self.push_frame(LLMFullResponseStartFrame())
frame = TTSTextFrame(text=text)
frame = TTSTextFrame(text=text, aggregated_by=AggregationType.SENTENCE)
# Gemini Live text already includes any necessary inter-chunk spaces
frame.includes_inter_frame_spaces = True

View File

@@ -920,7 +920,9 @@ class GoogleLLMService(LLMService):
for part in candidate.content.parts:
if not part.thought and part.text:
search_result += part.text
await self.push_frame(LLMTextFrame(part.text))
frame = LLMTextFrame(part.text)
frame.includes_inter_frame_spaces = True
await self.push_frame(frame)
elif part.function_call:
function_call = part.function_call
id = function_call.id or str(uuid.uuid4())

View File

@@ -596,6 +596,15 @@ class GoogleHttpTTSService(TTSService):
"""
return True
@property
def includes_inter_frame_spaces(self) -> bool:
"""Indicates that Google TTSTextFrames include necessary inter-frame spaces.
Returns:
True, indicating that Google's text frames include necessary inter-frame spaces.
"""
return True
def language_to_service_language(self, language: Language) -> Optional[str]:
"""Convert a Language enum to Google TTS language format.
@@ -794,6 +803,15 @@ class GoogleBaseTTSService(TTSService):
"""
return True
@property
def includes_inter_frame_spaces(self) -> bool:
"""Indicates that Google and Gemini TTSTextFrames include necessary inter-frame spaces.
Returns:
True, indicating that Google's text frames include necessary inter-frame spaces.
"""
return True
def language_to_service_language(self, language: Language) -> Optional[str]:
"""Convert a Language enum to Google TTS language format.

View File

@@ -111,6 +111,15 @@ class GroqTTSService(TTSService):
"""
return True
@property
def includes_inter_frame_spaces(self) -> bool:
"""Indicates that Groq TTSTextFrames include necessary inter-frame spaces.
Returns:
True, indicating that Groq's text frames include necessary inter-frame spaces.
"""
return True
@traced_tts
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
"""Generate speech from text using Groq's TTS API.

View File

@@ -14,14 +14,12 @@ from pydantic import BaseModel
from pipecat.frames.frames import (
ErrorFrame,
Frame,
InterruptionFrame,
StartFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.tts_service import WordTTSService
from pipecat.services.tts_service import TTSService
from pipecat.utils.tracing.service_decorators import traced_tts
try:
@@ -31,7 +29,6 @@ try:
PostedUtterance,
PostedUtteranceVoiceWithId,
)
from hume.tts.types import TimestampMessage
except ModuleNotFoundError as e: # pragma: no cover - import-time guidance
logger.error(f"Exception: {e}")
logger.error("In order to use Hume, you need to `pip install pipecat-ai[hume]`.")
@@ -41,7 +38,7 @@ except ModuleNotFoundError as e: # pragma: no cover - import-time guidance
HUME_SAMPLE_RATE = 48_000 # Hume TTS streams at 48 kHz
class HumeTTSService(WordTTSService):
class HumeTTSService(TTSService):
"""Hume Octave Text-to-Speech service.
Streams PCM audio via Hume's HTTP output streaming (JSON chunks) endpoint
@@ -51,7 +48,6 @@ class HumeTTSService(WordTTSService):
- Generates speech from text using Hume TTS.
- Streams PCM audio.
- Supports word-level timestamps for precise audio-text synchronization.
- Supports dynamic updates of voice and synthesis parameters at runtime.
- Provides metrics for Time To First Byte (TTFB) and TTS usage.
"""
@@ -96,13 +92,7 @@ class HumeTTSService(WordTTSService):
f"Hume TTS streams at {HUME_SAMPLE_RATE} Hz; configured sample_rate={sample_rate}"
)
# WordTTSService sets push_text_frames=False by default, which we want
super().__init__(
sample_rate=sample_rate,
push_text_frames=False,
push_stop_frames=True,
**kwargs,
)
super().__init__(sample_rate=sample_rate, **kwargs)
self._client = AsyncHumeClient(api_key=api_key)
self._params = params or HumeTTSService.InputParams()
@@ -112,10 +102,6 @@ class HumeTTSService(WordTTSService):
self._audio_bytes = b""
# Track cumulative time for word timestamps across utterances
self._cumulative_time = 0.0
self._started = False
def can_generate_metrics(self) -> bool:
"""Can generate metrics.
@@ -124,6 +110,15 @@ class HumeTTSService(WordTTSService):
"""
return True
@property
def includes_inter_frame_spaces(self) -> bool:
"""Indicates that Hume TTSTextFrames include necessary inter-frame spaces.
Returns:
True, indicating that Hume's text frames include necessary inter-frame spaces.
"""
return True
async def start(self, frame: StartFrame) -> None:
"""Start the service.
@@ -131,27 +126,6 @@ class HumeTTSService(WordTTSService):
frame: The start frame.
"""
await super().start(frame)
self._reset_state()
def _reset_state(self):
"""Reset internal state variables."""
self._cumulative_time = 0.0
self._started = False
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
"""Push a frame and handle state changes.
Args:
frame: The frame to push.
direction: The direction to push the frame.
"""
await super().push_frame(frame, direction)
if isinstance(frame, (InterruptionFrame, TTSStoppedFrame)):
# Reset timing on interruption or stop
self._reset_state()
if isinstance(frame, TTSStoppedFrame):
await self.add_word_timestamps([("Reset", 0)])
async def update_setting(self, key: str, value: Any) -> None:
"""Runtime updates via `TTSUpdateSettingsFrame`.
@@ -168,7 +142,7 @@ class HumeTTSService(WordTTSService):
if key_l == "voice_id":
self.set_voice(str(value))
logger.debug(f"HumeTTSService voice_id set to: {self.voice}")
logger.info(f"HumeTTSService voice_id set to: {self.voice}")
elif key_l == "description":
self._params.description = None if value is None else str(value)
elif key_l == "speed":
@@ -181,7 +155,7 @@ class HumeTTSService(WordTTSService):
@traced_tts
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
"""Generate speech from text using Hume TTS with word timestamps.
"""Generate speech from text using Hume TTS.
Args:
text: The text to be synthesized.
@@ -212,12 +186,7 @@ class HumeTTSService(WordTTSService):
await self.start_ttfb_metrics()
await self.start_tts_usage_metrics(text)
# Start TTS sequence if not already started
if not self._started:
self.start_word_timestamps()
yield TTSStartedFrame()
self._started = True
yield TTSStartedFrame()
try:
# Instant mode is always enabled here (not user-configurable)
@@ -228,50 +197,23 @@ class HumeTTSService(WordTTSService):
# Use version "2" by default if no description is provided
# Version "1" is needed when description is used
version = "1" if self._params.description is not None else "2"
# Track the duration of this utterance based on the last timestamp
utterance_duration = 0.0
async for chunk in self._client.tts.synthesize_json_streaming(
utterances=[utterance],
format=pcm_fmt,
instant_mode=True,
version=version,
include_timestamp_types=["word"], # Request word-level timestamps
):
# Process audio chunks
audio_b64 = getattr(chunk, "audio", None)
if audio_b64:
await self.stop_ttfb_metrics()
pcm_bytes = base64.b64decode(audio_b64)
self._audio_bytes += pcm_bytes
if not audio_b64:
continue
# Buffer audio until we have enough to avoid glitches
if len(self._audio_bytes) >= self.chunk_size:
frame = TTSAudioRawFrame(
audio=self._audio_bytes,
sample_rate=self.sample_rate,
num_channels=1,
)
yield frame
self._audio_bytes = b""
pcm_bytes = base64.b64decode(audio_b64)
self._audio_bytes += pcm_bytes
# Process timestamp messages
if isinstance(chunk, TimestampMessage):
timestamp = chunk.timestamp
if timestamp.type == "word":
# Convert milliseconds to seconds and add cumulative offset
word_start_time = self._cumulative_time + (timestamp.time.begin / 1000.0)
word_end_time = self._cumulative_time + (timestamp.time.end / 1000.0)
# Buffer audio until we have enough to avoid glitches
if len(self._audio_bytes) < self.chunk_size:
continue
# Track the maximum end time for this utterance
utterance_duration = max(utterance_duration, word_end_time)
# Add word timestamp
await self.add_word_timestamps([(timestamp.text, word_start_time)])
# Flush any remaining audio bytes
if self._audio_bytes:
frame = TTSAudioRawFrame(
audio=self._audio_bytes,
sample_rate=self.sample_rate,
@@ -282,14 +224,10 @@ class HumeTTSService(WordTTSService):
self._audio_bytes = b""
# Update cumulative time for next utterance
if utterance_duration > 0:
self._cumulative_time = utterance_duration
except Exception as e:
logger.error(f"{self} exception: {e}")
await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
finally:
# Ensure TTFB timer is stopped even on early failures
await self.stop_ttfb_metrics()
# Let the parent class handle TTSStoppedFrame via push_stop_frames
yield TTSStoppedFrame()

View File

@@ -146,8 +146,6 @@ class InworldTTSService(TTSService):
Parameters:
temperature: Voice temperature control for synthesis variability (e.g., 1.1).
Valid range: [0, 2]. Higher values increase variability.
speaking_rate: Speaking speed control (range: [0.5, 1.5]). Defaults to 1.0 when
unset.
Note:
Language is automatically inferred from the input text by Inworld's TTS models,
@@ -155,7 +153,6 @@ class InworldTTSService(TTSService):
"""
temperature: Optional[float] = None # optional temperature control (range: [0, 2])
speaking_rate: Optional[float] = None # optional speaking rate control (range: [0.5, 1.5])
def __init__(
self,
@@ -201,7 +198,6 @@ class InworldTTSService(TTSService):
- Other formats as supported by Inworld API
params: Optional input parameters for additional configuration. Use this to specify:
- temperature: Voice temperature control for variability (range: [0, 2], e.g., 1.1, optional)
- speaking_rate: Set desired speaking speed (range: [0.5, 1.5], optional)
Language is automatically inferred from input text.
**kwargs: Additional arguments passed to the parent TTSService class.
@@ -232,18 +228,15 @@ class InworldTTSService(TTSService):
self._settings = {
"voiceId": voice_id, # Voice selection from direct parameter
"modelId": model, # TTS model selection from direct parameter
"audioConfig": { # Audio format configuration
"audioEncoding": encoding, # Format: LINEAR16, MP3, etc.
"sampleRateHertz": 0, # Will be set in start() from parent service
"audio_config": { # Audio format configuration
"audio_encoding": encoding, # Format: LINEAR16, MP3, etc.
"sample_rate_hertz": 0, # Will be set in start() from parent service
},
}
# Add optional temperature parameter if provided (valid range: [0, 2])
if params and params.temperature is not None:
self._settings["temperature"] = params.temperature
# Add optional speaking rate if provided (valid range: [0.5, 1.5])
if params and params.speaking_rate is not None:
self._settings["audioConfig"]["speakingRate"] = params.speaking_rate
# Register voice and model with parent service for metrics and tracking
self.set_voice(voice_id) # Used for logging and metrics
@@ -257,6 +250,15 @@ class InworldTTSService(TTSService):
"""
return True
@property
def includes_inter_frame_spaces(self) -> bool:
"""Indicates that Inworld TTSTextFrames include necessary inter-frame spaces.
Returns:
True, indicating that Inworld's text frames include necessary inter-frame spaces.
"""
return True
async def start(self, frame: StartFrame):
"""Start the Inworld TTS service.
@@ -264,7 +266,7 @@ class InworldTTSService(TTSService):
frame: The start frame containing initialization parameters.
"""
await super().start(frame)
self._settings["audioConfig"]["sampleRateHertz"] = self.sample_rate
self._settings["audio_config"]["sample_rate_hertz"] = self.sample_rate
async def stop(self, frame: EndFrame):
"""Stop the Inworld TTS service.
@@ -330,7 +332,9 @@ class InworldTTSService(TTSService):
"text": text, # Text to synthesize
"voiceId": self._settings["voiceId"], # Voice selection (Ashley, Hades, etc.)
"modelId": self._settings["modelId"], # TTS model (inworld-tts-1)
"audioConfig": self._settings["audioConfig"], # Audio format settings (LINEAR16, 48kHz)
"audio_config": self._settings[
"audio_config"
], # Audio format settings (LINEAR16, 48kHz)
}
# Add optional temperature parameter if configured (valid range: [0, 2])

View File

@@ -124,6 +124,15 @@ class LmntTTSService(InterruptibleTTSService):
"""
return True
@property
def includes_inter_frame_spaces(self) -> bool:
"""Indicates that LMNT TTSTextFrames include necessary inter-frame spaces.
Returns:
True, indicating that LMNT's text frames include necessary inter-frame spaces.
"""
return True
def language_to_service_language(self, language: Language) -> Optional[str]:
"""Convert a Language enum to LMNT service language format.

View File

@@ -194,6 +194,15 @@ class MiniMaxHttpTTSService(TTSService):
"""
return True
@property
def includes_inter_frame_spaces(self) -> bool:
"""Indicates that MiniMax TTSTextFrames include necessary inter-frame spaces.
Returns:
True, indicating that MiniMax's text frames include necessary inter-frame spaces.
"""
return True
def language_to_service_language(self, language: Language) -> Optional[str]:
"""Convert a Language enum to MiniMax service language format.

View File

@@ -151,6 +151,15 @@ class NeuphonicTTSService(InterruptibleTTSService):
"""
return True
@property
def includes_inter_frame_spaces(self) -> bool:
"""Indicates that Neuphonic TTSTextFrames include necessary inter-frame spaces.
Returns:
True, indicating that Neuphonic's text frames include necessary inter-frame spaces.
"""
return True
def language_to_service_language(self, language: Language) -> Optional[str]:
"""Convert a Language enum to Neuphonic service language format.
@@ -440,6 +449,15 @@ class NeuphonicHttpTTSService(TTSService):
"""
return True
@property
def includes_inter_frame_spaces(self) -> bool:
"""Indicates that Neuphonic TTSTextFrames include necessary inter-frame spaces.
Returns:
True, indicating that Neuphonic's text frames include necessary inter-frame spaces.
"""
return True
def language_to_service_language(self, language: Language) -> Optional[str]:
"""Convert a Language enum to Neuphonic service language format.

View File

@@ -390,7 +390,9 @@ class BaseOpenAILLMService(LLMService):
# Keep iterating through the response to collect all the argument fragments
arguments += tool_call.function.arguments
elif chunk.choices[0].delta.content:
await self.push_frame(LLMTextFrame(chunk.choices[0].delta.content))
frame = LLMTextFrame(chunk.choices[0].delta.content)
frame.includes_inter_frame_spaces = True
await self.push_frame(frame)
# When gpt-4o-audio / gpt-4o-mini-audio is used for llm or stt+llm
# we need to get LLMTextFrame for the transcript

View File

@@ -19,6 +19,7 @@ from pipecat.adapters.services.open_ai_realtime_adapter import (
OpenAIRealtimeLLMAdapter,
)
from pipecat.frames.frames import (
AggregationType,
BotStoppedSpeakingFrame,
CancelFrame,
EndFrame,
@@ -678,13 +679,15 @@ class OpenAIRealtimeLLMService(LLMService):
# the output modality is "text"
if evt.delta:
frame = LLMTextFrame(evt.delta)
# OpenAI Realtime text already includes any necessary inter-chunk spaces
frame.includes_inter_frame_spaces = True
await self.push_frame(frame)
async def _handle_evt_audio_transcript_delta(self, evt):
# We receive audio transcript deltas (as opposed to text deltas) when
# the output modality is "audio" (the default)
if evt.delta:
frame = TTSTextFrame(evt.delta)
frame = TTSTextFrame(evt.delta, aggregated_by=AggregationType.SENTENCE)
# OpenAI Realtime text already includes any necessary inter-chunk spaces
frame.includes_inter_frame_spaces = True
await self.push_frame(frame)

View File

@@ -131,6 +131,15 @@ class OpenAITTSService(TTSService):
"""
return True
@property
def includes_inter_frame_spaces(self) -> bool:
"""Indicates that OpenAI TTSTextFrames include necessary inter-frame spaces.
Returns:
True, indicating that OpenAI's text frames include necessary inter-frame spaces.
"""
return True
async def set_model(self, model: str):
"""Set the TTS model to use.

View File

@@ -17,6 +17,7 @@ from loguru import logger
from pipecat.adapters.services.open_ai_realtime_adapter import OpenAIRealtimeLLMAdapter
from pipecat.frames.frames import (
AggregationType,
BotStoppedSpeakingFrame,
CancelFrame,
EndFrame,
@@ -652,7 +653,7 @@ class OpenAIRealtimeBetaLLMService(LLMService):
async def _handle_evt_audio_transcript_delta(self, evt):
if evt.delta:
await self.push_frame(LLMTextFrame(evt.delta))
await self.push_frame(TTSTextFrame(evt.delta))
await self.push_frame(TTSTextFrame(evt.delta, aggregated_by=AggregationType.SENTENCE))
async def _handle_evt_speech_started(self, evt):
await self._truncate_current_audio_response()

View File

@@ -66,6 +66,15 @@ class PiperTTSService(TTSService):
"""
return True
@property
def includes_inter_frame_spaces(self) -> bool:
"""Indicates that Piper TTSTextFrames include necessary inter-frame spaces.
Returns:
True, indicating that Piper's text frames include necessary inter-frame spaces.
"""
return True
@traced_tts
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
"""Generate speech from text using Piper's HTTP API.

View File

@@ -113,6 +113,10 @@ class RimeTTSService(AudioContextWordTTSService):
sample_rate: Audio sample rate in Hz.
params: Additional configuration parameters.
text_aggregator: Custom text aggregator for processing input text.
.. deprecated:: 0.0.95
Use an LLMTextProcessor before the TTSService for custom text aggregation.
aggregate_sentences: Whether to aggregate sentences within the TTSService.
**kwargs: Additional arguments passed to parent class.
"""
@@ -123,10 +127,17 @@ class RimeTTSService(AudioContextWordTTSService):
push_stop_frames=True,
pause_frame_processing=True,
sample_rate=sample_rate,
text_aggregator=text_aggregator or SkipTagsAggregator([("spell(", ")")]),
**kwargs,
)
if not text_aggregator:
# Always skip tags added for spelled-out text
# Note: This is primarily to support backwards compatibility.
# The preferred way of taking advantage of Rime spelling is
# to use an LLMTextProcessor and/or a text_transformer to identify
# and insert these tags for the purpose of the TTS service alone.
self._text_aggregator = SkipTagsAggregator([("spell(", ")")])
params = params or RimeTTSService.InputParams()
# Store service configuration
@@ -152,6 +163,7 @@ class RimeTTSService(AudioContextWordTTSService):
self._context_id = None # Tracks current turn
self._receive_task = None
self._cumulative_time = 0 # Accumulates time across messages
self._extra_msg_fields = {} # Extra fields for next message
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.
@@ -181,6 +193,31 @@ class RimeTTSService(AudioContextWordTTSService):
self._model = model
await super().set_model(model)
# A set of Rime-specific helpers for text transformations
def SPELL(text: str) -> str:
"""Wrap text in Rime spell function."""
return f"spell({text})"
def PAUSE_TAG(seconds: float) -> str:
"""Convenience method to create a pause tag."""
return f"<{seconds * 1000}>"
def PRONOUNCE(self, text: str, word: str, phoneme: str) -> str:
"""Convenience method to support Rime's custom pronunciations feature.
https://docs.rime.ai/api-reference/custom-pronunciation
"""
self._extra_msg_fields["phonemizeBetweenBrackets"] = True
return text.replace(word, f"{phoneme}")
def INLINE_SPEED(self, text: str, speed: float) -> str:
"""Convenience method to support inline speeds."""
if not self._extra_msg_fields:
self._extra_msg_fields = {}
speed_vals = self._extra_msg_fields.get("inlineSpeedAlpha", "").split(",")
self._extra_msg_fields["inlineSpeedAlpha"] = ",".join(speed_vals + [str(speed)])
return f"[{text}]"
async def _update_settings(self, settings: Mapping[str, Any]):
"""Update service settings and reconnect if voice changed."""
prev_voice = self._voice_id
@@ -193,7 +230,11 @@ class RimeTTSService(AudioContextWordTTSService):
def _build_msg(self, text: str = "") -> dict:
"""Build JSON message for Rime API."""
return {"text": text, "contextId": self._context_id}
msg = {"text": text, "contextId": self._context_id}
if self._extra_msg_fields:
msg |= self._extra_msg_fields
self._extra_msg_fields = {}
return msg
def _build_clear_msg(self) -> dict:
"""Build clear operation message."""
@@ -501,6 +542,15 @@ class RimeHttpTTSService(TTSService):
"""
return True
@property
def includes_inter_frame_spaces(self) -> bool:
"""Indicates that Rime TTSTextFrames include necessary inter-frame spaces.
Returns:
True, indicating that Rime's text frames include necessary inter-frame spaces.
"""
return True
def language_to_service_language(self, language: Language) -> str | None:
"""Convert pipecat language to Rime language code.

View File

@@ -113,6 +113,15 @@ class RivaTTSService(TTSService):
riva.client.proto.riva_tts_pb2.RivaSynthesisConfigRequest()
)
@property
def includes_inter_frame_spaces(self) -> bool:
"""Indicates that Riva TTSTextFrames include necessary inter-frame spaces.
Returns:
True, indicating that Riva's text frames include necessary inter-frame spaces.
"""
return True
async def set_model(self, model: str):
"""Attempt to set the TTS model.
@@ -157,6 +166,7 @@ class RivaTTSService(TTSService):
add_response(None)
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(error=f"{self} error: {e}")
add_response(None)
await self.start_ttfb_metrics()
@@ -181,7 +191,6 @@ class RivaTTSService(TTSService):
resp = await asyncio.wait_for(queue.get(), timeout=RIVA_TTS_TIMEOUT_SECS)
except asyncio.TimeoutError:
logger.error(f"{self} timeout waiting for audio response")
yield ErrorFrame(error=f"{self} error: {e}")
await self.start_tts_usage_metrics(text)
yield TTSStoppedFrame()

View File

@@ -176,7 +176,9 @@ class SambaNovaLLMService(OpenAILLMService): # type: ignore
# Keep iterating through the response to collect all the argument fragments
arguments += tool_call.function.arguments
elif chunk.choices[0].delta.content:
await self.push_frame(LLMTextFrame(chunk.choices[0].delta.content))
frame = LLMTextFrame(chunk.choices[0].delta.content)
frame.includes_inter_frame_spaces = True
await self.push_frame(frame)
# When gpt-4o-audio / gpt-4o-mini-audio is used for llm or stt+llm
# we need to get LLMTextFrame for the transcript

View File

@@ -195,6 +195,15 @@ class SarvamHttpTTSService(TTSService):
"""
return True
@property
def includes_inter_frame_spaces(self) -> bool:
"""Indicates that Sarvam TTSTextFrames include necessary inter-frame spaces.
Returns:
True, indicating that Sarvam's text frames include necessary inter-frame spaces.
"""
return True
def language_to_service_language(self, language: Language) -> Optional[str]:
"""Convert a Language enum to Sarvam AI language format.
@@ -458,6 +467,15 @@ class SarvamTTSService(InterruptibleTTSService):
"""
return True
@property
def includes_inter_frame_spaces(self) -> bool:
"""Indicates that Sarvam TTSTextFrames include necessary inter-frame spaces.
Returns:
True, indicating that Sarvam's text frames include necessary inter-frame spaces.
"""
return True
def language_to_service_language(self, language: Language) -> Optional[str]:
"""Convert a Language enum to Sarvam AI language format.

View File

@@ -84,10 +84,6 @@ class SimliVideoService(FrameProcessor):
Please use 'api_key' and 'face_id' parameters instead.
use_turn_server: Whether to use TURN server for connection. Defaults to False.
.. deprecated:: 0.0.95
The 'use_turn_server' parameter is deprecated and will be removed in a future version.
latency_interval: Latency interval setting for sending health checks to check
the latency to Simli Servers. Defaults to 0.
simli_url: URL of the simli servers. Can be changed for custom deployments
@@ -139,20 +135,14 @@ class SimliVideoService(FrameProcessor):
config = SimliConfig(**config_kwargs)
if use_turn_server:
warnings.warn(
"The 'use_turn_server' parameter is deprecated and will be removed in a future version.",
DeprecationWarning,
stacklevel=2,
)
self._initialized = False
# Add buffer time to session limits
config.maxIdleTime += 5
config.maxSessionLength += 5
self._simli_client = SimliClient(
config=config,
latencyInterval=latency_interval,
config,
use_turn_server,
latency_interval,
simliURL=simli_url,
)

View File

@@ -105,6 +105,15 @@ class SpeechmaticsTTSService(TTSService):
"""
return True
@property
def includes_inter_frame_spaces(self) -> bool:
"""Indicates that Speechmatics TTSTextFrames include necessary inter-frame spaces.
Returns:
True, indicating that Speechmatics's text frames include necessary inter-frame spaces.
"""
return True
@traced_tts
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
"""Generate speech from text using Speechmatics' HTTP API.

View File

@@ -38,7 +38,7 @@ class STTService(AIService):
Event handlers:
on_connected: Called when connected to the STT service.
on_disconnected: Called when disconnected from the STT service.
on_connected: Called when disconnected from the STT service.
on_connection_error: Called when a connection to the STT service error occurs.
Example::

View File

@@ -12,6 +12,8 @@ from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Awaitable,
Callable,
Dict,
List,
Mapping,
@@ -23,6 +25,8 @@ from typing import (
from loguru import logger
from pipecat.frames.frames import (
AggregatedTextFrame,
AggregationType,
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
CancelFrame,
@@ -101,6 +105,16 @@ class TTSService(AIService):
sample_rate: Optional[int] = None,
# Text aggregator to aggregate incoming tokens and decide when to push to the TTS.
text_aggregator: Optional[BaseTextAggregator] = None,
# Types of text aggregations that should not be spoken.
skip_aggregator_types: Optional[List[str]] = [],
# A list of callables to transform text before just before sending it to TTS.
# Each callable takes the aggregated text and its type, and returns the transformed text.
# To register, provide a list of tuples of (aggregation_type | '*', transform_function).
text_transforms: Optional[
List[
Tuple[AggregationType | str, Callable[[str, str | AggregationType], Awaitable[str]]]
]
] = None,
# Text filter executed after text has been aggregated.
text_filters: Optional[Sequence[BaseTextFilter]] = None,
text_filter: Optional[BaseTextFilter] = None,
@@ -120,6 +134,16 @@ class TTSService(AIService):
pause_frame_processing: Whether to pause frame processing during audio generation.
sample_rate: Output sample rate for generated audio.
text_aggregator: Custom text aggregator for processing incoming text.
.. deprecated:: 0.0.95
Use an LLMTextProcessor before the TTSService for custom text aggregation.
skip_aggregator_types: List of aggregation types that should not be spoken.
text_transforms: A list of callables to transform text before just before sending it
to TTS. Each callable takes the aggregated text and its type, and returns the
transformed text. To register, provide a list of tuples of
(aggregation_type | '*', transform_function).
text_filters: Sequence of text filters to apply after aggregation.
text_filter: Single text filter (deprecated, use text_filters).
@@ -142,7 +166,21 @@ class TTSService(AIService):
self._voice_id: str = ""
self._settings: Dict[str, Any] = {}
self._text_aggregator: BaseTextAggregator = text_aggregator or SimpleTextAggregator()
self._aggregated_text_includes_inter_frame_spaces: bool = False
if text_aggregator:
import warnings
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"Parameter 'text_aggregator' is deprecated. Use an LLMTextProcessor before the TTSService for custom text aggregation.",
DeprecationWarning,
)
self._skip_aggregator_types: List[str] = skip_aggregator_types or []
self._text_transforms: List[
Tuple[AggregationType | str, Callable[[str, AggregationType | str], Awaitable[str]]]
] = text_transforms or []
# TODO: Deprecate _text_filters when added to LLMTextProcessor
self._text_filters: Sequence[BaseTextFilter] = text_filters or []
self._transport_destination: Optional[str] = transport_destination
self._tracing_enabled: bool = False
@@ -193,6 +231,23 @@ class TTSService(AIService):
CHUNK_SECONDS = 0.5
return int(self.sample_rate * CHUNK_SECONDS * 2) # 2 bytes/sample
@property
def includes_inter_frame_spaces(self) -> bool:
"""Indicates whether TTSTextFrames include necesary inter-frame spaces.
When True, the TTSTextFrame objects pushed by this service already
include all necessary spaces between subsequent frames. When False,
downstream processors (like the assistant context aggregator) may need
to add spacing.
Subclasses should override this property to return True if their text
generation process already includes necessary inter-frame spaces.
Returns:
False by default. Subclasses can override to return True.
"""
return False
async def set_model(self, model: str):
"""Set the TTS model to use.
@@ -282,6 +337,39 @@ class TTSService(AIService):
await self.cancel_task(self._stop_frame_task)
self._stop_frame_task = None
def add_text_transformer(
self,
transform_function: Callable[[str, AggregationType | str], Awaitable[str]],
aggregation_type: AggregationType | str = "*",
):
"""Transform text for a specific aggregation type.
Args:
transform_function: The function to apply for transformation. This function should take
the text and aggregation type as input and return the transformed text.
Ex.: async def my_transform(text: str, aggregation_type: str) -> str:
aggregation_type: The type of aggregation to transform. This value defaults to "*" indicating
the function should handle all text before sending to TTS.
"""
self._text_transforms.append((aggregation_type, transform_function))
def remove_text_transformer(
self,
transform_function: Callable[[str, AggregationType | str], Awaitable[str]],
aggregation_type: AggregationType | str = "*",
):
"""Remove a text transformer for a specific aggregation type.
Args:
transform_function: The function to remove.
aggregation_type: The type of aggregation to remove the transformer for.
"""
self._text_transforms = [
(agg_type, func)
for agg_type, func in self._text_transforms
if not (agg_type == aggregation_type and func == transform_function)
]
async def _update_settings(self, settings: Mapping[str, Any]):
for key, value in settings.items():
if key in self._settings:
@@ -337,6 +425,8 @@ class TTSService(AIService):
and frame.skip_tts
):
await self.push_frame(frame, direction)
elif isinstance(frame, AggregatedTextFrame):
await self._push_tts_frames(frame)
elif (
isinstance(frame, TextFrame)
and not isinstance(frame, InterimTranscriptionFrame)
@@ -352,17 +442,10 @@ class TTSService(AIService):
# pause to avoid audio overlapping.
await self._maybe_pause_frame_processing()
sentence = self._text_aggregator.text
includes_inter_frame_spaces = self._aggregated_text_includes_inter_frame_spaces
# Reset aggregator state
aggregate = self._text_aggregator.text
await self._text_aggregator.reset()
self._processing_text = False
self._aggregated_text_includes_inter_frame_spaces = False
await self._push_tts_frames(
sentence, includes_inter_frame_spaces=includes_inter_frame_spaces
)
await self._push_tts_frames(AggregatedTextFrame(aggregate.text, aggregate.type))
if isinstance(frame, LLMFullResponseEndFrame):
if self._push_text_frames:
await self.push_frame(frame, direction)
@@ -371,8 +454,7 @@ class TTSService(AIService):
elif isinstance(frame, TTSSpeakFrame):
# Store if we were processing text or not so we can set it back.
processing_text = self._processing_text
# Assumption: text in TTSSpeakFrame does not include inter-frame spaces
await self._push_tts_frames(frame.text, includes_inter_frame_spaces=False)
await self._push_tts_frames(AggregatedTextFrame(frame.text, AggregationType.SENTENCE))
# We pause processing incoming frames because we are sending data to
# the TTS. We pause to avoid audio overlapping.
await self._maybe_pause_frame_processing()
@@ -464,19 +546,24 @@ class TTSService(AIService):
text: Optional[str] = None
if not self._aggregate_sentences:
text = frame.text
aggregated_by = "token"
else:
text = await self._text_aggregator.aggregate(frame.text)
# Assumption: whether inter-frame spaces are included shouldn't
# change during aggregation, so we can just use the latest frame's
# value
self._aggregated_text_includes_inter_frame_spaces = frame.includes_inter_frame_spaces
aggregate = await self._text_aggregator.aggregate(frame.text)
if aggregate:
text = aggregate.text
aggregated_by = aggregate.type
if text:
await self._push_tts_frames(
text, includes_inter_frame_spaces=frame.includes_inter_frame_spaces
)
logger.trace(f"Pushing TTS frames for text: {text}, {aggregated_by}")
await self._push_tts_frames(AggregatedTextFrame(text, aggregated_by))
async def _push_tts_frames(self, src_frame: AggregatedTextFrame):
type = src_frame.aggregated_by
text = src_frame.text
if type in self._skip_aggregator_types:
await self.push_frame(src_frame)
return
async def _push_tts_frames(self, text: str, includes_inter_frame_spaces: bool):
# Remove leading newlines only
text = text.lstrip("\n")
@@ -497,16 +584,40 @@ class TTSService(AIService):
await filter.reset_interruption()
text = await filter.filter(text)
if text:
await self.process_generator(self.run_tts(text))
if not text.strip():
await self.stop_processing_metrics()
return
# To support use cases that may want to know the text before it's spoken, we
# push the AggregatedTextFrame version before transforming and sending to TTS.
# However, we do not want to add this text to the assistant context until it
# is spoken, so we set append_to_context to False.
src_frame.append_to_context = False
await self.push_frame(src_frame)
# Note: Text transformations are meant to only affect the text sent to the TTS for
# TTS-specific purposes. This allows for explicit TTS modifications (e.g., inserting
# TTS supported tags for spelling or emotion or replacing an @ with "at"). For TTS
# services that support word-level timestamps, this CAN affect the resulting context
# since the TTSTextFrames are generated from the TTS output stream
transformed_text = text
for aggregation_type, transform in self._text_transforms:
if aggregation_type == type or aggregation_type == "*":
transformed_text = await transform(transformed_text, type)
await self.process_generator(self.run_tts(transformed_text))
await self.stop_processing_metrics()
if self._push_text_frames:
# We send the original text after the audio. This way, if we are
# interrupted, the text is not added to the assistant context.
frame = TTSTextFrame(text)
frame.includes_inter_frame_spaces = includes_inter_frame_spaces
# In TTS services that support word timestamps, the TTSTextFrames
# are pushed as words are spoken. However, in the case where the TTS service
# does not support word timestamps (i.e. _push_text_frames is True), we send
# the original (non-transformed) text after the TTS generation has completed.
# This way, if we are interrupted, the text is not added to the assistant
# context and the context that IS added does not include TTS-specific tags
# or transformations.
frame = TTSTextFrame(text, aggregated_by=type)
frame.includes_inter_frame_spaces = self.includes_inter_frame_spaces
await self.push_frame(frame)
async def _stop_frame_handler(self):
@@ -633,9 +744,7 @@ class WordTTSService(TTSService):
frame = TTSStoppedFrame()
frame.pts = last_pts
else:
# Assumption: word-by-word text frames don't include spaces, so
# we can rely on the default includes_inter_frame_spaces=False
frame = TTSTextFrame(word)
frame = TTSTextFrame(word, aggregated_by=AggregationType.WORD)
frame.pts = self._initial_word_timestamp + timestamp
if frame:
last_pts = frame.pts

View File

@@ -36,7 +36,6 @@ class WebsocketService(ABC):
"""
self._websocket: Optional[websockets.WebSocketClientProtocol] = None
self._reconnect_on_error = reconnect_on_error
self._reconnect_in_progress: bool = False # Add this flag
async def _verify_connection(self) -> bool:
"""Verify the websocket connection is active and responsive.
@@ -67,59 +66,6 @@ class WebsocketService(ABC):
await self._connect_websocket()
return await self._verify_connection()
async def _try_reconnect(
self,
max_retries: int = 3,
report_error: Optional[Callable[[ErrorFrame], Awaitable[None]]] = None,
) -> bool:
# Prevent concurrent reconnection attempts
if self._reconnect_in_progress:
logger.warning(f"{self} reconnect attempt aborted: already in progress")
return False
self._reconnect_in_progress = True
last_exception: Optional[Exception] = None
try:
for attempt in range(1, max_retries + 1):
try:
logger.warning(f"{self} reconnecting, attempt {attempt}")
if await self._reconnect_websocket(attempt):
logger.info(f"{self} reconnected successfully on attempt {attempt}")
return True
except Exception as e:
last_exception = e
logger.error(f"{self} reconnection attempt {attempt} failed: {e}")
if report_error:
await report_error(
ErrorFrame(f"{self} reconnection attempt {attempt} failed: {e}")
)
wait_time = exponential_backoff_time(attempt)
await asyncio.sleep(wait_time)
fatal_msg = f"{self} failed to reconnect after {max_retries} attempts"
if last_exception:
fatal_msg += f": {last_exception}"
logger.error(fatal_msg)
if report_error:
await report_error(ErrorFrame(fatal_msg, fatal=True))
return False
finally:
self._reconnect_in_progress = False
async def send_with_retry(self, message, report_error: Callable[[ErrorFrame], Awaitable[None]]):
"""Attempt to send a message, retrying after reconnect if necessary."""
try:
await self._websocket.send(message)
except Exception as e:
logger.error(f"{self} send failed: {e}, will try to reconnect")
# Try to reconnect before retrying
success = await self._try_reconnect(report_error=report_error)
if success:
logger.info(f"{self} reconnected successfully, will retry send the message")
# trying to send the message one more time
await self._websocket.send(message)
else:
logger.error(f"{self} send failed; unable to reconnect")
async def _receive_task_handler(self, report_error: Callable[[ErrorFrame], Awaitable[None]]):
"""Handle websocket message receiving with automatic retry logic.
@@ -130,9 +76,13 @@ class WebsocketService(ABC):
Args:
report_error: Callback function to report connection errors.
"""
retry_count = 0
MAX_RETRIES = 3
while True:
try:
await self._receive_messages()
retry_count = 0 # Reset counter on successful message receive
except ConnectionClosedOK as e:
# Normal closure, don't retry
logger.debug(f"{self} connection closed normally: {e}")
@@ -142,9 +92,21 @@ class WebsocketService(ABC):
logger.error(message)
if self._reconnect_on_error:
success = await self._try_reconnect(report_error=report_error)
if not success:
retry_count += 1
if retry_count >= MAX_RETRIES:
await report_error(ErrorFrame(message))
break
logger.warning(f"{self} connection error, will retry: {e}")
await report_error(ErrorFrame(message))
try:
if await self._reconnect_websocket(retry_count):
retry_count = 0 # Reset counter on successful reconnection
wait_time = exponential_backoff_time(retry_count)
await asyncio.sleep(wait_time)
except Exception as reconnect_error:
logger.error(f"{self} reconnection failed: {reconnect_error}")
else:
await report_error(ErrorFrame(message))
break

View File

@@ -203,8 +203,16 @@ async def run_test(
if not isinstance(frame, EndFrame) or not send_end_frame:
received_down_frames.append(frame)
print("received DOWN frames =", received_down_frames)
print("expected DOWN frames =", expected_down_frames)
down_frames_printed = "["
for frame in received_down_frames:
down_frames_printed += f"{frame.__class__.__name__}, "
down_frames_printed += "]"
expected_frames_printed = "["
for frame in expected_down_frames:
expected_frames_printed += f"{frame.__name__}, "
expected_frames_printed += "]"
print("received DOWN frames =", down_frames_printed)
print("expected DOWN frames =", expected_frames_printed)
assert len(received_down_frames) == len(expected_down_frames)

View File

@@ -18,7 +18,6 @@ Dependencies:
"""
import re
from dataclasses import dataclass
from typing import FrozenSet, List, Optional, Sequence, Tuple
import nltk
@@ -199,24 +198,7 @@ def parse_start_end_tags(
return (None, current_tag_index)
@dataclass
class TextPartForConcatenation:
"""Class representing a part of text for concatenation with concatenate_aggregated_text.
Attributes:
text: The text content.
includes_inter_part_spaces: Whether any necessary inter-frame
(leading/trailing) spaces are already included in the text.
"""
text: str
includes_inter_part_spaces: bool
def __str__(self):
return f"{self.name}(text: [{self.text}], includes_inter_part_spaces: {self.includes_inter_part_spaces})"
def concatenate_aggregated_text(text_parts: List[TextPartForConcatenation]) -> str:
def concatenate_aggregated_text(text_parts: List[str], add_spaces: bool) -> str:
"""Concatenate a list of text parts into a single string.
This function joins the provided list of text parts into a single string,
@@ -226,55 +208,15 @@ def concatenate_aggregated_text(text_parts: List[TextPartForConcatenation]) -> s
transcription services.
Args:
text_parts: A list of text parts to concatenate.
text_parts: A list of strings representing parts of text to concatenate.
add_spaces: Whether to add spaces between text parts during concatenation.
Returns:
A single concatenated string.
"""
result = ""
last_includes_inter_part_spaces = False
if not text_parts:
return result
def append_part(part: TextPartForConcatenation):
nonlocal result
nonlocal last_includes_inter_part_spaces
result += part.text
last_includes_inter_part_spaces = part.includes_inter_part_spaces
for part in text_parts:
# Part is empty.
# Skip.
if not part.text:
continue
# Result is as yet empty.
# Just append.
if not result:
append_part(part)
continue
if part.includes_inter_part_spaces and last_includes_inter_part_spaces:
# This part is part of an ongoing run that has spaces already included.
# Just append.
append_part(part)
elif not part.includes_inter_part_spaces and not last_includes_inter_part_spaces:
# This part is part of an ongoing run that has no spaces included.
# Add a space before appending.
result += " "
append_part(part)
else:
# This part represents a transition to a new run (spaces -> no spaces, or vice versa).
# Add a space if needed, before appending.
if not result[-1].isspace() and not part.text[0].isspace():
result += " "
append_part(part)
# NOTE: the above logic assumes that runs of text parts with
# includes_inter_part_spaces=True are well-formed, i.e. they're not
# actually multiple separate runs with a space-less boundary, like
# "hello ", "world.", "goodnight ", "moon."
# Concatenate text parts with or without spaces based on the flag
separator = " " if add_spaces else ""
result = separator.join(text_parts)
# Clean up any excessive whitespace
result = result.strip()

View File

@@ -12,9 +12,46 @@ aggregated text should be sent for speech synthesis.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Optional
class AggregationType(str, Enum):
"""Built-in aggregation strings."""
SENTENCE = "sentence"
WORD = "word"
def __str__(self):
return self.value
@dataclass
class Aggregation:
"""Data class representing aggregated text and its type.
An Aggregation object is created whenever a stream of text is aggregated by
a text aggregator. It contains the aggregated text and a type indicating
the nature of the aggregation.
Parameters:
text: The aggregated text content.
type: The type of aggregation the text represents (e.g., 'sentence', 'word', 'token', 'my_custom_aggregation').
"""
text: str
type: str
def __str__(self) -> str:
"""Return a string representation of the aggregation.
Returns:
A descriptive string showing the type and text of the aggregation.
"""
return f"Aggregation by {self.type}: {self.text}"
class BaseTextAggregator(ABC):
"""Base class for text aggregators in the Pipecat framework.
@@ -30,7 +67,7 @@ class BaseTextAggregator(ABC):
@property
@abstractmethod
def text(self) -> str:
def text(self) -> Aggregation:
"""Get the currently aggregated text.
Subclasses must implement this property to return the text that has
@@ -42,12 +79,13 @@ class BaseTextAggregator(ABC):
pass
@abstractmethod
async def aggregate(self, text: str) -> Optional[str]:
async def aggregate(self, text: str) -> Optional[Aggregation]:
"""Aggregate the specified text with the currently accumulated text.
This method should be implemented to define how the new text contributes
to the aggregation process. It returns the updated aggregated text if
it's ready to be processed, or None otherwise.
to the aggregation process. It returns the aggregated text and a string
describing how it was aggregated if it's ready to be processed,
or None otherwise.
Subclasses should implement their specific logic for:

View File

@@ -26,6 +26,7 @@ class BaseTextFilter(ABC):
behavior, settings management, and interruption handling logic.
"""
@abstractmethod
async def update_settings(self, settings: Mapping[str, Any]):
"""Update the filter's configuration settings.
@@ -52,6 +53,7 @@ class BaseTextFilter(ABC):
"""
pass
@abstractmethod
async def handle_interruption(self):
"""Handle interruption events in the processing pipeline.
@@ -60,6 +62,7 @@ class BaseTextFilter(ABC):
"""
pass
@abstractmethod
async def reset_interruption(self):
"""Reset the filter state after an interruption has been handled.

View File

@@ -8,19 +8,41 @@
This module provides an aggregator that identifies and processes content between
pattern pairs (like XML tags or custom delimiters) in streaming text, with
support for custom handlers and configurable pattern removal.
support for custom handlers and configurable actions for when a pattern is found.
"""
import re
from typing import Awaitable, Callable, Optional, Tuple
from enum import Enum
from typing import Awaitable, Callable, List, Optional, Tuple
from loguru import logger
from pipecat.utils.string import match_endofsentence
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType, BaseTextAggregator
class PatternMatch:
class MatchAction(Enum):
"""Actions to take when a pattern pair is matched.
Parameters:
REMOVE: The text along with its delimiters will be removed from the streaming text.
Sentence aggregation will continue on as if this text did not exist.
KEEP: The delimiters will be removed, but the content between them will be kept.
Sentence aggregation will continue on with the internal text included.
AGGREGATE: The delimiters will be removed and the content between will be treated
as a separate aggregation. Any text before the start of the pattern will be
returned early, whether or not a complete sentence was found. Then the pattern
will be returned. Then the aggregation will continue on sentence matching after
the closing delimiter is found. The content between the delimiters is not
aggregated by sentence. It is aggregated as one single block of text.
"""
REMOVE = "remove"
KEEP = "keep"
AGGREGATE = "aggregate"
class PatternMatch(Aggregation):
"""Represents a matched pattern pair with its content.
A PatternMatch object is created when a complete pattern pair is found
@@ -29,25 +51,25 @@ class PatternMatch:
content between the patterns.
"""
def __init__(self, pattern_id: str, full_match: str, content: str):
def __init__(self, content: str, type: str, full_match: str):
"""Initialize a pattern match.
Args:
pattern_id: The identifier of the matched pattern pair.
type: The type of the matched pattern pair. It should be representative
of the content type (e.g., 'sentence', 'code', 'speaker', 'custom').
full_match: The complete text including start and end patterns.
content: The text content between the start and end patterns.
"""
self.pattern_id = pattern_id
super().__init__(text=content, type=type)
self.full_match = full_match
self.content = content
def __str__(self) -> str:
"""Return a string representation of the pattern match.
Returns:
A descriptive string showing the pattern ID and content.
A descriptive string showing the pattern type and content.
"""
return f"PatternMatch(id={self.pattern_id}, content={self.content})"
return f"PatternMatch(type={self.type}, text={self.text}, full_match={self.full_match})"
class PatternPairAggregator(BaseTextAggregator):
@@ -55,16 +77,21 @@ class PatternPairAggregator(BaseTextAggregator):
This aggregator buffers text until it can identify complete pattern pairs
(defined by start and end patterns), processes the content between these
patterns using registered handlers, and returns text at sentence boundaries.
It's particularly useful for processing structured content in streaming text,
such as XML tags, markdown formatting, or custom delimiters.
patterns using registered handlers. By default, its aggregation method
returns text at sentence boundaries, and remove the content found between
any matched patterns. However, matched patterns can also be configured to
returned as a separate aggregation object containing the content between
their start and end patterns or left in, so that only the delimiters are
removed and a callback can be triggered.
This aggregator is particularly useful for processing structured content in
streaming text, such as XML tags, markdown formatting, or custom delimiters.
The aggregator ensures that patterns spanning multiple text chunks are
correctly identified and handles cases where patterns contain sentence
boundaries.
correctly identified.
"""
def __init__(self):
def __init__(self, **kwargs):
"""Initialize the pattern pair aggregator.
Creates an empty aggregator with no patterns or handlers registered.
@@ -75,16 +102,23 @@ class PatternPairAggregator(BaseTextAggregator):
self._handlers = {}
@property
def text(self) -> str:
"""Get the currently buffered text.
def text(self) -> Aggregation:
"""Get the currently aggregated text.
Returns:
The current text buffer content that hasn't been processed yet.
The text that has been accumulated in the buffer.
"""
return self._text
pattern_start = self._match_start_of_pattern(self._text)
if pattern_start:
return Aggregation(self._text, pattern_start[1].get("type", AggregationType.SENTENCE))
return Aggregation(self._text, AggregationType.SENTENCE)
def add_pattern_pair(
self, pattern_id: str, start_pattern: str, end_pattern: str, remove_match: bool = True
def add_pattern(
self,
type: str,
start_pattern: str,
end_pattern: str,
action: MatchAction = MatchAction.REMOVE,
) -> "PatternPairAggregator":
"""Add a pattern pair to detect in the text.
@@ -93,41 +127,94 @@ class PatternPairAggregator(BaseTextAggregator):
the end pattern, and treat the content between them as a match.
Args:
pattern_id: Unique identifier for this pattern pair.
type: Identifier for this pattern pair. Should be unique and ideally descriptive.
(e.g., 'code', 'speaker', 'custom'). type can not be 'sentence' or 'word' as
those are reserved for the default behavior.
start_pattern: Pattern that marks the beginning of content.
end_pattern: Pattern that marks the end of content.
remove_match: Whether to remove the matched content from the text.
action: What to do when a complete pattern is matched:
- MatchAction.REMOVE: Remove the matched pattern from the text.
- MatchAction.KEEP: Keep the matched pattern in the text and treat it as
normal text. This allows you to register handlers for
the pattern without affecting the aggregation logic.
- MatchAction.AGGREGATE: Return the matched pattern as a separate
aggregation object.
Returns:
Self for method chaining.
"""
self._patterns[pattern_id] = {
if type in [AggregationType.SENTENCE, AggregationType.WORD]:
raise ValueError(
f"The aggregation type '{type}' is reserved for default behavior and can not be used for custom patterns."
)
self._patterns[type] = {
"start": start_pattern,
"end": end_pattern,
"remove_match": remove_match,
"type": type,
"action": action,
}
return self
def add_pattern_pair(
self, pattern_id: str, start_pattern: str, end_pattern: str, remove_match: bool = True
):
"""Add a pattern pair to detect in the text.
.. deprecated:: 0.0.95
This function is deprecated and will be removed in a future version.
Use `add_pattern` with a type and MatchAction instead.
This method calls `add_pattern` setting type with the provided pattern_id and action
to either MatchAction.REMOVE or MatchAction.KEEP based on `remove_match`.
Args:
pattern_id: Identifier for this pattern pair. Should be unique and ideally descriptive.
(e.g., 'code', 'speaker', 'custom'). pattern_id can not be 'sentence' or 'word'
as those arereserved for the default behavior.
start_pattern: Pattern that marks the beginning of content.
end_pattern: Pattern that marks the end of content.
remove_match: If True, the matched pattern will be removed from the text. (Same as MatchAction.REMOVE)
If False, it will be kept and treated as normal text. (Same as MatchAction.KEEP)
"""
import warnings
with warnings.catch_warnings():
warnings.simplefilter("once")
warnings.warn(
"add_pattern_pair with a pattern_id or remove_match is deprecated and will be"
" removed in a future version. Use add_pattern with a type and MatchAction instead",
DeprecationWarning,
stacklevel=2,
)
action = MatchAction.REMOVE if remove_match else MatchAction.KEEP
return self.add_pattern(
type=pattern_id,
start_pattern=start_pattern,
end_pattern=end_pattern,
action=action,
)
def on_pattern_match(
self, pattern_id: str, handler: Callable[[PatternMatch], Awaitable[None]]
self, type: str, handler: Callable[[PatternMatch], Awaitable[None]]
) -> "PatternPairAggregator":
"""Register a handler for when a pattern pair is matched.
The handler will be called whenever a complete match for the
specified pattern ID is found in the text.
specified type is found in the text.
Args:
pattern_id: ID of the pattern pair to match.
type: The type of the pattern pair to trigger the handler.
handler: Async function to call when pattern is matched.
The function should accept a PatternMatch object.
Returns:
Self for method chaining.
"""
self._handlers[pattern_id] = handler
self._handlers[type] = handler
return self
async def _process_complete_patterns(self, text: str) -> Tuple[str, bool]:
async def _process_complete_patterns(self, text: str) -> Tuple[List[PatternMatch], str]:
"""Process all complete pattern pairs in the text.
Searches for all complete pattern pairs in the text, calls the
@@ -137,19 +224,19 @@ class PatternPairAggregator(BaseTextAggregator):
text: The text to process for pattern matches.
Returns:
Tuple of (processed_text, was_modified) where:
Tuple of (all_matches, processed_text) where:
- processed_text is the text after processing patterns
- was_modified indicates whether any changes were made
- all_matches is a list of all pattern matches found. Note: There really should only ever be 1.
- processed_text is the text after processing patterns. If no patterns are found, it will be the same as input text.
"""
all_matches = []
processed_text = text
modified = False
for pattern_id, pattern_info in self._patterns.items():
for type, pattern_info in self._patterns.items():
# Escape special regex characters in the patterns
start = re.escape(pattern_info["start"])
end = re.escape(pattern_info["end"])
remove_match = pattern_info["remove_match"]
action = pattern_info["action"]
# Create regex to match from start pattern to end pattern
# The .*? is non-greedy to handle nested patterns
@@ -164,25 +251,24 @@ class PatternPairAggregator(BaseTextAggregator):
full_match = match.group(0) # Full match including patterns
# Create pattern match object
pattern_match = PatternMatch(
pattern_id=pattern_id, full_match=full_match, content=content
)
pattern_match = PatternMatch(content=content, type=type, full_match=full_match)
# Call the appropriate handler if registered
if pattern_id in self._handlers:
if type in self._handlers:
try:
await self._handlers[pattern_id](pattern_match)
await self._handlers[type](pattern_match)
except Exception as e:
logger.error(f"Error in pattern handler for {pattern_id}: {e}")
logger.error(f"Error in pattern handler for {type}: {e}")
# Remove the pattern from the text if configured
if remove_match:
if action == MatchAction.REMOVE:
processed_text = processed_text.replace(full_match, "", 1)
modified = True
else:
all_matches.append(pattern_match)
return processed_text, modified
return all_matches, processed_text
def _has_incomplete_patterns(self, text: str) -> bool:
def _match_start_of_pattern(self, text: str) -> Optional[Tuple[int, dict]]:
"""Check if text contains incomplete pattern pairs.
Determines whether the text contains any start patterns without
@@ -192,9 +278,10 @@ class PatternPairAggregator(BaseTextAggregator):
text: The text to check for incomplete patterns.
Returns:
True if there are incomplete patterns, False otherwise.
A tuple of (start_index, pattern_info) if an incomplete pattern is found,
or None if no patterns are found or all patterns are complete.
"""
for pattern_id, pattern_info in self._patterns.items():
for type, pattern_info in self._patterns.items():
start = pattern_info["start"]
end = pattern_info["end"]
@@ -203,12 +290,16 @@ class PatternPairAggregator(BaseTextAggregator):
end_count = text.count(end)
# If there are more starts than ends, we have incomplete patterns
# Again, this is written generically but there only ever should
# be one pattern active at a time, so the counts should be 0 or 1.
# Which is why we base the return on the first found.
if start_count > end_count:
return True
start_index = text.find(start)
return [start_index, pattern_info]
return False
return None
async def aggregate(self, text: str) -> Optional[str]:
async def aggregate(self, text: str) -> Optional[PatternMatch]:
"""Aggregate text and process pattern pairs.
This method adds the new text to the buffer, processes any complete pattern
@@ -227,16 +318,34 @@ class PatternPairAggregator(BaseTextAggregator):
self._text += text
# Process any complete patterns in the buffer
processed_text, modified = await self._process_complete_patterns(self._text)
patterns, processed_text = await self._process_complete_patterns(self._text)
# Only update the buffer if modifications were made
if modified:
self._text = processed_text
self._text = processed_text
if len(patterns) > 0:
if len(patterns) > 1:
logger.warning(
f"Multiple patterns matched: {[p.type for p in patterns]}. Only the first pattern will be returned."
)
# If the pattern found is set to be aggregated, return it
action = self._patterns[patterns[0].type].get("action", MatchAction.REMOVE)
if action == MatchAction.AGGREGATE:
self._text = ""
return patterns[0]
# Check if we have incomplete patterns
if self._has_incomplete_patterns(self._text):
# Still waiting for complete patterns
return None
pattern_start = self._match_start_of_pattern(self._text)
if pattern_start is not None:
# If the start pattern is at the beginning or should not be separately aggregated, return None
if (
pattern_start[0] == 0
or pattern_start[1].get("action", MatchAction.REMOVE) != MatchAction.AGGREGATE
):
return None
# Otherwise, strip the text up to the start pattern and return it
result = self._text[: pattern_start[0]]
self._text = self._text[pattern_start[0] :]
return PatternMatch(content=result, type=AggregationType.SENTENCE, full_match=result)
# Find sentence boundary if no incomplete patterns
eos_marker = match_endofsentence(self._text)
@@ -244,7 +353,7 @@ class PatternPairAggregator(BaseTextAggregator):
# Extract text up to the sentence boundary
result = self._text[:eos_marker]
self._text = self._text[eos_marker:]
return result
return PatternMatch(content=result, type=AggregationType.SENTENCE, full_match=result)
# No complete sentence found yet
return None

View File

@@ -14,7 +14,7 @@ text processing scenarios.
from typing import Optional
from pipecat.utils.string import match_endofsentence
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType, BaseTextAggregator
class SimpleTextAggregator(BaseTextAggregator):
@@ -33,15 +33,15 @@ class SimpleTextAggregator(BaseTextAggregator):
self._text = ""
@property
def text(self) -> str:
def text(self) -> Aggregation:
"""Get the currently aggregated text.
Returns:
The text that has been accumulated in the buffer.
"""
return self._text
return Aggregation(self._text, AggregationType.SENTENCE)
async def aggregate(self, text: str) -> Optional[str]:
async def aggregate(self, text: str) -> Optional[Aggregation]:
"""Aggregate text and return completed sentences.
Adds the new text to the buffer and checks for end-of-sentence markers.
@@ -64,7 +64,7 @@ class SimpleTextAggregator(BaseTextAggregator):
result = self._text[:eos_end_marker]
self._text = self._text[eos_end_marker:]
return result
return Aggregation(result, AggregationType.SENTENCE) if result else None
async def handle_interruption(self):
"""Handle interruptions by clearing the text buffer.

View File

@@ -14,7 +14,7 @@ as a unit regardless of internal punctuation.
from typing import Optional, Sequence
from pipecat.utils.string import StartEndTags, match_endofsentence, parse_start_end_tags
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType, BaseTextAggregator
class SkipTagsAggregator(BaseTextAggregator):
@@ -49,9 +49,9 @@ class SkipTagsAggregator(BaseTextAggregator):
Returns:
The current text buffer content that hasn't been processed yet.
"""
return self._text
return Aggregation(self._text, AggregationType.SENTENCE)
async def aggregate(self, text: str) -> Optional[str]:
async def aggregate(self, text: str) -> Optional[Aggregation]:
"""Aggregate text while respecting tag boundaries.
This method adds the new text to the buffer, processes any complete
@@ -80,7 +80,7 @@ class SkipTagsAggregator(BaseTextAggregator):
# Extract text up to the sentence boundary
result = self._text[:eos_marker]
self._text = self._text[eos_marker:]
return result
return Aggregation(result, AggregationType.SENTENCE)
# No complete sentence found yet
return None

View File

@@ -23,7 +23,7 @@ if TYPE_CHECKING:
from opentelemetry import context as context_api
from opentelemetry import trace
from pipecat.processors.aggregators.llm_context import NOT_GIVEN, LLMContext
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.utils.tracing.service_attributes import (
add_gemini_live_span_attributes,
@@ -399,6 +399,11 @@ def traced_llm(func: Optional[Callable] = None, *, name: Optional[str] = None) -
if hasattr(self, "get_llm_adapter"):
adapter = self.get_llm_adapter()
messages = adapter.get_messages_for_logging(context)
elif hasattr(context, "get_messages"):
# Fallback for unknown context types
messages = context.get_messages()
elif hasattr(context, "messages"):
messages = context.messages
# Serialize messages if available
if messages:
@@ -419,10 +424,15 @@ def traced_llm(func: Optional[Callable] = None, *, name: Optional[str] = None) -
if hasattr(self, "get_llm_adapter") and hasattr(context, "tools"):
adapter = self.get_llm_adapter()
tools = adapter.from_standard_tools(context.tools)
elif hasattr(context, "tools"):
# Fallback for unknown context types
tools = context.tools
# Serialize and count tools if available
# Check if tools is not None and not NOT_GIVEN
if tools is not None and tools is not NOT_GIVEN:
# Check if tools is not None and not NOT_GIVEN (using attribute check as fallback)
if tools is not None and not (
hasattr(tools, "__name__") and tools.__name__ == "NOT_GIVEN"
):
serialized_tools = json.dumps(tools)
tool_count = len(tools) if isinstance(tools, list) else 1

View File

@@ -1005,53 +1005,3 @@ class TestLLMAssistantAggregator(
) -> Optional[LLMAssistantAggregatorParams]:
kwargs.pop("expect_stripped_words", None)
return LLMAssistantAggregatorParams(**kwargs) if kwargs else None
async def test_multiple_text_mixed(self):
assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass"
assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass"
context = self.CONTEXT_CLASS()
aggregator = self.AGGREGATOR_CLASS(
context, params=self.create_assistant_aggregator_params(expect_stripped_words=False)
)
# The newer LLMAssistantAggregator expects TextFrames to declare
# when they include inter-frame spaces.
def make_text_frame(text: str, includes_spaces: bool) -> TextFrame:
frame = TextFrame(text=text)
frame.includes_inter_frame_spaces = includes_spaces
return frame
frames_to_send = [
LLMFullResponseStartFrame(),
make_text_frame("Hello ", includes_spaces=True),
make_text_frame("Pipecat. ", includes_spaces=True),
make_text_frame("Here's some", includes_spaces=True),
make_text_frame(
" code:", includes_spaces=True
), # Validates ending includes_inter_frame_spaces run with no space
make_text_frame("```python\nprint('Hello, World!')\n```", includes_spaces=False),
make_text_frame(
"```javascript\nconsole.log('Hello, World!');\n```", includes_spaces=False
),
make_text_frame(
" And some more: ", includes_spaces=True
), # Validates starting includes_inter_frame_spaces run with a space and ending it with no space
make_text_frame("```html\n<div>Hello, World!</div>\n```", includes_spaces=False),
make_text_frame(
"Hope that ", includes_spaces=True
), # Validates starting includes_inter_frame_spaces run with no space
make_text_frame("helps!", includes_spaces=True),
LLMFullResponseEndFrame(),
]
expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES]
await run_test(
aggregator,
frames_to_send=frames_to_send,
expected_down_frames=expected_down_frames,
)
self.check_message_content(
context,
0,
"Hello Pipecat. Here's some code: ```python\nprint('Hello, World!')\n``` ```javascript\nconsole.log('Hello, World!');\n``` And some more: ```html\n<div>Hello, World!</div>\n``` Hope that helps!",
)

View File

@@ -7,30 +7,42 @@
import unittest
from unittest.mock import AsyncMock
from pipecat.utils.text.pattern_pair_aggregator import PatternMatch, PatternPairAggregator
from pipecat.utils.text.pattern_pair_aggregator import (
MatchAction,
PatternMatch,
PatternPairAggregator,
)
class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.aggregator = PatternPairAggregator()
self.test_handler = AsyncMock()
self.code_handler = AsyncMock()
# Add a test pattern
self.aggregator.add_pattern_pair(
pattern_id="test_pattern",
start_pattern="<test>",
end_pattern="</test>",
remove_match=True,
)
self.aggregator.add_pattern(
type="code_pattern",
start_pattern="<code>",
end_pattern="</code>",
action=MatchAction.AGGREGATE,
)
# Register the mock handler
self.aggregator.on_pattern_match("test_pattern", self.test_handler)
self.aggregator.on_pattern_match("code_pattern", self.code_handler)
async def test_pattern_match_and_removal(self):
# First part doesn't complete the pattern
result = await self.aggregator.aggregate("Hello <test>pattern")
self.assertIsNone(result)
self.assertEqual(self.aggregator.text, "Hello <test>pattern")
self.assertEqual(self.aggregator.text.text, "Hello <test>pattern")
self.assertEqual(self.aggregator.text.type, "test_pattern")
# Second part completes the pattern and includes an exclamation point
result = await self.aggregator.aggregate(" content</test>!")
@@ -39,20 +51,49 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
self.test_handler.assert_called_once()
call_args = self.test_handler.call_args[0][0]
self.assertIsInstance(call_args, PatternMatch)
self.assertEqual(call_args.pattern_id, "test_pattern")
self.assertEqual(call_args.type, "test_pattern")
self.assertEqual(call_args.full_match, "<test>pattern content</test>")
self.assertEqual(call_args.content, "pattern content")
self.assertEqual(call_args.text, "pattern content")
# The exclamation point should be treated as a sentence boundary,
# so the result should include just text up to and including "!"
self.assertEqual(result, "Hello !")
self.assertEqual(result.text, "Hello !")
self.assertEqual(result.type, "sentence")
# Next sentence should be processed separately
result = await self.aggregator.aggregate(" This is another sentence.")
self.assertEqual(result, " This is another sentence.")
self.assertEqual(result.text, " This is another sentence.")
# Buffer should be empty after returning a complete sentence
self.assertEqual(self.aggregator.text, "")
self.assertEqual(self.aggregator.text.text, "")
async def test_pattern_match_and_aggregate(self):
# First part doesn't complete the pattern
result = await self.aggregator.aggregate("Here is code <code>pattern")
self.assertEqual(result.text, "Here is code ")
self.assertEqual(self.aggregator.text.text, "<code>pattern")
self.assertEqual(self.aggregator.text.type, "code_pattern")
# Second part completes the pattern and includes an exclamation point
result = await self.aggregator.aggregate(" content</code>")
# Verify the handler was called with correct PatternMatch object
self.code_handler.assert_called_once()
call_args = self.code_handler.call_args[0][0]
self.assertIsInstance(call_args, PatternMatch)
self.assertEqual(call_args.type, "code_pattern")
self.assertEqual(call_args.full_match, "<code>pattern content</code>")
self.assertEqual(call_args.text, "pattern content")
self.assertEqual(result.text, "pattern content")
self.assertEqual(result.type, "code_pattern")
# Next sentence should be processed separately
result = await self.aggregator.aggregate(" This is another sentence.")
self.assertEqual(result.text, " This is another sentence.")
self.assertEqual(result.type, "sentence")
# Buffer should be empty after returning a complete sentence
self.assertEqual(self.aggregator.text.text, "")
async def test_incomplete_pattern(self):
# Add text with incomplete pattern
@@ -65,26 +106,30 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
self.test_handler.assert_not_called()
# Buffer should contain the incomplete text
self.assertEqual(self.aggregator.text, "Hello <test>pattern content")
self.assertEqual(self.aggregator.text.text, "Hello <test>pattern content")
self.assertEqual(self.aggregator.text.type, "test_pattern")
# Reset and confirm buffer is cleared
await self.aggregator.reset()
self.assertEqual(self.aggregator.text, "")
self.assertEqual(self.aggregator.text.text, "")
async def test_multiple_patterns(self):
# Set up multiple patterns and handlers
voice_handler = AsyncMock()
emphasis_handler = AsyncMock()
self.aggregator.add_pattern_pair(
pattern_id="voice", start_pattern="<voice>", end_pattern="</voice>", remove_match=True
self.aggregator.add_pattern(
type="voice",
start_pattern="<voice>",
end_pattern="</voice>",
action=MatchAction.REMOVE,
)
self.aggregator.add_pattern_pair(
pattern_id="emphasis",
self.aggregator.add_pattern(
type="emphasis",
start_pattern="<em>",
end_pattern="</em>",
remove_match=False, # Keep emphasis tags
action=MatchAction.KEEP, # Keep emphasis tags
)
self.aggregator.on_pattern_match("voice", voice_handler)
@@ -97,19 +142,19 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
# Both handlers should be called with correct data
voice_handler.assert_called_once()
voice_match = voice_handler.call_args[0][0]
self.assertEqual(voice_match.pattern_id, "voice")
self.assertEqual(voice_match.content, "female")
self.assertEqual(voice_match.type, "voice")
self.assertEqual(voice_match.text, "female")
emphasis_handler.assert_called_once()
emphasis_match = emphasis_handler.call_args[0][0]
self.assertEqual(emphasis_match.pattern_id, "emphasis")
self.assertEqual(emphasis_match.content, "very")
self.assertEqual(emphasis_match.type, "emphasis")
self.assertEqual(emphasis_match.text, "very")
# Voice pattern should be removed, emphasis pattern should remain
self.assertEqual(result, "Hello I am <em>very</em> excited to meet you!")
self.assertEqual(result.text, "Hello I am <em>very</em> excited to meet you!")
# Buffer should be empty
self.assertEqual(self.aggregator.text, "")
self.assertEqual(self.aggregator.text.text, "")
async def test_handle_interruption(self):
# Start with incomplete pattern
@@ -120,7 +165,7 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
await self.aggregator.handle_interruption()
# Buffer should be cleared
self.assertEqual(self.aggregator.text, "")
self.assertEqual(self.aggregator.text.text, "")
# Handler should not have been called
self.test_handler.assert_not_called()
@@ -138,10 +183,10 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
# Handler should be called with entire content
self.test_handler.assert_called_once()
call_args = self.test_handler.call_args[0][0]
self.assertEqual(call_args.content, "This is sentence one. This is sentence two.")
self.assertEqual(call_args.text, "This is sentence one. This is sentence two.")
# Pattern should be removed, resulting in text with sentences merged
self.assertEqual(result, "Hello Final sentence.")
self.assertEqual(result.text, "Hello Final sentence.")
# Buffer should be empty
self.assertEqual(self.aggregator.text, "")
self.assertEqual(self.aggregator.text.text, "")

View File

@@ -13,6 +13,7 @@ import pytest
from aiohttp import web
from pipecat.frames.frames import (
AggregatedTextFrame,
ErrorFrame,
TTSAudioRawFrame,
TTSSpeakFrame,
@@ -74,6 +75,7 @@ async def test_run_piper_tts_success(aiohttp_client):
]
expected_returned_frames = [
AggregatedTextFrame,
TTSStartedFrame,
TTSAudioRawFrame,
TTSAudioRawFrame,
@@ -121,7 +123,7 @@ async def test_run_piper_tts_error(aiohttp_client):
TTSSpeakFrame(text="Error case."),
]
expected_down_frames = [TTSStoppedFrame, TTSTextFrame]
expected_down_frames = [AggregatedTextFrame, TTSStoppedFrame, TTSTextFrame]
expected_up_frames = [ErrorFrame]

View File

@@ -15,15 +15,20 @@ class TestSimpleTextAggregator(unittest.IsolatedAsyncioTestCase):
async def test_reset_aggregations(self):
assert await self.aggregator.aggregate("Hello ") == None
assert self.aggregator.text == "Hello "
assert self.aggregator.text.text == "Hello "
await self.aggregator.reset()
assert self.aggregator.text == ""
assert self.aggregator.text.text == ""
async def test_simple_sentence(self):
assert await self.aggregator.aggregate("Hello ") == None
assert await self.aggregator.aggregate("Pipecat!") == "Hello Pipecat!"
assert self.aggregator.text == ""
aggregate = await self.aggregator.aggregate("Pipecat!")
assert aggregate.text == "Hello Pipecat!"
assert aggregate.type == "sentence"
assert self.aggregator.text.text == ""
async def test_multiple_sentences(self):
assert await self.aggregator.aggregate("Hello Pipecat! How are ") == "Hello Pipecat!"
assert await self.aggregator.aggregate("you?") == " How are you?"
aggregate = await self.aggregator.aggregate("Hello Pipecat! How are ")
assert aggregate.text == "Hello Pipecat!"
assert self.aggregator.text.text == " How are "
aggregate = await self.aggregator.aggregate("you?")
assert aggregate.text == " How are you?"

View File

@@ -18,16 +18,18 @@ class TestSkipTagsAggregator(unittest.IsolatedAsyncioTestCase):
# No tags involved, aggregate at end of sentence.
result = await self.aggregator.aggregate("Hello Pipecat!")
self.assertEqual(result, "Hello Pipecat!")
self.assertEqual(self.aggregator.text, "")
self.assertEqual(result.text, "Hello Pipecat!")
self.assertEqual(result.type, "sentence")
self.assertEqual(self.aggregator.text.text, "")
async def test_basic_tags(self):
await self.aggregator.reset()
# Tags involved, avoid aggregation during tags.
result = await self.aggregator.aggregate("My email is <spell>foo@pipecat.ai</spell>.")
self.assertEqual(result, "My email is <spell>foo@pipecat.ai</spell>.")
self.assertEqual(self.aggregator.text, "")
self.assertEqual(result.text, "My email is <spell>foo@pipecat.ai</spell>.")
self.assertEqual(result.type, "sentence")
self.assertEqual(self.aggregator.text.text, "")
async def test_streaming_tags(self):
await self.aggregator.reset()
@@ -35,20 +37,22 @@ class TestSkipTagsAggregator(unittest.IsolatedAsyncioTestCase):
# Tags involved, stream small chunk of texts.
result = await self.aggregator.aggregate("My email is <sp")
self.assertIsNone(result)
self.assertEqual(self.aggregator.text, "My email is <sp")
self.assertEqual(self.aggregator.text.text, "My email is <sp")
result = await self.aggregator.aggregate("ell>foo.")
self.assertIsNone(result)
self.assertEqual(self.aggregator.text, "My email is <spell>foo.")
self.assertEqual(self.aggregator.text.text, "My email is <spell>foo.")
result = await self.aggregator.aggregate("bar@pipecat.")
self.assertIsNone(result)
self.assertEqual(self.aggregator.text, "My email is <spell>foo.bar@pipecat.")
self.assertEqual(self.aggregator.text.text, "My email is <spell>foo.bar@pipecat.")
result = await self.aggregator.aggregate("ai</spe")
self.assertIsNone(result)
self.assertEqual(self.aggregator.text, "My email is <spell>foo.bar@pipecat.ai</spe")
self.assertEqual(self.aggregator.text.text, "My email is <spell>foo.bar@pipecat.ai</spe")
self.assertEqual(self.aggregator.text.type, "sentence")
result = await self.aggregator.aggregate("ll>.")
self.assertEqual(result, "My email is <spell>foo.bar@pipecat.ai</spell>.")
self.assertEqual(self.aggregator.text, "")
self.assertEqual(result.text, "My email is <spell>foo.bar@pipecat.ai</spell>.")
self.assertEqual(self.aggregator.text.text, "")
self.assertEqual(self.aggregator.text.type, "sentence")

View File

@@ -11,6 +11,7 @@ from datetime import datetime, timezone
from typing import List, Tuple, cast
from pipecat.frames.frames import (
AggregationType,
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
CancelFrame,
@@ -130,11 +131,11 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
frames_to_send = [
BotStartedSpeakingFrame(),
SleepFrame(), # Wait for StartedSpeaking to process
TTSTextFrame(text="Hello"),
TTSTextFrame(text="world!"),
TTSTextFrame(text="How"),
TTSTextFrame(text="are"),
TTSTextFrame(text="you?"),
TTSTextFrame(text="Hello", aggregated_by=AggregationType.WORD),
TTSTextFrame(text="world!", aggregated_by=AggregationType.WORD),
TTSTextFrame(text="How", aggregated_by=AggregationType.WORD),
TTSTextFrame(text="are", aggregated_by=AggregationType.WORD),
TTSTextFrame(text="you?", aggregated_by=AggregationType.WORD),
SleepFrame(), # Wait for text frames to queue
BotStoppedSpeakingFrame(),
]
@@ -195,9 +196,9 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
frames_to_send = [
BotStartedSpeakingFrame(),
SleepFrame(),
TTSTextFrame(text=""), # Empty text
TTSTextFrame(text=" "), # Just whitespace
TTSTextFrame(text="\n"), # Just newline
TTSTextFrame(text="", aggregated_by=AggregationType.WORD), # Empty text
TTSTextFrame(text=" ", aggregated_by=AggregationType.WORD), # Just whitespace
TTSTextFrame(text="\n", aggregated_by=AggregationType.WORD), # Just newline
BotStoppedSpeakingFrame(),
# Pipeline ends here; run_test will automatically send EndFrame
]
@@ -235,14 +236,14 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
frames_to_send = [
BotStartedSpeakingFrame(),
SleepFrame(),
TTSTextFrame(text="Hello"),
TTSTextFrame(text="world!"),
TTSTextFrame(text="Hello", aggregated_by=AggregationType.WORD),
TTSTextFrame(text="world!", aggregated_by=AggregationType.WORD),
SleepFrame(),
InterruptionFrame(), # User interrupts here
SleepFrame(),
BotStartedSpeakingFrame(),
TTSTextFrame(text="New"),
TTSTextFrame(text="response"),
TTSTextFrame(text="New", aggregated_by=AggregationType.WORD),
TTSTextFrame(text="response", aggregated_by=AggregationType.WORD),
SleepFrame(),
BotStoppedSpeakingFrame(),
]
@@ -299,8 +300,8 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
frames_to_send = [
BotStartedSpeakingFrame(),
SleepFrame(),
TTSTextFrame(text="Hello"),
TTSTextFrame(text="world"),
TTSTextFrame(text="Hello", aggregated_by=AggregationType.WORD),
TTSTextFrame(text="world", aggregated_by=AggregationType.WORD),
# Pipeline ends here; run_test will automatically send EndFrame
]
@@ -338,8 +339,8 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
frames_to_send = [
BotStartedSpeakingFrame(),
SleepFrame(),
TTSTextFrame(text="Hello"),
TTSTextFrame(text="world"),
TTSTextFrame(text="Hello", aggregated_by=AggregationType.WORD),
TTSTextFrame(text="world", aggregated_by=AggregationType.WORD),
SleepFrame(), # Ensure messages are processed
CancelFrame(),
]
@@ -401,8 +402,8 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
frames_to_send = [
BotStartedSpeakingFrame(),
SleepFrame(),
TTSTextFrame(text="Assistant"),
TTSTextFrame(text="message"),
TTSTextFrame(text="Assistant", aggregated_by=AggregationType.WORD),
TTSTextFrame(text="message", aggregated_by=AggregationType.WORD),
BotStoppedSpeakingFrame(),
]
@@ -439,7 +440,7 @@ class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
# Test the specific pattern shared
def make_tts_text_frame(text: str) -> TTSTextFrame:
frame = TTSTextFrame(text=text)
frame = TTSTextFrame(text=text, aggregated_by=AggregationType.WORD)
frame.includes_inter_frame_spaces = True
return frame

15
uv.lock generated
View File

@@ -36,12 +36,12 @@ wheels = [
[[package]]
name = "aic-sdk"
version = "1.1.0"
version = "1.0.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "numpy" },
]
sdist = { url = "https://files.pythonhosted.org/packages/99/83/bf38b95d98c67b8ebc574fb4a4f23c07a3740b51992d7522976173d30b98/aic_sdk-1.1.0.tar.gz", hash = "sha256:04e08df695581c8cb4db8acca20e73815e9f449e7bd08e0162fd55518c727963", size = 34954, upload-time = "2025-11-11T20:45:24.25Z" }
sdist = { url = "https://files.pythonhosted.org/packages/51/90/b02e853e863c303f8456c689b42ac24ad403b781adc9642d0a91ed4bed7e/aic_sdk-1.0.2.tar.gz", hash = "sha256:239097dd3aaa8a8a0fd7542b75d2510cb34144caec796370639b7c636acbc56e", size = 32059, upload-time = "2025-08-24T09:20:03.9Z" }
[[package]]
name = "aioboto3"
@@ -4647,7 +4647,7 @@ docs = [
[package.metadata]
requires-dist = [
{ name = "accelerate", marker = "extra == 'moondream'", specifier = "~=1.10.0" },
{ name = "aic-sdk", marker = "extra == 'aic'", specifier = "~=1.1.0" },
{ name = "aic-sdk", marker = "extra == 'aic'", specifier = "~=1.0.1" },
{ name = "aioboto3", marker = "extra == 'aws'", specifier = "~=15.0.0" },
{ name = "aiofiles", specifier = ">=24.1.0,<25" },
{ name = "aiohttp", specifier = ">=3.11.12,<4" },
@@ -4727,7 +4727,7 @@ requires-dist = [
{ name = "resampy", specifier = "~=0.4.3" },
{ name = "sarvamai", marker = "extra == 'sarvam'", specifier = "==0.1.21" },
{ name = "sentry-sdk", marker = "extra == 'sentry'", specifier = ">=2.28.0,<3" },
{ name = "simli-ai", marker = "extra == 'simli'", specifier = "~=1.0.3" },
{ name = "simli-ai", marker = "extra == 'simli'", specifier = "~=0.1.25" },
{ name = "soundfile", marker = "extra == 'soundfile'", specifier = "~=0.13.1" },
{ name = "soxr", specifier = "~=0.5.0" },
{ name = "speechmatics-rt", marker = "extra == 'speechmatics'", specifier = ">=0.5.0" },
@@ -6496,19 +6496,18 @@ wheels = [
[[package]]
name = "simli-ai"
version = "1.0.3"
version = "0.1.25"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "aiortc" },
{ name = "av" },
{ name = "httpx" },
{ name = "livekit" },
{ name = "numpy" },
{ name = "websockets" },
]
sdist = { url = "https://files.pythonhosted.org/packages/81/03/b0b3e12c68fd3f9c57f6afeee67841349e4866b88760f413357af3043ae4/simli_ai-1.0.3.tar.gz", hash = "sha256:e96b0621a1dbd9582b2ae3d51eefd4995983b49c1f1061eb9239707b15a1ee27", size = 13350, upload-time = "2025-11-13T12:22:32.514Z" }
sdist = { url = "https://files.pythonhosted.org/packages/64/6a/b28f90baf76f6a60865985f6233ff44abc72d45b66b76658bff3961e20a7/simli_ai-0.1.25.tar.gz", hash = "sha256:7a00b3426dc26a6a421641072c3e49014b7950c621cf4544152f35c58d13fcff", size = 13182, upload-time = "2025-11-06T16:27:08.862Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/5e/d1/dc382ba529de0d2d51f35e9bfd20b41d8f5c96404a3aa24bae97a5a5e51f/simli_ai-1.0.3-py3-none-any.whl", hash = "sha256:ffafa7540aa28833e207be8f3b199367c7f500dac1a8ba0108395bfb7d8362bc", size = 13863, upload-time = "2025-11-13T12:22:31.218Z" },
{ url = "https://files.pythonhosted.org/packages/ac/57/ae1032fd88214ea4ee6d3028c817c12a999eb90a67766bbab31e9819385a/simli_ai-0.1.25-py3-none-any.whl", hash = "sha256:7d01f65321dc9052f25e15d0463af6a20a86c6d37d9a7b3a2c4b01cbec0a54ed", size = 13651, upload-time = "2025-11-06T16:27:07.765Z" },
]
[[package]]