Compare commits
197 Commits
khk-greedy
...
aleix/stop
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8d6b8b035e | ||
|
|
0a15874c12 | ||
|
|
d60e99a043 | ||
|
|
77723b34c7 | ||
|
|
c466d34a06 | ||
|
|
f816897833 | ||
|
|
c1e8a5e522 | ||
|
|
76aca32f2e | ||
|
|
7e31b2a795 | ||
|
|
028e38a86b | ||
|
|
8cf7649855 | ||
|
|
64f5119b08 | ||
|
|
4d606aefb3 | ||
|
|
4bafdaa04d | ||
|
|
5afe1abf82 | ||
|
|
f066d50b98 | ||
|
|
91103e21cc | ||
|
|
f44dabcd65 | ||
|
|
0fd2fca231 | ||
|
|
5bb64098e7 | ||
|
|
3fc85e75e0 | ||
|
|
3f61ea16b7 | ||
|
|
4b393092b5 | ||
|
|
b583f5162b | ||
|
|
060a22f395 | ||
|
|
d3e85355f1 | ||
|
|
83e730b768 | ||
|
|
5fcc96446c | ||
|
|
ad88925154 | ||
|
|
0a6ddbf15c | ||
|
|
08e0722d97 | ||
|
|
05d4fba551 | ||
|
|
f41c2b3c9f | ||
|
|
69f64899fe | ||
|
|
33f0865430 | ||
|
|
ad5b9202ab | ||
|
|
1676693091 | ||
|
|
0852b50b8f | ||
|
|
eb998aa502 | ||
|
|
6dab0e9de7 | ||
|
|
95ff1d141c | ||
|
|
87bc8a9da6 | ||
|
|
087fe9a537 | ||
|
|
c1170260b5 | ||
|
|
65cdf50774 | ||
|
|
9233bb490c | ||
|
|
43932220f7 | ||
|
|
cea4d1894e | ||
|
|
80baa0358d | ||
|
|
5d73db53a0 | ||
|
|
302ea90dce | ||
|
|
37b04ed283 | ||
|
|
be6995cfdf | ||
|
|
dfbc11300c | ||
|
|
82d539d174 | ||
|
|
6e00f31014 | ||
|
|
a46ac3cc92 | ||
|
|
6fbf98d8e2 | ||
|
|
f094c42728 | ||
|
|
13827e1282 | ||
|
|
32170b47d9 | ||
|
|
09c05354c2 | ||
|
|
b0b1475563 | ||
|
|
b85dd7283a | ||
|
|
846ae765e5 | ||
|
|
4c629e538e | ||
|
|
f6e22bb3b9 | ||
|
|
46a048d7f6 | ||
|
|
bd9f4eea06 | ||
|
|
0a672e61e2 | ||
|
|
29a8530221 | ||
|
|
3e738642a7 | ||
|
|
f551f55f03 | ||
|
|
9f012c8002 | ||
|
|
0a69a9e5ef | ||
|
|
194790183a | ||
|
|
2227721173 | ||
|
|
77a53da5f5 | ||
|
|
ab63ff275d | ||
|
|
e5363f65f0 | ||
|
|
ffc157de65 | ||
|
|
f9fdadb4c0 | ||
|
|
4efccb79f2 | ||
|
|
337968199a | ||
|
|
37027f68cb | ||
|
|
d1b62c5495 | ||
|
|
355fe01cb7 | ||
|
|
9d050a16c7 | ||
|
|
fa53c67606 | ||
|
|
5006376fe6 | ||
|
|
2204b8e205 | ||
|
|
270007b17c | ||
|
|
568eb2ef4c | ||
|
|
73ca9184a8 | ||
|
|
5e8e11e16e | ||
|
|
029bbc16f2 | ||
|
|
9e3d87e4f6 | ||
|
|
f1410a1127 | ||
|
|
2b980d16c3 | ||
|
|
b2b97aafb8 | ||
|
|
da2082b025 | ||
|
|
327ea9d547 | ||
|
|
b23db4a202 | ||
|
|
d1a36004ab | ||
|
|
6071920c45 | ||
|
|
5f539e1fba | ||
|
|
8e1539c360 | ||
|
|
065cfb2aca | ||
|
|
3147534e86 | ||
|
|
be5603bf16 | ||
|
|
b9b0bcdcbd | ||
|
|
5bcece56f3 | ||
|
|
d67faef88c | ||
|
|
8f6db5e905 | ||
|
|
82e93a0560 | ||
|
|
a9a82c083b | ||
|
|
974d9c33ed | ||
|
|
c1957ab694 | ||
|
|
b20a10a4bc | ||
|
|
be14ce465d | ||
|
|
d1ca0c5614 | ||
|
|
535514f506 | ||
|
|
933b63cf13 | ||
|
|
d7c3e380a5 | ||
|
|
c5298f78cb | ||
|
|
4f8f7b8d1d | ||
|
|
d7d46919ac | ||
|
|
e5d73d2e2e | ||
|
|
b145e8ec90 | ||
|
|
97ff4a1fb8 | ||
|
|
5018a552c1 | ||
|
|
7f9fd9ffce | ||
|
|
ddd0ca6a8f | ||
|
|
06f817c7e3 | ||
|
|
df4c3e56c4 | ||
|
|
9d5c2b9656 | ||
|
|
7ce59c5e2e | ||
|
|
1c9631fc78 | ||
|
|
efbe7297f7 | ||
|
|
1b45946a61 | ||
|
|
cbf5a6362c | ||
|
|
583b96c341 | ||
|
|
fc0920504d | ||
|
|
abd65a93b2 | ||
|
|
c3244fdd7a | ||
|
|
e8f58938b0 | ||
|
|
602b4f34b1 | ||
|
|
0399c84dfa | ||
|
|
fd5d879bf5 | ||
|
|
8dff460307 | ||
|
|
cce1ddb183 | ||
|
|
8691d14289 | ||
|
|
dd402da9e5 | ||
|
|
2fd04248f1 | ||
|
|
0ac42006f8 | ||
|
|
66e331248d | ||
|
|
4be3e8c87d | ||
|
|
dac033fe61 | ||
|
|
d302cbb114 | ||
|
|
e3b407db28 | ||
|
|
4ef623f09e | ||
|
|
253530a63d | ||
|
|
4f38d989f5 | ||
|
|
84074e90ee | ||
|
|
38aee7d8f2 | ||
|
|
64198313c6 | ||
|
|
d61b6c301c | ||
|
|
83d1931266 | ||
|
|
c31f2ab285 | ||
|
|
0ddc5721b4 | ||
|
|
98bd183bc4 | ||
|
|
aaa154524c | ||
|
|
beced68337 | ||
|
|
94823ab952 | ||
|
|
0b6a19802f | ||
|
|
c4a2d2197c | ||
|
|
269d06aa15 | ||
|
|
dfef1f2c54 | ||
|
|
b62beaba0b | ||
|
|
adf414e40f | ||
|
|
dc64e57f63 | ||
|
|
d3e410b2ac | ||
|
|
c544b2474b | ||
|
|
18243de358 | ||
|
|
6625895d1f | ||
|
|
f9ecce739e | ||
|
|
0075dd8386 | ||
|
|
eef1cde816 | ||
|
|
8d867c30c6 | ||
|
|
42c668b7ae | ||
|
|
b62227b4ae | ||
|
|
25ef0cb87b | ||
|
|
e195941aa5 | ||
|
|
e09eef1dd7 | ||
|
|
7c13663a4e | ||
|
|
5753869e5e | ||
|
|
ba878a19f4 |
7
.github/workflows/publish_test.yaml
vendored
7
.github/workflows/publish_test.yaml
vendored
@@ -1,10 +1,6 @@
|
||||
name: publish-test
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
on: workflow_dispatch
|
||||
|
||||
jobs:
|
||||
build:
|
||||
@@ -14,7 +10,6 @@ jobs:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.inputs.gitref }}
|
||||
fetch-tags: true
|
||||
fetch-depth: 100
|
||||
- name: Set up Python
|
||||
|
||||
232
CHANGELOG.md
232
CHANGELOG.md
@@ -9,10 +9,240 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Added
|
||||
|
||||
- Added new `AzureSTTService`. This allows you to use Azure Speech-To-Text.
|
||||
- Added new `BotStartedSpeakingFrame` and `BotStoppedSpeakingFrame` control
|
||||
frames. These frames are pushed upstream and they should wrap
|
||||
`BotSpeakingFrame`.
|
||||
|
||||
- Transports now allow you to register event handlers without decorators.
|
||||
|
||||
### Changed
|
||||
|
||||
- `BotSpeakingFrame` is now a control frame.
|
||||
|
||||
- `StartFrame` is now a control frame similar to `EndFrame`.
|
||||
|
||||
- `DeepgramTTSService` now is more customizable. You can adjust the encoding and
|
||||
sample rate.
|
||||
|
||||
### Fixed
|
||||
|
||||
- RTVI's `bot-ready` message is now sent when the RTVI pipeline is ready and
|
||||
a first participant joins.
|
||||
|
||||
- Fixed a `BaseInputTransport` issue that was causing incoming system frames to
|
||||
be queued instead of being pushed immediately.
|
||||
|
||||
- Fixed a `BaseInputTransport` issue that was causing start/stop interruptions
|
||||
incoming frames to not cancel tasks and be processed properly.
|
||||
|
||||
## [0.0.39] - 2024-07-23
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed a regression introduced in 0.0.38 that would cause Daily transcription
|
||||
to stop the Pipeline.
|
||||
|
||||
## [0.0.38] - 2024-07-23
|
||||
|
||||
### Added
|
||||
|
||||
- Added `force_reload`, `skip_validation` and `trust_repo` to `SileroVAD` and
|
||||
`SileroVADAnalyzer`. This allows caching and various GitHub repo validations.
|
||||
|
||||
- Added `send_initial_empty_metrics` flag to `PipelineParams` to request for
|
||||
initial empty metrics (zero values). True by default.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed initial metrics format. It was using the wrong keys name/time instead of
|
||||
processor/value.
|
||||
|
||||
- STT services should be using ISO 8601 time format for transcription frames.
|
||||
|
||||
- Fixed an issue that would cause Daily transport to show a stop transcription
|
||||
error when actually none occurred.
|
||||
|
||||
## [0.0.37] - 2024-07-22
|
||||
|
||||
### Added
|
||||
|
||||
- Added `RTVIProcessor` which implements the RTVI-AI standard.
|
||||
See https://github.com/rtvi-ai
|
||||
|
||||
- Added `BotInterruptionFrame` which allows interrupting the bot while talking.
|
||||
|
||||
- Added `LLMMessagesAppendFrame` which allows appending messages to the current
|
||||
LLM context.
|
||||
|
||||
- Added `LLMMessagesUpdateFrame` which allows changing the LLM context for the
|
||||
one provided in this new frame.
|
||||
|
||||
- Added `LLMModelUpdateFrame` which allows updating the LLM model.
|
||||
|
||||
- Added `TTSSpeakFrame` which causes the bot say some text. This text will not
|
||||
be part of the LLM context.
|
||||
|
||||
- Added `TTSVoiceUpdateFrame` which allows updating the TTS voice.
|
||||
|
||||
### Removed
|
||||
|
||||
- We remove the `LLMResponseStartFrame` and `LLMResponseEndFrame` frames. These
|
||||
were added in the past to properly handle interruptions for the
|
||||
`LLMAssistantContextAggregator`. But the `LLMContextAggregator` is now based
|
||||
on `LLMResponseAggregator` which handles interruptions properly by just
|
||||
processing the `StartInterruptionFrame`, so there's no need for these extra
|
||||
frames any more.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed an issue with `StatelessTextTransformer` where it was pushing a string
|
||||
instead of a `TextFrame`.
|
||||
|
||||
- `TTSService` end of sentence detection has been improved. It now works with
|
||||
acronyms, numbers, hours and others.
|
||||
|
||||
- Fixed an issue in `TTSService` that would not properly flush the current
|
||||
aggregated sentence if an `LLMFullResponseEndFrame` was found.
|
||||
|
||||
### Performance
|
||||
|
||||
- `CartesiaTTSService` now uses websockets which improves speed. It also
|
||||
leverages the new Cartesia contexts which maintains generated audio prosody
|
||||
when multiple inputs are sent, therefore improving audio quality a lot.
|
||||
|
||||
## [0.0.36] - 2024-07-02
|
||||
|
||||
### Added
|
||||
|
||||
- Added `GladiaSTTService`.
|
||||
See https://docs.gladia.io/chapters/speech-to-text-api/pages/live-speech-recognition
|
||||
|
||||
- Added `XTTSService`. This is a local Text-To-Speech service.
|
||||
See https://github.com/coqui-ai/TTS
|
||||
|
||||
- Added `UserIdleProcessor`. This processor can be used to wait for any
|
||||
interaction with the user. If the user doesn't say anything within a given
|
||||
timeout a provided callback is called.
|
||||
|
||||
- Added `IdleFrameProcessor`. This processor can be used to wait for frames
|
||||
within a given timeout. If no frame is received within the timeout a provided
|
||||
callback is called.
|
||||
|
||||
- Added new frame `BotSpeakingFrame`. This frame will be continuously pushed
|
||||
upstream while the bot is talking.
|
||||
|
||||
- It is now possible to specify a Silero VAD version when using `SileroVADAnalyzer`
|
||||
or `SileroVAD`.
|
||||
|
||||
- Added `AysncFrameProcessor` and `AsyncAIService`. Some services like
|
||||
`DeepgramSTTService` need to process things asynchronously. For example, audio
|
||||
is sent to Deepgram but transcriptions are not returned immediately. In these
|
||||
cases we still require all frames (except system frames) to be pushed
|
||||
downstream from a single task. That's what `AsyncFrameProcessor` is for. It
|
||||
creates a task and all frames should be pushed from that task. So, whenever a
|
||||
new Deepgram transcription is ready that transcription will also be pushed
|
||||
from this internal task.
|
||||
|
||||
- The `MetricsFrame` now includes processing metrics if metrics are enabled. The
|
||||
processing metrics indicate the time a processor needs to generate all its
|
||||
output. Note that not all processors generate these kind of metrics.
|
||||
|
||||
### Changed
|
||||
|
||||
- `WhisperSTTService` model can now also be a string.
|
||||
|
||||
- Added missing * keyword separators in services.
|
||||
|
||||
### Fixed
|
||||
|
||||
- `WebsocketServerTransport` doesn't try to send frames anymore if serializers
|
||||
returns `None`.
|
||||
|
||||
- Fixed an issue where exceptions that occurred inside frame processors were
|
||||
being swallowed and not displayed.
|
||||
|
||||
- Fixed an issue in `FastAPIWebsocketTransport` where it would still try to send
|
||||
data to the websocket after being closed.
|
||||
|
||||
### Other
|
||||
|
||||
- Added Fly.io deployment example in `examples/deployment/flyio-example`.
|
||||
|
||||
- Added new `17-detect-user-idle.py` example that shows how to use the new
|
||||
`UserIdleProcessor`.
|
||||
|
||||
## [0.0.35] - 2024-06-28
|
||||
|
||||
### Changed
|
||||
|
||||
- `FastAPIWebsocketParams` now require a serializer.
|
||||
|
||||
- `TwilioFrameSerializer` now requires a `streamSid`.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Silero VAD number of frames needs to be 512 for 16000 sample rate or 256 for
|
||||
8000 sample rate.
|
||||
|
||||
## [0.0.34] - 2024-06-25
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed an issue with asynchronous STT services (Deepgram and Azure) that could
|
||||
interruptions to ignore transcriptions.
|
||||
|
||||
- Fixed an issue introduced in 0.0.33 that would cause the LLM to generate
|
||||
shorter output.
|
||||
|
||||
## [0.0.33] - 2024-06-25
|
||||
|
||||
### Changed
|
||||
|
||||
- Upgraded to Cartesia's new Python library 1.0.0. `CartesiaTTSService` now
|
||||
expects a voice ID instead of a voice name (you can get the voice ID from
|
||||
Cartesia's playground). You can also specify the audio `sample_rate` and
|
||||
`encoding` instead of the previous `output_format`.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed an issue with asynchronous STT services (Deepgram and Azure) that could
|
||||
cause static audio issues and interruptions to not work properly when dealing
|
||||
with multiple LLMs sentences.
|
||||
|
||||
- Fixed an issue that could mix new LLM responses with previous ones when
|
||||
handling interruptions.
|
||||
|
||||
- Fixed a Daily transport blocking situation that occurred while reading audio
|
||||
frames after a participant left the room. Needs daily-python >= 0.10.1.
|
||||
|
||||
## [0.0.32] - 2024-06-22
|
||||
|
||||
### Added
|
||||
|
||||
- Allow specifying a `DeepgramSTTService` url which allows using on-prem
|
||||
Deepgram.
|
||||
|
||||
- Added new `FastAPIWebsocketTransport`. This is a new websocket transport that
|
||||
can be integrated with FastAPI websockets.
|
||||
|
||||
- Added new `TwilioFrameSerializer`. This is a new serializer that knows how to
|
||||
serialize and deserialize audio frames from Twilio.
|
||||
|
||||
- Added Daily transport event: `on_dialout_answered`. See
|
||||
https://reference-python.daily.co/api_reference.html#daily.EventHandler
|
||||
|
||||
- Added new `AzureSTTService`. This allows you to use Azure Speech-To-Text.
|
||||
|
||||
### Performance
|
||||
|
||||
- Convert `BaseOutputTransport` and `BaseOutputTransport` to fully use asyncio
|
||||
and remove the use of threads.
|
||||
|
||||
### Other
|
||||
|
||||
- Added `twilio-chatbot`. This is an example that shows how to integrate Twilio
|
||||
phone numbers with a Pipecat bot.
|
||||
|
||||
- Updated `07f-interruptible-azure.py` to use `AzureLLMService`,
|
||||
`AzureSTTService` and `AzureTTSService`.
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ pip install "pipecat-ai[option,...]"
|
||||
|
||||
Your project may or may not need these, so they're made available as optional requirements. Here is a list:
|
||||
|
||||
- **AI services**: `anthropic`, `azure`, `deepgram`, `google`, `fal`, `moondream`, `openai`, `openpipe`, `playht`, `silero`, `whisper`
|
||||
- **AI services**: `anthropic`, `azure`, `deepgram`, `gladia`, `google`, `fal`, `moondream`, `openai`, `openpipe`, `playht`, `silero`, `whisper`, `xtts`
|
||||
- **Transports**: `local`, `websocket`, `daily`
|
||||
|
||||
## Code examples
|
||||
@@ -70,8 +70,8 @@ async def main():
|
||||
transport = DailyTransport(
|
||||
room_url=...,
|
||||
token=...,
|
||||
"Bot Name",
|
||||
DailyParams(audio_out_enabled=True))
|
||||
bot_name="Bot Name",
|
||||
params=DailyParams(audio_out_enabled=True))
|
||||
|
||||
# Use Eleven Labs for Text-to-Speech
|
||||
tts = ElevenLabsTTSService(
|
||||
@@ -125,7 +125,7 @@ Sign up [here](https://dashboard.daily.co/u/signup) and [create a room](https://
|
||||
|
||||
Voice Activity Detection — very important for knowing when a user has finished speaking to your bot. If you are not using press-to-talk, and want Pipecat to detect when the user has finished talking, VAD is an essential component for a natural feeling conversation.
|
||||
|
||||
Pipecast makes use of WebRTC VAD by default when using a WebRTC transport layer. Optionally, you can use Silero VAD for improved accuracy at the cost of higher CPU usage.
|
||||
Pipecat makes use of WebRTC VAD by default when using a WebRTC transport layer. Optionally, you can use Silero VAD for improved accuracy at the cost of higher CPU usage.
|
||||
|
||||
```shell
|
||||
pip install pipecat-ai[silero]
|
||||
|
||||
@@ -2,6 +2,7 @@ autopep8~=2.1.0
|
||||
build~=1.2.1
|
||||
grpcio-tools~=1.62.2
|
||||
pip-tools~=7.4.1
|
||||
pyright~=1.1.367
|
||||
pytest~=8.2.0
|
||||
setuptools~=69.5.1
|
||||
setuptools~=71.1.0
|
||||
setuptools_scm~=8.1.0
|
||||
|
||||
@@ -27,6 +27,9 @@ FAL_KEY=...
|
||||
# Fireworks
|
||||
FIREWORKS_API_KEY=...
|
||||
|
||||
# Gladia
|
||||
GLADIA_API_KEY=...
|
||||
|
||||
# PlayHT
|
||||
PLAY_HT_USER_ID=...
|
||||
PLAY_HT_API_KEY=...
|
||||
|
||||
@@ -32,14 +32,15 @@ Next, follow the steps in the README for each demo.
|
||||
|
||||
## Projects:
|
||||
|
||||
| Project | Description | Services |
|
||||
| -------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------ | ---------------------------------------------- |
|
||||
| [Simple Chatbot](simple-chatbot) | Basic voice-driven conversational bot. A good starting point for learning the flow of the framework. | Deepgram, OpenAI, Daily, Daily Prebuilt UI |
|
||||
| [Storytelling Chatbot](storytelling-chatbot) | Stitches together multiple third-party services to create a collaborative storytime experience. | Deepgram, ElevenLabs, Open AI, Fal, Daily, Custom UI |
|
||||
| [Translation Chatbot](translation-chatbot) | Listens for user speech, then translates that speech to Spanish and speaks the translation back. Demonstrates multi-participant use-cases. | Deepgram, Azure, OpenAI, Daily, Daily Prebuilt UI |
|
||||
| [Moondream Chatbot](moondream-chatbot) | Demonstrates how to add vision capabilities to GPT4. **Note: works best with a GPU** | Deepgram, OpenAI, Moondream, Daily, Daily Prebuilt UI |
|
||||
| Function-calling Chatbot (TBC) | A chatbot that can call functions in response to user input. | Deepgram, OpenAI, Fireworks, Daily, Daily Prebuilt UI |
|
||||
| [Dialin Chatbot](dialin-chatbot) | A chatbot that connects to an incoming phone call from Daily or Twilio. | Deepgram, OpenAI, ElevenLabs, Daily, Twilio |
|
||||
| Project | Description | Services |
|
||||
|----------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------|
|
||||
| [Simple Chatbot](simple-chatbot) | Basic voice-driven conversational bot. A good starting point for learning the flow of the framework. | Deepgram, ElevenLabs, OpenAI, Daily, Daily Prebuilt UI |
|
||||
| [Storytelling Chatbot](storytelling-chatbot) | Stitches together multiple third-party services to create a collaborative storytime experience. | Deepgram, ElevenLabs, OpenAI, Fal, Daily, Custom UI |
|
||||
| [Translation Chatbot](translation-chatbot) | Listens for user speech, then translates that speech to Spanish and speaks the translation back. Demonstrates multi-participant use-cases. | Deepgram, Azure, OpenAI, Daily, Daily Prebuilt UI |
|
||||
| [Moondream Chatbot](moondream-chatbot) | Demonstrates how to add vision capabilities to GPT4. **Note: works best with a GPU** | Deepgram, ElevenLabs, OpenAI, Moondream, Daily, Daily Prebuilt UI |
|
||||
| [Patient intake](patient-intake) | A chatbot that can call functions in response to user input. | Deepgram, ElevenLabs, OpenAI, Daily, Daily Prebuilt UI |
|
||||
| [Dialin Chatbot](dialin-chatbot) | A chatbot that connects to an incoming phone call from Daily or Twilio. | Deepgram, ElevenLabs, OpenAI, Daily, Twilio |
|
||||
| [Twilio Chatbot](twilio-chatbot) | A chatbot that connects to an incoming phone call from Twilio. | Deepgram, ElevenLabs, OpenAI, Daily, Twilio |
|
||||
|
||||
> [!IMPORTANT]
|
||||
> These example projects use Daily as a WebRTC transport and can be joined using their hosted Prebuilt UI.
|
||||
|
||||
16
examples/deployment/flyio-example/Dockerfile
Normal file
16
examples/deployment/flyio-example/Dockerfile
Normal file
@@ -0,0 +1,16 @@
|
||||
FROM python:3.11-bullseye
|
||||
|
||||
# Open port 7860 for http service
|
||||
ENV FAST_API_PORT=7860
|
||||
EXPOSE 7860
|
||||
|
||||
# Install Python dependencies
|
||||
COPY *.py .
|
||||
COPY ./requirements.txt requirements.txt
|
||||
RUN pip3 install --no-cache-dir --upgrade -r requirements.txt
|
||||
|
||||
# Install models
|
||||
RUN python3 install_deps.py
|
||||
|
||||
# Start the FastAPI server
|
||||
CMD python3 bot_runner.py --port ${FAST_API_PORT}
|
||||
43
examples/deployment/flyio-example/README.md
Normal file
43
examples/deployment/flyio-example/README.md
Normal file
@@ -0,0 +1,43 @@
|
||||
# Fly.io deployment example
|
||||
|
||||
This project modifies the `bot_runner.py` server to launch a new machine for each user session. This is a recommended approach for production vs. running shell processess as your deployment will quickly run out of system resources under load.
|
||||
|
||||
To speed up machine boot times, we also download and cache Silero VAD as part of the Dockerfile (`install_deps.py`). If you are using other custom models, you can add them here too.
|
||||
|
||||
For this example, we are using Daily as a WebRTC transport and provisioning a new room and token for each session. You can use another transport, such as WebSockets, by modifying the `bot.py` and `bot_runner.py` files accordingly.
|
||||
|
||||
## Setting up your fly.io deployment
|
||||
|
||||
### Create your fly.toml file
|
||||
|
||||
You can copy the `example-fly.toml` as a reference. Be sure to change the app name to something unique.
|
||||
|
||||
### Create your .env file
|
||||
|
||||
Copy the base `env.example` to `.env` and enter the necessary API keys.
|
||||
|
||||
`FLY_APP_NAME` should match that in the `fly.toml` file.
|
||||
|
||||
### Launch a new fly.io project
|
||||
|
||||
`fly launch` or `fly launch --org your-org-name`
|
||||
|
||||
### Set the necessary app secrets from your .env
|
||||
|
||||
Note: you can do this manually via the fly.io dashboard under the "secrets" sub-section of your deployment (e.g. "https://fly.io/apps/fly-app-name/secrets") or run the following terminal command:
|
||||
|
||||
`cat .env | tr '\n' ' ' | xargs flyctl secrets set`
|
||||
|
||||
### Deploy your machine
|
||||
|
||||
`fly deploy`
|
||||
|
||||
|
||||
## Connecting to your bot
|
||||
|
||||
Send a post request to your running fly.io instance:
|
||||
|
||||
`curl --location --request POST 'https://YOUR_FLY_APP_NAME/start_bot'`
|
||||
|
||||
This request will wait until the machine enters into a `starting` state, before returning the a room URL and token to join.
|
||||
|
||||
0
examples/deployment/flyio-example/__init__.py
Normal file
0
examples/deployment/flyio-example/__init__.py
Normal file
103
examples/deployment/flyio-example/bot.py
Normal file
103
examples/deployment/flyio-example/bot.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_response import LLMAssistantResponseAggregator, LLMUserResponseAggregator
|
||||
from pipecat.frames.frames import LLMMessagesFrame, EndFrame
|
||||
from pipecat.services.openai import OpenAILLMService
|
||||
from pipecat.services.elevenlabs import ElevenLabsTTSService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
from pipecat.vad.silero import SileroVADAnalyzer
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
daily_api_key = os.getenv("DAILY_API_KEY", "")
|
||||
daily_api_url = os.getenv("DAILY_API_URL", "https://api.daily.co/v1")
|
||||
|
||||
|
||||
async def main(room_url: str, token: str):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
transport = DailyTransport(
|
||||
room_url,
|
||||
token,
|
||||
"Chatbot",
|
||||
DailyParams(
|
||||
api_url=daily_api_url,
|
||||
api_key=daily_api_key,
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
camera_out_enabled=False,
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
transcription_enabled=True,
|
||||
)
|
||||
)
|
||||
|
||||
tts = ElevenLabsTTSService(
|
||||
aiohttp_session=session,
|
||||
api_key=os.getenv("ELEVENLABS_API_KEY", ""),
|
||||
voice_id=os.getenv("ELEVENLABS_VOICE_ID", ""),
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
model="gpt-4o")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are Chatbot, a friendly, helpful robot. Your output will be converted to audio so don't include special characters other than '!' or '?' in your answers. Respond to what the user said in a creative and helpful way, but keep your responses brief. Start by saying hello.",
|
||||
},
|
||||
]
|
||||
|
||||
tma_in = LLMUserResponseAggregator(messages)
|
||||
tma_out = LLMAssistantResponseAggregator(messages)
|
||||
|
||||
pipeline = Pipeline([
|
||||
transport.input(),
|
||||
tma_in,
|
||||
llm,
|
||||
tts,
|
||||
transport.output(),
|
||||
tma_out,
|
||||
])
|
||||
|
||||
task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True))
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
transport.capture_participant_transcription(participant["id"])
|
||||
await task.queue_frames([LLMMessagesFrame(messages)])
|
||||
|
||||
@transport.event_handler("on_participant_left")
|
||||
async def on_participant_left(transport, participant, reason):
|
||||
await task.queue_frame(EndFrame())
|
||||
|
||||
@transport.event_handler("on_call_state_updated")
|
||||
async def on_call_state_updated(transport, state):
|
||||
if state == "left":
|
||||
await task.queue_frame(EndFrame())
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Pipecat Bot")
|
||||
parser.add_argument("-u", type=str, help="Room URL")
|
||||
parser.add_argument("-t", type=str, help="Token")
|
||||
config = parser.parse_args()
|
||||
|
||||
asyncio.run(main(config.u, config.t))
|
||||
199
examples/deployment/flyio-example/bot_runner.py
Normal file
199
examples/deployment/flyio-example/bot_runner.py
Normal file
@@ -0,0 +1,199 @@
|
||||
import os
|
||||
import argparse
|
||||
import subprocess
|
||||
import requests
|
||||
|
||||
from pipecat.transports.services.helpers.daily_rest import DailyRESTHelper, DailyRoomObject, DailyRoomProperties, DailyRoomParams
|
||||
|
||||
from fastapi import FastAPI, Request, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
# ------------ Configuration ------------ #
|
||||
|
||||
MAX_SESSION_TIME = 5 * 60 # 5 minutes
|
||||
REQUIRED_ENV_VARS = [
|
||||
'DAILY_API_KEY',
|
||||
'OPENAI_API_KEY',
|
||||
'ELEVENLABS_API_KEY',
|
||||
'ELEVENLABS_VOICE_ID',
|
||||
'FLY_API_KEY',
|
||||
'FLY_APP_NAME',]
|
||||
|
||||
FLY_API_HOST = os.getenv("FLY_API_HOST", "https://api.machines.dev/v1")
|
||||
FLY_APP_NAME = os.getenv("FLY_APP_NAME", "pipecat-fly-example")
|
||||
FLY_API_KEY = os.getenv("FLY_API_KEY", "")
|
||||
FLY_HEADERS = {
|
||||
'Authorization': f"Bearer {FLY_API_KEY}",
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
daily_rest_helper = DailyRESTHelper(
|
||||
os.getenv("DAILY_API_KEY", ""),
|
||||
os.getenv("DAILY_API_URL", 'https://api.daily.co/v1'))
|
||||
|
||||
|
||||
# ----------------- API ----------------- #
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"]
|
||||
)
|
||||
|
||||
# ----------------- Main ----------------- #
|
||||
|
||||
|
||||
def spawn_fly_machine(room_url: str, token: str):
|
||||
# Use the same image as the bot runner
|
||||
res = requests.get(f"{FLY_API_HOST}/apps/{FLY_APP_NAME}/machines", headers=FLY_HEADERS)
|
||||
if res.status_code != 200:
|
||||
raise Exception(f"Unable to get machine info from Fly: {res.text}")
|
||||
image = res.json()[0]['config']['image']
|
||||
|
||||
# Machine configuration
|
||||
cmd = f"python3 bot.py -u {room_url} -t {token}"
|
||||
cmd = cmd.split()
|
||||
worker_props = {
|
||||
"config": {
|
||||
"image": image,
|
||||
"auto_destroy": True,
|
||||
"init": {
|
||||
"cmd": cmd
|
||||
},
|
||||
"restart": {
|
||||
"policy": "no"
|
||||
},
|
||||
"guest": {
|
||||
"cpu_kind": "shared",
|
||||
"cpus": 1,
|
||||
"memory_mb": 1024
|
||||
}
|
||||
},
|
||||
|
||||
}
|
||||
|
||||
# Spawn a new machine instance
|
||||
res = requests.post(
|
||||
f"{FLY_API_HOST}/apps/{FLY_APP_NAME}/machines",
|
||||
headers=FLY_HEADERS,
|
||||
json=worker_props)
|
||||
|
||||
if res.status_code != 200:
|
||||
raise Exception(f"Problem starting a bot worker: {res.text}")
|
||||
|
||||
# Wait for the machine to enter the started state
|
||||
vm_id = res.json()['id']
|
||||
|
||||
res = requests.get(
|
||||
f"{FLY_API_HOST}/apps/{FLY_APP_NAME}/machines/{vm_id}/wait?state=started",
|
||||
headers=FLY_HEADERS)
|
||||
|
||||
if res.status_code != 200:
|
||||
raise Exception(f"Bot was unable to enter started state: {res.text}")
|
||||
|
||||
print(f"Machine joined room: {room_url}")
|
||||
|
||||
|
||||
@app.post("/start_bot")
|
||||
async def start_bot(request: Request) -> JSONResponse:
|
||||
try:
|
||||
data = await request.json()
|
||||
# Is this a webhook creation request?
|
||||
if "test" in data:
|
||||
return JSONResponse({"test": True})
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
# Use specified room URL, or create a new one if not specified
|
||||
room_url = os.getenv("DAILY_SAMPLE_ROOM_URL", "")
|
||||
|
||||
if not room_url:
|
||||
params = DailyRoomParams(
|
||||
properties=DailyRoomProperties()
|
||||
)
|
||||
try:
|
||||
room: DailyRoomObject = daily_rest_helper.create_room(params=params)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Unable to provision room {e}")
|
||||
else:
|
||||
# Check passed room URL exists, we should assume that it already has a sip set up
|
||||
try:
|
||||
room: DailyRoomObject = daily_rest_helper.get_room_from_url(room_url)
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Room not found: {room_url}")
|
||||
|
||||
# Give the agent a token to join the session
|
||||
token = daily_rest_helper.get_token(room.url, MAX_SESSION_TIME)
|
||||
|
||||
if not room or not token:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to get token for room: {room_url}")
|
||||
|
||||
# Launch a new fly.io machine, or run as a shell process (not recommended)
|
||||
run_as_process = os.getenv("RUN_AS_PROCESS", False)
|
||||
|
||||
if run_as_process:
|
||||
try:
|
||||
subprocess.Popen(
|
||||
[f"python3 -m bot -u {room.url} -t {token}"],
|
||||
shell=True,
|
||||
bufsize=1,
|
||||
cwd=os.path.dirname(os.path.abspath(__file__)))
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to start subprocess: {e}")
|
||||
else:
|
||||
try:
|
||||
spawn_fly_machine(room.url, token)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to spawn VM: {e}")
|
||||
|
||||
# Grab a token for the user to join with
|
||||
user_token = daily_rest_helper.get_token(room.url, MAX_SESSION_TIME)
|
||||
|
||||
return JSONResponse({
|
||||
"room_url": room.url,
|
||||
"token": user_token,
|
||||
})
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Check environment variables
|
||||
for env_var in REQUIRED_ENV_VARS:
|
||||
if env_var not in os.environ:
|
||||
raise Exception(f"Missing environment variable: {env_var}.")
|
||||
|
||||
parser = argparse.ArgumentParser(description="Pipecat Bot Runner")
|
||||
parser.add_argument("--host", type=str,
|
||||
default=os.getenv("HOST", "0.0.0.0"), help="Host address")
|
||||
parser.add_argument("--port", type=int,
|
||||
default=os.getenv("PORT", 7860), help="Port number")
|
||||
parser.add_argument("--reload", action="store_true",
|
||||
default=False, help="Reload code on change")
|
||||
|
||||
config = parser.parse_args()
|
||||
|
||||
try:
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(
|
||||
"bot_runner:app",
|
||||
host=config.host,
|
||||
port=config.port,
|
||||
reload=config.reload
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("Pipecat runner shutting down...")
|
||||
8
examples/deployment/flyio-example/env.example
Normal file
8
examples/deployment/flyio-example/env.example
Normal file
@@ -0,0 +1,8 @@
|
||||
DAILY_API_KEY=
|
||||
DAILY_SAMPLE_ROOM_URL= # Enter a Daily room URL to use a set room URL each time (useful for local testing)
|
||||
OPENAI_API_KEY=
|
||||
ELEVENLABS_API_KEY=
|
||||
ELEVENLABS_VOICE_ID=
|
||||
FLY_API_KEY=
|
||||
FLY_APP_NAME=
|
||||
RUN_AS_PROCESS= # Spawn fly.io machine for each session or run as local process
|
||||
25
examples/deployment/flyio-example/example-fly.toml
Normal file
25
examples/deployment/flyio-example/example-fly.toml
Normal file
@@ -0,0 +1,25 @@
|
||||
# fly.toml app configuration file generated for pipecat-fly-example on 2024-07-01T15:04:53+01:00
|
||||
#
|
||||
# See https://fly.io/docs/reference/configuration/ for information about how to use this file.
|
||||
#
|
||||
|
||||
app = 'pipecat-fly-example'
|
||||
primary_region = 'sjc'
|
||||
|
||||
[build]
|
||||
|
||||
[env]
|
||||
FLY_APP_NAME = 'pipecat-fly-example'
|
||||
|
||||
[http_service]
|
||||
internal_port = 7860
|
||||
force_https = true
|
||||
auto_stop_machines = true
|
||||
auto_start_machines = true
|
||||
min_machines_running = 0
|
||||
processes = ['app']
|
||||
|
||||
[[vm]]
|
||||
memory = 512
|
||||
cpu_kind = 'shared'
|
||||
cpus = 1
|
||||
4
examples/deployment/flyio-example/install_deps.py
Normal file
4
examples/deployment/flyio-example/install_deps.py
Normal file
@@ -0,0 +1,4 @@
|
||||
import torch
|
||||
|
||||
# Download (cache) the Silero VAD model
|
||||
torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad', force_reload=True)
|
||||
6
examples/deployment/flyio-example/requirements.txt
Normal file
6
examples/deployment/flyio-example/requirements.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
pipecat-ai[daily,openai,silero]
|
||||
fastapi
|
||||
uvicorn
|
||||
requests
|
||||
python-dotenv
|
||||
loguru
|
||||
@@ -51,7 +51,7 @@ class ImageSyncAggregator(FrameProcessor):
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if not isinstance(frame, SystemFrame):
|
||||
if not isinstance(frame, SystemFrame) and direction == FrameDirection.DOWNSTREAM:
|
||||
await self.push_frame(ImageRawFrame(image=self._speaking_image_bytes, size=(1024, 1024), format=self._speaking_image_format))
|
||||
await self.push_frame(frame)
|
||||
await self.push_frame(ImageRawFrame(image=self._waiting_image_bytes, size=(1024, 1024), format=self._waiting_image_format))
|
||||
@@ -67,11 +67,12 @@ async def main(room_url: str, token):
|
||||
"Respond bot",
|
||||
DailyParams(
|
||||
audio_out_enabled=True,
|
||||
camera_out_enabled=True,
|
||||
camera_out_width=1024,
|
||||
camera_out_height=1024,
|
||||
transcription_enabled=True,
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer()
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -116,7 +117,7 @@ async def main(room_url: str, token):
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
participant_name = participant["info"]["userName"] or ''
|
||||
transport.capture_participant_transcription(participant["id"])
|
||||
await task.queue_frames([TextFrame(f"Hi, this is {participant_name}.")])
|
||||
await task.queue_frames([TextFrame(f"Hi there {participant_name}!")])
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
|
||||
@@ -37,8 +37,8 @@ async def main(room_url: str, token):
|
||||
token,
|
||||
"Respond bot",
|
||||
DailyParams(
|
||||
audio_out_enabled=True,
|
||||
audio_out_sample_rate=44100,
|
||||
audio_out_enabled=True,
|
||||
transcription_enabled=True,
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer()
|
||||
@@ -47,8 +47,8 @@ async def main(room_url: str, token):
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_name="British Lady",
|
||||
output_format="pcm_44100"
|
||||
voice_id="a0e99841-438c-4a64-b679-ae501e7d6091", # Barbershop Man
|
||||
sample_rate=44100,
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(
|
||||
@@ -70,11 +70,11 @@ async def main(room_url: str, token):
|
||||
tma_in, # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
tma_out, # Goes before the transport because cartesia has word-level timestamps!
|
||||
transport.output(), # Transport bot output
|
||||
tma_out # Assistant spoken responses
|
||||
])
|
||||
|
||||
task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True))
|
||||
task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True, enable_metrics=True))
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
|
||||
96
examples/foundational/07i-interruptible-xtts.py
Normal file
96
examples/foundational/07i-interruptible-xtts.py
Normal file
@@ -0,0 +1,96 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import os
|
||||
import sys
|
||||
|
||||
from pipecat.frames.frames import LLMMessagesFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantResponseAggregator, LLMUserResponseAggregator)
|
||||
from pipecat.services.deepgram import DeepgramSTTService, DeepgramTTSService
|
||||
from pipecat.services.openai import OpenAILLMService
|
||||
from pipecat.services.xtts import XTTSService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
from pipecat.vad.silero import SileroVADAnalyzer
|
||||
|
||||
from runner import configure
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
async def main(room_url: str, token):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
transport = DailyTransport(
|
||||
room_url,
|
||||
token,
|
||||
"Respond bot",
|
||||
DailyParams(
|
||||
audio_out_enabled=True,
|
||||
transcription_enabled=True,
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
)
|
||||
)
|
||||
|
||||
tts = XTTSService(
|
||||
aiohttp_session=session,
|
||||
voice_id="Claribel Dervla",
|
||||
language="en",
|
||||
base_url="http://localhost:8000"
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
model="gpt-4o")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
tma_in = LLMUserResponseAggregator(messages)
|
||||
tma_out = LLMAssistantResponseAggregator(messages)
|
||||
|
||||
pipeline = Pipeline([
|
||||
transport.input(), # Transport user input
|
||||
tma_in, # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
tma_out # Assistant spoken responses
|
||||
])
|
||||
|
||||
task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True))
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
transport.capture_participant_transcription(participant["id"])
|
||||
# Kick off the conversation.
|
||||
messages.append(
|
||||
{"role": "system", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([LLMMessagesFrame(messages)])
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
(url, token) = configure()
|
||||
asyncio.run(main(url, token))
|
||||
101
examples/foundational/07j-interruptible-gladia.py
Normal file
101
examples/foundational/07j-interruptible-gladia.py
Normal file
@@ -0,0 +1,101 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import os
|
||||
import sys
|
||||
|
||||
from pipecat.frames.frames import LLMMessagesFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantResponseAggregator, LLMUserResponseAggregator)
|
||||
from pipecat.services.deepgram import DeepgramSTTService, DeepgramTTSService
|
||||
from pipecat.services.gladia import GladiaSTTService
|
||||
from pipecat.services.openai import OpenAILLMService
|
||||
from pipecat.services.xtts import XTTSService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
from pipecat.vad.silero import SileroVADAnalyzer
|
||||
|
||||
from runner import configure
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
async def main(room_url: str, token):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
transport = DailyTransport(
|
||||
room_url,
|
||||
token,
|
||||
"Respond bot",
|
||||
DailyParams(
|
||||
audio_out_enabled=True,
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
vad_audio_passthrough=True,
|
||||
)
|
||||
)
|
||||
|
||||
stt = GladiaSTTService(
|
||||
api_key=os.getenv("GLADIA_API_KEY"),
|
||||
)
|
||||
|
||||
tts = DeepgramTTSService(
|
||||
aiohttp_session=session,
|
||||
api_key=os.getenv("DEEPGRAM_API_KEY"),
|
||||
voice="aura-helios-en"
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
model="gpt-4o")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
tma_in = LLMUserResponseAggregator(messages)
|
||||
tma_out = LLMAssistantResponseAggregator(messages)
|
||||
|
||||
pipeline = Pipeline([
|
||||
transport.input(), # Transport user input
|
||||
stt, # STT
|
||||
tma_in, # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
tma_out # Assistant spoken responses
|
||||
])
|
||||
|
||||
task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True))
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
transport.capture_participant_transcription(participant["id"])
|
||||
# Kick off the conversation.
|
||||
messages.append(
|
||||
{"role": "system", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([LLMMessagesFrame(messages)])
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
(url, token) = configure()
|
||||
asyncio.run(main(url, token))
|
||||
@@ -66,7 +66,6 @@ async def main(room_url: str, token):
|
||||
"Pipecat",
|
||||
DailyParams(
|
||||
audio_out_enabled=True,
|
||||
audio_out_sample_rate=44100,
|
||||
transcription_enabled=True,
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer()
|
||||
@@ -75,20 +74,17 @@ async def main(room_url: str, token):
|
||||
|
||||
news_lady = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_name="Newslady",
|
||||
output_format="pcm_44100"
|
||||
voice_id="bf991597-6c13-47e4-8411-91ec2de5c466", # Newslady
|
||||
)
|
||||
|
||||
british_lady = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_name="British Lady",
|
||||
output_format="pcm_44100"
|
||||
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
|
||||
)
|
||||
|
||||
barbershop_man = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_name="Barbershop Man",
|
||||
output_format="pcm_44100"
|
||||
voice_id="a0e99841-438c-4a64-b679-ae501e7d6091", # Barbershop Man
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(
|
||||
|
||||
108
examples/foundational/17-detect-user-idle.py
Normal file
108
examples/foundational/17-detect-user-idle.py
Normal file
@@ -0,0 +1,108 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import os
|
||||
import sys
|
||||
|
||||
from pipecat.frames.frames import LLMMessagesFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantResponseAggregator, LLMUserResponseAggregator)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.processors.user_idle_processor import UserIdleProcessor
|
||||
from pipecat.services.elevenlabs import ElevenLabsTTSService
|
||||
from pipecat.services.openai import OpenAILLMService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
from pipecat.vad.silero import SileroVADAnalyzer
|
||||
|
||||
from runner import configure
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
async def main(room_url: str, token):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
transport = DailyTransport(
|
||||
room_url,
|
||||
token,
|
||||
"Respond bot",
|
||||
DailyParams(
|
||||
audio_out_enabled=True,
|
||||
transcription_enabled=True,
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer()
|
||||
)
|
||||
)
|
||||
|
||||
tts = ElevenLabsTTSService(
|
||||
aiohttp_session=session,
|
||||
api_key=os.getenv("ELEVENLABS_API_KEY"),
|
||||
voice_id=os.getenv("ELEVENLABS_VOICE_ID"),
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
model="gpt-4o")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
tma_in = LLMUserResponseAggregator(messages)
|
||||
tma_out = LLMAssistantResponseAggregator(messages)
|
||||
|
||||
async def user_idle_callback(user_idle: UserIdleProcessor):
|
||||
messages.append(
|
||||
{"role": "system", "content": "Ask the user if they are still there and try to prompt for some input, but be short."})
|
||||
await user_idle.queue_frame(LLMMessagesFrame(messages))
|
||||
|
||||
user_idle = UserIdleProcessor(callback=user_idle_callback, timeout=5.0)
|
||||
|
||||
pipeline = Pipeline([
|
||||
transport.input(), # Transport user input
|
||||
user_idle, # Idle user check-in
|
||||
tma_in, # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
tma_out # Assistant spoken responses
|
||||
])
|
||||
|
||||
task = PipelineTask(pipeline, PipelineParams(
|
||||
allow_interruptions=True,
|
||||
enable_metrics=True,
|
||||
report_only_initial_ttfb=True,
|
||||
))
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
transport.capture_participant_transcription(participant["id"])
|
||||
# Kick off the conversation.
|
||||
messages.append(
|
||||
{"role": "system", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([LLMMessagesFrame(messages)])
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
(url, token) = configure()
|
||||
asyncio.run(main(url, token))
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM python:3.11-bullseye
|
||||
FROM python:3.11-slim-bookworm
|
||||
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
ARG USE_PERSISTENT_DATA
|
||||
@@ -51,4 +51,4 @@ COPY --chown=user ./frontend/ frontend/
|
||||
RUN cd frontend && npm install && npm run build
|
||||
|
||||
# Start the FastAPI server
|
||||
CMD python3 src/server.py --port ${FAST_API_PORT}
|
||||
CMD python3 src/bot_runner.py --port ${FAST_API_PORT}
|
||||
@@ -48,6 +48,8 @@ pip install -r requirements.txt
|
||||
mv env.example .env
|
||||
```
|
||||
|
||||
When deploying to production, to ensure only this app can spawn a new bot, set your `ENV` to `production`
|
||||
|
||||
**Build the frontend:**
|
||||
|
||||
This project uses a custom frontend, which needs to built. Note: this is done automatically as part of the Docker deployment.
|
||||
@@ -64,11 +66,11 @@ The build UI files can be found in `frontend/out`
|
||||
|
||||
Start the API / bot manager:
|
||||
|
||||
`python src/server.py`
|
||||
`python src/bot_runner.py`
|
||||
|
||||
If you'd like to run a custom domain or port:
|
||||
|
||||
`python src/server.py --host somehost --p 7777`
|
||||
`python src/bot_runner.py --host somehost --p someport`
|
||||
|
||||
➡️ Open the host URL in your browser `http://localhost:7860`
|
||||
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
DAILY_API_KEY=7df...
|
||||
ELEVENLABS_API_KEY=aeb...
|
||||
ELEVENLABS_VOICE_ID=7S...
|
||||
FAL_KEY=8c...
|
||||
OPENAI_API_KEY=sk-PL...
|
||||
DAILY_API_KEY=
|
||||
DAILY_SAMPLE_ROOM_URL=
|
||||
ELEVENLABS_API_KEY=
|
||||
ELEVENLABS_VOICE_ID=
|
||||
FAL_KEY=
|
||||
OPENAI_API_KEY=
|
||||
|
||||
ENV= # dev | production
|
||||
RUN_AS_VM= # Set this if you want to run bots on process (not launch a new VM)
|
||||
@@ -27,14 +27,11 @@ export default function Call() {
|
||||
|
||||
// Create a new room for the story session
|
||||
try {
|
||||
const response = await fetch("/create", {
|
||||
const response = await fetch("/start_bot", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
room_url: process.env.NEXT_PUBLIC_ROOM_URL || null,
|
||||
}),
|
||||
});
|
||||
|
||||
const { room_url, token } = await response.json();
|
||||
@@ -55,21 +52,9 @@ export default function Call() {
|
||||
// Disable local audio, the bot will say hello first
|
||||
daily.setLocalAudio(false);
|
||||
|
||||
// Start the bot
|
||||
const resp = await fetch("/start", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
room_url,
|
||||
}),
|
||||
});
|
||||
|
||||
setState("started");
|
||||
} catch (error) {
|
||||
setState("error");
|
||||
leave();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -79,7 +64,13 @@ export default function Call() {
|
||||
}
|
||||
|
||||
if (state === "error") {
|
||||
return <div>An Error occured</div>;
|
||||
return (
|
||||
<div className="flex items-center mx-auto">
|
||||
<p className="text-red-500 font-semibold bg-white px-4 py-2 shadow-xl rounded-lg">
|
||||
This demo is currently at capacity. Please try again later.
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (state === "started") {
|
||||
|
||||
@@ -108,26 +108,26 @@ export default function DevicePicker({}: Props) {
|
||||
{hasMicError && (
|
||||
<div className="error">
|
||||
{micState === "blocked" ? (
|
||||
<p>
|
||||
<p className="text-red-500">
|
||||
Please check your browser and system permissions. Make sure that
|
||||
this app is allowed to access your microphone.
|
||||
</p>
|
||||
) : micState === "in-use" ? (
|
||||
<p>
|
||||
<p className="text-red-500">
|
||||
Your microphone is being used by another app. Please close any
|
||||
other apps using your microphone and restart this app.
|
||||
</p>
|
||||
) : micState === "not-found" ? (
|
||||
<p>
|
||||
<p className="text-red-500">
|
||||
No microphone seems to be connected. Please connect a microphone.
|
||||
</p>
|
||||
) : micState === "not-supported" ? (
|
||||
<p>
|
||||
<p className="text-red-500">
|
||||
This app is not supported on your device. Please update your
|
||||
software or use a different device.
|
||||
</p>
|
||||
) : (
|
||||
<p>
|
||||
<p className="text-red-500">
|
||||
There seems to be an issue accessing your microphone. Try
|
||||
restarting the app or consult a system administrator.
|
||||
</p>
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import React from "react";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import DevicePicker from "@/components/DevicePicker";
|
||||
import { IconEar, IconLoader2 } from "@tabler/icons-react";
|
||||
import { IconAlertCircle, IconEar, IconLoader2 } from "@tabler/icons-react";
|
||||
|
||||
type SetupProps = {
|
||||
handleStart: () => void;
|
||||
@@ -24,7 +24,6 @@ export const Setup: React.FC<SetupProps> = ({ handleStart }) => {
|
||||
<h1 className="text-4xl font-bold text-pretty tracking-tighter mb-4">
|
||||
Welcome to <span className="text-sky-500">Storytime</span>
|
||||
</h1>
|
||||
|
||||
{state === "intro" ? (
|
||||
<>
|
||||
<p className="text-gray-600 leading-relaxed text-pretty">
|
||||
@@ -38,6 +37,9 @@ export const Setup: React.FC<SetupProps> = ({ handleStart }) => {
|
||||
<IconEar size={24} /> For best results, try in a quiet
|
||||
environment!
|
||||
</p>
|
||||
<p className="flex flex-row gap-2 text-gray-600 font-medium text-red-500">
|
||||
<IconAlertCircle size={24} /> This demo expires after 5 minutes.
|
||||
</p>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
@@ -49,7 +51,6 @@ export const Setup: React.FC<SetupProps> = ({ handleStart }) => {
|
||||
<DevicePicker />
|
||||
</>
|
||||
)}
|
||||
|
||||
<hr className="border-gray-150 my-2" />
|
||||
|
||||
<Button
|
||||
|
||||
@@ -1,2 +1 @@
|
||||
NEXT_PUBLIC_ROOM_URL=
|
||||
SITE_URL=
|
||||
@@ -899,11 +899,11 @@ brace-expansion@^2.0.1:
|
||||
balanced-match "^1.0.0"
|
||||
|
||||
braces@^3.0.2, braces@~3.0.2:
|
||||
version "3.0.2"
|
||||
resolved "https://registry.yarnpkg.com/braces/-/braces-3.0.2.tgz#3454e1a462ee8d599e236df336cd9ea4f8afe107"
|
||||
integrity sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A==
|
||||
version "3.0.3"
|
||||
resolved "https://registry.yarnpkg.com/braces/-/braces-3.0.3.tgz#490332f40919452272d55a8480adc0c441358789"
|
||||
integrity "sha1-SQMy9AkZRSJy1VqEgK3AxEE1h4k= sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA=="
|
||||
dependencies:
|
||||
fill-range "^7.0.1"
|
||||
fill-range "^7.1.1"
|
||||
|
||||
browserslist@^4.23.0:
|
||||
version "4.23.0"
|
||||
@@ -1551,10 +1551,10 @@ file-entry-cache@^6.0.1:
|
||||
dependencies:
|
||||
flat-cache "^3.0.4"
|
||||
|
||||
fill-range@^7.0.1:
|
||||
version "7.0.1"
|
||||
resolved "https://registry.yarnpkg.com/fill-range/-/fill-range-7.0.1.tgz#1919a6a7c75fe38b2c7c77e5198535da9acdda40"
|
||||
integrity sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ==
|
||||
fill-range@^7.1.1:
|
||||
version "7.1.1"
|
||||
resolved "https://registry.yarnpkg.com/fill-range/-/fill-range-7.1.1.tgz#44265d3cac07e3ea7dc247516380643754a05292"
|
||||
integrity "sha1-RCZdPKwH4+p9wkdRY4BkN1SgUpI= sha512-YsGpe3WHLK8ZYi4tWDg2Jy3ebRz2rXowDxnld4bkQB00cc/1Zw9AWnC0i9ztDJitivtQvaI9KaLyKrc+hBW0yg=="
|
||||
dependencies:
|
||||
to-regex-range "^5.0.1"
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import os
|
||||
import sys
|
||||
|
||||
|
||||
from pipecat.frames.frames import LLMMessagesFrame, StopTaskFrame
|
||||
from pipecat.frames.frames import LLMMessagesFrame, StopTaskFrame, EndFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
@@ -139,6 +139,16 @@ async def main(room_url, token=None):
|
||||
|
||||
main_task = PipelineTask(main_pipeline)
|
||||
|
||||
@transport.event_handler("on_participant_left")
|
||||
async def on_participant_left(transport, participant, reason):
|
||||
intro_task.queue_frame(EndFrame())
|
||||
await main_task.queue_frame(EndFrame())
|
||||
|
||||
@transport.event_handler("on_call_state_updated")
|
||||
async def on_call_state_updated(transport, state):
|
||||
if state == "left":
|
||||
await main_task.queue_frame(EndFrame())
|
||||
|
||||
await runner.run(main_task)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
233
examples/storytelling-chatbot/src/bot_runner.py
Normal file
233
examples/storytelling-chatbot/src/bot_runner.py
Normal file
@@ -0,0 +1,233 @@
|
||||
import os
|
||||
import argparse
|
||||
import subprocess
|
||||
import requests
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import FastAPI, Request, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
|
||||
from pipecat.transports.services.helpers.daily_rest import DailyRESTHelper, DailyRoomObject, DailyRoomProperties, DailyRoomParams
|
||||
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv(override=True)
|
||||
|
||||
# ------------ Fast API Config ------------ #
|
||||
|
||||
MAX_SESSION_TIME = 5 * 60 # 5 minutes
|
||||
|
||||
daily_rest_helper = DailyRESTHelper(
|
||||
os.getenv("DAILY_API_KEY", ""),
|
||||
os.getenv("DAILY_API_URL", 'https://api.daily.co/v1'))
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Mount the static directory
|
||||
STATIC_DIR = "frontend/out"
|
||||
|
||||
|
||||
# ------------ Fast API Routes ------------ #
|
||||
|
||||
app.mount("/static", StaticFiles(directory=STATIC_DIR, html=True), name="static")
|
||||
|
||||
|
||||
@app.post("/start_bot")
|
||||
async def start_bot(request: Request) -> JSONResponse:
|
||||
if os.getenv("ENV", "dev") == "production":
|
||||
# Only allow requests from the specified domain
|
||||
host_header = request.headers.get("host")
|
||||
allowed_domains = ["storytelling-chatbot.fly.dev", "www.storytelling-chatbot.fly.dev"]
|
||||
# Check if the Host header matches the allowed domain
|
||||
if host_header not in allowed_domains:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
try:
|
||||
data = await request.json()
|
||||
# Is this a webhook creation request?
|
||||
if "test" in data:
|
||||
return JSONResponse({"test": True})
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
# Use specified room URL, or create a new one if not specified
|
||||
room_url = os.getenv("DAILY_SAMPLE_ROOM_URL", "")
|
||||
|
||||
if not room_url:
|
||||
params = DailyRoomParams(
|
||||
properties=DailyRoomProperties()
|
||||
)
|
||||
try:
|
||||
room: DailyRoomObject = daily_rest_helper.create_room(params=params)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Unable to provision room {e}")
|
||||
else:
|
||||
# Check passed room URL exists, we should assume that it already has a sip set up
|
||||
try:
|
||||
room: DailyRoomObject = daily_rest_helper.get_room_from_url(room_url)
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Room not found: {room_url}")
|
||||
|
||||
# Give the agent a token to join the session
|
||||
token = daily_rest_helper.get_token(room.url, MAX_SESSION_TIME)
|
||||
|
||||
if not room or not token:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to get token for room: {room_url}")
|
||||
|
||||
# Launch a new VM, or run as a shell process (not recommended)
|
||||
if os.getenv("RUN_AS_VM", False):
|
||||
try:
|
||||
virtualize_bot(room.url, token)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to spawn VM: {e}")
|
||||
else:
|
||||
try:
|
||||
subprocess.Popen(
|
||||
[f"python3 -m bot -u {room.url} -t {token}"],
|
||||
shell=True,
|
||||
bufsize=1,
|
||||
cwd=os.path.dirname(os.path.abspath(__file__)))
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to start subprocess: {e}")
|
||||
|
||||
# Grab a token for the user to join with
|
||||
user_token = daily_rest_helper.get_token(room.url, MAX_SESSION_TIME)
|
||||
|
||||
return JSONResponse({
|
||||
"room_url": room.url,
|
||||
"token": user_token,
|
||||
})
|
||||
|
||||
|
||||
@app.get("/{path_name:path}", response_class=FileResponse)
|
||||
async def catch_all(path_name: Optional[str] = ""):
|
||||
if path_name == "":
|
||||
return FileResponse(f"{STATIC_DIR}/index.html")
|
||||
|
||||
file_path = Path(STATIC_DIR) / (path_name or "")
|
||||
|
||||
if file_path.is_file():
|
||||
return file_path
|
||||
|
||||
html_file_path = file_path.with_suffix(".html")
|
||||
if html_file_path.is_file():
|
||||
return FileResponse(html_file_path)
|
||||
|
||||
raise HTTPException(status_code=450, detail="Incorrect API call")
|
||||
|
||||
|
||||
# ------------ Virtualization ------------ #
|
||||
|
||||
def virtualize_bot(room_url: str, token: str):
|
||||
"""
|
||||
This is an example of how to virtualize the bot using Fly.io
|
||||
You can adapt this method to use whichever cloud provider you prefer.
|
||||
"""
|
||||
FLY_API_HOST = os.getenv("FLY_API_HOST", "https://api.machines.dev/v1")
|
||||
FLY_APP_NAME = os.getenv("FLY_APP_NAME", "storytelling-chatbot")
|
||||
FLY_API_KEY = os.getenv("FLY_API_KEY", "")
|
||||
FLY_HEADERS = {
|
||||
'Authorization': f"Bearer {FLY_API_KEY}",
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
# Use the same image as the bot runner
|
||||
res = requests.get(f"{FLY_API_HOST}/apps/{FLY_APP_NAME}/machines", headers=FLY_HEADERS)
|
||||
if res.status_code != 200:
|
||||
raise Exception(f"Unable to get machine info from Fly: {res.text}")
|
||||
image = res.json()[0]['config']['image']
|
||||
|
||||
# Machine configuration
|
||||
cmd = f"python3 src/bot.py -u {room_url} -t {token}"
|
||||
cmd = cmd.split()
|
||||
worker_props = {
|
||||
"config": {
|
||||
"image": image,
|
||||
"auto_destroy": True,
|
||||
"init": {
|
||||
"cmd": cmd
|
||||
},
|
||||
"restart": {
|
||||
"policy": "no"
|
||||
},
|
||||
"guest": {
|
||||
"cpu_kind": "shared",
|
||||
"cpus": 1,
|
||||
"memory_mb": 512
|
||||
}
|
||||
},
|
||||
|
||||
}
|
||||
|
||||
# Spawn a new machine instance
|
||||
res = requests.post(
|
||||
f"{FLY_API_HOST}/apps/{FLY_APP_NAME}/machines",
|
||||
headers=FLY_HEADERS,
|
||||
json=worker_props)
|
||||
|
||||
if res.status_code != 200:
|
||||
raise Exception(f"Problem starting a bot worker: {res.text}")
|
||||
|
||||
# Wait for the machine to enter the started state
|
||||
vm_id = res.json()['id']
|
||||
|
||||
res = requests.get(
|
||||
f"{FLY_API_HOST}/apps/{FLY_APP_NAME}/machines/{vm_id}/wait?state=started",
|
||||
headers=FLY_HEADERS)
|
||||
|
||||
if res.status_code != 200:
|
||||
raise Exception(f"Bot was unable to enter started state: {res.text}")
|
||||
|
||||
print(f"Machine joined room: {room_url}")
|
||||
|
||||
|
||||
# ------------ Main ------------ #
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Check environment variables
|
||||
required_env_vars = ['OPENAI_API_KEY', 'DAILY_API_KEY',
|
||||
'FAL_KEY', 'ELEVENLABS_VOICE_ID', 'ELEVENLABS_API_KEY']
|
||||
for env_var in required_env_vars:
|
||||
if env_var not in os.environ:
|
||||
raise Exception(f"Missing environment variable: {env_var}.")
|
||||
|
||||
import uvicorn
|
||||
|
||||
default_host = os.getenv("HOST", "0.0.0.0")
|
||||
default_port = int(os.getenv("FAST_API_PORT", "7860"))
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Daily Storyteller FastAPI server")
|
||||
parser.add_argument("--host", type=str,
|
||||
default=default_host, help="Host address")
|
||||
parser.add_argument("--port", type=int,
|
||||
default=default_port, help="Port number")
|
||||
parser.add_argument("--reload", action="store_true",
|
||||
help="Reload code on change")
|
||||
|
||||
config = parser.parse_args()
|
||||
|
||||
uvicorn.run(
|
||||
"bot_runner:app",
|
||||
host=config.host,
|
||||
port=config.port,
|
||||
reload=config.reload
|
||||
)
|
||||
@@ -1,175 +0,0 @@
|
||||
import os
|
||||
import argparse
|
||||
import subprocess
|
||||
import atexit
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import FastAPI, Request, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
|
||||
from utils.daily_helpers import create_room as _create_room, get_token, get_name_from_url
|
||||
|
||||
MAX_BOTS_PER_ROOM = 1
|
||||
|
||||
# Bot sub-process dict for status reporting and concurrency control
|
||||
bot_procs = {}
|
||||
|
||||
|
||||
def cleanup():
|
||||
# Clean up function, just to be extra safe
|
||||
for proc in bot_procs.values():
|
||||
proc.terminate()
|
||||
proc.wait()
|
||||
|
||||
|
||||
atexit.register(cleanup)
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Mount the static directory
|
||||
STATIC_DIR = "frontend/out"
|
||||
|
||||
app.mount("/static", StaticFiles(directory=STATIC_DIR, html=True), name="static")
|
||||
|
||||
|
||||
@app.post("/create")
|
||||
async def create_room(request: Request) -> JSONResponse:
|
||||
data = await request.json()
|
||||
|
||||
if data.get('room_url') is not None:
|
||||
room_url = data.get('room_url')
|
||||
room_name = get_name_from_url(room_url)
|
||||
else:
|
||||
room_url, room_name = _create_room()
|
||||
|
||||
token = get_token(room_url)
|
||||
|
||||
return JSONResponse({"room_url": room_url, "room_name": room_name, "token": token})
|
||||
|
||||
|
||||
@app.post("/start")
|
||||
async def start_agent(request: Request) -> JSONResponse:
|
||||
data = await request.json()
|
||||
|
||||
# Is this a webhook creation request?
|
||||
if "test" in data:
|
||||
return JSONResponse({"test": True})
|
||||
|
||||
# Ensure the room property is present
|
||||
room_url = data.get('room_url')
|
||||
if not room_url:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Missing 'room' property in request data. Cannot start agent without a target room!")
|
||||
|
||||
# Check if there is already an existing process running in this room
|
||||
num_bots_in_room = sum(
|
||||
1 for proc in bot_procs.values() if proc[1] == room_url and proc[0].poll() is None)
|
||||
if num_bots_in_room >= MAX_BOTS_PER_ROOM:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Max bot limited reach for room: {room_url}")
|
||||
|
||||
# Get the token for the room
|
||||
token = get_token(room_url)
|
||||
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to get token for room: {room_url}")
|
||||
|
||||
# Spawn a new agent, and join the user session
|
||||
# Note: this is mostly for demonstration purposes (refer to 'deployment' in README)
|
||||
try:
|
||||
proc = subprocess.Popen(
|
||||
[
|
||||
f"python3 -m bot -u {room_url} -t {token}"
|
||||
],
|
||||
shell=True,
|
||||
bufsize=1,
|
||||
cwd=os.path.dirname(os.path.abspath(__file__))
|
||||
)
|
||||
bot_procs[proc.pid] = (proc, room_url)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to start subprocess: {e}")
|
||||
|
||||
return JSONResponse({"bot_id": proc.pid, "room_url": room_url})
|
||||
|
||||
|
||||
@app.get("/status/{pid}")
|
||||
def get_status(pid: int):
|
||||
# Look up the subprocess
|
||||
proc = bot_procs.get(pid)
|
||||
|
||||
# If the subprocess doesn't exist, return an error
|
||||
if not proc:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Bot with process id: {pid} not found")
|
||||
|
||||
# Check the status of the subprocess
|
||||
if proc[0].poll() is None:
|
||||
status = "running"
|
||||
else:
|
||||
status = "finished"
|
||||
|
||||
return JSONResponse({"bot_id": pid, "status": status})
|
||||
|
||||
|
||||
@app.get("/{path_name:path}", response_class=FileResponse)
|
||||
async def catch_all(path_name: Optional[str] = ""):
|
||||
if path_name == "":
|
||||
return FileResponse(f"{STATIC_DIR}/index.html")
|
||||
|
||||
file_path = Path(STATIC_DIR) / (path_name or "")
|
||||
|
||||
if file_path.is_file():
|
||||
return file_path
|
||||
|
||||
html_file_path = file_path.with_suffix(".html")
|
||||
if html_file_path.is_file():
|
||||
return FileResponse(html_file_path)
|
||||
|
||||
raise HTTPException(status_code=450, detail="Incorrect API call")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Check environment variables
|
||||
required_env_vars = ['OPENAI_API_KEY', 'DAILY_API_KEY',
|
||||
'FAL_KEY', 'ELEVENLABS_VOICE_ID', 'ELEVENLABS_API_KEY']
|
||||
for env_var in required_env_vars:
|
||||
if env_var not in os.environ:
|
||||
raise Exception(f"Missing environment variable: {env_var}.")
|
||||
|
||||
import uvicorn
|
||||
|
||||
default_host = os.getenv("HOST", "0.0.0.0")
|
||||
default_port = int(os.getenv("FAST_API_PORT", "7860"))
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Daily Storyteller FastAPI server")
|
||||
parser.add_argument("--host", type=str,
|
||||
default=default_host, help="Host address")
|
||||
parser.add_argument("--port", type=int,
|
||||
default=default_port, help="Port number")
|
||||
parser.add_argument("--reload", action="store_true",
|
||||
help="Reload code on change")
|
||||
|
||||
config = parser.parse_args()
|
||||
|
||||
uvicorn.run(
|
||||
"server:app",
|
||||
host=config.host,
|
||||
port=config.port,
|
||||
reload=config.reload
|
||||
)
|
||||
161
examples/twilio-chatbot/.gitignore
vendored
Normal file
161
examples/twilio-chatbot/.gitignore
vendored
Normal file
@@ -0,0 +1,161 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
runpod.toml
|
||||
20
examples/twilio-chatbot/Dockerfile
Normal file
20
examples/twilio-chatbot/Dockerfile
Normal file
@@ -0,0 +1,20 @@
|
||||
# Use an official Python runtime as a parent image
|
||||
FROM python:3.10-bullseye
|
||||
|
||||
# Set the working directory in the container
|
||||
WORKDIR /twilio-chatbot
|
||||
|
||||
# Copy the requirements file into the container
|
||||
COPY requirements.txt .
|
||||
|
||||
# Install any needed packages specified in requirements.txt
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy the current directory contents into the container
|
||||
COPY . .
|
||||
|
||||
# Expose the desired port
|
||||
EXPOSE 8765
|
||||
|
||||
# Run the application
|
||||
CMD ["uvicorn", "server:app", "--host", "0.0.0.0", "--port", "8765"]
|
||||
82
examples/twilio-chatbot/README.md
Normal file
82
examples/twilio-chatbot/README.md
Normal file
@@ -0,0 +1,82 @@
|
||||
# Twilio Chatbot
|
||||
|
||||
This project is a FastAPI-based chatbot that integrates with Twilio to handle WebSocket connections and provide real-time communication. The project includes endpoints for starting a call and handling WebSocket connections.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Features](#features)
|
||||
- [Requirements](#requirements)
|
||||
- [Installation](#installation)
|
||||
- [Configure Twilio URLs](#configure-twilio-urls)
|
||||
- [Running the Application](#running-the-application)
|
||||
- [Usage](#usage)
|
||||
|
||||
## Features
|
||||
|
||||
- **FastAPI**: A modern, fast (high-performance), web framework for building APIs with Python 3.6+.
|
||||
- **WebSocket Support**: Real-time communication using WebSockets.
|
||||
- **CORS Middleware**: Allowing cross-origin requests for testing.
|
||||
- **Dockerized**: Easily deployable using Docker.
|
||||
|
||||
## Requirements
|
||||
|
||||
- Python 3.10
|
||||
- Docker (for containerized deployment)
|
||||
- ngrok (for tunneling)
|
||||
- Twilio Account
|
||||
|
||||
## Installation
|
||||
|
||||
1. **Set up a virtual environment** (optional but recommended):
|
||||
```sh
|
||||
python -m venv venv
|
||||
source venv/bin/activate # On Windows, use `venv\Scripts\activate`
|
||||
```
|
||||
|
||||
2. **Install dependencies**:
|
||||
```sh
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
3. **Create .env**:
|
||||
create .env based on env.example
|
||||
|
||||
4. **Install ngrok**:
|
||||
Follow the instructions on the [ngrok website](https://ngrok.com/download) to download and install ngrok.
|
||||
|
||||
## Configure Twilio URLs
|
||||
|
||||
1. **Update the Twilio Webhook**:
|
||||
Copy the ngrok URL and update your Twilio phone number webhook URL to `http://<ngrok_url>/start_call`.
|
||||
|
||||
2. **Update the streams.xml**:
|
||||
Copy the ngrok URL and update templates/streams.xml with `wss://<ngrok_url>/ws`.
|
||||
|
||||
## Running the Application
|
||||
|
||||
### Using Python
|
||||
|
||||
1. **Run the FastAPI application**:
|
||||
```sh
|
||||
python server.py
|
||||
```
|
||||
|
||||
2. **Start ngrok**:
|
||||
In a new terminal, start ngrok to tunnel the local server:
|
||||
```sh
|
||||
ngrok http 8765
|
||||
```
|
||||
### Using Docker
|
||||
|
||||
1. **Build the Docker image**:
|
||||
```sh
|
||||
docker build -t twilio-chatbot .
|
||||
```
|
||||
|
||||
2. **Run the Docker container**:
|
||||
```sh
|
||||
docker run -it --rm -p 8765:8765 twilio-chatbot
|
||||
```
|
||||
## Usage
|
||||
|
||||
To start a call, simply make a call to your Twilio phone number. The webhook URL will direct the call to your FastAPI application, which will handle it accordingly.
|
||||
90
examples/twilio-chatbot/bot.py
Normal file
90
examples/twilio-chatbot/bot.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import aiohttp
|
||||
import os
|
||||
import sys
|
||||
|
||||
from pipecat.frames.frames import EndFrame, LLMMessagesFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantResponseAggregator,
|
||||
LLMUserResponseAggregator
|
||||
)
|
||||
from pipecat.services.openai import OpenAILLMService
|
||||
from pipecat.services.deepgram import DeepgramSTTService
|
||||
from pipecat.services.elevenlabs import ElevenLabsTTSService
|
||||
from pipecat.transports.network.fastapi_websocket import FastAPIWebsocketTransport, FastAPIWebsocketParams
|
||||
from pipecat.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.serializers.twilio import TwilioFrameSerializer
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
async def run_bot(websocket_client, stream_sid):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
transport = FastAPIWebsocketTransport(
|
||||
websocket=websocket_client,
|
||||
params=FastAPIWebsocketParams(
|
||||
audio_out_enabled=True,
|
||||
add_wav_header=False,
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
vad_audio_passthrough=True,
|
||||
serializer=TwilioFrameSerializer(stream_sid)
|
||||
)
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
model="gpt-4o")
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv('DEEPGRAM_API_KEY'))
|
||||
|
||||
tts = ElevenLabsTTSService(
|
||||
aiohttp_session=session,
|
||||
api_key=os.getenv("ELEVENLABS_API_KEY"),
|
||||
voice_id=os.getenv("ELEVENLABS_VOICE_ID"),
|
||||
)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in an audio call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
tma_in = LLMUserResponseAggregator(messages)
|
||||
tma_out = LLMAssistantResponseAggregator(messages)
|
||||
|
||||
pipeline = Pipeline([
|
||||
transport.input(), # Websocket input from client
|
||||
stt, # Speech-To-Text
|
||||
tma_in, # User responses
|
||||
llm, # LLM
|
||||
tts, # Text-To-Speech
|
||||
transport.output(), # Websocket output to client
|
||||
tma_out # LLM responses
|
||||
])
|
||||
|
||||
task = PipelineTask(pipeline, params=PipelineParams(allow_interruptions=True))
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
# Kick off the conversation.
|
||||
messages.append(
|
||||
{"role": "system", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([LLMMessagesFrame(messages)])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
await task.queue_frames([EndFrame()])
|
||||
|
||||
runner = PipelineRunner(handle_sigint=False)
|
||||
|
||||
await runner.run(task)
|
||||
4
examples/twilio-chatbot/env.example
Normal file
4
examples/twilio-chatbot/env.example
Normal file
@@ -0,0 +1,4 @@
|
||||
OPENAI_API_KEY=
|
||||
DEEPGRAM_API_KEY=
|
||||
ELEVENLABS_API_KEY=
|
||||
ELEVENLABS_VOICE_ID=
|
||||
5
examples/twilio-chatbot/requirements.txt
Normal file
5
examples/twilio-chatbot/requirements.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
pipecat-ai[daily,openai,silero,deepgram]
|
||||
fastapi
|
||||
uvicorn
|
||||
python-dotenv
|
||||
loguru
|
||||
41
examples/twilio-chatbot/server.py
Normal file
41
examples/twilio-chatbot/server.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import json
|
||||
|
||||
import uvicorn
|
||||
|
||||
from fastapi import FastAPI, WebSocket
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from starlette.responses import HTMLResponse
|
||||
|
||||
from bot import run_bot
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # Allow all origins for testing
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@app.post('/start_call')
|
||||
async def start_call():
|
||||
print("POST TwiML")
|
||||
return HTMLResponse(content=open("templates/streams.xml").read(), media_type="application/xml")
|
||||
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
start_data = websocket.iter_text()
|
||||
await start_data.__anext__()
|
||||
call_data = json.loads(await start_data.__anext__())
|
||||
print(call_data, flush=True)
|
||||
stream_sid = call_data['start']['streamSid']
|
||||
print("WebSocket connection accepted")
|
||||
await run_bot(websocket, stream_sid)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="0.0.0.0", port=8765)
|
||||
7
examples/twilio-chatbot/templates/streams.xml
Normal file
7
examples/twilio-chatbot/templates/streams.xml
Normal file
@@ -0,0 +1,7 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<Response>
|
||||
<Connect>
|
||||
<Stream url="wss://<your server url>/ws"></Stream>
|
||||
</Connect>
|
||||
<Pause length="40"/>
|
||||
</Response>
|
||||
@@ -4,11 +4,10 @@
|
||||
#
|
||||
# pip-compile --all-extras pyproject.toml
|
||||
#
|
||||
aiofiles==23.2.1
|
||||
aiofiles==24.1.0
|
||||
# via deepgram-sdk
|
||||
aiohttp==3.9.5
|
||||
# via
|
||||
# cartesia
|
||||
# deepgram-sdk
|
||||
# langchain
|
||||
# langchain-community
|
||||
@@ -17,7 +16,7 @@ aiosignal==1.3.1
|
||||
# via aiohttp
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
anthropic==0.25.9
|
||||
anthropic==0.28.1
|
||||
# via
|
||||
# openpipe
|
||||
# pipecat-ai (pyproject.toml)
|
||||
@@ -26,6 +25,8 @@ anyio==4.4.0
|
||||
# anthropic
|
||||
# httpx
|
||||
# openai
|
||||
# starlette
|
||||
# watchfiles
|
||||
async-timeout==4.0.3
|
||||
# via
|
||||
# aiohttp
|
||||
@@ -34,32 +35,31 @@ attrs==23.2.0
|
||||
# via
|
||||
# aiohttp
|
||||
# openpipe
|
||||
av==12.1.0
|
||||
av==12.3.0
|
||||
# via faster-whisper
|
||||
azure-cognitiveservices-speech==1.37.0
|
||||
azure-cognitiveservices-speech==1.38.0
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
blinker==1.8.2
|
||||
# via flask
|
||||
cachetools==5.3.3
|
||||
cachetools==5.4.0
|
||||
# via google-auth
|
||||
cartesia==0.1.1
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
certifi==2024.6.2
|
||||
certifi==2024.7.4
|
||||
# via
|
||||
# httpcore
|
||||
# httpx
|
||||
# requests
|
||||
cffi==1.16.0
|
||||
# via sounddevice
|
||||
charset-normalizer==3.3.2
|
||||
# via requests
|
||||
click==8.1.7
|
||||
# via flask
|
||||
# via
|
||||
# flask
|
||||
# typer
|
||||
# uvicorn
|
||||
coloredlogs==15.0.1
|
||||
# via onnxruntime
|
||||
ctranslate2==4.3.1
|
||||
# via faster-whisper
|
||||
daily-python==0.9.1
|
||||
daily-python==0.10.1
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
dataclasses-json==0.6.7
|
||||
# via
|
||||
@@ -71,17 +71,23 @@ distro==1.9.0
|
||||
# via
|
||||
# anthropic
|
||||
# openai
|
||||
dnspython==2.6.1
|
||||
# via email-validator
|
||||
einops==0.8.0
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
exceptiongroup==1.2.1
|
||||
# via
|
||||
# anyio
|
||||
# pytest
|
||||
fal-client==0.4.0
|
||||
email-validator==2.2.0
|
||||
# via fastapi
|
||||
exceptiongroup==1.2.2
|
||||
# via anyio
|
||||
fal-client==0.4.1
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
faster-whisper==1.0.2
|
||||
fastapi==0.111.1
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
filelock==3.15.1
|
||||
fastapi-cli==0.0.4
|
||||
# via fastapi
|
||||
faster-whisper==1.0.3
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
filelock==3.15.4
|
||||
# via
|
||||
# huggingface-hub
|
||||
# pyht
|
||||
@@ -100,22 +106,22 @@ frozenlist==1.4.1
|
||||
# via
|
||||
# aiohttp
|
||||
# aiosignal
|
||||
fsspec==2024.6.0
|
||||
fsspec==2024.6.1
|
||||
# via
|
||||
# huggingface-hub
|
||||
# torch
|
||||
future==1.0.0
|
||||
# via pyloudnorm
|
||||
google-ai-generativelanguage==0.6.4
|
||||
google-ai-generativelanguage==0.6.6
|
||||
# via google-generativeai
|
||||
google-api-core[grpc]==2.19.0
|
||||
google-api-core[grpc]==2.19.1
|
||||
# via
|
||||
# google-ai-generativelanguage
|
||||
# google-api-python-client
|
||||
# google-generativeai
|
||||
google-api-python-client==2.133.0
|
||||
google-api-python-client==2.137.0
|
||||
# via google-generativeai
|
||||
google-auth==2.30.0
|
||||
google-auth==2.32.0
|
||||
# via
|
||||
# google-ai-generativelanguage
|
||||
# google-api-core
|
||||
@@ -124,15 +130,15 @@ google-auth==2.30.0
|
||||
# google-generativeai
|
||||
google-auth-httplib2==0.2.0
|
||||
# via google-api-python-client
|
||||
google-generativeai==0.5.4
|
||||
google-generativeai==0.7.2
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
googleapis-common-protos==1.63.1
|
||||
googleapis-common-protos==1.63.2
|
||||
# via
|
||||
# google-api-core
|
||||
# grpcio-status
|
||||
greenlet==3.0.3
|
||||
# via sqlalchemy
|
||||
grpcio==1.64.1
|
||||
grpcio==1.65.1
|
||||
# via
|
||||
# google-api-core
|
||||
# grpcio-status
|
||||
@@ -140,24 +146,28 @@ grpcio==1.64.1
|
||||
grpcio-status==1.62.2
|
||||
# via google-api-core
|
||||
h11==0.14.0
|
||||
# via httpcore
|
||||
# via
|
||||
# httpcore
|
||||
# uvicorn
|
||||
httpcore==1.0.5
|
||||
# via httpx
|
||||
httplib2==0.22.0
|
||||
# via
|
||||
# google-api-python-client
|
||||
# google-auth-httplib2
|
||||
httptools==0.6.1
|
||||
# via uvicorn
|
||||
httpx==0.27.0
|
||||
# via
|
||||
# anthropic
|
||||
# cartesia
|
||||
# deepgram-sdk
|
||||
# fal-client
|
||||
# fastapi
|
||||
# openai
|
||||
# openpipe
|
||||
httpx-sse==0.4.0
|
||||
# via fal-client
|
||||
huggingface-hub==0.23.3
|
||||
huggingface-hub==0.24.1
|
||||
# via
|
||||
# faster-whisper
|
||||
# timm
|
||||
@@ -168,50 +178,58 @@ humanfriendly==10.0
|
||||
idna==3.7
|
||||
# via
|
||||
# anyio
|
||||
# email-validator
|
||||
# httpx
|
||||
# requests
|
||||
# yarl
|
||||
iniconfig==2.0.0
|
||||
# via pytest
|
||||
itsdangerous==2.2.0
|
||||
# via flask
|
||||
jinja2==3.1.4
|
||||
# via
|
||||
# fastapi
|
||||
# flask
|
||||
# torch
|
||||
jiter==0.5.0
|
||||
# via anthropic
|
||||
jsonpatch==1.33
|
||||
# via langchain-core
|
||||
jsonpointer==3.0.0
|
||||
# via jsonpatch
|
||||
langchain==0.2.3
|
||||
langchain==0.2.11
|
||||
# via
|
||||
# langchain-community
|
||||
# pipecat-ai (pyproject.toml)
|
||||
langchain-community==0.2.4
|
||||
langchain-community==0.2.10
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
langchain-core==0.2.5
|
||||
langchain-core==0.2.23
|
||||
# via
|
||||
# langchain
|
||||
# langchain-community
|
||||
# langchain-openai
|
||||
# langchain-text-splitters
|
||||
langchain-openai==0.1.8
|
||||
langchain-openai==0.1.17
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
langchain-text-splitters==0.2.1
|
||||
langchain-text-splitters==0.2.2
|
||||
# via langchain
|
||||
langsmith==0.1.77
|
||||
langsmith==0.1.93
|
||||
# via
|
||||
# langchain
|
||||
# langchain-community
|
||||
# langchain-core
|
||||
llvmlite==0.43.0
|
||||
# via numba
|
||||
loguru==0.7.2
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
markdown-it-py==3.0.0
|
||||
# via rich
|
||||
markupsafe==2.1.5
|
||||
# via
|
||||
# jinja2
|
||||
# werkzeug
|
||||
marshmallow==3.21.3
|
||||
# via dataclasses-json
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
multidict==6.0.5
|
||||
@@ -222,14 +240,18 @@ mypy-extensions==1.0.0
|
||||
# via typing-inspect
|
||||
networkx==3.3
|
||||
# via torch
|
||||
numba==0.60.0
|
||||
# via resampy
|
||||
numpy==1.26.4
|
||||
# via
|
||||
# ctranslate2
|
||||
# langchain
|
||||
# langchain-community
|
||||
# numba
|
||||
# onnxruntime
|
||||
# pipecat-ai (pyproject.toml)
|
||||
# pyloudnorm
|
||||
# resampy
|
||||
# scipy
|
||||
# torchvision
|
||||
# transformers
|
||||
@@ -258,38 +280,35 @@ nvidia-cusparse-cu12==12.1.0.106
|
||||
# torch
|
||||
nvidia-nccl-cu12==2.20.5
|
||||
# via torch
|
||||
nvidia-nvjitlink-cu12==12.5.40
|
||||
nvidia-nvjitlink-cu12==12.5.82
|
||||
# via
|
||||
# nvidia-cusolver-cu12
|
||||
# nvidia-cusparse-cu12
|
||||
nvidia-nvtx-cu12==12.1.105
|
||||
# via torch
|
||||
onnxruntime==1.18.0
|
||||
onnxruntime==1.18.1
|
||||
# via faster-whisper
|
||||
openai==1.26.0
|
||||
openai==1.35.15
|
||||
# via
|
||||
# langchain-openai
|
||||
# openpipe
|
||||
# pipecat-ai (pyproject.toml)
|
||||
openpipe==4.14.0
|
||||
openpipe==4.18.0
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
orjson==3.10.4
|
||||
orjson==3.10.6
|
||||
# via langsmith
|
||||
packaging==23.2
|
||||
packaging==24.1
|
||||
# via
|
||||
# huggingface-hub
|
||||
# langchain-core
|
||||
# marshmallow
|
||||
# onnxruntime
|
||||
# pytest
|
||||
# transformers
|
||||
pillow==10.3.0
|
||||
# via
|
||||
# pipecat-ai (pyproject.toml)
|
||||
# torchvision
|
||||
pluggy==1.5.0
|
||||
# via pytest
|
||||
proto-plus==1.23.0
|
||||
proto-plus==1.24.0
|
||||
# via
|
||||
# google-ai-generativelanguage
|
||||
# google-api-core
|
||||
@@ -312,32 +331,33 @@ pyasn1-modules==0.4.0
|
||||
# via google-auth
|
||||
pyaudio==0.2.14
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
pycparser==2.22
|
||||
# via cffi
|
||||
pydantic==2.7.4
|
||||
pydantic==2.8.2
|
||||
# via
|
||||
# anthropic
|
||||
# fastapi
|
||||
# google-generativeai
|
||||
# langchain
|
||||
# langchain-core
|
||||
# langsmith
|
||||
# openai
|
||||
pydantic-core==2.18.4
|
||||
pydantic-core==2.20.1
|
||||
# via pydantic
|
||||
pygments==2.18.0
|
||||
# via rich
|
||||
pyht==0.0.28
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
pyloudnorm==0.1.1
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
pyparsing==3.1.2
|
||||
# via httplib2
|
||||
pytest==8.2.2
|
||||
# via pytest-asyncio
|
||||
pytest-asyncio==0.23.7
|
||||
# via cartesia
|
||||
python-dateutil==2.9.0.post0
|
||||
# via openpipe
|
||||
python-dotenv==1.0.1
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
# via
|
||||
# pipecat-ai (pyproject.toml)
|
||||
# uvicorn
|
||||
python-multipart==0.0.9
|
||||
# via fastapi
|
||||
pyyaml==6.0.1
|
||||
# via
|
||||
# ctranslate2
|
||||
@@ -347,13 +367,13 @@ pyyaml==6.0.1
|
||||
# langchain-core
|
||||
# timm
|
||||
# transformers
|
||||
# uvicorn
|
||||
regex==2024.5.15
|
||||
# via
|
||||
# tiktoken
|
||||
# transformers
|
||||
requests==2.32.3
|
||||
# via
|
||||
# cartesia
|
||||
# google-api-core
|
||||
# huggingface-hub
|
||||
# langchain
|
||||
@@ -362,14 +382,20 @@ requests==2.32.3
|
||||
# pyht
|
||||
# tiktoken
|
||||
# transformers
|
||||
resampy==0.4.3
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
rich==13.7.1
|
||||
# via typer
|
||||
rsa==4.9
|
||||
# via google-auth
|
||||
safetensors==0.4.3
|
||||
# via
|
||||
# timm
|
||||
# transformers
|
||||
scipy==1.13.1
|
||||
scipy==1.14.0
|
||||
# via pyloudnorm
|
||||
shellingham==1.5.4
|
||||
# via typer
|
||||
six==1.16.0
|
||||
# via python-dateutil
|
||||
sniffio==1.3.1
|
||||
@@ -378,17 +404,17 @@ sniffio==1.3.1
|
||||
# anyio
|
||||
# httpx
|
||||
# openai
|
||||
sounddevice==0.4.7
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
sqlalchemy==2.0.30
|
||||
sqlalchemy==2.0.31
|
||||
# via
|
||||
# langchain
|
||||
# langchain-community
|
||||
sympy==1.12.1
|
||||
starlette==0.37.2
|
||||
# via fastapi
|
||||
sympy==1.13.1
|
||||
# via
|
||||
# onnxruntime
|
||||
# torch
|
||||
tenacity==8.3.0
|
||||
tenacity==8.5.0
|
||||
# via
|
||||
# langchain
|
||||
# langchain-community
|
||||
@@ -402,8 +428,6 @@ tokenizers==0.19.1
|
||||
# anthropic
|
||||
# faster-whisper
|
||||
# transformers
|
||||
tomli==2.0.1
|
||||
# via pytest
|
||||
torch==2.3.1
|
||||
# via
|
||||
# pipecat-ai (pyproject.toml)
|
||||
@@ -424,11 +448,14 @@ transformers==4.40.2
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
triton==2.3.1
|
||||
# via torch
|
||||
typer==0.12.3
|
||||
# via fastapi-cli
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# anthropic
|
||||
# anyio
|
||||
# deepgram-sdk
|
||||
# fastapi
|
||||
# google-generativeai
|
||||
# huggingface-hub
|
||||
# openai
|
||||
@@ -437,20 +464,28 @@ typing-extensions==4.12.2
|
||||
# pydantic-core
|
||||
# sqlalchemy
|
||||
# torch
|
||||
# typer
|
||||
# typing-inspect
|
||||
# uvicorn
|
||||
typing-inspect==0.9.0
|
||||
# via dataclasses-json
|
||||
uritemplate==4.1.1
|
||||
# via google-api-python-client
|
||||
urllib3==2.2.1
|
||||
urllib3==2.2.2
|
||||
# via requests
|
||||
uvicorn[standard]==0.30.3
|
||||
# via fastapi
|
||||
uvloop==0.19.0
|
||||
# via uvicorn
|
||||
verboselogs==1.7
|
||||
# via deepgram-sdk
|
||||
watchfiles==0.22.0
|
||||
# via uvicorn
|
||||
websockets==12.0
|
||||
# via
|
||||
# cartesia
|
||||
# deepgram-sdk
|
||||
# pipecat-ai (pyproject.toml)
|
||||
# uvicorn
|
||||
werkzeug==3.0.3
|
||||
# via flask
|
||||
yarl==1.9.4
|
||||
|
||||
@@ -4,11 +4,10 @@
|
||||
#
|
||||
# pip-compile --all-extras pyproject.toml
|
||||
#
|
||||
aiofiles==23.2.1
|
||||
aiofiles==24.1.0
|
||||
# via deepgram-sdk
|
||||
aiohttp==3.9.5
|
||||
# via
|
||||
# cartesia
|
||||
# deepgram-sdk
|
||||
# langchain
|
||||
# langchain-community
|
||||
@@ -17,7 +16,7 @@ aiosignal==1.3.1
|
||||
# via aiohttp
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
anthropic==0.25.9
|
||||
anthropic==0.28.1
|
||||
# via
|
||||
# openpipe
|
||||
# pipecat-ai (pyproject.toml)
|
||||
@@ -26,6 +25,8 @@ anyio==4.4.0
|
||||
# anthropic
|
||||
# httpx
|
||||
# openai
|
||||
# starlette
|
||||
# watchfiles
|
||||
async-timeout==4.0.3
|
||||
# via
|
||||
# aiohttp
|
||||
@@ -34,32 +35,31 @@ attrs==23.2.0
|
||||
# via
|
||||
# aiohttp
|
||||
# openpipe
|
||||
av==12.1.0
|
||||
av==12.3.0
|
||||
# via faster-whisper
|
||||
azure-cognitiveservices-speech==1.37.0
|
||||
azure-cognitiveservices-speech==1.38.0
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
blinker==1.8.2
|
||||
# via flask
|
||||
cachetools==5.3.3
|
||||
cachetools==5.4.0
|
||||
# via google-auth
|
||||
cartesia==0.1.1
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
certifi==2024.6.2
|
||||
certifi==2024.7.4
|
||||
# via
|
||||
# httpcore
|
||||
# httpx
|
||||
# requests
|
||||
cffi==1.16.0
|
||||
# via sounddevice
|
||||
charset-normalizer==3.3.2
|
||||
# via requests
|
||||
click==8.1.7
|
||||
# via flask
|
||||
# via
|
||||
# flask
|
||||
# typer
|
||||
# uvicorn
|
||||
coloredlogs==15.0.1
|
||||
# via onnxruntime
|
||||
ctranslate2==4.3.1
|
||||
# via faster-whisper
|
||||
daily-python==0.9.1
|
||||
daily-python==0.10.1
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
dataclasses-json==0.6.7
|
||||
# via
|
||||
@@ -71,17 +71,23 @@ distro==1.9.0
|
||||
# via
|
||||
# anthropic
|
||||
# openai
|
||||
dnspython==2.6.1
|
||||
# via email-validator
|
||||
einops==0.8.0
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
exceptiongroup==1.2.1
|
||||
# via
|
||||
# anyio
|
||||
# pytest
|
||||
fal-client==0.4.0
|
||||
email-validator==2.2.0
|
||||
# via fastapi
|
||||
exceptiongroup==1.2.2
|
||||
# via anyio
|
||||
fal-client==0.4.1
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
faster-whisper==1.0.2
|
||||
fastapi==0.111.1
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
filelock==3.15.1
|
||||
fastapi-cli==0.0.4
|
||||
# via fastapi
|
||||
faster-whisper==1.0.3
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
filelock==3.15.4
|
||||
# via
|
||||
# huggingface-hub
|
||||
# pyht
|
||||
@@ -99,22 +105,22 @@ frozenlist==1.4.1
|
||||
# via
|
||||
# aiohttp
|
||||
# aiosignal
|
||||
fsspec==2024.6.0
|
||||
fsspec==2024.6.1
|
||||
# via
|
||||
# huggingface-hub
|
||||
# torch
|
||||
future==1.0.0
|
||||
# via pyloudnorm
|
||||
google-ai-generativelanguage==0.6.4
|
||||
google-ai-generativelanguage==0.6.6
|
||||
# via google-generativeai
|
||||
google-api-core[grpc]==2.19.0
|
||||
google-api-core[grpc]==2.19.1
|
||||
# via
|
||||
# google-ai-generativelanguage
|
||||
# google-api-python-client
|
||||
# google-generativeai
|
||||
google-api-python-client==2.133.0
|
||||
google-api-python-client==2.137.0
|
||||
# via google-generativeai
|
||||
google-auth==2.30.0
|
||||
google-auth==2.32.0
|
||||
# via
|
||||
# google-ai-generativelanguage
|
||||
# google-api-core
|
||||
@@ -123,13 +129,13 @@ google-auth==2.30.0
|
||||
# google-generativeai
|
||||
google-auth-httplib2==0.2.0
|
||||
# via google-api-python-client
|
||||
google-generativeai==0.5.4
|
||||
google-generativeai==0.7.2
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
googleapis-common-protos==1.63.1
|
||||
googleapis-common-protos==1.63.2
|
||||
# via
|
||||
# google-api-core
|
||||
# grpcio-status
|
||||
grpcio==1.64.1
|
||||
grpcio==1.65.1
|
||||
# via
|
||||
# google-api-core
|
||||
# grpcio-status
|
||||
@@ -137,24 +143,28 @@ grpcio==1.64.1
|
||||
grpcio-status==1.62.2
|
||||
# via google-api-core
|
||||
h11==0.14.0
|
||||
# via httpcore
|
||||
# via
|
||||
# httpcore
|
||||
# uvicorn
|
||||
httpcore==1.0.5
|
||||
# via httpx
|
||||
httplib2==0.22.0
|
||||
# via
|
||||
# google-api-python-client
|
||||
# google-auth-httplib2
|
||||
httptools==0.6.1
|
||||
# via uvicorn
|
||||
httpx==0.27.0
|
||||
# via
|
||||
# anthropic
|
||||
# cartesia
|
||||
# deepgram-sdk
|
||||
# fal-client
|
||||
# fastapi
|
||||
# openai
|
||||
# openpipe
|
||||
httpx-sse==0.4.0
|
||||
# via fal-client
|
||||
huggingface-hub==0.23.3
|
||||
huggingface-hub==0.24.1
|
||||
# via
|
||||
# faster-whisper
|
||||
# timm
|
||||
@@ -165,50 +175,58 @@ humanfriendly==10.0
|
||||
idna==3.7
|
||||
# via
|
||||
# anyio
|
||||
# email-validator
|
||||
# httpx
|
||||
# requests
|
||||
# yarl
|
||||
iniconfig==2.0.0
|
||||
# via pytest
|
||||
itsdangerous==2.2.0
|
||||
# via flask
|
||||
jinja2==3.1.4
|
||||
# via
|
||||
# fastapi
|
||||
# flask
|
||||
# torch
|
||||
jiter==0.5.0
|
||||
# via anthropic
|
||||
jsonpatch==1.33
|
||||
# via langchain-core
|
||||
jsonpointer==3.0.0
|
||||
# via jsonpatch
|
||||
langchain==0.2.3
|
||||
langchain==0.2.11
|
||||
# via
|
||||
# langchain-community
|
||||
# pipecat-ai (pyproject.toml)
|
||||
langchain-community==0.2.4
|
||||
langchain-community==0.2.10
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
langchain-core==0.2.5
|
||||
langchain-core==0.2.23
|
||||
# via
|
||||
# langchain
|
||||
# langchain-community
|
||||
# langchain-openai
|
||||
# langchain-text-splitters
|
||||
langchain-openai==0.1.8
|
||||
langchain-openai==0.1.17
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
langchain-text-splitters==0.2.1
|
||||
langchain-text-splitters==0.2.2
|
||||
# via langchain
|
||||
langsmith==0.1.77
|
||||
langsmith==0.1.93
|
||||
# via
|
||||
# langchain
|
||||
# langchain-community
|
||||
# langchain-core
|
||||
llvmlite==0.43.0
|
||||
# via numba
|
||||
loguru==0.7.2
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
markdown-it-py==3.0.0
|
||||
# via rich
|
||||
markupsafe==2.1.5
|
||||
# via
|
||||
# jinja2
|
||||
# werkzeug
|
||||
marshmallow==3.21.3
|
||||
# via dataclasses-json
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
multidict==6.0.5
|
||||
@@ -219,43 +237,44 @@ mypy-extensions==1.0.0
|
||||
# via typing-inspect
|
||||
networkx==3.3
|
||||
# via torch
|
||||
numba==0.60.0
|
||||
# via resampy
|
||||
numpy==1.26.4
|
||||
# via
|
||||
# ctranslate2
|
||||
# langchain
|
||||
# langchain-community
|
||||
# numba
|
||||
# onnxruntime
|
||||
# pipecat-ai (pyproject.toml)
|
||||
# pyloudnorm
|
||||
# resampy
|
||||
# scipy
|
||||
# torchvision
|
||||
# transformers
|
||||
onnxruntime==1.18.0
|
||||
onnxruntime==1.18.1
|
||||
# via faster-whisper
|
||||
openai==1.26.0
|
||||
openai==1.35.15
|
||||
# via
|
||||
# langchain-openai
|
||||
# openpipe
|
||||
# pipecat-ai (pyproject.toml)
|
||||
openpipe==4.14.0
|
||||
openpipe==4.18.0
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
orjson==3.10.4
|
||||
orjson==3.10.6
|
||||
# via langsmith
|
||||
packaging==23.2
|
||||
packaging==24.1
|
||||
# via
|
||||
# huggingface-hub
|
||||
# langchain-core
|
||||
# marshmallow
|
||||
# onnxruntime
|
||||
# pytest
|
||||
# transformers
|
||||
pillow==10.3.0
|
||||
# via
|
||||
# pipecat-ai (pyproject.toml)
|
||||
# torchvision
|
||||
pluggy==1.5.0
|
||||
# via pytest
|
||||
proto-plus==1.23.0
|
||||
proto-plus==1.24.0
|
||||
# via
|
||||
# google-ai-generativelanguage
|
||||
# google-api-core
|
||||
@@ -278,32 +297,33 @@ pyasn1-modules==0.4.0
|
||||
# via google-auth
|
||||
pyaudio==0.2.14
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
pycparser==2.22
|
||||
# via cffi
|
||||
pydantic==2.7.4
|
||||
pydantic==2.8.2
|
||||
# via
|
||||
# anthropic
|
||||
# fastapi
|
||||
# google-generativeai
|
||||
# langchain
|
||||
# langchain-core
|
||||
# langsmith
|
||||
# openai
|
||||
pydantic-core==2.18.4
|
||||
pydantic-core==2.20.1
|
||||
# via pydantic
|
||||
pygments==2.18.0
|
||||
# via rich
|
||||
pyht==0.0.28
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
pyloudnorm==0.1.1
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
pyparsing==3.1.2
|
||||
# via httplib2
|
||||
pytest==8.2.2
|
||||
# via pytest-asyncio
|
||||
pytest-asyncio==0.23.7
|
||||
# via cartesia
|
||||
python-dateutil==2.9.0.post0
|
||||
# via openpipe
|
||||
python-dotenv==1.0.1
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
# via
|
||||
# pipecat-ai (pyproject.toml)
|
||||
# uvicorn
|
||||
python-multipart==0.0.9
|
||||
# via fastapi
|
||||
pyyaml==6.0.1
|
||||
# via
|
||||
# ctranslate2
|
||||
@@ -313,13 +333,13 @@ pyyaml==6.0.1
|
||||
# langchain-core
|
||||
# timm
|
||||
# transformers
|
||||
# uvicorn
|
||||
regex==2024.5.15
|
||||
# via
|
||||
# tiktoken
|
||||
# transformers
|
||||
requests==2.32.3
|
||||
# via
|
||||
# cartesia
|
||||
# google-api-core
|
||||
# huggingface-hub
|
||||
# langchain
|
||||
@@ -328,14 +348,20 @@ requests==2.32.3
|
||||
# pyht
|
||||
# tiktoken
|
||||
# transformers
|
||||
resampy==0.4.3
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
rich==13.7.1
|
||||
# via typer
|
||||
rsa==4.9
|
||||
# via google-auth
|
||||
safetensors==0.4.3
|
||||
# via
|
||||
# timm
|
||||
# transformers
|
||||
scipy==1.13.1
|
||||
scipy==1.14.0
|
||||
# via pyloudnorm
|
||||
shellingham==1.5.4
|
||||
# via typer
|
||||
six==1.16.0
|
||||
# via python-dateutil
|
||||
sniffio==1.3.1
|
||||
@@ -344,17 +370,17 @@ sniffio==1.3.1
|
||||
# anyio
|
||||
# httpx
|
||||
# openai
|
||||
sounddevice==0.4.7
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
sqlalchemy==2.0.30
|
||||
sqlalchemy==2.0.31
|
||||
# via
|
||||
# langchain
|
||||
# langchain-community
|
||||
sympy==1.12.1
|
||||
starlette==0.37.2
|
||||
# via fastapi
|
||||
sympy==1.13.1
|
||||
# via
|
||||
# onnxruntime
|
||||
# torch
|
||||
tenacity==8.3.0
|
||||
tenacity==8.5.0
|
||||
# via
|
||||
# langchain
|
||||
# langchain-community
|
||||
@@ -368,8 +394,6 @@ tokenizers==0.19.1
|
||||
# anthropic
|
||||
# faster-whisper
|
||||
# transformers
|
||||
tomli==2.0.1
|
||||
# via pytest
|
||||
torch==2.3.1
|
||||
# via
|
||||
# pipecat-ai (pyproject.toml)
|
||||
@@ -388,11 +412,14 @@ tqdm==4.66.4
|
||||
# transformers
|
||||
transformers==4.40.2
|
||||
# via pipecat-ai (pyproject.toml)
|
||||
typer==0.12.3
|
||||
# via fastapi-cli
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# anthropic
|
||||
# anyio
|
||||
# deepgram-sdk
|
||||
# fastapi
|
||||
# google-generativeai
|
||||
# huggingface-hub
|
||||
# openai
|
||||
@@ -401,20 +428,28 @@ typing-extensions==4.12.2
|
||||
# pydantic-core
|
||||
# sqlalchemy
|
||||
# torch
|
||||
# typer
|
||||
# typing-inspect
|
||||
# uvicorn
|
||||
typing-inspect==0.9.0
|
||||
# via dataclasses-json
|
||||
uritemplate==4.1.1
|
||||
# via google-api-python-client
|
||||
urllib3==2.2.1
|
||||
urllib3==2.2.2
|
||||
# via requests
|
||||
uvicorn[standard]==0.30.3
|
||||
# via fastapi
|
||||
uvloop==0.19.0
|
||||
# via uvicorn
|
||||
verboselogs==1.7
|
||||
# via deepgram-sdk
|
||||
watchfiles==0.22.0
|
||||
# via uvicorn
|
||||
websockets==12.0
|
||||
# via
|
||||
# cartesia
|
||||
# deepgram-sdk
|
||||
# pipecat-ai (pyproject.toml)
|
||||
# uvicorn
|
||||
werkzeug==3.0.3
|
||||
# via flask
|
||||
yarl==1.9.4
|
||||
|
||||
@@ -8,7 +8,7 @@ dynamic = ["version"]
|
||||
description = "An open source framework for voice (and multimodal) assistants"
|
||||
license = { text = "BSD 2-Clause License" }
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.7"
|
||||
requires-python = ">=3.10"
|
||||
keywords = ["webrtc", "audio", "video", "ai"]
|
||||
classifiers = [
|
||||
"Development Status :: 5 - Production/Stable",
|
||||
@@ -34,24 +34,26 @@ Source = "https://github.com/pipecat-ai/pipecat"
|
||||
Website = "https://pipecat.ai"
|
||||
|
||||
[project.optional-dependencies]
|
||||
anthropic = [ "anthropic~=0.25.7" ]
|
||||
azure = [ "azure-cognitiveservices-speech~=1.37.0" ]
|
||||
cartesia = [ "numpy~=1.26.0", "sounddevice", "cartesia" ]
|
||||
daily = [ "daily-python~=0.9.0" ]
|
||||
anthropic = [ "anthropic~=0.28.1" ]
|
||||
azure = [ "azure-cognitiveservices-speech~=1.38.0" ]
|
||||
cartesia = [ "websockets~=12.0" ]
|
||||
daily = [ "daily-python~=0.10.1" ]
|
||||
deepgram = [ "deepgram-sdk~=3.2.7" ]
|
||||
examples = [ "python-dotenv~=1.0.0", "flask~=3.0.3", "flask_cors~=4.0.1" ]
|
||||
fal = [ "fal-client~=0.4.0" ]
|
||||
google = [ "google-generativeai~=0.5.3" ]
|
||||
fireworks = [ "openai~=1.26.0" ]
|
||||
langchain = [ "langchain~=0.2.1", "langchain-community~=0.2.1", "langchain-openai~=0.1.8" ]
|
||||
fal = [ "fal-client~=0.4.1" ]
|
||||
gladia = [ "websockets~=12.0" ]
|
||||
google = [ "google-generativeai~=0.7.1" ]
|
||||
fireworks = [ "openai~=1.35.0" ]
|
||||
langchain = [ "langchain~=0.2.10", "langchain-community~=0.2.9", "langchain-openai~=0.1.17" ]
|
||||
local = [ "pyaudio~=0.2.0" ]
|
||||
moondream = [ "einops~=0.8.0", "timm~=0.9.16", "transformers~=4.40.2" ]
|
||||
openai = [ "openai~=1.26.0" ]
|
||||
openpipe = [ "openpipe~=4.14.0" ]
|
||||
openai = [ "openai~=1.35.0" ]
|
||||
openpipe = [ "openpipe~=4.18.0" ]
|
||||
playht = [ "pyht~=0.0.28" ]
|
||||
silero = [ "torch~=2.3.0", "torchaudio~=2.3.0" ]
|
||||
websocket = [ "websockets~=12.0" ]
|
||||
whisper = [ "faster-whisper~=1.0.2" ]
|
||||
silero = [ "torch~=2.3.1", "torchaudio~=2.3.1" ]
|
||||
websocket = [ "websockets~=12.0", "fastapi~=0.111.0" ]
|
||||
whisper = [ "faster-whisper~=1.0.3" ]
|
||||
xtts = [ "resampy~=0.4.3" ]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
# All the following settings are optional:
|
||||
|
||||
@@ -158,6 +158,34 @@ class LLMMessagesFrame(DataFrame):
|
||||
messages: List[dict]
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMMessagesAppendFrame(DataFrame):
|
||||
"""A frame containing a list of LLM messages that neeed to be added to the
|
||||
current context.
|
||||
|
||||
"""
|
||||
messages: List[dict]
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMMessagesUpdateFrame(DataFrame):
|
||||
"""A frame containing a list of new LLM messages. These messages will
|
||||
replace the current context LLM messages and should generate a new
|
||||
LLMMessagesFrame.
|
||||
|
||||
"""
|
||||
messages: List[dict]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TTSSpeakFrame(DataFrame):
|
||||
"""A frame that contains a text that should be spoken by the TTS in the
|
||||
pipeline (if any).
|
||||
|
||||
"""
|
||||
text: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransportMessageFrame(DataFrame):
|
||||
message: Any
|
||||
@@ -184,14 +212,6 @@ class SystemFrame(Frame):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class StartFrame(SystemFrame):
|
||||
"""This is the first frame that should be pushed down a pipeline."""
|
||||
allow_interruptions: bool = False
|
||||
enable_metrics: bool = False
|
||||
report_only_initial_ttfb: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class CancelFrame(SystemFrame):
|
||||
"""Indicates that a pipeline needs to stop right away."""
|
||||
@@ -240,12 +260,22 @@ class StopInterruptionFrame(SystemFrame):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BotInterruptionFrame(SystemFrame):
|
||||
"""Emitted by when the bot should be interrupted. This will mainly cause the
|
||||
same actions as if the user interrupted except that the
|
||||
UserStartedSpeakingFrame and UserStoppedSpeakingFrame won't be generated.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetricsFrame(SystemFrame):
|
||||
"""Emitted by processor that can compute metrics like latencies.
|
||||
"""
|
||||
ttfb: Mapping[str, float]
|
||||
|
||||
ttfb: List[Mapping[str, Any]] | None = None
|
||||
processing: List[Mapping[str, Any]] | None = None
|
||||
|
||||
#
|
||||
# Control frames
|
||||
@@ -257,6 +287,14 @@ class ControlFrame(Frame):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class StartFrame(ControlFrame):
|
||||
"""This is the first frame that should be pushed down a pipeline."""
|
||||
allow_interruptions: bool = False
|
||||
enable_metrics: bool = False
|
||||
report_only_initial_ttfb: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class EndFrame(ControlFrame):
|
||||
"""Indicates that a pipeline has ended and frame processors and pipelines
|
||||
@@ -271,27 +309,13 @@ class EndFrame(ControlFrame):
|
||||
|
||||
@dataclass
|
||||
class LLMFullResponseStartFrame(ControlFrame):
|
||||
"""Used to indicate the beginning of a full LLM response. Following
|
||||
LLMResponseStartFrame, TextFrame and LLMResponseEndFrame for each sentence
|
||||
until a LLMFullResponseEndFrame."""
|
||||
"""Used to indicate the beginning of an LLM response. Following by one or
|
||||
more TextFrame and a final LLMFullResponseEndFrame."""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMFullResponseEndFrame(ControlFrame):
|
||||
"""Indicates the end of a full LLM response."""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMResponseStartFrame(ControlFrame):
|
||||
"""Used to indicate the beginning of an LLM response. Following TextFrames
|
||||
are part of the LLM response until an LLMResponseEndFrame"""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMResponseEndFrame(ControlFrame):
|
||||
"""Indicates the end of an LLM response."""
|
||||
pass
|
||||
|
||||
@@ -313,6 +337,33 @@ class UserStoppedSpeakingFrame(ControlFrame):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BotStartedSpeakingFrame(ControlFrame):
|
||||
"""Emitted upstream by transport outputs to indicate the bot started speaking.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BotStoppedSpeakingFrame(ControlFrame):
|
||||
"""Emitted upstream by transport outputs to indicate the bot stopped speaking.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BotSpeakingFrame(ControlFrame):
|
||||
"""Emitted upstream by transport outputs while the bot is still
|
||||
speaking. This can be used, for example, to detect when a user is idle. That
|
||||
is, while the bot is speaking we don't want to trigger any user idle timeout
|
||||
since the user might be listening.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class TTSStartedFrame(ControlFrame):
|
||||
"""Used to indicate the beginning of a TTS response. Following
|
||||
@@ -338,3 +389,17 @@ class UserImageRequestFrame(ControlFrame):
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name}, user: {self.user_id}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMModelUpdateFrame(ControlFrame):
|
||||
"""A control frame containing a request to update to a new LLM model.
|
||||
"""
|
||||
model: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class TTSVoiceUpdateFrame(ControlFrame):
|
||||
"""A control frame containing a request to update to a new TTS voice.
|
||||
"""
|
||||
voice: str
|
||||
|
||||
@@ -64,7 +64,7 @@ class Pipeline(BasePipeline):
|
||||
services = []
|
||||
for p in self._processors:
|
||||
if isinstance(p, BasePipeline):
|
||||
services += p.processors_with_metrics()
|
||||
services.extend(p.processors_with_metrics())
|
||||
elif p.can_generate_metrics():
|
||||
services.append(p)
|
||||
return services
|
||||
@@ -91,5 +91,7 @@ class Pipeline(BasePipeline):
|
||||
def _link_processors(self):
|
||||
prev = self._processors[0]
|
||||
for curr in self._processors[1:]:
|
||||
prev.set_parent(self)
|
||||
prev.link(curr)
|
||||
prev = curr
|
||||
prev.set_parent(self)
|
||||
|
||||
@@ -15,7 +15,7 @@ from loguru import logger
|
||||
|
||||
class PipelineRunner:
|
||||
|
||||
def __init__(self, name: str | None = None, handle_sigint: bool = True):
|
||||
def __init__(self, *, name: str | None = None, handle_sigint: bool = True):
|
||||
self.id: int = obj_id()
|
||||
self.name: str = name or f"{self.__class__.__name__}#{obj_count(self)}"
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ from loguru import logger
|
||||
class PipelineParams(BaseModel):
|
||||
allow_interruptions: bool = False
|
||||
enable_metrics: bool = False
|
||||
send_initial_empty_metrics: bool = True
|
||||
report_only_initial_ttfb: bool = False
|
||||
|
||||
|
||||
@@ -71,6 +72,8 @@ class PipelineTask:
|
||||
await self._source.process_frame(CancelFrame(), FrameDirection.DOWNSTREAM)
|
||||
self._process_down_task.cancel()
|
||||
self._process_up_task.cancel()
|
||||
await self._process_down_task
|
||||
await self._process_up_task
|
||||
|
||||
async def run(self):
|
||||
self._process_up_task = asyncio.create_task(self._process_up_queue())
|
||||
@@ -93,8 +96,9 @@ class PipelineTask:
|
||||
|
||||
def _initial_metrics_frame(self) -> MetricsFrame:
|
||||
processors = self._pipeline.processors_with_metrics()
|
||||
ttfb = dict(zip([p.name for p in processors], [0] * len(processors)))
|
||||
return MetricsFrame(ttfb=ttfb)
|
||||
ttfb = [{"processor": p.name, "value": 0.0} for p in processors]
|
||||
processing = [{"processor": p.name, "value": 0.0} for p in processors]
|
||||
return MetricsFrame(ttfb=ttfb, processing=processing)
|
||||
|
||||
async def _process_down_queue(self):
|
||||
start_frame = StartFrame(
|
||||
@@ -103,7 +107,9 @@ class PipelineTask:
|
||||
report_only_initial_ttfb=self._params.report_only_initial_ttfb
|
||||
)
|
||||
await self._source.process_frame(start_frame, FrameDirection.DOWNSTREAM)
|
||||
await self._source.process_frame(self._initial_metrics_frame(), FrameDirection.DOWNSTREAM)
|
||||
|
||||
if self._params.send_initial_empty_metrics:
|
||||
await self._source.process_frame(self._initial_metrics_frame(), FrameDirection.DOWNSTREAM)
|
||||
|
||||
running = True
|
||||
should_cleanup = True
|
||||
@@ -122,6 +128,7 @@ class PipelineTask:
|
||||
await self._pipeline.cleanup()
|
||||
# We just enqueue None to terminate the task gracefully.
|
||||
self._process_up_task.cancel()
|
||||
await self._process_up_task
|
||||
|
||||
async def _process_up_queue(self):
|
||||
while True:
|
||||
|
||||
@@ -14,9 +14,9 @@ from pipecat.frames.frames import (
|
||||
InterimTranscriptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMResponseEndFrame,
|
||||
LLMResponseStartFrame,
|
||||
LLMMessagesAppendFrame,
|
||||
LLMMessagesFrame,
|
||||
LLMMessagesUpdateFrame,
|
||||
StartInterruptionFrame,
|
||||
TranscriptionFrame,
|
||||
TextFrame,
|
||||
@@ -122,6 +122,19 @@ class LLMResponseAggregator(FrameProcessor):
|
||||
# Reset anyways
|
||||
self._reset()
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, LLMMessagesAppendFrame):
|
||||
self._messages.extend(frame.messages)
|
||||
messages_frame = LLMMessagesFrame(self._messages)
|
||||
await self.push_frame(messages_frame)
|
||||
elif isinstance(frame, LLMMessagesUpdateFrame):
|
||||
# We push the frame downstream so the assistant aggregator gets
|
||||
# updated as well.
|
||||
await self.push_frame(frame)
|
||||
# We can now reset this one.
|
||||
self._reset()
|
||||
self._messages = frame.messages
|
||||
messages_frame = LLMMessagesFrame(self._messages)
|
||||
await self.push_frame(messages_frame)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -173,7 +186,7 @@ class LLMUserResponseAggregator(LLMResponseAggregator):
|
||||
|
||||
class LLMFullResponseAggregator(FrameProcessor):
|
||||
"""This class aggregates Text frames until it receives a
|
||||
LLMResponseEndFrame, then emits the concatenated text as
|
||||
LLMFullResponseEndFrame, then emits the concatenated text as
|
||||
a single text frame.
|
||||
|
||||
given the following frames:
|
||||
@@ -182,12 +195,12 @@ class LLMFullResponseAggregator(FrameProcessor):
|
||||
TextFrame(" world.")
|
||||
TextFrame(" I am")
|
||||
TextFrame(" an LLM.")
|
||||
LLMResponseEndFrame()]
|
||||
LLMFullResponseEndFrame()]
|
||||
|
||||
this processor will yield nothing for the first 4 frames, then
|
||||
|
||||
TextFrame("Hello, world. I am an LLM.")
|
||||
LLMResponseEndFrame()
|
||||
LLMFullResponseEndFrame()
|
||||
|
||||
when passed the last frame.
|
||||
|
||||
@@ -203,9 +216,9 @@ class LLMFullResponseAggregator(FrameProcessor):
|
||||
>>> asyncio.run(print_frames(aggregator, TextFrame(" world.")))
|
||||
>>> asyncio.run(print_frames(aggregator, TextFrame(" I am")))
|
||||
>>> asyncio.run(print_frames(aggregator, TextFrame(" an LLM.")))
|
||||
>>> asyncio.run(print_frames(aggregator, LLMResponseEndFrame()))
|
||||
>>> asyncio.run(print_frames(aggregator, LLMFullResponseEndFrame()))
|
||||
Hello, world. I am an LLM.
|
||||
LLMResponseEndFrame
|
||||
LLMFullResponseEndFrame
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
@@ -234,6 +247,11 @@ class LLMContextAggregator(LLMResponseAggregator):
|
||||
async def _push_aggregation(self):
|
||||
if len(self._aggregation) > 0:
|
||||
self._context.add_message({"role": self._role, "content": self._aggregation})
|
||||
|
||||
# Reset the aggregation. Reset it before pushing it down, otherwise
|
||||
# if the tasks gets cancelled we won't be able to clear things up.
|
||||
self._aggregation = ""
|
||||
|
||||
frame = OpenAILLMContextFrame(self._context)
|
||||
await self.push_frame(frame)
|
||||
|
||||
@@ -247,9 +265,10 @@ class LLMAssistantContextAggregator(LLMContextAggregator):
|
||||
messages=[],
|
||||
context=context,
|
||||
role="assistant",
|
||||
start_frame=LLMResponseStartFrame,
|
||||
end_frame=LLMResponseEndFrame,
|
||||
accumulator_frame=TextFrame
|
||||
start_frame=LLMFullResponseStartFrame,
|
||||
end_frame=LLMFullResponseEndFrame,
|
||||
accumulator_frame=TextFrame,
|
||||
handle_interruptions=True
|
||||
)
|
||||
|
||||
|
||||
|
||||
64
src/pipecat/processors/async_frame_processor.py
Normal file
64
src/pipecat/processors/async_frame_processor.py
Normal file
@@ -0,0 +1,64 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
|
||||
from pipecat.frames.frames import EndFrame, Frame, StartInterruptionFrame
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
|
||||
class AsyncFrameProcessor(FrameProcessor):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
name: str | None = None,
|
||||
loop: asyncio.AbstractEventLoop | None = None,
|
||||
**kwargs):
|
||||
super().__init__(name=name, loop=loop, **kwargs)
|
||||
|
||||
self._create_push_task()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
await self._handle_interruptions(frame)
|
||||
|
||||
async def queue_frame(
|
||||
self,
|
||||
frame: Frame,
|
||||
direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
await self._push_queue.put((frame, direction))
|
||||
|
||||
async def cleanup(self):
|
||||
self._push_frame_task.cancel()
|
||||
await self._push_frame_task
|
||||
|
||||
async def _handle_interruptions(self, frame: Frame):
|
||||
# Cancel the task. This will stop pushing frames downstream.
|
||||
self._push_frame_task.cancel()
|
||||
await self._push_frame_task
|
||||
# Push an out-of-band frame (i.e. not using the ordered push
|
||||
# frame task).
|
||||
await self.push_frame(frame)
|
||||
# Create a new queue and task.
|
||||
self._create_push_task()
|
||||
|
||||
def _create_push_task(self):
|
||||
self._push_queue = asyncio.Queue()
|
||||
self._push_frame_task = self.get_event_loop().create_task(self._push_frame_task_handler())
|
||||
|
||||
async def _push_frame_task_handler(self):
|
||||
running = True
|
||||
while running:
|
||||
try:
|
||||
(frame, direction) = await self._push_queue.get()
|
||||
await self.push_frame(frame, direction)
|
||||
running = not isinstance(frame, EndFrame)
|
||||
self._push_queue.task_done()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
@@ -82,5 +82,5 @@ class WakeCheckFilter(FrameProcessor):
|
||||
await self.push_frame(frame, direction)
|
||||
except Exception as e:
|
||||
error_msg = f"Error in wake word filter: {e}"
|
||||
logger.error(error_msg)
|
||||
logger.exception(error_msg)
|
||||
await self.push_error(ErrorFrame(error_msg))
|
||||
|
||||
@@ -9,7 +9,7 @@ import time
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from pipecat.frames.frames import ErrorFrame, Frame, MetricsFrame, StartFrame, UserStoppedSpeakingFrame
|
||||
from pipecat.frames.frames import ErrorFrame, Frame, MetricsFrame, StartFrame, StartInterruptionFrame, UserStoppedSpeakingFrame
|
||||
from pipecat.utils.utils import obj_count, obj_id
|
||||
|
||||
from loguru import logger
|
||||
@@ -20,15 +20,59 @@ class FrameDirection(Enum):
|
||||
UPSTREAM = 2
|
||||
|
||||
|
||||
class FrameProcessorMetrics:
|
||||
def __init__(self, name: str):
|
||||
self._name = name
|
||||
self._start_ttfb_time = 0
|
||||
self._start_processing_time = 0
|
||||
self._should_report_ttfb = True
|
||||
|
||||
async def start_ttfb_metrics(self, report_only_initial_ttfb):
|
||||
if self._should_report_ttfb:
|
||||
self._start_ttfb_time = time.time()
|
||||
self._should_report_ttfb = not report_only_initial_ttfb
|
||||
|
||||
async def stop_ttfb_metrics(self):
|
||||
if self._start_ttfb_time == 0:
|
||||
return None
|
||||
|
||||
value = time.time() - self._start_ttfb_time
|
||||
logger.debug(f"{self._name} TTFB: {value}")
|
||||
ttfb = {
|
||||
"processor": self._name,
|
||||
"value": value
|
||||
}
|
||||
self._start_ttfb_time = 0
|
||||
return MetricsFrame(ttfb=[ttfb])
|
||||
|
||||
async def start_processing_metrics(self):
|
||||
self._start_processing_time = time.time()
|
||||
|
||||
async def stop_processing_metrics(self):
|
||||
if self._start_processing_time == 0:
|
||||
return None
|
||||
|
||||
value = time.time() - self._start_processing_time
|
||||
logger.debug(f"{self._name} processing time: {value}")
|
||||
processing = {
|
||||
"processor": self._name,
|
||||
"value": value
|
||||
}
|
||||
self._start_processing_time = 0
|
||||
return MetricsFrame(processing=[processing])
|
||||
|
||||
|
||||
class FrameProcessor:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
name: str | None = None,
|
||||
loop: asyncio.AbstractEventLoop | None = None,
|
||||
**kwargs):
|
||||
self.id: int = obj_id()
|
||||
self.name = name or f"{self.__class__.__name__}#{obj_count(self)}"
|
||||
self._parent: "FrameProcessor" | None = None
|
||||
self._prev: "FrameProcessor" | None = None
|
||||
self._next: "FrameProcessor" | None = None
|
||||
self._loop: asyncio.AbstractEventLoop = loop or asyncio.get_running_loop()
|
||||
@@ -39,8 +83,7 @@ class FrameProcessor:
|
||||
self._report_only_initial_ttfb = False
|
||||
|
||||
# Metrics
|
||||
self._start_ttfb_time = 0
|
||||
self._should_report_ttfb = True
|
||||
self._metrics = FrameProcessorMetrics(name=self.name)
|
||||
|
||||
@property
|
||||
def interruptions_allowed(self):
|
||||
@@ -58,21 +101,33 @@ class FrameProcessor:
|
||||
return False
|
||||
|
||||
async def start_ttfb_metrics(self):
|
||||
if self.metrics_enabled and self._should_report_ttfb:
|
||||
self._start_ttfb_time = time.time()
|
||||
self._should_report_ttfb = not self._report_only_initial_ttfb
|
||||
if self.can_generate_metrics() and self.metrics_enabled:
|
||||
await self._metrics.start_ttfb_metrics(self._report_only_initial_ttfb)
|
||||
|
||||
async def stop_ttfb_metrics(self):
|
||||
if self.metrics_enabled and self._start_ttfb_time > 0:
|
||||
ttfb = time.time() - self._start_ttfb_time
|
||||
logger.debug(f"{self.name} TTFB: {ttfb}")
|
||||
await self.push_frame(MetricsFrame(ttfb={self.name: ttfb}))
|
||||
self._start_ttfb_time = 0
|
||||
if self.can_generate_metrics() and self.metrics_enabled:
|
||||
frame = await self._metrics.stop_ttfb_metrics()
|
||||
if frame:
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def start_processing_metrics(self):
|
||||
if self.can_generate_metrics() and self.metrics_enabled:
|
||||
await self._metrics.start_processing_metrics()
|
||||
|
||||
async def stop_processing_metrics(self):
|
||||
if self.can_generate_metrics() and self.metrics_enabled:
|
||||
frame = await self._metrics.stop_processing_metrics()
|
||||
if frame:
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def stop_all_metrics(self):
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.stop_processing_metrics()
|
||||
|
||||
async def cleanup(self):
|
||||
pass
|
||||
|
||||
def link(self, processor: 'FrameProcessor'):
|
||||
def link(self, processor: "FrameProcessor"):
|
||||
self._next = processor
|
||||
processor._prev = self
|
||||
logger.debug(f"Linking {self} -> {self._next}")
|
||||
@@ -80,11 +135,19 @@ class FrameProcessor:
|
||||
def get_event_loop(self) -> asyncio.AbstractEventLoop:
|
||||
return self._loop
|
||||
|
||||
def set_parent(self, parent: "FrameProcessor"):
|
||||
self._parent = parent
|
||||
|
||||
def get_parent(self) -> "FrameProcessor":
|
||||
return self._parent
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
if isinstance(frame, StartFrame):
|
||||
self._allow_interruptions = frame.allow_interruptions
|
||||
self._enable_metrics = frame.enable_metrics
|
||||
self._report_only_initial_ttfb = frame.report_only_initial_ttfb
|
||||
elif isinstance(frame, StartInterruptionFrame):
|
||||
await self.stop_all_metrics()
|
||||
elif isinstance(frame, UserStoppedSpeakingFrame):
|
||||
self._should_report_ttfb = True
|
||||
|
||||
@@ -92,12 +155,15 @@ class FrameProcessor:
|
||||
await self.push_frame(error, FrameDirection.UPSTREAM)
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
if direction == FrameDirection.DOWNSTREAM and self._next:
|
||||
logger.trace(f"Pushing {frame} from {self} to {self._next}")
|
||||
await self._next.process_frame(frame, direction)
|
||||
elif direction == FrameDirection.UPSTREAM and self._prev:
|
||||
logger.trace(f"Pushing {frame} upstream from {self} to {self._prev}")
|
||||
await self._prev.process_frame(frame, direction)
|
||||
try:
|
||||
if direction == FrameDirection.DOWNSTREAM and self._next:
|
||||
logger.trace(f"Pushing {frame} from {self} to {self._next}")
|
||||
await self._next.process_frame(frame, direction)
|
||||
elif direction == FrameDirection.UPSTREAM and self._prev:
|
||||
logger.trace(f"Pushing {frame} upstream from {self} to {self._prev}")
|
||||
await self._prev.process_frame(frame, direction)
|
||||
except Exception as e:
|
||||
logger.exception(f"Uncaught exception in {self}: {e}")
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
@@ -11,8 +11,6 @@ from pipecat.frames.frames import (
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesFrame,
|
||||
LLMResponseEndFrame,
|
||||
LLMResponseStartFrame,
|
||||
TextFrame)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
@@ -69,11 +67,10 @@ class LangchainProcessor(FrameProcessor):
|
||||
{self._transcript_key: text},
|
||||
config={"configurable": {"session_id": self._participant_id}},
|
||||
):
|
||||
await self.push_frame(LLMResponseStartFrame())
|
||||
await self.push_frame(TextFrame(self.__get_token_value(token)))
|
||||
await self.push_frame(LLMResponseEndFrame())
|
||||
except GeneratorExit:
|
||||
logger.warning(f"{self} generator was closed prematurely")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} an unknown error occurred: {e}")
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
logger.exception(f"{self} an unknown error occurred: {e}")
|
||||
finally:
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
|
||||
615
src/pipecat/processors/frameworks/rtvi.py
Normal file
615
src/pipecat/processors/frameworks/rtvi.py
Normal file
@@ -0,0 +1,615 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Type
|
||||
from pydantic import PrivateAttr, BaseModel, ValidationError
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotInterruptionFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesAppendFrame,
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMModelUpdateFrame,
|
||||
MetricsFrame,
|
||||
StartFrame,
|
||||
SystemFrame,
|
||||
TTSSpeakFrame,
|
||||
TTSVoiceUpdateFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
TransportMessageFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantResponseAggregator, LLMUserResponseAggregator)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.services.cartesia import CartesiaTTSService
|
||||
from pipecat.services.openai import OpenAILLMService, OpenAILLMContext
|
||||
from pipecat.transports.base_transport import BaseTransport
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class RTVIServiceOption(BaseModel):
|
||||
name: str
|
||||
handler: Optional[Callable[['RTVIProcessor',
|
||||
'RTVIServiceOptionConfig'],
|
||||
Awaitable[None]]] = None
|
||||
|
||||
|
||||
class RTVIService(BaseModel):
|
||||
name: str
|
||||
cls: Type[FrameProcessor]
|
||||
options: List[RTVIServiceOption]
|
||||
_options_dict: Dict[str, RTVIServiceOption] = PrivateAttr(default={})
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
self._options_dict = {}
|
||||
for option in self.options:
|
||||
self._options_dict[option.name] = option
|
||||
return super().model_post_init(__context)
|
||||
|
||||
#
|
||||
# Client -> Pipecat messages.
|
||||
#
|
||||
|
||||
|
||||
class RTVIServiceOptionConfig(BaseModel):
|
||||
name: str
|
||||
value: Any
|
||||
|
||||
|
||||
class RTVIServiceConfig(BaseModel):
|
||||
service: str
|
||||
options: List[RTVIServiceOptionConfig]
|
||||
|
||||
|
||||
class RTVIConfig(BaseModel):
|
||||
config: List[RTVIServiceConfig]
|
||||
_config_dict: Dict[str, RTVIServiceConfig] = PrivateAttr(default={})
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
self._config_dict = {}
|
||||
for c in self.config:
|
||||
self._config_dict[c.service] = c
|
||||
return super().model_post_init(__context)
|
||||
|
||||
|
||||
class RTVILLMContextData(BaseModel):
|
||||
messages: List[dict]
|
||||
|
||||
|
||||
class RTVITTSSpeakData(BaseModel):
|
||||
text: str
|
||||
interrupt: Optional[bool] = False
|
||||
|
||||
|
||||
class RTVIMessage(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: str
|
||||
id: str
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
|
||||
#
|
||||
# Pipecat -> Client responses and messages.
|
||||
#
|
||||
|
||||
|
||||
class RTVIResponseData(BaseModel):
|
||||
success: bool
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class RTVIResponse(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["response"] = "response"
|
||||
id: str
|
||||
data: RTVIResponseData
|
||||
|
||||
|
||||
class RTVIErrorData(BaseModel):
|
||||
message: str
|
||||
|
||||
|
||||
class RTVIError(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["error"] = "error"
|
||||
data: RTVIErrorData
|
||||
|
||||
|
||||
class RTVILLMContextMessageData(BaseModel):
|
||||
messages: List[dict]
|
||||
|
||||
|
||||
class RTVILLMContextMessage(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["llm-context"] = "llm-context"
|
||||
data: RTVILLMContextMessageData
|
||||
|
||||
|
||||
class RTVITTSTextMessageData(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
class RTVITTSTextMessage(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["tts-text"] = "tts-text"
|
||||
data: RTVITTSTextMessageData
|
||||
|
||||
|
||||
class RTVIBotReady(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["bot-ready"] = "bot-ready"
|
||||
|
||||
|
||||
class RTVITranscriptionMessageData(BaseModel):
|
||||
text: str
|
||||
user_id: str
|
||||
timestamp: str
|
||||
final: bool
|
||||
|
||||
|
||||
class RTVITranscriptionMessage(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["user-transcription"] = "user-transcription"
|
||||
data: RTVITranscriptionMessageData
|
||||
|
||||
|
||||
class RTVIUserStartedSpeakingMessage(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["user-started-speaking"] = "user-started-speaking"
|
||||
|
||||
|
||||
class RTVIUserStoppedSpeakingMessage(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["user-stopped-speaking"] = "user-stopped-speaking"
|
||||
|
||||
|
||||
class RTVIJSONCompletion(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["json-completion"] = "json-completion"
|
||||
data: str
|
||||
|
||||
|
||||
class FunctionCaller(FrameProcessor):
|
||||
|
||||
def __init__(self, context):
|
||||
super().__init__()
|
||||
self._checking = False
|
||||
self._aggregating = False
|
||||
self._emitted_start = False
|
||||
self._aggregation = ""
|
||||
self._context = context
|
||||
|
||||
self._callbacks = {}
|
||||
self._start_callbacks = {}
|
||||
|
||||
def register_function(self, function_name: str, callback, start_callback=None):
|
||||
self._callbacks[function_name] = callback
|
||||
if start_callback:
|
||||
self._start_callbacks[function_name] = start_callback
|
||||
|
||||
def unregister_function(self, function_name: str):
|
||||
del self._callbacks[function_name]
|
||||
if self._start_callbacks[function_name]:
|
||||
del self._start_callbacks[function_name]
|
||||
|
||||
def has_function(self, function_name: str):
|
||||
return function_name in self._callbacks.keys()
|
||||
|
||||
async def call_function(self, function_name: str, args):
|
||||
if function_name in self._callbacks.keys():
|
||||
return await self._callbacks[function_name](self, args)
|
||||
return None
|
||||
|
||||
async def call_start_function(self, function_name: str):
|
||||
if function_name in self._start_callbacks.keys():
|
||||
await self._start_callbacks[function_name](self)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, LLMFullResponseStartFrame):
|
||||
self._checking = True
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, TextFrame) and self._checking:
|
||||
# TODO-CB: should we expand this to any non-text character to start the completion?
|
||||
if frame.text.strip().startswith("{") or frame.text.strip().startswith("```"):
|
||||
self._emitted_start = False
|
||||
self._checking = False
|
||||
self._aggregation = frame.text
|
||||
self._aggregating = True
|
||||
else:
|
||||
self._checking = False
|
||||
self._aggregating = False
|
||||
self._aggregation = ""
|
||||
self._emitted_start = False
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, TextFrame) and self._aggregating:
|
||||
self._aggregation += frame.text
|
||||
# TODO-CB: We can probably ignore function start I think
|
||||
# if not self._emitted_start:
|
||||
# fn = re.search(r'{"function_name":\s*"(.*)",', self._aggregation)
|
||||
# if fn and fn.group(1):
|
||||
# await self.call_start_function(fn.group(1))
|
||||
# self._emitted_start = True
|
||||
elif isinstance(frame, LLMFullResponseEndFrame) and self._aggregating:
|
||||
try:
|
||||
self._aggregation = self._aggregation.replace("```json", "").replace("```", "")
|
||||
self._context.add_message({"role": "assistant", "content": self._aggregation})
|
||||
message = RTVIJSONCompletion(data=self._aggregation)
|
||||
msg = message.model_dump(exclude_none=True)
|
||||
await self.push_frame(TransportMessageFrame(message=msg))
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error parsing function call json: {e}")
|
||||
print(f"aggregation was: {self._aggregation}")
|
||||
|
||||
self._aggregating = False
|
||||
self._aggregation = ""
|
||||
self._emitted_start = False
|
||||
elif isinstance(frame, LLMFullResponseEndFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
class RTVITTSTextProcessor(FrameProcessor):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, TextFrame):
|
||||
message = RTVITTSTextMessage(data=RTVITTSTextMessageData(text=frame.text))
|
||||
await self.push_frame(TransportMessageFrame(message=message.model_dump(exclude_none=True)))
|
||||
|
||||
|
||||
async def handle_llm_model_update(rtvi: 'RTVIProcessor', option: RTVIServiceOptionConfig):
|
||||
frame = LLMModelUpdateFrame(option.value)
|
||||
await rtvi.push_frame(frame)
|
||||
|
||||
|
||||
async def handle_llm_messages_update(rtvi: 'RTVIProcessor', option: RTVIServiceOptionConfig):
|
||||
frame = LLMMessagesUpdateFrame(option.value)
|
||||
await rtvi.push_frame(frame)
|
||||
|
||||
|
||||
async def handle_tts_voice_update(rtvi: 'RTVIProcessor', option: RTVIServiceOptionConfig):
|
||||
frame = TTSVoiceUpdateFrame(option.value)
|
||||
await rtvi.push_frame(frame)
|
||||
|
||||
DEFAULT_LLM_SERVICE = RTVIService(
|
||||
name="llm",
|
||||
cls=OpenAILLMService,
|
||||
options=[
|
||||
RTVIServiceOption(name="model", handler=handle_llm_model_update),
|
||||
RTVIServiceOption(name="messages", handler=handle_llm_messages_update)
|
||||
])
|
||||
|
||||
DEFAULT_TTS_SERVICE = RTVIService(
|
||||
name="tts",
|
||||
cls=CartesiaTTSService,
|
||||
options=[
|
||||
RTVIServiceOption(name="voice_id", handler=handle_tts_voice_update),
|
||||
])
|
||||
|
||||
|
||||
class RTVIProcessor(FrameProcessor):
|
||||
|
||||
def __init__(self, *, transport: BaseTransport):
|
||||
super().__init__()
|
||||
self._transport = transport
|
||||
self._config: RTVIConfig | None = None
|
||||
self._ctor_args: Dict[str, Any] = {}
|
||||
|
||||
self._start_frame: Frame | None = None
|
||||
self._pipeline: FrameProcessor | None = None
|
||||
self._first_participant_joined: bool = False
|
||||
|
||||
# Register transport event so we can send a `bot-ready` event (and maybe
|
||||
# others) when the participant joins.
|
||||
transport.add_event_handler(
|
||||
"on_first_participant_joined",
|
||||
self._on_first_participant_joined)
|
||||
|
||||
# Register default services.
|
||||
self._registered_services: Dict[str, RTVIService] = {}
|
||||
self.register_service(DEFAULT_LLM_SERVICE)
|
||||
self.register_service(DEFAULT_TTS_SERVICE)
|
||||
|
||||
self._frame_handler_task = self.get_event_loop().create_task(self._frame_handler())
|
||||
self._frame_queue = asyncio.Queue()
|
||||
|
||||
def register_service(self, service: RTVIService):
|
||||
self._registered_services[service.name] = service
|
||||
|
||||
def setup_on_start(self, config: RTVIConfig | None, ctor_args: Dict[str, Any]):
|
||||
self._config = config
|
||||
self._ctor_args = ctor_args
|
||||
|
||||
async def update_config(self, config: RTVIConfig):
|
||||
if self._pipeline:
|
||||
await self._handle_config_update(config)
|
||||
self._config = config
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# Specific system frames
|
||||
if isinstance(frame, CancelFrame):
|
||||
await self._cancel(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
# All other system frames
|
||||
elif isinstance(frame, SystemFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
# Control frames
|
||||
elif isinstance(frame, StartFrame):
|
||||
await self._start(frame)
|
||||
await self._internal_push_frame(frame, direction)
|
||||
elif isinstance(frame, EndFrame):
|
||||
# Push EndFrame before stop(), because stop() waits on the task to
|
||||
# finish and the task finishes when EndFrame is processed.
|
||||
await self._internal_push_frame(frame, direction)
|
||||
await self._stop(frame)
|
||||
# Other frames
|
||||
else:
|
||||
await self._internal_push_frame(frame, direction)
|
||||
|
||||
async def cleanup(self):
|
||||
if self._pipeline:
|
||||
await self._pipeline.cleanup()
|
||||
|
||||
async def _start(self, frame: StartFrame):
|
||||
try:
|
||||
await self._handle_pipeline_setup(frame, self._config)
|
||||
except Exception as e:
|
||||
await self._send_error(f"unable to setup RTVI pipeline: {e}")
|
||||
|
||||
async def _stop(self, frame: EndFrame):
|
||||
await self._frame_handler_task
|
||||
|
||||
async def _cancel(self, frame: CancelFrame):
|
||||
self._frame_handler_task.cancel()
|
||||
await self._frame_handler_task
|
||||
|
||||
async def _internal_push_frame(
|
||||
self,
|
||||
frame: Frame | None,
|
||||
direction: FrameDirection | None = FrameDirection.DOWNSTREAM):
|
||||
await self._frame_queue.put((frame, direction))
|
||||
|
||||
async def _frame_handler(self):
|
||||
running = True
|
||||
while running:
|
||||
try:
|
||||
(frame, direction) = await self._frame_queue.get()
|
||||
await self._handle_frame(frame, direction)
|
||||
self._frame_queue.task_done()
|
||||
running = not isinstance(frame, EndFrame)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
async def _handle_frame(self, frame: Frame, direction: FrameDirection):
|
||||
if isinstance(frame, TransportMessageFrame):
|
||||
await self._handle_message(frame)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, TranscriptionFrame) or isinstance(frame, InterimTranscriptionFrame):
|
||||
await self._handle_transcriptions(frame)
|
||||
elif isinstance(frame, UserStartedSpeakingFrame) or isinstance(frame, UserStoppedSpeakingFrame):
|
||||
await self._handle_interruptions(frame)
|
||||
|
||||
async def _handle_transcriptions(self, frame: Frame):
|
||||
# TODO(aleix): Once we add support for using custom piplines, the STTs will
|
||||
# be in the pipeline after this processor. This means the STT will have to
|
||||
# push transcriptions upstream as well.
|
||||
|
||||
message = None
|
||||
if isinstance(frame, TranscriptionFrame):
|
||||
message = RTVITranscriptionMessage(
|
||||
data=RTVITranscriptionMessageData(
|
||||
text=frame.text,
|
||||
user_id=frame.user_id,
|
||||
timestamp=frame.timestamp,
|
||||
final=True))
|
||||
elif isinstance(frame, InterimTranscriptionFrame):
|
||||
message = RTVITranscriptionMessage(
|
||||
data=RTVITranscriptionMessageData(
|
||||
text=frame.text,
|
||||
user_id=frame.user_id,
|
||||
timestamp=frame.timestamp,
|
||||
final=False))
|
||||
|
||||
if message:
|
||||
frame = TransportMessageFrame(message=message.model_dump(exclude_none=True))
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def _handle_interruptions(self, frame: Frame):
|
||||
message = None
|
||||
if isinstance(frame, UserStartedSpeakingFrame):
|
||||
message = RTVIUserStartedSpeakingMessage()
|
||||
elif isinstance(frame, UserStoppedSpeakingFrame):
|
||||
message = RTVIUserStoppedSpeakingMessage()
|
||||
|
||||
if message:
|
||||
frame = TransportMessageFrame(message=message.model_dump(exclude_none=True))
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def _handle_message(self, frame: TransportMessageFrame):
|
||||
try:
|
||||
message = RTVIMessage.model_validate(frame.message)
|
||||
except ValidationError as e:
|
||||
await self._send_error(f"Invalid incoming message: {e}")
|
||||
logger.warning(f"Invalid incoming message: {e}")
|
||||
return
|
||||
|
||||
try:
|
||||
success = True
|
||||
error = None
|
||||
match message.type:
|
||||
case "config-update":
|
||||
await self._handle_config_update(RTVIConfig.model_validate(message.data))
|
||||
case "llm-get-context":
|
||||
await self._handle_llm_get_context()
|
||||
case "llm-append-context":
|
||||
await self._handle_llm_append_context(RTVILLMContextData.model_validate(message.data))
|
||||
case "llm-update-context":
|
||||
await self._handle_llm_update_context(RTVILLMContextData.model_validate(message.data))
|
||||
case "tts-speak":
|
||||
await self._handle_tts_speak(RTVITTSSpeakData.model_validate(message.data))
|
||||
case "tts-interrupt":
|
||||
await self._handle_tts_interrupt()
|
||||
case _:
|
||||
success = False
|
||||
error = f"Unsupported type {message.type}"
|
||||
|
||||
await self._send_response(message.id, success, error)
|
||||
except ValidationError as e:
|
||||
await self._send_response(message.id, False, f"Invalid incoming message: {e}")
|
||||
logger.warning(f"Invalid incoming message: {e}")
|
||||
except Exception as e:
|
||||
await self._send_response(message.id, False, f"Exception processing message: {e}")
|
||||
logger.warning(f"Exception processing message: {e}")
|
||||
|
||||
async def _handle_pipeline_setup(self, start_frame: StartFrame, config: RTVIConfig | None):
|
||||
# TODO(aleix): We shouldn't need to save this in `self._tma_in`.
|
||||
self._tma_in = LLMUserResponseAggregator()
|
||||
tma_out = LLMAssistantResponseAggregator()
|
||||
|
||||
llm_cls = self._registered_services["llm"].cls
|
||||
llm_args = self._ctor_args["llm"]
|
||||
llm = llm_cls(**llm_args)
|
||||
|
||||
tts_cls = self._registered_services["tts"].cls
|
||||
tts_args = self._ctor_args["tts"]
|
||||
tts = tts_cls(**tts_args)
|
||||
|
||||
# TODO-CB: Eventually we'll need to switch the context aggregators to use the
|
||||
# OpenAI context frames instead of message frames
|
||||
context = OpenAILLMContext()
|
||||
fc = FunctionCaller(context)
|
||||
|
||||
tts_text = RTVITTSTextProcessor()
|
||||
|
||||
pipeline = Pipeline([
|
||||
self._tma_in,
|
||||
llm,
|
||||
fc,
|
||||
tts,
|
||||
tts_text,
|
||||
tma_out,
|
||||
self._transport.output(),
|
||||
])
|
||||
|
||||
parent = self.get_parent()
|
||||
if parent:
|
||||
parent.link(pipeline)
|
||||
|
||||
# We need to initialize the new pipeline with the same settings
|
||||
# as the initial one.
|
||||
start_frame = dataclasses.replace(start_frame)
|
||||
await self.push_frame(start_frame)
|
||||
|
||||
# Configure the pipeline
|
||||
if config:
|
||||
await self._handle_config_update(config)
|
||||
|
||||
# Send new initial metrics with the new processors
|
||||
processors = parent.processors_with_metrics()
|
||||
processors.extend(pipeline.processors_with_metrics())
|
||||
ttfb = [{"processor": p.name, "value": 0.0} for p in processors]
|
||||
processing = [{"processor": p.name, "value": 0.0} for p in processors]
|
||||
await self.push_frame(MetricsFrame(ttfb=ttfb, processing=processing))
|
||||
|
||||
self._pipeline = pipeline
|
||||
|
||||
await self._maybe_send_bot_ready()
|
||||
|
||||
async def _handle_config_service(self, config: RTVIServiceConfig):
|
||||
service = self._registered_services[config.service]
|
||||
for option in config.options:
|
||||
handler = service._options_dict[option.name].handler
|
||||
if handler:
|
||||
await handler(self, option)
|
||||
|
||||
async def _handle_config_update(self, data: RTVIConfig):
|
||||
for config in data.config:
|
||||
await self._handle_config_service(config)
|
||||
|
||||
async def _handle_llm_get_context(self):
|
||||
data = RTVILLMContextMessageData(messages=self._tma_in.messages)
|
||||
message = RTVILLMContextMessage(data=data)
|
||||
frame = TransportMessageFrame(message=message.model_dump(exclude_none=True))
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def _handle_llm_append_context(self, data: RTVILLMContextData):
|
||||
if data and data.messages:
|
||||
frame = LLMMessagesAppendFrame(data.messages)
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def _handle_llm_update_context(self, data: RTVILLMContextData):
|
||||
if data and data.messages:
|
||||
frame = LLMMessagesUpdateFrame(data.messages)
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def _handle_tts_speak(self, data: RTVITTSSpeakData):
|
||||
if data and data.text:
|
||||
if data.interrupt:
|
||||
await self._handle_tts_interrupt()
|
||||
frame = TTSSpeakFrame(text=data.text)
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def _handle_tts_interrupt(self):
|
||||
await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM)
|
||||
|
||||
async def _on_first_participant_joined(self, transport, participant):
|
||||
self._first_participant_joined = True
|
||||
await self._maybe_send_bot_ready()
|
||||
|
||||
async def _maybe_send_bot_ready(self):
|
||||
if self._pipeline and self._first_participant_joined:
|
||||
message = RTVIBotReady()
|
||||
frame = TransportMessageFrame(message=message.model_dump(exclude_none=True))
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def _send_error(self, error: str):
|
||||
message = RTVIError(data=RTVIErrorData(message=error))
|
||||
frame = TransportMessageFrame(message=message.model_dump(exclude_none=True))
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def _send_response(self, id: str, success: bool, error: str | None = None):
|
||||
# TODO(aleix): This is a bit hacky, but we might get invalid
|
||||
# configuration or something might going wrong during setup and we would
|
||||
# like to send the error to the client. However, if the pipeline is not
|
||||
# setup yet we don't have an output transport and therefore we can't
|
||||
# send any messages. So, we setup a super basic pipeline with just the
|
||||
# output transport so we can send messages.
|
||||
if not self._pipeline:
|
||||
pipeline = Pipeline([self._transport.output()])
|
||||
self._pipeline = pipeline
|
||||
|
||||
parent = self.get_parent()
|
||||
if parent:
|
||||
parent.link(pipeline)
|
||||
|
||||
message = RTVIResponse(id=id, data=RTVIResponseData(success=success, error=error))
|
||||
frame = TransportMessageFrame(message=message.model_dump(exclude_none=True))
|
||||
await self.push_frame(frame)
|
||||
76
src/pipecat/processors/idle_frame_processor.py
Normal file
76
src/pipecat/processors/idle_frame_processor.py
Normal file
@@ -0,0 +1,76 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
|
||||
from typing import Awaitable, Callable, List
|
||||
|
||||
from pipecat.frames.frames import Frame, SystemFrame
|
||||
from pipecat.processors.async_frame_processor import AsyncFrameProcessor
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
|
||||
|
||||
class IdleFrameProcessor(AsyncFrameProcessor):
|
||||
"""This class waits to receive any frame or list of desired frames within a
|
||||
given timeout. If the timeout is reached before receiving any of those
|
||||
frames the provided callback will be called.
|
||||
|
||||
The callback can then be used to push frames downstream by using
|
||||
`queue_frame()` (or `push_frame()` for system frames).
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
callback: Callable[["IdleFrameProcessor"], Awaitable[None]],
|
||||
timeout: float,
|
||||
types: List[type] = [],
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._callback = callback
|
||||
self._timeout = timeout
|
||||
self._types = types
|
||||
|
||||
self._create_idle_task()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, SystemFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
await self.queue_frame(frame, direction)
|
||||
|
||||
# If we are not waiting for any specific frame set the event, otherwise
|
||||
# check if we have received one of the desired frames.
|
||||
if not self._types:
|
||||
self._idle_event.set()
|
||||
else:
|
||||
for t in self._types:
|
||||
if isinstance(frame, t):
|
||||
self._idle_event.set()
|
||||
|
||||
# If we are not waiting for any specific frame set the event, otherwise
|
||||
async def cleanup(self):
|
||||
self._idle_task.cancel()
|
||||
await self._idle_task
|
||||
|
||||
def _create_idle_task(self):
|
||||
self._idle_event = asyncio.Event()
|
||||
self._idle_task = self.get_event_loop().create_task(self._idle_task_handler())
|
||||
|
||||
async def _idle_task_handler(self):
|
||||
while True:
|
||||
try:
|
||||
await asyncio.wait_for(self._idle_event.wait(), timeout=self._timeout)
|
||||
except asyncio.TimeoutError:
|
||||
await self._callback(self)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
finally:
|
||||
self._idle_event.clear()
|
||||
@@ -33,6 +33,6 @@ class StatelessTextTransformer(FrameProcessor):
|
||||
result = self._transform_fn(frame.text)
|
||||
if isinstance(result, Coroutine):
|
||||
result = await result
|
||||
await self.push_frame(result)
|
||||
await self.push_frame(TextFrame(text=result))
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
82
src/pipecat/processors/user_idle_processor.py
Normal file
82
src/pipecat/processors/user_idle_processor.py
Normal file
@@ -0,0 +1,82 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotSpeakingFrame,
|
||||
Frame,
|
||||
SystemFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame)
|
||||
from pipecat.processors.async_frame_processor import AsyncFrameProcessor
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
|
||||
|
||||
class UserIdleProcessor(AsyncFrameProcessor):
|
||||
"""This class is useful to check if the user is interacting with the bot
|
||||
within a given timeout. If the timeout is reached before any interaction
|
||||
occurred the provided callback will be called.
|
||||
|
||||
The callback can then be used to push frames downstream by using
|
||||
`queue_frame()` (or `push_frame()` for system frames).
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
callback: Callable[["UserIdleProcessor"], Awaitable[None]],
|
||||
timeout: float,
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._callback = callback
|
||||
self._timeout = timeout
|
||||
|
||||
self._interrupted = False
|
||||
|
||||
self._create_idle_task()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, SystemFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
await self.queue_frame(frame, direction)
|
||||
|
||||
# We shouldn't call the idle callback if the user or the bot are speaking.
|
||||
if isinstance(frame, UserStartedSpeakingFrame):
|
||||
self._interrupted = True
|
||||
self._idle_event.set()
|
||||
elif isinstance(frame, UserStoppedSpeakingFrame):
|
||||
self._interrupted = False
|
||||
self._idle_event.set()
|
||||
elif isinstance(frame, BotSpeakingFrame):
|
||||
self._idle_event.set()
|
||||
|
||||
async def cleanup(self):
|
||||
self._idle_task.cancel()
|
||||
await self._idle_task
|
||||
|
||||
def _create_idle_task(self):
|
||||
self._idle_event = asyncio.Event()
|
||||
self._idle_task = self.get_event_loop().create_task(self._idle_task_handler())
|
||||
|
||||
async def _idle_task_handler(self):
|
||||
while True:
|
||||
try:
|
||||
await asyncio.wait_for(self._idle_event.wait(), timeout=self._timeout)
|
||||
except asyncio.TimeoutError:
|
||||
if not self._interrupted:
|
||||
await self._callback(self)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
finally:
|
||||
self._idle_event.clear()
|
||||
@@ -12,9 +12,9 @@ from pipecat.frames.frames import Frame
|
||||
class FrameSerializer(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def serialize(self, frame: Frame) -> bytes:
|
||||
def serialize(self, frame: Frame) -> str | bytes | None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def deserialize(self, data: bytes) -> Frame | None:
|
||||
def deserialize(self, data: str | bytes) -> Frame | None:
|
||||
pass
|
||||
|
||||
@@ -26,7 +26,7 @@ class ProtobufFrameSerializer(FrameSerializer):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def serialize(self, frame: Frame) -> bytes:
|
||||
def serialize(self, frame: Frame) -> str | bytes | None:
|
||||
proto_frame = frame_protos.Frame()
|
||||
if type(frame) not in self.SERIALIZABLE_TYPES:
|
||||
raise ValueError(
|
||||
@@ -41,7 +41,7 @@ class ProtobufFrameSerializer(FrameSerializer):
|
||||
result = proto_frame.SerializeToString()
|
||||
return result
|
||||
|
||||
def deserialize(self, data: bytes) -> Frame | None:
|
||||
def deserialize(self, data: str | bytes) -> Frame | None:
|
||||
"""Returns a Frame object from a Frame protobuf. Used to convert frames
|
||||
passed over the wire as protobufs to Frame objects used in pipelines
|
||||
and frame processors.
|
||||
|
||||
52
src/pipecat/serializers/twilio.py
Normal file
52
src/pipecat/serializers/twilio.py
Normal file
@@ -0,0 +1,52 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import base64
|
||||
import json
|
||||
|
||||
from pipecat.frames.frames import AudioRawFrame, Frame
|
||||
from pipecat.serializers.base_serializer import FrameSerializer
|
||||
from pipecat.utils.audio import ulaw_8000_to_pcm_16000, pcm_16000_to_ulaw_8000
|
||||
|
||||
|
||||
class TwilioFrameSerializer(FrameSerializer):
|
||||
SERIALIZABLE_TYPES = {
|
||||
AudioRawFrame: "audio",
|
||||
}
|
||||
|
||||
def __init__(self, stream_sid: str):
|
||||
self._stream_sid = stream_sid
|
||||
|
||||
def serialize(self, frame: Frame) -> str | bytes | None:
|
||||
if not isinstance(frame, AudioRawFrame):
|
||||
return None
|
||||
|
||||
data = frame.audio
|
||||
|
||||
serialized_data = pcm_16000_to_ulaw_8000(data)
|
||||
payload = base64.b64encode(serialized_data).decode("utf-8")
|
||||
answer = {
|
||||
"event": "media",
|
||||
"streamSid": self._stream_sid,
|
||||
"media": {
|
||||
"payload": payload
|
||||
}
|
||||
}
|
||||
|
||||
return json.dumps(answer)
|
||||
|
||||
def deserialize(self, data: str | bytes) -> Frame | None:
|
||||
message = json.loads(data)
|
||||
|
||||
if message["event"] != "media":
|
||||
return None
|
||||
else:
|
||||
payload_base64 = message["media"]["payload"]
|
||||
payload = base64.b64decode(payload_base64)
|
||||
|
||||
deserialized_data = ulaw_8000_to_pcm_16000(payload)
|
||||
audio_frame = AudioRawFrame(audio=deserialized_data, num_channels=1, sample_rate=16000)
|
||||
return audio_frame
|
||||
@@ -16,15 +16,38 @@ from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
LLMFullResponseEndFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TTSSpeakFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
TTSVoiceUpdateFrame,
|
||||
TextFrame,
|
||||
VisionImageRawFrame,
|
||||
)
|
||||
from pipecat.processors.async_frame_processor import AsyncFrameProcessor
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.utils.audio import calculate_audio_volume
|
||||
from pipecat.utils.utils import exp_smoothing
|
||||
import re
|
||||
|
||||
|
||||
ENDOFSENTENCE_PATTERN_STR = r"""
|
||||
(?<![A-Z]) # Negative lookbehind: not preceded by an uppercase letter (e.g., "U.S.A.")
|
||||
(?<!\d) # Negative lookbehind: not preceded by a digit (e.g., "1. Let's start")
|
||||
(?<!\d\s[ap]) # Negative lookbehind: not preceded by time (e.g., "3:00 a.m.")
|
||||
(?<!Mr|Ms|Dr) # Negative lookbehind: not preceded by Mr, Ms, Dr (combined bc. length is the same)
|
||||
(?<!Mrs) # Negative lookbehind: not preceded by "Mrs"
|
||||
(?<!Prof) # Negative lookbehind: not preceded by "Prof"
|
||||
[\.\?\!:] # Match a period, question mark, exclamation point, or colon
|
||||
$ # End of string
|
||||
"""
|
||||
ENDOFSENTENCE_PATTERN = re.compile(ENDOFSENTENCE_PATTERN_STR, re.VERBOSE)
|
||||
|
||||
|
||||
def match_endofsentence(text: str) -> bool:
|
||||
return ENDOFSENTENCE_PATTERN.search(text.rstrip()) is not None
|
||||
|
||||
|
||||
class AIService(FrameProcessor):
|
||||
@@ -58,6 +81,30 @@ class AIService(FrameProcessor):
|
||||
await self.push_frame(f)
|
||||
|
||||
|
||||
class AsyncAIService(AsyncFrameProcessor):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
pass
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
pass
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
pass
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, StartFrame):
|
||||
await self.start(frame)
|
||||
elif isinstance(frame, CancelFrame):
|
||||
await self.cancel(frame)
|
||||
elif isinstance(frame, EndFrame):
|
||||
await self.stop(frame)
|
||||
|
||||
|
||||
class LLMService(AIService):
|
||||
"""This class is a no-op but serves as a base class for LLM services."""
|
||||
|
||||
@@ -91,11 +138,22 @@ class LLMService(AIService):
|
||||
|
||||
|
||||
class TTSService(AIService):
|
||||
def __init__(self, aggregate_sentences: bool = True, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
aggregate_sentences: bool = True,
|
||||
# if True, subclass is responsible for pushing TextFrames and LLMFullResponseEndFrames
|
||||
push_text_frames: bool = True,
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._aggregate_sentences: bool = aggregate_sentences
|
||||
self._push_text_frames: bool = push_text_frames
|
||||
self._current_sentence: str = ""
|
||||
|
||||
@abstractmethod
|
||||
async def set_voice(self, voice: str):
|
||||
pass
|
||||
|
||||
# Converts the text to audio.
|
||||
@abstractmethod
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
@@ -104,36 +162,58 @@ class TTSService(AIService):
|
||||
async def say(self, text: str):
|
||||
await self.process_frame(TextFrame(text=text), FrameDirection.DOWNSTREAM)
|
||||
|
||||
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
|
||||
self._current_sentence = ""
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _process_text_frame(self, frame: TextFrame):
|
||||
text: str | None = None
|
||||
if not self._aggregate_sentences:
|
||||
text = frame.text
|
||||
else:
|
||||
self._current_sentence += frame.text
|
||||
if self._current_sentence.strip().endswith((".", "?", "!")):
|
||||
text = self._current_sentence.strip()
|
||||
if match_endofsentence(self._current_sentence):
|
||||
text = self._current_sentence
|
||||
self._current_sentence = ""
|
||||
|
||||
if text:
|
||||
await self._push_tts_frames(text)
|
||||
|
||||
async def _push_tts_frames(self, text: str):
|
||||
async def _push_tts_frames(self, text: str, text_passthrough: bool = True):
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return
|
||||
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
await self.start_processing_metrics()
|
||||
await self.process_generator(self.run_tts(text))
|
||||
await self.stop_processing_metrics()
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
# We send the original text after the audio. This way, if we are
|
||||
# interrupted, the text is not added to the assistant context.
|
||||
await self.push_frame(TextFrame(text))
|
||||
if self._push_text_frames:
|
||||
# We send the original text after the audio. This way, if we are
|
||||
# interrupted, the text is not added to the assistant context.
|
||||
await self.push_frame(TextFrame(text))
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, TextFrame):
|
||||
await self._process_text_frame(frame)
|
||||
elif isinstance(frame, EndFrame):
|
||||
if self._current_sentence:
|
||||
await self._push_tts_frames(self._current_sentence)
|
||||
await self.push_frame(frame)
|
||||
elif isinstance(frame, StartInterruptionFrame):
|
||||
await self._handle_interruption(frame, direction)
|
||||
elif isinstance(frame, LLMFullResponseEndFrame) or isinstance(frame, EndFrame):
|
||||
sentence = self._current_sentence
|
||||
self._current_sentence = ""
|
||||
await self._push_tts_frames(sentence)
|
||||
if isinstance(frame, LLMFullResponseEndFrame):
|
||||
if self._push_text_frames:
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, TTSSpeakFrame):
|
||||
await self._push_tts_frames(frame.text, False)
|
||||
elif isinstance(frame, TTSVoiceUpdateFrame):
|
||||
await self.set_voice(frame.voice)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -142,6 +222,7 @@ class STTService(AIService):
|
||||
"""STTService is a base class for speech-to-text services."""
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
min_volume: float = 0.6,
|
||||
max_silence_secs: float = 0.3,
|
||||
max_buffer_secs: float = 1.5,
|
||||
@@ -197,17 +278,22 @@ class STTService(AIService):
|
||||
self._silence_num_frames = 0
|
||||
self._wave.close()
|
||||
self._content.seek(0)
|
||||
await self.start_processing_metrics()
|
||||
await self.process_generator(self.run_stt(self._content.read()))
|
||||
await self.stop_processing_metrics()
|
||||
(self._content, self._wave) = self._new_wave()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
self._wave.close()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
self._wave.close()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Processes a frame of audio data, either buffering or transcribing it."""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, CancelFrame) or isinstance(frame, EndFrame):
|
||||
self._wave.close()
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, AudioRawFrame):
|
||||
if isinstance(frame, AudioRawFrame):
|
||||
# In this service we accumulate audio internally and at the end we
|
||||
# push a TextFrame. We don't really want to push audio frames down.
|
||||
await self._append_audio(frame)
|
||||
@@ -230,7 +316,9 @@ class ImageGenService(AIService):
|
||||
|
||||
if isinstance(frame, TextFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
await self.start_processing_metrics()
|
||||
await self.process_generator(self.run_image_gen(frame.text))
|
||||
await self.stop_processing_metrics()
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -250,6 +338,8 @@ class VisionService(AIService):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, VisionImageRawFrame):
|
||||
await self.start_processing_metrics()
|
||||
await self.process_generator(self.run_vision(frame))
|
||||
await self.stop_processing_metrics()
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -8,12 +8,11 @@ import base64
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
LLMModelUpdateFrame,
|
||||
TextFrame,
|
||||
VisionImageRawFrame,
|
||||
LLMMessagesFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMResponseStartFrame,
|
||||
LLMResponseEndFrame,
|
||||
LLMFullResponseEndFrame
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
@@ -41,6 +40,7 @@ class AnthropicLLMService(LLMService):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
model: str = "claude-3-opus-20240229",
|
||||
max_tokens: int = 1024):
|
||||
@@ -117,12 +117,10 @@ class AnthropicLLMService(LLMService):
|
||||
async for event in response:
|
||||
# logger.debug(f"Anthropic LLM event: {event}")
|
||||
if (event.type == "content_block_delta"):
|
||||
await self.push_frame(LLMResponseStartFrame())
|
||||
await self.push_frame(TextFrame(event.delta.text))
|
||||
await self.push_frame(LLMResponseEndFrame())
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
finally:
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
|
||||
@@ -137,6 +135,9 @@ class AnthropicLLMService(LLMService):
|
||||
context = OpenAILLMContext.from_messages(frame.messages)
|
||||
elif isinstance(frame, VisionImageRawFrame):
|
||||
context = OpenAILLMContext.from_image_frame(frame)
|
||||
elif isinstance(frame, LLMModelUpdateFrame):
|
||||
logger.debug(f"Switching LLM model to: [{frame.model}]")
|
||||
self._model = frame.model
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
@@ -12,9 +12,18 @@ import time
|
||||
from PIL import Image
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from pipecat.frames.frames import AudioRawFrame, CancelFrame, EndFrame, ErrorFrame, Frame, StartFrame, SystemFrame, TranscriptionFrame, URLImageRawFrame
|
||||
from pipecat.frames.frames import (
|
||||
AudioRawFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
StartFrame,
|
||||
SystemFrame,
|
||||
TranscriptionFrame,
|
||||
URLImageRawFrame)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import AIService, TTSService, ImageGenService
|
||||
from pipecat.services.ai_services import AsyncAIService, TTSService, ImageGenService
|
||||
from pipecat.services.openai import BaseOpenAILLMService
|
||||
|
||||
from loguru import logger
|
||||
@@ -34,7 +43,7 @@ try:
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use Azure TTS, you need to `pip install pipecat-ai[azure]`. Also, set `AZURE_SPEECH_API_KEY` and `AZURE_SPEECH_REGION` environment variables.")
|
||||
"In order to use Azure, you need to `pip install pipecat-ai[azure]`. Also, set `AZURE_SPEECH_API_KEY` and `AZURE_SPEECH_REGION` environment variables.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
@@ -72,8 +81,12 @@ class AzureTTSService(TTSService):
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def set_voice(self, voice: str):
|
||||
logger.debug(f"Switching TTS voice to: [{voice}]")
|
||||
self._voice = voice
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Generating TTS: {text}")
|
||||
logger.debug(f"Generating TTS: [{text}]")
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
@@ -100,7 +113,7 @@ class AzureTTSService(TTSService):
|
||||
logger.error(f"{self} error: {cancellation_details.error_details}")
|
||||
|
||||
|
||||
class AzureSTTService(AIService):
|
||||
class AzureSTTService(AsyncAIService):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -123,8 +136,6 @@ class AzureSTTService(AIService):
|
||||
speech_config=speech_config, audio_config=audio_config)
|
||||
self._speech_recognizer.recognized.connect(self._on_handle_recognized)
|
||||
|
||||
self._create_push_task()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
@@ -136,37 +147,23 @@ class AzureSTTService(AIService):
|
||||
await self._push_queue.put((frame, direction))
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
self._speech_recognizer.start_continuous_recognition_async()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
self._speech_recognizer.stop_continuous_recognition_async()
|
||||
await self._push_queue.put((frame, FrameDirection.DOWNSTREAM))
|
||||
await self._push_frame_task
|
||||
self._audio_stream.close()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
self._speech_recognizer.stop_continuous_recognition_async()
|
||||
self._push_frame_task.cancel()
|
||||
|
||||
def _create_push_task(self):
|
||||
self._push_frame_task = self.get_event_loop().create_task(self._push_frame_task_handler())
|
||||
self._push_queue = asyncio.Queue()
|
||||
|
||||
async def _push_frame_task_handler(self):
|
||||
running = True
|
||||
while running:
|
||||
try:
|
||||
(frame, direction) = await self._push_queue.get()
|
||||
await self.push_frame(frame, direction)
|
||||
running = not isinstance(frame, EndFrame)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
self._audio_stream.close()
|
||||
|
||||
def _on_handle_recognized(self, event):
|
||||
if event.result.reason == ResultReason.RecognizedSpeech and len(event.result.text) > 0:
|
||||
direction = FrameDirection.DOWNSTREAM
|
||||
frame = TranscriptionFrame(event.result.text, "", int(time.time_ns() / 1000000))
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._push_queue.put((frame, direction)), self.get_event_loop())
|
||||
asyncio.run_coroutine_threadsafe(self.queue_frame(frame), self.get_event_loop())
|
||||
|
||||
|
||||
class AzureImageGenServiceREST(ImageGenService):
|
||||
@@ -174,12 +171,12 @@ class AzureImageGenServiceREST(ImageGenService):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
aiohttp_session: aiohttp.ClientSession,
|
||||
image_size: str,
|
||||
api_key: str,
|
||||
endpoint: str,
|
||||
model: str,
|
||||
api_version="2023-06-01-preview",
|
||||
aiohttp_session: aiohttp.ClientSession | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -187,8 +184,14 @@ class AzureImageGenServiceREST(ImageGenService):
|
||||
self._azure_endpoint = endpoint
|
||||
self._api_version = api_version
|
||||
self._model = model
|
||||
self._aiohttp_session = aiohttp_session
|
||||
self._image_size = image_size
|
||||
self._aiohttp_session = aiohttp_session or aiohttp.ClientSession()
|
||||
self._close_aiohttp_session = aiohttp_session is None
|
||||
|
||||
async def cleanup(self):
|
||||
await super().cleanup()
|
||||
if self._close_aiohttp_session:
|
||||
await self._aiohttp_session.close()
|
||||
|
||||
async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
|
||||
url = f"{self._azure_endpoint}openai/images/generations:submit?api-version={self._api_version}"
|
||||
|
||||
@@ -4,15 +4,38 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from cartesia.tts import AsyncCartesiaTTS
|
||||
import json
|
||||
import uuid
|
||||
import base64
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from pipecat.frames.frames import AudioRawFrame, Frame
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
Frame,
|
||||
AudioRawFrame,
|
||||
StartInterruptionFrame,
|
||||
StartFrame,
|
||||
EndFrame,
|
||||
TextFrame,
|
||||
LLMFullResponseEndFrame
|
||||
)
|
||||
from pipecat.services.ai_services import TTSService
|
||||
|
||||
from loguru import logger
|
||||
|
||||
# See .env.example for Cartesia configuration needed
|
||||
try:
|
||||
import websockets
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use Cartesia, you need to `pip install pipecat-ai[cartesia]`. Also, set `CARTESIA_API_KEY` environment variable.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class CartesiaTTSService(TTSService):
|
||||
|
||||
@@ -20,44 +43,192 @@ class CartesiaTTSService(TTSService):
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
voice_name: str,
|
||||
model_id: str = "upbeat-moon",
|
||||
output_format: str = "pcm_16000",
|
||||
voice_id: str,
|
||||
cartesia_version: str = "2024-06-10",
|
||||
url: str = "wss://api.cartesia.ai/tts/websocket",
|
||||
model_id: str = "sonic-english",
|
||||
encoding: str = "pcm_s16le",
|
||||
sample_rate: int = 16000,
|
||||
language: str = "en",
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._api_key = api_key
|
||||
self._voice_name = voice_name
|
||||
self._model_id = model_id
|
||||
self._output_format = output_format
|
||||
# Aggregating sentences still gives cleaner-sounding results and fewer
|
||||
# artifacts than streaming one word at a time. On average, waiting for
|
||||
# a full sentence should only "cost" us 15ms or so with GPT-4o or a Llama 3
|
||||
# model, and it's worth it for the better audio quality.
|
||||
self._aggregate_sentences = True
|
||||
|
||||
try:
|
||||
self._client = AsyncCartesiaTTS(api_key=self._api_key)
|
||||
voices = self._client.get_voices()
|
||||
voice_id = voices[self._voice_name]["id"]
|
||||
self._voice = self._client.get_voice_embedding(voice_id=voice_id)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
# we don't want to automatically push LLM response text frames, because the
|
||||
# context aggregators will add them to the LLM context even if we're
|
||||
# interrupted. cartesia gives us word-by-word timestamps. we can use those
|
||||
# to generate text frames ourselves aligned with the playout timing of the audio!
|
||||
self._push_text_frames = False
|
||||
|
||||
self._api_key = api_key
|
||||
self._cartesia_version = cartesia_version
|
||||
self._url = url
|
||||
self._voice_id = voice_id
|
||||
self._model_id = model_id
|
||||
self._output_format = {
|
||||
"container": "raw",
|
||||
"encoding": encoding,
|
||||
"sample_rate": sample_rate,
|
||||
}
|
||||
self._language = language
|
||||
|
||||
self._websocket = None
|
||||
self._context_id = None
|
||||
self._context_id_start_timestamp = None
|
||||
self._timestamped_words_buffer = []
|
||||
self._receive_task = None
|
||||
self._context_appending_task = None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def set_voice(self, voice: str):
|
||||
logger.debug(f"Switching TTS voice to: [{voice}]")
|
||||
self._voice_id = voice
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def _connect(self):
|
||||
try:
|
||||
self._websocket = await websockets.connect(
|
||||
f"{self._url}?api_key={self._api_key}&cartesia_version={self._cartesia_version}"
|
||||
)
|
||||
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
|
||||
self._context_appending_task = self.get_event_loop().create_task(self._context_appending_task_handler())
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} initialization error: {e}")
|
||||
self._websocket = None
|
||||
|
||||
async def _disconnect(self):
|
||||
try:
|
||||
await self.stop_all_metrics()
|
||||
|
||||
if self._context_appending_task:
|
||||
self._context_appending_task.cancel()
|
||||
await self._context_appending_task
|
||||
self._context_appending_task = None
|
||||
if self._receive_task:
|
||||
self._receive_task.cancel()
|
||||
await self._receive_task
|
||||
self._receive_task = None
|
||||
if self._websocket:
|
||||
await self._websocket.close()
|
||||
self._websocket = None
|
||||
|
||||
self._context_id = None
|
||||
self._context_id_start_timestamp = None
|
||||
self._timestamped_words_buffer = []
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} error closing websocket: {e}")
|
||||
|
||||
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
|
||||
await super()._handle_interruption(frame, direction)
|
||||
self._context_id = None
|
||||
self._context_id_start_timestamp = None
|
||||
self._timestamped_words_buffer = []
|
||||
await self.stop_all_metrics()
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
|
||||
async def _receive_task_handler(self):
|
||||
try:
|
||||
async for message in self._websocket:
|
||||
msg = json.loads(message)
|
||||
if not msg or msg["context_id"] != self._context_id:
|
||||
continue
|
||||
if msg["type"] == "done":
|
||||
await self.stop_ttfb_metrics()
|
||||
# Unset _context_id but not the _context_id_start_timestamp
|
||||
# because we are likely still playing out audio and need the
|
||||
# timestamp to set send context frames.
|
||||
self._context_id = None
|
||||
self._timestamped_words_buffer.append(("LLMFullResponseEndFrame", 0))
|
||||
elif msg["type"] == "timestamps":
|
||||
# logger.debug(f"TIMESTAMPS: {msg}")
|
||||
self._timestamped_words_buffer.extend(
|
||||
list(zip(msg["word_timestamps"]["words"], msg["word_timestamps"]["end"]))
|
||||
)
|
||||
elif msg["type"] == "chunk":
|
||||
await self.stop_ttfb_metrics()
|
||||
if not self._context_id_start_timestamp:
|
||||
self._context_id_start_timestamp = time.time()
|
||||
frame = AudioRawFrame(
|
||||
audio=base64.b64decode(msg["data"]),
|
||||
sample_rate=self._output_format["sample_rate"],
|
||||
num_channels=1
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
|
||||
async def _context_appending_task_handler(self):
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(0.1)
|
||||
if not self._context_id_start_timestamp:
|
||||
continue
|
||||
elapsed_seconds = time.time() - self._context_id_start_timestamp
|
||||
# Pop all words from self._timestamped_words_buffer that are
|
||||
# older than the elapsed time and print a message about them to
|
||||
# the console.
|
||||
while self._timestamped_words_buffer and self._timestamped_words_buffer[0][1] <= elapsed_seconds:
|
||||
word, timestamp = self._timestamped_words_buffer.pop(0)
|
||||
if word == "LLMFullResponseEndFrame" and timestamp == 0:
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
continue
|
||||
await self.push_frame(TextFrame(word))
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Generating TTS: [{text}]")
|
||||
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
if not self._websocket:
|
||||
await self._connect()
|
||||
|
||||
chunk_generator = await self._client.generate(
|
||||
stream=True,
|
||||
transcript=text,
|
||||
voice=self._voice,
|
||||
model_id=self._model_id,
|
||||
output_format=self._output_format,
|
||||
)
|
||||
if not self._context_id:
|
||||
await self.start_ttfb_metrics()
|
||||
self._context_id = str(uuid.uuid4())
|
||||
|
||||
async for chunk in chunk_generator:
|
||||
await self.stop_ttfb_metrics()
|
||||
yield AudioRawFrame(chunk["audio"], chunk["sampling_rate"], 1)
|
||||
msg = {
|
||||
"transcript": text + " ",
|
||||
"continue": True,
|
||||
"context_id": self._context_id,
|
||||
"model_id": self._model_id,
|
||||
"voice": {
|
||||
"mode": "id",
|
||||
"id": self._voice_id
|
||||
},
|
||||
"output_format": self._output_format,
|
||||
"language": self._language,
|
||||
"add_timestamps": True,
|
||||
}
|
||||
try:
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} error sending message: {e}")
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
return
|
||||
yield None
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
|
||||
@@ -5,8 +5,6 @@
|
||||
#
|
||||
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
@@ -21,43 +19,66 @@ from pipecat.frames.frames import (
|
||||
SystemFrame,
|
||||
TranscriptionFrame)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import AIService, TTSService
|
||||
|
||||
from deepgram import (
|
||||
DeepgramClient,
|
||||
DeepgramClientOptions,
|
||||
LiveTranscriptionEvents,
|
||||
LiveOptions,
|
||||
)
|
||||
from pipecat.services.ai_services import AsyncAIService, TTSService
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
# See .env.example for Deepgram configuration needed
|
||||
try:
|
||||
from deepgram import (
|
||||
DeepgramClient,
|
||||
DeepgramClientOptions,
|
||||
LiveTranscriptionEvents,
|
||||
LiveOptions,
|
||||
)
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use Deepgram, you need to `pip install pipecat-ai[deepgram]`. Also, set `DEEPGRAM_API_KEY` environment variable.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class DeepgramTTSService(TTSService):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
aiohttp_session: aiohttp.ClientSession,
|
||||
api_key: str,
|
||||
voice: str = "aura-helios-en",
|
||||
base_url: str = "https://api.deepgram.com/v1/speak",
|
||||
sample_rate: int = 16000,
|
||||
encoding: str = "linear16",
|
||||
aiohttp_session: aiohttp.ClientSession | None = None,
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._voice = voice
|
||||
self._api_key = api_key
|
||||
self._aiohttp_session = aiohttp_session
|
||||
self._base_url = base_url
|
||||
self._sample_rate = sample_rate
|
||||
self._encoding = encoding
|
||||
self._aiohttp_session = aiohttp_session or aiohttp.ClientSession()
|
||||
self._close_aiohttp_session = aiohttp_session is None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def cleanup(self):
|
||||
await super().cleanup()
|
||||
if self._close_aiohttp_session:
|
||||
await self._aiohttp_session.close()
|
||||
|
||||
async def set_voice(self, voice: str):
|
||||
logger.debug(f"Switching TTS voice to: [{voice}]")
|
||||
self._voice = voice
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Generating TTS: [{text}]")
|
||||
|
||||
base_url = self._base_url
|
||||
request_url = f"{base_url}?model={self._voice}&encoding=linear16&container=none&sample_rate=16000"
|
||||
request_url = f"{base_url}?model={self._voice}&encoding={self._encoding}&container=none&sample_rate={self._sample_rate}"
|
||||
headers = {"authorization": f"token {self._api_key}"}
|
||||
body = {"text": text}
|
||||
|
||||
@@ -80,15 +101,17 @@ class DeepgramTTSService(TTSService):
|
||||
|
||||
async for data in r.content:
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = AudioRawFrame(audio=data, sample_rate=16000, num_channels=1)
|
||||
frame = AudioRawFrame(audio=data, sample_rate=self._sample_rate, num_channels=1)
|
||||
yield frame
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
|
||||
|
||||
class DeepgramSTTService(AIService):
|
||||
class DeepgramSTTService(AsyncAIService):
|
||||
def __init__(self,
|
||||
*,
|
||||
api_key: str,
|
||||
url: str = "",
|
||||
live_options: LiveOptions = LiveOptions(
|
||||
encoding="linear16",
|
||||
language="en-US",
|
||||
@@ -104,12 +127,10 @@ class DeepgramSTTService(AIService):
|
||||
self._live_options = live_options
|
||||
|
||||
self._client = DeepgramClient(
|
||||
api_key, config=DeepgramClientOptions(options={"keepalive": "true"}))
|
||||
api_key, config=DeepgramClientOptions(url=url, options={"keepalive": "true"}))
|
||||
self._connection = self._client.listen.asynclive.v("1")
|
||||
self._connection.on(LiveTranscriptionEvents.Transcript, self._on_message)
|
||||
|
||||
self._create_push_task()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
@@ -118,36 +139,22 @@ class DeepgramSTTService(AIService):
|
||||
elif isinstance(frame, AudioRawFrame):
|
||||
await self._connection.send(frame.audio)
|
||||
else:
|
||||
await self._push_queue.put((frame, direction))
|
||||
await self.queue_frame(frame, direction)
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
if await self._connection.start(self._live_options):
|
||||
logger.debug(f"{self}: Connected to Deepgram")
|
||||
else:
|
||||
logger.error(f"{self}: Unable to connect to Deepgram")
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
await self._connection.finish()
|
||||
await self._push_queue.put((frame, FrameDirection.DOWNSTREAM))
|
||||
await self._push_frame_task
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
await self._connection.finish()
|
||||
self._push_frame_task.cancel()
|
||||
|
||||
def _create_push_task(self):
|
||||
self._push_frame_task = self.get_event_loop().create_task(self._push_frame_task_handler())
|
||||
self._push_queue = asyncio.Queue()
|
||||
|
||||
async def _push_frame_task_handler(self):
|
||||
running = True
|
||||
while running:
|
||||
try:
|
||||
(frame, direction) = await self._push_queue.get()
|
||||
await self.push_frame(frame, direction)
|
||||
running = not isinstance(frame, EndFrame)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
async def _on_message(self, *args, **kwargs):
|
||||
result = kwargs["result"]
|
||||
@@ -155,6 +162,6 @@ class DeepgramSTTService(AIService):
|
||||
transcript = result.channel.alternatives[0].transcript
|
||||
if len(transcript) > 0:
|
||||
if is_final:
|
||||
await self._push_queue.put((TranscriptionFrame(transcript, "", int(time.time_ns() / 1000000)), FrameDirection.DOWNSTREAM))
|
||||
await self.queue_frame(TranscriptionFrame(transcript, "", time_now_iso8601()))
|
||||
else:
|
||||
await self._push_queue.put((InterimTranscriptionFrame(transcript, "", int(time.time_ns() / 1000000)), FrameDirection.DOWNSTREAM))
|
||||
await self.queue_frame(InterimTranscriptionFrame(transcript, "", time_now_iso8601()))
|
||||
|
||||
@@ -19,21 +19,31 @@ class ElevenLabsTTSService(TTSService):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
aiohttp_session: aiohttp.ClientSession,
|
||||
api_key: str,
|
||||
voice_id: str,
|
||||
model: str = "eleven_turbo_v2",
|
||||
aiohttp_session: aiohttp.ClientSession | None = None,
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._api_key = api_key
|
||||
self._voice_id = voice_id
|
||||
self._aiohttp_session = aiohttp_session
|
||||
self._model = model
|
||||
self._aiohttp_session = aiohttp_session or aiohttp.ClientSession()
|
||||
self._close_aiohttp_session = aiohttp_session is None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def cleanup(self):
|
||||
await super().cleanup()
|
||||
if self._close_aiohttp_session:
|
||||
await self._aiohttp_session.close()
|
||||
|
||||
async def set_voice(self, voice: str):
|
||||
logger.debug(f"Switching TTS voice to: [{voice}]")
|
||||
self._voice_id = voice
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Generating TTS: [{text}]")
|
||||
|
||||
|
||||
@@ -39,24 +39,30 @@ class FalImageGenService(ImageGenService):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
aiohttp_session: aiohttp.ClientSession,
|
||||
params: InputParams,
|
||||
model: str = "fal-ai/fast-sdxl",
|
||||
key: str | None = None,
|
||||
aiohttp_session: aiohttp.ClientSession | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self._model = model
|
||||
self._params = params
|
||||
self._aiohttp_session = aiohttp_session
|
||||
self._aiohttp_session = aiohttp_session or aiohttp.ClientSession()
|
||||
self._close_aiohttp_session = aiohttp_session is None
|
||||
if key:
|
||||
os.environ["FAL_KEY"] = key
|
||||
|
||||
async def cleanup(self):
|
||||
await super().cleanup()
|
||||
if self._close_aiohttp_session:
|
||||
await self._aiohttp_session.close()
|
||||
|
||||
async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Generating image from prompt: {prompt}")
|
||||
|
||||
response = await fal_client.run_async(
|
||||
self._model,
|
||||
arguments={"prompt": prompt, **self._params.model_dump()}
|
||||
arguments={"prompt": prompt, **self._params.model_dump(exclude_none=True)}
|
||||
)
|
||||
|
||||
image_url = response["images"][0]["url"] if response else None
|
||||
|
||||
@@ -19,6 +19,7 @@ except ModuleNotFoundError as e:
|
||||
|
||||
class FireworksLLMService(BaseOpenAILLMService):
|
||||
def __init__(self,
|
||||
*,
|
||||
model: str = "accounts/fireworks/models/firefunction-v1",
|
||||
base_url: str = "https://api.fireworks.ai/inference/v1"):
|
||||
super().__init__(model, base_url)
|
||||
|
||||
118
src/pipecat/services/gladia.py
Normal file
118
src/pipecat/services/gladia.py
Normal file
@@ -0,0 +1,118 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import base64
|
||||
import json
|
||||
|
||||
from typing import Optional
|
||||
from pydantic.main import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
AudioRawFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
SystemFrame,
|
||||
TranscriptionFrame)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import AsyncAIService
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
from loguru import logger
|
||||
|
||||
# See .env.example for Gladia configuration needed
|
||||
try:
|
||||
import websockets
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use Gladia, you need to `pip install pipecat-ai[gladia]`. Also, set `GLADIA_API_KEY` environment variable.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class GladiaSTTService(AsyncAIService):
|
||||
class InputParams(BaseModel):
|
||||
sample_rate: Optional[int] = 16000
|
||||
language: Optional[str] = "english"
|
||||
transcription_hint: Optional[str] = None
|
||||
endpointing: Optional[int] = 200
|
||||
prosody: Optional[bool] = None
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
api_key: str,
|
||||
url: str = "wss://api.gladia.io/audio/text/audio-transcription",
|
||||
confidence: float = 0.5,
|
||||
params: InputParams = InputParams(),
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._api_key = api_key
|
||||
self._url = url
|
||||
self._params = params
|
||||
self._confidence = confidence
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, SystemFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, AudioRawFrame):
|
||||
await self._send_audio(frame)
|
||||
else:
|
||||
await self.queue_frame(frame, direction)
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
self._websocket = await websockets.connect(self._url)
|
||||
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
|
||||
await self._setup_gladia()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
await self._websocket.close()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
await self._websocket.close()
|
||||
|
||||
async def _setup_gladia(self):
|
||||
configuration = {
|
||||
"x_gladia_key": self._api_key,
|
||||
"encoding": "WAV/PCM",
|
||||
"model_type": "fast",
|
||||
"language_behaviour": "manual",
|
||||
**self._params.model_dump(exclude_none=True)
|
||||
}
|
||||
|
||||
await self._websocket.send(json.dumps(configuration))
|
||||
|
||||
async def _send_audio(self, frame: AudioRawFrame):
|
||||
message = {
|
||||
'frames': base64.b64encode(frame.audio).decode("utf-8")
|
||||
}
|
||||
await self._websocket.send(json.dumps(message))
|
||||
|
||||
async def _receive_task_handler(self):
|
||||
async for message in self._websocket:
|
||||
utterance = json.loads(message)
|
||||
if not utterance:
|
||||
continue
|
||||
|
||||
if "error" in utterance:
|
||||
message = utterance["message"]
|
||||
logger.error(f"Gladia error: {message}")
|
||||
elif "confidence" in utterance:
|
||||
type = utterance["type"]
|
||||
confidence = utterance["confidence"]
|
||||
transcript = utterance["transcription"]
|
||||
if confidence >= self._confidence:
|
||||
if type == "final":
|
||||
await self.queue_frame(TranscriptionFrame(transcript, "", time_now_iso8601()))
|
||||
else:
|
||||
await self.queue_frame(InterimTranscriptionFrame(transcript, "", time_now_iso8601()))
|
||||
@@ -10,12 +10,11 @@ from typing import List
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
LLMModelUpdateFrame,
|
||||
TextFrame,
|
||||
VisionImageRawFrame,
|
||||
LLMMessagesFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMResponseStartFrame,
|
||||
LLMResponseEndFrame,
|
||||
LLMFullResponseEndFrame
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
@@ -42,14 +41,17 @@ class GoogleLLMService(LLMService):
|
||||
franca for all LLM services, so that it is easy to switch between different LLMs.
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: str, model: str = "gemini-1.5-flash-latest", **kwargs):
|
||||
def __init__(self, *, api_key: str, model: str = "gemini-1.5-flash-latest", **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
gai.configure(api_key=api_key)
|
||||
self._client = gai.GenerativeModel(model)
|
||||
self._create_client(model)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
def _create_client(self, model: str):
|
||||
self._client = gai.GenerativeModel(model)
|
||||
|
||||
def _get_messages_from_openai_context(
|
||||
self, context: OpenAILLMContext) -> List[glm.Content]:
|
||||
openai_messages = context.get_messages()
|
||||
@@ -95,19 +97,17 @@ class GoogleLLMService(LLMService):
|
||||
async for chunk in self._async_generator_wrapper(response):
|
||||
try:
|
||||
text = chunk.text
|
||||
await self.push_frame(LLMResponseStartFrame())
|
||||
await self.push_frame(TextFrame(text))
|
||||
await self.push_frame(LLMResponseEndFrame())
|
||||
except Exception as e:
|
||||
# Google LLMs seem to flag safety issues a lot!
|
||||
if chunk.candidates[0].finish_reason == 3:
|
||||
logger.debug(
|
||||
f"LLM refused to generate content for safety reasons - {messages}.")
|
||||
else:
|
||||
logger.error(f"{self} error: {e}")
|
||||
logger.exception(f"{self} error: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
finally:
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
|
||||
@@ -122,6 +122,9 @@ class GoogleLLMService(LLMService):
|
||||
context = OpenAILLMContext.from_messages(frame.messages)
|
||||
elif isinstance(frame, VisionImageRawFrame):
|
||||
context = OpenAILLMContext.from_image_frame(frame)
|
||||
elif isinstance(frame, LLMModelUpdateFrame):
|
||||
logger.debug(f"Switching LLM model to: [{frame.model}]")
|
||||
self._create_client(frame.model)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
@@ -46,6 +46,7 @@ def detect_device():
|
||||
class MoondreamService(VisionService):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model="vikhyatk/moondream2",
|
||||
revision="2024-04-02",
|
||||
use_cpu=False
|
||||
|
||||
@@ -9,5 +9,5 @@ from pipecat.services.openai import BaseOpenAILLMService
|
||||
|
||||
class OLLamaLLMService(BaseOpenAILLMService):
|
||||
|
||||
def __init__(self, model: str = "llama2", base_url: str = "http://localhost:11434/v1"):
|
||||
def __init__(self, *, model: str = "llama2", base_url: str = "http://localhost:11434/v1"):
|
||||
super().__init__(model=model, base_url=base_url, api_key="ollama")
|
||||
|
||||
@@ -8,8 +8,9 @@ import aiohttp
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import httpx
|
||||
|
||||
from typing import Any, AsyncGenerator, List, Literal
|
||||
from typing import AsyncGenerator, List, Literal
|
||||
|
||||
from loguru import logger
|
||||
from PIL import Image
|
||||
@@ -21,8 +22,7 @@ from pipecat.frames.frames import (
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesFrame,
|
||||
LLMResponseEndFrame,
|
||||
LLMResponseStartFrame,
|
||||
LLMModelUpdateFrame,
|
||||
TextFrame,
|
||||
URLImageRawFrame,
|
||||
VisionImageRawFrame
|
||||
@@ -39,7 +39,7 @@ from pipecat.services.ai_services import (
|
||||
)
|
||||
|
||||
try:
|
||||
from openai import AsyncOpenAI, AsyncStream, BadRequestError
|
||||
from openai import AsyncOpenAI, AsyncStream, DefaultAsyncHttpxClient, BadRequestError
|
||||
from openai.types.chat import (
|
||||
ChatCompletionChunk,
|
||||
ChatCompletionFunctionMessageParam,
|
||||
@@ -53,7 +53,7 @@ except ModuleNotFoundError as e:
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class OpenAIUnhandledFunctionException(BaseException):
|
||||
class OpenAIUnhandledFunctionException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@@ -67,13 +67,20 @@ class BaseOpenAILLMService(LLMService):
|
||||
calls from the LLM.
|
||||
"""
|
||||
|
||||
def __init__(self, model: str, api_key=None, base_url=None, **kwargs):
|
||||
def __init__(self, *, model: str, api_key=None, base_url=None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._model: str = model
|
||||
self._client = self.create_client(api_key=api_key, base_url=base_url, **kwargs)
|
||||
|
||||
def create_client(self, api_key=None, base_url=None, **kwargs):
|
||||
return AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||
return AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
http_client=DefaultAsyncHttpxClient(
|
||||
limits=httpx.Limits(
|
||||
max_keepalive_connections=100,
|
||||
max_connections=1000,
|
||||
keepalive_expiry=None)))
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
@@ -109,10 +116,7 @@ class BaseOpenAILLMService(LLMService):
|
||||
del message["data"]
|
||||
del message["mime_type"]
|
||||
|
||||
try:
|
||||
chunks = await self.get_chat_completions(context, messages)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
chunks = await self.get_chat_completions(context, messages)
|
||||
|
||||
return chunks
|
||||
|
||||
@@ -154,9 +158,7 @@ class BaseOpenAILLMService(LLMService):
|
||||
# Keep iterating through the response to collect all the argument fragments
|
||||
arguments += tool_call.function.arguments
|
||||
elif chunk.choices[0].delta.content:
|
||||
await self.push_frame(LLMResponseStartFrame())
|
||||
await self.push_frame(TextFrame(chunk.choices[0].delta.content))
|
||||
await self.push_frame(LLMResponseEndFrame())
|
||||
|
||||
# if we got a function name and arguments, check to see if it's a function with
|
||||
# a registered handler. If so, run the registered callback, save the result to
|
||||
@@ -214,7 +216,7 @@ class BaseOpenAILLMService(LLMService):
|
||||
elif isinstance(result, type(None)):
|
||||
pass
|
||||
else:
|
||||
raise BaseException(f"Unknown return type from function callback: {type(result)}")
|
||||
raise TypeError(f"Unknown return type from function callback: {type(result)}")
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
@@ -226,19 +228,24 @@ class BaseOpenAILLMService(LLMService):
|
||||
context = OpenAILLMContext.from_messages(frame.messages)
|
||||
elif isinstance(frame, VisionImageRawFrame):
|
||||
context = OpenAILLMContext.from_image_frame(frame)
|
||||
elif isinstance(frame, LLMModelUpdateFrame):
|
||||
logger.debug(f"Switching LLM model to: [{frame.model}]")
|
||||
self._model = frame.model
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
if context:
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
await self.start_processing_metrics()
|
||||
await self._process_context(context)
|
||||
await self.stop_processing_metrics()
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
|
||||
|
||||
class OpenAILLMService(BaseOpenAILLMService):
|
||||
|
||||
def __init__(self, model="gpt-4o", **kwargs):
|
||||
super().__init__(model, **kwargs)
|
||||
def __init__(self, *, model: str = "gpt-4o", **kwargs):
|
||||
super().__init__(model=model, **kwargs)
|
||||
|
||||
|
||||
class OpenAIImageGenService(ImageGenService):
|
||||
@@ -246,16 +253,22 @@ class OpenAIImageGenService(ImageGenService):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image_size: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"],
|
||||
aiohttp_session: aiohttp.ClientSession,
|
||||
api_key: str,
|
||||
model: str = "dall-e-3",
|
||||
image_size: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"],
|
||||
aiohttp_session: aiohttp.ClientSession | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self._model = model
|
||||
self._image_size = image_size
|
||||
self._client = AsyncOpenAI(api_key=api_key)
|
||||
self._aiohttp_session = aiohttp_session
|
||||
self._aiohttp_session = aiohttp_session or aiohttp.ClientSession()
|
||||
self._close_aiohttp_session = aiohttp_session is None
|
||||
|
||||
async def cleanup(self):
|
||||
await super().cleanup()
|
||||
if self._close_aiohttp_session:
|
||||
await self._aiohttp_session.close()
|
||||
|
||||
async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Generating image from prompt: {prompt}")
|
||||
@@ -310,6 +323,10 @@ class OpenAITTSService(TTSService):
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def set_voice(self, voice: str):
|
||||
logger.debug(f"Switching TTS voice to: [{voice}]")
|
||||
self._voice = voice
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Generating TTS: [{text}]")
|
||||
|
||||
@@ -334,4 +351,4 @@ class OpenAITTSService(TTSService):
|
||||
frame = AudioRawFrame(chunk, 24_000, 1)
|
||||
yield frame
|
||||
except BadRequestError as e:
|
||||
logger.error(f"{self} error generating TTS: {e}")
|
||||
logger.exception(f"{self} error generating TTS: {e}")
|
||||
|
||||
@@ -25,6 +25,7 @@ class OpenPipeLLMService(BaseOpenAILLMService):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model: str = "gpt-4o",
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
@@ -33,9 +34,9 @@ class OpenPipeLLMService(BaseOpenAILLMService):
|
||||
tags: Dict[str, str] | None = None,
|
||||
**kwargs):
|
||||
super().__init__(
|
||||
model,
|
||||
api_key,
|
||||
base_url,
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
openpipe_api_key=openpipe_api_key,
|
||||
openpipe_base_url=openpipe_base_url,
|
||||
**kwargs)
|
||||
|
||||
@@ -80,4 +80,4 @@ class PlayHTTTSService(TTSService):
|
||||
frame = AudioRawFrame(chunk, 16000, 1)
|
||||
yield frame
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error generating TTS: {e}")
|
||||
logger.exception(f"{self} error generating TTS: {e}")
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
"""This module implements Whisper transcription with a locally-downloaded model."""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from enum import Enum
|
||||
from typing_extensions import AsyncGenerator
|
||||
@@ -16,6 +15,7 @@ import numpy as np
|
||||
|
||||
from pipecat.frames.frames import ErrorFrame, Frame, TranscriptionFrame
|
||||
from pipecat.services.ai_services import STTService
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -42,7 +42,8 @@ class WhisperSTTService(STTService):
|
||||
"""Class to transcribe audio with a locally-downloaded Whisper model"""
|
||||
|
||||
def __init__(self,
|
||||
model: Model = Model.DISTIL_MEDIUM_EN,
|
||||
*,
|
||||
model: str | Model = Model.DISTIL_MEDIUM_EN,
|
||||
device: str = "auto",
|
||||
compute_type: str = "default",
|
||||
no_speech_prob: float = 0.4,
|
||||
@@ -51,7 +52,7 @@ class WhisperSTTService(STTService):
|
||||
super().__init__(**kwargs)
|
||||
self._device: str = device
|
||||
self._compute_type = compute_type
|
||||
self._model_name: Model = model
|
||||
self._model_name: str | Model = model
|
||||
self._no_speech_prob = no_speech_prob
|
||||
self._model: WhisperModel | None = None
|
||||
self._load()
|
||||
@@ -64,7 +65,7 @@ class WhisperSTTService(STTService):
|
||||
this model is being run, it will take time to download."""
|
||||
logger.debug("Loading Whisper model...")
|
||||
self._model = WhisperModel(
|
||||
self._model_name.value,
|
||||
self._model_name.value if isinstance(self._model_name, Enum) else self._model_name,
|
||||
device=self._device,
|
||||
compute_type=self._compute_type)
|
||||
logger.debug("Loaded Whisper model")
|
||||
@@ -90,4 +91,4 @@ class WhisperSTTService(STTService):
|
||||
if text:
|
||||
await self.stop_ttfb_metrics()
|
||||
logger.debug(f"Transcription: [{text}]")
|
||||
yield TranscriptionFrame(text, "", int(time.time_ns() / 1000000))
|
||||
yield TranscriptionFrame(text, "", time_now_iso8601())
|
||||
|
||||
122
src/pipecat/services/xtts.py
Normal file
122
src/pipecat/services/xtts.py
Normal file
@@ -0,0 +1,122 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import aiohttp
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from pipecat.frames.frames import AudioRawFrame, ErrorFrame, Frame
|
||||
from pipecat.services.ai_services import TTSService
|
||||
|
||||
from loguru import logger
|
||||
|
||||
import requests
|
||||
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import resampy
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use XTTS, you need to `pip install pipecat-ai[xtts]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
# The server below can connect to XTTS through a local running docker
|
||||
#
|
||||
# Docker command: $ docker run --gpus=all -e COQUI_TOS_AGREED=1 --rm -p 8000:80 ghcr.io/coqui-ai/xtts-streaming-server:latest-cuda121
|
||||
#
|
||||
# You can find more information on the official repo:
|
||||
# https://github.com/coqui-ai/xtts-streaming-server
|
||||
|
||||
|
||||
class XTTSService(TTSService):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
voice_id: str,
|
||||
language: str,
|
||||
base_url: str,
|
||||
aiohttp_session: aiohttp.ClientSession | None = None,
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._voice_id = voice_id
|
||||
self._language = language
|
||||
self._base_url = base_url
|
||||
self._studio_speakers = requests.get(self._base_url + "/studio_speakers").json()
|
||||
self._aiohttp_session = aiohttp_session or aiohttp.ClientSession()
|
||||
self._close_aiohttp_session = aiohttp_session is None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def cleanup(self):
|
||||
await super().cleanup()
|
||||
if self._close_aiohttp_session:
|
||||
await self._aiohttp_session.close()
|
||||
|
||||
async def set_voice(self, voice: str):
|
||||
logger.debug(f"Switching TTS voice to: [{voice}]")
|
||||
self._voice_id = voice
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Generating TTS: [{text}]")
|
||||
embeddings = self._studio_speakers[self._voice_id]
|
||||
|
||||
url = self._base_url + "/tts_stream"
|
||||
|
||||
payload = {
|
||||
"text": text.replace('.', '').replace('*', ''),
|
||||
"language": self._language,
|
||||
"speaker_embedding": embeddings["speaker_embedding"],
|
||||
"gpt_cond_latent": embeddings["gpt_cond_latent"],
|
||||
"add_wav_header": False,
|
||||
"stream_chunk_size": 20,
|
||||
}
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
async with self._aiohttp_session.post(url, json=payload) as r:
|
||||
if r.status != 200:
|
||||
text = await r.text()
|
||||
logger.error(f"{self} error getting audio (status: {r.status}, error: {text})")
|
||||
yield ErrorFrame(f"Error getting audio (status: {r.status}, error: {text})")
|
||||
return
|
||||
|
||||
buffer = bytearray()
|
||||
|
||||
async for chunk in r.content.iter_chunked(1024):
|
||||
if len(chunk) > 0:
|
||||
await self.stop_ttfb_metrics()
|
||||
# Append new chunk to the buffer
|
||||
buffer.extend(chunk)
|
||||
|
||||
# Check if buffer has enough data for processing
|
||||
while len(buffer) >= 48000: # Assuming at least 0.5 seconds of audio data at 24000 Hz
|
||||
# Process the buffer up to a safe size for resampling
|
||||
process_data = buffer[:48000]
|
||||
# Remove processed data from buffer
|
||||
buffer = buffer[48000:]
|
||||
|
||||
# Convert the byte data to numpy array for resampling
|
||||
audio_np = np.frombuffer(process_data, dtype=np.int16)
|
||||
# Resample the audio from 24000 Hz to 16000 Hz
|
||||
resampled_audio = resampy.resample(audio_np, 24000, 16000)
|
||||
# Convert the numpy array back to bytes
|
||||
resampled_audio_bytes = resampled_audio.astype(np.int16).tobytes()
|
||||
# Create the frame with the resampled audio
|
||||
frame = AudioRawFrame(resampled_audio_bytes, 16000, 1)
|
||||
yield frame
|
||||
|
||||
# Process any remaining data in the buffer
|
||||
if len(buffer) > 0:
|
||||
audio_np = np.frombuffer(buffer, dtype=np.int16)
|
||||
resampled_audio = resampy.resample(audio_np, 24000, 16000)
|
||||
resampled_audio_bytes = resampled_audio.astype(np.int16).tobytes()
|
||||
frame = AudioRawFrame(resampled_audio_bytes, 16000, 1)
|
||||
yield frame
|
||||
@@ -5,19 +5,20 @@
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import queue
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.frames.frames import (
|
||||
AudioRawFrame,
|
||||
BotInterruptionFrame,
|
||||
CancelFrame,
|
||||
StartFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
StartInterruptionFrame,
|
||||
StopInterruptionFrame,
|
||||
SystemFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame)
|
||||
from pipecat.transports.base_transport import TransportParams
|
||||
@@ -33,8 +34,6 @@ class BaseInputTransport(FrameProcessor):
|
||||
|
||||
self._params = params
|
||||
|
||||
self._running = False
|
||||
|
||||
self._executor = ThreadPoolExecutor(max_workers=5)
|
||||
|
||||
# Create push frame task. This is the task that will push frames in
|
||||
@@ -42,57 +41,68 @@ class BaseInputTransport(FrameProcessor):
|
||||
self._create_push_task()
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
|
||||
# Create audio input queue and thread if needed.
|
||||
# Create audio input queue and task if needed.
|
||||
if self._params.audio_in_enabled or self._params.vad_enabled:
|
||||
self._audio_in_queue = queue.Queue()
|
||||
self._audio_thread = self._loop.run_in_executor(
|
||||
self._executor, self._audio_thread_handler)
|
||||
self._audio_in_queue = asyncio.Queue()
|
||||
self._audio_task = self.get_event_loop().create_task(self._audio_task_handler())
|
||||
|
||||
async def stop(self):
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
# This will exit all threads.
|
||||
self._running = False
|
||||
|
||||
# Wait for the threads to finish.
|
||||
async def stop(self, frame: EndFrame):
|
||||
# Cancel and wait for the audio input task to finish.
|
||||
if self._params.audio_in_enabled or self._params.vad_enabled:
|
||||
await self._audio_thread
|
||||
self._audio_task.cancel()
|
||||
await self._audio_task
|
||||
|
||||
# Wait for the push frame task to finish. It will finish when the
|
||||
# EndFrame is actually processed.
|
||||
await self._push_frame_task
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
# Cancel all the tasks and wait for them to finish.
|
||||
|
||||
if self._params.audio_in_enabled or self._params.vad_enabled:
|
||||
self._audio_task.cancel()
|
||||
await self._audio_task
|
||||
|
||||
self._push_frame_task.cancel()
|
||||
await self._push_frame_task
|
||||
|
||||
def vad_analyzer(self) -> VADAnalyzer | None:
|
||||
return self._params.vad_analyzer
|
||||
|
||||
def push_audio_frame(self, frame: AudioRawFrame):
|
||||
async def push_audio_frame(self, frame: AudioRawFrame):
|
||||
if self._params.audio_in_enabled or self._params.vad_enabled:
|
||||
self._audio_in_queue.put_nowait(frame)
|
||||
await self._audio_in_queue.put(frame)
|
||||
|
||||
#
|
||||
# Frame processor
|
||||
#
|
||||
|
||||
async def cleanup(self):
|
||||
pass
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# Specific system frames
|
||||
if isinstance(frame, CancelFrame):
|
||||
await self.stop()
|
||||
# We don't queue a CancelFrame since we want to stop ASAP.
|
||||
await self.cancel(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, BotInterruptionFrame):
|
||||
await self._handle_interruptions(frame, False)
|
||||
elif isinstance(frame, StartInterruptionFrame):
|
||||
await self._start_interruption()
|
||||
elif isinstance(frame, StopInterruptionFrame):
|
||||
await self._stop_interruption()
|
||||
# All other system frames
|
||||
elif isinstance(frame, SystemFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
# Control frames
|
||||
elif isinstance(frame, StartFrame):
|
||||
await self.start(frame)
|
||||
await self._internal_push_frame(frame, direction)
|
||||
elif isinstance(frame, EndFrame):
|
||||
# Push EndFrame before stop(), because stop() waits on the task to
|
||||
# finish and the task finishes when EndFrame is processed.
|
||||
await self._internal_push_frame(frame, direction)
|
||||
await self.stop()
|
||||
await self.stop(frame)
|
||||
# Other frames
|
||||
else:
|
||||
await self._internal_push_frame(frame, direction)
|
||||
|
||||
@@ -102,8 +112,8 @@ class BaseInputTransport(FrameProcessor):
|
||||
|
||||
def _create_push_task(self):
|
||||
loop = self.get_event_loop()
|
||||
self._push_frame_task = loop.create_task(self._push_frame_task_handler())
|
||||
self._push_queue = asyncio.Queue()
|
||||
self._push_frame_task = loop.create_task(self._push_frame_task_handler())
|
||||
|
||||
async def _internal_push_frame(
|
||||
self,
|
||||
@@ -112,10 +122,13 @@ class BaseInputTransport(FrameProcessor):
|
||||
await self._push_queue.put((frame, direction))
|
||||
|
||||
async def _push_frame_task_handler(self):
|
||||
while True:
|
||||
running = True
|
||||
while running:
|
||||
try:
|
||||
(frame, direction) = await self._push_queue.get()
|
||||
await self.push_frame(frame, direction)
|
||||
running = not isinstance(frame, EndFrame)
|
||||
self._push_queue.task_done()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
@@ -123,32 +136,56 @@ class BaseInputTransport(FrameProcessor):
|
||||
# Handle interruptions
|
||||
#
|
||||
|
||||
async def _handle_interruptions(self, frame: Frame):
|
||||
async def _start_interruption(self):
|
||||
if not self.interruptions_allowed:
|
||||
return
|
||||
|
||||
# Cancel the task. This will stop pushing frames downstream.
|
||||
self._push_frame_task.cancel()
|
||||
await self._push_frame_task
|
||||
# Push an out-of-band frame (i.e. not using the ordered push
|
||||
# frame task) to stop everything, specially at the output
|
||||
# transport.
|
||||
await self.push_frame(StartInterruptionFrame())
|
||||
# Create a new queue and task.
|
||||
self._create_push_task()
|
||||
|
||||
async def _stop_interruption(self):
|
||||
if not self.interruptions_allowed:
|
||||
return
|
||||
|
||||
await self.push_frame(StopInterruptionFrame())
|
||||
|
||||
async def _handle_interruptions(self, frame: Frame, push_frame: bool):
|
||||
if self.interruptions_allowed:
|
||||
# Make sure we notify about interruptions quickly out-of-band
|
||||
if isinstance(frame, UserStartedSpeakingFrame):
|
||||
if isinstance(frame, BotInterruptionFrame):
|
||||
logger.debug("Bot interruption")
|
||||
await self._start_interruption()
|
||||
elif isinstance(frame, UserStartedSpeakingFrame):
|
||||
logger.debug("User started speaking")
|
||||
self._push_frame_task.cancel()
|
||||
self._create_push_task()
|
||||
await self.push_frame(StartInterruptionFrame())
|
||||
await self._start_interruption()
|
||||
elif isinstance(frame, UserStoppedSpeakingFrame):
|
||||
logger.debug("User stopped speaking")
|
||||
await self.push_frame(StopInterruptionFrame())
|
||||
await self._internal_push_frame(frame)
|
||||
await self._stop_interruption()
|
||||
|
||||
if push_frame:
|
||||
await self._internal_push_frame(frame)
|
||||
|
||||
#
|
||||
# Audio input
|
||||
#
|
||||
|
||||
def _vad_analyze(self, audio_frames: bytes) -> VADState:
|
||||
async def _vad_analyze(self, audio_frames: bytes) -> VADState:
|
||||
state = VADState.QUIET
|
||||
vad_analyzer = self.vad_analyzer()
|
||||
if vad_analyzer:
|
||||
state = vad_analyzer.analyze_audio(audio_frames)
|
||||
state = await self.get_event_loop().run_in_executor(
|
||||
self._executor, vad_analyzer.analyze_audio, audio_frames)
|
||||
return state
|
||||
|
||||
def _handle_vad(self, audio_frames: bytes, vad_state: VADState):
|
||||
new_vad_state = self._vad_analyze(audio_frames)
|
||||
async def _handle_vad(self, audio_frames: bytes, vad_state: VADState):
|
||||
new_vad_state = await self._vad_analyze(audio_frames)
|
||||
if new_vad_state != vad_state and new_vad_state != VADState.STARTING and new_vad_state != VADState.STOPPING:
|
||||
frame = None
|
||||
if new_vad_state == VADState.SPEAKING:
|
||||
@@ -157,33 +194,31 @@ class BaseInputTransport(FrameProcessor):
|
||||
frame = UserStoppedSpeakingFrame()
|
||||
|
||||
if frame:
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._handle_interruptions(frame), self.get_event_loop())
|
||||
future.result()
|
||||
await self._handle_interruptions(frame, True)
|
||||
|
||||
vad_state = new_vad_state
|
||||
return vad_state
|
||||
|
||||
def _audio_thread_handler(self):
|
||||
async def _audio_task_handler(self):
|
||||
vad_state: VADState = VADState.QUIET
|
||||
while self._running:
|
||||
while True:
|
||||
try:
|
||||
frame: AudioRawFrame = self._audio_in_queue.get(timeout=1)
|
||||
frame: AudioRawFrame = await self._audio_in_queue.get()
|
||||
|
||||
audio_passthrough = True
|
||||
|
||||
# Check VAD and push event if necessary. We just care about
|
||||
# changes from QUIET to SPEAKING and vice versa.
|
||||
if self._params.vad_enabled:
|
||||
vad_state = self._handle_vad(frame.audio, vad_state)
|
||||
vad_state = await self._handle_vad(frame.audio, vad_state)
|
||||
audio_passthrough = self._params.vad_audio_passthrough
|
||||
|
||||
# Push audio downstream if passthrough.
|
||||
if audio_passthrough:
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._internal_push_frame(frame), self._loop)
|
||||
future.result()
|
||||
except queue.Empty:
|
||||
pass
|
||||
except BaseException as e:
|
||||
logger.error(f"{self} error reading audio frames: {e}")
|
||||
await self._internal_push_frame(frame)
|
||||
|
||||
self._audio_in_queue.task_done()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} error reading audio frames: {e}")
|
||||
|
||||
@@ -7,11 +7,6 @@
|
||||
|
||||
import asyncio
|
||||
import itertools
|
||||
import queue
|
||||
import time
|
||||
import threading
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from PIL import Image
|
||||
from typing import List
|
||||
@@ -19,6 +14,9 @@ from typing import List
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.frames.frames import (
|
||||
AudioRawFrame,
|
||||
BotSpeakingFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
MetricsFrame,
|
||||
SpriteFrame,
|
||||
@@ -29,6 +27,8 @@ from pipecat.frames.frames import (
|
||||
StartInterruptionFrame,
|
||||
StopInterruptionFrame,
|
||||
SystemFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
TransportMessageFrame)
|
||||
from pipecat.transports.base_transport import TransportParams
|
||||
|
||||
@@ -42,166 +42,164 @@ class BaseOutputTransport(FrameProcessor):
|
||||
|
||||
self._params = params
|
||||
|
||||
self._running = False
|
||||
|
||||
self._executor = ThreadPoolExecutor(max_workers=5)
|
||||
|
||||
# These are the images that we should send to the camera at our desired
|
||||
# framerate.
|
||||
self._camera_images = None
|
||||
|
||||
# Create media threads queues.
|
||||
if self._params.camera_out_enabled:
|
||||
self._camera_out_queue = queue.Queue()
|
||||
self._sink_queue = queue.Queue()
|
||||
self._sink_thread = None
|
||||
|
||||
self._stopped_event = asyncio.Event()
|
||||
self._is_interrupted = threading.Event()
|
||||
|
||||
# We will write 20ms audio at a time. If we receive long audio frames we
|
||||
# will chunk them. This will help with interruption handling.
|
||||
audio_bytes_10ms = int(self._params.audio_out_sample_rate / 100) * \
|
||||
self._params.audio_out_channels * 2
|
||||
self._audio_chunk_size = audio_bytes_10ms * 2
|
||||
|
||||
self._stopped_event = asyncio.Event()
|
||||
|
||||
# Create sink frame task. This is the task that will actually write
|
||||
# audio or video frames. We write audio/video in a task so we can keep
|
||||
# generating frames upstream while, for example, the audio is playing.
|
||||
self._create_sink_task()
|
||||
|
||||
# Create push frame task. This is the task that will push frames in
|
||||
# order. We also guarantee that all frames are pushed in the same task.
|
||||
self._create_push_task()
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
|
||||
loop = self.get_event_loop()
|
||||
|
||||
# Create queues and threads.
|
||||
# Create camera output queue and task if needed.
|
||||
if self._params.camera_out_enabled:
|
||||
self._camera_out_thread = loop.run_in_executor(
|
||||
self._executor, self._camera_out_thread_handler)
|
||||
self._camera_out_queue = asyncio.Queue()
|
||||
self._camera_out_task = self.get_event_loop().create_task(self._camera_out_task_handler())
|
||||
|
||||
self._sink_thread = loop.run_in_executor(self._executor, self._sink_thread_handler)
|
||||
async def stop(self, frame: EndFrame):
|
||||
# Cancel and wait for the camera output task to finish.
|
||||
if self._params.camera_out_enabled:
|
||||
self._camera_out_task.cancel()
|
||||
await self._camera_out_task
|
||||
|
||||
async def stop(self):
|
||||
if not self._running:
|
||||
return
|
||||
# Wait for the push frame and sink tasks to finish. They will finish when
|
||||
# the EndFrame is actually processed.
|
||||
await self._push_frame_task
|
||||
await self._sink_task
|
||||
|
||||
# This will exit all threads.
|
||||
self._running = False
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
# Cancel all the tasks and wait for them to finish.
|
||||
|
||||
self._stopped_event.set()
|
||||
if self._params.camera_out_enabled:
|
||||
self._camera_out_task.cancel()
|
||||
await self._camera_out_task
|
||||
|
||||
def send_message(self, frame: TransportMessageFrame):
|
||||
self._push_frame_task.cancel()
|
||||
await self._push_frame_task
|
||||
|
||||
self._sink_task.cancel()
|
||||
await self._sink_task
|
||||
|
||||
async def send_message(self, frame: TransportMessageFrame):
|
||||
pass
|
||||
|
||||
def send_metrics(self, frame: MetricsFrame):
|
||||
async def send_metrics(self, frame: MetricsFrame):
|
||||
pass
|
||||
|
||||
def write_frame_to_camera(self, frame: ImageRawFrame):
|
||||
async def write_frame_to_camera(self, frame: ImageRawFrame):
|
||||
pass
|
||||
|
||||
def write_raw_audio_frames(self, frames: bytes):
|
||||
async def write_raw_audio_frames(self, frames: bytes):
|
||||
pass
|
||||
|
||||
#
|
||||
# Frame processor
|
||||
#
|
||||
|
||||
async def cleanup(self):
|
||||
# Wait on the threads to finish.
|
||||
if self._params.camera_out_enabled:
|
||||
await self._camera_out_thread
|
||||
|
||||
if self._sink_thread:
|
||||
await self._sink_thread
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
#
|
||||
# Out-of-band frames like (CancelFrame or StartInterruptionFrame) are
|
||||
# pushed immediately. Other frames require order so they are put in the
|
||||
# sink queue.
|
||||
# System frames (like StartInterruptionFrame) are pushed
|
||||
# immediately. Other frames require order so they are put in the sink
|
||||
# queue.
|
||||
#
|
||||
if isinstance(frame, StartFrame):
|
||||
await self.start(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
# EndFrame is managed in the queue handler.
|
||||
elif isinstance(frame, CancelFrame):
|
||||
await self.stop()
|
||||
if isinstance(frame, CancelFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
await self.cancel(frame)
|
||||
elif isinstance(frame, StartInterruptionFrame) or isinstance(frame, StopInterruptionFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
await self._handle_interruptions(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, MetricsFrame):
|
||||
self.send_metrics(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
await self.send_metrics(frame)
|
||||
elif isinstance(frame, SystemFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
# Control frames.
|
||||
elif isinstance(frame, StartFrame):
|
||||
await self._sink_queue.put(frame)
|
||||
await self.start(frame)
|
||||
elif isinstance(frame, EndFrame):
|
||||
await self._sink_queue.put(frame)
|
||||
await self.stop(frame)
|
||||
# Other frames.
|
||||
elif isinstance(frame, AudioRawFrame):
|
||||
await self._handle_audio(frame)
|
||||
else:
|
||||
self._sink_queue.put_nowait(frame)
|
||||
|
||||
# If we are finishing, wait here until we have stopped, otherwise we might
|
||||
# close things too early upstream. We need this event because we don't
|
||||
# know when the internal threads will finish.
|
||||
if isinstance(frame, CancelFrame) or isinstance(frame, EndFrame):
|
||||
await self._stopped_event.wait()
|
||||
await self._sink_queue.put(frame)
|
||||
|
||||
async def _handle_interruptions(self, frame: Frame):
|
||||
if not self.interruptions_allowed:
|
||||
return
|
||||
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
self._is_interrupted.set()
|
||||
# Stop sink task.
|
||||
self._sink_task.cancel()
|
||||
await self._sink_task
|
||||
self._create_sink_task()
|
||||
# Stop push task.
|
||||
self._push_frame_task.cancel()
|
||||
await self._push_frame_task
|
||||
self._create_push_task()
|
||||
elif isinstance(frame, StopInterruptionFrame):
|
||||
self._is_interrupted.clear()
|
||||
|
||||
async def _handle_audio(self, frame: AudioRawFrame):
|
||||
audio = frame.audio
|
||||
for i in range(0, len(audio), self._audio_chunk_size):
|
||||
chunk = AudioRawFrame(audio[i: i + self._audio_chunk_size],
|
||||
sample_rate=frame.sample_rate, num_channels=frame.num_channels)
|
||||
self._sink_queue.put_nowait(chunk)
|
||||
await self._sink_queue.put(chunk)
|
||||
|
||||
def _sink_thread_handler(self):
|
||||
def _create_sink_task(self):
|
||||
loop = self.get_event_loop()
|
||||
self._sink_queue = asyncio.Queue()
|
||||
self._sink_task = loop.create_task(self._sink_task_handler())
|
||||
|
||||
async def _sink_task_handler(self):
|
||||
# Audio accumlation buffer
|
||||
buffer = bytearray()
|
||||
while self._running:
|
||||
try:
|
||||
frame = self._sink_queue.get(timeout=1)
|
||||
if not self._is_interrupted.is_set():
|
||||
if isinstance(frame, AudioRawFrame) and self._params.audio_out_enabled:
|
||||
buffer.extend(frame.audio)
|
||||
buffer = self._maybe_send_audio(buffer)
|
||||
elif isinstance(frame, ImageRawFrame) and self._params.camera_out_enabled:
|
||||
self._set_camera_image(frame)
|
||||
elif isinstance(frame, SpriteFrame) and self._params.camera_out_enabled:
|
||||
self._set_camera_images(frame.images)
|
||||
elif isinstance(frame, TransportMessageFrame):
|
||||
self.send_message(frame)
|
||||
else:
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._internal_push_frame(frame), self.get_event_loop())
|
||||
future.result()
|
||||
else:
|
||||
# If we get interrupted just clear the output buffer.
|
||||
buffer = bytearray()
|
||||
|
||||
if isinstance(frame, EndFrame):
|
||||
future = asyncio.run_coroutine_threadsafe(self.stop(), self.get_event_loop())
|
||||
future.result()
|
||||
running = True
|
||||
while running:
|
||||
try:
|
||||
frame = await self._sink_queue.get()
|
||||
if isinstance(frame, AudioRawFrame) and self._params.audio_out_enabled:
|
||||
buffer.extend(frame.audio)
|
||||
buffer = await self._maybe_send_audio(buffer)
|
||||
elif isinstance(frame, ImageRawFrame) and self._params.camera_out_enabled:
|
||||
await self._set_camera_image(frame)
|
||||
elif isinstance(frame, SpriteFrame) and self._params.camera_out_enabled:
|
||||
await self._set_camera_images(frame.images)
|
||||
elif isinstance(frame, TransportMessageFrame):
|
||||
await self.send_message(frame)
|
||||
elif isinstance(frame, TTSStartedFrame):
|
||||
await self._internal_push_frame(BotStartedSpeakingFrame(), FrameDirection.UPSTREAM)
|
||||
await self._internal_push_frame(frame)
|
||||
elif isinstance(frame, TTSStoppedFrame):
|
||||
await self._internal_push_frame(BotStoppedSpeakingFrame(), FrameDirection.UPSTREAM)
|
||||
await self._internal_push_frame(frame)
|
||||
else:
|
||||
await self._internal_push_frame(frame)
|
||||
|
||||
running = not isinstance(frame, EndFrame)
|
||||
|
||||
self._sink_queue.task_done()
|
||||
except queue.Empty:
|
||||
pass
|
||||
except BaseException as e:
|
||||
logger.error(f"{self} error processing sink queue: {e}")
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} error processing sink queue: {e}")
|
||||
|
||||
#
|
||||
# Push frames task
|
||||
@@ -209,8 +207,8 @@ class BaseOutputTransport(FrameProcessor):
|
||||
|
||||
def _create_push_task(self):
|
||||
loop = self.get_event_loop()
|
||||
self._push_frame_task = loop.create_task(self._push_frame_task_handler())
|
||||
self._push_queue = asyncio.Queue()
|
||||
self._push_frame_task = loop.create_task(self._push_frame_task_handler())
|
||||
|
||||
async def _internal_push_frame(
|
||||
self,
|
||||
@@ -219,10 +217,13 @@ class BaseOutputTransport(FrameProcessor):
|
||||
await self._push_queue.put((frame, direction))
|
||||
|
||||
async def _push_frame_task_handler(self):
|
||||
while True:
|
||||
running = True
|
||||
while running:
|
||||
try:
|
||||
(frame, direction) = await self._push_queue.get()
|
||||
await self.push_frame(frame, direction)
|
||||
running = not isinstance(frame, EndFrame)
|
||||
self._push_queue.task_done()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
@@ -233,7 +234,7 @@ class BaseOutputTransport(FrameProcessor):
|
||||
async def send_image(self, frame: ImageRawFrame | SpriteFrame):
|
||||
await self.process_frame(frame, FrameDirection.DOWNSTREAM)
|
||||
|
||||
def _draw_image(self, frame: ImageRawFrame):
|
||||
async def _draw_image(self, frame: ImageRawFrame):
|
||||
desired_size = (self._params.camera_out_width, self._params.camera_out_height)
|
||||
|
||||
if frame.size != desired_size:
|
||||
@@ -243,34 +244,34 @@ class BaseOutputTransport(FrameProcessor):
|
||||
f"{frame} does not have the expected size {desired_size}, resizing")
|
||||
frame = ImageRawFrame(resized_image.tobytes(), resized_image.size, resized_image.format)
|
||||
|
||||
self.write_frame_to_camera(frame)
|
||||
await self.write_frame_to_camera(frame)
|
||||
|
||||
def _set_camera_image(self, image: ImageRawFrame):
|
||||
async def _set_camera_image(self, image: ImageRawFrame):
|
||||
if self._params.camera_out_is_live:
|
||||
self._camera_out_queue.put_nowait(image)
|
||||
await self._camera_out_queue.put(image)
|
||||
else:
|
||||
self._camera_images = itertools.cycle([image])
|
||||
|
||||
def _set_camera_images(self, images: List[ImageRawFrame]):
|
||||
async def _set_camera_images(self, images: List[ImageRawFrame]):
|
||||
self._camera_images = itertools.cycle(images)
|
||||
|
||||
def _camera_out_thread_handler(self):
|
||||
while self._running:
|
||||
async def _camera_out_task_handler(self):
|
||||
while True:
|
||||
try:
|
||||
if self._params.camera_out_is_live:
|
||||
image = self._camera_out_queue.get(timeout=1)
|
||||
self._draw_image(image)
|
||||
image = await self._camera_out_queue.get()
|
||||
await self._draw_image(image)
|
||||
self._camera_out_queue.task_done()
|
||||
elif self._camera_images:
|
||||
image = next(self._camera_images)
|
||||
self._draw_image(image)
|
||||
time.sleep(1.0 / self._params.camera_out_framerate)
|
||||
await self._draw_image(image)
|
||||
await asyncio.sleep(1.0 / self._params.camera_out_framerate)
|
||||
else:
|
||||
time.sleep(1.0 / self._params.camera_out_framerate)
|
||||
except queue.Empty:
|
||||
pass
|
||||
await asyncio.sleep(1.0 / self._params.camera_out_framerate)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error writing to camera: {e}")
|
||||
logger.exception(f"{self} error writing to camera: {e}")
|
||||
|
||||
#
|
||||
# Audio out
|
||||
@@ -279,12 +280,9 @@ class BaseOutputTransport(FrameProcessor):
|
||||
async def send_audio(self, frame: AudioRawFrame):
|
||||
await self.process_frame(frame, FrameDirection.DOWNSTREAM)
|
||||
|
||||
def _maybe_send_audio(self, buffer: bytearray) -> bytearray:
|
||||
try:
|
||||
if len(buffer) >= self._audio_chunk_size:
|
||||
self.write_raw_audio_frames(bytes(buffer[:self._audio_chunk_size]))
|
||||
buffer = buffer[self._audio_chunk_size:]
|
||||
return buffer
|
||||
except BaseException as e:
|
||||
logger.error(f"{self} error writing audio frames: {e}")
|
||||
return buffer
|
||||
async def _maybe_send_audio(self, buffer: bytearray) -> bytearray:
|
||||
if len(buffer) >= self._audio_chunk_size:
|
||||
await self.write_raw_audio_frames(bytes(buffer[:self._audio_chunk_size]))
|
||||
buffer = buffer[self._audio_chunk_size:]
|
||||
await self.push_frame(BotSpeakingFrame(), FrameDirection.UPSTREAM)
|
||||
return buffer
|
||||
|
||||
@@ -60,20 +60,20 @@ class BaseTransport(ABC):
|
||||
|
||||
def event_handler(self, event_name: str):
|
||||
def decorator(handler):
|
||||
self._add_event_handler(event_name, handler)
|
||||
self.add_event_handler(event_name, handler)
|
||||
return handler
|
||||
return decorator
|
||||
|
||||
def add_event_handler(self, event_name: str, handler):
|
||||
if event_name not in self._event_handlers:
|
||||
raise Exception(f"Event handler {event_name} not registered")
|
||||
self._event_handlers[event_name].append(handler)
|
||||
|
||||
def _register_event_handler(self, event_name: str):
|
||||
if event_name in self._event_handlers:
|
||||
raise Exception(f"Event handler {event_name} already registered")
|
||||
self._event_handlers[event_name] = []
|
||||
|
||||
def _add_event_handler(self, event_name: str, handler):
|
||||
if event_name not in self._event_handlers:
|
||||
raise Exception(f"Event handler {event_name} not registered")
|
||||
self._event_handlers[event_name].append(handler)
|
||||
|
||||
async def _call_event_handler(self, event_name: str, *args, **kwargs):
|
||||
try:
|
||||
for handler in self._event_handlers[event_name]:
|
||||
@@ -82,5 +82,4 @@ class BaseTransport(ABC):
|
||||
else:
|
||||
handler(self, *args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(f"Exception in event handler {event_name}: {e}")
|
||||
raise e
|
||||
logger.exception(f"Exception in event handler {event_name}: {e}")
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
|
||||
import asyncio
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from pipecat.frames.frames import AudioRawFrame, StartFrame
|
||||
from pipecat.processors.frame_processor import FrameProcessor
|
||||
from pipecat.transports.base_input import BaseInputTransport
|
||||
@@ -43,26 +45,20 @@ class LocalAudioInputTransport(BaseInputTransport):
|
||||
await super().start(frame)
|
||||
self._in_stream.start_stream()
|
||||
|
||||
async def stop(self):
|
||||
await super().stop()
|
||||
self._in_stream.stop_stream()
|
||||
|
||||
async def cleanup(self):
|
||||
await super().cleanup()
|
||||
self._in_stream.stop_stream()
|
||||
# This is not very pretty (taken from PyAudio docs).
|
||||
while self._in_stream.is_active():
|
||||
await asyncio.sleep(0.1)
|
||||
self._in_stream.close()
|
||||
|
||||
await super().cleanup()
|
||||
|
||||
def _audio_in_callback(self, in_data, frame_count, time_info, status):
|
||||
if not self._running:
|
||||
return (None, pyaudio.paAbort)
|
||||
|
||||
frame = AudioRawFrame(audio=in_data,
|
||||
sample_rate=self._params.audio_in_sample_rate,
|
||||
num_channels=self._params.audio_in_channels)
|
||||
self.push_audio_frame(frame)
|
||||
|
||||
asyncio.run_coroutine_threadsafe(self.push_audio_frame(frame), self.get_event_loop())
|
||||
|
||||
return (None, pyaudio.paContinue)
|
||||
|
||||
@@ -72,19 +68,29 @@ class LocalAudioOutputTransport(BaseOutputTransport):
|
||||
def __init__(self, py_audio: pyaudio.PyAudio, params: TransportParams):
|
||||
super().__init__(params)
|
||||
|
||||
self._executor = ThreadPoolExecutor(max_workers=5)
|
||||
|
||||
self._out_stream = py_audio.open(
|
||||
format=py_audio.get_format_from_width(2),
|
||||
channels=params.audio_out_channels,
|
||||
rate=params.audio_out_sample_rate,
|
||||
output=True)
|
||||
|
||||
def write_raw_audio_frames(self, frames: bytes):
|
||||
self._out_stream.write(frames)
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
self._out_stream.start_stream()
|
||||
|
||||
async def cleanup(self):
|
||||
await super().cleanup()
|
||||
self._out_stream.stop_stream()
|
||||
# This is not very pretty (taken from PyAudio docs).
|
||||
while self._out_stream.is_active():
|
||||
await asyncio.sleep(0.1)
|
||||
self._out_stream.close()
|
||||
|
||||
async def write_raw_audio_frames(self, frames: bytes):
|
||||
await self.get_event_loop().run_in_executor(self._executor, self._out_stream.write, frames)
|
||||
|
||||
|
||||
class LocalAudioTransport(BaseTransport):
|
||||
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
|
||||
import asyncio
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import numpy as np
|
||||
import tkinter as tk
|
||||
|
||||
@@ -53,25 +55,20 @@ class TkInputTransport(BaseInputTransport):
|
||||
await super().start(frame)
|
||||
self._in_stream.start_stream()
|
||||
|
||||
async def stop(self):
|
||||
await super().stop()
|
||||
self._in_stream.stop_stream()
|
||||
|
||||
async def cleanup(self):
|
||||
await super().cleanup()
|
||||
self._in_stream.stop_stream()
|
||||
# This is not very pretty (taken from PyAudio docs).
|
||||
while self._in_stream.is_active():
|
||||
await asyncio.sleep(0.1)
|
||||
self._in_stream.close()
|
||||
|
||||
def _audio_in_callback(self, in_data, frame_count, time_info, status):
|
||||
if not self._running:
|
||||
return (None, pyaudio.paAbort)
|
||||
|
||||
frame = AudioRawFrame(audio=in_data,
|
||||
sample_rate=self._params.audio_in_sample_rate,
|
||||
num_channels=self._params.audio_in_channels)
|
||||
self.push_audio_frame(frame)
|
||||
|
||||
asyncio.run_coroutine_threadsafe(self.push_audio_frame(frame), self.get_event_loop())
|
||||
|
||||
return (None, pyaudio.paContinue)
|
||||
|
||||
@@ -81,6 +78,8 @@ class TkOutputTransport(BaseOutputTransport):
|
||||
def __init__(self, tk_root: tk.Tk, py_audio: pyaudio.PyAudio, params: TransportParams):
|
||||
super().__init__(params)
|
||||
|
||||
self._executor = ThreadPoolExecutor(max_workers=5)
|
||||
|
||||
self._out_stream = py_audio.open(
|
||||
format=py_audio.get_format_from_width(2),
|
||||
channels=params.audio_out_channels,
|
||||
@@ -94,16 +93,24 @@ class TkOutputTransport(BaseOutputTransport):
|
||||
self._image_label = tk.Label(tk_root, image=photo)
|
||||
self._image_label.pack()
|
||||
|
||||
def write_raw_audio_frames(self, frames: bytes):
|
||||
self._out_stream.write(frames)
|
||||
|
||||
def write_frame_to_camera(self, frame: ImageRawFrame):
|
||||
self.get_event_loop().call_soon(self._write_frame_to_tk, frame)
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
self._out_stream.start_stream()
|
||||
|
||||
async def cleanup(self):
|
||||
await super().cleanup()
|
||||
self._out_stream.stop_stream()
|
||||
# This is not very pretty (taken from PyAudio docs).
|
||||
while self._out_stream.is_active():
|
||||
await asyncio.sleep(0.1)
|
||||
self._out_stream.close()
|
||||
|
||||
async def write_raw_audio_frames(self, frames: bytes):
|
||||
await self.get_event_loop().run_in_executor(self._executor, self._out_stream.write, frames)
|
||||
|
||||
async def write_frame_to_camera(self, frame: ImageRawFrame):
|
||||
self.get_event_loop().call_soon(self._write_frame_to_tk, frame)
|
||||
|
||||
def _write_frame_to_tk(self, frame: ImageRawFrame):
|
||||
width = frame.size[0]
|
||||
height = frame.size[1]
|
||||
|
||||
164
src/pipecat/transports/network/fastapi_websocket.py
Normal file
164
src/pipecat/transports/network/fastapi_websocket.py
Normal file
@@ -0,0 +1,164 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import wave
|
||||
|
||||
from typing import Awaitable, Callable
|
||||
from pydantic.main import BaseModel
|
||||
|
||||
from pipecat.frames.frames import AudioRawFrame, CancelFrame, EndFrame, StartFrame
|
||||
from pipecat.processors.frame_processor import FrameProcessor
|
||||
from pipecat.serializers.base_serializer import FrameSerializer
|
||||
from pipecat.transports.base_input import BaseInputTransport
|
||||
from pipecat.transports.base_output import BaseOutputTransport
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
from fastapi import WebSocket
|
||||
from starlette.websockets import WebSocketState
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use FastAPI websockets, you need to `pip install pipecat-ai[websocket]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class FastAPIWebsocketParams(TransportParams):
|
||||
add_wav_header: bool = False
|
||||
audio_frame_size: int = 6400 # 200ms
|
||||
serializer: FrameSerializer
|
||||
|
||||
|
||||
class FastAPIWebsocketCallbacks(BaseModel):
|
||||
on_client_connected: Callable[[WebSocket], Awaitable[None]]
|
||||
on_client_disconnected: Callable[[WebSocket], Awaitable[None]]
|
||||
|
||||
|
||||
class FastAPIWebsocketInputTransport(BaseInputTransport):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
websocket: WebSocket,
|
||||
params: FastAPIWebsocketParams,
|
||||
callbacks: FastAPIWebsocketCallbacks,
|
||||
**kwargs):
|
||||
super().__init__(params, **kwargs)
|
||||
|
||||
self._websocket = websocket
|
||||
self._params = params
|
||||
self._callbacks = callbacks
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
await self._callbacks.on_client_connected(self._websocket)
|
||||
self._receive_task = self.get_event_loop().create_task(self._receive_messages())
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
if self._websocket.client_state != WebSocketState.DISCONNECTED:
|
||||
await self._websocket.close()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
if self._websocket.client_state != WebSocketState.DISCONNECTED:
|
||||
await self._websocket.close()
|
||||
|
||||
async def _receive_messages(self):
|
||||
async for message in self._websocket.iter_text():
|
||||
frame = self._params.serializer.deserialize(message)
|
||||
|
||||
if not frame:
|
||||
continue
|
||||
|
||||
if isinstance(frame, AudioRawFrame):
|
||||
await self.push_audio_frame(frame)
|
||||
|
||||
await self._callbacks.on_client_disconnected(self._websocket)
|
||||
|
||||
|
||||
class FastAPIWebsocketOutputTransport(BaseOutputTransport):
|
||||
|
||||
def __init__(self, websocket: WebSocket, params: FastAPIWebsocketParams, **kwargs):
|
||||
super().__init__(params, **kwargs)
|
||||
|
||||
self._websocket = websocket
|
||||
self._params = params
|
||||
self._audio_buffer = bytes()
|
||||
|
||||
async def write_raw_audio_frames(self, frames: bytes):
|
||||
self._audio_buffer += frames
|
||||
while len(self._audio_buffer) >= self._params.audio_frame_size:
|
||||
frame = AudioRawFrame(
|
||||
audio=self._audio_buffer[:self._params.audio_frame_size],
|
||||
sample_rate=self._params.audio_out_sample_rate,
|
||||
num_channels=self._params.audio_out_channels
|
||||
)
|
||||
|
||||
if self._params.add_wav_header:
|
||||
content = io.BytesIO()
|
||||
ww = wave.open(content, "wb")
|
||||
ww.setsampwidth(2)
|
||||
ww.setnchannels(frame.num_channels)
|
||||
ww.setframerate(frame.sample_rate)
|
||||
ww.writeframes(frame.audio)
|
||||
ww.close()
|
||||
content.seek(0)
|
||||
wav_frame = AudioRawFrame(
|
||||
content.read(),
|
||||
sample_rate=frame.sample_rate,
|
||||
num_channels=frame.num_channels)
|
||||
frame = wav_frame
|
||||
|
||||
payload = self._params.serializer.serialize(frame)
|
||||
if payload and self._websocket.client_state == WebSocketState.CONNECTED:
|
||||
await self._websocket.send_text(payload)
|
||||
|
||||
self._audio_buffer = self._audio_buffer[self._params.audio_frame_size:]
|
||||
|
||||
|
||||
class FastAPIWebsocketTransport(BaseTransport):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
websocket: WebSocket,
|
||||
params: FastAPIWebsocketParams,
|
||||
input_name: str | None = None,
|
||||
output_name: str | None = None,
|
||||
loop: asyncio.AbstractEventLoop | None = None):
|
||||
super().__init__(input_name=input_name, output_name=output_name, loop=loop)
|
||||
self._params = params
|
||||
|
||||
self._callbacks = FastAPIWebsocketCallbacks(
|
||||
on_client_connected=self._on_client_connected,
|
||||
on_client_disconnected=self._on_client_disconnected
|
||||
)
|
||||
|
||||
self._input = FastAPIWebsocketInputTransport(
|
||||
websocket, self._params, self._callbacks, name=self._input_name)
|
||||
self._output = FastAPIWebsocketOutputTransport(
|
||||
websocket, self._params, name=self._output_name)
|
||||
|
||||
# Register supported handlers. The user will only be able to register
|
||||
# these handlers.
|
||||
self._register_event_handler("on_client_connected")
|
||||
self._register_event_handler("on_client_disconnected")
|
||||
|
||||
def input(self) -> FrameProcessor:
|
||||
return self._input
|
||||
|
||||
def output(self) -> FrameProcessor:
|
||||
return self._output
|
||||
|
||||
async def _on_client_connected(self, websocket):
|
||||
await self._call_event_handler("on_client_connected", websocket)
|
||||
|
||||
async def _on_client_disconnected(self, websocket):
|
||||
await self._call_event_handler("on_client_disconnected", websocket)
|
||||
@@ -4,16 +4,14 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import wave
|
||||
import websockets
|
||||
|
||||
from typing import Awaitable, Callable
|
||||
from pydantic.main import BaseModel
|
||||
|
||||
from pipecat.frames.frames import AudioRawFrame, StartFrame
|
||||
from pipecat.frames.frames import AudioRawFrame, CancelFrame, EndFrame, StartFrame
|
||||
from pipecat.processors.frame_processor import FrameProcessor
|
||||
from pipecat.serializers.base_serializer import FrameSerializer
|
||||
from pipecat.serializers.protobuf import ProtobufFrameSerializer
|
||||
@@ -23,6 +21,13 @@ from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
import websockets
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use websockets, you need to `pip install pipecat-ai[websocket]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class WebsocketServerParams(TransportParams):
|
||||
add_wav_header: bool = False
|
||||
@@ -59,10 +64,15 @@ class WebsocketServerInputTransport(BaseInputTransport):
|
||||
self._server_task = self.get_event_loop().create_task(self._server_task_handler())
|
||||
await super().start(frame)
|
||||
|
||||
async def stop(self):
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
self._stop_server_event.set()
|
||||
await self._server_task
|
||||
await super().stop()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
self._server_task.cancel()
|
||||
await self._server_task
|
||||
|
||||
async def _server_task_handler(self):
|
||||
logger.info(f"Starting websocket server on {self._host}:{self._port}")
|
||||
@@ -88,7 +98,7 @@ class WebsocketServerInputTransport(BaseInputTransport):
|
||||
continue
|
||||
|
||||
if isinstance(frame, AudioRawFrame):
|
||||
self.push_audio_frame(frame)
|
||||
await self.push_audio_frame(frame)
|
||||
else:
|
||||
await self._internal_push_frame(frame)
|
||||
|
||||
@@ -118,7 +128,10 @@ class WebsocketServerOutputTransport(BaseOutputTransport):
|
||||
logger.warning("Only one client allowed, using new connection")
|
||||
self._websocket = websocket
|
||||
|
||||
def write_raw_audio_frames(self, frames: bytes):
|
||||
async def write_raw_audio_frames(self, frames: bytes):
|
||||
if not self._websocket:
|
||||
return
|
||||
|
||||
self._audio_buffer += frames
|
||||
while len(self._audio_buffer) >= self._params.audio_frame_size:
|
||||
frame = AudioRawFrame(
|
||||
@@ -143,10 +156,8 @@ class WebsocketServerOutputTransport(BaseOutputTransport):
|
||||
frame = wav_frame
|
||||
|
||||
proto = self._params.serializer.serialize(frame)
|
||||
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._websocket.send(proto), self.get_event_loop())
|
||||
future.result()
|
||||
if proto:
|
||||
await self._websocket.send(proto)
|
||||
|
||||
self._audio_buffer = self._audio_buffer[self._params.audio_frame_size:]
|
||||
|
||||
|
||||
@@ -6,11 +6,10 @@
|
||||
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import queue
|
||||
import time
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Mapping
|
||||
from typing import Any, Awaitable, Callable, Mapping, Optional
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from daily import (
|
||||
@@ -24,6 +23,8 @@ from pydantic.main import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
AudioRawFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
ImageRawFrame,
|
||||
InterimTranscriptionFrame,
|
||||
@@ -60,8 +61,8 @@ class DailyTransportMessageFrame(TransportMessageFrame):
|
||||
|
||||
class WebRTCVADAnalyzer(VADAnalyzer):
|
||||
|
||||
def __init__(self, sample_rate=16000, num_channels=1, params: VADParams = VADParams()):
|
||||
super().__init__(sample_rate, num_channels, params)
|
||||
def __init__(self, *, sample_rate=16000, num_channels=1, params: VADParams = VADParams()):
|
||||
super().__init__(sample_rate=sample_rate, num_channels=num_channels, params=params)
|
||||
|
||||
self._webrtc_vad = Daily.create_native_vad(
|
||||
reset_period_ms=VAD_RESET_PERIOD_MS,
|
||||
@@ -102,25 +103,40 @@ class DailyTranscriptionSettings(BaseModel):
|
||||
class DailyParams(TransportParams):
|
||||
api_url: str = "https://api.daily.co/v1"
|
||||
api_key: str = ""
|
||||
dialin_settings: DailyDialinSettings | None = None
|
||||
dialin_settings: Optional[DailyDialinSettings] = None
|
||||
transcription_enabled: bool = False
|
||||
transcription_settings: DailyTranscriptionSettings = DailyTranscriptionSettings()
|
||||
|
||||
|
||||
class DailyCallbacks(BaseModel):
|
||||
on_joined: Callable[[Mapping[str, Any]], None]
|
||||
on_left: Callable[[], None]
|
||||
on_error: Callable[[str], None]
|
||||
on_app_message: Callable[[Any, str], None]
|
||||
on_call_state_updated: Callable[[str], None]
|
||||
on_dialin_ready: Callable[[str], None]
|
||||
on_dialout_connected: Callable[[Any], None]
|
||||
on_dialout_stopped: Callable[[Any], None]
|
||||
on_dialout_error: Callable[[Any], None]
|
||||
on_dialout_warning: Callable[[Any], None]
|
||||
on_first_participant_joined: Callable[[Mapping[str, Any]], None]
|
||||
on_participant_joined: Callable[[Mapping[str, Any]], None]
|
||||
on_participant_left: Callable[[Mapping[str, Any], str], None]
|
||||
on_joined: Callable[[Mapping[str, Any]], Awaitable[None]]
|
||||
on_left: Callable[[], Awaitable[None]]
|
||||
on_error: Callable[[str], Awaitable[None]]
|
||||
on_app_message: Callable[[Any, str], Awaitable[None]]
|
||||
on_call_state_updated: Callable[[str], Awaitable[None]]
|
||||
on_dialin_ready: Callable[[str], Awaitable[None]]
|
||||
on_dialout_answered: Callable[[Any], Awaitable[None]]
|
||||
on_dialout_connected: Callable[[Any], Awaitable[None]]
|
||||
on_dialout_stopped: Callable[[Any], Awaitable[None]]
|
||||
on_dialout_error: Callable[[Any], Awaitable[None]]
|
||||
on_dialout_warning: Callable[[Any], Awaitable[None]]
|
||||
on_first_participant_joined: Callable[[Mapping[str, Any]], Awaitable[None]]
|
||||
on_participant_joined: Callable[[Mapping[str, Any]], Awaitable[None]]
|
||||
on_participant_left: Callable[[Mapping[str, Any], str], Awaitable[None]]
|
||||
|
||||
|
||||
def completion_callback(future):
|
||||
def _callback(*args):
|
||||
def set_result(future, *args):
|
||||
try:
|
||||
if len(args) > 1:
|
||||
future.set_result(args)
|
||||
else:
|
||||
future.set_result(*args)
|
||||
except asyncio.InvalidStateError:
|
||||
pass
|
||||
future.get_loop().call_soon_threadsafe(set_result, future, *args)
|
||||
return _callback
|
||||
|
||||
|
||||
class DailyTransportClient(EventHandler):
|
||||
@@ -160,7 +176,6 @@ class DailyTransportClient(EventHandler):
|
||||
self._joined = False
|
||||
self._joining = False
|
||||
self._leaving = False
|
||||
self._sync_response = {k: queue.Queue() for k in ["join", "leave"]}
|
||||
|
||||
self._executor = ThreadPoolExecutor(max_workers=5)
|
||||
|
||||
@@ -173,10 +188,16 @@ class DailyTransportClient(EventHandler):
|
||||
color_format=self._params.camera_out_color_format)
|
||||
|
||||
self._mic: VirtualMicrophoneDevice = Daily.create_microphone_device(
|
||||
"mic", sample_rate=self._params.audio_out_sample_rate, channels=self._params.audio_out_channels)
|
||||
"mic",
|
||||
sample_rate=self._params.audio_out_sample_rate,
|
||||
channels=self._params.audio_out_channels,
|
||||
non_blocking=True)
|
||||
|
||||
self._speaker: VirtualSpeakerDevice = Daily.create_speaker_device(
|
||||
"speaker", sample_rate=self._params.audio_in_sample_rate, channels=self._params.audio_in_channels)
|
||||
"speaker",
|
||||
sample_rate=self._params.audio_in_sample_rate,
|
||||
channels=self._params.audio_in_channels,
|
||||
non_blocking=True)
|
||||
Daily.select_speaker_device("speaker")
|
||||
|
||||
@property
|
||||
@@ -186,30 +207,45 @@ class DailyTransportClient(EventHandler):
|
||||
def set_callbacks(self, callbacks: DailyCallbacks):
|
||||
self._callbacks = callbacks
|
||||
|
||||
def send_message(self, frame: DailyTransportMessageFrame):
|
||||
self._client.send_app_message(frame.message, frame.participant_id)
|
||||
async def send_message(self, frame: TransportMessageFrame):
|
||||
if not self._client:
|
||||
return
|
||||
|
||||
def read_next_audio_frame(self) -> AudioRawFrame | None:
|
||||
participant_id = None
|
||||
if isinstance(frame, DailyTransportMessageFrame):
|
||||
participant_id = frame.participant_id
|
||||
|
||||
future = self._loop.create_future()
|
||||
self._client.send_app_message(
|
||||
frame.message,
|
||||
participant_id,
|
||||
completion=completion_callback(future))
|
||||
await future
|
||||
|
||||
async def read_next_audio_frame(self) -> AudioRawFrame | None:
|
||||
sample_rate = self._params.audio_in_sample_rate
|
||||
num_channels = self._params.audio_in_channels
|
||||
num_frames = int(sample_rate / 100) * 2 # 20ms of audio
|
||||
|
||||
if self._other_participant_has_joined:
|
||||
num_frames = int(sample_rate / 100) * 2 # 20ms of audio
|
||||
|
||||
audio = self._speaker.read_frames(num_frames)
|
||||
future = self._loop.create_future()
|
||||
self._speaker.read_frames(num_frames, completion=completion_callback(future))
|
||||
audio = await future
|
||||
|
||||
if len(audio) > 0:
|
||||
return AudioRawFrame(audio=audio, sample_rate=sample_rate, num_channels=num_channels)
|
||||
else:
|
||||
# If no one has ever joined the meeting `read_frames()` would block,
|
||||
# instead we just wait a bit. daily-python should probably return
|
||||
# silence instead.
|
||||
time.sleep(0.01)
|
||||
# If we don't read any audio it could be there's no participant
|
||||
# connected. daily-python will return immediately if that's the
|
||||
# case, so let's sleep for a little bit (i.e. busy wait).
|
||||
await asyncio.sleep(0.01)
|
||||
return None
|
||||
|
||||
def write_raw_audio_frames(self, frames: bytes):
|
||||
self._mic.write_frames(frames)
|
||||
async def write_raw_audio_frames(self, frames: bytes):
|
||||
future = self._loop.create_future()
|
||||
self._mic.write_frames(frames, completion=completion_callback(future))
|
||||
await future
|
||||
|
||||
def write_frame_to_camera(self, frame: ImageRawFrame):
|
||||
async def write_frame_to_camera(self, frame: ImageRawFrame):
|
||||
self._camera.write_frame(frame.image)
|
||||
|
||||
async def join(self):
|
||||
@@ -217,13 +253,10 @@ class DailyTransportClient(EventHandler):
|
||||
if self._joined or self._joining:
|
||||
return
|
||||
|
||||
self._joining = True
|
||||
|
||||
await self._loop.run_in_executor(self._executor, self._join)
|
||||
|
||||
def _join(self):
|
||||
logger.info(f"Joining {self._room_url}")
|
||||
|
||||
self._joining = True
|
||||
|
||||
# For performance reasons, never subscribe to video streams (unless a
|
||||
# video renderer is registered).
|
||||
self._client.update_subscription_profiles({
|
||||
@@ -235,10 +268,47 @@ class DailyTransportClient(EventHandler):
|
||||
|
||||
self._client.set_user_name(self._bot_name)
|
||||
|
||||
try:
|
||||
(data, error) = await self._join()
|
||||
|
||||
if not error:
|
||||
self._joined = True
|
||||
self._joining = False
|
||||
|
||||
logger.info(f"Joined {self._room_url}")
|
||||
|
||||
if self._token and self._params.transcription_enabled:
|
||||
await self._start_transcription()
|
||||
|
||||
await self._callbacks.on_joined(data["participants"]["local"])
|
||||
else:
|
||||
error_msg = f"Error joining {self._room_url}: {error}"
|
||||
logger.error(error_msg)
|
||||
await self._callbacks.on_error(error_msg)
|
||||
except asyncio.TimeoutError:
|
||||
error_msg = f"Time out joining {self._room_url}"
|
||||
logger.error(error_msg)
|
||||
await self._callbacks.on_error(error_msg)
|
||||
|
||||
async def _start_transcription(self):
|
||||
logger.info(f"Enabling transcription with settings {self._params.transcription_settings}")
|
||||
|
||||
future = self._loop.create_future()
|
||||
self._client.start_transcription(
|
||||
settings=self._params.transcription_settings.model_dump(exclude_none=True),
|
||||
completion=completion_callback(future)
|
||||
)
|
||||
error = await future
|
||||
if error:
|
||||
logger.error(f"Unable to start transcription: {error}")
|
||||
|
||||
async def _join(self):
|
||||
future = self._loop.create_future()
|
||||
|
||||
self._client.join(
|
||||
self._room_url,
|
||||
self._token,
|
||||
completion=self._call_joined,
|
||||
completion=completion_callback(future),
|
||||
client_settings={
|
||||
"inputs": {
|
||||
"camera": {
|
||||
@@ -274,33 +344,7 @@ class DailyTransportClient(EventHandler):
|
||||
},
|
||||
})
|
||||
|
||||
self._handle_join_response()
|
||||
|
||||
def _handle_join_response(self):
|
||||
try:
|
||||
(data, error) = self._sync_response["join"].get(timeout=10)
|
||||
if not error:
|
||||
self._joined = True
|
||||
self._joining = False
|
||||
|
||||
logger.info(f"Joined {self._room_url}")
|
||||
|
||||
if self._token and self._params.transcription_enabled:
|
||||
logger.info(
|
||||
f"Enabling transcription with settings {self._params.transcription_settings}")
|
||||
self._client.start_transcription(
|
||||
self._params.transcription_settings.model_dump())
|
||||
|
||||
self._callbacks.on_joined(data["participants"]["local"])
|
||||
else:
|
||||
error_msg = f"Error joining {self._room_url}: {error}"
|
||||
logger.error(error_msg)
|
||||
self._callbacks.on_error(error_msg)
|
||||
self._sync_response["join"].task_done()
|
||||
except queue.Empty:
|
||||
error_msg = f"Time out joining {self._room_url}"
|
||||
logger.error(error_msg)
|
||||
self._callbacks.on_error(error_msg)
|
||||
return await asyncio.wait_for(future, timeout=10)
|
||||
|
||||
async def leave(self):
|
||||
# Transport not joined, ignore.
|
||||
@@ -310,34 +354,37 @@ class DailyTransportClient(EventHandler):
|
||||
self._joined = False
|
||||
self._leaving = True
|
||||
|
||||
await self._loop.run_in_executor(self._executor, self._leave)
|
||||
|
||||
def _leave(self):
|
||||
logger.info(f"Leaving {self._room_url}")
|
||||
|
||||
if self._params.transcription_enabled:
|
||||
self._client.stop_transcription()
|
||||
await self._stop_transcription()
|
||||
|
||||
self._client.leave(completion=self._call_left)
|
||||
|
||||
self._handle_leave_response()
|
||||
|
||||
def _handle_leave_response(self):
|
||||
try:
|
||||
error = self._sync_response["leave"].get(timeout=10)
|
||||
error = await self._leave()
|
||||
if not error:
|
||||
self._leaving = False
|
||||
logger.info(f"Left {self._room_url}")
|
||||
self._callbacks.on_left()
|
||||
await self._callbacks.on_left()
|
||||
else:
|
||||
error_msg = f"Error leaving {self._room_url}: {error}"
|
||||
logger.error(error_msg)
|
||||
self._callbacks.on_error(error_msg)
|
||||
self._sync_response["leave"].task_done()
|
||||
except queue.Empty:
|
||||
await self._callbacks.on_error(error_msg)
|
||||
except asyncio.TimeoutError:
|
||||
error_msg = f"Time out leaving {self._room_url}"
|
||||
logger.error(error_msg)
|
||||
self._callbacks.on_error(error_msg)
|
||||
await self._callbacks.on_error(error_msg)
|
||||
|
||||
async def _stop_transcription(self):
|
||||
future = self._loop.create_future()
|
||||
self._client.stop_transcription(completion=completion_callback(future))
|
||||
error = await future
|
||||
if error:
|
||||
logger.error(f"Unable to stop transcription: {error}")
|
||||
|
||||
async def _leave(self):
|
||||
future = self._loop.create_future()
|
||||
self._client.leave(completion=completion_callback(future))
|
||||
return await asyncio.wait_for(future, timeout=10)
|
||||
|
||||
async def cleanup(self):
|
||||
await self._loop.run_in_executor(self._executor, self._cleanup)
|
||||
@@ -399,25 +446,28 @@ class DailyTransportClient(EventHandler):
|
||||
#
|
||||
|
||||
def on_app_message(self, message: Any, sender: str):
|
||||
self._callbacks.on_app_message(message, sender)
|
||||
self._call_async_callback(self._callbacks.on_app_message, message, sender)
|
||||
|
||||
def on_call_state_updated(self, state: str):
|
||||
self._callbacks.on_call_state_updated(state)
|
||||
self._call_async_callback(self._callbacks.on_call_state_updated, state)
|
||||
|
||||
def on_dialin_ready(self, sip_endpoint: str):
|
||||
self._callbacks.on_dialin_ready(sip_endpoint)
|
||||
self._call_async_callback(self._callbacks.on_dialin_ready, sip_endpoint)
|
||||
|
||||
def on_dialout_answered(self, data: Any):
|
||||
self._call_async_callback(self._callbacks.on_dialout_answered, data)
|
||||
|
||||
def on_dialout_connected(self, data: Any):
|
||||
self._callbacks.on_dialout_connected(data)
|
||||
self._call_async_callback(self._callbacks.on_dialout_connected, data)
|
||||
|
||||
def on_dialout_stopped(self, data: Any):
|
||||
self._callbacks.on_dialout_stopped(data)
|
||||
self._call_async_callback(self._callbacks.on_dialout_stopped, data)
|
||||
|
||||
def on_dialout_error(self, data: Any):
|
||||
self._callbacks.on_dialout_error(data)
|
||||
self._call_async_callback(self._callbacks.on_dialout_error, data)
|
||||
|
||||
def on_dialout_warning(self, data: Any):
|
||||
self._callbacks.on_dialout_warning(data)
|
||||
self._call_async_callback(self._callbacks.on_dialout_warning, data)
|
||||
|
||||
def on_participant_joined(self, participant):
|
||||
id = participant["id"]
|
||||
@@ -425,15 +475,15 @@ class DailyTransportClient(EventHandler):
|
||||
|
||||
if not self._other_participant_has_joined:
|
||||
self._other_participant_has_joined = True
|
||||
self._callbacks.on_first_participant_joined(participant)
|
||||
self._call_async_callback(self._callbacks.on_first_participant_joined, participant)
|
||||
|
||||
self._callbacks.on_participant_joined(participant)
|
||||
self._call_async_callback(self._callbacks.on_participant_joined, participant)
|
||||
|
||||
def on_participant_left(self, participant, reason):
|
||||
id = participant["id"]
|
||||
logger.info(f"Participant left {id}")
|
||||
|
||||
self._callbacks.on_participant_left(participant, reason)
|
||||
self._call_async_callback(self._callbacks.on_participant_left, participant, reason)
|
||||
|
||||
def on_transcription_message(self, message: Mapping[str, Any]):
|
||||
participant_id = ""
|
||||
@@ -442,7 +492,7 @@ class DailyTransportClient(EventHandler):
|
||||
|
||||
if participant_id in self._transcription_renderers:
|
||||
callback = self._transcription_renderers[participant_id]
|
||||
callback(participant_id, message)
|
||||
self._call_async_callback(callback, participant_id, message)
|
||||
|
||||
def on_transcription_error(self, message):
|
||||
logger.error(f"Transcription error: {message}")
|
||||
@@ -457,18 +507,19 @@ class DailyTransportClient(EventHandler):
|
||||
# Daily (CallClient callbacks)
|
||||
#
|
||||
|
||||
def _call_joined(self, data, error):
|
||||
self._sync_response["join"].put((data, error))
|
||||
|
||||
def _call_left(self, error):
|
||||
self._sync_response["leave"].put(error)
|
||||
|
||||
def _video_frame_received(self, participant_id, video_frame):
|
||||
callback = self._video_renderers[participant_id]
|
||||
callback(participant_id,
|
||||
video_frame.buffer,
|
||||
(video_frame.width, video_frame.height),
|
||||
video_frame.color_format)
|
||||
self._call_async_callback(
|
||||
callback,
|
||||
participant_id,
|
||||
video_frame.buffer,
|
||||
(video_frame.width,
|
||||
video_frame.height),
|
||||
video_frame.color_format)
|
||||
|
||||
def _call_async_callback(self, callback, *args):
|
||||
future = asyncio.run_coroutine_threadsafe(callback(*args), self._loop)
|
||||
future.result()
|
||||
|
||||
|
||||
class DailyInputTransport(BaseInputTransport):
|
||||
@@ -487,8 +538,6 @@ class DailyInputTransport(BaseInputTransport):
|
||||
num_channels=self._params.audio_in_channels)
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
if self._running:
|
||||
return
|
||||
# Parent start.
|
||||
await super().start(frame)
|
||||
# Join the room.
|
||||
@@ -496,19 +545,27 @@ class DailyInputTransport(BaseInputTransport):
|
||||
# Create audio task. It reads audio frames from Daily and push them
|
||||
# internally for VAD processing.
|
||||
if self._params.audio_in_enabled or self._params.vad_enabled:
|
||||
self._audio_in_thread = self._loop.run_in_executor(
|
||||
self._executor, self._audio_in_thread_handler)
|
||||
self._audio_in_task = self.get_event_loop().create_task(self._audio_in_task_handler())
|
||||
|
||||
async def stop(self):
|
||||
if not self._running:
|
||||
return
|
||||
# Parent stop. This will set _running to False.
|
||||
await super().stop()
|
||||
async def stop(self, frame: EndFrame):
|
||||
# Parent stop.
|
||||
await super().stop(frame)
|
||||
# Leave the room.
|
||||
await self._client.leave()
|
||||
# Stop audio thread.
|
||||
if self._params.audio_in_enabled or self._params.vad_enabled:
|
||||
await self._audio_in_thread
|
||||
self._audio_in_task.cancel()
|
||||
await self._audio_in_task
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
# Parent stop.
|
||||
await super().cancel(frame)
|
||||
# Leave the room.
|
||||
await self._client.leave()
|
||||
# Stop audio thread.
|
||||
if self._params.audio_in_enabled or self._params.vad_enabled:
|
||||
self._audio_in_task.cancel()
|
||||
await self._audio_in_task
|
||||
|
||||
async def cleanup(self):
|
||||
await super().cleanup()
|
||||
@@ -531,26 +588,25 @@ class DailyInputTransport(BaseInputTransport):
|
||||
# Frames
|
||||
#
|
||||
|
||||
def push_transcription_frame(self, frame: TranscriptionFrame | InterimTranscriptionFrame):
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._internal_push_frame(frame), self.get_event_loop())
|
||||
future.result()
|
||||
async def push_transcription_frame(self, frame: TranscriptionFrame | InterimTranscriptionFrame):
|
||||
await self._internal_push_frame(frame)
|
||||
|
||||
def push_app_message(self, message: Any, sender: str):
|
||||
async def push_app_message(self, message: Any, sender: str):
|
||||
frame = DailyTransportMessageFrame(message=message, participant_id=sender)
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._internal_push_frame(frame), self.get_event_loop())
|
||||
future.result()
|
||||
await self._internal_push_frame(frame)
|
||||
|
||||
#
|
||||
# Audio in
|
||||
#
|
||||
|
||||
def _audio_in_thread_handler(self):
|
||||
while self._running:
|
||||
frame = self._client.read_next_audio_frame()
|
||||
if frame:
|
||||
self.push_audio_frame(frame)
|
||||
async def _audio_in_task_handler(self):
|
||||
while True:
|
||||
try:
|
||||
frame = await self._client.read_next_audio_frame()
|
||||
if frame:
|
||||
await self.push_audio_frame(frame)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
#
|
||||
# Camera in
|
||||
@@ -580,7 +636,7 @@ class DailyInputTransport(BaseInputTransport):
|
||||
if participant_id in self._video_renderers:
|
||||
self._video_renderers[participant_id]["render_next_frame"] = True
|
||||
|
||||
def _on_participant_video_frame(self, participant_id: str, buffer, size, format):
|
||||
async def _on_participant_video_frame(self, participant_id: str, buffer, size, format):
|
||||
render_frame = False
|
||||
|
||||
curr_time = time.time()
|
||||
@@ -600,9 +656,7 @@ class DailyInputTransport(BaseInputTransport):
|
||||
image=buffer,
|
||||
size=size,
|
||||
format=format)
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._internal_push_frame(frame), self.get_event_loop())
|
||||
future.result()
|
||||
await self._internal_push_frame(frame)
|
||||
|
||||
self._video_renderers[participant_id]["timestamp"] = curr_time
|
||||
|
||||
@@ -615,18 +669,20 @@ class DailyOutputTransport(BaseOutputTransport):
|
||||
self._client = client
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
if self._running:
|
||||
return
|
||||
# Parent start.
|
||||
await super().start(frame)
|
||||
# Join the room.
|
||||
await self._client.join()
|
||||
|
||||
async def stop(self):
|
||||
if not self._running:
|
||||
return
|
||||
# Parent stop. This will set _running to False.
|
||||
await super().stop()
|
||||
async def stop(self, frame: EndFrame):
|
||||
# Parent stop.
|
||||
await super().stop(frame)
|
||||
# Leave the room.
|
||||
await self._client.leave()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
# Parent stop.
|
||||
await super().cancel(frame)
|
||||
# Leave the room.
|
||||
await self._client.leave()
|
||||
|
||||
@@ -634,24 +690,27 @@ class DailyOutputTransport(BaseOutputTransport):
|
||||
await super().cleanup()
|
||||
await self._client.cleanup()
|
||||
|
||||
def send_message(self, frame: DailyTransportMessageFrame):
|
||||
self._client.send_message(frame)
|
||||
async def send_message(self, frame: TransportMessageFrame):
|
||||
await self._client.send_message(frame)
|
||||
|
||||
async def send_metrics(self, frame: MetricsFrame):
|
||||
metrics = {}
|
||||
if frame.ttfb:
|
||||
metrics["ttfb"] = frame.ttfb
|
||||
if frame.processing:
|
||||
metrics["processing"] = frame.processing
|
||||
|
||||
def send_metrics(self, frame: MetricsFrame):
|
||||
ttfb = [{"name": n, "time": t} for n, t in frame.ttfb.items()]
|
||||
message = DailyTransportMessageFrame(message={
|
||||
"type": "pipecat-metrics",
|
||||
"metrics": {
|
||||
"ttfb": ttfb
|
||||
},
|
||||
"metrics": metrics
|
||||
})
|
||||
self._client.send_message(message)
|
||||
await self._client.send_message(message)
|
||||
|
||||
def write_raw_audio_frames(self, frames: bytes):
|
||||
self._client.write_raw_audio_frames(frames)
|
||||
async def write_raw_audio_frames(self, frames: bytes):
|
||||
await self._client.write_raw_audio_frames(frames)
|
||||
|
||||
def write_frame_to_camera(self, frame: ImageRawFrame):
|
||||
self._client.write_frame_to_camera(frame)
|
||||
async def write_frame_to_camera(self, frame: ImageRawFrame):
|
||||
await self._client.write_frame_to_camera(frame)
|
||||
|
||||
|
||||
class DailyTransport(BaseTransport):
|
||||
@@ -674,6 +733,7 @@ class DailyTransport(BaseTransport):
|
||||
on_app_message=self._on_app_message,
|
||||
on_call_state_updated=self._on_call_state_updated,
|
||||
on_dialin_ready=self._on_dialin_ready,
|
||||
on_dialout_answered=self._on_dialout_answered,
|
||||
on_dialout_connected=self._on_dialout_connected,
|
||||
on_dialout_stopped=self._on_dialout_stopped,
|
||||
on_dialout_error=self._on_dialout_error,
|
||||
@@ -696,6 +756,7 @@ class DailyTransport(BaseTransport):
|
||||
self._register_event_handler("on_app_message")
|
||||
self._register_event_handler("on_call_state_updated")
|
||||
self._register_event_handler("on_dialin_ready")
|
||||
self._register_event_handler("on_dialout_answered")
|
||||
self._register_event_handler("on_dialout_connected")
|
||||
self._register_event_handler("on_dialout_stopped")
|
||||
self._register_event_handler("on_dialout_error")
|
||||
@@ -768,24 +829,24 @@ class DailyTransport(BaseTransport):
|
||||
self._input.capture_participant_video(
|
||||
participant_id, framerate, video_source, color_format)
|
||||
|
||||
def _on_joined(self, participant):
|
||||
self._call_async_event_handler("on_joined", participant)
|
||||
async def _on_joined(self, participant):
|
||||
await self._call_event_handler("on_joined", participant)
|
||||
|
||||
def _on_left(self):
|
||||
self._call_async_event_handler("on_left")
|
||||
async def _on_left(self):
|
||||
await self._call_event_handler("on_left")
|
||||
|
||||
def _on_error(self, error):
|
||||
async def _on_error(self, error):
|
||||
# TODO(aleix): Report error to input/output transports. The one managing
|
||||
# the client should report the error.
|
||||
pass
|
||||
|
||||
def _on_app_message(self, message: Any, sender: str):
|
||||
async def _on_app_message(self, message: Any, sender: str):
|
||||
if self._input:
|
||||
self._input.push_app_message(message, sender)
|
||||
self._call_async_event_handler("on_app_message", message, sender)
|
||||
await self._input.push_app_message(message, sender)
|
||||
await self._call_event_handler("on_app_message", message, sender)
|
||||
|
||||
def _on_call_state_updated(self, state: str):
|
||||
self._call_async_event_handler("on_call_state_updated", state)
|
||||
async def _on_call_state_updated(self, state: str):
|
||||
await self._call_event_handler("on_call_state_updated", state)
|
||||
|
||||
async def _handle_dialin_ready(self, sip_endpoint: str):
|
||||
if not self._params.dialin_settings:
|
||||
@@ -815,36 +876,39 @@ class DailyTransport(BaseTransport):
|
||||
logger.debug("Event dialin-ready was handled successfully")
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Timeout handling dialin-ready event ({url})")
|
||||
except BaseException as e:
|
||||
logger.error(f"Error handling dialin-ready event ({url}): {e}")
|
||||
except Exception as e:
|
||||
logger.exception(f"Error handling dialin-ready event ({url}): {e}")
|
||||
|
||||
def _on_dialin_ready(self, sip_endpoint):
|
||||
async def _on_dialin_ready(self, sip_endpoint):
|
||||
if self._params.dialin_settings:
|
||||
asyncio.run_coroutine_threadsafe(self._handle_dialin_ready(sip_endpoint), self._loop)
|
||||
self._call_async_event_handler("on_dialin_ready", sip_endpoint)
|
||||
await self._handle_dialin_ready(sip_endpoint)
|
||||
await self._call_event_handler("on_dialin_ready", sip_endpoint)
|
||||
|
||||
def _on_dialout_connected(self, data):
|
||||
self._call_async_event_handler("on_dialout_connected", data)
|
||||
async def _on_dialout_answered(self, data):
|
||||
await self._call_event_handler("on_dialout_answered", data)
|
||||
|
||||
def _on_dialout_stopped(self, data):
|
||||
self._call_async_event_handler("on_dialout_stopped", data)
|
||||
async def _on_dialout_connected(self, data):
|
||||
await self._call_event_handler("on_dialout_connected", data)
|
||||
|
||||
def _on_dialout_error(self, data):
|
||||
self._call_async_event_handler("on_dialout_error", data)
|
||||
async def _on_dialout_stopped(self, data):
|
||||
await self._call_event_handler("on_dialout_stopped", data)
|
||||
|
||||
def _on_dialout_warning(self, data):
|
||||
self._call_async_event_handler("on_dialout_warning", data)
|
||||
async def _on_dialout_error(self, data):
|
||||
await self._call_event_handler("on_dialout_error", data)
|
||||
|
||||
def _on_participant_joined(self, participant):
|
||||
self._call_async_event_handler("on_participant_joined", participant)
|
||||
async def _on_dialout_warning(self, data):
|
||||
await self._call_event_handler("on_dialout_warning", data)
|
||||
|
||||
def _on_participant_left(self, participant, reason):
|
||||
self._call_async_event_handler("on_participant_left", participant, reason)
|
||||
async def _on_participant_joined(self, participant):
|
||||
await self._call_event_handler("on_participant_joined", participant)
|
||||
|
||||
def _on_first_participant_joined(self, participant):
|
||||
self._call_async_event_handler("on_first_participant_joined", participant)
|
||||
async def _on_participant_left(self, participant, reason):
|
||||
await self._call_event_handler("on_participant_left", participant, reason)
|
||||
|
||||
def _on_transcription_message(self, participant_id, message):
|
||||
async def _on_first_participant_joined(self, participant):
|
||||
await self._call_event_handler("on_first_participant_joined", participant)
|
||||
|
||||
async def _on_transcription_message(self, participant_id, message):
|
||||
text = message["text"]
|
||||
timestamp = message["timestamp"]
|
||||
is_final = message["rawResponse"]["is_final"]
|
||||
@@ -855,9 +919,4 @@ class DailyTransport(BaseTransport):
|
||||
frame = InterimTranscriptionFrame(text, participant_id, timestamp)
|
||||
|
||||
if self._input:
|
||||
self._input.push_transcription_frame(frame)
|
||||
|
||||
def _call_async_event_handler(self, event_name: str, *args, **kwargs):
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._call_event_handler(event_name, *args, **kwargs), self._loop)
|
||||
future.result()
|
||||
await self._input.push_transcription_frame(frame)
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import audioop
|
||||
import numpy as np
|
||||
import pyloudnorm as pyln
|
||||
|
||||
@@ -31,3 +32,23 @@ def calculate_audio_volume(audio: bytes, sample_rate: int) -> float:
|
||||
|
||||
def exp_smoothing(value: float, prev_value: float, factor: float) -> float:
|
||||
return prev_value + factor * (value - prev_value)
|
||||
|
||||
|
||||
def ulaw_8000_to_pcm_16000(ulaw_8000_bytes):
|
||||
# Convert μ-law to PCM
|
||||
pcm_8000_bytes = audioop.ulaw2lin(ulaw_8000_bytes, 2)
|
||||
|
||||
# Resample from 8000 Hz to 16000 Hz
|
||||
pcm_16000_bytes = audioop.ratecv(pcm_8000_bytes, 2, 1, 8000, 16000, None)[0]
|
||||
|
||||
return pcm_16000_bytes
|
||||
|
||||
|
||||
def pcm_16000_to_ulaw_8000(pcm_16000_bytes):
|
||||
# Resample from 16000 Hz to 8000 Hz
|
||||
pcm_8000_bytes = audioop.ratecv(pcm_16000_bytes, 2, 1, 16000, 8000, None)[0]
|
||||
|
||||
# Convert PCM to μ-law
|
||||
ulaw_8000_bytes = audioop.lin2ulaw(pcm_8000_bytes, 2)
|
||||
|
||||
return ulaw_8000_bytes
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import List
|
||||
from pipecat.processors.frame_processor import FrameProcessor
|
||||
|
||||
|
||||
class TestException(BaseException):
|
||||
class TestException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
11
src/pipecat/utils/time.py
Normal file
11
src/pipecat/utils/time.py
Normal file
@@ -0,0 +1,11 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import datetime
|
||||
|
||||
|
||||
def time_now_iso8601() -> str:
|
||||
return datetime.datetime.now(datetime.timezone.utc).isoformat(timespec="milliseconds")
|
||||
@@ -33,14 +33,27 @@ _MODEL_RESET_STATES_TIME = 5.0
|
||||
|
||||
class SileroVADAnalyzer(VADAnalyzer):
|
||||
|
||||
def __init__(self, sample_rate=16000, params: VADParams = VADParams()):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
sample_rate: int = 16000,
|
||||
version: str = "v5.0",
|
||||
force_reload: bool = False,
|
||||
skip_validation: bool = True,
|
||||
trust_repo: bool = True,
|
||||
params: VADParams = VADParams()):
|
||||
super().__init__(sample_rate=sample_rate, num_channels=1, params=params)
|
||||
|
||||
if sample_rate != 16000 and sample_rate != 8000:
|
||||
raise ValueError("Silero VAD sample rate needs to be 16000 or 8000")
|
||||
|
||||
logger.debug("Loading Silero VAD model...")
|
||||
|
||||
(self._model, utils) = torch.hub.load(
|
||||
repo_or_dir="snakers4/silero-vad", model="silero_vad", force_reload=False
|
||||
)
|
||||
(self._model, _) = torch.hub.load(repo_or_dir=f"snakers4/silero-vad:{version}",
|
||||
model="silero_vad",
|
||||
force_reload=force_reload,
|
||||
skip_validation=skip_validation,
|
||||
trust_repo=trust_repo)
|
||||
|
||||
self._last_reset_time = 0
|
||||
|
||||
@@ -51,7 +64,7 @@ class SileroVADAnalyzer(VADAnalyzer):
|
||||
#
|
||||
|
||||
def num_frames_required(self) -> int:
|
||||
return int(self.sample_rate / 100) * 4 # 40ms
|
||||
return 512 if self.sample_rate == 16000 else 256
|
||||
|
||||
def voice_confidence(self, buffer) -> float:
|
||||
try:
|
||||
@@ -69,9 +82,9 @@ class SileroVADAnalyzer(VADAnalyzer):
|
||||
self._last_reset_time = curr_time
|
||||
|
||||
return new_confidence
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
# This comes from an empty audio array
|
||||
logger.error(f"Error analyzing audio with Silero VAD: {e}")
|
||||
logger.exception(f"Error analyzing audio with Silero VAD: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
@@ -79,12 +92,23 @@ class SileroVAD(FrameProcessor):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
sample_rate: int = 16000,
|
||||
version: str = "v5.0",
|
||||
force_reload: bool = False,
|
||||
skip_validation: bool = True,
|
||||
trust_repo: bool = True,
|
||||
vad_params: VADParams = VADParams(),
|
||||
audio_passthrough: bool = False):
|
||||
super().__init__()
|
||||
|
||||
self._vad_analyzer = SileroVADAnalyzer(sample_rate=sample_rate, params=vad_params)
|
||||
self._vad_analyzer = SileroVADAnalyzer(
|
||||
sample_rate=sample_rate,
|
||||
version=version,
|
||||
force_reload=force_reload,
|
||||
skip_validation=skip_validation,
|
||||
trust_repo=trust_repo,
|
||||
params=vad_params)
|
||||
self._audio_passthrough = audio_passthrough
|
||||
|
||||
self._processor_vad_state: VADState = VADState.QUIET
|
||||
|
||||
@@ -28,7 +28,7 @@ class VADParams(BaseModel):
|
||||
|
||||
class VADAnalyzer:
|
||||
|
||||
def __init__(self, sample_rate: int, num_channels: int, params: VADParams):
|
||||
def __init__(self, *, sample_rate: int, num_channels: int, params: VADParams):
|
||||
self._sample_rate = sample_rate
|
||||
self._num_channels = num_channels
|
||||
self._params = params
|
||||
|
||||
@@ -8,8 +8,6 @@ from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.frames.frames import (
|
||||
LLMFullResponseStartFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMResponseEndFrame,
|
||||
LLMResponseStartFrame,
|
||||
TextFrame
|
||||
)
|
||||
from pipecat.utils.test_frame_processor import TestFrameProcessor
|
||||
@@ -64,7 +62,7 @@ if __name__ == "__main__":
|
||||
llm.register_function("get_current_weather", get_weather_from_api)
|
||||
t = TestFrameProcessor([
|
||||
LLMFullResponseStartFrame,
|
||||
[LLMResponseStartFrame, TextFrame, LLMResponseEndFrame],
|
||||
TextFrame,
|
||||
LLMFullResponseEndFrame
|
||||
])
|
||||
llm.link(t)
|
||||
@@ -98,7 +96,7 @@ if __name__ == "__main__":
|
||||
llm.register_function("get_current_weather", get_weather_from_api)
|
||||
t = TestFrameProcessor([
|
||||
LLMFullResponseStartFrame,
|
||||
[LLMResponseStartFrame, TextFrame, LLMResponseEndFrame],
|
||||
TextFrame,
|
||||
LLMFullResponseEndFrame
|
||||
])
|
||||
llm.link(t)
|
||||
@@ -121,7 +119,7 @@ if __name__ == "__main__":
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
t = TestFrameProcessor([
|
||||
LLMFullResponseStartFrame,
|
||||
[LLMResponseStartFrame, TextFrame, LLMResponseEndFrame],
|
||||
TextFrame,
|
||||
LLMFullResponseEndFrame
|
||||
])
|
||||
llm = OpenAILLMService(
|
||||
|
||||
@@ -2,8 +2,8 @@ import unittest
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from pipecat.services.ai_services import AIService
|
||||
from pipecat.pipeline.frames import EndFrame, Frame, TextFrame
|
||||
from pipecat.services.ai_services import AIService, match_endofsentence
|
||||
from pipecat.frames.frames import EndFrame, Frame, TextFrame
|
||||
|
||||
|
||||
class SimpleAIService(AIService):
|
||||
@@ -27,6 +27,22 @@ class TestBaseAIService(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
self.assertEqual(input_frames, output_frames)
|
||||
|
||||
async def test_endofsentence(self):
|
||||
assert match_endofsentence("This is a sentence.")
|
||||
assert match_endofsentence("This is a sentence! ")
|
||||
assert match_endofsentence("This is a sentence?")
|
||||
assert match_endofsentence("This is a sentence:")
|
||||
assert not match_endofsentence("This is not a sentence")
|
||||
assert not match_endofsentence("This is not a sentence,")
|
||||
assert not match_endofsentence("This is not a sentence, ")
|
||||
assert not match_endofsentence("Ok, Mr. Smith let's ")
|
||||
assert not match_endofsentence("Dr. Walker, I presume ")
|
||||
assert not match_endofsentence("Prof. Walker, I presume ")
|
||||
assert not match_endofsentence("zweitens, und 3.")
|
||||
assert not match_endofsentence("Heute ist Dienstag, der 3.") # 3. Juli 2024
|
||||
assert not match_endofsentence("America, or the U.") # U.S.A.
|
||||
assert not match_endofsentence("It still early, it's 3:00 a.") # 3:00 a.m.
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user