Compare commits
193 Commits
async-reba
...
v0.0.45
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4075b19f7c | ||
|
|
bb14918a33 | ||
|
|
2aee8a12f8 | ||
|
|
5760fadb44 | ||
|
|
af5a7e9092 | ||
|
|
8d9a7486d1 | ||
|
|
00d0f9ae48 | ||
|
|
d255b7d1b2 | ||
|
|
4eb2c95b63 | ||
|
|
3910aeb4de | ||
|
|
713dcb7a4d | ||
|
|
04da51c7d8 | ||
|
|
e52d18e42d | ||
|
|
0c4a513ca2 | ||
|
|
4a71eacac3 | ||
|
|
f0d89e57ad | ||
|
|
79b52d4301 | ||
|
|
bb00dbefbc | ||
|
|
0c250c0603 | ||
|
|
7bbaf4dfe9 | ||
|
|
3a3bf3fe34 | ||
|
|
616aa54f75 | ||
|
|
164f06415c | ||
|
|
51bc4839d1 | ||
|
|
6d778e0491 | ||
|
|
fc4fa2faaa | ||
|
|
90b7f65545 | ||
|
|
f7b7f0d680 | ||
|
|
5431c44e51 | ||
|
|
40b3e50815 | ||
|
|
2f6232fac9 | ||
|
|
b4f2525c76 | ||
|
|
8e956a4e88 | ||
|
|
7b9712daad | ||
|
|
d4269acd67 | ||
|
|
d2ae82fb38 | ||
|
|
270949e6cd | ||
|
|
cfada94c13 | ||
|
|
68fd6f7c44 | ||
|
|
96bfcc3dca | ||
|
|
b0890b1f75 | ||
|
|
802b3e42c4 | ||
|
|
bd134839ff | ||
|
|
428ce63e17 | ||
|
|
46d6cde383 | ||
|
|
6de82b3c11 | ||
|
|
ec0bc7a057 | ||
|
|
c62156a4c3 | ||
|
|
e8618a07d0 | ||
|
|
0ba99514a9 | ||
|
|
837c8dad27 | ||
|
|
6f2a464451 | ||
|
|
ac4c5ab369 | ||
|
|
9e95419301 | ||
|
|
f390ec9608 | ||
|
|
ce8a83efba | ||
|
|
e5a2bf9564 | ||
|
|
7838018686 | ||
|
|
31916ed9fd | ||
|
|
3a2fbc2b19 | ||
|
|
43520b44da | ||
|
|
ab4a8d791a | ||
|
|
40dc546b81 | ||
|
|
5426891feb | ||
|
|
1c5ccd3406 | ||
|
|
3a745bfa3f | ||
|
|
ac4e39991e | ||
|
|
c870832da6 | ||
|
|
e782016c57 | ||
|
|
00badaf98e | ||
|
|
7dfac0163b | ||
|
|
09a3c2a82d | ||
|
|
c32c65014b | ||
|
|
f082eb10a2 | ||
|
|
b8898e449e | ||
|
|
d1f6d229ca | ||
|
|
4fa0318005 | ||
|
|
93ebb9d541 | ||
|
|
16101c79c5 | ||
|
|
c866b3f2c9 | ||
|
|
c26a45721f | ||
|
|
d9c900f872 | ||
|
|
73becbad29 | ||
|
|
f1df3de263 | ||
|
|
3bc5c8cda7 | ||
|
|
7b3b1058b2 | ||
|
|
87473f857f | ||
|
|
a96209185c | ||
|
|
34cc2ed1a1 | ||
|
|
667aa0c25a | ||
|
|
12707f4ff7 | ||
|
|
53451899a7 | ||
|
|
dc73b20c0b | ||
|
|
4330374ba4 | ||
|
|
79c8aa2c4a | ||
|
|
083d221dd2 | ||
|
|
74d47b725f | ||
|
|
917e482876 | ||
|
|
522d931950 | ||
|
|
d10c7ac7ce | ||
|
|
84705427c5 | ||
|
|
66a76af341 | ||
|
|
d402d91c2f | ||
|
|
b05130a089 | ||
|
|
b3cc0779f0 | ||
|
|
cbecae40a9 | ||
|
|
5b8753c8b6 | ||
|
|
3c5f9457f1 | ||
|
|
e32e56d0bc | ||
|
|
788aec665b | ||
|
|
3cada03a92 | ||
|
|
e21fb520f9 | ||
|
|
864f4d385f | ||
|
|
26ac2878ae | ||
|
|
cac63f5565 | ||
|
|
aadffd6199 | ||
|
|
3403197a90 | ||
|
|
8cdb9ab1ad | ||
|
|
5dbf26d283 | ||
|
|
8001bab9b0 | ||
|
|
12d0686adc | ||
|
|
a28a5e954a | ||
|
|
bb966a89d2 | ||
|
|
4a74eb3321 | ||
|
|
1f54ee6991 | ||
|
|
86143f79a1 | ||
|
|
b373bc82b5 | ||
|
|
ea2a05a04b | ||
|
|
5692ca586c | ||
|
|
a11ad81f02 | ||
|
|
805efdb144 | ||
|
|
c49b31e6ad | ||
|
|
7796a272ce | ||
|
|
678e87fd31 | ||
|
|
4d81a2ebfe | ||
|
|
2d82702e04 | ||
|
|
27dcf83f37 | ||
|
|
72db83528d | ||
|
|
45c7d36b2e | ||
|
|
65eeb0f1f6 | ||
|
|
1d7d0bb1ea | ||
|
|
598936bc53 | ||
|
|
b1bf6f7733 | ||
|
|
75d27aeb9f | ||
|
|
0a37caf4b4 | ||
|
|
6db65f4335 | ||
|
|
3648874301 | ||
|
|
8bcb5d7fd2 | ||
|
|
8c01a900cd | ||
|
|
d378e699d2 | ||
|
|
c25c375c41 | ||
|
|
70c3ff31fd | ||
|
|
cd2e29f285 | ||
|
|
6d4d7d763d | ||
|
|
6c1851eef8 | ||
|
|
096a15eef6 | ||
|
|
3d642df2b0 | ||
|
|
d75a02dc51 | ||
|
|
28643b453d | ||
|
|
d5635de5f6 | ||
|
|
88cca7bf68 | ||
|
|
a397b859fe | ||
|
|
8aae4e9856 | ||
|
|
92d8b37229 | ||
|
|
0801fc578b | ||
|
|
0d5cb84531 | ||
|
|
47b943a117 | ||
|
|
128355add5 | ||
|
|
0499fe41e4 | ||
|
|
6ad3437fd2 | ||
|
|
a5c73ec829 | ||
|
|
def04ac0ce | ||
|
|
5d63615b1b | ||
|
|
90ee284fe0 | ||
|
|
539e0b66fb | ||
|
|
fef393dcac | ||
|
|
ed607d5c4b | ||
|
|
37da7e44cd | ||
|
|
69c7edd60c | ||
|
|
392f210371 | ||
|
|
9a63df1ea1 | ||
|
|
f8a75cede9 | ||
|
|
4d1e370e02 | ||
|
|
d080a31a5c | ||
|
|
a90ebdfe7c | ||
|
|
c8995b82e5 | ||
|
|
6b7f924af6 | ||
|
|
51580e5349 | ||
|
|
ed49cebf2c | ||
|
|
387a36dd8a | ||
|
|
2e02ab740d | ||
|
|
b4eff2028f | ||
|
|
f411bf33fd |
146
CHANGELOG.md
146
CHANGELOG.md
@@ -1,20 +1,115 @@
|
||||
# Changelog
|
||||
|
||||
All notable changes to **pipecat** will be documented in this file.
|
||||
All notable changes to **Pipecat** will be documented in this file.
|
||||
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
## [0.0.45] - 2024-10-16
|
||||
|
||||
### Changed
|
||||
|
||||
- Metrics messages have moved out from the transport's base output into RTVI.
|
||||
|
||||
## [0.0.44] - 2024-10-15
|
||||
|
||||
### Added
|
||||
|
||||
- Added Google TTS service and corresponding foundational example `07n-interruptible-google.py`
|
||||
- Added support for OpenAI Realtime API with the new
|
||||
`OpenAILLMServiceRealtimeBeta` processor.
|
||||
(see https://platform.openai.com/docs/guides/realtime/overview)
|
||||
|
||||
- Added `RTVIBotTranscriptionProcessor` which will send the RTVI
|
||||
`bot-transcription` protocol message. These are TTS text aggregated (into
|
||||
sentences) messages.
|
||||
|
||||
- Added new input params to the `MarkdownTextFilter` utility. You can set
|
||||
`filter_code` to filter code from text and `filter_tables` to filter tables
|
||||
from text.
|
||||
|
||||
- Added `CanonicalMetricsService`. This processor uses the new
|
||||
`AudioBufferProcessor` to capture conversation audio and later send it to
|
||||
Canonical AI.
|
||||
(see https://canonical.chat/)
|
||||
|
||||
- Added `AudioBufferProcessor`. This processor can be used to buffer mixed user and
|
||||
bot audio. This can later be saved into an audio file or processed by some
|
||||
audio analyzer.
|
||||
|
||||
- Added `on_first_participant_joined` event to `LiveKitTransport`.
|
||||
|
||||
### Changed
|
||||
|
||||
- LLM text responses are now logged properly as unicode characters.
|
||||
|
||||
- `UserStartedSpeakingFrame`, `UserStoppedSpeakingFrame`,
|
||||
`BotStartedSpeakingFrame`, `BotStoppedSpeakingFrame`, `BotSpeakingFrame` and
|
||||
`UserImageRequestFrame` are now based from `SystemFrame`
|
||||
|
||||
### Fixed
|
||||
|
||||
- Merge `RTVIBotLLMProcessor`/`RTVIBotLLMTextProcessor` and
|
||||
`RTVIBotTTSProcessor`/`RTVIBotTTSTextProcessor` to avoid out of order issues.
|
||||
|
||||
- Fixed an issue in RTVI protocol that could cause a `bot-llm-stopped` or
|
||||
`bot-tts-stopped` message to be sent before a `bot-llm-text` or `bot-tts-text`
|
||||
message.
|
||||
|
||||
- Fixed `DeepgramSTTService` constructor settings not being merged with default
|
||||
ones.
|
||||
|
||||
- Fixed an issue in Daily transport that would cause tasks to be hanging if
|
||||
urgent transport messages were being sent from a transport event handler.
|
||||
|
||||
- Fixed an issue in `BaseOutputTransport` that would cause `EndFrame` to be
|
||||
pushed downed too early and call `FrameProcessor.cleanup()` before letting the
|
||||
transport stop properly.
|
||||
|
||||
## [0.0.43] - 2024-10-10
|
||||
|
||||
### Added
|
||||
|
||||
- Added a new util called `MarkdownTextFilter` which is a subclass of a new
|
||||
base class called `BaseTextFilter`. This is a configurable utility which
|
||||
is intended to filter text received by TTS services.
|
||||
|
||||
- Added new `RTVIUserLLMTextProcessor`. This processor will send an RTVI
|
||||
`user-llm-text` message with the user content's that was sent to the LLM.
|
||||
|
||||
### Changed
|
||||
|
||||
- `TransportMessageFrame` doesn't have an `urgent` field anymore, instead
|
||||
there's now a `TransportMessageUrgentFrame` which is a `SystemFrame` and
|
||||
therefore skip all internal queuing.
|
||||
|
||||
- For TTS services, convert inputted languages to match each service's language
|
||||
format
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed an issue where changing a language with the Deepgram STT service
|
||||
wouldn't apply the change. This was fixed by disconnecting and reconnecting
|
||||
when the language changes.
|
||||
|
||||
## [0.0.42] - 2024-10-02
|
||||
|
||||
### Added
|
||||
|
||||
- `SentryMetrics` has been added to report frame processor metrics to
|
||||
Sentry. This is now possible because `FrameProcessorMetrics` can now be passed
|
||||
to `FrameProcessor`.
|
||||
|
||||
- Added Google TTS service and corresponding foundational example
|
||||
`07n-interruptible-google.py`
|
||||
|
||||
- Added AWS Polly TTS support and `07m-interruptible-aws.py` as an example.
|
||||
|
||||
- Added InputParams to Azure TTS service.
|
||||
|
||||
- Added `LivekitTransport` (audio-only for now).
|
||||
|
||||
- RTVI 0.2.0 is now supported.
|
||||
|
||||
- All `FrameProcessors` can now register event handlers.
|
||||
|
||||
```
|
||||
@@ -48,15 +143,10 @@ async def on_connected(processor):
|
||||
frames. To achieve that, each frame processor should only output frames from a
|
||||
single task.
|
||||
|
||||
In this version we introduce synchronous and asynchronous frame
|
||||
processors. The synchronous processors push output frames from the same task
|
||||
that they receive input frames, and therefore only pushing frames from one
|
||||
task. Asynchronous frame processors can have internal tasks to perform things
|
||||
asynchronously (e.g. receiving data from a websocket) but they also have a
|
||||
single task where they push frames from.
|
||||
|
||||
By default, frame processors are synchronous. To change a frame processor to
|
||||
asynchronous you only need to pass `sync=False` to the base class constructor.
|
||||
In this version all the frame processors have their own task to push
|
||||
frames. That is, when `push_frame()` is called the given frame will be put
|
||||
into an internal queue (with the exception of system frames) and a frame
|
||||
processor task will push it out.
|
||||
|
||||
- Added pipeline clocks. A pipeline clock is used by the output transport to
|
||||
know when a frame needs to be presented. For that, all frames now have an
|
||||
@@ -68,9 +158,7 @@ async def on_connected(processor):
|
||||
`SystemClock`). This clock will be passed to each frame processor via the
|
||||
`StartFrame`.
|
||||
|
||||
- Added `CartesiaHttpTTSService`. This is a synchronous frame processor
|
||||
(i.e. given an input text frame it will wait for the whole output before
|
||||
returning).
|
||||
- Added `CartesiaHttpTTSService`.
|
||||
|
||||
- `DailyTransport` now supports setting the audio bitrate to improve audio
|
||||
quality through the `DailyParams.audio_out_bitrate` parameter. The new
|
||||
@@ -93,8 +181,12 @@ async def on_connected(processor):
|
||||
|
||||
### Changed
|
||||
|
||||
- Updated individual update settings frame classes into a single UpdateSettingsFrame
|
||||
class for STT, LLM, and TTS.
|
||||
- Context frames are now pushed downstream from assistant context aggregators.
|
||||
|
||||
- Removed Silero VAD torch dependency.
|
||||
|
||||
- Updated individual update settings frame classes into a single
|
||||
`ServiceUpdateSettingsFrame` class.
|
||||
|
||||
- We now distinguish between input and output audio and image frames. We
|
||||
introduce `InputAudioRawFrame`, `OutputAudioRawFrame`, `InputImageRawFrame`
|
||||
@@ -110,12 +202,13 @@ async def on_connected(processor):
|
||||
pipelines to be executed concurrently. The difference between a
|
||||
`SyncParallelPipeline` and a `ParallelPipeline` is that, given an input frame,
|
||||
the `SyncParallelPipeline` will wait for all the internal pipelines to
|
||||
complete. This is achieved by ensuring all the processors in each of the
|
||||
internal pipelines are synchronous.
|
||||
complete. This is achieved by making sure the last processor in each of the
|
||||
pipelines is synchronous (e.g. an HTTP-based service that waits for the
|
||||
response).
|
||||
|
||||
- `StartFrame` is back a system frame so we make sure it's processed immediately
|
||||
by all processors. `EndFrame` stays a control frame since it needs to be
|
||||
ordered allowing the frames in the pipeline to be processed.
|
||||
- `StartFrame` is back a system frame to make sure it's processed immediately by
|
||||
all processors. `EndFrame` stays a control frame since it needs to be ordered
|
||||
allowing the frames in the pipeline to be processed.
|
||||
|
||||
- Updated `MoondreamService` revision to `2024-08-26`.
|
||||
|
||||
@@ -139,6 +232,11 @@ async def on_connected(processor):
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed OpenAI multiple function calls.
|
||||
|
||||
- Fixed a Cartesia TTS issue that would cause audio to be truncated in some
|
||||
cases.
|
||||
|
||||
- Fixed a `BaseOutputTransport` issue that would stop audio and video rendering
|
||||
tasks (after receiving and `EndFrame`) before the internal queue was emptied,
|
||||
causing the pipeline to finish prematurely.
|
||||
@@ -152,6 +250,10 @@ async def on_connected(processor):
|
||||
- `obj_id()` and `obj_count()` now use `itertools.count` avoiding the need of
|
||||
`threading.Lock`.
|
||||
|
||||
### Other
|
||||
|
||||
- Pipecat now uses Ruff as its formatter (https://github.com/astral-sh/ruff).
|
||||
|
||||
## [0.0.41] - 2024-08-22
|
||||
|
||||
### Added
|
||||
|
||||
@@ -128,8 +128,6 @@ Pipecat makes use of WebRTC VAD by default when using a WebRTC transport layer.
|
||||
pip install pipecat-ai[silero]
|
||||
```
|
||||
|
||||
The first time your run your bot with Silero, startup may take a while whilst it downloads and caches the model in the background. You can check the progress of this in the console.
|
||||
|
||||
## Hacking on the framework itself
|
||||
|
||||
_Note that you may need to set up a virtual environment before following the instructions below. For instance, you might need to run the following from the root of the repo:_
|
||||
|
||||
161
examples/canonical-metrics/.gitignore
vendored
Normal file
161
examples/canonical-metrics/.gitignore
vendored
Normal file
@@ -0,0 +1,161 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
recordings/
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
runpod.toml
|
||||
16
examples/canonical-metrics/Dockerfile
Normal file
16
examples/canonical-metrics/Dockerfile
Normal file
@@ -0,0 +1,16 @@
|
||||
FROM python:3.10-bullseye
|
||||
|
||||
RUN mkdir /app
|
||||
RUN mkdir /app/assets
|
||||
RUN mkdir /app/utils
|
||||
COPY *.py /app/
|
||||
COPY requirements.txt /app/
|
||||
copy assets/* /app/assets/
|
||||
copy utils/* /app/utils/
|
||||
|
||||
WORKDIR /app
|
||||
RUN pip3 install -r requirements.txt
|
||||
|
||||
EXPOSE 7860
|
||||
|
||||
CMD ["python3", "server.py"]
|
||||
37
examples/canonical-metrics/README.md
Normal file
37
examples/canonical-metrics/README.md
Normal file
@@ -0,0 +1,37 @@
|
||||
# Simple Chatbot
|
||||
|
||||
<img src="image.png" width="420px">
|
||||
|
||||
This app connects you to a chatbot powered by GPT-4, complete with animations generated by Stable Video Diffusion.
|
||||
|
||||
See a video of it in action: https://x.com/kwindla/status/1778628911817183509
|
||||
|
||||
And a quick video walkthrough of the code: https://www.loom.com/share/13df1967161f4d24ade054e7f8753416
|
||||
|
||||
ℹ️ The first time, things might take extra time to get started since VAD (Voice Activity Detection) model needs to be downloaded.
|
||||
|
||||
## Get started
|
||||
|
||||
```python
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate
|
||||
pip install -r requirements.txt
|
||||
|
||||
cp env.example .env # and add your credentials
|
||||
|
||||
```
|
||||
|
||||
## Run the server
|
||||
|
||||
```bash
|
||||
python server.py
|
||||
```
|
||||
|
||||
Then, visit `http://localhost:7860/start` in your browser to start a chatbot session.
|
||||
|
||||
## Build and test the Docker image
|
||||
|
||||
```
|
||||
docker build -t chatbot .
|
||||
docker run --env-file .env -p 7860:7860 chatbot
|
||||
```
|
||||
149
examples/canonical-metrics/bot.py
Normal file
149
examples/canonical-metrics/bot.py
Normal file
@@ -0,0 +1,149 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.frames.frames import EndFrame, LLMMessagesFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantResponseAggregator,
|
||||
LLMUserResponseAggregator,
|
||||
)
|
||||
from pipecat.processors.audio.audio_buffer_processor import AudioBufferProcessor
|
||||
from pipecat.services.canonical import CanonicalMetricsService
|
||||
from pipecat.services.elevenlabs import ElevenLabsTTSService
|
||||
from pipecat.services.openai import OpenAILLMService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
from pipecat.vad.silero import SileroVADAnalyzer
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
async def main():
|
||||
async with aiohttp.ClientSession() as session:
|
||||
(room_url, token) = await configure(session)
|
||||
|
||||
transport = DailyTransport(
|
||||
room_url,
|
||||
token,
|
||||
"Chatbot",
|
||||
DailyParams(
|
||||
audio_out_enabled=True,
|
||||
audio_in_enabled=True,
|
||||
camera_out_enabled=False,
|
||||
vad_enabled=True,
|
||||
vad_audio_passthrough=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
transcription_enabled=True,
|
||||
#
|
||||
# Spanish
|
||||
#
|
||||
# transcription_settings=DailyTranscriptionSettings(
|
||||
# language="es",
|
||||
# tier="nova",
|
||||
# model="2-general"
|
||||
# )
|
||||
),
|
||||
)
|
||||
|
||||
tts = ElevenLabsTTSService(
|
||||
api_key=os.getenv("ELEVENLABS_API_KEY"),
|
||||
#
|
||||
# English
|
||||
#
|
||||
voice_id="cgSgspJ2msm6clMCkdW9",
|
||||
aiohttp_session=session,
|
||||
#
|
||||
# Spanish
|
||||
#
|
||||
# model="eleven_multilingual_v2",
|
||||
# voice_id="gD1IexrzCvsXPHUuT0s3",
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
#
|
||||
# English
|
||||
#
|
||||
"content": "You are Chatbot, a friendly, helpful robot. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way, but keep your responses brief. Start by introducing yourself. Keep all your responses to 12 words or fewer.",
|
||||
#
|
||||
# Spanish
|
||||
#
|
||||
# "content": "Eres Chatbot, un amigable y útil robot. Tu objetivo es demostrar tus capacidades de una manera breve. Tus respuestas se convertiran a audio así que nunca no debes incluir caracteres especiales. Contesta a lo que el usuario pregunte de una manera creativa, útil y breve. Empieza por presentarte a ti mismo.",
|
||||
},
|
||||
]
|
||||
|
||||
user_response = LLMUserResponseAggregator()
|
||||
assistant_response = LLMAssistantResponseAggregator()
|
||||
|
||||
"""
|
||||
CanonicalMetrics uses AudioBufferProcessor under the hood to buffer the audio. On
|
||||
call completion, CanonicalMetrics will send the audio buffer to Canonical for
|
||||
analysis. Visit https://voice.canonical.chat to learn more.
|
||||
"""
|
||||
audio_buffer_processor = AudioBufferProcessor()
|
||||
canonical = CanonicalMetricsService(
|
||||
audio_buffer_processor=audio_buffer_processor,
|
||||
aiohttp_session=session,
|
||||
api_key=os.getenv("CANONICAL_API_KEY"),
|
||||
api_url=os.getenv("CANONICAL_API_URL"),
|
||||
call_id=str(uuid.uuid4()),
|
||||
assistant="pipecat-chatbot",
|
||||
assistant_speaks_first=True,
|
||||
)
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # microphone
|
||||
user_response,
|
||||
llm,
|
||||
tts,
|
||||
transport.output(),
|
||||
audio_buffer_processor, # captures audio into a buffer
|
||||
canonical, # uploads audio buffer to Canonical AI for metrics
|
||||
assistant_response,
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True))
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
transport.capture_participant_transcription(participant["id"])
|
||||
await task.queue_frames([LLMMessagesFrame(messages)])
|
||||
|
||||
@transport.event_handler("on_participant_left")
|
||||
async def on_participant_left(transport, participant, reason):
|
||||
print(f"Participant left: {participant}")
|
||||
await task.queue_frame(EndFrame())
|
||||
|
||||
@transport.event_handler("on_call_state_updated")
|
||||
async def on_call_state_updated(transport, state):
|
||||
if state == "left":
|
||||
await task.queue_frame(EndFrame())
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
5
examples/canonical-metrics/env.example
Normal file
5
examples/canonical-metrics/env.example
Normal file
@@ -0,0 +1,5 @@
|
||||
DAILY_SAMPLE_ROOM_URL=https://yourdomain.daily.co/yourroom # (for joining the bot to the same room repeatedly for local dev)
|
||||
DAILY_API_KEY=7df...
|
||||
OPENAI_API_KEY=sk-PL...
|
||||
ELEVENLABS_API_KEY=aeb...
|
||||
CANONICAL_API_KEY=can...
|
||||
5
examples/canonical-metrics/requirements.txt
Normal file
5
examples/canonical-metrics/requirements.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
python-dotenv
|
||||
fastapi[all]
|
||||
uvicorn
|
||||
pipecat-ai[daily,openai,silero,elevenlabs,canonical]
|
||||
|
||||
56
examples/canonical-metrics/runner.py
Normal file
56
examples/canonical-metrics/runner.py
Normal file
@@ -0,0 +1,56 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import aiohttp
|
||||
|
||||
from pipecat.transports.services.helpers.daily_rest import DailyRESTHelper
|
||||
|
||||
|
||||
async def configure(aiohttp_session: aiohttp.ClientSession):
|
||||
parser = argparse.ArgumentParser(description="Daily AI SDK Bot Sample")
|
||||
parser.add_argument(
|
||||
"-u", "--url", type=str, required=False, help="URL of the Daily room to join"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-k",
|
||||
"--apikey",
|
||||
type=str,
|
||||
required=False,
|
||||
help="Daily API Key (needed to create an owner token for the room)",
|
||||
)
|
||||
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
url = args.url or os.getenv("DAILY_SAMPLE_ROOM_URL")
|
||||
key = args.apikey or os.getenv("DAILY_API_KEY")
|
||||
|
||||
if not url:
|
||||
raise Exception(
|
||||
"No Daily room specified. use the -u/--url option from the command line, or set DAILY_SAMPLE_ROOM_URL in your environment to specify a Daily room URL."
|
||||
)
|
||||
|
||||
if not key:
|
||||
raise Exception(
|
||||
"No Daily API key specified. use the -k/--apikey option from the command line, or set DAILY_API_KEY in your environment to specify a Daily API key, available from https://dashboard.daily.co/developers."
|
||||
)
|
||||
|
||||
daily_rest_helper = DailyRESTHelper(
|
||||
daily_api_key=key,
|
||||
daily_api_url=os.getenv("DAILY_API_URL", "https://api.daily.co/v1"),
|
||||
aiohttp_session=aiohttp_session,
|
||||
)
|
||||
|
||||
# Create a meeting token for the given room with an expiration 1 hour in
|
||||
# the future.
|
||||
expiry_time: float = 60 * 60
|
||||
|
||||
token = await daily_rest_helper.get_token(url, expiry_time)
|
||||
|
||||
return (url, token)
|
||||
return (url, token)
|
||||
139
examples/canonical-metrics/server.py
Normal file
139
examples/canonical-metrics/server.py
Normal file
@@ -0,0 +1,139 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import subprocess
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, RedirectResponse
|
||||
|
||||
from pipecat.transports.services.helpers.daily_rest import DailyRESTHelper, DailyRoomParams
|
||||
|
||||
MAX_BOTS_PER_ROOM = 1
|
||||
|
||||
# Bot sub-process dict for status reporting and concurrency control
|
||||
bot_procs = {}
|
||||
|
||||
daily_helpers = {}
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
def cleanup():
|
||||
# Clean up function, just to be extra safe
|
||||
for entry in bot_procs.values():
|
||||
proc = entry[0]
|
||||
proc.terminate()
|
||||
proc.wait()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
aiohttp_session = aiohttp.ClientSession()
|
||||
daily_helpers["rest"] = DailyRESTHelper(
|
||||
daily_api_key=os.getenv("DAILY_API_KEY", ""),
|
||||
daily_api_url=os.getenv("DAILY_API_URL", "https://api.daily.co/v1"),
|
||||
aiohttp_session=aiohttp_session,
|
||||
)
|
||||
yield
|
||||
await aiohttp_session.close()
|
||||
cleanup()
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@app.get("/start")
|
||||
async def start_agent(request: Request):
|
||||
print(f"!!! Creating room")
|
||||
room = await daily_helpers["rest"].create_room(DailyRoomParams())
|
||||
print(f"!!! Room URL: {room.url}")
|
||||
# Ensure the room property is present
|
||||
if not room.url:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Missing 'room' property in request data. Cannot start agent without a target room!",
|
||||
)
|
||||
|
||||
# Check if there is already an existing process running in this room
|
||||
num_bots_in_room = sum(
|
||||
1 for proc in bot_procs.values() if proc[1] == room.url and proc[0].poll() is None
|
||||
)
|
||||
if num_bots_in_room >= MAX_BOTS_PER_ROOM:
|
||||
raise HTTPException(status_code=500, detail=f"Max bot limited reach for room: {room.url}")
|
||||
|
||||
# Get the token for the room
|
||||
token = await daily_helpers["rest"].get_token(room.url)
|
||||
|
||||
if not token:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get token for room: {room.url}")
|
||||
|
||||
# Spawn a new agent, and join the user session
|
||||
# Note: this is mostly for demonstration purposes (refer to 'deployment' in README)
|
||||
try:
|
||||
proc = subprocess.Popen(
|
||||
[f"python3 -m bot -u {room.url} -t {token}"],
|
||||
shell=True,
|
||||
bufsize=1,
|
||||
cwd=os.path.dirname(os.path.abspath(__file__)),
|
||||
)
|
||||
bot_procs[proc.pid] = (proc, room.url)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to start subprocess: {e}")
|
||||
|
||||
return RedirectResponse(room.url)
|
||||
|
||||
|
||||
@app.get("/status/{pid}")
|
||||
def get_status(pid: int):
|
||||
# Look up the subprocess
|
||||
proc = bot_procs.get(pid)
|
||||
|
||||
# If the subprocess doesn't exist, return an error
|
||||
if not proc:
|
||||
raise HTTPException(status_code=404, detail=f"Bot with process id: {pid} not found")
|
||||
|
||||
# Check the status of the subprocess
|
||||
if proc[0].poll() is None:
|
||||
status = "running"
|
||||
else:
|
||||
status = "finished"
|
||||
|
||||
return JSONResponse({"bot_id": pid, "status": status})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
default_host = os.getenv("HOST", "0.0.0.0")
|
||||
default_port = int(os.getenv("FAST_API_PORT", "7860"))
|
||||
|
||||
parser = argparse.ArgumentParser(description="Daily Storyteller FastAPI server")
|
||||
parser.add_argument("--host", type=str, default=default_host, help="Host address")
|
||||
parser.add_argument("--port", type=int, default=default_port, help="Port number")
|
||||
parser.add_argument("--reload", action="store_true", help="Reload code on change")
|
||||
|
||||
config = parser.parse_args()
|
||||
|
||||
uvicorn.run(
|
||||
"server:app",
|
||||
host=config.host,
|
||||
port=config.port,
|
||||
reload=config.reload,
|
||||
)
|
||||
161
examples/chatbot-audio-recording/.gitignore
vendored
Normal file
161
examples/chatbot-audio-recording/.gitignore
vendored
Normal file
@@ -0,0 +1,161 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
runpod.toml
|
||||
15
examples/chatbot-audio-recording/Dockerfile
Normal file
15
examples/chatbot-audio-recording/Dockerfile
Normal file
@@ -0,0 +1,15 @@
|
||||
FROM python:3.10-bullseye
|
||||
|
||||
RUN mkdir /app
|
||||
RUN mkdir /app/assets
|
||||
RUN mkdir /app/utils
|
||||
COPY *.py /app/
|
||||
COPY requirements.txt /app/
|
||||
|
||||
|
||||
WORKDIR /app
|
||||
RUN pip3 install -r requirements.txt
|
||||
|
||||
EXPOSE 7860
|
||||
|
||||
CMD ["python3", "server.py"]
|
||||
37
examples/chatbot-audio-recording/README.md
Normal file
37
examples/chatbot-audio-recording/README.md
Normal file
@@ -0,0 +1,37 @@
|
||||
# Simple Chatbot
|
||||
|
||||
<img src="image.png" width="420px">
|
||||
|
||||
This app connects you to a chatbot powered by GPT-4, complete with animations generated by Stable Video Diffusion.
|
||||
|
||||
See a video of it in action: https://x.com/kwindla/status/1778628911817183509
|
||||
|
||||
And a quick video walkthrough of the code: https://www.loom.com/share/13df1967161f4d24ade054e7f8753416
|
||||
|
||||
ℹ️ The first time, things might take extra time to get started since VAD (Voice Activity Detection) model needs to be downloaded.
|
||||
|
||||
## Get started
|
||||
|
||||
```python
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate
|
||||
pip install -r requirements.txt
|
||||
|
||||
cp env.example .env # and add your credentials
|
||||
|
||||
```
|
||||
|
||||
## Run the server
|
||||
|
||||
```bash
|
||||
python server.py
|
||||
```
|
||||
|
||||
Then, visit `http://localhost:7860/start` in your browser to start a chatbot session.
|
||||
|
||||
## Build and test the Docker image
|
||||
|
||||
```
|
||||
docker build -t chatbot .
|
||||
docker run --env-file .env -p 7860:7860 chatbot
|
||||
```
|
||||
132
examples/chatbot-audio-recording/bot.py
Normal file
132
examples/chatbot-audio-recording/bot.py
Normal file
@@ -0,0 +1,132 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.frames.frames import EndFrame, LLMMessagesFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantResponseAggregator,
|
||||
LLMUserResponseAggregator,
|
||||
)
|
||||
from pipecat.processors.audio.audio_buffer_processor import AudioBufferProcessor
|
||||
from pipecat.services.elevenlabs import ElevenLabsTTSService
|
||||
from pipecat.services.openai import OpenAILLMService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
from pipecat.vad.silero import SileroVADAnalyzer
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
async def main():
|
||||
async with aiohttp.ClientSession() as session:
|
||||
(room_url, token) = await configure(session)
|
||||
|
||||
transport = DailyTransport(
|
||||
room_url,
|
||||
token,
|
||||
"Chatbot",
|
||||
DailyParams(
|
||||
audio_out_enabled=True,
|
||||
audio_in_enabled=True,
|
||||
camera_out_enabled=False,
|
||||
vad_enabled=True,
|
||||
vad_audio_passthrough=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
transcription_enabled=True,
|
||||
#
|
||||
# Spanish
|
||||
#
|
||||
# transcription_settings=DailyTranscriptionSettings(
|
||||
# language="es",
|
||||
# tier="nova",
|
||||
# model="2-general"
|
||||
# )
|
||||
),
|
||||
)
|
||||
|
||||
tts = ElevenLabsTTSService(
|
||||
api_key=os.getenv("ELEVENLABS_API_KEY"),
|
||||
#
|
||||
# English
|
||||
#
|
||||
voice_id="cgSgspJ2msm6clMCkdW9",
|
||||
aiohttp_session=session,
|
||||
#
|
||||
# Spanish
|
||||
#
|
||||
# model="eleven_multilingual_v2",
|
||||
# voice_id="gD1IexrzCvsXPHUuT0s3",
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
#
|
||||
# English
|
||||
#
|
||||
"content": "You are Chatbot, a friendly, helpful robot. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way, but keep your responses brief. Start by introducing yourself. Keep all your response to 12 words or fewer.",
|
||||
#
|
||||
# Spanish
|
||||
#
|
||||
# "content": "Eres Chatbot, un amigable y útil robot. Tu objetivo es demostrar tus capacidades de una manera breve. Tus respuestas se convertiran a audio así que nunca no debes incluir caracteres especiales. Contesta a lo que el usuario pregunte de una manera creativa, útil y breve. Empieza por presentarte a ti mismo.",
|
||||
},
|
||||
]
|
||||
|
||||
user_response = LLMUserResponseAggregator()
|
||||
assistant_response = LLMAssistantResponseAggregator()
|
||||
|
||||
audiobuffer = AudioBufferProcessor()
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # microphone
|
||||
user_response,
|
||||
llm,
|
||||
tts,
|
||||
transport.output(),
|
||||
audiobuffer, # used to buffer the audio in the pipeline
|
||||
assistant_response,
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True))
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
transport.capture_participant_transcription(participant["id"])
|
||||
await task.queue_frames([LLMMessagesFrame(messages)])
|
||||
|
||||
@transport.event_handler("on_participant_left")
|
||||
async def on_participant_left(transport, participant, reason):
|
||||
print(f"Participant left: {participant}")
|
||||
await task.queue_frame(EndFrame())
|
||||
|
||||
@transport.event_handler("on_call_state_updated")
|
||||
async def on_call_state_updated(transport, state):
|
||||
if state == "left":
|
||||
await task.queue_frame(EndFrame())
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
4
examples/chatbot-audio-recording/env.example
Normal file
4
examples/chatbot-audio-recording/env.example
Normal file
@@ -0,0 +1,4 @@
|
||||
DAILY_SAMPLE_ROOM_URL=https://yourdomain.daily.co/yourroom # (for joining the bot to the same room repeatedly for local dev)
|
||||
DAILY_API_KEY=7df...
|
||||
OPENAI_API_KEY=sk-PL...
|
||||
ELEVENLABS_API_KEY=aeb...
|
||||
4
examples/chatbot-audio-recording/requirements.txt
Normal file
4
examples/chatbot-audio-recording/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
python-dotenv
|
||||
fastapi[all]
|
||||
uvicorn
|
||||
pipecat-ai[daily,openai,silero,elevenlabs]
|
||||
56
examples/chatbot-audio-recording/runner.py
Normal file
56
examples/chatbot-audio-recording/runner.py
Normal file
@@ -0,0 +1,56 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import aiohttp
|
||||
|
||||
from pipecat.transports.services.helpers.daily_rest import DailyRESTHelper
|
||||
|
||||
|
||||
async def configure(aiohttp_session: aiohttp.ClientSession):
|
||||
parser = argparse.ArgumentParser(description="Daily AI SDK Bot Sample")
|
||||
parser.add_argument(
|
||||
"-u", "--url", type=str, required=False, help="URL of the Daily room to join"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-k",
|
||||
"--apikey",
|
||||
type=str,
|
||||
required=False,
|
||||
help="Daily API Key (needed to create an owner token for the room)",
|
||||
)
|
||||
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
url = args.url or os.getenv("DAILY_SAMPLE_ROOM_URL")
|
||||
key = args.apikey or os.getenv("DAILY_API_KEY")
|
||||
|
||||
if not url:
|
||||
raise Exception(
|
||||
"No Daily room specified. use the -u/--url option from the command line, or set DAILY_SAMPLE_ROOM_URL in your environment to specify a Daily room URL."
|
||||
)
|
||||
|
||||
if not key:
|
||||
raise Exception(
|
||||
"No Daily API key specified. use the -k/--apikey option from the command line, or set DAILY_API_KEY in your environment to specify a Daily API key, available from https://dashboard.daily.co/developers."
|
||||
)
|
||||
|
||||
daily_rest_helper = DailyRESTHelper(
|
||||
daily_api_key=key,
|
||||
daily_api_url=os.getenv("DAILY_API_URL", "https://api.daily.co/v1"),
|
||||
aiohttp_session=aiohttp_session,
|
||||
)
|
||||
|
||||
# Create a meeting token for the given room with an expiration 1 hour in
|
||||
# the future.
|
||||
expiry_time: float = 60 * 60
|
||||
|
||||
token = await daily_rest_helper.get_token(url, expiry_time)
|
||||
|
||||
return (url, token)
|
||||
return (url, token)
|
||||
139
examples/chatbot-audio-recording/server.py
Normal file
139
examples/chatbot-audio-recording/server.py
Normal file
@@ -0,0 +1,139 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import subprocess
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, RedirectResponse
|
||||
|
||||
from pipecat.transports.services.helpers.daily_rest import DailyRESTHelper, DailyRoomParams
|
||||
|
||||
MAX_BOTS_PER_ROOM = 1
|
||||
|
||||
# Bot sub-process dict for status reporting and concurrency control
|
||||
bot_procs = {}
|
||||
|
||||
daily_helpers = {}
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
def cleanup():
|
||||
# Clean up function, just to be extra safe
|
||||
for entry in bot_procs.values():
|
||||
proc = entry[0]
|
||||
proc.terminate()
|
||||
proc.wait()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
aiohttp_session = aiohttp.ClientSession()
|
||||
daily_helpers["rest"] = DailyRESTHelper(
|
||||
daily_api_key=os.getenv("DAILY_API_KEY", ""),
|
||||
daily_api_url=os.getenv("DAILY_API_URL", "https://api.daily.co/v1"),
|
||||
aiohttp_session=aiohttp_session,
|
||||
)
|
||||
yield
|
||||
await aiohttp_session.close()
|
||||
cleanup()
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@app.get("/start")
|
||||
async def start_agent(request: Request):
|
||||
print(f"!!! Creating room")
|
||||
room = await daily_helpers["rest"].create_room(DailyRoomParams())
|
||||
print(f"!!! Room URL: {room.url}")
|
||||
# Ensure the room property is present
|
||||
if not room.url:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Missing 'room' property in request data. Cannot start agent without a target room!",
|
||||
)
|
||||
|
||||
# Check if there is already an existing process running in this room
|
||||
num_bots_in_room = sum(
|
||||
1 for proc in bot_procs.values() if proc[1] == room.url and proc[0].poll() is None
|
||||
)
|
||||
if num_bots_in_room >= MAX_BOTS_PER_ROOM:
|
||||
raise HTTPException(status_code=500, detail=f"Max bot limited reach for room: {room.url}")
|
||||
|
||||
# Get the token for the room
|
||||
token = await daily_helpers["rest"].get_token(room.url)
|
||||
|
||||
if not token:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get token for room: {room.url}")
|
||||
|
||||
# Spawn a new agent, and join the user session
|
||||
# Note: this is mostly for demonstration purposes (refer to 'deployment' in README)
|
||||
try:
|
||||
proc = subprocess.Popen(
|
||||
[f"python3 -m bot -u {room.url} -t {token}"],
|
||||
shell=True,
|
||||
bufsize=1,
|
||||
cwd=os.path.dirname(os.path.abspath(__file__)),
|
||||
)
|
||||
bot_procs[proc.pid] = (proc, room.url)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to start subprocess: {e}")
|
||||
|
||||
return RedirectResponse(room.url)
|
||||
|
||||
|
||||
@app.get("/status/{pid}")
|
||||
def get_status(pid: int):
|
||||
# Look up the subprocess
|
||||
proc = bot_procs.get(pid)
|
||||
|
||||
# If the subprocess doesn't exist, return an error
|
||||
if not proc:
|
||||
raise HTTPException(status_code=404, detail=f"Bot with process id: {pid} not found")
|
||||
|
||||
# Check the status of the subprocess
|
||||
if proc[0].poll() is None:
|
||||
status = "running"
|
||||
else:
|
||||
status = "finished"
|
||||
|
||||
return JSONResponse({"bot_id": pid, "status": status})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
default_host = os.getenv("HOST", "0.0.0.0")
|
||||
default_port = int(os.getenv("FAST_API_PORT", "7860"))
|
||||
|
||||
parser = argparse.ArgumentParser(description="Daily Storyteller FastAPI server")
|
||||
parser.add_argument("--host", type=str, default=default_host, help="Host address")
|
||||
parser.add_argument("--port", type=int, default=default_port, help="Port number")
|
||||
parser.add_argument("--reload", action="store_true", help="Reload code on change")
|
||||
|
||||
config = parser.parse_args()
|
||||
|
||||
uvicorn.run(
|
||||
"server:app",
|
||||
host=config.host,
|
||||
port=config.port,
|
||||
reload=config.reload,
|
||||
)
|
||||
@@ -86,13 +86,13 @@ async def main():
|
||||
),
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o")
|
||||
|
||||
tts = CartesiaHttpTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o")
|
||||
|
||||
imagegen = FalImageGenService(
|
||||
params=FalImageGenService.InputParams(image_size="square_hd"),
|
||||
aiohttp_session=session,
|
||||
@@ -107,8 +107,10 @@ async def main():
|
||||
# that, each pipeline runs concurrently and `SyncParallelPipeline` will
|
||||
# wait for the input frame to be processed.
|
||||
#
|
||||
# Note that `SyncParallelPipeline` requires all processors in it to be
|
||||
# synchronous (which is the default for most processors).
|
||||
# Note that `SyncParallelPipeline` requires the last processor in each
|
||||
# of the pipelines to be synchronous. In this case, we use
|
||||
# `CartesiaHttpTTSService` and `FalImageGenService` which make HTTP
|
||||
# requests and wait for the response.
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
llm, # LLM
|
||||
|
||||
@@ -82,6 +82,7 @@ async def main():
|
||||
self.frame = OutputAudioRawFrame(
|
||||
bytes(self.audio), frame.sample_rate, frame.num_channels
|
||||
)
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
class ImageGrabber(FrameProcessor):
|
||||
def __init__(self):
|
||||
@@ -93,6 +94,7 @@ async def main():
|
||||
|
||||
if isinstance(frame, URLImageRawFrame):
|
||||
self.frame = frame
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o")
|
||||
|
||||
@@ -121,8 +123,10 @@ async def main():
|
||||
# `SyncParallelPipeline` will wait for the input frame to be
|
||||
# processed.
|
||||
#
|
||||
# Note that `SyncParallelPipeline` requires all processors in it to
|
||||
# be synchronous (which is the default for most processors).
|
||||
# Note that `SyncParallelPipeline` requires the last processor in
|
||||
# each of the pipelines to be synchronous. In this case, we use
|
||||
# `CartesiaHttpTTSService` and `FalImageGenService` which make HTTP
|
||||
# requests and wait for the response.
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
llm, # LLM
|
||||
|
||||
@@ -5,29 +5,24 @@
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import os
|
||||
import sys
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.frames.frames import LLMMessagesFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantResponseAggregator,
|
||||
LLMUserResponseAggregator,
|
||||
)
|
||||
from pipecat.services.cartesia import CartesiaTTSService
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.anthropic import AnthropicLLMService
|
||||
from pipecat.services.cartesia import CartesiaTTSService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
from pipecat.vad.silero import SileroVADAnalyzer
|
||||
|
||||
from runner import configure
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
@@ -69,17 +64,17 @@ async def main():
|
||||
},
|
||||
]
|
||||
|
||||
tma_in = LLMUserResponseAggregator(messages)
|
||||
tma_out = LLMAssistantResponseAggregator(messages)
|
||||
context = OpenAILLMContext(messages)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
tma_in, # User responses
|
||||
context_aggregator.user(), # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
tma_out, # Assistant spoken responses
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -4,11 +4,15 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.frames.frames import LLMMessagesFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
@@ -17,17 +21,11 @@ from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantResponseAggregator,
|
||||
LLMUserResponseAggregator,
|
||||
)
|
||||
from pipecat.services.playht import PlayHTTTSService
|
||||
from pipecat.services.openai import OpenAILLMService
|
||||
from pipecat.services.playht import PlayHTTTSService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
from pipecat.vad.silero import SileroVADAnalyzer
|
||||
|
||||
from runner import configure
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
|
||||
@@ -4,11 +4,15 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.frames.frames import LLMMessagesFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
@@ -17,17 +21,10 @@ from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantResponseAggregator,
|
||||
LLMUserResponseAggregator,
|
||||
)
|
||||
from pipecat.services.openai import OpenAITTSService
|
||||
from pipecat.services.openai import OpenAILLMService
|
||||
from pipecat.services.openai import OpenAILLMService, OpenAITTSService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
from pipecat.vad.silero import SileroVADAnalyzer
|
||||
|
||||
from runner import configure
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
|
||||
@@ -5,29 +5,24 @@
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import os
|
||||
import sys
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.frames.frames import LLMMessagesFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantResponseAggregator,
|
||||
LLMUserResponseAggregator,
|
||||
)
|
||||
from pipecat.services.ai_services import OpenAILLMContext
|
||||
from pipecat.services.cartesia import CartesiaTTSService
|
||||
from pipecat.services.together import TogetherLLMService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
from pipecat.vad.silero import SileroVADAnalyzer
|
||||
|
||||
from runner import configure
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
@@ -72,25 +67,32 @@ async def main():
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond in plain language. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
tma_in = LLMUserResponseAggregator(messages)
|
||||
tma_out = LLMAssistantResponseAggregator(messages)
|
||||
context = OpenAILLMContext(messages)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
user_aggregator = context_aggregator.user()
|
||||
assistant_aggregator = context_aggregator.assistant()
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
tma_in, # User responses
|
||||
user_aggregator, # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
tma_out, # Assistant spoken responses
|
||||
assistant_aggregator, # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True))
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
PipelineParams(
|
||||
allow_interruptions=True, enable_metrics=True, enable_usage_metrics=True
|
||||
),
|
||||
)
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
|
||||
@@ -53,7 +53,6 @@ async def main():
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
tts = GoogleTTSService(
|
||||
credentials=os.getenv("GOOGLE_CREDENTIALS"),
|
||||
voice_id="en-US-Neural2-J",
|
||||
params=GoogleTTSService.InputParams(language="en-US", rate="1.05"),
|
||||
)
|
||||
|
||||
@@ -14,7 +14,7 @@ from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.services.deepgram import DeepgramSTTService
|
||||
from pipecat.services.deepgram import DeepgramSTTService, LiveOptions, Language
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
|
||||
from runner import configure
|
||||
@@ -45,7 +45,10 @@ async def main():
|
||||
room_url, None, "Transcription bot", DailyParams(audio_in_enabled=True)
|
||||
)
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
stt = DeepgramSTTService(
|
||||
api_key=os.getenv("DEEPGRAM_API_KEY"),
|
||||
# live_options=LiveOptions(language=Language.FR),
|
||||
)
|
||||
|
||||
tl = TranscriptionLogger()
|
||||
|
||||
|
||||
@@ -9,11 +9,9 @@ import aiohttp
|
||||
import os
|
||||
import sys
|
||||
|
||||
from pipecat.frames.frames import TextFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.processors.logger import FrameLogger
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.services.cartesia import CartesiaTTSService
|
||||
from pipecat.services.openai import OpenAILLMContext, OpenAILLMService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
@@ -34,7 +32,12 @@ logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
async def start_fetch_weather(function_name, llm, context):
|
||||
await llm.push_frame(TextFrame("Let me check on that."))
|
||||
# note: we can't push a frame to the LLM here. the bot
|
||||
# can interrupt itself and/or cause audio overlapping glitches.
|
||||
# possible question for Aleix and Chad about what the right way
|
||||
# to trigger speech is, now, with the new queues/async/sync refactors.
|
||||
# await llm.push_frame(TextFrame("Let me check on that."))
|
||||
logger.debug(f"Starting fetch_weather_from_api with function_name: {function_name}")
|
||||
|
||||
|
||||
async def fetch_weather_from_api(function_name, tool_call_id, args, llm, context, result_callback):
|
||||
@@ -67,9 +70,6 @@ async def main():
|
||||
# sent to the same callback with an additional function_name parameter.
|
||||
llm.register_function(None, fetch_weather_from_api, start_callback=start_fetch_weather)
|
||||
|
||||
fl_in = FrameLogger("Inner")
|
||||
fl_out = FrameLogger("Outer")
|
||||
|
||||
tools = [
|
||||
ChatCompletionToolParam(
|
||||
type="function",
|
||||
@@ -106,24 +106,30 @@ async def main():
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
fl_in,
|
||||
transport.input(),
|
||||
context_aggregator.user(),
|
||||
llm,
|
||||
fl_out,
|
||||
tts,
|
||||
transport.output(),
|
||||
context_aggregator.assistant(),
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(pipeline)
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
PipelineParams(
|
||||
allow_interruptions=True,
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
report_only_initial_ttfb=True,
|
||||
),
|
||||
)
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
transport.capture_participant_transcription(participant["id"])
|
||||
# Kick off the conversation.
|
||||
await tts.say("Hi! Ask me about the weather in San Francisco.")
|
||||
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
|
||||
136
examples/foundational/14c-function-calling-together.py
Normal file
136
examples/foundational/14c-function-calling-together.py
Normal file
@@ -0,0 +1,136 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import os
|
||||
import sys
|
||||
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.services.cartesia import CartesiaTTSService
|
||||
from pipecat.services.openai import OpenAILLMContext
|
||||
from pipecat.services.together import TogetherLLMService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
from pipecat.vad.silero import SileroVADAnalyzer
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from runner import configure
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
async def start_fetch_weather(function_name, llm, context):
|
||||
# note: we can't push a frame to the LLM here. the bot
|
||||
# can interrupt itself and/or cause audio overlapping glitches.
|
||||
# possible question for Aleix and Chad about what the right way
|
||||
# to trigger speech is, now, with the new queues/async/sync refactors.
|
||||
# await llm.push_frame(TextFrame("Let me check on that."))
|
||||
logger.debug(f"Starting fetch_weather_from_api with function_name: {function_name}")
|
||||
|
||||
|
||||
async def fetch_weather_from_api(function_name, tool_call_id, args, llm, context, result_callback):
|
||||
await result_callback({"conditions": "nice", "temperature": "75"})
|
||||
|
||||
|
||||
async def main():
|
||||
async with aiohttp.ClientSession() as session:
|
||||
(room_url, token) = await configure(session)
|
||||
|
||||
transport = DailyTransport(
|
||||
room_url,
|
||||
token,
|
||||
"Respond bot",
|
||||
DailyParams(
|
||||
audio_out_enabled=True,
|
||||
transcription_enabled=True,
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
)
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
|
||||
)
|
||||
|
||||
llm = TogetherLLMService(
|
||||
api_key=os.getenv("TOGETHER_API_KEY"),
|
||||
model="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
|
||||
)
|
||||
# Register a function_name of None to get all functions
|
||||
# sent to the same callback with an additional function_name parameter.
|
||||
llm.register_function(None, fetch_weather_from_api, start_callback=start_fetch_weather)
|
||||
|
||||
tools = [
|
||||
ChatCompletionToolParam(
|
||||
type="function",
|
||||
function={
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location.",
|
||||
},
|
||||
},
|
||||
"required": ["location", "format"],
|
||||
},
|
||||
},
|
||||
)
|
||||
]
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages, tools)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
context_aggregator.user(),
|
||||
llm,
|
||||
tts,
|
||||
transport.output(),
|
||||
context_aggregator.assistant(),
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(pipeline)
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
transport.capture_participant_transcription(participant["id"])
|
||||
# Kick off the conversation.
|
||||
# await tts.say("Hi! Ask me about the weather in San Francisco.")
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
167
examples/foundational/14d-function-calling-video.py
Normal file
167
examples/foundational/14d-function-calling-video.py
Normal file
@@ -0,0 +1,167 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import os
|
||||
import sys
|
||||
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.services.cartesia import CartesiaTTSService
|
||||
from pipecat.services.openai import OpenAILLMContext, OpenAILLMService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
from pipecat.vad.silero import SileroVADAnalyzer
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from runner import configure
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
video_participant_id = None
|
||||
|
||||
|
||||
async def get_weather(function_name, tool_call_id, arguments, llm, context, result_callback):
|
||||
location = arguments["location"]
|
||||
await result_callback(f"The weather in {location} is currently 72 degrees and sunny.")
|
||||
|
||||
|
||||
async def get_image(function_name, tool_call_id, arguments, llm, context, result_callback):
|
||||
logger.debug(f"!!! IN get_image {video_participant_id}, {arguments}")
|
||||
question = arguments["question"]
|
||||
await llm.request_image_frame(user_id=video_participant_id, text_content=question)
|
||||
|
||||
|
||||
async def main():
|
||||
async with aiohttp.ClientSession() as session:
|
||||
(room_url, token) = await configure(session)
|
||||
|
||||
transport = DailyTransport(
|
||||
room_url,
|
||||
token,
|
||||
"Respond bot",
|
||||
DailyParams(
|
||||
audio_out_enabled=True,
|
||||
transcription_enabled=True,
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
)
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o")
|
||||
llm.register_function("get_weather", get_weather)
|
||||
llm.register_function("get_image", get_image)
|
||||
|
||||
tools = [
|
||||
ChatCompletionToolParam(
|
||||
type="function",
|
||||
function={
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location.",
|
||||
},
|
||||
},
|
||||
"required": ["location", "format"],
|
||||
},
|
||||
},
|
||||
),
|
||||
ChatCompletionToolParam(
|
||||
type="function",
|
||||
function={
|
||||
"name": "get_image",
|
||||
"description": "Get an image from the video stream.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"question": {
|
||||
"type": "string",
|
||||
"description": "The question to ask the AI to generate an image of",
|
||||
},
|
||||
},
|
||||
"required": ["question"],
|
||||
},
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
system_prompt = """\
|
||||
You are a helpful assistant who converses with a user and answers questions. Respond concisely to general questions.
|
||||
|
||||
Your response will be turned into speech so use only simple words and punctuation.
|
||||
|
||||
You have access to two tools: get_weather and get_image.
|
||||
|
||||
You can respond to questions about the weather using the get_weather tool.
|
||||
|
||||
You can answer questions about the user's video stream using the get_image tool. Some examples of phrases that \
|
||||
indicate you should use the get_image tool are:
|
||||
- What do you see?
|
||||
- What's in the video?
|
||||
- Can you describe the video?
|
||||
- Tell me about what you see.
|
||||
- Tell me something interesting about what you see.
|
||||
- What's happening in the video?
|
||||
"""
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages, tools)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
context_aggregator.user(),
|
||||
llm,
|
||||
tts,
|
||||
transport.output(),
|
||||
context_aggregator.assistant(),
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(pipeline)
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
global video_participant_id
|
||||
video_participant_id = participant["id"]
|
||||
transport.capture_participant_transcription(participant["id"])
|
||||
transport.capture_participant_video(video_participant_id, framerate=0)
|
||||
# Kick off the conversation.
|
||||
await tts.say("Hi! Ask me about the weather in San Francisco.")
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -5,10 +5,14 @@
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import os
|
||||
import sys
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.frames.frames import LLMMessagesFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
@@ -26,12 +30,6 @@ from pipecat.transports.services.daily import (
|
||||
)
|
||||
from pipecat.vad.silero import SileroVADAnalyzer
|
||||
|
||||
from runner import configure
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
|
||||
164
examples/foundational/19-openai-realtime-beta.py
Normal file
164
examples/foundational/19-openai-realtime-beta.py
Normal file
@@ -0,0 +1,164 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
)
|
||||
from pipecat.services.openai_realtime_beta import (
|
||||
InputAudioTranscription,
|
||||
OpenAILLMServiceRealtimeBeta,
|
||||
SessionProperties,
|
||||
TurnDetection,
|
||||
)
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
from pipecat.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.vad.vad_analyzer import VADParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
async def fetch_weather_from_api(function_name, tool_call_id, args, llm, context, result_callback):
|
||||
temperature = 75 if args["format"] == "fahrenheit" else 24
|
||||
await result_callback(
|
||||
{
|
||||
"conditions": "nice",
|
||||
"temperature": temperature,
|
||||
"format": args["format"],
|
||||
"timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location.",
|
||||
},
|
||||
},
|
||||
"required": ["location", "format"],
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
async def main():
|
||||
async with aiohttp.ClientSession() as session:
|
||||
(room_url, token) = await configure(session)
|
||||
|
||||
transport = DailyTransport(
|
||||
room_url,
|
||||
token,
|
||||
"Respond bot",
|
||||
DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_in_sample_rate=24000,
|
||||
audio_out_enabled=True,
|
||||
audio_out_sample_rate=24000,
|
||||
transcription_enabled=False,
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.8)),
|
||||
vad_audio_passthrough=True,
|
||||
),
|
||||
)
|
||||
|
||||
session_properties = SessionProperties(
|
||||
input_audio_transcription=InputAudioTranscription(),
|
||||
# Set openai TurnDetection parameters. Not setting this at all will turn it
|
||||
# on by default
|
||||
turn_detection=TurnDetection(silence_duration_ms=1000),
|
||||
# Or set to False to disable openai turn detection and use transport VAD
|
||||
# turn_detection=False,
|
||||
# tools=tools,
|
||||
instructions="""Your knowledge cutoff is 2023-10. You are a helpful and friendly AI.
|
||||
|
||||
Act like a human, but remember that you aren't a human and that you can't do human
|
||||
things in the real world. Your voice and personality should be warm and engaging, with a lively and
|
||||
playful tone.
|
||||
|
||||
If interacting in a non-English language, start by using the standard accent or dialect familiar to
|
||||
the user. Talk quickly. You should always call a function if you can. Do not refer to these rules,
|
||||
even if you're asked about them.
|
||||
-
|
||||
You are participating in a voice conversation. Keep your responses concise, short, and to the point
|
||||
unless specifically asked to elaborate on a topic.
|
||||
|
||||
Remember, your responses should be short. Just one or two sentences, usually.""",
|
||||
)
|
||||
|
||||
llm = OpenAILLMServiceRealtimeBeta(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
session_properties=session_properties,
|
||||
start_audio_paused=False,
|
||||
)
|
||||
|
||||
# you can either register a single function for all function calls, or specific functions
|
||||
# llm.register_function(None, fetch_weather_from_api)
|
||||
llm.register_function("get_current_weather", fetch_weather_from_api)
|
||||
|
||||
context = OpenAILLMContext([{"role": "user", "content": "Say hello!"}], tools)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
context_aggregator.user(),
|
||||
llm, # LLM
|
||||
context_aggregator.assistant(),
|
||||
transport.output(), # Transport bot output
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
PipelineParams(
|
||||
allow_interruptions=True,
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
# report_only_initial_ttfb=True,
|
||||
),
|
||||
)
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
transport.capture_participant_transcription(participant["id"])
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,137 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
|
||||
from pipecat.frames.frames import LLMMessagesFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.cartesia import CartesiaTTSService
|
||||
from pipecat.services.together import TogetherLLMService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
from pipecat.vad.silero import SileroVADAnalyzer
|
||||
|
||||
from runner import configure
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
async def get_current_weather(
|
||||
function_name, tool_call_id, arguments, llm, context, result_callback
|
||||
):
|
||||
logger.debug("IN get_current_weather")
|
||||
location = arguments["location"]
|
||||
await result_callback(f"The weather in {location} is currently 72 degrees and sunny.")
|
||||
|
||||
|
||||
async def main():
|
||||
async with aiohttp.ClientSession() as session:
|
||||
(room_url, token) = await configure(session)
|
||||
|
||||
transport = DailyTransport(
|
||||
room_url,
|
||||
token,
|
||||
"Respond bot",
|
||||
DailyParams(
|
||||
audio_out_enabled=True,
|
||||
transcription_enabled=True,
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
)
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
|
||||
)
|
||||
|
||||
llm = TogetherLLMService(
|
||||
api_key=os.getenv("TOGETHER_API_KEY"),
|
||||
model=os.getenv("TOGETHER_MODEL"),
|
||||
)
|
||||
llm.register_function("get_current_weather", get_current_weather)
|
||||
|
||||
weatherTool = {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
}
|
||||
|
||||
system_prompt = f"""\
|
||||
You have access to the following functions:
|
||||
|
||||
Use the function '{weatherTool["name"]}' to '{weatherTool["description"]}':
|
||||
{json.dumps(weatherTool)}
|
||||
|
||||
If you choose to call a function ONLY reply in the following format with no prefix or suffix:
|
||||
|
||||
<function=example_function_name>{{\"example_name\": \"example_value\"}}</function>
|
||||
|
||||
Reminder:
|
||||
- Function calls MUST follow the specified format, start with <function= and end with </function>
|
||||
- Required parameters MUST be specified
|
||||
- Only call one function at a time
|
||||
- Put the entire function call reply on one line
|
||||
- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls
|
||||
|
||||
"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": "Wait for the user to say something."},
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
context_aggregator.user(), # User speech to text
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses and tool context
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True, enable_metrics=True))
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
transport.capture_participant_transcription(participant["id"])
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([LLMMessagesFrame(messages)])
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
236
examples/foundational/20a-persistent-context-openai.py
Normal file
236
examples/foundational/20a-persistent-context-openai.py
Normal file
@@ -0,0 +1,236 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
)
|
||||
from pipecat.services.openai import OpenAILLMService
|
||||
from pipecat.services.cartesia import CartesiaTTSService
|
||||
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
from pipecat.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.vad.vad_analyzer import VADParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
BASE_FILENAME = "/tmp/pipecat_conversation_"
|
||||
tts = None
|
||||
|
||||
|
||||
async def fetch_weather_from_api(function_name, tool_call_id, args, llm, context, result_callback):
|
||||
temperature = 75 if args["format"] == "fahrenheit" else 24
|
||||
await result_callback(
|
||||
{
|
||||
"conditions": "nice",
|
||||
"temperature": temperature,
|
||||
"format": args["format"],
|
||||
"timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def get_saved_conversation_filenames(
|
||||
function_name, tool_call_id, args, llm, context, result_callback
|
||||
):
|
||||
# Construct the full pattern including the BASE_FILENAME
|
||||
full_pattern = f"{BASE_FILENAME}*.json"
|
||||
|
||||
# Use glob to find all matching files
|
||||
matching_files = glob.glob(full_pattern)
|
||||
logger.debug(f"matching files: {matching_files}")
|
||||
|
||||
await result_callback({"filenames": matching_files})
|
||||
|
||||
|
||||
async def save_conversation(function_name, tool_call_id, args, llm, context, result_callback):
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
|
||||
filename = f"{BASE_FILENAME}{timestamp}.json"
|
||||
logger.debug(f"writing conversation to {filename}\n{json.dumps(context.messages, indent=4)}")
|
||||
try:
|
||||
with open(filename, "w") as file:
|
||||
messages = context.get_messages_for_persistent_storage()
|
||||
# remove the last message, which is the instruction we just gave to save the conversation
|
||||
messages.pop()
|
||||
json.dump(messages, file, indent=2)
|
||||
await result_callback({"success": True})
|
||||
except Exception as e:
|
||||
await result_callback({"success": False, "error": str(e)})
|
||||
|
||||
|
||||
async def load_conversation(function_name, tool_call_id, args, llm, context, result_callback):
|
||||
global tts
|
||||
filename = args["filename"]
|
||||
logger.debug(f"loading conversation from {filename}")
|
||||
try:
|
||||
with open(filename, "r") as file:
|
||||
context.set_messages(json.load(file))
|
||||
logger.debug(
|
||||
f"loaded conversation from {filename}\n{json.dumps(context.messages, indent=4)}"
|
||||
)
|
||||
await tts.say("Ok, I've loaded that conversation.")
|
||||
except Exception as e:
|
||||
await result_callback({"success": False, "error": str(e)})
|
||||
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location.",
|
||||
},
|
||||
},
|
||||
"required": ["location", "format"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "save_conversation",
|
||||
"description": "Save the current conversatione. Use this function to persist the current conversation to external storage.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_saved_conversation_filenames",
|
||||
"description": "Get a list of saved conversation histories. Returns a list of filenames. Each filename includes a date and timestamp. Each file is conversation history that can be loaded into this session.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "load_conversation",
|
||||
"description": "Load a conversation history. Use this function to load a conversation history into the current session.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filename": {
|
||||
"type": "string",
|
||||
"description": "The filename of the conversation history to load.",
|
||||
}
|
||||
},
|
||||
"required": ["filename"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
async def main():
|
||||
global tts
|
||||
async with aiohttp.ClientSession() as session:
|
||||
(room_url, token) = await configure(session)
|
||||
|
||||
transport = DailyTransport(
|
||||
room_url,
|
||||
token,
|
||||
"Respond bot",
|
||||
DailyParams(
|
||||
audio_out_enabled=True,
|
||||
transcription_enabled=True,
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.8)),
|
||||
),
|
||||
)
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o")
|
||||
|
||||
# you can either register a single function for all function calls, or specific functions
|
||||
# llm.register_function(None, fetch_weather_from_api)
|
||||
llm.register_function("get_current_weather", fetch_weather_from_api)
|
||||
llm.register_function("save_conversation", save_conversation)
|
||||
llm.register_function("get_saved_conversation_filenames", get_saved_conversation_filenames)
|
||||
llm.register_function("load_conversation", load_conversation)
|
||||
|
||||
context = OpenAILLMContext(messages, tools)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
context_aggregator.user(),
|
||||
llm, # LLM
|
||||
tts,
|
||||
context_aggregator.assistant(),
|
||||
transport.output(), # Transport bot output
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
PipelineParams(
|
||||
allow_interruptions=True,
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
# report_only_initial_ttfb=True,
|
||||
),
|
||||
)
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
transport.capture_participant_transcription(participant["id"])
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
262
examples/foundational/20b-persistent-context-openai-realtime.py
Normal file
262
examples/foundational/20b-persistent-context-openai-realtime.py
Normal file
@@ -0,0 +1,262 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
)
|
||||
from pipecat.services.openai_realtime_beta import (
|
||||
InputAudioTranscription,
|
||||
OpenAILLMServiceRealtimeBeta,
|
||||
SessionProperties,
|
||||
TurnDetection,
|
||||
)
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
from pipecat.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.vad.vad_analyzer import VADParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
BASE_FILENAME = "/tmp/pipecat_conversation_"
|
||||
|
||||
|
||||
async def fetch_weather_from_api(function_name, tool_call_id, args, llm, context, result_callback):
|
||||
temperature = 75 if args["format"] == "fahrenheit" else 24
|
||||
await result_callback(
|
||||
{
|
||||
"conditions": "nice",
|
||||
"temperature": temperature,
|
||||
"format": args["format"],
|
||||
"timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def get_saved_conversation_filenames(
|
||||
function_name, tool_call_id, args, llm, context, result_callback
|
||||
):
|
||||
# Construct the full pattern including the BASE_FILENAME
|
||||
full_pattern = f"{BASE_FILENAME}*.json"
|
||||
|
||||
# Use glob to find all matching files
|
||||
matching_files = glob.glob(full_pattern)
|
||||
logger.debug(f"matching files: {matching_files}")
|
||||
|
||||
await result_callback({"filenames": matching_files})
|
||||
|
||||
|
||||
# async def get_saved_conversation_filenames(
|
||||
# function_name, tool_call_id, args, llm, context, result_callback
|
||||
# ):
|
||||
# pattern = re.compile(re.escape(BASE_FILENAME) + "\\d{8}_\\d{6}\\.json$")
|
||||
# matching_files = []
|
||||
|
||||
# for filename in os.listdir("."):
|
||||
# if pattern.match(filename):
|
||||
# matching_files.append(filename)
|
||||
|
||||
# await result_callback({"filenames": matching_files})
|
||||
|
||||
|
||||
async def save_conversation(function_name, tool_call_id, args, llm, context, result_callback):
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
|
||||
filename = f"{BASE_FILENAME}{timestamp}.json"
|
||||
logger.debug(f"writing conversation to {filename}\n{json.dumps(context.messages, indent=4)}")
|
||||
try:
|
||||
with open(filename, "w") as file:
|
||||
messages = context.get_messages_for_persistent_storage()
|
||||
# remove the last message, which is the instruction we just gave to save the conversation
|
||||
messages.pop()
|
||||
json.dump(messages, file, indent=2)
|
||||
await result_callback({"success": True})
|
||||
except Exception as e:
|
||||
await result_callback({"success": False, "error": str(e)})
|
||||
|
||||
|
||||
async def load_conversation(function_name, tool_call_id, args, llm, context, result_callback):
|
||||
async def _reset():
|
||||
filename = args["filename"]
|
||||
logger.debug(f"loading conversation from {filename}")
|
||||
try:
|
||||
with open(filename, "r") as file:
|
||||
context.set_messages(json.load(file))
|
||||
await llm.reset_conversation()
|
||||
await llm._create_response()
|
||||
except Exception as e:
|
||||
await result_callback({"success": False, "error": str(e)})
|
||||
|
||||
asyncio.create_task(_reset())
|
||||
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location.",
|
||||
},
|
||||
},
|
||||
"required": ["location", "format"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"name": "save_conversation",
|
||||
"description": "Save the current conversatione. Use this function to persist the current conversation to external storage.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"name": "get_saved_conversation_filenames",
|
||||
"description": "Get a list of saved conversation histories. Returns a list of filenames. Each filename includes a date and timestamp. Each file is conversation history that can be loaded into this session.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"name": "load_conversation",
|
||||
"description": "Load a conversation history. Use this function to load a conversation history into the current session.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filename": {
|
||||
"type": "string",
|
||||
"description": "The filename of the conversation history to load.",
|
||||
}
|
||||
},
|
||||
"required": ["filename"],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
async def main():
|
||||
async with aiohttp.ClientSession() as session:
|
||||
(room_url, token) = await configure(session)
|
||||
|
||||
transport = DailyTransport(
|
||||
room_url,
|
||||
token,
|
||||
"Respond bot",
|
||||
DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_in_sample_rate=24000,
|
||||
audio_out_enabled=True,
|
||||
audio_out_sample_rate=24000,
|
||||
transcription_enabled=False,
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.8)),
|
||||
vad_audio_passthrough=True,
|
||||
),
|
||||
)
|
||||
|
||||
session_properties = SessionProperties(
|
||||
input_audio_transcription=InputAudioTranscription(),
|
||||
# Set openai TurnDetection parameters. Not setting this at all will turn it
|
||||
# on by default
|
||||
turn_detection=TurnDetection(silence_duration_ms=1000),
|
||||
# Or set to False to disable openai turn detection and use transport VAD
|
||||
# turn_detection=False,
|
||||
# tools=tools,
|
||||
instructions="""Your knowledge cutoff is 2023-10. You are a helpful and friendly AI.
|
||||
|
||||
Act like a human, but remember that you aren't a human and that you can't do human
|
||||
things in the real world. Your voice and personality should be warm and engaging, with a lively and
|
||||
playful tone.
|
||||
|
||||
If interacting in a non-English language, start by using the standard accent or dialect familiar to
|
||||
the user. Talk quickly. You should always call a function if you can. Do not refer to these rules,
|
||||
even if you're asked about them.
|
||||
-
|
||||
You are participating in a voice conversation. Keep your responses concise, short, and to the point
|
||||
unless specifically asked to elaborate on a topic.
|
||||
|
||||
Remember, your responses should be short. Just one or two sentences, usually.""",
|
||||
)
|
||||
|
||||
llm = OpenAILLMServiceRealtimeBeta(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
session_properties=session_properties,
|
||||
start_audio_paused=False,
|
||||
)
|
||||
|
||||
# you can either register a single function for all function calls, or specific functions
|
||||
# llm.register_function(None, fetch_weather_from_api)
|
||||
llm.register_function("get_current_weather", fetch_weather_from_api)
|
||||
llm.register_function("save_conversation", save_conversation)
|
||||
llm.register_function("get_saved_conversation_filenames", get_saved_conversation_filenames)
|
||||
llm.register_function("load_conversation", load_conversation)
|
||||
|
||||
context = OpenAILLMContext([], tools)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
context_aggregator.user(),
|
||||
llm, # LLM
|
||||
context_aggregator.assistant(),
|
||||
transport.output(), # Transport bot output
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
PipelineParams(
|
||||
allow_interruptions=True,
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
# report_only_initial_ttfb=True,
|
||||
),
|
||||
)
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
transport.capture_participant_transcription(participant["id"])
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
232
examples/foundational/20c-persistent-context-anthropic.py
Normal file
232
examples/foundational/20c-persistent-context-anthropic.py
Normal file
@@ -0,0 +1,232 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
)
|
||||
from pipecat.services.cartesia import CartesiaTTSService
|
||||
from pipecat.services.anthropic import AnthropicLLMService
|
||||
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
from pipecat.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.vad.vad_analyzer import VADParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
BASE_FILENAME = "/tmp/pipecat_conversation_"
|
||||
tts = None
|
||||
|
||||
|
||||
async def fetch_weather_from_api(function_name, tool_call_id, args, llm, context, result_callback):
|
||||
temperature = 75 if args["format"] == "fahrenheit" else 24
|
||||
await result_callback(
|
||||
{
|
||||
"conditions": "nice",
|
||||
"temperature": temperature,
|
||||
"format": args["format"],
|
||||
"timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def get_saved_conversation_filenames(
|
||||
function_name, tool_call_id, args, llm, context, result_callback
|
||||
):
|
||||
# Construct the full pattern including the BASE_FILENAME
|
||||
full_pattern = f"{BASE_FILENAME}*.json"
|
||||
|
||||
# Use glob to find all matching files
|
||||
matching_files = glob.glob(full_pattern)
|
||||
logger.debug(f"matching files: {matching_files}")
|
||||
|
||||
await result_callback({"filenames": matching_files})
|
||||
|
||||
|
||||
async def save_conversation(function_name, tool_call_id, args, llm, context, result_callback):
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
|
||||
filename = f"{BASE_FILENAME}{timestamp}.json"
|
||||
logger.debug(f"writing conversation to {filename}\n{json.dumps(context.messages, indent=4)}")
|
||||
try:
|
||||
with open(filename, "w") as file:
|
||||
# todo: extract 'system' into the first message in the list
|
||||
messages = context.get_messages_for_persistent_storage()
|
||||
# remove the last message, which is the instruction we just gave to save the conversation
|
||||
messages.pop()
|
||||
json.dump(messages, file, indent=2)
|
||||
await result_callback({"success": True})
|
||||
except Exception as e:
|
||||
await result_callback({"success": False, "error": str(e)})
|
||||
|
||||
|
||||
async def load_conversation(function_name, tool_call_id, args, llm, context, result_callback):
|
||||
global tts
|
||||
filename = args["filename"]
|
||||
logger.debug(f"loading conversation from {filename}")
|
||||
try:
|
||||
with open(filename, "r") as file:
|
||||
context.set_messages(json.load(file))
|
||||
logger.debug(
|
||||
f"loaded conversation from {filename}\n{json.dumps(context.messages, indent=4)}"
|
||||
)
|
||||
await tts.say("Ok, I've loaded that conversation.")
|
||||
except Exception as e:
|
||||
await result_callback({"success": False, "error": str(e)})
|
||||
|
||||
|
||||
# Test message munging ...
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
{"role": "user", "content": ""},
|
||||
{"role": "assistant", "content": []},
|
||||
{"role": "user", "content": "Tell me"},
|
||||
{"role": "user", "content": "a joke"},
|
||||
]
|
||||
tools = [
|
||||
{
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location.",
|
||||
},
|
||||
},
|
||||
"required": ["location", "format"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "save_conversation",
|
||||
"description": "Save the current conversation. Use this function to persist the current conversation to external storage.",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "get_saved_conversation_filenames",
|
||||
"description": "Get a list of saved conversation histories. Returns a list of filenames. Each filename includes a date and timestamp. Each file is conversation history that can be loaded into this session.",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "load_conversation",
|
||||
"description": "Load a conversation history. Use this function to load a conversation history into the current session.",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filename": {
|
||||
"type": "string",
|
||||
"description": "The filename of the conversation history to load.",
|
||||
}
|
||||
},
|
||||
"required": ["filename"],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
async def main():
|
||||
global tts
|
||||
async with aiohttp.ClientSession() as session:
|
||||
(room_url, token) = await configure(session)
|
||||
|
||||
transport = DailyTransport(
|
||||
room_url,
|
||||
token,
|
||||
"Respond bot",
|
||||
DailyParams(
|
||||
audio_out_enabled=True,
|
||||
transcription_enabled=True,
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.8)),
|
||||
),
|
||||
)
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
|
||||
)
|
||||
|
||||
llm = AnthropicLLMService(
|
||||
api_key=os.getenv("ANTHROPIC_API_KEY"), model="claude-3-5-sonnet-20240620"
|
||||
)
|
||||
|
||||
# you can either register a single function for all function calls, or specific functions
|
||||
# llm.register_function(None, fetch_weather_from_api)
|
||||
llm.register_function("get_current_weather", fetch_weather_from_api)
|
||||
llm.register_function("save_conversation", save_conversation)
|
||||
llm.register_function("get_saved_conversation_filenames", get_saved_conversation_filenames)
|
||||
llm.register_function("load_conversation", load_conversation)
|
||||
|
||||
context = OpenAILLMContext(messages, tools)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
context_aggregator.user(),
|
||||
llm, # LLM
|
||||
tts,
|
||||
context_aggregator.assistant(),
|
||||
transport.output(), # Transport bot output
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
PipelineParams(
|
||||
allow_interruptions=True,
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
# report_only_initial_ttfb=True,
|
||||
),
|
||||
)
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
transport.capture_participant_transcription(participant["id"])
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,4 +1,4 @@
|
||||
DAILY_SAMPLE_ROOM_URL=https://yourdomain.daily.co/yourroom # (for joining the bot to the same room repeatedly for local dev)
|
||||
DAILY_API_KEY=7df...
|
||||
OPENAI_API_KEY=sk-PL...
|
||||
ELEVENLABS_API_KEY=aeb...
|
||||
CARTESIA_API_KEY=your_cartesia_api_key_here
|
||||
|
||||
@@ -1,12 +1,39 @@
|
||||
# Simple Chatbot
|
||||
# Patient-intake chatbot
|
||||
|
||||
<img src="image.png" width="420px">
|
||||
|
||||
This app connects you to a chatbot powered by GPT-4, complete with animations generated by Stable Video Diffusion.
|
||||
This project implements an AI-powered chatbot designed to streamline the medical intake process for Tri-County Health Services. The chatbot, named Jessica, interacts with patients to collect essential information before their doctor's visit, enhancing efficiency and improving the patient experience.
|
||||
|
||||
See a video of it in action: https://x.com/kwindla/status/1778628911817183509
|
||||
## Features
|
||||
|
||||
And a quick video walkthrough of the code: https://www.loom.com/share/13df1967161f4d24ade054e7f8753416
|
||||
Identity Verification: Confirms patient identity by verifying their date of birth.
|
||||
Prescription Information: Collects details about current medications and dosages.
|
||||
Allergy Documentation: Records patient allergies.
|
||||
Medical Conditions: Gathers information about existing medical conditions.
|
||||
Reason for Visit: Asks patients about the purpose of their current doctor's visit.
|
||||
|
||||
## Technical Stack
|
||||
|
||||
Language: Python
|
||||
AI Model: OpenAI's GPT-4
|
||||
Text-to-Speech: Cartesia TTS Service
|
||||
Audio Processing: Silero VAD (Voice Activity Detection)
|
||||
Real-time Communication: Daily.co API
|
||||
|
||||
## Key Components
|
||||
|
||||
IntakeProcessor: Manages the conversation flow and information gathering process.
|
||||
DailyTransport: Handles real-time audio communication.
|
||||
CartesiaTTSService: Converts text responses to speech.
|
||||
OpenAILLMService: Processes natural language and generates appropriate responses.
|
||||
Pipeline: Orchestrates the flow of information between different components.
|
||||
|
||||
How It Works
|
||||
|
||||
The chatbot introduces itself and verifies the patient's identity.
|
||||
It systematically collects information about prescriptions, allergies, medical conditions, and the reason for the visit.
|
||||
The conversation is guided by a series of function calls that transition between different stages of the intake process.
|
||||
All collected information is logged for later use by medical professionals.
|
||||
|
||||
ℹ️ The first time, things might take extra time to get started since VAD (Voice Activity Detection) model needs to be downloaded.
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
DAILY_SAMPLE_ROOM_URL=https://yourdomain.daily.co/yourroom # (for joining the bot to the same room repeatedly for local dev)
|
||||
DAILY_API_KEY=7df...
|
||||
OPENAI_API_KEY=sk-PL...
|
||||
ELEVENLABS_API_KEY=aeb...
|
||||
CARTESIA_API_KEY=your_cartesia_api_key_here
|
||||
|
||||
@@ -122,7 +122,7 @@ if __name__ == "__main__":
|
||||
default_host = os.getenv("HOST", "0.0.0.0")
|
||||
default_port = int(os.getenv("FAST_API_PORT", "7860"))
|
||||
|
||||
parser = argparse.ArgumentParser(description="Daily Storyteller FastAPI server")
|
||||
parser = argparse.ArgumentParser(description="Daily patient-intake FastAPI server")
|
||||
parser.add_argument("--host", type=str, default=default_host, help="Host address")
|
||||
parser.add_argument("--port", type=int, default=default_port, help="Port number")
|
||||
parser.add_argument("--reload", action="store_true", help="Reload code on change")
|
||||
|
||||
2497
examples/storytelling-chatbot/frontend/package-lock.json
generated
2497
examples/storytelling-chatbot/frontend/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -11,28 +11,28 @@
|
||||
"dependencies": {
|
||||
"@daily-co/daily-js": "^0.62.0",
|
||||
"@daily-co/daily-react": "^0.18.0",
|
||||
"@radix-ui/react-select": "^2.0.0",
|
||||
"@radix-ui/react-select": "^2.1.2",
|
||||
"@radix-ui/react-slot": "^1.0.2",
|
||||
"@tabler/icons-react": "^3.1.0",
|
||||
"@tabler/icons-react": "^3.19.0",
|
||||
"class-variance-authority": "^0.7.0",
|
||||
"clsx": "^2.1.0",
|
||||
"framer-motion": "^11.0.27",
|
||||
"next": "14.1.4",
|
||||
"react": "^18",
|
||||
"react-dom": "^18",
|
||||
"clsx": "^2.1.1",
|
||||
"framer-motion": "^11.9.0",
|
||||
"next": "^14.2.14",
|
||||
"react": "^18.3.1",
|
||||
"react-dom": "^18.3.1",
|
||||
"recoil": "^0.7.7",
|
||||
"tailwind-merge": "^2.2.2",
|
||||
"tailwind-merge": "^2.5.2",
|
||||
"tailwindcss-animate": "^1.0.7"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/node": "^20",
|
||||
"@types/react": "^18",
|
||||
"@types/react-dom": "^18",
|
||||
"autoprefixer": "^10.0.1",
|
||||
"eslint": "^8",
|
||||
"@types/node": "^20.16.10",
|
||||
"@types/react": "^18.3.11",
|
||||
"@types/react-dom": "^18.3.0",
|
||||
"autoprefixer": "^10.4.20",
|
||||
"eslint": "^8.57.1",
|
||||
"eslint-config-next": "14.1.4",
|
||||
"postcss": "^8",
|
||||
"tailwindcss": "^3.4.3",
|
||||
"typescript": "^5"
|
||||
"postcss": "^8.4.47",
|
||||
"tailwindcss": "^3.4.13",
|
||||
"typescript": "^5.6.2"
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -143,7 +143,7 @@ async def main(room_url, token=None):
|
||||
|
||||
@transport.event_handler("on_participant_left")
|
||||
async def on_participant_left(transport, participant, reason):
|
||||
intro_task.queue_frame(EndFrame())
|
||||
await intro_task.queue_frame(EndFrame())
|
||||
await main_task.queue_frame(EndFrame())
|
||||
|
||||
@transport.event_handler("on_call_state_updated")
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
python-dotenv
|
||||
fastapi[all]
|
||||
pipecat-ai[daily,openai,azure]
|
||||
aiohttp
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
OPENAI_API_KEY=
|
||||
DEEPGRAM_API_KEY=
|
||||
ELEVENLABS_API_KEY=
|
||||
ELEVENLABS_VOICE_ID=
|
||||
CARTESIA_API_KEY=
|
||||
|
||||
15
examples/websocket-server/Dockerfile
Normal file
15
examples/websocket-server/Dockerfile
Normal file
@@ -0,0 +1,15 @@
|
||||
FROM python:3.10-bullseye
|
||||
|
||||
RUN mkdir /app
|
||||
|
||||
COPY *.py /app/
|
||||
COPY requirements.txt /app/
|
||||
COPY .env /app/
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN pip3 install -r requirements.txt
|
||||
|
||||
EXPOSE 7860
|
||||
|
||||
CMD ["python3", "bot.py"]
|
||||
@@ -8,6 +8,7 @@ This is an example that shows how to use `WebsocketServerTransport` to communica
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate
|
||||
pip install -r requirements.txt
|
||||
cp env.example .env # and add your credentials
|
||||
```
|
||||
|
||||
## Run the bot
|
||||
|
||||
8
examples/websocket-server/env.example
Normal file
8
examples/websocket-server/env.example
Normal file
@@ -0,0 +1,8 @@
|
||||
# OpenAI API Key
|
||||
OPENAI_API_KEY=your_openai_api_key_here
|
||||
|
||||
# Deepgram API Key
|
||||
DEEPGRAM_API_KEY=your_deepgram_api_key_here
|
||||
|
||||
# Cartesia API Key
|
||||
CARTESIA_API_KEY=your_cartesia_api_key_here
|
||||
@@ -1,2 +1,2 @@
|
||||
python-dotenv
|
||||
pipecat-ai[cartesia,openai,silero,websocket,whisper]
|
||||
pipecat-ai[cartesia,openai,silero,websocket,deepgram]
|
||||
|
||||
@@ -21,6 +21,7 @@ classifiers = [
|
||||
]
|
||||
dependencies = [
|
||||
"aiohttp~=3.10.3",
|
||||
"Markdown~=3.7",
|
||||
"numpy~=1.26.4",
|
||||
"loguru~=0.7.2",
|
||||
"Pillow~=10.4.0",
|
||||
@@ -37,9 +38,10 @@ Website = "https://pipecat.ai"
|
||||
anthropic = [ "anthropic~=0.34.0" ]
|
||||
aws = [ "boto3~=1.35.27" ]
|
||||
azure = [ "azure-cognitiveservices-speech~=1.40.0" ]
|
||||
canonical = [ "aiofiles~=24.1.0" ]
|
||||
cartesia = [ "cartesia~=1.0.13", "websockets~=12.0" ]
|
||||
daily = [ "daily-python~=0.10.1" ]
|
||||
deepgram = [ "deepgram-sdk~=3.5.0" ]
|
||||
daily = [ "daily-python~=0.11.0" ]
|
||||
deepgram = [ "deepgram-sdk~=3.7.3" ]
|
||||
elevenlabs = [ "websockets~=12.0" ]
|
||||
examples = [ "python-dotenv~=1.0.1", "flask~=3.0.3", "flask_cors~=4.0.1" ]
|
||||
fal = [ "fal-client~=0.4.1" ]
|
||||
@@ -52,11 +54,11 @@ livekit = [ "livekit~=0.13.1", "tenacity~=9.0.0" ]
|
||||
lmnt = [ "lmnt~=1.1.4" ]
|
||||
local = [ "pyaudio~=0.2.14" ]
|
||||
moondream = [ "einops~=0.8.0", "timm~=1.0.8", "transformers~=4.44.0" ]
|
||||
openai = [ "openai~=1.37.2" ]
|
||||
openai = [ "openai~=1.50.2", "websockets~=12.0", "python-deepcompare~=1.0.1" ]
|
||||
openpipe = [ "openpipe~=4.24.0" ]
|
||||
playht = [ "pyht~=0.0.28" ]
|
||||
silero = [ "onnxruntime>=1.16.1" ]
|
||||
together = [ "together~=1.2.7" ]
|
||||
together = [ "openai~=1.50.2" ]
|
||||
websocket = [ "websockets~=12.0", "fastapi~=0.115.0" ]
|
||||
whisper = [ "faster-whisper~=1.0.3" ]
|
||||
xtts = [ "resampy~=0.4.3" ]
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
#
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from pipecat.clocks.base_clock import BaseClock
|
||||
from pipecat.metrics.metrics import MetricsData
|
||||
@@ -269,12 +269,22 @@ class TTSSpeakFrame(DataFrame):
|
||||
@dataclass
|
||||
class TransportMessageFrame(DataFrame):
|
||||
message: Any
|
||||
urgent: bool = False
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name}(message: {self.message})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionCallResultFrame(DataFrame):
|
||||
"""A frame containing the result of an LLM function (tool) call."""
|
||||
|
||||
function_name: str
|
||||
tool_call_id: str
|
||||
arguments: str
|
||||
result: Any
|
||||
run_llm: bool = True
|
||||
|
||||
|
||||
#
|
||||
# App frames. Application user-defined frames.
|
||||
#
|
||||
@@ -394,6 +404,25 @@ class StopInterruptionFrame(SystemFrame):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserStartedSpeakingFrame(SystemFrame):
|
||||
"""Emitted by VAD to indicate that a user has started speaking. This can be
|
||||
used for interruptions or other times when detecting that someone is
|
||||
speaking is more important than knowing what they're saying (as you will
|
||||
with a TranscriptionFrame)
|
||||
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserStoppedSpeakingFrame(SystemFrame):
|
||||
"""Emitted by the VAD to indicate that a user stopped speaking."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BotInterruptionFrame(SystemFrame):
|
||||
"""Emitted by when the bot should be interrupted. This will mainly cause the
|
||||
@@ -405,6 +434,60 @@ class BotInterruptionFrame(SystemFrame):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BotStartedSpeakingFrame(SystemFrame):
|
||||
"""Emitted upstream by transport outputs to indicate the bot started speaking."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BotStoppedSpeakingFrame(SystemFrame):
|
||||
"""Emitted upstream by transport outputs to indicate the bot stopped speaking."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BotSpeakingFrame(SystemFrame):
|
||||
"""Emitted upstream by transport outputs while the bot is still
|
||||
speaking. This can be used, for example, to detect when a user is idle. That
|
||||
is, while the bot is speaking we don't want to trigger any user idle timeout
|
||||
since the user might be listening.
|
||||
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserImageRequestFrame(SystemFrame):
|
||||
"""A frame user to request an image from the given user."""
|
||||
|
||||
user_id: str
|
||||
context: Optional[Any] = None
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name}, user: {self.user_id}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionCallInProgressFrame(SystemFrame):
|
||||
"""A frame signaling that a function call is in progress."""
|
||||
|
||||
function_name: str
|
||||
tool_call_id: str
|
||||
arguments: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransportMessageUrgentFrame(SystemFrame):
|
||||
message: Any
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name}(message: {self.message})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetricsFrame(SystemFrame):
|
||||
"""Emitted by processor that can compute metrics like latencies."""
|
||||
@@ -450,51 +533,6 @@ class LLMFullResponseEndFrame(ControlFrame):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserStartedSpeakingFrame(ControlFrame):
|
||||
"""Emitted by VAD to indicate that a user has started speaking. This can be
|
||||
used for interruptions or other times when detecting that someone is
|
||||
speaking is more important than knowing what they're saying (as you will
|
||||
with a TranscriptionFrame)
|
||||
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserStoppedSpeakingFrame(ControlFrame):
|
||||
"""Emitted by the VAD to indicate that a user stopped speaking."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BotStartedSpeakingFrame(ControlFrame):
|
||||
"""Emitted upstream by transport outputs to indicate the bot started speaking."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BotStoppedSpeakingFrame(ControlFrame):
|
||||
"""Emitted upstream by transport outputs to indicate the bot stopped speaking."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BotSpeakingFrame(ControlFrame):
|
||||
"""Emitted upstream by transport outputs while the bot is still
|
||||
speaking. This can be used, for example, to detect when a user is idle. That
|
||||
is, while the bot is speaking we don't want to trigger any user idle timeout
|
||||
since the user might be listening.
|
||||
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class TTSStartedFrame(ControlFrame):
|
||||
"""Used to indicate the beginning of a TTS response. Following
|
||||
@@ -516,75 +554,25 @@ class TTSStoppedFrame(ControlFrame):
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserImageRequestFrame(ControlFrame):
|
||||
"""A frame user to request an image from the given user."""
|
||||
class ServiceUpdateSettingsFrame(ControlFrame):
|
||||
"""A control frame containing a request to update service settings."""
|
||||
|
||||
user_id: str
|
||||
context: Optional[Any] = None
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name}, user: {self.user_id}"
|
||||
settings: Dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMUpdateSettingsFrame(ControlFrame):
|
||||
"""A control frame containing a request to update LLM settings."""
|
||||
|
||||
model: Optional[str] = None
|
||||
temperature: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
top_p: Optional[float] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
max_tokens: Optional[int] = None
|
||||
seed: Optional[int] = None
|
||||
extra: dict = field(default_factory=dict)
|
||||
class LLMUpdateSettingsFrame(ServiceUpdateSettingsFrame):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class TTSUpdateSettingsFrame(ControlFrame):
|
||||
"""A control frame containing a request to update TTS settings."""
|
||||
|
||||
model: Optional[str] = None
|
||||
voice: Optional[str] = None
|
||||
language: Optional[Language] = None
|
||||
speed: Optional[Union[str, float]] = None
|
||||
emotion: Optional[List[str]] = None
|
||||
engine: Optional[str] = None
|
||||
pitch: Optional[str] = None
|
||||
rate: Optional[str] = None
|
||||
volume: Optional[str] = None
|
||||
emphasis: Optional[str] = None
|
||||
style: Optional[str] = None
|
||||
style_degree: Optional[str] = None
|
||||
role: Optional[str] = None
|
||||
class TTSUpdateSettingsFrame(ServiceUpdateSettingsFrame):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class STTUpdateSettingsFrame(ControlFrame):
|
||||
"""A control frame containing a request to update STT settings."""
|
||||
|
||||
model: Optional[str] = None
|
||||
language: Optional[Language] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionCallInProgressFrame(SystemFrame):
|
||||
"""A frame signaling that a function call is in progress."""
|
||||
|
||||
function_name: str
|
||||
tool_call_id: str
|
||||
arguments: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionCallResultFrame(DataFrame):
|
||||
"""A frame containing the result of an LLM function (tool) call."""
|
||||
|
||||
function_name: str
|
||||
tool_call_id: str
|
||||
arguments: str
|
||||
result: Any
|
||||
class STTUpdateSettingsFrame(ServiceUpdateSettingsFrame):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -120,7 +120,7 @@ class ParallelPipeline(BasePipeline):
|
||||
|
||||
# If we get an EndFrame we stop our queue processing tasks and wait on
|
||||
# all the pipelines to finish.
|
||||
if isinstance(frame, CancelFrame) or isinstance(frame, EndFrame):
|
||||
if isinstance(frame, (CancelFrame, EndFrame)):
|
||||
# Use None to indicate when queues should be done processing.
|
||||
await self._up_queue.put(None)
|
||||
await self._down_queue.put(None)
|
||||
|
||||
@@ -6,17 +6,25 @@
|
||||
|
||||
import asyncio
|
||||
|
||||
from dataclasses import dataclass
|
||||
from itertools import chain
|
||||
from typing import List
|
||||
|
||||
from pipecat.frames.frames import ControlFrame, EndFrame, Frame, SystemFrame
|
||||
from pipecat.pipeline.base_pipeline import BasePipeline
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.frames.frames import Frame
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class SyncFrame(ControlFrame):
|
||||
"""This frame is used to know when the internal pipelines have finished."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class Source(FrameProcessor):
|
||||
def __init__(self, upstream_queue: asyncio.Queue):
|
||||
super().__init__()
|
||||
@@ -67,13 +75,16 @@ class SyncParallelPipeline(BasePipeline):
|
||||
raise TypeError(f"SyncParallelPipeline argument {processors} is not a list")
|
||||
|
||||
# We add a source at the beginning of the pipeline and a sink at the end.
|
||||
source = Source(self._up_queue)
|
||||
sink = Sink(self._down_queue)
|
||||
up_queue = asyncio.Queue()
|
||||
down_queue = asyncio.Queue()
|
||||
source = Source(up_queue)
|
||||
sink = Sink(down_queue)
|
||||
processors: List[FrameProcessor] = [source] + processors + [sink]
|
||||
|
||||
# Keep track of sources and sinks.
|
||||
self._sources.append(source)
|
||||
self._sinks.append(sink)
|
||||
# Keep track of sources and sinks. We also keep the output queue of
|
||||
# the source and the sinks so we can use it later.
|
||||
self._sources.append({"processor": source, "queue": down_queue})
|
||||
self._sinks.append({"processor": sink, "queue": up_queue})
|
||||
|
||||
# Create pipeline
|
||||
pipeline = Pipeline(processors)
|
||||
@@ -94,17 +105,52 @@ class SyncParallelPipeline(BasePipeline):
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# The last processor of each pipeline needs to be synchronous otherwise
|
||||
# this element won't work. Since, we know it should be synchronous we
|
||||
# push a SyncFrame. Since frames are ordered we know this frame will be
|
||||
# pushed after the synchronous processor has pushed its data allowing us
|
||||
# to synchrnonize all the internal pipelines by waiting for the
|
||||
# SyncFrame in all of them.
|
||||
async def wait_for_sync(
|
||||
obj, main_queue: asyncio.Queue, frame: Frame, direction: FrameDirection
|
||||
):
|
||||
processor = obj["processor"]
|
||||
queue = obj["queue"]
|
||||
|
||||
await processor.process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, (SystemFrame, EndFrame)):
|
||||
new_frame = await queue.get()
|
||||
if isinstance(new_frame, (SystemFrame, EndFrame)):
|
||||
await main_queue.put(new_frame)
|
||||
else:
|
||||
while not isinstance(new_frame, (SystemFrame, EndFrame)):
|
||||
await main_queue.put(new_frame)
|
||||
queue.task_done()
|
||||
new_frame = await queue.get()
|
||||
else:
|
||||
await processor.process_frame(SyncFrame(), direction)
|
||||
new_frame = await queue.get()
|
||||
while not isinstance(new_frame, SyncFrame):
|
||||
await main_queue.put(new_frame)
|
||||
queue.task_done()
|
||||
new_frame = await queue.get()
|
||||
|
||||
if direction == FrameDirection.UPSTREAM:
|
||||
# If we get an upstream frame we process it in each sink.
|
||||
await asyncio.gather(*[s.process_frame(frame, direction) for s in self._sinks])
|
||||
await asyncio.gather(
|
||||
*[wait_for_sync(s, self._up_queue, frame, direction) for s in self._sinks]
|
||||
)
|
||||
elif direction == FrameDirection.DOWNSTREAM:
|
||||
# If we get a downstream frame we process it in each source.
|
||||
await asyncio.gather(*[s.process_frame(frame, direction) for s in self._sources])
|
||||
await asyncio.gather(
|
||||
*[wait_for_sync(s, self._down_queue, frame, direction) for s in self._sources]
|
||||
)
|
||||
|
||||
seen_ids = set()
|
||||
while not self._up_queue.empty():
|
||||
frame = await self._up_queue.get()
|
||||
if frame and frame.id not in seen_ids:
|
||||
if frame.id not in seen_ids:
|
||||
await self.push_frame(frame, FrameDirection.UPSTREAM)
|
||||
seen_ids.add(frame.id)
|
||||
self._up_queue.task_done()
|
||||
@@ -112,7 +158,7 @@ class SyncParallelPipeline(BasePipeline):
|
||||
seen_ids = set()
|
||||
while not self._down_queue.empty():
|
||||
frame = await self._down_queue.get()
|
||||
if frame and frame.id not in seen_ids:
|
||||
if frame.id not in seen_ids:
|
||||
await self.push_frame(frame, FrameDirection.DOWNSTREAM)
|
||||
seen_ids.add(frame.id)
|
||||
self._down_queue.task_done()
|
||||
|
||||
@@ -69,6 +69,19 @@ class Source(FrameProcessor):
|
||||
await self._up_queue.put(StopTaskFrame())
|
||||
|
||||
|
||||
class Sink(FrameProcessor):
|
||||
def __init__(self, down_queue: asyncio.Queue):
|
||||
super().__init__()
|
||||
self._down_queue = down_queue
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# We really just want to know when the EndFrame reached the sink.
|
||||
if isinstance(frame, EndFrame):
|
||||
await self._down_queue.put(frame)
|
||||
|
||||
|
||||
class PipelineTask:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -84,12 +97,16 @@ class PipelineTask:
|
||||
self._params = params
|
||||
self._finished = False
|
||||
|
||||
self._down_queue = asyncio.Queue()
|
||||
self._up_queue = asyncio.Queue()
|
||||
self._down_queue = asyncio.Queue()
|
||||
self._push_queue = asyncio.Queue()
|
||||
|
||||
self._source = Source(self._up_queue)
|
||||
self._source.link(pipeline)
|
||||
|
||||
self._sink = Sink(self._down_queue)
|
||||
pipeline.link(self._sink)
|
||||
|
||||
def has_finished(self):
|
||||
return self._finished
|
||||
|
||||
@@ -103,19 +120,19 @@ class PipelineTask:
|
||||
# out-of-band from the main streaming task which is what we want since
|
||||
# we want to cancel right away.
|
||||
await self._source.push_frame(CancelFrame())
|
||||
self._process_down_task.cancel()
|
||||
self._process_push_task.cancel()
|
||||
self._process_up_task.cancel()
|
||||
await self._process_down_task
|
||||
await self._process_push_task
|
||||
await self._process_up_task
|
||||
|
||||
async def run(self):
|
||||
self._process_up_task = asyncio.create_task(self._process_up_queue())
|
||||
self._process_down_task = asyncio.create_task(self._process_down_queue())
|
||||
await asyncio.gather(self._process_up_task, self._process_down_task)
|
||||
self._process_push_task = asyncio.create_task(self._process_push_queue())
|
||||
await asyncio.gather(self._process_up_task, self._process_push_task)
|
||||
self._finished = True
|
||||
|
||||
async def queue_frame(self, frame: Frame):
|
||||
await self._down_queue.put(frame)
|
||||
await self._push_queue.put(frame)
|
||||
|
||||
async def queue_frames(self, frames: Iterable[Frame] | AsyncIterable[Frame]):
|
||||
if isinstance(frames, AsyncIterable):
|
||||
@@ -133,7 +150,7 @@ class PipelineTask:
|
||||
data.append(ProcessingMetricsData(processor=p.name, value=0.0))
|
||||
return MetricsFrame(data=data)
|
||||
|
||||
async def _process_down_queue(self):
|
||||
async def _process_push_queue(self):
|
||||
self._clock.start()
|
||||
|
||||
start_frame = StartFrame(
|
||||
@@ -154,11 +171,13 @@ class PipelineTask:
|
||||
should_cleanup = True
|
||||
while running:
|
||||
try:
|
||||
frame = await self._down_queue.get()
|
||||
frame = await self._push_queue.get()
|
||||
await self._source.process_frame(frame, FrameDirection.DOWNSTREAM)
|
||||
running = not (isinstance(frame, StopTaskFrame) or isinstance(frame, EndFrame))
|
||||
if isinstance(frame, EndFrame):
|
||||
await self._wait_for_endframe()
|
||||
running = not isinstance(frame, (StopTaskFrame, EndFrame))
|
||||
should_cleanup = not isinstance(frame, StopTaskFrame)
|
||||
self._down_queue.task_done()
|
||||
self._push_queue.task_done()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
# Cleanup only if we need to.
|
||||
@@ -169,6 +188,12 @@ class PipelineTask:
|
||||
self._process_up_task.cancel()
|
||||
await self._process_up_task
|
||||
|
||||
async def _wait_for_endframe(self):
|
||||
# NOTE(aleix): the Sink element just pushes EndFrames to the down queue,
|
||||
# so just wait for it. In the future we might do something else here,
|
||||
# but for now this is fine.
|
||||
await self._down_queue.get()
|
||||
|
||||
async def _process_up_queue(self):
|
||||
while True:
|
||||
try:
|
||||
|
||||
@@ -6,12 +6,6 @@
|
||||
|
||||
from typing import List, Type
|
||||
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContextFrame,
|
||||
OpenAILLMContext,
|
||||
)
|
||||
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
@@ -22,11 +16,16 @@ from pipecat.frames.frames import (
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMSetToolsFrame,
|
||||
StartInterruptionFrame,
|
||||
TranscriptionFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
|
||||
class LLMResponseAggregator(FrameProcessor):
|
||||
@@ -40,6 +39,7 @@ class LLMResponseAggregator(FrameProcessor):
|
||||
accumulator_frame: Type[TextFrame],
|
||||
interim_accumulator_frame: Type[TextFrame] | None = None,
|
||||
handle_interruptions: bool = False,
|
||||
expect_stripped_words: bool = True, # if True, need to add spaces between words
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -50,6 +50,7 @@ class LLMResponseAggregator(FrameProcessor):
|
||||
self._accumulator_frame = accumulator_frame
|
||||
self._interim_accumulator_frame = interim_accumulator_frame
|
||||
self._handle_interruptions = handle_interruptions
|
||||
self._expect_stripped_words = expect_stripped_words
|
||||
|
||||
# Reset our accumulator state.
|
||||
self._reset()
|
||||
@@ -111,7 +112,10 @@ class LLMResponseAggregator(FrameProcessor):
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, self._accumulator_frame):
|
||||
if self._aggregating:
|
||||
self._aggregation += f" {frame.text}" if self._aggregation else frame.text
|
||||
if self._expect_stripped_words:
|
||||
self._aggregation += f" {frame.text}" if self._aggregation else frame.text
|
||||
else:
|
||||
self._aggregation += frame.text
|
||||
# We have recevied a complete sentence, so if we have seen the
|
||||
# end frame and we were still aggregating, it means we should
|
||||
# send the aggregation.
|
||||
@@ -290,7 +294,7 @@ class LLMContextAggregator(LLMResponseAggregator):
|
||||
|
||||
|
||||
class LLMAssistantContextAggregator(LLMContextAggregator):
|
||||
def __init__(self, context: OpenAILLMContext):
|
||||
def __init__(self, context: OpenAILLMContext, *, expect_stripped_words: bool = True):
|
||||
super().__init__(
|
||||
messages=[],
|
||||
context=context,
|
||||
@@ -299,6 +303,7 @@ class LLMAssistantContextAggregator(LLMContextAggregator):
|
||||
end_frame=LLMFullResponseEndFrame,
|
||||
accumulator_frame=TextFrame,
|
||||
handle_interruptions=True,
|
||||
expect_stripped_words=expect_stripped_words,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import base64
|
||||
import copy
|
||||
import io
|
||||
import json
|
||||
|
||||
@@ -60,6 +62,7 @@ class OpenAILLMContext:
|
||||
self._messages: List[ChatCompletionMessageParam] = messages if messages else []
|
||||
self._tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = tool_choice
|
||||
self._tools: List[ChatCompletionToolParam] | NotGiven = tools
|
||||
self._user_image_request_context = {}
|
||||
|
||||
@staticmethod
|
||||
def from_messages(messages: List[dict]) -> "OpenAILLMContext":
|
||||
@@ -112,7 +115,39 @@ class OpenAILLMContext:
|
||||
return self._messages
|
||||
|
||||
def get_messages_json(self) -> str:
|
||||
return json.dumps(self._messages, cls=CustomEncoder)
|
||||
return json.dumps(self._messages, cls=CustomEncoder, ensure_ascii=False, indent=2)
|
||||
|
||||
def get_messages_for_logging(self) -> str:
|
||||
msgs = []
|
||||
for message in self.messages:
|
||||
msg = copy.deepcopy(message)
|
||||
if "content" in msg:
|
||||
if isinstance(msg["content"], list):
|
||||
for item in msg["content"]:
|
||||
if item["type"] == "image_url":
|
||||
if item["image_url"]["url"].startswith("data:image/"):
|
||||
item["image_url"]["url"] = "data:image/..."
|
||||
if "mime_type" in msg and msg["mime_type"].startswith("image/"):
|
||||
msg["data"] = "..."
|
||||
msgs.append(msg)
|
||||
return json.dumps(msgs)
|
||||
|
||||
def from_standard_message(self, message):
|
||||
return message
|
||||
|
||||
# convert a message in this LLM's format to one or more messages in OpenAI format
|
||||
def to_standard_messages(self, obj) -> list:
|
||||
return [obj]
|
||||
|
||||
def get_messages_for_initializing_history(self):
|
||||
return self._messages
|
||||
|
||||
def get_messages_for_persistent_storage(self):
|
||||
messages = []
|
||||
for m in self._messages:
|
||||
standard_messages = self.to_standard_messages(m)
|
||||
messages.extend(standard_messages)
|
||||
return messages
|
||||
|
||||
def set_tool_choice(self, tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven):
|
||||
self._tool_choice = tool_choice
|
||||
@@ -122,6 +157,21 @@ class OpenAILLMContext:
|
||||
tools = NOT_GIVEN
|
||||
self._tools = tools
|
||||
|
||||
def add_image_frame_message(
|
||||
self, *, format: str, size: tuple[int, int], image: bytes, text: str = None
|
||||
):
|
||||
buffer = io.BytesIO()
|
||||
Image.frombytes(format, size, image).save(buffer, format="JPEG")
|
||||
encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
content = [
|
||||
{"type": "text", "text": text},
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}},
|
||||
]
|
||||
if text:
|
||||
content.append({"type": "text", "text": text})
|
||||
self.add_message({"role": "user", "content": content})
|
||||
|
||||
async def call_function(
|
||||
self,
|
||||
f: Callable[
|
||||
@@ -133,7 +183,9 @@ class OpenAILLMContext:
|
||||
tool_call_id: str,
|
||||
arguments: str,
|
||||
llm: FrameProcessor,
|
||||
run_llm: bool = True,
|
||||
) -> None:
|
||||
logger.debug(f"Calling function {function_name} with arguments {arguments}")
|
||||
# Push a SystemFrame downstream. This frame will let our assistant context aggregator
|
||||
# know that we are in the middle of a function call. Some contexts/aggregators may
|
||||
# not need this. But some definitely do (Anthropic, for example).
|
||||
@@ -153,6 +205,7 @@ class OpenAILLMContext:
|
||||
tool_call_id=tool_call_id,
|
||||
arguments=arguments,
|
||||
result=result,
|
||||
run_llm=run_llm,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
100
src/pipecat/processors/audio/audio_buffer_processor.py
Normal file
100
src/pipecat/processors/audio/audio_buffer_processor.py
Normal file
@@ -0,0 +1,100 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import wave
|
||||
from io import BytesIO
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
AudioRawFrame,
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
OutputAudioRawFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
|
||||
class AudioBufferProcessor(FrameProcessor):
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
Initialize the AudioBufferProcessor.
|
||||
|
||||
This constructor sets up the initial state for audio processing:
|
||||
- audio_buffer: A bytearray to store incoming audio data.
|
||||
- num_channels: The number of audio channels (initialized as None).
|
||||
- sample_rate: The sample rate of the audio (initialized as None).
|
||||
|
||||
The num_channels and sample_rate are set to None initially and will be
|
||||
populated when the first audio frame is processed.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._user_audio_buffer = bytearray()
|
||||
self._assistant_audio_buffer = bytearray()
|
||||
self._num_channels = None
|
||||
self._sample_rate = None
|
||||
|
||||
def _buffer_has_audio(self, buffer: bytearray):
|
||||
return buffer is not None and len(buffer) > 0
|
||||
|
||||
def has_audio(self):
|
||||
return (
|
||||
self._buffer_has_audio(self._user_audio_buffer)
|
||||
and self._buffer_has_audio(self._assistant_audio_buffer)
|
||||
and self._sample_rate is not None
|
||||
)
|
||||
|
||||
def reset_audio_buffer(self):
|
||||
self._user_audio_buffer = bytearray()
|
||||
self._assistant_audio_buffer = bytearray()
|
||||
|
||||
def merge_audio_buffers(self):
|
||||
with BytesIO() as buffer:
|
||||
with wave.open(buffer, "wb") as wf:
|
||||
wf.setnchannels(2)
|
||||
wf.setsampwidth(2)
|
||||
wf.setframerate(self._sample_rate)
|
||||
# Interleave the two audio streams
|
||||
max_length = max(len(self._user_audio_buffer), len(self._assistant_audio_buffer))
|
||||
interleaved = bytearray(max_length * 2)
|
||||
|
||||
for i in range(0, max_length, 2):
|
||||
if i < len(self._user_audio_buffer):
|
||||
interleaved[i * 2] = self._user_audio_buffer[i]
|
||||
interleaved[i * 2 + 1] = self._user_audio_buffer[i + 1]
|
||||
else:
|
||||
interleaved[i * 2] = 0
|
||||
interleaved[i * 2 + 1] = 0
|
||||
|
||||
if i < len(self._assistant_audio_buffer):
|
||||
interleaved[i * 2 + 2] = self._assistant_audio_buffer[i]
|
||||
interleaved[i * 2 + 3] = self._assistant_audio_buffer[i + 1]
|
||||
else:
|
||||
interleaved[i * 2 + 2] = 0
|
||||
interleaved[i * 2 + 3] = 0
|
||||
|
||||
wf.writeframes(interleaved)
|
||||
return buffer.getvalue()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
if isinstance(frame, AudioRawFrame) and self._sample_rate is None:
|
||||
self._sample_rate = frame.sample_rate
|
||||
|
||||
# include all audio from the user
|
||||
if isinstance(frame, InputAudioRawFrame):
|
||||
self._user_audio_buffer.extend(frame.audio)
|
||||
# Sync the assistant's buffer to the user's buffer by adding silence if needed
|
||||
if len(self._user_audio_buffer) > len(self._assistant_audio_buffer):
|
||||
silence_length = len(self._user_audio_buffer) - len(self._assistant_audio_buffer)
|
||||
silence = b"\x00" * silence_length
|
||||
self._assistant_audio_buffer.extend(silence)
|
||||
|
||||
# if the assistant is speaking, include all audio from the assistant,
|
||||
if isinstance(frame, OutputAudioRawFrame):
|
||||
self._assistant_audio_buffer.extend(frame.audio)
|
||||
|
||||
# do not push the user's audio frame, doing so will result in echo
|
||||
if not isinstance(frame, InputAudioRawFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
@@ -37,7 +37,6 @@ class FrameProcessor:
|
||||
*,
|
||||
name: str | None = None,
|
||||
metrics: FrameProcessorMetrics | None = None,
|
||||
sync: bool = True,
|
||||
loop: asyncio.AbstractEventLoop | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -47,7 +46,6 @@ class FrameProcessor:
|
||||
self._prev: "FrameProcessor" | None = None
|
||||
self._next: "FrameProcessor" | None = None
|
||||
self._loop: asyncio.AbstractEventLoop = loop or asyncio.get_running_loop()
|
||||
self._sync = sync
|
||||
|
||||
self._event_handlers: dict = {}
|
||||
|
||||
@@ -66,11 +64,8 @@ class FrameProcessor:
|
||||
|
||||
# Every processor in Pipecat should only output frames from a single
|
||||
# task. This avoid problems like audio overlapping. System frames are
|
||||
# the exception to this rule.
|
||||
#
|
||||
# This create this task.
|
||||
if not self._sync:
|
||||
self.__create_push_task()
|
||||
# the exception to this rule. This create this task.
|
||||
self.__create_push_task()
|
||||
|
||||
@property
|
||||
def interruptions_allowed(self):
|
||||
@@ -167,7 +162,7 @@ class FrameProcessor:
|
||||
await self.push_frame(error, FrameDirection.UPSTREAM)
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
if self._sync or isinstance(frame, SystemFrame):
|
||||
if isinstance(frame, SystemFrame):
|
||||
await self.__internal_push_frame(frame, direction)
|
||||
else:
|
||||
await self.__push_queue.put((frame, direction))
|
||||
@@ -194,13 +189,12 @@ class FrameProcessor:
|
||||
#
|
||||
|
||||
async def _start_interruption(self):
|
||||
if not self._sync:
|
||||
# Cancel the task. This will stop pushing frames downstream.
|
||||
self.__push_frame_task.cancel()
|
||||
await self.__push_frame_task
|
||||
# Cancel the task. This will stop pushing frames downstream.
|
||||
self.__push_frame_task.cancel()
|
||||
await self.__push_frame_task
|
||||
|
||||
# Create a new queue and task.
|
||||
self.__create_push_task()
|
||||
# Create a new queue and task.
|
||||
self.__create_push_task()
|
||||
|
||||
async def _stop_interruption(self):
|
||||
# Nothing to do right now.
|
||||
|
||||
@@ -5,11 +5,21 @@
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Union
|
||||
from pydantic import BaseModel, Field, PrivateAttr, ValidationError
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field, PrivateAttr, ValidationError
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotInterruptionFrame,
|
||||
@@ -20,26 +30,34 @@ from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
FunctionCallResultFrame,
|
||||
InterimTranscriptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
OutputAudioRawFrame,
|
||||
MetricsFrame,
|
||||
StartFrame,
|
||||
SystemFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
TransportMessageFrame,
|
||||
TransportMessageUrgentFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
FunctionCallResultFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.metrics.metrics import (
|
||||
LLMUsageMetricsData,
|
||||
ProcessingMetricsData,
|
||||
TTFBMetricsData,
|
||||
TTSUsageMetricsData,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.utils.string import match_endofsentence
|
||||
|
||||
RTVI_PROTOCOL_VERSION = "0.2"
|
||||
|
||||
@@ -273,6 +291,12 @@ class RTVITextMessageData(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
class RTVIBotTranscriptionMessage(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["bot-transcription"] = "bot-transcription"
|
||||
data: RTVITextMessageData
|
||||
|
||||
|
||||
class RTVIBotLLMTextMessage(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["bot-llm-text"] = "bot-llm-text"
|
||||
@@ -291,22 +315,12 @@ class RTVIAudioMessageData(BaseModel):
|
||||
num_channels: int
|
||||
|
||||
|
||||
class RTVIBotAudioMessage(BaseModel):
|
||||
class RTVIBotTTSAudioMessage(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["bot-audio"] = "bot-audio"
|
||||
type: Literal["bot-tts-audio"] = "bot-tts-audio"
|
||||
data: RTVIAudioMessageData
|
||||
|
||||
|
||||
class RTVIBotTranscriptionMessageData(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
class RTVIBotTranscriptionMessage(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["bot-transcription"] = "bot-transcription"
|
||||
data: RTVIBotTranscriptionMessageData
|
||||
|
||||
|
||||
class RTVIUserTranscriptionMessageData(BaseModel):
|
||||
text: str
|
||||
user_id: str
|
||||
@@ -320,6 +334,12 @@ class RTVIUserTranscriptionMessage(BaseModel):
|
||||
data: RTVIUserTranscriptionMessageData
|
||||
|
||||
|
||||
class RTVIUserLLMTextMessage(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["user-llm-text"] = "user-llm-text"
|
||||
data: RTVITextMessageData
|
||||
|
||||
|
||||
class RTVIUserStartedSpeakingMessage(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["user-started-speaking"] = "user-started-speaking"
|
||||
@@ -340,6 +360,12 @@ class RTVIBotStoppedSpeakingMessage(BaseModel):
|
||||
type: Literal["bot-stopped-speaking"] = "bot-stopped-speaking"
|
||||
|
||||
|
||||
class RTVIMetricsMessage(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["metrics"] = "metrics"
|
||||
data: Mapping[str, Any]
|
||||
|
||||
|
||||
class RTVIProcessorParams(BaseModel):
|
||||
send_bot_ready: bool = True
|
||||
|
||||
@@ -349,10 +375,8 @@ class RTVIFrameProcessor(FrameProcessor):
|
||||
super().__init__(**kwargs)
|
||||
self._direction = direction
|
||||
|
||||
async def _push_transport_message(self, model: BaseModel, exclude_none: bool = True):
|
||||
frame = TransportMessageFrame(
|
||||
message=model.model_dump(exclude_none=exclude_none), urgent=True
|
||||
)
|
||||
async def _push_transport_message_urgent(self, model: BaseModel, exclude_none: bool = True):
|
||||
frame = TransportMessageUrgentFrame(message=model.model_dump(exclude_none=exclude_none))
|
||||
await self.push_frame(frame, self._direction)
|
||||
|
||||
|
||||
@@ -378,7 +402,7 @@ class RTVISpeakingProcessor(RTVIFrameProcessor):
|
||||
message = RTVIUserStoppedSpeakingMessage()
|
||||
|
||||
if message:
|
||||
await self._push_transport_message(message)
|
||||
await self._push_transport_message_urgent(message)
|
||||
|
||||
async def _handle_bot_speaking(self, frame: Frame):
|
||||
message = None
|
||||
@@ -388,7 +412,7 @@ class RTVISpeakingProcessor(RTVIFrameProcessor):
|
||||
message = RTVIBotStoppedSpeakingMessage()
|
||||
|
||||
if message:
|
||||
await self._push_transport_message(message)
|
||||
await self._push_transport_message_urgent(message)
|
||||
|
||||
|
||||
class RTVIUserTranscriptionProcessor(RTVIFrameProcessor):
|
||||
@@ -419,7 +443,57 @@ class RTVIUserTranscriptionProcessor(RTVIFrameProcessor):
|
||||
)
|
||||
|
||||
if message:
|
||||
await self._push_transport_message(message)
|
||||
await self._push_transport_message_urgent(message)
|
||||
|
||||
|
||||
class RTVIUserLLMTextProcessor(RTVIFrameProcessor):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, OpenAILLMContextFrame):
|
||||
await self._handle_context(frame)
|
||||
|
||||
async def _handle_context(self, frame: OpenAILLMContextFrame):
|
||||
messages = frame.context.messages
|
||||
if len(messages) > 0:
|
||||
message = messages[-1]
|
||||
if message["role"] == "user":
|
||||
content = message["content"]
|
||||
if isinstance(content, list):
|
||||
text = " ".join(item["text"] for item in content if "text" in item)
|
||||
else:
|
||||
text = content
|
||||
rtvi_message = RTVIUserLLMTextMessage(data=RTVITextMessageData(text=text))
|
||||
await self._push_transport_message_urgent(rtvi_message)
|
||||
|
||||
|
||||
class RTVIBotTranscriptionProcessor(RTVIFrameProcessor):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._aggregation = ""
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, UserStartedSpeakingFrame):
|
||||
await self._push_aggregation()
|
||||
elif isinstance(frame, TextFrame):
|
||||
self._aggregation += frame.text
|
||||
if match_endofsentence(self._aggregation):
|
||||
await self._push_aggregation()
|
||||
|
||||
async def _push_aggregation(self):
|
||||
if len(self._aggregation) > 0:
|
||||
message = RTVIBotTranscriptionMessage(data=RTVITextMessageData(text=self._aggregation))
|
||||
await self._push_transport_message_urgent(message)
|
||||
self._aggregation = ""
|
||||
|
||||
|
||||
class RTVIBotLLMProcessor(RTVIFrameProcessor):
|
||||
@@ -432,9 +506,12 @@ class RTVIBotLLMProcessor(RTVIFrameProcessor):
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, LLMFullResponseStartFrame):
|
||||
await self._push_transport_message(RTVIBotLLMStartedMessage())
|
||||
await self._push_transport_message_urgent(RTVIBotLLMStartedMessage())
|
||||
elif isinstance(frame, LLMFullResponseEndFrame):
|
||||
await self._push_transport_message(RTVIBotLLMStoppedMessage())
|
||||
await self._push_transport_message_urgent(RTVIBotLLMStoppedMessage())
|
||||
elif type(frame) is TextFrame:
|
||||
message = RTVIBotLLMTextMessage(data=RTVITextMessageData(text=frame.text))
|
||||
await self._push_transport_message_urgent(message)
|
||||
|
||||
|
||||
class RTVIBotTTSProcessor(RTVIFrameProcessor):
|
||||
@@ -447,12 +524,15 @@ class RTVIBotTTSProcessor(RTVIFrameProcessor):
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, TTSStartedFrame):
|
||||
await self._push_transport_message(RTVIBotTTSStartedMessage())
|
||||
await self._push_transport_message_urgent(RTVIBotTTSStartedMessage())
|
||||
elif isinstance(frame, TTSStoppedFrame):
|
||||
await self._push_transport_message(RTVIBotTTSStoppedMessage())
|
||||
await self._push_transport_message_urgent(RTVIBotTTSStoppedMessage())
|
||||
elif type(frame) is TextFrame:
|
||||
message = RTVIBotTTSTextMessage(data=RTVITextMessageData(text=frame.text))
|
||||
await self._push_transport_message_urgent(message)
|
||||
|
||||
|
||||
class RTVIBotLLMTextProcessor(RTVIFrameProcessor):
|
||||
class RTVIMetricsProcessor(RTVIFrameProcessor):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -461,51 +541,31 @@ class RTVIBotLLMTextProcessor(RTVIFrameProcessor):
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, TextFrame):
|
||||
await self._handle_text(frame)
|
||||
if isinstance(frame, MetricsFrame):
|
||||
await self._handle_metrics(frame)
|
||||
|
||||
async def _handle_text(self, frame: TextFrame):
|
||||
message = RTVIBotLLMTextMessage(data=RTVITextMessageData(text=frame.text))
|
||||
await self._push_transport_message(message)
|
||||
async def _handle_metrics(self, frame: MetricsFrame):
|
||||
metrics = {}
|
||||
for d in frame.data:
|
||||
if isinstance(d, TTFBMetricsData):
|
||||
if "ttfb" not in metrics:
|
||||
metrics["ttfb"] = []
|
||||
metrics["ttfb"].append(d.model_dump(exclude_none=True))
|
||||
elif isinstance(d, ProcessingMetricsData):
|
||||
if "processing" not in metrics:
|
||||
metrics["processing"] = []
|
||||
metrics["processing"].append(d.model_dump(exclude_none=True))
|
||||
elif isinstance(d, LLMUsageMetricsData):
|
||||
if "tokens" not in metrics:
|
||||
metrics["tokens"] = []
|
||||
metrics["tokens"].append(d.value.model_dump(exclude_none=True))
|
||||
elif isinstance(d, TTSUsageMetricsData):
|
||||
if "characters" not in metrics:
|
||||
metrics["characters"] = []
|
||||
metrics["characters"].append(d.model_dump(exclude_none=True))
|
||||
|
||||
|
||||
class RTVIBotTTSTextProcessor(RTVIFrameProcessor):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, TextFrame):
|
||||
await self._handle_text(frame)
|
||||
|
||||
async def _handle_text(self, frame: TextFrame):
|
||||
message = RTVIBotTTSTextMessage(data=RTVITextMessageData(text=frame.text))
|
||||
await self._push_transport_message(message)
|
||||
|
||||
|
||||
class RTVIBotAudioProcessor(RTVIFrameProcessor):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, OutputAudioRawFrame):
|
||||
await self._handle_audio(frame)
|
||||
|
||||
async def _handle_audio(self, frame: OutputAudioRawFrame):
|
||||
encoded = base64.b64encode(frame.audio).decode("utf-8")
|
||||
message = RTVIBotAudioMessage(
|
||||
data=RTVIAudioMessageData(
|
||||
audio=encoded, sample_rate=frame.sample_rate, num_channels=frame.num_channels
|
||||
)
|
||||
)
|
||||
await self._push_transport_message(message)
|
||||
message = RTVIMetricsMessage(data=metrics)
|
||||
await self._push_transport_message_urgent(message)
|
||||
|
||||
|
||||
class RTVIProcessor(FrameProcessor):
|
||||
@@ -516,7 +576,7 @@ class RTVIProcessor(FrameProcessor):
|
||||
params: RTVIProcessorParams = RTVIProcessorParams(),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(sync=False, **kwargs)
|
||||
super().__init__(**kwargs)
|
||||
self._config = config
|
||||
self._params = params
|
||||
|
||||
@@ -647,9 +707,7 @@ class RTVIProcessor(FrameProcessor):
|
||||
self._message_task = None
|
||||
|
||||
async def _push_transport_message(self, model: BaseModel, exclude_none: bool = True):
|
||||
frame = TransportMessageFrame(
|
||||
message=model.model_dump(exclude_none=exclude_none), urgent=True
|
||||
)
|
||||
frame = TransportMessageUrgentFrame(message=model.model_dump(exclude_none=exclude_none))
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def _action_task_handler(self):
|
||||
|
||||
@@ -44,7 +44,7 @@ class GStreamerPipelineSource(FrameProcessor):
|
||||
clock_sync: bool = True
|
||||
|
||||
def __init__(self, *, pipeline: str, out_params: OutputParams = OutputParams(), **kwargs):
|
||||
super().__init__(sync=False, **kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._out_params = out_params
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ class IdleFrameProcessor(FrameProcessor):
|
||||
types: List[type] = [],
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(sync=False, **kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._callback = callback
|
||||
self._timeout = timeout
|
||||
|
||||
@@ -31,7 +31,7 @@ class UserIdleProcessor(FrameProcessor):
|
||||
timeout: float,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(sync=False, **kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._callback = callback
|
||||
self._timeout = timeout
|
||||
|
||||
@@ -8,7 +8,7 @@ import asyncio
|
||||
import io
|
||||
import wave
|
||||
from abc import abstractmethod
|
||||
from typing import AsyncGenerator, List, Optional, Tuple, Union
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -37,6 +37,7 @@ from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.audio import calculate_audio_volume
|
||||
from pipecat.utils.string import match_endofsentence
|
||||
from pipecat.utils.text.base_text_filter import BaseTextFilter
|
||||
from pipecat.utils.time import seconds_to_nanoseconds
|
||||
from pipecat.utils.utils import exp_smoothing
|
||||
|
||||
@@ -45,6 +46,8 @@ class AIService(FrameProcessor):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._model_name: str = ""
|
||||
self._settings: Dict[str, Any] = {}
|
||||
self._session_properties: Dict[str, Any] = {}
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
@@ -63,6 +66,49 @@ class AIService(FrameProcessor):
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
pass
|
||||
|
||||
async def _update_settings(self, settings: Dict[str, Any]):
|
||||
from pipecat.services.openai_realtime_beta.events import (
|
||||
SessionProperties,
|
||||
)
|
||||
|
||||
for key, value in settings.items():
|
||||
print("Update request for:", key, value)
|
||||
|
||||
if key in self._settings:
|
||||
logger.debug(f"Updating LLM setting {key} to: [{value}]")
|
||||
self._settings[key] = value
|
||||
elif key in SessionProperties.model_fields:
|
||||
print("Attempting to update", key, value)
|
||||
|
||||
try:
|
||||
from pipecat.services.openai_realtime_beta.events import (
|
||||
TurnDetection,
|
||||
)
|
||||
|
||||
if isinstance(self._session_properties, SessionProperties):
|
||||
current_properties = self._session_properties
|
||||
else:
|
||||
current_properties = SessionProperties(**self._session_properties)
|
||||
|
||||
if key == "turn_detection" and isinstance(value, dict):
|
||||
turn_detection = TurnDetection(**value)
|
||||
setattr(current_properties, key, turn_detection)
|
||||
else:
|
||||
setattr(current_properties, key, value)
|
||||
|
||||
validated_properties = SessionProperties.model_validate(
|
||||
current_properties.model_dump()
|
||||
)
|
||||
logger.debug(f"Updating LLM setting {key} to: [{value}]")
|
||||
self._session_properties = validated_properties.model_dump()
|
||||
except Exception as e:
|
||||
logger.warning(f"Unexpected error updating session property {key}: {e}")
|
||||
elif key == "model":
|
||||
logger.debug(f"Updating LLM setting {key} to: [{value}]")
|
||||
self.set_model_name(value)
|
||||
else:
|
||||
logger.warning(f"Unknown setting for {self.name} service: {key}")
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
@@ -110,7 +156,13 @@ class LLMService(AIService):
|
||||
return function_name in self._callbacks.keys()
|
||||
|
||||
async def call_function(
|
||||
self, *, context: OpenAILLMContext, tool_call_id: str, function_name: str, arguments: str
|
||||
self,
|
||||
*,
|
||||
context: OpenAILLMContext,
|
||||
tool_call_id: str,
|
||||
function_name: str,
|
||||
arguments: str,
|
||||
run_llm: bool = True,
|
||||
) -> None:
|
||||
f = None
|
||||
if function_name in self._callbacks.keys():
|
||||
@@ -120,7 +172,12 @@ class LLMService(AIService):
|
||||
else:
|
||||
return None
|
||||
await context.call_function(
|
||||
f, function_name=function_name, tool_call_id=tool_call_id, arguments=arguments, llm=self
|
||||
f,
|
||||
function_name=function_name,
|
||||
tool_call_id=tool_call_id,
|
||||
arguments=arguments,
|
||||
llm=self,
|
||||
run_llm=run_llm,
|
||||
)
|
||||
|
||||
# QUESTION FOR CB: maybe this isn't needed anymore?
|
||||
@@ -144,15 +201,29 @@ class TTSService(AIService):
|
||||
# if True, TTSService will push TextFrames and LLMFullResponseEndFrames,
|
||||
# otherwise subclass must do it
|
||||
push_text_frames: bool = True,
|
||||
# if True, TTSService will push TTSStoppedFrames, otherwise subclass must do it
|
||||
push_stop_frames: bool = False,
|
||||
# if push_stop_frames is True, wait for this idle period before pushing TTSStoppedFrame
|
||||
stop_frame_timeout_s: float = 1.0,
|
||||
# TTS output sample rate
|
||||
sample_rate: int = 16000,
|
||||
text_filter: Optional[BaseTextFilter] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._aggregate_sentences: bool = aggregate_sentences
|
||||
self._push_text_frames: bool = push_text_frames
|
||||
self._current_sentence: str = ""
|
||||
self._push_stop_frames: bool = push_stop_frames
|
||||
self._stop_frame_timeout_s: float = stop_frame_timeout_s
|
||||
self._sample_rate: int = sample_rate
|
||||
self._voice_id: str = ""
|
||||
self._settings: Dict[str, Any] = {}
|
||||
self._text_filter: Optional[BaseTextFilter] = text_filter
|
||||
|
||||
self._stop_frame_task: Optional[asyncio.Task] = None
|
||||
self._stop_frame_queue: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
self._current_sentence: str = ""
|
||||
|
||||
@property
|
||||
def sample_rate(self) -> int:
|
||||
@@ -163,165 +234,20 @@ class TTSService(AIService):
|
||||
self.set_model_name(model)
|
||||
|
||||
@abstractmethod
|
||||
async def set_voice(self, voice: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_language(self, language: Language):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_speed(self, speed: Union[str, float]):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_emotion(self, emotion: List[str]):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_engine(self, engine: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_pitch(self, pitch: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_rate(self, rate: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_volume(self, volume: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_emphasis(self, emphasis: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_style(self, style: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_style_degree(self, style_degree: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_role(self, role: str):
|
||||
pass
|
||||
|
||||
# Converts the text to audio.
|
||||
@abstractmethod
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
pass
|
||||
|
||||
async def say(self, text: str):
|
||||
await self.process_frame(TextFrame(text=text), FrameDirection.DOWNSTREAM)
|
||||
|
||||
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
|
||||
self._current_sentence = ""
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _process_text_frame(self, frame: TextFrame):
|
||||
text: str | None = None
|
||||
if not self._aggregate_sentences:
|
||||
text = frame.text
|
||||
else:
|
||||
self._current_sentence += frame.text
|
||||
if match_endofsentence(self._current_sentence):
|
||||
text = self._current_sentence
|
||||
self._current_sentence = ""
|
||||
|
||||
if text:
|
||||
await self._push_tts_frames(text)
|
||||
|
||||
async def _push_tts_frames(self, text: str):
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return
|
||||
|
||||
await self.start_processing_metrics()
|
||||
await self.process_generator(self.run_tts(text))
|
||||
await self.stop_processing_metrics()
|
||||
if self._push_text_frames:
|
||||
# We send the original text after the audio. This way, if we are
|
||||
# interrupted, the text is not added to the assistant context.
|
||||
await self.push_frame(TextFrame(text))
|
||||
|
||||
async def _update_tts_settings(self, frame: TTSUpdateSettingsFrame):
|
||||
if frame.model is not None:
|
||||
await self.set_model(frame.model)
|
||||
if frame.voice is not None:
|
||||
await self.set_voice(frame.voice)
|
||||
if frame.language is not None:
|
||||
await self.set_language(frame.language)
|
||||
if frame.speed is not None:
|
||||
await self.set_speed(frame.speed)
|
||||
if frame.emotion is not None:
|
||||
await self.set_emotion(frame.emotion)
|
||||
if frame.engine is not None:
|
||||
await self.set_engine(frame.engine)
|
||||
if frame.pitch is not None:
|
||||
await self.set_pitch(frame.pitch)
|
||||
if frame.rate is not None:
|
||||
await self.set_rate(frame.rate)
|
||||
if frame.volume is not None:
|
||||
await self.set_volume(frame.volume)
|
||||
if frame.emphasis is not None:
|
||||
await self.set_emphasis(frame.emphasis)
|
||||
if frame.style is not None:
|
||||
await self.set_style(frame.style)
|
||||
if frame.style_degree is not None:
|
||||
await self.set_style_degree(frame.style_degree)
|
||||
if frame.role is not None:
|
||||
await self.set_role(frame.role)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, TextFrame):
|
||||
await self._process_text_frame(frame)
|
||||
elif isinstance(frame, StartInterruptionFrame):
|
||||
await self._handle_interruption(frame, direction)
|
||||
elif isinstance(frame, LLMFullResponseEndFrame) or isinstance(frame, EndFrame):
|
||||
sentence = self._current_sentence
|
||||
self._current_sentence = ""
|
||||
await self._push_tts_frames(sentence)
|
||||
if isinstance(frame, LLMFullResponseEndFrame):
|
||||
if self._push_text_frames:
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, TTSSpeakFrame):
|
||||
await self._push_tts_frames(frame.text)
|
||||
elif isinstance(frame, TTSUpdateSettingsFrame):
|
||||
await self._update_tts_settings(frame)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
class AsyncTTSService(TTSService):
|
||||
def __init__(
|
||||
self,
|
||||
# if True, TTSService will push TTSStoppedFrames, otherwise subclass must do it
|
||||
push_stop_frames: bool = False,
|
||||
# if push_stop_frames is True, wait for this idle period before pushing TTSStoppedFrame
|
||||
stop_frame_timeout_s: float = 1.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(sync=False, **kwargs)
|
||||
self._push_stop_frames: bool = push_stop_frames
|
||||
self._stop_frame_timeout_s: float = stop_frame_timeout_s
|
||||
self._stop_frame_task: Optional[asyncio.Task] = None
|
||||
self._stop_frame_queue: asyncio.Queue = asyncio.Queue()
|
||||
def set_voice(self, voice: str):
|
||||
self._voice_id = voice
|
||||
|
||||
@abstractmethod
|
||||
async def flush_audio(self):
|
||||
pass
|
||||
|
||||
async def say(self, text: str):
|
||||
await super().say(text)
|
||||
await self.flush_audio()
|
||||
def language_to_service_language(self, language: Language) -> str | None:
|
||||
return Language(language)
|
||||
|
||||
# Converts the text to audio.
|
||||
@abstractmethod
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
pass
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
@@ -342,10 +268,52 @@ class AsyncTTSService(TTSService):
|
||||
await self._stop_frame_task
|
||||
self._stop_frame_task = None
|
||||
|
||||
async def _update_settings(self, settings: Dict[str, Any]):
|
||||
for key, value in settings.items():
|
||||
if key in self._settings:
|
||||
logger.debug(f"Updating TTS setting {key} to: [{value}]")
|
||||
self._settings[key] = value
|
||||
if key == "language":
|
||||
self._settings[key] = self.language_to_service_language(value)
|
||||
elif key == "model":
|
||||
self.set_model_name(value)
|
||||
elif key == "voice":
|
||||
self.set_voice(value)
|
||||
elif key == "text_filter" and self._text_filter:
|
||||
self._text_filter.update_settings(value)
|
||||
else:
|
||||
logger.warning(f"Unknown setting for TTS service: {key}")
|
||||
|
||||
async def say(self, text: str):
|
||||
aggregate_sentences = self._aggregate_sentences
|
||||
self._aggregate_sentences = False
|
||||
await self.process_frame(TextFrame(text=text), FrameDirection.DOWNSTREAM)
|
||||
self._aggregate_sentences = aggregate_sentences
|
||||
await self.flush_audio()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
if isinstance(frame, TTSSpeakFrame):
|
||||
|
||||
if isinstance(frame, TextFrame):
|
||||
await self._process_text_frame(frame)
|
||||
elif isinstance(frame, StartInterruptionFrame):
|
||||
await self._handle_interruption(frame, direction)
|
||||
elif isinstance(frame, (LLMFullResponseEndFrame, EndFrame)):
|
||||
sentence = self._current_sentence
|
||||
self._current_sentence = ""
|
||||
await self._push_tts_frames(sentence)
|
||||
if isinstance(frame, LLMFullResponseEndFrame):
|
||||
if self._push_text_frames:
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, TTSSpeakFrame):
|
||||
await self._push_tts_frames(frame.text)
|
||||
await self.flush_audio()
|
||||
elif isinstance(frame, TTSUpdateSettingsFrame):
|
||||
await self._update_settings(frame.settings)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
await super().push_frame(frame, direction)
|
||||
@@ -358,6 +326,43 @@ class AsyncTTSService(TTSService):
|
||||
):
|
||||
await self._stop_frame_queue.put(frame)
|
||||
|
||||
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
|
||||
self._current_sentence = ""
|
||||
if self._text_filter:
|
||||
self._text_filter.handle_interruption()
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _process_text_frame(self, frame: TextFrame):
|
||||
text: str | None = None
|
||||
if not self._aggregate_sentences:
|
||||
text = frame.text
|
||||
else:
|
||||
self._current_sentence += frame.text
|
||||
eos_end_marker = match_endofsentence(self._current_sentence)
|
||||
if eos_end_marker:
|
||||
text = self._current_sentence[:eos_end_marker]
|
||||
self._current_sentence = self._current_sentence[eos_end_marker:]
|
||||
|
||||
if text:
|
||||
await self._push_tts_frames(text)
|
||||
|
||||
async def _push_tts_frames(self, text: str):
|
||||
# Don't send only whitespace. This causes problems for some TTS models. But also don't
|
||||
# strip all whitespace, as whitespace can influence prosody.
|
||||
if not text.strip():
|
||||
return
|
||||
|
||||
await self.start_processing_metrics()
|
||||
if self._text_filter:
|
||||
self._text_filter.reset_interruption()
|
||||
text = self._text_filter.filter(text)
|
||||
await self.process_generator(self.run_tts(text))
|
||||
await self.stop_processing_metrics()
|
||||
if self._push_text_frames:
|
||||
# We send the original text after the audio. This way, if we are
|
||||
# interrupted, the text is not added to the assistant context.
|
||||
await self.push_frame(TextFrame(text))
|
||||
|
||||
async def _stop_frame_handler(self):
|
||||
try:
|
||||
has_started = False
|
||||
@@ -378,7 +383,7 @@ class AsyncTTSService(TTSService):
|
||||
pass
|
||||
|
||||
|
||||
class AsyncWordTTSService(AsyncTTSService):
|
||||
class WordTTSService(TTSService):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._initial_word_timestamp = -1
|
||||
@@ -408,7 +413,7 @@ class AsyncWordTTSService(AsyncTTSService):
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, LLMFullResponseEndFrame) or isinstance(frame, EndFrame):
|
||||
if isinstance(frame, (LLMFullResponseEndFrame, EndFrame)):
|
||||
await self.flush_audio()
|
||||
|
||||
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
|
||||
@@ -422,15 +427,21 @@ class AsyncWordTTSService(AsyncTTSService):
|
||||
self._words_task = None
|
||||
|
||||
async def _words_task_handler(self):
|
||||
last_pts = 0
|
||||
while True:
|
||||
try:
|
||||
(word, timestamp) = await self._words_queue.get()
|
||||
if word == "LLMFullResponseEndFrame" and timestamp == 0:
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
frame = LLMFullResponseEndFrame()
|
||||
frame.pts = last_pts
|
||||
elif word == "TTSStoppedFrame" and timestamp == 0:
|
||||
frame = TTSStoppedFrame()
|
||||
frame.pts = last_pts
|
||||
else:
|
||||
frame = TextFrame(word)
|
||||
frame.pts = self._initial_word_timestamp + timestamp
|
||||
await self.push_frame(frame)
|
||||
last_pts = frame.pts
|
||||
await self.push_frame(frame)
|
||||
self._words_queue.task_done()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
@@ -443,6 +454,7 @@ class STTService(AIService):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._settings: Dict[str, Any] = {}
|
||||
|
||||
@abstractmethod
|
||||
async def set_model(self, model: str):
|
||||
@@ -457,11 +469,18 @@ class STTService(AIService):
|
||||
"""Returns transcript as a string"""
|
||||
pass
|
||||
|
||||
async def _update_stt_settings(self, frame: STTUpdateSettingsFrame):
|
||||
if frame.model is not None:
|
||||
await self.set_model(frame.model)
|
||||
if frame.language is not None:
|
||||
await self.set_language(frame.language)
|
||||
async def _update_settings(self, settings: Dict[str, Any]):
|
||||
logger.debug(f"Updating STT settings: {self._settings}")
|
||||
for key, value in settings.items():
|
||||
if key in self._settings:
|
||||
logger.debug(f"Updating STT setting {key} to: [{value}]")
|
||||
self._settings[key] = value
|
||||
if key == "language":
|
||||
await self.set_language(value)
|
||||
elif key == "model":
|
||||
self.set_model_name(value)
|
||||
else:
|
||||
logger.warning(f"Unknown setting for STT service: {key}")
|
||||
|
||||
async def process_audio_frame(self, frame: AudioRawFrame):
|
||||
await self.process_generator(self.run_stt(frame.audio))
|
||||
@@ -475,7 +494,7 @@ class STTService(AIService):
|
||||
# push a TextFrame. We don't really want to push audio frames down.
|
||||
await self.process_audio_frame(frame)
|
||||
elif isinstance(frame, STTUpdateSettingsFrame):
|
||||
await self._update_stt_settings(frame)
|
||||
await self._update_settings(frame.settings)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
@@ -55,6 +55,7 @@ except ModuleNotFoundError as e:
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
# internal use only -- todo: refactor
|
||||
@dataclass
|
||||
class AnthropicImageMessageFrame(Frame):
|
||||
user_image_raw_frame: UserImageRawFrame
|
||||
@@ -95,12 +96,14 @@ class AnthropicLLMService(LLMService):
|
||||
super().__init__(**kwargs)
|
||||
self._client = AsyncAnthropic(api_key=api_key)
|
||||
self.set_model_name(model)
|
||||
self._max_tokens = params.max_tokens
|
||||
self._enable_prompt_caching_beta: bool = params.enable_prompt_caching_beta or False
|
||||
self._temperature = params.temperature
|
||||
self._top_k = params.top_k
|
||||
self._top_p = params.top_p
|
||||
self._extra = params.extra if isinstance(params.extra, dict) else {}
|
||||
self._settings = {
|
||||
"max_tokens": params.max_tokens,
|
||||
"enable_prompt_caching_beta": params.enable_prompt_caching_beta or False,
|
||||
"temperature": params.temperature,
|
||||
"top_k": params.top_k,
|
||||
"top_p": params.top_p,
|
||||
"extra": params.extra if isinstance(params.extra, dict) else {},
|
||||
}
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
@@ -110,35 +113,15 @@ class AnthropicLLMService(LLMService):
|
||||
return self._enable_prompt_caching_beta
|
||||
|
||||
@staticmethod
|
||||
def create_context_aggregator(context: OpenAILLMContext) -> AnthropicContextAggregatorPair:
|
||||
def create_context_aggregator(
|
||||
context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True
|
||||
) -> AnthropicContextAggregatorPair:
|
||||
user = AnthropicUserContextAggregator(context)
|
||||
assistant = AnthropicAssistantContextAggregator(user)
|
||||
assistant = AnthropicAssistantContextAggregator(
|
||||
user, expect_stripped_words=assistant_expect_stripped_words
|
||||
)
|
||||
return AnthropicContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
async def set_enable_prompt_caching_beta(self, enable_prompt_caching_beta: bool):
|
||||
logger.debug(f"Switching LLM enable_prompt_caching_beta to: [{enable_prompt_caching_beta}]")
|
||||
self._enable_prompt_caching_beta = enable_prompt_caching_beta
|
||||
|
||||
async def set_max_tokens(self, max_tokens: int):
|
||||
logger.debug(f"Switching LLM max_tokens to: [{max_tokens}]")
|
||||
self._max_tokens = max_tokens
|
||||
|
||||
async def set_temperature(self, temperature: float):
|
||||
logger.debug(f"Switching LLM temperature to: [{temperature}]")
|
||||
self._temperature = temperature
|
||||
|
||||
async def set_top_k(self, top_k: float):
|
||||
logger.debug(f"Switching LLM top_k to: [{top_k}]")
|
||||
self._top_k = top_k
|
||||
|
||||
async def set_top_p(self, top_p: float):
|
||||
logger.debug(f"Switching LLM top_p to: [{top_p}]")
|
||||
self._top_p = top_p
|
||||
|
||||
async def set_extra(self, extra: Dict[str, Any]):
|
||||
logger.debug(f"Switching LLM extra to: [{extra}]")
|
||||
self._extra = extra
|
||||
|
||||
async def _process_context(self, context: OpenAILLMContext):
|
||||
# Usage tracking. We track the usage reported by Anthropic in prompt_tokens and
|
||||
# completion_tokens. We also estimate the completion tokens from output text
|
||||
@@ -160,11 +143,11 @@ class AnthropicLLMService(LLMService):
|
||||
)
|
||||
|
||||
messages = context.messages
|
||||
if self._enable_prompt_caching_beta:
|
||||
if self._settings["enable_prompt_caching_beta"]:
|
||||
messages = context.get_messages_with_cache_control_markers()
|
||||
|
||||
api_call = self._client.messages.create
|
||||
if self._enable_prompt_caching_beta:
|
||||
if self._settings["enable_prompt_caching_beta"]:
|
||||
api_call = self._client.beta.prompt_caching.messages.create
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
@@ -174,14 +157,14 @@ class AnthropicLLMService(LLMService):
|
||||
"system": context.system,
|
||||
"messages": messages,
|
||||
"model": self.model_name,
|
||||
"max_tokens": self._max_tokens,
|
||||
"max_tokens": self._settings["max_tokens"],
|
||||
"stream": True,
|
||||
"temperature": self._temperature,
|
||||
"top_k": self._top_k,
|
||||
"top_p": self._top_p,
|
||||
"temperature": self._settings["temperature"],
|
||||
"top_k": self._settings["top_k"],
|
||||
"top_p": self._settings["top_p"],
|
||||
}
|
||||
|
||||
params.update(self._extra)
|
||||
params.update(self._settings["extra"])
|
||||
|
||||
response = await api_call(**params)
|
||||
|
||||
@@ -279,27 +262,12 @@ class AnthropicLLMService(LLMService):
|
||||
cache_read_input_tokens=cache_read_input_tokens,
|
||||
)
|
||||
|
||||
async def _update_settings(self, frame: LLMUpdateSettingsFrame):
|
||||
if frame.model is not None:
|
||||
logger.debug(f"Switching LLM model to: [{frame.model}]")
|
||||
self.set_model_name(frame.model)
|
||||
if frame.max_tokens is not None:
|
||||
await self.set_max_tokens(frame.max_tokens)
|
||||
if frame.temperature is not None:
|
||||
await self.set_temperature(frame.temperature)
|
||||
if frame.top_k is not None:
|
||||
await self.set_top_k(frame.top_k)
|
||||
if frame.top_p is not None:
|
||||
await self.set_top_p(frame.top_p)
|
||||
if frame.extra:
|
||||
await self.set_extra(frame.extra)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
context = None
|
||||
if isinstance(frame, OpenAILLMContextFrame):
|
||||
context = frame.context
|
||||
context: "AnthropicLLMContext" = AnthropicLLMContext.upgrade_to_anthropic(frame.context)
|
||||
elif isinstance(frame, LLMMessagesFrame):
|
||||
context = AnthropicLLMContext.from_messages(frame.messages)
|
||||
elif isinstance(frame, VisionImageRawFrame):
|
||||
@@ -309,10 +277,10 @@ class AnthropicLLMService(LLMService):
|
||||
# to the context.
|
||||
context = AnthropicLLMContext.from_image_frame(frame)
|
||||
elif isinstance(frame, LLMUpdateSettingsFrame):
|
||||
await self._update_settings(frame)
|
||||
await self._update_settings(frame.settings)
|
||||
elif isinstance(frame, LLMEnablePromptCachingFrame):
|
||||
logger.debug(f"Setting enable prompt caching to: [{frame.enable}]")
|
||||
self._enable_prompt_caching_beta = frame.enable
|
||||
self._settings["enable_prompt_caching_beta"] = frame.enable
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -355,7 +323,6 @@ class AnthropicLLMContext(OpenAILLMContext):
|
||||
system: str | NotGiven = NOT_GIVEN,
|
||||
):
|
||||
super().__init__(messages=messages, tools=tools, tool_choice=tool_choice)
|
||||
self._user_image_request_context = {}
|
||||
|
||||
# For beta prompt caching. This is a counter that tracks the number of turns
|
||||
# we've seen above the cache threshold. We reset this when we reset the
|
||||
@@ -365,6 +332,14 @@ class AnthropicLLMContext(OpenAILLMContext):
|
||||
|
||||
self.system = system
|
||||
|
||||
@staticmethod
|
||||
def upgrade_to_anthropic(obj: OpenAILLMContext) -> "AnthropicLLMContext":
|
||||
logger.debug(f"Upgrading to Anthropic: {obj}")
|
||||
if isinstance(obj, OpenAILLMContext) and not isinstance(obj, AnthropicLLMContext):
|
||||
obj.__class__ = AnthropicLLMContext
|
||||
obj._restructure_from_openai_messages()
|
||||
return obj
|
||||
|
||||
@classmethod
|
||||
def from_openai_context(cls, openai_context: OpenAILLMContext):
|
||||
self = cls(
|
||||
@@ -394,6 +369,100 @@ class AnthropicLLMContext(OpenAILLMContext):
|
||||
self._messages[:] = messages
|
||||
self._restructure_from_openai_messages()
|
||||
|
||||
# convert a message in Anthropic format into one or more messages in OpenAI format
|
||||
def to_standard_messages(self, obj):
|
||||
# todo: image format (?)
|
||||
# tool_use
|
||||
role = obj.get("role")
|
||||
content = obj.get("content")
|
||||
if role == "assistant":
|
||||
if isinstance(content, str):
|
||||
return [{"role": role, "content": [{"type": "text", "text": content}]}]
|
||||
elif isinstance(content, list):
|
||||
text_items = []
|
||||
tool_items = []
|
||||
for item in content:
|
||||
if item["type"] == "text":
|
||||
text_items.append({"type": "text", "text": item["text"]})
|
||||
elif item["type"] == "tool_use":
|
||||
tool_items.append(
|
||||
{
|
||||
"type": "function",
|
||||
"id": item["id"],
|
||||
"function": {
|
||||
"name": item["name"],
|
||||
"arguments": json.dumps(item["input"]),
|
||||
},
|
||||
}
|
||||
)
|
||||
messages = []
|
||||
if text_items:
|
||||
messages.append({"role": role, "content": text_items})
|
||||
if tool_items:
|
||||
messages.append({"role": role, "tool_calls": tool_items})
|
||||
return messages
|
||||
elif role == "user":
|
||||
if isinstance(content, str):
|
||||
return [{"role": role, "content": [{"type": "text", "text": content}]}]
|
||||
elif isinstance(content, list):
|
||||
text_items = []
|
||||
tool_items = []
|
||||
for item in content:
|
||||
if item["type"] == "text":
|
||||
text_items.append({"type": "text", "text": item["text"]})
|
||||
elif item["type"] == "tool_result":
|
||||
tool_items.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": item["tool_use_id"],
|
||||
"content": item["content"],
|
||||
}
|
||||
)
|
||||
messages = []
|
||||
if text_items:
|
||||
messages.append({"role": role, "content": text_items})
|
||||
messages.extend(tool_items)
|
||||
return messages
|
||||
|
||||
def from_standard_message(self, message):
|
||||
# todo: image messages (?)
|
||||
if message["role"] == "tool":
|
||||
return {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": message["tool_call_id"],
|
||||
"content": message["content"],
|
||||
},
|
||||
],
|
||||
}
|
||||
if message.get("tool_calls"):
|
||||
tc = message["tool_calls"]
|
||||
ret = {"role": "assistant", "content": []}
|
||||
for tool_call in tc:
|
||||
function = tool_call["function"]
|
||||
arguments = json.loads(function["arguments"])
|
||||
new_tool_use = {
|
||||
"type": "tool_use",
|
||||
"id": tool_call["id"],
|
||||
"name": function["name"],
|
||||
"input": arguments,
|
||||
}
|
||||
ret["content"].append(new_tool_use)
|
||||
return ret
|
||||
# check for empty text strings
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
if content == "":
|
||||
content = "(empty)"
|
||||
elif isinstance(content, list):
|
||||
for item in content:
|
||||
if item["type"] == "text" and item["text"] == "":
|
||||
item["text"] = "(empty)"
|
||||
|
||||
return message
|
||||
|
||||
def add_image_frame_message(
|
||||
self, *, format: str, size: tuple[int, int], image: bytes, text: str = None
|
||||
):
|
||||
@@ -462,6 +531,12 @@ class AnthropicLLMContext(OpenAILLMContext):
|
||||
return self.messages
|
||||
|
||||
def _restructure_from_openai_messages(self):
|
||||
# first, map across self._messages calling self.from_standard_message(m) to modify messages in place
|
||||
try:
|
||||
self._messages[:] = [self.from_standard_message(m) for m in self._messages]
|
||||
except Exception as e:
|
||||
logger.error(f"Error mapping messages: {e}")
|
||||
|
||||
# See if we should pull the system message out of our context.messages list. (For
|
||||
# compatibility with Open AI messages format.)
|
||||
if self.messages and self.messages[0]["role"] == "system":
|
||||
@@ -475,6 +550,39 @@ class AnthropicLLMContext(OpenAILLMContext):
|
||||
self.system = self.messages[0]["content"]
|
||||
self.messages.pop(0)
|
||||
|
||||
# Merge consecutive messages with the same role.
|
||||
i = 0
|
||||
while i < len(self.messages) - 1:
|
||||
current_message = self.messages[i]
|
||||
next_message = self.messages[i + 1]
|
||||
if current_message["role"] == next_message["role"]:
|
||||
# Convert content to list of dictionaries if it's a string
|
||||
if isinstance(current_message["content"], str):
|
||||
current_message["content"] = [
|
||||
{"type": "text", "text": current_message["content"]}
|
||||
]
|
||||
if isinstance(next_message["content"], str):
|
||||
next_message["content"] = [{"type": "text", "text": next_message["content"]}]
|
||||
# Concatenate the content
|
||||
current_message["content"].extend(next_message["content"])
|
||||
# Remove the next message from the list
|
||||
self.messages.pop(i + 1)
|
||||
else:
|
||||
i += 1
|
||||
|
||||
# Avoid empty content in messages
|
||||
for message in self.messages:
|
||||
if isinstance(message["content"], str) and message["content"] == "":
|
||||
message["content"] = "(empty)"
|
||||
elif isinstance(message["content"], list) and len(message["content"]) == 0:
|
||||
message["content"] = [{"type": "text", "text": "(empty)"}]
|
||||
|
||||
def get_messages_for_persistent_storage(self):
|
||||
messages = super().get_messages_for_persistent_storage()
|
||||
if self.system:
|
||||
messages.insert(0, {"role": "system", "content": self.system})
|
||||
return messages
|
||||
|
||||
def get_messages_for_logging(self) -> str:
|
||||
msgs = []
|
||||
for message in self.messages:
|
||||
@@ -541,8 +649,8 @@ class AnthropicUserContextAggregator(LLMUserContextAggregator):
|
||||
|
||||
|
||||
class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
def __init__(self, user_context_aggregator: AnthropicUserContextAggregator):
|
||||
super().__init__(context=user_context_aggregator._context)
|
||||
def __init__(self, user_context_aggregator: AnthropicUserContextAggregator, **kwargs):
|
||||
super().__init__(context=user_context_aggregator._context, **kwargs)
|
||||
self._user_context_aggregator = user_context_aggregator
|
||||
self._function_call_in_progress = None
|
||||
self._function_call_result = None
|
||||
@@ -579,7 +687,7 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
run_llm = False
|
||||
|
||||
aggregation = self._aggregation
|
||||
self._aggregation = ""
|
||||
self._reset()
|
||||
|
||||
try:
|
||||
if self._function_call_result:
|
||||
@@ -630,5 +738,8 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
if run_llm:
|
||||
await self._user_context_aggregator.push_context_frame()
|
||||
|
||||
frame = OpenAILLMContextFrame(self._context)
|
||||
await self.push_frame(frame)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing frame: {e}")
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
@@ -16,8 +17,7 @@ from pipecat.frames.frames import (
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.ai_services import TTSService
|
||||
|
||||
from loguru import logger
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
try:
|
||||
import boto3
|
||||
@@ -33,7 +33,7 @@ except ModuleNotFoundError as e:
|
||||
class AWSTTSService(TTSService):
|
||||
class InputParams(BaseModel):
|
||||
engine: Optional[str] = None
|
||||
language: Optional[str] = None
|
||||
language: Optional[Language] = Language.EN
|
||||
pitch: Optional[str] = None
|
||||
rate: Optional[str] = None
|
||||
volume: Optional[str] = None
|
||||
@@ -57,28 +57,95 @@ class AWSTTSService(TTSService):
|
||||
aws_secret_access_key=api_key,
|
||||
region_name=region,
|
||||
)
|
||||
self._voice_id = voice_id
|
||||
self._sample_rate = sample_rate
|
||||
self._params = params
|
||||
self._settings = {
|
||||
"sample_rate": sample_rate,
|
||||
"engine": params.engine,
|
||||
"language": self.language_to_service_language(params.language)
|
||||
if params.language
|
||||
else Language.EN,
|
||||
"pitch": params.pitch,
|
||||
"rate": params.rate,
|
||||
"volume": params.volume,
|
||||
}
|
||||
|
||||
self.set_voice(voice_id)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> str | None:
|
||||
match language:
|
||||
case Language.CA:
|
||||
return "ca-ES"
|
||||
case Language.ZH:
|
||||
return "cmn-CN"
|
||||
case Language.DA:
|
||||
return "da-DK"
|
||||
case Language.NL:
|
||||
return "nl-NL"
|
||||
case Language.NL_BE:
|
||||
return "nl-BE"
|
||||
case Language.EN | Language.EN_US:
|
||||
return "en-US"
|
||||
case Language.EN_AU:
|
||||
return "en-AU"
|
||||
case Language.EN_GB:
|
||||
return "en-GB"
|
||||
case Language.EN_NZ:
|
||||
return "en-NZ"
|
||||
case Language.EN_IN:
|
||||
return "en-IN"
|
||||
case Language.FI:
|
||||
return "fi-FI"
|
||||
case Language.FR:
|
||||
return "fr-FR"
|
||||
case Language.FR_CA:
|
||||
return "fr-CA"
|
||||
case Language.DE:
|
||||
return "de-DE"
|
||||
case Language.HI:
|
||||
return "hi-IN"
|
||||
case Language.IT:
|
||||
return "it-IT"
|
||||
case Language.JA:
|
||||
return "ja-JP"
|
||||
case Language.KO:
|
||||
return "ko-KR"
|
||||
case Language.NO:
|
||||
return "nb-NO"
|
||||
case Language.PL:
|
||||
return "pl-PL"
|
||||
case Language.PT:
|
||||
return "pt-PT"
|
||||
case Language.PT_BR:
|
||||
return "pt-BR"
|
||||
case Language.RO:
|
||||
return "ro-RO"
|
||||
case Language.RU:
|
||||
return "ru-RU"
|
||||
case Language.ES:
|
||||
return "es-ES"
|
||||
case Language.SV:
|
||||
return "sv-SE"
|
||||
case Language.TR:
|
||||
return "tr-TR"
|
||||
return None
|
||||
|
||||
def _construct_ssml(self, text: str) -> str:
|
||||
ssml = "<speak>"
|
||||
|
||||
if self._params.language:
|
||||
ssml += f"<lang xml:lang='{self._params.language}'>"
|
||||
language = self._settings["language"]
|
||||
ssml += f"<lang xml:lang='{language}'>"
|
||||
|
||||
prosody_attrs = []
|
||||
# Prosody tags are only supported for standard and neural engines
|
||||
if self._params.engine != "generative":
|
||||
if self._params.rate:
|
||||
prosody_attrs.append(f"rate='{self._params.rate}'")
|
||||
if self._params.pitch:
|
||||
prosody_attrs.append(f"pitch='{self._params.pitch}'")
|
||||
if self._params.volume:
|
||||
prosody_attrs.append(f"volume='{self._params.volume}'")
|
||||
if self._settings["engine"] != "generative":
|
||||
if self._settings["rate"]:
|
||||
prosody_attrs.append(f"rate='{self._settings['rate']}'")
|
||||
if self._settings["pitch"]:
|
||||
prosody_attrs.append(f"pitch='{self._settings['pitch']}'")
|
||||
if self._settings["volume"]:
|
||||
prosody_attrs.append(f"volume='{self._settings['volume']}'")
|
||||
|
||||
if prosody_attrs:
|
||||
ssml += f"<prosody {' '.join(prosody_attrs)}>"
|
||||
@@ -90,41 +157,12 @@ class AWSTTSService(TTSService):
|
||||
if prosody_attrs:
|
||||
ssml += "</prosody>"
|
||||
|
||||
if self._params.language:
|
||||
ssml += "</lang>"
|
||||
ssml += "</lang>"
|
||||
|
||||
ssml += "</speak>"
|
||||
|
||||
return ssml
|
||||
|
||||
async def set_voice(self, voice: str):
|
||||
logger.debug(f"Switching TTS voice to: [{voice}]")
|
||||
self._voice_id = voice
|
||||
|
||||
async def set_engine(self, engine: str):
|
||||
logger.debug(f"Switching TTS engine to: [{engine}]")
|
||||
self._params.engine = engine
|
||||
|
||||
async def set_language(self, language: str):
|
||||
logger.debug(f"Switching TTS language to: [{language}]")
|
||||
self._params.language = language
|
||||
|
||||
async def set_pitch(self, pitch: str):
|
||||
logger.debug(f"Switching TTS pitch to: [{pitch}]")
|
||||
self._params.pitch = pitch
|
||||
|
||||
async def set_rate(self, rate: str):
|
||||
logger.debug(f"Switching TTS rate to: [{rate}]")
|
||||
self._params.rate = rate
|
||||
|
||||
async def set_volume(self, volume: str):
|
||||
logger.debug(f"Switching TTS volume to: [{volume}]")
|
||||
self._params.volume = volume
|
||||
|
||||
async def set_params(self, params: InputParams):
|
||||
logger.debug(f"Switching TTS params to: [{params}]")
|
||||
self._params = params
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Generating TTS: [{text}]")
|
||||
|
||||
@@ -139,8 +177,8 @@ class AWSTTSService(TTSService):
|
||||
"TextType": "ssml",
|
||||
"OutputFormat": "pcm",
|
||||
"VoiceId": self._voice_id,
|
||||
"Engine": self._params.engine,
|
||||
"SampleRate": str(self._sample_rate),
|
||||
"Engine": self._settings["engine"],
|
||||
"SampleRate": str(self._settings["sample_rate"]),
|
||||
}
|
||||
|
||||
# Filter out None values
|
||||
@@ -150,7 +188,7 @@ class AWSTTSService(TTSService):
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
yield TTSStartedFrame()
|
||||
|
||||
if "AudioStream" in response:
|
||||
with response["AudioStream"] as stream:
|
||||
@@ -160,10 +198,10 @@ class AWSTTSService(TTSService):
|
||||
chunk = audio_data[i : i + chunk_size]
|
||||
if len(chunk) > 0:
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(chunk, self._sample_rate, 1)
|
||||
frame = TTSAudioRawFrame(chunk, self._settings["sample_rate"], 1)
|
||||
yield frame
|
||||
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
except (BotoCoreError, ClientError) as error:
|
||||
logger.exception(f"{self} error generating TTS: {error}")
|
||||
@@ -171,4 +209,4 @@ class AWSTTSService(TTSService):
|
||||
yield ErrorFrame(error=error_message)
|
||||
|
||||
finally:
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -4,12 +4,13 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import io
|
||||
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
@@ -26,12 +27,9 @@ from pipecat.frames.frames import (
|
||||
)
|
||||
from pipecat.services.ai_services import ImageGenService, STTService, TTSService
|
||||
from pipecat.services.openai import BaseOpenAILLMService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from loguru import logger
|
||||
|
||||
# See .env.example for Azure configuration needed
|
||||
try:
|
||||
from azure.cognitiveservices.speech import (
|
||||
@@ -76,7 +74,7 @@ class AzureLLMService(BaseOpenAILLMService):
|
||||
class AzureTTSService(TTSService):
|
||||
class InputParams(BaseModel):
|
||||
emphasis: Optional[str] = None
|
||||
language: Optional[str] = "en-US"
|
||||
language: Optional[Language] = Language.EN_US
|
||||
pitch: Optional[str] = None
|
||||
rate: Optional[str] = "1.05"
|
||||
role: Optional[str] = None
|
||||
@@ -99,114 +97,158 @@ class AzureTTSService(TTSService):
|
||||
speech_config = SpeechConfig(subscription=api_key, region=region)
|
||||
self._speech_synthesizer = SpeechSynthesizer(speech_config=speech_config, audio_config=None)
|
||||
|
||||
self._voice = voice
|
||||
self._sample_rate = sample_rate
|
||||
self._params = params
|
||||
self._settings = {
|
||||
"sample_rate": sample_rate,
|
||||
"emphasis": params.emphasis,
|
||||
"language": self.language_to_service_language(params.language)
|
||||
if params.language
|
||||
else Language.EN_US,
|
||||
"pitch": params.pitch,
|
||||
"rate": params.rate,
|
||||
"role": params.role,
|
||||
"style": params.style,
|
||||
"style_degree": params.style_degree,
|
||||
"volume": params.volume,
|
||||
}
|
||||
|
||||
self.set_voice(voice)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> str | None:
|
||||
match language:
|
||||
case Language.BG:
|
||||
return "bg-BG"
|
||||
case Language.CA:
|
||||
return "ca-ES"
|
||||
case Language.ZH:
|
||||
return "zh-CN"
|
||||
case Language.ZH_TW:
|
||||
return "zh-TW"
|
||||
case Language.CS:
|
||||
return "cs-CZ"
|
||||
case Language.DA:
|
||||
return "da-DK"
|
||||
case Language.NL:
|
||||
return "nl-NL"
|
||||
case Language.EN | Language.EN_US:
|
||||
return "en-US"
|
||||
case Language.EN_AU:
|
||||
return "en-AU"
|
||||
case Language.EN_GB:
|
||||
return "en-GB"
|
||||
case Language.EN_NZ:
|
||||
return "en-NZ"
|
||||
case Language.EN_IN:
|
||||
return "en-IN"
|
||||
case Language.ET:
|
||||
return "et-EE"
|
||||
case Language.FI:
|
||||
return "fi-FI"
|
||||
case Language.NL_BE:
|
||||
return "nl-BE"
|
||||
case Language.FR:
|
||||
return "fr-FR"
|
||||
case Language.FR_CA:
|
||||
return "fr-CA"
|
||||
case Language.DE:
|
||||
return "de-DE"
|
||||
case Language.DE_CH:
|
||||
return "de-CH"
|
||||
case Language.EL:
|
||||
return "el-GR"
|
||||
case Language.HI:
|
||||
return "hi-IN"
|
||||
case Language.HU:
|
||||
return "hu-HU"
|
||||
case Language.ID:
|
||||
return "id-ID"
|
||||
case Language.IT:
|
||||
return "it-IT"
|
||||
case Language.JA:
|
||||
return "ja-JP"
|
||||
case Language.KO:
|
||||
return "ko-KR"
|
||||
case Language.LV:
|
||||
return "lv-LV"
|
||||
case Language.LT:
|
||||
return "lt-LT"
|
||||
case Language.MS:
|
||||
return "ms-MY"
|
||||
case Language.NO:
|
||||
return "nb-NO"
|
||||
case Language.PL:
|
||||
return "pl-PL"
|
||||
case Language.PT:
|
||||
return "pt-PT"
|
||||
case Language.PT_BR:
|
||||
return "pt-BR"
|
||||
case Language.RO:
|
||||
return "ro-RO"
|
||||
case Language.RU:
|
||||
return "ru-RU"
|
||||
case Language.SK:
|
||||
return "sk-SK"
|
||||
case Language.ES:
|
||||
return "es-ES"
|
||||
case Language.SV:
|
||||
return "sv-SE"
|
||||
case Language.TH:
|
||||
return "th-TH"
|
||||
case Language.TR:
|
||||
return "tr-TR"
|
||||
case Language.UK:
|
||||
return "uk-UA"
|
||||
case Language.VI:
|
||||
return "vi-VN"
|
||||
return None
|
||||
|
||||
def _construct_ssml(self, text: str) -> str:
|
||||
language = self._settings["language"]
|
||||
ssml = (
|
||||
f"<speak version='1.0' xml:lang='{self._params.language}' "
|
||||
f"<speak version='1.0' xml:lang='{language}' "
|
||||
"xmlns='http://www.w3.org/2001/10/synthesis' "
|
||||
"xmlns:mstts='http://www.w3.org/2001/mstts'>"
|
||||
f"<voice name='{self._voice}'>"
|
||||
f"<voice name='{self._voice_id}'>"
|
||||
"<mstts:silence type='Sentenceboundary' value='20ms' />"
|
||||
)
|
||||
|
||||
if self._params.style:
|
||||
ssml += f"<mstts:express-as style='{self._params.style}'"
|
||||
if self._params.style_degree:
|
||||
ssml += f" styledegree='{self._params.style_degree}'"
|
||||
if self._params.role:
|
||||
ssml += f" role='{self._params.role}'"
|
||||
if self._settings["style"]:
|
||||
ssml += f"<mstts:express-as style='{self._settings['style']}'"
|
||||
if self._settings["style_degree"]:
|
||||
ssml += f" styledegree='{self._settings['style_degree']}'"
|
||||
if self._settings["role"]:
|
||||
ssml += f" role='{self._settings['role']}'"
|
||||
ssml += ">"
|
||||
|
||||
prosody_attrs = []
|
||||
if self._params.rate:
|
||||
prosody_attrs.append(f"rate='{self._params.rate}'")
|
||||
if self._params.pitch:
|
||||
prosody_attrs.append(f"pitch='{self._params.pitch}'")
|
||||
if self._params.volume:
|
||||
prosody_attrs.append(f"volume='{self._params.volume}'")
|
||||
if self._settings["rate"]:
|
||||
prosody_attrs.append(f"rate='{self._settings['rate']}'")
|
||||
if self._settings["pitch"]:
|
||||
prosody_attrs.append(f"pitch='{self._settings['pitch']}'")
|
||||
if self._settings["volume"]:
|
||||
prosody_attrs.append(f"volume='{self._settings['volume']}'")
|
||||
|
||||
ssml += f"<prosody {' '.join(prosody_attrs)}>"
|
||||
|
||||
if self._params.emphasis:
|
||||
ssml += f"<emphasis level='{self._params.emphasis}'>"
|
||||
if self._settings["emphasis"]:
|
||||
ssml += f"<emphasis level='{self._settings['emphasis']}'>"
|
||||
|
||||
ssml += text
|
||||
|
||||
if self._params.emphasis:
|
||||
if self._settings["emphasis"]:
|
||||
ssml += "</emphasis>"
|
||||
|
||||
ssml += "</prosody>"
|
||||
|
||||
if self._params.style:
|
||||
if self._settings["style"]:
|
||||
ssml += "</mstts:express-as>"
|
||||
|
||||
ssml += "</voice></speak>"
|
||||
|
||||
return ssml
|
||||
|
||||
async def set_voice(self, voice: str):
|
||||
logger.debug(f"Switching TTS voice to: [{voice}]")
|
||||
self._voice = voice
|
||||
|
||||
async def set_emphasis(self, emphasis: str):
|
||||
logger.debug(f"Setting TTS emphasis to: [{emphasis}]")
|
||||
self._params.emphasis = emphasis
|
||||
|
||||
async def set_language(self, language: str):
|
||||
logger.debug(f"Setting TTS language code to: [{language}]")
|
||||
self._params.language = language
|
||||
|
||||
async def set_pitch(self, pitch: str):
|
||||
logger.debug(f"Setting TTS pitch to: [{pitch}]")
|
||||
self._params.pitch = pitch
|
||||
|
||||
async def set_rate(self, rate: str):
|
||||
logger.debug(f"Setting TTS rate to: [{rate}]")
|
||||
self._params.rate = rate
|
||||
|
||||
async def set_role(self, role: str):
|
||||
logger.debug(f"Setting TTS role to: [{role}]")
|
||||
self._params.role = role
|
||||
|
||||
async def set_style(self, style: str):
|
||||
logger.debug(f"Setting TTS style to: [{style}]")
|
||||
self._params.style = style
|
||||
|
||||
async def set_style_degree(self, style_degree: str):
|
||||
logger.debug(f"Setting TTS style degree to: [{style_degree}]")
|
||||
self._params.style_degree = style_degree
|
||||
|
||||
async def set_volume(self, volume: str):
|
||||
logger.debug(f"Setting TTS volume to: [{volume}]")
|
||||
self._params.volume = volume
|
||||
|
||||
async def set_params(self, **kwargs):
|
||||
valid_params = {
|
||||
"voice": self.set_voice,
|
||||
"emphasis": self.set_emphasis,
|
||||
"language_code": self.set_language,
|
||||
"pitch": self.set_pitch,
|
||||
"rate": self.set_rate,
|
||||
"role": self.set_role,
|
||||
"style": self.set_style,
|
||||
"style_degree": self.set_style_degree,
|
||||
"volume": self.set_volume,
|
||||
}
|
||||
|
||||
for param, value in kwargs.items():
|
||||
if param in valid_params:
|
||||
await valid_params[param](value)
|
||||
else:
|
||||
logger.warning(f"Ignoring unknown parameter: {param}")
|
||||
|
||||
logger.debug(f"Updated TTS parameters: {', '.join(kwargs.keys())}")
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Generating TTS: [{text}]")
|
||||
|
||||
@@ -219,12 +261,14 @@ class AzureTTSService(TTSService):
|
||||
if result.reason == ResultReason.SynthesizingAudioCompleted:
|
||||
await self.start_tts_usage_metrics(text)
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
yield TTSStartedFrame()
|
||||
# Azure always sends a 44-byte header. Strip it off.
|
||||
yield TTSAudioRawFrame(
|
||||
audio=result.audio_data[44:], sample_rate=self._sample_rate, num_channels=1
|
||||
audio=result.audio_data[44:],
|
||||
sample_rate=self._settings["sample_rate"],
|
||||
num_channels=1,
|
||||
)
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
yield TTSStoppedFrame()
|
||||
elif result.reason == ResultReason.Canceled:
|
||||
cancellation_details = result.cancellation_details
|
||||
logger.warning(f"Speech synthesis canceled: {cancellation_details.reason}")
|
||||
@@ -238,7 +282,7 @@ class AzureSTTService(STTService):
|
||||
*,
|
||||
api_key: str,
|
||||
region: str,
|
||||
language="en-US",
|
||||
language=Language.EN_US,
|
||||
sample_rate=16000,
|
||||
channels=1,
|
||||
**kwargs,
|
||||
|
||||
190
src/pipecat/services/canonical.py
Normal file
190
src/pipecat/services/canonical.py
Normal file
@@ -0,0 +1,190 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import aiohttp
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from pipecat.frames.frames import CancelFrame, EndFrame, Frame
|
||||
from pipecat.processors.audio.audio_buffer_processor import AudioBufferProcessor
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import AIService
|
||||
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
import aiofiles
|
||||
import aiofiles.os
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use Canonical Metrics, you need to `pip install pipecat-ai[canonical]`. "
|
||||
+ "Also, set the `CANONICAL_API_KEY` environment variable."
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
# Multipart upload part size in bytes, cannot be smaller than 5MB
|
||||
PART_SIZE = 1024 * 1024 * 5
|
||||
|
||||
|
||||
class CanonicalMetricsService(AIService):
|
||||
"""Initialize a CanonicalAudioProcessor instance.
|
||||
|
||||
This class uses an AudioBufferProcessor to get the conversation audio and
|
||||
uploads it to Canonical Voice API for audio processing.
|
||||
|
||||
Args:
|
||||
|
||||
call_id (str): Your unique identifier for the call. This is used to match the call in the Canonical Voice system to the call in your system.
|
||||
assistant (str): Identifier for the AI assistant. This can be whatever you want, it's intended for you convenience so you can distinguish
|
||||
between different assistants and a grouping mechanism for calls.
|
||||
assistant_speaks_first (bool, optional): Indicates if the assistant speaks first in the conversation. Defaults to True.
|
||||
output_dir (str, optional): Directory to save temporary audio files. Defaults to "recordings".
|
||||
|
||||
Attributes:
|
||||
call_id (str): Stores the unique call identifier.
|
||||
assistant (str): Stores the assistant identifier.
|
||||
assistant_speaks_first (bool): Indicates whether the assistant speaks first.
|
||||
output_dir (str): Directory path for saving temporary audio files.
|
||||
|
||||
The constructor also ensures that the output directory exists.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
aiohttp_session: aiohttp.ClientSession,
|
||||
audio_buffer_processor: AudioBufferProcessor,
|
||||
call_id: str,
|
||||
assistant: str,
|
||||
api_key: str,
|
||||
api_url: str = "https://voiceapp.canonical.chat/api/v1",
|
||||
assistant_speaks_first: bool = True,
|
||||
output_dir: str = "recordings",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._aiohttp_session = aiohttp_session
|
||||
self._audio_buffer_processor = audio_buffer_processor
|
||||
self._api_key = api_key
|
||||
self._api_url = api_url
|
||||
self._call_id = call_id
|
||||
self._assistant = assistant
|
||||
self._assistant_speaks_first = assistant_speaks_first
|
||||
self._output_dir = output_dir
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await self._process_audio()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await self._process_audio()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _process_audio(self):
|
||||
pipeline = self._audio_buffer_processor
|
||||
if pipeline.has_audio():
|
||||
os.makedirs(self._output_dir, exist_ok=True)
|
||||
filename = self._get_output_filename()
|
||||
wave_data = pipeline.merge_audio_buffers()
|
||||
|
||||
async with aiofiles.open(filename, "wb") as file:
|
||||
await file.write(wave_data)
|
||||
|
||||
try:
|
||||
await self._multipart_upload(filename)
|
||||
pipeline.reset_audio_buffer()
|
||||
await aiofiles.os.remove(filename)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to upload recording: {e}")
|
||||
|
||||
def _get_output_filename(self):
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
return f"{self._output_dir}/{timestamp}-{uuid.uuid4().hex}.wav"
|
||||
|
||||
def _request_headers(self):
|
||||
return {"Content-Type": "application/json", "X-Canonical-Api-Key": self._api_key}
|
||||
|
||||
async def _multipart_upload(self, file_path: str):
|
||||
upload_request, upload_response = await self._request_upload(file_path)
|
||||
if upload_request is None or upload_response is None:
|
||||
return
|
||||
parts = await self._upload_parts(file_path, upload_response)
|
||||
if parts is None:
|
||||
return
|
||||
await self._upload_complete(parts, upload_request, upload_response)
|
||||
|
||||
async def _request_upload(self, file_path: str) -> Tuple[Dict, Dict]:
|
||||
filename = os.path.basename(file_path)
|
||||
filesize = os.path.getsize(file_path)
|
||||
numparts = int((filesize + PART_SIZE - 1) / PART_SIZE)
|
||||
|
||||
params = {
|
||||
"filename": filename,
|
||||
"parts": numparts,
|
||||
"callId": self._call_id,
|
||||
"assistant": {"id": self._assistant, "speaksFirst": self._assistant_speaks_first},
|
||||
}
|
||||
logger.debug(f"Requesting presigned URLs for {numparts} parts")
|
||||
response = await self._aiohttp_session.post(
|
||||
f"{self._api_url}/recording/uploadRequest", headers=self._request_headers(), json=params
|
||||
)
|
||||
if not response.ok:
|
||||
logger.error(f"Failed to get presigned URLs: {await response.text()}")
|
||||
return None, None
|
||||
response_json = await response.json()
|
||||
return params, response_json
|
||||
|
||||
async def _upload_parts(self, file_path: str, upload_response: Dict) -> List[Dict]:
|
||||
urls = upload_response["urls"]
|
||||
parts = []
|
||||
try:
|
||||
async with aiofiles.open(file_path, "rb") as file:
|
||||
for partnum, upload_url in enumerate(urls, start=1):
|
||||
data = await file.read(PART_SIZE)
|
||||
if not data:
|
||||
break
|
||||
|
||||
response = await self._aiohttp_session.put(upload_url, data=data)
|
||||
if not response.ok:
|
||||
logger.error(f"Failed to upload part {partnum}: {await response.text()}")
|
||||
return None
|
||||
|
||||
etag = response.headers["ETag"]
|
||||
parts.append({"partnum": str(partnum), "etag": etag})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Multipart upload aborted, an error occurred: {str(e)}")
|
||||
return parts
|
||||
|
||||
async def _upload_complete(
|
||||
self, parts: List[Dict], upload_request: Dict, upload_response: Dict
|
||||
):
|
||||
params = {
|
||||
"filename": upload_request["filename"],
|
||||
"parts": parts,
|
||||
"slug": upload_response["slug"],
|
||||
"callId": self._call_id,
|
||||
"assistant": {"id": self._assistant, "speaksFirst": self._assistant_speaks_first},
|
||||
}
|
||||
logger.debug(f"Completing upload for {params['filename']}")
|
||||
logger.debug(f"Slug: {params['slug']}")
|
||||
response = await self._aiohttp_session.post(
|
||||
f"{self._api_url}/recording/uploadComplete",
|
||||
headers=self._request_headers(),
|
||||
json=params,
|
||||
)
|
||||
if not response.ok:
|
||||
logger.error(f"Failed to complete upload: {await response.text()}")
|
||||
return
|
||||
@@ -4,36 +4,35 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import uuid
|
||||
import base64
|
||||
import asyncio
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from typing import AsyncGenerator, Optional, Union, List
|
||||
from loguru import logger
|
||||
from pydantic.main import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
StartInterruptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
StartFrame,
|
||||
EndFrame,
|
||||
StartInterruptionFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import TTSService, WordTTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.services.ai_services import AsyncWordTTSService, TTSService
|
||||
|
||||
from loguru import logger
|
||||
|
||||
# See .env.example for Cartesia configuration needed
|
||||
try:
|
||||
from cartesia import AsyncCartesia
|
||||
import websockets
|
||||
from cartesia import AsyncCartesia
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
@@ -46,27 +45,34 @@ def language_to_cartesia_language(language: Language) -> str | None:
|
||||
match language:
|
||||
case Language.DE:
|
||||
return "de"
|
||||
case Language.EN:
|
||||
case (
|
||||
Language.EN
|
||||
| Language.EN_US
|
||||
| Language.EN_GB
|
||||
| Language.EN_AU
|
||||
| Language.EN_NZ
|
||||
| Language.EN_IN
|
||||
):
|
||||
return "en"
|
||||
case Language.ES:
|
||||
return "es"
|
||||
case Language.FR:
|
||||
case Language.FR | Language.FR_CA:
|
||||
return "fr"
|
||||
case Language.JA:
|
||||
return "ja"
|
||||
case Language.PT:
|
||||
case Language.PT | Language.PT_BR:
|
||||
return "pt"
|
||||
case Language.ZH:
|
||||
case Language.ZH | Language.ZH_TW:
|
||||
return "zh"
|
||||
return None
|
||||
|
||||
|
||||
class CartesiaTTSService(AsyncWordTTSService):
|
||||
class CartesiaTTSService(WordTTSService):
|
||||
class InputParams(BaseModel):
|
||||
encoding: Optional[str] = "pcm_s16le"
|
||||
sample_rate: Optional[int] = 16000
|
||||
container: Optional[str] = "raw"
|
||||
language: Optional[str] = "en"
|
||||
language: Optional[Language] = Language.EN
|
||||
speed: Optional[Union[str, float]] = ""
|
||||
emotion: Optional[List[str]] = []
|
||||
|
||||
@@ -77,7 +83,7 @@ class CartesiaTTSService(AsyncWordTTSService):
|
||||
voice_id: str,
|
||||
cartesia_version: str = "2024-06-10",
|
||||
url: str = "wss://api.cartesia.ai/tts/websocket",
|
||||
model_id: str = "sonic-english",
|
||||
model: str = "sonic-english",
|
||||
params: InputParams = InputParams(),
|
||||
**kwargs,
|
||||
):
|
||||
@@ -101,17 +107,20 @@ class CartesiaTTSService(AsyncWordTTSService):
|
||||
self._api_key = api_key
|
||||
self._cartesia_version = cartesia_version
|
||||
self._url = url
|
||||
self._voice_id = voice_id
|
||||
self._model_id = model_id
|
||||
self.set_model_name(model_id)
|
||||
self._output_format = {
|
||||
"container": params.container,
|
||||
"encoding": params.encoding,
|
||||
"sample_rate": params.sample_rate,
|
||||
self._settings = {
|
||||
"output_format": {
|
||||
"container": params.container,
|
||||
"encoding": params.encoding,
|
||||
"sample_rate": params.sample_rate,
|
||||
},
|
||||
"language": self.language_to_service_language(params.language)
|
||||
if params.language
|
||||
else Language.EN,
|
||||
"speed": params.speed,
|
||||
"emotion": params.emotion,
|
||||
}
|
||||
self._language = params.language
|
||||
self._speed = params.speed
|
||||
self._emotion = params.emotion
|
||||
self.set_model_name(model)
|
||||
self.set_voice(voice_id)
|
||||
|
||||
self._websocket = None
|
||||
self._context_id = None
|
||||
@@ -125,42 +134,31 @@ class CartesiaTTSService(AsyncWordTTSService):
|
||||
await super().set_model(model)
|
||||
logger.debug(f"Switching TTS model to: [{model}]")
|
||||
|
||||
async def set_voice(self, voice: str):
|
||||
logger.debug(f"Switching TTS voice to: [{voice}]")
|
||||
self._voice_id = voice
|
||||
|
||||
async def set_speed(self, speed: str):
|
||||
logger.debug(f"Switching TTS speed to: [{speed}]")
|
||||
self._speed = speed
|
||||
|
||||
async def set_emotion(self, emotion: list[str]):
|
||||
logger.debug(f"Switching TTS emotion to: [{emotion}]")
|
||||
self._emotion = emotion
|
||||
|
||||
async def set_language(self, language: Language):
|
||||
logger.debug(f"Switching TTS language to: [{language}]")
|
||||
self._language = language_to_cartesia_language(language)
|
||||
def language_to_service_language(self, language: Language) -> str | None:
|
||||
return language_to_cartesia_language(language)
|
||||
|
||||
def _build_msg(
|
||||
self, text: str = "", continue_transcript: bool = True, add_timestamps: bool = True
|
||||
):
|
||||
voice_config = {"mode": "id", "id": self._voice_id}
|
||||
voice_config = {}
|
||||
voice_config["mode"] = "id"
|
||||
voice_config["id"] = self._voice_id
|
||||
|
||||
if self._speed or self._emotion:
|
||||
if self._settings["speed"] or self._settings["emotion"]:
|
||||
voice_config["__experimental_controls"] = {}
|
||||
if self._speed:
|
||||
voice_config["__experimental_controls"]["speed"] = self._speed
|
||||
if self._emotion:
|
||||
voice_config["__experimental_controls"]["emotion"] = self._emotion
|
||||
if self._settings["speed"]:
|
||||
voice_config["__experimental_controls"]["speed"] = self._settings["speed"]
|
||||
if self._settings["emotion"]:
|
||||
voice_config["__experimental_controls"]["emotion"] = self._settings["emotion"]
|
||||
|
||||
msg = {
|
||||
"transcript": text,
|
||||
"transcript": text or " ", # Text must contain at least one character
|
||||
"continue": continue_transcript,
|
||||
"context_id": self._context_id,
|
||||
"model_id": self._model_name,
|
||||
"model_id": self.model_name,
|
||||
"voice": voice_config,
|
||||
"output_format": self._output_format,
|
||||
"language": self._language,
|
||||
"output_format": self._settings["output_format"],
|
||||
"language": self._settings["language"],
|
||||
"add_timestamps": add_timestamps,
|
||||
}
|
||||
return json.dumps(msg)
|
||||
@@ -212,7 +210,6 @@ class CartesiaTTSService(AsyncWordTTSService):
|
||||
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
|
||||
await super()._handle_interruption(frame, direction)
|
||||
await self.stop_all_metrics()
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
self._context_id = None
|
||||
|
||||
async def flush_audio(self):
|
||||
@@ -230,12 +227,13 @@ class CartesiaTTSService(AsyncWordTTSService):
|
||||
continue
|
||||
if msg["type"] == "done":
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
# Unset _context_id but not the _context_id_start_timestamp
|
||||
# because we are likely still playing out audio and need the
|
||||
# timestamp to set send context frames.
|
||||
self._context_id = None
|
||||
await self.add_word_timestamps([("LLMFullResponseEndFrame", 0)])
|
||||
await self.add_word_timestamps(
|
||||
[("TTSStoppedFrame", 0), ("LLMFullResponseEndFrame", 0)]
|
||||
)
|
||||
elif msg["type"] == "timestamps":
|
||||
await self.add_word_timestamps(
|
||||
list(zip(msg["word_timestamps"]["words"], msg["word_timestamps"]["start"]))
|
||||
@@ -245,7 +243,7 @@ class CartesiaTTSService(AsyncWordTTSService):
|
||||
self.start_word_timestamps()
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=base64.b64decode(msg["data"]),
|
||||
sample_rate=self._output_format["sample_rate"],
|
||||
sample_rate=self._settings["output_format"]["sample_rate"],
|
||||
num_channels=1,
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
@@ -269,18 +267,18 @@ class CartesiaTTSService(AsyncWordTTSService):
|
||||
await self._connect()
|
||||
|
||||
if not self._context_id:
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame()
|
||||
self._context_id = str(uuid.uuid4())
|
||||
|
||||
msg = self._build_msg(text=text)
|
||||
msg = self._build_msg(text=text or " ") # Text must contain at least one character
|
||||
|
||||
try:
|
||||
await self._get_websocket().send(msg)
|
||||
await self.start_tts_usage_metrics(text)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error sending message: {e}")
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
yield TTSStoppedFrame()
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
return
|
||||
@@ -294,7 +292,7 @@ class CartesiaHttpTTSService(TTSService):
|
||||
encoding: Optional[str] = "pcm_s16le"
|
||||
sample_rate: Optional[int] = 16000
|
||||
container: Optional[str] = "raw"
|
||||
language: Optional[str] = "en"
|
||||
language: Optional[Language] = Language.EN
|
||||
speed: Optional[Union[str, float]] = ""
|
||||
emotion: Optional[List[str]] = []
|
||||
|
||||
@@ -303,7 +301,7 @@ class CartesiaHttpTTSService(TTSService):
|
||||
*,
|
||||
api_key: str,
|
||||
voice_id: str,
|
||||
model_id: str = "sonic-english",
|
||||
model: str = "sonic-english",
|
||||
base_url: str = "https://api.cartesia.ai",
|
||||
params: InputParams = InputParams(),
|
||||
**kwargs,
|
||||
@@ -311,43 +309,28 @@ class CartesiaHttpTTSService(TTSService):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._api_key = api_key
|
||||
self._voice_id = voice_id
|
||||
self._model_id = model_id
|
||||
self.set_model_name(model_id)
|
||||
self._output_format = {
|
||||
"container": params.container,
|
||||
"encoding": params.encoding,
|
||||
"sample_rate": params.sample_rate,
|
||||
self._settings = {
|
||||
"output_format": {
|
||||
"container": params.container,
|
||||
"encoding": params.encoding,
|
||||
"sample_rate": params.sample_rate,
|
||||
},
|
||||
"language": self.language_to_service_language(params.language)
|
||||
if params.language
|
||||
else Language.EN,
|
||||
"speed": params.speed,
|
||||
"emotion": params.emotion,
|
||||
}
|
||||
self._language = params.language
|
||||
self._speed = params.speed
|
||||
self._emotion = params.emotion
|
||||
self.set_voice(voice_id)
|
||||
self.set_model_name(model)
|
||||
|
||||
self._client = AsyncCartesia(api_key=api_key, base_url=base_url)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def set_model(self, model: str):
|
||||
logger.debug(f"Switching TTS model to: [{model}]")
|
||||
self._model_id = model
|
||||
await super().set_model(model)
|
||||
|
||||
async def set_voice(self, voice: str):
|
||||
logger.debug(f"Switching TTS voice to: [{voice}]")
|
||||
self._voice_id = voice
|
||||
|
||||
async def set_speed(self, speed: str):
|
||||
logger.debug(f"Switching TTS speed to: [{speed}]")
|
||||
self._speed = speed
|
||||
|
||||
async def set_emotion(self, emotion: list[str]):
|
||||
logger.debug(f"Switching TTS emotion to: [{emotion}]")
|
||||
self._emotion = emotion
|
||||
|
||||
async def set_language(self, language: Language):
|
||||
logger.debug(f"Switching TTS language to: [{language}]")
|
||||
self._language = language_to_cartesia_language(language)
|
||||
def language_to_service_language(self, language: Language) -> str | None:
|
||||
return language_to_cartesia_language(language)
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
@@ -360,24 +343,24 @@ class CartesiaHttpTTSService(TTSService):
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Generating TTS: [{text}]")
|
||||
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame()
|
||||
|
||||
try:
|
||||
voice_controls = None
|
||||
if self._speed or self._emotion:
|
||||
if self._settings["speed"] or self._settings["emotion"]:
|
||||
voice_controls = {}
|
||||
if self._speed:
|
||||
voice_controls["speed"] = self._speed
|
||||
if self._emotion:
|
||||
voice_controls["emotion"] = self._emotion
|
||||
if self._settings["speed"]:
|
||||
voice_controls["speed"] = self._settings["speed"]
|
||||
if self._settings["emotion"]:
|
||||
voice_controls["emotion"] = self._settings["emotion"]
|
||||
|
||||
output = await self._client.tts.sse(
|
||||
model_id=self._model_id,
|
||||
model_id=self._model_name,
|
||||
transcript=text,
|
||||
voice_id=self._voice_id,
|
||||
output_format=self._output_format,
|
||||
language=self._language,
|
||||
output_format=self._settings["output_format"],
|
||||
language=self._settings["language"],
|
||||
stream=False,
|
||||
_experimental_voice_controls=voice_controls,
|
||||
)
|
||||
@@ -386,7 +369,7 @@ class CartesiaHttpTTSService(TTSService):
|
||||
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=output["audio"],
|
||||
sample_rate=self._output_format["sample_rate"],
|
||||
sample_rate=self._settings["output_format"]["sample_rate"],
|
||||
num_channels=1,
|
||||
)
|
||||
yield frame
|
||||
@@ -394,4 +377,4 @@ class CartesiaHttpTTSService(TTSService):
|
||||
logger.error(f"{self} exception: {e}")
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -5,9 +5,10 @@
|
||||
#
|
||||
|
||||
import asyncio
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
@@ -24,8 +25,6 @@ from pipecat.services.ai_services import STTService, TTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
from loguru import logger
|
||||
|
||||
# See .env.example for Deepgram configuration needed
|
||||
try:
|
||||
from deepgram import (
|
||||
@@ -36,6 +35,7 @@ try:
|
||||
LiveResultResponse,
|
||||
LiveTranscriptionEvents,
|
||||
SpeakOptions,
|
||||
logging,
|
||||
)
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
@@ -57,25 +57,23 @@ class DeepgramTTSService(TTSService):
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._voice = voice
|
||||
self._sample_rate = sample_rate
|
||||
self._encoding = encoding
|
||||
self._settings = {
|
||||
"sample_rate": sample_rate,
|
||||
"encoding": encoding,
|
||||
}
|
||||
self.set_voice(voice)
|
||||
self._deepgram_client = DeepgramClient(api_key=api_key)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def set_voice(self, voice: str):
|
||||
logger.debug(f"Switching TTS voice to: [{voice}]")
|
||||
self._voice = voice
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Generating TTS: [{text}]")
|
||||
|
||||
options = SpeakOptions(
|
||||
model=self._voice,
|
||||
encoding=self._encoding,
|
||||
sample_rate=self._sample_rate,
|
||||
model=self._voice_id,
|
||||
encoding=self._settings["encoding"],
|
||||
sample_rate=self._settings["sample_rate"],
|
||||
container="none",
|
||||
)
|
||||
|
||||
@@ -87,7 +85,7 @@ class DeepgramTTSService(TTSService):
|
||||
)
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
yield TTSStartedFrame()
|
||||
|
||||
# The response.stream_memory is already a BytesIO object
|
||||
audio_buffer = response.stream_memory
|
||||
@@ -103,10 +101,12 @@ class DeepgramTTSService(TTSService):
|
||||
chunk = audio_buffer.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
frame = TTSAudioRawFrame(audio=chunk, sample_rate=self._sample_rate, num_channels=1)
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=chunk, sample_rate=self._settings["sample_rate"], num_channels=1
|
||||
)
|
||||
yield frame
|
||||
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
@@ -119,9 +119,13 @@ class DeepgramSTTService(STTService):
|
||||
*,
|
||||
api_key: str,
|
||||
url: str = "",
|
||||
live_options: LiveOptions = LiveOptions(
|
||||
live_options: LiveOptions = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
default_options = LiveOptions(
|
||||
encoding="linear16",
|
||||
language="en-US",
|
||||
language=Language.EN,
|
||||
model="nova-2-conversationalai",
|
||||
sample_rate=16000,
|
||||
channels=1,
|
||||
@@ -130,15 +134,19 @@ class DeepgramSTTService(STTService):
|
||||
punctuate=True,
|
||||
profanity_filter=True,
|
||||
vad_events=False,
|
||||
),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
)
|
||||
|
||||
self._live_options = live_options
|
||||
merged_options = default_options
|
||||
if live_options:
|
||||
merged_options = LiveOptions(**{**default_options.to_dict(), **live_options.to_dict()})
|
||||
self._settings = merged_options.to_dict()
|
||||
|
||||
self._client = DeepgramClient(
|
||||
api_key, config=DeepgramClientOptions(url=url, options={"keepalive": "true"})
|
||||
api_key,
|
||||
config=DeepgramClientOptions(
|
||||
url=url,
|
||||
options={"keepalive": "true"}, # verbose=logging.DEBUG
|
||||
),
|
||||
)
|
||||
self._connection: AsyncListenWebSocketClient = self._client.listen.asyncwebsocket.v("1")
|
||||
self._connection.on(LiveTranscriptionEvents.Transcript, self._on_message)
|
||||
@@ -147,7 +155,7 @@ class DeepgramSTTService(STTService):
|
||||
|
||||
@property
|
||||
def vad_enabled(self):
|
||||
return self._live_options.vad_events
|
||||
return self._settings["vad_events"]
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return self.vad_enabled
|
||||
@@ -155,13 +163,13 @@ class DeepgramSTTService(STTService):
|
||||
async def set_model(self, model: str):
|
||||
await super().set_model(model)
|
||||
logger.debug(f"Switching STT model to: [{model}]")
|
||||
self._live_options.model = model
|
||||
self._settings["model"] = model
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
|
||||
async def set_language(self, language: Language):
|
||||
logger.debug(f"Switching STT language to: [{language}]")
|
||||
self._live_options.language = language
|
||||
self._settings["language"] = language
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
|
||||
@@ -182,7 +190,7 @@ class DeepgramSTTService(STTService):
|
||||
yield None
|
||||
|
||||
async def _connect(self):
|
||||
if await self._connection.start(self._live_options):
|
||||
if await self._connection.start(self._settings):
|
||||
logger.debug(f"{self}: Connected to Deepgram")
|
||||
else:
|
||||
logger.error(f"{self}: Unable to connect to Deepgram")
|
||||
|
||||
@@ -23,7 +23,8 @@ from pipecat.frames.frames import (
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import AsyncWordTTSService
|
||||
from pipecat.services.ai_services import WordTTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
# See .env.example for ElevenLabs configuration needed
|
||||
try:
|
||||
@@ -70,9 +71,9 @@ def calculate_word_times(
|
||||
return word_times
|
||||
|
||||
|
||||
class ElevenLabsTTSService(AsyncWordTTSService):
|
||||
class ElevenLabsTTSService(WordTTSService):
|
||||
class InputParams(BaseModel):
|
||||
language: Optional[str] = None
|
||||
language: Optional[Language] = Language.EN
|
||||
output_format: Literal["pcm_16000", "pcm_22050", "pcm_24000", "pcm_44100"] = "pcm_16000"
|
||||
optimize_streaming_latency: Optional[str] = None
|
||||
stability: Optional[float] = None
|
||||
@@ -124,10 +125,21 @@ class ElevenLabsTTSService(AsyncWordTTSService):
|
||||
)
|
||||
|
||||
self._api_key = api_key
|
||||
self._voice_id = voice_id
|
||||
self.set_model_name(model)
|
||||
self._url = url
|
||||
self._params = params
|
||||
self._settings = {
|
||||
"sample_rate": sample_rate_from_output_format(params.output_format),
|
||||
"language": self.language_to_service_language(params.language)
|
||||
if params.language
|
||||
else Language.EN,
|
||||
"output_format": params.output_format,
|
||||
"optimize_streaming_latency": params.optimize_streaming_latency,
|
||||
"stability": params.stability,
|
||||
"similarity_boost": params.similarity_boost,
|
||||
"style": params.style,
|
||||
"use_speaker_boost": params.use_speaker_boost,
|
||||
}
|
||||
self.set_model_name(model)
|
||||
self.set_voice(voice_id)
|
||||
self._voice_settings = self._set_voice_settings()
|
||||
|
||||
# Websocket connection to ElevenLabs.
|
||||
@@ -140,21 +152,93 @@ class ElevenLabsTTSService(AsyncWordTTSService):
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> str | None:
|
||||
match language:
|
||||
case Language.BG:
|
||||
return "bg"
|
||||
case Language.ZH:
|
||||
return "zh"
|
||||
case Language.CS:
|
||||
return "cs"
|
||||
case Language.DA:
|
||||
return "da"
|
||||
case Language.NL:
|
||||
return "nl"
|
||||
case (
|
||||
Language.EN
|
||||
| Language.EN_US
|
||||
| Language.EN_AU
|
||||
| Language.EN_GB
|
||||
| Language.EN_NZ
|
||||
| Language.EN_IN
|
||||
):
|
||||
return "en"
|
||||
case Language.FI:
|
||||
return "fi"
|
||||
case Language.FR | Language.FR_CA:
|
||||
return "fr"
|
||||
case Language.DE | Language.DE_CH:
|
||||
return "de"
|
||||
case Language.EL:
|
||||
return "el"
|
||||
case Language.HI:
|
||||
return "hi"
|
||||
case Language.HU:
|
||||
return "hu"
|
||||
case Language.ID:
|
||||
return "id"
|
||||
case Language.IT:
|
||||
return "it"
|
||||
case Language.JA:
|
||||
return "ja"
|
||||
case Language.KO:
|
||||
return "ko"
|
||||
case Language.MS:
|
||||
return "ms"
|
||||
case Language.NO:
|
||||
return "no"
|
||||
case Language.PL:
|
||||
return "pl"
|
||||
case Language.PT:
|
||||
return "pt-PT"
|
||||
case Language.PT_BR:
|
||||
return "pt-BR"
|
||||
case Language.RO:
|
||||
return "ro"
|
||||
case Language.RU:
|
||||
return "ru"
|
||||
case Language.SK:
|
||||
return "sk"
|
||||
case Language.ES:
|
||||
return "es"
|
||||
case Language.SV:
|
||||
return "sv"
|
||||
case Language.TR:
|
||||
return "tr"
|
||||
case Language.UK:
|
||||
return "uk"
|
||||
case Language.VI:
|
||||
return "vi"
|
||||
return None
|
||||
|
||||
def _set_voice_settings(self):
|
||||
voice_settings = {}
|
||||
if self._params.stability is not None and self._params.similarity_boost is not None:
|
||||
voice_settings["stability"] = self._params.stability
|
||||
voice_settings["similarity_boost"] = self._params.similarity_boost
|
||||
if self._params.style is not None:
|
||||
voice_settings["style"] = self._params.style
|
||||
if self._params.use_speaker_boost is not None:
|
||||
voice_settings["use_speaker_boost"] = self._params.use_speaker_boost
|
||||
if (
|
||||
self._settings["stability"] is not None
|
||||
and self._settings["similarity_boost"] is not None
|
||||
):
|
||||
voice_settings["stability"] = self._settings["stability"]
|
||||
voice_settings["similarity_boost"] = self._settings["similarity_boost"]
|
||||
if self._settings["style"] is not None:
|
||||
voice_settings["style"] = self._settings["style"]
|
||||
if self._settings["use_speaker_boost"] is not None:
|
||||
voice_settings["use_speaker_boost"] = self._settings["use_speaker_boost"]
|
||||
else:
|
||||
if self._params.style is not None:
|
||||
if self._settings["style"] is not None:
|
||||
logger.warning(
|
||||
"'style' is set but will not be applied because 'stability' and 'similarity_boost' are not both set."
|
||||
)
|
||||
if self._params.use_speaker_boost is not None:
|
||||
if self._settings["use_speaker_boost"] is not None:
|
||||
logger.warning(
|
||||
"'use_speaker_boost' is set but will not be applied because 'stability' and 'similarity_boost' are not both set."
|
||||
)
|
||||
@@ -167,33 +251,13 @@ class ElevenLabsTTSService(AsyncWordTTSService):
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
|
||||
async def set_voice(self, voice: str):
|
||||
logger.debug(f"Switching TTS voice to: [{voice}]")
|
||||
self._voice_id = voice
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
|
||||
async def set_voice_settings(
|
||||
self,
|
||||
stability: Optional[float] = None,
|
||||
similarity_boost: Optional[float] = None,
|
||||
style: Optional[float] = None,
|
||||
use_speaker_boost: Optional[bool] = None,
|
||||
):
|
||||
self._params.stability = stability if stability is not None else self._params.stability
|
||||
self._params.similarity_boost = (
|
||||
similarity_boost if similarity_boost is not None else self._params.similarity_boost
|
||||
)
|
||||
self._params.style = style if style is not None else self._params.style
|
||||
self._params.use_speaker_boost = (
|
||||
use_speaker_boost if use_speaker_boost is not None else self._params.use_speaker_boost
|
||||
)
|
||||
|
||||
self._set_voice_settings()
|
||||
|
||||
if self._websocket:
|
||||
msg = {"voice_settings": self._voice_settings}
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
async def _update_settings(self, settings: Dict[str, Any]):
|
||||
prev_voice = self._voice_id
|
||||
await super()._update_settings(settings)
|
||||
if not prev_voice == self._voice_id:
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
logger.debug(f"Switching TTS voice to: [{self._voice_id}]")
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
@@ -223,20 +287,20 @@ class ElevenLabsTTSService(AsyncWordTTSService):
|
||||
try:
|
||||
voice_id = self._voice_id
|
||||
model = self.model_name
|
||||
output_format = self._params.output_format
|
||||
output_format = self._settings["output_format"]
|
||||
url = f"{self._url}/v1/text-to-speech/{voice_id}/stream-input?model_id={model}&output_format={output_format}"
|
||||
|
||||
if self._params.optimize_streaming_latency:
|
||||
url += f"&optimize_streaming_latency={self._params.optimize_streaming_latency}"
|
||||
if self._settings["optimize_streaming_latency"]:
|
||||
url += f"&optimize_streaming_latency={self._settings['optimize_streaming_latency']}"
|
||||
|
||||
# language can only be used with the 'eleven_turbo_v2_5' model
|
||||
if self._params.language:
|
||||
if model == "eleven_turbo_v2_5":
|
||||
url += f"&language_code={self._params.language}"
|
||||
else:
|
||||
logger.debug(
|
||||
f"Language code [{self._params.language}] not applied. Language codes can only be used with the 'eleven_turbo_v2_5' model."
|
||||
)
|
||||
# Language can only be used with the 'eleven_turbo_v2_5' model
|
||||
language = self._settings["language"]
|
||||
if model == "eleven_turbo_v2_5":
|
||||
url += f"&language_code={language}"
|
||||
else:
|
||||
logger.debug(
|
||||
f"Language code [{language}] not applied. Language codes can only be used with the 'eleven_turbo_v2_5' model."
|
||||
)
|
||||
|
||||
self._websocket = await websockets.connect(url)
|
||||
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
|
||||
@@ -286,7 +350,7 @@ class ElevenLabsTTSService(AsyncWordTTSService):
|
||||
self.start_word_timestamps()
|
||||
|
||||
audio = base64.b64decode(msg["audio"])
|
||||
frame = TTSAudioRawFrame(audio, self.sample_rate, 1)
|
||||
frame = TTSAudioRawFrame(audio, self._settings["sample_rate"], 1)
|
||||
await self.push_frame(frame)
|
||||
|
||||
if msg.get("alignment"):
|
||||
@@ -322,8 +386,8 @@ class ElevenLabsTTSService(AsyncWordTTSService):
|
||||
|
||||
try:
|
||||
if not self._started:
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame()
|
||||
self._started = True
|
||||
self._cumulative_time = 0
|
||||
|
||||
@@ -331,7 +395,7 @@ class ElevenLabsTTSService(AsyncWordTTSService):
|
||||
await self.start_tts_usage_metrics(text)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error sending message: {e}")
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
yield TTSStoppedFrame()
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
return
|
||||
|
||||
@@ -6,8 +6,9 @@
|
||||
|
||||
import base64
|
||||
import json
|
||||
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
from loguru import logger
|
||||
from pydantic.main import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
@@ -19,10 +20,9 @@ from pipecat.frames.frames import (
|
||||
TranscriptionFrame,
|
||||
)
|
||||
from pipecat.services.ai_services import STTService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
from loguru import logger
|
||||
|
||||
# See .env.example for Gladia configuration needed
|
||||
try:
|
||||
import websockets
|
||||
@@ -37,7 +37,7 @@ except ModuleNotFoundError as e:
|
||||
class GladiaSTTService(STTService):
|
||||
class InputParams(BaseModel):
|
||||
sample_rate: Optional[int] = 16000
|
||||
language: Optional[str] = "english"
|
||||
language: Optional[Language] = Language.EN
|
||||
transcription_hint: Optional[str] = None
|
||||
endpointing: Optional[int] = 200
|
||||
prosody: Optional[bool] = None
|
||||
@@ -51,13 +51,98 @@ class GladiaSTTService(STTService):
|
||||
params: InputParams = InputParams(),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(sync=False, **kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._api_key = api_key
|
||||
self._url = url
|
||||
self._params = params
|
||||
self._settings = {
|
||||
"sample_rate": params.sample_rate,
|
||||
"language": self.language_to_service_language(params.language)
|
||||
if params.language
|
||||
else Language.EN,
|
||||
"transcription_hint": params.transcription_hint,
|
||||
"endpointing": params.endpointing,
|
||||
"prosody": params.prosody,
|
||||
}
|
||||
self._confidence = confidence
|
||||
|
||||
def language_to_service_language(self, language: Language) -> str | None:
|
||||
match language:
|
||||
case Language.BG:
|
||||
return "bulgarian"
|
||||
case Language.CA:
|
||||
return "catalan"
|
||||
case Language.ZH:
|
||||
return "chinese"
|
||||
case Language.CS:
|
||||
return "czech"
|
||||
case Language.DA:
|
||||
return "danish"
|
||||
case Language.NL:
|
||||
return "dutch"
|
||||
case (
|
||||
Language.EN
|
||||
| Language.EN_US
|
||||
| Language.EN_AU
|
||||
| Language.EN_GB
|
||||
| Language.EN_NZ
|
||||
| Language.EN_IN
|
||||
):
|
||||
return "english"
|
||||
case Language.ET:
|
||||
return "estonian"
|
||||
case Language.FI:
|
||||
return "finnish"
|
||||
case Language.FR | Language.FR_CA:
|
||||
return "french"
|
||||
case Language.DE | Language.DE_CH:
|
||||
return "german"
|
||||
case Language.EL:
|
||||
return "greek"
|
||||
case Language.HI:
|
||||
return "hindi"
|
||||
case Language.HU:
|
||||
return "hungarian"
|
||||
case Language.ID:
|
||||
return "indonesian"
|
||||
case Language.IT:
|
||||
return "italian"
|
||||
case Language.JA:
|
||||
return "japanese"
|
||||
case Language.KO:
|
||||
return "korean"
|
||||
case Language.LV:
|
||||
return "latvian"
|
||||
case Language.LT:
|
||||
return "lithuanian"
|
||||
case Language.MS:
|
||||
return "malay"
|
||||
case Language.NO:
|
||||
return "norwegian"
|
||||
case Language.PL:
|
||||
return "polish"
|
||||
case Language.PT | Language.PT_BR:
|
||||
return "portuguese"
|
||||
case Language.RO:
|
||||
return "romanian"
|
||||
case Language.RU:
|
||||
return "russian"
|
||||
case Language.SK:
|
||||
return "slovak"
|
||||
case Language.ES:
|
||||
return "spanish"
|
||||
case Language.SV:
|
||||
return "slovenian"
|
||||
case Language.TH:
|
||||
return "thai"
|
||||
case Language.TR:
|
||||
return "turkish"
|
||||
case Language.UK:
|
||||
return "ukrainian"
|
||||
case Language.VI:
|
||||
return "vietnamese"
|
||||
return None
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
self._websocket = await websockets.connect(self._url)
|
||||
@@ -84,7 +169,11 @@ class GladiaSTTService(STTService):
|
||||
"encoding": "WAV/PCM",
|
||||
"model_type": "fast",
|
||||
"language_behaviour": "manual",
|
||||
**self._params.model_dump(exclude_none=True),
|
||||
"sample_rate": self._settings["sample_rate"],
|
||||
"language": self._settings["language"],
|
||||
"transcription_hint": self._settings["transcription_hint"],
|
||||
"endpointing": self._settings["endpointing"],
|
||||
"prosody": self._settings["prosody"],
|
||||
}
|
||||
|
||||
await self._websocket.send(json.dumps(configuration))
|
||||
|
||||
@@ -30,6 +30,7 @@ from pipecat.processors.aggregators.openai_llm_context import (
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import LLMService, TTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
try:
|
||||
import google.ai.generativelanguage as glm
|
||||
@@ -39,7 +40,7 @@ try:
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use Google AI, you need to `pip install pipecat-ai[google]`. Also, set `GOOGLE_API_KEY` environment variable."
|
||||
"In order to use Google AI, you need to `pip install pipecat-ai[google]`. Also, set the environment variable GOOGLE_API_KEY for the GoogleLLMService and GOOGLE_APPLICATION_CREDENTIALS for the GoogleTTSService`."
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
@@ -137,9 +138,7 @@ class GoogleLLMService(LLMService):
|
||||
elif isinstance(frame, VisionImageRawFrame):
|
||||
context = OpenAILLMContext.from_image_frame(frame)
|
||||
elif isinstance(frame, LLMUpdateSettingsFrame):
|
||||
if frame.model is not None:
|
||||
logger.debug(f"Switching LLM model to: [{frame.model}]")
|
||||
self.set_model_name(frame.model)
|
||||
await self._update_settings(frame.settings)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -153,7 +152,7 @@ class GoogleTTSService(TTSService):
|
||||
rate: Optional[str] = None
|
||||
volume: Optional[str] = None
|
||||
emphasis: Optional[Literal["strong", "moderate", "reduced", "none"]] = None
|
||||
language: Optional[str] = None
|
||||
language: Optional[Language] = Language.EN
|
||||
gender: Optional[Literal["male", "female", "neutral"]] = None
|
||||
google_style: Optional[Literal["apologetic", "calm", "empathetic", "firm", "lively"]] = None
|
||||
|
||||
@@ -169,8 +168,19 @@ class GoogleTTSService(TTSService):
|
||||
):
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
self._voice_id: str = voice_id
|
||||
self._params = params
|
||||
self._settings = {
|
||||
"sample_rate": sample_rate,
|
||||
"pitch": params.pitch,
|
||||
"rate": params.rate,
|
||||
"volume": params.volume,
|
||||
"emphasis": params.emphasis,
|
||||
"language": self.language_to_service_language(params.language)
|
||||
if params.language
|
||||
else Language.EN,
|
||||
"gender": params.gender,
|
||||
"google_style": params.google_style,
|
||||
}
|
||||
self.set_voice(voice_id)
|
||||
self._client: texttospeech_v1.TextToSpeechAsyncClient = self._create_client(
|
||||
credentials, credentials_path
|
||||
)
|
||||
@@ -190,51 +200,135 @@ class GoogleTTSService(TTSService):
|
||||
elif credentials_path:
|
||||
# Use service account JSON file if provided
|
||||
creds = service_account.Credentials.from_service_account_file(credentials_path)
|
||||
else:
|
||||
raise ValueError("Either 'credentials' or 'credentials_path' must be provided.")
|
||||
|
||||
return texttospeech_v1.TextToSpeechAsyncClient(credentials=creds)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> str | None:
|
||||
match language:
|
||||
case Language.BG:
|
||||
return "bg-BG"
|
||||
case Language.CA:
|
||||
return "ca-ES"
|
||||
case Language.ZH:
|
||||
return "cmn-CN"
|
||||
case Language.ZH_TW:
|
||||
return "cmn-TW"
|
||||
case Language.CS:
|
||||
return "cs-CZ"
|
||||
case Language.DA:
|
||||
return "da-DK"
|
||||
case Language.NL:
|
||||
return "nl-NL"
|
||||
case Language.EN | Language.EN_US:
|
||||
return "en-US"
|
||||
case Language.EN_AU:
|
||||
return "en-AU"
|
||||
case Language.EN_GB:
|
||||
return "en-GB"
|
||||
case Language.EN_IN:
|
||||
return "en-IN"
|
||||
case Language.ET:
|
||||
return "et-EE"
|
||||
case Language.FI:
|
||||
return "fi-FI"
|
||||
case Language.NL_BE:
|
||||
return "nl-BE"
|
||||
case Language.FR:
|
||||
return "fr-FR"
|
||||
case Language.FR_CA:
|
||||
return "fr-CA"
|
||||
case Language.DE:
|
||||
return "de-DE"
|
||||
case Language.EL:
|
||||
return "el-GR"
|
||||
case Language.HI:
|
||||
return "hi-IN"
|
||||
case Language.HU:
|
||||
return "hu-HU"
|
||||
case Language.ID:
|
||||
return "id-ID"
|
||||
case Language.IT:
|
||||
return "it-IT"
|
||||
case Language.JA:
|
||||
return "ja-JP"
|
||||
case Language.KO:
|
||||
return "ko-KR"
|
||||
case Language.LV:
|
||||
return "lv-LV"
|
||||
case Language.LT:
|
||||
return "lt-LT"
|
||||
case Language.MS:
|
||||
return "ms-MY"
|
||||
case Language.NO:
|
||||
return "nb-NO"
|
||||
case Language.PL:
|
||||
return "pl-PL"
|
||||
case Language.PT:
|
||||
return "pt-PT"
|
||||
case Language.PT_BR:
|
||||
return "pt-BR"
|
||||
case Language.RO:
|
||||
return "ro-RO"
|
||||
case Language.RU:
|
||||
return "ru-RU"
|
||||
case Language.SK:
|
||||
return "sk-SK"
|
||||
case Language.ES:
|
||||
return "es-ES"
|
||||
case Language.SV:
|
||||
return "sv-SE"
|
||||
case Language.TH:
|
||||
return "th-TH"
|
||||
case Language.TR:
|
||||
return "tr-TR"
|
||||
case Language.UK:
|
||||
return "uk-UA"
|
||||
case Language.VI:
|
||||
return "vi-VN"
|
||||
return None
|
||||
|
||||
def _construct_ssml(self, text: str) -> str:
|
||||
ssml = "<speak>"
|
||||
|
||||
# Voice tag
|
||||
voice_attrs = [f"name='{self._voice_id}'"]
|
||||
if self._params.language:
|
||||
voice_attrs.append(f"language='{self._params.language}'")
|
||||
if self._params.gender:
|
||||
voice_attrs.append(f"gender='{self._params.gender}'")
|
||||
|
||||
language = self._settings["language"]
|
||||
voice_attrs.append(f"language='{language}'")
|
||||
|
||||
if self._settings["gender"]:
|
||||
voice_attrs.append(f"gender='{self._settings['gender']}'")
|
||||
ssml += f"<voice {' '.join(voice_attrs)}>"
|
||||
|
||||
# Prosody tag
|
||||
prosody_attrs = []
|
||||
if self._params.pitch:
|
||||
prosody_attrs.append(f"pitch='{self._params.pitch}'")
|
||||
if self._params.rate:
|
||||
prosody_attrs.append(f"rate='{self._params.rate}'")
|
||||
if self._params.volume:
|
||||
prosody_attrs.append(f"volume='{self._params.volume}'")
|
||||
if self._settings["pitch"]:
|
||||
prosody_attrs.append(f"pitch='{self._settings['pitch']}'")
|
||||
if self._settings["rate"]:
|
||||
prosody_attrs.append(f"rate='{self._settings['rate']}'")
|
||||
if self._settings["volume"]:
|
||||
prosody_attrs.append(f"volume='{self._settings['volume']}'")
|
||||
|
||||
if prosody_attrs:
|
||||
ssml += f"<prosody {' '.join(prosody_attrs)}>"
|
||||
|
||||
# Emphasis tag
|
||||
if self._params.emphasis:
|
||||
ssml += f"<emphasis level='{self._params.emphasis}'>"
|
||||
if self._settings["emphasis"]:
|
||||
ssml += f"<emphasis level='{self._settings['emphasis']}'>"
|
||||
|
||||
# Google style tag
|
||||
if self._params.google_style:
|
||||
ssml += f"<google:style name='{self._params.google_style}'>"
|
||||
if self._settings["google_style"]:
|
||||
ssml += f"<google:style name='{self._settings['google_style']}'>"
|
||||
|
||||
ssml += text
|
||||
|
||||
# Close tags
|
||||
if self._params.google_style:
|
||||
if self._settings["google_style"]:
|
||||
ssml += "</google:style>"
|
||||
if self._params.emphasis:
|
||||
if self._settings["emphasis"]:
|
||||
ssml += "</emphasis>"
|
||||
if prosody_attrs:
|
||||
ssml += "</prosody>"
|
||||
@@ -242,46 +336,6 @@ class GoogleTTSService(TTSService):
|
||||
|
||||
return ssml
|
||||
|
||||
async def set_voice(self, voice: str) -> None:
|
||||
logger.debug(f"Switching TTS voice to: [{voice}]")
|
||||
self._voice_id = voice
|
||||
|
||||
async def set_language(self, language: str) -> None:
|
||||
logger.debug(f"Switching TTS language to: [{language}]")
|
||||
self._params.language = language
|
||||
|
||||
async def set_pitch(self, pitch: str) -> None:
|
||||
logger.debug(f"Switching TTS pitch to: [{pitch}]")
|
||||
self._params.pitch = pitch
|
||||
|
||||
async def set_rate(self, rate: str) -> None:
|
||||
logger.debug(f"Switching TTS rate to: [{rate}]")
|
||||
self._params.rate = rate
|
||||
|
||||
async def set_volume(self, volume: str) -> None:
|
||||
logger.debug(f"Switching TTS volume to: [{volume}]")
|
||||
self._params.volume = volume
|
||||
|
||||
async def set_emphasis(
|
||||
self, emphasis: Literal["strong", "moderate", "reduced", "none"]
|
||||
) -> None:
|
||||
logger.debug(f"Switching TTS emphasis to: [{emphasis}]")
|
||||
self._params.emphasis = emphasis
|
||||
|
||||
async def set_gender(self, gender: Literal["male", "female", "neutral"]) -> None:
|
||||
logger.debug(f"Switch TTS gender to [{gender}]")
|
||||
self._params.gender = gender
|
||||
|
||||
async def google_style(
|
||||
self, google_style: Literal["apologetic", "calm", "empathetic", "firm", "lively"]
|
||||
) -> None:
|
||||
logger.debug(f"Switching TTS google style to: [{google_style}]")
|
||||
self._params.google_style = google_style
|
||||
|
||||
async def set_params(self, params: InputParams) -> None:
|
||||
logger.debug(f"Switching TTS params to: [{params}]")
|
||||
self._params = params
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Generating TTS: [{text}]")
|
||||
|
||||
@@ -291,11 +345,11 @@ class GoogleTTSService(TTSService):
|
||||
ssml = self._construct_ssml(text)
|
||||
synthesis_input = texttospeech_v1.SynthesisInput(ssml=ssml)
|
||||
voice = texttospeech_v1.VoiceSelectionParams(
|
||||
language_code=self._params.language, name=self._voice_id
|
||||
language_code=self._settings["language"], name=self._voice_id
|
||||
)
|
||||
audio_config = texttospeech_v1.AudioConfig(
|
||||
audio_encoding=texttospeech_v1.AudioEncoding.LINEAR16,
|
||||
sample_rate_hertz=self.sample_rate,
|
||||
sample_rate_hertz=self._settings["sample_rate"],
|
||||
)
|
||||
|
||||
request = texttospeech_v1.SynthesizeSpeechRequest(
|
||||
@@ -306,7 +360,7 @@ class GoogleTTSService(TTSService):
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
yield TTSStartedFrame()
|
||||
|
||||
# Skip the first 44 bytes to remove the WAV header
|
||||
audio_content = response.audio_content[44:]
|
||||
@@ -318,15 +372,15 @@ class GoogleTTSService(TTSService):
|
||||
if not chunk:
|
||||
break
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(chunk, self.sample_rate, 1)
|
||||
frame = TTSAudioRawFrame(chunk, self._settings["sample_rate"], 1)
|
||||
yield frame
|
||||
await asyncio.sleep(0) # Allow other tasks to run
|
||||
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} error generating TTS: {e}")
|
||||
error_message = f"TTS generation error: {str(e)}"
|
||||
yield ErrorFrame(error=error_message)
|
||||
finally:
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -5,10 +5,10 @@
|
||||
#
|
||||
|
||||
import asyncio
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
@@ -20,9 +20,9 @@ from pipecat.frames.frames import (
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.ai_services import AsyncTTSService
|
||||
|
||||
from loguru import logger
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import TTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
# See .env.example for LMNT configuration needed
|
||||
try:
|
||||
@@ -35,28 +35,31 @@ except ModuleNotFoundError as e:
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class LmntTTSService(AsyncTTSService):
|
||||
class LmntTTSService(TTSService):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
voice_id: str,
|
||||
sample_rate: int = 24000,
|
||||
language: str = "en",
|
||||
language: Language = Language.EN,
|
||||
**kwargs,
|
||||
):
|
||||
# Let TTSService produce TTSStoppedFrames after a short delay of
|
||||
# no activity.
|
||||
super().__init__(sync=False, push_stop_frames=True, sample_rate=sample_rate, **kwargs)
|
||||
super().__init__(push_stop_frames=True, sample_rate=sample_rate, **kwargs)
|
||||
|
||||
self._api_key = api_key
|
||||
self._voice_id = voice_id
|
||||
self._output_format = {
|
||||
"container": "raw",
|
||||
"encoding": "pcm_s16le",
|
||||
"sample_rate": sample_rate,
|
||||
self._settings = {
|
||||
"output_format": {
|
||||
"container": "raw",
|
||||
"encoding": "pcm_s16le",
|
||||
"sample_rate": sample_rate,
|
||||
},
|
||||
"language": self.language_to_service_language(language),
|
||||
}
|
||||
self._language = language
|
||||
|
||||
self.set_voice(voice_id)
|
||||
|
||||
self._speech = None
|
||||
self._connection = None
|
||||
@@ -68,9 +71,30 @@ class LmntTTSService(AsyncTTSService):
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def set_voice(self, voice: str):
|
||||
logger.debug(f"Switching TTS voice to: [{voice}]")
|
||||
self._voice_id = voice
|
||||
def language_to_service_language(self, language: Language) -> str | None:
|
||||
match language:
|
||||
case Language.DE:
|
||||
return "de"
|
||||
case (
|
||||
Language.EN
|
||||
| Language.EN_US
|
||||
| Language.EN_AU
|
||||
| Language.EN_GB
|
||||
| Language.EN_NZ
|
||||
| Language.EN_IN
|
||||
):
|
||||
return "en"
|
||||
case Language.ES:
|
||||
return "es"
|
||||
case Language.FR | Language.FR_CA:
|
||||
return "fr"
|
||||
case Language.PT | Language.PT_BR:
|
||||
return "pt"
|
||||
case Language.ZH | Language.ZH_TW:
|
||||
return "zh"
|
||||
case Language.KO:
|
||||
return "ko"
|
||||
return None
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
@@ -93,7 +117,10 @@ class LmntTTSService(AsyncTTSService):
|
||||
try:
|
||||
self._speech = Speech()
|
||||
self._connection = await self._speech.synthesize_streaming(
|
||||
self._voice_id, format="raw", sample_rate=self._output_format["sample_rate"]
|
||||
self._voice_id,
|
||||
format="raw",
|
||||
sample_rate=self._settings["output_format"]["sample_rate"],
|
||||
language=self._settings["language"],
|
||||
)
|
||||
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
|
||||
except Exception as e:
|
||||
@@ -130,7 +157,7 @@ class LmntTTSService(AsyncTTSService):
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=msg["audio"],
|
||||
sample_rate=self._output_format["sample_rate"],
|
||||
sample_rate=self._settings["output_format"]["sample_rate"],
|
||||
num_channels=1,
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
@@ -149,8 +176,8 @@ class LmntTTSService(AsyncTTSService):
|
||||
await self._connect()
|
||||
|
||||
if not self._started:
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame()
|
||||
self._started = True
|
||||
|
||||
try:
|
||||
@@ -159,7 +186,7 @@ class LmntTTSService(AsyncTTSService):
|
||||
await self.start_tts_usage_metrics(text)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error sending message: {e}")
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
yield TTSStoppedFrame()
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
return
|
||||
|
||||
@@ -31,6 +31,8 @@ from pipecat.frames.frames import (
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
URLImageRawFrame,
|
||||
UserImageRawFrame,
|
||||
UserImageRequestFrame,
|
||||
VisionImageRawFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
@@ -61,6 +63,7 @@ except ModuleNotFoundError as e:
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
ValidVoice = Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"]
|
||||
|
||||
VALID_VOICES: Dict[str, ValidVoice] = {
|
||||
@@ -109,14 +112,16 @@ class BaseOpenAILLMService(LLMService):
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._settings = {
|
||||
"frequency_penalty": params.frequency_penalty,
|
||||
"presence_penalty": params.presence_penalty,
|
||||
"seed": params.seed,
|
||||
"temperature": params.temperature,
|
||||
"top_p": params.top_p,
|
||||
"extra": params.extra if isinstance(params.extra, dict) else {},
|
||||
}
|
||||
self.set_model_name(model)
|
||||
self._client = self.create_client(api_key=api_key, base_url=base_url, **kwargs)
|
||||
self._frequency_penalty = params.frequency_penalty
|
||||
self._presence_penalty = params.presence_penalty
|
||||
self._seed = params.seed
|
||||
self._temperature = params.temperature
|
||||
self._top_p = params.top_p
|
||||
self._extra = params.extra if isinstance(params.extra, dict) else {}
|
||||
|
||||
def create_client(self, api_key=None, base_url=None, **kwargs):
|
||||
return AsyncOpenAI(
|
||||
@@ -132,30 +137,6 @@ class BaseOpenAILLMService(LLMService):
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def set_frequency_penalty(self, frequency_penalty: float):
|
||||
logger.debug(f"Switching LLM frequency_penalty to: [{frequency_penalty}]")
|
||||
self._frequency_penalty = frequency_penalty
|
||||
|
||||
async def set_presence_penalty(self, presence_penalty: float):
|
||||
logger.debug(f"Switching LLM presence_penalty to: [{presence_penalty}]")
|
||||
self._presence_penalty = presence_penalty
|
||||
|
||||
async def set_seed(self, seed: int):
|
||||
logger.debug(f"Switching LLM seed to: [{seed}]")
|
||||
self._seed = seed
|
||||
|
||||
async def set_temperature(self, temperature: float):
|
||||
logger.debug(f"Switching LLM temperature to: [{temperature}]")
|
||||
self._temperature = temperature
|
||||
|
||||
async def set_top_p(self, top_p: float):
|
||||
logger.debug(f"Switching LLM top_p to: [{top_p}]")
|
||||
self._top_p = top_p
|
||||
|
||||
async def set_extra(self, extra: Dict[str, Any]):
|
||||
logger.debug(f"Switching LLM extra to: [{extra}]")
|
||||
self._extra = extra
|
||||
|
||||
async def get_chat_completions(
|
||||
self, context: OpenAILLMContext, messages: List[ChatCompletionMessageParam]
|
||||
) -> AsyncStream[ChatCompletionChunk]:
|
||||
@@ -166,14 +147,14 @@ class BaseOpenAILLMService(LLMService):
|
||||
"tools": context.tools,
|
||||
"tool_choice": context.tool_choice,
|
||||
"stream_options": {"include_usage": True},
|
||||
"frequency_penalty": self._frequency_penalty,
|
||||
"presence_penalty": self._presence_penalty,
|
||||
"seed": self._seed,
|
||||
"temperature": self._temperature,
|
||||
"top_p": self._top_p,
|
||||
"frequency_penalty": self._settings["frequency_penalty"],
|
||||
"presence_penalty": self._settings["presence_penalty"],
|
||||
"seed": self._settings["seed"],
|
||||
"temperature": self._settings["temperature"],
|
||||
"top_p": self._settings["top_p"],
|
||||
}
|
||||
|
||||
params.update(self._extra)
|
||||
params.update(self._settings["extra"])
|
||||
|
||||
chunks = await self._client.chat.completions.create(**params)
|
||||
return chunks
|
||||
@@ -181,7 +162,7 @@ class BaseOpenAILLMService(LLMService):
|
||||
async def _stream_chat_completions(
|
||||
self, context: OpenAILLMContext
|
||||
) -> AsyncStream[ChatCompletionChunk]:
|
||||
logger.debug(f"Generating chat: {context.get_messages_json()}")
|
||||
logger.debug(f"Generating chat: {context.get_messages_for_logging()}")
|
||||
|
||||
messages: List[ChatCompletionMessageParam] = context.get_messages()
|
||||
|
||||
@@ -205,6 +186,10 @@ class BaseOpenAILLMService(LLMService):
|
||||
return chunks
|
||||
|
||||
async def _process_context(self, context: OpenAILLMContext):
|
||||
functions_list = []
|
||||
arguments_list = []
|
||||
tool_id_list = []
|
||||
func_idx = 0
|
||||
function_name = ""
|
||||
arguments = ""
|
||||
tool_call_id = ""
|
||||
@@ -242,6 +227,14 @@ class BaseOpenAILLMService(LLMService):
|
||||
# yield a frame containing the function name and the arguments.
|
||||
|
||||
tool_call = chunk.choices[0].delta.tool_calls[0]
|
||||
if tool_call.index != func_idx:
|
||||
functions_list.append(function_name)
|
||||
arguments_list.append(arguments)
|
||||
tool_id_list.append(tool_call_id)
|
||||
function_name = ""
|
||||
arguments = ""
|
||||
tool_call_id = ""
|
||||
func_idx += 1
|
||||
if tool_call.function and tool_call.function.name:
|
||||
function_name += tool_call.function.name
|
||||
tool_call_id = tool_call.id
|
||||
@@ -257,38 +250,28 @@ class BaseOpenAILLMService(LLMService):
|
||||
# the context, and re-prompt to get a chat answer. If we don't have a registered
|
||||
# handler, raise an exception.
|
||||
if function_name and arguments:
|
||||
if self.has_function(function_name):
|
||||
await self._handle_function_call(context, tool_call_id, function_name, arguments)
|
||||
else:
|
||||
raise OpenAIUnhandledFunctionException(
|
||||
f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function."
|
||||
)
|
||||
# added to the list as last function name and arguments not added to the list
|
||||
functions_list.append(function_name)
|
||||
arguments_list.append(arguments)
|
||||
tool_id_list.append(tool_call_id)
|
||||
|
||||
async def _handle_function_call(self, context, tool_call_id, function_name, arguments):
|
||||
arguments = json.loads(arguments)
|
||||
await self.call_function(
|
||||
context=context,
|
||||
tool_call_id=tool_call_id,
|
||||
function_name=function_name,
|
||||
arguments=arguments,
|
||||
)
|
||||
|
||||
async def _update_settings(self, frame: LLMUpdateSettingsFrame):
|
||||
if frame.model is not None:
|
||||
logger.debug(f"Switching LLM model to: [{frame.model}]")
|
||||
self.set_model_name(frame.model)
|
||||
if frame.frequency_penalty is not None:
|
||||
await self.set_frequency_penalty(frame.frequency_penalty)
|
||||
if frame.presence_penalty is not None:
|
||||
await self.set_presence_penalty(frame.presence_penalty)
|
||||
if frame.seed is not None:
|
||||
await self.set_seed(frame.seed)
|
||||
if frame.temperature is not None:
|
||||
await self.set_temperature(frame.temperature)
|
||||
if frame.top_p is not None:
|
||||
await self.set_top_p(frame.top_p)
|
||||
if frame.extra:
|
||||
await self.set_extra(frame.extra)
|
||||
for index, (function_name, arguments, tool_id) in enumerate(
|
||||
zip(functions_list, arguments_list, tool_id_list), start=1
|
||||
):
|
||||
if self.has_function(function_name):
|
||||
run_llm = False
|
||||
arguments = json.loads(arguments)
|
||||
await self.call_function(
|
||||
context=context,
|
||||
function_name=function_name,
|
||||
arguments=arguments,
|
||||
tool_call_id=tool_id,
|
||||
run_llm=run_llm,
|
||||
)
|
||||
else:
|
||||
raise OpenAIUnhandledFunctionException(
|
||||
f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function."
|
||||
)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
@@ -301,7 +284,7 @@ class BaseOpenAILLMService(LLMService):
|
||||
elif isinstance(frame, VisionImageRawFrame):
|
||||
context = OpenAILLMContext.from_image_frame(frame)
|
||||
elif isinstance(frame, LLMUpdateSettingsFrame):
|
||||
await self._update_settings(frame)
|
||||
await self._update_settings(frame.settings)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -336,9 +319,13 @@ class OpenAILLMService(BaseOpenAILLMService):
|
||||
super().__init__(model=model, params=params, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def create_context_aggregator(context: OpenAILLMContext) -> OpenAIContextAggregatorPair:
|
||||
def create_context_aggregator(
|
||||
context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True
|
||||
) -> OpenAIContextAggregatorPair:
|
||||
user = OpenAIUserContextAggregator(context)
|
||||
assistant = OpenAIAssistantContextAggregator(user)
|
||||
assistant = OpenAIAssistantContextAggregator(
|
||||
user, expect_stripped_words=assistant_expect_stripped_words
|
||||
)
|
||||
return OpenAIContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
|
||||
@@ -401,22 +388,20 @@ class OpenAITTSService(TTSService):
|
||||
):
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
self._voice: ValidVoice = VALID_VOICES.get(voice, "alloy")
|
||||
self._settings = {
|
||||
"sample_rate": sample_rate,
|
||||
}
|
||||
self.set_model_name(model)
|
||||
self._sample_rate = sample_rate
|
||||
self.set_voice(voice)
|
||||
|
||||
self._client = AsyncOpenAI(api_key=api_key)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def set_voice(self, voice: str):
|
||||
logger.debug(f"Switching TTS voice to: [{voice}]")
|
||||
self._voice = VALID_VOICES.get(voice, self._voice)
|
||||
|
||||
async def set_model(self, model: str):
|
||||
logger.debug(f"Switching TTS model to: [{model}]")
|
||||
self._model = model
|
||||
self.set_model_name(model)
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Generating TTS: [{text}]")
|
||||
@@ -424,9 +409,9 @@ class OpenAITTSService(TTSService):
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
async with self._client.audio.speech.with_streaming_response.create(
|
||||
input=text,
|
||||
input=text or " ", # Text must contain at least one character
|
||||
model=self.model_name,
|
||||
voice=self._voice,
|
||||
voice=VALID_VOICES[self._voice_id],
|
||||
response_format="pcm",
|
||||
) as r:
|
||||
if r.status_code != 200:
|
||||
@@ -441,61 +426,104 @@ class OpenAITTSService(TTSService):
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
yield TTSStartedFrame()
|
||||
async for chunk in r.iter_bytes(8192):
|
||||
if len(chunk) > 0:
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(chunk, self.sample_rate, 1)
|
||||
frame = TTSAudioRawFrame(chunk, self._settings["sample_rate"], 1)
|
||||
yield frame
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
yield TTSStoppedFrame()
|
||||
except BadRequestError as e:
|
||||
logger.exception(f"{self} error generating TTS: {e}")
|
||||
|
||||
|
||||
# internal use only -- todo: refactor
|
||||
@dataclass
|
||||
class OpenAIImageMessageFrame(Frame):
|
||||
user_image_raw_frame: UserImageRawFrame
|
||||
text: Optional[str] = None
|
||||
|
||||
|
||||
class OpenAIUserContextAggregator(LLMUserContextAggregator):
|
||||
def __init__(self, context: OpenAILLMContext):
|
||||
super().__init__(context=context)
|
||||
|
||||
async def process_frame(self, frame, direction):
|
||||
await super().process_frame(frame, direction)
|
||||
# Our parent method has already called push_frame(). So we can't interrupt the
|
||||
# flow here and we don't need to call push_frame() ourselves.
|
||||
try:
|
||||
if isinstance(frame, UserImageRequestFrame):
|
||||
# The LLM sends a UserImageRequestFrame upstream. Cache any context provided with
|
||||
# that frame so we can use it when we assemble the image message in the assistant
|
||||
# context aggregator.
|
||||
if frame.context:
|
||||
if isinstance(frame.context, str):
|
||||
self._context._user_image_request_context[frame.user_id] = frame.context
|
||||
else:
|
||||
logger.error(
|
||||
f"Unexpected UserImageRequestFrame context type: {type(frame.context)}"
|
||||
)
|
||||
del self._context._user_image_request_context[frame.user_id]
|
||||
else:
|
||||
if frame.user_id in self._context._user_image_request_context:
|
||||
del self._context._user_image_request_context[frame.user_id]
|
||||
elif isinstance(frame, UserImageRawFrame):
|
||||
# Push a new OpenAIImageMessageFrame with the text context we cached
|
||||
# downstream to be handled by our assistant context aggregator. This is
|
||||
# necessary so that we add the message to the context in the right order.
|
||||
text = self._context._user_image_request_context.get(frame.user_id) or ""
|
||||
if text:
|
||||
del self._context._user_image_request_context[frame.user_id]
|
||||
frame = OpenAIImageMessageFrame(user_image_raw_frame=frame, text=text)
|
||||
await self.push_frame(frame)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing frame: {e}")
|
||||
|
||||
|
||||
class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
def __init__(self, user_context_aggregator: OpenAIUserContextAggregator):
|
||||
super().__init__(context=user_context_aggregator._context)
|
||||
def __init__(self, user_context_aggregator: OpenAIUserContextAggregator, **kwargs):
|
||||
super().__init__(context=user_context_aggregator._context, **kwargs)
|
||||
self._user_context_aggregator = user_context_aggregator
|
||||
self._function_call_in_progress = None
|
||||
self._function_calls_in_progress = {}
|
||||
self._function_call_result = None
|
||||
self._pending_image_frame_message = None
|
||||
|
||||
async def process_frame(self, frame, direction):
|
||||
await super().process_frame(frame, direction)
|
||||
# See note above about not calling push_frame() here.
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
self._function_call_in_progress = None
|
||||
self._function_calls_in_progress.clear()
|
||||
self._function_call_finished = None
|
||||
elif isinstance(frame, FunctionCallInProgressFrame):
|
||||
self._function_call_in_progress = frame
|
||||
logger.debug(f"FunctionCallInProgressFrame: {frame}")
|
||||
self._function_calls_in_progress[frame.tool_call_id] = frame
|
||||
elif isinstance(frame, FunctionCallResultFrame):
|
||||
if (
|
||||
self._function_call_in_progress
|
||||
and self._function_call_in_progress.tool_call_id == frame.tool_call_id
|
||||
):
|
||||
self._function_call_in_progress = None
|
||||
logger.debug(f"FunctionCallResultFrame: {frame}")
|
||||
if frame.tool_call_id in self._function_calls_in_progress:
|
||||
del self._function_calls_in_progress[frame.tool_call_id]
|
||||
self._function_call_result = frame
|
||||
# TODO-CB: Kwin wants us to refactor this out of here but I REFUSE
|
||||
await self._push_aggregation()
|
||||
else:
|
||||
logger.warning(
|
||||
"FunctionCallResultFrame tool_call_id does not match FunctionCallInProgressFrame tool_call_id"
|
||||
"FunctionCallResultFrame tool_call_id does not match any function call in progress"
|
||||
)
|
||||
self._function_call_in_progress = None
|
||||
self._function_call_result = None
|
||||
elif isinstance(frame, OpenAIImageMessageFrame):
|
||||
self._pending_image_frame_message = frame
|
||||
await self._push_aggregation()
|
||||
|
||||
async def _push_aggregation(self):
|
||||
if not (self._aggregation or self._function_call_result):
|
||||
if not (
|
||||
self._aggregation or self._function_call_result or self._pending_image_frame_message
|
||||
):
|
||||
return
|
||||
|
||||
run_llm = False
|
||||
|
||||
aggregation = self._aggregation
|
||||
self._aggregation = ""
|
||||
self._reset()
|
||||
|
||||
try:
|
||||
if self._function_call_result:
|
||||
@@ -524,12 +552,27 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
"tool_call_id": frame.tool_call_id,
|
||||
}
|
||||
)
|
||||
run_llm = True
|
||||
# Only run the LLM if there are no more function calls in progress.
|
||||
run_llm = not bool(self._function_calls_in_progress)
|
||||
else:
|
||||
self._context.add_message({"role": "assistant", "content": aggregation})
|
||||
|
||||
if self._pending_image_frame_message:
|
||||
frame = self._pending_image_frame_message
|
||||
self._pending_image_frame_message = None
|
||||
self._context.add_image_frame_message(
|
||||
format=frame.user_image_raw_frame.format,
|
||||
size=frame.user_image_raw_frame.size,
|
||||
image=frame.user_image_raw_frame.image,
|
||||
text=frame.text,
|
||||
)
|
||||
run_llm = True
|
||||
|
||||
if run_llm:
|
||||
await self._user_context_aggregator.push_context_frame()
|
||||
|
||||
frame = OpenAILLMContextFrame(self._context)
|
||||
await self.push_frame(frame)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing frame: {e}")
|
||||
|
||||
2
src/pipecat/services/openai_realtime_beta/__init__.py
Normal file
2
src/pipecat/services/openai_realtime_beta/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .events import InputAudioTranscription, SessionProperties, TurnDetection
|
||||
from .llm_and_context import OpenAILLMServiceRealtimeBeta
|
||||
433
src/pipecat/services/openai_realtime_beta/events.py
Normal file
433
src/pipecat/services/openai_realtime_beta/events.py
Normal file
@@ -0,0 +1,433 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
#
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
#
|
||||
# session properties
|
||||
#
|
||||
|
||||
|
||||
class InputAudioTranscription(BaseModel):
|
||||
model: Optional[str] = "whisper-1"
|
||||
|
||||
|
||||
class TurnDetection(BaseModel):
|
||||
type: Optional[Literal["server_vad"]] = "server_vad"
|
||||
threshold: Optional[float] = 0.5
|
||||
prefix_padding_ms: Optional[int] = 300
|
||||
silence_duration_ms: Optional[int] = 800
|
||||
|
||||
|
||||
class SessionProperties(BaseModel):
|
||||
modalities: Optional[List[Literal["text", "audio"]]] = None
|
||||
instructions: Optional[str] = None
|
||||
voice: Optional[str] = None
|
||||
input_audio_format: Optional[Literal["pcm16", "g711_ulaw", "g711_alaw"]] = None
|
||||
output_audio_format: Optional[Literal["pcm16", "g711_ulaw", "g711_alaw"]] = None
|
||||
input_audio_transcription: Optional[InputAudioTranscription] = None
|
||||
# set turn_detection to False to disable turn detection
|
||||
turn_detection: Optional[Union[TurnDetection, bool]] = Field(default=None)
|
||||
tools: Optional[List[Dict]] = None
|
||||
tool_choice: Optional[Literal["auto", "none", "required"]] = None
|
||||
temperature: Optional[float] = None
|
||||
max_response_output_tokens: Optional[Union[int, Literal["inf"]]] = None
|
||||
|
||||
|
||||
#
|
||||
# context
|
||||
#
|
||||
|
||||
|
||||
class ItemContent(BaseModel):
|
||||
type: Literal["text", "audio", "input_text", "input_audio"]
|
||||
text: Optional[str] = None
|
||||
audio: Optional[str] = None # base64-encoded audio
|
||||
transcript: Optional[str] = None
|
||||
|
||||
|
||||
class ConversationItem(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4().hex))
|
||||
object: Optional[Literal["realtime.item"]] = None
|
||||
type: Literal["message", "function_call", "function_call_output"]
|
||||
status: Optional[Literal["completed", "in_progress", "incomplete"]] = None
|
||||
# role and content are present for message items
|
||||
role: Optional[Literal["user", "assistant", "system"]] = None
|
||||
content: Optional[List[ItemContent]] = None
|
||||
# these four fields are present for function_call items
|
||||
call_id: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
arguments: Optional[str] = None
|
||||
output: Optional[str] = None
|
||||
|
||||
|
||||
class RealtimeConversation(BaseModel):
|
||||
id: str
|
||||
object: Literal["realtime.conversation"]
|
||||
|
||||
|
||||
class ResponseProperties(BaseModel):
|
||||
modalities: Optional[List[Literal["text", "audio"]]] = ["audio", "text"]
|
||||
instructions: Optional[str] = None
|
||||
voice: Optional[str] = None
|
||||
output_audio_format: Optional[Literal["pcm16", "g711_ulaw", "g711_alaw"]] = None
|
||||
tools: Optional[List[Dict]] = []
|
||||
tool_choice: Optional[Literal["auto", "none", "required"]] = None
|
||||
temperature: Optional[float] = None
|
||||
max_response_output_tokens: Optional[Union[int, Literal["inf"]]] = None
|
||||
|
||||
|
||||
#
|
||||
# error class
|
||||
#
|
||||
class RealtimeError(BaseModel):
|
||||
type: str
|
||||
code: Optional[str] = ""
|
||||
message: str
|
||||
param: Optional[str] = None
|
||||
|
||||
|
||||
#
|
||||
# client events
|
||||
#
|
||||
|
||||
|
||||
class ClientEvent(BaseModel):
|
||||
event_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
|
||||
|
||||
class SessionUpdateEvent(ClientEvent):
|
||||
type: Literal["session.update"] = "session.update"
|
||||
session: SessionProperties
|
||||
|
||||
def model_dump(self, *args, **kwargs) -> Dict[str, Any]:
|
||||
dump = super().model_dump(*args, **kwargs)
|
||||
|
||||
# Handle turn_detection so that False is serialized as null
|
||||
if "turn_detection" in dump["session"]:
|
||||
if dump["session"]["turn_detection"] is False:
|
||||
dump["session"]["turn_detection"] = None
|
||||
|
||||
return dump
|
||||
|
||||
|
||||
class InputAudioBufferAppendEvent(ClientEvent):
|
||||
type: Literal["input_audio_buffer.append"] = "input_audio_buffer.append"
|
||||
audio: str # base64-encoded audio
|
||||
|
||||
|
||||
class InputAudioBufferCommitEvent(ClientEvent):
|
||||
type: Literal["input_audio_buffer.commit"] = "input_audio_buffer.commit"
|
||||
|
||||
|
||||
class InputAudioBufferClearEvent(ClientEvent):
|
||||
type: Literal["input_audio_buffer.clear"] = "input_audio_buffer.clear"
|
||||
|
||||
|
||||
class ConversationItemCreateEvent(ClientEvent):
|
||||
type: Literal["conversation.item.create"] = "conversation.item.create"
|
||||
previous_item_id: Optional[str] = None
|
||||
item: ConversationItem
|
||||
|
||||
|
||||
class ConversationItemTruncateEvent(ClientEvent):
|
||||
type: Literal["conversation.item.truncate"] = "conversation.item.truncate"
|
||||
item_id: str
|
||||
content_index: int
|
||||
audio_end_ms: int
|
||||
|
||||
|
||||
class ConversationItemDeleteEvent(ClientEvent):
|
||||
type: Literal["conversation.item.delete"] = "conversation.item.delete"
|
||||
item_id: str
|
||||
|
||||
|
||||
class ResponseCreateEvent(ClientEvent):
|
||||
type: Literal["response.create"] = "response.create"
|
||||
response: Optional[ResponseProperties] = None
|
||||
|
||||
|
||||
class ResponseCancelEvent(ClientEvent):
|
||||
type: Literal["response.cancel"] = "response.cancel"
|
||||
|
||||
|
||||
#
|
||||
# server events
|
||||
#
|
||||
|
||||
|
||||
class ServerEvent(BaseModel):
|
||||
event_id: str
|
||||
type: str
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class SessionCreatedEvent(ServerEvent):
|
||||
type: Literal["session.created"]
|
||||
session: SessionProperties
|
||||
|
||||
|
||||
class SessionUpdatedEvent(ServerEvent):
|
||||
type: Literal["session.updated"]
|
||||
session: SessionProperties
|
||||
|
||||
|
||||
class ConversationCreated(ServerEvent):
|
||||
type: Literal["conversation.created"]
|
||||
conversation: RealtimeConversation
|
||||
|
||||
|
||||
class ConversationItemCreated(ServerEvent):
|
||||
type: Literal["conversation.item.created"]
|
||||
previous_item_id: Optional[str] = None
|
||||
item: ConversationItem
|
||||
|
||||
|
||||
class ConversationItemInputAudioTranscriptionCompleted(ServerEvent):
|
||||
type: Literal["conversation.item.input_audio_transcription.completed"]
|
||||
item_id: str
|
||||
content_index: int
|
||||
transcript: str
|
||||
|
||||
|
||||
class ConversationItemInputAudioTranscriptionFailed(ServerEvent):
|
||||
type: Literal["conversation.item.input_audio_transcription.failed"]
|
||||
item_id: str
|
||||
content_index: int
|
||||
error: RealtimeError
|
||||
|
||||
|
||||
class ConversationItemTruncated(ServerEvent):
|
||||
type: Literal["conversation.item.truncated"]
|
||||
item_id: str
|
||||
content_index: int
|
||||
audio_end_ms: int
|
||||
|
||||
|
||||
class ConversationItemDeleted(ServerEvent):
|
||||
type: Literal["conversation.item.deleted"]
|
||||
item_id: str
|
||||
|
||||
|
||||
class ResponseCreated(ServerEvent):
|
||||
type: Literal["response.created"]
|
||||
response: "Response"
|
||||
|
||||
|
||||
class ResponseDone(ServerEvent):
|
||||
type: Literal["response.done"]
|
||||
response: "Response"
|
||||
|
||||
|
||||
class ResponseOutputItemAdded(ServerEvent):
|
||||
type: Literal["response.output_item.added"]
|
||||
response_id: str
|
||||
output_index: int
|
||||
item: ConversationItem
|
||||
|
||||
|
||||
class ResponseOutputItemDone(ServerEvent):
|
||||
type: Literal["response.output_item.done"]
|
||||
response_id: str
|
||||
output_index: int
|
||||
item: ConversationItem
|
||||
|
||||
|
||||
class ResponseContentPartAdded(ServerEvent):
|
||||
type: Literal["response.content_part.added"]
|
||||
response_id: str
|
||||
item_id: str
|
||||
output_index: int
|
||||
content_index: int
|
||||
part: ItemContent
|
||||
|
||||
|
||||
class ResponseContentPartDone(ServerEvent):
|
||||
type: Literal["response.content_part.done"]
|
||||
response_id: str
|
||||
item_id: str
|
||||
output_index: int
|
||||
content_index: int
|
||||
part: ItemContent
|
||||
|
||||
|
||||
class ResponseTextDelta(ServerEvent):
|
||||
type: Literal["response.text.delta"]
|
||||
response_id: str
|
||||
item_id: str
|
||||
output_index: int
|
||||
content_index: int
|
||||
delta: str
|
||||
|
||||
|
||||
class ResponseTextDone(ServerEvent):
|
||||
type: Literal["response.text.done"]
|
||||
response_id: str
|
||||
item_id: str
|
||||
output_index: int
|
||||
content_index: int
|
||||
text: str
|
||||
|
||||
|
||||
class ResponseAudioTranscriptDelta(ServerEvent):
|
||||
type: Literal["response.audio_transcript.delta"]
|
||||
response_id: str
|
||||
item_id: str
|
||||
output_index: int
|
||||
content_index: int
|
||||
delta: str
|
||||
|
||||
|
||||
class ResponseAudioTranscriptDone(ServerEvent):
|
||||
type: Literal["response.audio_transcript.done"]
|
||||
response_id: str
|
||||
item_id: str
|
||||
output_index: int
|
||||
content_index: int
|
||||
transcript: str
|
||||
|
||||
|
||||
class ResponseAudioDelta(ServerEvent):
|
||||
type: Literal["response.audio.delta"]
|
||||
response_id: str
|
||||
item_id: str
|
||||
output_index: int
|
||||
content_index: int
|
||||
delta: str # base64-encoded audio
|
||||
|
||||
|
||||
class ResponseAudioDone(ServerEvent):
|
||||
type: Literal["response.audio.done"]
|
||||
response_id: str
|
||||
item_id: str
|
||||
output_index: int
|
||||
content_index: int
|
||||
|
||||
|
||||
class ResponseFunctionCallArgumentsDelta(ServerEvent):
|
||||
type: Literal["response.function_call_arguments.delta"]
|
||||
response_id: str
|
||||
item_id: str
|
||||
output_index: int
|
||||
call_id: str
|
||||
delta: str
|
||||
|
||||
|
||||
class ResponseFunctionCallArgumentsDone(ServerEvent):
|
||||
type: Literal["response.function_call_arguments.done"]
|
||||
response_id: str
|
||||
item_id: str
|
||||
output_index: int
|
||||
call_id: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class InputAudioBufferSpeechStarted(ServerEvent):
|
||||
type: Literal["input_audio_buffer.speech_started"]
|
||||
audio_start_ms: int
|
||||
item_id: str
|
||||
|
||||
|
||||
class InputAudioBufferSpeechStopped(ServerEvent):
|
||||
type: Literal["input_audio_buffer.speech_stopped"]
|
||||
audio_end_ms: int
|
||||
item_id: str
|
||||
|
||||
|
||||
class InputAudioBufferCommitted(ServerEvent):
|
||||
type: Literal["input_audio_buffer.committed"]
|
||||
previous_item_id: Optional[str] = None
|
||||
item_id: str
|
||||
|
||||
|
||||
class InputAudioBufferCleared(ServerEvent):
|
||||
type: Literal["input_audio_buffer.cleared"]
|
||||
|
||||
|
||||
class ErrorEvent(ServerEvent):
|
||||
type: Literal["error"]
|
||||
error: RealtimeError
|
||||
|
||||
|
||||
class RateLimitsUpdated(ServerEvent):
|
||||
type: Literal["rate_limits.updated"]
|
||||
rate_limits: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class TokenDetails(BaseModel):
|
||||
cached_tokens: Optional[int] = 0
|
||||
text_tokens: Optional[int] = 0
|
||||
audio_tokens: Optional[int] = 0
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class Usage(BaseModel):
|
||||
total_tokens: int
|
||||
input_tokens: int
|
||||
output_tokens: int
|
||||
input_token_details: TokenDetails
|
||||
output_token_details: TokenDetails
|
||||
|
||||
|
||||
class Response(BaseModel):
|
||||
id: str
|
||||
object: Literal["realtime.response"]
|
||||
status: Literal["completed", "in_progress", "incomplete", "cancelled", "failed"]
|
||||
status_details: Any
|
||||
output: List[ConversationItem]
|
||||
usage: Optional[Usage] = None
|
||||
|
||||
|
||||
_server_event_types = {
|
||||
"error": ErrorEvent,
|
||||
"session.created": SessionCreatedEvent,
|
||||
"session.updated": SessionUpdatedEvent,
|
||||
"conversation.created": ConversationCreated,
|
||||
"input_audio_buffer.committed": InputAudioBufferCommitted,
|
||||
"input_audio_buffer.cleared": InputAudioBufferCleared,
|
||||
"input_audio_buffer.speech_started": InputAudioBufferSpeechStarted,
|
||||
"input_audio_buffer.speech_stopped": InputAudioBufferSpeechStopped,
|
||||
"conversation.item.created": ConversationItemCreated,
|
||||
"conversation.item.input_audio_transcription.completed": ConversationItemInputAudioTranscriptionCompleted,
|
||||
"conversation.item.input_audio_transcription.failed": ConversationItemInputAudioTranscriptionFailed,
|
||||
"conversation.item.truncated": ConversationItemTruncated,
|
||||
"conversation.item.deleted": ConversationItemDeleted,
|
||||
"response.created": ResponseCreated,
|
||||
"response.done": ResponseDone,
|
||||
"response.output_item.added": ResponseOutputItemAdded,
|
||||
"response.output_item.done": ResponseOutputItemDone,
|
||||
"response.content_part.added": ResponseContentPartAdded,
|
||||
"response.content_part.done": ResponseContentPartDone,
|
||||
"response.text.delta": ResponseTextDelta,
|
||||
"response.text.done": ResponseTextDone,
|
||||
"response.audio_transcript.delta": ResponseAudioTranscriptDelta,
|
||||
"response.audio_transcript.done": ResponseAudioTranscriptDone,
|
||||
"response.audio.delta": ResponseAudioDelta,
|
||||
"response.audio.done": ResponseAudioDone,
|
||||
"response.function_call_arguments.delta": ResponseFunctionCallArgumentsDelta,
|
||||
"response.function_call_arguments.done": ResponseFunctionCallArgumentsDone,
|
||||
"rate_limits.updated": RateLimitsUpdated,
|
||||
}
|
||||
|
||||
|
||||
def parse_server_event(str):
|
||||
try:
|
||||
event = json.loads(str)
|
||||
event_type = event["type"]
|
||||
if event_type not in _server_event_types:
|
||||
raise Exception(f"Unimplemented server event type: {event_type}")
|
||||
return _server_event_types[event_type].model_validate(event)
|
||||
except Exception as e:
|
||||
raise Exception(f"{e} \n\n{str}")
|
||||
754
src/pipecat/services/openai_realtime_beta/llm_and_context.py
Normal file
754
src/pipecat/services/openai_realtime_beta/llm_and_context.py
Normal file
@@ -0,0 +1,754 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import copy
|
||||
import json
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
import websockets
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
DataFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
FunctionCallResultFrame,
|
||||
InputAudioRawFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesAppendFrame,
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMSetToolsFrame,
|
||||
LLMUpdateSettingsFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
StopInterruptionFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import LLMService
|
||||
from pipecat.services.openai import (
|
||||
OpenAIAssistantContextAggregator,
|
||||
OpenAIContextAggregatorPair,
|
||||
OpenAIUserContextAggregator,
|
||||
)
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
from . import events
|
||||
from .events import SessionProperties
|
||||
|
||||
# websocket logger -- in case needed for debugging send/recv
|
||||
# import logging
|
||||
# logging.basicConfig(
|
||||
# format="%(message)s",
|
||||
# level=logging.DEBUG,
|
||||
# )
|
||||
|
||||
|
||||
@dataclass
|
||||
class _InternalMessagesUpdateFrame(DataFrame):
|
||||
context: "OpenAIRealtimeLLMContext"
|
||||
|
||||
|
||||
@dataclass
|
||||
class _InternalFunctionCallResultFrame(DataFrame):
|
||||
result_frame: FunctionCallResultFrame
|
||||
|
||||
|
||||
@dataclass
|
||||
class _CurrentAudioResponse:
|
||||
item_id: str
|
||||
content_index: int
|
||||
start_time_ms: int
|
||||
total_size: int = 0
|
||||
|
||||
|
||||
class OpenAIUnhandledFunctionException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class OpenAIRealtimeLLMContext(OpenAILLMContext):
|
||||
def __init__(self, messages=None, tools=None, **kwargs):
|
||||
super().__init__(messages=messages, tools=tools, **kwargs)
|
||||
self.__setup_local()
|
||||
|
||||
def __setup_local(self):
|
||||
self.llm_needs_settings_update = True
|
||||
self.llm_needs_initial_messages = True
|
||||
self._session_instructions = ""
|
||||
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def upgrade_to_realtime(obj: OpenAILLMContext) -> "OpenAIRealtimeLLMContext":
|
||||
if isinstance(obj, OpenAILLMContext) and not isinstance(obj, OpenAIRealtimeLLMContext):
|
||||
obj.__class__ = OpenAIRealtimeLLMContext
|
||||
obj.__setup_local()
|
||||
return obj
|
||||
|
||||
# todo
|
||||
# - finish implementing all frames
|
||||
# - add message conversion functions to OpenAILLMContext base class
|
||||
|
||||
def from_standard_message(self, message):
|
||||
if message.get("role") == "assistant" and message.get("tool_calls"):
|
||||
tc = message.get("tool_calls")[0]
|
||||
return events.ConversationItem(
|
||||
type="function_call",
|
||||
call_id=tc["id"],
|
||||
name=tc["function"]["name"],
|
||||
arguments=tc["function"]["arguments"],
|
||||
)
|
||||
logger.error(f"Unhandled message type in from_standard_message: {message}")
|
||||
|
||||
def get_messages_for_initializing_history(self):
|
||||
# We can't load a long conversation history into the openai realtime api yet. (The API/model
|
||||
# forgets that it can do audio, if you do a series of `conversation.item.create` calls.) So
|
||||
# our general strategy until this is fixed is just to put everything into a first "user"
|
||||
# message as a single input.
|
||||
if not self.messages:
|
||||
return []
|
||||
|
||||
messages = copy.deepcopy(self.messages)
|
||||
|
||||
# If we have a "system" message as our first message, let's pull that out into session
|
||||
# "instructions"
|
||||
if messages[0].get("role") == "system":
|
||||
self.llm_needs_settings_update = True
|
||||
system = messages.pop(0)
|
||||
content = system.get("content")
|
||||
if isinstance(content, str):
|
||||
self._session_instructions = content
|
||||
elif isinstance(content, list):
|
||||
self._session_instructions = content[0].get("text")
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
# If we have just a single "user" item, we can just send it normally
|
||||
if len(messages) == 1 and messages[0].get("role") == "user":
|
||||
return messages
|
||||
|
||||
# Otherwise, let's pack everything into a single "user" message with a bit of
|
||||
# explanation for the LLM
|
||||
intro_text = """
|
||||
This is a previously saved conversation. Please treat this conversation history as a
|
||||
starting point for the current conversation."""
|
||||
|
||||
trailing_text = """
|
||||
This is the end of the previously saved conversation. Please continue the conversation
|
||||
from here. If the last message is a user instruction or question, act on that instruction
|
||||
or answer the question. If the last message is an assistant response, simple say that you
|
||||
are ready to continue the conversation."""
|
||||
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"type": "message",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": "\n\n".join(
|
||||
[intro_text, json.dumps(messages, indent=2), trailing_text]
|
||||
),
|
||||
}
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
def add_user_content_item_as_message(self, item):
|
||||
message = {
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": item.content[0].transcript}],
|
||||
}
|
||||
self.add_message(message)
|
||||
|
||||
def add_assistant_content_item_as_message(self, item):
|
||||
message = {"role": "assistant", "content": []}
|
||||
for content in item.content:
|
||||
if content.type == "audio":
|
||||
message["content"].append({"type": "text", "text": content.transcript})
|
||||
else:
|
||||
logger.error(f"Unhandled content type in assistant item: {content.type} - {item}")
|
||||
self.add_message(message)
|
||||
|
||||
|
||||
class OpenAIRealtimeUserContextAggregator(OpenAIUserContextAggregator):
|
||||
async def process_frame(
|
||||
self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM
|
||||
):
|
||||
await super().process_frame(frame, direction)
|
||||
# Parent does not push LLMMessagesUpdateFrame. This ensures that in a typical pipeline,
|
||||
# messages are only processed by the user context aggregator, which is generally what we want. But
|
||||
# we also need to send new messages over the websocket, so the openai realtime API has them
|
||||
# in its context.
|
||||
if isinstance(frame, LLMMessagesUpdateFrame):
|
||||
await self.push_frame(_InternalMessagesUpdateFrame(context=self._context))
|
||||
|
||||
# Parent also doesn't push the LLMSetToolsFrame.
|
||||
if isinstance(frame, LLMSetToolsFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _push_aggregation(self):
|
||||
# for the moment, ignore all user input coming into the pipeline.
|
||||
# todo: think about whether/how to fix this to allow for text input from
|
||||
# upstream (transport/transcription, or other sources)
|
||||
pass
|
||||
|
||||
|
||||
class OpenAIRealtimeAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
async def _push_aggregation(self):
|
||||
# the only thing we implement here is function calling. in all other cases, messages
|
||||
# are added to the context when we receive openai realtime api events
|
||||
if not self._function_call_result:
|
||||
return
|
||||
|
||||
self._reset()
|
||||
try:
|
||||
frame = self._function_call_result
|
||||
self._function_call_result = None
|
||||
if frame.result:
|
||||
# The "tool_call" message from the LLM that triggered the function call
|
||||
self._context.add_message(
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": frame.tool_call_id,
|
||||
"function": {
|
||||
"name": frame.function_name,
|
||||
"arguments": json.dumps(frame.arguments),
|
||||
},
|
||||
"type": "function",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
# The result of the function call. Need to add this both to our context here and to
|
||||
# the openai realtime api context.
|
||||
result_message = {
|
||||
"role": "tool",
|
||||
"content": json.dumps(frame.result),
|
||||
"tool_call_id": frame.tool_call_id,
|
||||
}
|
||||
|
||||
self._context.add_message(result_message)
|
||||
# The standard function callback code path pushes the FunctionCallResultFrame from the llm itself,
|
||||
# so we didn't have a chance to add the result to the openai realtime api context. Let's push a
|
||||
# special frame to do that.
|
||||
await self._user_context_aggregator.push_frame(
|
||||
_InternalFunctionCallResultFrame(result_frame=frame)
|
||||
)
|
||||
run_llm = frame.run_llm
|
||||
|
||||
if run_llm:
|
||||
await self._user_context_aggregator.push_context_frame()
|
||||
|
||||
frame = OpenAILLMContextFrame(self._context)
|
||||
await self.push_frame(frame)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing frame: {e}")
|
||||
|
||||
|
||||
class OpenAILLMServiceRealtimeBeta(LLMService):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
base_url="wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01",
|
||||
session_properties: events.SessionProperties = events.SessionProperties(),
|
||||
start_audio_paused: bool = False,
|
||||
send_transcription_frames: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(base_url=base_url, **kwargs)
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
|
||||
self._session_properties: events.SessionProperties = session_properties
|
||||
self._audio_input_paused = start_audio_paused
|
||||
self._send_transcription_frames = send_transcription_frames
|
||||
self._websocket = None
|
||||
self._receive_task = None
|
||||
self._context = None
|
||||
|
||||
self._disconnecting = False
|
||||
self._api_session_ready = False
|
||||
self._run_llm_when_api_session_ready = False
|
||||
|
||||
self._current_assistant_response = None
|
||||
self._current_audio_response = None
|
||||
|
||||
self._messages_added_manually = {}
|
||||
self._user_and_response_message_tuple = None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
def set_audio_input_paused(self, paused: bool):
|
||||
self._audio_input_paused = paused
|
||||
|
||||
#
|
||||
# standard AIService frame handling
|
||||
#
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
#
|
||||
# speech and interruption handling
|
||||
#
|
||||
|
||||
async def _handle_interruption(self):
|
||||
if self._session_properties.turn_detection is None:
|
||||
await self.send_client_event(events.InputAudioBufferClearEvent())
|
||||
await self.send_client_event(events.ResponseCancelEvent())
|
||||
await self._truncate_current_audio_response()
|
||||
await self.stop_all_metrics()
|
||||
if self._current_assistant_response:
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
|
||||
async def _handle_user_started_speaking(self, frame):
|
||||
if self._session_properties.turn_detection is None:
|
||||
await self._handle_interruption()
|
||||
|
||||
async def _handle_user_stopped_speaking(self, frame):
|
||||
if self._session_properties.turn_detection is None:
|
||||
await self.send_client_event(events.InputAudioBufferCommitEvent())
|
||||
await self.send_client_event(events.ResponseCreateEvent())
|
||||
|
||||
async def _handle_bot_stopped_speaking(self):
|
||||
self._current_audio_response = None
|
||||
|
||||
async def _truncate_current_audio_response(self):
|
||||
# if the bot is still speaking, truncate the last message
|
||||
if self._current_audio_response:
|
||||
current = self._current_audio_response
|
||||
self._current_audio_response = None
|
||||
elapsed_ms = int(time.time() * 1000 - current.start_time_ms)
|
||||
await self.send_client_event(
|
||||
events.ConversationItemTruncateEvent(
|
||||
item_id=current.item_id,
|
||||
content_index=current.content_index,
|
||||
audio_end_ms=elapsed_ms,
|
||||
)
|
||||
)
|
||||
|
||||
#
|
||||
# frame processing
|
||||
#
|
||||
# StartFrame, StopFrame, CancelFrame implemented in base class
|
||||
#
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, TranscriptionFrame):
|
||||
pass
|
||||
elif isinstance(frame, OpenAILLMContextFrame):
|
||||
context: OpenAIRealtimeLLMContext = OpenAIRealtimeLLMContext.upgrade_to_realtime(
|
||||
frame.context
|
||||
)
|
||||
if not self._context:
|
||||
self._context = context
|
||||
elif frame.context is not self._context:
|
||||
# If the context has changed, reset the conversation
|
||||
self._context = context
|
||||
await self.reset_conversation()
|
||||
# Run the LLM at next opportunity
|
||||
await self._create_response()
|
||||
elif isinstance(frame, InputAudioRawFrame):
|
||||
if not self._audio_input_paused:
|
||||
await self._send_user_audio(frame)
|
||||
elif isinstance(frame, StartInterruptionFrame):
|
||||
await self._handle_interruption()
|
||||
elif isinstance(frame, UserStartedSpeakingFrame):
|
||||
await self._handle_user_started_speaking(frame)
|
||||
elif isinstance(frame, UserStoppedSpeakingFrame):
|
||||
await self._handle_user_stopped_speaking(frame)
|
||||
elif isinstance(frame, BotStoppedSpeakingFrame):
|
||||
await self._handle_bot_stopped_speaking()
|
||||
elif isinstance(frame, LLMMessagesAppendFrame):
|
||||
await self._handle_messages_append(frame)
|
||||
elif isinstance(frame, _InternalMessagesUpdateFrame):
|
||||
self._context = frame.context
|
||||
elif isinstance(frame, LLMUpdateSettingsFrame):
|
||||
self._session_properties = SessionProperties(**frame.settings)
|
||||
await self._update_settings()
|
||||
elif isinstance(frame, LLMSetToolsFrame):
|
||||
await self._update_settings()
|
||||
elif isinstance(frame, _InternalFunctionCallResultFrame):
|
||||
await self._handle_function_call_result(frame.result_frame)
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _handle_messages_append(self, frame):
|
||||
logger.error("!!! NEED TO IMPLEMENT MESSAGES APPEND")
|
||||
|
||||
async def _handle_function_call_result(self, frame):
|
||||
item = events.ConversationItem(
|
||||
type="function_call_output",
|
||||
call_id=frame.tool_call_id,
|
||||
output=json.dumps(frame.result),
|
||||
)
|
||||
await self.send_client_event(events.ConversationItemCreateEvent(item=item))
|
||||
|
||||
#
|
||||
# websocket communication
|
||||
#
|
||||
|
||||
async def send_client_event(self, event: events.ClientEvent):
|
||||
await self._ws_send(event.model_dump(exclude_none=True))
|
||||
|
||||
async def _connect(self):
|
||||
try:
|
||||
if self._websocket:
|
||||
# Here we assume that if we have a websocket, we are connected. We
|
||||
# handle disconnections in the send/recv code paths.
|
||||
return
|
||||
self._websocket = await websockets.connect(
|
||||
uri=self.base_url,
|
||||
extra_headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"OpenAI-Beta": "realtime=v1",
|
||||
},
|
||||
)
|
||||
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._websocket = None
|
||||
|
||||
async def _disconnect(self):
|
||||
try:
|
||||
self._disconnecting = True
|
||||
self._api_session_ready = False
|
||||
await self.stop_all_metrics()
|
||||
if self._websocket:
|
||||
await self._websocket.close()
|
||||
self._websocket = None
|
||||
if self._receive_task:
|
||||
self._receive_task.cancel()
|
||||
try:
|
||||
await asyncio.wait_for(self._receive_task, timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Timed out waiting for receive task to finish")
|
||||
self._receive_task = None
|
||||
self._disconnecting = False
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error disconnecting: {e}")
|
||||
|
||||
async def _ws_send(self, realtime_message):
|
||||
try:
|
||||
if self._websocket:
|
||||
await self._websocket.send(json.dumps(realtime_message))
|
||||
except Exception as e:
|
||||
if self._disconnecting:
|
||||
return
|
||||
logger.error(f"Error sending message to websocket: {e}")
|
||||
# In server-to-server contexts, a WebSocket error should be quite rare. Given how hard
|
||||
# it is to recover from a send-side error with proper state management, and that exponential
|
||||
# backoff for retries can have cost/stability implications for a service cluster, let's just
|
||||
# treat a send-side error as fatal.
|
||||
await self.push_error(ErrorFrame(error=f"Error sending client event: {e}", fatal=True))
|
||||
|
||||
async def _update_settings(self):
|
||||
settings = self._session_properties
|
||||
# tools given in the context override the tools in the session properties
|
||||
if self._context and self._context.tools:
|
||||
settings.tools = self._context.tools
|
||||
# instructions in the context come from an initial "system" message in the
|
||||
# messages list, and override instructions in the session properties
|
||||
if self._context and self._context._session_instructions:
|
||||
settings.instructions = self._context._session_instructions
|
||||
await self.send_client_event(events.SessionUpdateEvent(session=settings))
|
||||
|
||||
#
|
||||
# inbound server event handling
|
||||
# https://platform.openai.com/docs/api-reference/realtime-server-events
|
||||
#
|
||||
|
||||
async def _receive_task_handler(self):
|
||||
try:
|
||||
async for message in self._websocket:
|
||||
evt = events.parse_server_event(message)
|
||||
if evt.type == "session.created":
|
||||
await self._handle_evt_session_created(evt)
|
||||
elif evt.type == "session.updated":
|
||||
await self._handle_evt_session_updated(evt)
|
||||
elif evt.type == "response.audio.delta":
|
||||
await self._handle_evt_audio_delta(evt)
|
||||
elif evt.type == "response.audio.done":
|
||||
await self._handle_evt_audio_done(evt)
|
||||
elif evt.type == "conversation.item.created":
|
||||
await self._handle_evt_conversation_item_created(evt)
|
||||
elif evt.type == "conversation.item.input_audio_transcription.completed":
|
||||
await self.handle_evt_input_audio_transcription_completed(evt)
|
||||
elif evt.type == "response.done":
|
||||
await self._handle_evt_response_done(evt)
|
||||
elif evt.type == "input_audio_buffer.speech_started":
|
||||
await self._handle_evt_speech_started(evt)
|
||||
elif evt.type == "input_audio_buffer.speech_stopped":
|
||||
await self._handle_evt_speech_stopped(evt)
|
||||
elif evt.type == "response.audio_transcript.delta":
|
||||
await self._handle_evt_audio_transcript_delta(evt)
|
||||
elif evt.type == "error":
|
||||
await self._handle_evt_error(evt)
|
||||
# errors are fatal, so exit the receive loop
|
||||
return
|
||||
|
||||
else:
|
||||
# logger.debug(f"!!! Unhandled event: {evt}")
|
||||
pass
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("websocket receive task cancelled")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
|
||||
async def _handle_evt_session_created(self, evt):
|
||||
# session.created is received right after connecting. Send a message
|
||||
# to configure the session properties.
|
||||
await self._update_settings()
|
||||
|
||||
async def _handle_evt_session_updated(self, evt):
|
||||
# If this is our first context frame, run the LLM
|
||||
self._api_session_ready = True
|
||||
# Now that we've configured the session, we can run the LLM if we need to.
|
||||
if self._run_llm_when_api_session_ready:
|
||||
self._run_llm_when_api_session_ready = False
|
||||
await self._create_response()
|
||||
|
||||
async def _handle_evt_audio_delta(self, evt):
|
||||
# note: ttfb is faster by 1/2 RTT than ttfb as measured for other services, since we're getting
|
||||
# this event from the server
|
||||
await self.stop_ttfb_metrics()
|
||||
if not self._current_audio_response:
|
||||
self._current_audio_response = _CurrentAudioResponse(
|
||||
item_id=evt.item_id,
|
||||
content_index=evt.content_index,
|
||||
start_time_ms=int(time.time() * 1000),
|
||||
)
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
audio = base64.b64decode(evt.delta)
|
||||
self._current_audio_response.total_size += len(audio)
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=audio,
|
||||
sample_rate=24000,
|
||||
num_channels=1,
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def _handle_evt_audio_done(self, evt):
|
||||
if self._current_audio_response:
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
# Don't clear the self._current_audio_response here. We need to wait until we
|
||||
# receive a BotStoppedSpeakingFrame from the output transport.
|
||||
|
||||
async def _handle_evt_conversation_item_created(self, evt):
|
||||
# This will get sent from the server every time a new "message" is added
|
||||
# to the server's conversation state, whether we create it via the API
|
||||
# or the server creates it from LLM output.
|
||||
if self._messages_added_manually.get(evt.item.id):
|
||||
del self._messages_added_manually[evt.item.id]
|
||||
return
|
||||
|
||||
if evt.item.role == "user":
|
||||
# We need to wait for completion of both user message and response message. Then we'll
|
||||
# add both to the context. User message is complete when we have a "transcript" field
|
||||
# that is not None. Response message is complete when we get a "response.done" event.
|
||||
self._user_and_response_message_tuple = (evt.item, {"done": False, "output": []})
|
||||
elif evt.item.role == "assistant":
|
||||
self._current_assistant_response = evt.item
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
|
||||
async def handle_evt_input_audio_transcription_completed(self, evt):
|
||||
if self._send_transcription_frames:
|
||||
await self.push_frame(
|
||||
# no way to get a language code?
|
||||
TranscriptionFrame(evt.transcript, "", time_now_iso8601())
|
||||
)
|
||||
pair = self._user_and_response_message_tuple
|
||||
if pair:
|
||||
user, assistant = pair
|
||||
user.content[0].transcript = evt.transcript
|
||||
if assistant["done"]:
|
||||
self._user_and_response_message_tuple = None
|
||||
self._context.add_user_content_item_as_message(user)
|
||||
await self._handle_assistant_output(assistant["output"])
|
||||
else:
|
||||
# User message without preceding conversation.item.created. Bug?
|
||||
logger.warning(f"Transcript for unknown user message: {evt}")
|
||||
|
||||
async def _handle_evt_response_done(self, evt):
|
||||
# todo: figure out whether there's anything we need to do for "cancelled" events
|
||||
# usage metrics
|
||||
tokens = LLMTokenUsage(
|
||||
prompt_tokens=evt.response.usage.input_tokens,
|
||||
completion_tokens=evt.response.usage.output_tokens,
|
||||
total_tokens=evt.response.usage.total_tokens,
|
||||
)
|
||||
await self.start_llm_usage_metrics(tokens)
|
||||
await self.stop_processing_metrics()
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
self._current_assistant_response = None
|
||||
# response content
|
||||
pair = self._user_and_response_message_tuple
|
||||
if pair:
|
||||
user, assistant = pair
|
||||
assistant["done"] = True
|
||||
assistant["output"] = evt.response.output
|
||||
if user.content[0].transcript is not None:
|
||||
self._user_and_response_message_tuple = None
|
||||
self._context.add_user_content_item_as_message(user)
|
||||
await self._handle_assistant_output(assistant["output"])
|
||||
else:
|
||||
# Response message without preceding user message. Add it to the context.
|
||||
await self._handle_assistant_output(evt.response.output)
|
||||
|
||||
async def _handle_evt_audio_transcript_delta(self, evt):
|
||||
if evt.delta:
|
||||
await self.push_frame(TextFrame(evt.delta))
|
||||
|
||||
async def _handle_evt_speech_started(self, evt):
|
||||
await self._truncate_current_audio_response()
|
||||
# todo: might need to guard sending these when we fully support using either openai
|
||||
# turn detection of Pipecat turn detection
|
||||
await self._start_interruption() # cancels this processor task
|
||||
await self.push_frame(StartInterruptionFrame()) # cancels downstream tasks
|
||||
await self.push_frame(UserStartedSpeakingFrame())
|
||||
|
||||
async def _handle_evt_speech_stopped(self, evt):
|
||||
await self.start_ttfb_metrics()
|
||||
await self.start_processing_metrics()
|
||||
await self._stop_interruption()
|
||||
await self.push_frame(StopInterruptionFrame())
|
||||
await self.push_frame(UserStoppedSpeakingFrame())
|
||||
|
||||
async def _handle_evt_error(self, evt):
|
||||
# Errors are fatal to this connection. Send an ErrorFrame.
|
||||
await self.push_error(ErrorFrame(error=f"Error: {evt}", fatal=True))
|
||||
|
||||
async def _handle_assistant_output(self, output):
|
||||
# logger.debug(f"!!! HANDLE Assistant output: {output}")
|
||||
# We haven't seen intermixed audio and function_call items in the same response. But let's
|
||||
# try to write logic that handles that, if it does happen.
|
||||
messages = [item for item in output if item.type == "message"]
|
||||
function_calls = [item for item in output if item.type == "function_call"]
|
||||
for item in messages:
|
||||
self._context.add_assistant_content_item_as_message(item)
|
||||
await self._handle_function_call_items(function_calls)
|
||||
|
||||
async def _handle_function_call_items(self, items):
|
||||
total_items = len(items)
|
||||
for index, item in enumerate(items):
|
||||
function_name = item.name
|
||||
tool_id = item.call_id
|
||||
arguments = json.loads(item.arguments)
|
||||
if self.has_function(function_name):
|
||||
run_llm = index == total_items - 1
|
||||
if function_name in self._callbacks.keys():
|
||||
await self.call_function(
|
||||
context=self._context,
|
||||
tool_call_id=tool_id,
|
||||
function_name=function_name,
|
||||
arguments=arguments,
|
||||
run_llm=run_llm,
|
||||
)
|
||||
elif None in self._callbacks.keys():
|
||||
await self.call_function(
|
||||
context=self._context,
|
||||
tool_call_id=tool_id,
|
||||
function_name=function_name,
|
||||
arguments=arguments,
|
||||
run_llm=run_llm,
|
||||
)
|
||||
else:
|
||||
raise OpenAIUnhandledFunctionException(
|
||||
f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function."
|
||||
)
|
||||
|
||||
#
|
||||
# state and client events for the current conversation
|
||||
# https://platform.openai.com/docs/api-reference/realtime-client-events
|
||||
#
|
||||
|
||||
async def reset_conversation(self):
|
||||
# Disconnect/reconnect is the safest way to start a new conversation.
|
||||
# Note that this will fail if called from the receive task.
|
||||
logger.debug("Resetting conversation")
|
||||
await self._disconnect()
|
||||
if self._context:
|
||||
self._context.llm_needs_settings_update = True
|
||||
self._context.llm_needs_initial_messages = True
|
||||
await self._connect()
|
||||
|
||||
async def _create_response(self):
|
||||
if not self._api_session_ready:
|
||||
self._run_llm_when_api_session_ready = True
|
||||
return
|
||||
|
||||
if self._context.llm_needs_initial_messages:
|
||||
messages = self._context.get_messages_for_initializing_history()
|
||||
for item in messages:
|
||||
evt = events.ConversationItemCreateEvent(item=item)
|
||||
self._messages_added_manually[evt.item.id] = True
|
||||
await self.send_client_event(evt)
|
||||
self._context.llm_needs_initial_messages = False
|
||||
|
||||
if self._context.llm_needs_settings_update:
|
||||
await self._update_settings()
|
||||
self._context.llm_needs_settings_update = False
|
||||
|
||||
logger.debug(f"Creating response: {self._context.get_messages_for_logging()}")
|
||||
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
await self.start_processing_metrics()
|
||||
await self.start_ttfb_metrics()
|
||||
await self.send_client_event(
|
||||
events.ResponseCreateEvent(
|
||||
response=events.ResponseProperties(modalities=["audio", "text"])
|
||||
)
|
||||
)
|
||||
|
||||
async def _send_user_audio(self, frame):
|
||||
payload = base64.b64encode(frame.audio).decode("utf-8")
|
||||
await self.send_client_event(events.InputAudioBufferAppendEvent(audio=payload))
|
||||
|
||||
def create_context_aggregator(
|
||||
self, context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = False
|
||||
) -> OpenAIContextAggregatorPair:
|
||||
OpenAIRealtimeLLMContext.upgrade_to_realtime(context)
|
||||
user = OpenAIRealtimeUserContextAggregator(context)
|
||||
assistant = OpenAIRealtimeAssistantContextAggregator(
|
||||
user, expect_stripped_words=assistant_expect_stripped_words
|
||||
)
|
||||
return OpenAIContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
@@ -6,17 +6,21 @@
|
||||
|
||||
import io
|
||||
import struct
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from pipecat.frames.frames import Frame, TTSAudioRawFrame, TTSStartedFrame, TTSStoppedFrame
|
||||
from pipecat.services.ai_services import TTSService
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.ai_services import TTSService
|
||||
|
||||
try:
|
||||
from pyht.client import TTSOptions
|
||||
from pyht.async_client import AsyncClient
|
||||
from pyht.client import TTSOptions
|
||||
from pyht.protos.api_pb2 import Format
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
@@ -39,17 +43,23 @@ class PlayHTTTSService(TTSService):
|
||||
user_id=self._user_id,
|
||||
api_key=self._speech_key,
|
||||
)
|
||||
self._settings = {
|
||||
"sample_rate": sample_rate,
|
||||
"quality": "higher",
|
||||
"format": Format.FORMAT_WAV,
|
||||
"voice_engine": "PlayHT2.0-turbo",
|
||||
}
|
||||
self.set_voice(voice_url)
|
||||
self._options = TTSOptions(
|
||||
voice=voice_url, sample_rate=sample_rate, quality="higher", format=Format.FORMAT_WAV
|
||||
voice=self._voice_id,
|
||||
sample_rate=self._settings["sample_rate"],
|
||||
quality=self._settings["quality"],
|
||||
format=self._settings["format"],
|
||||
)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def set_voice(self, voice: str):
|
||||
logger.debug(f"Switching TTS voice to: [{voice}]")
|
||||
self._options.voice = voice
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Generating TTS: [{text}]")
|
||||
|
||||
@@ -60,12 +70,12 @@ class PlayHTTTSService(TTSService):
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
playht_gen = self._client.tts(
|
||||
text, voice_engine="PlayHT2.0-turbo", options=self._options
|
||||
text, voice_engine=self._settings["voice_engine"], options=self._options
|
||||
)
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
yield TTSStartedFrame()
|
||||
async for chunk in playht_gen:
|
||||
# skip the RIFF header.
|
||||
if in_header:
|
||||
@@ -83,8 +93,8 @@ class PlayHTTTSService(TTSService):
|
||||
else:
|
||||
if len(chunk):
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(chunk, 16000, 1)
|
||||
frame = TTSAudioRawFrame(chunk, self._settings["sample_rate"], 1)
|
||||
yield frame
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
yield TTSStoppedFrame()
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} error generating TTS: {e}")
|
||||
|
||||
@@ -4,42 +4,18 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from asyncio import CancelledError
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesFrame,
|
||||
LLMUpdateSettingsFrame,
|
||||
StartInterruptionFrame,
|
||||
TextFrame,
|
||||
UserImageRequestFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantContextAggregator,
|
||||
LLMUserContextAggregator,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import LLMService
|
||||
from pipecat.services.openai import OpenAILLMService
|
||||
|
||||
try:
|
||||
from together import AsyncTogether
|
||||
# Together.ai is recommending OpenAI-compatible function calling, so we've switched over
|
||||
# to using the OpenAI client library here rather than the Together Python client library.
|
||||
from openai import AsyncOpenAI, DefaultAsyncHttpxClient
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
@@ -48,19 +24,7 @@ except ModuleNotFoundError as e:
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class TogetherContextAggregatorPair:
|
||||
_user: "TogetherUserContextAggregator"
|
||||
_assistant: "TogetherAssistantContextAggregator"
|
||||
|
||||
def user(self) -> "TogetherUserContextAggregator":
|
||||
return self._user
|
||||
|
||||
def assistant(self) -> "TogetherAssistantContextAggregator":
|
||||
return self._assistant
|
||||
|
||||
|
||||
class TogetherLLMService(LLMService):
|
||||
class TogetherLLMService(OpenAILLMService):
|
||||
"""This class implements inference with Together's Llama 3.1 models"""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
@@ -68,327 +32,45 @@ class TogetherLLMService(LLMService):
|
||||
max_tokens: Optional[int] = Field(default=4096, ge=1)
|
||||
presence_penalty: Optional[float] = Field(default=None, ge=-2.0, le=2.0)
|
||||
temperature: Optional[float] = Field(default=None, ge=0.0, le=1.0)
|
||||
# Note: top_k is currently not supported by the OpenAI client library,
|
||||
# so top_k is ignore right now.
|
||||
top_k: Optional[int] = Field(default=None, ge=0)
|
||||
top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0)
|
||||
extra: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
seed: Optional[int] = Field(default=None)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
base_url: str = "https://api.together.xyz/v1",
|
||||
model: str = "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
|
||||
params: InputParams = InputParams(),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._client = AsyncTogether(api_key=api_key)
|
||||
super().__init__(api_key=api_key, base_url=base_url, model=model, params=params, **kwargs)
|
||||
self.set_model_name(model)
|
||||
self._max_tokens = params.max_tokens
|
||||
self._frequency_penalty = params.frequency_penalty
|
||||
self._presence_penalty = params.presence_penalty
|
||||
self._temperature = params.temperature
|
||||
self._top_k = params.top_k
|
||||
self._top_p = params.top_p
|
||||
self._extra = params.extra if isinstance(params.extra, dict) else {}
|
||||
self._settings = {
|
||||
"max_tokens": params.max_tokens,
|
||||
"frequency_penalty": params.frequency_penalty,
|
||||
"presence_penalty": params.presence_penalty,
|
||||
"seed": params.seed,
|
||||
"temperature": params.temperature,
|
||||
"top_p": params.top_p,
|
||||
"extra": params.extra if isinstance(params.extra, dict) else {},
|
||||
}
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def create_context_aggregator(context: OpenAILLMContext) -> TogetherContextAggregatorPair:
|
||||
user = TogetherUserContextAggregator(context)
|
||||
assistant = TogetherAssistantContextAggregator(user)
|
||||
return TogetherContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
async def set_frequency_penalty(self, frequency_penalty: float):
|
||||
logger.debug(f"Switching LLM frequency_penalty to: [{frequency_penalty}]")
|
||||
self._frequency_penalty = frequency_penalty
|
||||
|
||||
async def set_max_tokens(self, max_tokens: int):
|
||||
logger.debug(f"Switching LLM max_tokens to: [{max_tokens}]")
|
||||
self._max_tokens = max_tokens
|
||||
|
||||
async def set_presence_penalty(self, presence_penalty: float):
|
||||
logger.debug(f"Switching LLM presence_penalty to: [{presence_penalty}]")
|
||||
self._presence_penalty = presence_penalty
|
||||
|
||||
async def set_temperature(self, temperature: float):
|
||||
logger.debug(f"Switching LLM temperature to: [{temperature}]")
|
||||
self._temperature = temperature
|
||||
|
||||
async def set_top_k(self, top_k: float):
|
||||
logger.debug(f"Switching LLM top_k to: [{top_k}]")
|
||||
self._top_k = top_k
|
||||
|
||||
async def set_top_p(self, top_p: float):
|
||||
logger.debug(f"Switching LLM top_p to: [{top_p}]")
|
||||
self._top_p = top_p
|
||||
|
||||
async def set_extra(self, extra: Dict[str, Any]):
|
||||
logger.debug(f"Switching LLM extra to: [{extra}]")
|
||||
self._extra = extra
|
||||
|
||||
async def _update_settings(self, frame: LLMUpdateSettingsFrame):
|
||||
if frame.model is not None:
|
||||
logger.debug(f"Switching LLM model to: [{frame.model}]")
|
||||
self.set_model_name(frame.model)
|
||||
if frame.frequency_penalty is not None:
|
||||
await self.set_frequency_penalty(frame.frequency_penalty)
|
||||
if frame.max_tokens is not None:
|
||||
await self.set_max_tokens(frame.max_tokens)
|
||||
if frame.presence_penalty is not None:
|
||||
await self.set_presence_penalty(frame.presence_penalty)
|
||||
if frame.temperature is not None:
|
||||
await self.set_temperature(frame.temperature)
|
||||
if frame.top_k is not None:
|
||||
await self.set_top_k(frame.top_k)
|
||||
if frame.top_p is not None:
|
||||
await self.set_top_p(frame.top_p)
|
||||
if frame.extra:
|
||||
await self.set_extra(frame.extra)
|
||||
|
||||
async def _process_context(self, context: OpenAILLMContext):
|
||||
try:
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
await self.start_processing_metrics()
|
||||
|
||||
logger.debug(f"Generating chat: {context.get_messages_for_logging()}")
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
params = {
|
||||
"messages": context.messages,
|
||||
"model": self.model_name,
|
||||
"max_tokens": self._max_tokens,
|
||||
"stream": True,
|
||||
"frequency_penalty": self._frequency_penalty,
|
||||
"presence_penalty": self._presence_penalty,
|
||||
"temperature": self._temperature,
|
||||
"top_k": self._top_k,
|
||||
"top_p": self._top_p,
|
||||
}
|
||||
|
||||
params.update(self._extra)
|
||||
|
||||
stream = await self._client.chat.completions.create(**params)
|
||||
|
||||
# Function calling
|
||||
got_first_chunk = False
|
||||
accumulating_function_call = False
|
||||
function_call_accumulator = ""
|
||||
|
||||
async for chunk in stream:
|
||||
# logger.debug(f"Together LLM event: {chunk}")
|
||||
if chunk.usage:
|
||||
tokens = LLMTokenUsage(
|
||||
prompt_tokens=chunk.usage.prompt_tokens,
|
||||
completion_tokens=chunk.usage.completion_tokens,
|
||||
total_tokens=chunk.usage.total_tokens,
|
||||
)
|
||||
await self.start_llm_usage_metrics(tokens)
|
||||
|
||||
if len(chunk.choices) == 0:
|
||||
continue
|
||||
|
||||
if not got_first_chunk:
|
||||
await self.stop_ttfb_metrics()
|
||||
if chunk.choices[0].delta.content:
|
||||
got_first_chunk = True
|
||||
if chunk.choices[0].delta.content[0] == "<":
|
||||
accumulating_function_call = True
|
||||
|
||||
if chunk.choices[0].delta.content:
|
||||
if accumulating_function_call:
|
||||
function_call_accumulator += chunk.choices[0].delta.content
|
||||
else:
|
||||
await self.push_frame(TextFrame(chunk.choices[0].delta.content))
|
||||
|
||||
if chunk.choices[0].finish_reason == "eos" and accumulating_function_call:
|
||||
await self._extract_function_call(context, function_call_accumulator)
|
||||
|
||||
except CancelledError:
|
||||
# todo: implement token counting estimates for use when the user interrupts a long generation
|
||||
# we do this in the anthropic.py service
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
finally:
|
||||
await self.stop_processing_metrics()
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
context = None
|
||||
if isinstance(frame, OpenAILLMContextFrame):
|
||||
context = frame.context
|
||||
elif isinstance(frame, LLMMessagesFrame):
|
||||
context = TogetherLLMContext.from_messages(frame.messages)
|
||||
elif isinstance(frame, LLMUpdateSettingsFrame):
|
||||
await self._update_settings(frame)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
if context:
|
||||
await self._process_context(context)
|
||||
|
||||
async def _extract_function_call(self, context, function_call_accumulator):
|
||||
context.add_message({"role": "assistant", "content": function_call_accumulator})
|
||||
|
||||
function_regex = r"<function=(\w+)>(.*?)</function>"
|
||||
match = re.search(function_regex, function_call_accumulator)
|
||||
if match:
|
||||
function_name, args_string = match.groups()
|
||||
try:
|
||||
arguments = json.loads(args_string)
|
||||
await self.call_function(
|
||||
context=context,
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
function_name=function_name,
|
||||
arguments=arguments,
|
||||
def create_client(self, api_key=None, base_url=None, **kwargs):
|
||||
logger.debug(f"Creating Together.ai client with api {base_url}")
|
||||
return AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
http_client=DefaultAsyncHttpxClient(
|
||||
limits=httpx.Limits(
|
||||
max_keepalive_connections=100, max_connections=1000, keepalive_expiry=None
|
||||
)
|
||||
return
|
||||
except json.JSONDecodeError as error:
|
||||
# We get here if the LLM returns a function call with invalid JSON arguments. This could happen
|
||||
# because of LLM non-determinism, or maybe more often because of user error in the prompt.
|
||||
# Should we do anything more than log a warning?
|
||||
logger.debug(f"Error parsing function arguments: {error}")
|
||||
|
||||
|
||||
class TogetherLLMContext(OpenAILLMContext):
|
||||
def __init__(
|
||||
self,
|
||||
messages: list[dict] | None = None,
|
||||
):
|
||||
super().__init__(messages=messages)
|
||||
|
||||
@classmethod
|
||||
def from_openai_context(cls, openai_context: OpenAILLMContext):
|
||||
self = cls(
|
||||
messages=openai_context.messages,
|
||||
),
|
||||
)
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_messages(cls, messages: List[dict]) -> "TogetherLLMContext":
|
||||
return cls(messages=messages)
|
||||
|
||||
def add_message(self, message):
|
||||
try:
|
||||
self.messages.append(message)
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding message: {e}")
|
||||
|
||||
def get_messages_for_logging(self) -> str:
|
||||
return json.dumps(self.messages)
|
||||
|
||||
|
||||
class TogetherUserContextAggregator(LLMUserContextAggregator):
|
||||
def __init__(self, context: OpenAILLMContext | TogetherLLMContext):
|
||||
super().__init__(context=context)
|
||||
|
||||
if isinstance(context, OpenAILLMContext):
|
||||
self._context = TogetherLLMContext.from_openai_context(context)
|
||||
|
||||
async def push_messages_frame(self):
|
||||
frame = OpenAILLMContextFrame(self._context)
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def process_frame(self, frame, direction):
|
||||
await super().process_frame(frame, direction)
|
||||
# Our parent method has already called push_frame(). So we can't interrupt the
|
||||
# flow here and we don't need to call push_frame() ourselves. Possibly something
|
||||
# to talk through (tagging @aleix). At some point we might need to refactor these
|
||||
# context aggregators.
|
||||
try:
|
||||
if isinstance(frame, UserImageRequestFrame):
|
||||
# The LLM sends a UserImageRequestFrame upstream. Cache any context provided with
|
||||
# that frame so we can use it when we assemble the image message in the assistant
|
||||
# context aggregator.
|
||||
if frame.context:
|
||||
if isinstance(frame.context, str):
|
||||
self._context._user_image_request_context[frame.user_id] = frame.context
|
||||
else:
|
||||
logger.error(
|
||||
f"Unexpected UserImageRequestFrame context type: {type(frame.context)}"
|
||||
)
|
||||
del self._context._user_image_request_context[frame.user_id]
|
||||
else:
|
||||
if frame.user_id in self._context._user_image_request_context:
|
||||
del self._context._user_image_request_context[frame.user_id]
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing frame: {e}")
|
||||
|
||||
|
||||
#
|
||||
# Claude returns a text content block along with a tool use content block. This works quite nicely
|
||||
# with streaming. We get the text first, so we can start streaming it right away. Then we get the
|
||||
# tool_use block. While the text is streaming to TTS and the transport, we can run the tool call.
|
||||
#
|
||||
# But Claude is verbose. It would be nice to come up with prompt language that suppresses Claude's
|
||||
# chattiness about it's tool thinking.
|
||||
#
|
||||
|
||||
|
||||
class TogetherAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
def __init__(self, user_context_aggregator: TogetherUserContextAggregator):
|
||||
super().__init__(context=user_context_aggregator._context)
|
||||
self._user_context_aggregator = user_context_aggregator
|
||||
self._function_call_in_progress = None
|
||||
self._function_call_result = None
|
||||
|
||||
async def process_frame(self, frame, direction):
|
||||
await super().process_frame(frame, direction)
|
||||
# See note above about not calling push_frame() here.
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
self._function_call_in_progress = None
|
||||
self._function_call_finished = None
|
||||
elif isinstance(frame, FunctionCallInProgressFrame):
|
||||
self._function_call_in_progress = frame
|
||||
elif isinstance(frame, FunctionCallResultFrame):
|
||||
if (
|
||||
self._function_call_in_progress
|
||||
and self._function_call_in_progress.tool_call_id == frame.tool_call_id
|
||||
):
|
||||
self._function_call_in_progress = None
|
||||
self._function_call_result = frame
|
||||
await self._push_aggregation()
|
||||
else:
|
||||
logger.warning(
|
||||
"FunctionCallResultFrame tool_call_id does not match FunctionCallInProgressFrame tool_call_id"
|
||||
)
|
||||
self._function_call_in_progress = None
|
||||
self._function_call_result = None
|
||||
|
||||
def add_message(self, message):
|
||||
self._user_context_aggregator.add_message(message)
|
||||
|
||||
async def _push_aggregation(self):
|
||||
if not (self._aggregation or self._function_call_result):
|
||||
return
|
||||
|
||||
run_llm = False
|
||||
|
||||
aggregation = self._aggregation
|
||||
self._aggregation = ""
|
||||
|
||||
try:
|
||||
if self._function_call_result:
|
||||
frame = self._function_call_result
|
||||
self._function_call_result = None
|
||||
self._context.add_message(
|
||||
{
|
||||
"role": "tool",
|
||||
# Together expects the content here to be a string, so stringify it
|
||||
"content": str(frame.result),
|
||||
}
|
||||
)
|
||||
run_llm = True
|
||||
else:
|
||||
self._context.add_message({"role": "assistant", "content": aggregation})
|
||||
|
||||
if run_llm:
|
||||
await self._user_context_aggregator.push_messages_frame()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing frame: {e}")
|
||||
|
||||
@@ -4,10 +4,12 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import aiohttp
|
||||
|
||||
from typing import Any, AsyncGenerator, Dict
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
@@ -17,10 +19,7 @@ from pipecat.frames.frames import (
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.ai_services import TTSService
|
||||
|
||||
import numpy as np
|
||||
|
||||
from loguru import logger
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
try:
|
||||
import resampy
|
||||
@@ -43,25 +42,70 @@ class XTTSService(TTSService):
|
||||
self,
|
||||
*,
|
||||
voice_id: str,
|
||||
language: str,
|
||||
language: Language,
|
||||
base_url: str,
|
||||
aiohttp_session: aiohttp.ClientSession,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._voice_id = voice_id
|
||||
self._language = language
|
||||
self._base_url = base_url
|
||||
self._settings = {
|
||||
"language": self.language_to_service_language(language),
|
||||
"base_url": base_url,
|
||||
}
|
||||
self.set_voice(voice_id)
|
||||
self._studio_speakers: Dict[str, Any] | None = None
|
||||
self._aiohttp_session = aiohttp_session
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> str | None:
|
||||
match language:
|
||||
case Language.CS:
|
||||
return "cs"
|
||||
case Language.DE:
|
||||
return "de"
|
||||
case (
|
||||
Language.EN
|
||||
| Language.EN_US
|
||||
| Language.EN_AU
|
||||
| Language.EN_GB
|
||||
| Language.EN_NZ
|
||||
| Language.EN_IN
|
||||
):
|
||||
return "en"
|
||||
case Language.ES:
|
||||
return "es"
|
||||
case Language.FR:
|
||||
return "fr"
|
||||
case Language.HI:
|
||||
return "hi"
|
||||
case Language.HU:
|
||||
return "hu"
|
||||
case Language.IT:
|
||||
return "it"
|
||||
case Language.JA:
|
||||
return "ja"
|
||||
case Language.KO:
|
||||
return "ko"
|
||||
case Language.NL:
|
||||
return "nl"
|
||||
case Language.PL:
|
||||
return "pl"
|
||||
case Language.PT | Language.PT_BR:
|
||||
return "pt"
|
||||
case Language.RU:
|
||||
return "ru"
|
||||
case Language.TR:
|
||||
return "tr"
|
||||
case Language.ZH:
|
||||
return "zh-cn"
|
||||
return None
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
async with self._aiohttp_session.get(self._base_url + "/studio_speakers") as r:
|
||||
async with self._aiohttp_session.get(self._settings["base_url"] + "/studio_speakers") as r:
|
||||
if r.status != 200:
|
||||
text = await r.text()
|
||||
logger.error(
|
||||
@@ -75,10 +119,6 @@ class XTTSService(TTSService):
|
||||
return
|
||||
self._studio_speakers = await r.json()
|
||||
|
||||
async def set_voice(self, voice: str):
|
||||
logger.debug(f"Switching TTS voice to: [{voice}]")
|
||||
self._voice_id = voice
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Generating TTS: [{text}]")
|
||||
|
||||
@@ -88,11 +128,11 @@ class XTTSService(TTSService):
|
||||
|
||||
embeddings = self._studio_speakers[self._voice_id]
|
||||
|
||||
url = self._base_url + "/tts_stream"
|
||||
url = self._settings["base_url"] + "/tts_stream"
|
||||
|
||||
payload = {
|
||||
"text": text.replace(".", "").replace("*", ""),
|
||||
"language": self._language,
|
||||
"language": self._settings["language"],
|
||||
"speaker_embedding": embeddings["speaker_embedding"],
|
||||
"gpt_cond_latent": embeddings["gpt_cond_latent"],
|
||||
"add_wav_header": False,
|
||||
@@ -110,7 +150,7 @@ class XTTSService(TTSService):
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
yield TTSStartedFrame()
|
||||
|
||||
buffer = bytearray()
|
||||
async for chunk in r.content.iter_chunked(1024):
|
||||
@@ -146,4 +186,4 @@ class XTTSService(TTSService):
|
||||
frame = TTSAudioRawFrame(resampled_audio_bytes, 16000, 1)
|
||||
yield frame
|
||||
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -5,17 +5,17 @@
|
||||
#
|
||||
|
||||
import asyncio
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotInterruptionFrame,
|
||||
CancelFrame,
|
||||
InputAudioRawFrame,
|
||||
StartFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
StopInterruptionFrame,
|
||||
SystemFrame,
|
||||
@@ -23,15 +23,14 @@ from pipecat.frames.frames import (
|
||||
UserStoppedSpeakingFrame,
|
||||
VADParamsUpdateFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.transports.base_transport import TransportParams
|
||||
from pipecat.vad.vad_analyzer import VADAnalyzer, VADState
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class BaseInputTransport(FrameProcessor):
|
||||
def __init__(self, params: TransportParams, **kwargs):
|
||||
super().__init__(sync=False, **kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._params = params
|
||||
|
||||
@@ -87,6 +86,7 @@ class BaseInputTransport(FrameProcessor):
|
||||
elif isinstance(frame, BotInterruptionFrame):
|
||||
logger.debug("Bot interruption")
|
||||
await self._start_interruption()
|
||||
await self.push_frame(StartInterruptionFrame())
|
||||
# All other system frames
|
||||
elif isinstance(frame, SystemFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -1,49 +1,45 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import itertools
|
||||
import time
|
||||
import sys
|
||||
|
||||
from PIL import Image
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from loguru import logger
|
||||
from PIL import Image
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotSpeakingFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
MetricsFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
OutputAudioRawFrame,
|
||||
OutputImageRawFrame,
|
||||
SpriteFrame,
|
||||
StartFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
StartInterruptionFrame,
|
||||
StopInterruptionFrame,
|
||||
SystemFrame,
|
||||
TransportMessageFrame,
|
||||
TransportMessageUrgentFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
TextFrame,
|
||||
TransportMessageFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.transports.base_transport import TransportParams
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.utils.time import nanoseconds_to_seconds
|
||||
|
||||
|
||||
class BaseOutputTransport(FrameProcessor):
|
||||
def __init__(self, params: TransportParams, **kwargs):
|
||||
super().__init__(sync=False, **kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._params = params
|
||||
|
||||
@@ -96,15 +92,6 @@ class BaseOutputTransport(FrameProcessor):
|
||||
self._audio_out_task = self.get_event_loop().create_task(self._audio_out_task_handler())
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
# At this point we have enqueued an EndFrame and we need to wait for
|
||||
# that EndFrame to be processed by the sink tasks. We also need to wait
|
||||
# for these tasks before cancelling the camera and audio tasks below
|
||||
# because they might be still rendering.
|
||||
if self._sink_task:
|
||||
await self._sink_task
|
||||
if self._sink_clock_task:
|
||||
await self._sink_clock_task
|
||||
|
||||
# Cancel and wait for the camera output task to finish.
|
||||
if self._camera_out_task and self._params.camera_out_enabled:
|
||||
self._camera_out_task.cancel()
|
||||
@@ -148,10 +135,7 @@ class BaseOutputTransport(FrameProcessor):
|
||||
await self._audio_out_task
|
||||
self._audio_out_task = None
|
||||
|
||||
async def send_message(self, frame: TransportMessageFrame):
|
||||
pass
|
||||
|
||||
async def send_metrics(self, frame: MetricsFrame):
|
||||
async def send_message(self, frame: TransportMessageFrame | TransportMessageUrgentFrame):
|
||||
pass
|
||||
|
||||
async def write_frame_to_camera(self, frame: OutputImageRawFrame):
|
||||
@@ -180,32 +164,46 @@ class BaseOutputTransport(FrameProcessor):
|
||||
elif isinstance(frame, CancelFrame):
|
||||
await self.cancel(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, StartInterruptionFrame) or isinstance(frame, StopInterruptionFrame):
|
||||
elif isinstance(frame, (StartInterruptionFrame, StopInterruptionFrame)):
|
||||
await self.push_frame(frame, direction)
|
||||
await self._handle_interruptions(frame)
|
||||
elif isinstance(frame, MetricsFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
await self.send_metrics(frame)
|
||||
elif isinstance(frame, TransportMessageUrgentFrame):
|
||||
await self.send_message(frame)
|
||||
elif isinstance(frame, SystemFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
# Control frames.
|
||||
elif isinstance(frame, EndFrame):
|
||||
await self._sink_clock_queue.put((sys.maxsize, frame.id, frame))
|
||||
await self._sink_queue.put(frame)
|
||||
# Process sink tasks.
|
||||
await self._stop_sink_tasks(frame)
|
||||
# Now we can stop.
|
||||
await self.stop(frame)
|
||||
# We finally push EndFrame down so PipelineTask stops nicely.
|
||||
await self.push_frame(frame, direction)
|
||||
# Other frames.
|
||||
elif isinstance(frame, OutputAudioRawFrame):
|
||||
await self._handle_audio(frame)
|
||||
elif isinstance(frame, OutputImageRawFrame) or isinstance(frame, SpriteFrame):
|
||||
elif isinstance(frame, (OutputImageRawFrame, SpriteFrame)):
|
||||
await self._handle_image(frame)
|
||||
elif isinstance(frame, TransportMessageFrame) and frame.urgent:
|
||||
await self.send_message(frame)
|
||||
# TODO(aleix): Images and audio should support presentation timestamps.
|
||||
elif frame.pts:
|
||||
await self._sink_clock_queue.put((frame.pts, frame.id, frame))
|
||||
else:
|
||||
await self._sink_queue.put(frame)
|
||||
|
||||
async def _stop_sink_tasks(self, frame: EndFrame):
|
||||
# Let the sink tasks process the queue until they reach this EndFrame.
|
||||
await self._sink_clock_queue.put((sys.maxsize, frame.id, frame))
|
||||
await self._sink_queue.put(frame)
|
||||
|
||||
# At this point we have enqueued an EndFrame and we need to wait for
|
||||
# that EndFrame to be processed by the sink tasks. We also need to wait
|
||||
# for these tasks before cancelling the camera and audio tasks below
|
||||
# because they might be still rendering.
|
||||
if self._sink_task:
|
||||
await self._sink_task
|
||||
if self._sink_clock_task:
|
||||
await self._sink_clock_task
|
||||
|
||||
async def _handle_interruptions(self, frame: Frame):
|
||||
if not self.interruptions_allowed:
|
||||
return
|
||||
@@ -279,7 +277,8 @@ class BaseOutputTransport(FrameProcessor):
|
||||
elif isinstance(frame, TTSStoppedFrame):
|
||||
await self._bot_stopped_speaking()
|
||||
await self.push_frame(frame)
|
||||
else:
|
||||
# We will push EndFrame later.
|
||||
elif not isinstance(frame, EndFrame):
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def _sink_task_handler(self):
|
||||
@@ -295,12 +294,6 @@ class BaseOutputTransport(FrameProcessor):
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} error processing sink queue: {e}")
|
||||
|
||||
async def _sink_clock_frame_handler(self, frame: Frame):
|
||||
# TODO(aleix): For now we just process TextFrame. But we should process
|
||||
# audio and video as well.
|
||||
if isinstance(frame, TextFrame):
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def _sink_clock_task_handler(self):
|
||||
running = True
|
||||
while running:
|
||||
@@ -315,12 +308,10 @@ class BaseOutputTransport(FrameProcessor):
|
||||
# time to process it.
|
||||
if running:
|
||||
current_time = self.get_clock().get_time()
|
||||
if timestamp <= current_time:
|
||||
await self._sink_clock_frame_handler(frame)
|
||||
else:
|
||||
if timestamp > current_time:
|
||||
wait_time = nanoseconds_to_seconds(timestamp - current_time)
|
||||
await asyncio.sleep(wait_time)
|
||||
await self._sink_frame_handler(frame)
|
||||
await self._sink_frame_handler(frame)
|
||||
|
||||
self._sink_clock_queue.task_done()
|
||||
except asyncio.CancelledError:
|
||||
|
||||
@@ -4,14 +4,13 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Callable, Mapping, Optional
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import aiohttp
|
||||
from daily import (
|
||||
CallClient,
|
||||
Daily,
|
||||
@@ -20,6 +19,7 @@ from daily import (
|
||||
VirtualMicrophoneDevice,
|
||||
VirtualSpeakerDevice,
|
||||
)
|
||||
from loguru import logger
|
||||
from pydantic.main import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
@@ -28,33 +28,25 @@ from pipecat.frames.frames import (
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
InterimTranscriptionFrame,
|
||||
MetricsFrame,
|
||||
OutputAudioRawFrame,
|
||||
OutputImageRawFrame,
|
||||
SpriteFrame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
TransportMessageFrame,
|
||||
TransportMessageUrgentFrame,
|
||||
UserImageRawFrame,
|
||||
UserImageRequestFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import (
|
||||
LLMUsageMetricsData,
|
||||
ProcessingMetricsData,
|
||||
TTFBMetricsData,
|
||||
TTSUsageMetricsData,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.transports.base_input import BaseInputTransport
|
||||
from pipecat.transports.base_output import BaseOutputTransport
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.vad.vad_analyzer import VADAnalyzer, VADParams
|
||||
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
from daily import EventHandler, CallClient, Daily
|
||||
from daily import CallClient, Daily, EventHandler
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
@@ -70,6 +62,11 @@ class DailyTransportMessageFrame(TransportMessageFrame):
|
||||
participant_id: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DailyTransportMessageUrgentFrame(TransportMessageUrgentFrame):
|
||||
participant_id: str | None = None
|
||||
|
||||
|
||||
class WebRTCVADAnalyzer(VADAnalyzer):
|
||||
def __init__(self, *, sample_rate=16000, num_channels=1, params: VADParams = VADParams()):
|
||||
super().__init__(sample_rate=sample_rate, num_channels=num_channels, params=params)
|
||||
@@ -234,12 +231,12 @@ class DailyTransportClient(EventHandler):
|
||||
def set_callbacks(self, callbacks: DailyCallbacks):
|
||||
self._callbacks = callbacks
|
||||
|
||||
async def send_message(self, frame: TransportMessageFrame):
|
||||
if not self._client:
|
||||
async def send_message(self, frame: TransportMessageFrame | TransportMessageUrgentFrame):
|
||||
if not self._joined or self._leaving:
|
||||
return
|
||||
|
||||
participant_id = None
|
||||
if isinstance(frame, DailyTransportMessageFrame):
|
||||
if isinstance(frame, (DailyTransportMessageFrame, DailyTransportMessageUrgentFrame)):
|
||||
participant_id = frame.participant_id
|
||||
|
||||
future = self._loop.create_future()
|
||||
@@ -714,21 +711,37 @@ class DailyOutputTransport(BaseOutputTransport):
|
||||
|
||||
self._client = client
|
||||
|
||||
# Task to process outgoing messages.
|
||||
self._messages_task = None
|
||||
self._messages_queue = asyncio.Queue()
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
# Parent start.
|
||||
await super().start(frame)
|
||||
# Join the room.
|
||||
await self._client.join()
|
||||
# Start messages task
|
||||
self._messages_task = self.get_event_loop().create_task(self._messages_task_handler())
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
# Parent stop.
|
||||
await super().stop(frame)
|
||||
# Cancel messages task
|
||||
if self._messages_task:
|
||||
self._messages_task.cancel()
|
||||
await self._messages_task
|
||||
self._messages_task = None
|
||||
# Leave the room.
|
||||
await self._client.leave()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
# Parent stop.
|
||||
await super().cancel(frame)
|
||||
# Cancel messages task
|
||||
if self._messages_task:
|
||||
self._messages_task.cancel()
|
||||
await self._messages_task
|
||||
self._messages_task = None
|
||||
# Leave the room.
|
||||
await self._client.leave()
|
||||
|
||||
@@ -736,33 +749,8 @@ class DailyOutputTransport(BaseOutputTransport):
|
||||
await super().cleanup()
|
||||
await self._client.cleanup()
|
||||
|
||||
async def send_message(self, frame: TransportMessageFrame):
|
||||
await self._client.send_message(frame)
|
||||
|
||||
async def send_metrics(self, frame: MetricsFrame):
|
||||
metrics = {}
|
||||
for d in frame.data:
|
||||
if isinstance(d, TTFBMetricsData):
|
||||
if "ttfb" not in metrics:
|
||||
metrics["ttfb"] = []
|
||||
metrics["ttfb"].append(d.model_dump(exclude_none=True))
|
||||
elif isinstance(d, ProcessingMetricsData):
|
||||
if "processing" not in metrics:
|
||||
metrics["processing"] = []
|
||||
metrics["processing"].append(d.model_dump(exclude_none=True))
|
||||
elif isinstance(d, LLMUsageMetricsData):
|
||||
if "tokens" not in metrics:
|
||||
metrics["tokens"] = []
|
||||
metrics["tokens"].append(d.value.model_dump(exclude_none=True))
|
||||
elif isinstance(d, TTSUsageMetricsData):
|
||||
if "characters" not in metrics:
|
||||
metrics["characters"] = []
|
||||
metrics["characters"].append(d.model_dump(exclude_none=True))
|
||||
|
||||
message = DailyTransportMessageFrame(
|
||||
message={"type": "pipecat-metrics", "metrics": metrics}
|
||||
)
|
||||
await self._client.send_message(message)
|
||||
async def send_message(self, frame: TransportMessageFrame | TransportMessageUrgentFrame):
|
||||
await self._messages_queue.put(frame)
|
||||
|
||||
async def write_raw_audio_frames(self, frames: bytes):
|
||||
await self._client.write_raw_audio_frames(frames)
|
||||
@@ -770,6 +758,17 @@ class DailyOutputTransport(BaseOutputTransport):
|
||||
async def write_frame_to_camera(self, frame: OutputImageRawFrame):
|
||||
await self._client.write_frame_to_camera(frame)
|
||||
|
||||
async def _messages_task_handler(self):
|
||||
while True:
|
||||
try:
|
||||
message = await self._messages_queue.get()
|
||||
await self._client.send_message(message)
|
||||
self._messages_queue.task_done()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} error processing message queue: {e}")
|
||||
|
||||
|
||||
class DailyTransport(BaseTransport):
|
||||
def __init__(
|
||||
|
||||
@@ -5,13 +5,12 @@
|
||||
#
|
||||
|
||||
import asyncio
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Callable, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
from scipy import signal
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
@@ -19,24 +18,18 @@ from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
MetricsFrame,
|
||||
InputAudioRawFrame,
|
||||
OutputAudioRawFrame,
|
||||
StartFrame,
|
||||
TransportMessageFrame,
|
||||
TransportMessageUrgentFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import (
|
||||
LLMUsageMetricsData,
|
||||
ProcessingMetricsData,
|
||||
TTFBMetricsData,
|
||||
TTSUsageMetricsData,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.transports.base_input import BaseInputTransport
|
||||
from pipecat.transports.base_output import BaseOutputTransport
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.vad.vad_analyzer import VADAnalyzer
|
||||
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
from livekit import rtc
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
@@ -51,6 +44,11 @@ class LiveKitTransportMessageFrame(TransportMessageFrame):
|
||||
participant_id: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LiveKitTransportMessageUrgentFrame(TransportMessageUrgentFrame):
|
||||
participant_id: str | None = None
|
||||
|
||||
|
||||
class LiveKitParams(TransportParams):
|
||||
audio_out_sample_rate: int = 48000
|
||||
audio_out_channels: int = 1
|
||||
@@ -67,6 +65,7 @@ class LiveKitCallbacks(BaseModel):
|
||||
on_audio_track_subscribed: Callable[[str], Awaitable[None]]
|
||||
on_audio_track_unsubscribed: Callable[[str], Awaitable[None]]
|
||||
on_data_received: Callable[[bytes, str], Awaitable[None]]
|
||||
on_first_participant_joined: Callable[[str], Awaitable[None]]
|
||||
|
||||
|
||||
class LiveKitTransportClient:
|
||||
@@ -92,6 +91,7 @@ class LiveKitTransportClient:
|
||||
self._audio_track: rtc.LocalAudioTrack | None = None
|
||||
self._audio_tracks = {}
|
||||
self._audio_queue = asyncio.Queue()
|
||||
self._other_participant_has_joined = False
|
||||
|
||||
# Set up room event handlers
|
||||
self._room.on("participant_connected")(self._on_participant_connected_wrapper)
|
||||
@@ -135,6 +135,12 @@ class LiveKitTransportClient:
|
||||
await self._room.local_participant.publish_track(self._audio_track, options)
|
||||
|
||||
await self._callbacks.on_connected()
|
||||
|
||||
# Check if there are already participants in the room
|
||||
participants = self.get_participants()
|
||||
if participants and not self._other_participant_has_joined:
|
||||
self._other_participant_has_joined = True
|
||||
await self._callbacks.on_first_participant_joined(participants[0])
|
||||
except Exception as e:
|
||||
logger.error(f"Error connecting to {self._room_name}: {e}")
|
||||
raise
|
||||
@@ -239,10 +245,15 @@ class LiveKitTransportClient:
|
||||
async def _async_on_participant_connected(self, participant: rtc.RemoteParticipant):
|
||||
logger.info(f"Participant connected: {participant.identity}")
|
||||
await self._callbacks.on_participant_connected(participant.sid)
|
||||
if not self._other_participant_has_joined:
|
||||
self._other_participant_has_joined = True
|
||||
await self._callbacks.on_first_participant_joined(participant.sid)
|
||||
|
||||
async def _async_on_participant_disconnected(self, participant: rtc.RemoteParticipant):
|
||||
logger.info(f"Participant disconnected: {participant.identity}")
|
||||
await self._callbacks.on_participant_disconnected(participant.sid)
|
||||
if len(self.get_participants()) == 0:
|
||||
self._other_participant_has_joined = False
|
||||
|
||||
async def _async_on_track_subscribed(
|
||||
self,
|
||||
@@ -351,10 +362,15 @@ class LiveKitInputTransport(BaseInputTransport):
|
||||
if audio_data:
|
||||
audio_frame_event, participant_id = audio_data
|
||||
pipecat_audio_frame = self._convert_livekit_audio_to_pipecat(audio_frame_event)
|
||||
await self.push_audio_frame(pipecat_audio_frame)
|
||||
input_audio_frame = InputAudioRawFrame(
|
||||
audio=pipecat_audio_frame.audio,
|
||||
sample_rate=pipecat_audio_frame.sample_rate,
|
||||
num_channels=pipecat_audio_frame.num_channels,
|
||||
)
|
||||
await self.push_frame(
|
||||
pipecat_audio_frame
|
||||
) # TODO: ensure audio frames are pushed with the default BaseInputTransport.push_audio_frame()
|
||||
await self.push_audio_frame(input_audio_frame)
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Audio input task cancelled")
|
||||
break
|
||||
@@ -377,9 +393,11 @@ class LiveKitInputTransport(BaseInputTransport):
|
||||
|
||||
if sample_rate != self._current_sample_rate:
|
||||
self._current_sample_rate = sample_rate
|
||||
self._vad_analyzer = VADAnalyzer(
|
||||
sample_rate=self._current_sample_rate, num_channels=self._params.audio_in_channels
|
||||
)
|
||||
if self._params.vad_enabled:
|
||||
self._vad_analyzer = VADAnalyzer(
|
||||
sample_rate=self._current_sample_rate,
|
||||
num_channels=self._params.audio_in_channels,
|
||||
)
|
||||
|
||||
return AudioRawFrame(
|
||||
audio=audio_data.tobytes(),
|
||||
@@ -420,37 +438,12 @@ class LiveKitOutputTransport(BaseOutputTransport):
|
||||
await super().cancel(frame)
|
||||
await self._client.disconnect()
|
||||
|
||||
async def send_message(self, frame: TransportMessageFrame):
|
||||
if isinstance(frame, LiveKitTransportMessageFrame):
|
||||
async def send_message(self, frame: TransportMessageFrame | TransportMessageUrgentFrame):
|
||||
if isinstance(frame, (LiveKitTransportMessageFrame, LiveKitTransportMessageUrgentFrame)):
|
||||
await self._client.send_data(frame.message.encode(), frame.participant_id)
|
||||
else:
|
||||
await self._client.send_data(frame.message.encode())
|
||||
|
||||
async def send_metrics(self, frame: MetricsFrame):
|
||||
metrics = {}
|
||||
for d in frame.data:
|
||||
if isinstance(d, TTFBMetricsData):
|
||||
if "ttfb" not in metrics:
|
||||
metrics["ttfb"] = []
|
||||
metrics["ttfb"].append(d.model_dump())
|
||||
elif isinstance(d, ProcessingMetricsData):
|
||||
if "processing" not in metrics:
|
||||
metrics["processing"] = []
|
||||
metrics["processing"].append(d.model_dump())
|
||||
elif isinstance(d, LLMUsageMetricsData):
|
||||
if "tokens" not in metrics:
|
||||
metrics["tokens"] = []
|
||||
metrics["tokens"].append(d.value.model_dump(exclude_none=True))
|
||||
elif isinstance(d, TTSUsageMetricsData):
|
||||
if "characters" not in metrics:
|
||||
metrics["characters"] = []
|
||||
metrics["characters"].append(d.model_dump())
|
||||
|
||||
message = LiveKitTransportMessageFrame(
|
||||
message={"type": "pipecat-metrics", "metrics": metrics}
|
||||
)
|
||||
await self._client.send_data(str(message.message).encode())
|
||||
|
||||
async def write_raw_audio_frames(self, frames: bytes):
|
||||
livekit_audio = self._convert_pipecat_audio_to_livekit(frames)
|
||||
await self._client.publish_audio(livekit_audio)
|
||||
@@ -481,13 +474,20 @@ class LiveKitTransport(BaseTransport):
|
||||
):
|
||||
super().__init__(input_name=input_name, output_name=output_name, loop=loop)
|
||||
|
||||
self._url = url
|
||||
self._token = token
|
||||
self._room_name = room_name
|
||||
callbacks = LiveKitCallbacks(
|
||||
on_connected=self._on_connected,
|
||||
on_disconnected=self._on_disconnected,
|
||||
on_participant_connected=self._on_participant_connected,
|
||||
on_participant_disconnected=self._on_participant_disconnected,
|
||||
on_audio_track_subscribed=self._on_audio_track_subscribed,
|
||||
on_audio_track_unsubscribed=self._on_audio_track_unsubscribed,
|
||||
on_data_received=self._on_data_received,
|
||||
on_first_participant_joined=self._on_first_participant_joined,
|
||||
)
|
||||
self._params = params
|
||||
|
||||
self._client = LiveKitTransportClient(
|
||||
url, token, room_name, self._params, self._create_callbacks(), self._loop
|
||||
url, token, room_name, self._params, callbacks, self._loop
|
||||
)
|
||||
self._input: LiveKitInputTransport | None = None
|
||||
self._output: LiveKitOutputTransport | None = None
|
||||
@@ -503,23 +503,12 @@ class LiveKitTransport(BaseTransport):
|
||||
self._register_event_handler("on_participant_left")
|
||||
self._register_event_handler("on_call_state_updated")
|
||||
|
||||
def _create_callbacks(self) -> LiveKitCallbacks:
|
||||
return LiveKitCallbacks(
|
||||
on_connected=self._on_connected,
|
||||
on_disconnected=self._on_disconnected,
|
||||
on_participant_connected=self._on_participant_connected,
|
||||
on_participant_disconnected=self._on_participant_disconnected,
|
||||
on_audio_track_subscribed=self._on_audio_track_subscribed,
|
||||
on_audio_track_unsubscribed=self._on_audio_track_unsubscribed,
|
||||
on_data_received=self._on_data_received,
|
||||
)
|
||||
|
||||
def input(self) -> FrameProcessor:
|
||||
def input(self) -> LiveKitInputTransport:
|
||||
if not self._input:
|
||||
self._input = LiveKitInputTransport(self._client, self._params, name=self._input_name)
|
||||
return self._input
|
||||
|
||||
def output(self) -> FrameProcessor:
|
||||
def output(self) -> LiveKitOutputTransport:
|
||||
if not self._output:
|
||||
self._output = LiveKitOutputTransport(
|
||||
self._client, self._params, name=self._output_name
|
||||
@@ -530,7 +519,7 @@ class LiveKitTransport(BaseTransport):
|
||||
def participant_id(self) -> str:
|
||||
return self._client.participant_id
|
||||
|
||||
async def send_audio(self, frame: AudioRawFrame):
|
||||
async def send_audio(self, frame: OutputAudioRawFrame):
|
||||
if self._output:
|
||||
await self._output.process_frame(frame, FrameDirection.DOWNSTREAM)
|
||||
|
||||
@@ -563,8 +552,6 @@ class LiveKitTransport(BaseTransport):
|
||||
|
||||
async def _on_participant_connected(self, participant_id: str):
|
||||
await self._call_event_handler("on_participant_connected", participant_id)
|
||||
if len(self.get_participants()) == 1:
|
||||
await self._call_event_handler("on_first_participant_joined", participant_id)
|
||||
|
||||
async def _on_participant_disconnected(self, participant_id: str):
|
||||
await self._call_event_handler("on_participant_disconnected", participant_id)
|
||||
@@ -596,6 +583,13 @@ class LiveKitTransport(BaseTransport):
|
||||
frame = LiveKitTransportMessageFrame(message=message, participant_id=participant_id)
|
||||
await self._output.send_message(frame)
|
||||
|
||||
async def send_message_urgent(self, message: str, participant_id: str | None = None):
|
||||
if self._output:
|
||||
frame = LiveKitTransportMessageUrgentFrame(
|
||||
message=message, participant_id=participant_id
|
||||
)
|
||||
await self._output.send_message(frame)
|
||||
|
||||
async def cleanup(self):
|
||||
if self._input:
|
||||
await self._input.cleanup()
|
||||
@@ -617,3 +611,6 @@ class LiveKitTransport(BaseTransport):
|
||||
|
||||
async def _on_call_state_updated(self, state: str):
|
||||
await self._call_event_handler("on_call_state_updated", self, state)
|
||||
|
||||
async def _on_first_participant_joined(self, participant_id: str):
|
||||
await self._call_event_handler("on_first_participant_joined", participant_id)
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
|
||||
import re
|
||||
|
||||
|
||||
ENDOFSENTENCE_PATTERN_STR = r"""
|
||||
(?<![A-Z]) # Negative lookbehind: not preceded by an uppercase letter (e.g., "U.S.A.")
|
||||
(?<!\d) # Negative lookbehind: not preceded by a digit (e.g., "1. Let's start")
|
||||
@@ -21,5 +20,6 @@ ENDOFSENTENCE_PATTERN_STR = r"""
|
||||
ENDOFSENTENCE_PATTERN = re.compile(ENDOFSENTENCE_PATTERN_STR, re.VERBOSE)
|
||||
|
||||
|
||||
def match_endofsentence(text: str) -> bool:
|
||||
return ENDOFSENTENCE_PATTERN.search(text.rstrip()) is not None
|
||||
def match_endofsentence(text: str) -> int:
|
||||
match = ENDOFSENTENCE_PATTERN.search(text.rstrip())
|
||||
return match.end() if match else 0
|
||||
|
||||
0
src/pipecat/utils/text/__init__.py
Normal file
0
src/pipecat/utils/text/__init__.py
Normal file
26
src/pipecat/utils/text/base_text_filter.py
Normal file
26
src/pipecat/utils/text/base_text_filter.py
Normal file
@@ -0,0 +1,26 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Mapping
|
||||
|
||||
|
||||
class BaseTextFilter(ABC):
|
||||
@abstractmethod
|
||||
def update_settings(self, settings: Mapping[str, Any]):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def filter(self, text: str) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def handle_interruption(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset_interruption(self):
|
||||
pass
|
||||
216
src/pipecat/utils/text/markdown_text_filter.py
Normal file
216
src/pipecat/utils/text/markdown_text_filter.py
Normal file
@@ -0,0 +1,216 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import re
|
||||
from typing import Any, Mapping, Optional
|
||||
|
||||
from markdown import Markdown
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.utils.text.base_text_filter import BaseTextFilter
|
||||
|
||||
|
||||
class MarkdownTextFilter(BaseTextFilter):
|
||||
"""Removes Markdown formatting from text in TextFrames.
|
||||
|
||||
Converts Markdown to plain text while preserving the overall structure,
|
||||
including leading and trailing spaces. Handles special cases like
|
||||
asterisks and table formatting.
|
||||
"""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
enable_text_filter: Optional[bool] = True
|
||||
filter_code: Optional[bool] = False
|
||||
filter_tables: Optional[bool] = False
|
||||
|
||||
def __init__(self, params: InputParams = InputParams(), **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._settings = params
|
||||
self._in_code_block = False
|
||||
self._in_table = False
|
||||
self._interrupted = False
|
||||
|
||||
def update_settings(self, settings: Mapping[str, Any]):
|
||||
for key, value in settings.items():
|
||||
if hasattr(self._settings, key):
|
||||
setattr(self._settings, key, value)
|
||||
|
||||
def filter(self, text: str) -> str:
|
||||
if self._settings.enable_text_filter:
|
||||
# Remove newlines and replace with a space only when there's no text before or after
|
||||
filtered_text = re.sub(r"^\s*\n", " ", text, flags=re.MULTILINE)
|
||||
|
||||
# Remove backticks from inline code, but not from code blocks
|
||||
filtered_text = re.sub(r"(?<!`)`([^`\n]+)`(?!`)", r"\1", filtered_text)
|
||||
|
||||
# Remove repeated sequences of 5 or more characters
|
||||
filtered_text = re.sub(r"(\S)(\1{4,})", "", filtered_text)
|
||||
|
||||
# Preserve numbered list items with a unique marker, §NUM§
|
||||
filtered_text = re.sub(r"^(\d+\.)\s", r"§NUM§\1 ", filtered_text)
|
||||
|
||||
# Preserve leading/trailing spaces with a unique marker, §
|
||||
# Critical for word-by-word streaming in bot-tts-text
|
||||
filtered_text = re.sub(
|
||||
r"^( +)|\s+$", lambda m: "§" * len(m.group(0)), filtered_text, flags=re.MULTILINE
|
||||
)
|
||||
|
||||
# Remove space placeholders before tables, so that tables are converted to HTML
|
||||
# correctly
|
||||
filtered_text = re.sub(r"§\| ", "| ", filtered_text)
|
||||
|
||||
# Convert markdown to HTML
|
||||
extension = ["tables"] if self._settings.filter_tables else []
|
||||
md = Markdown(extensions=extension)
|
||||
filtered_text = md.convert(filtered_text)
|
||||
|
||||
# Remove tables
|
||||
if self._settings.filter_tables:
|
||||
filtered_text = self.remove_tables(filtered_text)
|
||||
|
||||
# Remove HTML tags
|
||||
filtered_text = re.sub("<[^<]+?>", "", filtered_text)
|
||||
|
||||
# Replace HTML entities
|
||||
filtered_text = filtered_text.replace(" ", " ")
|
||||
filtered_text = filtered_text.replace("<", "<")
|
||||
filtered_text = filtered_text.replace(">", ">")
|
||||
filtered_text = filtered_text.replace("&", "&")
|
||||
|
||||
# Remove double asterisks (consecutive without any exceptions)
|
||||
filtered_text = re.sub(r"\*\*", "", filtered_text)
|
||||
|
||||
# Remove single asterisks at the start or end of words
|
||||
filtered_text = re.sub(r"(^|\s)\*|\*($|\s)", r"\1\2", filtered_text)
|
||||
|
||||
# Remove Markdown table formatting
|
||||
filtered_text = re.sub(r"\|", "", filtered_text)
|
||||
filtered_text = re.sub(r"^\s*[-:]+\s*$", "", filtered_text, flags=re.MULTILINE)
|
||||
|
||||
# Remove code blocks
|
||||
if self._settings.filter_code:
|
||||
filtered_text = self._remove_code_blocks(filtered_text)
|
||||
|
||||
# Restore numbered list items
|
||||
filtered_text = filtered_text.replace("§NUM§", "")
|
||||
|
||||
# Restore leading and trailing spaces
|
||||
filtered_text = re.sub("§", " ", filtered_text)
|
||||
|
||||
return filtered_text
|
||||
else:
|
||||
return text
|
||||
|
||||
def handle_interruption(self):
|
||||
self._interrupted = True
|
||||
self._in_code_block = False
|
||||
self._in_table = False
|
||||
|
||||
def reset_interruption(self):
|
||||
self._interrupted = False
|
||||
|
||||
#
|
||||
# Filter code
|
||||
#
|
||||
|
||||
def _remove_code_blocks(self, text: str) -> str:
|
||||
"""
|
||||
Main method to remove code blocks from the input text.
|
||||
Handles interruptions and delegates to specific methods based on the current state.
|
||||
"""
|
||||
if self._interrupted:
|
||||
self._in_code_block = False
|
||||
return text
|
||||
|
||||
# Pattern to match three consecutive backticks (code block delimiter)
|
||||
code_block_pattern = r"```"
|
||||
match = re.search(code_block_pattern, text)
|
||||
|
||||
if self._in_code_block:
|
||||
return self._handle_in_code_block(match, text)
|
||||
|
||||
return self._handle_not_in_code_block(match, text, code_block_pattern)
|
||||
|
||||
def _handle_in_code_block(self, match, text):
|
||||
"""
|
||||
Handle text when we're currently inside a code block.
|
||||
If we find the end of the block, return text after it. Otherwise, skip the content.
|
||||
"""
|
||||
if match:
|
||||
self._in_code_block = False
|
||||
end_index = match.end()
|
||||
return text[end_index:].strip()
|
||||
return "" # Skip content inside code block
|
||||
|
||||
def _handle_not_in_code_block(self, match, text, code_block_pattern):
|
||||
"""
|
||||
Handle text when we're not currently inside a code block.
|
||||
Delegate to specific methods based on whether we find a code block delimiter.
|
||||
"""
|
||||
if not match:
|
||||
return text # No code block found, return original text
|
||||
|
||||
start_index = match.start()
|
||||
if start_index == 0 or text[:start_index].isspace():
|
||||
return self._handle_start_of_code_block(text, start_index)
|
||||
return self._handle_code_block_within_text(text, code_block_pattern)
|
||||
|
||||
def _handle_start_of_code_block(self, text, start_index):
|
||||
"""
|
||||
Handle the case where we find the start of a code block.
|
||||
Return any text before the code block and set the state to inside a code block.
|
||||
"""
|
||||
self._in_code_block = True
|
||||
return text[:start_index].strip()
|
||||
|
||||
def _handle_code_block_within_text(self, text, code_block_pattern):
|
||||
"""
|
||||
Handle the case where we find a code block within the text.
|
||||
If it's a complete code block, remove it and return surrounding text.
|
||||
If it's the start of a code block, return text before it and set state.
|
||||
"""
|
||||
parts = re.split(code_block_pattern, text)
|
||||
if len(parts) > 2:
|
||||
return (parts[0] + " " + parts[-1]).strip()
|
||||
self._in_code_block = True
|
||||
return parts[0].strip()
|
||||
|
||||
#
|
||||
# Filter tables
|
||||
#
|
||||
def remove_tables(self, text: str) -> str:
|
||||
"""
|
||||
Remove tables from the input text, handling cases where
|
||||
both start and end tags are in the same input.
|
||||
"""
|
||||
if self._interrupted:
|
||||
self._in_table = False
|
||||
return text
|
||||
|
||||
# Pattern to match entire table or parts of it
|
||||
table_pattern = r"<table>.*?</table>"
|
||||
partial_table_start = r"<table>.*"
|
||||
partial_table_end = r".*</table>"
|
||||
|
||||
# Remove complete tables
|
||||
text = re.sub(table_pattern, "", text, flags=re.DOTALL | re.IGNORECASE)
|
||||
|
||||
# Handle partial tables at the start
|
||||
if self._in_table:
|
||||
match = re.match(partial_table_end, text, re.DOTALL | re.IGNORECASE)
|
||||
if match:
|
||||
self._in_table = False
|
||||
return text[match.end() :].strip()
|
||||
else:
|
||||
return "" # Still inside a table, remove all content
|
||||
|
||||
# Handle partial tables at the end
|
||||
match = re.search(partial_table_start, text, re.DOTALL | re.IGNORECASE)
|
||||
if match:
|
||||
self._in_table = True
|
||||
return text[: match.start()].strip()
|
||||
|
||||
return text.strip()
|
||||
@@ -2,10 +2,10 @@ aiohttp~=3.10.3
|
||||
anthropic~=0.30.0
|
||||
azure-cognitiveservices-speech~=1.40.0
|
||||
boto3~=1.35.27
|
||||
daily-python~=0.10.1
|
||||
daily-python~=0.11.0
|
||||
deepgram-sdk~=3.5.0
|
||||
fal-client~=0.4.1
|
||||
fastapi~=0.112.1
|
||||
fastapi~=0.115.0
|
||||
faster-whisper~=1.0.3
|
||||
google-cloud-texttospeech~=2.17.2
|
||||
google-generativeai~=0.7.2
|
||||
|
||||
@@ -7,9 +7,9 @@
|
||||
import unittest
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
StopTaskFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
@@ -32,6 +32,7 @@ from langchain_core.language_models import FakeStreamingListLLM
|
||||
class TestLangchain(unittest.IsolatedAsyncioTestCase):
|
||||
class MockProcessor(FrameProcessor):
|
||||
def __init__(self, name):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.token: list[str] = []
|
||||
# Start collecting tokens when we see the start frame
|
||||
@@ -55,13 +56,13 @@ class TestLangchain(unittest.IsolatedAsyncioTestCase):
|
||||
def setUp(self):
|
||||
self.expected_response = "Hello dear human"
|
||||
self.fake_llm = FakeStreamingListLLM(responses=[self.expected_response])
|
||||
self.mock_proc = self.MockProcessor("token_collector")
|
||||
|
||||
async def test_langchain(self):
|
||||
messages = [("system", "Say hello to {name}"), ("human", "{input}")]
|
||||
prompt = ChatPromptTemplate.from_messages(messages).partial(name="Thomas")
|
||||
chain = prompt | self.fake_llm
|
||||
proc = LangchainProcessor(chain=chain)
|
||||
self.mock_proc = self.MockProcessor("token_collector")
|
||||
|
||||
tma_in = LLMUserResponseAggregator(messages)
|
||||
tma_out = LLMAssistantResponseAggregator(messages)
|
||||
@@ -81,7 +82,7 @@ class TestLangchain(unittest.IsolatedAsyncioTestCase):
|
||||
UserStartedSpeakingFrame(),
|
||||
TranscriptionFrame(text="Hi World", user_id="user", timestamp="now"),
|
||||
UserStoppedSpeakingFrame(),
|
||||
StopTaskFrame(),
|
||||
EndFrame(),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user