Compare commits
69 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
89f0ff17c0 | ||
|
|
c024eb7b8c | ||
|
|
608570e89d | ||
|
|
3ad61a8a04 | ||
|
|
4c4bae2db6 | ||
|
|
901b6b5913 | ||
|
|
71cd0f1c87 | ||
|
|
a2a419e6db | ||
|
|
bbbbdc459a | ||
|
|
d203528dad | ||
|
|
4bcca7956e | ||
|
|
68a4cf4c68 | ||
|
|
0508ddddfb | ||
|
|
54a4d8a9f8 | ||
|
|
38af514d95 | ||
|
|
6aa80c0b8e | ||
|
|
e720573e60 | ||
|
|
541a43905b | ||
|
|
707df913cd | ||
|
|
3f3d757581 | ||
|
|
7c781ce816 | ||
|
|
f3efc9da00 | ||
|
|
827a70104d | ||
|
|
a40327305c | ||
|
|
168af44429 | ||
|
|
5f8433476c | ||
|
|
6a6fea74f5 | ||
|
|
91b557ecbf | ||
|
|
be85291414 | ||
|
|
09f171b69d | ||
|
|
929fd98958 | ||
|
|
1cfbfcaf11 | ||
|
|
cd5a3c13bd | ||
|
|
9b871b0cc5 | ||
|
|
0d499a8aa3 | ||
|
|
45292ab13d | ||
|
|
be6ea0dbf6 | ||
|
|
fb18ae174e | ||
|
|
c4506523ab | ||
|
|
b360cb31dc | ||
|
|
07f104199c | ||
|
|
bc1949b4bf | ||
|
|
2035dd8b39 | ||
|
|
24c8189327 | ||
|
|
998ac32627 | ||
|
|
50645c1c4f | ||
|
|
8ce29ee8f2 | ||
|
|
7b8aeef4cc | ||
|
|
6a24457f0e | ||
|
|
2c01c2b5b3 | ||
|
|
1c2e114fa2 | ||
|
|
0f137e36c2 | ||
|
|
b7f12a96f1 | ||
|
|
3331f71e17 | ||
|
|
55d200e2d1 | ||
|
|
3fae00e067 | ||
|
|
78cdefd191 | ||
|
|
42502a4f3b | ||
|
|
fc67cc3302 | ||
|
|
241ab19228 | ||
|
|
c08e8ec8fb | ||
|
|
eb9bc9644e | ||
|
|
3a306dae90 | ||
|
|
e503ea7466 | ||
|
|
c42cc8254f | ||
|
|
a8e21f7d5d | ||
|
|
c6ef8de578 | ||
|
|
fc571fba42 | ||
|
|
248206e234 |
42
.github/workflows/update-lockfile.yaml
vendored
42
.github/workflows/update-lockfile.yaml
vendored
@@ -1,42 +0,0 @@
|
||||
name: Update lockfile
|
||||
|
||||
on:
|
||||
push:
|
||||
paths:
|
||||
- 'pyproject.toml'
|
||||
branches:
|
||||
- main
|
||||
workflow_dispatch: # Allows manual triggering from GitHub UI
|
||||
|
||||
jobs:
|
||||
update-lockfile:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
# This gives the workflow permission to push back to the repo
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v1
|
||||
|
||||
- name: Update lockfile
|
||||
run: uv lock
|
||||
|
||||
- name: Check for changes
|
||||
id: verify-changed-files
|
||||
run: |
|
||||
if [ -n "$(git status --porcelain)" ]; then
|
||||
echo "changed=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "changed=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Commit lockfile
|
||||
if: steps.verify-changed-files.outputs.changed == 'true'
|
||||
run: |
|
||||
git config --local user.email "action@github.com"
|
||||
git config --local user.name "GitHub Action"
|
||||
git add uv.lock
|
||||
git commit -m "chore: update uv.lock after dependency changes"
|
||||
git push
|
||||
99
CHANGELOG.md
99
CHANGELOG.md
@@ -5,6 +5,105 @@ All notable changes to **Pipecat** will be documented in this file.
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [0.0.80] - 2025-08-13
|
||||
|
||||
### Added
|
||||
|
||||
- Added `GeminiTTSService` which uses Google Gemini to generate TTS output. The
|
||||
Gemini model can be prompted to insert styled speech to control the TTS
|
||||
output.
|
||||
|
||||
- Added Exotel support to Pipecat's development runner. You can now connect
|
||||
using the runner with `uv run bot.py -t exotel` and an ngrok connection to
|
||||
HTTP port 7860.
|
||||
|
||||
- Added `enable_direct_mode` argument to `FrameProcessor`. The direct mode is
|
||||
for processors which require very little I/O or compute resources, that is
|
||||
processors that can perform their task almost immediately. These type of
|
||||
processors don't need any of the internal tasks and queues usually created by
|
||||
frame processors which means overall application performance might be slightly
|
||||
increased. Use with care.
|
||||
|
||||
- Added TTFB metrics for `HeyGenVideoService` and `TavusVideoService`.
|
||||
|
||||
- Added `endpoint_id` parameter to `AzureSTTService`. ([Custom EndpointId](https://docs.azure.cn/en-us/ai-services/speech-service/how-to-recognize-speech?pivots=programming-language-python#use-a-custom-endpoint))
|
||||
|
||||
### Changed
|
||||
|
||||
- `WatchdogPriorityQueue` now requires the items to be inserted to always be
|
||||
tuples and the size of the tuple needs to be specified in the constructor when
|
||||
creating the queue with the `tuple_size` argument.
|
||||
|
||||
- Updated Moondream to revision `2025-01-09`.
|
||||
|
||||
- Updated `PlayHTHttpTTSService` to no longer use the `pyht` client to remove
|
||||
compatibility issues with other packages. Now you can use the PlayHT HTTP
|
||||
service with other services, like GoogleLLMService.
|
||||
|
||||
- Updated `pyproject.toml` to once again pin `numba` to `>=0.61.2` in order to
|
||||
resolve package versioning issues.
|
||||
|
||||
- Updated the `STTMuteFilter` to include `VADUserStartedSpeakingFrame` and
|
||||
`VADUserStoppedSpeakingFrame` in the list of frames to filter when the
|
||||
filtering is on.
|
||||
|
||||
### Performance
|
||||
|
||||
- Improving the latency of the `HeyGenVideoService`.
|
||||
|
||||
- Improved some frame processors performance by using the new frame processor
|
||||
direct mode. In direct mode a frame processor will process frames right away
|
||||
avoiding the need for internal queues and tasks. This is useful for some
|
||||
simple processors. For example, in processors that wrap other processors
|
||||
(e.g. `Pipeline`, `ParallelPipeline`), we add one processor before and one
|
||||
after the wrapped processors (internally, you will see them as sources and
|
||||
sinks). These sources and sinks don't do any special processing and they
|
||||
basically forward frames. So, for these simple processors we now enable the
|
||||
new direct mode which avoids creating any internal tasks (and queues) and
|
||||
therefore improves performance.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed an issue with the `BaseWhisperSTTService` where the language was
|
||||
specified as an enum and not a string.
|
||||
|
||||
- Fixed an issue where `SmallWebRTCTransport` ended before TTS finished.
|
||||
|
||||
- Fixed an issue in `OpenAIRealtimeBetaLLMService` where specifying a `text`
|
||||
`modalities` didn't result in text being outputted from the model.
|
||||
|
||||
- Added SSML reserved character escaping to `AzureBaseTTSService` to properly
|
||||
handle special characters in text sent to Azure TTS. This fixes an issue
|
||||
where characters like `&`, `<`, `>`, `"`, and `'` in LLM-generated text would
|
||||
cause TTS failures.
|
||||
|
||||
- Fixed a `WatchdogPriorityQueue` issue that could cause an exception when
|
||||
compating watchdog cancel sentinel items with other items in the queue.
|
||||
|
||||
- Fixed an issue that would cause system frames to not be processed with higher
|
||||
priority than other frames. This could cause slower interruption times.
|
||||
|
||||
- Fixed an issue where retrying a websocket connection error would result in an
|
||||
error.
|
||||
|
||||
### Other
|
||||
|
||||
- Add foundation example `19b-openai-realtime-beta-text.py`, showing how to use
|
||||
`OpenAIRealtimeBetaLLMService` to output text to a TTS service.
|
||||
|
||||
- Add vision support to release evals so we can run the foundational examples 12
|
||||
series.
|
||||
|
||||
- Added foundational example `15a-switch-languages.py` to release evals. It is
|
||||
able to detect if we switched the language properly.
|
||||
|
||||
- Updated foundational examples to show how to enclose complex logic
|
||||
(e.g. `ParallelPipeline`) into a single processor so the main pipeline becomes
|
||||
simpler.
|
||||
|
||||
- Added `07n-interruptible-gemini.py`, demonstrating how to use
|
||||
`GeminiTTSService`.
|
||||
|
||||
## [0.0.79] - 2025-08-07
|
||||
|
||||
### Changed
|
||||
|
||||
@@ -31,6 +31,23 @@ git push origin your-branch-name
|
||||
|
||||
Our maintainers will review your PR, and once everything is good, your contributions will be merged!
|
||||
|
||||
## Dependency Management
|
||||
|
||||
This project uses [uv](https://docs.astral.sh/uv/) for dependency management. The `uv.lock` file is committed to ensure reproducible builds.
|
||||
|
||||
### Adding or Updating Dependencies
|
||||
|
||||
1. Edit `pyproject.toml` to add/update dependencies
|
||||
2. Run `uv lock` to update the lockfile with new dependency resolution
|
||||
3. Run `uv sync` to install the updated dependencies locally
|
||||
4. Always commit both files together:
|
||||
```bash
|
||||
git add pyproject.toml uv.lock
|
||||
git commit -m "feat: add new dependency for feature X"
|
||||
```
|
||||
|
||||
**Important:** Never manually edit `uv.lock`. It's auto-generated by `uv lock`.
|
||||
|
||||
## Code Style and Documentation
|
||||
|
||||
### Python Code Style
|
||||
|
||||
26
README.md
26
README.md
@@ -114,7 +114,8 @@ You can get started with Pipecat running on your local machine, then move your a
|
||||
|
||||
### Prerequisites
|
||||
|
||||
**Python Version:** 3.10+
|
||||
**Minimum Python Version:** 3.10
|
||||
**Recommended Python Version:** 3.11-3.12
|
||||
|
||||
### Setup Steps
|
||||
|
||||
@@ -128,7 +129,7 @@ You can get started with Pipecat running on your local machine, then move your a
|
||||
2. Install development and testing dependencies:
|
||||
|
||||
```bash
|
||||
uv sync --group dev --all-extras --no-extra krisp
|
||||
uv sync --group dev --all-extras --no-extra gstreamer --no-extra krisp --no-extra local
|
||||
```
|
||||
|
||||
3. Install the git pre-commit hooks:
|
||||
@@ -137,18 +138,25 @@ You can get started with Pipecat running on your local machine, then move your a
|
||||
uv run pre-commit install
|
||||
```
|
||||
|
||||
### Python 3.13+ Note
|
||||
### Python 3.13+ Compatibility
|
||||
|
||||
Some features require PyTorch (not yet available on Python 3.13+):
|
||||
|
||||
- `ultravox`, `local-smart-turn`, `moondream`, `mlx-whisper`
|
||||
|
||||
**For full compatibility:** Use Python 3.12
|
||||
Some features require PyTorch, which doesn't yet support Python 3.13+. Install using:
|
||||
|
||||
```bash
|
||||
uv python pin 3.12 && uv sync --group dev --all-extras --no-extra krisp
|
||||
uv sync --group dev --all-extras \
|
||||
--no-extra gstreamer \
|
||||
--no-extra krisp \
|
||||
--no-extra local \
|
||||
--no-extra local-smart-turn \
|
||||
--no-extra mlx-whisper \
|
||||
--no-extra moondream \
|
||||
--no-extra ultravox
|
||||
```
|
||||
|
||||
> **Tip:** For full compatibility, use Python 3.12: `uv python pin 3.12`
|
||||
|
||||
> **Note**: Some extras (local, gstreamer) require system dependencies. See documentation if you encounter build errors.
|
||||
|
||||
### Running tests
|
||||
|
||||
To run all tests, from the root directory:
|
||||
|
||||
163
examples/foundational/07n-interruptible-gemini.py
Normal file
163
examples/foundational/07n-interruptible-gemini.py
Normal file
@@ -0,0 +1,163 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""
|
||||
A conversational AI bot using Gemini for both LLM and TTS.
|
||||
|
||||
This example demonstrates how to use Gemini's TTS capabilities with the new
|
||||
GeminiTTSService, which uses Gemini's TTS-specific models instead of Google Cloud TTS.
|
||||
|
||||
Features showcased:
|
||||
- Gemini LLM for conversation
|
||||
- Gemini TTS with natural voice control
|
||||
- Support for different voice personalities
|
||||
- Style and tone control through natural language prompts
|
||||
|
||||
Run with:
|
||||
python examples/foundational/gemini-tts.py
|
||||
|
||||
Make sure to set your environment variables:
|
||||
export GOOGLE_API_KEY=your_api_key_here
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.google.llm import GoogleLLMService
|
||||
from pipecat.services.google.stt import GoogleSTTService
|
||||
from pipecat.services.google.tts import GeminiTTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.network.fastapi_websocket import FastAPIWebsocketParams
|
||||
from pipecat.transports.services.daily import DailyParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot with Gemini TTS")
|
||||
|
||||
stt = GoogleSTTService(
|
||||
params=GoogleSTTService.InputParams(languages=Language.EN_US),
|
||||
credentials=os.getenv("GOOGLE_TEST_CREDENTIALS"),
|
||||
)
|
||||
|
||||
tts = GeminiTTSService(
|
||||
api_key=os.getenv("GOOGLE_API_KEY"),
|
||||
model="gemini-2.5-flash-preview-tts", # TTS-specific model
|
||||
voice_id="Charon",
|
||||
params=GeminiTTSService.InputParams(language=Language.EN_US),
|
||||
)
|
||||
|
||||
llm = GoogleLLMService(
|
||||
api_key=os.getenv("GOOGLE_API_KEY"),
|
||||
model="gemini-2.5-flash",
|
||||
)
|
||||
|
||||
# System message that instructs the AI on how to speak
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": """You are a helpful AI assistant in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way.
|
||||
|
||||
IMPORTANT: Since you're using Gemini TTS which supports natural voice control, you can include speaking instructions in your responses. For example:
|
||||
- "Say cheerfully: Welcome to our conversation!"
|
||||
- "Read this in a calm, professional tone: Here are the details you requested."
|
||||
- "Speak in an excited whisper: I have some great news to share!"
|
||||
- "Say slowly and clearly: Let me explain this step by step."
|
||||
|
||||
Feel free to use natural language instructions to control your voice style, tone, pace, and emotion. The TTS system will interpret these instructions and adjust the speech accordingly.
|
||||
|
||||
Your output will be converted to audio, so avoid special characters in your answers. Respond to what the user said in a creative and helpful way.""",
|
||||
},
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt, # STT
|
||||
context_aggregator.user(), # User responses
|
||||
llm, # LLM
|
||||
tts, # Gemini TTS
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation with a styled introduction
|
||||
messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Say cheerfully and warmly: Hello! I'm your AI assistant powered by Gemini's new TTS technology. I can speak with different voices, tones, and styles. How can I help you today?",
|
||||
}
|
||||
)
|
||||
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -12,6 +12,7 @@ from loguru import logger
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import Frame
|
||||
from pipecat.pipeline.parallel_pipeline import ParallelPipeline
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
@@ -31,29 +32,54 @@ from pipecat.transports.services.daily import DailyParams
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
current_voice = "News Lady"
|
||||
class SwitchVoices(ParallelPipeline):
|
||||
def __init__(self):
|
||||
self._current_voice = "News Lady"
|
||||
|
||||
news_lady = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="bf991597-6c13-47e4-8411-91ec2de5c466", # Newslady
|
||||
)
|
||||
|
||||
async def switch_voice(params: FunctionCallParams):
|
||||
global current_voice
|
||||
current_voice = params.arguments["voice"]
|
||||
await params.result_callback(
|
||||
{
|
||||
"voice": f"You are now using your {current_voice} voice. Your responses should now be as if you were a {current_voice}."
|
||||
}
|
||||
)
|
||||
british_lady = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
)
|
||||
|
||||
barbershop_man = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="a0e99841-438c-4a64-b679-ae501e7d6091", # Barbershop Man
|
||||
)
|
||||
|
||||
async def news_lady_filter(frame) -> bool:
|
||||
return current_voice == "News Lady"
|
||||
super().__init__(
|
||||
# News Lady voice
|
||||
[FunctionFilter(self.news_lady_filter), news_lady],
|
||||
# British Reading Lady voice
|
||||
[FunctionFilter(self.british_lady_filter), british_lady],
|
||||
# Barbershop Man voice
|
||||
[FunctionFilter(self.barbershop_man_filter), barbershop_man],
|
||||
)
|
||||
|
||||
@property
|
||||
def current_voice(self):
|
||||
return self._current_voice
|
||||
|
||||
async def british_lady_filter(frame) -> bool:
|
||||
return current_voice == "British Lady"
|
||||
async def switch_voice(self, params: FunctionCallParams):
|
||||
self._current_voice = params.arguments["voice"]
|
||||
await params.result_callback(
|
||||
{
|
||||
"voice": f"You are now using your {self.current_voice} voice. Your responses should now be as if you were a {self.current_voice}."
|
||||
}
|
||||
)
|
||||
|
||||
async def news_lady_filter(self, _: Frame) -> bool:
|
||||
return self.current_voice == "News Lady"
|
||||
|
||||
async def barbershop_man_filter(frame) -> bool:
|
||||
return current_voice == "Barbershop Man"
|
||||
async def british_lady_filter(self, _: Frame) -> bool:
|
||||
return self.current_voice == "British Lady"
|
||||
|
||||
async def barbershop_man_filter(self, _: Frame) -> bool:
|
||||
return self.current_voice == "Barbershop Man"
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
@@ -83,23 +109,10 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
news_lady = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="bf991597-6c13-47e4-8411-91ec2de5c466", # Newslady
|
||||
)
|
||||
|
||||
british_lady = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
)
|
||||
|
||||
barbershop_man = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="a0e99841-438c-4a64-b679-ae501e7d6091", # Barbershop Man
|
||||
)
|
||||
tts = SwitchVoices()
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
llm.register_function("switch_voice", switch_voice)
|
||||
llm.register_function("switch_voice", tts.switch_voice)
|
||||
|
||||
tools = [
|
||||
ChatCompletionToolParam(
|
||||
@@ -136,14 +149,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
stt,
|
||||
context_aggregator.user(), # User responses
|
||||
llm, # LLM
|
||||
ParallelPipeline( # TTS (one of the following vocies)
|
||||
[FunctionFilter(news_lady_filter), news_lady], # News Lady voice
|
||||
[
|
||||
FunctionFilter(british_lady_filter),
|
||||
british_lady,
|
||||
], # British Reading Lady voice
|
||||
[FunctionFilter(barbershop_man_filter), barbershop_man], # Barbershop Man voice
|
||||
),
|
||||
tts, # TTS with switch voice functionality
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
]
|
||||
@@ -165,7 +171,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"Please introduce yourself to the user and let them know the voices you can do. Your initial responses should be as if you were a {current_voice}.",
|
||||
"content": f"Please introduce yourself to the user and let them know the voices you can do. Your initial responses should be as if you were a {tts.current_voice}.",
|
||||
}
|
||||
)
|
||||
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||
|
||||
@@ -13,6 +13,7 @@ from loguru import logger
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import Frame
|
||||
from pipecat.pipeline.parallel_pipeline import ParallelPipeline
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
@@ -32,23 +33,42 @@ from pipecat.transports.services.daily import DailyParams
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
current_language = "English"
|
||||
class SwitchLanguage(ParallelPipeline):
|
||||
def __init__(self):
|
||||
self._current_language = "English"
|
||||
|
||||
english_tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
)
|
||||
|
||||
async def switch_language(params: FunctionCallParams):
|
||||
global current_language
|
||||
current_language = params.arguments["language"]
|
||||
await params.result_callback(
|
||||
{"voice": f"Your answers from now on should be in {current_language}."}
|
||||
)
|
||||
spanish_tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="d4db5fb9-f44b-4bd1-85fa-192e0f0d75f9", # Spanish-speaking Lady
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
# English
|
||||
[FunctionFilter(self.english_filter), english_tts],
|
||||
# Spanish
|
||||
[FunctionFilter(self.spanish_filter), spanish_tts],
|
||||
)
|
||||
|
||||
async def english_filter(frame) -> bool:
|
||||
return current_language == "English"
|
||||
@property
|
||||
def current_language(self):
|
||||
return self._current_language
|
||||
|
||||
async def switch_language(self, params: FunctionCallParams):
|
||||
self._current_language = params.arguments["language"]
|
||||
await params.result_callback(
|
||||
{"voice": f"Your answers from now on should be in {self.current_language}."}
|
||||
)
|
||||
|
||||
async def spanish_filter(frame) -> bool:
|
||||
return current_language == "Spanish"
|
||||
async def english_filter(self, _: Frame) -> bool:
|
||||
return self.current_language == "English"
|
||||
|
||||
async def spanish_filter(self, _: Frame) -> bool:
|
||||
return self.current_language == "Spanish"
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
@@ -80,18 +100,10 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
api_key=os.getenv("DEEPGRAM_API_KEY"), live_options=LiveOptions(language="multi")
|
||||
)
|
||||
|
||||
english_tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
)
|
||||
|
||||
spanish_tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="d4db5fb9-f44b-4bd1-85fa-192e0f0d75f9", # Spanish-speaking Lady
|
||||
)
|
||||
tts = SwitchLanguage()
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
llm.register_function("switch_language", switch_language)
|
||||
llm.register_function("switch_language", tts.switch_language)
|
||||
|
||||
tools = [
|
||||
ChatCompletionToolParam(
|
||||
@@ -128,10 +140,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
stt, # STT
|
||||
context_aggregator.user(), # User responses
|
||||
llm, # LLM
|
||||
ParallelPipeline( # TTS (bot will speak the chosen language)
|
||||
[FunctionFilter(english_filter), english_tts], # English
|
||||
[FunctionFilter(spanish_filter), spanish_tts], # Spanish
|
||||
),
|
||||
tts, # TTS (bot will speak the chosen language)
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
]
|
||||
@@ -153,7 +162,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"Please introduce yourself to the user and let them know the languages you speak. Your initial responses should be in {current_language}.",
|
||||
"content": f"Please introduce yourself to the user and let them know the languages you speak. Your initial responses should be in {tts.current_language}.",
|
||||
}
|
||||
)
|
||||
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||
|
||||
@@ -158,16 +158,6 @@ Remember, your responses should be short. Just one or two sentences, usually."""
|
||||
# openai WebSocket API can understand.
|
||||
context = OpenAILLMContext(
|
||||
[{"role": "user", "content": "Say hello!"}],
|
||||
# [{"role": "user", "content": [{"type": "text", "text": "Say hello!"}]}],
|
||||
# [
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": [
|
||||
# {"type": "text", "text": "Say"},
|
||||
# {"type": "text", "text": "yo what's up!"},
|
||||
# ],
|
||||
# }
|
||||
# ],
|
||||
tools,
|
||||
)
|
||||
|
||||
|
||||
229
examples/foundational/19b-openai-realtime-beta-text.py
Normal file
229
examples/foundational/19b-openai-realtime-beta-text.py
Normal file
@@ -0,0 +1,229 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import TranscriptionMessage
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.transcript_processor import TranscriptProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia import CartesiaTTSService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.services.openai_realtime_beta import (
|
||||
InputAudioNoiseReduction,
|
||||
InputAudioTranscription,
|
||||
OpenAIRealtimeBetaLLMService,
|
||||
SemanticTurnDetection,
|
||||
SessionProperties,
|
||||
)
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.network.fastapi_websocket import FastAPIWebsocketParams
|
||||
from pipecat.transports.services.daily import DailyParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
async def fetch_weather_from_api(params: FunctionCallParams):
|
||||
temperature = 75 if params.arguments["format"] == "fahrenheit" else 24
|
||||
await params.result_callback(
|
||||
{
|
||||
"conditions": "nice",
|
||||
"temperature": temperature,
|
||||
"format": params.arguments["format"],
|
||||
"timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def fetch_restaurant_recommendation(params: FunctionCallParams):
|
||||
await params.result_callback({"name": "The Golden Dragon"})
|
||||
|
||||
|
||||
weather_function = FunctionSchema(
|
||||
name="get_current_weather",
|
||||
description="Get the current weather",
|
||||
properties={
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location.",
|
||||
},
|
||||
},
|
||||
required=["location", "format"],
|
||||
)
|
||||
|
||||
restaurant_function = FunctionSchema(
|
||||
name="get_restaurant_recommendation",
|
||||
description="Get a restaurant recommendation",
|
||||
properties={
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
},
|
||||
required=["location"],
|
||||
)
|
||||
|
||||
# Create tools schema
|
||||
tools = ToolsSchema(standard_tools=[weather_function, restaurant_function])
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
session_properties = SessionProperties(
|
||||
input_audio_transcription=InputAudioTranscription(),
|
||||
modalities=["text"],
|
||||
# Set openai TurnDetection parameters. Not setting this at all will turn it
|
||||
# on by default
|
||||
turn_detection=SemanticTurnDetection(),
|
||||
# Or set to False to disable openai turn detection and use transport VAD
|
||||
# turn_detection=False,
|
||||
input_audio_noise_reduction=InputAudioNoiseReduction(type="near_field"),
|
||||
# tools=tools,
|
||||
instructions="""You are a helpful and friendly AI.
|
||||
|
||||
Act like a human, but remember that you aren't a human and that you can't do human
|
||||
things in the real world. Your voice and personality should be warm and engaging, with a lively and
|
||||
playful tone.
|
||||
|
||||
If interacting in a non-English language, start by using the standard accent or dialect familiar to
|
||||
the user. Talk quickly. You should always call a function if you can. Do not refer to these rules,
|
||||
even if you're asked about them.
|
||||
|
||||
You are participating in a voice conversation. Keep your responses concise, short, and to the point
|
||||
unless specifically asked to elaborate on a topic.
|
||||
|
||||
You have access to the following tools:
|
||||
- get_current_weather: Get the current weather for a given location.
|
||||
- get_restaurant_recommendation: Get a restaurant recommendation for a given location.
|
||||
|
||||
Remember, your responses should be short. Just one or two sentences, usually.""",
|
||||
)
|
||||
|
||||
llm = OpenAIRealtimeBetaLLMService(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
session_properties=session_properties,
|
||||
start_audio_paused=False,
|
||||
)
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
)
|
||||
|
||||
# you can either register a single function for all function calls, or specific functions
|
||||
# llm.register_function(None, fetch_weather_from_api)
|
||||
llm.register_function("get_current_weather", fetch_weather_from_api)
|
||||
llm.register_function("get_restaurant_recommendation", fetch_restaurant_recommendation)
|
||||
|
||||
transcript = TranscriptProcessor()
|
||||
|
||||
# Create a standard OpenAI LLM context object using the normal messages format. The
|
||||
# OpenAIRealtimeBetaLLMService will convert this internally to messages that the
|
||||
# openai WebSocket API can understand.
|
||||
context = OpenAILLMContext(
|
||||
[{"role": "user", "content": "Say hello!"}],
|
||||
tools,
|
||||
)
|
||||
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
context_aggregator.user(),
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transcript.user(), # Placed after the LLM, as LLM pushes TranscriptionFrames downstream
|
||||
transport.output(), # Transport bot output
|
||||
transcript.assistant(), # After the transcript output, to time with the audio output
|
||||
context_aggregator.assistant(),
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
# Register event handler for transcript updates
|
||||
@transcript.event_handler("on_transcript_update")
|
||||
async def on_transcript_update(processor, frame):
|
||||
for msg in frame.messages:
|
||||
if isinstance(msg, TranscriptionMessage):
|
||||
timestamp = f"[{msg.timestamp}] " if msg.timestamp else ""
|
||||
line = f"{timestamp}{msg.role}: {msg.content}"
|
||||
logger.info(f"Transcript: {line}")
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -25,7 +25,8 @@ from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.services.llm_service import LLMService
|
||||
from pipecat.services.openai.llm import OpenAIContextAggregatorPair, OpenAILLMService
|
||||
from pipecat.sync.event_notifier import EventNotifier
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.network.fastapi_websocket import FastAPIWebsocketParams
|
||||
@@ -34,6 +35,76 @@ from pipecat.transports.services.daily import DailyParams
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
class TurnDetectionLLM(Pipeline):
|
||||
def __init__(self, llm: LLMService, context_aggregator: OpenAIContextAggregatorPair):
|
||||
# This is the LLM that will be used to detect if the user has finished a
|
||||
# statement. This doesn't really need to be an LLM, we could use NLP
|
||||
# libraries for that, but it was easier as an example because we
|
||||
# leverage the context aggregators.
|
||||
statement_llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
statement_messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Determine if the user's statement is a complete sentence or question, ending in a natural pause or punctuation. Return 'YES' if it is complete and 'NO' if it seems to leave a thought unfinished.",
|
||||
},
|
||||
]
|
||||
|
||||
statement_context = OpenAILLMContext(statement_messages)
|
||||
statement_context_aggregator = statement_llm.create_context_aggregator(statement_context)
|
||||
|
||||
# We have instructed the LLM to return 'YES' if it thinks the user
|
||||
# completed a sentence. So, if it's 'YES' we will return true in this
|
||||
# predicate which will wake up the notifier.
|
||||
async def wake_check_filter(frame):
|
||||
logger.debug(f"Completeness check frame: {frame}")
|
||||
return frame.text == "YES"
|
||||
|
||||
# This is a notifier that we use to synchronize the two LLMs.
|
||||
notifier = EventNotifier()
|
||||
|
||||
# This a filter that will wake up the notifier if the given predicate
|
||||
# (wake_check_filter) returns true.
|
||||
completness_check = WakeNotifierFilter(
|
||||
notifier, types=(TextFrame,), filter=wake_check_filter
|
||||
)
|
||||
|
||||
# This processor keeps the last context and will let it through once the
|
||||
# notifier is woken up. We start with the gate open because we send an
|
||||
# initial context frame to start the conversation.
|
||||
gated_context_aggregator = GatedOpenAILLMContextAggregator(
|
||||
notifier=notifier, start_open=True
|
||||
)
|
||||
|
||||
# Notify if the user hasn't said anything.
|
||||
async def user_idle_notifier(frame):
|
||||
await notifier.notify()
|
||||
|
||||
# Sometimes the LLM will fail detecting if a user has completed a
|
||||
# sentence, this will wake up the notifier if that happens.
|
||||
user_idle = UserIdleProcessor(callback=user_idle_notifier, timeout=3.0)
|
||||
|
||||
# The ParallePipeline input are the user transcripts. We have two
|
||||
# contexts. The first one will be used to determine if the user finished
|
||||
# a statement and if so the notifier will be woken up. The second
|
||||
# context is simply the regular context but it's gated waiting for the
|
||||
# notifier to be woken up.
|
||||
super().__init__(
|
||||
[
|
||||
ParallelPipeline(
|
||||
[
|
||||
statement_context_aggregator.user(),
|
||||
statement_llm,
|
||||
completness_check,
|
||||
NullFilter(),
|
||||
],
|
||||
[context_aggregator.user(), gated_context_aggregator, llm],
|
||||
),
|
||||
user_idle,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
@@ -66,24 +137,8 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
)
|
||||
|
||||
# This is the LLM that will be used to detect if the user has finished a
|
||||
# statement. This doesn't really need to be an LLM, we could use NLP
|
||||
# libraries for that, but it was easier as an example because we
|
||||
# leverage the context aggregators.
|
||||
statement_llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
statement_messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Determine if the user's statement is a complete sentence or question, ending in a natural pause or punctuation. Return 'YES' if it is complete and 'NO' if it seems to leave a thought unfinished.",
|
||||
},
|
||||
]
|
||||
|
||||
statement_context = OpenAILLMContext(statement_messages)
|
||||
statement_context_aggregator = statement_llm.create_context_aggregator(statement_context)
|
||||
|
||||
# This is the regular LLM.
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
llm_main = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
messages = [
|
||||
{
|
||||
@@ -93,53 +148,16 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
context_aggregator = llm_main.create_context_aggregator(context)
|
||||
|
||||
# We have instructed the LLM to return 'YES' if it thinks the user
|
||||
# completed a sentence. So, if it's 'YES' we will return true in this
|
||||
# predicate which will wake up the notifier.
|
||||
async def wake_check_filter(frame):
|
||||
return frame.text == "YES"
|
||||
# LLM + turn detection (with an extra LLM as a judge)
|
||||
llm = TurnDetectionLLM(llm_main, context_aggregator)
|
||||
|
||||
# This is a notifier that we use to synchronize the two LLMs.
|
||||
notifier = EventNotifier()
|
||||
|
||||
# This a filter that will wake up the notifier if the given predicate
|
||||
# (wake_check_filter) returns true.
|
||||
completness_check = WakeNotifierFilter(notifier, types=(TextFrame,), filter=wake_check_filter)
|
||||
|
||||
# This processor keeps the last context and will let it through once the
|
||||
# notifier is woken up. We start with the gate open because we send an
|
||||
# initial context frame to start the conversation.
|
||||
gated_context_aggregator = GatedOpenAILLMContextAggregator(notifier=notifier, start_open=True)
|
||||
|
||||
# Notify if the user hasn't said anything.
|
||||
async def user_idle_notifier(frame):
|
||||
await notifier.notify()
|
||||
|
||||
# Sometimes the LLM will fail detecting if a user has completed a
|
||||
# sentence, this will wake up the notifier if that happens.
|
||||
user_idle = UserIdleProcessor(callback=user_idle_notifier, timeout=3.0)
|
||||
|
||||
# The ParallePipeline input are the user transcripts. We have two
|
||||
# contexts. The first one will be used to determine if the user finished
|
||||
# a statement and if so the notifier will be woken up. The second
|
||||
# context is simply the regular context but it's gated waiting for the
|
||||
# notifier to be woken up.
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt,
|
||||
ParallelPipeline(
|
||||
[
|
||||
statement_context_aggregator.user(),
|
||||
statement_llm,
|
||||
completness_check,
|
||||
NullFilter(),
|
||||
],
|
||||
[context_aggregator.user(), gated_context_aggregator, llm],
|
||||
),
|
||||
user_idle,
|
||||
stt, # STT
|
||||
llm, # LLM with turn detection
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
@@ -44,13 +43,14 @@ from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.services.llm_service import FunctionCallParams, LLMService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.sync.base_notifier import BaseNotifier
|
||||
from pipecat.sync.event_notifier import EventNotifier
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.network.fastapi_websocket import FastAPIWebsocketParams
|
||||
from pipecat.transports.services.daily import DailyParams
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
@@ -192,6 +192,75 @@ async def fetch_weather_from_api(params: FunctionCallParams):
|
||||
await params.result_callback({"conditions": "nice", "temperature": "75"})
|
||||
|
||||
|
||||
class TurnDetectionLLM(Pipeline):
|
||||
def __init__(self, llm: LLMService):
|
||||
# This is the LLM that will be used to detect if the user has finished a
|
||||
# statement. This doesn't really need to be an LLM, we could use NLP
|
||||
# libraries for that, but we have the machinery to use an LLM, so we
|
||||
# might as well!
|
||||
statement_llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
# We have instructed the LLM to return 'YES' if it thinks the user
|
||||
# completed a sentence. So, if it's 'YES' we will return true in this
|
||||
# predicate which will wake up the notifier.
|
||||
async def wake_check_filter(frame):
|
||||
logger.debug(f"Completeness check frame: {frame}")
|
||||
return frame.text == "YES"
|
||||
|
||||
# This is a notifier that we use to synchronize the two LLMs.
|
||||
notifier = EventNotifier()
|
||||
|
||||
# This turns the LLM context into an inference request to classify the user's speech
|
||||
# as complete or incomplete.
|
||||
statement_judge_context_filter = StatementJudgeContextFilter()
|
||||
|
||||
# This sends a UserStoppedSpeakingFrame and triggers the notifier event
|
||||
completeness_check = CompletenessCheck(notifier=notifier)
|
||||
|
||||
# # Notify if the user hasn't said anything.
|
||||
async def user_idle_notifier(frame):
|
||||
await notifier.notify()
|
||||
|
||||
# Sometimes the LLM will fail detecting if a user has completed a
|
||||
# sentence, this will wake up the notifier if that happens.
|
||||
user_idle = UserIdleProcessor(callback=user_idle_notifier, timeout=5.0)
|
||||
|
||||
# We start with the gate open because we send an initial context frame
|
||||
# to start the conversation.
|
||||
bot_output_gate = OutputGate(notifier=notifier, start_open=True)
|
||||
|
||||
async def pass_only_llm_trigger_frames(frame):
|
||||
return (
|
||||
isinstance(frame, OpenAILLMContextFrame)
|
||||
or isinstance(frame, StartInterruptionFrame)
|
||||
or isinstance(frame, StopInterruptionFrame)
|
||||
or isinstance(frame, FunctionCallInProgressFrame)
|
||||
or isinstance(frame, FunctionCallResultFrame)
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
[
|
||||
ParallelPipeline(
|
||||
[
|
||||
# Ignore everything except an OpenAILLMContextFrame. Pass a specially constructed
|
||||
# simplified context frame to the statement classifier LLM. The only frame this
|
||||
# sub-pipeline will output is a UserStoppedSpeakingFrame.
|
||||
statement_judge_context_filter,
|
||||
statement_llm,
|
||||
completeness_check,
|
||||
],
|
||||
[
|
||||
# Block everything except frames that trigger LLM inference.
|
||||
FunctionFilter(filter=pass_only_llm_trigger_frames),
|
||||
llm,
|
||||
bot_output_gate, # Buffer all llm/tts output until notified.
|
||||
],
|
||||
),
|
||||
user_idle,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
@@ -224,18 +293,13 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
)
|
||||
|
||||
# This is the LLM that will be used to detect if the user has finished a
|
||||
# statement. This doesn't really need to be an LLM, we could use NLP
|
||||
# libraries for that, but we have the machinery to use an LLM, so we might as well!
|
||||
statement_llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
# This is the regular LLM.
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
llm_main = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
# You can also register a function_name of None to get all functions
|
||||
# sent to the same callback with an additional function_name parameter.
|
||||
llm.register_function("get_current_weather", fetch_weather_from_api)
|
||||
llm_main.register_function("get_current_weather", fetch_weather_from_api)
|
||||
|
||||
@llm.event_handler("on_function_calls_started")
|
||||
@llm_main.event_handler("on_function_calls_started")
|
||||
async def on_function_calls_started(service, function_calls):
|
||||
await tts.queue_frame(TTSSpeakFrame("Let me check on that."))
|
||||
|
||||
@@ -272,69 +336,18 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages, tools)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
context_aggregator = llm_main.create_context_aggregator(context)
|
||||
|
||||
# We have instructed the LLM to return 'YES' if it thinks the user
|
||||
# completed a sentence. So, if it's 'YES' we will return true in this
|
||||
# predicate which will wake up the notifier.
|
||||
async def wake_check_filter(frame):
|
||||
logger.debug(f"Completeness check frame: {frame}")
|
||||
return frame.text == "YES"
|
||||
|
||||
# This is a notifier that we use to synchronize the two LLMs.
|
||||
notifier = EventNotifier()
|
||||
|
||||
# This turns the LLM context into an inference request to classify the user's speech
|
||||
# as complete or incomplete.
|
||||
statement_judge_context_filter = StatementJudgeContextFilter()
|
||||
|
||||
# This sends a UserStoppedSpeakingFrame and triggers the notifier event
|
||||
completeness_check = CompletenessCheck(notifier=notifier)
|
||||
|
||||
# # Notify if the user hasn't said anything.
|
||||
async def user_idle_notifier(frame):
|
||||
await notifier.notify()
|
||||
|
||||
# Sometimes the LLM will fail detecting if a user has completed a
|
||||
# sentence, this will wake up the notifier if that happens.
|
||||
user_idle = UserIdleProcessor(callback=user_idle_notifier, timeout=5.0)
|
||||
|
||||
# We start with the gate open because we send an initial context frame
|
||||
# to start the conversation.
|
||||
bot_output_gate = OutputGate(notifier=notifier, start_open=True)
|
||||
|
||||
async def pass_only_llm_trigger_frames(frame):
|
||||
return (
|
||||
isinstance(frame, OpenAILLMContextFrame)
|
||||
or isinstance(frame, StartInterruptionFrame)
|
||||
or isinstance(frame, StopInterruptionFrame)
|
||||
or isinstance(frame, FunctionCallInProgressFrame)
|
||||
or isinstance(frame, FunctionCallResultFrame)
|
||||
)
|
||||
# LLM + turn detection (with an extra LLM as a judge)
|
||||
llm = TurnDetectionLLM(llm_main)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
stt,
|
||||
context_aggregator.user(),
|
||||
ParallelPipeline(
|
||||
[
|
||||
# Ignore everything except an OpenAILLMContextFrame. Pass a specially constructed
|
||||
# simplified context frame to the statement classifier LLM. The only frame this
|
||||
# sub-pipeline will output is a UserStoppedSpeakingFrame.
|
||||
statement_judge_context_filter,
|
||||
statement_llm,
|
||||
completeness_check,
|
||||
],
|
||||
[
|
||||
# Block everything except frames that trigger LLM inference.
|
||||
FunctionFilter(filter=pass_only_llm_trigger_frames),
|
||||
llm,
|
||||
bot_output_gate, # Buffer all llm/tts output until notified.
|
||||
],
|
||||
),
|
||||
llm,
|
||||
tts,
|
||||
user_idle,
|
||||
transport.output(),
|
||||
context_aggregator.assistant(),
|
||||
]
|
||||
@@ -365,7 +378,9 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
await task.queue_frames(
|
||||
[
|
||||
UserStartedSpeakingFrame(),
|
||||
TranscriptionFrame(user_id="", timestamp=time.time(), text=message["message"]),
|
||||
TranscriptionFrame(
|
||||
user_id="", timestamp=time_now_iso8601(), text=message["message"]
|
||||
),
|
||||
UserStoppedSpeakingFrame(),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
@@ -45,13 +44,14 @@ from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.anthropic.llm import AnthropicLLMService
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.services.llm_service import FunctionCallParams, LLMService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.sync.base_notifier import BaseNotifier
|
||||
from pipecat.sync.event_notifier import EventNotifier
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.network.fastapi_websocket import FastAPIWebsocketParams
|
||||
from pipecat.transports.services.daily import DailyParams
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
@@ -391,6 +391,75 @@ class OutputGate(FrameProcessor):
|
||||
break
|
||||
|
||||
|
||||
class TurnDetectionLLM(Pipeline):
|
||||
def __init__(self, llm: LLMService):
|
||||
# This is the LLM that will be used to detect if the user has finished a
|
||||
# statement. This doesn't really need to be an LLM, we could use NLP
|
||||
# libraries for that, but we have the machinery to use an LLM, so we might as well!
|
||||
statement_llm = AnthropicLLMService(api_key=os.getenv("ANTHROPIC_API_KEY"))
|
||||
|
||||
# This is a notifier that we use to synchronize the two LLMs.
|
||||
notifier = EventNotifier()
|
||||
|
||||
# This turns the LLM context into an inference request to classify the user's speech
|
||||
# as complete or incomplete.
|
||||
statement_judge_context_filter = StatementJudgeContextFilter()
|
||||
|
||||
# This sends a UserStoppedSpeakingFrame and triggers the notifier event
|
||||
completeness_check = CompletenessCheck(notifier=notifier)
|
||||
|
||||
# # Notify if the user hasn't said anything.
|
||||
async def user_idle_notifier(frame):
|
||||
await notifier.notify()
|
||||
|
||||
# Sometimes the LLM will fail detecting if a user has completed a
|
||||
# sentence, this will wake up the notifier if that happens.
|
||||
user_idle = UserIdleProcessor(callback=user_idle_notifier, timeout=5.0)
|
||||
|
||||
# We start with the gate open because we send an initial context frame
|
||||
# to start the conversation.
|
||||
bot_output_gate = OutputGate(notifier=notifier, start_open=True)
|
||||
|
||||
async def block_user_stopped_speaking(frame):
|
||||
return not isinstance(frame, UserStoppedSpeakingFrame)
|
||||
|
||||
async def pass_only_llm_trigger_frames(frame):
|
||||
return (
|
||||
isinstance(frame, OpenAILLMContextFrame)
|
||||
or isinstance(frame, StartInterruptionFrame)
|
||||
or isinstance(frame, StopInterruptionFrame)
|
||||
or isinstance(frame, FunctionCallInProgressFrame)
|
||||
or isinstance(frame, FunctionCallResultFrame)
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
[
|
||||
ParallelPipeline(
|
||||
[
|
||||
# Pass everything except UserStoppedSpeaking to the elements after
|
||||
# this ParallelPipeline
|
||||
FunctionFilter(filter=block_user_stopped_speaking),
|
||||
],
|
||||
[
|
||||
# Ignore everything except an OpenAILLMContextFrame. Pass a specially constructed
|
||||
# simplified context frame to the statement classifier LLM. The only frame this
|
||||
# sub-pipeline will output is a UserStoppedSpeakingFrame.
|
||||
statement_judge_context_filter,
|
||||
statement_llm,
|
||||
completeness_check,
|
||||
],
|
||||
[
|
||||
# Block everything except frames that trigger LLM inference.
|
||||
FunctionFilter(filter=pass_only_llm_trigger_frames),
|
||||
llm,
|
||||
bot_output_gate, # Buffer all llm/tts output until notified.
|
||||
],
|
||||
),
|
||||
user_idle,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
async def fetch_weather_from_api(params: FunctionCallParams):
|
||||
await params.result_callback({"conditions": "nice", "temperature": "75"})
|
||||
|
||||
@@ -427,18 +496,13 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
)
|
||||
|
||||
# This is the LLM that will be used to detect if the user has finished a
|
||||
# statement. This doesn't really need to be an LLM, we could use NLP
|
||||
# libraries for that, but we have the machinery to use an LLM, so we might as well!
|
||||
statement_llm = AnthropicLLMService(api_key=os.getenv("ANTHROPIC_API_KEY"))
|
||||
|
||||
# This is the regular LLM.
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
llm_main = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
# Register a function_name of None to get all functions
|
||||
# sent to the same callback with an additional function_name parameter.
|
||||
llm.register_function("get_current_weather", fetch_weather_from_api)
|
||||
llm_main.register_function("get_current_weather", fetch_weather_from_api)
|
||||
|
||||
@llm.event_handler("on_function_calls_started")
|
||||
@llm_main.event_handler("on_function_calls_started")
|
||||
async def on_function_calls_started(service, function_calls):
|
||||
await tts.queue_frame(TTSSpeakFrame("Let me check on that."))
|
||||
|
||||
@@ -475,76 +539,18 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages, tools)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
context_aggregator = llm_main.create_context_aggregator(context)
|
||||
|
||||
# We have instructed the LLM to return 'YES' if it thinks the user
|
||||
# completed a sentence. So, if it's 'YES' we will return true in this
|
||||
# predicate which will wake up the notifier.
|
||||
async def wake_check_filter(frame):
|
||||
return frame.text == "YES"
|
||||
|
||||
# This is a notifier that we use to synchronize the two LLMs.
|
||||
notifier = EventNotifier()
|
||||
|
||||
# This turns the LLM context into an inference request to classify the user's speech
|
||||
# as complete or incomplete.
|
||||
statement_judge_context_filter = StatementJudgeContextFilter()
|
||||
|
||||
# This sends a UserStoppedSpeakingFrame and triggers the notifier event
|
||||
completeness_check = CompletenessCheck(notifier=notifier)
|
||||
|
||||
# # Notify if the user hasn't said anything.
|
||||
async def user_idle_notifier(frame):
|
||||
await notifier.notify()
|
||||
|
||||
# Sometimes the LLM will fail detecting if a user has completed a
|
||||
# sentence, this will wake up the notifier if that happens.
|
||||
user_idle = UserIdleProcessor(callback=user_idle_notifier, timeout=5.0)
|
||||
|
||||
# We start with the gate open because we send an initial context frame
|
||||
# to start the conversation.
|
||||
bot_output_gate = OutputGate(notifier=notifier, start_open=True)
|
||||
|
||||
async def block_user_stopped_speaking(frame):
|
||||
return not isinstance(frame, UserStoppedSpeakingFrame)
|
||||
|
||||
async def pass_only_llm_trigger_frames(frame):
|
||||
return (
|
||||
isinstance(frame, OpenAILLMContextFrame)
|
||||
or isinstance(frame, StartInterruptionFrame)
|
||||
or isinstance(frame, StopInterruptionFrame)
|
||||
or isinstance(frame, FunctionCallInProgressFrame)
|
||||
or isinstance(frame, FunctionCallResultFrame)
|
||||
)
|
||||
# LLM + turn detection (with an extra LLM as a judge)
|
||||
llm = TurnDetectionLLM(llm_main)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
stt,
|
||||
context_aggregator.user(),
|
||||
ParallelPipeline(
|
||||
[
|
||||
# Pass everything except UserStoppedSpeaking to the elements after
|
||||
# this ParallelPipeline
|
||||
FunctionFilter(filter=block_user_stopped_speaking),
|
||||
],
|
||||
[
|
||||
# Ignore everything except an OpenAILLMContextFrame. Pass a specially constructed
|
||||
# simplified context frame to the statement classifier LLM. The only frame this
|
||||
# sub-pipeline will output is a UserStoppedSpeakingFrame.
|
||||
statement_judge_context_filter,
|
||||
statement_llm,
|
||||
completeness_check,
|
||||
],
|
||||
[
|
||||
# Block everything except frames that trigger LLM inference.
|
||||
FunctionFilter(filter=pass_only_llm_trigger_frames),
|
||||
llm,
|
||||
bot_output_gate, # Buffer all llm/tts output until notified.
|
||||
],
|
||||
),
|
||||
llm,
|
||||
tts,
|
||||
user_idle,
|
||||
transport.output(),
|
||||
context_aggregator.assistant(),
|
||||
]
|
||||
@@ -580,7 +586,9 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
await task.queue_frames(
|
||||
[
|
||||
UserStartedSpeakingFrame(),
|
||||
TranscriptionFrame(user_id="", timestamp=time.time(), text=message["message"]),
|
||||
TranscriptionFrame(
|
||||
user_id="", timestamp=time_now_iso8601(), text=message["message"]
|
||||
),
|
||||
UserStoppedSpeakingFrame(),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -47,11 +47,13 @@ from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.google.llm import GoogleLLMContext, GoogleLLMService
|
||||
from pipecat.services.llm_service import LLMService
|
||||
from pipecat.sync.base_notifier import BaseNotifier
|
||||
from pipecat.sync.event_notifier import EventNotifier
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.network.fastapi_websocket import FastAPIWebsocketParams
|
||||
from pipecat.transports.services.daily import DailyParams
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
@@ -607,23 +609,90 @@ class OutputGate(FrameProcessor):
|
||||
self._gate_task = None
|
||||
|
||||
async def _gate_task_handler(self):
|
||||
while True:
|
||||
try:
|
||||
await self._notifier.wait()
|
||||
await self._notifier.wait()
|
||||
|
||||
transcription = await self._transcription_buffer.wait_for_transcription() or "-"
|
||||
self._context.add_message(Content(role="user", parts=[Part(text=transcription)]))
|
||||
transcription = await self._transcription_buffer.wait_for_transcription() or "-"
|
||||
self._context.add_message(Content(role="user", parts=[Part(text=transcription)]))
|
||||
|
||||
self.open_gate()
|
||||
for frame, direction in self._frames_buffer:
|
||||
await self.push_frame(frame, direction)
|
||||
self._frames_buffer = []
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"OutputGate error: {e}")
|
||||
raise e
|
||||
break
|
||||
self.open_gate()
|
||||
for frame, direction in self._frames_buffer:
|
||||
await self.push_frame(frame, direction)
|
||||
self._frames_buffer = []
|
||||
|
||||
|
||||
class TurnDetectionLLM(Pipeline):
|
||||
def __init__(self, llm: LLMService, context: OpenAILLMContext):
|
||||
# This is the LLM that will transcribe user speech.
|
||||
tx_llm = GoogleLLMService(
|
||||
name="Transcriber",
|
||||
model=TRANSCRIBER_MODEL,
|
||||
api_key=os.getenv("GOOGLE_API_KEY"),
|
||||
temperature=0.0,
|
||||
system_instruction=transcriber_system_instruction,
|
||||
)
|
||||
|
||||
# This is the LLM that will classify user speech as complete or incomplete.
|
||||
classifier_llm = GoogleLLMService(
|
||||
name="Classifier",
|
||||
model=CLASSIFIER_MODEL,
|
||||
api_key=os.getenv("GOOGLE_API_KEY"),
|
||||
temperature=0.0,
|
||||
system_instruction=classifier_system_instruction,
|
||||
)
|
||||
|
||||
# This is a notifier that we use to synchronize the two LLMs.
|
||||
notifier = EventNotifier()
|
||||
|
||||
# This turns the LLM context into an inference request to classify the user's speech
|
||||
# as complete or incomplete.
|
||||
# statement_judge_context_filter = StatementJudgeAudioContextAccumulator(notifier=notifier)
|
||||
|
||||
audio_accumulater = AudioAccumulator()
|
||||
# This sends a UserStoppedSpeakingFrame and triggers the notifier event
|
||||
completeness_check = CompletenessCheck(
|
||||
notifier=notifier, audio_accumulator=audio_accumulater
|
||||
)
|
||||
|
||||
async def block_user_stopped_speaking(frame):
|
||||
return not isinstance(frame, UserStoppedSpeakingFrame)
|
||||
|
||||
conversation_audio_context_assembler = ConversationAudioContextAssembler(context=context)
|
||||
|
||||
llm_aggregator_buffer = LLMAggregatorBuffer()
|
||||
|
||||
bot_output_gate = OutputGate(
|
||||
notifier=notifier, context=context, llm_transcription_buffer=llm_aggregator_buffer
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
[
|
||||
audio_accumulater,
|
||||
ParallelPipeline(
|
||||
[
|
||||
# Pass everything except UserStoppedSpeaking to the elements after
|
||||
# this ParallelPipeline
|
||||
FunctionFilter(filter=block_user_stopped_speaking),
|
||||
],
|
||||
[
|
||||
ParallelPipeline(
|
||||
[
|
||||
classifier_llm,
|
||||
completeness_check,
|
||||
],
|
||||
[
|
||||
tx_llm,
|
||||
llm_aggregator_buffer,
|
||||
],
|
||||
)
|
||||
],
|
||||
[
|
||||
conversation_audio_context_assembler,
|
||||
llm,
|
||||
bot_output_gate, # buffer output until notified, then flush frames and update context
|
||||
],
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
@@ -656,24 +725,6 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
)
|
||||
|
||||
# This is the LLM that will transcribe user speech.
|
||||
tx_llm = GoogleLLMService(
|
||||
name="Transcriber",
|
||||
model=TRANSCRIBER_MODEL,
|
||||
api_key=os.getenv("GOOGLE_API_KEY"),
|
||||
temperature=0.0,
|
||||
system_instruction=transcriber_system_instruction,
|
||||
)
|
||||
|
||||
# This is the LLM that will classify user speech as complete or incomplete.
|
||||
classifier_llm = GoogleLLMService(
|
||||
name="Classifier",
|
||||
model=CLASSIFIER_MODEL,
|
||||
api_key=os.getenv("GOOGLE_API_KEY"),
|
||||
temperature=0.0,
|
||||
system_instruction=classifier_system_instruction,
|
||||
)
|
||||
|
||||
# This is the regular LLM that responds conversationally.
|
||||
conversation_llm = GoogleLLMService(
|
||||
name="Conversation",
|
||||
@@ -685,57 +736,12 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
context = OpenAILLMContext()
|
||||
context_aggregator = conversation_llm.create_context_aggregator(context)
|
||||
|
||||
# This is a notifier that we use to synchronize the two LLMs.
|
||||
notifier = EventNotifier()
|
||||
|
||||
# This turns the LLM context into an inference request to classify the user's speech
|
||||
# as complete or incomplete.
|
||||
# statement_judge_context_filter = StatementJudgeAudioContextAccumulator(notifier=notifier)
|
||||
|
||||
audio_accumulater = AudioAccumulator()
|
||||
# This sends a UserStoppedSpeakingFrame and triggers the notifier event
|
||||
completeness_check = CompletenessCheck(notifier=notifier, audio_accumulator=audio_accumulater)
|
||||
|
||||
async def block_user_stopped_speaking(frame):
|
||||
return not isinstance(frame, UserStoppedSpeakingFrame)
|
||||
|
||||
conversation_audio_context_assembler = ConversationAudioContextAssembler(context=context)
|
||||
|
||||
llm_aggregator_buffer = LLMAggregatorBuffer()
|
||||
|
||||
bot_output_gate = OutputGate(
|
||||
notifier=notifier, context=context, llm_transcription_buffer=llm_aggregator_buffer
|
||||
)
|
||||
llm = TurnDetectionLLM(conversation_llm, context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
audio_accumulater,
|
||||
ParallelPipeline(
|
||||
[
|
||||
# Pass everything except UserStoppedSpeaking to the elements after
|
||||
# this ParallelPipeline
|
||||
FunctionFilter(filter=block_user_stopped_speaking),
|
||||
],
|
||||
[
|
||||
ParallelPipeline(
|
||||
[
|
||||
classifier_llm,
|
||||
completeness_check,
|
||||
],
|
||||
[
|
||||
tx_llm,
|
||||
llm_aggregator_buffer,
|
||||
],
|
||||
)
|
||||
],
|
||||
[
|
||||
conversation_audio_context_assembler,
|
||||
conversation_llm,
|
||||
bot_output_gate, # buffer output until notified, then flush frames and update context
|
||||
# TempPrinter(),
|
||||
],
|
||||
),
|
||||
llm,
|
||||
tts,
|
||||
transport.output(),
|
||||
context_aggregator.assistant(),
|
||||
@@ -766,7 +772,9 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
await task.queue_frames(
|
||||
[
|
||||
UserStartedSpeakingFrame(),
|
||||
TranscriptionFrame(user_id="", timestamp=time.time(), text=message["message"]),
|
||||
TranscriptionFrame(
|
||||
user_id="", timestamp=time_now_iso8601(), text=message["message"]
|
||||
),
|
||||
UserStoppedSpeakingFrame(),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -34,6 +34,8 @@ dependencies = [
|
||||
"resampy~=0.4.3",
|
||||
"soxr~=0.5.0",
|
||||
"openai>=1.74.0,<=1.99.1",
|
||||
# Pinning numba to resolve package dependencies
|
||||
"numba==0.61.2",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
@@ -72,7 +74,7 @@ local = [ "pyaudio~=0.2.14" ]
|
||||
mcp = [ "mcp[cli]~=1.9.4" ]
|
||||
mem0 = [ "mem0ai~=0.1.94" ]
|
||||
mlx-whisper = [ "mlx-whisper~=0.4.2" ]
|
||||
moondream = [ "einops~=0.8.0", "timm~=1.0.13", "transformers>=4.48.0" ]
|
||||
moondream = [ "accelerate~=1.10.0", "einops~=0.8.0", "pyvips[binary]~=3.0.0", "timm~=1.0.13", "transformers>=4.48.0" ]
|
||||
nim = []
|
||||
neuphonic = [ "websockets>=13.1,<15.0" ]
|
||||
noisereduce = [ "noisereduce~=3.0.3" ]
|
||||
@@ -80,7 +82,7 @@ openai = [ "websockets>=13.1,<15.0" ]
|
||||
openpipe = [ "openpipe~=4.50.0" ]
|
||||
openrouter = []
|
||||
perplexity = []
|
||||
playht = [ "pyht>=0.1.6", "websockets>=13.1,<15.0" ]
|
||||
playht = [ "websockets>=13.1,<15.0" ]
|
||||
qwen = []
|
||||
rime = [ "websockets>=13.1,<15.0" ]
|
||||
riva = [ "nvidia-riva-client~=2.21.1" ]
|
||||
@@ -122,7 +124,7 @@ dev = [
|
||||
docs = [
|
||||
"sphinx>=8.1.3",
|
||||
"sphinx-rtd-theme",
|
||||
"sphinx-markdown-builder",
|
||||
"sphinx-markdown-builder",
|
||||
"sphinx-autodoc-typehints",
|
||||
"toml",
|
||||
]
|
||||
|
||||
BIN
scripts/evals/assets/cat.jpg
Normal file
BIN
scripts/evals/assets/cat.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 63 KiB |
@@ -4,7 +4,6 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import io
|
||||
import os
|
||||
@@ -13,11 +12,12 @@ import time
|
||||
import wave
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import aiofiles
|
||||
from deepgram import LiveOptions
|
||||
from loguru import logger
|
||||
from PIL.ImageFile import ImageFile
|
||||
from utils import (
|
||||
EvalResult,
|
||||
load_module_from_path,
|
||||
@@ -30,7 +30,7 @@ from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import EndTaskFrame
|
||||
from pipecat.frames.frames import EndTaskFrame, OutputImageRawFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
@@ -49,6 +49,8 @@ SCRIPT_DIR = Path(__file__).resolve().parent
|
||||
PIPELINE_IDLE_TIMEOUT_SECS = 60
|
||||
EVAL_TIMEOUT_SECS = 90
|
||||
|
||||
EvalPrompt = str | Tuple[str, ImageFile]
|
||||
|
||||
|
||||
class EvalRunner:
|
||||
def __init__(
|
||||
@@ -87,7 +89,7 @@ class EvalRunner:
|
||||
async def assert_eval_false(self):
|
||||
await self._queue.put(False)
|
||||
|
||||
async def run_eval(self, example_file: str, prompt: str, eval: Optional[str] = None):
|
||||
async def run_eval(self, example_file: str, prompt: EvalPrompt, eval: Optional[str] = None):
|
||||
if not re.match(self._pattern, example_file):
|
||||
return
|
||||
|
||||
@@ -178,6 +180,7 @@ async def run_example_pipeline(script_path: Path):
|
||||
DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
video_in_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
)
|
||||
@@ -189,7 +192,10 @@ async def run_example_pipeline(script_path: Path):
|
||||
|
||||
|
||||
async def run_eval_pipeline(
|
||||
eval_runner: EvalRunner, example_file: str, prompt: str, eval: Optional[str]
|
||||
eval_runner: EvalRunner,
|
||||
example_file: str,
|
||||
prompt: EvalPrompt,
|
||||
eval: Optional[str],
|
||||
):
|
||||
logger.info(f"Starting eval bot")
|
||||
|
||||
@@ -202,6 +208,7 @@ async def run_eval_pipeline(
|
||||
DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
video_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=2.0)),
|
||||
),
|
||||
)
|
||||
@@ -210,7 +217,10 @@ async def run_eval_pipeline(
|
||||
# 5" (in audio) this can be converted to "32 is 5".
|
||||
stt = DeepgramSTTService(
|
||||
api_key=os.getenv("DEEPGRAM_API_KEY"),
|
||||
live_options=LiveOptions(smart_format=False),
|
||||
live_options=LiveOptions(
|
||||
language="multi",
|
||||
smart_format=False,
|
||||
),
|
||||
)
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
@@ -239,6 +249,14 @@ async def run_eval_pipeline(
|
||||
)
|
||||
tools = ToolsSchema(standard_tools=[eval_function])
|
||||
|
||||
# Load example prompt depending on image.
|
||||
example_prompt = ""
|
||||
example_image: Optional[ImageFile] = None
|
||||
if isinstance(prompt, str):
|
||||
example_prompt = prompt
|
||||
elif isinstance(prompt, tuple):
|
||||
example_prompt, example_image = prompt
|
||||
|
||||
# See if we need to include an eval prompt.
|
||||
eval_prompt = ""
|
||||
if eval:
|
||||
@@ -247,7 +265,7 @@ async def run_eval_pipeline(
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"You are an LLM eval, be extremly brief. Your goal is to only ask one question: {prompt}. Call the eval function only if the user answers the question and check if the answer is correct (words as numbers are valid). {eval_prompt}",
|
||||
"content": f"You are an LLM eval, be extremly brief. Your goal is to only ask one question: {example_prompt}. Call the eval function only if the user answers the question and check if the answer is correct (words as numbers are valid). {eval_prompt}",
|
||||
},
|
||||
]
|
||||
|
||||
@@ -285,6 +303,14 @@ async def run_eval_pipeline(
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
if example_image:
|
||||
await task.queue_frame(
|
||||
OutputImageRawFrame(
|
||||
image=example_image.tobytes(),
|
||||
size=example_image.size,
|
||||
format="RGB",
|
||||
)
|
||||
)
|
||||
await audio_buffer.start_recording()
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
|
||||
@@ -13,12 +13,15 @@ from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
from eval import EvalRunner
|
||||
from loguru import logger
|
||||
from PIL import Image
|
||||
from utils import check_env_variables
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
SCRIPT_DIR = Path(__file__).resolve().parent
|
||||
|
||||
ASSETS_DIR = SCRIPT_DIR / "assets"
|
||||
|
||||
FOUNDATIONAL_DIR = SCRIPT_DIR.parent.parent / "examples" / "foundational"
|
||||
|
||||
|
||||
@@ -35,6 +38,14 @@ EVAL_WEATHER = (
|
||||
PROMPT_ONLINE_SEARCH = "What's the date right now in London?"
|
||||
EVAL_ONLINE_SEARCH = f"Today is {datetime.now(timezone.utc).strftime('%B %d, %Y')}."
|
||||
|
||||
# Switch language
|
||||
PROMPT_SWITCH_LANGUAGE = "Say something in Spanish."
|
||||
EVAL_SWITCH_LANGUAGE = "Check if the user is now talking in Spanish."
|
||||
|
||||
# Vision
|
||||
PROMPT_VISION = ("What do you see?", Image.open(ASSETS_DIR / "cat.jpg"))
|
||||
EVAL_VISION = "A cat description."
|
||||
|
||||
TESTS_07 = [
|
||||
# 07 series
|
||||
("07-interruptible.py", PROMPT_SIMPLE_MATH, None),
|
||||
@@ -57,6 +68,7 @@ TESTS_07 = [
|
||||
("07k-interruptible-lmnt.py", PROMPT_SIMPLE_MATH, None),
|
||||
("07l-interruptible-groq.py", PROMPT_SIMPLE_MATH, None),
|
||||
("07m-interruptible-aws.py", PROMPT_SIMPLE_MATH, None),
|
||||
("07n-interruptible-gemini.py", PROMPT_SIMPLE_MATH, None),
|
||||
("07n-interruptible-google.py", PROMPT_SIMPLE_MATH, None),
|
||||
("07o-interruptible-assemblyai.py", PROMPT_SIMPLE_MATH, None),
|
||||
("07q-interruptible-rime.py", PROMPT_SIMPLE_MATH, None),
|
||||
@@ -77,6 +89,13 @@ TESTS_07 = [
|
||||
# ("07u-interruptible-ultravox.py", PROMPT_SIMPLE_MATH, None),
|
||||
]
|
||||
|
||||
TESTS_12 = [
|
||||
("12-describe-video.py", PROMPT_VISION, EVAL_VISION),
|
||||
("12a-describe-video-gemini-flash.py", PROMPT_VISION, EVAL_VISION),
|
||||
("12b-describe-video-gpt-4o.py", PROMPT_VISION, EVAL_VISION),
|
||||
("12c-describe-video-anthropic.py", PROMPT_VISION, EVAL_VISION),
|
||||
]
|
||||
|
||||
TESTS_14 = [
|
||||
("14-function-calling.py", PROMPT_WEATHER, EVAL_WEATHER),
|
||||
("14a-function-calling-anthropic.py", PROMPT_WEATHER, EVAL_WEATHER),
|
||||
@@ -93,7 +112,7 @@ TESTS_14 = [
|
||||
("14p-function-calling-gemini-vertex-ai.py", PROMPT_WEATHER, EVAL_WEATHER),
|
||||
("14q-function-calling-qwen.py", PROMPT_WEATHER, EVAL_WEATHER),
|
||||
("14r-function-calling-aws.py", PROMPT_WEATHER, EVAL_WEATHER),
|
||||
("14v-function-calling-openai.py.py", PROMPT_WEATHER, EVAL_WEATHER),
|
||||
("14v-function-calling-openai.py", PROMPT_WEATHER, EVAL_WEATHER),
|
||||
# Currently not working.
|
||||
# ("14c-function-calling-together.py", PROMPT_WEATHER, EVAL_WEATHER),
|
||||
# ("14k-function-calling-cerebras.py", PROMPT_WEATHER, EVAL_WEATHER),
|
||||
@@ -101,9 +120,14 @@ TESTS_14 = [
|
||||
# ("14o-function-calling-gemini-openai-format.py", PROMPT_WEATHER, EVAL_WEATHER),
|
||||
]
|
||||
|
||||
TESTS_15 = [
|
||||
("15a-switch-languages.py", PROMPT_SWITCH_LANGUAGE, EVAL_SWITCH_LANGUAGE),
|
||||
]
|
||||
|
||||
TESTS_19 = [
|
||||
("19-openai-realtime-beta.py", PROMPT_WEATHER, EVAL_WEATHER),
|
||||
("19a-azure-realtime-beta.py", PROMPT_WEATHER, EVAL_WEATHER),
|
||||
("19b-openai-realtime-beta-text.py", PROMPT_WEATHER, EVAL_WEATHER),
|
||||
]
|
||||
|
||||
TESTS_21 = [
|
||||
@@ -134,7 +158,9 @@ TESTS_43 = [
|
||||
|
||||
TESTS = [
|
||||
*TESTS_07,
|
||||
*TESTS_12,
|
||||
*TESTS_14,
|
||||
*TESTS_15,
|
||||
*TESTS_19,
|
||||
*TESTS_21,
|
||||
*TESTS_26,
|
||||
|
||||
@@ -49,7 +49,7 @@ class ParallelPipelineSource(FrameProcessor):
|
||||
upstream_queue: Queue for collecting upstream frames from this branch.
|
||||
push_frame_func: Function to push frames to the parent parallel pipeline.
|
||||
"""
|
||||
super().__init__()
|
||||
super().__init__(enable_direct_mode=True)
|
||||
self._up_queue = upstream_queue
|
||||
self._push_frame_func = push_frame_func
|
||||
|
||||
@@ -90,7 +90,7 @@ class ParallelPipelineSink(FrameProcessor):
|
||||
downstream_queue: Queue for collecting downstream frames from this branch.
|
||||
push_frame_func: Function to push frames to the parent parallel pipeline.
|
||||
"""
|
||||
super().__init__()
|
||||
super().__init__(enable_direct_mode=True)
|
||||
self._down_queue = downstream_queue
|
||||
self._push_frame_func = push_frame_func
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ class PipelineSource(FrameProcessor):
|
||||
Args:
|
||||
upstream_push_frame: Coroutine function to handle upstream frames.
|
||||
"""
|
||||
super().__init__()
|
||||
super().__init__(enable_direct_mode=True)
|
||||
self._upstream_push_frame = upstream_push_frame
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
@@ -65,7 +65,7 @@ class PipelineSink(FrameProcessor):
|
||||
Args:
|
||||
downstream_push_frame: Coroutine function to handle downstream frames.
|
||||
"""
|
||||
super().__init__()
|
||||
super().__init__(enable_direct_mode=True)
|
||||
self._downstream_push_frame = downstream_push_frame
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
|
||||
@@ -49,7 +49,7 @@ class SyncParallelPipelineSource(FrameProcessor):
|
||||
Args:
|
||||
upstream_queue: Queue for collecting upstream frames from the pipeline.
|
||||
"""
|
||||
super().__init__()
|
||||
super().__init__(enable_direct_mode=True)
|
||||
self._up_queue = upstream_queue
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
@@ -81,7 +81,7 @@ class SyncParallelPipelineSink(FrameProcessor):
|
||||
Args:
|
||||
downstream_queue: Queue for collecting downstream frames from the pipeline.
|
||||
"""
|
||||
super().__init__()
|
||||
super().__init__(enable_direct_mode=True)
|
||||
self._down_queue = downstream_queue
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
|
||||
@@ -110,14 +110,14 @@ class PipelineTaskSource(FrameProcessor):
|
||||
pipeline.
|
||||
"""
|
||||
|
||||
def __init__(self, up_queue: asyncio.Queue, **kwargs):
|
||||
def __init__(self, up_queue: asyncio.Queue):
|
||||
"""Initialize the pipeline task source.
|
||||
|
||||
Args:
|
||||
up_queue: Queue for upstream frame processing.
|
||||
**kwargs: Additional arguments passed to the parent class.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(enable_direct_mode=True)
|
||||
self._up_queue = up_queue
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
@@ -144,14 +144,14 @@ class PipelineTaskSink(FrameProcessor):
|
||||
act on them, for example, waiting to receive an EndFrame.
|
||||
"""
|
||||
|
||||
def __init__(self, down_queue: asyncio.Queue, **kwargs):
|
||||
def __init__(self, down_queue: asyncio.Queue):
|
||||
"""Initialize the pipeline task sink.
|
||||
|
||||
Args:
|
||||
down_queue: Queue for downstream frame processing.
|
||||
**kwargs: Additional arguments passed to the parent class.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(enable_direct_mode=True)
|
||||
self._down_queue = down_queue
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
|
||||
@@ -32,6 +32,8 @@ from pipecat.frames.frames import (
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
@@ -205,6 +207,8 @@ class STTMuteFilter(FrameProcessor):
|
||||
(
|
||||
StartInterruptionFrame,
|
||||
StopInterruptionFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
InputAudioRawFrame,
|
||||
|
||||
@@ -14,7 +14,7 @@ management, and frame flow control mechanisms.
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Awaitable, Callable, Coroutine, List, Optional, Sequence
|
||||
from typing import Any, Awaitable, Callable, Coroutine, List, Optional, Sequence, Tuple
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -38,6 +38,10 @@ from pipecat.observers.base_observer import BaseObserver, FramePushed
|
||||
from pipecat.processors.metrics.frame_processor_metrics import FrameProcessorMetrics
|
||||
from pipecat.utils.asyncio.task_manager import BaseTaskManager
|
||||
from pipecat.utils.asyncio.watchdog_event import WatchdogEvent
|
||||
from pipecat.utils.asyncio.watchdog_priority_queue import (
|
||||
WatchdogPriorityCancelSentinel,
|
||||
WatchdogPriorityQueue,
|
||||
)
|
||||
from pipecat.utils.asyncio.watchdog_queue import WatchdogQueue
|
||||
from pipecat.utils.base_object import BaseObject
|
||||
|
||||
@@ -54,6 +58,9 @@ class FrameDirection(Enum):
|
||||
UPSTREAM = 2
|
||||
|
||||
|
||||
FrameCallback = Callable[["FrameProcessor", Frame, FrameDirection], Awaitable[None]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class FrameProcessorSetup:
|
||||
"""Configuration parameters for frame processor initialization.
|
||||
@@ -71,23 +78,18 @@ class FrameProcessorSetup:
|
||||
watchdog_timers_enabled: bool = False
|
||||
|
||||
|
||||
class FrameProcessorQueue(WatchdogQueue):
|
||||
class FrameProcessorQueue(WatchdogPriorityQueue):
|
||||
"""A priority queue for systems frames and other frames.
|
||||
|
||||
This is a specialized queue for frame processors that separates and
|
||||
prioritizes system frames over other frames.
|
||||
|
||||
This queue uses two internal `WatchdogQueue` instances:
|
||||
- One for system-level frames (`SystemFrame`)
|
||||
- One for regular frames
|
||||
|
||||
It ensures that `SystemFrame` objects are processed before any other
|
||||
frames. Additionally, it uses an `asyncio.Event` to signal when new items
|
||||
have been added to either queue, allowing consumers to wait efficiently when
|
||||
the queue is empty.
|
||||
prioritizes system frames over other frames. It ensures that `SystemFrame`
|
||||
objects are processed before any other frames by using a priority queue.
|
||||
|
||||
"""
|
||||
|
||||
HIGH_PRIORITY = 1
|
||||
LOW_PRIORITY = 2
|
||||
|
||||
def __init__(self, manager: BaseTaskManager):
|
||||
"""Initialize the FrameProcessorQueue.
|
||||
|
||||
@@ -95,26 +97,28 @@ class FrameProcessorQueue(WatchdogQueue):
|
||||
manager (BaseTaskManager): The task manager used by the internal watchdog queues.
|
||||
|
||||
"""
|
||||
super().__init__(manager)
|
||||
self.__event = WatchdogEvent(manager)
|
||||
self.__main_queue = WatchdogQueue(manager)
|
||||
self.__system_queue = WatchdogQueue(manager)
|
||||
super().__init__(manager, tuple_size=3)
|
||||
self.__high_counter = 0
|
||||
self.__low_counter = 0
|
||||
|
||||
async def put(self, item: Any):
|
||||
"""Put an item into the appropriate queue.
|
||||
async def put(self, item: Tuple[Frame, FrameDirection, FrameCallback]):
|
||||
"""Put an item into the priority queue.
|
||||
|
||||
System frames (`SystemFrame`) are placed into the system queue and all others
|
||||
into the regular queue. Signals the event to wake up any waiting consumers.
|
||||
System frames (`SystemFrame`) have higher priority than any other
|
||||
frames. If a non-frame item (e.g. a watchdog cancellation sentinel) is
|
||||
provided it will have the highest priority.
|
||||
|
||||
Args:
|
||||
item (Any): The item to enqueue.
|
||||
|
||||
"""
|
||||
if isinstance(item, SystemFrame):
|
||||
await self.__system_queue.put(item)
|
||||
frame, _, _ = item
|
||||
if isinstance(frame, SystemFrame):
|
||||
self.__high_counter += 1
|
||||
await super().put((self.HIGH_PRIORITY, self.__high_counter, item))
|
||||
else:
|
||||
await self.__main_queue.put(item)
|
||||
self.__event.set()
|
||||
self.__low_counter += 1
|
||||
await super().put((self.LOW_PRIORITY, self.__low_counter, item))
|
||||
|
||||
async def get(self) -> Any:
|
||||
"""Retrieve the next item from the queue.
|
||||
@@ -126,38 +130,9 @@ class FrameProcessorQueue(WatchdogQueue):
|
||||
Any: The next item from the system or main queue.
|
||||
|
||||
"""
|
||||
# Wait for an item in any of the queues if they are empty.
|
||||
if self.__main_queue.empty() and self.__system_queue.empty():
|
||||
await self.__event.wait()
|
||||
|
||||
# Prioritize system frames.
|
||||
if self.__system_queue.qsize() > 0:
|
||||
item = await self.__system_queue.get()
|
||||
self.__system_queue.task_done()
|
||||
else:
|
||||
item = await self.__main_queue.get()
|
||||
self.__main_queue.task_done()
|
||||
|
||||
# Clear the event only if all queues are empty.
|
||||
if self.__main_queue.empty() and self.__system_queue.empty():
|
||||
self.__event.clear()
|
||||
|
||||
_, _, item = await super().get()
|
||||
return item
|
||||
|
||||
def cancel(self):
|
||||
"""Cancel both internal queues.
|
||||
|
||||
This method is used to stop processing and release any pending tasks
|
||||
in both the system and main queues. Typically used during shutdown
|
||||
or cleanup to prevent further processing of frames.
|
||||
|
||||
"""
|
||||
self.__main_queue.cancel()
|
||||
self.__system_queue.cancel()
|
||||
|
||||
|
||||
FrameCallback = Callable[["FrameProcessor", Frame, FrameDirection], Awaitable[None]]
|
||||
|
||||
|
||||
class FrameProcessor(BaseObject):
|
||||
"""Base class for all frame processors in the pipeline.
|
||||
@@ -175,6 +150,7 @@ class FrameProcessor(BaseObject):
|
||||
self,
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
enable_direct_mode: bool = False,
|
||||
enable_watchdog_logging: Optional[bool] = None,
|
||||
enable_watchdog_timers: Optional[bool] = None,
|
||||
metrics: Optional[FrameProcessorMetrics] = None,
|
||||
@@ -185,6 +161,7 @@ class FrameProcessor(BaseObject):
|
||||
|
||||
Args:
|
||||
name: Optional name for this processor instance.
|
||||
enable_direct_mode: Whether to process frames immediately or use internal queues.
|
||||
enable_watchdog_logging: Whether to enable watchdog logging for tasks.
|
||||
enable_watchdog_timers: Whether to enable watchdog timers for tasks.
|
||||
metrics: Optional metrics collector for this processor.
|
||||
@@ -196,6 +173,9 @@ class FrameProcessor(BaseObject):
|
||||
self._prev: Optional["FrameProcessor"] = None
|
||||
self._next: Optional["FrameProcessor"] = None
|
||||
|
||||
# Enable direct mode to skip queues and process frames right away.
|
||||
self._enable_direct_mode = enable_direct_mode
|
||||
|
||||
# Enable watchdog timers for all tasks created by this frame processor.
|
||||
self._enable_watchdog_timers = enable_watchdog_timers
|
||||
|
||||
@@ -254,9 +234,7 @@ class FrameProcessor(BaseObject):
|
||||
# called. To resume processing frames we need to call
|
||||
# `resume_processing_frames()` which will wake up the event.
|
||||
self.__should_block_frames = False
|
||||
self.__process_event = None
|
||||
self.__process_frame_task: Optional[asyncio.Task] = None
|
||||
self.__process_queue = None
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
@@ -558,7 +536,10 @@ class FrameProcessor(BaseObject):
|
||||
if self._cancelling:
|
||||
return
|
||||
|
||||
await self.__input_queue.put((frame, direction, callback))
|
||||
if self._enable_direct_mode:
|
||||
await self.__process_frame(frame, direction, callback)
|
||||
else:
|
||||
await self.__input_queue.put((frame, direction, callback))
|
||||
|
||||
async def pause_processing_frames(self):
|
||||
"""Pause processing of queued frames."""
|
||||
@@ -730,6 +711,9 @@ class FrameProcessor(BaseObject):
|
||||
|
||||
def __create_input_task(self):
|
||||
"""Create the frame input processing task."""
|
||||
if self._enable_direct_mode:
|
||||
return
|
||||
|
||||
if not self.__input_frame_task:
|
||||
self.__input_queue = FrameProcessorQueue(self.task_manager)
|
||||
self.__input_frame_task = self.create_task(self.__input_frame_task_handler())
|
||||
@@ -743,11 +727,12 @@ class FrameProcessor(BaseObject):
|
||||
|
||||
def __create_process_task(self):
|
||||
"""Create the non-system frame processing task."""
|
||||
if self._enable_direct_mode:
|
||||
return
|
||||
|
||||
if not self.__process_frame_task:
|
||||
self.__should_block_frames = False
|
||||
if not self.__process_event:
|
||||
self.__process_event = WatchdogEvent(self.task_manager)
|
||||
self.__process_event.clear()
|
||||
self.__process_event = WatchdogEvent(self.task_manager)
|
||||
self.__process_queue = WatchdogQueue(self.task_manager)
|
||||
self.__process_frame_task = self.create_task(self.__process_frame_task_handler())
|
||||
|
||||
@@ -759,7 +744,7 @@ class FrameProcessor(BaseObject):
|
||||
self.__process_frame_task = None
|
||||
|
||||
async def __process_frame(
|
||||
self, frame: Frame, direction: FrameDirection, callback: FrameCallback
|
||||
self, frame: Frame, direction: FrameDirection, callback: Optional[FrameCallback]
|
||||
):
|
||||
try:
|
||||
# Process the frame.
|
||||
@@ -790,10 +775,12 @@ class FrameProcessor(BaseObject):
|
||||
f"{self}: __process_queue is None when processing frame {frame.name}"
|
||||
)
|
||||
|
||||
self.__input_queue.task_done()
|
||||
|
||||
async def __process_frame_task_handler(self):
|
||||
"""Handle non-system frames from the process queue."""
|
||||
while True:
|
||||
if self.__should_block_frames and self.__process_event:
|
||||
if self.__should_block_frames:
|
||||
logger.trace(f"{self}: frame processing paused")
|
||||
await self.__process_event.wait()
|
||||
self.__process_event.clear()
|
||||
@@ -803,3 +790,5 @@ class FrameProcessor(BaseObject):
|
||||
(frame, direction, callback) = await self.__process_queue.get()
|
||||
|
||||
await self.__process_frame(frame, direction, callback)
|
||||
|
||||
self.__process_queue.task_done()
|
||||
|
||||
@@ -53,7 +53,7 @@ Supported transports:
|
||||
|
||||
- Daily - Creates rooms and tokens, runs bot as participant
|
||||
- WebRTC - Provides local WebRTC interface with prebuilt UI
|
||||
- Telephony - Handles webhook and WebSocket connections for Twilio, Telnyx, Plivo
|
||||
- Telephony - Handles webhook and WebSocket connections for Twilio, Telnyx, Plivo, Exotel
|
||||
|
||||
To run locally:
|
||||
|
||||
@@ -62,6 +62,7 @@ To run locally:
|
||||
- Daily (server): `python bot.py -t daily`
|
||||
- Daily (direct, testing only): `python bot.py -d`
|
||||
- Telephony: `python bot.py -t twilio -x your_username.ngrok.io`
|
||||
- Exotel: `python bot.py -t exotel` (no proxy needed, but ngrok connection to HTTP 7860 is required)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@@ -145,7 +146,6 @@ async def _run_telephony_bot(websocket: WebSocket):
|
||||
|
||||
# Just pass the WebSocket - let the bot handle parsing
|
||||
runner_args = WebSocketRunnerArguments(websocket=websocket)
|
||||
runner_args.handle_sigint = False
|
||||
|
||||
await bot_module.bot(runner_args)
|
||||
|
||||
@@ -169,7 +169,7 @@ def _create_server_app(
|
||||
_setup_webrtc_routes(app, esp32_mode=esp32_mode, host=host)
|
||||
elif transport_type == "daily":
|
||||
_setup_daily_routes(app)
|
||||
elif transport_type in ["twilio", "telnyx", "plivo"]:
|
||||
elif transport_type in ["twilio", "telnyx", "plivo", "exotel"]:
|
||||
_setup_telephony_routes(app, transport_type, proxy)
|
||||
else:
|
||||
logger.warning(f"Unknown transport type: {transport_type}")
|
||||
@@ -223,7 +223,6 @@ def _setup_webrtc_routes(app: FastAPI, esp32_mode: bool = False, host: str = "lo
|
||||
|
||||
bot_module = _get_bot_module()
|
||||
runner_args = SmallWebRTCRunnerArguments(webrtc_connection=pipecat_connection)
|
||||
runner_args.handle_sigint = False
|
||||
background_tasks.add_task(bot_module.bot, runner_args)
|
||||
|
||||
answer = pipecat_connection.get_answer()
|
||||
@@ -266,7 +265,6 @@ def _setup_daily_routes(app: FastAPI):
|
||||
# Start the bot in the background with empty body for GET requests
|
||||
bot_module = _get_bot_module()
|
||||
runner_args = DailyRunnerArguments(room_url=room_url, token=token)
|
||||
runner_args.handle_sigint = False
|
||||
asyncio.create_task(bot_module.bot(runner_args))
|
||||
return RedirectResponse(room_url)
|
||||
|
||||
@@ -311,7 +309,6 @@ def _setup_daily_routes(app: FastAPI):
|
||||
# Start the bot in the background with extracted body data
|
||||
bot_module = _get_bot_module()
|
||||
runner_args = DailyRunnerArguments(room_url=room_url, token=token, body=bot_body)
|
||||
runner_args.handle_sigint = False
|
||||
asyncio.create_task(bot_module.bot(runner_args))
|
||||
# Match PCC /start endpoint response format:
|
||||
return {"dailyRoom": room_url, "dailyToken": token}
|
||||
@@ -337,7 +334,7 @@ def _setup_daily_routes(app: FastAPI):
|
||||
|
||||
def _setup_telephony_routes(app: FastAPI, transport_type: str, proxy: str):
|
||||
"""Set up telephony-specific routes."""
|
||||
# XML response templates
|
||||
# XML response templates (Exotel doesn't use XML webhooks)
|
||||
XML_TEMPLATES = {
|
||||
"twilio": f"""<?xml version="1.0" encoding="UTF-8"?>
|
||||
<Response>
|
||||
@@ -362,9 +359,18 @@ def _setup_telephony_routes(app: FastAPI, transport_type: str, proxy: str):
|
||||
@app.post("/")
|
||||
async def start_call():
|
||||
"""Handle telephony webhook and return XML response."""
|
||||
logger.debug(f"POST {transport_type.upper()} XML")
|
||||
xml_content = XML_TEMPLATES.get(transport_type, "<Response></Response>")
|
||||
return HTMLResponse(content=xml_content, media_type="application/xml")
|
||||
if transport_type == "exotel":
|
||||
# Exotel doesn't use POST webhooks - redirect to proper documentation
|
||||
logger.debug("POST Exotel endpoint - not used")
|
||||
return {
|
||||
"error": "Exotel doesn't use POST webhooks",
|
||||
"websocket_url": f"wss://{proxy}/ws",
|
||||
"note": "Configure the WebSocket URL above in your Exotel App Bazaar Voicebot Applet",
|
||||
}
|
||||
else:
|
||||
logger.debug(f"POST {transport_type.upper()} XML")
|
||||
xml_content = XML_TEMPLATES.get(transport_type, "<Response></Response>")
|
||||
return HTMLResponse(content=xml_content, media_type="application/xml")
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
@@ -440,7 +446,7 @@ def main():
|
||||
Args:
|
||||
--host: Server host address (default: localhost)
|
||||
--port: Server port (default: 7860)
|
||||
-t/--transport: Transport type (daily, webrtc, twilio, telnyx, plivo)
|
||||
-t/--transport: Transport type (daily, webrtc, twilio, telnyx, plivo, exotel)
|
||||
-x/--proxy: Public proxy hostname for telephony webhooks
|
||||
--esp32: Enable SDP munging for ESP32 compatibility (requires --host with IP address)
|
||||
-d/--direct: Connect directly to Daily room (automatically sets transport to daily)
|
||||
@@ -455,7 +461,7 @@ def main():
|
||||
"-t",
|
||||
"--transport",
|
||||
type=str,
|
||||
choices=["daily", "webrtc", "twilio", "telnyx", "plivo"],
|
||||
choices=["daily", "webrtc", "twilio", "telnyx", "plivo", "exotel"],
|
||||
default="webrtc",
|
||||
help="Transport type",
|
||||
)
|
||||
|
||||
@@ -25,7 +25,7 @@ class RunnerArguments:
|
||||
pipeline_idle_timeout_secs: int = field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
self.handle_sigint = True
|
||||
self.handle_sigint = False
|
||||
self.handle_sigterm = False
|
||||
self.pipeline_idle_timeout_secs = 300
|
||||
|
||||
|
||||
@@ -77,6 +77,17 @@ def _detect_transport_type_from_message(message_data: dict) -> str:
|
||||
logger.trace("Auto-detected: PLIVO")
|
||||
return "plivo"
|
||||
|
||||
# Exotel detection
|
||||
if (
|
||||
message_data.get("event") == "start"
|
||||
and "start" in message_data
|
||||
and "stream_sid" in message_data.get("start", {})
|
||||
and "call_sid" in message_data.get("start", {})
|
||||
and "account_sid" in message_data.get("start", {})
|
||||
):
|
||||
logger.trace("Auto-detected: EXOTEL")
|
||||
return "exotel"
|
||||
|
||||
logger.trace("Auto-detection failed - unknown format")
|
||||
return "unknown"
|
||||
|
||||
@@ -91,6 +102,7 @@ async def parse_telephony_websocket(websocket: WebSocket):
|
||||
- Twilio: {"stream_id": str, "call_id": str}
|
||||
- Telnyx: {"stream_id": str, "call_control_id": str, "outbound_encoding": str}
|
||||
- Plivo: {"stream_id": str, "call_id": str}
|
||||
- Exotel: {"stream_id": str, "call_id": str, "account_sid": str}
|
||||
|
||||
Example usage::
|
||||
|
||||
@@ -160,6 +172,14 @@ async def parse_telephony_websocket(websocket: WebSocket):
|
||||
"call_id": start_data.get("callId"),
|
||||
}
|
||||
|
||||
elif transport_type == "exotel":
|
||||
start_data = call_data_raw.get("start", {})
|
||||
call_data = {
|
||||
"stream_id": start_data.get("stream_sid"),
|
||||
"call_id": start_data.get("call_sid"),
|
||||
"account_sid": start_data.get("account_sid"),
|
||||
}
|
||||
|
||||
else:
|
||||
call_data = {}
|
||||
|
||||
@@ -379,10 +399,17 @@ async def _create_telephony_transport(
|
||||
auth_id=os.getenv("PLIVO_AUTH_ID", ""),
|
||||
auth_token=os.getenv("PLIVO_AUTH_TOKEN", ""),
|
||||
)
|
||||
elif transport_type == "exotel":
|
||||
from pipecat.serializers.exotel import ExotelFrameSerializer
|
||||
|
||||
params.serializer = ExotelFrameSerializer(
|
||||
stream_sid=call_data["stream_id"],
|
||||
call_sid=call_data["call_id"],
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported telephony provider: {transport_type}. "
|
||||
f"Supported providers: twilio, telnyx, plivo"
|
||||
f"Supported providers: twilio, telnyx, plivo, exotel"
|
||||
)
|
||||
|
||||
return FastAPIWebsocketTransport(websocket=websocket, params=params)
|
||||
@@ -399,7 +426,7 @@ async def create_transport(
|
||||
Args:
|
||||
runner_args: Arguments from the runner.
|
||||
transport_params: Dict mapping transport names to parameter factory functions.
|
||||
Keys should be: "daily", "webrtc", "twilio", "telnyx", "plivo"
|
||||
Keys should be: "daily", "webrtc", "twilio", "telnyx", "plivo", "exotel"
|
||||
Values should be functions that return transport parameters when called.
|
||||
|
||||
Returns:
|
||||
@@ -440,6 +467,12 @@ async def create_transport(
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
# add_wav_header and serializer will be set automatically
|
||||
),
|
||||
"exotel": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
# add_wav_header and serializer will be set automatically
|
||||
),
|
||||
}
|
||||
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
|
||||
@@ -60,6 +60,7 @@ class AzureSTTService(STTService):
|
||||
region: str,
|
||||
language: Language = Language.EN_US,
|
||||
sample_rate: Optional[int] = None,
|
||||
endpoint_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Azure STT service.
|
||||
@@ -69,6 +70,7 @@ class AzureSTTService(STTService):
|
||||
region: Azure region for the Speech service (e.g., 'eastus').
|
||||
language: Language for speech recognition. Defaults to English (US).
|
||||
sample_rate: Audio sample rate in Hz. If None, uses service default.
|
||||
endpoint_id: Custom model endpoint id.
|
||||
**kwargs: Additional arguments passed to parent STTService.
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
@@ -79,6 +81,9 @@ class AzureSTTService(STTService):
|
||||
speech_recognition_language=language_to_azure_language(language),
|
||||
)
|
||||
|
||||
if endpoint_id:
|
||||
self._speech_config.endpoint_id = endpoint_id
|
||||
|
||||
self._audio_stream = None
|
||||
self._speech_recognizer = None
|
||||
self._settings = {
|
||||
|
||||
@@ -68,6 +68,16 @@ class AzureBaseTTSService(TTSService):
|
||||
construction, voice configuration, and parameter management.
|
||||
"""
|
||||
|
||||
# Define SSML escape mappings based on SSML reserved characters
|
||||
# See - https://learn.microsoft.com/en-us/azure/ai-services/speech-service/speech-synthesis-markup-structure
|
||||
SSML_ESCAPE_CHARS = {
|
||||
"&": "&",
|
||||
"<": "<",
|
||||
">": ">",
|
||||
'"': """,
|
||||
"'": "'",
|
||||
}
|
||||
|
||||
class InputParams(BaseModel):
|
||||
"""Input parameters for Azure TTS voice configuration.
|
||||
|
||||
@@ -154,6 +164,10 @@ class AzureBaseTTSService(TTSService):
|
||||
|
||||
def _construct_ssml(self, text: str) -> str:
|
||||
language = self._settings["language"]
|
||||
|
||||
# Escape special characters
|
||||
escaped_text = self._escape_text(text)
|
||||
|
||||
ssml = (
|
||||
f"<speak version='1.0' xml:lang='{language}' "
|
||||
"xmlns='http://www.w3.org/2001/10/synthesis' "
|
||||
@@ -183,7 +197,7 @@ class AzureBaseTTSService(TTSService):
|
||||
if self._settings["emphasis"]:
|
||||
ssml += f"<emphasis level='{self._settings['emphasis']}'>"
|
||||
|
||||
ssml += text
|
||||
ssml += escaped_text
|
||||
|
||||
if self._settings["emphasis"]:
|
||||
ssml += "</emphasis>"
|
||||
@@ -197,6 +211,27 @@ class AzureBaseTTSService(TTSService):
|
||||
|
||||
return ssml
|
||||
|
||||
def _escape_text(self, text: str) -> str:
|
||||
"""Escapes XML/SSML reserved characters according to Microsoft documentation.
|
||||
|
||||
This method escapes the following characters:
|
||||
- & becomes &
|
||||
- < becomes <
|
||||
- > becomes >
|
||||
- " becomes "
|
||||
- ' becomes '
|
||||
|
||||
Args:
|
||||
text: The text to escape.
|
||||
|
||||
Returns:
|
||||
The escaped text.
|
||||
"""
|
||||
escaped_text = text
|
||||
for char, escape_code in AzureBaseTTSService.SSML_ESCAPE_CHARS.items():
|
||||
escaped_text = escaped_text.replace(char, escape_code)
|
||||
return escaped_text
|
||||
|
||||
|
||||
class AzureTTSService(AzureBaseTTSService):
|
||||
"""Azure Cognitive Services streaming TTS service.
|
||||
|
||||
@@ -204,7 +204,7 @@ class GladiaSTTService(STTService):
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
region: Optional[Literal["us-west", "eu-west"]] = "eu-west",
|
||||
region: Literal["us-west", "eu-west"] | None = None,
|
||||
url: str = "https://api.gladia.io/v2/live",
|
||||
confidence: float = 0.5,
|
||||
sample_rate: Optional[int] = None,
|
||||
@@ -341,13 +341,6 @@ class GladiaSTTService(STTService):
|
||||
|
||||
return settings
|
||||
|
||||
def _get_endpoint_url(self) -> str:
|
||||
query_params = dict()
|
||||
query_params["region"] = self._region or "eu-west"
|
||||
query = urlencode(query_params)
|
||||
|
||||
return f"{self._url}?{query}"
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the Gladia STT websocket connection.
|
||||
|
||||
@@ -495,14 +488,16 @@ class GladiaSTTService(STTService):
|
||||
|
||||
async def _setup_gladia(self, settings: Dict[str, Any]):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
params = {}
|
||||
if self._region:
|
||||
params["region"] = self._region
|
||||
async with session.post(
|
||||
self._get_endpoint_url(),
|
||||
headers={"X-Gladia-Key": self._api_key, "Content-Type": "application/json"},
|
||||
self._url,
|
||||
headers={"X-Gladia-Key": self._api_key},
|
||||
json=settings,
|
||||
params=params,
|
||||
) as response:
|
||||
if response.ok:
|
||||
response_text = await response.json()
|
||||
logger.error(f"Gladia response: {response_text}")
|
||||
return await response.json()
|
||||
else:
|
||||
error_text = await response.text()
|
||||
|
||||
@@ -9,6 +9,9 @@
|
||||
This module provides integration with Google Cloud Text-to-Speech API,
|
||||
offering both HTTP-based synthesis with SSML support and streaming synthesis
|
||||
for real-time applications.
|
||||
|
||||
It also includes GeminiTTSService which uses Gemini's TTS-specific models
|
||||
for natural voice control and multi-speaker conversations.
|
||||
"""
|
||||
|
||||
import json
|
||||
@@ -19,7 +22,7 @@ from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
# Suppress gRPC fork warnings
|
||||
os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false"
|
||||
|
||||
from typing import AsyncGenerator, Literal, Optional
|
||||
from typing import AsyncGenerator, List, Literal, Optional
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
@@ -27,6 +30,7 @@ from pydantic import BaseModel
|
||||
from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
@@ -47,6 +51,15 @@ except ModuleNotFoundError as e:
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
try:
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use Gemini TTS, you need to `pip install pipecat-ai[google]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
def language_to_google_tts_language(language: Language) -> Optional[str]:
|
||||
"""Convert a Language enum to Google TTS language code.
|
||||
@@ -642,3 +655,252 @@ class GoogleTTSService(TTSService):
|
||||
logger.exception(f"{self} error generating TTS: {e}")
|
||||
error_message = f"TTS generation error: {str(e)}"
|
||||
yield ErrorFrame(error=error_message)
|
||||
|
||||
|
||||
class GeminiTTSService(TTSService):
|
||||
"""Gemini Text-to-Speech service using Gemini TTS models.
|
||||
|
||||
Provides text-to-speech synthesis using Gemini's TTS-specific models
|
||||
(gemini-2.5-flash-preview-tts and gemini-2.5-pro-preview-tts) with
|
||||
support for natural voice control, multiple speakers, and voice styles.
|
||||
|
||||
Note:
|
||||
Requires Google AI API key. This uses the Gemini API, not Google Cloud TTS.
|
||||
Audio-out is currently a preview feature.
|
||||
|
||||
Example::
|
||||
|
||||
tts = GeminiTTSService(
|
||||
api_key="your-google-ai-api-key",
|
||||
model="gemini-2.5-flash-preview-tts",
|
||||
voice_id="Kore",
|
||||
params=GeminiTTSService.InputParams(
|
||||
language=Language.EN_US,
|
||||
)
|
||||
)
|
||||
"""
|
||||
|
||||
GOOGLE_SAMPLE_RATE = 24000 # Google TTS always outputs at 24kHz
|
||||
|
||||
# List of available Gemini TTS voices
|
||||
AVAILABLE_VOICES = [
|
||||
"Zephyr",
|
||||
"Puck",
|
||||
"Charon",
|
||||
"Kore",
|
||||
"Fenrir",
|
||||
"Leda",
|
||||
"Orus",
|
||||
"Aoede",
|
||||
"Callirhoe",
|
||||
"Autonoe",
|
||||
"Enceladus",
|
||||
"Iapetus",
|
||||
"Umbriel",
|
||||
"Algieba",
|
||||
"Despina",
|
||||
"Erinome",
|
||||
"Algenib",
|
||||
"Rasalgethi",
|
||||
"Laomedeia",
|
||||
"Achernar",
|
||||
"Alnilam",
|
||||
"Schedar",
|
||||
"Gacrux",
|
||||
"Pulcherrima",
|
||||
"Achird",
|
||||
"Zubenelgenubi",
|
||||
"Vindemiatrix",
|
||||
"Sadachbia",
|
||||
"Sadaltager",
|
||||
"Sulafar",
|
||||
]
|
||||
|
||||
class InputParams(BaseModel):
|
||||
"""Input parameters for Gemini TTS configuration.
|
||||
|
||||
Parameters:
|
||||
language: Language for synthesis. Defaults to English.
|
||||
multi_speaker: Whether to enable multi-speaker support.
|
||||
speaker_configs: List of speaker configurations for multi-speaker mode.
|
||||
"""
|
||||
|
||||
language: Optional[Language] = Language.EN
|
||||
multi_speaker: bool = False
|
||||
speaker_configs: Optional[List[dict]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
model: str = "gemini-2.5-flash-preview-tts",
|
||||
voice_id: str = "Kore",
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initializes the Gemini TTS service.
|
||||
|
||||
Args:
|
||||
api_key: Google AI API key for authentication.
|
||||
model: Gemini TTS model to use. Must be a TTS model like
|
||||
"gemini-2.5-flash-preview-tts" or "gemini-2.5-pro-preview-tts".
|
||||
voice_id: Voice name from the available Gemini voices.
|
||||
sample_rate: Audio sample rate in Hz. If None, uses Google's default 24kHz.
|
||||
params: TTS configuration parameters.
|
||||
**kwargs: Additional arguments passed to parent TTSService.
|
||||
"""
|
||||
if sample_rate and sample_rate != self.GOOGLE_SAMPLE_RATE:
|
||||
logger.warning(
|
||||
f"Google TTS only supports {self.GOOGLE_SAMPLE_RATE}Hz sample rate. "
|
||||
f"Current rate of {sample_rate}Hz may cause issues."
|
||||
)
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
params = params or GeminiTTSService.InputParams()
|
||||
|
||||
if voice_id not in self.AVAILABLE_VOICES:
|
||||
logger.warning(f"Voice '{voice_id}' not in known voices list. Using anyway.")
|
||||
|
||||
self._api_key = api_key
|
||||
self._model = model
|
||||
self._voice_id = voice_id
|
||||
self._settings = {
|
||||
"language": self.language_to_service_language(params.language)
|
||||
if params.language
|
||||
else "en-US",
|
||||
"multi_speaker": params.multi_speaker,
|
||||
"speaker_configs": params.speaker_configs,
|
||||
}
|
||||
|
||||
self._client = genai.Client(api_key=api_key)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
|
||||
Returns:
|
||||
True, as Gemini TTS service supports metrics generation.
|
||||
"""
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert a Language enum to Gemini TTS language format.
|
||||
|
||||
Args:
|
||||
language: The language to convert.
|
||||
|
||||
Returns:
|
||||
The Gemini TTS-specific language code, or None if not supported.
|
||||
"""
|
||||
return language_to_google_tts_language(language)
|
||||
|
||||
def set_voice(self, voice_id: str):
|
||||
"""Set the voice for TTS generation.
|
||||
|
||||
Args:
|
||||
voice_id: Name of the voice to use from AVAILABLE_VOICES.
|
||||
"""
|
||||
if voice_id not in self.AVAILABLE_VOICES:
|
||||
logger.warning(f"Voice '{voice_id}' not in known voices list. Using anyway.")
|
||||
self._voice_id = voice_id
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the Gemini TTS service.
|
||||
|
||||
Args:
|
||||
frame: The start frame containing initialization parameters.
|
||||
"""
|
||||
await super().start(frame)
|
||||
if self.sample_rate != self.GOOGLE_SAMPLE_RATE:
|
||||
logger.warning(
|
||||
f"Google TTS requires {self.GOOGLE_SAMPLE_RATE}Hz sample rate. "
|
||||
f"Current rate of {self.sample_rate}Hz may cause issues."
|
||||
)
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate speech from text using Gemini TTS models.
|
||||
|
||||
Args:
|
||||
text: The text to synthesize into speech. Can include natural language
|
||||
instructions for style, tone, etc.
|
||||
|
||||
Yields:
|
||||
Frame: Audio frames containing the synthesized speech.
|
||||
"""
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
# Build the speech config
|
||||
if self._settings["multi_speaker"] and self._settings["speaker_configs"]:
|
||||
# Multi-speaker mode
|
||||
speaker_voice_configs = []
|
||||
for speaker_config in self._settings["speaker_configs"]:
|
||||
speaker_voice_configs.append(
|
||||
types.SpeakerVoiceConfig(
|
||||
speaker=speaker_config["speaker"],
|
||||
voice_config=types.VoiceConfig(
|
||||
prebuilt_voice_config=types.PrebuiltVoiceConfig(
|
||||
voice_name=speaker_config.get("voice_id", self._voice_id)
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
speech_config = types.SpeechConfig(
|
||||
multi_speaker_voice_config=types.MultiSpeakerVoiceConfig(
|
||||
speaker_voice_configs=speaker_voice_configs
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Single speaker mode
|
||||
speech_config = types.SpeechConfig(
|
||||
voice_config=types.VoiceConfig(
|
||||
prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name=self._voice_id)
|
||||
)
|
||||
)
|
||||
|
||||
# Create the generation config
|
||||
generation_config = types.GenerateContentConfig(
|
||||
response_modalities=["AUDIO"],
|
||||
speech_config=speech_config,
|
||||
)
|
||||
|
||||
# Generate the content
|
||||
response = await self._client.aio.models.generate_content(
|
||||
model=self._model,
|
||||
contents=text,
|
||||
config=generation_config,
|
||||
)
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
yield TTSStartedFrame()
|
||||
|
||||
# Extract audio data from response
|
||||
if response.candidates and len(response.candidates) > 0:
|
||||
candidate = response.candidates[0]
|
||||
if candidate.content and candidate.content.parts:
|
||||
for part in candidate.content.parts:
|
||||
if part.inline_data and part.inline_data.mime_type.startswith("audio/"):
|
||||
audio_data = part.inline_data.data
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
# Gemini TTS returns PCM audio data, chunk it appropriately
|
||||
CHUNK_SIZE = self.chunk_size
|
||||
|
||||
for i in range(0, len(audio_data), CHUNK_SIZE):
|
||||
chunk = audio_data[i : i + CHUNK_SIZE]
|
||||
if not chunk:
|
||||
break
|
||||
frame = TTSAudioRawFrame(chunk, self.sample_rate, 1)
|
||||
yield frame
|
||||
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} error generating TTS: {e}")
|
||||
error_message = f"Gemini TTS generation error: {str(e)}"
|
||||
yield ErrorFrame(error=error_message)
|
||||
|
||||
@@ -362,7 +362,7 @@ class HeyGenClient:
|
||||
"""Simulate audio playback timing with appropriate delays."""
|
||||
# Only sleep after we've sent the first second of audio
|
||||
# This appears to reduce the latency to receive the answer from HeyGen
|
||||
if self._audio_seconds_sent < 1.0:
|
||||
if self._audio_seconds_sent < 3.0:
|
||||
self._audio_seconds_sent += self._send_interval
|
||||
self._next_send_time = time.monotonic() + self._send_interval
|
||||
return
|
||||
|
||||
@@ -20,6 +20,7 @@ from loguru import logger
|
||||
from pipecat.audio.utils import create_stream_resampler
|
||||
from pipecat.frames.frames import (
|
||||
AudioRawFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
@@ -30,6 +31,7 @@ from pipecat.frames.frames import (
|
||||
SpeechOutputAudioRawFrame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
@@ -232,9 +234,24 @@ class HeyGenVideoService(AIService):
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, TTSAudioRawFrame):
|
||||
await self._handle_audio_frame(frame)
|
||||
elif isinstance(frame, TTSStartedFrame):
|
||||
await self.start_ttfb_metrics()
|
||||
elif isinstance(frame, BotStartedSpeakingFrame):
|
||||
# We constantly receive audio through WebRTC, but most of the time it is silence.
|
||||
# As soon as we receive actual audio, the base output transport will create a
|
||||
# BotStartedSpeakingFrame, which we can use as a signal for the TTFB metrics.
|
||||
await self.stop_ttfb_metrics()
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if the service can generate metrics.
|
||||
|
||||
Returns:
|
||||
True if metrics generation is supported.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def _handle_user_started_speaking(self):
|
||||
"""Handle the event when a user starts speaking.
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ class MoondreamService(VisionService):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, *, model="vikhyatk/moondream2", revision="2024-08-26", use_cpu=False, **kwargs
|
||||
self, *, model="vikhyatk/moondream2", revision="2025-01-09", use_cpu=False, **kwargs
|
||||
):
|
||||
"""Initialize the Moondream service.
|
||||
|
||||
@@ -82,14 +82,15 @@ class MoondreamService(VisionService):
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(model, revision=revision)
|
||||
|
||||
logger.debug("Loading Moondream model...")
|
||||
|
||||
self._model = AutoModelForCausalLM.from_pretrained(
|
||||
model, trust_remote_code=True, revision=revision
|
||||
).to(device=device, dtype=dtype)
|
||||
self._model.eval()
|
||||
model,
|
||||
trust_remote_code=True,
|
||||
revision=revision,
|
||||
device_map={"": device},
|
||||
torch_dtype=dtype,
|
||||
).eval()
|
||||
|
||||
logger.debug("Loaded Moondream model")
|
||||
|
||||
@@ -121,9 +122,7 @@ class MoondreamService(VisionService):
|
||||
"""
|
||||
image = Image.frombytes(frame.format, frame.size, frame.image)
|
||||
image_embeds = self._model.encode_image(image)
|
||||
description = self._model.answer_question(
|
||||
image_embeds=image_embeds, question=frame.text, tokenizer=self._tokenizer
|
||||
)
|
||||
description = self._model.query(image_embeds, frame.text)["answer"]
|
||||
return description
|
||||
|
||||
description = await asyncio.to_thread(get_image_description, frame)
|
||||
|
||||
@@ -171,6 +171,15 @@ class OpenAIRealtimeBetaLLMService(LLMService):
|
||||
"""
|
||||
self._audio_input_paused = paused
|
||||
|
||||
def _is_modality_enabled(self, modality: str) -> bool:
|
||||
"""Check if a specific modality is enabled, "text" or "audio"."""
|
||||
modalities = self._session_properties.modalities or ["audio", "text"]
|
||||
return modality in modalities
|
||||
|
||||
def _get_enabled_modalities(self) -> list[str]:
|
||||
"""Get the list of enabled modalities."""
|
||||
return self._session_properties.modalities or ["audio", "text"]
|
||||
|
||||
async def retrieve_conversation_item(self, item_id: str):
|
||||
"""Retrieve a conversation item by ID from the server.
|
||||
|
||||
@@ -243,7 +252,9 @@ class OpenAIRealtimeBetaLLMService(LLMService):
|
||||
await self.stop_all_metrics()
|
||||
if self._current_assistant_response:
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
# Only push TTSStoppedFrame if audio modality is enabled
|
||||
if self._is_modality_enabled("audio"):
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
|
||||
async def _handle_user_started_speaking(self, frame):
|
||||
pass
|
||||
@@ -469,6 +480,8 @@ class OpenAIRealtimeBetaLLMService(LLMService):
|
||||
await self._handle_evt_speech_started(evt)
|
||||
elif evt.type == "input_audio_buffer.speech_stopped":
|
||||
await self._handle_evt_speech_stopped(evt)
|
||||
elif evt.type == "response.text.delta":
|
||||
await self._handle_evt_text_delta(evt)
|
||||
elif evt.type == "response.audio_transcript.delta":
|
||||
await self._handle_evt_audio_transcript_delta(evt)
|
||||
elif evt.type == "error":
|
||||
@@ -617,6 +630,10 @@ class OpenAIRealtimeBetaLLMService(LLMService):
|
||||
# Response message without preceding user message. Add it to the context.
|
||||
await self._handle_assistant_output(evt.response.output)
|
||||
|
||||
async def _handle_evt_text_delta(self, evt):
|
||||
if evt.delta:
|
||||
await self.push_frame(LLMTextFrame(evt.delta))
|
||||
|
||||
async def _handle_evt_audio_transcript_delta(self, evt):
|
||||
if evt.delta:
|
||||
await self.push_frame(LLMTextFrame(evt.delta))
|
||||
@@ -723,7 +740,7 @@ class OpenAIRealtimeBetaLLMService(LLMService):
|
||||
await self.start_ttfb_metrics()
|
||||
await self.send_client_event(
|
||||
events.ResponseCreateEvent(
|
||||
response=events.ResponseProperties(modalities=["audio", "text"])
|
||||
response=events.ResponseProperties(modalities=self._get_enabled_modalities())
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ import io
|
||||
import json
|
||||
import struct
|
||||
import uuid
|
||||
import warnings
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
import aiohttp
|
||||
@@ -37,14 +38,11 @@ from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
try:
|
||||
from pyht.async_client import AsyncClient
|
||||
from pyht.client import Format, TTSOptions
|
||||
from pyht.client import Language as PlayHTLanguage
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
from websockets.protocol import State
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use PlayHT, you need to `pip install pipecat-ai[playht]`.")
|
||||
logger.error("In order to use PlayHTTTSService, you need to `pip install pipecat-ai[playht]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
@@ -429,7 +427,8 @@ class PlayHTHttpTTSService(TTSService):
|
||||
user_id: str,
|
||||
voice_url: str,
|
||||
voice_engine: str = "Play3.0-mini",
|
||||
protocol: str = "http", # Options: http, ws
|
||||
protocol: Optional[str] = None,
|
||||
output_format: str = "wav",
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
**kwargs,
|
||||
@@ -441,40 +440,46 @@ class PlayHTHttpTTSService(TTSService):
|
||||
user_id: PlayHT user ID for authentication.
|
||||
voice_url: URL of the voice to use for synthesis.
|
||||
voice_engine: Voice engine to use. Defaults to "Play3.0-mini".
|
||||
protocol: Protocol to use ("http" or "ws"). Defaults to "http".
|
||||
protocol: Protocol to use ("http" or "ws").
|
||||
|
||||
.. deprecated:: 0.0.80
|
||||
This parameter no longer has any effect and will be removed in a future version.
|
||||
Use PlayHTTTSService for WebSocket or PlayHTHttpTTSService for HTTP.
|
||||
|
||||
output_format: Audio output format. Defaults to "wav".
|
||||
sample_rate: Audio sample rate. If None, uses default.
|
||||
params: Additional input parameters for voice customization.
|
||||
**kwargs: Additional arguments passed to parent TTSService.
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
# Warn about deprecated protocol parameter if explicitly provided
|
||||
if protocol:
|
||||
warnings.warn(
|
||||
"The 'protocol' parameter is deprecated and will be removed in a future version.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
params = params or PlayHTHttpTTSService.InputParams()
|
||||
|
||||
self._user_id = user_id
|
||||
self._api_key = api_key
|
||||
|
||||
self._client = AsyncClient(
|
||||
user_id=self._user_id,
|
||||
api_key=self._api_key,
|
||||
)
|
||||
|
||||
# Check if voice_engine contains protocol information (backward compatibility)
|
||||
if "-http" in voice_engine:
|
||||
# Extract the base engine name
|
||||
voice_engine = voice_engine.replace("-http", "")
|
||||
protocol = "http"
|
||||
elif "-ws" in voice_engine:
|
||||
# Extract the base engine name
|
||||
voice_engine = voice_engine.replace("-ws", "")
|
||||
protocol = "ws"
|
||||
|
||||
self._settings = {
|
||||
"language": self.language_to_service_language(params.language)
|
||||
if params.language
|
||||
else "english",
|
||||
"format": Format.FORMAT_WAV,
|
||||
"output_format": output_format,
|
||||
"voice_engine": voice_engine,
|
||||
"protocol": protocol,
|
||||
"speed": params.speed,
|
||||
"seed": params.seed,
|
||||
}
|
||||
@@ -490,26 +495,6 @@ class PlayHTHttpTTSService(TTSService):
|
||||
await super().start(frame)
|
||||
self._settings["sample_rate"] = self.sample_rate
|
||||
|
||||
def _create_options(self) -> TTSOptions:
|
||||
"""Create TTSOptions object from current settings."""
|
||||
language_str = self._settings["language"]
|
||||
playht_language = None
|
||||
if language_str:
|
||||
# Convert string to PlayHT Language enum
|
||||
for lang in PlayHTLanguage:
|
||||
if lang.value == language_str:
|
||||
playht_language = lang
|
||||
break
|
||||
|
||||
return TTSOptions(
|
||||
voice=self._voice_id,
|
||||
language=playht_language,
|
||||
sample_rate=self.sample_rate,
|
||||
format=self._settings["format"],
|
||||
speed=self._settings["speed"],
|
||||
seed=self._settings["seed"],
|
||||
)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
|
||||
@@ -542,41 +527,78 @@ class PlayHTHttpTTSService(TTSService):
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
try:
|
||||
options = self._create_options()
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
playht_gen = self._client.tts(
|
||||
text,
|
||||
voice_engine=self._settings["voice_engine"],
|
||||
protocol=self._settings["protocol"],
|
||||
options=options,
|
||||
)
|
||||
# Prepare the request payload
|
||||
payload = {
|
||||
"text": text,
|
||||
"voice": self._voice_id,
|
||||
"voice_engine": self._settings["voice_engine"],
|
||||
"output_format": self._settings["output_format"],
|
||||
"sample_rate": self.sample_rate,
|
||||
"language": self._settings["language"],
|
||||
}
|
||||
|
||||
# Add optional parameters if they exist
|
||||
if self._settings["speed"] is not None:
|
||||
payload["speed"] = self._settings["speed"]
|
||||
if self._settings["seed"] is not None:
|
||||
payload["seed"] = self._settings["seed"]
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self._api_key}",
|
||||
"X-User-Id": self._user_id,
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "*/*",
|
||||
}
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
yield TTSStartedFrame()
|
||||
|
||||
b = bytearray()
|
||||
in_header = True
|
||||
async for chunk in playht_gen:
|
||||
# skip the RIFF header.
|
||||
if in_header:
|
||||
b.extend(chunk)
|
||||
if len(b) <= 36:
|
||||
continue
|
||||
else:
|
||||
fh = io.BytesIO(b)
|
||||
fh.seek(36)
|
||||
(data, size) = struct.unpack("<4sI", fh.read(8))
|
||||
while data != b"data":
|
||||
fh.read(size)
|
||||
(data, size) = struct.unpack("<4sI", fh.read(8))
|
||||
in_header = False
|
||||
elif len(chunk) > 0:
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(chunk, self.sample_rate, 1)
|
||||
yield frame
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
"https://api.play.ht/api/v2/tts/stream",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
) as response:
|
||||
if response.status not in (200, 201):
|
||||
error_text = await response.text()
|
||||
raise Exception(f"PlayHT API error {response.status}: {error_text}")
|
||||
|
||||
in_header = True
|
||||
buffer = b""
|
||||
|
||||
CHUNK_SIZE = self.chunk_size
|
||||
|
||||
async for chunk in response.content.iter_chunked(CHUNK_SIZE):
|
||||
if len(chunk) == 0:
|
||||
continue
|
||||
|
||||
# Skip the RIFF header
|
||||
if in_header:
|
||||
buffer += chunk
|
||||
if len(buffer) <= 36:
|
||||
continue
|
||||
else:
|
||||
fh = io.BytesIO(buffer)
|
||||
fh.seek(36)
|
||||
(data, size) = struct.unpack("<4sI", fh.read(8))
|
||||
while data != b"data":
|
||||
fh.read(size)
|
||||
(data, size) = struct.unpack("<4sI", fh.read(8))
|
||||
# Extract audio data after header
|
||||
audio_data = buffer[fh.tell() :]
|
||||
if len(audio_data) > 0:
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(audio_data, self.sample_rate, 1)
|
||||
yield frame
|
||||
in_header = False
|
||||
elif len(chunk) > 0:
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(chunk, self.sample_rate, 1)
|
||||
yield frame
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error generating TTS: {e}")
|
||||
finally:
|
||||
|
||||
@@ -19,6 +19,7 @@ from loguru import logger
|
||||
|
||||
from pipecat.audio.utils import create_stream_resampler
|
||||
from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
@@ -29,6 +30,7 @@ from pipecat.frames.frames import (
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessorSetup
|
||||
from pipecat.services.ai_service import AIService
|
||||
@@ -229,6 +231,13 @@ class TavusVideoService(AIService):
|
||||
elif isinstance(frame, OutputTransportReadyFrame):
|
||||
self._transport_ready = True
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, TTSStartedFrame):
|
||||
await self.start_ttfb_metrics()
|
||||
elif isinstance(frame, BotStartedSpeakingFrame):
|
||||
# We constantly receive audio through WebRTC, but most of the time it is silence.
|
||||
# As soon as we receive actual audio, the base output transport will create a
|
||||
# BotStartedSpeakingFrame, which we can use as a signal for the TTFB metrics.
|
||||
await self.stop_ttfb_metrics()
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ from typing import Awaitable, Callable, Optional
|
||||
|
||||
import websockets
|
||||
from loguru import logger
|
||||
from websockets.exceptions import ConnectionClosedOK
|
||||
from websockets.protocol import State
|
||||
|
||||
from pipecat.frames.frames import ErrorFrame
|
||||
@@ -82,12 +83,10 @@ class WebsocketService(ABC):
|
||||
try:
|
||||
await self._receive_messages()
|
||||
retry_count = 0 # Reset counter on successful message receive
|
||||
if self._websocket and self._websocket.state is State.CLOSED:
|
||||
raise websockets.ConnectionClosedOK(
|
||||
self._websocket.close_rcvd,
|
||||
self._websocket.close_sent,
|
||||
self._websocket.close_rcvd_then_sent,
|
||||
)
|
||||
except ConnectionClosedOK as e:
|
||||
# Normal closure, don't retry
|
||||
logger.debug(f"{self} connection closed normally: {e}")
|
||||
break
|
||||
except Exception as e:
|
||||
message = f"{self} error receiving messages: {e}"
|
||||
logger.error(message)
|
||||
|
||||
@@ -186,7 +186,7 @@ class BaseWhisperSTTService(SegmentedSTTService):
|
||||
language: The Language enum value to use for transcription.
|
||||
"""
|
||||
logger.info(f"Switching STT language to: [{language}]")
|
||||
self._language = language
|
||||
self._language = self.language_to_service_language(language)
|
||||
|
||||
@traced_stt
|
||||
async def _handle_transcription(
|
||||
|
||||
@@ -12,7 +12,6 @@ output processing, including frame buffering, mixing, timing, and media streamin
|
||||
|
||||
import asyncio
|
||||
import itertools
|
||||
import sys
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, AsyncGenerator, Dict, List, Mapping, Optional
|
||||
@@ -429,7 +428,7 @@ class BaseOutputTransport(FrameProcessor):
|
||||
frame: The end frame signaling sender shutdown.
|
||||
"""
|
||||
# Let the sink tasks process the queue until they reach this EndFrame.
|
||||
await self._clock_queue.put((sys.maxsize, frame.id, frame))
|
||||
await self._clock_queue.put((float("inf"), frame.id, frame))
|
||||
await self._audio_queue.put(frame)
|
||||
|
||||
# At this point we have enqueued an EndFrame and we need to wait for
|
||||
@@ -828,7 +827,9 @@ class BaseOutputTransport(FrameProcessor):
|
||||
def _create_clock_task(self):
|
||||
"""Create the clock/timing processing task."""
|
||||
if not self._clock_task:
|
||||
self._clock_queue = WatchdogPriorityQueue(self._transport.task_manager)
|
||||
self._clock_queue = WatchdogPriorityQueue(
|
||||
self._transport.task_manager, tuple_size=3
|
||||
)
|
||||
self._clock_task = self._transport.create_task(self._clock_task_handler())
|
||||
|
||||
async def _cancel_clock_task(self):
|
||||
|
||||
@@ -226,6 +226,7 @@ class SmallWebRTCClient:
|
||||
self._audio_in_channels = None
|
||||
self._in_sample_rate = None
|
||||
self._out_sample_rate = None
|
||||
self._leave_counter = 0
|
||||
|
||||
# We are always resampling it for 16000 if the sample_rate that we receive is bigger than that.
|
||||
# otherwise we face issues with Silero VAD
|
||||
@@ -395,6 +396,7 @@ class SmallWebRTCClient:
|
||||
self._in_sample_rate = _params.audio_in_sample_rate or frame.audio_in_sample_rate
|
||||
self._out_sample_rate = _params.audio_out_sample_rate or frame.audio_out_sample_rate
|
||||
self._params = _params
|
||||
self._leave_counter += 1
|
||||
|
||||
async def connect(self):
|
||||
"""Establish the WebRTC connection."""
|
||||
@@ -407,6 +409,10 @@ class SmallWebRTCClient:
|
||||
|
||||
async def disconnect(self):
|
||||
"""Disconnect from the WebRTC peer."""
|
||||
self._leave_counter -= 1
|
||||
if self._leave_counter > 0:
|
||||
return
|
||||
|
||||
if self.is_connected and not self.is_closing:
|
||||
logger.info(f"Disconnecting to Small WebRTC")
|
||||
self._closing = True
|
||||
|
||||
@@ -560,7 +560,7 @@ class DailyTransportClient(EventHandler):
|
||||
self._out_sample_rate = self._params.audio_out_sample_rate or frame.audio_out_sample_rate
|
||||
|
||||
if self._params.audio_in_enabled:
|
||||
if self._params.audio_in_user_tracks and not self._audio_task:
|
||||
if self._params.audio_in_user_tracks and not self._audio_task and self._task_manager:
|
||||
self._audio_queue = WatchdogQueue(self._task_manager)
|
||||
self._audio_task = self._task_manager.create_task(
|
||||
self._callback_task_handler(self._audio_queue),
|
||||
|
||||
@@ -20,23 +20,42 @@ from pipecat.utils.asyncio.task_manager import BaseTaskManager
|
||||
|
||||
|
||||
@dataclass
|
||||
class _WatchdogPriorityCancelSentinel:
|
||||
def __lt__(self, other):
|
||||
return True
|
||||
class WatchdogPriorityCancelSentinel:
|
||||
"""Sentinel object used in priority queues to force cancellation.
|
||||
|
||||
An instance of this class is typically inserted into a
|
||||
`WatchdogPriorityQueue` to act as a high-priority marker asyncio task
|
||||
cancellation.
|
||||
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class WatchdogPriorityQueue(asyncio.PriorityQueue):
|
||||
"""Watchdog-enabled asyncio PriorityQueue.
|
||||
"""Class for watchdog-enabled asyncio PriorityQueue.
|
||||
|
||||
An asynchronous priority queue that resets the current task watchdog
|
||||
timer. This is necessary to avoid task watchdog timers to expire while we
|
||||
are waiting to get an item from the queue.
|
||||
|
||||
This queue expects items to be tuples, with the actual payload stored
|
||||
in the last element. All preceding elements are treated as numeric
|
||||
priority fields. For example:
|
||||
|
||||
(0, 1, "foo")
|
||||
|
||||
The tuple length must be specified at creation time so the queue can
|
||||
correctly construct special items, such as the watchdog cancel sentinel,
|
||||
with the proper tuple structure.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manager: BaseTaskManager,
|
||||
*,
|
||||
tuple_size: int,
|
||||
maxsize: int = 0,
|
||||
timeout: float = 2.0,
|
||||
) -> None:
|
||||
@@ -44,12 +63,14 @@ class WatchdogPriorityQueue(asyncio.PriorityQueue):
|
||||
|
||||
Args:
|
||||
manager: The task manager for watchdog timer control.
|
||||
tuple_size: The number of values in each inserted tuple.
|
||||
maxsize: Maximum queue size. 0 means unlimited.
|
||||
timeout: Timeout in seconds between watchdog resets while waiting.
|
||||
"""
|
||||
super().__init__(maxsize)
|
||||
self._manager = manager
|
||||
self._timeout = timeout
|
||||
self._tuple_size = tuple_size
|
||||
|
||||
async def get(self):
|
||||
"""Get an item from the queue with watchdog monitoring.
|
||||
@@ -62,7 +83,10 @@ class WatchdogPriorityQueue(asyncio.PriorityQueue):
|
||||
else:
|
||||
get_result = await super().get()
|
||||
|
||||
if isinstance(get_result, _WatchdogPriorityCancelSentinel):
|
||||
# Value is always at the end of the tuple.
|
||||
item = get_result[-1]
|
||||
|
||||
if isinstance(item, WatchdogPriorityCancelSentinel):
|
||||
logger.trace(
|
||||
"Received WatchdogPriorityCancelSentinel, throwing CancelledError to force cancelling"
|
||||
)
|
||||
@@ -91,7 +115,10 @@ class WatchdogPriorityQueue(asyncio.PriorityQueue):
|
||||
forces the task to raise CancelledError when consumed, ensuring proper
|
||||
task termination.
|
||||
"""
|
||||
super().put_nowait(_WatchdogPriorityCancelSentinel())
|
||||
item = [float("-inf")] * self._tuple_size
|
||||
# Values go always at the end.
|
||||
item[-1] = WatchdogPriorityCancelSentinel()
|
||||
super().put_nowait(tuple(item))
|
||||
|
||||
async def _watchdog_get(self):
|
||||
"""Get item from queue while periodically resetting watchdog timer."""
|
||||
|
||||
@@ -20,7 +20,14 @@ from pipecat.utils.asyncio.task_manager import BaseTaskManager
|
||||
|
||||
|
||||
@dataclass
|
||||
class _WatchdogQueueCancelSentinel:
|
||||
class WatchdogQueueCancelSentinel:
|
||||
"""Sentinel object used in queues to force cancellation.
|
||||
|
||||
An instance of this class is typically inserted into a `WatchdogQueue` to
|
||||
act as a marker for asyncio task cancellation.
|
||||
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@@ -61,7 +68,7 @@ class WatchdogQueue(asyncio.Queue):
|
||||
else:
|
||||
get_result = await super().get()
|
||||
|
||||
if isinstance(get_result, _WatchdogQueueCancelSentinel):
|
||||
if isinstance(get_result, WatchdogQueueCancelSentinel):
|
||||
logger.trace(
|
||||
"Received WatchdogQueueCancelFrame, throwing CancelledError to force cancelling"
|
||||
)
|
||||
@@ -90,7 +97,7 @@ class WatchdogQueue(asyncio.Queue):
|
||||
forces the task to raise CancelledError when consumed, ensuring proper
|
||||
task termination.
|
||||
"""
|
||||
super().put_nowait(_WatchdogQueueCancelSentinel())
|
||||
super().put_nowait(WatchdogQueueCancelSentinel())
|
||||
|
||||
async def _watchdog_get(self):
|
||||
"""Get item from queue while periodically resetting watchdog timer."""
|
||||
|
||||
@@ -17,6 +17,8 @@ from pipecat.frames.frames import (
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.filters.stt_mute_filter import STTMuteConfig, STTMuteFilter, STTMuteStrategy
|
||||
from pipecat.tests.utils import SleepFrame, run_test
|
||||
@@ -28,15 +30,19 @@ class TestSTTMuteFilter(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
frames_to_send = [
|
||||
BotStartedSpeakingFrame(), # First bot speech starts
|
||||
VADUserStartedSpeakingFrame(), # Should be suppressed
|
||||
UserStartedSpeakingFrame(), # Should be suppressed
|
||||
InputAudioRawFrame(
|
||||
audio=b"", sample_rate=16000, num_channels=1
|
||||
), # Should be suppressed
|
||||
VADUserStoppedSpeakingFrame(), # Should be suppressed
|
||||
UserStoppedSpeakingFrame(), # Should be suppressed
|
||||
BotStoppedSpeakingFrame(), # First bot speech ends
|
||||
BotStartedSpeakingFrame(), # Second bot speech
|
||||
VADUserStartedSpeakingFrame(), # Should pass through
|
||||
UserStartedSpeakingFrame(), # Should pass through
|
||||
InputAudioRawFrame(audio=b"", sample_rate=16000, num_channels=1), # Should pass through
|
||||
VADUserStoppedSpeakingFrame(), # Should pass through
|
||||
UserStoppedSpeakingFrame(), # Should pass through
|
||||
BotStoppedSpeakingFrame(),
|
||||
]
|
||||
@@ -47,8 +53,10 @@ class TestSTTMuteFilter(unittest.IsolatedAsyncioTestCase):
|
||||
BotStoppedSpeakingFrame,
|
||||
STTMuteFrame, # mute=False
|
||||
BotStartedSpeakingFrame,
|
||||
VADUserStartedSpeakingFrame, # Now passes through
|
||||
UserStartedSpeakingFrame, # Now passes through
|
||||
InputAudioRawFrame, # Now passes through
|
||||
VADUserStoppedSpeakingFrame, # Now passes through
|
||||
UserStoppedSpeakingFrame, # Now passes through
|
||||
BotStoppedSpeakingFrame,
|
||||
]
|
||||
@@ -64,20 +72,26 @@ class TestSTTMuteFilter(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
frames_to_send = [
|
||||
BotStartedSpeakingFrame(), # First speech starts
|
||||
VADUserStartedSpeakingFrame(), # Should be suppressed
|
||||
UserStartedSpeakingFrame(), # Should be suppressed
|
||||
InputAudioRawFrame(
|
||||
audio=b"", sample_rate=16000, num_channels=1
|
||||
), # Should be suppressed
|
||||
VADUserStoppedSpeakingFrame(), # Should be suppressed
|
||||
UserStoppedSpeakingFrame(), # Should be suppressed
|
||||
BotStoppedSpeakingFrame(), # First speech ends
|
||||
VADUserStartedSpeakingFrame(), # Should pass through
|
||||
UserStartedSpeakingFrame(), # Should pass through
|
||||
InputAudioRawFrame(audio=b"", sample_rate=16000, num_channels=1), # Should pass through
|
||||
VADUserStoppedSpeakingFrame(), # Should pass through
|
||||
UserStoppedSpeakingFrame(), # Should pass through
|
||||
BotStartedSpeakingFrame(), # Second speech starts
|
||||
VADUserStartedSpeakingFrame(), # Should be suppressed again
|
||||
UserStartedSpeakingFrame(), # Should be suppressed again
|
||||
InputAudioRawFrame(
|
||||
audio=b"", sample_rate=16000, num_channels=1
|
||||
), # Should be suppressed again
|
||||
VADUserStoppedSpeakingFrame(), # Should be suppressed again
|
||||
UserStoppedSpeakingFrame(), # Should be suppressed again
|
||||
BotStoppedSpeakingFrame(), # Second speech ends
|
||||
]
|
||||
@@ -87,8 +101,10 @@ class TestSTTMuteFilter(unittest.IsolatedAsyncioTestCase):
|
||||
STTMuteFrame, # mute=True
|
||||
BotStoppedSpeakingFrame,
|
||||
STTMuteFrame, # mute=False
|
||||
VADUserStartedSpeakingFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
InputAudioRawFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
STTMuteFrame, # mute=True
|
||||
@@ -146,14 +162,18 @@ class TestSTTMuteFilter(unittest.IsolatedAsyncioTestCase):
|
||||
# filter = STTMuteFilter(config=STTMuteConfig(strategies={STTMuteStrategy.FUNCTION_CALL}))
|
||||
|
||||
# frames_to_send = [
|
||||
# VADUserStartedSpeakingFrame(), # Should pass through initially
|
||||
# UserStartedSpeakingFrame(), # Should pass through initially
|
||||
# VADUserStoppedSpeakingFrame(),
|
||||
# UserStoppedSpeakingFrame(),
|
||||
# FunctionCallInProgressFrame(
|
||||
# function_name="get_weather",
|
||||
# tool_call_id="call_123",
|
||||
# arguments='{"location": "San Francisco"}',
|
||||
# ), # Start function call
|
||||
# VADUserStartedSpeakingFrame(), # Should be suppressed
|
||||
# UserStartedSpeakingFrame(), # Should be suppressed
|
||||
# VADUserStoppedSpeakingFrame(), # Should be suppressed
|
||||
# UserStoppedSpeakingFrame(), # Should be suppressed
|
||||
# FunctionCallResultFrame(
|
||||
# function_name="get_weather",
|
||||
@@ -161,18 +181,24 @@ class TestSTTMuteFilter(unittest.IsolatedAsyncioTestCase):
|
||||
# arguments='{"location": "San Francisco"}',
|
||||
# result={"temperature": 22},
|
||||
# ), # End function call
|
||||
# VADUserStartedSpeakingFrame(), # Should pass through again
|
||||
# UserStartedSpeakingFrame(), # Should pass through again
|
||||
# VADUserStoppedSpeakingFrame(),
|
||||
# UserStoppedSpeakingFrame(),
|
||||
# ]
|
||||
|
||||
# expected_returned_frames = [
|
||||
# VADUserStartedSpeakingFrame,
|
||||
# UserStartedSpeakingFrame,
|
||||
# VADUserStoppedSpeakingFrame,
|
||||
# UserStoppedSpeakingFrame,
|
||||
# FunctionCallInProgressFrame,
|
||||
# STTMuteFrame, # mute=True
|
||||
# FunctionCallResultFrame,
|
||||
# STTMuteFrame, # mute=False
|
||||
# VADUserStartedSpeakingFrame,
|
||||
# UserStartedSpeakingFrame,
|
||||
# VADUserStoppedSpeakingFrame,
|
||||
# UserStoppedSpeakingFrame,
|
||||
# ]
|
||||
|
||||
@@ -188,24 +214,32 @@ class TestSTTMuteFilter(unittest.IsolatedAsyncioTestCase):
|
||||
)
|
||||
|
||||
frames_to_send = [
|
||||
VADUserStartedSpeakingFrame(), # Should be suppressed (starts muted)
|
||||
UserStartedSpeakingFrame(), # Should be suppressed (starts muted)
|
||||
InputAudioRawFrame(
|
||||
audio=b"", sample_rate=16000, num_channels=1
|
||||
), # Should be suppressed
|
||||
VADUserStoppedSpeakingFrame(), # Should be suppressed
|
||||
UserStoppedSpeakingFrame(), # Should be suppressed
|
||||
BotStartedSpeakingFrame(), # First bot speech
|
||||
VADUserStartedSpeakingFrame(), # Should be suppressed
|
||||
UserStartedSpeakingFrame(), # Should be suppressed
|
||||
InputAudioRawFrame(
|
||||
audio=b"", sample_rate=16000, num_channels=1
|
||||
), # Should be suppressed
|
||||
VADUserStoppedSpeakingFrame(), # Should be suppressed
|
||||
UserStoppedSpeakingFrame(), # Should be suppressed
|
||||
BotStoppedSpeakingFrame(), # First speech ends, unmutes
|
||||
VADUserStartedSpeakingFrame(), # Should pass through
|
||||
UserStartedSpeakingFrame(), # Should pass through
|
||||
InputAudioRawFrame(audio=b"", sample_rate=16000, num_channels=1), # Should pass through
|
||||
VADUserStoppedSpeakingFrame(), # Should pass through
|
||||
UserStoppedSpeakingFrame(), # Should pass through
|
||||
BotStartedSpeakingFrame(), # Second speech
|
||||
VADUserStartedSpeakingFrame(), # Should pass through
|
||||
UserStartedSpeakingFrame(), # Should pass through
|
||||
InputAudioRawFrame(audio=b"", sample_rate=16000, num_channels=1), # Should pass through
|
||||
VADUserStoppedSpeakingFrame(), # Should pass through
|
||||
UserStoppedSpeakingFrame(), # Should pass through
|
||||
BotStoppedSpeakingFrame(),
|
||||
]
|
||||
@@ -215,12 +249,16 @@ class TestSTTMuteFilter(unittest.IsolatedAsyncioTestCase):
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
STTMuteFrame, # mute=False after first speech
|
||||
VADUserStartedSpeakingFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
InputAudioRawFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
InputAudioRawFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
]
|
||||
@@ -254,31 +292,41 @@ class TestSTTMuteFilter(unittest.IsolatedAsyncioTestCase):
|
||||
)
|
||||
|
||||
frames_to_send = [
|
||||
VADUserStartedSpeakingFrame(), # Should pass through
|
||||
UserStartedSpeakingFrame(), # Should pass through
|
||||
InputAudioRawFrame(audio=b"", sample_rate=16000, num_channels=1), # Should pass through
|
||||
VADUserStoppedSpeakingFrame(), # Should pass through
|
||||
UserStoppedSpeakingFrame(), # Should pass through
|
||||
BotStartedSpeakingFrame(), # Bot starts speaking
|
||||
VADUserStartedSpeakingFrame(), # Should be suppressed
|
||||
UserStartedSpeakingFrame(), # Should be suppressed
|
||||
InputAudioRawFrame(
|
||||
audio=b"", sample_rate=16000, num_channels=1
|
||||
), # Should be suppressed
|
||||
VADUserStoppedSpeakingFrame(), # Should be suppressed
|
||||
UserStoppedSpeakingFrame(), # Should be suppressed
|
||||
BotStoppedSpeakingFrame(), # Bot stops speaking
|
||||
VADUserStartedSpeakingFrame(), # Should pass through
|
||||
UserStartedSpeakingFrame(), # Should pass through
|
||||
InputAudioRawFrame(audio=b"", sample_rate=16000, num_channels=1), # Should pass through
|
||||
VADUserStoppedSpeakingFrame(), # Should pass through
|
||||
UserStoppedSpeakingFrame(), # Should pass through
|
||||
]
|
||||
|
||||
expected_returned_frames = [
|
||||
VADUserStartedSpeakingFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
InputAudioRawFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
STTMuteFrame, # mute=True
|
||||
BotStoppedSpeakingFrame,
|
||||
STTMuteFrame, # mute=False
|
||||
VADUserStartedSpeakingFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
InputAudioRawFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
]
|
||||
|
||||
|
||||
65
tests/test_watchdog_queue.py
Normal file
65
tests/test_watchdog_queue.py
Normal file
@@ -0,0 +1,65 @@
|
||||
#
|
||||
# Copyright (c) 2024-2025 Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import unittest
|
||||
|
||||
from pipecat.utils.asyncio.task_manager import TaskManager
|
||||
from pipecat.utils.asyncio.watchdog_priority_queue import WatchdogPriorityQueue
|
||||
from pipecat.utils.asyncio.watchdog_queue import WatchdogQueue
|
||||
|
||||
|
||||
class TestWatchdogQueue(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_simple_item(self):
|
||||
queue = WatchdogQueue(TaskManager())
|
||||
await queue.put(1)
|
||||
await queue.put(2)
|
||||
await queue.put(3)
|
||||
self.assertEqual(await queue.get(), 1)
|
||||
queue.task_done()
|
||||
self.assertEqual(await queue.get(), 2)
|
||||
queue.task_done()
|
||||
self.assertEqual(await queue.get(), 3)
|
||||
queue.task_done()
|
||||
|
||||
async def test_watchdog_sentinel(self):
|
||||
queue = WatchdogQueue(TaskManager())
|
||||
await queue.put(1)
|
||||
self.assertEqual(await queue.get(), 1)
|
||||
queue.task_done()
|
||||
# The get should throw an exception.
|
||||
queue.cancel()
|
||||
try:
|
||||
await queue.get()
|
||||
assert False
|
||||
except asyncio.CancelledError:
|
||||
assert True
|
||||
|
||||
|
||||
class TestWatchdogPriorityQueue(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_simple_item(self):
|
||||
queue = WatchdogPriorityQueue(TaskManager(), tuple_size=2)
|
||||
await queue.put((3, 1))
|
||||
await queue.put((2, 1))
|
||||
await queue.put((1, 1))
|
||||
self.assertEqual(await queue.get(), (1, 1))
|
||||
queue.task_done()
|
||||
self.assertEqual(await queue.get(), (2, 1))
|
||||
queue.task_done()
|
||||
self.assertEqual(await queue.get(), (3, 1))
|
||||
queue.task_done()
|
||||
|
||||
async def test_watchdog_sentinel(self):
|
||||
queue = WatchdogPriorityQueue(TaskManager(), tuple_size=2)
|
||||
await queue.put((0, 1))
|
||||
# The get should throw an exception because the watchdog sentinel has
|
||||
# higher priority.
|
||||
queue.cancel()
|
||||
try:
|
||||
await queue.get()
|
||||
assert False
|
||||
except asyncio.CancelledError:
|
||||
assert True
|
||||
Reference in New Issue
Block a user