Compare commits
202 Commits
v0.0.83
...
cb/frame-p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
05863ded53 | ||
|
|
6ab4a48d8f | ||
|
|
89e0092159 | ||
|
|
0de31dab79 | ||
|
|
10ff93307d | ||
|
|
414c9e3bc8 | ||
|
|
ba64f126a3 | ||
|
|
d28e3881a7 | ||
|
|
7df7395dd1 | ||
|
|
0885bc9cdf | ||
|
|
0204f6a95d | ||
|
|
b0bf653f04 | ||
|
|
e8a676eb36 | ||
|
|
ca96eef1f3 | ||
|
|
8e1637d6c7 | ||
|
|
367200c0ad | ||
|
|
766e1948a6 | ||
|
|
f369683b8b | ||
|
|
461025d1cc | ||
|
|
ac88706f38 | ||
|
|
93a89449b8 | ||
|
|
199bf72945 | ||
|
|
d20e4125f6 | ||
|
|
c1baed642e | ||
|
|
33ef68573f | ||
|
|
3c1b41df13 | ||
|
|
fca4ecc73c | ||
|
|
cfa333508b | ||
|
|
9e7260393a | ||
|
|
073b585c52 | ||
|
|
81c2e51bec | ||
|
|
42344125b1 | ||
|
|
db5bcfaa51 | ||
|
|
615239b7d2 | ||
|
|
27f1e9dd69 | ||
|
|
bd760deff2 | ||
|
|
8bc3c89140 | ||
|
|
2cd2567a37 | ||
|
|
5b55988846 | ||
|
|
a12392182c | ||
|
|
b814b70e1e | ||
|
|
a1f84e1b50 | ||
|
|
0839b48da8 | ||
|
|
de51637b77 | ||
|
|
e1b1dc16ec | ||
|
|
1fe27eb0a2 | ||
|
|
d7e1389497 | ||
|
|
8c7230aa8f | ||
|
|
2cf71239b0 | ||
|
|
ec2c62e32b | ||
|
|
38ce85e9a0 | ||
|
|
2279e5a899 | ||
|
|
cce6eb5d87 | ||
|
|
c2b98ae557 | ||
|
|
727eb12b16 | ||
|
|
ba96bd05d3 | ||
|
|
8ead309f8d | ||
|
|
fad0e55c64 | ||
|
|
74b1af56a0 | ||
|
|
6924850ec4 | ||
|
|
dfe7815dc5 | ||
|
|
69f0a75882 | ||
|
|
cca90791c4 | ||
|
|
f2a5d408de | ||
|
|
044c6eba46 | ||
|
|
db71089f5e | ||
|
|
f861f5066f | ||
|
|
81cede2c60 | ||
|
|
7603203230 | ||
|
|
8569b61598 | ||
|
|
fe42187dc1 | ||
|
|
999e88c942 | ||
|
|
c04df2f28b | ||
|
|
100ef0ab5c | ||
|
|
42886d7105 | ||
|
|
22cbba002a | ||
|
|
0a043154f2 | ||
|
|
5e322eba9e | ||
|
|
11d0c3d46d | ||
|
|
c873798ce5 | ||
|
|
95f72f6dce | ||
|
|
d8cd28bb8b | ||
|
|
c2df6c8aee | ||
|
|
82478be861 | ||
|
|
0f2b7bc01b | ||
|
|
1b2a5df017 | ||
|
|
2f496ac74f | ||
|
|
22633a63b0 | ||
|
|
e5ed0424e4 | ||
|
|
786387722a | ||
|
|
9f82c6b4a4 | ||
|
|
99cfcb1d4e | ||
|
|
d595676436 | ||
|
|
0190812ee8 | ||
|
|
2a24061bbb | ||
|
|
89f7e7d199 | ||
|
|
384814e640 | ||
|
|
ab4364b833 | ||
|
|
fafdadad3c | ||
|
|
05dc2fa916 | ||
|
|
0c30cc6ea6 | ||
|
|
c26d336e34 | ||
|
|
37b6198787 | ||
|
|
3c271da94c | ||
|
|
be28d3f93b | ||
|
|
d2f210e960 | ||
|
|
57add41971 | ||
|
|
74b38b59d6 | ||
|
|
dac58deffc | ||
|
|
aff11f5121 | ||
|
|
a4023d3915 | ||
|
|
d6543d244d | ||
|
|
fafcd79870 | ||
|
|
6a717fbbd1 | ||
|
|
9b3f6927c2 | ||
|
|
0b21f8a6bd | ||
|
|
8249b014f0 | ||
|
|
9d9f10ae0e | ||
|
|
e27b23694d | ||
|
|
66ce5fe6bd | ||
|
|
a9b53dc800 | ||
|
|
818352a300 | ||
|
|
3e9fc7be19 | ||
|
|
a2e76bcad8 | ||
|
|
8e8e42717b | ||
|
|
b31322e38e | ||
|
|
908325484d | ||
|
|
dd6ff789c7 | ||
|
|
f4938e0fad | ||
|
|
e8f60c7c6f | ||
|
|
fedb8a201f | ||
|
|
8ccd220a60 | ||
|
|
fe79de8f27 | ||
|
|
176573c342 | ||
|
|
75f9914f49 | ||
|
|
f4d6715e32 | ||
|
|
38f6e33f97 | ||
|
|
1c3e4e34e5 | ||
|
|
623c660027 | ||
|
|
a3e65ab3b5 | ||
|
|
f3a4b416df | ||
|
|
aa471a4ef5 | ||
|
|
d55133a44f | ||
|
|
0f1cf81691 | ||
|
|
ac4d335799 | ||
|
|
e65385c151 | ||
|
|
0bb7df7a6b | ||
|
|
daee1ddf3b | ||
|
|
1cccb97ccf | ||
|
|
d7794abf21 | ||
|
|
6a6a63a532 | ||
|
|
6edb6fed41 | ||
|
|
a537382816 | ||
|
|
46deaada70 | ||
|
|
7366b1aee0 | ||
|
|
dbc52bc6b0 | ||
|
|
d6432589f6 | ||
|
|
13b73d4406 | ||
|
|
85d8282f7e | ||
|
|
070690ec64 | ||
|
|
b9c96fd623 | ||
|
|
f8b2ab6331 | ||
|
|
ea3f7e3c34 | ||
|
|
2f44f88b08 | ||
|
|
25747a001b | ||
|
|
fbe4338440 | ||
|
|
64b4c65728 | ||
|
|
29442969a9 | ||
|
|
dc2e1d4ad3 | ||
|
|
5477dfcbea | ||
|
|
516f0e08ab | ||
|
|
246f9f3325 | ||
|
|
4699ee8d86 | ||
|
|
3d850e8cc5 | ||
|
|
6e734a37f9 | ||
|
|
f72ca2fd7d | ||
|
|
0826d72f74 | ||
|
|
ba5ebfa0ec | ||
|
|
dc3412b2df | ||
|
|
b2e9fd9341 | ||
|
|
c11b207c97 | ||
|
|
d6205027cf | ||
|
|
986160c077 | ||
|
|
b56ff86fee | ||
|
|
5c574eaad9 | ||
|
|
2df231143a | ||
|
|
e3597801d4 | ||
|
|
65298ab792 | ||
|
|
b609b02614 | ||
|
|
f2b50c14d2 | ||
|
|
ee3b023986 | ||
|
|
0d9e1190d7 | ||
|
|
595a7c7fbe | ||
|
|
586586f743 | ||
|
|
a1c6ad539d | ||
|
|
daf7fed8b3 | ||
|
|
a26647c433 | ||
|
|
83f64ecd3b | ||
|
|
0a3e98857e | ||
|
|
2ee481d541 | ||
|
|
48b3ad8f8f | ||
|
|
8bbdc7c8d1 |
233
CHANGELOG.md
233
CHANGELOG.md
@@ -5,14 +5,239 @@ All notable changes to **Pipecat** will be documented in this file.
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
### Added
|
||||
|
||||
- Added `on_before_disconnect` synchronous event to `DailyTransport` and
|
||||
`LiveKitTransport`.
|
||||
|
||||
- It is now possible to register synchronous event handlers. By default, all
|
||||
event handlers are executed in a separate task. However, in some cases we want
|
||||
to guarantee order of execution, for example, executing something before
|
||||
disconnecting a transport.
|
||||
|
||||
```python
|
||||
self._register_event_handler("on_event_name", sync=True)
|
||||
```
|
||||
|
||||
- Added support for global location in `GoogleVertexLLMService`. The service now
|
||||
supports both regional locations (e.g., "us-east4") and the "global" location
|
||||
for Vertex AI endpoints. When using "global" location, the service will use
|
||||
`aiplatform.googleapis.com` as the API host instead of the regional format.
|
||||
|
||||
- Added `on_pipeline_finished` event to `PipelineTask`. This event will get
|
||||
fired when the pipeline is done running. This can be the result of a
|
||||
`StopFrame`, `CancelFrame` or `EndFrame`.
|
||||
|
||||
```python
|
||||
@task.event_handler("on_pipeline_finished")
|
||||
async def on_pipeline_finished(task: PipelineTask, frame: Frame):
|
||||
...
|
||||
```
|
||||
|
||||
### Changed
|
||||
|
||||
- Updated Silero VAD model to v6.
|
||||
|
||||
- Updated `livekit` to 1.0.13.
|
||||
|
||||
- `torch` and `torchaudio` are no longer required for running Smart Turn
|
||||
locally. This avoids gigabytes of dependencies being installed.
|
||||
|
||||
- Updated `websockets` dependency to support version 15.0. Removed deprecated
|
||||
usage of `ConnectionClosed.code` and `ConnectionClosed.reason` attributes in
|
||||
`AWSTranscribeSTTService` for compatibility.
|
||||
|
||||
- Refactored `pyproject.toml` to reduce websockets dependency repetition using
|
||||
self-referencing extras. All websockets-dependent services now reference a
|
||||
shared `websockets-base` extra.
|
||||
|
||||
### Deprecated
|
||||
|
||||
- `GladiaSTTService`'s `confidence` arg is deprecated. `confidence` is no
|
||||
longer needed to determine which transcription or translation frames to
|
||||
emit.
|
||||
|
||||
- `PipelineTask` events `on_pipeline_stopped`, `on_pipeline_ended` and
|
||||
`on_pipeline_cancelled` are now deprecated. Use `on_pipeline_finished`
|
||||
instead.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed an issue where multiple handlers for an event would not run in parallel.
|
||||
|
||||
- Fixed `DailyTransport.sip_call_transfer()` to automatically use the session
|
||||
ID from the `on_dialin_connected` event, when not explicitly provided. Now
|
||||
supports cold transfers (from incoming dial-in calls) by automatically
|
||||
tracking session IDs from connection events.
|
||||
|
||||
- Fixed a memory leak in `SmallWebRTCTransport`. In `aiortc`, when you receive
|
||||
a `MediaStreamTrack` (audio or video), frames are produced asynchronously. If
|
||||
the code never consumes these frames, they are queued in memory, causing a
|
||||
memory leak.
|
||||
|
||||
- Fixed an issue in `AsyncAITTSService`, where `TTSTextFrames` were not being
|
||||
pushed.
|
||||
|
||||
- Fixed an issue that would cause `push_interruption_task_frame_and_wait()` to
|
||||
not wait if a previous interruption had already happened.
|
||||
|
||||
- Fixed a couple of bugs in `ServiceSwitcher`:
|
||||
|
||||
- Using multiple `ServiceSwitcher`s in a pipeline would result in an error.
|
||||
- `ServiceSwitcherFrame`s (such as `ManuallySwitchServiceFrame`s) were having
|
||||
an effect too early, essentially "jumping the queue" in terms of pipeline
|
||||
frame ordering.
|
||||
|
||||
- Fixed a self-cancellation deadlock in `UserIdleProcessor` when returning
|
||||
`False` from an idle callback. The task now terminates naturally instead of
|
||||
attempting to cancel itself.
|
||||
|
||||
- Fixed an issue in `AudioBufferProcessor` where a recording is not created
|
||||
when a bot speaks and user input is blocked.
|
||||
|
||||
- Fixed a `FastAPIWebsocketTransport` and `SmallWebRTCTransport` issue where
|
||||
`on_client_disconnected` would be triggered when the bot ends the
|
||||
conversation. That is, `on_client_disconnected` should only be triggered when
|
||||
the remote client actually disconnects.
|
||||
|
||||
- Fixed an issue in `HeyGenVideoService` where the `BotStartedSpeakingFrame`
|
||||
was blocked from moving through the Pipeline.
|
||||
|
||||
## [0.0.85] - 2025-09-12
|
||||
|
||||
### Added
|
||||
|
||||
- `AzureSTTService` now pushes interim transcriptions.
|
||||
|
||||
- Added `voice_cloning_key` to `GoogleTTSService` to support custom cloned
|
||||
voices.
|
||||
|
||||
- Added `speaking_rate` to `GoogleTTSService.InputParams` to control the
|
||||
speaking rate.
|
||||
|
||||
- Added a `speed` arg to `OpenAITTSService` to control the speed of the voice
|
||||
response.
|
||||
|
||||
- Added `FrameProcessor.push_interruption_task_frame_and_wait()`. Use this
|
||||
method to programatically interrupt the bot from any part of the
|
||||
pipeline. This guarantees that all the processors in the pipeline are
|
||||
interrupted in order (from upstream to downstream). Internally, this works by
|
||||
first pushing an `InterruptionTaskFrame` upstream until it reaches the
|
||||
pipeline task. The pipeline task then generates an `InterruptionFrame`, which
|
||||
flows downstream through all processors. Once the `InterruptionFrame` has
|
||||
reaches the processor waiting for the interruption, the function returns and
|
||||
execution continues after the call. Think of it as sending an upstream request
|
||||
for interruption and waiting until the acknowledgment flows back downstream.
|
||||
|
||||
- Added new base `TaskFrame` (which is a system frame). This is the base class
|
||||
for all task frames (`EndTaskFrame`, `CancelTaskFrame`, etc.) that are meant
|
||||
to be pushed upstream to reach the pipeline task.
|
||||
|
||||
- Expanded support for universal `LLMContext` to the AWS Bedrock LLM service.
|
||||
Using the universal `LLMContext` and associated `LLMContextAggregatorPair` is
|
||||
a pre-requisite for using `LLMSwitcher` to switch between LLMs at runtime.
|
||||
|
||||
- Added new fields to the development runner's `parse_telephony_websocket`
|
||||
method in support of providing dynamic data to a bot.
|
||||
|
||||
- Twilio: Added a new `body` parameter, which parses the websocket message
|
||||
for `customParameters`. Provide data via the `Parameter` nouns in your
|
||||
TwiML to use this feature.
|
||||
- Telnyx & Exotel: Both providers make the `to` and `from` phone numbers
|
||||
available in the websocket messages. You can now access these numbers as
|
||||
`call_data["to"]` and `call_data["from"]`.
|
||||
|
||||
Note: Each telephony provider offers different features. Refer to the
|
||||
corresponding example in `pipecat-examples` to see how to pass custom data
|
||||
to your bot.
|
||||
|
||||
- Added `body` to the `WebsocketRunnerArguments` as an optional parameter.
|
||||
Custom `body` information can be passed from the server into the bot file via
|
||||
the `bot()` method using this new parameter.
|
||||
|
||||
- Added video streaming support to `LiveKitTransport`.
|
||||
|
||||
- Added `OpenAIRealtimeLLMService` and `AzureRealtimeLLMService` which provide
|
||||
access to OpenAI Realtime.
|
||||
|
||||
### Changed
|
||||
|
||||
- `pipeline.tests.utils.run_test()` now allows passing `PipelineParams` instead
|
||||
of individual parameters.
|
||||
|
||||
### Removed
|
||||
|
||||
- Remove `VisionImageRawFrame` in favor of context frames (`LLMContextFrame` or
|
||||
`OpenAILLMContextFrame`).
|
||||
|
||||
### Deprecated
|
||||
|
||||
- `BotInterruptionFrame` is now deprecated, use `InterruptionTaskFrame` instead.
|
||||
|
||||
- `StartInterruptionFrame` is now deprected, use `InterruptionFrame` instead.
|
||||
|
||||
- Deprecate `VisionImageFrameAggregator` because `VisionImageRawFrame` has been
|
||||
removed. See the `12*` examples for the new recommended replacement pattern.
|
||||
|
||||
- `NoisereduceFilter` is now deprecated and will be removed in a future
|
||||
version. Use other audio filters like `KrispFilter` or `AICFilter`.
|
||||
|
||||
- Deprecated `OpenAIRealtimeBetaLLMService` and `AzureRealtimeBetaLLMService`.
|
||||
Use `OpenAIRealtimeLLMService` and `AzureRealtimeLLMService`, respectively.
|
||||
Each service will be removed in an upcoming version, 1.0.0.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed a `BaseOutputTransport` issue that caused incorrect detection of when
|
||||
the bot stopped talking while using an audio mixer.
|
||||
|
||||
- Fixed a `LiveKitTransport` issue where RTVI messages were not properly
|
||||
encoded.
|
||||
|
||||
- Add additional fixups to Mistral context messages to ensure they meet
|
||||
Mistral-specific requirements, avoiding Mistral "invalid request" errors.
|
||||
|
||||
- Fixed `DailyTransport` transcription handling to gracefully handle missing
|
||||
`rawResponse` field in transcription messages, preventing KeyError crashes.
|
||||
|
||||
## [0.0.84] - 2025-09-05
|
||||
|
||||
### Added
|
||||
|
||||
- Add the ability to send DTMF to `LiveKitTransport`.
|
||||
|
||||
- Expanded support for universal `LLMContext` to the Anthropic LLM service.
|
||||
Using the universal `LLMContext` and associated `LLMContextAggregatorPair` is
|
||||
a pre-requisite for using `LLMSwitcher` to switch between LLMs at runtime.
|
||||
|
||||
### Changed
|
||||
|
||||
- Updated `daily-python` to 0.19.9.
|
||||
|
||||
- Restored `DailyTransport`'s native DTMF support using Daily's `send_dtmf()`
|
||||
method instead of generated audio tones.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed a `AWSBedrockLLMService` crash caused by an extra `await`.
|
||||
|
||||
- Fixed a `OpenAIImageGenService` issue where it was not creating
|
||||
`URLImageRawFrame` correctly.
|
||||
|
||||
## [0.0.83] - 2025-09-03
|
||||
|
||||
### Added
|
||||
|
||||
- Added multilingual support for AsyncAI in `AsyncAITTSService` and `AsyncAIHttpTTSService`.
|
||||
|
||||
- New `languages`: `es`, `fr`, `de`, `it`.
|
||||
|
||||
- Added new frames `InputTransportMessageUrgentFrame` and
|
||||
`DailyInputTransportMessageUrgentFrame` for transport messages received from
|
||||
external sources.
|
||||
|
||||
|
||||
- Added `UserSpeakingFrame`. This will be sent upstream and downstream while VAD
|
||||
detects the user is speaking.
|
||||
|
||||
@@ -82,7 +307,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
- Added new config parameters to `GladiaSTTService`.
|
||||
- PreProcessingConfig > `audio_enhancer` to enhance audio quality.
|
||||
- CustomVocabularyItem > `pronunciations` and `language` to specify special pronunciations and in which language it will be pronounced.
|
||||
- CustomVocabularyItem > `pronunciations` and `language` to specify special
|
||||
pronunciations and in which language it will be pronounced.
|
||||
|
||||
### Changed
|
||||
|
||||
@@ -101,7 +327,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- `pipecat.frames.frames.KeypadEntry` is deprecated and has been moved to
|
||||
`pipecat.audio.dtmf.types.KeypadEntry`.
|
||||
|
||||
- Updated `RimeTTSService`'s flush_audio message to conform with Rime's official API.
|
||||
- Updated `RimeTTSService`'s flush_audio message to conform with Rime's official
|
||||
API.
|
||||
|
||||
- Updated the default model for `CerebrasLLMService` to GPT-OSS-120B.
|
||||
|
||||
|
||||
73
README.md
73
README.md
@@ -21,6 +21,8 @@
|
||||
|
||||
🧭 Looking to build structured conversations? Check out [Pipecat Flows](https://github.com/pipecat-ai/pipecat-flows) for managing complex conversational states and transitions.
|
||||
|
||||
🔍 Looking for help debugging your pipeline and processors? Check out [Whisker](https://github.com/pipecat-ai/whisker), a real-time Pipecat debugger.
|
||||
|
||||
## 🧠 Why Pipecat?
|
||||
|
||||
- **Voice-first**: Integrates speech recognition, text-to-speech, and conversation handling
|
||||
@@ -28,6 +30,41 @@
|
||||
- **Composable Pipelines**: Build complex behavior from modular components
|
||||
- **Real-Time**: Ultra-low latency interaction with different transports (e.g. WebSockets or WebRTC)
|
||||
|
||||
## 📱 Client SDKs
|
||||
|
||||
You can connect to Pipecat from any platform using our official SDKs:
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td>
|
||||
<img src="https://cdn.jsdelivr.net/gh/devicons/devicon/icons/javascript/javascript-original.svg" width="40" height="40" alt="JavaScript"/>
|
||||
<a href="https://docs.pipecat.ai/client/js/introduction">JavaScript</a>
|
||||
</td>
|
||||
<td>
|
||||
<img src="https://cdn.jsdelivr.net/gh/devicons/devicon/icons/react/react-original.svg" width="40" height="40" alt="React"/>
|
||||
<a href="https://docs.pipecat.ai/client/react/introduction">React</a>
|
||||
</td>
|
||||
<td>
|
||||
<img src="https://cdn.jsdelivr.net/gh/devicons/devicon/icons/react/react-original.svg" width="40" height="40" alt="React Native"/>
|
||||
<a href="https://docs.pipecat.ai/client/react-native/introduction">React Native</a>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<img src="https://cdn.jsdelivr.net/gh/devicons/devicon/icons/swift/swift-original.svg" width="40" height="40" alt="Swift"/>
|
||||
<a href="https://docs.pipecat.ai/client/ios/introduction">Swift</a>
|
||||
</td>
|
||||
<td>
|
||||
<img src="https://cdn.jsdelivr.net/gh/devicons/devicon/icons/kotlin/kotlin-original.svg" width="40" height="40" alt="Kotlin"/>
|
||||
<a href="https://docs.pipecat.ai/client/android/introduction">Kotlin</a>
|
||||
</td>
|
||||
<td>
|
||||
<img src="https://cdn.jsdelivr.net/gh/devicons/devicon/icons/cplusplus/cplusplus-original.svg" width="40" height="40" alt="JavaScript"/>
|
||||
<a href="https://docs.pipecat.ai/client/c++/introduction">C++</a>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## 🎬 See it in action
|
||||
|
||||
<p float="left">
|
||||
@@ -38,17 +75,6 @@
|
||||
<a href="https://github.com/pipecat-ai/pipecat-examples/tree/main/moondream-chatbot"><img src="https://raw.githubusercontent.com/pipecat-ai/pipecat-examples/main/moondream-chatbot/image.png" width="400" /></a>
|
||||
</p>
|
||||
|
||||
## 📱 Client SDKs
|
||||
|
||||
You can connect to Pipecat from any platform using our official SDKs:
|
||||
|
||||
| Platform | SDK Repo | Description |
|
||||
| -------- | ------------------------------------------------------------------------------ | -------------------------------- |
|
||||
| Web | [pipecat-client-web](https://github.com/pipecat-ai/pipecat-client-web) | JavaScript and React client SDKs |
|
||||
| iOS | [pipecat-client-ios](https://github.com/pipecat-ai/pipecat-client-ios) | Swift SDK for iOS |
|
||||
| Android | [pipecat-client-android](https://github.com/pipecat-ai/pipecat-client-android) | Kotlin SDK for Android |
|
||||
| C++ | [pipecat-client-cxx](https://github.com/pipecat-ai/pipecat-client-cxx) | C++ client SDK |
|
||||
|
||||
## 🧩 Available services
|
||||
|
||||
| Category | Services |
|
||||
@@ -62,7 +88,7 @@ You can connect to Pipecat from any platform using our official SDKs:
|
||||
| Video | [HeyGen](https://docs.pipecat.ai/server/services/video/heygen), [Tavus](https://docs.pipecat.ai/server/services/video/tavus), [Simli](https://docs.pipecat.ai/server/services/video/simli) |
|
||||
| Memory | [mem0](https://docs.pipecat.ai/server/services/memory/mem0) |
|
||||
| Vision & Image | [fal](https://docs.pipecat.ai/server/services/image-generation/fal), [Google Imagen](https://docs.pipecat.ai/server/services/image-generation/fal), [Moondream](https://docs.pipecat.ai/server/services/vision/moondream) |
|
||||
| Audio Processing | [Silero VAD](https://docs.pipecat.ai/server/utilities/audio/silero-vad-analyzer), [Krisp](https://docs.pipecat.ai/server/utilities/audio/krisp-filter), [Koala](https://docs.pipecat.ai/server/utilities/audio/koala-filter), [ai-coustics](https://docs.pipecat.ai/server/utilities/audio/aic-filter), [Noisereduce](https://docs.pipecat.ai/server/utilities/audio/noisereduce-filter) |
|
||||
| Audio Processing | [Silero VAD](https://docs.pipecat.ai/server/utilities/audio/silero-vad-analyzer), [Krisp](https://docs.pipecat.ai/server/utilities/audio/krisp-filter), [Koala](https://docs.pipecat.ai/server/utilities/audio/koala-filter), [ai-coustics](https://docs.pipecat.ai/server/utilities/audio/aic-filter) |
|
||||
| Analytics & Metrics | [OpenTelemetry](https://docs.pipecat.ai/server/utilities/opentelemetry), [Sentry](https://docs.pipecat.ai/server/services/analytics/sentry) |
|
||||
|
||||
📚 [View full services documentation →](https://docs.pipecat.ai/server/services/supported-services)
|
||||
@@ -129,7 +155,11 @@ You can get started with Pipecat running on your local machine, then move your a
|
||||
2. Install development and testing dependencies:
|
||||
|
||||
```bash
|
||||
uv sync --group dev --all-extras --no-extra gstreamer --no-extra krisp --no-extra local
|
||||
uv sync --group dev --all-extras \
|
||||
--no-extra gstreamer \
|
||||
--no-extra krisp \
|
||||
--no-extra local \
|
||||
--no-extra ultravox # (ultravox not fully supported on macOS)
|
||||
```
|
||||
|
||||
3. Install the git pre-commit hooks:
|
||||
@@ -138,23 +168,6 @@ You can get started with Pipecat running on your local machine, then move your a
|
||||
uv run pre-commit install
|
||||
```
|
||||
|
||||
### Python 3.13+ Compatibility
|
||||
|
||||
Some features require PyTorch, which doesn't yet support Python 3.13+. Install using:
|
||||
|
||||
```bash
|
||||
uv sync --group dev --all-extras \
|
||||
--no-extra gstreamer \
|
||||
--no-extra krisp \
|
||||
--no-extra local \
|
||||
--no-extra local-smart-turn \
|
||||
--no-extra mlx-whisper \
|
||||
--no-extra moondream \
|
||||
--no-extra ultravox
|
||||
```
|
||||
|
||||
> **Tip:** For full compatibility, use Python 3.12: `uv python pin 3.12`
|
||||
|
||||
> **Note**: Some extras (local, gstreamer) require system dependencies. See documentation if you encounter build errors.
|
||||
|
||||
### Running tests
|
||||
|
||||
@@ -11,7 +11,7 @@ import sys
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import TextFrame
|
||||
from pipecat.frames.frames import TTSSpeakFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
@@ -50,7 +50,7 @@ async def main():
|
||||
async def on_first_participant_joined(transport, participant_id):
|
||||
await asyncio.sleep(1)
|
||||
await task.queue_frame(
|
||||
TextFrame(
|
||||
TTSSpeakFrame(
|
||||
"Hello there! How are you doing today? Would you like to talk about the weather?"
|
||||
)
|
||||
)
|
||||
|
||||
@@ -14,7 +14,7 @@ from loguru import logger
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import (
|
||||
BotInterruptionFrame,
|
||||
InterruptionFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
@@ -115,7 +115,7 @@ async def main():
|
||||
|
||||
await task.queue_frames(
|
||||
[
|
||||
BotInterruptionFrame(),
|
||||
InterruptionFrame(),
|
||||
UserStartedSpeakingFrame(),
|
||||
TranscriptionFrame(
|
||||
user_id=participant_id,
|
||||
|
||||
@@ -36,7 +36,6 @@ load_dotenv(override=True)
|
||||
audiobuffer = AudioBufferProcessor(
|
||||
num_channels=2, # 1 for mono, 2 for stereo (user left, bot right)
|
||||
enable_turn_audio=False, # Enable per-turn audio recording
|
||||
user_continuous_stream=True, # User has continuous audio stream
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -12,8 +12,8 @@ from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
InterruptionFrame,
|
||||
LLMRunFrame,
|
||||
StartInterruptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
@@ -97,7 +97,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
|
||||
@stt.event_handler("on_speech_started")
|
||||
async def on_speech_started(stt, *args, **kwargs):
|
||||
await task.queue_frames([StartInterruptionFrame(), UserStartedSpeakingFrame()])
|
||||
await task.queue_frames([InterruptionFrame(), UserStartedSpeakingFrame()])
|
||||
|
||||
@stt.event_handler("on_utterance_end")
|
||||
async def on_utterance_end(stt, *args, **kwargs):
|
||||
|
||||
@@ -16,10 +16,10 @@ from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
InterruptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMRunFrame,
|
||||
StartInterruptionFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
@@ -93,9 +93,8 @@ class UserAudioCollector(FrameProcessor):
|
||||
elif isinstance(frame, UserStoppedSpeakingFrame):
|
||||
self._user_speaking = False
|
||||
self._context.add_audio_frames_message(audio_frames=self._audio_frames)
|
||||
await self._user_context_aggregator.push_frame(
|
||||
self._user_context_aggregator.get_context_frame()
|
||||
)
|
||||
await self._user_context_aggregator.push_frame(LLMRunFrame())
|
||||
|
||||
elif isinstance(frame, InputAudioRawFrame):
|
||||
if self._user_speaking:
|
||||
self._audio_frames.append(frame)
|
||||
@@ -151,7 +150,7 @@ class TranscriptExtractor(FrameProcessor):
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
class TanscriptionContextFixup(FrameProcessor):
|
||||
class TranscriptionContextFixup(FrameProcessor):
|
||||
def __init__(self, context):
|
||||
super().__init__()
|
||||
self._context = context
|
||||
@@ -182,9 +181,7 @@ class TanscriptionContextFixup(FrameProcessor):
|
||||
|
||||
if isinstance(frame, MagicDemoTranscriptionFrame):
|
||||
self._transcript = frame.text
|
||||
elif isinstance(frame, LLMFullResponseEndFrame) or isinstance(
|
||||
frame, StartInterruptionFrame
|
||||
):
|
||||
elif isinstance(frame, LLMFullResponseEndFrame) or isinstance(frame, InterruptionFrame):
|
||||
self.swap_user_audio()
|
||||
self.add_transcript_back_to_inference_output()
|
||||
self._transcript = ""
|
||||
@@ -245,7 +242,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
audio_collector = UserAudioCollector(context, context_aggregator.user())
|
||||
pull_transcript_out_of_llm_output = TranscriptExtractor(context)
|
||||
fixup_context_messages = TanscriptionContextFixup(context)
|
||||
fixup_context_messages = TranscriptionContextFixup(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
|
||||
@@ -11,12 +11,19 @@ from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import Frame, TextFrame, TTSSpeakFrame, UserImageRequestFrame
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
LLMContextFrame,
|
||||
TextFrame,
|
||||
TTSSpeakFrame,
|
||||
UserImageRawFrame,
|
||||
UserImageRequestFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.user_response import UserResponseAggregator
|
||||
from pipecat.processors.aggregators.vision_image_frame import VisionImageFrameAggregator
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import (
|
||||
@@ -34,6 +41,8 @@ load_dotenv(override=True)
|
||||
|
||||
|
||||
class UserImageRequester(FrameProcessor):
|
||||
"""Converts incoming text into requests for user images."""
|
||||
|
||||
def __init__(self, participant_id: Optional[str] = None):
|
||||
super().__init__()
|
||||
self._participant_id = participant_id
|
||||
@@ -46,9 +55,32 @@ class UserImageRequester(FrameProcessor):
|
||||
|
||||
if self._participant_id and isinstance(frame, TextFrame):
|
||||
await self.push_frame(
|
||||
UserImageRequestFrame(self._participant_id), FrameDirection.UPSTREAM
|
||||
UserImageRequestFrame(self._participant_id, context=frame.text),
|
||||
FrameDirection.UPSTREAM,
|
||||
)
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
class UserImageProcessor(FrameProcessor):
|
||||
"""Converts incoming user images into context frames."""
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, UserImageRawFrame):
|
||||
if frame.request and frame.request.context:
|
||||
context = LLMContext()
|
||||
context.add_image_frame_message(
|
||||
image=frame.image,
|
||||
text=frame.request.context,
|
||||
size=frame.size,
|
||||
format=frame.format,
|
||||
)
|
||||
frame = LLMContextFrame(context)
|
||||
await self.push_frame(frame)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
@@ -78,7 +110,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
# Initialize the image requester without setting the participant ID yet
|
||||
image_requester = UserImageRequester()
|
||||
|
||||
vision_aggregator = VisionImageFrameAggregator()
|
||||
image_processor = UserImageProcessor()
|
||||
|
||||
# If you run into weird description, try with use_cpu=True
|
||||
moondream = MoondreamService()
|
||||
@@ -96,7 +128,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
stt,
|
||||
user_response,
|
||||
image_requester,
|
||||
vision_aggregator,
|
||||
image_processor,
|
||||
moondream,
|
||||
tts,
|
||||
transport.output(),
|
||||
@@ -119,7 +151,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
image_requester.set_participant_id(client_id)
|
||||
|
||||
# Welcome message
|
||||
await task.queue_frame(TTSSpeakFrame("Hi there! Feel free to ask me what I see."))
|
||||
await task.queue_frame(TTSSpeakFrame("Hi there! Feel free to ask me about what I see."))
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
|
||||
@@ -11,12 +11,19 @@ from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import Frame, TextFrame, TTSSpeakFrame, UserImageRequestFrame
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
LLMContextFrame,
|
||||
TextFrame,
|
||||
TTSSpeakFrame,
|
||||
UserImageRawFrame,
|
||||
UserImageRequestFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.user_response import UserResponseAggregator
|
||||
from pipecat.processors.aggregators.vision_image_frame import VisionImageFrameAggregator
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import (
|
||||
@@ -34,6 +41,8 @@ load_dotenv(override=True)
|
||||
|
||||
|
||||
class UserImageRequester(FrameProcessor):
|
||||
"""Converts incoming text into requests for user images."""
|
||||
|
||||
def __init__(self, participant_id: Optional[str] = None):
|
||||
super().__init__()
|
||||
self._participant_id = participant_id
|
||||
@@ -46,9 +55,32 @@ class UserImageRequester(FrameProcessor):
|
||||
|
||||
if self._participant_id and isinstance(frame, TextFrame):
|
||||
await self.push_frame(
|
||||
UserImageRequestFrame(self._participant_id), FrameDirection.UPSTREAM
|
||||
UserImageRequestFrame(self._participant_id, context=frame.text),
|
||||
FrameDirection.UPSTREAM,
|
||||
)
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
class UserImageProcessor(FrameProcessor):
|
||||
"""Converts incoming user images into context frames."""
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, UserImageRawFrame):
|
||||
if frame.request and frame.request.context:
|
||||
context = LLMContext()
|
||||
context.add_image_frame_message(
|
||||
image=frame.image,
|
||||
text=frame.request.context,
|
||||
size=frame.size,
|
||||
format=frame.format,
|
||||
)
|
||||
frame = LLMContextFrame(context)
|
||||
await self.push_frame(frame)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
@@ -78,7 +110,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
# Initialize the image requester without setting the participant ID yet
|
||||
image_requester = UserImageRequester()
|
||||
|
||||
vision_aggregator = VisionImageFrameAggregator()
|
||||
image_processor = UserImageProcessor()
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
@@ -96,7 +128,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
stt,
|
||||
user_response,
|
||||
image_requester,
|
||||
vision_aggregator,
|
||||
image_processor,
|
||||
google,
|
||||
tts,
|
||||
transport.output(),
|
||||
@@ -123,7 +155,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
image_requester.set_participant_id(client_id)
|
||||
|
||||
# Welcome message
|
||||
await task.queue_frame(TTSSpeakFrame("Hi there! Feel free to ask me what I see."))
|
||||
await task.queue_frame(TTSSpeakFrame("Hi there! Feel free to ask me about what I see."))
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
|
||||
@@ -11,12 +11,19 @@ from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import Frame, TextFrame, TTSSpeakFrame, UserImageRequestFrame
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
LLMContextFrame,
|
||||
TextFrame,
|
||||
TTSSpeakFrame,
|
||||
UserImageRawFrame,
|
||||
UserImageRequestFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.user_response import UserResponseAggregator
|
||||
from pipecat.processors.aggregators.vision_image_frame import VisionImageFrameAggregator
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import (
|
||||
@@ -34,6 +41,8 @@ load_dotenv(override=True)
|
||||
|
||||
|
||||
class UserImageRequester(FrameProcessor):
|
||||
"""Converts incoming text into requests for user images."""
|
||||
|
||||
def __init__(self, participant_id: Optional[str] = None):
|
||||
super().__init__()
|
||||
self._participant_id = participant_id
|
||||
@@ -46,9 +55,32 @@ class UserImageRequester(FrameProcessor):
|
||||
|
||||
if self._participant_id and isinstance(frame, TextFrame):
|
||||
await self.push_frame(
|
||||
UserImageRequestFrame(self._participant_id), FrameDirection.UPSTREAM
|
||||
UserImageRequestFrame(self._participant_id, context=frame.text),
|
||||
FrameDirection.UPSTREAM,
|
||||
)
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
class UserImageProcessor(FrameProcessor):
|
||||
"""Converts incoming user images into context frames."""
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, UserImageRawFrame):
|
||||
if frame.request and frame.request.context:
|
||||
context = LLMContext()
|
||||
context.add_image_frame_message(
|
||||
image=frame.image,
|
||||
text=frame.request.context,
|
||||
size=frame.size,
|
||||
format=frame.format,
|
||||
)
|
||||
frame = LLMContextFrame(context)
|
||||
await self.push_frame(frame)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
@@ -78,7 +110,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
# Initialize the image requester without setting the participant ID yet
|
||||
image_requester = UserImageRequester()
|
||||
|
||||
vision_aggregator = VisionImageFrameAggregator()
|
||||
image_processor = UserImageProcessor()
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
@@ -96,7 +128,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
stt,
|
||||
user_response,
|
||||
image_requester,
|
||||
vision_aggregator,
|
||||
image_processor,
|
||||
openai,
|
||||
tts,
|
||||
transport.output(),
|
||||
@@ -123,7 +155,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
image_requester.set_participant_id(client_id)
|
||||
|
||||
# Welcome message
|
||||
await task.queue_frame(TTSSpeakFrame("Hi there! Feel free to ask me what I see."))
|
||||
await task.queue_frame(TTSSpeakFrame("Hi there! Feel free to ask me about what I see."))
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
|
||||
@@ -11,12 +11,19 @@ from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import Frame, TextFrame, TTSSpeakFrame, UserImageRequestFrame
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
LLMContextFrame,
|
||||
TextFrame,
|
||||
TTSSpeakFrame,
|
||||
UserImageRawFrame,
|
||||
UserImageRequestFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.user_response import UserResponseAggregator
|
||||
from pipecat.processors.aggregators.vision_image_frame import VisionImageFrameAggregator
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import (
|
||||
@@ -34,6 +41,8 @@ load_dotenv(override=True)
|
||||
|
||||
|
||||
class UserImageRequester(FrameProcessor):
|
||||
"""Converts incoming text into requests for user images."""
|
||||
|
||||
def __init__(self, participant_id: Optional[str] = None):
|
||||
super().__init__()
|
||||
self._participant_id = participant_id
|
||||
@@ -46,9 +55,32 @@ class UserImageRequester(FrameProcessor):
|
||||
|
||||
if self._participant_id and isinstance(frame, TextFrame):
|
||||
await self.push_frame(
|
||||
UserImageRequestFrame(self._participant_id), FrameDirection.UPSTREAM
|
||||
UserImageRequestFrame(self._participant_id, context=frame.text),
|
||||
FrameDirection.UPSTREAM,
|
||||
)
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
class UserImageProcessor(FrameProcessor):
|
||||
"""Converts incoming user images into context frames."""
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, UserImageRawFrame):
|
||||
if frame.request and frame.request.context:
|
||||
context = LLMContext()
|
||||
context.add_image_frame_message(
|
||||
image=frame.image,
|
||||
text=frame.request.context,
|
||||
size=frame.size,
|
||||
format=frame.format,
|
||||
)
|
||||
frame = LLMContextFrame(context)
|
||||
await self.push_frame(frame)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
@@ -78,7 +110,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
# Initialize the image requester without setting the participant ID yet
|
||||
image_requester = UserImageRequester()
|
||||
|
||||
vision_aggregator = VisionImageFrameAggregator()
|
||||
image_processor = UserImageProcessor()
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
@@ -96,7 +128,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
stt,
|
||||
user_response,
|
||||
image_requester,
|
||||
vision_aggregator,
|
||||
image_processor,
|
||||
anthropic,
|
||||
tts,
|
||||
transport.output(),
|
||||
@@ -123,7 +155,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
image_requester.set_participant_id(client_id)
|
||||
|
||||
# Welcome message
|
||||
await task.queue_frame(TTSSpeakFrame("Hi there! Feel free to ask me what I see."))
|
||||
await task.queue_frame(TTSSpeakFrame("Hi there! Feel free to ask me about what I see."))
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
|
||||
187
examples/foundational/12d-describe-video-aws.py
Normal file
187
examples/foundational/12d-describe-video-aws.py
Normal file
@@ -0,0 +1,187 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
LLMContextFrame,
|
||||
TextFrame,
|
||||
TTSSpeakFrame,
|
||||
UserImageRawFrame,
|
||||
UserImageRequestFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.user_response import UserResponseAggregator
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import (
|
||||
create_transport,
|
||||
get_transport_client_id,
|
||||
maybe_capture_participant_camera,
|
||||
)
|
||||
from pipecat.services.aws.llm import AWSBedrockLLMService
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
class UserImageRequester(FrameProcessor):
|
||||
"""Converts incoming text into requests for user images."""
|
||||
|
||||
def __init__(self, participant_id: Optional[str] = None):
|
||||
super().__init__()
|
||||
self._participant_id = participant_id
|
||||
|
||||
def set_participant_id(self, participant_id: str):
|
||||
self._participant_id = participant_id
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if self._participant_id and isinstance(frame, TextFrame):
|
||||
await self.push_frame(
|
||||
UserImageRequestFrame(self._participant_id, context=frame.text),
|
||||
FrameDirection.UPSTREAM,
|
||||
)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
class UserImageProcessor(FrameProcessor):
|
||||
"""Converts incoming user images into context frames."""
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, UserImageRawFrame):
|
||||
if frame.request and frame.request.context:
|
||||
# Note: AWS Bedrock does not yet support the universal LLMContext
|
||||
context = LLMContext()
|
||||
context.add_image_frame_message(
|
||||
image=frame.image,
|
||||
text=frame.request.context,
|
||||
size=frame.size,
|
||||
format=frame.format,
|
||||
)
|
||||
frame = LLMContextFrame(context)
|
||||
await self.push_frame(frame)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
video_in_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
video_in_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
user_response = UserResponseAggregator()
|
||||
|
||||
# Initialize the image requester without setting the participant ID yet
|
||||
image_requester = UserImageRequester()
|
||||
|
||||
image_processor = UserImageProcessor()
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
# AWS for vision analysis
|
||||
aws = AWSBedrockLLMService(
|
||||
aws_region="us-west-2",
|
||||
model="us.anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||
# Note: usually, prefer providing latency="optimized" param.
|
||||
# Here we can't because AWS Bedrock doesn't support it for Claude 3.7,
|
||||
# which we need for image input.
|
||||
params=AWSBedrockLLMService.InputParams(temperature=0.8),
|
||||
)
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
stt,
|
||||
user_response,
|
||||
image_requester,
|
||||
image_processor,
|
||||
aws,
|
||||
tts,
|
||||
transport.output(),
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected: {client}")
|
||||
|
||||
await maybe_capture_participant_camera(transport, client)
|
||||
|
||||
# Set the participant ID in the image requester
|
||||
client_id = get_transport_client_id(transport, client)
|
||||
image_requester.set_participant_id(client_id)
|
||||
|
||||
# Welcome message
|
||||
await task.queue_frame(TTSSpeakFrame("Hi there! Feel free to ask me about what I see."))
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -31,6 +31,9 @@ class TranscriptionLogger(FrameProcessor):
|
||||
if isinstance(frame, TranscriptionFrame):
|
||||
print(f"Transcription: {frame.text}")
|
||||
|
||||
# Push all frames through
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
|
||||
@@ -32,6 +32,9 @@ class TranscriptionLogger(FrameProcessor):
|
||||
if isinstance(frame, TranscriptionFrame):
|
||||
print(f"Transcription: {frame.text}")
|
||||
|
||||
# Push all frames through
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
async def main():
|
||||
transport = LocalAudioTransport(
|
||||
|
||||
@@ -31,6 +31,9 @@ class TranscriptionLogger(FrameProcessor):
|
||||
if isinstance(frame, TranscriptionFrame):
|
||||
print(f"Transcription: {frame.text}")
|
||||
|
||||
# Push all frames through
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
|
||||
@@ -31,6 +31,9 @@ class TranscriptionLogger(FrameProcessor):
|
||||
if isinstance(frame, TranscriptionFrame):
|
||||
print(f"Transcription: {frame.text}")
|
||||
|
||||
# Push all frames through
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
|
||||
@@ -40,6 +40,9 @@ class TranscriptionLogger(FrameProcessor):
|
||||
elif isinstance(frame, TranslationFrame):
|
||||
print(f"Translation ({frame.language}): {frame.text}")
|
||||
|
||||
# Push all frames through
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
|
||||
@@ -31,6 +31,9 @@ class TranscriptionLogger(FrameProcessor):
|
||||
if isinstance(frame, TranscriptionFrame):
|
||||
print(f"Transcription: {frame.text}")
|
||||
|
||||
# Push all frames through
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
|
||||
@@ -52,6 +52,9 @@ class TranscriptionLogger(FrameProcessor):
|
||||
if isinstance(frame, TranscriptionFrame):
|
||||
self._last_transcription_time = time.time()
|
||||
|
||||
# Push all frames through
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
|
||||
@@ -31,6 +31,9 @@ class TranscriptionLogger(FrameProcessor):
|
||||
if isinstance(frame, TranscriptionFrame):
|
||||
print(f"Transcription: {frame.text}")
|
||||
|
||||
# Push all frames through
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
|
||||
@@ -53,6 +53,9 @@ class TranscriptionLogger(FrameProcessor):
|
||||
if isinstance(frame, TranscriptionFrame):
|
||||
self._last_transcription_time = time.time()
|
||||
|
||||
# Push all frames through
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
|
||||
@@ -32,6 +32,9 @@ class TranscriptionLogger(FrameProcessor):
|
||||
if isinstance(frame, TranscriptionFrame):
|
||||
print(f"Transcription: {frame.text}")
|
||||
|
||||
# Push all frames through
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
|
||||
@@ -32,6 +32,9 @@ class TranscriptionLogger(FrameProcessor):
|
||||
if isinstance(frame, TranscriptionFrame):
|
||||
print(f"Transcription: {frame.text}")
|
||||
|
||||
# Push all frames through
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
|
||||
@@ -32,6 +32,9 @@ class TranscriptionLogger(FrameProcessor):
|
||||
if isinstance(frame, TranscriptionFrame):
|
||||
print(f"Transcription: {frame.text}")
|
||||
|
||||
# Push all frames through
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
|
||||
@@ -0,0 +1,214 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import (
|
||||
create_transport,
|
||||
get_transport_client_id,
|
||||
maybe_capture_participant_camera,
|
||||
)
|
||||
from pipecat.services.aws.llm import AWSBedrockLLMService
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.services.daily import DailyParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
# Global variable to store the client ID
|
||||
client_id = ""
|
||||
|
||||
|
||||
async def get_weather(params: FunctionCallParams):
|
||||
location = params.arguments["location"]
|
||||
await params.result_callback(f"The weather in {location} is currently 72 degrees and sunny.")
|
||||
|
||||
|
||||
async def get_image(params: FunctionCallParams):
|
||||
question = params.arguments["question"]
|
||||
logger.debug(f"Requesting image with user_id={client_id}, question={question}")
|
||||
|
||||
# Request the image frame
|
||||
await params.llm.request_image_frame(
|
||||
user_id=client_id,
|
||||
function_name=params.function_name,
|
||||
tool_call_id=params.tool_call_id,
|
||||
text_content=question,
|
||||
)
|
||||
|
||||
# Wait a short time for the frame to be processed
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Return a result to complete the function call
|
||||
await params.result_callback(
|
||||
f"I've captured an image from your camera and I'm analyzing what you asked about: {question}"
|
||||
)
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
video_in_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
video_in_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
)
|
||||
|
||||
llm = AWSBedrockLLMService(
|
||||
aws_region="us-west-2",
|
||||
model="us.anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||
# Note: usually, prefer providing latency="optimized" param.
|
||||
# Here we can't because AWS Bedrock doesn't support it for Claude 3.7,
|
||||
# which we need for image input.
|
||||
params=AWSBedrockLLMService.InputParams(temperature=0.8),
|
||||
)
|
||||
llm.register_function("get_weather", get_weather)
|
||||
llm.register_function("get_image", get_image)
|
||||
|
||||
weather_function = FunctionSchema(
|
||||
name="get_weather",
|
||||
description="Get the current weather",
|
||||
properties={
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
},
|
||||
required=["location"],
|
||||
)
|
||||
get_image_function = FunctionSchema(
|
||||
name="get_image",
|
||||
description="Get an image from the video stream.",
|
||||
properties={
|
||||
"question": {
|
||||
"type": "string",
|
||||
"description": "The question that the user is asking about the image.",
|
||||
}
|
||||
},
|
||||
required=["question"],
|
||||
)
|
||||
tools = ToolsSchema(standard_tools=[weather_function, get_image_function])
|
||||
|
||||
system_prompt = """\
|
||||
You are a helpful assistant who converses with a user and answers questions. Respond concisely to general questions.
|
||||
|
||||
Your response will be turned into speech so use only simple words and punctuation.
|
||||
|
||||
You have access to two tools: get_weather and get_image.
|
||||
|
||||
You can respond to questions about the weather using the get_weather tool.
|
||||
|
||||
You can answer questions about the user's video stream using the get_image tool. Some examples of phrases that \
|
||||
indicate you should use the get_image tool are:
|
||||
- What do you see?
|
||||
- What's in the video?
|
||||
- Can you describe the video?
|
||||
- Tell me about what you see.
|
||||
- Tell me something interesting about what you see.
|
||||
- What's happening in the video?
|
||||
|
||||
If you need to use a tool, simply use the tool. Do not tell the user the tool you are using. Be brief and concise.
|
||||
"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": "Start the conversation by introducing yourself."},
|
||||
]
|
||||
|
||||
context = LLMContext(messages, tools)
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt, # STT
|
||||
context_aggregator.user(), # User speech to text
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses and tool context
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected: {client}")
|
||||
|
||||
await maybe_capture_participant_camera(transport, client)
|
||||
|
||||
global client_id
|
||||
client_id = get_transport_client_id(transport, client)
|
||||
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -97,7 +97,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
llm = AnthropicLLMService(
|
||||
api_key=os.getenv("ANTHROPIC_API_KEY"),
|
||||
model="claude-3-7-sonnet-latest",
|
||||
enable_prompt_caching_beta=True,
|
||||
params=AnthropicLLMService.InputParams(enable_prompt_caching=True),
|
||||
)
|
||||
llm.register_function("get_weather", get_weather)
|
||||
llm.register_function("get_image", get_image)
|
||||
|
||||
@@ -0,0 +1,211 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import (
|
||||
create_transport,
|
||||
get_transport_client_id,
|
||||
maybe_capture_participant_camera,
|
||||
)
|
||||
from pipecat.services.anthropic.llm import AnthropicLLMService
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.services.daily import DailyParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
# Global variable to store the client ID
|
||||
client_id = ""
|
||||
|
||||
|
||||
async def get_weather(params: FunctionCallParams):
|
||||
location = params.arguments["location"]
|
||||
await params.result_callback(f"The weather in {location} is currently 72 degrees and sunny.")
|
||||
|
||||
|
||||
async def get_image(params: FunctionCallParams):
|
||||
question = params.arguments["question"]
|
||||
logger.debug(f"Requesting image with user_id={client_id}, question={question}")
|
||||
|
||||
# Request the image frame
|
||||
await params.llm.request_image_frame(
|
||||
user_id=client_id,
|
||||
function_name=params.function_name,
|
||||
tool_call_id=params.tool_call_id,
|
||||
text_content=question,
|
||||
)
|
||||
|
||||
# Wait a short time for the frame to be processed
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Return a result to complete the function call
|
||||
await params.result_callback(
|
||||
f"I've captured an image from your camera and I'm analyzing what you asked about: {question}"
|
||||
)
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
video_in_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
video_in_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
)
|
||||
|
||||
llm = AnthropicLLMService(
|
||||
api_key=os.getenv("ANTHROPIC_API_KEY"),
|
||||
model="claude-3-7-sonnet-latest",
|
||||
params=AnthropicLLMService.InputParams(enable_prompt_caching=True),
|
||||
)
|
||||
llm.register_function("get_weather", get_weather)
|
||||
llm.register_function("get_image", get_image)
|
||||
|
||||
weather_function = FunctionSchema(
|
||||
name="get_weather",
|
||||
description="Get the current weather",
|
||||
properties={
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
},
|
||||
required=["location"],
|
||||
)
|
||||
get_image_function = FunctionSchema(
|
||||
name="get_image",
|
||||
description="Get an image from the video stream.",
|
||||
properties={
|
||||
"question": {
|
||||
"type": "string",
|
||||
"description": "The question that the user is asking about the image.",
|
||||
}
|
||||
},
|
||||
required=["question"],
|
||||
)
|
||||
tools = ToolsSchema(standard_tools=[weather_function, get_image_function])
|
||||
|
||||
system_prompt = """\
|
||||
You are a helpful assistant who converses with a user and answers questions. Respond concisely to general questions.
|
||||
|
||||
Your response will be turned into speech so use only simple words and punctuation.
|
||||
|
||||
You have access to two tools: get_weather and get_image.
|
||||
|
||||
You can respond to questions about the weather using the get_weather tool.
|
||||
|
||||
You can answer questions about the user's video stream using the get_image tool. Some examples of phrases that \
|
||||
indicate you should use the get_image tool are:
|
||||
- What do you see?
|
||||
- What's in the video?
|
||||
- Can you describe the video?
|
||||
- Tell me about what you see.
|
||||
- Tell me something interesting about what you see.
|
||||
- What's happening in the video?
|
||||
|
||||
If you need to use a tool, simply use the tool. Do not tell the user the tool you are using. Be brief and concise.
|
||||
"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": "Start the conversation by introducing yourself."},
|
||||
]
|
||||
|
||||
context = LLMContext(messages, tools)
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt, # STT
|
||||
context_aggregator.user(), # User speech to text
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses and tool context
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected: {client}")
|
||||
|
||||
await maybe_capture_participant_camera(transport, client)
|
||||
|
||||
global client_id
|
||||
client_id = get_transport_client_id(transport, client)
|
||||
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
228
examples/foundational/19-openai-realtime.py
Normal file
228
examples/foundational/19-openai-realtime.py
Normal file
@@ -0,0 +1,228 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import LLMRunFrame, TranscriptionMessage
|
||||
from pipecat.observers.loggers.transcription_log_observer import TranscriptionLogObserver
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.transcript_processor import TranscriptProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.services.openai_realtime import (
|
||||
InputAudioNoiseReduction,
|
||||
InputAudioTranscription,
|
||||
OpenAIRealtimeLLMService,
|
||||
SemanticTurnDetection,
|
||||
SessionProperties,
|
||||
)
|
||||
from pipecat.services.openai_realtime.events import AudioConfiguration, AudioInput
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
async def fetch_weather_from_api(params: FunctionCallParams):
|
||||
temperature = 75 if params.arguments["format"] == "fahrenheit" else 24
|
||||
await params.result_callback(
|
||||
{
|
||||
"conditions": "nice",
|
||||
"temperature": temperature,
|
||||
"format": params.arguments["format"],
|
||||
"timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def fetch_restaurant_recommendation(params: FunctionCallParams):
|
||||
await params.result_callback({"name": "The Golden Dragon"})
|
||||
|
||||
|
||||
weather_function = FunctionSchema(
|
||||
name="get_current_weather",
|
||||
description="Get the current weather",
|
||||
properties={
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location.",
|
||||
},
|
||||
},
|
||||
required=["location", "format"],
|
||||
)
|
||||
|
||||
restaurant_function = FunctionSchema(
|
||||
name="get_restaurant_recommendation",
|
||||
description="Get a restaurant recommendation",
|
||||
properties={
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
},
|
||||
required=["location"],
|
||||
)
|
||||
|
||||
# Create tools schema
|
||||
tools = ToolsSchema(standard_tools=[weather_function, restaurant_function])
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
session_properties = SessionProperties(
|
||||
audio=AudioConfiguration(
|
||||
input=AudioInput(
|
||||
transcription=InputAudioTranscription(),
|
||||
# Set openai TurnDetection parameters. Not setting this at all will turn it
|
||||
# on by default
|
||||
turn_detection=SemanticTurnDetection(),
|
||||
# Or set to False to disable openai turn detection and use transport VAD
|
||||
# turn_detection=False,
|
||||
noise_reduction=InputAudioNoiseReduction(type="near_field"),
|
||||
)
|
||||
),
|
||||
# tools=tools,
|
||||
instructions="""You are a helpful and friendly AI.
|
||||
|
||||
Act like a human, but remember that you aren't a human and that you can't do human
|
||||
things in the real world. Your voice and personality should be warm and engaging, with a lively and
|
||||
playful tone.
|
||||
|
||||
If interacting in a non-English language, start by using the standard accent or dialect familiar to
|
||||
the user. Talk quickly. You should always call a function if you can. Do not refer to these rules,
|
||||
even if you're asked about them.
|
||||
|
||||
You are participating in a voice conversation. Keep your responses concise, short, and to the point
|
||||
unless specifically asked to elaborate on a topic.
|
||||
|
||||
You have access to the following tools:
|
||||
- get_current_weather: Get the current weather for a given location.
|
||||
- get_restaurant_recommendation: Get a restaurant recommendation for a given location.
|
||||
|
||||
Remember, your responses should be short. Just one or two sentences, usually. Respond in English.""",
|
||||
)
|
||||
|
||||
llm = OpenAIRealtimeLLMService(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
session_properties=session_properties,
|
||||
start_audio_paused=False,
|
||||
)
|
||||
|
||||
# you can either register a single function for all function calls, or specific functions
|
||||
# llm.register_function(None, fetch_weather_from_api)
|
||||
llm.register_function("get_current_weather", fetch_weather_from_api)
|
||||
llm.register_function("get_restaurant_recommendation", fetch_restaurant_recommendation)
|
||||
|
||||
transcript = TranscriptProcessor()
|
||||
|
||||
# Create a standard OpenAI LLM context object using the normal messages format. The
|
||||
# OpenAIRealtimeLLMService will convert this internally to messages that the
|
||||
# openai WebSocket API can understand.
|
||||
context = OpenAILLMContext(
|
||||
[{"role": "user", "content": "Say hello!"}],
|
||||
tools,
|
||||
)
|
||||
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
context_aggregator.user(),
|
||||
llm, # LLM
|
||||
transcript.user(), # Placed after the LLM, as LLM pushes TranscriptionFrames downstream
|
||||
transport.output(), # Transport bot output
|
||||
transcript.assistant(), # After the transcript output, to time with the audio output
|
||||
context_aggregator.assistant(),
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
observers=[TranscriptionLogObserver()],
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
# Register event handler for transcript updates
|
||||
@transcript.event_handler("on_transcript_update")
|
||||
async def on_transcript_update(processor, frame):
|
||||
for msg in frame.messages:
|
||||
if isinstance(msg, TranscriptionMessage):
|
||||
timestamp = f"[{msg.timestamp}] " if msg.timestamp else ""
|
||||
line = f"{timestamp}{msg.role}: {msg.content}"
|
||||
logger.info(f"Transcript: {line}")
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
221
examples/foundational/19a-azure-realtime.py
Normal file
221
examples/foundational/19a-azure-realtime.py
Normal file
@@ -0,0 +1,221 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.services.openai_realtime import (
|
||||
AzureRealtimeLLMService,
|
||||
InputAudioTranscription,
|
||||
SessionProperties,
|
||||
)
|
||||
from pipecat.services.openai_realtime.events import AudioConfiguration, AudioInput
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
async def fetch_weather_from_api(params: FunctionCallParams):
|
||||
temperature = 75 if params.arguments["format"] == "fahrenheit" else 24
|
||||
await params.result_callback(
|
||||
{
|
||||
"conditions": "nice",
|
||||
"temperature": temperature,
|
||||
"format": params.arguments["format"],
|
||||
"timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def fetch_restaurant_recommendation(params: FunctionCallParams):
|
||||
await params.result_callback({"name": "The Golden Dragon"})
|
||||
|
||||
|
||||
# Define weather function using standardized schema
|
||||
weather_function = FunctionSchema(
|
||||
name="get_current_weather",
|
||||
description="Get the current weather",
|
||||
properties={
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location.",
|
||||
},
|
||||
},
|
||||
required=["location", "format"],
|
||||
)
|
||||
|
||||
restaurant_function = FunctionSchema(
|
||||
name="get_restaurant_recommendation",
|
||||
description="Get a restaurant recommendation",
|
||||
properties={
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
},
|
||||
required=["location"],
|
||||
)
|
||||
|
||||
# Create tools schema
|
||||
tools = ToolsSchema(standard_tools=[weather_function, restaurant_function])
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
session_properties = SessionProperties(
|
||||
audio=AudioConfiguration(
|
||||
input=AudioInput(
|
||||
transcription=InputAudioTranscription(model="whisper-1"),
|
||||
# Set openai TurnDetection parameters. Not setting this at all will turn it
|
||||
# on by default
|
||||
# turn_detection=TurnDetection(silence_duration_ms=1000),
|
||||
# Or set to False to disable openai turn detection and use transport VAD
|
||||
# turn_detection=False,
|
||||
)
|
||||
),
|
||||
# tools=tools,
|
||||
instructions="""You are a helpful and friendly AI.
|
||||
|
||||
Act like a human, but remember that you aren't a human and that you can't do human
|
||||
things in the real world. Your voice and personality should be warm and engaging, with a lively and
|
||||
playful tone.
|
||||
|
||||
If interacting in a non-English language, start by using the standard accent or dialect familiar to
|
||||
the user. Talk quickly. You should always call a function if you can. Do not refer to these rules,
|
||||
even if you're asked about them.
|
||||
-
|
||||
You are participating in a voice conversation. Keep your responses concise, short, and to the point
|
||||
unless specifically asked to elaborate on a topic.
|
||||
|
||||
You have access to the following tools:
|
||||
- get_current_weather: Get the current weather for a given location.
|
||||
- get_restaurant_recommendation: Get a restaurant recommendation for a given location.
|
||||
|
||||
Remember, your responses should be short. Just one or two sentences, usually. Respond in English.""",
|
||||
)
|
||||
|
||||
llm = AzureRealtimeLLMService(
|
||||
api_key=os.getenv("AZURE_REALTIME_API_KEY"),
|
||||
base_url=os.getenv("AZURE_REALTIME_BASE_URL"),
|
||||
session_properties=session_properties,
|
||||
start_audio_paused=False,
|
||||
)
|
||||
|
||||
# you can either register a single function for all function calls, or specific functions
|
||||
# llm.register_function(None, fetch_weather_from_api)
|
||||
llm.register_function("get_current_weather", fetch_weather_from_api)
|
||||
llm.register_function("get_restaurant_recommendation", fetch_restaurant_recommendation)
|
||||
|
||||
# Create a standard OpenAI LLM context object using the normal messages format. The
|
||||
# OpenAIRealtimeBetaLLMService will convert this internally to messages that the
|
||||
# openai WebSocket API can understand.
|
||||
context = OpenAILLMContext(
|
||||
[{"role": "user", "content": "Say hello!"}],
|
||||
# [{"role": "user", "content": [{"type": "text", "text": "Say hello!"}]}],
|
||||
# [
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": [
|
||||
# {"type": "text", "text": "Say"},
|
||||
# {"type": "text", "text": "yo what's up!"},
|
||||
# ],
|
||||
# }
|
||||
# ],
|
||||
tools,
|
||||
)
|
||||
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
context_aggregator.user(),
|
||||
llm, # LLM
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(),
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -22,7 +22,7 @@ from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.transcript_processor import TranscriptProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia import CartesiaTTSService
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.services.openai_realtime_beta import (
|
||||
InputAudioNoiseReduction,
|
||||
|
||||
234
examples/foundational/19b-openai-realtime-text.py
Normal file
234
examples/foundational/19b-openai-realtime-text.py
Normal file
@@ -0,0 +1,234 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import LLMRunFrame, TranscriptionMessage
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.transcript_processor import TranscriptProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia import CartesiaTTSService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.services.openai_realtime import (
|
||||
InputAudioNoiseReduction,
|
||||
InputAudioTranscription,
|
||||
OpenAIRealtimeLLMService,
|
||||
SemanticTurnDetection,
|
||||
SessionProperties,
|
||||
)
|
||||
from pipecat.services.openai_realtime.events import AudioConfiguration, AudioInput
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
async def fetch_weather_from_api(params: FunctionCallParams):
|
||||
temperature = 75 if params.arguments["format"] == "fahrenheit" else 24
|
||||
await params.result_callback(
|
||||
{
|
||||
"conditions": "nice",
|
||||
"temperature": temperature,
|
||||
"format": params.arguments["format"],
|
||||
"timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def fetch_restaurant_recommendation(params: FunctionCallParams):
|
||||
await params.result_callback({"name": "The Golden Dragon"})
|
||||
|
||||
|
||||
weather_function = FunctionSchema(
|
||||
name="get_current_weather",
|
||||
description="Get the current weather",
|
||||
properties={
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location.",
|
||||
},
|
||||
},
|
||||
required=["location", "format"],
|
||||
)
|
||||
|
||||
restaurant_function = FunctionSchema(
|
||||
name="get_restaurant_recommendation",
|
||||
description="Get a restaurant recommendation",
|
||||
properties={
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
},
|
||||
required=["location"],
|
||||
)
|
||||
|
||||
# Create tools schema
|
||||
tools = ToolsSchema(standard_tools=[weather_function, restaurant_function])
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
session_properties = SessionProperties(
|
||||
audio=AudioConfiguration(
|
||||
input=AudioInput(
|
||||
transcription=InputAudioTranscription(),
|
||||
# Set openai TurnDetection parameters. Not setting this at all will turn it
|
||||
# on by default
|
||||
turn_detection=SemanticTurnDetection(),
|
||||
# Or set to False to disable openai turn detection and use transport VAD
|
||||
# turn_detection=False,
|
||||
noise_reduction=InputAudioNoiseReduction(type="near_field"),
|
||||
)
|
||||
),
|
||||
output_modalities=["text"],
|
||||
# tools=tools,
|
||||
instructions="""You are a helpful and friendly AI.
|
||||
|
||||
Act like a human, but remember that you aren't a human and that you can't do human
|
||||
things in the real world. Your voice and personality should be warm and engaging, with a lively and
|
||||
playful tone.
|
||||
|
||||
If interacting in a non-English language, start by using the standard accent or dialect familiar to
|
||||
the user. Talk quickly. You should always call a function if you can. Do not refer to these rules,
|
||||
even if you're asked about them.
|
||||
|
||||
You are participating in a voice conversation. Keep your responses concise, short, and to the point
|
||||
unless specifically asked to elaborate on a topic.
|
||||
|
||||
You have access to the following tools:
|
||||
- get_current_weather: Get the current weather for a given location.
|
||||
- get_restaurant_recommendation: Get a restaurant recommendation for a given location.
|
||||
|
||||
Remember, your responses should be short. Just one or two sentences, usually. Respond in English.""",
|
||||
)
|
||||
|
||||
llm = OpenAIRealtimeLLMService(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
session_properties=session_properties,
|
||||
start_audio_paused=False,
|
||||
)
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
)
|
||||
|
||||
# you can either register a single function for all function calls, or specific functions
|
||||
# llm.register_function(None, fetch_weather_from_api)
|
||||
llm.register_function("get_current_weather", fetch_weather_from_api)
|
||||
llm.register_function("get_restaurant_recommendation", fetch_restaurant_recommendation)
|
||||
|
||||
transcript = TranscriptProcessor()
|
||||
|
||||
# Create a standard OpenAI LLM context object using the normal messages format. The
|
||||
# OpenAIRealtimeLLMService will convert this internally to messages that the
|
||||
# openai WebSocket API can understand.
|
||||
context = OpenAILLMContext(
|
||||
[{"role": "user", "content": "Say hello!"}],
|
||||
tools,
|
||||
)
|
||||
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
context_aggregator.user(),
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transcript.user(), # Placed after the LLM, as LLM pushes TranscriptionFrames downstream
|
||||
transport.output(), # Transport bot output
|
||||
transcript.assistant(), # After the transcript output, to time with the audio output
|
||||
context_aggregator.assistant(),
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
# Register event handler for transcript updates
|
||||
@transcript.event_handler("on_transcript_update")
|
||||
async def on_transcript_update(processor, frame):
|
||||
for msg in frame.messages:
|
||||
if isinstance(msg, TranscriptionMessage):
|
||||
timestamp = f"[{msg.timestamp}] " if msg.timestamp else ""
|
||||
line = f"{timestamp}{msg.role}: {msg.content}"
|
||||
logger.info(f"Transcript: {line}")
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -0,0 +1,274 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
)
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.services.openai_realtime_beta import (
|
||||
InputAudioTranscription,
|
||||
OpenAIRealtimeBetaLLMService,
|
||||
SessionProperties,
|
||||
TurnDetection,
|
||||
)
|
||||
from pipecat.services.openai_realtime_beta.events import AudioConfiguration, AudioInput
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
BASE_FILENAME = "/tmp/pipecat_conversation_"
|
||||
|
||||
|
||||
async def fetch_weather_from_api(params: FunctionCallParams):
|
||||
temperature = 75 if params.arguments["format"] == "fahrenheit" else 24
|
||||
await params.result_callback(
|
||||
{
|
||||
"conditions": "nice",
|
||||
"temperature": temperature,
|
||||
"format": params.arguments["format"],
|
||||
"timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def get_saved_conversation_filenames(params: FunctionCallParams):
|
||||
# Construct the full pattern including the BASE_FILENAME
|
||||
full_pattern = f"{BASE_FILENAME}*.json"
|
||||
|
||||
# Use glob to find all matching files
|
||||
matching_files = glob.glob(full_pattern)
|
||||
logger.debug(f"matching files: {matching_files}")
|
||||
|
||||
await params.result_callback({"filenames": matching_files})
|
||||
|
||||
|
||||
async def save_conversation(params: FunctionCallParams):
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
|
||||
filename = f"{BASE_FILENAME}{timestamp}.json"
|
||||
logger.debug(
|
||||
f"writing conversation to {filename}\n{json.dumps(params.context.messages, indent=4)}"
|
||||
)
|
||||
try:
|
||||
with open(filename, "w") as file:
|
||||
messages = params.context.get_messages_for_persistent_storage()
|
||||
# remove the last message, which is the instruction we just gave to save the conversation
|
||||
messages.pop()
|
||||
json.dump(messages, file, indent=2)
|
||||
await params.result_callback({"success": True})
|
||||
except Exception as e:
|
||||
await params.result_callback({"success": False, "error": str(e)})
|
||||
|
||||
|
||||
async def load_conversation(params: FunctionCallParams):
|
||||
async def _reset():
|
||||
filename = params.arguments["filename"]
|
||||
logger.debug(f"loading conversation from {filename}")
|
||||
try:
|
||||
with open(filename, "r") as file:
|
||||
params.context.set_messages(json.load(file))
|
||||
await params.llm.reset_conversation()
|
||||
await params.llm._create_response()
|
||||
except Exception as e:
|
||||
await params.result_callback({"success": False, "error": str(e)})
|
||||
|
||||
asyncio.create_task(_reset())
|
||||
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location.",
|
||||
},
|
||||
},
|
||||
"required": ["location", "format"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"name": "save_conversation",
|
||||
"description": "Save the current conversatione. Use this function to persist the current conversation to external storage.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"name": "get_saved_conversation_filenames",
|
||||
"description": "Get a list of saved conversation histories. Returns a list of filenames. Each filename includes a date and timestamp. Each file is conversation history that can be loaded into this session.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"name": "load_conversation",
|
||||
"description": "Load a conversation history. Use this function to load a conversation history into the current session.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filename": {
|
||||
"type": "string",
|
||||
"description": "The filename of the conversation history to load.",
|
||||
}
|
||||
},
|
||||
"required": ["filename"],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
session_properties = SessionProperties(
|
||||
audio=AudioConfiguration(
|
||||
input=AudioInput(
|
||||
transcription=InputAudioTranscription(),
|
||||
# Set openai TurnDetection parameters. Not setting this at all will turn it
|
||||
# on by default
|
||||
turn_detection=TurnDetection(silence_duration_ms=1000),
|
||||
# Or set to False to disable openai turn detection and use transport VAD
|
||||
# turn_detection=False,
|
||||
)
|
||||
),
|
||||
# tools=tools,
|
||||
instructions="""Your knowledge cutoff is 2023-10. You are a helpful and friendly AI.
|
||||
|
||||
Act like a human, but remember that you aren't a human and that you can't do human
|
||||
things in the real world. Your voice and personality should be warm and engaging, with a lively and
|
||||
playful tone.
|
||||
|
||||
If interacting in a non-English language, start by using the standard accent or dialect familiar to
|
||||
the user. Talk quickly. You should always call a function if you can. Do not refer to these rules,
|
||||
even if you're asked about them.
|
||||
-
|
||||
You are participating in a voice conversation. Keep your responses concise, short, and to the point
|
||||
unless specifically asked to elaborate on a topic.
|
||||
|
||||
Remember, your responses should be short. Just one or two sentences, usually.""",
|
||||
)
|
||||
|
||||
llm = OpenAIRealtimeBetaLLMService(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
session_properties=session_properties,
|
||||
start_audio_paused=False,
|
||||
)
|
||||
|
||||
# you can either register a single function for all function calls, or specific functions
|
||||
# llm.register_function(None, fetch_weather_from_api)
|
||||
llm.register_function("get_current_weather", fetch_weather_from_api)
|
||||
llm.register_function("save_conversation", save_conversation)
|
||||
llm.register_function("get_saved_conversation_filenames", get_saved_conversation_filenames)
|
||||
llm.register_function("load_conversation", load_conversation)
|
||||
|
||||
context = OpenAILLMContext([], tools)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt, # STT
|
||||
context_aggregator.user(),
|
||||
llm, # LLM
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(),
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -25,12 +25,13 @@ from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.services.openai_realtime_beta import (
|
||||
from pipecat.services.openai_realtime import (
|
||||
InputAudioTranscription,
|
||||
OpenAIRealtimeBetaLLMService,
|
||||
OpenAIRealtimeLLMService,
|
||||
SessionProperties,
|
||||
TurnDetection,
|
||||
)
|
||||
from pipecat.services.openai_realtime.events import AudioConfiguration, AudioInput
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
@@ -182,12 +183,16 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
session_properties = SessionProperties(
|
||||
input_audio_transcription=InputAudioTranscription(),
|
||||
# Set openai TurnDetection parameters. Not setting this at all will turn it
|
||||
# on by default
|
||||
turn_detection=TurnDetection(silence_duration_ms=1000),
|
||||
# Or set to False to disable openai turn detection and use transport VAD
|
||||
# turn_detection=False,
|
||||
audio=AudioConfiguration(
|
||||
input=AudioInput(
|
||||
transcription=InputAudioTranscription(),
|
||||
# Set openai TurnDetection parameters. Not setting this at all will turn it
|
||||
# on by default
|
||||
turn_detection=TurnDetection(silence_duration_ms=1000),
|
||||
# Or set to False to disable openai turn detection and use transport VAD
|
||||
# turn_detection=False,
|
||||
)
|
||||
),
|
||||
# tools=tools,
|
||||
instructions="""Your knowledge cutoff is 2023-10. You are a helpful and friendly AI.
|
||||
|
||||
@@ -205,7 +210,7 @@ unless specifically asked to elaborate on a topic.
|
||||
Remember, your responses should be short. Just one or two sentences, usually.""",
|
||||
)
|
||||
|
||||
llm = OpenAIRealtimeBetaLLMService(
|
||||
llm = OpenAIRealtimeLLMService(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
session_properties=session_properties,
|
||||
start_audio_paused=False,
|
||||
|
||||
@@ -18,9 +18,9 @@ from pipecat.frames.frames import (
|
||||
Frame,
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
InterruptionFrame,
|
||||
LLMRunFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
SystemFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
@@ -144,7 +144,7 @@ class OutputGate(FrameProcessor):
|
||||
await self._start()
|
||||
if isinstance(frame, (EndFrame, CancelFrame)):
|
||||
await self._stop()
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
if isinstance(frame, InterruptionFrame):
|
||||
self._frames_buffer = []
|
||||
self.close_gate()
|
||||
await self.push_frame(frame, direction)
|
||||
@@ -232,7 +232,7 @@ class TurnDetectionLLM(Pipeline):
|
||||
async def pass_only_llm_trigger_frames(frame):
|
||||
return (
|
||||
isinstance(frame, OpenAILLMContextFrame)
|
||||
or isinstance(frame, StartInterruptionFrame)
|
||||
or isinstance(frame, InterruptionFrame)
|
||||
or isinstance(frame, FunctionCallInProgressFrame)
|
||||
or isinstance(frame, FunctionCallResultFrame)
|
||||
)
|
||||
|
||||
@@ -18,9 +18,9 @@ from pipecat.frames.frames import (
|
||||
Frame,
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
InterruptionFrame,
|
||||
LLMRunFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
SystemFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
@@ -347,7 +347,7 @@ class OutputGate(FrameProcessor):
|
||||
await self._start()
|
||||
if isinstance(frame, (EndFrame, CancelFrame)):
|
||||
await self._stop()
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
if isinstance(frame, InterruptionFrame):
|
||||
self._frames_buffer = []
|
||||
self.close_gate()
|
||||
await self.push_frame(frame, direction)
|
||||
@@ -426,7 +426,7 @@ class TurnDetectionLLM(Pipeline):
|
||||
async def pass_only_llm_trigger_frames(frame):
|
||||
return (
|
||||
isinstance(frame, OpenAILLMContextFrame)
|
||||
or isinstance(frame, StartInterruptionFrame)
|
||||
or isinstance(frame, InterruptionFrame)
|
||||
or isinstance(frame, FunctionCallInProgressFrame)
|
||||
or isinstance(frame, FunctionCallResultFrame)
|
||||
)
|
||||
|
||||
@@ -20,10 +20,10 @@ from pipecat.frames.frames import (
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
InputAudioRawFrame,
|
||||
InterruptionFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMRunFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
SystemFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
@@ -570,7 +570,7 @@ class OutputGate(FrameProcessor):
|
||||
await self._start()
|
||||
if isinstance(frame, (EndFrame, CancelFrame)):
|
||||
await self._stop()
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
if isinstance(frame, InterruptionFrame):
|
||||
self._frames_buffer = []
|
||||
self.close_gate()
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -15,8 +15,8 @@ from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
EndFrame,
|
||||
InterruptionFrame,
|
||||
LLMRunFrame,
|
||||
StartInterruptionFrame,
|
||||
TTSTextFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
)
|
||||
@@ -48,7 +48,7 @@ class CustomObserver(BaseObserver):
|
||||
"""Observer to log interruptions and bot speaking events to the console.
|
||||
|
||||
Logs all frame instances of:
|
||||
- StartInterruptionFrame
|
||||
- InterruptionFrame
|
||||
- BotStartedSpeakingFrame
|
||||
- BotStoppedSpeakingFrame
|
||||
|
||||
@@ -69,7 +69,7 @@ class CustomObserver(BaseObserver):
|
||||
# Create direction arrow
|
||||
arrow = "→" if direction == FrameDirection.DOWNSTREAM else "←"
|
||||
|
||||
if isinstance(frame, StartInterruptionFrame) and isinstance(src, BaseOutputTransport):
|
||||
if isinstance(frame, InterruptionFrame) and isinstance(src, BaseOutputTransport):
|
||||
logger.info(f"⚡ INTERRUPTION START: {src} {arrow} {dst} at {time_sec:.2f}s")
|
||||
elif isinstance(frame, BotStartedSpeakingFrame):
|
||||
logger.info(f"🤖 BOT START SPEAKING: {src} {arrow} {dst} at {time_sec:.2f}s")
|
||||
|
||||
@@ -11,7 +11,7 @@ from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v2 import LocalSmartTurnAnalyzerV2
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
@@ -30,23 +30,6 @@ from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
# To use this locally, set the environment variable LOCAL_SMART_TURN_MODEL_PATH
|
||||
# to the path where the smart-turn repo is cloned.
|
||||
#
|
||||
# Example setup:
|
||||
#
|
||||
# # Git LFS (Large File Storage)
|
||||
# brew install git-lfs
|
||||
# # Hugging Face uses LFS to store large model files, including .mlpackage
|
||||
# git lfs install
|
||||
# # Clone the repo with the smart_turn_classifier.mlpackage
|
||||
# git clone https://huggingface.co/pipecat-ai/smart-turn-v2
|
||||
#
|
||||
# Then set the env variable:
|
||||
# export LOCAL_SMART_TURN_MODEL_PATH=./smart-turn
|
||||
# or add it to your .env file
|
||||
smart_turn_model_path = os.getenv("LOCAL_SMART_TURN_MODEL_PATH")
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
@@ -55,25 +38,19 @@ transport_params = {
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV2(
|
||||
smart_turn_model_path=smart_turn_model_path, params=SmartTurnParams()
|
||||
),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV2(
|
||||
smart_turn_model_path=smart_turn_model_path, params=SmartTurnParams()
|
||||
),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV2(
|
||||
smart_turn_model_path=smart_turn_model_path, params=SmartTurnParams()
|
||||
),
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -21,9 +21,10 @@ from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.google.llm import GoogleLLMService
|
||||
from pipecat.services.heygen.api import AvatarQuality, NewSessionRequest
|
||||
from pipecat.services.heygen.video import HeyGenVideoService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.daily.transport import DailyParams, DailyTransport
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
@@ -38,6 +39,7 @@ transport_params = {
|
||||
video_out_is_live=True,
|
||||
video_out_width=1280,
|
||||
video_out_height=720,
|
||||
video_out_bitrate=2_000_000, # 2MBps
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
@@ -64,7 +66,13 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
|
||||
llm = GoogleLLMService(api_key=os.getenv("GOOGLE_API_KEY"))
|
||||
|
||||
heyGen = HeyGenVideoService(api_key=os.getenv("HEYGEN_API_KEY"), session=session)
|
||||
heyGen = HeyGenVideoService(
|
||||
api_key=os.getenv("HEYGEN_API_KEY"),
|
||||
session=session,
|
||||
session_request=NewSessionRequest(
|
||||
avatar_id="Shawn_Therapist_public", version="v2", quality=AvatarQuality.high
|
||||
),
|
||||
)
|
||||
|
||||
messages = [
|
||||
{
|
||||
@@ -101,6 +109,18 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Updating publishing settings to enable adaptive bitrate
|
||||
if isinstance(transport, DailyTransport):
|
||||
await transport.update_publishing(
|
||||
publishing_settings={
|
||||
"camera": {
|
||||
"sendSettings": {
|
||||
"allowAdaptiveLayers": True,
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Kick off the conversation.
|
||||
messages.append(
|
||||
{
|
||||
|
||||
@@ -4,7 +4,7 @@ version = "0.1.0"
|
||||
description = "Quickstart example for building voice AI bots with Pipecat"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"pipecat-ai[webrtc,daily,silero,deepgram,openai,cartesia,runner]>=0.0.82",
|
||||
"pipecat-ai[webrtc,daily,silero,deepgram,openai,cartesia,runner]>=0.0.83",
|
||||
"pipecatcloud>=0.2.4"
|
||||
]
|
||||
|
||||
|
||||
3158
examples/quickstart/uv.lock
generated
3158
examples/quickstart/uv.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -47,32 +47,32 @@ Website = "https://pipecat.ai"
|
||||
[project.optional-dependencies]
|
||||
aic = [ "aic-sdk~=1.0.1" ]
|
||||
anthropic = [ "anthropic~=0.49.0" ]
|
||||
assemblyai = [ "websockets>=13.1,<15.0" ]
|
||||
asyncai = [ "websockets>=13.1,<15.0" ]
|
||||
aws = [ "aioboto3~=15.0.0", "websockets>=13.1,<15.0" ]
|
||||
assemblyai = [ "pipecat-ai[websockets-base]" ]
|
||||
asyncai = [ "pipecat-ai[websockets-base]" ]
|
||||
aws = [ "aioboto3~=15.0.0", "pipecat-ai[websockets-base]" ]
|
||||
aws-nova-sonic = [ "aws_sdk_bedrock_runtime~=0.0.2; python_version>='3.12'" ]
|
||||
azure = [ "azure-cognitiveservices-speech~=1.42.0"]
|
||||
cartesia = [ "cartesia~=2.0.3", "websockets>=13.1,<15.0" ]
|
||||
cartesia = [ "cartesia~=2.0.3", "pipecat-ai[websockets-base]" ]
|
||||
cerebras = []
|
||||
deepseek = []
|
||||
daily = [ "daily-python~=0.19.8" ]
|
||||
daily = [ "daily-python~=0.19.9" ]
|
||||
deepgram = [ "deepgram-sdk~=4.7.0" ]
|
||||
elevenlabs = [ "websockets>=13.1,<15.0" ]
|
||||
elevenlabs = [ "pipecat-ai[websockets-base]" ]
|
||||
fal = [ "fal-client~=0.5.9" ]
|
||||
fireworks = []
|
||||
fish = [ "ormsgpack~=1.7.0", "websockets>=13.1,<15.0" ]
|
||||
gladia = [ "websockets>=13.1,<15.0" ]
|
||||
google = [ "google-cloud-speech~=2.32.0", "google-cloud-texttospeech~=2.26.0", "google-genai~=1.24.0", "websockets>=13.1,<15.0" ]
|
||||
fish = [ "ormsgpack~=1.7.0", "pipecat-ai[websockets-base]" ]
|
||||
gladia = [ "pipecat-ai[websockets-base]" ]
|
||||
google = [ "google-cloud-speech~=2.32.0", "google-cloud-texttospeech~=2.26.0", "google-genai~=1.24.0", "pipecat-ai[websockets-base]" ]
|
||||
grok = []
|
||||
groq = [ "groq~=0.23.0" ]
|
||||
gstreamer = [ "pygobject~=3.50.0" ]
|
||||
heygen = [ "livekit>=0.22.0", "websockets>=13.1,<15.0" ]
|
||||
heygen = [ "livekit>=1.0.13", "pipecat-ai[websockets-base]" ]
|
||||
inworld = []
|
||||
krisp = [ "pipecat-ai-krisp~=0.4.0" ]
|
||||
koala = [ "pvkoala~=2.0.3" ]
|
||||
langchain = [ "langchain~=0.3.20", "langchain-community~=0.3.20", "langchain-openai~=0.3.9" ]
|
||||
livekit = [ "livekit~=0.22.0", "livekit-api~=0.8.2", "tenacity>=8.2.3,<10.0.0" ]
|
||||
lmnt = [ "websockets>=13.1,<15.0" ]
|
||||
livekit = [ "livekit~=1.0.13", "livekit-api~=1.0.5", "tenacity>=8.2.3,<10.0.0" ]
|
||||
lmnt = [ "pipecat-ai[websockets-base]" ]
|
||||
local = [ "pyaudio~=0.2.14" ]
|
||||
mcp = [ "mcp[cli]~=1.9.4" ]
|
||||
mem0 = [ "mem0ai~=0.1.94" ]
|
||||
@@ -80,33 +80,35 @@ mistral = []
|
||||
mlx-whisper = [ "mlx-whisper~=0.4.2" ]
|
||||
moondream = [ "accelerate~=1.10.0", "einops~=0.8.0", "pyvips[binary]~=3.0.0", "timm~=1.0.13", "transformers>=4.48.0" ]
|
||||
nim = []
|
||||
neuphonic = [ "websockets>=13.1,<15.0" ]
|
||||
neuphonic = [ "pipecat-ai[websockets-base]" ]
|
||||
noisereduce = [ "noisereduce~=3.0.3" ]
|
||||
openai = [ "websockets>=13.1,<15.0" ]
|
||||
openai = [ "pipecat-ai[websockets-base]" ]
|
||||
openpipe = [ "openpipe~=4.50.0" ]
|
||||
openrouter = []
|
||||
perplexity = []
|
||||
playht = [ "websockets>=13.1,<15.0" ]
|
||||
playht = [ "pipecat-ai[websockets-base]" ]
|
||||
qwen = []
|
||||
rime = [ "websockets>=13.1,<15.0" ]
|
||||
rime = [ "pipecat-ai[websockets-base]" ]
|
||||
riva = [ "nvidia-riva-client~=2.21.1" ]
|
||||
runner = [ "python-dotenv>=1.0.0,<2.0.0", "uvicorn>=0.32.0,<1.0.0", "fastapi>=0.115.6,<0.117.0", "pipecat-ai-small-webrtc-prebuilt>=1.0.0"]
|
||||
sambanova = []
|
||||
sarvam = [ "websockets>=13.1,<15.0" ]
|
||||
sarvam = [ "pipecat-ai[websockets-base]" ]
|
||||
sentry = [ "sentry-sdk~=2.23.1" ]
|
||||
local-smart-turn = [ "coremltools>=8.0", "transformers", "torch>=2.5.0,<3", "torchaudio>=2.5.0,<3" ]
|
||||
local-smart-turn-v3 = [ "transformers", "onnxruntime>=1.20.1, <2" ]
|
||||
remote-smart-turn = []
|
||||
silero = [ "onnxruntime~=1.20.1" ]
|
||||
silero = [ "onnxruntime>=1.20.1, <2" ]
|
||||
simli = [ "simli-ai~=0.1.10"]
|
||||
soniox = [ "websockets>=13.1,<15.0" ]
|
||||
soniox = [ "pipecat-ai[websockets-base]" ]
|
||||
soundfile = [ "soundfile~=0.13.0" ]
|
||||
speechmatics = [ "speechmatics-rt>=0.4.0" ]
|
||||
tavus=[]
|
||||
together = []
|
||||
tracing = [ "opentelemetry-sdk>=1.33.0", "opentelemetry-api>=1.33.0", "opentelemetry-instrumentation>=0.54b0" ]
|
||||
ultravox = [ "transformers>=4.48.0", "vllm>=0.9.0" ]
|
||||
webrtc = [ "aiortc~=1.11.0", "opencv-python~=4.11.0.86" ]
|
||||
websocket = [ "websockets>=13.1,<15.0", "fastapi>=0.115.6,<0.117.0" ]
|
||||
webrtc = [ "aiortc~=1.13.0", "opencv-python~=4.11.0.86" ]
|
||||
websocket = [ "pipecat-ai[websockets-base]", "fastapi>=0.115.6,<0.117.0" ]
|
||||
websockets-base = [ "websockets>=13.1,<16.0" ]
|
||||
whisper = [ "faster-whisper~=1.1.1" ]
|
||||
|
||||
[dependency-groups]
|
||||
@@ -154,6 +156,7 @@ where = ["src"]
|
||||
"src/pipecat/audio/dtmf/dtmf-star.wav",
|
||||
]
|
||||
"pipecat.services.aws_nova_sonic" = ["src/pipecat/services/aws_nova_sonic/ready.wav"]
|
||||
"pipecat.audio.turn.smart_turn.data" = ["src/pipecat/audio/turn/smart_turn/data/smart-turn-v3.0.onnx"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "--verbose"
|
||||
|
||||
@@ -47,7 +47,7 @@ from pipecat.transports.daily.transport import DailyParams, DailyTransport
|
||||
SCRIPT_DIR = Path(__file__).resolve().parent
|
||||
|
||||
PIPELINE_IDLE_TIMEOUT_SECS = 60
|
||||
EVAL_TIMEOUT_SECS = 90
|
||||
EVAL_TIMEOUT_SECS = 120
|
||||
|
||||
EvalPrompt = str | Tuple[str, ImageFile]
|
||||
|
||||
@@ -266,8 +266,11 @@ async def run_eval_pipeline(
|
||||
elif isinstance(prompt, tuple):
|
||||
example_prompt, example_image = prompt
|
||||
|
||||
eval_prompt = f"The answer is correct if it's appropriate for the context and matches: {eval}."
|
||||
common_system_prompt = f"Call the eval function with your assessment only if the user answers the question. {eval_prompt}"
|
||||
eval_prompt = f"The answer is correct if it matches: {eval}."
|
||||
common_system_prompt = (
|
||||
"The user might say things other than the answer and that's allowed. "
|
||||
f"You should only call the eval function with your assessment when the user actually answers the question. {eval_prompt}"
|
||||
)
|
||||
if user_speaks_first:
|
||||
system_prompt = f"You are an LLM eval, be extremly brief. You will start the conversation by saying: '{example_prompt}'. {common_system_prompt}"
|
||||
else:
|
||||
|
||||
@@ -135,6 +135,25 @@ TESTS_14 = [
|
||||
("14r-function-calling-aws.py", PROMPT_WEATHER, EVAL_WEATHER, BOT_SPEAKS_FIRST),
|
||||
("14v-function-calling-openai.py", PROMPT_WEATHER, EVAL_WEATHER, BOT_SPEAKS_FIRST),
|
||||
("14w-function-calling-mistral.py", PROMPT_WEATHER, EVAL_WEATHER, BOT_SPEAKS_FIRST),
|
||||
("14x-function-calling-universal-context.py", PROMPT_WEATHER, EVAL_WEATHER, BOT_SPEAKS_FIRST),
|
||||
(
|
||||
"14y-function-calling-google-universal-context.py",
|
||||
PROMPT_WEATHER,
|
||||
EVAL_WEATHER,
|
||||
BOT_SPEAKS_FIRST,
|
||||
),
|
||||
(
|
||||
"14z-function-calling-anthropic-universal-context.py",
|
||||
PROMPT_WEATHER,
|
||||
EVAL_WEATHER,
|
||||
BOT_SPEAKS_FIRST,
|
||||
),
|
||||
(
|
||||
"14aa-function-calling-aws-universal-context.py",
|
||||
PROMPT_WEATHER,
|
||||
EVAL_WEATHER,
|
||||
BOT_SPEAKS_FIRST,
|
||||
),
|
||||
# Currently not working.
|
||||
# ("14c-function-calling-together.py", PROMPT_WEATHER, EVAL_WEATHER, BOT_SPEAKS_FIRST),
|
||||
# ("14l-function-calling-deepseek.py", PROMPT_WEATHER, EVAL_WEATHER, BOT_SPEAKS_FIRST),
|
||||
@@ -148,6 +167,7 @@ TESTS_15 = [
|
||||
TESTS_19 = [
|
||||
("19-openai-realtime-beta.py", PROMPT_WEATHER, EVAL_WEATHER, BOT_SPEAKS_FIRST),
|
||||
("19a-azure-realtime-beta.py", PROMPT_WEATHER, EVAL_WEATHER, BOT_SPEAKS_FIRST),
|
||||
("19b-openai-realtime-text.py", PROMPT_WEATHER, EVAL_WEATHER, BOT_SPEAKS_FIRST),
|
||||
("19b-openai-realtime-beta-text.py", PROMPT_WEATHER, EVAL_WEATHER, BOT_SPEAKS_FIRST),
|
||||
]
|
||||
|
||||
|
||||
12
scripts/mem-watch.sh
Executable file
12
scripts/mem-watch.sh
Executable file
@@ -0,0 +1,12 @@
|
||||
#!/bin/bash
|
||||
|
||||
PID=$1
|
||||
|
||||
while true; do
|
||||
# Clear the screen
|
||||
clear
|
||||
# Print the header + RSS in GB
|
||||
ps -p "$PID" -o pid,comm,rss | \
|
||||
awk 'NR==1 {print $0, "rss_GB"} NR>1 {printf "%s %s %s %.2f\n", $1,$2,$3,$3/1024/1024}'
|
||||
sleep 1
|
||||
done
|
||||
@@ -16,7 +16,12 @@ from typing import Any, Dict, Generic, List, TypeVar
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext, NotGiven
|
||||
from pipecat.processors.aggregators.llm_context import (
|
||||
LLMContext,
|
||||
LLMContextMessage,
|
||||
LLMSpecificMessage,
|
||||
NotGiven,
|
||||
)
|
||||
|
||||
# Should be a TypedDict
|
||||
TLLMInvocationParams = TypeVar("TLLMInvocationParams", bound=dict[str, Any])
|
||||
@@ -38,12 +43,23 @@ class BaseLLMAdapter(ABC, Generic[TLLMInvocationParams]):
|
||||
Subclasses must implement provider-specific conversion logic.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def get_llm_invocation_params(self, context: LLMContext) -> TLLMInvocationParams:
|
||||
def id_for_llm_specific_messages(self) -> str:
|
||||
"""Get the identifier used in LLMSpecificMessage instances for this LLM provider.
|
||||
|
||||
Returns:
|
||||
The identifier string.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_llm_invocation_params(self, context: LLMContext, **kwargs) -> TLLMInvocationParams:
|
||||
"""Get provider-specific LLM invocation parameters from a universal LLM context.
|
||||
|
||||
Args:
|
||||
context: The LLM context containing messages, tools, etc.
|
||||
**kwargs: Additional provider-specific arguments that subclasses can use.
|
||||
|
||||
Returns:
|
||||
Provider-specific parameters for invoking the LLM.
|
||||
@@ -75,6 +91,28 @@ class BaseLLMAdapter(ABC, Generic[TLLMInvocationParams]):
|
||||
"""
|
||||
pass
|
||||
|
||||
def create_llm_specific_message(self, message: Any) -> LLMSpecificMessage:
|
||||
"""Create an LLM-specific message (as opposed to a standard message) for use in an LLMContext.
|
||||
|
||||
Args:
|
||||
message: The message content.
|
||||
|
||||
Returns:
|
||||
A LLMSpecificMessage instance.
|
||||
"""
|
||||
return LLMSpecificMessage(llm=self.id_for_llm_specific_messages, message=message)
|
||||
|
||||
def get_messages(self, context: LLMContext) -> List[LLMContextMessage]:
|
||||
"""Get messages from the LLM context, including standard and LLM-specific messages.
|
||||
|
||||
Args:
|
||||
context: The LLM context containing messages.
|
||||
|
||||
Returns:
|
||||
List of messages including standard and LLM-specific messages.
|
||||
"""
|
||||
return context.get_messages(self.id_for_llm_specific_messages)
|
||||
|
||||
def from_standard_tools(self, tools: Any) -> List[Any] | NotGiven:
|
||||
"""Convert tools from standard format to provider format.
|
||||
|
||||
|
||||
@@ -6,21 +6,33 @@
|
||||
|
||||
"""Anthropic LLM adapter for Pipecat."""
|
||||
|
||||
import copy
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, TypedDict
|
||||
|
||||
from anthropic import NOT_GIVEN, NotGiven
|
||||
from anthropic.types.message_param import MessageParam
|
||||
from anthropic.types.tool_union_param import ToolUnionParam
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_context import (
|
||||
LLMContext,
|
||||
LLMContextMessage,
|
||||
LLMSpecificMessage,
|
||||
LLMStandardMessage,
|
||||
)
|
||||
|
||||
|
||||
class AnthropicLLMInvocationParams(TypedDict):
|
||||
"""Context-based parameters for invoking Anthropic's LLM API.
|
||||
"""Context-based parameters for invoking Anthropic's LLM API."""
|
||||
|
||||
This is a placeholder until support for universal LLMContext machinery is added for Anthropic.
|
||||
"""
|
||||
|
||||
pass
|
||||
system: str | NotGiven
|
||||
messages: List[MessageParam]
|
||||
tools: List[ToolUnionParam]
|
||||
|
||||
|
||||
class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]):
|
||||
@@ -30,33 +42,278 @@ class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]):
|
||||
to the specific format required by Anthropic's Claude models for function calling.
|
||||
"""
|
||||
|
||||
def get_llm_invocation_params(self, context: LLMContext) -> AnthropicLLMInvocationParams:
|
||||
"""Get Anthropic-specific LLM invocation parameters from a universal LLM context.
|
||||
@property
|
||||
def id_for_llm_specific_messages(self) -> str:
|
||||
"""Get the identifier used in LLMSpecificMessage instances for Anthropic."""
|
||||
return "anthropic"
|
||||
|
||||
This is a placeholder until support for universal LLMContext machinery is added for Anthropic.
|
||||
def get_llm_invocation_params(
|
||||
self, context: LLMContext, enable_prompt_caching: bool
|
||||
) -> AnthropicLLMInvocationParams:
|
||||
"""Get Anthropic-specific LLM invocation parameters from a universal LLM context.
|
||||
|
||||
Args:
|
||||
context: The LLM context containing messages, tools, etc.
|
||||
enable_prompt_caching: Whether prompt caching should be enabled.
|
||||
|
||||
Returns:
|
||||
Dictionary of parameters for invoking Anthropic's LLM API.
|
||||
"""
|
||||
raise NotImplementedError("Universal LLMContext is not yet supported for Anthropic.")
|
||||
messages = self._from_universal_context_messages(self.get_messages(context))
|
||||
return {
|
||||
"system": messages.system,
|
||||
"messages": (
|
||||
self._with_cache_control_markers(messages.messages)
|
||||
if enable_prompt_caching
|
||||
else messages.messages
|
||||
),
|
||||
# NOTE: LLMContext's tools are guaranteed to be a ToolsSchema (or NOT_GIVEN)
|
||||
"tools": self.from_standard_tools(context.tools) or [],
|
||||
}
|
||||
|
||||
def get_messages_for_logging(self, context) -> List[Dict[str, Any]]:
|
||||
def get_messages_for_logging(self, context: LLMContext) -> List[Dict[str, Any]]:
|
||||
"""Get messages from a universal LLM context in a format ready for logging about Anthropic.
|
||||
|
||||
Removes or truncates sensitive data like image content for safe logging.
|
||||
|
||||
This is a placeholder until support for universal LLMContext machinery is added for Anthropic.
|
||||
|
||||
Args:
|
||||
context: The LLM context containing messages.
|
||||
|
||||
Returns:
|
||||
List of messages in a format ready for logging about Anthropic.
|
||||
"""
|
||||
raise NotImplementedError("Universal LLMContext is not yet supported for Anthropic.")
|
||||
# Get messages in Anthropic's format
|
||||
messages = self._from_universal_context_messages(self.get_messages(context)).messages
|
||||
|
||||
# Sanitize messages for logging
|
||||
messages_for_logging = []
|
||||
for message in messages:
|
||||
msg = copy.deepcopy(message)
|
||||
if "content" in msg:
|
||||
if isinstance(msg["content"], list):
|
||||
for item in msg["content"]:
|
||||
if item["type"] == "image":
|
||||
item["source"]["data"] = "..."
|
||||
messages_for_logging.append(msg)
|
||||
return messages_for_logging
|
||||
|
||||
@dataclass
|
||||
class ConvertedMessages:
|
||||
"""Container for Anthropic-formatted messages converted from universal context."""
|
||||
|
||||
messages: List[MessageParam]
|
||||
system: str | NotGiven
|
||||
|
||||
def _from_universal_context_messages(
|
||||
self, universal_context_messages: List[LLMContextMessage]
|
||||
) -> ConvertedMessages:
|
||||
system = NOT_GIVEN
|
||||
messages = []
|
||||
|
||||
# first, map messages using self._from_universal_context_message(m)
|
||||
try:
|
||||
messages = [self._from_universal_context_message(m) for m in universal_context_messages]
|
||||
except Exception as e:
|
||||
logger.error(f"Error mapping messages: {e}")
|
||||
|
||||
# See if we should pull the system message out of our messages list.
|
||||
if messages and messages[0]["role"] == "system":
|
||||
if len(messages) == 1:
|
||||
# If we have only have a system message in the list, all we can really do
|
||||
# without introducing too much magic is change the role to "user".
|
||||
messages[0]["role"] = "user"
|
||||
else:
|
||||
# If we have more than one message, we'll pull the system message out of the
|
||||
# list.
|
||||
system = messages[0]["content"]
|
||||
messages.pop(0)
|
||||
|
||||
# Convert any subsequent "system"-role messages to "user"-role
|
||||
# messages, as Anthropic doesn't support system input messages.
|
||||
for message in messages:
|
||||
if message["role"] == "system":
|
||||
message["role"] = "user"
|
||||
|
||||
# Merge consecutive messages with the same role.
|
||||
i = 0
|
||||
while i < len(messages) - 1:
|
||||
current_message = messages[i]
|
||||
next_message = messages[i + 1]
|
||||
if current_message["role"] == next_message["role"]:
|
||||
# Convert content to list of dictionaries if it's a string
|
||||
if isinstance(current_message["content"], str):
|
||||
current_message["content"] = [
|
||||
{"type": "text", "text": current_message["content"]}
|
||||
]
|
||||
if isinstance(next_message["content"], str):
|
||||
next_message["content"] = [{"type": "text", "text": next_message["content"]}]
|
||||
# Concatenate the content
|
||||
current_message["content"].extend(next_message["content"])
|
||||
# Remove the next message from the list
|
||||
messages.pop(i + 1)
|
||||
else:
|
||||
i += 1
|
||||
|
||||
# Avoid empty content in messages
|
||||
for message in messages:
|
||||
if isinstance(message["content"], str) and message["content"] == "":
|
||||
message["content"] = "(empty)"
|
||||
elif isinstance(message["content"], list) and len(message["content"]) == 0:
|
||||
message["content"] = [{"type": "text", "text": "(empty)"}]
|
||||
|
||||
return self.ConvertedMessages(messages=messages, system=system)
|
||||
|
||||
def _from_universal_context_message(self, message: LLMContextMessage) -> MessageParam:
|
||||
if isinstance(message, LLMSpecificMessage):
|
||||
return copy.deepcopy(message.message)
|
||||
return self._from_standard_message(message)
|
||||
|
||||
def _from_standard_message(self, message: LLMStandardMessage) -> MessageParam:
|
||||
"""Convert standard universal context message to Anthropic format.
|
||||
|
||||
Handles conversion of text content, tool calls, and tool results.
|
||||
Empty text content is converted to "(empty)".
|
||||
|
||||
Args:
|
||||
message: Message in standard universal context format.
|
||||
|
||||
Returns:
|
||||
Message in Anthropic format.
|
||||
|
||||
Examples:
|
||||
Input standard format::
|
||||
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "123",
|
||||
"function": {"name": "search", "arguments": '{"q": "test"}'}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
Output Anthropic format::
|
||||
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "123",
|
||||
"name": "search",
|
||||
"input": {"q": "test"}
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
message = copy.deepcopy(message)
|
||||
if message["role"] == "tool":
|
||||
return {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": message["tool_call_id"],
|
||||
"content": message["content"],
|
||||
},
|
||||
],
|
||||
}
|
||||
if message.get("tool_calls"):
|
||||
tc = message["tool_calls"]
|
||||
ret = {"role": "assistant", "content": []}
|
||||
for tool_call in tc:
|
||||
function = tool_call["function"]
|
||||
arguments = json.loads(function["arguments"])
|
||||
new_tool_use = {
|
||||
"type": "tool_use",
|
||||
"id": tool_call["id"],
|
||||
"name": function["name"],
|
||||
"input": arguments,
|
||||
}
|
||||
ret["content"].append(new_tool_use)
|
||||
return ret
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
# fix empty text
|
||||
if content == "":
|
||||
content = "(empty)"
|
||||
elif isinstance(content, list):
|
||||
for item in content:
|
||||
# fix empty text
|
||||
if item["type"] == "text" and item["text"] == "":
|
||||
item["text"] = "(empty)"
|
||||
# handle image_url -> image conversion
|
||||
if item["type"] == "image_url":
|
||||
item["type"] = "image"
|
||||
item["source"] = {
|
||||
"type": "base64",
|
||||
"media_type": "image/jpeg",
|
||||
"data": item["image_url"]["url"].split(",")[1],
|
||||
}
|
||||
del item["image_url"]
|
||||
# In the case where there's a single image in the list (like what
|
||||
# would result from a UserImageRawFrame), ensure that the image
|
||||
# comes before text, as recommended by Anthropic docs
|
||||
# (https://docs.anthropic.com/en/docs/build-with-claude/vision#example-one-image)
|
||||
image_indices = [i for i, item in enumerate(content) if item["type"] == "image"]
|
||||
text_indices = [i for i, item in enumerate(content) if item["type"] == "text"]
|
||||
if len(image_indices) == 1 and text_indices:
|
||||
img_idx = image_indices[0]
|
||||
first_txt_idx = text_indices[0]
|
||||
if img_idx > first_txt_idx:
|
||||
# Move image before the first text
|
||||
image_item = content.pop(img_idx)
|
||||
content.insert(first_txt_idx, image_item)
|
||||
|
||||
return message
|
||||
|
||||
def _with_cache_control_markers(self, messages: List[MessageParam]) -> List[MessageParam]:
|
||||
"""Add cache control markers to messages for prompt caching.
|
||||
|
||||
Args:
|
||||
messages: List of messages in Anthropic format.
|
||||
|
||||
Returns:
|
||||
List of messages with cache control markers added.
|
||||
"""
|
||||
|
||||
def add_cache_control_marker(message: MessageParam):
|
||||
if isinstance(message["content"], str):
|
||||
message["content"] = [{"type": "text", "text": message["content"]}]
|
||||
message["content"][-1]["cache_control"] = {"type": "ephemeral"}
|
||||
|
||||
try:
|
||||
# Add cache control markers to the most recent two user messages.
|
||||
# - The marker at the most recent user message tells Anthropic to
|
||||
# cache the prompt up to that point.
|
||||
# - The marker at the second-most-recent user message tells Anthropic
|
||||
# to look up the cached prompt that goes up to that point (the
|
||||
# point that *was* the last user message the previous turn).
|
||||
# If we only added the marker to the last user message, we'd only
|
||||
# ever be adding to the cache, never looking up from it.
|
||||
# Why user messages? We're assuming that we're primarily running
|
||||
# inference as soon as user turns come in. In Anthropic, turns
|
||||
# strictly alternate between user and assistant.
|
||||
|
||||
messages_with_markers = copy.deepcopy(messages)
|
||||
|
||||
# Find the most recent two user messages
|
||||
user_message_indices = []
|
||||
for i in range(len(messages_with_markers) - 1, -1, -1):
|
||||
if messages_with_markers[i]["role"] == "user":
|
||||
user_message_indices.append(i)
|
||||
if len(user_message_indices) == 2:
|
||||
break
|
||||
|
||||
# Add cache control markers to the identified user messages
|
||||
for index in user_message_indices:
|
||||
add_cache_control_marker(messages_with_markers[index])
|
||||
|
||||
return messages_with_markers
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding cache control marker: {e}")
|
||||
return messages_with_markers
|
||||
|
||||
@staticmethod
|
||||
def _to_anthropic_function_format(function: FunctionSchema) -> Dict[str, Any]:
|
||||
|
||||
@@ -31,6 +31,11 @@ class AWSNovaSonicLLMAdapter(BaseLLMAdapter[AWSNovaSonicLLMInvocationParams]):
|
||||
specific function-calling format, enabling tool use with Nova Sonic models.
|
||||
"""
|
||||
|
||||
@property
|
||||
def id_for_llm_specific_messages(self) -> str:
|
||||
"""Get the identifier used in LLMSpecificMessage instances for AWS Nova Sonic."""
|
||||
raise NotImplementedError("Universal LLMContext is not yet supported for AWS Nova Sonic.")
|
||||
|
||||
def get_llm_invocation_params(self, context: LLMContext) -> AWSNovaSonicLLMInvocationParams:
|
||||
"""Get AWS Nova Sonic-specific LLM invocation parameters from a universal LLM context.
|
||||
|
||||
|
||||
@@ -6,21 +6,33 @@
|
||||
|
||||
"""AWS Bedrock LLM adapter for Pipecat."""
|
||||
|
||||
from typing import Any, Dict, List, TypedDict
|
||||
import base64
|
||||
import copy
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Literal, Optional, TypedDict
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_context import (
|
||||
LLMContext,
|
||||
LLMContextMessage,
|
||||
LLMContextToolChoice,
|
||||
LLMSpecificMessage,
|
||||
LLMStandardMessage,
|
||||
)
|
||||
|
||||
|
||||
class AWSBedrockLLMInvocationParams(TypedDict):
|
||||
"""Context-based parameters for invoking AWS Bedrock's LLM API.
|
||||
"""Context-based parameters for invoking AWS Bedrock's LLM API."""
|
||||
|
||||
This is a placeholder until support for universal LLMContext machinery is added for Bedrock.
|
||||
"""
|
||||
|
||||
pass
|
||||
system: Optional[List[dict[str, Any]]] # [{"text": "system message"}]
|
||||
messages: List[dict[str, Any]]
|
||||
tools: List[dict[str, Any]]
|
||||
tool_choice: LLMContextToolChoice
|
||||
|
||||
|
||||
class AWSBedrockLLMAdapter(BaseLLMAdapter[AWSBedrockLLMInvocationParams]):
|
||||
@@ -30,33 +42,244 @@ class AWSBedrockLLMAdapter(BaseLLMAdapter[AWSBedrockLLMInvocationParams]):
|
||||
into AWS Bedrock's expected tool format for function calling capabilities.
|
||||
"""
|
||||
|
||||
@property
|
||||
def id_for_llm_specific_messages(self) -> str:
|
||||
"""Get the identifier used in LLMSpecificMessage instances for AWS Bedrock."""
|
||||
return "aws"
|
||||
|
||||
def get_llm_invocation_params(self, context: LLMContext) -> AWSBedrockLLMInvocationParams:
|
||||
"""Get AWS Bedrock-specific LLM invocation parameters from a universal LLM context.
|
||||
|
||||
This is a placeholder until support for universal LLMContext machinery is added for Bedrock.
|
||||
|
||||
Args:
|
||||
context: The LLM context containing messages, tools, etc.
|
||||
|
||||
Returns:
|
||||
Dictionary of parameters for invoking AWS Bedrock's LLM API.
|
||||
"""
|
||||
raise NotImplementedError("Universal LLMContext is not yet supported for AWS Bedrock.")
|
||||
messages = self._from_universal_context_messages(self.get_messages(context))
|
||||
return {
|
||||
"system": messages.system,
|
||||
"messages": messages.messages,
|
||||
# NOTE: LLMContext's tools are guaranteed to be a ToolsSchema (or NOT_GIVEN)
|
||||
"tools": self.from_standard_tools(context.tools) or [],
|
||||
# To avoid refactoring in AWSBedrockLLMService, we just pass through tool_choice.
|
||||
# Eventually (when we don't have to maintain the non-LLMContext code path) we should do
|
||||
# the conversion to Bedrock's expected format here rather than in AWSBedrockLLMService.
|
||||
"tool_choice": context.tool_choice,
|
||||
}
|
||||
|
||||
def get_messages_for_logging(self, context) -> List[Dict[str, Any]]:
|
||||
"""Get messages from a universal LLM context in a format ready for logging about AWS Bedrock.
|
||||
|
||||
Removes or truncates sensitive data like image content for safe logging.
|
||||
|
||||
This is a placeholder until support for universal LLMContext machinery is added for Bedrock.
|
||||
|
||||
Args:
|
||||
context: The LLM context containing messages.
|
||||
|
||||
Returns:
|
||||
List of messages in a format ready for logging about AWS Bedrock.
|
||||
"""
|
||||
raise NotImplementedError("Universal LLMContext is not yet supported for AWS Bedrock.")
|
||||
# Get messages in Anthropic's format
|
||||
messages = self._from_universal_context_messages(self.get_messages(context)).messages
|
||||
|
||||
# Sanitize messages for logging
|
||||
messages_for_logging = []
|
||||
for message in messages:
|
||||
msg = copy.deepcopy(message)
|
||||
if "content" in msg:
|
||||
if isinstance(msg["content"], list):
|
||||
for item in msg["content"]:
|
||||
if item.get("image"):
|
||||
item["image"]["source"]["bytes"] = "..."
|
||||
messages_for_logging.append(msg)
|
||||
return messages_for_logging
|
||||
|
||||
@dataclass
|
||||
class ConvertedMessages:
|
||||
"""Container for Anthropic-formatted messages converted from universal context."""
|
||||
|
||||
messages: List[dict[str, Any]]
|
||||
system: Optional[str]
|
||||
|
||||
def _from_universal_context_messages(
|
||||
self, universal_context_messages: List[LLMContextMessage]
|
||||
) -> ConvertedMessages:
|
||||
system = None
|
||||
messages = []
|
||||
|
||||
# first, map messages using self._from_universal_context_message(m)
|
||||
try:
|
||||
messages = [self._from_universal_context_message(m) for m in universal_context_messages]
|
||||
except Exception as e:
|
||||
logger.error(f"Error mapping messages: {e}")
|
||||
|
||||
# See if we should pull the system message out of our messages list
|
||||
if messages and messages[0]["role"] == "system":
|
||||
system = messages[0]["content"]
|
||||
messages.pop(0)
|
||||
|
||||
# Convert any subsequent "system"-role messages to "user"-role
|
||||
# messages, as AWS Bedrock doesn't support system input messages.
|
||||
for message in messages:
|
||||
if message["role"] == "system":
|
||||
message["role"] = "user"
|
||||
|
||||
# Merge consecutive messages with the same role.
|
||||
i = 0
|
||||
while i < len(messages) - 1:
|
||||
current_message = messages[i]
|
||||
next_message = messages[i + 1]
|
||||
if current_message["role"] == next_message["role"]:
|
||||
# Convert content to list of dictionaries if it's a string
|
||||
if isinstance(current_message["content"], str):
|
||||
current_message["content"] = [
|
||||
{"type": "text", "text": current_message["content"]}
|
||||
]
|
||||
if isinstance(next_message["content"], str):
|
||||
next_message["content"] = [{"type": "text", "text": next_message["content"]}]
|
||||
# Concatenate the content
|
||||
current_message["content"].extend(next_message["content"])
|
||||
# Remove the next message from the list
|
||||
messages.pop(i + 1)
|
||||
else:
|
||||
i += 1
|
||||
|
||||
# Avoid empty content in messages
|
||||
for message in messages:
|
||||
if isinstance(message["content"], str) and message["content"] == "":
|
||||
message["content"] = "(empty)"
|
||||
elif isinstance(message["content"], list) and len(message["content"]) == 0:
|
||||
message["content"] = [{"type": "text", "text": "(empty)"}]
|
||||
|
||||
return self.ConvertedMessages(messages=messages, system=system)
|
||||
|
||||
def _from_universal_context_message(self, message: LLMContextMessage) -> dict[str, Any]:
|
||||
if isinstance(message, LLMSpecificMessage):
|
||||
return copy.deepcopy(message.message)
|
||||
return self._from_standard_message(message)
|
||||
|
||||
def _from_standard_message(self, message: LLMStandardMessage) -> dict[str, Any]:
|
||||
"""Convert standard format message to AWS Bedrock format.
|
||||
|
||||
Handles conversion of text content, tool calls, and tool results.
|
||||
Empty text content is converted to "(empty)".
|
||||
|
||||
Args:
|
||||
message: Message in standard format.
|
||||
|
||||
Returns:
|
||||
Message in AWS Bedrock format.
|
||||
|
||||
Examples:
|
||||
Standard format input::
|
||||
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "123",
|
||||
"function": {"name": "search", "arguments": '{"q": "test"}'}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
AWS Bedrock format output::
|
||||
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"toolUse": {
|
||||
"toolUseId": "123",
|
||||
"name": "search",
|
||||
"input": {"q": "test"}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
message = copy.deepcopy(message)
|
||||
if message["role"] == "tool":
|
||||
# Try to parse the content as JSON if it looks like JSON
|
||||
try:
|
||||
if message["content"].strip().startswith("{") and message[
|
||||
"content"
|
||||
].strip().endswith("}"):
|
||||
content_json = json.loads(message["content"])
|
||||
tool_result_content = [{"json": content_json}]
|
||||
else:
|
||||
tool_result_content = [{"text": message["content"]}]
|
||||
except:
|
||||
tool_result_content = [{"text": message["content"]}]
|
||||
|
||||
return {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"toolResult": {
|
||||
"toolUseId": message["tool_call_id"],
|
||||
"content": tool_result_content,
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
if message.get("tool_calls"):
|
||||
tc = message["tool_calls"]
|
||||
ret = {"role": "assistant", "content": []}
|
||||
for tool_call in tc:
|
||||
function = tool_call["function"]
|
||||
arguments = json.loads(function["arguments"])
|
||||
new_tool_use = {
|
||||
"toolUse": {
|
||||
"toolUseId": tool_call["id"],
|
||||
"name": function["name"],
|
||||
"input": arguments,
|
||||
}
|
||||
}
|
||||
ret["content"].append(new_tool_use)
|
||||
return ret
|
||||
|
||||
# Handle text content
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
if content == "":
|
||||
return {"role": message["role"], "content": [{"text": "(empty)"}]}
|
||||
else:
|
||||
return {"role": message["role"], "content": [{"text": content}]}
|
||||
elif isinstance(content, list):
|
||||
new_content = []
|
||||
for item in content:
|
||||
# fix empty text
|
||||
if item.get("type", "") == "text":
|
||||
text_content = item["text"] if item["text"] != "" else "(empty)"
|
||||
new_content.append({"text": text_content})
|
||||
# handle image_url -> image conversion
|
||||
if item["type"] == "image_url":
|
||||
new_item = {
|
||||
"image": {
|
||||
"format": "jpeg",
|
||||
"source": {
|
||||
"bytes": base64.b64decode(item["image_url"]["url"].split(",")[1])
|
||||
},
|
||||
}
|
||||
}
|
||||
new_content.append(new_item)
|
||||
# In the case where there's a single image in the list (like what
|
||||
# would result from a UserImageRawFrame), ensure that the image
|
||||
# comes before text
|
||||
image_indices = [i for i, item in enumerate(new_content) if "image" in item]
|
||||
text_indices = [i for i, item in enumerate(new_content) if "text" in item]
|
||||
if len(image_indices) == 1 and text_indices:
|
||||
img_idx = image_indices[0]
|
||||
first_txt_idx = text_indices[0]
|
||||
if img_idx > first_txt_idx:
|
||||
# Move image before the first text
|
||||
image_item = new_content.pop(img_idx)
|
||||
new_content.insert(first_txt_idx, image_item)
|
||||
return {"role": message["role"], "content": new_content}
|
||||
|
||||
return message
|
||||
|
||||
@staticmethod
|
||||
def _to_bedrock_function_format(function: FunctionSchema) -> Dict[str, Any]:
|
||||
|
||||
@@ -54,6 +54,11 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
|
||||
- Extracting and sanitizing messages from the LLM context for logging with Gemini.
|
||||
"""
|
||||
|
||||
@property
|
||||
def id_for_llm_specific_messages(self) -> str:
|
||||
"""Get the identifier used in LLMSpecificMessage instances for Google."""
|
||||
return "google"
|
||||
|
||||
def get_llm_invocation_params(self, context: LLMContext) -> GeminiLLMInvocationParams:
|
||||
"""Get Gemini-specific LLM invocation parameters from a universal LLM context.
|
||||
|
||||
@@ -63,11 +68,11 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
|
||||
Returns:
|
||||
Dictionary of parameters for Gemini's API.
|
||||
"""
|
||||
messages = self._from_universal_context_messages(self._get_messages(context))
|
||||
messages = self._from_universal_context_messages(self.get_messages(context))
|
||||
return {
|
||||
"system_instruction": messages.system_instruction,
|
||||
"messages": messages.messages,
|
||||
# NOTE; LLMContext's tools are guaranteed to be a ToolsSchema (or NOT_GIVEN)
|
||||
# NOTE: LLMContext's tools are guaranteed to be a ToolsSchema (or NOT_GIVEN)
|
||||
"tools": self.from_standard_tools(context.tools),
|
||||
}
|
||||
|
||||
@@ -103,7 +108,7 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
|
||||
List of messages in a format ready for logging about Gemini.
|
||||
"""
|
||||
# Get messages in Gemini's format
|
||||
messages = self._from_universal_context_messages(self._get_messages(context)).messages
|
||||
messages = self._from_universal_context_messages(self.get_messages(context)).messages
|
||||
|
||||
# Sanitize messages for logging
|
||||
messages_for_logging = []
|
||||
@@ -119,9 +124,6 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
|
||||
messages_for_logging.append(obj)
|
||||
return messages_for_logging
|
||||
|
||||
def _get_messages(self, context: LLMContext) -> List[LLMContextMessage]:
|
||||
return context.get_messages("google")
|
||||
|
||||
@dataclass
|
||||
class ConvertedMessages:
|
||||
"""Container for Google-formatted messages converted from universal context."""
|
||||
@@ -192,14 +194,14 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
|
||||
def _from_standard_message(
|
||||
self, message: LLMStandardMessage, already_have_system_instruction: bool
|
||||
) -> Content | str:
|
||||
"""Convert universal context message to Google Content object.
|
||||
"""Convert standard universal context message to Google Content object.
|
||||
|
||||
Handles conversion of text, images, and function calls to Google's
|
||||
format.
|
||||
System instructions are returned as a plain string.
|
||||
|
||||
Args:
|
||||
message: Message in universal context format.
|
||||
message: Message in standard universal context format.
|
||||
already_have_system_instruction: Whether we already have a system instruction
|
||||
|
||||
Returns:
|
||||
@@ -308,5 +310,4 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
|
||||
audio_bytes = base64.b64decode(input_audio["data"])
|
||||
parts.append(Part(inline_data=Blob(mime_type="audio/wav", data=audio_bytes)))
|
||||
|
||||
message = Content(role=role, parts=parts)
|
||||
return message
|
||||
return Content(role=role, parts=parts)
|
||||
|
||||
@@ -24,6 +24,7 @@ from pipecat.processors.aggregators.llm_context import (
|
||||
LLMContext,
|
||||
LLMContextMessage,
|
||||
LLMContextToolChoice,
|
||||
LLMSpecificMessage,
|
||||
NotGiven,
|
||||
)
|
||||
|
||||
@@ -47,6 +48,11 @@ class OpenAILLMAdapter(BaseLLMAdapter[OpenAILLMInvocationParams]):
|
||||
- Extracting and sanitizing messages from the LLM context for logging about OpenAI.
|
||||
"""
|
||||
|
||||
@property
|
||||
def id_for_llm_specific_messages(self) -> str:
|
||||
"""Get the identifier used in LLMSpecificMessage instances for OpenAI."""
|
||||
return "openai"
|
||||
|
||||
def get_llm_invocation_params(self, context: LLMContext) -> OpenAILLMInvocationParams:
|
||||
"""Get OpenAI-specific LLM invocation parameters from a universal LLM context.
|
||||
|
||||
@@ -57,7 +63,7 @@ class OpenAILLMAdapter(BaseLLMAdapter[OpenAILLMInvocationParams]):
|
||||
Dictionary of parameters for OpenAI's ChatCompletion API.
|
||||
"""
|
||||
return {
|
||||
"messages": self._from_universal_context_messages(self._get_messages(context)),
|
||||
"messages": self._from_universal_context_messages(self.get_messages(context)),
|
||||
# NOTE; LLMContext's tools are guaranteed to be a ToolsSchema (or NOT_GIVEN)
|
||||
"tools": self.from_standard_tools(context.tools),
|
||||
"tool_choice": context.tool_choice,
|
||||
@@ -91,7 +97,7 @@ class OpenAILLMAdapter(BaseLLMAdapter[OpenAILLMInvocationParams]):
|
||||
List of messages in a format ready for logging about OpenAI.
|
||||
"""
|
||||
msgs = []
|
||||
for message in self._get_messages(context):
|
||||
for message in self.get_messages(context):
|
||||
msg = copy.deepcopy(message)
|
||||
if "content" in msg:
|
||||
if isinstance(msg["content"], list):
|
||||
@@ -104,14 +110,18 @@ class OpenAILLMAdapter(BaseLLMAdapter[OpenAILLMInvocationParams]):
|
||||
msgs.append(msg)
|
||||
return msgs
|
||||
|
||||
def _get_messages(self, context: LLMContext) -> List[LLMContextMessage]:
|
||||
return context.get_messages("openai")
|
||||
|
||||
def _from_universal_context_messages(
|
||||
self, messages: List[LLMContextMessage]
|
||||
) -> List[ChatCompletionMessageParam]:
|
||||
# Just a pass-through: messages are already the right type
|
||||
return messages
|
||||
result = []
|
||||
for message in messages:
|
||||
if isinstance(message, LLMSpecificMessage):
|
||||
# Extract the actual message content from LLMSpecificMessage
|
||||
result.append(message.message)
|
||||
else:
|
||||
# Standard message, pass through unchanged
|
||||
result.append(message)
|
||||
return result
|
||||
|
||||
def _from_standard_tool_choice(
|
||||
self, tool_choice: LLMContextToolChoice | NotGiven
|
||||
|
||||
@@ -30,6 +30,11 @@ class OpenAIRealtimeLLMAdapter(BaseLLMAdapter):
|
||||
OpenAI's Realtime API for function calling capabilities.
|
||||
"""
|
||||
|
||||
@property
|
||||
def id_for_llm_specific_messages(self) -> str:
|
||||
"""Get the identifier used in LLMSpecificMessage instances for OpenAI Realtime."""
|
||||
raise NotImplementedError("Universal LLMContext is not yet supported for OpenAI Realtime.")
|
||||
|
||||
def get_llm_invocation_params(self, context: LLMContext) -> OpenAIRealtimeLLMInvocationParams:
|
||||
"""Get OpenAI Realtime-specific LLM invocation parameters from a universal LLM context.
|
||||
|
||||
|
||||
@@ -33,6 +33,10 @@ class NoisereduceFilter(BaseAudioFilter):
|
||||
Applies spectral gating noise reduction algorithms to suppress background
|
||||
noise in audio streams. Uses the noisereduce library's default noise
|
||||
reduction parameters.
|
||||
|
||||
.. deprecated:: 0.0.85
|
||||
`NoisereduceFilter` is deprecated and will be removed in a future version.
|
||||
We recommend using other real-time audio filters like `KrispFilter` or `AICFilter`.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
@@ -40,6 +44,17 @@ class NoisereduceFilter(BaseAudioFilter):
|
||||
self._filtering = True
|
||||
self._sample_rate = 0
|
||||
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"`NoisereduceFilter` is deprecated. "
|
||||
"Use other real-time audio filters like `KrispFilter` or `AICFilter`.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
async def start(self, sample_rate: int):
|
||||
"""Initialize the filter with the transport's sample rate.
|
||||
|
||||
|
||||
0
src/pipecat/audio/turn/smart_turn/data/__init__.py
Normal file
0
src/pipecat/audio/turn/smart_turn/data/__init__.py
Normal file
BIN
src/pipecat/audio/turn/smart_turn/data/smart-turn-v3.0.onnx
Normal file
BIN
src/pipecat/audio/turn/smart_turn/data/smart-turn-v3.0.onnx
Normal file
Binary file not shown.
124
src/pipecat/audio/turn/smart_turn/local_smart_turn_v3.py
Normal file
124
src/pipecat/audio/turn/smart_turn/local_smart_turn_v3.py
Normal file
@@ -0,0 +1,124 @@
|
||||
#
|
||||
# Copyright (c) 2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Local turn analyzer for on-device ML inference using the smart-turn-v3 model.
|
||||
|
||||
This module provides a smart turn analyzer that uses an ONNX model for
|
||||
local end-of-turn detection without requiring network connectivity.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import BaseSmartTurn
|
||||
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
from transformers import WhisperFeatureExtractor
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use LocalSmartTurnAnalyzerV3, you need to `pip install pipecat-ai[local-smart-turn-v3]`."
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
|
||||
"""Local turn analyzer using the smart-turn-v3 ONNX model.
|
||||
|
||||
Provides end-of-turn detection using locally-stored ONNX model,
|
||||
enabling offline operation without network dependencies.
|
||||
"""
|
||||
|
||||
def __init__(self, *, smart_turn_model_path: Optional[str] = None, **kwargs):
|
||||
"""Initialize the local ONNX smart-turn-v3 analyzer.
|
||||
|
||||
Args:
|
||||
smart_turn_model_path: Path to the ONNX model file. If this is not
|
||||
set, the bundled smart-turn-v3.0 model will be used.
|
||||
**kwargs: Additional arguments passed to BaseSmartTurn.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
logger.debug("Loading Local Smart Turn v3 model...")
|
||||
|
||||
if not smart_turn_model_path:
|
||||
# Load bundled model
|
||||
model_name = "smart-turn-v3.0.onnx"
|
||||
package_path = "pipecat.audio.turn.smart_turn.data"
|
||||
|
||||
try:
|
||||
import importlib_resources as impresources
|
||||
|
||||
smart_turn_model_path = str(impresources.files(package_path).joinpath(model_name))
|
||||
except BaseException:
|
||||
from importlib import resources as impresources
|
||||
|
||||
try:
|
||||
with impresources.path(package_path, model_name) as f:
|
||||
smart_turn_model_path = f
|
||||
except BaseException:
|
||||
smart_turn_model_path = str(
|
||||
impresources.files(package_path).joinpath(model_name)
|
||||
)
|
||||
|
||||
so = ort.SessionOptions()
|
||||
so.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
|
||||
so.inter_op_num_threads = 1
|
||||
so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
|
||||
self._feature_extractor = WhisperFeatureExtractor(chunk_length=8)
|
||||
self._session = ort.InferenceSession(smart_turn_model_path, sess_options=so)
|
||||
|
||||
logger.debug("Loaded Local Smart Turn v3")
|
||||
|
||||
async def _predict_endpoint(self, audio_array: np.ndarray) -> Dict[str, Any]:
|
||||
"""Predict end-of-turn using local ONNX model."""
|
||||
|
||||
def truncate_audio_to_last_n_seconds(audio_array, n_seconds=8, sample_rate=16000):
|
||||
"""Truncate audio to last n seconds or pad with zeros to meet n seconds."""
|
||||
max_samples = n_seconds * sample_rate
|
||||
if len(audio_array) > max_samples:
|
||||
return audio_array[-max_samples:]
|
||||
elif len(audio_array) < max_samples:
|
||||
# Pad with zeros at the beginning
|
||||
padding = max_samples - len(audio_array)
|
||||
return np.pad(audio_array, (padding, 0), mode="constant", constant_values=0)
|
||||
return audio_array
|
||||
|
||||
# Truncate to 8 seconds (keeping the end) or pad to 8 seconds
|
||||
audio_array = truncate_audio_to_last_n_seconds(audio_array, n_seconds=8)
|
||||
|
||||
# Process audio using Whisper's feature extractor
|
||||
inputs = self._feature_extractor(
|
||||
audio_array,
|
||||
sampling_rate=16000,
|
||||
return_tensors="np",
|
||||
padding="max_length",
|
||||
max_length=8 * 16000,
|
||||
truncation=True,
|
||||
do_normalize=True,
|
||||
)
|
||||
|
||||
# Extract features and ensure correct shape for ONNX
|
||||
input_features = inputs.input_features.squeeze(0).astype(np.float32)
|
||||
input_features = np.expand_dims(input_features, axis=0) # Add batch dimension
|
||||
|
||||
# Run ONNX inference
|
||||
outputs = self._session.run(None, {"input_features": input_features})
|
||||
|
||||
# Extract probability (ONNX model returns sigmoid probabilities)
|
||||
probability = outputs[0][0].item()
|
||||
|
||||
# Make prediction (1 for Complete, 0 for Incomplete)
|
||||
prediction = 1 if probability > 0.5 else 0
|
||||
|
||||
return {
|
||||
"prediction": prediction,
|
||||
"probability": probability,
|
||||
}
|
||||
Binary file not shown.
@@ -21,7 +21,6 @@ from typing import List, Optional
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotInterruptionFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
LLMFullResponseEndFrame,
|
||||
@@ -360,7 +359,7 @@ class ClassificationProcessor(FrameProcessor):
|
||||
await self._voicemail_notifier.notify() # Clear buffered TTS frames
|
||||
|
||||
# Interrupt the current pipeline to stop any ongoing processing
|
||||
await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM)
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
|
||||
# Set the voicemail event to trigger the voicemail handler
|
||||
self._voicemail_event.clear()
|
||||
|
||||
@@ -788,43 +788,6 @@ class FatalErrorFrame(ErrorFrame):
|
||||
fatal: bool = field(default=True, init=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EndTaskFrame(SystemFrame):
|
||||
"""Frame to request graceful pipeline task closure.
|
||||
|
||||
This is used to notify the pipeline task that the pipeline should be
|
||||
closed nicely (flushing all the queued frames) by pushing an EndFrame
|
||||
downstream. This frame should be pushed upstream.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class CancelTaskFrame(SystemFrame):
|
||||
"""Frame to request immediate pipeline task cancellation.
|
||||
|
||||
This is used to notify the pipeline task that the pipeline should be
|
||||
stopped immediately by pushing a CancelFrame downstream. This frame
|
||||
should be pushed upstream.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class StopTaskFrame(SystemFrame):
|
||||
"""Frame to request pipeline task stop while keeping processors running.
|
||||
|
||||
This is used to notify the pipeline task that it should be stopped as
|
||||
soon as possible (flushing all the queued frames) but that the pipeline
|
||||
processors should be kept in a running state. This frame should be pushed
|
||||
upstream.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class FrameProcessorPauseUrgentFrame(SystemFrame):
|
||||
"""Frame to pause frame processing immediately.
|
||||
@@ -857,7 +820,7 @@ class FrameProcessorResumeUrgentFrame(SystemFrame):
|
||||
|
||||
|
||||
@dataclass
|
||||
class StartInterruptionFrame(SystemFrame):
|
||||
class InterruptionFrame(SystemFrame):
|
||||
"""Frame indicating user started speaking (interruption detected).
|
||||
|
||||
Emitted by the BaseInputTransport to indicate that a user has started
|
||||
@@ -869,6 +832,34 @@ class StartInterruptionFrame(SystemFrame):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class StartInterruptionFrame(InterruptionFrame):
|
||||
"""Frame indicating user started speaking (interruption detected).
|
||||
|
||||
.. deprecated:: 0.0.85
|
||||
This frame is deprecated and will be removed in a future version.
|
||||
Instead, use `InterruptionFrame`.
|
||||
|
||||
Emitted by the BaseInputTransport to indicate that a user has started
|
||||
speaking (i.e. is interrupting). This is similar to
|
||||
UserStartedSpeakingFrame except that it should be pushed concurrently
|
||||
with other frames (so the order is not guaranteed).
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"StartInterruptionFrame is deprecated and will be removed in a future version. "
|
||||
"Instead, use InterruptionFrame.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserStartedSpeakingFrame(SystemFrame):
|
||||
"""Frame indicating user has started speaking.
|
||||
@@ -944,20 +935,6 @@ class VADUserStoppedSpeakingFrame(SystemFrame):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BotInterruptionFrame(SystemFrame):
|
||||
"""Frame indicating the bot should be interrupted.
|
||||
|
||||
Emitted when the bot should be interrupted. This will mainly cause the
|
||||
same actions as if the user interrupted except that the
|
||||
UserStartedSpeakingFrame and UserStoppedSpeakingFrame won't be generated.
|
||||
This frame should be pushed upstreams. It results in the BaseInputTransport
|
||||
starting an interruption by pushing a StartInterruptionFrame downstream.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BotStartedSpeakingFrame(SystemFrame):
|
||||
"""Frame indicating the bot started speaking.
|
||||
@@ -1253,23 +1230,6 @@ class UserImageRawFrame(InputImageRawFrame):
|
||||
return f"{self.name}(pts: {pts}, user: {self.user_id}, source: {self.transport_source}, size: {self.size}, format: {self.format}, request: {self.request})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class VisionImageRawFrame(InputImageRawFrame):
|
||||
"""Image frame for vision/image analysis with associated text prompt.
|
||||
|
||||
An image with an associated text to ask for a description of it.
|
||||
|
||||
Parameters:
|
||||
text: Optional text prompt describing what to analyze in the image.
|
||||
"""
|
||||
|
||||
text: Optional[str] = None
|
||||
|
||||
def __str__(self):
|
||||
pts = format_pts(self.pts)
|
||||
return f"{self.name}(pts: {pts}, text: [{self.text}], size: {self.size}, format: {self.format})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputDTMFFrame(DTMFFrame, SystemFrame):
|
||||
"""DTMF keypress input frame from transport."""
|
||||
@@ -1306,6 +1266,103 @@ class SpeechControlParamsFrame(SystemFrame):
|
||||
turn_params: Optional[SmartTurnParams] = None
|
||||
|
||||
|
||||
#
|
||||
# Task frames
|
||||
#
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskFrame(SystemFrame):
|
||||
"""Base frame for task frames.
|
||||
|
||||
This is a base class for frames that are meant to be sent and handled
|
||||
upstream by the pipeline task. This might result in a corresponding frame
|
||||
sent downstream (e.g. `InterruptionTaskFrame` / `InterruptionFrame` or
|
||||
`EndTaskFrame` / `EndFrame`).
|
||||
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class EndTaskFrame(TaskFrame):
|
||||
"""Frame to request graceful pipeline task closure.
|
||||
|
||||
This is used to notify the pipeline task that the pipeline should be
|
||||
closed nicely (flushing all the queued frames) by pushing an EndFrame
|
||||
downstream. This frame should be pushed upstream.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class CancelTaskFrame(TaskFrame):
|
||||
"""Frame to request immediate pipeline task cancellation.
|
||||
|
||||
This is used to notify the pipeline task that the pipeline should be
|
||||
stopped immediately by pushing a CancelFrame downstream. This frame
|
||||
should be pushed upstream.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class StopTaskFrame(TaskFrame):
|
||||
"""Frame to request pipeline task stop while keeping processors running.
|
||||
|
||||
This is used to notify the pipeline task that it should be stopped as
|
||||
soon as possible (flushing all the queued frames) but that the pipeline
|
||||
processors should be kept in a running state. This frame should be pushed
|
||||
upstream.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class InterruptionTaskFrame(TaskFrame):
|
||||
"""Frame indicating the bot should be interrupted.
|
||||
|
||||
Emitted when the bot should be interrupted. This will mainly cause the
|
||||
same actions as if the user interrupted except that the
|
||||
UserStartedSpeakingFrame and UserStoppedSpeakingFrame won't be generated.
|
||||
This frame should be pushed upstream.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BotInterruptionFrame(InterruptionTaskFrame):
|
||||
"""Frame indicating the bot should be interrupted.
|
||||
|
||||
.. deprecated:: 0.0.85
|
||||
This frame is deprecated and will be removed in a future version.
|
||||
Instead, use `InterruptionTaskFrame`.
|
||||
|
||||
Emitted when the bot should be interrupted. This will mainly cause the
|
||||
same actions as if the user interrupted except that the
|
||||
UserStartedSpeakingFrame and UserStoppedSpeakingFrame won't be generated.
|
||||
This frame should be pushed upstream.
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"BotInterruptionFrame is deprecated and will be removed in a future version. "
|
||||
"Instead, use InterruptionTaskFrame.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
#
|
||||
# Control frames
|
||||
#
|
||||
@@ -1547,7 +1604,7 @@ class MixerEnableFrame(MixerControlFrame):
|
||||
|
||||
@dataclass
|
||||
class ServiceSwitcherFrame(ControlFrame):
|
||||
"""A base class for frames that control ServiceSwitcher behavior."""
|
||||
"""A base class for frames that affect ServiceSwitcher behavior."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ class DebugLogObserver(BaseObserver):
|
||||
|
||||
Log frames with specific source/destination filters::
|
||||
|
||||
from pipecat.frames.frames import StartInterruptionFrame, UserStartedSpeakingFrame, LLMTextFrame
|
||||
from pipecat.frames.frames import InterruptionFrame, UserStartedSpeakingFrame, LLMTextFrame
|
||||
from pipecat.observers.loggers.debug_log_observer import DebugLogObserver, FrameEndpoint
|
||||
from pipecat.transports.base_output import BaseOutputTransport
|
||||
from pipecat.services.stt_service import STTService
|
||||
@@ -62,8 +62,8 @@ class DebugLogObserver(BaseObserver):
|
||||
observers=[
|
||||
DebugLogObserver(
|
||||
frame_types={
|
||||
# Only log StartInterruptionFrame when source is BaseOutputTransport
|
||||
StartInterruptionFrame: (BaseOutputTransport, FrameEndpoint.SOURCE),
|
||||
# Only log InterruptionFrame when source is BaseOutputTransport
|
||||
InterruptionFrame: (BaseOutputTransport, FrameEndpoint.SOURCE),
|
||||
# Only log UserStartedSpeakingFrame when destination is STTService
|
||||
UserStartedSpeakingFrame: (STTService, FrameEndpoint.DESTINATION),
|
||||
# Log LLMTextFrame regardless of source or destination type
|
||||
|
||||
@@ -30,25 +30,17 @@ class LLMSwitcher(ServiceSwitcher[StrategyType]):
|
||||
"""Get the currently active LLM, if any."""
|
||||
return self.strategy.active_service
|
||||
|
||||
async def run_inference(
|
||||
self, context: LLMContext, system_instruction: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
async def run_inference(self, context: LLMContext) -> Optional[str]:
|
||||
"""Run a one-shot, out-of-band (i.e. out-of-pipeline) inference with the given LLM context, using the currently active LLM.
|
||||
|
||||
Args:
|
||||
context: The LLM context containing conversation history.
|
||||
system_instruction: Optional system instruction to guide the LLM's
|
||||
behavior. You could also (again, optionally) provide a system
|
||||
instruction directly in the context. If both are provided, the
|
||||
one in the context takes precedence.
|
||||
|
||||
Returns:
|
||||
The LLM's response as a string, or None if no response is generated.
|
||||
"""
|
||||
if self.active_llm:
|
||||
return await self.active_llm.run_inference(
|
||||
context=context, system_instruction=system_instruction
|
||||
)
|
||||
return await self.active_llm.run_inference(context=context)
|
||||
return None
|
||||
|
||||
def register_function(
|
||||
|
||||
@@ -6,9 +6,15 @@
|
||||
|
||||
"""Service switcher for switching between different services at runtime, with different switching strategies."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic, List, Optional, Type, TypeVar
|
||||
|
||||
from pipecat.frames.frames import Frame, ManuallySwitchServiceFrame, ServiceSwitcherFrame
|
||||
from pipecat.frames.frames import (
|
||||
ControlFrame,
|
||||
Frame,
|
||||
ManuallySwitchServiceFrame,
|
||||
ServiceSwitcherFrame,
|
||||
)
|
||||
from pipecat.pipeline.parallel_pipeline import ParallelPipeline
|
||||
from pipecat.processors.filters.function_filter import FunctionFilter
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
@@ -22,19 +28,6 @@ class ServiceSwitcherStrategy:
|
||||
self.services = services
|
||||
self.active_service: Optional[FrameProcessor] = None
|
||||
|
||||
def is_active(self, service: FrameProcessor) -> bool:
|
||||
"""Determine if the given service is the currently active one.
|
||||
|
||||
This method should be overridden by subclasses to implement specific logic.
|
||||
|
||||
Args:
|
||||
service: The service to check.
|
||||
|
||||
Returns:
|
||||
True if the given service is the active one, False otherwise.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement this method.")
|
||||
|
||||
def handle_frame(self, frame: ServiceSwitcherFrame, direction: FrameDirection):
|
||||
"""Handle a frame that controls service switching.
|
||||
|
||||
@@ -60,17 +53,6 @@ class ServiceSwitcherStrategyManual(ServiceSwitcherStrategy):
|
||||
super().__init__(services)
|
||||
self.active_service = services[0] if services else None
|
||||
|
||||
def is_active(self, service: FrameProcessor) -> bool:
|
||||
"""Check if the given service is the currently active one.
|
||||
|
||||
Args:
|
||||
service: The service to check.
|
||||
|
||||
Returns:
|
||||
True if the given service is the active one, False otherwise.
|
||||
"""
|
||||
return service == self.active_service
|
||||
|
||||
def handle_frame(self, frame: ServiceSwitcherFrame, direction: FrameDirection):
|
||||
"""Handle a frame that controls service switching.
|
||||
|
||||
@@ -79,20 +61,21 @@ class ServiceSwitcherStrategyManual(ServiceSwitcherStrategy):
|
||||
direction: The direction of the frame (upstream or downstream).
|
||||
"""
|
||||
if isinstance(frame, ManuallySwitchServiceFrame):
|
||||
self._set_active(frame.service)
|
||||
self._set_active_if_available(frame.service)
|
||||
else:
|
||||
raise ValueError(f"Unsupported frame type: {type(frame)}")
|
||||
|
||||
def _set_active(self, service: FrameProcessor):
|
||||
"""Set the active service to the given one.
|
||||
def _set_active_if_available(self, service: FrameProcessor):
|
||||
"""Set the active service to the given one, if it is in the list of available services.
|
||||
|
||||
If it's not in the list, the request is ignored, as it may have been
|
||||
intended for another ServiceSwitcher in the pipeline.
|
||||
|
||||
Args:
|
||||
service: The service to set as active.
|
||||
"""
|
||||
if service in self.services:
|
||||
self.active_service = service
|
||||
else:
|
||||
raise ValueError(f"Service {service} is not in the list of available services.")
|
||||
|
||||
|
||||
StrategyType = TypeVar("StrategyType", bound=ServiceSwitcherStrategy)
|
||||
@@ -108,6 +91,43 @@ class ServiceSwitcher(ParallelPipeline, Generic[StrategyType]):
|
||||
self.services = services
|
||||
self.strategy = strategy
|
||||
|
||||
class ServiceSwitcherFilter(FunctionFilter):
|
||||
"""An internal filter that allows frames to pass through to the wrapped service only if it's the active service."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
wrapped_service: FrameProcessor,
|
||||
active_service: FrameProcessor,
|
||||
direction: FrameDirection,
|
||||
):
|
||||
"""Initialize the service switcher filter with a strategy and direction."""
|
||||
|
||||
async def filter(_: Frame) -> bool:
|
||||
return self._wrapped_service == self._active_service
|
||||
|
||||
super().__init__(filter, direction)
|
||||
self._wrapped_service = wrapped_service
|
||||
self._active_service = active_service
|
||||
|
||||
async def process_frame(self, frame, direction):
|
||||
"""Process a frame through the filter, handling special internal filter-updating frames."""
|
||||
if isinstance(frame, ServiceSwitcher.ServiceSwitcherFilterFrame):
|
||||
self._active_service = frame.active_service
|
||||
# Two ServiceSwitcherFilters "sandwich" a service. Push the
|
||||
# frame only to update the other side of the sandwich, but
|
||||
# otherwise don't let it leave the sandwich.
|
||||
if direction == self._direction:
|
||||
await self.push_frame(frame, direction)
|
||||
return
|
||||
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
@dataclass
|
||||
class ServiceSwitcherFilterFrame(ControlFrame):
|
||||
"""An internal frame used by ServiceSwitcher to filter frames based on active service."""
|
||||
|
||||
active_service: FrameProcessor
|
||||
|
||||
@staticmethod
|
||||
def _make_pipeline_definitions(
|
||||
services: List[FrameProcessor], strategy: ServiceSwitcherStrategy
|
||||
@@ -121,14 +141,18 @@ class ServiceSwitcher(ParallelPipeline, Generic[StrategyType]):
|
||||
def _make_pipeline_definition(
|
||||
service: FrameProcessor, strategy: ServiceSwitcherStrategy
|
||||
) -> Any:
|
||||
async def filter(frame) -> bool:
|
||||
_ = frame
|
||||
return strategy.is_active(service)
|
||||
|
||||
return [
|
||||
FunctionFilter(filter, direction=FrameDirection.DOWNSTREAM),
|
||||
ServiceSwitcher.ServiceSwitcherFilter(
|
||||
wrapped_service=service,
|
||||
active_service=strategy.active_service,
|
||||
direction=FrameDirection.DOWNSTREAM,
|
||||
),
|
||||
service,
|
||||
FunctionFilter(filter, direction=FrameDirection.UPSTREAM),
|
||||
ServiceSwitcher.ServiceSwitcherFilter(
|
||||
wrapped_service=service,
|
||||
active_service=strategy.active_service,
|
||||
direction=FrameDirection.UPSTREAM,
|
||||
),
|
||||
]
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
@@ -142,3 +166,7 @@ class ServiceSwitcher(ParallelPipeline, Generic[StrategyType]):
|
||||
|
||||
if isinstance(frame, ServiceSwitcherFrame):
|
||||
self.strategy.handle_frame(frame, direction)
|
||||
service_switcher_filter_frame = ServiceSwitcher.ServiceSwitcherFilterFrame(
|
||||
active_service=self.strategy.active_service
|
||||
)
|
||||
await super().process_frame(service_switcher_filter_frame, direction)
|
||||
|
||||
@@ -32,6 +32,8 @@ from pipecat.frames.frames import (
|
||||
Frame,
|
||||
HeartbeatFrame,
|
||||
InputAudioRawFrame,
|
||||
InterruptionFrame,
|
||||
InterruptionTaskFrame,
|
||||
MetricsFrame,
|
||||
StartFrame,
|
||||
StopFrame,
|
||||
@@ -113,9 +115,28 @@ class PipelineTask(BasePipelineTask):
|
||||
- on_frame_reached_downstream: Called when downstream frames reach the sink
|
||||
- on_idle_timeout: Called when pipeline is idle beyond timeout threshold
|
||||
- on_pipeline_started: Called when pipeline starts with StartFrame
|
||||
- on_pipeline_stopped: Called when pipeline stops with StopFrame
|
||||
- on_pipeline_ended: Called when pipeline ends with EndFrame
|
||||
- on_pipeline_cancelled: Called when pipeline is cancelled
|
||||
- on_pipeline_stopped: [deprecated] Called when pipeline stops with StopFrame
|
||||
|
||||
.. deprecated:: 0.0.86
|
||||
Use `on_pipeline_finished` instead.
|
||||
|
||||
- on_pipeline_ended: [deprecated] Called when pipeline ends with EndFrame
|
||||
|
||||
.. deprecated:: 0.0.86
|
||||
Use `on_pipeline_finished` instead.
|
||||
|
||||
- on_pipeline_cancelled: [deprecated] Called when pipeline is cancelled with CancelFrame
|
||||
|
||||
.. deprecated:: 0.0.86
|
||||
Use `on_pipeline_finished` instead.
|
||||
|
||||
- on_pipeline_finished: Called after the pipeline has reached any terminal state.
|
||||
This includes:
|
||||
- StopFrame: pipeline was stopped (processors keep connections open)
|
||||
- EndFrame: pipeline ended normally
|
||||
- CancelFrame: pipeline was cancelled
|
||||
Use this event for cleanup, logging, or post-processing tasks. Users can inspect
|
||||
the frame if they need to handle specific cases.
|
||||
|
||||
Example::
|
||||
|
||||
@@ -126,6 +147,10 @@ class PipelineTask(BasePipelineTask):
|
||||
@task.event_handler("on_idle_timeout")
|
||||
async def on_pipeline_idle_timeout(task):
|
||||
...
|
||||
|
||||
@task.event_handler("on_pipeline_finished")
|
||||
async def on_pipeline_finished(task, frame):
|
||||
...
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -262,6 +287,7 @@ class PipelineTask(BasePipelineTask):
|
||||
self._register_event_handler("on_pipeline_stopped")
|
||||
self._register_event_handler("on_pipeline_ended")
|
||||
self._register_event_handler("on_pipeline_cancelled")
|
||||
self._register_event_handler("on_pipeline_finished")
|
||||
|
||||
@property
|
||||
def params(self) -> PipelineParams:
|
||||
@@ -290,6 +316,27 @@ class PipelineTask(BasePipelineTask):
|
||||
"""
|
||||
return self._turn_trace_observer
|
||||
|
||||
def event_handler(self, event_name: str):
|
||||
"""Decorator for registering event handlers.
|
||||
|
||||
Args:
|
||||
event_name: The name of the event to handle.
|
||||
|
||||
Returns:
|
||||
The decorator function that registers the handler.
|
||||
"""
|
||||
if event_name in ["on_pipeline_stopped", "on_pipeline_ended", "on_pipeline_cancelled"]:
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
f"Event '{event_name}' is deprecated, use 'on_pipeline_finished' instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
return super().event_handler(event_name)
|
||||
|
||||
def add_observer(self, observer: BaseObserver):
|
||||
"""Add an observer to monitor pipeline execution.
|
||||
|
||||
@@ -532,6 +579,7 @@ class PipelineTask(BasePipelineTask):
|
||||
)
|
||||
finally:
|
||||
await self._call_event_handler("on_pipeline_cancelled", frame)
|
||||
await self._call_event_handler("on_pipeline_finished", frame)
|
||||
|
||||
logger.debug(f"{self}: Closing. Waiting for {frame} to reach the end of the pipeline...")
|
||||
|
||||
@@ -627,13 +675,23 @@ class PipelineTask(BasePipelineTask):
|
||||
|
||||
if isinstance(frame, EndTaskFrame):
|
||||
# Tell the task we should end nicely.
|
||||
logger.debug(f"{self}: received end task frame {frame}")
|
||||
await self.queue_frame(EndFrame())
|
||||
elif isinstance(frame, CancelTaskFrame):
|
||||
# Tell the task we should end right away.
|
||||
logger.debug(f"{self}: received cancel task frame {frame}")
|
||||
await self.queue_frame(CancelFrame())
|
||||
elif isinstance(frame, StopTaskFrame):
|
||||
# Tell the task we should stop nicely.
|
||||
logger.debug(f"{self}: received stop task frame {frame}")
|
||||
await self.queue_frame(StopFrame())
|
||||
elif isinstance(frame, InterruptionTaskFrame):
|
||||
# Tell the task we should interrupt the pipeline. Note that we are
|
||||
# bypassing the push queue and directly queue into the
|
||||
# pipeline. This is in case the push task is blocked waiting for a
|
||||
# pipeline-ending frame to finish traversing the pipeline.
|
||||
logger.debug(f"{self}: received interruption task frame {frame}")
|
||||
await self._pipeline.queue_frame(InterruptionFrame())
|
||||
elif isinstance(frame, ErrorFrame):
|
||||
if frame.fatal:
|
||||
logger.error(f"A fatal error occurred: {frame}")
|
||||
@@ -642,7 +700,7 @@ class PipelineTask(BasePipelineTask):
|
||||
# Tell the task we should stop.
|
||||
await self.queue_frame(StopTaskFrame())
|
||||
else:
|
||||
logger.warning(f"Something went wrong: {frame}")
|
||||
logger.warning(f"{self}: Something went wrong: {frame}")
|
||||
|
||||
async def _sink_push_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process frames coming downstream from the pipeline.
|
||||
@@ -669,9 +727,11 @@ class PipelineTask(BasePipelineTask):
|
||||
self._pipeline_start_event.set()
|
||||
elif isinstance(frame, EndFrame):
|
||||
await self._call_event_handler("on_pipeline_ended", frame)
|
||||
await self._call_event_handler("on_pipeline_finished", frame)
|
||||
self._pipeline_end_event.set()
|
||||
elif isinstance(frame, StopFrame):
|
||||
await self._call_event_handler("on_pipeline_stopped", frame)
|
||||
await self._call_event_handler("on_pipeline_finished", frame)
|
||||
self._pipeline_end_event.set()
|
||||
elif isinstance(frame, CancelFrame):
|
||||
self._pipeline_end_event.set()
|
||||
|
||||
@@ -16,7 +16,6 @@ from typing import Optional
|
||||
|
||||
from pipecat.audio.dtmf.types import KeypadEntry
|
||||
from pipecat.frames.frames import (
|
||||
BotInterruptionFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
@@ -24,7 +23,7 @@ from pipecat.frames.frames import (
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor, FrameProcessorSetup
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
|
||||
@@ -105,7 +104,7 @@ class DTMFAggregator(FrameProcessor):
|
||||
|
||||
# For first digit, schedule interruption.
|
||||
if is_first_digit:
|
||||
await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM)
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
|
||||
# Check for immediate flush conditions
|
||||
if frame.button == self._termination_digit:
|
||||
|
||||
@@ -22,7 +22,6 @@ from pipecat.audio.interruptions.base_interruption_strategy import BaseInterrupt
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import (
|
||||
BotInterruptionFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
@@ -36,6 +35,7 @@ from pipecat.frames.frames import (
|
||||
FunctionCallsStartedFrame,
|
||||
InputAudioRawFrame,
|
||||
InterimTranscriptionFrame,
|
||||
InterruptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesAppendFrame,
|
||||
@@ -48,7 +48,6 @@ from pipecat.frames.frames import (
|
||||
OpenAILLMContextAssistantTimestampFrame,
|
||||
SpeechControlParamsFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
UserImageRawFrame,
|
||||
@@ -138,7 +137,7 @@ class LLMFullResponseAggregator(FrameProcessor):
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
if isinstance(frame, InterruptionFrame):
|
||||
await self._call_event_handler("on_completion", self._aggregation, False)
|
||||
self._aggregation = ""
|
||||
self._started = False
|
||||
@@ -532,9 +531,9 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
|
||||
|
||||
if should_interrupt:
|
||||
logger.debug(
|
||||
"Interruption conditions met - pushing BotInterruptionFrame and aggregation"
|
||||
"Interruption conditions met - pushing interruption and aggregation"
|
||||
)
|
||||
await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM)
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self._process_aggregation()
|
||||
else:
|
||||
logger.debug("Interruption conditions not met - not pushing aggregation")
|
||||
@@ -838,7 +837,7 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
if isinstance(frame, InterruptionFrame):
|
||||
await self._handle_interruptions(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, LLMFullResponseStartFrame):
|
||||
@@ -904,7 +903,7 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
|
||||
if frame.run_llm:
|
||||
await self.push_context_frame(FrameDirection.UPSTREAM)
|
||||
|
||||
async def _handle_interruptions(self, frame: StartInterruptionFrame):
|
||||
async def _handle_interruptions(self, frame: InterruptionFrame):
|
||||
await self.push_aggregation()
|
||||
self._started = 0
|
||||
await self.reset()
|
||||
|
||||
@@ -13,7 +13,6 @@ LLM processing, and text-to-speech components in conversational AI pipelines.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Literal, Optional, Set
|
||||
|
||||
from loguru import logger
|
||||
@@ -23,7 +22,6 @@ from pipecat.audio.interruptions.base_interruption_strategy import BaseInterrupt
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import (
|
||||
BotInterruptionFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
@@ -37,6 +35,7 @@ from pipecat.frames.frames import (
|
||||
FunctionCallsStartedFrame,
|
||||
InputAudioRawFrame,
|
||||
InterimTranscriptionFrame,
|
||||
InterruptionFrame,
|
||||
LLMContextAssistantTimestampFrame,
|
||||
LLMContextFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
@@ -48,7 +47,6 @@ from pipecat.frames.frames import (
|
||||
LLMSetToolsFrame,
|
||||
SpeechControlParamsFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
UserImageRawFrame,
|
||||
@@ -311,9 +309,9 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
|
||||
if should_interrupt:
|
||||
logger.debug(
|
||||
"Interruption conditions met - pushing BotInterruptionFrame and aggregation"
|
||||
"Interruption conditions met - pushing interruption and aggregation"
|
||||
)
|
||||
await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM)
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self._process_aggregation()
|
||||
else:
|
||||
logger.debug("Interruption conditions not met - not pushing aggregation")
|
||||
@@ -579,7 +577,7 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
if isinstance(frame, InterruptionFrame):
|
||||
await self._handle_interruptions(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, LLMFullResponseStartFrame):
|
||||
@@ -645,7 +643,7 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
if frame.run_llm:
|
||||
await self.push_context_frame(FrameDirection.UPSTREAM)
|
||||
|
||||
async def _handle_interruptions(self, frame: StartInterruptionFrame):
|
||||
async def _handle_interruptions(self, frame: InterruptionFrame):
|
||||
await self._push_aggregation()
|
||||
self._started = 0
|
||||
await self.reset()
|
||||
|
||||
@@ -10,13 +10,22 @@ This module provides frame aggregation functionality to combine text and image
|
||||
frames into vision frames for multimodal processing.
|
||||
"""
|
||||
|
||||
from pipecat.frames.frames import Frame, InputImageRawFrame, TextFrame, VisionImageRawFrame
|
||||
from pipecat.frames.frames import Frame, InputImageRawFrame, TextFrame
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
|
||||
class VisionImageFrameAggregator(FrameProcessor):
|
||||
"""Aggregates consecutive text and image frames into vision frames.
|
||||
|
||||
.. deprecated:: 0.0.85
|
||||
VisionImageRawFrame has been removed in favor of context frames
|
||||
(LLMContextFrame or OpenAILLMContextFrame), so this aggregator is not
|
||||
needed anymore. See the 12* examples for the new recommended pattern.
|
||||
|
||||
This aggregator waits for a consecutive TextFrame and an InputImageRawFrame.
|
||||
After the InputImageRawFrame arrives it will output a VisionImageRawFrame
|
||||
combining both the text and image data for multimodal processing.
|
||||
@@ -28,6 +37,17 @@ class VisionImageFrameAggregator(FrameProcessor):
|
||||
The aggregator starts with no cached text, waiting for the first
|
||||
TextFrame to arrive before it can create vision frames.
|
||||
"""
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"VisionImageFrameAggregator is deprecated. "
|
||||
"VisionImageRawFrame has been removed in favor of context frames "
|
||||
"(LLMContextFrame or OpenAILLMContextFrame), so this aggregator is "
|
||||
"not needed anymore. See the 12* examples for the new recommended "
|
||||
"pattern.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
super().__init__()
|
||||
self._describe_text = None
|
||||
|
||||
@@ -47,12 +67,14 @@ class VisionImageFrameAggregator(FrameProcessor):
|
||||
self._describe_text = frame.text
|
||||
elif isinstance(frame, InputImageRawFrame):
|
||||
if self._describe_text:
|
||||
frame = VisionImageRawFrame(
|
||||
context = OpenAILLMContext()
|
||||
context.add_image_frame_message(
|
||||
text=self._describe_text,
|
||||
image=frame.image,
|
||||
size=frame.size,
|
||||
format=frame.format,
|
||||
)
|
||||
frame = OpenAILLMContextFrame(context)
|
||||
await self.push_frame(frame)
|
||||
self._describe_text = None
|
||||
else:
|
||||
|
||||
@@ -137,12 +137,12 @@ class AudioBufferProcessor(FrameProcessor):
|
||||
return self._num_channels
|
||||
|
||||
def has_audio(self) -> bool:
|
||||
"""Check if both user and bot audio buffers contain data.
|
||||
"""Check if either user or bot audio buffers contain data.
|
||||
|
||||
Returns:
|
||||
True if both buffers contain audio data.
|
||||
True if either buffer contains audio data.
|
||||
"""
|
||||
return self._buffer_has_audio(self._user_audio_buffer) and self._buffer_has_audio(
|
||||
return self._buffer_has_audio(self._user_audio_buffer) or self._buffer_has_audio(
|
||||
self._bot_audio_buffer
|
||||
)
|
||||
|
||||
|
||||
@@ -25,8 +25,8 @@ from pipecat.frames.frames import (
|
||||
FunctionCallResultFrame,
|
||||
InputAudioRawFrame,
|
||||
InterimTranscriptionFrame,
|
||||
InterruptionFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
STTMuteFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
@@ -204,7 +204,7 @@ class STTMuteFilter(FrameProcessor):
|
||||
if isinstance(
|
||||
frame,
|
||||
(
|
||||
StartInterruptionFrame,
|
||||
InterruptionFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
|
||||
@@ -28,8 +28,9 @@ from pipecat.frames.frames import (
|
||||
FrameProcessorPauseUrgentFrame,
|
||||
FrameProcessorResumeFrame,
|
||||
FrameProcessorResumeUrgentFrame,
|
||||
InterruptionFrame,
|
||||
InterruptionTaskFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
SystemFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import LLMTokenUsage, MetricsData
|
||||
@@ -219,6 +220,14 @@ class FrameProcessor(BaseObject):
|
||||
self.__process_event: Optional[asyncio.Event] = None
|
||||
self.__process_frame_task: Optional[asyncio.Task] = None
|
||||
|
||||
# To interrupt a pipeline, we push an `InterruptionTaskFrame` upstream.
|
||||
# Then we wait for the corresponding `InterruptionFrame` to travel from
|
||||
# the start of the pipeline back to the processor that sent the
|
||||
# `InterruptionTaskFrame`. This wait is handled using the following
|
||||
# event.
|
||||
self._wait_for_interruption = False
|
||||
self._wait_interruption_event = asyncio.Event()
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
"""Get the unique identifier for this processor.
|
||||
@@ -542,6 +551,14 @@ class FrameProcessor(BaseObject):
|
||||
if self._cancelling:
|
||||
return
|
||||
|
||||
# If we are waiting for an interruption we will bypass all queued system
|
||||
# frames and we will process the frame right away. This is because a
|
||||
# previous system frame might be waiting for the interruption frame and
|
||||
# it's blocking the input task.
|
||||
if self._wait_for_interruption and isinstance(frame, InterruptionFrame):
|
||||
await self.__process_frame(frame, direction, callback)
|
||||
return
|
||||
|
||||
if self._enable_direct_mode:
|
||||
await self.__process_frame(frame, direction, callback)
|
||||
else:
|
||||
@@ -551,11 +568,17 @@ class FrameProcessor(BaseObject):
|
||||
"""Pause processing of queued frames."""
|
||||
logger.trace(f"{self}: pausing frame processing")
|
||||
self.__should_block_frames = True
|
||||
# We should also unset the process event here, in case it was set immediately after an interruption
|
||||
if self.__process_event:
|
||||
self.__process_event.clear()
|
||||
|
||||
async def pause_processing_system_frames(self):
|
||||
"""Pause processing of queued system frames."""
|
||||
logger.trace(f"{self}: pausing system frame processing")
|
||||
self.__should_block_system_frames = True
|
||||
# We should also unset the input event here, in case it was set immediately after an interruption
|
||||
if self.__input_event:
|
||||
self.__input_event.clear()
|
||||
|
||||
async def resume_processing_frames(self):
|
||||
"""Resume processing of queued frames."""
|
||||
@@ -588,7 +611,7 @@ class FrameProcessor(BaseObject):
|
||||
|
||||
if isinstance(frame, StartFrame):
|
||||
await self.__start(frame)
|
||||
elif isinstance(frame, StartInterruptionFrame):
|
||||
elif isinstance(frame, InterruptionFrame):
|
||||
await self._start_interruption()
|
||||
await self.stop_all_metrics()
|
||||
elif isinstance(frame, CancelFrame):
|
||||
@@ -620,6 +643,34 @@ class FrameProcessor(BaseObject):
|
||||
|
||||
await self.__internal_push_frame(frame, direction)
|
||||
|
||||
# If we are waiting for an interruption and we get an interruption, then
|
||||
# we can unblock `push_interruption_task_frame_and_wait()`.
|
||||
if self._wait_for_interruption and isinstance(frame, InterruptionFrame):
|
||||
self._wait_interruption_event.set()
|
||||
|
||||
async def push_interruption_task_frame_and_wait(self):
|
||||
"""Push an interruption task frame upstream and wait for the interruption.
|
||||
|
||||
This function sends an `InterruptionTaskFrame` upstream to the pipeline
|
||||
task and waits to receive the corresponding `InterruptionFrame`. When
|
||||
the function finishes it is guaranteed that the `InterruptionFrame` has
|
||||
been pushed downstream.
|
||||
"""
|
||||
self._wait_for_interruption = True
|
||||
|
||||
await self.push_frame(InterruptionTaskFrame(), FrameDirection.UPSTREAM)
|
||||
|
||||
# Wait for an `InterruptionFrame` to come to this processor and be
|
||||
# pushed. Take a look at `push_frame()` to see how we first push the
|
||||
# `InterruptionFrame` and then we set the event in order to maintain
|
||||
# frame ordering.
|
||||
await self._wait_interruption_event.wait()
|
||||
|
||||
# Clean the event.
|
||||
self._wait_interruption_event.clear()
|
||||
|
||||
self._wait_for_interruption = False
|
||||
|
||||
async def __start(self, frame: StartFrame):
|
||||
"""Handle the start frame to initialize processor state.
|
||||
|
||||
@@ -669,20 +720,22 @@ class FrameProcessor(BaseObject):
|
||||
async def _start_interruption(self):
|
||||
"""Start handling an interruption by cancelling current tasks."""
|
||||
try:
|
||||
# Cancel the process task. This will stop processing queued frames.
|
||||
await self.__cancel_process_task()
|
||||
if self._wait_for_interruption:
|
||||
# If we get here we know the process task was just waiting for
|
||||
# an interruption (push_interruption_task_frame_and_wait()), so
|
||||
# we can't cancel the task because it might still need to do
|
||||
# more things (e.g. pushing a frame after the
|
||||
# interruption). Instead we just drain the queue because this is
|
||||
# an interruption.
|
||||
self.__reset_process_task()
|
||||
else:
|
||||
# Cancel and re-create the process task including the queue.
|
||||
await self.__cancel_process_task()
|
||||
self.__create_process_task()
|
||||
except Exception as e:
|
||||
logger.exception(f"Uncaught exception in {self} when handling _start_interruption: {e}")
|
||||
await self.push_error(ErrorFrame(str(e)))
|
||||
|
||||
# Create a new process queue and task.
|
||||
self.__create_process_task()
|
||||
|
||||
async def _stop_interruption(self):
|
||||
"""Stop handling an interruption."""
|
||||
# Nothing to do right now.
|
||||
pass
|
||||
|
||||
async def __internal_push_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Internal method to push frames to adjacent processors.
|
||||
|
||||
@@ -764,6 +817,17 @@ class FrameProcessor(BaseObject):
|
||||
self.__process_queue = asyncio.Queue()
|
||||
self.__process_frame_task = self.create_task(self.__process_frame_task_handler())
|
||||
|
||||
def __reset_process_task(self):
|
||||
"""Reset non-system frame processing task."""
|
||||
if self._enable_direct_mode:
|
||||
return
|
||||
|
||||
self.__should_block_frames = False
|
||||
self.__process_event = asyncio.Event()
|
||||
while not self.__process_queue.empty():
|
||||
self.__process_queue.get_nowait()
|
||||
self.__process_queue.task_done()
|
||||
|
||||
async def __cancel_process_task(self):
|
||||
"""Cancel the non-system frame processing task."""
|
||||
if self.__process_frame_task:
|
||||
|
||||
@@ -30,7 +30,6 @@ from loguru import logger
|
||||
from pydantic import BaseModel, Field, PrivateAttr, ValidationError
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotInterruptionFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
@@ -1206,7 +1205,7 @@ class RTVIProcessor(FrameProcessor):
|
||||
|
||||
async def interrupt_bot(self):
|
||||
"""Send a bot interruption frame upstream."""
|
||||
await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM)
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
|
||||
async def send_server_message(self, data: Any):
|
||||
"""Send a server message to the client."""
|
||||
|
||||
@@ -19,7 +19,7 @@ from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
StartInterruptionFrame,
|
||||
InterruptionFrame,
|
||||
TranscriptionFrame,
|
||||
TranscriptionMessage,
|
||||
TranscriptionUpdateFrame,
|
||||
@@ -86,7 +86,7 @@ class AssistantTranscriptProcessor(BaseTranscriptProcessor):
|
||||
transcript messages. Utterances are completed when:
|
||||
|
||||
- The bot stops speaking (BotStoppedSpeakingFrame)
|
||||
- The bot is interrupted (StartInterruptionFrame)
|
||||
- The bot is interrupted (InterruptionFrame)
|
||||
- The pipeline ends (EndFrame)
|
||||
"""
|
||||
|
||||
@@ -185,7 +185,7 @@ class AssistantTranscriptProcessor(BaseTranscriptProcessor):
|
||||
|
||||
- TTSTextFrame: Aggregates text for current utterance
|
||||
- BotStoppedSpeakingFrame: Completes current utterance
|
||||
- StartInterruptionFrame: Completes current utterance due to interruption
|
||||
- InterruptionFrame: Completes current utterance due to interruption
|
||||
- EndFrame: Completes current utterance at pipeline end
|
||||
- CancelFrame: Completes current utterance due to cancellation
|
||||
|
||||
@@ -195,7 +195,7 @@ class AssistantTranscriptProcessor(BaseTranscriptProcessor):
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, (StartInterruptionFrame, CancelFrame)):
|
||||
if isinstance(frame, (InterruptionFrame, CancelFrame)):
|
||||
# Push frame first otherwise our emitted transcription update frame
|
||||
# might get cleaned up.
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -17,7 +17,6 @@ from pipecat.frames.frames import (
|
||||
Frame,
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
StartFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
@@ -185,15 +184,13 @@ class UserIdleProcessor(FrameProcessor):
|
||||
|
||||
Runs in a loop until cancelled or callback indicates completion.
|
||||
"""
|
||||
while True:
|
||||
running = True
|
||||
while running:
|
||||
try:
|
||||
await asyncio.wait_for(self._idle_event.wait(), timeout=self._timeout)
|
||||
except asyncio.TimeoutError:
|
||||
if not self._interrupted:
|
||||
self._retry_count += 1
|
||||
should_continue = await self._callback(self, self._retry_count)
|
||||
if not should_continue:
|
||||
await self._stop()
|
||||
break
|
||||
running = await self._callback(self, self._retry_count)
|
||||
finally:
|
||||
self._idle_event.clear()
|
||||
|
||||
@@ -70,7 +70,6 @@ import asyncio
|
||||
import os
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Dict
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -183,13 +182,14 @@ def _setup_webrtc_routes(app: FastAPI, esp32_mode: bool = False, host: str = "lo
|
||||
from pipecat_ai_small_webrtc_prebuilt.frontend import SmallWebRTCPrebuiltUI
|
||||
|
||||
from pipecat.transports.smallwebrtc.connection import SmallWebRTCConnection
|
||||
from pipecat.transports.smallwebrtc.request_handler import (
|
||||
SmallWebRTCRequest,
|
||||
SmallWebRTCRequestHandler,
|
||||
)
|
||||
except ImportError as e:
|
||||
logger.error(f"WebRTC transport dependencies not installed: {e}")
|
||||
return
|
||||
|
||||
# Store connections by pc_id
|
||||
pcs_map: Dict[str, SmallWebRTCConnection] = {}
|
||||
|
||||
# Mount the frontend
|
||||
app.mount("/client", SmallWebRTCPrebuiltUI)
|
||||
|
||||
@@ -198,51 +198,33 @@ def _setup_webrtc_routes(app: FastAPI, esp32_mode: bool = False, host: str = "lo
|
||||
"""Redirect root requests to client interface."""
|
||||
return RedirectResponse(url="/client/")
|
||||
|
||||
# Initialize the SmallWebRTC request handler
|
||||
small_webrtc_handler: SmallWebRTCRequestHandler = SmallWebRTCRequestHandler(
|
||||
esp32_mode=esp32_mode, host=host
|
||||
)
|
||||
|
||||
@app.post("/api/offer")
|
||||
async def offer(request: dict, background_tasks: BackgroundTasks):
|
||||
"""Handle WebRTC offer requests and manage peer connections."""
|
||||
pc_id = request.get("pc_id")
|
||||
|
||||
if pc_id and pc_id in pcs_map:
|
||||
pipecat_connection = pcs_map[pc_id]
|
||||
logger.info(f"Reusing existing connection for pc_id: {pc_id}")
|
||||
await pipecat_connection.renegotiate(
|
||||
sdp=request["sdp"],
|
||||
type=request["type"],
|
||||
restart_pc=request.get("restart_pc", False),
|
||||
)
|
||||
else:
|
||||
pipecat_connection = SmallWebRTCConnection()
|
||||
await pipecat_connection.initialize(sdp=request["sdp"], type=request["type"])
|
||||
|
||||
@pipecat_connection.event_handler("closed")
|
||||
async def handle_disconnected(webrtc_connection: SmallWebRTCConnection):
|
||||
"""Handle WebRTC connection closure and cleanup."""
|
||||
logger.info(f"Discarding peer connection for pc_id: {webrtc_connection.pc_id}")
|
||||
pcs_map.pop(webrtc_connection.pc_id, None)
|
||||
async def offer(request: SmallWebRTCRequest, background_tasks: BackgroundTasks):
|
||||
"""Handle WebRTC offer requests via SmallWebRTCRequestHandler."""
|
||||
|
||||
# Prepare runner arguments with the callback to run your bot
|
||||
async def webrtc_connection_callback(connection):
|
||||
bot_module = _get_bot_module()
|
||||
runner_args = SmallWebRTCRunnerArguments(webrtc_connection=pipecat_connection)
|
||||
runner_args = SmallWebRTCRunnerArguments(webrtc_connection=connection)
|
||||
background_tasks.add_task(bot_module.bot, runner_args)
|
||||
|
||||
answer = pipecat_connection.get_answer()
|
||||
|
||||
# Apply ESP32 SDP munging if enabled
|
||||
if esp32_mode and host != "localhost":
|
||||
from pipecat.runner.utils import smallwebrtc_sdp_munging
|
||||
|
||||
answer["sdp"] = smallwebrtc_sdp_munging(answer["sdp"], host)
|
||||
|
||||
pcs_map[answer["pc_id"]] = pipecat_connection
|
||||
# Delegate handling to SmallWebRTCRequestHandler
|
||||
answer = await small_webrtc_handler.handle_web_request(
|
||||
request=request,
|
||||
webrtc_connection_callback=webrtc_connection_callback,
|
||||
)
|
||||
return answer
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Manage FastAPI application lifecycle and cleanup connections."""
|
||||
yield
|
||||
coros = [pc.disconnect() for pc in pcs_map.values()]
|
||||
await asyncio.gather(*coros)
|
||||
pcs_map.clear()
|
||||
await small_webrtc_handler.close()
|
||||
|
||||
app.router.lifespan_context = lifespan
|
||||
|
||||
|
||||
@@ -51,9 +51,11 @@ class WebSocketRunnerArguments(RunnerArguments):
|
||||
|
||||
Parameters:
|
||||
websocket: WebSocket connection for audio streaming
|
||||
body: Additional request data
|
||||
"""
|
||||
|
||||
websocket: WebSocket
|
||||
body: Optional[Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -99,16 +99,35 @@ async def parse_telephony_websocket(websocket: WebSocket):
|
||||
tuple: (transport_type: str, call_data: dict)
|
||||
|
||||
call_data contains provider-specific fields:
|
||||
- Twilio: {"stream_id": str, "call_id": str}
|
||||
- Telnyx: {"stream_id": str, "call_control_id": str, "outbound_encoding": str}
|
||||
- Plivo: {"stream_id": str, "call_id": str}
|
||||
- Exotel: {"stream_id": str, "call_id": str, "account_sid": str}
|
||||
- Twilio: {
|
||||
"stream_id": str,
|
||||
"call_id": str,
|
||||
"body": dict
|
||||
}
|
||||
- Telnyx: {
|
||||
"stream_id": str,
|
||||
"call_control_id": str,
|
||||
"outbound_encoding": str,
|
||||
"from": str,
|
||||
"to": str,
|
||||
}
|
||||
- Plivo: {
|
||||
"stream_id": str,
|
||||
"call_id": str,
|
||||
}
|
||||
- Exotel: {
|
||||
"stream_id": str,
|
||||
"call_id": str,
|
||||
"account_sid": str,
|
||||
"from": str,
|
||||
"to": str,
|
||||
}
|
||||
|
||||
Example usage::
|
||||
|
||||
transport_type, call_data = await parse_telephony_websocket(websocket)
|
||||
if transport_type == "telnyx":
|
||||
outbound_encoding = call_data["outbound_encoding"]
|
||||
if transport_type == "twilio":
|
||||
user_id = call_data["body"]["user_id"]
|
||||
"""
|
||||
# Read first two messages
|
||||
start_data = websocket.iter_text()
|
||||
@@ -151,9 +170,12 @@ async def parse_telephony_websocket(websocket: WebSocket):
|
||||
# Extract provider-specific data
|
||||
if transport_type == "twilio":
|
||||
start_data = call_data_raw.get("start", {})
|
||||
body_data = start_data.get("customParameters", {})
|
||||
call_data = {
|
||||
"stream_id": start_data.get("streamSid"),
|
||||
"call_id": start_data.get("callSid"),
|
||||
# All custom parameters
|
||||
"body": body_data,
|
||||
}
|
||||
|
||||
elif transport_type == "telnyx":
|
||||
@@ -163,6 +185,8 @@ async def parse_telephony_websocket(websocket: WebSocket):
|
||||
"outbound_encoding": call_data_raw.get("start", {})
|
||||
.get("media_format", {})
|
||||
.get("encoding"),
|
||||
"from": call_data_raw.get("start", {}).get("from", ""),
|
||||
"to": call_data_raw.get("start", {}).get("to", ""),
|
||||
}
|
||||
|
||||
elif transport_type == "plivo":
|
||||
@@ -178,6 +202,8 @@ async def parse_telephony_websocket(websocket: WebSocket):
|
||||
"stream_id": start_data.get("stream_sid"),
|
||||
"call_id": start_data.get("call_sid"),
|
||||
"account_sid": start_data.get("account_sid"),
|
||||
"from": start_data.get("from", ""),
|
||||
"to": start_data.get("to", ""),
|
||||
}
|
||||
|
||||
else:
|
||||
|
||||
@@ -20,8 +20,8 @@ from pipecat.frames.frames import (
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
InputDTMFFrame,
|
||||
InterruptionFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TransportMessageFrame,
|
||||
TransportMessageUrgentFrame,
|
||||
)
|
||||
@@ -98,7 +98,7 @@ class ExotelFrameSerializer(FrameSerializer):
|
||||
Returns:
|
||||
Serialized data as string or bytes, or None if the frame isn't handled.
|
||||
"""
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
if isinstance(frame, InterruptionFrame):
|
||||
answer = {"event": "clear", "streamSid": self._stream_sid}
|
||||
return json.dumps(answer)
|
||||
elif isinstance(frame, AudioRawFrame):
|
||||
|
||||
@@ -22,8 +22,8 @@ from pipecat.frames.frames import (
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
InputDTMFFrame,
|
||||
InterruptionFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TransportMessageFrame,
|
||||
TransportMessageUrgentFrame,
|
||||
)
|
||||
@@ -122,7 +122,7 @@ class PlivoFrameSerializer(FrameSerializer):
|
||||
self._hangup_attempted = True
|
||||
await self._hang_up_call()
|
||||
return None
|
||||
elif isinstance(frame, StartInterruptionFrame):
|
||||
elif isinstance(frame, InterruptionFrame):
|
||||
answer = {"event": "clearAudio", "streamId": self._stream_id}
|
||||
return json.dumps(answer)
|
||||
elif isinstance(frame, AudioRawFrame):
|
||||
|
||||
@@ -29,8 +29,8 @@ from pipecat.frames.frames import (
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
InputDTMFFrame,
|
||||
InterruptionFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
)
|
||||
from pipecat.serializers.base_serializer import FrameSerializer, FrameSerializerType
|
||||
|
||||
@@ -137,7 +137,7 @@ class TelnyxFrameSerializer(FrameSerializer):
|
||||
self._hangup_attempted = True
|
||||
await self._hang_up_call()
|
||||
return None
|
||||
elif isinstance(frame, StartInterruptionFrame):
|
||||
elif isinstance(frame, InterruptionFrame):
|
||||
answer = {"event": "clear"}
|
||||
return json.dumps(answer)
|
||||
elif isinstance(frame, AudioRawFrame):
|
||||
|
||||
@@ -22,8 +22,8 @@ from pipecat.frames.frames import (
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
InputDTMFFrame,
|
||||
InterruptionFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TransportMessageFrame,
|
||||
TransportMessageUrgentFrame,
|
||||
)
|
||||
@@ -122,7 +122,7 @@ class TwilioFrameSerializer(FrameSerializer):
|
||||
self._hangup_attempted = True
|
||||
await self._hang_up_call()
|
||||
return None
|
||||
elif isinstance(frame, StartInterruptionFrame):
|
||||
elif isinstance(frame, InterruptionFrame):
|
||||
answer = {"event": "clear", "streamSid": self._stream_sid}
|
||||
return json.dumps(answer)
|
||||
elif isinstance(frame, AudioRawFrame):
|
||||
|
||||
@@ -24,7 +24,10 @@ from loguru import logger
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from pipecat.adapters.services.anthropic_adapter import AnthropicLLMAdapter
|
||||
from pipecat.adapters.services.anthropic_adapter import (
|
||||
AnthropicLLMAdapter,
|
||||
AnthropicLLMInvocationParams,
|
||||
)
|
||||
from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
@@ -39,7 +42,6 @@ from pipecat.frames.frames import (
|
||||
LLMTextFrame,
|
||||
LLMUpdateSettingsFrame,
|
||||
UserImageRawFrame,
|
||||
VisionImageRawFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
@@ -112,7 +114,12 @@ class AnthropicLLMService(LLMService):
|
||||
"""Input parameters for Anthropic model inference.
|
||||
|
||||
Parameters:
|
||||
enable_prompt_caching_beta: Whether to enable beta prompt caching feature.
|
||||
enable_prompt_caching: Whether to enable the prompt caching feature.
|
||||
enable_prompt_caching_beta (deprecated): Whether to enable the beta prompt caching feature.
|
||||
|
||||
.. deprecated:: 0.0.84
|
||||
Use the `enable_prompt_caching` parameter instead.
|
||||
|
||||
max_tokens: Maximum tokens to generate. Must be at least 1.
|
||||
temperature: Sampling temperature between 0.0 and 1.0.
|
||||
top_k: Top-k sampling parameter.
|
||||
@@ -120,13 +127,26 @@ class AnthropicLLMService(LLMService):
|
||||
extra: Additional parameters to pass to the API.
|
||||
"""
|
||||
|
||||
enable_prompt_caching_beta: Optional[bool] = False
|
||||
enable_prompt_caching: Optional[bool] = None
|
||||
enable_prompt_caching_beta: Optional[bool] = None
|
||||
max_tokens: Optional[int] = Field(default_factory=lambda: 4096, ge=1)
|
||||
temperature: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=1.0)
|
||||
top_k: Optional[int] = Field(default_factory=lambda: NOT_GIVEN, ge=0)
|
||||
top_p: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=1.0)
|
||||
extra: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
def model_post_init(self, __context):
|
||||
"""Post-initialization to handle deprecated parameters."""
|
||||
if self.enable_prompt_caching_beta is not None:
|
||||
import warnings
|
||||
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"enable_prompt_caching_beta is deprecated. Use enable_prompt_caching instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -159,7 +179,15 @@ class AnthropicLLMService(LLMService):
|
||||
self._retry_on_timeout = retry_on_timeout
|
||||
self._settings = {
|
||||
"max_tokens": params.max_tokens,
|
||||
"enable_prompt_caching_beta": params.enable_prompt_caching_beta or False,
|
||||
"enable_prompt_caching": (
|
||||
params.enable_prompt_caching
|
||||
if params.enable_prompt_caching is not None
|
||||
else (
|
||||
params.enable_prompt_caching_beta
|
||||
if params.enable_prompt_caching_beta is not None
|
||||
else False
|
||||
)
|
||||
),
|
||||
"temperature": params.temperature,
|
||||
"top_k": params.top_k,
|
||||
"top_p": params.top_p,
|
||||
@@ -199,34 +227,28 @@ class AnthropicLLMService(LLMService):
|
||||
response = await api_call(**params)
|
||||
return response
|
||||
|
||||
async def run_inference(
|
||||
self, context: LLMContext | OpenAILLMContext, system_instruction: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
async def run_inference(self, context: LLMContext | OpenAILLMContext) -> Optional[str]:
|
||||
"""Run a one-shot, out-of-band (i.e. out-of-pipeline) inference with the given LLM context.
|
||||
|
||||
Args:
|
||||
context: The LLM context containing conversation history.
|
||||
system_instruction: Optional system instruction to guide the LLM's
|
||||
behavior. You could also (again, optionally) provide a system
|
||||
instruction directly in the context. If both are provided, the
|
||||
one in the context takes precedence.
|
||||
|
||||
Returns:
|
||||
The LLM's response as a string, or None if no response is generated.
|
||||
"""
|
||||
messages = []
|
||||
system = []
|
||||
system = NOT_GIVEN
|
||||
if isinstance(context, LLMContext):
|
||||
# Future code will be something like this:
|
||||
# adapter = self.get_llm_adapter()
|
||||
# params: AnthropicLLMInvocationParams = adapter.get_llm_invocation_params(context)
|
||||
# messages = params["messages"]
|
||||
# system = params["system_instruction"]
|
||||
raise NotImplementedError("Universal LLMContext is not yet supported for Anthropic.")
|
||||
adapter: AnthropicLLMAdapter = self.get_llm_adapter()
|
||||
params = adapter.get_llm_invocation_params(
|
||||
context, enable_prompt_caching=self._settings["enable_prompt_caching"]
|
||||
)
|
||||
messages = params["messages"]
|
||||
system = params["system"]
|
||||
else:
|
||||
context = AnthropicLLMContext.upgrade_to_anthropic(context)
|
||||
messages = context.messages
|
||||
system = getattr(context, "system", None) or system_instruction
|
||||
system = getattr(context, "system", NOT_GIVEN)
|
||||
|
||||
# LLM completion
|
||||
response = await self._client.messages.create(
|
||||
@@ -239,15 +261,6 @@ class AnthropicLLMService(LLMService):
|
||||
|
||||
return response.content[0].text
|
||||
|
||||
@property
|
||||
def enable_prompt_caching_beta(self) -> bool:
|
||||
"""Check if prompt caching beta feature is enabled.
|
||||
|
||||
Returns:
|
||||
True if prompt caching is enabled.
|
||||
"""
|
||||
return self._enable_prompt_caching_beta
|
||||
|
||||
def create_context_aggregator(
|
||||
self,
|
||||
context: OpenAILLMContext,
|
||||
@@ -277,8 +290,31 @@ class AnthropicLLMService(LLMService):
|
||||
assistant = AnthropicAssistantContextAggregator(context, params=assistant_params)
|
||||
return AnthropicContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
def _get_llm_invocation_params(
|
||||
self, context: OpenAILLMContext | LLMContext
|
||||
) -> AnthropicLLMInvocationParams:
|
||||
# Universal LLMContext
|
||||
if isinstance(context, LLMContext):
|
||||
adapter: AnthropicLLMAdapter = self.get_llm_adapter()
|
||||
params = adapter.get_llm_invocation_params(
|
||||
context, enable_prompt_caching=self._settings["enable_prompt_caching"]
|
||||
)
|
||||
return params
|
||||
|
||||
# Anthropic-specific context
|
||||
messages = (
|
||||
context.get_messages_with_cache_control_markers()
|
||||
if self._settings["enable_prompt_caching"]
|
||||
else context.messages
|
||||
)
|
||||
return AnthropicLLMInvocationParams(
|
||||
system=context.system,
|
||||
messages=messages,
|
||||
tools=context.tools or [],
|
||||
)
|
||||
|
||||
@traced_llm
|
||||
async def _process_context(self, context: OpenAILLMContext):
|
||||
async def _process_context(self, context: OpenAILLMContext | LLMContext):
|
||||
# Usage tracking. We track the usage reported by Anthropic in prompt_tokens and
|
||||
# completion_tokens. We also estimate the completion tokens from output text
|
||||
# and use that estimate if we are interrupted, because we almost certainly won't
|
||||
@@ -294,24 +330,22 @@ class AnthropicLLMService(LLMService):
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
await self.start_processing_metrics()
|
||||
|
||||
params_from_context = self._get_llm_invocation_params(context)
|
||||
|
||||
if isinstance(context, LLMContext):
|
||||
adapter = self.get_llm_adapter()
|
||||
context_type_for_logging = "universal"
|
||||
messages_for_logging = adapter.get_messages_for_logging(context)
|
||||
else:
|
||||
context_type_for_logging = "LLM-specific"
|
||||
messages_for_logging = context.get_messages_for_logging()
|
||||
logger.debug(
|
||||
f"{self}: Generating chat [{context.system}] | {context.get_messages_for_logging()}"
|
||||
f"{self}: Generating chat from {context_type_for_logging} context [{params_from_context['system']}] | {messages_for_logging}"
|
||||
)
|
||||
|
||||
messages = context.messages
|
||||
if self._settings["enable_prompt_caching_beta"]:
|
||||
messages = context.get_messages_with_cache_control_markers()
|
||||
|
||||
api_call = self._client.messages.create
|
||||
if self._settings["enable_prompt_caching_beta"]:
|
||||
api_call = self._client.beta.prompt_caching.messages.create
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
params = {
|
||||
"tools": context.tools or [],
|
||||
"system": context.system,
|
||||
"messages": messages,
|
||||
"model": self.model_name,
|
||||
"max_tokens": self._settings["max_tokens"],
|
||||
"stream": True,
|
||||
@@ -320,9 +354,12 @@ class AnthropicLLMService(LLMService):
|
||||
"top_p": self._settings["top_p"],
|
||||
}
|
||||
|
||||
# Messages, system, tools
|
||||
params.update(params_from_context)
|
||||
|
||||
params.update(self._settings["extra"])
|
||||
|
||||
response = await self._create_message_stream(api_call, params)
|
||||
response = await self._create_message_stream(self._client.messages.create, params)
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
@@ -405,7 +442,10 @@ class AnthropicLLMService(LLMService):
|
||||
prompt_tokens + cache_creation_input_tokens + cache_read_input_tokens
|
||||
)
|
||||
if total_input_tokens >= 1024:
|
||||
context.turns_above_cache_threshold += 1
|
||||
if hasattr(
|
||||
context, "turns_above_cache_threshold"
|
||||
): # LLMContext doesn't have this attribute
|
||||
context.turns_above_cache_threshold += 1
|
||||
|
||||
await self.run_function_calls(function_calls)
|
||||
|
||||
@@ -451,20 +491,14 @@ class AnthropicLLMService(LLMService):
|
||||
if isinstance(frame, OpenAILLMContextFrame):
|
||||
context: "AnthropicLLMContext" = AnthropicLLMContext.upgrade_to_anthropic(frame.context)
|
||||
elif isinstance(frame, LLMContextFrame):
|
||||
raise NotImplementedError("Universal LLMContext is not yet supported for Anthropic.")
|
||||
context = frame.context
|
||||
elif isinstance(frame, LLMMessagesFrame):
|
||||
context = AnthropicLLMContext.from_messages(frame.messages)
|
||||
elif isinstance(frame, VisionImageRawFrame):
|
||||
# This is only useful in very simple pipelines because it creates
|
||||
# a new context. Generally we want a context manager to catch
|
||||
# UserImageRawFrames coming through the pipeline and add them
|
||||
# to the context.
|
||||
context = AnthropicLLMContext.from_image_frame(frame)
|
||||
elif isinstance(frame, LLMUpdateSettingsFrame):
|
||||
await self._update_settings(frame.settings)
|
||||
elif isinstance(frame, LLMEnablePromptCachingFrame):
|
||||
logger.debug(f"Setting enable prompt caching to: [{frame.enable}]")
|
||||
self._settings["enable_prompt_caching_beta"] = frame.enable
|
||||
self._settings["enable_prompt_caching"] = frame.enable
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -585,22 +619,6 @@ class AnthropicLLMContext(OpenAILLMContext):
|
||||
self._restructure_from_openai_messages()
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_image_frame(cls, frame: VisionImageRawFrame) -> "AnthropicLLMContext":
|
||||
"""Create context from a vision image frame.
|
||||
|
||||
Args:
|
||||
frame: The vision image frame to process.
|
||||
|
||||
Returns:
|
||||
New Anthropic context with the image message.
|
||||
"""
|
||||
context = cls()
|
||||
context.add_image_frame_message(
|
||||
format=frame.format, size=frame.size, image=frame.image, text=frame.text
|
||||
)
|
||||
return context
|
||||
|
||||
def set_messages(self, messages: List):
|
||||
"""Set the messages list and reset cache tracking.
|
||||
|
||||
|
||||
@@ -20,8 +20,8 @@ from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
@@ -52,6 +52,10 @@ def language_to_async_language(language: Language) -> Optional[str]:
|
||||
"""
|
||||
BASE_LANGUAGES = {
|
||||
Language.EN: "en",
|
||||
Language.FR: "fr",
|
||||
Language.ES: "es",
|
||||
Language.DE: "de",
|
||||
Language.IT: "it",
|
||||
}
|
||||
|
||||
result = BASE_LANGUAGES.get(language)
|
||||
@@ -115,7 +119,6 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
"""
|
||||
super().__init__(
|
||||
aggregate_sentences=aggregate_sentences,
|
||||
push_text_frames=False,
|
||||
pause_frame_processing=True,
|
||||
push_stop_frames=True,
|
||||
sample_rate=sample_rate,
|
||||
@@ -271,7 +274,7 @@ class AsyncAITTSService(InterruptibleTTSService):
|
||||
direction: The direction to push the frame.
|
||||
"""
|
||||
await super().push_frame(frame, direction)
|
||||
if isinstance(frame, (TTSStoppedFrame, StartInterruptionFrame)):
|
||||
if isinstance(frame, (TTSStoppedFrame, InterruptionFrame)):
|
||||
self._started = False
|
||||
|
||||
async def _receive_messages(self):
|
||||
|
||||
@@ -25,7 +25,10 @@ from loguru import logger
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from pipecat.adapters.services.bedrock_adapter import AWSBedrockLLMAdapter
|
||||
from pipecat.adapters.services.bedrock_adapter import (
|
||||
AWSBedrockLLMAdapter,
|
||||
AWSBedrockLLMInvocationParams,
|
||||
)
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
FunctionCallCancelFrame,
|
||||
@@ -39,7 +42,6 @@ from pipecat.frames.frames import (
|
||||
LLMTextFrame,
|
||||
LLMUpdateSettingsFrame,
|
||||
UserImageRawFrame,
|
||||
VisionImageRawFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
@@ -180,22 +182,6 @@ class AWSBedrockLLMContext(OpenAILLMContext):
|
||||
self._restructure_from_openai_messages()
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_image_frame(cls, frame: VisionImageRawFrame) -> "AWSBedrockLLMContext":
|
||||
"""Create AWS Bedrock context from vision image frame.
|
||||
|
||||
Args:
|
||||
frame: The vision image frame to convert.
|
||||
|
||||
Returns:
|
||||
New AWS Bedrock LLM context instance.
|
||||
"""
|
||||
context = cls()
|
||||
context.add_image_frame_message(
|
||||
format=frame.format, size=frame.size, image=frame.image, text=frame.text
|
||||
)
|
||||
return context
|
||||
|
||||
def set_messages(self, messages: List):
|
||||
"""Set the messages list and restructure for Bedrock format.
|
||||
|
||||
@@ -399,9 +385,33 @@ class AWSBedrockLLMContext(OpenAILLMContext):
|
||||
elif isinstance(content, list):
|
||||
new_content = []
|
||||
for item in content:
|
||||
# fix empty text
|
||||
if item.get("type", "") == "text":
|
||||
text_content = item["text"] if item["text"] != "" else "(empty)"
|
||||
new_content.append({"text": text_content})
|
||||
# handle image_url -> image conversion
|
||||
if item["type"] == "image_url":
|
||||
new_item = {
|
||||
"image": {
|
||||
"format": "jpeg",
|
||||
"source": {
|
||||
"bytes": base64.b64decode(item["image_url"]["url"].split(",")[1])
|
||||
},
|
||||
}
|
||||
}
|
||||
new_content.append(new_item)
|
||||
# In the case where there's a single image in the list (like what
|
||||
# would result from a UserImageRawFrame), ensure that the image
|
||||
# comes before text
|
||||
image_indices = [i for i, item in enumerate(new_content) if "image" in item]
|
||||
text_indices = [i for i, item in enumerate(new_content) if "text" in item]
|
||||
if len(image_indices) == 1 and text_indices:
|
||||
img_idx = image_indices[0]
|
||||
first_txt_idx = text_indices[0]
|
||||
if img_idx > first_txt_idx:
|
||||
# Move image before the first text
|
||||
image_item = new_content.pop(img_idx)
|
||||
new_content.insert(first_txt_idx, image_item)
|
||||
return {"role": message["role"], "content": new_content}
|
||||
|
||||
return message
|
||||
@@ -569,7 +579,7 @@ class AWSBedrockLLMContext(OpenAILLMContext):
|
||||
if isinstance(msg["content"], list):
|
||||
for item in msg["content"]:
|
||||
if item.get("image"):
|
||||
item["source"]["bytes"] = "..."
|
||||
item["image"]["source"]["bytes"] = "..."
|
||||
msgs.append(msg)
|
||||
return msgs
|
||||
|
||||
@@ -792,79 +802,64 @@ class AWSBedrockLLMService(LLMService):
|
||||
"""
|
||||
return True
|
||||
|
||||
async def run_inference(
|
||||
self, context: LLMContext | OpenAILLMContext, system_instruction: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
async def run_inference(self, context: LLMContext | OpenAILLMContext) -> Optional[str]:
|
||||
"""Run a one-shot, out-of-band (i.e. out-of-pipeline) inference with the given LLM context.
|
||||
|
||||
Args:
|
||||
context: The LLM context containing conversation history.
|
||||
system_instruction: Optional system instruction to guide the LLM's
|
||||
behavior. You could also (again, optionally) provide a system
|
||||
instruction directly in the context. If both are provided, the
|
||||
one in the context takes precedence.
|
||||
|
||||
Returns:
|
||||
The LLM's response as a string, or None if no response is generated.
|
||||
"""
|
||||
try:
|
||||
messages = []
|
||||
system = []
|
||||
if isinstance(context, LLMContext):
|
||||
# Future code will be something like this:
|
||||
# adapter = self.get_llm_adapter()
|
||||
# params: AWSBedrockLLMInvocationParams = adapter.get_llm_invocation_params(context)
|
||||
# messages = params["messages"]
|
||||
# system = params["system_instruction"]
|
||||
raise NotImplementedError(
|
||||
"Universal LLMContext is not yet supported for AWS Bedrock."
|
||||
)
|
||||
else:
|
||||
context = AWSBedrockLLMContext.upgrade_to_bedrock(context)
|
||||
messages = context.messages
|
||||
system = getattr(context, "system", None) or system_instruction
|
||||
messages = []
|
||||
system = []
|
||||
if isinstance(context, LLMContext):
|
||||
adapter: AWSBedrockLLMAdapter = self.get_llm_adapter()
|
||||
params: AWSBedrockLLMInvocationParams = adapter.get_llm_invocation_params(context)
|
||||
messages = params["messages"]
|
||||
system = params["system"] # [{"text": "system message"}]
|
||||
else:
|
||||
context = AWSBedrockLLMContext.upgrade_to_bedrock(context)
|
||||
messages = context.messages
|
||||
system = getattr(context, "system", None) # [{"text": "system message"}]
|
||||
|
||||
# Determine if we're using Claude or Nova based on model ID
|
||||
model_id = self.model_name
|
||||
# Determine if we're using Claude or Nova based on model ID
|
||||
model_id = self.model_name
|
||||
|
||||
# Prepare request parameters
|
||||
request_params = {
|
||||
"modelId": model_id,
|
||||
"messages": messages,
|
||||
"inferenceConfig": {
|
||||
"maxTokens": 8192,
|
||||
"temperature": 0.7,
|
||||
"topP": 0.9,
|
||||
},
|
||||
}
|
||||
# Prepare request parameters
|
||||
request_params = {
|
||||
"modelId": model_id,
|
||||
"messages": messages,
|
||||
"inferenceConfig": {
|
||||
"maxTokens": 8192,
|
||||
"temperature": 0.7,
|
||||
"topP": 0.9,
|
||||
},
|
||||
}
|
||||
|
||||
if system:
|
||||
request_params["system"] = [{"text": system}]
|
||||
if system:
|
||||
request_params["system"] = system
|
||||
|
||||
async with self._aws_session.client(
|
||||
service_name="bedrock-runtime", **self._aws_params
|
||||
) as client:
|
||||
# Call Bedrock without streaming
|
||||
response = await client.converse(**request_params)
|
||||
async with self._aws_session.client(
|
||||
service_name="bedrock-runtime", **self._aws_params
|
||||
) as client:
|
||||
# Call Bedrock without streaming
|
||||
response = await client.converse(**request_params)
|
||||
|
||||
# Extract the response text
|
||||
if (
|
||||
"output" in response
|
||||
and "message" in response["output"]
|
||||
and "content" in response["output"]["message"]
|
||||
):
|
||||
content = response["output"]["message"]["content"]
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if item.get("text"):
|
||||
return item["text"]
|
||||
elif isinstance(content, str):
|
||||
return content
|
||||
# Extract the response text
|
||||
if (
|
||||
"output" in response
|
||||
and "message" in response["output"]
|
||||
and "content" in response["output"]["message"]
|
||||
):
|
||||
content = response["output"]["message"]["content"]
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if item.get("text"):
|
||||
return item["text"]
|
||||
elif isinstance(content, str):
|
||||
return content
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Bedrock summary generation failed: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
async def _create_converse_stream(self, client, request_params):
|
||||
@@ -880,7 +875,7 @@ class AWSBedrockLLMService(LLMService):
|
||||
if self._retry_on_timeout:
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
await client.converse_stream(**request_params), timeout=self._retry_timeout_secs
|
||||
client.converse_stream(**request_params), timeout=self._retry_timeout_secs
|
||||
)
|
||||
return response
|
||||
except (ReadTimeoutError, asyncio.TimeoutError) as e:
|
||||
@@ -939,8 +934,25 @@ class AWSBedrockLLMService(LLMService):
|
||||
}
|
||||
}
|
||||
|
||||
def _get_llm_invocation_params(
|
||||
self, context: OpenAILLMContext | LLMContext
|
||||
) -> AWSBedrockLLMInvocationParams:
|
||||
# Universal LLMContext
|
||||
if isinstance(context, LLMContext):
|
||||
adapter: AWSBedrockLLMAdapter = self.get_llm_adapter()
|
||||
params = adapter.get_llm_invocation_params(context)
|
||||
return params
|
||||
|
||||
# AWS Bedrock-specific context
|
||||
return AWSBedrockLLMInvocationParams(
|
||||
system=getattr(context, "system", None),
|
||||
messages=context.messages,
|
||||
tools=context.tools or [],
|
||||
tool_choice=context.tool_choice,
|
||||
)
|
||||
|
||||
@traced_llm
|
||||
async def _process_context(self, context: AWSBedrockLLMContext):
|
||||
async def _process_context(self, context: AWSBedrockLLMContext | LLMContext):
|
||||
# Usage tracking
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
@@ -957,6 +969,12 @@ class AWSBedrockLLMService(LLMService):
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
params_from_context = self._get_llm_invocation_params(context)
|
||||
messages = params_from_context["messages"]
|
||||
system = params_from_context["system"]
|
||||
tools = params_from_context["tools"]
|
||||
tool_choice = params_from_context["tool_choice"]
|
||||
|
||||
# Set up inference config
|
||||
inference_config = {
|
||||
"maxTokens": self._settings["max_tokens"],
|
||||
@@ -967,17 +985,18 @@ class AWSBedrockLLMService(LLMService):
|
||||
# Prepare request parameters
|
||||
request_params = {
|
||||
"modelId": self.model_name,
|
||||
"messages": context.messages,
|
||||
"messages": messages,
|
||||
"inferenceConfig": inference_config,
|
||||
"additionalModelRequestFields": self._settings["additional_model_request_fields"],
|
||||
}
|
||||
|
||||
# Add system message
|
||||
request_params["system"] = context.system
|
||||
if system:
|
||||
request_params["system"] = system
|
||||
|
||||
# Check if messages contain tool use or tool result content blocks
|
||||
has_tool_content = False
|
||||
for message in context.messages:
|
||||
for message in messages:
|
||||
if isinstance(message.get("content"), list):
|
||||
for content_item in message["content"]:
|
||||
if "toolUse" in content_item or "toolResult" in content_item:
|
||||
@@ -987,7 +1006,6 @@ class AWSBedrockLLMService(LLMService):
|
||||
break
|
||||
|
||||
# Handle tools: use current tools, or no-op if tool content exists but no current tools
|
||||
tools = context.tools or []
|
||||
if has_tool_content and not tools:
|
||||
tools = [self._create_no_op_tool()]
|
||||
using_noop_tool = True
|
||||
@@ -996,17 +1014,15 @@ class AWSBedrockLLMService(LLMService):
|
||||
tool_config = {"tools": tools}
|
||||
|
||||
# Only add tool_choice if we have real tools (not just no-op)
|
||||
if not using_noop_tool and context.tool_choice:
|
||||
if context.tool_choice == "auto":
|
||||
if not using_noop_tool and tool_choice:
|
||||
if tool_choice == "auto":
|
||||
tool_config["toolChoice"] = {"auto": {}}
|
||||
elif context.tool_choice == "none":
|
||||
elif tool_choice == "none":
|
||||
# Skip adding toolChoice for "none"
|
||||
pass
|
||||
elif (
|
||||
isinstance(context.tool_choice, dict) and "function" in context.tool_choice
|
||||
):
|
||||
elif isinstance(tool_choice, dict) and "function" in tool_choice:
|
||||
tool_config["toolChoice"] = {
|
||||
"tool": {"name": context.tool_choice["function"]["name"]}
|
||||
"tool": {"name": tool_choice["function"]["name"]}
|
||||
}
|
||||
|
||||
request_params["toolConfig"] = tool_config
|
||||
@@ -1015,7 +1031,17 @@ class AWSBedrockLLMService(LLMService):
|
||||
if self._settings["latency"] in ["standard", "optimized"]:
|
||||
request_params["performanceConfig"] = {"latency": self._settings["latency"]}
|
||||
|
||||
logger.debug(f"Calling AWS Bedrock model with: {request_params}")
|
||||
# Log request params with messages redacted for logging
|
||||
if isinstance(context, LLMContext):
|
||||
adapter = self.get_llm_adapter()
|
||||
context_type_for_logging = "universal"
|
||||
messages_for_logging = adapter.get_messages_for_logging(context)
|
||||
else:
|
||||
context_type_for_logging = "LLM-specific"
|
||||
messages_for_logging = context.get_messages_for_logging()
|
||||
logger.debug(
|
||||
f"{self}: Generating chat from {context_type_for_logging} context [{system}] | {messages_for_logging}"
|
||||
)
|
||||
|
||||
async with self._aws_session.client(
|
||||
service_name="bedrock-runtime", **self._aws_params
|
||||
@@ -1123,15 +1149,9 @@ class AWSBedrockLLMService(LLMService):
|
||||
if isinstance(frame, OpenAILLMContextFrame):
|
||||
context = AWSBedrockLLMContext.upgrade_to_bedrock(frame.context)
|
||||
if isinstance(frame, LLMContextFrame):
|
||||
raise NotImplementedError("Universal LLMContext is not yet supported for AWS Bedrock.")
|
||||
context = frame.context
|
||||
elif isinstance(frame, LLMMessagesFrame):
|
||||
context = AWSBedrockLLMContext.from_messages(frame.messages)
|
||||
elif isinstance(frame, VisionImageRawFrame):
|
||||
# This is only useful in very simple pipelines because it creates
|
||||
# a new context. Generally we want a context manager to catch
|
||||
# UserImageRawFrames coming through the pipeline and add them
|
||||
# to the context.
|
||||
context = AWSBedrockLLMContext.from_image_frame(frame)
|
||||
elif isinstance(frame, LLMUpdateSettingsFrame):
|
||||
await self._update_settings(frame.settings)
|
||||
else:
|
||||
|
||||
@@ -532,9 +532,7 @@ class AWSTranscribeSTTService(STTService):
|
||||
logger.debug(f"{self} Other message type received: {headers}")
|
||||
logger.debug(f"{self} Payload: {payload}")
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
logger.error(
|
||||
f"{self} WebSocket connection closed in receive loop with code {e.code}: {e.reason}"
|
||||
)
|
||||
logger.error(f"{self} WebSocket connection closed in receive loop: {e}")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"{self} Unexpected error in receive loop: {e}")
|
||||
|
||||
@@ -247,13 +247,14 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
self._ready_to_send_context = False
|
||||
self._handling_bot_stopped_speaking = False
|
||||
self._triggering_assistant_response = False
|
||||
self._assistant_response_trigger_audio: Optional[bytes] = (
|
||||
None # Not cleared on _disconnect()
|
||||
)
|
||||
self._disconnecting = False
|
||||
self._connected_time: Optional[float] = None
|
||||
self._wants_connection = False
|
||||
|
||||
file_path = files("pipecat.services.aws_nova_sonic").joinpath("ready.wav")
|
||||
with wave.open(file_path.open("rb"), "rb") as wav_file:
|
||||
self._assistant_response_trigger_audio = wav_file.readframes(wav_file.getnframes())
|
||||
|
||||
#
|
||||
# standard AIService frame handling
|
||||
#
|
||||
@@ -1099,20 +1100,13 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
|
||||
self._triggering_assistant_response = True
|
||||
|
||||
# Read audio bytes, if we don't already have them cached
|
||||
if not self._assistant_response_trigger_audio:
|
||||
file_path = files("pipecat.services.aws_nova_sonic").joinpath("ready.wav")
|
||||
with wave.open(file_path.open("rb"), "rb") as wav_file:
|
||||
self._assistant_response_trigger_audio = wav_file.readframes(wav_file.getnframes())
|
||||
|
||||
# Send the trigger audio, if we're fully connected and set up
|
||||
if self._connected_time is not None:
|
||||
if self._connected_time:
|
||||
await self._send_assistant_response_trigger()
|
||||
|
||||
async def _send_assistant_response_trigger(self):
|
||||
if (
|
||||
not self._assistant_response_trigger_audio or self._connected_time is None
|
||||
): # should never happen
|
||||
if not self._connected_time:
|
||||
# should never happen
|
||||
return
|
||||
|
||||
try:
|
||||
|
||||
@@ -21,13 +21,13 @@ from pipecat.frames.frames import (
|
||||
DataFrame,
|
||||
Frame,
|
||||
FunctionCallResultFrame,
|
||||
InterruptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesAppendFrame,
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMSetToolChoiceFrame,
|
||||
LLMSetToolsFrame,
|
||||
StartInterruptionFrame,
|
||||
TextFrame,
|
||||
UserImageRawFrame,
|
||||
)
|
||||
@@ -306,7 +306,7 @@ class AWSNovaSonicAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
if isinstance(
|
||||
frame,
|
||||
(
|
||||
StartInterruptionFrame,
|
||||
InterruptionFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
TextFrame,
|
||||
|
||||
@@ -19,6 +19,7 @@ from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
)
|
||||
@@ -140,6 +141,7 @@ class AzureSTTService(STTService):
|
||||
self._speech_recognizer = SpeechRecognizer(
|
||||
speech_config=self._speech_config, audio_config=audio_config
|
||||
)
|
||||
self._speech_recognizer.recognizing.connect(self._on_handle_recognizing)
|
||||
self._speech_recognizer.recognized.connect(self._on_handle_recognized)
|
||||
self._speech_recognizer.start_continuous_recognition_async()
|
||||
|
||||
@@ -197,3 +199,15 @@ class AzureSTTService(STTService):
|
||||
self._handle_transcription(event.result.text, True, language), self.get_event_loop()
|
||||
)
|
||||
asyncio.run_coroutine_threadsafe(self.push_frame(frame), self.get_event_loop())
|
||||
|
||||
def _on_handle_recognizing(self, event):
|
||||
if event.result.reason == ResultReason.RecognizingSpeech and len(event.result.text) > 0:
|
||||
language = getattr(event.result, "language", None) or self._settings.get("language")
|
||||
frame = InterimTranscriptionFrame(
|
||||
event.result.text,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
result=event,
|
||||
)
|
||||
asyncio.run_coroutine_threadsafe(self.push_frame(frame), self.get_event_loop())
|
||||
|
||||
@@ -20,8 +20,8 @@ from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
@@ -371,7 +371,7 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
||||
return self._websocket
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
await super()._handle_interruption(frame, direction)
|
||||
await self.stop_all_metrics()
|
||||
if self._context_id:
|
||||
|
||||
@@ -25,9 +25,9 @@ from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
@@ -460,7 +460,7 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
direction: The direction to push the frame.
|
||||
"""
|
||||
await super().push_frame(frame, direction)
|
||||
if isinstance(frame, (TTSStoppedFrame, StartInterruptionFrame)):
|
||||
if isinstance(frame, (TTSStoppedFrame, InterruptionFrame)):
|
||||
self._started = False
|
||||
if isinstance(frame, TTSStoppedFrame):
|
||||
await self.add_word_timestamps([("Reset", 0)])
|
||||
@@ -549,7 +549,7 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
return self._websocket
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
"""Handle interruption by closing the current context."""
|
||||
await super()._handle_interruption(frame, direction)
|
||||
|
||||
@@ -558,7 +558,7 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
logger.trace(f"Closing context {self._context_id} due to interruption")
|
||||
try:
|
||||
# ElevenLabs requires that Pipecat manages the contexts and closes them
|
||||
# when they're not longer in use. Since a StartInterruptionFrame is pushed
|
||||
# when they're not longer in use. Since an InterruptionFrame is pushed
|
||||
# every time the user speaks, we'll use this as a trigger to close the context
|
||||
# and reset the state.
|
||||
# Note: We do not need to call remove_audio_context here, as the context is
|
||||
@@ -856,7 +856,7 @@ class ElevenLabsHttpTTSService(WordTTSService):
|
||||
direction: The direction to push the frame.
|
||||
"""
|
||||
await super().push_frame(frame, direction)
|
||||
if isinstance(frame, (StartInterruptionFrame, TTSStoppedFrame)):
|
||||
if isinstance(frame, (InterruptionFrame, TTSStoppedFrame)):
|
||||
# Reset timing on interruption or stop
|
||||
self._reset_state()
|
||||
|
||||
|
||||
@@ -21,8 +21,8 @@ from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
@@ -259,7 +259,7 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
return self._websocket
|
||||
raise Exception("Websocket not connected")
|
||||
|
||||
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
|
||||
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
|
||||
await super()._handle_interruption(frame, direction)
|
||||
await self.stop_all_metrics()
|
||||
self._request_id = None
|
||||
|
||||
@@ -33,6 +33,8 @@ from pipecat.frames.frames import (
|
||||
InputAudioRawFrame,
|
||||
InputImageRawFrame,
|
||||
InputTextRawFrame,
|
||||
InterruptionFrame,
|
||||
LLMContextFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesAppendFrame,
|
||||
@@ -40,7 +42,6 @@ from pipecat.frames.frames import (
|
||||
LLMTextFrame,
|
||||
LLMUpdateSettingsFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TranscriptionFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
@@ -738,6 +739,10 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
# Support just one tool call per context frame for now
|
||||
tool_result_message = context.messages[-1]
|
||||
await self._tool_result(tool_result_message)
|
||||
elif isinstance(frame, LLMContextFrame):
|
||||
raise NotImplementedError(
|
||||
"Universal LLMContext is not yet supported for Gemini Multimodal Live."
|
||||
)
|
||||
elif isinstance(frame, InputTextRawFrame):
|
||||
await self._send_user_text(frame.text)
|
||||
await self.push_frame(frame, direction)
|
||||
@@ -747,7 +752,7 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
elif isinstance(frame, InputImageRawFrame):
|
||||
await self._send_user_video(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, StartInterruptionFrame):
|
||||
elif isinstance(frame, InterruptionFrame):
|
||||
await self._handle_interruption()
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, UserStartedSpeakingFrame):
|
||||
|
||||
@@ -13,6 +13,7 @@ supporting multiple languages, custom vocabulary, and various audio processing o
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import warnings
|
||||
from typing import Any, AsyncGenerator, Dict, Literal, Optional
|
||||
|
||||
import aiohttp
|
||||
@@ -173,8 +174,6 @@ class _InputParamsDescriptor:
|
||||
"""Descriptor for backward compatibility with deprecation warning."""
|
||||
|
||||
def __get__(self, obj, objtype=None):
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
@@ -208,7 +207,7 @@ class GladiaSTTService(STTService):
|
||||
api_key: str,
|
||||
region: Literal["us-west", "eu-west"] | None = None,
|
||||
url: str = "https://api.gladia.io/v2/live",
|
||||
confidence: float = 0.5,
|
||||
confidence: Optional[float] = None,
|
||||
sample_rate: Optional[int] = None,
|
||||
model: str = "solaria-1",
|
||||
params: Optional[GladiaInputParams] = None,
|
||||
@@ -224,6 +223,11 @@ class GladiaSTTService(STTService):
|
||||
region: Region used to process audio. eu-west or us-west. Defaults to eu-west.
|
||||
url: Gladia API URL. Defaults to "https://api.gladia.io/v2/live".
|
||||
confidence: Minimum confidence threshold for transcriptions (0.0-1.0).
|
||||
|
||||
.. deprecated:: 0.0.86
|
||||
The 'confidence' parameter is deprecated and will be removed in a future version.
|
||||
No confidence threshold is applied.
|
||||
|
||||
sample_rate: Audio sample rate in Hz. If None, uses service default.
|
||||
model: Model to use for transcription. Defaults to "solaria-1".
|
||||
params: Additional configuration parameters for Gladia service.
|
||||
@@ -236,7 +240,6 @@ class GladiaSTTService(STTService):
|
||||
|
||||
params = params or GladiaInputParams()
|
||||
|
||||
# Warn about deprecated language parameter if it's used
|
||||
if params.language is not None:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
@@ -247,11 +250,20 @@ class GladiaSTTService(STTService):
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if confidence:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"The 'confidence' parameter is deprecated and will be removed in a future version. "
|
||||
"No confidence threshold is applied.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
self._api_key = api_key
|
||||
self._region = region
|
||||
self._url = url
|
||||
self.set_model_name(model)
|
||||
self._confidence = confidence
|
||||
self._params = params
|
||||
self._websocket = None
|
||||
self._receive_task = None
|
||||
@@ -575,43 +587,40 @@ class GladiaSTTService(STTService):
|
||||
|
||||
elif content["type"] == "transcript":
|
||||
utterance = content["data"]["utterance"]
|
||||
confidence = utterance.get("confidence", 0)
|
||||
language = utterance["language"]
|
||||
transcript = utterance["text"]
|
||||
is_final = content["data"]["is_final"]
|
||||
if confidence >= self._confidence:
|
||||
if is_final:
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
result=content,
|
||||
)
|
||||
if is_final:
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
result=content,
|
||||
)
|
||||
await self._handle_transcription(
|
||||
transcript=transcript,
|
||||
is_final=is_final,
|
||||
language=language,
|
||||
)
|
||||
else:
|
||||
await self.push_frame(
|
||||
InterimTranscriptionFrame(
|
||||
transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
result=content,
|
||||
)
|
||||
)
|
||||
await self._handle_transcription(
|
||||
transcript=transcript,
|
||||
is_final=is_final,
|
||||
language=language,
|
||||
)
|
||||
else:
|
||||
await self.push_frame(
|
||||
InterimTranscriptionFrame(
|
||||
transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
language,
|
||||
result=content,
|
||||
)
|
||||
)
|
||||
elif content["type"] == "translation":
|
||||
translated_utterance = content["data"]["translated_utterance"]
|
||||
original_language = content["data"]["original_language"]
|
||||
translated_language = translated_utterance["language"]
|
||||
confidence = translated_utterance.get("confidence", 0)
|
||||
translation = translated_utterance["text"]
|
||||
if translated_language != original_language and confidence >= self._confidence:
|
||||
if translated_language != original_language:
|
||||
await self.push_frame(
|
||||
TranslationFrame(
|
||||
translation, "", time_now_iso8601(), translated_language
|
||||
|
||||
@@ -36,7 +36,6 @@ from pipecat.frames.frames import (
|
||||
LLMTextFrame,
|
||||
LLMUpdateSettingsFrame,
|
||||
UserImageRawFrame,
|
||||
VisionImageRawFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
@@ -733,17 +732,11 @@ class GoogleLLMService(LLMService):
|
||||
def _create_client(self, api_key: str, http_options: Optional[HttpOptions] = None):
|
||||
self._client = genai.Client(api_key=api_key, http_options=http_options)
|
||||
|
||||
async def run_inference(
|
||||
self, context: LLMContext | OpenAILLMContext, system_instruction: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
async def run_inference(self, context: LLMContext | OpenAILLMContext) -> Optional[str]:
|
||||
"""Run a one-shot, out-of-band (i.e. out-of-pipeline) inference with the given LLM context.
|
||||
|
||||
Args:
|
||||
context: The LLM context containing conversation history.
|
||||
system_instruction: Optional system instruction to guide the LLM's
|
||||
behavior. You could also (again, optionally) provide a system
|
||||
instruction directly in the context. If both are provided, the
|
||||
one in the context takes precedence.
|
||||
|
||||
Returns:
|
||||
The LLM's response as a string, or None if no response is generated.
|
||||
@@ -758,7 +751,7 @@ class GoogleLLMService(LLMService):
|
||||
else:
|
||||
context = GoogleLLMContext.upgrade_to_google(context)
|
||||
messages = context.messages
|
||||
system = getattr(context, "system_message", None) or system_instruction
|
||||
system = getattr(context, "system_message", None)
|
||||
|
||||
generation_config = GenerateContentConfig(system_instruction=system)
|
||||
|
||||
@@ -858,8 +851,7 @@ class GoogleLLMService(LLMService):
|
||||
self, context: OpenAILLMContext
|
||||
) -> AsyncIterator[GenerateContentResponse]:
|
||||
logger.debug(
|
||||
# f"{self}: Generating chat [{self._system_instruction}] | {context.get_messages_for_logging()}"
|
||||
f"{self}: Generating chat from OpenAI context {context.get_messages_for_logging()}"
|
||||
f"{self}: Generating chat from LLM-specific context [{context.system_message}] | {context.get_messages_for_logging()}"
|
||||
)
|
||||
|
||||
params = GeminiLLMInvocationParams(
|
||||
@@ -874,13 +866,12 @@ class GoogleLLMService(LLMService):
|
||||
self, context: LLMContext
|
||||
) -> AsyncIterator[GenerateContentResponse]:
|
||||
adapter = self.get_llm_adapter()
|
||||
logger.debug(
|
||||
# f"{self}: Generating chat [{self._system_instruction}] | {context.get_messages_for_logging()}"
|
||||
f"{self}: Generating chat from universal context {adapter.get_messages_for_logging(context)}"
|
||||
)
|
||||
|
||||
params: GeminiLLMInvocationParams = adapter.get_llm_invocation_params(context)
|
||||
|
||||
logger.debug(
|
||||
f"{self}: Generating chat from universal context [{params['system_instruction']}] | {adapter.get_messages_for_logging(context)}"
|
||||
)
|
||||
|
||||
return await self._stream_content(params)
|
||||
|
||||
@traced_llm
|
||||
@@ -1021,15 +1012,6 @@ class GoogleLLMService(LLMService):
|
||||
# NOTE: LLMMessagesFrame is deprecated, so we don't support the newer universal
|
||||
# LLMContext with it
|
||||
context = GoogleLLMContext(frame.messages)
|
||||
elif isinstance(frame, VisionImageRawFrame):
|
||||
# This is only useful in very simple pipelines because it creates
|
||||
# a new context. Generally we want a context manager to catch
|
||||
# UserImageRawFrames coming through the pipeline and add them
|
||||
# to the context.
|
||||
context = GoogleLLMContext()
|
||||
context.add_image_frame_message(
|
||||
format=frame.format, size=frame.size, image=frame.image, text=frame.text
|
||||
)
|
||||
elif isinstance(frame, LLMUpdateSettingsFrame):
|
||||
await self._update_settings(frame.settings)
|
||||
else:
|
||||
|
||||
@@ -83,14 +83,23 @@ class GoogleVertexLLMService(OpenAILLMService):
|
||||
self._api_key = self._get_api_token(credentials, credentials_path)
|
||||
|
||||
super().__init__(
|
||||
api_key=self._api_key, base_url=base_url, model=model, params=params, **kwargs
|
||||
api_key=self._api_key,
|
||||
base_url=base_url,
|
||||
model=model,
|
||||
params=params,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_base_url(params: InputParams) -> str:
|
||||
"""Construct the base URL for Vertex AI API."""
|
||||
# Determine the correct API host based on location
|
||||
if params.location == "global":
|
||||
api_host = "aiplatform.googleapis.com"
|
||||
else:
|
||||
api_host = f"{params.location}-aiplatform.googleapis.com"
|
||||
return (
|
||||
f"https://{params.location}-aiplatform.googleapis.com/v1/"
|
||||
f"https://{api_host}/v1/"
|
||||
f"projects/{params.project_id}/locations/{params.location}/endpoints/openapi"
|
||||
)
|
||||
|
||||
@@ -118,12 +127,14 @@ class GoogleVertexLLMService(OpenAILLMService):
|
||||
if credentials:
|
||||
# Parse and load credentials from JSON string
|
||||
creds = service_account.Credentials.from_service_account_info(
|
||||
json.loads(credentials), scopes=["https://www.googleapis.com/auth/cloud-platform"]
|
||||
json.loads(credentials),
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
elif credentials_path:
|
||||
# Load credentials from JSON file
|
||||
creds = service_account.Credentials.from_service_account_file(
|
||||
credentials_path, scopes=["https://www.googleapis.com/auth/cloud-platform"]
|
||||
credentials_path,
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
else:
|
||||
try:
|
||||
|
||||
@@ -500,9 +500,11 @@ class GoogleTTSService(TTSService):
|
||||
|
||||
Parameters:
|
||||
language: Language for synthesis. Defaults to English.
|
||||
speaking_rate: The speaking rate, in the range [0.25, 4.0].
|
||||
"""
|
||||
|
||||
language: Optional[Language] = Language.EN
|
||||
speaking_rate: Optional[float] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -510,6 +512,7 @@ class GoogleTTSService(TTSService):
|
||||
credentials: Optional[str] = None,
|
||||
credentials_path: Optional[str] = None,
|
||||
voice_id: str = "en-US-Chirp3-HD-Charon",
|
||||
voice_cloning_key: Optional[str] = None,
|
||||
sample_rate: Optional[int] = None,
|
||||
params: InputParams = InputParams(),
|
||||
**kwargs,
|
||||
@@ -520,6 +523,7 @@ class GoogleTTSService(TTSService):
|
||||
credentials: JSON string containing Google Cloud service account credentials.
|
||||
credentials_path: Path to Google Cloud service account JSON file.
|
||||
voice_id: Google TTS voice identifier (e.g., "en-US-Chirp3-HD-Charon").
|
||||
voice_cloning_key: The voice cloning key for Chirp 3 custom voices.
|
||||
sample_rate: Audio sample rate in Hz. If None, uses default.
|
||||
params: Language configuration parameters.
|
||||
**kwargs: Additional arguments passed to parent TTSService.
|
||||
@@ -532,8 +536,10 @@ class GoogleTTSService(TTSService):
|
||||
"language": self.language_to_service_language(params.language)
|
||||
if params.language
|
||||
else "en-US",
|
||||
"speaking_rate": params.speaking_rate,
|
||||
}
|
||||
self.set_voice(voice_id)
|
||||
self._voice_cloning_key = voice_cloning_key
|
||||
self._client: texttospeech_v1.TextToSpeechAsyncClient = self._create_client(
|
||||
credentials, credentials_path
|
||||
)
|
||||
@@ -600,15 +606,24 @@ class GoogleTTSService(TTSService):
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
voice = texttospeech_v1.VoiceSelectionParams(
|
||||
language_code=self._settings["language"], name=self._voice_id
|
||||
)
|
||||
if self._voice_cloning_key:
|
||||
voice_clone_params = texttospeech_v1.VoiceCloneParams(
|
||||
voice_cloning_key=self._voice_cloning_key
|
||||
)
|
||||
voice = texttospeech_v1.VoiceSelectionParams(
|
||||
language_code=self._settings["language"], voice_clone=voice_clone_params
|
||||
)
|
||||
else:
|
||||
voice = texttospeech_v1.VoiceSelectionParams(
|
||||
language_code=self._settings["language"], name=self._voice_id
|
||||
)
|
||||
|
||||
streaming_config = texttospeech_v1.StreamingSynthesizeConfig(
|
||||
voice=voice,
|
||||
streaming_audio_config=texttospeech_v1.StreamingAudioConfig(
|
||||
audio_encoding=texttospeech_v1.AudioEncoding.PCM,
|
||||
sample_rate_hertz=self.sample_rate,
|
||||
speaking_rate=self._settings["speaking_rate"],
|
||||
),
|
||||
)
|
||||
config_request = texttospeech_v1.StreamingSynthesizeRequest(
|
||||
|
||||
@@ -240,6 +240,7 @@ class HeyGenVideoService(AIService):
|
||||
# As soon as we receive actual audio, the base output transport will create a
|
||||
# BotStartedSpeakingFrame, which we can use as a signal for the TTFB metrics.
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
@@ -36,15 +36,15 @@ from pipecat.frames.frames import (
|
||||
FunctionCallResultFrame,
|
||||
FunctionCallResultProperties,
|
||||
FunctionCallsStartedFrame,
|
||||
InterruptionFrame,
|
||||
LLMConfigureOutputFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMTextFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
UserImageRequestFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext, LLMSpecificMessage
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMUserAggregatorParams,
|
||||
@@ -195,18 +195,24 @@ class LLMService(AIService):
|
||||
"""
|
||||
return self._adapter
|
||||
|
||||
async def run_inference(
|
||||
self, context: LLMContext | OpenAILLMContext, system_instruction: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
def create_llm_specific_message(self, message: Any) -> LLMSpecificMessage:
|
||||
"""Create an LLM-specific message (as opposed to a standard message) for use in an LLMContext.
|
||||
|
||||
Args:
|
||||
message: The message content.
|
||||
|
||||
Returns:
|
||||
A LLMSpecificMessage instance.
|
||||
"""
|
||||
return self.get_llm_adapter().create_llm_specific_message(message)
|
||||
|
||||
async def run_inference(self, context: LLMContext | OpenAILLMContext) -> Optional[str]:
|
||||
"""Run a one-shot, out-of-band (i.e. out-of-pipeline) inference with the given LLM context.
|
||||
|
||||
Must be implemented by subclasses.
|
||||
|
||||
Args:
|
||||
context: The LLM context containing conversation history.
|
||||
system_instruction: Optional system instruction to guide the LLM's
|
||||
behavior. You could also (again, optionally) provide a system
|
||||
instruction directly in the context.
|
||||
|
||||
Returns:
|
||||
The LLM's response as a string, or None if no response is generated.
|
||||
@@ -274,7 +280,7 @@ class LLMService(AIService):
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
if isinstance(frame, InterruptionFrame):
|
||||
await self._handle_interruptions(frame)
|
||||
elif isinstance(frame, LLMConfigureOutputFrame):
|
||||
self._skip_tts = frame.skip_tts
|
||||
@@ -291,7 +297,7 @@ class LLMService(AIService):
|
||||
|
||||
await super().push_frame(frame, direction)
|
||||
|
||||
async def _handle_interruptions(self, _: StartInterruptionFrame):
|
||||
async def _handle_interruptions(self, _: InterruptionFrame):
|
||||
for function_name, entry in self._functions.items():
|
||||
if entry.cancel_on_interruption:
|
||||
await self._cancel_function_call(function_name)
|
||||
|
||||
@@ -16,8 +16,8 @@ from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
@@ -180,7 +180,7 @@ class LmntTTSService(InterruptibleTTSService):
|
||||
direction: The direction to push the frame.
|
||||
"""
|
||||
await super().push_frame(frame, direction)
|
||||
if isinstance(frame, (TTSStoppedFrame, StartInterruptionFrame)):
|
||||
if isinstance(frame, (TTSStoppedFrame, InterruptionFrame)):
|
||||
self._started = False
|
||||
|
||||
async def _connect(self):
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
"""MCP (Model Context Protocol) client for integrating external tools with LLMs."""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any, Dict, List, TypeAlias
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -28,6 +28,8 @@ except ModuleNotFoundError as e:
|
||||
logger.error("In order to use an MCP client, you need to `pip install pipecat-ai[mcp]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
ServerParameters: TypeAlias = StdioServerParameters | SseServerParameters | StreamableHttpParameters
|
||||
|
||||
|
||||
class MCPClient(BaseObject):
|
||||
"""Client for Model Context Protocol (MCP) servers.
|
||||
@@ -42,7 +44,7 @@ class MCPClient(BaseObject):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_params: Tuple[StdioServerParameters, SseServerParameters, StreamableHttpParameters],
|
||||
server_params: ServerParameters,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the MCP client with server parameters.
|
||||
|
||||
@@ -57,16 +57,18 @@ class MistralLLMService(OpenAILLMService):
|
||||
logger.debug(f"Creating Mistral client with api {base_url}")
|
||||
return super().create_client(api_key, base_url, **kwargs)
|
||||
|
||||
def _apply_mistral_assistant_prefix(
|
||||
def _apply_mistral_fixups(
|
||||
self, messages: List[ChatCompletionMessageParam]
|
||||
) -> List[ChatCompletionMessageParam]:
|
||||
"""Apply Mistral's assistant message prefix requirement.
|
||||
"""Apply fixups to messages to meet Mistral-specific requirements.
|
||||
|
||||
Mistral requires assistant messages to have prefix=True when they
|
||||
are the final message in a conversation. According to Mistral's API:
|
||||
- Assistant messages with prefix=True MUST be the last message
|
||||
- Only add prefix=True to the final assistant message when needed
|
||||
- This allows assistant messages to be accepted as the last message
|
||||
1. A "tool"-role message must be followed by an assistant message.
|
||||
|
||||
2. "system"-role messages must only appear at the start of a
|
||||
conversation.
|
||||
|
||||
3. Assistant messages must have prefix=True when they are the final
|
||||
message in a conversation (but at no other point).
|
||||
|
||||
Args:
|
||||
messages: The original list of messages.
|
||||
@@ -80,6 +82,25 @@ class MistralLLMService(OpenAILLMService):
|
||||
# Create a copy to avoid modifying the original
|
||||
fixed_messages = [dict(msg) for msg in messages]
|
||||
|
||||
# Ensure all tool responses are followed by an assistant message
|
||||
assistant_insert_indices = []
|
||||
for i, msg in enumerate(fixed_messages):
|
||||
if msg.get("role") == "tool":
|
||||
# If this is the last message or the next message is not assistant
|
||||
if i == len(fixed_messages) - 1 or fixed_messages[i + 1].get("role") != "assistant":
|
||||
assistant_insert_indices.append(i + 1)
|
||||
for idx in reversed(assistant_insert_indices):
|
||||
fixed_messages.insert(idx, {"role": "assistant", "content": " "})
|
||||
|
||||
# Convert any "system" messages that aren't at the start (i.e., after the initial contiguous block) to "user"
|
||||
first_non_system_idx = next(
|
||||
(i for i, msg in enumerate(fixed_messages) if msg.get("role") != "system"),
|
||||
len(fixed_messages),
|
||||
)
|
||||
for i, msg in enumerate(fixed_messages):
|
||||
if msg.get("role") == "system" and i >= first_non_system_idx:
|
||||
msg["role"] = "user"
|
||||
|
||||
# Get the last message
|
||||
last_message = fixed_messages[-1]
|
||||
|
||||
@@ -158,7 +179,7 @@ class MistralLLMService(OpenAILLMService):
|
||||
- Core completion settings
|
||||
"""
|
||||
# Apply Mistral's assistant prefix requirement for API compatibility
|
||||
fixed_messages = self._apply_mistral_assistant_prefix(params_from_context["messages"])
|
||||
fixed_messages = self._apply_mistral_fixups(params_from_context["messages"])
|
||||
|
||||
params = {
|
||||
"model": self.model_name,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user