Compare commits

...

144 Commits

Author SHA1 Message Date
Aleix Conchillo Flaqué
19354c6f2d Merge pull request #2078 from pipecat-ai/aleix/hotfix-0.0.73
just a quick hotfix for 0.0.73
2025-06-26 17:31:40 -07:00
Aleix Conchillo Flaqué
0b2079ad41 update CHANGELOG for 0.0.73 2025-06-26 17:02:12 -07:00
Aleix Conchillo Flaqué
5f18c3af70 OpenAIRealtimeLLMContext: fix circular dependency 2025-06-26 17:01:45 -07:00
Aleix Conchillo Flaqué
0a40285d43 update FrameProcessor.watchdog_timers_enabled references 2025-06-26 16:26:12 -07:00
Vanessa Pyne
5b1c328541 Merge pull request #2075 from pipecat-ai/vp-mcp-lint
mcp_service: lint
2025-06-26 15:25:39 -05:00
vipyne
37929533af mcp_service: lint 2025-06-26 15:00:20 -05:00
Vanessa Pyne
3b92113680 Merge pull request #2030 from yousifa/mcp-streaming-http
MCPClient streamable_http transport support
2025-06-26 14:57:31 -05:00
Yousif
46b52cb9bb Merge branch 'main' into mcp-streaming-http 2025-06-26 12:30:43 -07:00
Mark Backman
f0bcc9d9ba Add MCPClient docstrings. Removed google specific cleanup, changed example to openai 2025-06-26 12:29:45 -07:00
Yousif Astarabadi
1cac028bfe example using http transport for mcp client 2025-06-26 12:16:35 -07:00
Yousif Astarabadi
4956886819 updated error message with StreamableHttpParameters 2025-06-26 12:16:28 -07:00
Yousif Astarabadi
c720cfc7c7 updated streamablehttp to use StreamableHttpParameters type 2025-06-26 12:16:26 -07:00
Yousif Astarabadi
8fcef5628f added streamablehttp support, bumped mcp version, added additional headers and streamable_http params to MCPClient 2025-06-26 12:16:19 -07:00
Aleix Conchillo Flaqué
c4a72802f0 Merge pull request #2074 from pipecat-ai/aleix/pipecat-0.0.72
update CHANGELOG for 0.0.72
2025-06-26 12:10:14 -07:00
Aleix Conchillo Flaqué
917394803c update CHANGELOG for 0.0.72 2025-06-26 11:42:52 -07:00
Mark Backman
01040ddcdd Merge pull request #2071 from pipecat-ai/mb/services-docstrings-update
Add/update docstrings to LLM services
2025-06-26 14:42:32 -04:00
Aleix Conchillo Flaqué
7947497f7e Merge pull request #2073 from a6kme/patch-1
Start HeartBeat when all processors have processed StartFrame
2025-06-26 11:34:46 -07:00
Aleix Conchillo Flaqué
539ca5856f Merge pull request #2072 from pipecat-ai/aleix/utils-watchdog-cleanup
utils(asyncio): simplify watchdog helpers
2025-06-26 11:29:21 -07:00
Abhishek
89c801f82c Start HeartBeat when all processors have processed StartFrame
Some of the processors like STTService and TTSService don't push StartFrame ahead in the pipeline, unless they have connected with their service providers. This delays StartFrame in downstream processors. 

If we receive HeartBeat frame before StartFrame, we will get AttributeError `'Processor' object has no attribute '_FrameProcessor__input_queue'`. 

Idea is to start HeartBeats after StartFrame has been processed by all the Processors in the pipeline.
2025-06-26 23:28:37 +05:30
Aleix Conchillo Flaqué
3de4f22d34 utils(asyncio): simplify watchdog helpers 2025-06-26 09:40:42 -07:00
Mark Backman
0e4d2be98c Update AzureRealtimeBetaLLMService docstrings 2025-06-26 12:12:00 -04:00
Mark Backman
d8ce108ccd Update OpenAIRealtimeBetaLLMService docstrings 2025-06-26 12:06:47 -04:00
Mark Backman
d123cd4b2b Update GeminiMultimodalLiveLLMService docstrings 2025-06-26 11:47:30 -04:00
Aleix Conchillo Flaqué
4d34aa7cd6 Merge pull request #2069 from pipecat-ai/aleix/utils-asyncio-package
move things to new utils.asyncio package
2025-06-26 08:26:47 -07:00
Aleix Conchillo Flaqué
b860e94582 move things to new utils.asyncio package 2025-06-26 08:24:25 -07:00
Aleix Conchillo Flaqué
9d653e3788 Merge pull request #2068 from pipecat-ai/aleix/task-manager-dont-warn-reset-watchdog
TaskManager: don't warn on reset_watchdog()
2025-06-26 08:23:51 -07:00
Mark Backman
9e518cf2ba Update AWSNovaSonicLLMService docstrings 2025-06-26 11:21:18 -04:00
Mark Backman
2856372ad6 Update TogetherLLMService docstrings 2025-06-26 11:01:35 -04:00
Mark Backman
efbf574613 Update SambaNovaLLMService docstrings 2025-06-26 11:00:40 -04:00
Mark Backman
c018eb2f0e Update QwenLLMService docstrings 2025-06-26 10:57:42 -04:00
Mark Backman
d7bfe54b7c Update PerplexityLLMService docstrings 2025-06-26 10:56:48 -04:00
Mark Backman
137282b7a9 Update OpenRouterLLMService docstrings 2025-06-26 10:53:42 -04:00
Mark Backman
769f8c8f34 Update OpenPipeLLMService docstrings 2025-06-26 10:53:05 -04:00
Mark Backman
8b8a37ae7c Update OLLamaLLMService docstrings 2025-06-26 10:48:19 -04:00
Mark Backman
56e2b006f5 Update NimLLMService docstrings 2025-06-26 10:47:26 -04:00
Mark Backman
79cca05e43 Update GroqLLMService docstrings 2025-06-26 10:46:07 -04:00
Mark Backman
166c8e8e82 Update GrokLLMService docstrings 2025-06-26 10:39:46 -04:00
Mark Backman
9b64d2c325 Update GoogleLLMService docstrings 2025-06-26 10:37:22 -04:00
Mark Backman
03e3e9fae9 Update FireworksLLMService docstrings 2025-06-26 10:28:35 -04:00
Mark Backman
65234ae41a Update DeepSeekLLMService docstrings 2025-06-26 10:27:36 -04:00
Mark Backman
3828df8cf9 Update CerebrasLLMService docstrings 2025-06-26 10:26:42 -04:00
Mark Backman
9cbe85bf99 Update AzureLLMService docstrings 2025-06-26 10:25:17 -04:00
Mark Backman
7bf805b829 Update AWSBedrock docstrings 2025-06-26 10:23:40 -04:00
Mark Backman
990ee436e1 Add Anthropic docstrings 2025-06-26 07:42:22 -04:00
Mark Backman
1cd42066a6 Merge pull request #2067 from pipecat-ai/mb/update-docstrings-for-ref-docs
Update base service class docstrings for better docs auto-generation
2025-06-26 07:07:59 -04:00
Filipi da Silva Fuchter
ba43558049 Merge pull request #2066 from pipecat-ai/filipi/sentry_freeze_test
Enabling watchdog and sentry into the freeze-test
2025-06-26 08:01:51 -03:00
Mark Backman
951c8d34da Add special case handling for STT, TTS, LLM 2025-06-26 00:15:09 -04:00
Mark Backman
ac61139243 Add OpenAI LLM docstrings 2025-06-26 00:06:57 -04:00
Mark Backman
5b8f1fe3e3 Add Cartesia TTS docstrings 2025-06-25 23:50:55 -04:00
Mark Backman
0aa197e4a4 Add docstrings to DeepgramSTTService 2025-06-25 23:36:04 -04:00
Mark Backman
f04e058c96 Programmatically set the copyright date in docs 2025-06-25 23:29:37 -04:00
Mark Backman
6ef2ae12b7 Mock mcp imports 2025-06-25 23:29:37 -04:00
Mark Backman
fe6bbdaefe Skip dataclass attributes to remove duplicate entries 2025-06-25 23:29:37 -04:00
Mark Backman
cc66fddca9 Modify docs auto-gen rules to remove duplicate parameters listing 2025-06-25 23:29:37 -04:00
Mark Backman
04b70ddf13 Add MCPClient docstrings 2025-06-25 22:38:11 -04:00
Mark Backman
bb3bb8d9c6 Improve WebsocketService docstrings 2025-06-25 22:38:11 -04:00
Mark Backman
f80f62c7d1 Add VisionService docstrings 2025-06-25 22:38:11 -04:00
Mark Backman
2007ae4317 Add ImageGenService docstrings 2025-06-25 22:38:11 -04:00
Mark Backman
a1e5a1eff4 Add AIService docstrings 2025-06-25 22:38:11 -04:00
Mark Backman
691999b402 Add AIServices docstring 2025-06-25 22:38:11 -04:00
Mark Backman
33f3a4cea1 Add TTSService docstrings 2025-06-25 22:38:11 -04:00
Mark Backman
ab1d2dbe6a Add STTService docstrings 2025-06-25 22:27:07 -04:00
Mark Backman
f622b281d0 Make call_start_function a private function in llm_service 2025-06-25 22:23:13 -04:00
Mark Backman
fb12bf9b4c Update LLMService docstrings 2025-06-25 22:23:13 -04:00
Aleix Conchillo Flaqué
27af50087e TaskManager: don't warn on reset_watchdog() 2025-06-25 17:29:45 -07:00
Filipi Fuchter
03502bed52 Enabling watchdog and sentry into the freeze-test 2025-06-25 20:53:30 -03:00
Aleix Conchillo Flaqué
27c7e2d150 Merge pull request #2063 from pipecat-ai/aleix/watchdog-timers-remove-start-watchdog
no need to call start_watchdog() only reset_watchdog()
2025-06-25 16:47:44 -07:00
Aleix Conchillo Flaqué
e81d387971 TaskManager: rely on add_done_callback() 2025-06-25 16:44:20 -07:00
Aleix Conchillo Flaqué
ef1ade3a71 allow enabling watchdog timers per frame processor or task 2025-06-25 16:36:19 -07:00
Aleix Conchillo Flaqué
4f032f5b96 update keepalive times depending on watchdog timers 2025-06-25 15:55:16 -07:00
Aleix Conchillo Flaqué
72cb967780 update CHANGELOG with watchdog timers updates 2025-06-25 15:55:16 -07:00
Aleix Conchillo Flaqué
357934a644 watchdog timers are disabled by default use enable_watchdog_timers 2025-06-25 15:55:16 -07:00
Aleix Conchillo Flaqué
327973657f TaskManager: remove wathcdog timer when main task is done 2025-06-25 11:26:21 -07:00
Aleix Conchillo Flaqué
d2730e6741 GooglSTTService: cleanup request queues 2025-06-25 11:12:32 -07:00
Aleix Conchillo Flaqué
eb5ecab104 no need to call start_watchdog() only reset_watchdog() 2025-06-25 11:12:32 -07:00
Mark Backman
202055a9b8 Merge pull request #2065 from pipecat-ai/mb/fix-configdict-openai-realtime
fix: add missing ConfigDict import in openai_realtime_beta/events
2025-06-25 11:40:35 -04:00
Mark Backman
7034a9e3fd fix: add missing ConfigDict import in openai_realtime_beta/events 2025-06-25 11:32:29 -04:00
Filipi da Silva Fuchter
8f7ed12262 Merge pull request #2061 from pipecat-ai/not_force_bot_speaking
Not forcing the bot resume speaking in case we receive no transcription.
2025-06-24 20:57:46 -03:00
Aleix Conchillo Flaqué
96b5320ef9 Merge pull request #2055 from pipecat-ai/aleix/fix-sentry-async
SentryMetrics: send metrics to sentry asynchronously
2025-06-24 16:32:01 -07:00
Filipi Fuchter
d5cd742237 Not forcing the bot resume speaking in case we receive no transcription. 2025-06-24 20:12:49 -03:00
Aleix Conchillo Flaqué
1f1da8942d SentryMetrics: send metrics to sentry asynchronously 2025-06-24 15:56:08 -07:00
Mark Backman
7953e1e9d9 Merge pull request #2054 from pipecat-ai/mb/telnyx-catch-hangup-error
fix: Telnyx, catch error when user has hung up the call first
2025-06-24 18:04:19 -04:00
Mark Backman
d6f7ecc0a3 fix: Telnyx, catch error when user has hung up the call first 2025-06-24 17:28:00 -04:00
Mark Backman
3eed316049 Merge pull request #2020 from snova-jorgep/snova-jorgep/sambanova-integration
Add Sambanova LLM and STT integration
2025-06-24 17:04:24 -04:00
Jorge Piedrahita Ortiz
851cf079c3 Merge branch 'main' into snova-jorgep/sambanova-integration 2025-06-24 16:00:28 -05:00
jhpiedrahitao
dfb0da32a9 fmt 2025-06-24 15:59:40 -05:00
Aleix Conchillo Flaqué
f450da57e5 Merge pull request #2056 from pipecat-ai/khk/fix-22d
Update google libraries used in google audio-in examples
2025-06-24 13:47:59 -07:00
Aleix Conchillo Flaqué
2ec6b6c995 Merge pull request #2060 from pipecat-ai/aleix/watchdog-timeout-secs
FrameProcessor: use watchdog_timeout_secs
2025-06-24 13:36:39 -07:00
Aleix Conchillo Flaqué
53b769a8ec FrameProcessor: use watchdog_timeout_secs 2025-06-24 13:33:47 -07:00
Filipi da Silva Fuchter
4f9adc173a Merge pull request #2004 from pipecat-ai/filipi/pipeline_freeze
Pipeline freeze improvements
2025-06-24 17:20:38 -03:00
Filipi Fuchter
dc4a58877e Fixing merge conflict. 2025-06-24 17:12:40 -03:00
Filipi Fuchter
a6243a6fe7 Merge branch 'main' into filipi/pipeline_freeze
# Conflicts:
#	CHANGELOG.md
#	src/pipecat/pipeline/task.py
#	src/pipecat/processors/frame_processor.py
#	src/pipecat/transports/base_input.py
2025-06-24 17:11:21 -03:00
Aleix Conchillo Flaqué
cf5f1b541a Merge pull request #2049 from pipecat-ai/aleix/introduce-watchdog-timers
introduce watchdog timers
2025-06-24 13:00:57 -07:00
Filipi Fuchter
70e6c48233 Mentioning the fixes in the changelog. 2025-06-24 16:56:46 -03:00
Filipi Fuchter
51f7d14d0a Merge branch 'main' into filipi/pipeline_freeze 2025-06-24 16:44:07 -03:00
Filipi Fuchter
4853d5d1fc Handling the case where user stopped speaking but no new aggregation received. 2025-06-24 16:42:10 -03:00
Aleix Conchillo Flaqué
076a8938f0 add start_watchdog/reset_watchdog to tasks 2025-06-24 11:56:20 -07:00
Aleix Conchillo Flaqué
5a3457ba33 introduce task watchdog timers 2025-06-24 11:56:20 -07:00
Aleix Conchillo Flaqué
2fc224384d Merge pull request #2059 from pipecat-ai/aleix/heartbeatframe-control-frames
HeartbeatFrames are now control frames
2025-06-24 11:55:18 -07:00
Aleix Conchillo Flaqué
a4e6ea5a3f HeartbeatFrames are now control frames 2025-06-24 11:27:39 -07:00
Vanessa Pyne
d3c211f293 Merge pull request #2058 from pipecat-ai/vp-mcp-sse-up
follow up to #1887 - proper MCP SSE support
2025-06-24 13:06:01 -05:00
vipyne
20047c369e mcp: update examples to use SseServerParameter 2025-06-24 12:58:39 -05:00
vipyne
dd1ff237a8 lint mcp_service 2025-06-24 12:58:33 -05:00
Vanessa Pyne
39d80d0b0e Merge pull request #1887 from ezun-kim/feat/mcp-sse-params
Fix SSE server connection handling for MCP client
2025-06-24 12:58:05 -05:00
Kwindla Hultman Kramer
7a48316534 update google libraries used in google audio-in examples 2025-06-24 09:52:04 -07:00
Filipi da Silva Fuchter
031a93ac46 Merge pull request #2053 from pipecat-ai/sentry_dsn_environment_variable
Creating an environment variable for sentry dsn.
2025-06-24 12:10:20 -03:00
Mark Backman
ea6cc1aa95 Merge pull request #2052 from pipecat-ai/mb/11labs-keepalive
Send context_id when available in ElevenLabsTTSService keepalive message
2025-06-24 11:07:07 -04:00
Filipi Fuchter
365260ec44 Creating an environment variable for sentry dsn. 2025-06-24 11:57:14 -03:00
Mark Backman
2eb244c80a Send context_id when available in ElevenLabsTTSService keepalive message 2025-06-24 10:52:49 -04:00
Mark Backman
aee3011d61 Merge pull request #2037 from pipecat-ai/mb/11labs-close-context
Fix: Correctly close the context for ElevenLabsTTSService
2025-06-24 07:44:22 -04:00
Aleix Conchillo Flaqué
40496e7b0f Merge pull request #2034 from pipecat-ai/khk/pause-frames
small fix for processor pause/resume frames
2025-06-23 17:08:41 -07:00
Kwindla Hultman Kramer
6b24f89fa7 small fix for processor pause/resume frames 2025-06-23 16:44:32 -07:00
Filipi Fuchter
2097800042 Allowing to clear the turn analyser 2025-06-23 18:50:37 -03:00
Filipi Fuchter
6739318e68 Forcing user stopped speaking due to timeout to receive audio frame! 2025-06-23 18:50:02 -03:00
Filipi Fuchter
d0bd563d42 Logging the BaseException inside the cancel_task. 2025-06-23 18:48:44 -03:00
Filipi Fuchter
74280829fc Fixed an issue with the FastAPIWebsocketClient to disconnect in case the websocket is already closed. 2025-06-23 18:48:03 -03:00
Filipi Fuchter
3fde8880f2 Fixed a couple of places inside the FrameProcessor where we should not raise the exceptions. 2025-06-23 18:47:54 -03:00
Filipi Fuchter
98d39e0d38 Logging the last 10 frames received in case idle timeout is detected. 2025-06-23 18:47:17 -03:00
Filipi Fuchter
c9cebb5ffe Created an example for testing the bot and try to create freezing conditions. 2025-06-23 18:46:58 -03:00
Mark Backman
f52ac6e99c Merge pull request #1998 from pipecat-ai/mb/fix-38-smart-turn-fal 2025-06-23 17:15:29 -04:00
Mark Backman
787a6b1c6a Merge pull request #2038 from pipecat-ai/mb/openai-realtime-model-update
Update OpenAIRealtimeBetaLLMService model to gpt-4o-realtime-preview-…
2025-06-23 16:30:31 -04:00
Mark Backman
d00a91074e Update OpenAIRealtimeBetaLLMService model to gpt-4o-realtime-preview-2025-06-03 2025-06-23 16:26:42 -04:00
Mark Backman
4e11497a38 Merge pull request #2048 from thibaudbrg/patch-1
Fix missing video_in_enabled in vision bot.py for Moondream template
2025-06-23 16:11:50 -04:00
Tibo
0443d5202a Fix missing video_in_enabled in vision bot.py for Moondream template
The parameter video_in_enabled=True was missing in DailyParams, which prevented image capture 
from working. Without this parameter, UserImageRequestFrame would be sent but no actual image data would be captured from participants.

This fix enables the "Let me take a look" functionality to work as 
intended by allowing the transport to capture video frames for vision processing with Moondream.
2025-06-23 21:17:41 +02:00
Mark Backman
633c25cb13 Merge pull request #2039 from pipecat-ai/mb/remove-lang-validation
OpenAIRealtimeBetaLLMService accepts language for all InputAudioTrans…
2025-06-23 14:41:09 -04:00
jhpiedrahitao
d07f45132f update changelog 2025-06-23 12:54:00 -05:00
jhpiedrahitao
a51280afa6 add 13 and 14 type foundational examples for sambanova iontegration 2025-06-23 12:53:32 -05:00
Jorge Piedrahita Ortiz
be14eb2460 Merge branch 'pipecat-ai:main' into snova-jorgep/sambanova-integration 2025-06-23 12:23:00 -05:00
jhpiedrahitao
e26dbffcbe update sambanova init imports 2025-06-23 12:22:08 -05:00
Mark Backman
59992fd24a Merge pull request #2044 from pipecat-ai/mb/daily-rest-docstring
Add missing arg docstring in DailyRESTHelper
2025-06-23 11:24:44 -04:00
Mark Backman
455362ccaf Merge pull request #2022 from pipecat-ai/mb/turn-tracking-end-cancel-frame
TurnTrackingObserver ends turn upon seeing EndFrame, CancelFrame
2025-06-23 11:24:27 -04:00
Mark Backman
16c0e2460b TurnTrackingObserver ends turn upon seeing EndFrame, CancelFrame 2025-06-23 11:08:51 -04:00
Mark Backman
92246f7125 Add missing arg docstring in DailyRESTHelper 2025-06-22 13:44:59 -04:00
Mark Backman
7737335ec9 OpenAIRealtimeBetaLLMService accepts language for all InputAudioTranscription models 2025-06-21 10:08:46 -04:00
Mark Backman
5cc9b7e0d1 Fix: Correctly close the context for ElevenLabsTTSService 2025-06-20 15:47:03 -04:00
Mark Backman
8c6a441064 Merge pull request #2035 from smokyabdulrahman/feat/aws-polly-lexicon-names-support
Support AWS Polly Lexicon Names parameter
2025-06-20 10:03:27 -04:00
Alrahma
fddc058ce2 add CHANGELOG entry 2025-06-20 14:15:24 +01:00
Alrahma
89750086c5 Support AWS Polly Lexicon Names parameter
Documentation reference
[AWS Managing
Lexicons](https://docs.aws.amazon.com/polly/latest/dg/managing-lexicons.html)
2025-06-20 09:47:46 +01:00
Aleix Conchillo Flaqué
e69406c7e2 Merge pull request #2032 from pipecat-ai/aleix/aws-nova-sonic-function-calls
AWSNovaSonicLLMService: fix function calling
2025-06-19 14:42:47 -07:00
Aleix Conchillo Flaqué
878ae42d84 AWSNovaSonicLLMService: fix function calling 2025-06-19 14:26:34 -07:00
jhpiedrahitao
fae2d272d5 fmt 2025-06-18 10:53:06 -05:00
jhpiedrahitao
03a067d3e6 add sambanova llm and stt 2025-06-18 10:50:42 -05:00
Mark Backman
c94c51d44f Fix: 38-smart-turn-fal 2025-06-17 15:10:52 -04:00
ezun-kim
3da711ba8b Fix SSE server connection handling for MCP client
### Summary
This PR improves the MCP (Model Context Protocol) client's SSE (Server-Sent Events) server connection handling by replacing the generic string parameter with a proper `SseServerParameters` class.

### Changes
- **Breaking Change**: Changed `server_params` type from `Union[StdioServerParameters, str]` to `Union[StdioServerParameters, SseServerParameters]`
- Added import for `SseServerParameters` from `mcp.client.session_group`
- Updated SSE client connection to use structured parameters instead of a simple URL string
- Fixed error message to correctly reflect the expected parameter types
- Improved logging by changing info-level log to debug-level for consistency

### Details

#### Before
The SSE client connection only accepted a URL string:
```python
async with self._client(self._server_params) as (read, write):
```

#### After
Now properly unpacks SSE server parameters:
```python
async with self._client(
    url=self._server_params.url,
    headers=self._server_params.headers,
    timeout=self._server_params.timeout,
    sse_read_timeout=self._server_params.sse_read_timeout
) as (read, write):
```

### Benefits
- **Type Safety**: Stronger type checking with dedicated `SseServerParameters` class
- **Extended Configuration**: Support for custom headers (authentication), timeouts, and SSE-specific settings
- **Better Error Messages**: Clear type error messages when incorrect parameters are provided
- **Improved Debugging**: Debug logging of SSE server parameters for troubleshooting

### Migration Guide
Users need to update their SSE server initialization:
```python
# Before
client = MCPClient("https://example.com/sse")

# After
from mcp.client.session_group import SseServerParameters
client = MCPClient(SseServerParameters(
    url="https://example.com/sse",
    headers={"Authorization": "Bearer token"},
    timeout=30,
    sse_read_timeout=60
))
```

### Testing
- [ ] Tested with StdioServerParameters (unchanged behavior)
- [ ] Tested with SseServerParameters with various configurations
- [ ] Verified error handling for invalid parameter types

---

This is a necessary change to support production-ready SSE connections with proper authentication and timeout handling.
2025-05-24 22:35:57 +09:00
116 changed files with 8324 additions and 924 deletions

View File

@@ -5,12 +5,47 @@ All notable changes to **Pipecat** will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
## [0.0.73] - 2025-06-26
### Fixed
- Fixed an issue introduced in 0.0.72 that would cause `ElevenLabsTTSService`,
`GladiaSTTService`, `NeuphonicTTSService` and `OpenAIRealtimeBetaLLMService`
to throw an error.
## [0.0.72] - 2025-06-26
### Added
- Added logging and improved error handling to help diagnose and prevent potential
Pipeline freezes.
- Added `WatchdogQueue`, `WatchdogPriorityQueue`, `WatchdogEvent` and
`WatchdogAsyncIterator`. These helper utilities reset watchdog timers
appropriately before they expire. When watchdog timers are disabled, the
utilities behave as standard counterparts without side effects.
- Introduce task watchdog timers. Watchdog timers are used to detect if a
Pipecat task is taking longer than expected (by default 5 seconds). Watchdog
timers are disabled by default and can be enabled globally by passing
`enable_watchdog_timers` argument to `PipelineTask` constructor. It is
possible to change the default watchdog timer timeout by using the
`watchdog_timeout` argument. You can also log how long it takes to reset the
watchdog timers which is done with the `enable_watchdog_logging`. You can
control all these settings per each frame processor or even per task. That is,
you can set `enable_watchdog_timers`, `enable_watchdog_logging` and
`watchdog_timeout` when creating any frame processor through their constructor
arguments or when you create a task with `FrameProcessor.create_task()`. Note
that watchdog timers only work with Pipecat tasks and will not work if you use
`asycio.create_task()` or similar.
- Added `lexicon_names` parameter to `AWSPollyTTSService.InputParams`.
- Added reconnection logic and audio buffer management to `GladiaSTTService`.
- The `TurnTrackingObserver` now ends a turn upon observing an `EndFrame` or
`CancelFrame`.
- Added Polish support to `AWSTranscribeSTTService`.
- Added new frames `FrameProcessorPauseFrame` and `FrameProcessorResumeFrame`
@@ -27,8 +62,28 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
`LLMAssistantContextAggregator` that exposes whether a function call is in
progress.
- Added `SambaNovaLLMService` which provides llm api integration with an
OpenAI-compatible interface.
- Added `SambaNovaTTSService` which provides speech-to-text functionality using
SambaNovas's (whisper) API.
- Add fundational examples for function calling and transcription
`14s-function-calling-sambanova.py`, `13g-sambanova-transcription.py`
### Changed
- `HeartbeatFrame`s are now control frames. This will make it easier to detect
pipeline freezes. Previously, heartbeat frames were system frames which meant
they were not get queued with other frames, making it difficult to detect
pipeline stalls.
- Updated `OpenAIRealtimeBetaLLMService` to accept `language` in the
`InputAudioTranscription` class for all models.
- Updated the default model for `OpenAIRealtimeBetaLLMService` to
`gpt-4o-realtime-preview-2025-06-03`.
- The `PipelineParams` arg `allow_interruptions` now defaults to `True`.
- `TavusTransport` and `TavusVideoService` now send audio to Tavus using WebRTC
@@ -39,6 +94,28 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
- Fixed an issue that would cause heartbeat frames to be sent before processors
were started.
- Fixed an event loop blocking issue when using `SentryMetrics`.
- Fixed an issue in `FastAPIWebsocketClient` to ensure proper disconnection
when the websocket is already closed.
- Fixed an issue where the `UserStoppedSpeakingFrame` was not received if the
transport was not receiving new audio frames.
- Fixed an edge case where if the user interrupted the bot but no new aggregation
was received, the bot would not resume speaking.
- Fixed an issue with `TelnyxFrameSerializer` where it would throw an exception
when the user hung up the call.
- Fixed an issue with `ElevenLabsTTSService` where the context was not being
closed.
- Fixed function calling in `AWSNovaSonicLLMService`.
- Fixed an issue that would cause multiple `PipelineTask.on_idle_timeout`
events to be triggered repeatedly.

View File

@@ -41,36 +41,76 @@ We use Ruff for code linting and formatting. Please ensure your code passes all
We follow Google-style docstrings with these specific conventions:
- Class docstrings should fully document all parameters used in `__init__`
- We don't require separate docstrings for `__init__` methods when parameters are documented in the class docstring
- Property methods should have docstrings explaining their purpose and return value
**Regular Classes:**
Example of correctly documented class:
- Class docstring describes the class purpose and documents all `__init__` parameters in an `Args:` section
- No separate `__init__` docstring needed
- All public methods must have docstrings with `Args:` and `Returns:` sections as appropriate
**Dataclasses:**
- Class docstring describes the purpose and documents all fields in a `Parameters:` section
- No `__init__` docstring (auto-generated)
**Properties:**
- Must have docstrings with `Returns:` section
**Abstract Methods:**
- Must have docstrings explaining what subclasses should implement
#### Examples:
```python
class MyClass:
"""Class description.
Additional details about the class.
# Regular class
class MyService(BaseService):
"""Description of what the service does.
Args:
param1: Description of first parameter.
param2: Description of second parameter.
param1: Description of param1.
param2: Description of param2. Defaults to True.
**kwargs: Additional arguments passed to parent.
"""
def __init__(self, param1, param2):
# No docstring required here as parameters are documented above
self.param1 = param1
self.param2 = param2
def __init__(self, param1: str, param2: bool = True, **kwargs):
# No docstring - parameters documented above
super().__init__(**kwargs)
@property
def some_property(self) -> str:
"""Get the formatted property value.
def sample_rate(self) -> int:
"""Get the current sample rate.
Returns:
A string representation of the property.
The sample rate in Hz.
"""
return f"Property: {self.param1}"
return self._sample_rate
async def process_data(self, data: str) -> bool:
"""Process the provided data.
Args:
data: The data to process.
Returns:
True if processing succeeded.
"""
pass
# Dataclass
@dataclass
class ConfigParams:
"""Configuration parameters for the service.
Parameters:
host: The host address.
port: The port number. Defaults to 8080.
timeout: Connection timeout in seconds.
"""
host: str
port: int = 8080
timeout: float = 30.0
```
# Contributor Covenant Code of Conduct

View File

@@ -53,8 +53,8 @@ You can connect to Pipecat from any platform using our official SDKs:
| Category | Services |
| ------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| Speech-to-Text | [AssemblyAI](https://docs.pipecat.ai/server/services/stt/assemblyai), [AWS](https://docs.pipecat.ai/server/services/stt/aws), [Azure](https://docs.pipecat.ai/server/services/stt/azure), [Cartesia](https://docs.pipecat.ai/server/services/stt/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/stt/deepgram), [Fal Wizper](https://docs.pipecat.ai/server/services/stt/fal), [Gladia](https://docs.pipecat.ai/server/services/stt/gladia), [Google](https://docs.pipecat.ai/server/services/stt/google), [Groq (Whisper)](https://docs.pipecat.ai/server/services/stt/groq), [OpenAI (Whisper)](https://docs.pipecat.ai/server/services/stt/openai), [Parakeet (NVIDIA)](https://docs.pipecat.ai/server/services/stt/parakeet), [Ultravox](https://docs.pipecat.ai/server/services/stt/ultravox), [Whisper](https://docs.pipecat.ai/server/services/stt/whisper) |
| LLMs | [Anthropic](https://docs.pipecat.ai/server/services/llm/anthropic), [AWS](https://docs.pipecat.ai/server/services/llm/aws), [Azure](https://docs.pipecat.ai/server/services/llm/azure), [Cerebras](https://docs.pipecat.ai/server/services/llm/cerebras), [DeepSeek](https://docs.pipecat.ai/server/services/llm/deepseek), [Fireworks AI](https://docs.pipecat.ai/server/services/llm/fireworks), [Gemini](https://docs.pipecat.ai/server/services/llm/gemini), [Grok](https://docs.pipecat.ai/server/services/llm/grok), [Groq](https://docs.pipecat.ai/server/services/llm/groq), [NVIDIA NIM](https://docs.pipecat.ai/server/services/llm/nim), [Ollama](https://docs.pipecat.ai/server/services/llm/ollama), [OpenAI](https://docs.pipecat.ai/server/services/llm/openai), [OpenRouter](https://docs.pipecat.ai/server/services/llm/openrouter), [Perplexity](https://docs.pipecat.ai/server/services/llm/perplexity), [Qwen](https://docs.pipecat.ai/server/services/llm/qwen), [Together AI](https://docs.pipecat.ai/server/services/llm/together) |
| Speech-to-Text | [AssemblyAI](https://docs.pipecat.ai/server/services/stt/assemblyai), [AWS](https://docs.pipecat.ai/server/services/stt/aws), [Azure](https://docs.pipecat.ai/server/services/stt/azure), [Cartesia](https://docs.pipecat.ai/server/services/stt/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/stt/deepgram), [Fal Wizper](https://docs.pipecat.ai/server/services/stt/fal), [Gladia](https://docs.pipecat.ai/server/services/stt/gladia), [Google](https://docs.pipecat.ai/server/services/stt/google), [Groq (Whisper)](https://docs.pipecat.ai/server/services/stt/groq), [OpenAI (Whisper)](https://docs.pipecat.ai/server/services/stt/openai), [Parakeet (NVIDIA)](https://docs.pipecat.ai/server/services/stt/parakeet), [SambaNova (Whisper)](https://docs.pipecat.ai/server/services/stt/sambanova) [Ultravox](https://docs.pipecat.ai/server/services/stt/ultravox), [Whisper](https://docs.pipecat.ai/server/services/stt/whisper) |
| LLMs | [Anthropic](https://docs.pipecat.ai/server/services/llm/anthropic), [AWS](https://docs.pipecat.ai/server/services/llm/aws), [Azure](https://docs.pipecat.ai/server/services/llm/azure), [Cerebras](https://docs.pipecat.ai/server/services/llm/cerebras), [DeepSeek](https://docs.pipecat.ai/server/services/llm/deepseek), [Fireworks AI](https://docs.pipecat.ai/server/services/llm/fireworks), [Gemini](https://docs.pipecat.ai/server/services/llm/gemini), [Grok](https://docs.pipecat.ai/server/services/llm/grok), [Groq](https://docs.pipecat.ai/server/services/llm/groq), [NVIDIA NIM](https://docs.pipecat.ai/server/services/llm/nim), [Ollama](https://docs.pipecat.ai/server/services/llm/ollama), [OpenAI](https://docs.pipecat.ai/server/services/llm/openai), [OpenRouter](https://docs.pipecat.ai/server/services/llm/openrouter), [Perplexity](https://docs.pipecat.ai/server/services/llm/perplexity), [Qwen](https://docs.pipecat.ai/server/services/llm/qwen), [SambaNova](https://docs.pipecat.ai/server/services/llm/sambanova) [Together AI](https://docs.pipecat.ai/server/services/llm/together) |
| Text-to-Speech | [AWS](https://docs.pipecat.ai/server/services/tts/aws), [Azure](https://docs.pipecat.ai/server/services/tts/azure), [Cartesia](https://docs.pipecat.ai/server/services/tts/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/tts/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/tts/elevenlabs), [FastPitch (NVIDIA)](https://docs.pipecat.ai/server/services/tts/fastpitch), [Fish](https://docs.pipecat.ai/server/services/tts/fish), [Google](https://docs.pipecat.ai/server/services/tts/google), [LMNT](https://docs.pipecat.ai/server/services/tts/lmnt), [MiniMax](https://docs.pipecat.ai/server/services/tts/minimax), [Neuphonic](https://docs.pipecat.ai/server/services/tts/neuphonic), [OpenAI](https://docs.pipecat.ai/server/services/tts/openai), [Piper](https://docs.pipecat.ai/server/services/tts/piper), [PlayHT](https://docs.pipecat.ai/server/services/tts/playht), [Rime](https://docs.pipecat.ai/server/services/tts/rime), [Sarvam](https://docs.pipecat.ai/server/services/tts/sarvam), [XTTS](https://docs.pipecat.ai/server/services/tts/xtts) |
| Speech-to-Speech | [AWS Nova Sonic](https://docs.pipecat.ai/server/services/s2s/aws), [Gemini Multimodal Live](https://docs.pipecat.ai/server/services/s2s/gemini), [OpenAI Realtime](https://docs.pipecat.ai/server/services/s2s/openai) |
| Transport | [Daily (WebRTC)](https://docs.pipecat.ai/server/services/transport/daily), [FastAPI Websocket](https://docs.pipecat.ai/server/services/transport/fastapi-websocket), [SmallWebRTCTransport](https://docs.pipecat.ai/server/services/transport/small-webrtc), [WebSocket Server](https://docs.pipecat.ai/server/services/transport/websocket-server), Local |

View File

@@ -1,5 +1,6 @@
import logging
import sys
from datetime import datetime
from pathlib import Path
# Configure logging
@@ -13,7 +14,8 @@ sys.path.insert(0, str(project_root / "src"))
# Project information
project = "pipecat-ai"
copyright = "2024, Daily"
current_year = datetime.now().year
copyright = f"2024-{current_year}, Daily" if current_year > 2024 else "2024, Daily"
author = "Daily"
# General configuration
@@ -27,15 +29,14 @@ extensions = [
# Napoleon settings
napoleon_google_docstring = True
napoleon_numpy_docstring = False
napoleon_include_init_with_doc = True
napoleon_include_init_with_doc = False
# AutoDoc settings
autodoc_default_options = {
"members": True,
"member-order": "bysource",
"special-members": "__init__",
"undoc-members": True,
"exclude-members": "__weakref__",
"exclude-members": "__weakref__,__init__",
"no-index": True,
"show-inheritance": True,
}
@@ -145,6 +146,28 @@ autodoc_mock_imports = [
"transformers.AutoFeatureExtractor",
# Also add specific classes that are imported
"AutoFeatureExtractor",
# Sentry dependencies
"sentry_sdk",
# AWS Nova Sonic dependencies
"aws_sdk_bedrock_runtime",
"aws_sdk_bedrock_runtime.client",
"aws_sdk_bedrock_runtime.config",
"aws_sdk_bedrock_runtime.models",
"smithy_aws_core",
"smithy_aws_core.credentials_resolvers",
"smithy_aws_core.credentials_resolvers.static",
"smithy_aws_core.identity",
"smithy_core",
"smithy_core.aio",
"smithy_core.aio.eventstream",
# MCP dependencies (you may already have these)
"mcp",
"mcp.client",
"mcp.client.session_group",
"mcp.client.sse",
"mcp.client.stdio",
"mcp.ClientSession",
"mcp.StdioServerParameters",
]
# HTML output settings
@@ -249,6 +272,9 @@ def clean_title(title: str) -> str:
"playht": "PlayHT",
"xtts": "XTTS",
"lmnt": "LMNT",
"stt": "STT",
"tts": "TTS",
"llm": "LLM",
}
# Check if the entire title is a special case

View File

@@ -42,6 +42,7 @@ pipecat-ai[openai]
pipecat-ai[qwen]
pipecat-ai[remote-smart-turn]
# pipecat-ai[riva] # Mocked
pipecat-ai[sambanova]
pipecat-ai[silero]
pipecat-ai[simli]
pipecat-ai[soundfile]

View File

@@ -107,4 +107,10 @@ MINIMAX_API_KEY=...
MINIMAX_GROUP_ID=...
# Sarvam AI
SARVAM_API_KEY=...
SARVAM_API_KEY=...
# SambaNova
SAMBANOVA_API_KEY=...
# Sentry
SENTRY_DSN=...

View File

@@ -8,8 +8,8 @@ import argparse
import os
from dataclasses import dataclass
import google.ai.generativelanguage as glm
from dotenv import load_dotenv
from google.genai.types import Content, Part
from loguru import logger
from pipecat.audio.vad.silero import SileroVADAnalyzer
@@ -164,9 +164,7 @@ class TanscriptionContextFixup(FrameProcessor):
and last_part.inline_data
and last_part.inline_data.mime_type == "audio/wav"
):
self._context.messages[-2] = glm.Content(
role="user", parts=[glm.Part(text=self._transcript)]
)
self._context.messages[-2] = Content(role="user", parts=[Part(text=self._transcript)])
def add_transcript_back_to_inference_output(self):
if not self._transcript:

View File

@@ -0,0 +1,108 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import argparse
import os
import time
from dotenv import load_dotenv
from loguru import logger
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.audio.vad.vad_analyzer import VADParams
from pipecat.frames.frames import Frame, TranscriptionFrame, UserStoppedSpeakingFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.services.sambanova.stt import SambaNovaSTTService
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.network.fastapi_websocket import FastAPIWebsocketParams
from pipecat.transports.services.daily import DailyParams
load_dotenv(override=True)
STOP_SECS = 2.0
class TranscriptionLogger(FrameProcessor):
"""Measures transcription latency.
Uses the (intentionally) long STOP_SECS parameter to give the transcription time to finish,
then outputs the timing between when the VAD first classified audio input as not-speech and
the delivery of the last transcription frame.
"""
def __init__(self):
super().__init__()
self._last_transcription_time = time.time()
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, UserStoppedSpeakingFrame):
logger.debug(
f"Transcription latency: {(STOP_SECS - (time.time() - self._last_transcription_time)):.2f}"
)
if isinstance(frame, TranscriptionFrame):
self._last_transcription_time = time.time()
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
# instantiated. The function will be called when the desired transport gets
# selected.
transport_params = {
"daily": lambda: DailyParams(
audio_in_enabled=True,
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=STOP_SECS)),
),
"twilio": lambda: FastAPIWebsocketParams(
audio_in_enabled=True,
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=STOP_SECS)),
),
"webrtc": lambda: TransportParams(
audio_in_enabled=True,
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=STOP_SECS)),
),
}
async def run_example(transport: BaseTransport, _: argparse.Namespace, handle_sigint: bool):
logger.info(f"Starting bot")
stt = SambaNovaSTTService(
model="Whisper-Large-v3",
api_key=os.getenv("SAMBANOVA_API_KEY"),
)
tl = TranscriptionLogger()
pipeline = Pipeline([transport.input(), stt, tl])
task = PipelineTask(
pipeline,
params=PipelineParams(
enable_metrics=True,
enable_usage_metrics=True,
),
)
@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(transport, client):
logger.info(f"Client disconnected")
await task.cancel()
runner = PipelineRunner(handle_sigint=handle_sigint)
await runner.run(task)
if __name__ == "__main__":
from pipecat.examples.run import main
main(run_example, transport_params=transport_params)

View File

@@ -0,0 +1,152 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import argparse
import os
from dotenv import load_dotenv
from loguru import logger
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import TTSSpeakFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_response import LLMUserAggregatorParams
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.cartesia.tts import CartesiaTTSService
from pipecat.services.llm_service import FunctionCallParams
from pipecat.services.sambanova.llm import SambaNovaLLMService
from pipecat.services.sambanova.stt import SambaNovaSTTService
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.network.fastapi_websocket import FastAPIWebsocketParams
from pipecat.transports.services.daily import DailyParams
load_dotenv(override=True)
async def fetch_weather_from_api(params: FunctionCallParams):
await params.result_callback({"conditions": "nice", "temperature": "75"})
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
# instantiated. The function will be called when the desired transport gets
# selected.
transport_params = {
"daily": lambda: DailyParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
),
"twilio": lambda: FastAPIWebsocketParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
),
"webrtc": lambda: TransportParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
),
}
async def run_example(transport: BaseTransport, _: argparse.Namespace, handle_sigint: bool):
logger.info(f"Starting bot")
stt = SambaNovaSTTService(
model="Whisper-Large-v3",
api_key=os.getenv("SAMBANOVA_API_KEY"),
)
tts = CartesiaTTSService(
api_key=os.getenv("CARTESIA_API_KEY"),
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
)
llm = SambaNovaLLMService(
api_key=os.getenv("SAMBANOVA_API_KEY"),
model="Llama-4-Maverick-17B-128E-Instruct",
)
# You can also register a function_name of None to get all functions
# sent to the same callback with an additional function_name parameter.
llm.register_function("get_current_weather", fetch_weather_from_api)
@llm.event_handler("on_function_calls_started")
async def on_function_calls_started(service, function_calls):
await tts.queue_frame(TTSSpeakFrame("Let me check on that."))
weather_function = FunctionSchema(
name="get_current_weather",
description="Get the current weather",
properties={
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this from the user's location.",
},
},
required=["location"],
)
tools = ToolsSchema(standard_tools=[weather_function])
messages = [
{
"role": "system",
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
},
]
context = OpenAILLMContext(messages, tools)
context_aggregator = llm.create_context_aggregator(
context, user_params=LLMUserAggregatorParams(aggregation_timeout=0.05)
)
pipeline = Pipeline(
[
transport.input(),
stt,
context_aggregator.user(),
llm,
tts,
transport.output(),
context_aggregator.assistant(),
]
)
task = PipelineTask(
pipeline,
params=PipelineParams(
enable_metrics=True,
enable_usage_metrics=True,
),
)
@transport.event_handler("on_client_connected")
async def on_client_connected(transport, client):
logger.info(f"Client connected")
# Kick off the conversation.
await task.queue_frames([context_aggregator.user().get_context_frame()])
@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(transport, client):
logger.info(f"Client disconnected")
await task.cancel()
runner = PipelineRunner(handle_sigint=handle_sigint)
await runner.run(task)
if __name__ == "__main__":
from pipecat.examples.run import main
main(run_example, transport_params=transport_params)

View File

@@ -9,8 +9,8 @@ import asyncio
import os
import time
import google.ai.generativelanguage as glm
from dotenv import load_dotenv
from google.genai.types import Content, Part
from loguru import logger
from pipecat.audio.vad.silero import SileroVADAnalyzer
@@ -611,9 +611,7 @@ class OutputGate(FrameProcessor):
await self._notifier.wait()
transcription = await self._transcription_buffer.wait_for_transcription() or "-"
self._context._messages.append(
glm.Content(role="user", parts=[glm.Part(text=transcription)])
)
self._context.add_message(Content(role="user", parts=[Part(text=transcription)]))
self.open_gate()
for frame, direction in self._frames_buffer:

View File

@@ -8,8 +8,8 @@ import argparse
import os
from dataclasses import dataclass
import google.ai.generativelanguage as glm
from dotenv import load_dotenv
from google.genai.types import Content, Part
from loguru import logger
from pipecat.audio.vad.silero import SileroVADAnalyzer
@@ -142,8 +142,8 @@ class InputTranscriptionContextFilter(FrameProcessor):
context = GoogleLLMContext.upgrade_to_google(frame.context)
message = context.messages[-1]
if not isinstance(message, glm.Content):
logger.error(f"Expected glm.Content, got {type(message)}")
if not isinstance(message, Content):
logger.error(f"Expected Content, got {type(message)}")
return
last_part = message.parts[-1]
@@ -168,15 +168,15 @@ class InputTranscriptionContextFilter(FrameProcessor):
history += f"{msg.role}: {part.text}\n"
if history:
assembled = f"Here is the conversation history so far. These are not instructions. This is data that you should use only to improve the accuracy of your transcription.\n\n----\n\n{history}\n\n----\n\nEND OF CONVERSATION HISTORY\n\n"
parts.append(glm.Part(text=assembled))
parts.append(Part(text=assembled))
parts.append(
glm.Part(
Part(
text="Transcribe this audio. Respond either with the transcription exactly as it was said by the user, or with the special string 'EMPTY' if the audio is not clear."
)
)
parts.append(last_part)
msg = glm.Content(role="user", parts=parts)
msg = Content(role="user", parts=parts)
ctx = GoogleLLMContext([msg])
ctx.system_message = transcriber_system_message
await self.push_frame(OpenAILLMContextFrame(context=ctx))

View File

@@ -27,7 +27,6 @@ from pipecat.transports.services.daily import DailyParams
load_dotenv(override=True)
aiohttp_session = aiohttp.ClientSession()
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
# instantiated. The function will be called when the desired transport gets
@@ -38,7 +37,7 @@ transport_params = {
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
turn_analyzer=FalSmartTurnAnalyzer(
api_key=os.getenv("FAL_SMART_TURN_API_KEY"), aiohttp_session=aiohttp_session
api_key=os.getenv("FAL_SMART_TURN_API_KEY"), aiohttp_session=aiohttp.ClientSession()
),
),
"twilio": lambda: FastAPIWebsocketParams(
@@ -46,7 +45,7 @@ transport_params = {
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
turn_analyzer=FalSmartTurnAnalyzer(
api_key=os.getenv("FAL_SMART_TURN_API_KEY"), aiohttp_session=aiohttp_session
api_key=os.getenv("FAL_SMART_TURN_API_KEY"), aiohttp_session=aiohttp.ClientSession()
),
),
"webrtc": lambda: TransportParams(
@@ -54,7 +53,7 @@ transport_params = {
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
turn_analyzer=FalSmartTurnAnalyzer(
api_key=os.getenv("FAL_SMART_TURN_API_KEY"), aiohttp_session=aiohttp_session
api_key=os.getenv("FAL_SMART_TURN_API_KEY"), aiohttp_session=aiohttp.ClientSession()
),
),
}
@@ -118,8 +117,6 @@ async def run_example(transport: BaseTransport, _: argparse.Namespace, handle_si
await runner.run(task)
await aiohttp_session.close()
if __name__ == "__main__":
from pipecat.examples.run import main

View File

@@ -9,6 +9,7 @@ import os
from dotenv import load_dotenv
from loguru import logger
from mcp.client.session_group import SseServerParameters
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.pipeline.pipeline import Pipeline
@@ -63,7 +64,7 @@ async def run_example(transport: BaseTransport, _: argparse.Namespace, handle_si
try:
# https://docs.mcp.run/integrating/tutorials/mcp-run-sse-openai-agents/
mcp = MCPClient(server_params=os.getenv("MCP_RUN_SSE_URL"))
mcp = MCPClient(server_params=SseServerParameters(url=os.getenv("MCP_RUN_SSE_URL")))
except Exception as e:
logger.error(f"error setting up mcp")
logger.exception("error trace:")

View File

@@ -15,6 +15,7 @@ import aiohttp
from dotenv import load_dotenv
from loguru import logger
from mcp import StdioServerParameters
from mcp.client.session_group import SseServerParameters
from PIL import Image
from pipecat.adapters.schemas.tools_schema import ToolsSchema
@@ -149,7 +150,7 @@ async def run_example(transport: BaseTransport, _: argparse.Namespace, handle_si
# https://docs.mcp.run/integrating/tutorials/mcp-run-sse-openai-agents/
# ie. "https://www.mcp.run/api/mcp/sse?..."
# ensure the profile has a tool or few installed
mcp_run = MCPClient(server_params=os.getenv("MCP_RUN_SSE_URL"))
mcp_run = MCPClient(server_params=SseServerParameters(url=os.getenv("MCP_RUN_SSE_URL")))
except Exception as e:
logger.error(f"error setting up mcp.run")
logger.exception("error trace:")

View File

@@ -0,0 +1,133 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import argparse
import os
from dotenv import load_dotenv
from loguru import logger
from mcp.client.session_group import StreamableHttpParameters
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.cartesia.tts import CartesiaTTSService
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.mcp_service import MCPClient
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.network.fastapi_websocket import FastAPIWebsocketParams
from pipecat.transports.services.daily import DailyParams
load_dotenv(override=True)
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
# instantiated. The function will be called when the desired transport gets
# selected.
transport_params = {
"daily": lambda: DailyParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
),
"twilio": lambda: FastAPIWebsocketParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
),
"webrtc": lambda: TransportParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
),
}
async def run_example(transport: BaseTransport, _: argparse.Namespace, handle_sigint: bool):
logger.info(f"Starting bot")
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
tts = CartesiaTTSService(
api_key=os.getenv("CARTESIA_API_KEY"),
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
)
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o-mini")
try:
# Github MCP docs: https://github.com/github/github-mcp-server
# Enable Github Copilot on your GitHub account. Free tier is ok. (https://github.com/settings/copilot)
# Generate a personal access token. It must be a Fine-grained token, classic tokens are not supported. (https://github.com/settings/personal-access-tokens)
# Set permissions you want to use (eg. "all repositories", "profile: read/write", etc)
mcp = MCPClient(
server_params=StreamableHttpParameters(
url="https://api.githubcopilot.com/mcp/",
headers={"Authorization": f"Bearer {os.getenv('GITHUB_PERSONAL_ACCESS_TOKEN')}"},
)
)
except Exception as e:
logger.error(f"error setting up mcp")
logger.exception("error trace:")
tools = await mcp.register_tools(llm)
system = f"""
You are a helpful LLM in a WebRTC call.
Your goal is to answer questions about the user's GitHub repositories and account.
You have access to a number of tools provided by Github. Use any and all tools to help users.
Your output will be converted to audio so don't include special characters in your answers.
Don't overexplain what you are doing.
Just respond with short sentences when you are carrying out tool calls.
"""
messages = [{"role": "system", "content": system}]
context = OpenAILLMContext(messages, tools)
context_aggregator = llm.create_context_aggregator(context)
pipeline = Pipeline(
[
transport.input(), # Transport user input
stt,
context_aggregator.user(), # User spoken responses
llm, # LLM
tts, # TTS
transport.output(), # Transport bot output
context_aggregator.assistant(), # Assistant spoken responses and tool context
]
)
task = PipelineTask(
pipeline,
params=PipelineParams(
enable_metrics=True,
enable_usage_metrics=True,
),
)
@transport.event_handler("on_client_connected")
async def on_client_connected(transport, client):
logger.info(f"Client connected: {client}")
# Kick off the conversation.
await task.queue_frames([context_aggregator.user().get_context_frame()])
@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(transport, client):
logger.info(f"Client disconnected")
await task.cancel()
runner = PipelineRunner(handle_sigint=handle_sigint)
await runner.run(task)
if __name__ == "__main__":
from pipecat.examples.run import main
main(run_example, transport_params=transport_params)

View File

@@ -0,0 +1,59 @@
# Freeze Test Client
The purpose of this example is to create an environment for testing the bot and try to create freezing conditions.
### Approach 1: Server-Side Testing with `SimulateFreezeInput`
- Utilize only the bot `freeze_test_bot.py` with the `SimulateFreezeInput` processor. This input continuously injects frames, simulating user speech interruptions at random intervals.
- This approach excludes the use of input transport and speech-to-text (STT) functionalities.
### Approach 2: Server-Side with TypeScript Client
- Combine server-side operations with a TypeScript client.
- The client initially records a segment of audio, e.g., 510 seconds long. It can be anything.
- After that, it replays this recorded audio to the server at random intervals, mimicking user input interruptions.
- This helps testing interruptions in the pipeline as if real users were interacting with the bot.
## Setup
Follow these steps to set up and run the Freeze Test Client:
1. **Run the Bot Server**
- Set up and activate your virtual environment:
```bash
python3 -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate
```
- Install dependencies:
```bash
pip install -r requirements.txt
```
- Create your `.env` file and set your env vars:
```bash
cp env.example .env
```
- Run the server:
```bash
python freeze_test_bot.py
```
2. **Navigate to the Client Directory**
```bash
cd client
```
3. **Install Dependencies**
```bash
npm install
```
4. **Run the Client Application**
```bash
npm run dev
```
5. **Access the Client in Your Browser**
Visit [http://localhost:5173](http://localhost:5173) to interact with the Freeze Test Client.

View File

@@ -0,0 +1,43 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>AI Chatbot</title>
</head>
<body>
<div class="container">
<div class="status-bar">
<div class="status">
Transport: <span id="connection-status">Disconnected</span>
</div>
<div class="controls">
<button id="connect-btn">Connect</button>
<button id="disconnect-btn" disabled>Disconnect</button>
</div>
</div>
<div class="status-bar">
<div class="status">
Playing audio: <span id="play-audio-status"></span>
</div>
<div class="controls">
<button id="play-btn">Start</button>
<button id="stop-btn" disabled>Stop</button>
</div>
</div>
<audio id="bot-audio" autoplay></audio>
<div class="debug-panel">
<h3>Debug Info</h3>
<div id="debug-log"></div>
</div>
</div>
<script type="module" src="/src/app.ts"></script>
<link rel="stylesheet" href="/src/style.css">
</body>
</html>

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,26 @@
{
"name": "client",
"version": "1.0.0",
"main": "index.js",
"scripts": {
"dev": "vite",
"build": "tsc && vite build",
"preview": "vite preview"
},
"keywords": [],
"author": "",
"license": "ISC",
"description": "",
"devDependencies": {
"@types/node": "^22.15.30",
"@types/protobufjs": "^6.0.0",
"@vitejs/plugin-react-swc": "^3.10.1",
"typescript": "^5.8.3",
"vite": "^6.3.5"
},
"dependencies": {
"@pipecat-ai/client-js": "^0.4.0",
"@pipecat-ai/websocket-transport": "^0.4.1",
"protobufjs": "^7.4.0"
}
}

View File

@@ -0,0 +1,328 @@
/**
* Copyright (c) 20242025, Daily
*
* SPDX-License-Identifier: BSD 2-Clause License
*/
/**
* RTVI Client Implementation
*
* This client connects to an RTVI-compatible bot server using WebSocket.
*
* Requirements:
* - A running RTVI bot server (defaults to http://localhost:7860)
*/
import {
RTVIClient,
RTVIClientOptions,
RTVIEvent,
} from '@pipecat-ai/client-js';
import {
ProtobufFrameSerializer,
WebSocketTransport
} from "@pipecat-ai/websocket-transport";
class RecordingSerializer extends ProtobufFrameSerializer {
private lastTimestamp: number | null = null;
private recordingAudioToSend: boolean = false;
private _recordedAudio: { data: ArrayBuffer; delay: number }[] = [];
public startRecording() {
this.recordingAudioToSend = true;
this._recordedAudio = [];
this.lastTimestamp = null;
}
public stopRecording() {
this.recordingAudioToSend = false;
}
// @ts-ignore
serializeAudio(data: ArrayBuffer, sampleRate: number, numChannels: number): Uint8Array | null {
if (this.recordingAudioToSend) {
const now = Date.now();
// Compute delay since last packet
const delay = this.lastTimestamp ? now - this.lastTimestamp : 0;
this.lastTimestamp = now;
// Save audio chunk and delay
this._recordedAudio.push({ data, delay });
return null;
} else {
return super.serializeAudio(data, sampleRate, numChannels);
}
}
public get recordedAudio() {
return this._recordedAudio
}
}
class WebsocketClientApp {
private ENABLE_RECORDING_MODE = false
private RECORDING_TIME_MS = 10000
private rtviClient: RTVIClient | null = null;
private connectBtn: HTMLButtonElement | null = null;
private disconnectBtn: HTMLButtonElement | null = null;
private statusSpan: HTMLElement | null = null;
private debugLog: HTMLElement | null = null;
private botAudio: HTMLAudioElement;
private declare websocketTransport: WebSocketTransport;
private sendRecordedAudio: boolean = false
private declare recordingSerializer: RecordingSerializer;
private playBtn: HTMLButtonElement | null = null;
private stopBtn: HTMLButtonElement | null = null;
constructor() {
this.botAudio = document.createElement('audio');
this.botAudio.autoplay = true;
//this.botAudio.playsInline = true;
document.body.appendChild(this.botAudio);
this.setupDOMElements();
this.setupEventListeners();
}
/**
* Set up references to DOM elements and create necessary media elements
*/
private setupDOMElements(): void {
this.connectBtn = document.getElementById('connect-btn') as HTMLButtonElement;
this.disconnectBtn = document.getElementById('disconnect-btn') as HTMLButtonElement;
this.statusSpan = document.getElementById('connection-status');
this.debugLog = document.getElementById('debug-log');
this.playBtn = document.getElementById('play-btn') as HTMLButtonElement;
this.stopBtn = document.getElementById('stop-btn') as HTMLButtonElement;
}
/**
* Set up event listeners for connect/disconnect buttons
*/
private setupEventListeners(): void {
this.connectBtn?.addEventListener('click', () => this.connect());
this.disconnectBtn?.addEventListener('click', () => this.disconnect());
this.playBtn?.addEventListener('click', () => this.startSendingRecordedAudio());
this.stopBtn?.addEventListener('click', () => this.stopSendingRecordedAudio());
}
/**
* Add a timestamped message to the debug log
*/
private log(message: string): void {
if (!this.debugLog) return;
const entry = document.createElement('div');
entry.textContent = `${new Date().toISOString()} - ${message}`;
if (message.startsWith('User: ')) {
entry.style.color = '#2196F3';
} else if (message.startsWith('Bot: ')) {
entry.style.color = '#4CAF50';
}
this.debugLog.appendChild(entry);
this.debugLog.scrollTop = this.debugLog.scrollHeight;
console.log(message);
}
/**
* Update the connection status display
*/
private updateStatus(status: string): void {
if (this.statusSpan) {
this.statusSpan.textContent = status;
}
this.log(`Status: ${status}`);
}
/**
* Check for available media tracks and set them up if present
* This is called when the bot is ready or when the transport state changes to ready
*/
setupMediaTracks() {
if (!this.rtviClient) return;
const tracks = this.rtviClient.tracks();
if (tracks.bot?.audio) {
this.setupAudioTrack(tracks.bot.audio);
}
}
/**
* Set up listeners for track events (start/stop)
* This handles new tracks being added during the session
*/
setupTrackListeners() {
if (!this.rtviClient) return;
// Listen for new tracks starting
this.rtviClient.on(RTVIEvent.TrackStarted, (track, participant) => {
// Only handle non-local (bot) tracks
if (!participant?.local && track.kind === 'audio') {
this.setupAudioTrack(track);
}
});
// Listen for tracks stopping
this.rtviClient.on(RTVIEvent.TrackStopped, (track, participant) => {
this.log(`Track stopped: ${track.kind} from ${participant?.name || 'unknown'}`);
});
}
/**
* Set up an audio track for playback
* Handles both initial setup and track updates
*/
private setupAudioTrack(track: MediaStreamTrack): void {
this.log('Setting up audio track');
if (this.botAudio.srcObject && "getAudioTracks" in this.botAudio.srcObject) {
const oldTrack = this.botAudio.srcObject.getAudioTracks()[0];
if (oldTrack?.id === track.id) return;
}
this.botAudio.srcObject = new MediaStream([track]);
}
/**
* Initialize and connect to the bot
* This sets up the RTVI client, initializes devices, and establishes the connection
*/
public async connect(): Promise<void> {
try {
const startTime = Date.now();
this.recordingSerializer = new RecordingSerializer()
const transport = this.ENABLE_RECORDING_MODE ? new WebSocketTransport({serializer: this.recordingSerializer}) : new WebSocketTransport();
this.websocketTransport = transport
const RTVIConfig: RTVIClientOptions = {
transport,
params: {
// The baseURL and endpoint of your bot server that the client will connect to
baseUrl: 'http://localhost:7860',
endpoints: { connect: '/connect' },
},
enableMic: true,
enableCam: false,
callbacks: {
onConnected: () => {
this.updateStatus('Connected');
if (this.connectBtn) this.connectBtn.disabled = true;
if (this.disconnectBtn) this.disconnectBtn.disabled = false;
},
onDisconnected: () => {
this.updateStatus('Disconnected');
if (this.connectBtn) this.connectBtn.disabled = false;
if (this.disconnectBtn) this.disconnectBtn.disabled = true;
this.log('Client disconnected');
},
onBotReady: (data) => {
this.log(`Bot ready: ${JSON.stringify(data)}`);
this.setupMediaTracks();
},
onUserTranscript: (data) => {
if (data.final) {
this.log(`User: ${data.text}`);
}
},
onBotTranscript: (data) => this.log(`Bot: ${data.text}`),
onMessageError: (error) => console.error('Message error:', error),
onError: (error) => console.error('Error:', error),
},
}
this.rtviClient = new RTVIClient(RTVIConfig);
this.setupTrackListeners();
this.log('Initializing devices...');
await this.rtviClient.initDevices();
this.log('Connecting to bot...');
await this.rtviClient.connect();
const timeTaken = Date.now() - startTime;
this.log(`Connection complete, timeTaken: ${timeTaken}`);
if (this.ENABLE_RECORDING_MODE) {
this.log(`Starting to recording the next ${(this.RECORDING_TIME_MS/1000)}s of audio`);
this.recordingSerializer.startRecording()
await this.sleep(this.RECORDING_TIME_MS)
this.recordingSerializer.stopRecording()
this.log("Recording stopped");
this.rtviClient.enableMic(false)
this.startSendingRecordedAudio()
}
} catch (error) {
this.log(`Error connecting: ${(error as Error).message}`);
this.updateStatus('Error');
// Clean up if there's an error
if (this.rtviClient) {
try {
await this.rtviClient.disconnect();
} catch (disconnectError) {
this.log(`Error during disconnect: ${disconnectError}`);
}
}
}
}
/**
* Disconnect from the bot and clean up media resources
*/
public async disconnect(): Promise<void> {
if (this.rtviClient) {
try {
this.stopSendingRecordedAudio()
await this.rtviClient.disconnect();
this.rtviClient = null;
if (this.botAudio.srcObject && "getAudioTracks" in this.botAudio.srcObject) {
this.botAudio.srcObject.getAudioTracks().forEach((track) => track.stop());
this.botAudio.srcObject = null;
}
} catch (error) {
this.log(`Error disconnecting: ${(error as Error).message}`);
}
}
}
private startSendingRecordedAudio() {
this.sendRecordedAudio = true
if (this.playBtn) this.playBtn.disabled = true;
if (this.stopBtn) this.stopBtn.disabled = false;
void this.replayAudio()
}
private stopSendingRecordedAudio() {
if (this.stopBtn) this.stopBtn.disabled = true;
if (this.playBtn) this.playBtn.disabled = false;
this.sendRecordedAudio = false
}
private async replayAudio() {
if (this.sendRecordedAudio) {
this.log("Sending recorded audio")
for (const chunk of this.recordingSerializer.recordedAudio) {
await this.sleep(chunk.delay);
this.websocketTransport.handleUserAudioStream(chunk.data);
}
const randomDelay = 1000 + Math.random() * (10000 - 500);
await this.sleep(randomDelay);
void this.replayAudio()
}
}
private sleep(ms: number): Promise<void> {
return new Promise(resolve => setTimeout(resolve, ms));
}
}
declare global {
interface Window {
WebsocketClientApp: typeof WebsocketClientApp;
}
}
window.addEventListener('DOMContentLoaded', () => {
window.WebsocketClientApp = WebsocketClientApp;
new WebsocketClientApp();
});

View File

@@ -0,0 +1,98 @@
body {
margin: 0;
padding: 20px;
font-family: Arial, sans-serif;
background-color: #f0f0f0;
}
.container {
max-width: 1200px;
margin: 0 auto;
}
.status-bar {
display: flex;
justify-content: space-between;
align-items: center;
padding: 10px;
background-color: #fff;
border-radius: 8px;
margin-bottom: 20px;
}
.controls button {
padding: 8px 16px;
margin-left: 10px;
border: none;
border-radius: 4px;
cursor: pointer;
}
#connect-btn {
background-color: #4caf50;
color: white;
}
#disconnect-btn {
background-color: #f44336;
color: white;
}
button:disabled {
opacity: 0.5;
cursor: not-allowed;
}
.main-content {
background-color: #fff;
border-radius: 8px;
padding: 20px;
margin-bottom: 20px;
}
.bot-container {
display: flex;
flex-direction: column;
align-items: center;
}
#bot-video-container {
width: 640px;
height: 360px;
background-color: #e0e0e0;
border-radius: 8px;
margin: 20px auto;
overflow: hidden;
display: flex;
align-items: center;
justify-content: center;
}
#bot-video-container video {
width: 100%;
height: 100%;
object-fit: cover;
}
.debug-panel {
background-color: #fff;
border-radius: 8px;
padding: 20px;
}
.debug-panel h3 {
margin: 0 0 10px 0;
font-size: 16px;
font-weight: bold;
}
#debug-log {
height: 500px;
overflow-y: auto;
background-color: #f8f8f8;
padding: 10px;
border-radius: 4px;
font-family: monospace;
font-size: 12px;
line-height: 1.4;
}

View File

@@ -0,0 +1,111 @@
{
"compilerOptions": {
/* Visit https://aka.ms/tsconfig to read more about this file */
/* Projects */
// "incremental": true, /* Save .tsbuildinfo files to allow for incremental compilation of projects. */
// "composite": true, /* Enable constraints that allow a TypeScript project to be used with project references. */
// "tsBuildInfoFile": "./.tsbuildinfo", /* Specify the path to .tsbuildinfo incremental compilation file. */
// "disableSourceOfProjectReferenceRedirect": true, /* Disable preferring source files instead of declaration files when referencing composite projects. */
// "disableSolutionSearching": true, /* Opt a project out of multi-project reference checking when editing. */
// "disableReferencedProjectLoad": true, /* Reduce the number of projects loaded automatically by TypeScript. */
/* Language and Environment */
"target": "es2016", /* Set the JavaScript language version for emitted JavaScript and include compatible library declarations. */
// "lib": [], /* Specify a set of bundled library declaration files that describe the target runtime environment. */
// "jsx": "preserve", /* Specify what JSX code is generated. */
// "experimentalDecorators": true, /* Enable experimental support for legacy experimental decorators. */
// "emitDecoratorMetadata": true, /* Emit design-type metadata for decorated declarations in source files. */
// "jsxFactory": "", /* Specify the JSX factory function used when targeting React JSX emit, e.g. 'React.createElement' or 'h'. */
// "jsxFragmentFactory": "", /* Specify the JSX Fragment reference used for fragments when targeting React JSX emit e.g. 'React.Fragment' or 'Fragment'. */
// "jsxImportSource": "", /* Specify module specifier used to import the JSX factory functions when using 'jsx: react-jsx*'. */
// "reactNamespace": "", /* Specify the object invoked for 'createElement'. This only applies when targeting 'react' JSX emit. */
// "noLib": true, /* Disable including any library files, including the default lib.d.ts. */
// "useDefineForClassFields": true, /* Emit ECMAScript-standard-compliant class fields. */
// "moduleDetection": "auto", /* Control what method is used to detect module-format JS files. */
/* Modules */
"module": "commonjs", /* Specify what module code is generated. */
// "rootDir": "./", /* Specify the root folder within your source files. */
// "moduleResolution": "node10", /* Specify how TypeScript looks up a file from a given module specifier. */
// "baseUrl": "./", /* Specify the base directory to resolve non-relative module names. */
// "paths": {}, /* Specify a set of entries that re-map imports to additional lookup locations. */
// "rootDirs": [], /* Allow multiple folders to be treated as one when resolving modules. */
// "typeRoots": [], /* Specify multiple folders that act like './node_modules/@types'. */
// "types": [], /* Specify type package names to be included without being referenced in a source file. */
// "allowUmdGlobalAccess": true, /* Allow accessing UMD globals from modules. */
// "moduleSuffixes": [], /* List of file name suffixes to search when resolving a module. */
// "allowImportingTsExtensions": true, /* Allow imports to include TypeScript file extensions. Requires '--moduleResolution bundler' and either '--noEmit' or '--emitDeclarationOnly' to be set. */
// "rewriteRelativeImportExtensions": true, /* Rewrite '.ts', '.tsx', '.mts', and '.cts' file extensions in relative import paths to their JavaScript equivalent in output files. */
// "resolvePackageJsonExports": true, /* Use the package.json 'exports' field when resolving package imports. */
// "resolvePackageJsonImports": true, /* Use the package.json 'imports' field when resolving imports. */
// "customConditions": [], /* Conditions to set in addition to the resolver-specific defaults when resolving imports. */
// "noUncheckedSideEffectImports": true, /* Check side effect imports. */
// "resolveJsonModule": true, /* Enable importing .json files. */
// "allowArbitraryExtensions": true, /* Enable importing files with any extension, provided a declaration file is present. */
// "noResolve": true, /* Disallow 'import's, 'require's or '<reference>'s from expanding the number of files TypeScript should add to a project. */
/* JavaScript Support */
// "allowJs": true, /* Allow JavaScript files to be a part of your program. Use the 'checkJS' option to get errors from these files. */
// "checkJs": true, /* Enable error reporting in type-checked JavaScript files. */
// "maxNodeModuleJsDepth": 1, /* Specify the maximum folder depth used for checking JavaScript files from 'node_modules'. Only applicable with 'allowJs'. */
/* Emit */
// "declaration": true, /* Generate .d.ts files from TypeScript and JavaScript files in your project. */
// "declarationMap": true, /* Create sourcemaps for d.ts files. */
// "emitDeclarationOnly": true, /* Only output d.ts files and not JavaScript files. */
// "sourceMap": true, /* Create source map files for emitted JavaScript files. */
// "inlineSourceMap": true, /* Include sourcemap files inside the emitted JavaScript. */
// "noEmit": true, /* Disable emitting files from a compilation. */
// "outFile": "./", /* Specify a file that bundles all outputs into one JavaScript file. If 'declaration' is true, also designates a file that bundles all .d.ts output. */
// "outDir": "./", /* Specify an output folder for all emitted files. */
// "removeComments": true, /* Disable emitting comments. */
// "importHelpers": true, /* Allow importing helper functions from tslib once per project, instead of including them per-file. */
// "downlevelIteration": true, /* Emit more compliant, but verbose and less performant JavaScript for iteration. */
// "sourceRoot": "", /* Specify the root path for debuggers to find the reference source code. */
// "mapRoot": "", /* Specify the location where debugger should locate map files instead of generated locations. */
// "inlineSources": true, /* Include source code in the sourcemaps inside the emitted JavaScript. */
// "emitBOM": true, /* Emit a UTF-8 Byte Order Mark (BOM) in the beginning of output files. */
// "newLine": "crlf", /* Set the newline character for emitting files. */
// "stripInternal": true, /* Disable emitting declarations that have '@internal' in their JSDoc comments. */
// "noEmitHelpers": true, /* Disable generating custom helper functions like '__extends' in compiled output. */
// "noEmitOnError": true, /* Disable emitting files if any type checking errors are reported. */
// "preserveConstEnums": true, /* Disable erasing 'const enum' declarations in generated code. */
// "declarationDir": "./", /* Specify the output directory for generated declaration files. */
/* Interop Constraints */
// "isolatedModules": true, /* Ensure that each file can be safely transpiled without relying on other imports. */
// "verbatimModuleSyntax": true, /* Do not transform or elide any imports or exports not marked as type-only, ensuring they are written in the output file's format based on the 'module' setting. */
// "isolatedDeclarations": true, /* Require sufficient annotation on exports so other tools can trivially generate declaration files. */
// "allowSyntheticDefaultImports": true, /* Allow 'import x from y' when a module doesn't have a default export. */
"esModuleInterop": true, /* Emit additional JavaScript to ease support for importing CommonJS modules. This enables 'allowSyntheticDefaultImports' for type compatibility. */
// "preserveSymlinks": true, /* Disable resolving symlinks to their realpath. This correlates to the same flag in node. */
"forceConsistentCasingInFileNames": true, /* Ensure that casing is correct in imports. */
/* Type Checking */
"strict": true, /* Enable all strict type-checking options. */
// "noImplicitAny": true, /* Enable error reporting for expressions and declarations with an implied 'any' type. */
// "strictNullChecks": true, /* When type checking, take into account 'null' and 'undefined'. */
// "strictFunctionTypes": true, /* When assigning functions, check to ensure parameters and the return values are subtype-compatible. */
// "strictBindCallApply": true, /* Check that the arguments for 'bind', 'call', and 'apply' methods match the original function. */
// "strictPropertyInitialization": true, /* Check for class properties that are declared but not set in the constructor. */
// "strictBuiltinIteratorReturn": true, /* Built-in iterators are instantiated with a 'TReturn' type of 'undefined' instead of 'any'. */
// "noImplicitThis": true, /* Enable error reporting when 'this' is given the type 'any'. */
// "useUnknownInCatchVariables": true, /* Default catch clause variables as 'unknown' instead of 'any'. */
// "alwaysStrict": true, /* Ensure 'use strict' is always emitted. */
// "noUnusedLocals": true, /* Enable error reporting when local variables aren't read. */
// "noUnusedParameters": true, /* Raise an error when a function parameter isn't read. */
// "exactOptionalPropertyTypes": true, /* Interpret optional property types as written, rather than adding 'undefined'. */
// "noImplicitReturns": true, /* Enable error reporting for codepaths that do not explicitly return in a function. */
// "noFallthroughCasesInSwitch": true, /* Enable error reporting for fallthrough cases in switch statements. */
// "noUncheckedIndexedAccess": true, /* Add 'undefined' to a type when accessed using an index. */
// "noImplicitOverride": true, /* Ensure overriding members in derived classes are marked with an override modifier. */
// "noPropertyAccessFromIndexSignature": true, /* Enforces using indexed accessors for keys declared using an indexed type. */
// "allowUnusedLabels": true, /* Disable error reporting for unused labels. */
// "allowUnreachableCode": true, /* Disable error reporting for unreachable code. */
/* Completeness */
// "skipDefaultLibCheck": true, /* Skip type checking .d.ts files that are included with TypeScript. */
"skipLibCheck": true /* Skip type checking all .d.ts files. */
}
}

View File

@@ -0,0 +1,15 @@
import { defineConfig } from 'vite';
import react from '@vitejs/plugin-react-swc';
export default defineConfig({
plugins: [react()],
server: {
proxy: {
// Proxy /api requests to the backend server
'/connect': {
target: 'http://0.0.0.0:7860', // Replace with your backend URL
changeOrigin: true,
},
},
},
});

View File

@@ -0,0 +1,322 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import argparse
import asyncio
import os
import random
from contextlib import asynccontextmanager
from typing import Any, Dict
import sentry_sdk
import uvicorn
from dotenv import load_dotenv
from fastapi import FastAPI, Request, WebSocket
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse
from loguru import logger
from pipecat_ai_small_webrtc_prebuilt.frontend import SmallWebRTCPrebuiltUI
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
Frame,
InterimTranscriptionFrame,
LLMFullResponseEndFrame,
StartFrame,
StartInterruptionFrame,
StopFrame,
StopInterruptionFrame,
TranscriptionFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
)
from pipecat.observers.loggers.debug_log_observer import DebugLogObserver
from pipecat.pipeline.parallel_pipeline import ParallelPipeline
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContext,
OpenAILLMContextFrame,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.processors.frameworks.rtvi import RTVIConfig, RTVIProcessor
from pipecat.processors.metrics.sentry import SentryMetrics
from pipecat.serializers.protobuf import ProtobufFrameSerializer
from pipecat.services.cartesia.tts import CartesiaTTSService
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.transports.network.fastapi_websocket import (
FastAPIWebsocketParams,
FastAPIWebsocketTransport,
)
from pipecat.utils.time import time_now_iso8601
load_dotenv(override=True)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Handles FastAPI startup and shutdown."""
yield # Run app
# Initialize FastAPI app with lifespan manager
app = FastAPI(lifespan=lifespan)
# Configure CORS to allow requests from any origin
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Mount the frontend at /
app.mount("/client", SmallWebRTCPrebuiltUI)
class SimulateFreezeInput(FrameProcessor):
def __init__(
self,
**kwargs,
):
super().__init__(**kwargs)
# Whether we have seen a StartFrame already.
self._initialized = False
self._send_frames_task = None
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, StartFrame):
# Push StartFrame before start(), because we want StartFrame to be
# processed by every processor before any other frame is processed.
await self.push_frame(frame, direction)
await self._start(frame)
elif isinstance(frame, CancelFrame):
logger.info("SimulateFreezeInput: Received cancel frame")
await self._stop()
await self.push_frame(frame, direction)
elif isinstance(frame, EndFrame):
logger.info("SimulateFreezeInput: Received end frame")
await self.push_frame(frame, direction)
await self._stop()
elif isinstance(frame, StopFrame):
logger.info("SimulateFreezeInput: Received stop frame")
await self.push_frame(frame, direction)
await self._stop()
async def _start(self, frame: StartFrame):
if self._initialized:
return
logger.info(f"Starting SimulateFreezeInput")
self._initialized = True
if not self._send_frames_task:
self._send_frames_task = self.create_task(self._send_frames())
async def _stop(self):
logger.info(f"Stopping SimulateFreezeInput")
self._initialized = False
if self._send_frames_task:
await self.cancel_task(self._send_frames_task)
self._send_frames_task = None
async def _send_user_text(self, text: str):
self.reset_watchdog()
# Emulation as if the user has spoken and the stt transcribed
await self.push_frame(UserStartedSpeakingFrame())
await self.push_frame(StartInterruptionFrame())
await self.push_frame(
TranscriptionFrame(
text,
"",
time_now_iso8601(),
)
)
# Need to wait before sending the UserStoppedSpeakingFrame,
# otherwise TranscriptionFrame will be processed
# later than the UserStoppedSpeakingFrame
await asyncio.sleep(0.1)
await self.push_frame(UserStoppedSpeakingFrame())
await self.push_frame(StopInterruptionFrame())
async def _send_frames(self):
try:
i = 0
while True:
logger.debug("SimulateFreezeInput _send_frames")
await self._send_user_text("Tell me a brief history of Brazil!")
await asyncio.sleep(3)
await self._send_user_text("and who has discovered it")
i += 1
if i >= 20:
break
# sleeping 1s before interrupting
wait_time = random.uniform(1, 10)
await asyncio.sleep(wait_time)
except Exception as e:
logger.error(f"{self} exception receiving data: {e.__class__.__name__} ({e})")
async def run_example(websocket_client):
logger.info(f"Starting bot")
# Create a transport using the WebRTC connection
transport = FastAPIWebsocketTransport(
websocket=websocket_client,
params=FastAPIWebsocketParams(
audio_in_enabled=True,
audio_out_enabled=True,
add_wav_header=False,
vad_analyzer=SileroVADAnalyzer(),
serializer=ProtobufFrameSerializer(),
),
)
sentry_sdk.init(
dsn=os.getenv("SENTRY_DSN"),
traces_sample_rate=1.0,
)
freeze = SimulateFreezeInput()
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
tts = CartesiaTTSService(
api_key=os.getenv("CARTESIA_API_KEY"),
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
metrics=SentryMetrics(),
)
llm = OpenAILLMService(
api_key=os.getenv("OPENAI_API_KEY"),
metrics=SentryMetrics(),
)
rtvi = RTVIProcessor(config=RTVIConfig(config=[]))
messages = [
{
"role": "system",
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
},
]
context = OpenAILLMContext(messages)
context_aggregator = llm.create_context_aggregator(context)
pipeline = Pipeline(
[
ParallelPipeline(
[
freeze,
],
[
transport.input(),
stt,
],
),
rtvi,
context_aggregator.user(), # User responses
llm, # LLM
tts, # TTS
transport.output(), # Transport bot output
context_aggregator.assistant(), # Assistant spoken responses
]
)
task = PipelineTask(
pipeline,
params=PipelineParams(
allow_interruptions=True,
enable_metrics=True,
enable_usage_metrics=True,
report_only_initial_ttfb=True,
),
idle_timeout_secs=120,
observers=[
DebugLogObserver(
frame_types={
InterimTranscriptionFrame: None,
TranscriptionFrame: None,
# TTSTextFrame: None,
# LLMTextFrame: None,
OpenAILLMContextFrame: None,
LLMFullResponseEndFrame: None,
},
exclude_fields={
"result",
"metadata",
"audio",
"image",
"images",
},
),
],
enable_watchdog_timers=True,
)
@transport.event_handler("on_client_connected")
async def on_client_connected(transport, client):
logger.info(f"Client connected")
@rtvi.event_handler("on_client_ready")
async def on_client_ready(rtvi):
logger.info(f"Client ready")
await rtvi.set_bot_ready()
# Kick off the conversation.
# messages.append({"role": "system", "content": "Please introduce yourself to the user."})
# await task.queue_frames([context_aggregator.user().get_context_frame()])
@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(transport, client):
logger.info(f"Client disconnected")
await task.cancel()
runner = PipelineRunner(handle_sigint=False)
await runner.run(task)
@app.get("/", include_in_schema=False)
async def root_redirect():
return RedirectResponse(url="/client/")
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
print("WebSocket connection accepted")
try:
await run_example(websocket)
except Exception as e:
print(f"Exception in run_bot: {e}")
@app.post("/connect")
async def bot_connect(request: Request) -> Dict[Any, Any]:
server_mode = os.getenv("WEBSOCKET_SERVER", "fast_api")
if server_mode == "websocket_server":
ws_url = "ws://localhost:8765"
else:
ws_url = "ws://localhost:7860/ws"
return {"ws_url": ws_url}
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Pipecat Bot Runner")
parser.add_argument(
"--host", default="localhost", help="Host for HTTP server (default: localhost)"
)
parser.add_argument(
"--port", type=int, default=7860, help="Port for HTTP server (default: 7860)"
)
args = parser.parse_args()
uvicorn.run(app, host=args.host, port=args.port)

View File

@@ -143,6 +143,7 @@ async def main():
DailyParams(
audio_in_enabled=True,
audio_out_enabled=True,
video_in_enabled=True,
video_out_enabled=True,
video_out_width=1024,
video_out_height=576,

View File

@@ -49,7 +49,7 @@ async def main():
# Initialize Sentry
sentry_sdk.init(
dsn="your-project-dsn",
dsn=os.getenv("SENTRY_DSN"),
traces_sample_rate=1.0,
)

View File

@@ -64,7 +64,7 @@ langchain = [ "langchain~=0.3.20", "langchain-community~=0.3.20", "langchain-ope
livekit = [ "livekit~=0.22.0", "livekit-api~=0.8.2", "tenacity~=9.0.0" ]
lmnt = [ "websockets~=13.1" ]
local = [ "pyaudio~=0.2.14" ]
mcp = [ "mcp[cli]~=1.6.0" ]
mcp = [ "mcp[cli]~=1.9.4" ]
mem0 = [ "mem0ai~=0.1.94" ]
mlx-whisper = [ "mlx-whisper~=0.4.2" ]
moondream = [ "einops~=0.8.0", "timm~=1.0.13", "transformers~=4.48.0" ]
@@ -79,6 +79,7 @@ playht = [ "pyht~=0.1.12", "websockets~=13.1" ]
qwen = []
rime = [ "websockets~=13.1" ]
riva = [ "nvidia-riva-client~=2.19.1" ]
sambanova = []
sentry = [ "sentry-sdk~=2.23.1" ]
local-smart-turn = [ "coremltools>=8.0", "transformers", "torch==2.5.0", "torchaudio==2.5.0" ]
remote-smart-turn = []
@@ -122,8 +123,7 @@ select = [
"D", # Docstring rules
"I", # Import rules
]
# We ignore D107 because class docstrings already document __init__ parameters
# and our Sphinx configuration uses napoleon_include_init_with_doc=True
# Ignore requirement for __init__ docstrings
ignore = ["D107"]
[tool.ruff.lint.pydocstyle]

View File

@@ -78,3 +78,8 @@ class BaseTurnAnalyzer(ABC):
EndOfTurnState: The result of the end of turn analysis.
"""
pass
@abstractmethod
def clear(self):
"""Reset the turn analyzer to its initial state."""
pass

View File

@@ -98,6 +98,9 @@ class BaseSmartTurn(BaseTurnAnalyzer):
logger.debug(f"End of Turn result: {state}")
return state, result
def clear(self):
self._clear(EndOfTurnState.COMPLETE)
def _clear(self, turn_state: EndOfTurnState):
# If the state is still incomplete, keep the _speech_triggered as True
self._speech_triggered = turn_state == EndOfTurnState.INCOMPLETE

View File

@@ -7,6 +7,7 @@
from dataclasses import dataclass, field
from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
@@ -26,6 +27,9 @@ from pipecat.transcriptions.language import Language
from pipecat.utils.time import nanoseconds_to_str
from pipecat.utils.utils import obj_count, obj_id
if TYPE_CHECKING:
from pipecat.processors.frame_processor import FrameProcessor
class KeypadEntry(str, Enum):
"""DTMF entries."""
@@ -449,8 +453,8 @@ class StartFrame(SystemFrame):
allow_interruptions: bool = False
enable_metrics: bool = False
enable_usage_metrics: bool = False
report_only_initial_ttfb: bool = False
interruption_strategies: List[BaseInterruptionStrategy] = field(default_factory=list)
report_only_initial_ttfb: bool = False
@dataclass
@@ -485,16 +489,6 @@ class FatalErrorFrame(ErrorFrame):
fatal: bool = field(default=True, init=False)
@dataclass
class HeartbeatFrame(SystemFrame):
"""This frame is used by the pipeline task as a mechanism to know if the
pipeline is running properly.
"""
timestamp: int
@dataclass
class EndTaskFrame(SystemFrame):
"""This is used to notify the pipeline task that the pipeline should be
@@ -529,25 +523,25 @@ class StopTaskFrame(SystemFrame):
@dataclass
class FrameProcessorPauseUrgentFrame(SystemFrame):
"""This processor is used to pause frame processing for the given processor
as fast as possible. Pausing frame processing will keep frames in the
internal queue which will then be processed when frame processing is resumed
with `FrameProcessorResumeFrame`.
"""This frame is used to pause frame processing for the given processor as
fast as possible. Pausing frame processing will keep frames in the internal
queue which will then be processed when frame processing is resumed with
`FrameProcessorResumeFrame`.
"""
processor: str
processor: "FrameProcessor"
@dataclass
class FrameProcessorResumeUrgentFrame(SystemFrame):
"""This processor is used to resume frame processing for the given processor
"""This frame is used to resume frame processing for the given processor
if it was previously paused as fast as possible. After resuming frame
processing all queued frames will be processed in the order received.
"""
processor: str
processor: "FrameProcessor"
@dataclass
@@ -877,25 +871,37 @@ class StopFrame(ControlFrame):
pass
@dataclass
class HeartbeatFrame(ControlFrame):
"""This frame is used by the pipeline task as a mechanism to know if the
pipeline is running properly.
"""
timestamp: int
@dataclass
class FrameProcessorPauseFrame(ControlFrame):
"""This processor is used to pause frame processing for the given
"""This frame is used to pause frame processing for the given
processor. Pausing frame processing will keep frames in the internal queue
which will then be processed when frame processing is resumed with
`FrameProcessorResumeFrame`."""
`FrameProcessorResumeFrame`.
processor: str
"""
processor: "FrameProcessor"
@dataclass
class FrameProcessorResumeFrame(ControlFrame):
"""This processor is used to resume frame processing for the given processor
if it was previously paused. After resuming frame processing all queued
frames will be processed in the order received.
"""This frame is used to resume frame processing for the given processor if
it was previously paused. After resuming frame processing all queued frames
will be processed in the order received.
"""
processor: str
processor: "FrameProcessor"
@dataclass

View File

@@ -12,6 +12,8 @@ from loguru import logger
from pipecat.frames.frames import (
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
CancelFrame,
EndFrame,
StartFrame,
UserStartedSpeakingFrame,
)
@@ -73,6 +75,8 @@ class TurnTrackingObserver(BaseObserver):
# We only want to end the turn if the bot was previously speaking
elif isinstance(data.frame, BotStoppedSpeakingFrame) and self._is_bot_speaking:
await self._handle_bot_stopped_speaking(data)
elif isinstance(data.frame, (EndFrame, CancelFrame)):
await self._handle_pipeline_end(data)
def _schedule_turn_end(self, data: FramePushed):
"""Schedule turn end with a timeout."""
@@ -134,6 +138,14 @@ class TurnTrackingObserver(BaseObserver):
# This can happen with HTTP TTS services or function calls
self._schedule_turn_end(data)
async def _handle_pipeline_end(self, data: FramePushed):
"""Handle pipeline end or cancellation by flushing any active turn."""
if self._is_turn_active:
# Cancel any pending turn end timer
self._cancel_turn_end_timer()
# End the current turn
await self._end_turn(data, was_interrupted=True)
async def _start_turn(self, data: FramePushed):
"""Start a new turn."""
self._is_turn_active = True

View File

@@ -6,18 +6,21 @@
import asyncio
from abc import abstractmethod
from dataclasses import dataclass
from typing import AsyncIterable, Iterable
from pipecat.frames.frames import Frame
from pipecat.utils.base_object import BaseObject
class BaseTask(BaseObject):
@abstractmethod
def set_event_loop(self, loop: asyncio.AbstractEventLoop):
"""Sets the event loop that this task will run on."""
pass
@dataclass
class PipelineTaskParams:
"""Specific configuration for the pipeline task."""
loop: asyncio.AbstractEventLoop
class BasePipelineTask(BaseObject):
@abstractmethod
def has_finished(self) -> bool:
"""Indicates whether the tasks has finished. That is, all processors
@@ -40,7 +43,7 @@ class BaseTask(BaseObject):
pass
@abstractmethod
async def run(self):
async def run(self, params: PipelineTaskParams):
"""Starts running the given pipeline."""
pass

View File

@@ -21,6 +21,7 @@ from pipecat.frames.frames import (
from pipecat.pipeline.base_pipeline import BasePipeline
from pipecat.pipeline.pipeline import Pipeline
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor, FrameProcessorSetup
from pipecat.utils.asyncio.watchdog_queue import WatchdogQueue
class ParallelPipelineSource(FrameProcessor):
@@ -76,20 +77,36 @@ class ParallelPipeline(BasePipeline):
if len(args) == 0:
raise Exception(f"ParallelPipeline needs at least one argument")
self._args = args
self._sources = []
self._sinks = []
self._pipelines = []
self._seen_ids = set()
self._endframe_counter: Dict[int, int] = {}
self._up_task = None
self._down_task = None
self._up_queue = asyncio.Queue()
self._down_queue = asyncio.Queue()
self._pipelines = []
#
# BasePipeline
#
def processors_with_metrics(self) -> List[FrameProcessor]:
return list(chain.from_iterable(p.processors_with_metrics() for p in self._pipelines))
#
# Frame processor
#
async def setup(self, setup: FrameProcessorSetup):
await super().setup(setup)
self._up_queue = WatchdogQueue(setup.task_manager)
self._down_queue = WatchdogQueue(setup.task_manager)
logger.debug(f"Creating {self} pipelines")
for processors in args:
for processors in self._args:
if not isinstance(processors, list):
raise TypeError(f"ParallelPipeline argument {processors} is not a list")
@@ -107,19 +124,6 @@ class ParallelPipeline(BasePipeline):
logger.debug(f"Finished creating {self} pipelines")
#
# BasePipeline
#
def processors_with_metrics(self) -> List[FrameProcessor]:
return list(chain.from_iterable(p.processors_with_metrics() for p in self._pipelines))
#
# Frame processor
#
async def setup(self, setup: FrameProcessorSetup):
await super().setup(setup)
await asyncio.gather(*[s.setup(setup) for s in self._sources])
await asyncio.gather(*[p.setup(setup) for p in self._pipelines])
await asyncio.gather(*[s.setup(setup) for s in self._sinks])
@@ -134,7 +138,7 @@ class ParallelPipeline(BasePipeline):
await super().process_frame(frame, direction)
if isinstance(frame, StartFrame):
await self._start()
await self._start(frame)
elif isinstance(frame, EndFrame):
self._endframe_counter[frame.id] = len(self._pipelines)
elif isinstance(frame, CancelFrame):
@@ -154,7 +158,7 @@ class ParallelPipeline(BasePipeline):
elif isinstance(frame, EndFrame):
await self._stop()
async def _start(self):
async def _start(self, frame: StartFrame):
await self._create_tasks()
async def _stop(self):

View File

@@ -11,6 +11,7 @@ from typing import Optional
from loguru import logger
from pipecat.pipeline.base_task import PipelineTaskParams
from pipecat.pipeline.task import PipelineTask
from pipecat.utils.base_object import BaseObject
@@ -37,8 +38,8 @@ class PipelineRunner(BaseObject):
async def run(self, task: PipelineTask):
logger.debug(f"Runner {self} started running {task}")
self._tasks[task.name] = task
task.set_event_loop(self._loop)
await task.run()
params = PipelineTaskParams(loop=self._loop)
await task.run(params)
del self._tasks[task.name]
# Cleanup base object.

View File

@@ -15,6 +15,7 @@ from pipecat.frames.frames import ControlFrame, EndFrame, Frame, SystemFrame
from pipecat.pipeline.base_pipeline import BasePipeline
from pipecat.pipeline.pipeline import Pipeline
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor, FrameProcessorSetup
from pipecat.utils.asyncio.watchdog_queue import WatchdogQueue
@dataclass
@@ -61,15 +62,30 @@ class SyncParallelPipeline(BasePipeline):
if len(args) == 0:
raise Exception(f"SyncParallelPipeline needs at least one argument")
self._args = args
self._sinks = []
self._sources = []
self._pipelines = []
self._up_queue = asyncio.Queue()
self._down_queue = asyncio.Queue()
#
# BasePipeline
#
def processors_with_metrics(self) -> List[FrameProcessor]:
return list(chain.from_iterable(p.processors_with_metrics() for p in self._pipelines))
#
# Frame processor
#
async def setup(self, setup: FrameProcessorSetup):
await super().setup(setup)
self._up_queue = WatchdogQueue(setup.task_manager)
self._down_queue = WatchdogQueue(setup.task_manager)
logger.debug(f"Creating {self} pipelines")
for processors in args:
for processors in self._args:
if not isinstance(processors, list):
raise TypeError(f"SyncParallelPipeline argument {processors} is not a list")
@@ -92,19 +108,6 @@ class SyncParallelPipeline(BasePipeline):
logger.debug(f"Finished creating {self} pipelines")
#
# BasePipeline
#
def processors_with_metrics(self) -> List[FrameProcessor]:
return list(chain.from_iterable(p.processors_with_metrics() for p in self._pipelines))
#
# Frame processor
#
async def setup(self, setup: FrameProcessorSetup):
await super().setup(setup)
await asyncio.gather(*[s["processor"].setup(setup) for s in self._sources])
await asyncio.gather(*[p.setup(setup) for p in self._pipelines])
await asyncio.gather(*[s["processor"].setup(setup) for s in self._sinks])

View File

@@ -6,7 +6,8 @@
import asyncio
import time
from typing import Any, AsyncIterable, Dict, Iterable, List, Optional, Sequence, Tuple, Type
from collections import deque
from typing import Any, AsyncIterable, Deque, Dict, Iterable, List, Optional, Tuple, Type
from loguru import logger
from pydantic import BaseModel, ConfigDict, Field
@@ -23,6 +24,7 @@ from pipecat.frames.frames import (
ErrorFrame,
Frame,
HeartbeatFrame,
InputAudioRawFrame,
LLMFullResponseEndFrame,
MetricsFrame,
StartFrame,
@@ -33,19 +35,28 @@ from pipecat.metrics.metrics import ProcessingMetricsData, TTFBMetricsData
from pipecat.observers.base_observer import BaseObserver
from pipecat.observers.turn_tracking_observer import TurnTrackingObserver
from pipecat.pipeline.base_pipeline import BasePipeline
from pipecat.pipeline.base_task import BaseTask
from pipecat.pipeline.base_task import BasePipelineTask, PipelineTaskParams
from pipecat.pipeline.task_observer import TaskObserver
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor, FrameProcessorSetup
from pipecat.utils.asyncio import BaseTaskManager, TaskManager
from pipecat.utils.asyncio.task_manager import (
WATCHDOG_TIMEOUT,
BaseTaskManager,
TaskManager,
TaskManagerParams,
)
from pipecat.utils.asyncio.watchdog_queue import WatchdogQueue
from pipecat.utils.tracing.setup import is_tracing_available
from pipecat.utils.tracing.turn_trace_observer import TurnTraceObserver
HEARTBEAT_SECONDS = 1.0
HEARTBEAT_MONITOR_SECONDS = HEARTBEAT_SECONDS * 5
HEARTBEAT_MONITOR_SECONDS = HEARTBEAT_SECONDS * 10
class PipelineParams(BaseModel):
"""Configuration parameters for pipeline execution.
"""Configuration parameters for pipeline execution. These parameters are
usually passed to all frame processors using through `StartFrame`. For other
generic pipeline task parameters use `PipelineTask` constructor arguments
instead.
Attributes:
allow_interruptions: Whether to allow pipeline interruptions.
@@ -60,6 +71,7 @@ class PipelineParams(BaseModel):
send_initial_empty_metrics: Whether to send initial empty metrics.
start_metadata: Additional metadata for pipeline start.
interruption_strategies: Strategies for bot interruption behavior.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -71,11 +83,11 @@ class PipelineParams(BaseModel):
enable_metrics: bool = False
enable_usage_metrics: bool = False
heartbeats_period_secs: float = HEARTBEAT_SECONDS
interruption_strategies: List[BaseInterruptionStrategy] = Field(default_factory=list)
observers: List[BaseObserver] = Field(default_factory=list)
report_only_initial_ttfb: bool = False
send_initial_empty_metrics: bool = True
start_metadata: Dict[str, Any] = Field(default_factory=dict)
interruption_strategies: List[BaseInterruptionStrategy] = Field(default_factory=list)
class PipelineTaskSource(FrameProcessor):
@@ -125,7 +137,7 @@ class PipelineTaskSink(FrameProcessor):
await self._down_queue.put(frame)
class PipelineTask(BaseTask):
class PipelineTask(BasePipelineTask):
"""Manages the execution of a pipeline, handling frame processing and task lifecycle.
It has a couple of event handlers `on_frame_reached_upstream` and
@@ -172,21 +184,25 @@ class PipelineTask(BaseTask):
Args:
pipeline: The pipeline to execute.
params: Configuration parameters for the pipeline.
observers: List of observers for monitoring pipeline execution.
clock: Clock implementation for timing operations.
additional_span_attributes: Optional dictionary of attributes to propagate as
OpenTelemetry conversation span attributes.
cancel_on_idle_timeout: Whether the pipeline task should be cancelled if
the idle timeout is reached.
check_dangling_tasks: Whether to check for processors' tasks finishing properly.
clock: Clock implementation for timing operations.
conversation_id: Optional custom ID for the conversation.
enable_tracing: Whether to enable tracing.
enable_turn_tracking: Whether to enable turn tracking.
enable_watchdog_logging: Whether to print task processing times.
enable_watchdog_timers: Whether to enable task watchdog timers.
idle_timeout_frames: A tuple with the frames that should trigger an idle
timeout if not received withing `idle_timeout_seconds`.
idle_timeout_secs: Timeout (in seconds) to consider pipeline idle or
None. If a pipeline is idle the pipeline task will be cancelled
automatically.
idle_timeout_frames: A tuple with the frames that should trigger an idle
timeout if not received withing `idle_timeout_seconds`.
cancel_on_idle_timeout: Whether the pipeline task should be cancelled if
the idle timeout is reached.
enable_turn_tracking: Whether to enable turn tracking.
enable_turn_tracing: Whether to enable turn tracing.
conversation_id: Optional custom ID for the conversation.
additional_span_attributes: Optional dictionary of attributes to propagate as
OpenTelemetry conversation span attributes.
observers: List of observers for monitoring pipeline execution.
watchdog_timeout_secs: Watchdog timer timeout (in seconds). A warning
will be logged if the watchdog timer is not reset before this timeout.
"""
def __init__(
@@ -194,33 +210,39 @@ class PipelineTask(BaseTask):
pipeline: BasePipeline,
*,
params: Optional[PipelineParams] = None,
observers: Optional[List[BaseObserver]] = None,
clock: Optional[BaseClock] = None,
task_manager: Optional[BaseTaskManager] = None,
additional_span_attributes: Optional[dict] = None,
cancel_on_idle_timeout: bool = True,
check_dangling_tasks: bool = True,
idle_timeout_secs: Optional[float] = 300,
clock: Optional[BaseClock] = None,
conversation_id: Optional[str] = None,
enable_tracing: bool = False,
enable_turn_tracking: bool = True,
enable_watchdog_logging: bool = False,
enable_watchdog_timers: bool = False,
idle_timeout_frames: Tuple[Type[Frame], ...] = (
BotSpeakingFrame,
LLMFullResponseEndFrame,
),
cancel_on_idle_timeout: bool = True,
enable_turn_tracking: bool = True,
enable_tracing: bool = False,
conversation_id: Optional[str] = None,
additional_span_attributes: Optional[dict] = None,
idle_timeout_secs: Optional[float] = 300,
observers: Optional[List[BaseObserver]] = None,
task_manager: Optional[BaseTaskManager] = None,
watchdog_timeout_secs: float = WATCHDOG_TIMEOUT,
):
super().__init__()
self._pipeline = pipeline
self._clock = clock or SystemClock()
self._params = params or PipelineParams()
self._check_dangling_tasks = check_dangling_tasks
self._idle_timeout_secs = idle_timeout_secs
self._idle_timeout_frames = idle_timeout_frames
self._cancel_on_idle_timeout = cancel_on_idle_timeout
self._enable_turn_tracking = enable_turn_tracking
self._enable_tracing = enable_tracing and is_tracing_available()
self._conversation_id = conversation_id
self._additional_span_attributes = additional_span_attributes or {}
self._cancel_on_idle_timeout = cancel_on_idle_timeout
self._check_dangling_tasks = check_dangling_tasks
self._clock = clock or SystemClock()
self._conversation_id = conversation_id
self._enable_tracing = enable_tracing and is_tracing_available()
self._enable_turn_tracking = enable_turn_tracking
self._enable_watchdog_logging = enable_watchdog_logging
self._enable_watchdog_timers = enable_watchdog_timers
self._idle_timeout_frames = idle_timeout_frames
self._idle_timeout_secs = idle_timeout_secs
self._watchdog_timeout_secs = watchdog_timeout_secs
if self._params.observers:
import warnings
@@ -247,19 +269,29 @@ class PipelineTask(BaseTask):
self._finished = False
self._cancelled = False
# This task maneger will handle all the asyncio tasks created by this
# PipelineTask and its frame processors.
self._task_manager = task_manager or TaskManager()
# This queue receives frames coming from the pipeline upstream.
self._up_queue = asyncio.Queue()
self._up_queue = WatchdogQueue(self._task_manager)
self._process_up_task: Optional[asyncio.Task] = None
# This queue receives frames coming from the pipeline downstream.
self._down_queue = asyncio.Queue()
self._down_queue = WatchdogQueue(self._task_manager)
self._process_down_task: Optional[asyncio.Task] = None
# This queue is the queue used to push frames to the pipeline.
self._push_queue = asyncio.Queue()
self._push_queue = WatchdogQueue(self._task_manager)
self._process_push_task: Optional[asyncio.Task] = None
# This is the heartbeat queue. When a heartbeat frame is received in the
# down queue we add it to the heartbeat queue for processing.
self._heartbeat_queue = asyncio.Queue()
self._heartbeat_queue = WatchdogQueue(self._task_manager)
self._heartbeat_push_task: Optional[asyncio.Task] = None
self._heartbeat_monitor_task: Optional[asyncio.Task] = None
# This is the idle queue. When frames are received downstream they are
# put in the queue. If no frame is received the pipeline is considered
# idle.
self._idle_queue = asyncio.Queue()
self._idle_queue = WatchdogQueue(self._task_manager)
self._idle_monitor_task: Optional[asyncio.Task] = None
# This event is used to indicate a finalize frame (e.g. EndFrame,
# StopFrame) has been received in the down queue.
self._pipeline_end_event = asyncio.Event()
@@ -276,10 +308,6 @@ class PipelineTask(BaseTask):
self._sink = PipelineTaskSink(self._down_queue)
pipeline.link(self._sink)
# This task maneger will handle all the asyncio tasks created by this
# PipelineTask and its frame processors.
self._task_manager = task_manager or TaskManager()
# The task observer acts as a proxy to the provided observers. This way,
# we only need to pass a single observer (using the StartFrame) which
# then just acts as a proxy.
@@ -322,9 +350,6 @@ class PipelineTask(BaseTask):
async def remove_observer(self, observer: BaseObserver):
await self._observer.remove_observer(observer)
def set_event_loop(self, loop: asyncio.AbstractEventLoop):
self._task_manager.set_event_loop(loop)
def set_reached_upstream_filter(self, types: Tuple[Type[Frame], ...]):
"""Sets which frames will be checked before calling the
on_frame_reached_upstream event handler.
@@ -358,14 +383,14 @@ class PipelineTask(BaseTask):
"""Stops the running pipeline immediately."""
await self._cancel()
async def run(self):
async def run(self, params: PipelineTaskParams):
"""Starts and manages the pipeline execution until completion or cancellation."""
if self.has_finished():
return
cleanup_pipeline = True
try:
# Setup processors.
await self._setup()
await self._setup(params)
# Create all main tasks and wait of the main push task. This is the
# task that pushes frames to the very beginning of our pipeline (our
@@ -423,7 +448,9 @@ class PipelineTask(BaseTask):
# we want to cancel right away.
await self._source.push_frame(CancelFrame())
# Only cancel the push task. Everything else will be cancelled in run().
await self._task_manager.cancel_task(self._process_push_task)
if self._process_push_task:
await self._task_manager.cancel_task(self._process_push_task)
self._process_push_task = None
async def _create_tasks(self):
self._process_up_task = self._task_manager.create_task(
@@ -441,7 +468,7 @@ class PipelineTask(BaseTask):
return self._process_push_task
def _maybe_start_heartbeat_tasks(self):
if self._params.enable_heartbeats:
if self._params.enable_heartbeats and self._heartbeat_push_task is None:
self._heartbeat_push_task = self._task_manager.create_task(
self._heartbeat_push_handler(), f"{self}::_heartbeat_push_handler"
)
@@ -458,20 +485,33 @@ class PipelineTask(BaseTask):
async def _cancel_tasks(self):
await self._observer.stop()
await self._task_manager.cancel_task(self._process_up_task)
await self._task_manager.cancel_task(self._process_down_task)
if self._process_up_task:
await self._task_manager.cancel_task(self._process_up_task)
self._process_up_task = None
if self._process_down_task:
await self._task_manager.cancel_task(self._process_down_task)
self._process_down_task = None
await self._maybe_cancel_heartbeat_tasks()
await self._maybe_cancel_idle_task()
async def _maybe_cancel_heartbeat_tasks(self):
if self._params.enable_heartbeats:
if not self._params.enable_heartbeats:
return
if self._heartbeat_push_task:
await self._task_manager.cancel_task(self._heartbeat_push_task)
self._heartbeat_push_task = None
if self._heartbeat_monitor_task:
await self._task_manager.cancel_task(self._heartbeat_monitor_task)
self._heartbeat_monitor_task = None
async def _maybe_cancel_idle_task(self):
if self._idle_timeout_secs:
if self._idle_timeout_secs and self._idle_monitor_task:
await self._task_manager.cancel_task(self._idle_monitor_task)
self._idle_monitor_task = None
def _initial_metrics_frame(self) -> MetricsFrame:
processors = self._pipeline.processors_with_metrics()
@@ -485,11 +525,20 @@ class PipelineTask(BaseTask):
await self._pipeline_end_event.wait()
self._pipeline_end_event.clear()
async def _setup(self):
async def _setup(self, params: PipelineTaskParams):
mgr_params = TaskManagerParams(
loop=params.loop,
enable_watchdog_logging=self._enable_watchdog_logging,
enable_watchdog_timers=self._enable_watchdog_timers,
watchdog_timeout=self._watchdog_timeout_secs,
)
self._task_manager.setup(mgr_params)
setup = FrameProcessorSetup(
clock=self._clock,
task_manager=self._task_manager,
observer=self._observer,
watchdog_timers_enabled=self._enable_watchdog_timers,
)
await self._source.setup(setup)
await self._pipeline.setup(setup)
@@ -517,7 +566,6 @@ class PipelineTask(BaseTask):
"""
self._clock.start()
self._maybe_start_heartbeat_tasks()
self._maybe_start_idle_task()
start_frame = StartFrame(
@@ -599,6 +647,10 @@ class PipelineTask(BaseTask):
if isinstance(frame, StartFrame):
await self._call_event_handler("on_pipeline_started", frame)
# Start heartbeat tasks now that StartFrame has been processed
# by all processors in the pipeline
self._maybe_start_heartbeat_tasks()
elif isinstance(frame, EndFrame):
await self._call_event_handler("on_pipeline_ended", frame)
self._pipeline_end_event.set()
@@ -646,12 +698,17 @@ class PipelineTask(BaseTask):
"""
running = True
last_frame_time = 0
frame_buffer = deque(maxlen=10) # Store last 10 frames
while running:
try:
frame = await asyncio.wait_for(
self._idle_queue.get(), timeout=self._idle_timeout_secs
)
if not isinstance(frame, InputAudioRawFrame):
frame_buffer.append(frame)
if isinstance(frame, StartFrame) or isinstance(frame, self._idle_timeout_frames):
# If we find a StartFrame or one of the frames that prevents a
# time out we update the time.
@@ -662,7 +719,7 @@ class PipelineTask(BaseTask):
# valid frames.
diff_time = time.time() - last_frame_time
if diff_time >= self._idle_timeout_secs:
running = await self._idle_timeout_detected()
running = await self._idle_timeout_detected(frame_buffer)
# Reset `last_frame_time` so we don't trigger another
# immediate idle timeout if we are not cancelling. For
# example, we might want to force the bot to say goodbye
@@ -670,15 +727,20 @@ class PipelineTask(BaseTask):
last_frame_time = time.time()
self._idle_queue.task_done()
except asyncio.TimeoutError:
running = await self._idle_timeout_detected()
async def _idle_timeout_detected(self) -> bool:
except asyncio.TimeoutError:
running = await self._idle_timeout_detected(frame_buffer)
async def _idle_timeout_detected(self, last_frames: Deque[Frame]) -> bool:
"""Logic for when the pipeline is idle.
Returns:
bool: Whther the pipeline task is being cancelled or not.
"""
logger.warning("Idle timeout detected. Last 10 frames received:")
for i, frame in enumerate(last_frames, 1):
logger.warning(f"Frame {i}: {frame}")
await self._call_event_handler("on_idle_timeout")
if self._cancel_on_idle_timeout:
logger.warning(f"Idle pipeline detected, cancelling pipeline task...")

View File

@@ -11,7 +11,8 @@ from typing import Dict, List, Optional
from attr import dataclass
from pipecat.observers.base_observer import BaseObserver, FramePushed
from pipecat.utils.asyncio import BaseTaskManager
from pipecat.utils.asyncio.task_manager import BaseTaskManager
from pipecat.utils.asyncio.watchdog_queue import WatchdogQueue
@dataclass
@@ -82,6 +83,9 @@ class TaskObserver(BaseObserver):
async def stop(self):
"""Stops all proxy observer tasks."""
if not self._proxies:
return
for proxy in self._proxies.values():
await self._task_manager.cancel_task(proxy.task)
@@ -93,7 +97,7 @@ class TaskObserver(BaseObserver):
return self._proxies is not None
def _create_proxy(self, observer: BaseObserver) -> Proxy:
queue = asyncio.Queue()
queue = WatchdogQueue(self._task_manager)
task = self._task_manager.create_task(
self._proxy_task_handler(queue, observer),
f"TaskObserver::{observer}::_proxy_task_handler",

View File

@@ -119,6 +119,7 @@ class DTMFAggregator(FrameProcessor):
await asyncio.wait_for(self._digit_event.wait(), timeout=self._idle_timeout)
self._digit_event.clear()
except asyncio.TimeoutError:
self.reset_watchdog()
if self._aggregation:
await self._flush_aggregation()

View File

@@ -266,6 +266,7 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
self._user_speaking = False
self._bot_speaking = False
self._was_bot_speaking = False
self._emulating_vad = False
self._seen_interim_results = False
self._waiting_for_aggregation = False
@@ -275,6 +276,7 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
async def reset(self):
await super().reset()
self._was_bot_speaking = False
self._seen_interim_results = False
self._waiting_for_aggregation = False
[await s.reset() for s in self._interruption_strategies]
@@ -355,6 +357,20 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
else:
# No interruption config - normal behavior (always push aggregation)
await self._process_aggregation()
# Handles the case where both the user and the bot are not speaking,
# and the bot was previously speaking before the user interruption.
# Normally, when the user stops speaking, new text is expected,
# which triggers the bot to respond. However, if no new text
# is received, this safeguard ensures
# the bot doesn't hang indefinitely while waiting to speak again.
elif not self._seen_interim_results and self._was_bot_speaking and not self._bot_speaking:
logger.warning("User stopped speaking but no new aggregation received.")
# Resetting it so we don't trigger this twice
self._was_bot_speaking = False
# TODO: we are not enabling this for now, due to some STT services which can take as long as 2 seconds two return a transcription
# So we need more tests and probably make this feature configurable, disabled it by default.
# We are just pushing the same previous context to be processed again in this case
# await self.push_frame(OpenAILLMContextFrame(self._context))
async def _should_interrupt_based_on_strategies(self) -> bool:
"""Check if interruption should occur based on configured strategies."""
@@ -381,6 +397,7 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
async def _handle_user_started_speaking(self, frame: UserStartedSpeakingFrame):
self._user_speaking = True
self._waiting_for_aggregation = True
self._was_bot_speaking = self._bot_speaking
# If we get a non-emulated UserStartedSpeakingFrame but we are in the
# middle of emulating VAD, let's stop emulating VAD (i.e. don't send the
@@ -393,8 +410,15 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
# We just stopped speaking. Let's see if there's some aggregation to
# push. If the last thing we saw is an interim transcription, let's wait
# pushing the aggregation as we will probably get a final transcription.
if not self._seen_interim_results:
await self.push_aggregation()
if len(self._aggregation) > 0:
if not self._seen_interim_results:
await self.push_aggregation()
# Handles the case where both the user and the bot are not speaking,
# and the bot was previously speaking before the user interruption.
# So in this case we are resetting the aggregation timer
elif not self._seen_interim_results and self._was_bot_speaking and not self._bot_speaking:
# Reset aggregation timer.
self._aggregation_event.set()
async def _handle_bot_started_speaking(self, _: BotStartedSpeakingFrame):
self._bot_speaking = True
@@ -446,6 +470,7 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
)
self._emulating_vad = False
finally:
self.reset_watchdog()
self._aggregation_event.clear()
async def _maybe_emulate_user_speaking(self):

View File

@@ -10,6 +10,7 @@ from typing import Awaitable, Callable, Optional
from pipecat.frames.frames import CancelFrame, EndFrame, Frame, StartFrame
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.processors.producer_processor import ProducerProcessor, identity_transformer
from pipecat.utils.asyncio.watchdog_queue import WatchdogQueue
class ConsumerProcessor(FrameProcessor):
@@ -31,7 +32,7 @@ class ConsumerProcessor(FrameProcessor):
super().__init__(**kwargs)
self._transformer = transformer
self._direction = direction
self._queue: asyncio.Queue = producer.add_consumer()
self._producer = producer
self._consumer_task: Optional[asyncio.Task] = None
async def process_frame(self, frame: Frame, direction: FrameDirection):
@@ -48,6 +49,7 @@ class ConsumerProcessor(FrameProcessor):
async def _start(self, _: StartFrame):
if not self._consumer_task:
self._queue: WatchdogQueue = self._producer.add_consumer()
self._consumer_task = self.create_task(self._consumer_task_handler())
async def _stop(self, _: EndFrame):

View File

@@ -29,7 +29,9 @@ from pipecat.frames.frames import (
from pipecat.metrics.metrics import LLMTokenUsage, MetricsData
from pipecat.observers.base_observer import BaseObserver, FramePushed
from pipecat.processors.metrics.frame_processor_metrics import FrameProcessorMetrics
from pipecat.utils.asyncio import BaseTaskManager
from pipecat.utils.asyncio.task_manager import BaseTaskManager
from pipecat.utils.asyncio.watchdog_event import WatchdogEvent
from pipecat.utils.asyncio.watchdog_queue import WatchdogQueue
from pipecat.utils.base_object import BaseObject
@@ -43,6 +45,7 @@ class FrameProcessorSetup:
clock: BaseClock
task_manager: BaseTaskManager
observer: Optional[BaseObserver] = None
watchdog_timers_enabled: bool = False
class FrameProcessor(BaseObject):
@@ -50,7 +53,10 @@ class FrameProcessor(BaseObject):
self,
*,
name: Optional[str] = None,
enable_watchdog_logging: Optional[bool] = None,
enable_watchdog_timers: Optional[bool] = None,
metrics: Optional[FrameProcessorMetrics] = None,
watchdog_timeout_secs: Optional[float] = None,
**kwargs,
):
super().__init__(name=name)
@@ -58,6 +64,15 @@ class FrameProcessor(BaseObject):
self._prev: Optional["FrameProcessor"] = None
self._next: Optional["FrameProcessor"] = None
# Enable watchdog timers for all tasks created by this frame processor.
self._enable_watchdog_timers = enable_watchdog_timers
# Enable watchdog logging for all tasks created by this frame processor.
self._enable_watchdog_logging = enable_watchdog_logging
# Allow this frame processor to control their tasks timeout.
self._watchdog_timeout_secs = watchdog_timeout_secs
# Clock
self._clock: Optional[BaseClock] = None
@@ -93,7 +108,7 @@ class FrameProcessor(BaseObject):
# is called. To resume processing frames we need to call
# `resume_processing_frames()` which will wake up the event.
self.__should_block_frames = False
self.__input_event = asyncio.Event()
self.__input_event = None
self.__input_frame_task: Optional[asyncio.Task] = None
# Every processor in Pipecat should only output frames from a single
@@ -129,6 +144,12 @@ class FrameProcessor(BaseObject):
def interruption_strategies(self) -> Sequence[BaseInterruptionStrategy]:
return self._interruption_strategies
@property
def task_manager(self) -> BaseTaskManager:
if not self._task_manager:
raise Exception(f"{self} TaskManager is still not initialized.")
return self._task_manager
def can_generate_metrics(self) -> bool:
return False
@@ -171,34 +192,62 @@ class FrameProcessor(BaseObject):
await self.stop_ttfb_metrics()
await self.stop_processing_metrics()
def create_task(self, coroutine: Coroutine, name: Optional[str] = None) -> asyncio.Task:
if not self._task_manager:
raise Exception(f"{self} TaskManager is still not initialized.")
def create_task(
self,
coroutine: Coroutine,
name: Optional[str] = None,
*,
enable_watchdog_logging: Optional[bool] = None,
enable_watchdog_timers: Optional[bool] = None,
watchdog_timeout_secs: Optional[float] = None,
) -> asyncio.Task:
if name:
name = f"{self}::{name}"
else:
name = f"{self}::{coroutine.cr_code.co_name}"
return self._task_manager.create_task(coroutine, name)
return self.task_manager.create_task(
coroutine,
name,
enable_watchdog_logging=(
enable_watchdog_logging
if enable_watchdog_logging
else self._enable_watchdog_logging
),
enable_watchdog_timers=(
enable_watchdog_timers if enable_watchdog_timers else self._enable_watchdog_timers
),
watchdog_timeout=(
watchdog_timeout_secs if watchdog_timeout_secs else self._watchdog_timeout_secs
),
)
async def cancel_task(self, task: asyncio.Task, timeout: Optional[float] = None):
if not self._task_manager:
raise Exception(f"{self} TaskManager is still not initialized.")
await self._task_manager.cancel_task(task, timeout)
await self.task_manager.cancel_task(task, timeout)
async def wait_for_task(self, task: asyncio.Task, timeout: Optional[float] = None):
if not self._task_manager:
raise Exception(f"{self} TaskManager is still not initialized.")
await self._task_manager.wait_for_task(task, timeout)
await self.task_manager.wait_for_task(task, timeout)
def reset_watchdog(self):
self.task_manager.task_reset_watchdog()
async def setup(self, setup: FrameProcessorSetup):
self._clock = setup.clock
self._task_manager = setup.task_manager
self._observer = setup.observer
self._watchdog_timers_enabled = (
self._enable_watchdog_timers
if self._enable_watchdog_timers
else setup.watchdog_timers_enabled
)
if self._metrics is not None:
await self._metrics.setup(self._task_manager)
async def cleanup(self):
await super().cleanup()
await self.__cancel_input_task()
await self.__cancel_push_task()
if self._metrics is not None:
await self._metrics.cleanup()
def link(self, processor: "FrameProcessor"):
self._next = processor
@@ -206,9 +255,7 @@ class FrameProcessor(BaseObject):
logger.debug(f"Linking {self} -> {self._next}")
def get_event_loop(self) -> asyncio.AbstractEventLoop:
if not self._task_manager:
raise Exception(f"{self} TaskManager is still not initialized.")
return self._task_manager.get_event_loop()
return self.task_manager.get_event_loop()
def set_parent(self, parent: "FrameProcessor"):
self._parent = parent
@@ -221,11 +268,6 @@ class FrameProcessor(BaseObject):
raise Exception(f"{self} Clock is still not initialized.")
return self._clock
def get_task_manager(self) -> BaseTaskManager:
if not self._task_manager:
raise Exception(f"{self} TaskManager is still not initialized.")
return self._task_manager
async def queue_frame(
self,
frame: Frame,
@@ -251,7 +293,8 @@ class FrameProcessor(BaseObject):
async def resume_processing_frames(self):
logger.trace(f"{self}: resuming frame processing")
self.__input_event.set()
if self.__input_event:
self.__input_event.set()
async def process_frame(self, frame: Frame, direction: FrameDirection):
if isinstance(frame, StartFrame):
@@ -285,8 +328,8 @@ class FrameProcessor(BaseObject):
self._allow_interruptions = frame.allow_interruptions
self._enable_metrics = frame.enable_metrics
self._enable_usage_metrics = frame.enable_usage_metrics
self._report_only_initial_ttfb = frame.report_only_initial_ttfb
self._interruption_strategies = frame.interruption_strategies
self._report_only_initial_ttfb = frame.report_only_initial_ttfb
self.__create_input_task()
self.__create_push_task()
@@ -296,11 +339,11 @@ class FrameProcessor(BaseObject):
await self.__cancel_push_task()
async def __pause(self, frame: FrameProcessorPauseFrame | FrameProcessorPauseUrgentFrame):
if frame.name == self.name:
if frame.processor.name == self.name:
await self.pause_processing_frames()
async def __resume(self, frame: FrameProcessorResumeFrame | FrameProcessorResumeUrgentFrame):
if frame.name == self.name:
if frame.processor.name == self.name:
await self.resume_processing_frames()
#
@@ -315,9 +358,8 @@ class FrameProcessor(BaseObject):
# Cancel the input task. This will stop processing queued frames.
await self.__cancel_input_task()
except Exception as e:
logger.exception(f"Uncaught exception in {self}: {e}")
logger.exception(f"Uncaught exception in {self} when handling _start_interruption: {e}")
await self.push_error(ErrorFrame(str(e)))
raise
# Create a new input queue and task.
self.__create_input_task()
@@ -360,7 +402,6 @@ class FrameProcessor(BaseObject):
except Exception as e:
logger.exception(f"Uncaught exception in {self}: {e}")
await self.push_error(ErrorFrame(str(e)))
raise
def _check_started(self, frame: Frame):
if not self.__started:
@@ -370,8 +411,10 @@ class FrameProcessor(BaseObject):
def __create_input_task(self):
if not self.__input_frame_task:
self.__should_block_frames = False
if not self.__input_event:
self.__input_event = WatchdogEvent(self.task_manager)
self.__input_event.clear()
self.__input_queue = asyncio.Queue()
self.__input_queue = WatchdogQueue(self.task_manager)
self.__input_frame_task = self.create_task(self.__input_frame_task_handler())
async def __cancel_input_task(self):
@@ -381,7 +424,7 @@ class FrameProcessor(BaseObject):
async def __input_frame_task_handler(self):
while True:
if self.__should_block_frames:
if self.__should_block_frames and self.__input_event:
logger.trace(f"{self}: frame processing paused")
await self.__input_event.wait()
self.__input_event.clear()
@@ -389,19 +432,21 @@ class FrameProcessor(BaseObject):
logger.trace(f"{self}: frame processing resumed")
(frame, direction, callback) = await self.__input_queue.get()
# Process the frame.
await self.process_frame(frame, direction)
# If this frame has an associated callback, call it now.
if callback:
await callback(self, frame, direction)
self.__input_queue.task_done()
try:
# Process the frame.
await self.process_frame(frame, direction)
# If this frame has an associated callback, call it now.
if callback:
await callback(self, frame, direction)
except Exception as e:
logger.exception(f"{self}: error processing frame: {e}")
await self.push_error(ErrorFrame(str(e)))
finally:
self.__input_queue.task_done()
def __create_push_task(self):
if not self.__push_frame_task:
self.__push_queue = asyncio.Queue()
self.__push_queue = WatchdogQueue(self.task_manager)
self.__push_frame_task = self.create_task(self.__push_frame_task_handler())
async def __cancel_push_task(self):

View File

@@ -67,6 +67,7 @@ from pipecat.services.llm_service import (
from pipecat.transports.base_input import BaseInputTransport
from pipecat.transports.base_output import BaseOutputTransport
from pipecat.transports.base_transport import BaseTransport
from pipecat.utils.asyncio.watchdog_queue import WatchdogQueue
from pipecat.utils.string import match_endofsentence
RTVI_PROTOCOL_VERSION = "0.3.0"
@@ -650,11 +651,9 @@ class RTVIProcessor(FrameProcessor):
self._registered_services: Dict[str, RTVIService] = {}
# A task to process incoming action frames.
self._action_queue = asyncio.Queue()
self._action_task: Optional[asyncio.Task] = None
# A task to process incoming transport messages.
self._message_queue = asyncio.Queue()
self._message_task: Optional[asyncio.Task] = None
self._register_event_handler("on_bot_started")
@@ -756,8 +755,10 @@ class RTVIProcessor(FrameProcessor):
async def _start(self, frame: StartFrame):
if not self._action_task:
self._action_queue = WatchdogQueue(self.task_manager)
self._action_task = self.create_task(self._action_task_handler())
if not self._message_task:
self._message_queue = WatchdogQueue(self.task_manager)
self._message_task = self.create_task(self._message_task_handler())
await self._call_event_handler("on_bot_started")

View File

@@ -18,15 +18,29 @@ from pipecat.metrics.metrics import (
TTFBMetricsData,
TTSUsageMetricsData,
)
from pipecat.utils.asyncio.task_manager import BaseTaskManager
from pipecat.utils.base_object import BaseObject
class FrameProcessorMetrics:
class FrameProcessorMetrics(BaseObject):
def __init__(self):
super().__init__()
self._task_manager = None
self._start_ttfb_time = 0
self._start_processing_time = 0
self._last_ttfb_time = 0
self._should_report_ttfb = True
async def setup(self, task_manager: BaseTaskManager):
self._task_manager = task_manager
async def cleanup(self):
await super().cleanup()
@property
def task_manager(self) -> BaseTaskManager:
return self._task_manager
@property
def ttfb(self) -> Optional[float]:
"""Get the current TTFB value in seconds.

View File

@@ -4,8 +4,13 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
from loguru import logger
from pipecat.utils.asyncio.task_manager import BaseTaskManager
from pipecat.utils.asyncio.watchdog_queue import WatchdogQueue
try:
import sentry_sdk
except ModuleNotFoundError as e:
@@ -24,6 +29,24 @@ class SentryMetrics(FrameProcessorMetrics):
self._sentry_available = sentry_sdk.is_initialized()
if not self._sentry_available:
logger.warning("Sentry SDK not initialized. Sentry features will be disabled.")
self._sentry_task = None
async def setup(self, task_manager: BaseTaskManager):
await super().setup(task_manager)
if self._sentry_available:
self._sentry_queue = WatchdogQueue(task_manager)
self._sentry_task = self.task_manager.create_task(
self._sentry_task_handler(), name=f"{self}::_sentry_task_handler"
)
async def cleanup(self):
await super().cleanup()
if self._sentry_task:
await self._sentry_queue.put(None)
await self.task_manager.wait_for_task(self._sentry_task)
self._sentry_task = None
logger.trace(f"{self} Flushing Sentry metrics")
sentry_sdk.flush(timeout=5.0)
async def start_ttfb_metrics(self, report_only_initial_ttfb):
await super().start_ttfb_metrics(report_only_initial_ttfb)
@@ -34,14 +57,15 @@ class SentryMetrics(FrameProcessorMetrics):
name=f"TTFB for {self._processor_name()}",
)
logger.debug(
f"Sentry transaction started (ID: {self._ttfb_metrics_tx.span_id} Name: {self._ttfb_metrics_tx.name})"
f"{self} Sentry transaction started (ID: {self._ttfb_metrics_tx.span_id} Name: {self._ttfb_metrics_tx.name})"
)
async def stop_ttfb_metrics(self):
await super().stop_ttfb_metrics()
if self._sentry_available and self._ttfb_metrics_tx:
self._ttfb_metrics_tx.finish()
await self._sentry_queue.put(self._ttfb_metrics_tx)
self._ttfb_metrics_tx = None
async def start_processing_metrics(self):
await super().start_processing_metrics()
@@ -52,11 +76,21 @@ class SentryMetrics(FrameProcessorMetrics):
name=f"Processing for {self._processor_name()}",
)
logger.debug(
f"Sentry transaction started (ID: {self._processing_metrics_tx.span_id} Name: {self._processing_metrics_tx.name})"
f"{self} Sentry transaction started (ID: {self._processing_metrics_tx.span_id} Name: {self._processing_metrics_tx.name})"
)
async def stop_processing_metrics(self):
await super().stop_processing_metrics()
if self._sentry_available and self._processing_metrics_tx:
self._processing_metrics_tx.finish()
await self._sentry_queue.put(self._processing_metrics_tx)
self._processing_metrics_tx = None
async def _sentry_task_handler(self):
running = True
while running:
tx = await self._sentry_queue.get()
if tx:
await self.task_manager.get_event_loop().run_in_executor(None, tx.finish)
running = tx is not None
self._sentry_queue.task_done()

View File

@@ -9,6 +9,7 @@ from typing import Awaitable, Callable, List
from pipecat.frames.frames import Frame
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.utils.asyncio.watchdog_queue import WatchdogQueue
async def identity_transformer(frame: Frame):
@@ -43,7 +44,7 @@ class ProducerProcessor(FrameProcessor):
Returns:
asyncio.Queue: The queue for the newly added consumer.
"""
queue = asyncio.Queue()
queue = WatchdogQueue(self.task_manager)
self._consumers.append(queue)
return queue

View File

@@ -196,8 +196,31 @@ class TelnyxFrameSerializer(FrameSerializer):
async with session.post(endpoint, headers=headers) as response:
if response.status == 200:
logger.info(f"Successfully terminated Telnyx call {call_control_id}")
elif response.status == 422:
# Handle the case where the call has already ended
# Error code 90018: "Call has already ended"
# Source: https://developers.telnyx.com/api/errors/90018
try:
error_data = await response.json()
if any(
error.get("code") == "90018"
for error in error_data.get("errors", [])
):
logger.debug(
f"Telnyx call {call_control_id} was already terminated"
)
return
except:
pass # Fall through to log the raw error
# Log other 422 errors
error_text = await response.text()
logger.error(
f"Failed to terminate Telnyx call {call_control_id}: "
f"Status {response.status}, Response: {error_text}"
)
else:
# Get the error details for better debugging
# Log other errors
error_text = await response.text()
logger.error(
f"Failed to terminate Telnyx call {call_control_id}: "

View File

@@ -4,6 +4,12 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Base AI service implementation.
Provides the foundation for all AI services in the Pipecat framework, including
model management, settings handling, and frame processing lifecycle methods.
"""
from typing import Any, AsyncGenerator, Dict, Mapping
from loguru import logger
@@ -20,6 +26,17 @@ from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
class AIService(FrameProcessor):
"""Base class for all AI services.
Provides common functionality for AI services including model management,
settings handling, session properties, and frame processing lifecycle.
Subclasses should implement specific AI functionality while leveraging
this base infrastructure.
Args:
**kwargs: Additional arguments passed to the parent FrameProcessor.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._model_name: str = ""
@@ -28,19 +45,53 @@ class AIService(FrameProcessor):
@property
def model_name(self) -> str:
"""Get the current model name.
Returns:
The name of the AI model being used.
"""
return self._model_name
def set_model_name(self, model: str):
"""Set the AI model name and update metrics.
Args:
model: The name of the AI model to use.
"""
self._model_name = model
self.set_core_metrics_data(MetricsData(processor=self.name, model=self._model_name))
async def start(self, frame: StartFrame):
"""Start the AI service.
Called when the service should begin processing. Subclasses should
override this method to perform service-specific initialization.
Args:
frame: The start frame containing initialization parameters.
"""
pass
async def stop(self, frame: EndFrame):
"""Stop the AI service.
Called when the service should stop processing. Subclasses should
override this method to perform cleanup operations.
Args:
frame: The end frame.
"""
pass
async def cancel(self, frame: CancelFrame):
"""Cancel the AI service.
Called when the service should cancel all operations. Subclasses should
override this method to handle cancellation logic.
Args:
frame: The cancel frame.
"""
pass
async def _update_settings(self, settings: Mapping[str, Any]):
@@ -87,6 +138,15 @@ class AIService(FrameProcessor):
logger.warning(f"Unknown setting for {self.name} service: {key}")
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process frames and handle service lifecycle.
Automatically handles StartFrame, EndFrame, and CancelFrame by calling
the appropriate lifecycle methods.
Args:
frame: The frame to process.
direction: The direction of frame processing.
"""
await super().process_frame(frame, direction)
if isinstance(frame, StartFrame):
@@ -97,6 +157,14 @@ class AIService(FrameProcessor):
await self.stop(frame)
async def process_generator(self, generator: AsyncGenerator[Frame | None, None]):
"""Process frames from an async generator.
Takes an async generator that yields frames and processes each one,
handling error frames specially by pushing them as errors.
Args:
generator: An async generator that yields Frame objects or None.
"""
async for f in generator:
if f:
if isinstance(f, ErrorFrame):

View File

@@ -4,6 +4,17 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Deprecated AI services module.
This module is deprecated. Import services directly from their respective modules:
- pipecat.services.ai_service
- pipecat.services.image_service
- pipecat.services.llm_service
- pipecat.services.stt_service
- pipecat.services.tts_service
- pipecat.services.vision_service
"""
import sys
from pipecat.services import DeprecatedModuleProxy

View File

@@ -4,6 +4,12 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Anthropic AI service integration for Pipecat.
This module provides LLM services and context management for Anthropic's Claude models,
including support for function calling, vision, and prompt caching features.
"""
import asyncio
import base64
import copy
@@ -46,6 +52,7 @@ from pipecat.processors.aggregators.openai_llm_context import (
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.llm_service import FunctionCallFromLLM, LLMService
from pipecat.utils.asyncio.watchdog_async_iterator import WatchdogAsyncIterator
from pipecat.utils.tracing.service_decorators import traced_llm
try:
@@ -58,27 +65,66 @@ except ModuleNotFoundError as e:
@dataclass
class AnthropicContextAggregatorPair:
"""Pair of context aggregators for Anthropic conversations.
Encapsulates both user and assistant context aggregators
to manage conversation flow and message formatting.
Parameters:
_user: The user context aggregator.
_assistant: The assistant context aggregator.
"""
_user: "AnthropicUserContextAggregator"
_assistant: "AnthropicAssistantContextAggregator"
def user(self) -> "AnthropicUserContextAggregator":
"""Get the user context aggregator.
Returns:
The user context aggregator instance.
"""
return self._user
def assistant(self) -> "AnthropicAssistantContextAggregator":
"""Get the assistant context aggregator.
Returns:
The assistant context aggregator instance.
"""
return self._assistant
class AnthropicLLMService(LLMService):
"""This class implements inference with Anthropic's AI models.
"""LLM service for Anthropic's Claude models.
Can provide a custom client via the `client` kwarg, allowing you to
use `AsyncAnthropicBedrock` and `AsyncAnthropicVertex` clients
Provides inference capabilities with Claude models including support for
function calling, vision processing, streaming responses, and prompt caching.
Can use custom clients like AsyncAnthropicBedrock and AsyncAnthropicVertex.
Args:
api_key: Anthropic API key for authentication.
model: Model name to use. Defaults to "claude-sonnet-4-20250514".
params: Optional model parameters for inference.
client: Optional custom Anthropic client instance.
**kwargs: Additional arguments passed to parent LLMService.
"""
# Overriding the default adapter to use the Anthropic one.
adapter_class = AnthropicLLMAdapter
class InputParams(BaseModel):
"""Input parameters for Anthropic model inference.
Parameters:
enable_prompt_caching_beta: Whether to enable beta prompt caching feature.
max_tokens: Maximum tokens to generate. Must be at least 1.
temperature: Sampling temperature between 0.0 and 1.0.
top_k: Top-k sampling parameter.
top_p: Top-p sampling parameter between 0.0 and 1.0.
extra: Additional parameters to pass to the API.
"""
enable_prompt_caching_beta: Optional[bool] = False
max_tokens: Optional[int] = Field(default_factory=lambda: 4096, ge=1)
temperature: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=1.0)
@@ -111,10 +157,20 @@ class AnthropicLLMService(LLMService):
}
def can_generate_metrics(self) -> bool:
"""Check if this service can generate usage metrics.
Returns:
True, as Anthropic provides detailed token usage metrics.
"""
return True
@property
def enable_prompt_caching_beta(self) -> bool:
"""Check if prompt caching beta feature is enabled.
Returns:
True if prompt caching is enabled.
"""
return self._enable_prompt_caching_beta
def create_context_aggregator(
@@ -124,22 +180,19 @@ class AnthropicLLMService(LLMService):
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
) -> AnthropicContextAggregatorPair:
"""Create an instance of AnthropicContextAggregatorPair from an
OpenAILLMContext. Constructor keyword arguments for both the user and
assistant aggregators can be provided.
"""Create Anthropic-specific context aggregators.
Creates a pair of context aggregators optimized for Anthropic's message format,
including support for function calls, tool usage, and image handling.
Args:
context (OpenAILLMContext): The LLM context.
user_params (LLMUserAggregatorParams, optional): User aggregator
parameters.
assistant_params (LLMAssistantAggregatorParams, optional): User
aggregator parameters.
context: The LLM context.
user_params: User aggregator parameters.
assistant_params: Assistant aggregator parameters.
Returns:
AnthropicContextAggregatorPair: A pair of context aggregators, one
for the user and one for the assistant, encapsulated in an
AnthropicContextAggregatorPair.
A pair of context aggregators, one for the user and one for the assistant,
encapsulated in an AnthropicContextAggregatorPair.
"""
context.set_llm_adapter(self.get_llm_adapter())
@@ -203,7 +256,7 @@ class AnthropicLLMService(LLMService):
json_accumulator = ""
function_calls = []
async for event in response:
async for event in WatchdogAsyncIterator(response, manager=self.task_manager):
# Aggregate streaming content, create frames, trigger events
if event.type == "content_block_delta":
@@ -307,6 +360,15 @@ class AnthropicLLMService(LLMService):
)
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process incoming frames and route them appropriately.
Handles various frame types including context frames, message frames,
vision frames, and settings updates.
Args:
frame: The frame to process.
direction: The direction of frame processing.
"""
await super().process_frame(frame, direction)
context = None
@@ -358,6 +420,19 @@ class AnthropicLLMService(LLMService):
class AnthropicLLMContext(OpenAILLMContext):
"""LLM context specialized for Anthropic's message format and features.
Extends OpenAILLMContext to handle Anthropic-specific features like
system messages, prompt caching, and message format conversions.
Manages conversation state and message history formatting.
Args:
messages: Initial list of conversation messages.
tools: Available function calling tools.
tool_choice: Tool selection preference.
system: System message content.
"""
def __init__(
self,
messages: Optional[List[dict]] = None,
@@ -378,6 +453,16 @@ class AnthropicLLMContext(OpenAILLMContext):
@staticmethod
def upgrade_to_anthropic(obj: OpenAILLMContext) -> "AnthropicLLMContext":
"""Upgrade an OpenAI context to Anthropic format.
Converts message format and restructures content for Anthropic compatibility.
Args:
obj: The OpenAI context to upgrade.
Returns:
The upgraded Anthropic context.
"""
logger.debug(f"Upgrading to Anthropic: {obj}")
if isinstance(obj, OpenAILLMContext) and not isinstance(obj, AnthropicLLMContext):
obj.__class__ = AnthropicLLMContext
@@ -386,6 +471,14 @@ class AnthropicLLMContext(OpenAILLMContext):
@classmethod
def from_openai_context(cls, openai_context: OpenAILLMContext):
"""Create Anthropic context from OpenAI context.
Args:
openai_context: The OpenAI context to convert.
Returns:
New Anthropic context with converted messages.
"""
self = cls(
messages=openai_context.messages,
tools=openai_context.tools,
@@ -397,12 +490,28 @@ class AnthropicLLMContext(OpenAILLMContext):
@classmethod
def from_messages(cls, messages: List[dict]) -> "AnthropicLLMContext":
"""Create context from a list of messages.
Args:
messages: List of conversation messages.
Returns:
New Anthropic context with the provided messages.
"""
self = cls(messages=messages)
self._restructure_from_openai_messages()
return self
@classmethod
def from_image_frame(cls, frame: VisionImageRawFrame) -> "AnthropicLLMContext":
"""Create context from a vision image frame.
Args:
frame: The vision image frame to process.
Returns:
New Anthropic context with the image message.
"""
context = cls()
context.add_image_frame_message(
format=frame.format, size=frame.size, image=frame.image, text=frame.text
@@ -410,11 +519,15 @@ class AnthropicLLMContext(OpenAILLMContext):
return context
def set_messages(self, messages: List):
"""Set the messages list and reset cache tracking.
Args:
messages: New list of messages to set.
"""
self.turns_above_cache_threshold = 0
self._messages[:] = messages
self._restructure_from_openai_messages()
# convert a message in Anthropic format into one or more messages in OpenAI format
def to_standard_messages(self, obj):
"""Convert Anthropic message format to standard structured format.
@@ -555,6 +668,17 @@ class AnthropicLLMContext(OpenAILLMContext):
def add_image_frame_message(
self, *, format: str, size: tuple[int, int], image: bytes, text: str = None
):
"""Add an image message to the context.
Converts the image to base64 JPEG format and adds it as a user message
with optional accompanying text.
Args:
format: The image format (e.g., 'RGB', 'RGBA').
size: Image dimensions as (width, height).
image: Raw image bytes.
text: Optional text to accompany the image.
"""
buffer = io.BytesIO()
Image.frombytes(format, size, image).save(buffer, format="JPEG")
encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
@@ -575,6 +699,14 @@ class AnthropicLLMContext(OpenAILLMContext):
self.add_message({"role": "user", "content": content})
def add_message(self, message):
"""Add a message to the context, merging with previous message if same role.
Anthropic requires alternating roles, so consecutive messages from the same
role are merged together.
Args:
message: The message to add to the context.
"""
try:
if self.messages:
# Anthropic requires that roles alternate. If this message's role is the same as the
@@ -600,6 +732,14 @@ class AnthropicLLMContext(OpenAILLMContext):
logger.error(f"Error adding message: {e}")
def get_messages_with_cache_control_markers(self) -> List[dict]:
"""Get messages with prompt caching markers applied.
Adds cache control markers to appropriate messages based on the
number of turns above the cache threshold.
Returns:
List of messages with cache control markers added.
"""
try:
messages = copy.deepcopy(self.messages)
if self.turns_above_cache_threshold >= 1 and messages[-1]["role"] == "user":
@@ -667,12 +807,26 @@ class AnthropicLLMContext(OpenAILLMContext):
message["content"] = [{"type": "text", "text": "(empty)"}]
def get_messages_for_persistent_storage(self):
"""Get messages formatted for persistent storage.
Includes system message at the beginning if present.
Returns:
List of messages suitable for storage.
"""
messages = super().get_messages_for_persistent_storage()
if self.system:
messages.insert(0, {"role": "system", "content": self.system})
return messages
def get_messages_for_logging(self) -> str:
"""Get messages formatted for logging with sensitive data redacted.
Replaces image data with placeholder text for cleaner logs.
Returns:
JSON string representation of messages for logging.
"""
msgs = []
for message in self.messages:
msg = copy.deepcopy(message)
@@ -686,6 +840,12 @@ class AnthropicLLMContext(OpenAILLMContext):
class AnthropicUserContextAggregator(LLMUserContextAggregator):
"""Anthropic-specific user context aggregator.
Handles aggregation of user messages for Anthropic LLM services.
Inherits all functionality from the base LLMUserContextAggregator.
"""
pass
@@ -700,7 +860,20 @@ class AnthropicUserContextAggregator(LLMUserContextAggregator):
class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
"""Context aggregator for assistant messages in Anthropic conversations.
Handles function call lifecycle management including in-progress tracking,
result handling, and cancellation for Anthropic's tool use format.
"""
async def handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame):
"""Handle a function call that is starting.
Creates tool use message and placeholder tool result for tracking.
Args:
frame: Frame containing function call details.
"""
assistant_message = {"role": "assistant", "content": []}
assistant_message["content"].append(
{
@@ -725,6 +898,13 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
)
async def handle_function_call_result(self, frame: FunctionCallResultFrame):
"""Handle the result of a completed function call.
Updates the tool result with actual return value or completion status.
Args:
frame: Frame containing function call result.
"""
if frame.result:
result = json.dumps(frame.result)
await self._update_function_call_result(frame.function_name, frame.tool_call_id, result)
@@ -734,6 +914,13 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
)
async def handle_function_call_cancel(self, frame: FunctionCallCancelFrame):
"""Handle cancellation of a function call.
Updates the tool result to indicate cancellation.
Args:
frame: Frame containing function call cancellation details.
"""
await self._update_function_call_result(
frame.function_name, frame.tool_call_id, "CANCELLED"
)
@@ -752,6 +939,14 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
content["content"] = result
async def handle_user_image_frame(self, frame: UserImageRawFrame):
"""Handle a user image frame with function call context.
Marks the associated function call as completed and adds the image
to the conversation context.
Args:
frame: User image frame with request context.
"""
await self._update_function_call_result(
frame.request.function_name, frame.request.tool_call_id, "COMPLETED"
)

View File

@@ -189,9 +189,11 @@ class AssemblyAISTTService(STTService):
try:
while self._connected:
try:
message = await self._websocket.recv()
message = await asyncio.wait_for(self._websocket.recv(), timeout=1.0)
data = json.loads(message)
await self._handle_message(data)
except asyncio.TimeoutError:
self.reset_watchdog()
except websockets.exceptions.ConnectionClosedOK:
break
except Exception as e:

View File

@@ -4,6 +4,13 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""AWS Bedrock integration for Large Language Model services.
This module provides AWS Bedrock LLM service implementation with support for
Amazon Nova and Anthropic Claude models, including vision capabilities and
function calling.
"""
import asyncio
import base64
import copy
@@ -61,17 +68,50 @@ except ModuleNotFoundError as e:
@dataclass
class AWSBedrockContextAggregatorPair:
"""Container for AWS Bedrock context aggregators.
Provides convenient access to both user and assistant context aggregators
for AWS Bedrock LLM operations.
Parameters:
_user: The user context aggregator instance.
_assistant: The assistant context aggregator instance.
"""
_user: "AWSBedrockUserContextAggregator"
_assistant: "AWSBedrockAssistantContextAggregator"
def user(self) -> "AWSBedrockUserContextAggregator":
"""Get the user context aggregator.
Returns:
The user context aggregator instance.
"""
return self._user
def assistant(self) -> "AWSBedrockAssistantContextAggregator":
"""Get the assistant context aggregator.
Returns:
The assistant context aggregator instance.
"""
return self._assistant
class AWSBedrockLLMContext(OpenAILLMContext):
"""AWS Bedrock-specific LLM context implementation.
Extends OpenAI LLM context to handle AWS Bedrock's specific message format
and system message handling. Manages conversion between OpenAI and Bedrock
message formats.
Args:
messages: List of conversation messages in OpenAI format.
tools: List of available function calling tools.
tool_choice: Tool selection strategy or specific tool choice.
system: System message content for AWS Bedrock.
"""
def __init__(
self,
messages: Optional[List[dict]] = None,
@@ -85,6 +125,14 @@ class AWSBedrockLLMContext(OpenAILLMContext):
@staticmethod
def upgrade_to_bedrock(obj: OpenAILLMContext) -> "AWSBedrockLLMContext":
"""Upgrade an OpenAI LLM context to AWS Bedrock format.
Args:
obj: The OpenAI LLM context to upgrade.
Returns:
The upgraded AWS Bedrock LLM context.
"""
logger.debug(f"Upgrading to AWS Bedrock: {obj}")
if isinstance(obj, OpenAILLMContext) and not isinstance(obj, AWSBedrockLLMContext):
obj.__class__ = AWSBedrockLLMContext
@@ -95,6 +143,14 @@ class AWSBedrockLLMContext(OpenAILLMContext):
@classmethod
def from_openai_context(cls, openai_context: OpenAILLMContext):
"""Create AWS Bedrock context from OpenAI context.
Args:
openai_context: The OpenAI LLM context to convert.
Returns:
New AWS Bedrock LLM context instance.
"""
self = cls(
messages=openai_context.messages,
tools=openai_context.tools,
@@ -106,12 +162,28 @@ class AWSBedrockLLMContext(OpenAILLMContext):
@classmethod
def from_messages(cls, messages: List[dict]) -> "AWSBedrockLLMContext":
"""Create AWS Bedrock context from message list.
Args:
messages: List of messages in OpenAI format.
Returns:
New AWS Bedrock LLM context instance.
"""
self = cls(messages=messages)
self._restructure_from_openai_messages()
return self
@classmethod
def from_image_frame(cls, frame: VisionImageRawFrame) -> "AWSBedrockLLMContext":
"""Create AWS Bedrock context from vision image frame.
Args:
frame: The vision image frame to convert.
Returns:
New AWS Bedrock LLM context instance.
"""
context = cls()
context.add_image_frame_message(
format=frame.format, size=frame.size, image=frame.image, text=frame.text
@@ -119,10 +191,14 @@ class AWSBedrockLLMContext(OpenAILLMContext):
return context
def set_messages(self, messages: List):
"""Set the messages list and restructure for Bedrock format.
Args:
messages: List of messages to set.
"""
self._messages[:] = messages
self._restructure_from_openai_messages()
# convert a message in AWS Bedrock format into one or more messages in OpenAI format
def to_standard_messages(self, obj):
"""Convert AWS Bedrock message format to standard structured format.
@@ -295,6 +371,14 @@ class AWSBedrockLLMContext(OpenAILLMContext):
def add_image_frame_message(
self, *, format: str, size: tuple[int, int], image: bytes, text: str = None
):
"""Add an image message to the context.
Args:
format: The image format (e.g., 'RGB', 'RGBA').
size: The image dimensions as (width, height).
image: The raw image data as bytes.
text: Optional text to accompany the image.
"""
buffer = io.BytesIO()
Image.frombytes(format, size, image).save(buffer, format="JPEG")
encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
@@ -306,6 +390,14 @@ class AWSBedrockLLMContext(OpenAILLMContext):
self.add_message({"role": "user", "content": content})
def add_message(self, message):
"""Add a message to the context, merging with previous message if same role.
AWS Bedrock requires alternating roles, so consecutive messages from the
same role are merged together.
Args:
message: The message to add to the context.
"""
try:
if self.messages:
# AWS Bedrock requires that roles alternate. If this message's
@@ -330,10 +422,10 @@ class AWSBedrockLLMContext(OpenAILLMContext):
logger.error(f"Error adding message: {e}")
def _restructure_from_bedrock_messages(self):
"""Restructure messages in AWS Bedrock format by handling system
messages, merging consecutive messages with the same role, and ensuring
proper content formatting.
"""Restructure messages in AWS Bedrock format.
Handles system messages, merging consecutive messages with the same role,
and ensuring proper content formatting.
"""
# Handle system message if present at the beginning
if self.messages and self.messages[0]["role"] == "system":
@@ -416,12 +508,22 @@ class AWSBedrockLLMContext(OpenAILLMContext):
message["content"] = [{"type": "text", "text": "(empty)"}]
def get_messages_for_persistent_storage(self):
"""Get messages formatted for persistent storage.
Returns:
List of messages including system message if present.
"""
messages = super().get_messages_for_persistent_storage()
if self.system:
messages.insert(0, {"role": "system", "content": self.system})
return messages
def get_messages_for_logging(self) -> str:
"""Get messages formatted for logging with sensitive data redacted.
Returns:
JSON string representation of messages with image data redacted.
"""
msgs = []
for message in self.messages:
msg = copy.deepcopy(message)
@@ -435,11 +537,36 @@ class AWSBedrockLLMContext(OpenAILLMContext):
class AWSBedrockUserContextAggregator(LLMUserContextAggregator):
"""User context aggregator for AWS Bedrock LLM service.
Handles aggregation of user messages and frames for AWS Bedrock format.
Inherits all functionality from the base LLM user context aggregator.
Args:
context: The LLM context to aggregate messages into.
params: Configuration parameters for the aggregator.
"""
pass
class AWSBedrockAssistantContextAggregator(LLMAssistantContextAggregator):
"""Assistant context aggregator for AWS Bedrock LLM service.
Handles aggregation of assistant responses and function calls for AWS Bedrock
format, including tool use and tool result handling.
Args:
context: The LLM context to aggregate messages into.
params: Configuration parameters for the aggregator.
"""
async def handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame):
"""Handle function call in progress frame.
Args:
frame: The function call in progress frame to handle.
"""
# Format tool use according to AWS Bedrock API
self._context.add_message(
{
@@ -470,6 +597,11 @@ class AWSBedrockAssistantContextAggregator(LLMAssistantContextAggregator):
)
async def handle_function_call_result(self, frame: FunctionCallResultFrame):
"""Handle function call result frame.
Args:
frame: The function call result frame to handle.
"""
if frame.result:
result = json.dumps(frame.result)
await self._update_function_call_result(frame.function_name, frame.tool_call_id, result)
@@ -479,6 +611,11 @@ class AWSBedrockAssistantContextAggregator(LLMAssistantContextAggregator):
)
async def handle_function_call_cancel(self, frame: FunctionCallCancelFrame):
"""Handle function call cancel frame.
Args:
frame: The function call cancel frame to handle.
"""
await self._update_function_call_result(
frame.function_name, frame.tool_call_id, "CANCELLED"
)
@@ -497,6 +634,11 @@ class AWSBedrockAssistantContextAggregator(LLMAssistantContextAggregator):
content["toolResult"]["content"] = [{"text": result}]
async def handle_user_image_frame(self, frame: UserImageRawFrame):
"""Handle user image frame.
Args:
frame: The user image frame to handle.
"""
await self._update_function_call_result(
frame.request.function_name, frame.request.tool_call_id, "COMPLETED"
)
@@ -509,18 +651,38 @@ class AWSBedrockAssistantContextAggregator(LLMAssistantContextAggregator):
class AWSBedrockLLMService(LLMService):
"""This class implements inference with AWS Bedrock models including Amazon
Nova and Anthropic Claude.
"""AWS Bedrock Large Language Model service implementation.
Requires AWS credentials to be configured in the environment or through
boto3 configuration.
Provides inference capabilities for AWS Bedrock models including Amazon Nova
and Anthropic Claude. Supports streaming responses, function calling, and
vision capabilities.
Args:
model: The AWS Bedrock model identifier to use.
aws_access_key: AWS access key ID. If None, uses default credentials.
aws_secret_key: AWS secret access key. If None, uses default credentials.
aws_session_token: AWS session token for temporary credentials.
aws_region: AWS region for the Bedrock service.
params: Model parameters and configuration.
client_config: Custom boto3 client configuration.
**kwargs: Additional arguments passed to parent LLMService.
"""
# Overriding the default adapter to use the Anthropic one.
adapter_class = AWSBedrockLLMAdapter
class InputParams(BaseModel):
"""Input parameters for AWS Bedrock LLM service.
Parameters:
max_tokens: Maximum number of tokens to generate.
temperature: Sampling temperature between 0.0 and 1.0.
top_p: Nucleus sampling parameter between 0.0 and 1.0.
stop_sequences: List of strings that stop generation.
latency: Performance mode - "standard" or "optimized".
additional_model_request_fields: Additional model-specific parameters.
"""
max_tokens: Optional[int] = Field(default_factory=lambda: 4096, ge=1)
temperature: Optional[float] = Field(default_factory=lambda: 0.7, ge=0.0, le=1.0)
top_p: Optional[float] = Field(default_factory=lambda: 0.999, ge=0.0, le=1.0)
@@ -573,6 +735,11 @@ class AWSBedrockLLMService(LLMService):
logger.info(f"Using AWS Bedrock model: {model}")
def can_generate_metrics(self) -> bool:
"""Check if the service can generate usage metrics.
Returns:
True if metrics generation is supported.
"""
return True
def create_context_aggregator(
@@ -582,21 +749,21 @@ class AWSBedrockLLMService(LLMService):
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
) -> AWSBedrockContextAggregatorPair:
"""Create an instance of AWSBedrockContextAggregatorPair from an
OpenAILLMContext. Constructor keyword arguments for both the user and
assistant aggregators can be provided.
"""Create AWS Bedrock-specific context aggregators.
Creates a pair of context aggregators optimized for AWS Bedrocks's message
format, including support for function calls, tool usage, and image handling.
Args:
context (OpenAILLMContext): The LLM context.
user_params (LLMUserAggregatorParams, optional): User aggregator
parameters.
assistant_params (LLMAssistantAggregatorParams, optional): User
aggregator parameters.
context: The LLM context to create aggregators for.
user_params: Parameters for user message aggregation.
assistant_params: Parameters for assistant message aggregation.
Returns:
AWSBedrockContextAggregatorPair: A pair of context aggregators, one
for the user and one for the assistant, encapsulated in an
AWSBedrockContextAggregatorPair: A pair of context aggregators, one for
the user and one for the assistant, encapsulated in an
AWSBedrockContextAggregatorPair.
"""
context.set_llm_adapter(self.get_llm_adapter())
@@ -711,6 +878,8 @@ class AWSBedrockLLMService(LLMService):
function_calls = []
for event in response["stream"]:
self.reset_watchdog()
# Handle text content
if "contentBlockDelta" in event:
delta = event["contentBlockDelta"]["delta"]
@@ -762,6 +931,7 @@ class AWSBedrockLLMService(LLMService):
completion_tokens += usage.get("outputTokens", 0)
cache_read_input_tokens += usage.get("cacheReadInputTokens", 0)
cache_creation_input_tokens += usage.get("cacheWriteInputTokens", 0)
await self.run_function_calls(function_calls)
except asyncio.CancelledError:
# If we're interrupted, we won't get a complete usage report. So set our flag to use the
@@ -789,6 +959,12 @@ class AWSBedrockLLMService(LLMService):
)
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process incoming frames and handle LLM-specific frame types.
Args:
frame: The frame to process.
direction: The direction of frame processing.
"""
await super().process_frame(frame, direction)
context = None

View File

@@ -284,7 +284,8 @@ class AWSTranscribeSTTService(STTService):
break
try:
response = await self._ws_client.recv()
response = await asyncio.wait_for(self._ws_client.recv(), timeout=1.0)
headers, payload = decode_event(response)
if headers.get(":message-type") == "event":
@@ -334,6 +335,8 @@ class AWSTranscribeSTTService(STTService):
else:
logger.debug(f"{self} Other message type received: {headers}")
logger.debug(f"{self} Payload: {payload}")
except asyncio.TimeoutError:
self.reset_watchdog()
except websockets.exceptions.ConnectionClosed as e:
logger.error(
f"{self} WebSocket connection closed in receive loop with code {e.code}: {e.reason}"

View File

@@ -6,7 +6,7 @@
import asyncio
import os
from typing import AsyncGenerator, Optional
from typing import AsyncGenerator, List, Optional
from loguru import logger
from pydantic import BaseModel
@@ -115,6 +115,7 @@ class AWSPollyTTSService(TTSService):
pitch: Optional[str] = None
rate: Optional[str] = None
volume: Optional[str] = None
lexicon_names: Optional[List[str]] = None
def __init__(
self,
@@ -147,6 +148,7 @@ class AWSPollyTTSService(TTSService):
"pitch": params.pitch,
"rate": params.rate,
"volume": params.volume,
"lexicon_names": params.lexicon_names,
}
self._resampler = create_default_resampler()
@@ -235,6 +237,7 @@ class AWSPollyTTSService(TTSService):
"Engine": self._settings["engine"],
# AWS only supports 8000 and 16000 for PCM. We select 16000.
"SampleRate": "16000",
"LexiconNames": self._settings["lexicon_names"],
}
# Filter out None values

View File

@@ -4,6 +4,12 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""AWS Nova Sonic LLM service implementation for Pipecat AI framework.
This module provides a speech-to-speech LLM service using AWS Nova Sonic, which supports
bidirectional audio streaming, text generation, and function calling capabilities.
"""
import asyncio
import base64
import json
@@ -25,6 +31,7 @@ from pipecat.frames.frames import (
CancelFrame,
EndFrame,
Frame,
FunctionCallFromLLM,
InputAudioRawFrame,
InterimTranscriptionFrame,
LLMFullResponseEndFrame,
@@ -82,22 +89,37 @@ except ModuleNotFoundError as e:
class AWSNovaSonicUnhandledFunctionException(Exception):
"""Exception raised when the LLM attempts to call an unregistered function."""
pass
class ContentType(Enum):
"""Content types supported by AWS Nova Sonic."""
AUDIO = "AUDIO"
TEXT = "TEXT"
TOOL = "TOOL"
class TextStage(Enum):
"""Text generation stages in AWS Nova Sonic responses."""
FINAL = "FINAL" # what has been said
SPECULATIVE = "SPECULATIVE" # what's planned to be said
@dataclass
class CurrentContent:
"""Represents content currently being received from AWS Nova Sonic.
Parameters:
type: The type of content (audio, text, or tool).
role: The role generating the content (user, assistant, etc.).
text_stage: The stage of text generation (final or speculative).
text_content: The actual text content if applicable.
"""
type: ContentType
role: Role
text_stage: TextStage # None if not text
@@ -114,6 +136,20 @@ class CurrentContent:
class Params(BaseModel):
"""Configuration parameters for AWS Nova Sonic.
Attributes:
input_sample_rate: Audio input sample rate in Hz.
input_sample_size: Audio input sample size in bits.
input_channel_count: Number of input audio channels.
output_sample_rate: Audio output sample rate in Hz.
output_sample_size: Audio output sample size in bits.
output_channel_count: Number of output audio channels.
max_tokens: Maximum number of tokens to generate.
top_p: Nucleus sampling parameter.
temperature: Sampling temperature for text generation.
"""
# Audio input
input_sample_rate: Optional[int] = Field(default=16000)
input_sample_size: Optional[int] = Field(default=16)
@@ -131,6 +167,24 @@ class Params(BaseModel):
class AWSNovaSonicLLMService(LLMService):
"""AWS Nova Sonic speech-to-speech LLM service.
Provides bidirectional audio streaming, real-time transcription, text generation,
and function calling capabilities using AWS Nova Sonic model.
Args:
secret_access_key: AWS secret access key for authentication.
access_key_id: AWS access key ID for authentication.
region: AWS region where the service is hosted.
model: Model identifier. Defaults to "amazon.nova-sonic-v1:0".
voice_id: Voice ID for speech synthesis. Options: matthew, tiffany, amy.
params: Model parameters for audio configuration and inference.
system_instruction: System-level instruction for the model.
tools: Available tools/functions for the model to use.
send_transcription_frames: Whether to emit transcription frames.
**kwargs: Additional arguments passed to the parent LLMService.
"""
# Override the default adapter to use the AWSNovaSonicLLMAdapter one
adapter_class = AWSNovaSonicLLMAdapter
@@ -187,16 +241,31 @@ class AWSNovaSonicLLMService(LLMService):
#
async def start(self, frame: StartFrame):
"""Start the service and initiate connection to AWS Nova Sonic.
Args:
frame: The start frame triggering service initialization.
"""
await super().start(frame)
self._wants_connection = True
await self._start_connecting()
async def stop(self, frame: EndFrame):
"""Stop the service and close connections.
Args:
frame: The end frame triggering service shutdown.
"""
await super().stop(frame)
self._wants_connection = False
await self._disconnect()
async def cancel(self, frame: CancelFrame):
"""Cancel the service and close connections.
Args:
frame: The cancel frame triggering service cancellation.
"""
await super().cancel(frame)
self._wants_connection = False
await self._disconnect()
@@ -206,6 +275,11 @@ class AWSNovaSonicLLMService(LLMService):
#
async def reset_conversation(self):
"""Reset the conversation state while preserving context.
Handles bot stopped speaking event, disconnects from the service,
and reconnects with the preserved context.
"""
logger.debug("Resetting conversation")
await self._handle_bot_stopped_speaking(delay_to_catch_trailing_assistant_text=False)
@@ -221,6 +295,12 @@ class AWSNovaSonicLLMService(LLMService):
#
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process incoming frames and handle service-specific logic.
Args:
frame: The frame to process.
direction: The direction the frame is traveling.
"""
await super().process_frame(frame, direction)
if isinstance(frame, OpenAILLMContextFrame):
@@ -696,7 +776,9 @@ class AWSNovaSonicLLMService(LLMService):
try:
while self._stream and not self._disconnecting:
output = await self._stream.await_output()
result = await output[1].receive()
result = await asyncio.wait_for(output[1].receive(), timeout=1.0)
self.reset_watchdog()
if result.value and result.value.bytes_:
response_data = result.value.bytes_.decode("utf-8")
@@ -725,7 +807,8 @@ class AWSNovaSonicLLMService(LLMService):
elif "completionEnd" in event_json:
# Handle the LLM completion ending
await self._handle_completion_end_event(event_json)
except asyncio.TimeoutError:
self.reset_watchdog()
except Exception as e:
logger.error(f"{self} error processing responses: {e}")
if self._wants_connection:
@@ -804,12 +887,16 @@ class AWSNovaSonicLLMService(LLMService):
# Call tool function
if self.has_function(function_name):
if function_name in self._functions.keys() or None in self._functions.keys():
await self.call_function(
context=self._context,
tool_call_id=tool_call_id,
function_name=function_name,
arguments=arguments,
)
function_calls_llm = [
FunctionCallFromLLM(
context=self._context,
tool_call_id=tool_call_id,
function_name=function_name,
arguments=arguments,
)
]
await self.run_function_calls(function_calls_llm)
else:
raise AWSNovaSonicUnhandledFunctionException(
f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function."
@@ -952,6 +1039,16 @@ class AWSNovaSonicLLMService(LLMService):
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
) -> AWSNovaSonicContextAggregatorPair:
"""Create context aggregator pair for managing conversation context.
Args:
context: The OpenAI LLM context to upgrade.
user_params: Parameters for the user context aggregator.
assistant_params: Parameters for the assistant context aggregator.
Returns:
A pair of user and assistant context aggregators.
"""
context.set_llm_adapter(self.get_llm_adapter())
user = AWSNovaSonicUserContextAggregator(context=context, params=user_params)
@@ -970,6 +1067,14 @@ class AWSNovaSonicLLMService(LLMService):
)
async def trigger_assistant_response(self):
"""Trigger an assistant response by sending audio cue.
Sends a pre-recorded "ready" audio trigger to prompt the assistant
to start speaking. This is useful for controlling conversation flow.
Returns:
False if already triggering a response, True otherwise.
"""
if self._triggering_assistant_response:
return False

View File

@@ -4,6 +4,12 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Context management for AWS Nova Sonic LLM service.
This module provides specialized context aggregators and message handling for AWS Nova Sonic,
including conversation history management and role-specific message processing.
"""
import copy
from dataclasses import dataclass, field
from enum import Enum
@@ -35,6 +41,8 @@ from pipecat.services.openai.llm import (
class Role(Enum):
"""Roles supported in AWS Nova Sonic conversations."""
SYSTEM = "SYSTEM"
USER = "USER"
ASSISTANT = "ASSISTANT"
@@ -43,17 +51,42 @@ class Role(Enum):
@dataclass
class AWSNovaSonicConversationHistoryMessage:
"""A single message in AWS Nova Sonic conversation history.
Parameters:
role: The role of the message sender (USER or ASSISTANT only).
text: The text content of the message.
"""
role: Role # only USER and ASSISTANT
text: str
@dataclass
class AWSNovaSonicConversationHistory:
"""Complete conversation history for AWS Nova Sonic initialization.
Parameters:
system_instruction: System-level instruction for the conversation.
messages: List of conversation messages between user and assistant.
"""
system_instruction: str = None
messages: list[AWSNovaSonicConversationHistoryMessage] = field(default_factory=list)
class AWSNovaSonicLLMContext(OpenAILLMContext):
"""Specialized LLM context for AWS Nova Sonic service.
Extends OpenAI context with Nova Sonic-specific message handling,
conversation history management, and text buffering capabilities.
Args:
messages: Initial messages for the context.
tools: Available tools for the context.
**kwargs: Additional arguments passed to parent class.
"""
def __init__(self, messages=None, tools=None, **kwargs):
super().__init__(messages=messages, tools=tools, **kwargs)
self.__setup_local()
@@ -67,6 +100,15 @@ class AWSNovaSonicLLMContext(OpenAILLMContext):
def upgrade_to_nova_sonic(
obj: OpenAILLMContext, system_instruction: str
) -> "AWSNovaSonicLLMContext":
"""Upgrade an OpenAI context to AWS Nova Sonic context.
Args:
obj: The OpenAI context to upgrade.
system_instruction: System instruction for the context.
Returns:
The upgraded AWS Nova Sonic context.
"""
if isinstance(obj, OpenAILLMContext) and not isinstance(obj, AWSNovaSonicLLMContext):
obj.__class__ = AWSNovaSonicLLMContext
obj.__setup_local(system_instruction)
@@ -74,6 +116,14 @@ class AWSNovaSonicLLMContext(OpenAILLMContext):
# NOTE: this method has the side-effect of updating _system_instruction from messages
def get_messages_for_initializing_history(self) -> AWSNovaSonicConversationHistory:
"""Get conversation history for initializing AWS Nova Sonic session.
Processes stored messages and extracts system instruction and conversation
history in the format expected by AWS Nova Sonic.
Returns:
Formatted conversation history with system instruction and messages.
"""
history = AWSNovaSonicConversationHistory(system_instruction=self._system_instruction)
# Bail if there are no messages
@@ -103,6 +153,11 @@ class AWSNovaSonicLLMContext(OpenAILLMContext):
return history
def get_messages_for_persistent_storage(self):
"""Get messages formatted for persistent storage.
Returns:
List of messages including system instruction if present.
"""
messages = super().get_messages_for_persistent_storage()
# If we have a system instruction and messages doesn't already contain it, add it
if self._system_instruction and not (messages and messages[0].get("role") == "system"):
@@ -110,6 +165,14 @@ class AWSNovaSonicLLMContext(OpenAILLMContext):
return messages
def from_standard_message(self, message) -> AWSNovaSonicConversationHistoryMessage:
"""Convert standard message format to Nova Sonic format.
Args:
message: Standard message dictionary to convert.
Returns:
Nova Sonic conversation history message, or None if not convertible.
"""
role = message.get("role")
if message.get("role") == "user" or message.get("role") == "assistant":
content = message.get("content")
@@ -131,10 +194,20 @@ class AWSNovaSonicLLMContext(OpenAILLMContext):
# Sonic conversation history
def buffer_user_text(self, text):
"""Buffer user text for later flushing to context.
Args:
text: User text to buffer.
"""
self._user_text += f" {text}" if self._user_text else text
# logger.debug(f"User text buffered: {self._user_text}")
def flush_aggregated_user_text(self) -> str:
"""Flush buffered user text to context as a complete message.
Returns:
The flushed user text, or empty string if no text was buffered.
"""
if not self._user_text:
return ""
user_text = self._user_text
@@ -148,10 +221,16 @@ class AWSNovaSonicLLMContext(OpenAILLMContext):
return user_text
def buffer_assistant_text(self, text):
"""Buffer assistant text for later flushing to context.
Args:
text: Assistant text to buffer.
"""
self._assistant_text += text
# logger.debug(f"Assistant text buffered: {self._assistant_text}")
def flush_aggregated_assistant_text(self):
"""Flush buffered assistant text to context as a complete message."""
if not self._assistant_text:
return
message = {
@@ -165,13 +244,31 @@ class AWSNovaSonicLLMContext(OpenAILLMContext):
@dataclass
class AWSNovaSonicMessagesUpdateFrame(DataFrame):
"""Frame containing updated AWS Nova Sonic context.
Parameters:
context: The updated AWS Nova Sonic LLM context.
"""
context: AWSNovaSonicLLMContext
class AWSNovaSonicUserContextAggregator(OpenAIUserContextAggregator):
"""Context aggregator for user messages in AWS Nova Sonic conversations.
Extends the OpenAI user context aggregator to emit Nova Sonic-specific
context update frames.
"""
async def process_frame(
self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM
):
"""Process frames and emit Nova Sonic-specific context updates.
Args:
frame: The frame to process.
direction: The direction the frame is traveling.
"""
await super().process_frame(frame, direction)
# Parent does not push LLMMessagesUpdateFrame
@@ -180,7 +277,19 @@ class AWSNovaSonicUserContextAggregator(OpenAIUserContextAggregator):
class AWSNovaSonicAssistantContextAggregator(OpenAIAssistantContextAggregator):
"""Context aggregator for assistant messages in AWS Nova Sonic conversations.
Provides specialized handling for assistant responses and function calls
in AWS Nova Sonic context, with custom frame processing logic.
"""
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process frames with Nova Sonic-specific logic.
Args:
frame: The frame to process.
direction: The direction the frame is traveling.
"""
# HACK: For now, disable the context aggregator by making it just pass through all frames
# that the parent handles (except the function call stuff, which we still need).
# For an explanation of this hack, see
@@ -205,6 +314,11 @@ class AWSNovaSonicAssistantContextAggregator(OpenAIAssistantContextAggregator):
await super().process_frame(frame, direction)
async def handle_function_call_result(self, frame: FunctionCallResultFrame):
"""Handle function call results for AWS Nova Sonic.
Args:
frame: The function call result frame to handle.
"""
await super().handle_function_call_result(frame)
# The standard function callback code path pushes the FunctionCallResultFrame from the LLM
@@ -217,11 +331,28 @@ class AWSNovaSonicAssistantContextAggregator(OpenAIAssistantContextAggregator):
@dataclass
class AWSNovaSonicContextAggregatorPair:
"""Pair of user and assistant context aggregators for AWS Nova Sonic.
Parameters:
_user: The user context aggregator.
_assistant: The assistant context aggregator.
"""
_user: AWSNovaSonicUserContextAggregator
_assistant: AWSNovaSonicAssistantContextAggregator
def user(self) -> AWSNovaSonicUserContextAggregator:
"""Get the user context aggregator.
Returns:
The user context aggregator instance.
"""
return self._user
def assistant(self) -> AWSNovaSonicAssistantContextAggregator:
"""Get the assistant context aggregator.
Returns:
The assistant context aggregator instance.
"""
return self._assistant

View File

@@ -4,6 +4,8 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Custom frames for AWS Nova Sonic LLM service."""
from dataclasses import dataclass
from pipecat.frames.frames import DataFrame, FunctionCallResultFrame
@@ -11,4 +13,13 @@ from pipecat.frames.frames import DataFrame, FunctionCallResultFrame
@dataclass
class AWSNovaSonicFunctionCallResultFrame(DataFrame):
"""Frame containing function call result for AWS Nova Sonic processing.
This frame wraps a standard function call result frame to enable
AWS Nova Sonic-specific handling and context updates.
Parameters:
result_frame: The underlying function call result frame.
"""
result_frame: FunctionCallResultFrame

View File

@@ -4,6 +4,8 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Azure OpenAI service implementation for the Pipecat AI framework."""
from loguru import logger
from openai import AsyncAzureOpenAI
@@ -17,11 +19,11 @@ class AzureLLMService(OpenAILLMService):
maintaining full compatibility with OpenAI's interface and functionality.
Args:
api_key (str): The API key for accessing Azure OpenAI
endpoint (str): The Azure endpoint URL
model (str): The model identifier to use
api_version (str, optional): Azure API version. Defaults to "2024-09-01-preview"
**kwargs: Additional keyword arguments passed to OpenAILLMService
api_key: The API key for accessing Azure OpenAI.
endpoint: The Azure endpoint URL.
model: The model identifier to use.
api_version: Azure API version. Defaults to "2024-09-01-preview".
**kwargs: Additional keyword arguments passed to OpenAILLMService.
"""
def __init__(
@@ -40,7 +42,16 @@ class AzureLLMService(OpenAILLMService):
super().__init__(api_key=api_key, model=model, **kwargs)
def create_client(self, api_key=None, base_url=None, **kwargs):
"""Create OpenAI-compatible client for Azure OpenAI endpoint."""
"""Create OpenAI-compatible client for Azure OpenAI endpoint.
Args:
api_key: API key for authentication. Uses instance key if None.
base_url: Base URL for the client. Ignored for Azure implementation.
**kwargs: Additional keyword arguments. Ignored for Azure implementation.
Returns:
AsyncAzureOpenAI: Configured Azure OpenAI client instance.
"""
logger.debug(f"Creating Azure OpenAI client with endpoint {self._endpoint}")
return AsyncAzureOpenAI(
api_key=api_key,

View File

@@ -4,6 +4,8 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Cartesia text-to-speech service implementations."""
import base64
import json
import uuid
@@ -27,6 +29,7 @@ from pipecat.frames.frames import (
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.tts_service import AudioContextWordTTSService, TTSService
from pipecat.transcriptions.language import Language
from pipecat.utils.asyncio.watchdog_async_iterator import WatchdogAsyncIterator
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
from pipecat.utils.text.skip_tags_aggregator import SkipTagsAggregator
from pipecat.utils.tracing.service_decorators import traced_tts
@@ -42,6 +45,14 @@ except ModuleNotFoundError as e:
def language_to_cartesia_language(language: Language) -> Optional[str]:
"""Convert a Language enum to Cartesia language code.
Args:
language: The Language enum value to convert.
Returns:
The corresponding Cartesia language code, or None if not supported.
"""
BASE_LANGUAGES = {
Language.DE: "de",
Language.EN: "en",
@@ -74,7 +85,35 @@ def language_to_cartesia_language(language: Language) -> Optional[str]:
class CartesiaTTSService(AudioContextWordTTSService):
"""Cartesia TTS service with WebSocket streaming and word timestamps.
Provides text-to-speech using Cartesia's streaming WebSocket API.
Supports word-level timestamps, audio context management, and various voice
customization options including speed and emotion controls.
Args:
api_key: Cartesia API key for authentication.
voice_id: ID of the voice to use for synthesis.
cartesia_version: API version string for Cartesia service.
url: WebSocket URL for Cartesia TTS API.
model: TTS model to use (e.g., "sonic-2").
sample_rate: Audio sample rate. If None, uses default.
encoding: Audio encoding format.
container: Audio container format.
params: Additional input parameters for voice customization.
text_aggregator: Custom text aggregator for processing input text.
**kwargs: Additional arguments passed to the parent service.
"""
class InputParams(BaseModel):
"""Input parameters for Cartesia TTS configuration.
Parameters:
language: Language to use for synthesis.
speed: Voice speed control (string or float).
emotion: List of emotion controls (deprecated).
"""
language: Optional[Language] = Language.EN
speed: Optional[Union[str, float]] = ""
emotion: Optional[List[str]] = []
@@ -137,14 +176,32 @@ class CartesiaTTSService(AudioContextWordTTSService):
self._receive_task = None
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.
Returns:
True, as Cartesia service supports metrics generation.
"""
return True
async def set_model(self, model: str):
"""Set the TTS model.
Args:
model: The model name to use for synthesis.
"""
self._model_id = model
await super().set_model(model)
logger.info(f"Switching TTS model to: [{model}]")
def language_to_service_language(self, language: Language) -> Optional[str]:
"""Convert a Language enum to Cartesia language format.
Args:
language: The language to convert.
Returns:
The Cartesia-specific language code, or None if not supported.
"""
return language_to_cartesia_language(language)
def _build_msg(
@@ -182,15 +239,30 @@ class CartesiaTTSService(AudioContextWordTTSService):
return json.dumps(msg)
async def start(self, frame: StartFrame):
"""Start the Cartesia TTS service.
Args:
frame: The start frame containing initialization parameters.
"""
await super().start(frame)
self._settings["output_format"]["sample_rate"] = self.sample_rate
await self._connect()
async def stop(self, frame: EndFrame):
"""Stop the Cartesia TTS service.
Args:
frame: The end frame.
"""
await super().stop(frame)
await self._disconnect()
async def cancel(self, frame: CancelFrame):
"""Stop the Cartesia TTS service.
Args:
frame: The end frame.
"""
await super().cancel(frame)
await self._disconnect()
@@ -247,6 +319,7 @@ class CartesiaTTSService(AudioContextWordTTSService):
self._context_id = None
async def flush_audio(self):
"""Flush any pending audio and finalize the current context."""
if not self._context_id or not self._websocket:
return
logger.trace(f"{self}: flushing audio")
@@ -255,7 +328,9 @@ class CartesiaTTSService(AudioContextWordTTSService):
self._context_id = None
async def _receive_messages(self):
async for message in self._get_websocket():
async for message in WatchdogAsyncIterator(
self._get_websocket(), manager=self.task_manager
):
msg = json.loads(message)
if not msg or not self.audio_context_available(msg["context_id"]):
continue
@@ -287,6 +362,14 @@ class CartesiaTTSService(AudioContextWordTTSService):
@traced_tts
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
"""Generate speech from text using Cartesia's streaming API.
Args:
text: The text to synthesize into speech.
Yields:
Frame: Audio frames containing the synthesized speech.
"""
logger.debug(f"{self}: Generating TTS [{text}]")
try:
@@ -316,7 +399,34 @@ class CartesiaTTSService(AudioContextWordTTSService):
class CartesiaHttpTTSService(TTSService):
"""Cartesia HTTP-based TTS service.
Provides text-to-speech using Cartesia's HTTP API for simpler, non-streaming
synthesis. Suitable for use cases where streaming is not required and simpler
integration is preferred.
Args:
api_key: Cartesia API key for authentication.
voice_id: ID of the voice to use for synthesis.
model: TTS model to use (e.g., "sonic-2").
base_url: Base URL for Cartesia HTTP API.
cartesia_version: API version string for Cartesia service.
sample_rate: Audio sample rate. If None, uses default.
encoding: Audio encoding format.
container: Audio container format.
params: Additional input parameters for voice customization.
**kwargs: Additional arguments passed to the parent TTSService.
"""
class InputParams(BaseModel):
"""Input parameters for Cartesia HTTP TTS configuration.
Parameters:
language: Language to use for synthesis.
speed: Voice speed control (string or float).
emotion: List of emotion controls (deprecated).
"""
language: Optional[Language] = Language.EN
speed: Optional[Union[str, float]] = ""
emotion: Optional[List[str]] = Field(default_factory=list)
@@ -363,25 +473,61 @@ class CartesiaHttpTTSService(TTSService):
)
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.
Returns:
True, as Cartesia HTTP service supports metrics generation.
"""
return True
def language_to_service_language(self, language: Language) -> Optional[str]:
"""Convert a Language enum to Cartesia language format.
Args:
language: The language to convert.
Returns:
The Cartesia-specific language code, or None if not supported.
"""
return language_to_cartesia_language(language)
async def start(self, frame: StartFrame):
"""Start the Cartesia HTTP TTS service.
Args:
frame: The start frame containing initialization parameters.
"""
await super().start(frame)
self._settings["output_format"]["sample_rate"] = self.sample_rate
async def stop(self, frame: EndFrame):
"""Stop the Cartesia HTTP TTS service.
Args:
frame: The end frame.
"""
await super().stop(frame)
await self._client.close()
async def cancel(self, frame: CancelFrame):
"""Cancel the Cartesia HTTP TTS service.
Args:
frame: The cancel frame.
"""
await super().cancel(frame)
await self._client.close()
@traced_tts
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
"""Generate speech from text using Cartesia's HTTP API.
Args:
text: The text to synthesize into speech.
Yields:
Frame: Audio frames containing the synthesized speech.
"""
logger.debug(f"{self}: Generating TTS [{text}]")
try:

View File

@@ -4,6 +4,8 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Cerebras LLM service implementation using OpenAI-compatible interface."""
from typing import List
from loguru import logger
@@ -21,10 +23,10 @@ class CerebrasLLMService(OpenAILLMService):
maintaining full compatibility with OpenAI's interface and functionality.
Args:
api_key (str): The API key for accessing Cerebras's API
base_url (str, optional): The base URL for Cerebras API. Defaults to "https://api.cerebras.ai/v1"
model (str, optional): The model identifier to use. Defaults to "llama-3.3-70b"
**kwargs: Additional keyword arguments passed to OpenAILLMService
api_key: The API key for accessing Cerebras's API.
base_url: The base URL for Cerebras API. Defaults to "https://api.cerebras.ai/v1".
model: The model identifier to use. Defaults to "llama-3.3-70b".
**kwargs: Additional keyword arguments passed to OpenAILLMService.
"""
def __init__(
@@ -38,7 +40,16 @@ class CerebrasLLMService(OpenAILLMService):
super().__init__(api_key=api_key, base_url=base_url, model=model, **kwargs)
def create_client(self, api_key=None, base_url=None, **kwargs):
"""Create OpenAI-compatible client for Cerebras API endpoint."""
"""Create OpenAI-compatible client for Cerebras API endpoint.
Args:
api_key: The API key for authentication. If None, uses instance key.
base_url: The base URL for the API. If None, uses instance URL.
**kwargs: Additional arguments passed to the client constructor.
Returns:
An OpenAI-compatible client configured for Cerebras API.
"""
logger.debug(f"Creating Cerebras client with api {base_url}")
return super().create_client(api_key, base_url, **kwargs)
@@ -48,14 +59,14 @@ class CerebrasLLMService(OpenAILLMService):
"""Create a streaming chat completion using Cerebras's API.
Args:
context (OpenAILLMContext): The context object containing tools configuration
and other settings for the chat completion.
messages (List[ChatCompletionMessageParam]): The list of messages comprising
the conversation history and current request.
context: The context object containing tools configuration
and other settings for the chat completion.
messages: The list of messages comprising
the conversation history and current request.
Returns:
AsyncStream[ChatCompletionChunk]: A streaming response of chat completion
chunks that can be processed asynchronously.
A streaming response of chat completion
chunks that can be processed asynchronously.
"""
params = {
"model": self.model_name,

View File

@@ -4,6 +4,8 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Deepgram speech-to-text service implementation."""
from typing import AsyncGenerator, Dict, Optional
from loguru import logger
@@ -41,6 +43,22 @@ except ModuleNotFoundError as e:
class DeepgramSTTService(STTService):
"""Deepgram speech-to-text service.
Provides real-time speech recognition using Deepgram's WebSocket API.
Supports configurable models, languages, VAD events, and various audio
processing options.
Args:
api_key: Deepgram API key for authentication.
url: Deprecated. Use base_url instead.
base_url: Custom Deepgram API base URL.
sample_rate: Audio sample rate. If None, uses default or live_options value.
live_options: Deepgram LiveOptions for detailed configuration.
addons: Additional Deepgram features to enable.
**kwargs: Additional arguments passed to the parent STTService.
"""
def __init__(
self,
*,
@@ -108,12 +126,27 @@ class DeepgramSTTService(STTService):
@property
def vad_enabled(self):
"""Check if Deepgram VAD events are enabled.
Returns:
True if VAD events are enabled in the current settings.
"""
return self._settings["vad_events"]
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.
Returns:
True, as Deepgram service supports metrics generation.
"""
return True
async def set_model(self, model: str):
"""Set the Deepgram model and reconnect.
Args:
model: The Deepgram model name to use.
"""
await super().set_model(model)
logger.info(f"Switching STT model to: [{model}]")
self._settings["model"] = model
@@ -121,25 +154,53 @@ class DeepgramSTTService(STTService):
await self._connect()
async def set_language(self, language: Language):
"""Set the recognition language and reconnect.
Args:
language: The language to use for speech recognition.
"""
logger.info(f"Switching STT language to: [{language}]")
self._settings["language"] = language
await self._disconnect()
await self._connect()
async def start(self, frame: StartFrame):
"""Start the Deepgram STT service.
Args:
frame: The start frame containing initialization parameters.
"""
await super().start(frame)
self._settings["sample_rate"] = self.sample_rate
await self._connect()
async def stop(self, frame: EndFrame):
"""Stop the Deepgram STT service.
Args:
frame: The end frame.
"""
await super().stop(frame)
await self._disconnect()
async def cancel(self, frame: CancelFrame):
"""Cancel the Deepgram STT service.
Args:
frame: The cancel frame.
"""
await super().cancel(frame)
await self._disconnect()
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
"""Send audio data to Deepgram for transcription.
Args:
audio: Raw audio bytes to transcribe.
Yields:
Frame: None (transcription results come via WebSocket callbacks).
"""
await self._connection.send(audio)
yield None
@@ -172,6 +233,7 @@ class DeepgramSTTService(STTService):
await self._connection.finish()
async def start_metrics(self):
"""Start TTFB and processing metrics collection."""
await self.start_ttfb_metrics()
await self.start_processing_metrics()
@@ -235,6 +297,12 @@ class DeepgramSTTService(STTService):
)
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process frames with Deepgram-specific handling.
Args:
frame: The frame to process.
direction: The direction of frame processing.
"""
await super().process_frame(frame, direction)
if isinstance(frame, UserStartedSpeakingFrame) and not self.vad_enabled:

View File

@@ -4,6 +4,7 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""DeepSeek LLM service implementation using OpenAI-compatible interface."""
from typing import List
@@ -22,10 +23,10 @@ class DeepSeekLLMService(OpenAILLMService):
maintaining full compatibility with OpenAI's interface and functionality.
Args:
api_key (str): The API key for accessing DeepSeek's API
base_url (str, optional): The base URL for DeepSeek API. Defaults to "https://api.deepseek.com/v1"
model (str, optional): The model identifier to use. Defaults to "deepseek-chat"
**kwargs: Additional keyword arguments passed to OpenAILLMService
api_key: The API key for accessing DeepSeek's API.
base_url: The base URL for DeepSeek API. Defaults to "https://api.deepseek.com/v1".
model: The model identifier to use. Defaults to "deepseek-chat".
**kwargs: Additional keyword arguments passed to OpenAILLMService.
"""
def __init__(
@@ -39,24 +40,33 @@ class DeepSeekLLMService(OpenAILLMService):
super().__init__(api_key=api_key, base_url=base_url, model=model, **kwargs)
def create_client(self, api_key=None, base_url=None, **kwargs):
"""Create OpenAI-compatible client for DeepSeek API endpoint."""
"""Create OpenAI-compatible client for DeepSeek API endpoint.
Args:
api_key: The API key for authentication. If None, uses instance default.
base_url: The base URL for the API. If None, uses instance default.
**kwargs: Additional keyword arguments for client configuration.
Returns:
An OpenAI-compatible client configured for DeepSeek's API.
"""
logger.debug(f"Creating DeepSeek client with api {base_url}")
return super().create_client(api_key, base_url, **kwargs)
async def get_chat_completions(
self, context: OpenAILLMContext, messages: List[ChatCompletionMessageParam]
) -> AsyncStream[ChatCompletionChunk]:
"""Create a streaming chat completion using Cerebras's API.
"""Create a streaming chat completion using DeepSeek's API.
Args:
context (OpenAILLMContext): The context object containing tools configuration
and other settings for the chat completion.
messages (List[ChatCompletionMessageParam]): The list of messages comprising
the conversation history and current request.
context: The context object containing tools configuration
and other settings for the chat completion.
messages: The list of messages comprising the conversation
history and current request.
Returns:
AsyncStream[ChatCompletionChunk]: A streaming response of chat completion
chunks that can be processed asynchronously.
A streaming response of chat completion chunks that can be
processed asynchronously.
"""
params = {
"model": self.model_name,

View File

@@ -32,6 +32,7 @@ from pipecat.services.tts_service import (
WordTTSService,
)
from pipecat.transcriptions.language import Language
from pipecat.utils.asyncio.watchdog_async_iterator import WatchdogAsyncIterator
from pipecat.utils.tracing.service_decorators import traced_tts
# See .env.example for ElevenLabs configuration needed
@@ -284,7 +285,6 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
logger.trace(f"{self}: flushing audio")
msg = {"context_id": self._context_id, "flush": True}
await self._websocket.send(json.dumps(msg))
self._context_id = None
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
await super().push_frame(frame, direction)
@@ -380,6 +380,12 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
if self._context_id and self._websocket:
logger.trace(f"Closing context {self._context_id} due to interruption")
try:
# ElevenLabs requires that Pipecat manages the contexts and closes them
# when they're not longer in use. Since a StartInterruptionFrame is pushed
# every time the user speaks, we'll use this as a trigger to close the context
# and reset the state.
# Note: We do not need to call remove_audio_context here, as the context is
# automatically reset when super ()._handle_interruption is called.
await self._websocket.send(
json.dumps({"context_id": self._context_id, "close_context": True})
)
@@ -389,12 +395,24 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
self._started = False
async def _receive_messages(self):
async for message in self._get_websocket():
async for message in WatchdogAsyncIterator(
self._get_websocket(), manager=self.task_manager
):
msg = json.loads(message)
# Check if this message belongs to the current context
received_ctx_id = msg.get("contextId")
# Handle final messages first, regardless of context availability
# At the moment, this message is received AFTER the close_context message is
# sent, so it doesn't serve any functional purpose. For now, we'll just log it.
if msg.get("isFinal") is True:
logger.trace(f"Received final message for context {received_ctx_id}")
continue
# Check if this message belongs to the current context.
# This should never happen, so warn about it.
if not self.audio_context_available(received_ctx_id):
logger.trace(f"Ignoring message from unavailable context: {received_ctx_id}")
logger.warning(f"Ignoring message from unavailable context: {received_ctx_id}")
continue
if msg.get("audio"):
@@ -408,21 +426,28 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
word_times = calculate_word_times(msg["alignment"], self._cumulative_time)
await self.add_word_timestamps(word_times)
self._cumulative_time = word_times[-1][1]
if msg.get("isFinal"):
logger.trace(f"Received final message for context {received_ctx_id}")
await self.remove_audio_context(received_ctx_id)
# Reset context tracking if this was our active context
if self._context_id == received_ctx_id:
self._context_id = None
self._started = False
async def _keepalive_task_handler(self):
KEEPALIVE_SLEEP = 10 if self.task_manager.task_watchdog_enabled else 3
while True:
await asyncio.sleep(10)
self.reset_watchdog()
await asyncio.sleep(KEEPALIVE_SLEEP)
try:
# Send an empty message to keep the connection alive
if self._websocket and self._websocket.open:
await self._websocket.send(json.dumps({}))
if self._context_id:
# Send keepalive with context ID to keep the connection alive
keepalive_message = {
"text": "",
"context_id": self._context_id,
}
logger.trace(f"Sending keepalive for context {self._context_id}")
else:
# It's possible to have a user interruption which clears the context
# without generating a new TTS response. In this case, we'll just send
# an empty message to keep the connection alive.
keepalive_message = {"text": ""}
logger.trace("Sending keepalive without context")
await self._websocket.send(json.dumps(keepalive_message))
except websockets.ConnectionClosed as e:
logger.warning(f"{self} keepalive error: {e}")
break
@@ -441,14 +466,6 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
await self._connect()
try:
# Close previous context if there was one
if self._context_id and not self._started:
await self._websocket.send(
json.dumps({"context_id": self._context_id, "close_context": True})
)
await self.remove_audio_context(self._context_id)
self._context_id = None
if not self._started:
await self.start_ttfb_metrics()
yield TTSStartedFrame()
@@ -473,9 +490,6 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
logger.error(f"{self} error sending message: {e}")
yield TTSStoppedFrame()
self._started = False
if self._context_id:
await self.remove_audio_context(self._context_id)
self._context_id = None
return
yield None
except Exception as e:

View File

@@ -4,6 +4,7 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Fireworks AI service implementation using OpenAI-compatible interface."""
from typing import List
@@ -21,10 +22,10 @@ class FireworksLLMService(OpenAILLMService):
maintaining full compatibility with OpenAI's interface and functionality.
Args:
api_key (str): The API key for accessing Fireworks AI
model (str, optional): The model identifier to use. Defaults to "accounts/fireworks/models/firefunction-v2"
base_url (str, optional): The base URL for Fireworks API. Defaults to "https://api.fireworks.ai/inference/v1"
**kwargs: Additional keyword arguments passed to OpenAILLMService
api_key: The API key for accessing Fireworks AI.
model: The model identifier to use. Defaults to "accounts/fireworks/models/firefunction-v2".
base_url: The base URL for Fireworks API. Defaults to "https://api.fireworks.ai/inference/v1".
**kwargs: Additional keyword arguments passed to OpenAILLMService.
"""
def __init__(
@@ -38,7 +39,16 @@ class FireworksLLMService(OpenAILLMService):
super().__init__(api_key=api_key, base_url=base_url, model=model, **kwargs)
def create_client(self, api_key=None, base_url=None, **kwargs):
"""Create OpenAI-compatible client for Fireworks API endpoint."""
"""Create OpenAI-compatible client for Fireworks API endpoint.
Args:
api_key: API key for authentication. If None, uses instance default.
base_url: Base URL for the API. If None, uses instance default.
**kwargs: Additional arguments passed to the client constructor.
Returns:
Configured OpenAI client instance for Fireworks API.
"""
logger.debug(f"Creating Fireworks client with api {base_url}")
return super().create_client(api_key, base_url, **kwargs)
@@ -47,7 +57,15 @@ class FireworksLLMService(OpenAILLMService):
):
"""Get chat completions from Fireworks API.
Removes OpenAI-specific parameters not supported by Fireworks.
Removes OpenAI-specific parameters not supported by Fireworks and
configures the request with Fireworks-compatible settings.
Args:
context: The OpenAI LLM context containing tools and settings.
messages: List of chat completion message parameters.
Returns:
Async generator yielding chat completion chunks from Fireworks API.
"""
params = {
"model": self.model_name,

View File

@@ -3,7 +3,8 @@
#
# SPDX-License-Identifier: BSD 2-Clause License
#
#
"""Event models and utilities for Google Gemini Multimodal Live API."""
import base64
import io
@@ -22,16 +23,37 @@ from pipecat.frames.frames import ImageRawFrame
class MediaChunk(BaseModel):
"""Represents a chunk of media data for transmission.
Parameters:
mimeType: MIME type of the media content.
data: Base64-encoded media data.
"""
mimeType: str
data: str
class ContentPart(BaseModel):
"""Represents a part of content that can contain text or media.
Parameters:
text: Text content. Defaults to None.
inlineData: Inline media data. Defaults to None.
"""
text: Optional[str] = Field(default=None, validate_default=False)
inlineData: Optional[MediaChunk] = Field(default=None, validate_default=False)
class Turn(BaseModel):
"""Represents a conversational turn in the dialogue.
Parameters:
role: The role of the speaker, either "user" or "model". Defaults to "user".
parts: List of content parts that make up the turn.
"""
role: Literal["user", "model"] = "user"
parts: List[ContentPart]
@@ -53,7 +75,15 @@ class EndSensitivity(str, Enum):
class AutomaticActivityDetection(BaseModel):
"""Configures automatic detection of activity."""
"""Configures automatic detection of voice activity.
Parameters:
disabled: Whether automatic activity detection is disabled. Defaults to None.
start_of_speech_sensitivity: Sensitivity for detecting speech start. Defaults to None.
prefix_padding_ms: Padding before speech start in milliseconds. Defaults to None.
end_of_speech_sensitivity: Sensitivity for detecting speech end. Defaults to None.
silence_duration_ms: Duration of silence to detect speech end. Defaults to None.
"""
disabled: Optional[bool] = None
start_of_speech_sensitivity: Optional[StartSensitivity] = None
@@ -63,25 +93,57 @@ class AutomaticActivityDetection(BaseModel):
class RealtimeInputConfig(BaseModel):
"""Configures the realtime input behavior."""
"""Configures the realtime input behavior.
Parameters:
automatic_activity_detection: Voice activity detection configuration. Defaults to None.
"""
automatic_activity_detection: Optional[AutomaticActivityDetection] = None
class RealtimeInput(BaseModel):
"""Contains realtime input media chunks.
Parameters:
mediaChunks: List of media chunks for realtime processing.
"""
mediaChunks: List[MediaChunk]
class ClientContent(BaseModel):
"""Content sent from client to the Gemini Live API.
Parameters:
turns: List of conversation turns. Defaults to None.
turnComplete: Whether the client's turn is complete. Defaults to False.
"""
turns: Optional[List[Turn]] = None
turnComplete: bool = False
class AudioInputMessage(BaseModel):
"""Message containing audio input data.
Parameters:
realtimeInput: Realtime input containing audio chunks.
"""
realtimeInput: RealtimeInput
@classmethod
def from_raw_audio(cls, raw_audio: bytes, sample_rate: int) -> "AudioInputMessage":
"""Create an audio input message from raw audio data.
Args:
raw_audio: Raw audio bytes.
sample_rate: Audio sample rate in Hz.
Returns:
AudioInputMessage instance with encoded audio data.
"""
data = base64.b64encode(raw_audio).decode("utf-8")
return cls(
realtimeInput=RealtimeInput(
@@ -91,10 +153,24 @@ class AudioInputMessage(BaseModel):
class VideoInputMessage(BaseModel):
"""Message containing video/image input data.
Parameters:
realtimeInput: Realtime input containing video/image chunks.
"""
realtimeInput: RealtimeInput
@classmethod
def from_image_frame(cls, frame: ImageRawFrame) -> "VideoInputMessage":
"""Create a video input message from an image frame.
Args:
frame: Image frame to encode.
Returns:
VideoInputMessage instance with encoded image data.
"""
buffer = io.BytesIO()
Image.frombytes(frame.format, frame.size, frame.image).save(buffer, format="JPEG")
data = base64.b64encode(buffer.getvalue()).decode("utf-8")
@@ -104,18 +180,44 @@ class VideoInputMessage(BaseModel):
class ClientContentMessage(BaseModel):
"""Message containing client content for the API.
Parameters:
clientContent: The client content to send.
"""
clientContent: ClientContent
class SystemInstruction(BaseModel):
"""System instruction for the model.
Parameters:
parts: List of content parts that make up the system instruction.
"""
parts: List[ContentPart]
class AudioTranscriptionConfig(BaseModel):
"""Configuration for audio transcription."""
pass
class Setup(BaseModel):
"""Setup configuration for the Gemini Live session.
Parameters:
model: Model identifier to use.
system_instruction: System instruction for the model. Defaults to None.
tools: List of available tools/functions. Defaults to None.
generation_config: Generation configuration parameters. Defaults to None.
input_audio_transcription: Input audio transcription config. Defaults to None.
output_audio_transcription: Output audio transcription config. Defaults to None.
realtime_input_config: Realtime input configuration. Defaults to None.
"""
model: str
system_instruction: Optional[SystemInstruction] = None
tools: Optional[List[dict]] = None
@@ -126,6 +228,12 @@ class Setup(BaseModel):
class Config(BaseModel):
"""Configuration message for session setup.
Parameters:
setup: Setup configuration for the session.
"""
setup: Setup
@@ -135,36 +243,86 @@ class Config(BaseModel):
class SetupComplete(BaseModel):
"""Indicates that session setup is complete."""
pass
class InlineData(BaseModel):
"""Inline data embedded in server responses.
Parameters:
mimeType: MIME type of the data.
data: Base64-encoded data content.
"""
mimeType: str
data: str
class Part(BaseModel):
"""Part of a server response containing data or text.
Parameters:
inlineData: Inline binary data. Defaults to None.
text: Text content. Defaults to None.
"""
inlineData: Optional[InlineData] = None
text: Optional[str] = None
class ModelTurn(BaseModel):
"""Represents a turn from the model in the conversation.
Parameters:
parts: List of content parts in the model's response.
"""
parts: List[Part]
class ServerContentInterrupted(BaseModel):
"""Indicates server content was interrupted.
Parameters:
interrupted: Whether the content was interrupted.
"""
interrupted: bool
class ServerContentTurnComplete(BaseModel):
"""Indicates the server's turn is complete.
Parameters:
turnComplete: Whether the turn is complete.
"""
turnComplete: bool
class BidiGenerateContentTranscription(BaseModel):
"""Transcription data from bidirectional content generation.
Parameters:
text: The transcribed text content.
"""
text: str
class ServerContent(BaseModel):
"""Content sent from server to client.
Parameters:
modelTurn: Model's conversational turn. Defaults to None.
interrupted: Whether content was interrupted. Defaults to None.
turnComplete: Whether the turn is complete. Defaults to None.
inputTranscription: Transcription of input audio. Defaults to None.
outputTranscription: Transcription of output audio. Defaults to None.
"""
modelTurn: Optional[ModelTurn] = None
interrupted: Optional[bool] = None
turnComplete: Optional[bool] = None
@@ -173,12 +331,26 @@ class ServerContent(BaseModel):
class FunctionCall(BaseModel):
"""Represents a function call from the model.
Parameters:
id: Unique identifier for the function call.
name: Name of the function to call.
args: Arguments to pass to the function.
"""
id: str
name: str
args: dict
class ToolCall(BaseModel):
"""Contains one or more function calls.
Parameters:
functionCalls: List of function calls to execute.
"""
functionCalls: List[FunctionCall]
@@ -193,14 +365,32 @@ class Modality(str, Enum):
class ModalityTokenCount(BaseModel):
"""Token count for a specific modality."""
"""Token count for a specific modality.
Parameters:
modality: The modality type.
tokenCount: Number of tokens for this modality.
"""
modality: Modality
tokenCount: int
class UsageMetadata(BaseModel):
"""Usage metadata about the response."""
"""Usage metadata about the API response.
Parameters:
promptTokenCount: Number of tokens in the prompt. Defaults to None.
cachedContentTokenCount: Number of cached content tokens. Defaults to None.
responseTokenCount: Number of tokens in the response. Defaults to None.
toolUsePromptTokenCount: Number of tokens for tool use prompts. Defaults to None.
thoughtsTokenCount: Number of tokens for model thoughts. Defaults to None.
totalTokenCount: Total number of tokens used. Defaults to None.
promptTokensDetails: Detailed breakdown of prompt tokens by modality. Defaults to None.
cacheTokensDetails: Detailed breakdown of cache tokens by modality. Defaults to None.
responseTokensDetails: Detailed breakdown of response tokens by modality. Defaults to None.
toolUsePromptTokensDetails: Detailed breakdown of tool use tokens by modality. Defaults to None.
"""
promptTokenCount: Optional[int] = None
cachedContentTokenCount: Optional[int] = None
@@ -215,6 +405,15 @@ class UsageMetadata(BaseModel):
class ServerEvent(BaseModel):
"""Server event received from the Gemini Live API.
Parameters:
setupComplete: Setup completion notification. Defaults to None.
serverContent: Content from the server. Defaults to None.
toolCall: Tool/function call request. Defaults to None.
usageMetadata: Token usage metadata. Defaults to None.
"""
setupComplete: Optional[SetupComplete] = None
serverContent: Optional[ServerContent] = None
toolCall: Optional[ToolCall] = None
@@ -222,6 +421,14 @@ class ServerEvent(BaseModel):
def parse_server_event(str):
"""Parse a server event from JSON string.
Args:
str: JSON string containing the server event.
Returns:
ServerEvent instance if parsing succeeds, None otherwise.
"""
try:
evt = json.loads(str)
return ServerEvent.model_validate(evt)
@@ -231,7 +438,12 @@ def parse_server_event(str):
class ContextWindowCompressionConfig(BaseModel):
"""Configuration for context window compression."""
"""Configuration for context window compression.
Parameters:
sliding_window: Whether to use sliding window compression. Defaults to True.
trigger_tokens: Token count threshold to trigger compression. Defaults to None.
"""
sliding_window: Optional[bool] = Field(default=True)
trigger_tokens: Optional[int] = Field(default=None)

View File

@@ -4,6 +4,13 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Google Gemini Multimodal Live API service implementation.
This module provides real-time conversational AI capabilities using Google's
Gemini Multimodal Live API, supporting both text and audio modalities with
voice transcription, streaming responses, and tool usage.
"""
import base64
import json
import time
@@ -58,9 +65,10 @@ from pipecat.services.openai.llm import (
OpenAIUserContextAggregator,
)
from pipecat.transcriptions.language import Language
from pipecat.utils.asyncio.watchdog_async_iterator import WatchdogAsyncIterator
from pipecat.utils.string import match_endofsentence
from pipecat.utils.time import time_now_iso8601
from pipecat.utils.tracing.service_decorators import traced_gemini_live, traced_stt, traced_tts
from pipecat.utils.tracing.service_decorators import traced_gemini_live, traced_stt
from . import events
@@ -78,7 +86,11 @@ def language_to_gemini_language(language: Language) -> Optional[str]:
Source:
https://ai.google.dev/api/generate-content#MediaResolution
Returns None if the language is not supported by Gemini Live.
Args:
language: The language enum value to convert.
Returns:
The Gemini language code string, or None if the language is not supported.
"""
language_map = {
# Arabic
@@ -165,8 +177,22 @@ def language_to_gemini_language(language: Language) -> Optional[str]:
class GeminiMultimodalLiveContext(OpenAILLMContext):
"""Extended OpenAI context for Gemini Multimodal Live API.
Provides Gemini-specific context management including system instruction
extraction and message format conversion for the Live API.
"""
@staticmethod
def upgrade(obj: OpenAILLMContext) -> "GeminiMultimodalLiveContext":
"""Upgrade an OpenAI context to Gemini context.
Args:
obj: The OpenAI context to upgrade.
Returns:
The upgraded Gemini context instance.
"""
if isinstance(obj, OpenAILLMContext) and not isinstance(obj, GeminiMultimodalLiveContext):
logger.debug(f"Upgrading to Gemini Multimodal Live Context: {obj}")
obj.__class__ = GeminiMultimodalLiveContext
@@ -177,6 +203,11 @@ class GeminiMultimodalLiveContext(OpenAILLMContext):
pass
def extract_system_instructions(self):
"""Extract system instructions from context messages.
Returns:
Combined system instruction text from all system messages.
"""
system_instruction = ""
for item in self.messages:
if item.get("role") == "system":
@@ -188,6 +219,11 @@ class GeminiMultimodalLiveContext(OpenAILLMContext):
return system_instruction
def get_messages_for_initializing_history(self):
"""Get messages formatted for Gemini history initialization.
Returns:
List of messages in Gemini format for conversation history.
"""
messages = []
for item in self.messages:
role = item.get("role")
@@ -215,7 +251,19 @@ class GeminiMultimodalLiveContext(OpenAILLMContext):
class GeminiMultimodalLiveUserContextAggregator(OpenAIUserContextAggregator):
"""User context aggregator for Gemini Multimodal Live.
Extends OpenAI user aggregator to handle Gemini-specific message passing
while maintaining compatibility with the standard aggregation pipeline.
"""
async def process_frame(self, frame, direction):
"""Process incoming frames for user context aggregation.
Args:
frame: The frame to process.
direction: The frame processing direction.
"""
await super().process_frame(frame, direction)
# kind of a hack just to pass the LLMMessagesAppendFrame through, but it's fine for now
if isinstance(frame, LLMMessagesAppendFrame):
@@ -223,15 +271,33 @@ class GeminiMultimodalLiveUserContextAggregator(OpenAIUserContextAggregator):
class GeminiMultimodalLiveAssistantContextAggregator(OpenAIAssistantContextAggregator):
# The LLMAssistantContextAggregator uses TextFrames to aggregate the LLM output,
# but the GeminiMultimodalLiveAssistantContextAggregator pushes LLMTextFrames and TTSTextFrames. We
# need to override this proces_frame for LLMTextFrame, so that only the TTSTextFrames
# are process. This ensures that the context gets only one set of messages.
"""Assistant context aggregator for Gemini Multimodal Live.
Handles assistant response aggregation while filtering out LLMTextFrames
to prevent duplicate context entries, as Gemini Live pushes both
LLMTextFrames and TTSTextFrames.
"""
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process incoming frames for assistant context aggregation.
Args:
frame: The frame to process.
direction: The frame processing direction.
"""
# The LLMAssistantContextAggregator uses TextFrames to aggregate the LLM output,
# but the GeminiMultimodalLiveAssistantContextAggregator pushes LLMTextFrames and TTSTextFrames. We
# need to override this proces_frame for LLMTextFrame, so that only the TTSTextFrames
# are process. This ensures that the context gets only one set of messages.
if not isinstance(frame, LLMTextFrame):
await super().process_frame(frame, direction)
async def handle_user_image_frame(self, frame: UserImageRawFrame):
"""Handle user image frames.
Args:
frame: The user image frame to handle.
"""
# We don't want to store any images in the context. Revisit this later
# when the API evolves.
pass
@@ -239,17 +305,36 @@ class GeminiMultimodalLiveAssistantContextAggregator(OpenAIAssistantContextAggre
@dataclass
class GeminiMultimodalLiveContextAggregatorPair:
"""Pair of user and assistant context aggregators for Gemini Multimodal Live.
Parameters:
_user: The user context aggregator instance.
_assistant: The assistant context aggregator instance.
"""
_user: GeminiMultimodalLiveUserContextAggregator
_assistant: GeminiMultimodalLiveAssistantContextAggregator
def user(self) -> GeminiMultimodalLiveUserContextAggregator:
"""Get the user context aggregator.
Returns:
The user context aggregator instance.
"""
return self._user
def assistant(self) -> GeminiMultimodalLiveAssistantContextAggregator:
"""Get the assistant context aggregator.
Returns:
The assistant context aggregator instance.
"""
return self._assistant
class GeminiMultimodalModalities(Enum):
"""Supported modalities for Gemini Multimodal Live."""
TEXT = "TEXT"
AUDIO = "AUDIO"
@@ -264,7 +349,15 @@ class GeminiMediaResolution(str, Enum):
class GeminiVADParams(BaseModel):
"""Voice Activity Detection parameters."""
"""Voice Activity Detection parameters for Gemini Live.
Parameters:
disabled: Whether to disable VAD. Defaults to None.
start_sensitivity: Sensitivity for speech start detection. Defaults to None.
end_sensitivity: Sensitivity for speech end detection. Defaults to None.
prefix_padding_ms: Prefix padding in milliseconds. Defaults to None.
silence_duration_ms: Silence duration threshold in milliseconds. Defaults to None.
"""
disabled: Optional[bool] = Field(default=None)
start_sensitivity: Optional[events.StartSensitivity] = Field(default=None)
@@ -274,7 +367,12 @@ class GeminiVADParams(BaseModel):
class ContextWindowCompressionParams(BaseModel):
"""Parameters for context window compression."""
"""Parameters for context window compression in Gemini Live.
Parameters:
enabled: Whether compression is enabled. Defaults to False.
trigger_tokens: Token count to trigger compression. None uses 80% of context window.
"""
enabled: bool = Field(default=False)
trigger_tokens: Optional[int] = Field(
@@ -283,6 +381,23 @@ class ContextWindowCompressionParams(BaseModel):
class InputParams(BaseModel):
"""Input parameters for Gemini Multimodal Live generation.
Parameters:
frequency_penalty: Frequency penalty for generation (0.0-2.0). Defaults to None.
max_tokens: Maximum tokens to generate. Must be >= 1. Defaults to 4096.
presence_penalty: Presence penalty for generation (0.0-2.0). Defaults to None.
temperature: Sampling temperature (0.0-2.0). Defaults to None.
top_k: Top-k sampling parameter. Must be >= 0. Defaults to None.
top_p: Top-p sampling parameter (0.0-1.0). Defaults to None.
modalities: Response modalities. Defaults to AUDIO.
language: Language for generation. Defaults to EN_US.
media_resolution: Media resolution setting. Defaults to UNSPECIFIED.
vad: Voice activity detection parameters. Defaults to None.
context_window_compression: Context compression settings. Defaults to None.
extra: Additional parameters. Defaults to empty dict.
"""
frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0)
max_tokens: Optional[int] = Field(default=4096, ge=1)
presence_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0)
@@ -309,23 +424,18 @@ class GeminiMultimodalLiveLLMService(LLMService):
responses, and tool usage.
Args:
api_key (str): Google AI API key
base_url (str, optional): API endpoint base URL. Defaults to
"generativelanguage.googleapis.com/ws/google.ai.generativelanguage.v1beta.GenerativeService.BidiGenerateContent".
model (str, optional): Model identifier to use. Defaults to
"models/gemini-2.0-flash-live-001".
voice_id (str, optional): TTS voice identifier. Defaults to "Charon".
start_audio_paused (bool, optional): Whether to start with audio input paused.
Defaults to False.
start_video_paused (bool, optional): Whether to start with video input paused.
Defaults to False.
system_instruction (str, optional): System prompt for the model. Defaults to None.
tools (Union[List[dict], ToolsSchema], optional): Tools/functions available to the model.
Defaults to None.
params (InputParams, optional): Configuration parameters for the model.
Defaults to InputParams().
inference_on_context_initialization (bool, optional): Whether to generate a response
when context is first set. Defaults to True.
api_key: Google AI API key for authentication.
base_url: API endpoint base URL. Defaults to the official Gemini Live endpoint.
model: Model identifier to use. Defaults to "models/gemini-2.0-flash-live-001".
voice_id: TTS voice identifier. Defaults to "Charon".
start_audio_paused: Whether to start with audio input paused. Defaults to False.
start_video_paused: Whether to start with video input paused. Defaults to False.
system_instruction: System prompt for the model. Defaults to None.
tools: Tools/functions available to the model. Defaults to None.
params: Configuration parameters for the model. Defaults to InputParams().
inference_on_context_initialization: Whether to generate a response when context
is first set. Defaults to True.
**kwargs: Additional arguments passed to parent LLMService.
"""
# Overriding the default adapter to use the Gemini one.
@@ -407,19 +517,43 @@ class GeminiMultimodalLiveLLMService(LLMService):
}
def can_generate_metrics(self) -> bool:
"""Check if the service can generate usage metrics.
Returns:
True as Gemini Live supports token usage metrics.
"""
return True
def set_audio_input_paused(self, paused: bool):
"""Set the audio input pause state.
Args:
paused: Whether to pause audio input.
"""
self._audio_input_paused = paused
def set_video_input_paused(self, paused: bool):
"""Set the video input pause state.
Args:
paused: Whether to pause video input.
"""
self._video_input_paused = paused
def set_model_modalities(self, modalities: GeminiMultimodalModalities):
"""Set the model response modalities.
Args:
modalities: The modalities to use for responses.
"""
self._settings["modalities"] = modalities
def set_language(self, language: Language):
"""Set the language for generation."""
"""Set the language for generation.
Args:
language: The language to use for generation.
"""
self._language = language
self._language_code = language_to_gemini_language(language) or "en-US"
self._settings["language"] = self._language_code
@@ -432,6 +566,9 @@ class GeminiMultimodalLiveLLMService(LLMService):
way to trigger the pipeline. This sends the history to the server. The `inference_on_context_initialization`
flag controls whether to set the turnComplete flag when we do this. Without that flag, the model will
not respond. This is often what we want when setting the context at the beginning of a conversation.
Args:
context: The OpenAI LLM context to set.
"""
if self._context:
logger.error(
@@ -446,14 +583,29 @@ class GeminiMultimodalLiveLLMService(LLMService):
#
async def start(self, frame: StartFrame):
"""Start the service and establish websocket connection.
Args:
frame: The start frame.
"""
await super().start(frame)
await self._connect()
async def stop(self, frame: EndFrame):
"""Stop the service and close connections.
Args:
frame: The end frame.
"""
await super().stop(frame)
await self._disconnect()
async def cancel(self, frame: CancelFrame):
"""Cancel the service and close connections.
Args:
frame: The cancel frame.
"""
await super().cancel(frame)
await self._disconnect()
@@ -488,6 +640,12 @@ class GeminiMultimodalLiveLLMService(LLMService):
#
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process incoming frames for the Gemini Live service.
Args:
frame: The frame to process.
direction: The frame processing direction.
"""
await super().process_frame(frame, direction)
if isinstance(frame, TranscriptionFrame):
@@ -543,6 +701,11 @@ class GeminiMultimodalLiveLLMService(LLMService):
#
async def send_client_event(self, event):
"""Send a client event to the Gemini Live API.
Args:
event: The event to send.
"""
await self._ws_send(event.model_dump(exclude_none=True))
async def _connect(self):
@@ -686,7 +849,7 @@ class GeminiMultimodalLiveLLMService(LLMService):
#
async def _receive_task_handler(self):
async for message in self._websocket:
async for message in WatchdogAsyncIterator(self._websocket, manager=self.task_manager):
evt = events.parse_server_event(message)
# logger.debug(f"Received event: {message[:500]}")
# logger.debug(f"Received event: {evt}")
@@ -708,8 +871,6 @@ class GeminiMultimodalLiveLLMService(LLMService):
await self._handle_evt_error(evt)
# errors are fatal, so exit the receive loop
return
else:
pass
#
#
@@ -1032,22 +1193,19 @@ class GeminiMultimodalLiveLLMService(LLMService):
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
) -> GeminiMultimodalLiveContextAggregatorPair:
"""Create an instance of GeminiMultimodalLiveContextAggregatorPair from
an OpenAILLMContext. Constructor keyword arguments for both the user and
assistant aggregators can be provided.
"""Create an instance of GeminiMultimodalLiveContextAggregatorPair from an OpenAILLMContext.
Constructor keyword arguments for both the user and assistant aggregators can be provided.
Args:
context (OpenAILLMContext): The LLM context.
user_params (LLMUserAggregatorParams, optional): User aggregator
parameters.
assistant_params (LLMAssistantAggregatorParams, optional): User
aggregator parameters.
context: The LLM context to use.
user_params: User aggregator parameters. Defaults to LLMUserAggregatorParams().
assistant_params: Assistant aggregator parameters. Defaults to LLMAssistantAggregatorParams().
Returns:
GeminiMultimodalLiveContextAggregatorPair: A pair of context
aggregators, one for the user and one for the assistant,
encapsulated in an GeminiMultimodalLiveContextAggregatorPair.
"""
context.set_llm_adapter(self.get_llm_adapter())

View File

@@ -25,6 +25,7 @@ from pipecat.frames.frames import (
from pipecat.services.gladia.config import GladiaInputParams
from pipecat.services.stt_service import STTService
from pipecat.transcriptions.language import Language
from pipecat.utils.asyncio.watchdog_async_iterator import WatchdogAsyncIterator
from pipecat.utils.time import time_now_iso8601
from pipecat.utils.tracing.service_decorators import traced_stt
@@ -391,8 +392,8 @@ class GladiaSTTService(STTService):
await self._send_buffered_audio()
# Start tasks
self._receive_task = asyncio.create_task(self._receive_task_handler())
self._keepalive_task = asyncio.create_task(self._keepalive_task_handler())
self._receive_task = self.create_task(self._receive_task_handler())
self._keepalive_task = self.create_task(self._keepalive_task_handler())
# Wait for tasks to complete
await asyncio.gather(self._receive_task, self._keepalive_task)
@@ -403,9 +404,9 @@ class GladiaSTTService(STTService):
# Clean up tasks
if self._receive_task:
self._receive_task.cancel()
await self.cancel_task(self._receive_task)
if self._keepalive_task:
self._keepalive_task.cancel()
await self.cancel_task(self._keepalive_task)
# Attempt reconnect using helper
if not await self._maybe_reconnect():
@@ -484,9 +485,11 @@ class GladiaSTTService(STTService):
async def _keepalive_task_handler(self):
"""Send periodic empty audio chunks to keep the connection alive."""
try:
KEEPALIVE_SLEEP = 20 if self.task_manager.task_watchdog_enabled else 3
while self._connection_active:
# Send keepalive every 20 seconds (Gladia times out after 30 seconds)
await asyncio.sleep(20)
self.reset_watchdog()
# Send keepalive (Gladia times out after 30 seconds)
await asyncio.sleep(KEEPALIVE_SLEEP)
if self._websocket and not self._websocket.closed:
# Send an empty audio chunk as keepalive
empty_audio = b""
@@ -501,7 +504,7 @@ class GladiaSTTService(STTService):
async def _receive_task_handler(self):
try:
async for message in self._websocket:
async for message in WatchdogAsyncIterator(self._websocket, manager=self.task_manager):
content = json.loads(message)
# Handle audio chunk acknowledgments
@@ -559,6 +562,8 @@ class GladiaSTTService(STTService):
translation, "", time_now_iso8601(), translated_language
)
)
self.reset_watchdog()
except websockets.exceptions.ConnectionClosed:
# Expected when closing the connection
pass

View File

@@ -4,6 +4,12 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Google Gemini integration for Pipecat.
This module provides Google Gemini integration for the Pipecat framework,
including LLM services, context management, and message aggregation.
"""
import base64
import io
import json
@@ -47,6 +53,7 @@ from pipecat.services.openai.llm import (
OpenAIAssistantContextAggregator,
OpenAIUserContextAggregator,
)
from pipecat.utils.asyncio.watchdog_async_iterator import WatchdogAsyncIterator
from pipecat.utils.tracing.service_decorators import traced_llm
# Suppress gRPC fork warnings
@@ -70,7 +77,14 @@ except ModuleNotFoundError as e:
class GoogleUserContextAggregator(OpenAIUserContextAggregator):
"""Google-specific user context aggregator.
Extends OpenAI user context aggregator to handle Google AI's specific
Content and Part message format for user messages.
"""
async def push_aggregation(self):
"""Push aggregated user text as a Google Content message."""
if len(self._aggregation) > 0:
self._context.add_message(Content(role="user", parts=[Part(text=self._aggregation)]))
@@ -87,10 +101,26 @@ class GoogleUserContextAggregator(OpenAIUserContextAggregator):
class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator):
"""Google-specific assistant context aggregator.
Extends OpenAI assistant context aggregator to handle Google AI's specific
Content and Part message format for assistant responses and function calls.
"""
async def handle_aggregation(self, aggregation: str):
"""Handle aggregated assistant text response.
Args:
aggregation: The aggregated text response from the assistant.
"""
self._context.add_message(Content(role="model", parts=[Part(text=aggregation)]))
async def handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame):
"""Handle function call in progress frame.
Args:
frame: Frame containing function call details.
"""
self._context.add_message(
Content(
role="model",
@@ -119,6 +149,11 @@ class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator):
)
async def handle_function_call_result(self, frame: FunctionCallResultFrame):
"""Handle function call result frame.
Args:
frame: Frame containing function call result.
"""
if frame.result:
await self._update_function_call_result(
frame.function_name, frame.tool_call_id, frame.result
@@ -129,6 +164,11 @@ class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator):
)
async def handle_function_call_cancel(self, frame: FunctionCallCancelFrame):
"""Handle function call cancellation frame.
Args:
frame: Frame containing function call cancellation details.
"""
await self._update_function_call_result(
frame.function_name, frame.tool_call_id, "CANCELLED"
)
@@ -143,6 +183,11 @@ class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator):
part.function_response.response = {"value": json.dumps(result)}
async def handle_user_image_frame(self, frame: UserImageRawFrame):
"""Handle user image frame.
Args:
frame: Frame containing user image data and request context.
"""
await self._update_function_call_result(
frame.request.function_name, frame.request.tool_call_id, "COMPLETED"
)
@@ -156,17 +201,45 @@ class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator):
@dataclass
class GoogleContextAggregatorPair:
"""Pair of Google context aggregators for user and assistant messages.
Parameters:
_user: User context aggregator for handling user messages.
_assistant: Assistant context aggregator for handling assistant responses.
"""
_user: GoogleUserContextAggregator
_assistant: GoogleAssistantContextAggregator
def user(self) -> GoogleUserContextAggregator:
"""Get the user context aggregator.
Returns:
The user context aggregator instance.
"""
return self._user
def assistant(self) -> GoogleAssistantContextAggregator:
"""Get the assistant context aggregator.
Returns:
The assistant context aggregator instance.
"""
return self._assistant
class GoogleLLMContext(OpenAILLMContext):
"""Google AI LLM context that extends OpenAI context for Google-specific formatting.
This class handles conversion between OpenAI-style messages and Google AI's
Content/Part format, including system messages, function calls, and media.
Args:
messages: Initial messages in OpenAI format.
tools: Available tools/functions for the model.
tool_choice: Tool choice configuration.
"""
def __init__(
self,
messages: Optional[List[dict]] = None,
@@ -178,6 +251,14 @@ class GoogleLLMContext(OpenAILLMContext):
@staticmethod
def upgrade_to_google(obj: OpenAILLMContext) -> "GoogleLLMContext":
"""Upgrade an OpenAI context to a Google context.
Args:
obj: OpenAI LLM context to upgrade.
Returns:
GoogleLLMContext instance with converted messages.
"""
if isinstance(obj, OpenAILLMContext) and not isinstance(obj, GoogleLLMContext):
logger.debug(f"Upgrading to Google: {obj}")
obj.__class__ = GoogleLLMContext
@@ -185,10 +266,20 @@ class GoogleLLMContext(OpenAILLMContext):
return obj
def set_messages(self, messages: List):
"""Set messages and restructure them for Google format.
Args:
messages: List of messages to set.
"""
self._messages[:] = messages
self._restructure_from_openai_messages()
def add_messages(self, messages: List):
"""Add messages to the context, converting to Google format as needed.
Args:
messages: List of messages to add (can be mixed formats).
"""
# Convert each message individually
converted_messages = []
for msg in messages:
@@ -205,6 +296,11 @@ class GoogleLLMContext(OpenAILLMContext):
self._messages.extend(converted_messages)
def get_messages_for_logging(self):
"""Get messages formatted for logging with sensitive data redacted.
Returns:
List of message dictionaries with inline data redacted.
"""
msgs = []
for message in self.messages:
obj = message.to_json_dict()
@@ -221,6 +317,14 @@ class GoogleLLMContext(OpenAILLMContext):
def add_image_frame_message(
self, *, format: str, size: tuple[int, int], image: bytes, text: str = None
):
"""Add an image message to the context.
Args:
format: Image format (e.g., 'RGB', 'RGBA').
size: Image dimensions as (width, height).
image: Raw image bytes.
text: Optional text to accompany the image.
"""
buffer = io.BytesIO()
Image.frombytes(format, size, image).save(buffer, format="JPEG")
@@ -234,6 +338,12 @@ class GoogleLLMContext(OpenAILLMContext):
def add_audio_frames_message(
self, *, audio_frames: list[AudioRawFrame], text: str = "Audio follows"
):
"""Add audio frames as a message to the context.
Args:
audio_frames: List of audio frames to add.
text: Text description of the audio content.
"""
if not audio_frames:
return
@@ -447,17 +557,37 @@ class GoogleLLMContext(OpenAILLMContext):
class GoogleLLMService(LLMService):
"""This class implements inference with Google's AI models.
"""Google AI (Gemini) LLM service implementation.
This service translates internally from OpenAILLMContext to the messages format
expected by the Google AI model. We are using the OpenAILLMContext as a lingua
franca for all LLM services, so that it is easy to switch between different LLMs.
This class implements inference with Google's AI models, translating internally
from OpenAILLMContext to the messages format expected by the Google AI model.
We use OpenAILLMContext as a lingua franca for all LLM services to enable
easy switching between different LLMs.
Args:
api_key: Google AI API key for authentication.
model: Model name to use. Defaults to "gemini-2.0-flash".
params: Input parameters for the model.
system_instruction: System instruction/prompt for the model.
tools: List of available tools/functions.
tool_config: Configuration for tool usage.
**kwargs: Additional arguments passed to parent class.
"""
# Overriding the default adapter to use the Gemini one.
adapter_class = GeminiLLMAdapter
class InputParams(BaseModel):
"""Input parameters for Google AI models.
Parameters:
max_tokens: Maximum number of tokens to generate.
temperature: Sampling temperature between 0.0 and 2.0.
top_k: Top-k sampling parameter.
top_p: Top-p sampling parameter between 0.0 and 1.0.
extra: Additional parameters as a dictionary.
"""
max_tokens: Optional[int] = Field(default=4096, ge=1)
temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0)
top_k: Optional[int] = Field(default=None, ge=0)
@@ -494,6 +624,11 @@ class GoogleLLMService(LLMService):
self._tool_config = tool_config
def can_generate_metrics(self) -> bool:
"""Check if the service can generate usage metrics.
Returns:
True, as Google AI provides token usage metrics.
"""
return True
def _create_client(self, api_key: str):
@@ -557,7 +692,7 @@ class GoogleLLMService(LLMService):
)
function_calls = []
async for chunk in response:
async for chunk in WatchdogAsyncIterator(response, manager=self.task_manager):
# Stop TTFB metrics after the first chunk
await self.stop_ttfb_metrics()
if chunk.usage_metadata:
@@ -650,6 +785,12 @@ class GoogleLLMService(LLMService):
await self.push_frame(LLMFullResponseEndFrame())
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process incoming frames and handle different frame types.
Args:
frame: The frame to process.
direction: Direction of frame processing.
"""
await super().process_frame(frame, direction)
context = None
@@ -678,16 +819,15 @@ class GoogleLLMService(LLMService):
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
) -> GoogleContextAggregatorPair:
"""Create an instance of GoogleContextAggregatorPair from an
OpenAILLMContext. Constructor keyword arguments for both the user and
assistant aggregators can be provided.
"""Create Google-specific context aggregators.
Creates a pair of context aggregators optimized for Google's message format,
including support for function calls, tool usage, and image handling.
Args:
context (OpenAILLMContext): The LLM context.
user_params (LLMUserAggregatorParams, optional): User aggregator
parameters.
assistant_params (LLMAssistantAggregatorParams, optional): User
aggregator parameters.
context: The LLM context to create aggregators for.
user_params: Parameters for user message aggregation.
assistant_params: Parameters for assistant message aggregation.
Returns:
GoogleContextAggregatorPair: A pair of context aggregators, one for

View File

@@ -11,6 +11,7 @@ from openai import AsyncStream
from openai.types.chat import ChatCompletionChunk
from pipecat.services.llm_service import FunctionCallFromLLM
from pipecat.utils.asyncio.watchdog_async_iterator import WatchdogAsyncIterator
# Suppress gRPC fork warnings
os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false"
@@ -53,7 +54,7 @@ class GoogleLLMOpenAIBetaService(OpenAILLMService):
context
)
async for chunk in chunk_stream:
async for chunk in WatchdogAsyncIterator(chunk_stream, manager=self.task_manager):
if chunk.usage:
tokens = LLMTokenUsage(
prompt_tokens=chunk.usage.prompt_tokens,

View File

@@ -9,6 +9,7 @@ import json
import os
import time
from pipecat.utils.asyncio.watchdog_async_iterator import WatchdogAsyncIterator
from pipecat.utils.tracing.service_decorators import traced_stt
# Suppress gRPC fork warnings
@@ -436,7 +437,6 @@ class GoogleSTTService(STTService):
self._location = location
self._stream = None
self._config = None
self._request_queue = asyncio.Queue()
self._streaming_task = None
# Used for keep-alive logic
@@ -683,23 +683,15 @@ class GoogleSTTService(STTService):
),
)
self._request_queue = asyncio.Queue()
self._streaming_task = self.create_task(self._stream_audio())
async def _disconnect(self):
"""Clean up streaming recognition resources."""
if self._streaming_task:
logger.debug("Disconnecting from Google Speech-to-Text")
# Send sentinel value to stop request generator
await self._request_queue.put(None)
await self.cancel_task(self._streaming_task)
self._streaming_task = None
# Clear any remaining items in the queue
while not self._request_queue.empty():
try:
self._request_queue.get_nowait()
self._request_queue.task_done()
except asyncio.QueueEmpty:
break
async def _request_generator(self):
"""Generates requests for the streaming recognize method."""
@@ -714,29 +706,23 @@ class GoogleSTTService(STTService):
)
while True:
try:
audio_data = await self._request_queue.get()
if audio_data is None: # Sentinel value to stop
break
audio_data = await self._request_queue.get()
# Check streaming limit
if (int(time.time() * 1000) - self._stream_start_time) > self.STREAMING_LIMIT:
logger.debug("Streaming limit reached, initiating graceful reconnection")
# Instead of immediate reconnection, we'll break and let the stream close naturally
self._last_audio_input = self._audio_input
self._audio_input = []
self._restart_counter += 1
# Put the current audio chunk back in the queue
await self._request_queue.put(audio_data)
break
self._request_queue.task_done()
self._audio_input.append(audio_data)
yield cloud_speech.StreamingRecognizeRequest(audio=audio_data)
except asyncio.CancelledError:
# Check streaming limit
if (int(time.time() * 1000) - self._stream_start_time) > self.STREAMING_LIMIT:
logger.debug("Streaming limit reached, initiating graceful reconnection")
# Instead of immediate reconnection, we'll break and let the stream close naturally
self._last_audio_input = self._audio_input
self._audio_input = []
self._restart_counter += 1
# Put the current audio chunk back in the queue
await self._request_queue.put(audio_data)
break
finally:
self._request_queue.task_done()
self._audio_input.append(audio_data)
yield cloud_speech.StreamingRecognizeRequest(audio=audio_data)
except Exception as e:
logger.error(f"Error in request generator: {e}")
@@ -750,6 +736,7 @@ class GoogleSTTService(STTService):
if self._request_queue.empty():
# wait for 10ms in case we don't have audio
await asyncio.sleep(0.01)
self.reset_watchdog()
continue
# Start bi-directional streaming
@@ -765,7 +752,6 @@ class GoogleSTTService(STTService):
logger.debug("Reconnecting stream after timeout")
# Reset stream start time
self._stream_start_time = int(time.time() * 1000)
continue
else:
# Normal stream end
break
@@ -775,7 +761,6 @@ class GoogleSTTService(STTService):
await asyncio.sleep(1) # Brief delay before reconnecting
self._stream_start_time = int(time.time() * 1000)
continue
except Exception as e:
logger.error(f"Error in streaming task: {e}")
@@ -799,7 +784,9 @@ class GoogleSTTService(STTService):
async def _process_responses(self, streaming_recognize):
"""Process streaming recognition responses."""
try:
async for response in streaming_recognize:
async for response in WatchdogAsyncIterator(
streaming_recognize, manager=self.task_manager
):
# Check streaming limit
if (int(time.time() * 1000) - self._stream_start_time) > self.STREAMING_LIMIT:
logger.debug("Stream timeout reached in response processing")
@@ -847,9 +834,8 @@ class GoogleSTTService(STTService):
result=result,
)
)
except Exception as e:
logger.error(f"Error processing Google STT responses: {e}")
# Re-raise the exception to let it propagate (e.g. in the case of a timeout, propagate to _stream_audio to reconnect)
# Re-raise the exception to let it propagate (e.g. in the case of a
# timeout, propagate to _stream_audio to reconnect)
raise

View File

@@ -4,6 +4,13 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Grok LLM service implementation using OpenAI-compatible interface.
This module provides a service for interacting with Grok's API through an
OpenAI-compatible interface, including specialized token usage tracking
and context aggregation functionality.
"""
from dataclasses import dataclass
from loguru import logger
@@ -23,13 +30,33 @@ from pipecat.services.openai.llm import (
@dataclass
class GrokContextAggregatorPair:
"""Pair of context aggregators for user and assistant interactions.
Provides a convenient container for managing both user and assistant
context aggregators together for Grok LLM interactions.
Parameters:
_user: The user context aggregator instance.
_assistant: The assistant context aggregator instance.
"""
_user: OpenAIUserContextAggregator
_assistant: OpenAIAssistantContextAggregator
def user(self) -> OpenAIUserContextAggregator:
"""Get the user context aggregator.
Returns:
The user context aggregator instance.
"""
return self._user
def assistant(self) -> OpenAIAssistantContextAggregator:
"""Get the assistant context aggregator.
Returns:
The assistant context aggregator instance.
"""
return self._assistant
@@ -38,12 +65,14 @@ class GrokLLMService(OpenAILLMService):
This service extends OpenAILLMService to connect to Grok's API endpoint while
maintaining full compatibility with OpenAI's interface and functionality.
Includes specialized token usage tracking that accumulates metrics during
processing and reports final totals.
Args:
api_key (str): The API key for accessing Grok's API
base_url (str, optional): The base URL for Grok API. Defaults to "https://api.x.ai/v1"
model (str, optional): The model identifier to use. Defaults to "grok-3-beta"
**kwargs: Additional keyword arguments passed to OpenAILLMService
api_key: The API key for accessing Grok's API.
base_url: The base URL for Grok API. Defaults to "https://api.x.ai/v1".
model: The model identifier to use. Defaults to "grok-3-beta".
**kwargs: Additional keyword arguments passed to OpenAILLMService.
"""
def __init__(
@@ -63,7 +92,16 @@ class GrokLLMService(OpenAILLMService):
self._is_processing = False
def create_client(self, api_key=None, base_url=None, **kwargs):
"""Create OpenAI-compatible client for Grok API endpoint."""
"""Create OpenAI-compatible client for Grok API endpoint.
Args:
api_key: The API key to use. If None, uses instance default.
base_url: The base URL to use. If None, uses instance default.
**kwargs: Additional arguments passed to client creation.
Returns:
The configured client instance for Grok API.
"""
logger.debug(f"Creating Grok client with api {base_url}")
return super().create_client(api_key, base_url, **kwargs)
@@ -75,8 +113,8 @@ class GrokLLMService(OpenAILLMService):
them once at the end of processing.
Args:
context (OpenAILLMContext): The context to process, containing messages
and other information needed for the LLM interaction.
context: The context to process, containing messages and other
information needed for the LLM interaction.
"""
# Reset all counters and flags at the start of processing
self._prompt_tokens = 0
@@ -107,8 +145,8 @@ class GrokLLMService(OpenAILLMService):
The final accumulated totals are reported at the end of processing.
Args:
tokens (LLMTokenUsage): The token usage metrics for the current chunk
of processing, containing prompt_tokens and completion_tokens counts.
tokens: The token usage metrics for the current chunk of processing,
containing prompt_tokens and completion_tokens counts.
"""
# Only accumulate metrics during active processing
if not self._is_processing:
@@ -130,22 +168,20 @@ class GrokLLMService(OpenAILLMService):
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
) -> GrokContextAggregatorPair:
"""Create an instance of GrokContextAggregatorPair from an
OpenAILLMContext. Constructor keyword arguments for both the user and
assistant aggregators can be provided.
"""Create an instance of GrokContextAggregatorPair from an OpenAILLMContext.
Constructor keyword arguments for both the user and assistant aggregators
can be provided.
Args:
context (OpenAILLMContext): The LLM context.
user_params (LLMUserAggregatorParams, optional): User aggregator
parameters.
assistant_params (LLMAssistantAggregatorParams, optional): User
aggregator parameters.
context: The LLM context to create aggregators for.
user_params: Parameters for configuring the user aggregator.
assistant_params: Parameters for configuring the assistant aggregator.
Returns:
GrokContextAggregatorPair: A pair of context aggregators, one for
the user and one for the assistant, encapsulated in an
GrokContextAggregatorPair.
"""
context.set_llm_adapter(self.get_llm_adapter())

View File

@@ -4,6 +4,8 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Groq LLM Service implementation using OpenAI-compatible interface."""
from loguru import logger
from pipecat.services.openai.llm import OpenAILLMService
@@ -16,10 +18,10 @@ class GroqLLMService(OpenAILLMService):
maintaining full compatibility with OpenAI's interface and functionality.
Args:
api_key (str): The API key for accessing Groq's API
base_url (str, optional): The base URL for Groq API. Defaults to "https://api.groq.com/openai/v1"
model (str, optional): The model identifier to use. Defaults to "llama-3.3-70b-versatile"
**kwargs: Additional keyword arguments passed to OpenAILLMService
api_key: The API key for accessing Groq's API.
base_url: The base URL for Groq API. Defaults to "https://api.groq.com/openai/v1".
model: The model identifier to use. Defaults to "llama-3.3-70b-versatile".
**kwargs: Additional keyword arguments passed to OpenAILLMService.
"""
def __init__(
@@ -33,6 +35,15 @@ class GroqLLMService(OpenAILLMService):
super().__init__(api_key=api_key, base_url=base_url, model=model, **kwargs)
def create_client(self, api_key=None, base_url=None, **kwargs):
"""Create OpenAI-compatible client for Groq API endpoint."""
"""Create OpenAI-compatible client for Groq API endpoint.
Args:
api_key: API key for authentication. If None, uses instance api_key.
base_url: Base URL for the API. If None, uses instance base_url.
**kwargs: Additional arguments passed to the client constructor.
Returns:
An OpenAI-compatible client configured for Groq's API.
"""
logger.debug(f"Creating Groq client with api {base_url}")
return super().create_client(api_key, base_url, **kwargs)

View File

@@ -4,6 +4,12 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Image generation service implementation.
Provides base functionality for AI-powered image generation services that convert
text prompts into images.
"""
from abc import abstractmethod
from typing import AsyncGenerator
@@ -13,15 +19,46 @@ from pipecat.services.ai_service import AIService
class ImageGenService(AIService):
"""Base class for image generation services.
Processes TextFrames by using their content as prompts for image generation.
Subclasses must implement the run_image_gen method to provide actual image
generation functionality using their specific AI service.
Args:
**kwargs: Additional arguments passed to the parent AIService.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
# Renders the image. Returns an Image object.
@abstractmethod
async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
"""Generate an image from a text prompt.
This method must be implemented by subclasses to provide actual image
generation functionality using their specific AI service.
Args:
prompt: The text prompt to generate an image from.
Yields:
Frame: Frames containing the generated image (typically ImageRawFrame
or URLImageRawFrame).
"""
pass
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process frames for image generation.
TextFrames are used as prompts for image generation, while other frames
are passed through unchanged.
Args:
frame: The frame to process.
direction: The direction of frame processing.
"""
await super().process_frame(frame, direction)
if isinstance(frame, TextFrame):

View File

@@ -4,6 +4,8 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Base classes for Large Language Model services with function calling support."""
import asyncio
import inspect
from dataclasses import dataclass
@@ -41,23 +43,34 @@ FunctionCallHandler = Callable[["FunctionCallParams"], Awaitable[None]]
# Type alias for a callback function that handles the result of an LLM function call.
class FunctionCallResultCallback(Protocol):
"""Protocol for function call result callbacks.
Handles the result of an LLM function call execution.
"""
async def __call__(
self, result: Any, *, properties: Optional[FunctionCallResultProperties] = None
) -> None: ...
) -> None:
"""Call the result callback.
Args:
result: The result of the function call.
properties: Optional properties for the result.
"""
...
@dataclass
class FunctionCallParams:
"""Parameters for a function call.
Attributes:
function_name (str): The name of the function being called.
arguments (Mapping[str, Any]): The arguments for the function.
tool_call_id (str): A unique identifier for the function call.
llm (LLMService): The LLMService instance being used.
context (OpenAILLMContext): The LLM context.
result_callback (FunctionCallResultCallback): Callback to handle the result of the function call.
Parameters:
function_name: The name of the function being called.
tool_call_id: A unique identifier for the function call.
arguments: The arguments for the function.
llm: The LLMService instance being used.
context: The LLM context.
result_callback: Callback to handle the result of the function call.
"""
function_name: str
@@ -70,14 +83,14 @@ class FunctionCallParams:
@dataclass
class FunctionCallRegistryItem:
"""Represents an entry in our function call registry. This is what the user
registers.
"""Represents an entry in the function call registry.
Attributes:
function_name (Optional[str]): The name of the function.
handler (FunctionCallHandler): The handler for processing function call parameters.
cancel_on_interruption (bool): Flag indicating whether to cancel the call on interruption.
This is what the user registers when calling register_function.
Parameters:
function_name: The name of the function (None for catch-all handler).
handler: The handler for processing function call parameters.
cancel_on_interruption: Whether to cancel the call on interruption.
"""
function_name: Optional[str]
@@ -87,16 +100,17 @@ class FunctionCallRegistryItem:
@dataclass
class FunctionCallRunnerItem:
"""Represents an internal function call entry to our function call
runner. The runner executes function calls in order.
"""Internal function call entry for the function call runner.
Attributes:
registry_name (Optional[str]): The function call name registration (could be None).
function_name (str): The name of the function.
tool_call_id (str): A unique identifier for the function call.
arguments (Mapping[str, Any]): The arguments for the function.
context (OpenAILLMContext): The LLM context.
The runner executes function calls in order.
Parameters:
registry_item: The registry item containing handler information.
function_name: The name of the function.
tool_call_id: A unique identifier for the function call.
arguments: The arguments for the function.
context: The LLM context.
run_llm: Optional flag to control LLM execution after function call.
"""
registry_item: FunctionCallRegistryItem
@@ -108,22 +122,32 @@ class FunctionCallRunnerItem:
class LLMService(AIService):
"""This is the base class for all LLM services. It handles function calling
registration and execution. The class also provides event handlers.
"""Base class for all LLM services.
An event to know when an LLM service completion timeout occurs:
Handles function calling registration and execution with support for both
parallel and sequential execution modes. Provides event handlers for
completion timeouts and function call lifecycle events.
@task.event_handler("on_completion_timeout")
async def on_completion_timeout(service):
...
Args:
run_in_parallel: Whether to run function calls in parallel or sequentially.
Defaults to True.
**kwargs: Additional arguments passed to the parent AIService.
And an event to know that function calls have been received from the LLM
service and that we are going to start executing them:
Event handlers:
on_completion_timeout: Called when an LLM completion timeout occurs.
on_function_calls_started: Called when function calls are received and
execution is about to start.
@task.event_handler("on_function_calls_started")
async def on_function_calls_started(service, function_calls: Sequence[FunctionCallFromLLM]):
...
Example:
```python
@task.event_handler("on_completion_timeout")
async def on_completion_timeout(service):
logger.warning("LLM completion timed out")
@task.event_handler("on_function_calls_started")
async def on_function_calls_started(service, function_calls):
logger.info(f"Starting {len(function_calls)} function calls")
```
"""
# OpenAILLMAdapter is used as the default adapter since it aligns with most LLM implementations.
@@ -143,6 +167,11 @@ class LLMService(AIService):
self._register_event_handler("on_completion_timeout")
def get_llm_adapter(self) -> BaseLLMAdapter:
"""Get the LLM adapter instance.
Returns:
The adapter instance used for LLM communication.
"""
return self._adapter
def create_context_aggregator(
@@ -152,24 +181,57 @@ class LLMService(AIService):
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
) -> Any:
"""Create a context aggregator for managing LLM conversation context.
Must be implemented by subclasses.
Args:
context: The LLM context to create an aggregator for.
user_params: Parameters for user message aggregation.
assistant_params: Parameters for assistant message aggregation.
Returns:
A context aggregator instance.
"""
pass
async def start(self, frame: StartFrame):
"""Start the LLM service.
Args:
frame: The start frame.
"""
await super().start(frame)
if not self._run_in_parallel:
await self._create_sequential_runner_task()
async def stop(self, frame: EndFrame):
"""Stop the LLM service.
Args:
frame: The end frame.
"""
await super().stop(frame)
if not self._run_in_parallel:
await self._cancel_sequential_runner_task()
async def cancel(self, frame: CancelFrame):
"""Cancel the LLM service.
Args:
frame: The cancel frame.
"""
await super().cancel(frame)
if not self._run_in_parallel:
await self._cancel_sequential_runner_task()
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process a frame.
Args:
frame: The frame to process.
direction: The direction of frame processing.
"""
await super().process_frame(frame, direction)
if isinstance(frame, StartInterruptionFrame):
@@ -188,6 +250,18 @@ class LLMService(AIService):
*,
cancel_on_interruption: bool = True,
):
"""Register a function handler for LLM function calls.
Args:
function_name: The name of the function to handle. Use None to handle
all function calls with a catch-all handler.
handler: The function handler. Should accept a single FunctionCallParams
parameter.
start_callback: Legacy callback function (deprecated). Put initialization
code at the top of your handler instead.
cancel_on_interruption: Whether to cancel this function call when an
interruption occurs. Defaults to True.
"""
# Registering a function with the function_name set to None will run
# that handler for all functions
self._functions[function_name] = FunctionCallRegistryItem(
@@ -210,16 +284,38 @@ class LLMService(AIService):
self._start_callbacks[function_name] = start_callback
def unregister_function(self, function_name: Optional[str]):
"""Remove a registered function handler.
Args:
function_name: The name of the function handler to remove.
"""
del self._functions[function_name]
if self._start_callbacks[function_name]:
del self._start_callbacks[function_name]
def has_function(self, function_name: str):
"""Check if a function handler is registered.
Args:
function_name: The name of the function to check.
Returns:
True if the function is registered or if a catch-all handler (None)
is registered.
"""
if None in self._functions.keys():
return True
return function_name in self._functions.keys()
async def run_function_calls(self, function_calls: Sequence[FunctionCallFromLLM]):
"""Execute a sequence of function calls from the LLM.
Triggers the on_function_calls_started event and executes functions
either in parallel or sequentially based on the run_in_parallel setting.
Args:
function_calls: The function calls to execute.
"""
if len(function_calls) == 0:
return
@@ -257,7 +353,7 @@ class LLMService(AIService):
else:
await self._sequential_runner_queue.put(runner_item)
async def call_start_function(self, context: OpenAILLMContext, function_name: str):
async def _call_start_function(self, context: OpenAILLMContext, function_name: str):
if function_name in self._start_callbacks.keys():
await self._start_callbacks[function_name](function_name, self, context)
elif None in self._start_callbacks.keys():
@@ -272,6 +368,18 @@ class LLMService(AIService):
text_content: Optional[str] = None,
video_source: Optional[str] = None,
):
"""Request an image from a user.
Pushes a UserImageRequestFrame upstream to request an image from the
specified user.
Args:
user_id: The ID of the user to request an image from.
function_name: Optional function name associated with the request.
tool_call_id: Optional tool call ID associated with the request.
text_content: Optional text content/context for the image request.
video_source: Optional video source identifier.
"""
await self.push_frame(
UserImageRequestFrame(
user_id=user_id,
@@ -316,7 +424,7 @@ class LLMService(AIService):
)
# NOTE(aleix): This needs to be removed after we remove the deprecation.
await self.call_start_function(runner_item.context, runner_item.function_name)
await self._call_start_function(runner_item.context, runner_item.function_name)
# Push a function call in-progress downstream. This frame will let our
# assistant context aggregator know that we are in the middle of a

View File

@@ -1,5 +1,13 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""MCP (Model Context Protocol) client for integrating external tools with LLMs."""
import json
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Tuple
from loguru import logger
@@ -8,10 +16,12 @@ from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.utils.base_object import BaseObject
try:
from mcp import ClientSession, StdioServerParameters, types
from mcp import ClientSession, StdioServerParameters
from mcp.client.session import ClientSession
from mcp.client.session_group import SseServerParameters, StreamableHttpParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from mcp.client.streamable_http import streamablehttp_client
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use an MCP client, you need to `pip install pipecat-ai[mcp]`.")
@@ -19,26 +29,55 @@ except ModuleNotFoundError as e:
class MCPClient(BaseObject):
"""Client for Model Context Protocol (MCP) servers.
Enables integration with MCP servers to provide external tools and resources
to LLMs. Supports both stdio and SSE server connections with automatic tool
registration and schema conversion.
Args:
server_params: Server connection parameters (stdio or SSE).
**kwargs: Additional arguments passed to the parent BaseObject.
Raises:
TypeError: If server_params is not a supported parameter type.
"""
def __init__(
self,
server_params: Union[StdioServerParameters, str],
server_params: Tuple[StdioServerParameters, SseServerParameters, StreamableHttpParameters],
**kwargs,
):
super().__init__(**kwargs)
self._server_params = server_params
self._session = ClientSession
if isinstance(server_params, StdioServerParameters):
self._client = stdio_client
self._register_tools = self._stdio_register_tools
elif isinstance(server_params, str):
elif isinstance(server_params, SseServerParameters):
self._client = sse_client
self._register_tools = self._sse_register_tools
elif isinstance(server_params, StreamableHttpParameters):
self._client = streamablehttp_client
self._register_tools = self._streamable_http_register_tools
else:
raise TypeError(
f"{self} invalid argument type: `server_params` must be either StdioServerParameters or an SSE server url string."
f"{self} invalid argument type: `server_params` must be either StdioServerParameters, SseServerParameters, or StreamableHttpParameters."
)
async def register_tools(self, llm) -> ToolsSchema:
"""Register all available MCP tools with an LLM service.
Connects to the MCP server, discovers available tools, converts their
schemas to Pipecat format, and registers them with the LLM service.
Args:
llm: The Pipecat LLM service to register tools with.
Returns:
A ToolsSchema containing all successfully registered tools.
"""
tools_schema = await self._register_tools(llm)
return tools_schema
@@ -46,13 +85,13 @@ class MCPClient(BaseObject):
self, tool_name: str, tool_schema: Dict[str, Any]
) -> FunctionSchema:
"""Convert an mcp tool schema to Pipecat's FunctionSchema format.
Args:
tool_name: The name of the tool
tool_schema: The mcp tool schema
Returns:
A FunctionSchema instance
"""
logger.debug(f"Converting schema for tool '{tool_name}'")
logger.trace(f"Original schema: {json.dumps(tool_schema, indent=2)}")
@@ -71,7 +110,8 @@ class MCPClient(BaseObject):
return schema
async def _sse_register_tools(self, llm) -> ToolsSchema:
"""Register all available mcp.run tools with the LLM service.
"""Register all available mcp tools with the LLM service.
Args:
llm: The Pipecat LLM service to register tools with
Returns:
@@ -86,11 +126,11 @@ class MCPClient(BaseObject):
context: any,
result_callback: any,
) -> None:
"""Wrapper for mcp.run tool calls to match Pipecat's function call interface."""
"""Wrapper for mcp tool calls to match Pipecat's function call interface."""
logger.debug(f"Executing tool '{function_name}' with call ID: {tool_call_id}")
logger.trace(f"Tool arguments: {json.dumps(arguments, indent=2)}")
try:
async with self._client(self._server_params) as (read, write):
async with self._client(**self._server_params.model_dump()) as (read, write):
async with self._session(read, write) as session:
await session.initialize()
await self._call_tool(session, function_name, arguments, result_callback)
@@ -100,17 +140,18 @@ class MCPClient(BaseObject):
logger.exception("Full exception details:")
await result_callback(error_msg)
logger.debug("Starting registration of mcp.run tools")
tool_schemas: List[FunctionSchema] = []
logger.debug(f"SSE server parameters: {self._server_params}")
logger.debug("Starting registration of mcp tools")
async with self._client(self._server_params) as (read, write):
async with self._client(**self._server_params.model_dump()) as (read, write):
async with self._session(read, write) as session:
await session.initialize()
tools_schema = await self._list_tools(session, mcp_tool_wrapper, llm)
return tools_schema
async def _stdio_register_tools(self, llm) -> ToolsSchema:
"""Register all available mcp.run tools with the LLM service.
"""Register all available mcp tools with the LLM service.
Args:
llm: The Pipecat LLM service to register tools with
Returns:
@@ -125,7 +166,7 @@ class MCPClient(BaseObject):
context: any,
result_callback: any,
) -> None:
"""Wrapper for mcp.run tool calls to match Pipecat's function call interface."""
"""Wrapper for mcp tool calls to match Pipecat's function call interface."""
logger.debug(f"Executing tool '{function_name}' with call ID: {tool_call_id}")
logger.trace(f"Tool arguments: {json.dumps(arguments, indent=2)}")
try:
@@ -139,7 +180,7 @@ class MCPClient(BaseObject):
logger.exception("Full exception details:")
await result_callback(error_msg)
logger.debug("Starting registration of mcp.run tools")
logger.debug("Starting registration of mcp tools")
async with self._client(self._server_params) as streams:
async with self._session(streams[0], streams[1]) as session:
@@ -147,6 +188,52 @@ class MCPClient(BaseObject):
tools_schema = await self._list_tools(session, mcp_tool_wrapper, llm)
return tools_schema
async def _streamable_http_register_tools(self, llm) -> ToolsSchema:
"""Register all available mcp tools with the LLM service using streamable HTTP.
Args:
llm: The Pipecat LLM service to register tools with
Returns:
A ToolsSchema containing all registered tools
"""
async def mcp_tool_wrapper(
function_name: str,
tool_call_id: str,
arguments: Dict[str, Any],
llm: any,
context: any,
result_callback: any,
) -> None:
"""Wrapper for mcp tool calls to match Pipecat's function call interface."""
logger.debug(f"Executing tool '{function_name}' with call ID: {tool_call_id}")
logger.trace(f"Tool arguments: {json.dumps(arguments, indent=2)}")
try:
async with self._client(**self._server_params.model_dump()) as (
read_stream,
write_stream,
_,
):
async with self._session(read_stream, write_stream) as session:
await session.initialize()
await self._call_tool(session, function_name, arguments, result_callback)
except Exception as e:
error_msg = f"Error calling mcp tool {function_name}: {str(e)}"
logger.error(error_msg)
logger.exception("Full exception details:")
await result_callback(error_msg)
logger.debug("Starting registration of mcp tools using streamable HTTP")
async with self._client(**self._server_params.model_dump()) as (
read_stream,
write_stream,
_,
):
async with self._session(read_stream, write_stream) as session:
await session.initialize()
tools_schema = await self._list_tools(session, mcp_tool_wrapper, llm)
return tools_schema
async def _call_tool(self, session, function_name, arguments, result_callback):
logger.debug(f"Calling mcp tool '{function_name}'")
try:
@@ -190,8 +277,7 @@ class MCPClient(BaseObject):
try:
# Convert the schema
function_schema = self._convert_mcp_schema_to_pipecat(
tool_name,
{"description": tool.description, "input_schema": tool.inputSchema},
tool_name, {"description": tool.description, "input_schema": tool.inputSchema}
)
# Register the wrapped function

View File

@@ -29,6 +29,7 @@ from pipecat.frames.frames import (
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.tts_service import InterruptibleTTSService, TTSService
from pipecat.transcriptions.language import Language
from pipecat.utils.asyncio.watchdog_async_iterator import WatchdogAsyncIterator
from pipecat.utils.tracing.service_decorators import traced_tts
try:
@@ -221,7 +222,7 @@ class NeuphonicTTSService(InterruptibleTTSService):
self._websocket = None
async def _receive_messages(self):
async for message in self._websocket:
async for message in WatchdogAsyncIterator(self._websocket, manager=self.task_manager):
if isinstance(message, str):
msg = json.loads(message)
if msg.get("data", {}).get("audio") is not None:
@@ -232,8 +233,10 @@ class NeuphonicTTSService(InterruptibleTTSService):
await self.push_frame(frame)
async def _keepalive_task_handler(self):
KEEPALIVE_SLEEP = 10 if self.task_manager.task_watchdog_enabled else 3
while True:
await asyncio.sleep(10)
self.reset_watchdog()
await asyncio.sleep(KEEPALIVE_SLEEP)
await self._send_text("")
async def _send_text(self, text: str):

View File

@@ -4,6 +4,12 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""NVIDIA NIM API service implementation.
This module provides a service for interacting with NVIDIA's NIM (NVIDIA Inference
Microservice) API while maintaining compatibility with the OpenAI-style interface.
"""
from pipecat.metrics.metrics import LLMTokenUsage
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.openai.llm import OpenAILLMService
@@ -17,10 +23,10 @@ class NimLLMService(OpenAILLMService):
in token usage reporting between NIM (incremental) and OpenAI (final summary).
Args:
api_key (str): The API key for accessing NVIDIA's NIM API
base_url (str, optional): The base URL for NIM API. Defaults to "https://integrate.api.nvidia.com/v1"
model (str, optional): The model identifier to use. Defaults to "nvidia/llama-3.1-nemotron-70b-instruct"
**kwargs: Additional keyword arguments passed to OpenAILLMService
api_key: The API key for accessing NVIDIA's NIM API.
base_url: The base URL for NIM API. Defaults to "https://integrate.api.nvidia.com/v1".
model: The model identifier to use. Defaults to "nvidia/llama-3.1-nemotron-70b-instruct".
**kwargs: Additional keyword arguments passed to OpenAILLMService.
"""
def __init__(
@@ -47,8 +53,8 @@ class NimLLMService(OpenAILLMService):
them once at the end of processing.
Args:
context (OpenAILLMContext): The context to process, containing messages
and other information needed for the LLM interaction.
context: The context to process, containing messages and other information
needed for the LLM interaction.
"""
# Reset all counters and flags at the start of processing
self._prompt_tokens = 0
@@ -79,8 +85,8 @@ class NimLLMService(OpenAILLMService):
The final accumulated totals are reported at the end of processing.
Args:
tokens (LLMTokenUsage): The token usage metrics for the current chunk
of processing, containing prompt_tokens and completion_tokens counts.
tokens: The token usage metrics for the current chunk of processing,
containing prompt_tokens and completion_tokens counts.
"""
# Only accumulate metrics during active processing
if not self._is_processing:

View File

@@ -4,9 +4,22 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""OLLama LLM service implementation for Pipecat AI framework."""
from pipecat.services.openai.llm import OpenAILLMService
class OLLamaLLMService(OpenAILLMService):
"""OLLama LLM service that provides local language model capabilities.
This service extends OpenAILLMService to work with locally hosted OLLama models,
providing a compatible interface for running large language models locally.
Args:
model: The OLLama model to use. Defaults to "llama2".
base_url: The base URL for the OLLama API endpoint.
Defaults to "http://localhost:11434/v1".
"""
def __init__(self, *, model: str = "llama2", base_url: str = "http://localhost:11434/v1"):
super().__init__(model=model, base_url=base_url, api_key="ollama")

View File

@@ -4,6 +4,8 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Base OpenAI LLM service implementation."""
import base64
import json
from typing import Any, Dict, List, Mapping, Optional
@@ -35,20 +37,44 @@ from pipecat.processors.aggregators.openai_llm_context import (
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.llm_service import FunctionCallFromLLM, LLMService
from pipecat.utils.asyncio.watchdog_async_iterator import WatchdogAsyncIterator
from pipecat.utils.tracing.service_decorators import traced_llm
class BaseOpenAILLMService(LLMService):
"""This is the base for all services that use the AsyncOpenAI client.
"""Base class for all services that use the AsyncOpenAI client.
This service consumes OpenAILLMContextFrame frames, which contain a reference
to an OpenAILLMContext frame. The OpenAILLMContext object defines the context
sent to the LLM for a completion. This includes user, assistant and system messages
as well as tool choices and the tool, which is used if requesting function
calls from the LLM.
to an OpenAILLMContext object. The context defines what is sent to the LLM for
completion, including user, assistant, and system messages, as well as tool
choices and function call configurations.
Args:
model: The OpenAI model name to use (e.g., "gpt-4.1", "gpt-4o").
api_key: OpenAI API key. If None, uses environment variable.
base_url: Custom base URL for OpenAI API. If None, uses default.
organization: OpenAI organization ID.
project: OpenAI project ID.
default_headers: Additional HTTP headers to include in requests.
params: Input parameters for model configuration and behavior.
**kwargs: Additional arguments passed to the parent LLMService.
"""
class InputParams(BaseModel):
"""Input parameters for OpenAI model configuration.
Parameters:
frequency_penalty: Penalty for frequent tokens (-2.0 to 2.0).
presence_penalty: Penalty for new tokens (-2.0 to 2.0).
seed: Random seed for deterministic outputs.
temperature: Sampling temperature (0.0 to 2.0).
top_k: Top-k sampling parameter (currently ignored by OpenAI).
top_p: Top-p (nucleus) sampling parameter (0.0 to 1.0).
max_tokens: Maximum tokens in response (deprecated, use max_completion_tokens).
max_completion_tokens: Maximum completion tokens to generate.
extra: Additional model-specific parameters.
"""
frequency_penalty: Optional[float] = Field(
default_factory=lambda: NOT_GIVEN, ge=-2.0, le=2.0
)
@@ -110,6 +136,19 @@ class BaseOpenAILLMService(LLMService):
default_headers=None,
**kwargs,
):
"""Create an AsyncOpenAI client instance.
Args:
api_key: OpenAI API key.
base_url: Custom base URL for the API.
organization: OpenAI organization ID.
project: OpenAI project ID.
default_headers: Additional HTTP headers.
**kwargs: Additional client configuration arguments.
Returns:
Configured AsyncOpenAI client instance.
"""
return AsyncOpenAI(
api_key=api_key,
base_url=base_url,
@@ -124,11 +163,25 @@ class BaseOpenAILLMService(LLMService):
)
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.
Returns:
True, as OpenAI service supports metrics generation.
"""
return True
async def get_chat_completions(
self, context: OpenAILLMContext, messages: List[ChatCompletionMessageParam]
) -> AsyncStream[ChatCompletionChunk]:
"""Get streaming chat completions from OpenAI API.
Args:
context: The LLM context containing tools and configuration.
messages: List of chat completion messages to send.
Returns:
Async stream of chat completion chunks.
"""
params = {
"model": self.model_name,
"stream": True,
@@ -192,7 +245,7 @@ class BaseOpenAILLMService(LLMService):
context
)
async for chunk in chunk_stream:
async for chunk in WatchdogAsyncIterator(chunk_stream, manager=self.task_manager):
if chunk.usage:
tokens = LLMTokenUsage(
prompt_tokens=chunk.usage.prompt_tokens,
@@ -274,6 +327,15 @@ class BaseOpenAILLMService(LLMService):
await self.run_function_calls(function_calls)
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process frames for LLM completion requests.
Handles OpenAILLMContextFrame, LLMMessagesFrame, VisionImageRawFrame,
and LLMUpdateSettingsFrame to trigger LLM completions and manage settings.
Args:
frame: The frame to process.
direction: The direction of frame processing.
"""
await super().process_frame(frame, direction)
context = None

View File

@@ -4,6 +4,8 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""OpenAI LLM service implementation with context aggregators."""
import json
from dataclasses import dataclass
from typing import Any, Optional
@@ -26,17 +28,46 @@ from pipecat.services.openai.base_llm import BaseOpenAILLMService
@dataclass
class OpenAIContextAggregatorPair:
"""Pair of OpenAI context aggregators for user and assistant messages.
Parameters:
_user: User context aggregator for processing user messages.
_assistant: Assistant context aggregator for processing assistant messages.
"""
_user: "OpenAIUserContextAggregator"
_assistant: "OpenAIAssistantContextAggregator"
def user(self) -> "OpenAIUserContextAggregator":
"""Get the user context aggregator.
Returns:
The user context aggregator instance.
"""
return self._user
def assistant(self) -> "OpenAIAssistantContextAggregator":
"""Get the assistant context aggregator.
Returns:
The assistant context aggregator instance.
"""
return self._assistant
class OpenAILLMService(BaseOpenAILLMService):
"""OpenAI LLM service implementation.
Provides a complete OpenAI LLM service with context aggregation support.
Uses the BaseOpenAILLMService for core functionality and adds OpenAI-specific
context aggregator creation.
Args:
model: The OpenAI model name to use. Defaults to "gpt-4.1".
params: Input parameters for model configuration.
**kwargs: Additional arguments passed to the parent BaseOpenAILLMService.
"""
def __init__(
self,
*,
@@ -53,14 +84,15 @@ class OpenAILLMService(BaseOpenAILLMService):
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
) -> OpenAIContextAggregatorPair:
"""Create an instance of OpenAIContextAggregatorPair from an
OpenAILLMContext. Constructor keyword arguments for both the user and
assistant aggregators can be provided.
"""Create OpenAI-specific context aggregators.
Creates a pair of context aggregators optimized for OpenAI's message format,
including support for function calls, tool usage, and image handling.
Args:
context (OpenAILLMContext): The LLM context.
user_params (LLMUserAggregatorParams, optional): User aggregator parameters.
assistant_params (LLMAssistantAggregatorParams, optional): User aggregator parameters.
context: The LLM context to create aggregators for.
user_params: Parameters for user message aggregation.
assistant_params: Parameters for assistant message aggregation.
Returns:
OpenAIContextAggregatorPair: A pair of context aggregators, one for
@@ -75,11 +107,32 @@ class OpenAILLMService(BaseOpenAILLMService):
class OpenAIUserContextAggregator(LLMUserContextAggregator):
"""OpenAI-specific user context aggregator.
Handles aggregation of user messages for OpenAI LLM services.
Inherits all functionality from the base LLMUserContextAggregator.
"""
pass
class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
"""OpenAI-specific assistant context aggregator.
Handles aggregation of assistant messages for OpenAI LLM services,
with specialized support for OpenAI's function calling format,
tool usage tracking, and image message handling.
"""
async def handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame):
"""Handle a function call in progress.
Adds the function call to the context with an IN_PROGRESS status
to track ongoing function execution.
Args:
frame: Frame containing function call progress information.
"""
self._context.add_message(
{
"role": "assistant",
@@ -104,6 +157,14 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
)
async def handle_function_call_result(self, frame: FunctionCallResultFrame):
"""Handle the result of a function call.
Updates the context with the function call result, replacing any
previous IN_PROGRESS status.
Args:
frame: Frame containing the function call result.
"""
if frame.result:
result = json.dumps(frame.result)
await self._update_function_call_result(frame.function_name, frame.tool_call_id, result)
@@ -113,6 +174,13 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
)
async def handle_function_call_cancel(self, frame: FunctionCallCancelFrame):
"""Handle a cancelled function call.
Updates the context to mark the function call as cancelled.
Args:
frame: Frame containing the function call cancellation information.
"""
await self._update_function_call_result(
frame.function_name, frame.tool_call_id, "CANCELLED"
)
@@ -129,6 +197,14 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
message["content"] = result
async def handle_user_image_frame(self, frame: UserImageRawFrame):
"""Handle a user image frame from a function call request.
Marks the associated function call as completed and adds the image
to the context for processing.
Args:
frame: Frame containing the user image and request context.
"""
await self._update_function_call_result(
frame.request.function_name, frame.request.tool_call_id, "COMPLETED"
)

View File

@@ -4,6 +4,8 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Azure OpenAI Realtime Beta LLM service implementation."""
from loguru import logger
from .openai import OpenAIRealtimeBetaLLMService
@@ -19,7 +21,18 @@ except ModuleNotFoundError as e:
class AzureRealtimeBetaLLMService(OpenAIRealtimeBetaLLMService):
"""Subclass of OpenAI Realtime API Service with adjustments for Azure's wss connection."""
"""Azure OpenAI Realtime Beta LLM service with Azure-specific authentication.
Extends the OpenAI Realtime service to work with Azure OpenAI endpoints,
using Azure's authentication headers and endpoint format. Provides the same
real-time audio and text communication capabilities as the base OpenAI service.
Args:
api_key: The API key for the Azure OpenAI service.
base_url: The full Azure WebSocket endpoint URL including api-version and deployment.
Example: "wss://my-project.openai.azure.com/openai/realtime?api-version=2024-10-01-preview&deployment=my-realtime-deployment"
**kwargs: Additional arguments passed to parent OpenAIRealtimeBetaLLMService.
"""
def __init__(
self,
@@ -28,16 +41,6 @@ class AzureRealtimeBetaLLMService(OpenAIRealtimeBetaLLMService):
base_url: str,
**kwargs,
):
"""Constructor takes the same arguments as the parent class, OpenAIRealtimeBetaLLMService.
Note that the following are required arguments:
api_key: The API key for the Azure OpenAI service.
base_url: The base URL for the Azure OpenAI service.
base_url should be set to the full Azure endpoint URL including the api-version and the deployment name. For example,
wss://my-project.openai.azure.com/openai/realtime?api-version=2024-10-01-preview&deployment=my-realtime-deployment
"""
super().__init__(base_url=base_url, api_key=api_key, **kwargs)
self.api_key = api_key
self.base_url = base_url

View File

@@ -4,6 +4,8 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""OpenAI Realtime LLM context and aggregator implementations."""
import copy
import json
@@ -30,6 +32,18 @@ from .frames import RealtimeFunctionCallResultFrame, RealtimeMessagesUpdateFrame
class OpenAIRealtimeLLMContext(OpenAILLMContext):
"""OpenAI Realtime LLM context with session management and message conversion.
Extends the standard OpenAI LLM context to support real-time session properties,
instruction management, and conversion between standard message formats and
realtime conversation items.
Args:
messages: Initial conversation messages. Defaults to None.
tools: Available function tools. Defaults to None.
**kwargs: Additional arguments passed to parent OpenAILLMContext.
"""
def __init__(self, messages=None, tools=None, **kwargs):
super().__init__(messages=messages, tools=tools, **kwargs)
self.__setup_local()
@@ -43,6 +57,14 @@ class OpenAIRealtimeLLMContext(OpenAILLMContext):
@staticmethod
def upgrade_to_realtime(obj: OpenAILLMContext) -> "OpenAIRealtimeLLMContext":
"""Upgrade a standard OpenAI LLM context to a realtime context.
Args:
obj: The OpenAILLMContext instance to upgrade.
Returns:
The upgraded OpenAIRealtimeLLMContext instance.
"""
if isinstance(obj, OpenAILLMContext) and not isinstance(obj, OpenAIRealtimeLLMContext):
obj.__class__ = OpenAIRealtimeLLMContext
obj.__setup_local()
@@ -52,6 +74,14 @@ class OpenAIRealtimeLLMContext(OpenAILLMContext):
# - finish implementing all frames
def from_standard_message(self, message):
"""Convert a standard message format to a realtime conversation item.
Args:
message: The standard message dictionary to convert.
Returns:
A ConversationItem instance for the realtime API.
"""
if message.get("role") == "user":
content = message.get("content")
if isinstance(message.get("content"), list):
@@ -79,6 +109,14 @@ class OpenAIRealtimeLLMContext(OpenAILLMContext):
logger.error(f"Unhandled message type in from_standard_message: {message}")
def get_messages_for_initializing_history(self):
"""Get conversation items for initializing the realtime session history.
Converts the context's messages to a format suitable for the realtime API,
handling system instructions and conversation history packaging.
Returns:
List of conversation items for session initialization.
"""
# We can't load a long conversation history into the openai realtime api yet. (The API/model
# forgets that it can do audio, if you do a series of `conversation.item.create` calls.) So
# our general strategy until this is fixed is just to put everything into a first "user"
@@ -133,6 +171,11 @@ class OpenAIRealtimeLLMContext(OpenAILLMContext):
]
def add_user_content_item_as_message(self, item):
"""Add a user content item as a standard message to the context.
Args:
item: The conversation item to add as a user message.
"""
message = {
"role": "user",
"content": [{"type": "text", "text": item.content[0].transcript}],
@@ -141,9 +184,25 @@ class OpenAIRealtimeLLMContext(OpenAILLMContext):
class OpenAIRealtimeUserContextAggregator(OpenAIUserContextAggregator):
"""User context aggregator for OpenAI Realtime API.
Handles user input frames and generates appropriate context updates
for the realtime conversation, including message updates and tool settings.
Args:
context: The OpenAI realtime LLM context.
**kwargs: Additional arguments passed to parent aggregator.
"""
async def process_frame(
self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM
):
"""Process incoming frames and handle realtime-specific frame types.
Args:
frame: The frame to process.
direction: The direction of frame flow in the pipeline.
"""
await super().process_frame(frame, direction)
# Parent does not push LLMMessagesUpdateFrame. This ensures that in a typical pipeline,
# messages are only processed by the user context aggregator, which is generally what we want. But
@@ -157,6 +216,11 @@ class OpenAIRealtimeUserContextAggregator(OpenAIUserContextAggregator):
await self.push_frame(frame, direction)
async def push_aggregation(self):
"""Push user input aggregation.
Currently ignores all user input coming into the pipeline as realtime
audio input is handled directly by the service.
"""
# for the moment, ignore all user input coming into the pipeline.
# todo: think about whether/how to fix this to allow for text input from
# upstream (transport/transcription, or other sources)
@@ -164,6 +228,16 @@ class OpenAIRealtimeUserContextAggregator(OpenAIUserContextAggregator):
class OpenAIRealtimeAssistantContextAggregator(OpenAIAssistantContextAggregator):
"""Assistant context aggregator for OpenAI Realtime API.
Handles assistant output frames from the realtime service, filtering
out duplicate text frames and managing function call results.
Args:
context: The OpenAI realtime LLM context.
**kwargs: Additional arguments passed to parent aggregator.
"""
# The LLMAssistantContextAggregator uses TextFrames to aggregate the LLM output,
# but the OpenAIRealtimeLLMService pushes LLMTextFrames and TTSTextFrames. We
# need to override this proces_frame for LLMTextFrame, so that only the TTSTextFrames
@@ -171,10 +245,21 @@ class OpenAIRealtimeAssistantContextAggregator(OpenAIAssistantContextAggregator)
# OpenAIRealtimeLLMService also pushes TranscriptionFrames and InterimTranscriptionFrames,
# so we need to ignore pushing those as well, as they're also TextFrames.
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process assistant frames, filtering out duplicate text content.
Args:
frame: The frame to process.
direction: The direction of frame flow in the pipeline.
"""
if not isinstance(frame, (LLMTextFrame, TranscriptionFrame, InterimTranscriptionFrame)):
await super().process_frame(frame, direction)
async def handle_function_call_result(self, frame: FunctionCallResultFrame):
"""Handle function call result and notify the realtime service.
Args:
frame: The function call result frame to handle.
"""
await super().handle_function_call_result(frame)
# The standard function callback code path pushes the FunctionCallResultFrame from the llm itself,

View File

@@ -3,13 +3,14 @@
#
# SPDX-License-Identifier: BSD 2-Clause License
#
#
"""Event models and data structures for OpenAI Realtime API communication."""
import json
import uuid
from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
#
# session properties
@@ -19,7 +20,7 @@ from pydantic import BaseModel, Field
class InputAudioTranscription(BaseModel):
"""Configuration for audio transcription settings.
Attributes:
Parameters:
model: Transcription model to use (e.g., "gpt-4o-transcribe", "whisper-1").
language: Optional language code for transcription.
prompt: Optional transcription hint text.
@@ -36,13 +37,18 @@ class InputAudioTranscription(BaseModel):
prompt: Optional[str] = None,
):
super().__init__(model=model, language=language, prompt=prompt)
if self.model != "gpt-4o-transcribe" and (self.language or self.prompt):
raise ValueError(
"Fields 'language' and 'prompt' are only supported when model is 'gpt-4o-transcribe'"
)
class TurnDetection(BaseModel):
"""Server-side voice activity detection configuration.
Parameters:
type: Detection type, must be "server_vad".
threshold: Voice activity detection threshold (0.0-1.0). Defaults to 0.5.
prefix_padding_ms: Padding before speech starts in milliseconds. Defaults to 300.
silence_duration_ms: Silence duration to detect speech end in milliseconds. Defaults to 800.
"""
type: Optional[Literal["server_vad"]] = "server_vad"
threshold: Optional[float] = 0.5
prefix_padding_ms: Optional[int] = 300
@@ -50,6 +56,15 @@ class TurnDetection(BaseModel):
class SemanticTurnDetection(BaseModel):
"""Semantic-based turn detection configuration.
Parameters:
type: Detection type, must be "semantic_vad".
eagerness: Turn detection eagerness level. Can be "low", "medium", "high", or "auto".
create_response: Whether to automatically create responses on turn detection.
interrupt_response: Whether to interrupt ongoing responses on turn detection.
"""
type: Optional[Literal["semantic_vad"]] = "semantic_vad"
eagerness: Optional[Literal["low", "medium", "high", "auto"]] = None
create_response: Optional[bool] = None
@@ -57,10 +72,33 @@ class SemanticTurnDetection(BaseModel):
class InputAudioNoiseReduction(BaseModel):
"""Input audio noise reduction configuration.
Parameters:
type: Noise reduction type for different microphone scenarios.
"""
type: Optional[Literal["near_field", "far_field"]]
class SessionProperties(BaseModel):
"""Configuration properties for an OpenAI Realtime session.
Parameters:
modalities: Communication modalities to enable (text, audio, or both).
instructions: System instructions for the assistant.
voice: Voice ID for text-to-speech output.
input_audio_format: Format for input audio data.
output_audio_format: Format for output audio data.
input_audio_transcription: Configuration for input audio transcription.
input_audio_noise_reduction: Configuration for input audio noise reduction.
turn_detection: Turn detection configuration or False to disable.
tools: Available function tools for the assistant.
tool_choice: Tool usage strategy ("auto", "none", or "required").
temperature: Sampling temperature for response generation.
max_response_output_tokens: Maximum tokens in response or "inf" for unlimited.
"""
modalities: Optional[List[Literal["text", "audio"]]] = None
instructions: Optional[str] = None
voice: Optional[str] = None
@@ -84,6 +122,15 @@ class SessionProperties(BaseModel):
class ItemContent(BaseModel):
"""Content within a conversation item.
Parameters:
type: Content type (text, audio, input_text, or input_audio).
text: Text content for text-based items.
audio: Base64-encoded audio data for audio items.
transcript: Transcribed text for audio items.
"""
type: Literal["text", "audio", "input_text", "input_audio"]
text: Optional[str] = None
audio: Optional[str] = None # base64-encoded audio
@@ -91,6 +138,21 @@ class ItemContent(BaseModel):
class ConversationItem(BaseModel):
"""A conversation item in the realtime session.
Parameters:
id: Unique identifier for the item, auto-generated if not provided.
object: Object type identifier for the realtime API.
type: Item type (message, function_call, or function_call_output).
status: Current status of the item.
role: Speaker role for message items (user, assistant, or system).
content: Content list for message items.
call_id: Function call identifier for function_call items.
name: Function name for function_call items.
arguments: Function arguments as JSON string for function_call items.
output: Function output as JSON string for function_call_output items.
"""
id: str = Field(default_factory=lambda: str(uuid.uuid4().hex))
object: Optional[Literal["realtime.item"]] = None
type: Literal["message", "function_call", "function_call_output"]
@@ -106,11 +168,31 @@ class ConversationItem(BaseModel):
class RealtimeConversation(BaseModel):
"""A realtime conversation session.
Parameters:
id: Unique identifier for the conversation.
object: Object type identifier, always "realtime.conversation".
"""
id: str
object: Literal["realtime.conversation"]
class ResponseProperties(BaseModel):
"""Properties for configuring assistant responses.
Parameters:
modalities: Output modalities for the response. Defaults to ["audio", "text"].
instructions: Specific instructions for this response.
voice: Voice ID for text-to-speech in this response.
output_audio_format: Audio format for this response.
tools: Available tools for this response.
tool_choice: Tool usage strategy for this response.
temperature: Sampling temperature for this response.
max_response_output_tokens: Maximum tokens for this response.
"""
modalities: Optional[List[Literal["text", "audio"]]] = ["audio", "text"]
instructions: Optional[str] = None
voice: Optional[str] = None
@@ -125,6 +207,16 @@ class ResponseProperties(BaseModel):
# error class
#
class RealtimeError(BaseModel):
"""Error information from the realtime API.
Parameters:
type: Error type identifier.
code: Specific error code.
message: Human-readable error message.
param: Parameter name that caused the error, if applicable.
event_id: Event ID associated with the error, if applicable.
"""
type: str
code: Optional[str] = ""
message: str
@@ -138,14 +230,38 @@ class RealtimeError(BaseModel):
class ClientEvent(BaseModel):
"""Base class for client events sent to the realtime API.
Parameters:
event_id: Unique identifier for the event, auto-generated if not provided.
"""
event_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
class SessionUpdateEvent(ClientEvent):
"""Event to update session properties.
Parameters:
type: Event type, always "session.update".
session: Updated session properties.
"""
type: Literal["session.update"] = "session.update"
session: SessionProperties
def model_dump(self, *args, **kwargs) -> Dict[str, Any]:
"""Serialize the event to a dictionary.
Handles special serialization for turn_detection where False becomes null.
Args:
*args: Positional arguments passed to parent model_dump.
**kwargs: Keyword arguments passed to parent model_dump.
Returns:
Dictionary representation of the event.
"""
dump = super().model_dump(*args, **kwargs)
# Handle turn_detection so that False is serialized as null
@@ -157,25 +273,61 @@ class SessionUpdateEvent(ClientEvent):
class InputAudioBufferAppendEvent(ClientEvent):
"""Event to append audio data to the input buffer.
Parameters:
type: Event type, always "input_audio_buffer.append".
audio: Base64-encoded audio data to append.
"""
type: Literal["input_audio_buffer.append"] = "input_audio_buffer.append"
audio: str # base64-encoded audio
class InputAudioBufferCommitEvent(ClientEvent):
"""Event to commit the current input audio buffer.
Parameters:
type: Event type, always "input_audio_buffer.commit".
"""
type: Literal["input_audio_buffer.commit"] = "input_audio_buffer.commit"
class InputAudioBufferClearEvent(ClientEvent):
"""Event to clear the input audio buffer.
Parameters:
type: Event type, always "input_audio_buffer.clear".
"""
type: Literal["input_audio_buffer.clear"] = "input_audio_buffer.clear"
class ConversationItemCreateEvent(ClientEvent):
"""Event to create a new conversation item.
Parameters:
type: Event type, always "conversation.item.create".
previous_item_id: ID of the item to insert after, if any.
item: The conversation item to create.
"""
type: Literal["conversation.item.create"] = "conversation.item.create"
previous_item_id: Optional[str] = None
item: ConversationItem
class ConversationItemTruncateEvent(ClientEvent):
"""Event to truncate a conversation item's audio content.
Parameters:
type: Event type, always "conversation.item.truncate".
item_id: ID of the item to truncate.
content_index: Index of the content to truncate within the item.
audio_end_ms: End time in milliseconds for the truncated audio.
"""
type: Literal["conversation.item.truncate"] = "conversation.item.truncate"
item_id: str
content_index: int
@@ -183,21 +335,48 @@ class ConversationItemTruncateEvent(ClientEvent):
class ConversationItemDeleteEvent(ClientEvent):
"""Event to delete a conversation item.
Parameters:
type: Event type, always "conversation.item.delete".
item_id: ID of the item to delete.
"""
type: Literal["conversation.item.delete"] = "conversation.item.delete"
item_id: str
class ConversationItemRetrieveEvent(ClientEvent):
"""Event to retrieve a conversation item by ID.
Parameters:
type: Event type, always "conversation.item.retrieve".
item_id: ID of the item to retrieve.
"""
type: Literal["conversation.item.retrieve"] = "conversation.item.retrieve"
item_id: str
class ResponseCreateEvent(ClientEvent):
"""Event to create a new assistant response.
Parameters:
type: Event type, always "response.create".
response: Optional response configuration properties.
"""
type: Literal["response.create"] = "response.create"
response: Optional[ResponseProperties] = None
class ResponseCancelEvent(ClientEvent):
"""Event to cancel the current assistant response.
Parameters:
type: Event type, always "response.cancel".
"""
type: Literal["response.cancel"] = "response.cancel"
@@ -207,35 +386,79 @@ class ResponseCancelEvent(ClientEvent):
class ServerEvent(BaseModel):
"""Base class for server events received from the realtime API.
Parameters:
event_id: Unique identifier for the event.
type: Type of the server event.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
event_id: str
type: str
class Config:
arbitrary_types_allowed = True
class SessionCreatedEvent(ServerEvent):
"""Event indicating a session has been created.
Parameters:
type: Event type, always "session.created".
session: The created session properties.
"""
type: Literal["session.created"]
session: SessionProperties
class SessionUpdatedEvent(ServerEvent):
"""Event indicating a session has been updated.
Parameters:
type: Event type, always "session.updated".
session: The updated session properties.
"""
type: Literal["session.updated"]
session: SessionProperties
class ConversationCreated(ServerEvent):
"""Event indicating a conversation has been created.
Parameters:
type: Event type, always "conversation.created".
conversation: The created conversation.
"""
type: Literal["conversation.created"]
conversation: RealtimeConversation
class ConversationItemCreated(ServerEvent):
"""Event indicating a conversation item has been created.
Parameters:
type: Event type, always "conversation.item.created".
previous_item_id: ID of the previous item, if any.
item: The created conversation item.
"""
type: Literal["conversation.item.created"]
previous_item_id: Optional[str] = None
item: ConversationItem
class ConversationItemInputAudioTranscriptionDelta(ServerEvent):
"""Event containing incremental input audio transcription.
Parameters:
type: Event type, always "conversation.item.input_audio_transcription.delta".
item_id: ID of the conversation item being transcribed.
content_index: Index of the content within the item.
delta: Incremental transcription text.
"""
type: Literal["conversation.item.input_audio_transcription.delta"]
item_id: str
content_index: int
@@ -243,6 +466,15 @@ class ConversationItemInputAudioTranscriptionDelta(ServerEvent):
class ConversationItemInputAudioTranscriptionCompleted(ServerEvent):
"""Event indicating input audio transcription is complete.
Parameters:
type: Event type, always "conversation.item.input_audio_transcription.completed".
item_id: ID of the conversation item that was transcribed.
content_index: Index of the content within the item.
transcript: Complete transcription text.
"""
type: Literal["conversation.item.input_audio_transcription.completed"]
item_id: str
content_index: int
@@ -250,6 +482,15 @@ class ConversationItemInputAudioTranscriptionCompleted(ServerEvent):
class ConversationItemInputAudioTranscriptionFailed(ServerEvent):
"""Event indicating input audio transcription failed.
Parameters:
type: Event type, always "conversation.item.input_audio_transcription.failed".
item_id: ID of the conversation item that failed transcription.
content_index: Index of the content within the item.
error: Error details for the transcription failure.
"""
type: Literal["conversation.item.input_audio_transcription.failed"]
item_id: str
content_index: int
@@ -257,6 +498,15 @@ class ConversationItemInputAudioTranscriptionFailed(ServerEvent):
class ConversationItemTruncated(ServerEvent):
"""Event indicating a conversation item has been truncated.
Parameters:
type: Event type, always "conversation.item.truncated".
item_id: ID of the truncated conversation item.
content_index: Index of the content within the item.
audio_end_ms: End time in milliseconds for the truncated audio.
"""
type: Literal["conversation.item.truncated"]
item_id: str
content_index: int
@@ -264,26 +514,63 @@ class ConversationItemTruncated(ServerEvent):
class ConversationItemDeleted(ServerEvent):
"""Event indicating a conversation item has been deleted.
Parameters:
type: Event type, always "conversation.item.deleted".
item_id: ID of the deleted conversation item.
"""
type: Literal["conversation.item.deleted"]
item_id: str
class ConversationItemRetrieved(ServerEvent):
"""Event containing a retrieved conversation item.
Parameters:
type: Event type, always "conversation.item.retrieved".
item: The retrieved conversation item.
"""
type: Literal["conversation.item.retrieved"]
item: ConversationItem
class ResponseCreated(ServerEvent):
"""Event indicating an assistant response has been created.
Parameters:
type: Event type, always "response.created".
response: The created response object.
"""
type: Literal["response.created"]
response: "Response"
class ResponseDone(ServerEvent):
"""Event indicating an assistant response is complete.
Parameters:
type: Event type, always "response.done".
response: The completed response object.
"""
type: Literal["response.done"]
response: "Response"
class ResponseOutputItemAdded(ServerEvent):
"""Event indicating an output item has been added to a response.
Parameters:
type: Event type, always "response.output_item.added".
response_id: ID of the response.
output_index: Index of the output item.
item: The added conversation item.
"""
type: Literal["response.output_item.added"]
response_id: str
output_index: int
@@ -291,6 +578,15 @@ class ResponseOutputItemAdded(ServerEvent):
class ResponseOutputItemDone(ServerEvent):
"""Event indicating an output item is complete.
Parameters:
type: Event type, always "response.output_item.done".
response_id: ID of the response.
output_index: Index of the output item.
item: The completed conversation item.
"""
type: Literal["response.output_item.done"]
response_id: str
output_index: int
@@ -298,6 +594,17 @@ class ResponseOutputItemDone(ServerEvent):
class ResponseContentPartAdded(ServerEvent):
"""Event indicating a content part has been added to a response.
Parameters:
type: Event type, always "response.content_part.added".
response_id: ID of the response.
item_id: ID of the conversation item.
output_index: Index of the output item.
content_index: Index of the content part.
part: The added content part.
"""
type: Literal["response.content_part.added"]
response_id: str
item_id: str
@@ -307,6 +614,17 @@ class ResponseContentPartAdded(ServerEvent):
class ResponseContentPartDone(ServerEvent):
"""Event indicating a content part is complete.
Parameters:
type: Event type, always "response.content_part.done".
response_id: ID of the response.
item_id: ID of the conversation item.
output_index: Index of the output item.
content_index: Index of the content part.
part: The completed content part.
"""
type: Literal["response.content_part.done"]
response_id: str
item_id: str
@@ -316,6 +634,17 @@ class ResponseContentPartDone(ServerEvent):
class ResponseTextDelta(ServerEvent):
"""Event containing incremental text from a response.
Parameters:
type: Event type, always "response.text.delta".
response_id: ID of the response.
item_id: ID of the conversation item.
output_index: Index of the output item.
content_index: Index of the content part.
delta: Incremental text content.
"""
type: Literal["response.text.delta"]
response_id: str
item_id: str
@@ -325,6 +654,17 @@ class ResponseTextDelta(ServerEvent):
class ResponseTextDone(ServerEvent):
"""Event indicating text content is complete.
Parameters:
type: Event type, always "response.text.done".
response_id: ID of the response.
item_id: ID of the conversation item.
output_index: Index of the output item.
content_index: Index of the content part.
text: Complete text content.
"""
type: Literal["response.text.done"]
response_id: str
item_id: str
@@ -334,6 +674,17 @@ class ResponseTextDone(ServerEvent):
class ResponseAudioTranscriptDelta(ServerEvent):
"""Event containing incremental audio transcript from a response.
Parameters:
type: Event type, always "response.audio_transcript.delta".
response_id: ID of the response.
item_id: ID of the conversation item.
output_index: Index of the output item.
content_index: Index of the content part.
delta: Incremental transcript text.
"""
type: Literal["response.audio_transcript.delta"]
response_id: str
item_id: str
@@ -343,6 +694,17 @@ class ResponseAudioTranscriptDelta(ServerEvent):
class ResponseAudioTranscriptDone(ServerEvent):
"""Event indicating audio transcript is complete.
Parameters:
type: Event type, always "response.audio_transcript.done".
response_id: ID of the response.
item_id: ID of the conversation item.
output_index: Index of the output item.
content_index: Index of the content part.
transcript: Complete transcript text.
"""
type: Literal["response.audio_transcript.done"]
response_id: str
item_id: str
@@ -352,6 +714,17 @@ class ResponseAudioTranscriptDone(ServerEvent):
class ResponseAudioDelta(ServerEvent):
"""Event containing incremental audio data from a response.
Parameters:
type: Event type, always "response.audio.delta".
response_id: ID of the response.
item_id: ID of the conversation item.
output_index: Index of the output item.
content_index: Index of the content part.
delta: Base64-encoded incremental audio data.
"""
type: Literal["response.audio.delta"]
response_id: str
item_id: str
@@ -361,6 +734,16 @@ class ResponseAudioDelta(ServerEvent):
class ResponseAudioDone(ServerEvent):
"""Event indicating audio content is complete.
Parameters:
type: Event type, always "response.audio.done".
response_id: ID of the response.
item_id: ID of the conversation item.
output_index: Index of the output item.
content_index: Index of the content part.
"""
type: Literal["response.audio.done"]
response_id: str
item_id: str
@@ -369,6 +752,17 @@ class ResponseAudioDone(ServerEvent):
class ResponseFunctionCallArgumentsDelta(ServerEvent):
"""Event containing incremental function call arguments.
Parameters:
type: Event type, always "response.function_call_arguments.delta".
response_id: ID of the response.
item_id: ID of the conversation item.
output_index: Index of the output item.
call_id: ID of the function call.
delta: Incremental function arguments as JSON.
"""
type: Literal["response.function_call_arguments.delta"]
response_id: str
item_id: str
@@ -378,6 +772,17 @@ class ResponseFunctionCallArgumentsDelta(ServerEvent):
class ResponseFunctionCallArgumentsDone(ServerEvent):
"""Event indicating function call arguments are complete.
Parameters:
type: Event type, always "response.function_call_arguments.done".
response_id: ID of the response.
item_id: ID of the conversation item.
output_index: Index of the output item.
call_id: ID of the function call.
arguments: Complete function arguments as JSON string.
"""
type: Literal["response.function_call_arguments.done"]
response_id: str
item_id: str
@@ -387,38 +792,90 @@ class ResponseFunctionCallArgumentsDone(ServerEvent):
class InputAudioBufferSpeechStarted(ServerEvent):
"""Event indicating speech has started in the input audio buffer.
Parameters:
type: Event type, always "input_audio_buffer.speech_started".
audio_start_ms: Start time of speech in milliseconds.
item_id: ID of the associated conversation item.
"""
type: Literal["input_audio_buffer.speech_started"]
audio_start_ms: int
item_id: str
class InputAudioBufferSpeechStopped(ServerEvent):
"""Event indicating speech has stopped in the input audio buffer.
Parameters:
type: Event type, always "input_audio_buffer.speech_stopped".
audio_end_ms: End time of speech in milliseconds.
item_id: ID of the associated conversation item.
"""
type: Literal["input_audio_buffer.speech_stopped"]
audio_end_ms: int
item_id: str
class InputAudioBufferCommitted(ServerEvent):
"""Event indicating the input audio buffer has been committed.
Parameters:
type: Event type, always "input_audio_buffer.committed".
previous_item_id: ID of the previous item, if any.
item_id: ID of the committed conversation item.
"""
type: Literal["input_audio_buffer.committed"]
previous_item_id: Optional[str] = None
item_id: str
class InputAudioBufferCleared(ServerEvent):
"""Event indicating the input audio buffer has been cleared.
Parameters:
type: Event type, always "input_audio_buffer.cleared".
"""
type: Literal["input_audio_buffer.cleared"]
class ErrorEvent(ServerEvent):
"""Event indicating an error occurred.
Parameters:
type: Event type, always "error".
error: Error details.
"""
type: Literal["error"]
error: RealtimeError
class RateLimitsUpdated(ServerEvent):
"""Event indicating rate limits have been updated.
Parameters:
type: Event type, always "rate_limits.updated".
rate_limits: List of rate limit information.
"""
type: Literal["rate_limits.updated"]
rate_limits: List[Dict[str, Any]]
class TokenDetails(BaseModel):
"""Detailed token usage information.
Parameters:
cached_tokens: Number of cached tokens used. Defaults to 0.
text_tokens: Number of text tokens used. Defaults to 0.
audio_tokens: Number of audio tokens used. Defaults to 0.
"""
cached_tokens: Optional[int] = 0
text_tokens: Optional[int] = 0
audio_tokens: Optional[int] = 0
@@ -428,6 +885,16 @@ class TokenDetails(BaseModel):
class Usage(BaseModel):
"""Token usage statistics for a response.
Parameters:
total_tokens: Total number of tokens used.
input_tokens: Number of input tokens used.
output_tokens: Number of output tokens used.
input_token_details: Detailed breakdown of input token usage.
output_token_details: Detailed breakdown of output token usage.
"""
total_tokens: int
input_tokens: int
output_tokens: int
@@ -436,6 +903,17 @@ class Usage(BaseModel):
class Response(BaseModel):
"""A complete assistant response.
Parameters:
id: Unique identifier for the response.
object: Object type, always "realtime.response".
status: Current status of the response.
status_details: Additional status information.
output: List of conversation items in the response.
usage: Token usage statistics for the response.
"""
id: str
object: Literal["realtime.response"]
status: Literal["completed", "in_progress", "incomplete", "cancelled", "failed"]
@@ -479,6 +957,17 @@ _server_event_types = {
def parse_server_event(str):
"""Parse a server event from JSON string.
Args:
str: JSON string containing the server event.
Returns:
Parsed server event object of the appropriate type.
Raises:
Exception: If the event type is unimplemented or parsing fails.
"""
try:
event = json.loads(str)
event_type = event["type"]

View File

@@ -4,16 +4,34 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Custom frame types for OpenAI Realtime API integration."""
from dataclasses import dataclass
from typing import TYPE_CHECKING
from pipecat.frames.frames import DataFrame, FunctionCallResultFrame
if TYPE_CHECKING:
from pipecat.services.openai_realtime_beta.context import OpenAIRealtimeLLMContext
@dataclass
class RealtimeMessagesUpdateFrame(DataFrame):
"""Frame indicating that the realtime context messages have been updated.
Parameters:
context: The updated OpenAI realtime LLM context.
"""
context: "OpenAIRealtimeLLMContext"
@dataclass
class RealtimeFunctionCallResultFrame(DataFrame):
"""Frame containing function call results for the realtime service.
Parameters:
result_frame: The function call result frame to send to the realtime API.
"""
result_frame: FunctionCallResultFrame

View File

@@ -4,6 +4,8 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""OpenAI Realtime Beta LLM service implementation with WebSocket support."""
import base64
import json
import time
@@ -51,8 +53,9 @@ from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.llm_service import FunctionCallFromLLM, LLMService
from pipecat.services.openai.llm import OpenAIContextAggregatorPair
from pipecat.transcriptions.language import Language
from pipecat.utils.asyncio.watchdog_async_iterator import WatchdogAsyncIterator
from pipecat.utils.time import time_now_iso8601
from pipecat.utils.tracing.service_decorators import traced_openai_realtime, traced_stt, traced_tts
from pipecat.utils.tracing.service_decorators import traced_openai_realtime, traced_stt
from . import events
from .context import (
@@ -72,6 +75,15 @@ except ModuleNotFoundError as e:
@dataclass
class CurrentAudioResponse:
"""Tracks the current audio response from the assistant.
Parameters:
item_id: Unique identifier for the audio response item.
content_index: Index of the audio content within the item.
start_time_ms: Timestamp when the audio response started in milliseconds.
total_size: Total size of audio data received in bytes. Defaults to 0.
"""
item_id: str
content_index: int
start_time_ms: int
@@ -79,6 +91,24 @@ class CurrentAudioResponse:
class OpenAIRealtimeBetaLLMService(LLMService):
"""OpenAI Realtime Beta LLM service providing real-time audio and text communication.
Implements the OpenAI Realtime API Beta with WebSocket communication for low-latency
bidirectional audio and text interactions. Supports function calling, conversation
management, and real-time transcription.
Args:
api_key: OpenAI API key for authentication.
model: OpenAI model name. Defaults to "gpt-4o-realtime-preview-2025-06-03".
base_url: WebSocket base URL for the realtime API.
Defaults to "wss://api.openai.com/v1/realtime".
session_properties: Configuration properties for the realtime session.
If None, uses default SessionProperties.
start_audio_paused: Whether to start with audio input paused. Defaults to False.
send_transcription_frames: Whether to emit transcription frames. Defaults to True.
**kwargs: Additional arguments passed to parent LLMService.
"""
# Overriding the default adapter to use the OpenAIRealtimeLLMAdapter one.
adapter_class = OpenAIRealtimeLLMAdapter
@@ -86,7 +116,7 @@ class OpenAIRealtimeBetaLLMService(LLMService):
self,
*,
api_key: str,
model: str = "gpt-4o-realtime-preview-2024-12-17",
model: str = "gpt-4o-realtime-preview-2025-06-03",
base_url: str = "wss://api.openai.com/v1/realtime",
session_properties: Optional[events.SessionProperties] = None,
start_audio_paused: bool = False,
@@ -124,12 +154,30 @@ class OpenAIRealtimeBetaLLMService(LLMService):
self._retrieve_conversation_item_futures = {}
def can_generate_metrics(self) -> bool:
"""Check if the service can generate usage metrics.
Returns:
True if metrics generation is supported.
"""
return True
def set_audio_input_paused(self, paused: bool):
"""Set whether audio input is paused.
Args:
paused: True to pause audio input, False to resume.
"""
self._audio_input_paused = paused
async def retrieve_conversation_item(self, item_id: str):
"""Retrieve a conversation item by ID from the server.
Args:
item_id: The ID of the conversation item to retrieve.
Returns:
The retrieved conversation item.
"""
future = self.get_event_loop().create_future()
retrieval_in_flight = False
if not self._retrieve_conversation_item_futures.get(item_id):
@@ -153,14 +201,29 @@ class OpenAIRealtimeBetaLLMService(LLMService):
#
async def start(self, frame: StartFrame):
"""Start the service and establish WebSocket connection.
Args:
frame: The start frame triggering service initialization.
"""
await super().start(frame)
await self._connect()
async def stop(self, frame: EndFrame):
"""Stop the service and close WebSocket connection.
Args:
frame: The end frame triggering service shutdown.
"""
await super().stop(frame)
await self._disconnect()
async def cancel(self, frame: CancelFrame):
"""Cancel the service and close WebSocket connection.
Args:
frame: The cancel frame triggering service cancellation.
"""
await super().cancel(frame)
await self._disconnect()
@@ -246,6 +309,12 @@ class OpenAIRealtimeBetaLLMService(LLMService):
#
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process incoming frames from the pipeline.
Args:
frame: The frame to process.
direction: The direction of frame flow in the pipeline.
"""
await super().process_frame(frame, direction)
if isinstance(frame, TranscriptionFrame):
@@ -303,6 +372,11 @@ class OpenAIRealtimeBetaLLMService(LLMService):
#
async def send_client_event(self, event: events.ClientEvent):
"""Send a client event to the OpenAI Realtime API.
Args:
event: The client event to send.
"""
await self._ws_send(event.model_dump(exclude_none=True))
async def _connect(self):
@@ -369,7 +443,7 @@ class OpenAIRealtimeBetaLLMService(LLMService):
#
async def _receive_task_handler(self):
async for message in self._websocket:
async for message in WatchdogAsyncIterator(self._websocket, manager=self.task_manager):
evt = events.parse_server_event(message)
if evt.type == "session.created":
await self._handle_evt_session_created(evt)
@@ -475,6 +549,11 @@ class OpenAIRealtimeBetaLLMService(LLMService):
pass
async def handle_evt_input_audio_transcription_completed(self, evt):
"""Handle completion of input audio transcription.
Args:
evt: The transcription completed event.
"""
await self._call_event_handler("on_conversation_item_updated", evt.item_id, None)
if self._send_transcription_frames:
@@ -555,7 +634,9 @@ class OpenAIRealtimeBetaLLMService(LLMService):
await self.push_frame(UserStoppedSpeakingFrame())
async def _maybe_handle_evt_retrieve_conversation_item_error(self, evt: events.ErrorEvent):
"""If the given error event is an error retrieving a conversation item:
"""Maybe handle an error event related to retrieving a conversation item.
If the given error event is an error retrieving a conversation item:
- set an exception on the future that retrieve_conversation_item() is waiting on
- return true
Otherwise:
@@ -602,8 +683,11 @@ class OpenAIRealtimeBetaLLMService(LLMService):
#
async def reset_conversation(self):
# Disconnect/reconnect is the safest way to start a new conversation.
# Note that this will fail if called from the receive task.
"""Reset the conversation by disconnecting and reconnecting.
This is the safest way to start a new conversation. Note that this will
fail if called from the receive task.
"""
logger.debug("Resetting conversation")
await self._disconnect()
if self._context:
@@ -651,22 +735,19 @@ class OpenAIRealtimeBetaLLMService(LLMService):
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
) -> OpenAIContextAggregatorPair:
"""Create an instance of OpenAIContextAggregatorPair from an
OpenAILLMContext. Constructor keyword arguments for both the user and
assistant aggregators can be provided.
"""Create an instance of OpenAIContextAggregatorPair from an OpenAILLMContext.
Constructor keyword arguments for both the user and assistant aggregators can be provided.
Args:
context (OpenAILLMContext): The LLM context.
user_params (LLMUserAggregatorParams, optional): User aggregator
parameters.
assistant_params (LLMAssistantAggregatorParams, optional): User
aggregator parameters.
context: The LLM context.
user_params: User aggregator parameters.
assistant_params: Assistant aggregator parameters.
Returns:
OpenAIContextAggregatorPair: A pair of context aggregators, one for
the user and one for the assistant, encapsulated in an
OpenAIContextAggregatorPair.
"""
context.set_llm_adapter(self.get_llm_adapter())

View File

@@ -4,6 +4,12 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""OpenPipe LLM service implementation for Pipecat.
This module provides an OpenPipe-specific implementation of the OpenAI LLM service,
enabling integration with OpenPipe's fine-tuning and monitoring capabilities.
"""
from typing import Dict, List, Optional
from loguru import logger
@@ -22,6 +28,22 @@ except ModuleNotFoundError as e:
class OpenPipeLLMService(OpenAILLMService):
"""OpenPipe-powered Large Language Model service.
Extends OpenAI's LLM service to integrate with OpenPipe's fine-tuning and
monitoring platform. Provides enhanced request logging and tagging capabilities
for model training and evaluation.
Args:
model: The model name to use. Defaults to "gpt-4.1".
api_key: OpenAI API key for authentication. If None, reads from environment.
base_url: Custom OpenAI API endpoint URL. Uses default if None.
openpipe_api_key: OpenPipe API key for enhanced features. If None, reads from environment.
openpipe_base_url: OpenPipe API endpoint URL. Defaults to "https://app.openpipe.ai/api/v1".
tags: Optional dictionary of tags to apply to all requests for tracking.
**kwargs: Additional arguments passed to parent OpenAILLMService.
"""
def __init__(
self,
*,
@@ -44,6 +66,16 @@ class OpenPipeLLMService(OpenAILLMService):
self._tags = tags
def create_client(self, api_key=None, base_url=None, **kwargs):
"""Create an OpenPipe client instance.
Args:
api_key: OpenAI API key for authentication.
base_url: OpenAI API base URL.
**kwargs: Additional arguments including openpipe_api_key and openpipe_base_url.
Returns:
Configured OpenPipe AsyncOpenAI client instance.
"""
openpipe_api_key = kwargs.get("openpipe_api_key") or ""
openpipe_base_url = kwargs.get("openpipe_base_url") or ""
client = OpenPipeAI(
@@ -56,6 +88,15 @@ class OpenPipeLLMService(OpenAILLMService):
async def get_chat_completions(
self, context: OpenAILLMContext, messages: List[ChatCompletionMessageParam]
) -> AsyncStream[ChatCompletionChunk]:
"""Generate streaming chat completions with OpenPipe logging.
Args:
context: The OpenAI LLM context containing conversation state.
messages: List of chat completion message parameters.
Returns:
Async stream of chat completion chunks.
"""
chunks = await self._client.chat.completions.create(
model=self.model_name,
stream=True,

View File

@@ -4,6 +4,12 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""OpenRouter LLM service implementation.
This module provides an OpenAI-compatible interface for interacting with OpenRouter's API,
extending the base OpenAI LLM service functionality.
"""
from typing import Optional
from loguru import logger
@@ -18,10 +24,11 @@ class OpenRouterLLMService(OpenAILLMService):
maintaining full compatibility with OpenAI's interface and functionality.
Args:
api_key (str): The API key for accessing OpenRouter's API
base_url (str, optional): The base URL for OpenRouter API. Defaults to "https://openrouter.ai/api/v1"
model (str, optional): The model identifier to use. Defaults to "openai/gpt-4o-2024-11-20"
**kwargs: Additional keyword arguments passed to OpenAILLMService
api_key: The API key for accessing OpenRouter's API. If None, will attempt
to read from environment variables.
model: The model identifier to use. Defaults to "openai/gpt-4o-2024-11-20".
base_url: The base URL for OpenRouter API. Defaults to "https://openrouter.ai/api/v1".
**kwargs: Additional keyword arguments passed to OpenAILLMService.
"""
def __init__(
@@ -40,5 +47,15 @@ class OpenRouterLLMService(OpenAILLMService):
)
def create_client(self, api_key=None, base_url=None, **kwargs):
"""Create an OpenRouter API client.
Args:
api_key: The API key to use for authentication. If None, uses instance default.
base_url: The base URL for the API. If None, uses instance default.
**kwargs: Additional arguments passed to the parent client creation method.
Returns:
The configured OpenRouter API client instance.
"""
logger.debug(f"Creating OpenRouter client with api {base_url}")
return super().create_client(api_key, base_url, **kwargs)

View File

@@ -4,6 +4,13 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Perplexity LLM service implementation.
This module provides a service for interacting with Perplexity's API using
an OpenAI-compatible interface. It handles Perplexity's unique token usage
reporting patterns while maintaining compatibility with the Pipecat framework.
"""
from typing import List
from openai import NOT_GIVEN, AsyncStream
@@ -22,10 +29,10 @@ class PerplexityLLMService(OpenAILLMService):
in token usage reporting between Perplexity (incremental) and OpenAI (final summary).
Args:
api_key (str): The API key for accessing Perplexity's API
base_url (str, optional): The base URL for Perplexity's API. Defaults to "https://api.perplexity.ai"
model (str, optional): The model identifier to use. Defaults to "sonar"
**kwargs: Additional keyword arguments passed to OpenAILLMService
api_key: The API key for accessing Perplexity's API.
base_url: The base URL for Perplexity's API. Defaults to "https://api.perplexity.ai".
model: The model identifier to use. Defaults to "sonar".
**kwargs: Additional keyword arguments passed to OpenAILLMService.
"""
def __init__(
@@ -50,11 +57,11 @@ class PerplexityLLMService(OpenAILLMService):
"""Get chat completions from Perplexity API using OpenAI-compatible parameters.
Args:
context: The context containing conversation history and settings
messages: The messages to send to the API
context: The context containing conversation history and settings.
messages: The messages to send to the API.
Returns:
A stream of chat completion chunks
A stream of chat completion chunks from the Perplexity API.
"""
params = {
"model": self.model_name,
@@ -85,8 +92,8 @@ class PerplexityLLMService(OpenAILLMService):
and reporting them once at the end of processing.
Args:
context (OpenAILLMContext): The context to process, containing messages
and other information needed for the LLM interaction.
context: The context to process, containing messages and other
information needed for the LLM interaction.
"""
# Reset all counters and flags at the start of processing
self._prompt_tokens = 0
@@ -115,6 +122,9 @@ class PerplexityLLMService(OpenAILLMService):
Perplexity reports token usage incrementally during streaming,
unlike OpenAI which provides a final summary. We accumulate the
counts and report the total at the end of processing.
Args:
tokens: Token usage information to accumulate.
"""
if not self._is_processing:
return

View File

@@ -4,6 +4,8 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Qwen LLM service implementation using OpenAI-compatible interface."""
from loguru import logger
from pipecat.services.openai.llm import OpenAILLMService
@@ -16,10 +18,10 @@ class QwenLLMService(OpenAILLMService):
maintaining full compatibility with OpenAI's interface and functionality.
Args:
api_key (str): The API key for accessing Qwen's API (DashScope API key)
base_url (str, optional): Base URL for Qwen API. Defaults to "https://dashscope-intl.aliyuncs.com/compatible-mode/v1"
model (str, optional): The model identifier to use. Defaults to "qwen-plus".
**kwargs: Additional keyword arguments passed to OpenAILLMService
api_key: The API key for accessing Qwen's API (DashScope API key).
base_url: Base URL for Qwen API. Defaults to "https://dashscope-intl.aliyuncs.com/compatible-mode/v1".
model: The model identifier to use. Defaults to "qwen-plus".
**kwargs: Additional keyword arguments passed to OpenAILLMService.
"""
def __init__(
@@ -34,6 +36,15 @@ class QwenLLMService(OpenAILLMService):
logger.info(f"Initialized Qwen LLM service with model: {model}")
def create_client(self, api_key=None, base_url=None, **kwargs):
"""Create OpenAI-compatible client for Qwen API endpoint."""
"""Create OpenAI-compatible client for Qwen API endpoint.
Args:
api_key: API key for authentication. If None, uses instance default.
base_url: Base URL for the API. If None, uses instance default.
**kwargs: Additional arguments passed to the parent client creation.
Returns:
An OpenAI-compatible client configured for Qwen's API.
"""
logger.debug(f"Creating Qwen client with base URL: {base_url}")
return super().create_client(api_key, base_url, **kwargs)

View File

@@ -21,6 +21,7 @@ from pipecat.frames.frames import (
)
from pipecat.services.stt_service import SegmentedSTTService, STTService
from pipecat.transcriptions.language import Language
from pipecat.utils.asyncio.watchdog_queue import WatchdogQueue
from pipecat.utils.time import time_now_iso8601
from pipecat.utils.tracing.service_decorators import traced_stt
@@ -198,7 +199,7 @@ class RivaSTTService(STTService):
self._thread_task = self.create_task(self._thread_task_handler())
if not self._response_task:
self._response_queue = asyncio.Queue()
self._response_queue = WatchdogQueue(self.task_manager)
self._response_task = self.create_task(self._response_task_handler())
async def stop(self, frame: EndFrame):
@@ -224,6 +225,7 @@ class RivaSTTService(STTService):
streaming_config=self._config,
)
for response in responses:
self.reset_watchdog()
if not response.results:
continue
asyncio.run_coroutine_threadsafe(
@@ -284,6 +286,7 @@ class RivaSTTService(STTService):
while True:
response = await self._response_queue.get()
await self._handle_response(response)
self._response_queue.task_done()
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
await self.start_ttfb_metrics()

View File

@@ -0,0 +1,8 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
from .llm import *
from .stt import *

View File

@@ -0,0 +1,210 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""SambaNova LLM service implementation using OpenAI-compatible interface."""
import json
from typing import Any, Dict, List, Optional
from loguru import logger
from openai import AsyncStream
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
from pipecat.frames.frames import (
LLMTextFrame,
)
from pipecat.metrics.metrics import LLMTokenUsage
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.llm_service import FunctionCallFromLLM
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.utils.asyncio.watchdog_async_iterator import WatchdogAsyncIterator
from pipecat.utils.tracing.service_decorators import traced_llm
class SambaNovaLLMService(OpenAILLMService): # type: ignore
"""A service for interacting with SambaNova using the OpenAI-compatible interface.
This service extends OpenAILLMService to connect to SambaNova's API endpoint while
maintaining full compatibility with OpenAI's interface and functionality.
Args:
api_key: The API key for accessing SambaNova API.
model: The model identifier to use. Defaults to "Llama-4-Maverick-17B-128E-Instruct".
base_url: The base URL for SambaNova API. Defaults to "https://api.sambanova.ai/v1".
**kwargs: Additional keyword arguments passed to OpenAILLMService.
"""
def __init__(
self,
*,
api_key: str,
model: str = "Llama-4-Maverick-17B-128E-Instruct",
base_url: str = "https://api.sambanova.ai/v1",
**kwargs: Dict[Any, Any],
) -> None:
super().__init__(api_key=api_key, base_url=base_url, model=model, **kwargs)
def create_client(
self,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
**kwargs: Dict[Any, Any],
) -> Any:
"""Create OpenAI-compatible client for SambaNova API endpoint.
Args:
api_key: API key for authentication. If None, uses instance default.
base_url: Base URL for the API endpoint. If None, uses instance default.
**kwargs: Additional keyword arguments for client configuration.
Returns:
Configured OpenAI-compatible client instance.
"""
logger.debug(f"Creating SambaNova client with API {base_url}")
return super().create_client(api_key, base_url, **kwargs)
async def get_chat_completions(
self, context: OpenAILLMContext, messages: List[ChatCompletionMessageParam]
) -> Any:
"""Get chat completions from SambaNova API endpoint.
Args:
context: OpenAI LLM context containing tools and configuration.
messages: List of chat completion message parameters.
Returns:
Chat completion response stream from SambaNova API.
"""
params = {
"model": self.model_name,
"stream": True,
"messages": messages,
"tools": context.tools,
"tool_choice": context.tool_choice,
"stream_options": {"include_usage": True},
"temperature": self._settings["temperature"],
"top_p": self._settings["top_p"],
"max_tokens": self._settings["max_tokens"],
"max_completion_tokens": self._settings["max_completion_tokens"],
}
params.update(self._settings["extra"])
chunks = await self._client.chat.completions.create(**params)
return chunks
@traced_llm # type: ignore
async def _process_context(self, context: OpenAILLMContext) -> AsyncStream[ChatCompletionChunk]:
"""Process OpenAI LLM context and stream chat completion chunks.
This method handles the streaming response from SambaNova API, including
function call processing and text frame generation. It includes special
handling for SambaNova's API limitations with tool call indexing.
Args:
context: OpenAI LLM context containing conversation state and tools.
Returns:
Async stream of chat completion chunks.
"""
functions_list = []
arguments_list = []
tool_id_list = []
func_idx = 0
function_name = ""
arguments = ""
tool_call_id = ""
await self.start_ttfb_metrics()
chunk_stream: AsyncStream[ChatCompletionChunk] = await self._stream_chat_completions(
context
)
async for chunk in WatchdogAsyncIterator(chunk_stream, manager=self.task_manager):
if chunk.usage:
tokens = LLMTokenUsage(
prompt_tokens=chunk.usage.prompt_tokens,
completion_tokens=chunk.usage.completion_tokens,
total_tokens=chunk.usage.total_tokens,
)
await self.start_llm_usage_metrics(tokens)
if chunk.choices is None or len(chunk.choices) == 0:
continue
await self.stop_ttfb_metrics()
if not chunk.choices[0].delta:
continue
if chunk.choices[0].delta.tool_calls:
# We're streaming the LLM response to enable the fastest response times.
# For text, we just yield each chunk as we receive it and count on consumers
# to do whatever coalescing they need (eg. to pass full sentences to TTS)
#
# If the LLM is a function call, we'll do some coalescing here.
# If the response contains a function name, we'll yield a frame to tell consumers
# that they can start preparing to call the function with that name.
# We accumulate all the arguments for the rest of the streamed response, then when
# the response is done, we package up all the arguments and the function name and
# yield a frame containing the function name and the arguments.
tool_call = chunk.choices[0].delta.tool_calls[0]
if tool_call.index != func_idx:
functions_list.append(function_name)
arguments_list.append(arguments)
tool_id_list.append(tool_call_id)
function_name = ""
arguments = ""
tool_call_id = ""
func_idx += 1
if tool_call.function and tool_call.function.name:
function_name += tool_call.function.name
tool_call_id = tool_call.id # type: ignore
if tool_call.function and tool_call.function.arguments:
# Keep iterating through the response to collect all the argument fragments
arguments += tool_call.function.arguments
elif chunk.choices[0].delta.content:
await self.push_frame(LLMTextFrame(chunk.choices[0].delta.content))
# When gpt-4o-audio / gpt-4o-mini-audio is used for llm or stt+llm
# we need to get LLMTextFrame for the transcript
elif hasattr(chunk.choices[0].delta, "audio") and chunk.choices[0].delta.audio.get(
"transcript"
):
await self.push_frame(LLMTextFrame(chunk.choices[0].delta.audio["transcript"]))
# if we got a function name and arguments, check to see if it's a function with
# a registered handler. If so, run the registered callback, save the result to
# the context, and re-prompt to get a chat answer. If we don't have a registered
# handler, raise an exception.
if function_name and arguments:
# added to the list as last function name and arguments not added to the list
functions_list.append(function_name)
arguments_list.append(arguments)
tool_id_list.append(tool_call_id)
function_calls = []
for function_name, arguments, tool_id in zip(
functions_list, arguments_list, tool_id_list
):
# This allows compatibility until SambaNova API introduces indexing in tool calls.
if len(arguments) < 1:
continue
arguments = json.loads(arguments)
function_calls.append(
FunctionCallFromLLM(
context=context,
tool_call_id=tool_id,
function_name=function_name,
arguments=arguments,
)
)
await self.run_function_calls(function_calls)

View File

@@ -0,0 +1,65 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
from typing import Any, Optional
from pipecat.services.whisper.base_stt import BaseWhisperSTTService, Transcription
from pipecat.transcriptions.language import Language
class SambaNovaSTTService(BaseWhisperSTTService): # type: ignore
"""SambaNova Whisper speech-to-text service.
Uses SambaNova's Whisper API to convert audio to text.
Requires a SambaNova API key set via the api_key parameter or SAMBANOVA_API_KEY environment variable.
Args:
model: Whisper model to use. Defaults to "Whisper-Large-v3".
api_key: SambaNova API key. Defaults to None.
base_url: API base URL. Defaults to "https://api.sambanova.ai/v1".
language: Language of the audio input. Defaults to English.
prompt: Optional text to guide the model's style or continue a previous segment.
temperature: Optional sampling temperature between 0 and 1. Defaults to 0.0.
**kwargs: Additional arguments passed to `pipecat.services.whisper.base_stt.BaseWhisperSTTService`.
"""
def __init__(
self,
*,
model: str = "Whisper-Large-v3",
api_key: Optional[str] = None,
base_url: str = "https://api.sambanova.ai/v1",
language: Optional[Language] = Language.EN,
prompt: Optional[str] = None,
temperature: Optional[float] = None,
**kwargs: Any,
) -> None:
super().__init__(
model=model,
api_key=api_key,
base_url=base_url,
language=language,
prompt=prompt,
temperature=temperature,
**kwargs,
)
async def _transcribe(self, audio: bytes) -> Transcription:
assert self._language is not None # Assigned in the BaseWhisperSTTService class
# Build kwargs dict with only set parameters
kwargs = {
"file": ("audio.wav", audio, "audio/wav"),
"model": self.model_name,
"response_format": "json",
"language": self._language,
}
if self._prompt is not None:
kwargs["prompt"] = self._prompt
if self._temperature is not None:
kwargs["temperature"] = self._temperature
return await self._client.audio.transcriptions.create(**kwargs)

View File

@@ -18,6 +18,7 @@ from pipecat.frames.frames import (
TTSAudioRawFrame,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor, StartFrame
from pipecat.utils.asyncio.watchdog_async_iterator import WatchdogAsyncIterator
try:
from av.audio.frame import AudioFrame
@@ -61,7 +62,8 @@ class SimliVideoService(FrameProcessor):
async def _consume_and_process_audio(self):
await self._pipecat_resampler_event.wait()
async for audio_frame in self._simli_client.getAudioStreamIterator():
audio_iterator = self._simli_client.getAudioStreamIterator()
async for audio_frame in WatchdogAsyncIterator(audio_iterator, manager=self.task_manager):
resampled_frames = self._pipecat_resampler.resample(audio_frame)
for resampled_frame in resampled_frames:
audio_array = resampled_frame.to_ndarray()
@@ -77,7 +79,8 @@ class SimliVideoService(FrameProcessor):
async def _consume_and_process_video(self):
await self._pipecat_resampler_event.wait()
async for video_frame in self._simli_client.getVideoStreamIterator(targetFormat="rgb24"):
video_iterator = self._simli_client.getVideoStreamIterator(targetFormat="rgb24")
async for video_frame in WatchdogAsyncIterator(video_iterator, manager=self.task_manager):
# Process the video frame
convertedFrame: OutputImageRawFrame = OutputImageRawFrame(
image=video_frame.to_rgb().to_image().tobytes(),

View File

@@ -4,6 +4,8 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Base classes for Speech-to-Text services with continuous and segmented processing."""
import io
import wave
from abc import abstractmethod
@@ -26,7 +28,19 @@ from pipecat.transcriptions.language import Language
class STTService(AIService):
"""STTService is a base class for speech-to-text services."""
"""Base class for speech-to-text services.
Provides common functionality for STT services including audio passthrough,
muting, settings management, and audio processing. Subclasses must implement
the run_stt method to provide actual speech recognition.
Args:
audio_passthrough: Whether to pass audio frames downstream after processing.
Defaults to True.
sample_rate: The sample rate for audio input. If None, will be determined
from the start frame.
**kwargs: Additional arguments passed to the parent AIService.
"""
def __init__(
self,
@@ -44,25 +58,59 @@ class STTService(AIService):
@property
def is_muted(self) -> bool:
"""Returns whether the STT service is currently muted."""
"""Check if the STT service is currently muted.
Returns:
True if the service is muted and will not process audio.
"""
return self._muted
@property
def sample_rate(self) -> int:
"""Get the current sample rate for audio processing.
Returns:
The sample rate in Hz.
"""
return self._sample_rate
async def set_model(self, model: str):
"""Set the speech recognition model.
Args:
model: The name of the model to use for speech recognition.
"""
self.set_model_name(model)
async def set_language(self, language: Language):
"""Set the language for speech recognition.
Args:
language: The language to use for speech recognition.
"""
pass
@abstractmethod
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
"""Returns transcript as a string"""
"""Run speech-to-text on the provided audio data.
This method must be implemented by subclasses to provide actual speech
recognition functionality.
Args:
audio: Raw audio bytes to transcribe.
Yields:
Frame: Frames containing transcription results (typically TextFrame).
"""
pass
async def start(self, frame: StartFrame):
"""Start the STT service.
Args:
frame: The start frame containing initialization parameters.
"""
await super().start(frame)
self._sample_rate = self._init_sample_rate or frame.audio_in_sample_rate
@@ -80,13 +128,24 @@ class STTService(AIService):
logger.warning(f"Unknown setting for STT service: {key}")
async def process_audio_frame(self, frame: AudioRawFrame, direction: FrameDirection):
"""Process an audio frame for speech recognition.
Args:
frame: The audio frame to process.
direction: The direction of frame processing.
"""
if self._muted:
return
await self.process_generator(self.run_stt(frame.audio))
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Processes a frame of audio data, either buffering or transcribing it."""
"""Process frames, handling VAD events and audio segmentation.
Args:
frame: The frame to process.
direction: The direction of frame processing.
"""
await super().process_frame(frame, direction)
if isinstance(frame, AudioRawFrame):
@@ -106,14 +165,19 @@ class STTService(AIService):
class SegmentedSTTService(STTService):
"""SegmentedSTTService is an STTService that uses VAD events to detect
speech and will run speech-to-text on speech segments only, instead of a
continous stream. Since it uses VAD it means that VAD needs to be enabled in
the pipeline.
"""STT service that processes speech in segments using VAD events.
This service always keeps a small audio buffer to take into account that VAD
events are delayed from when the user speech really starts.
Uses Voice Activity Detection (VAD) events to detect speech segments and runs
speech-to-text only on those segments, rather than continuously.
Requires VAD to be enabled in the pipeline to function properly. Maintains a
small audio buffer to account for the delay between actual speech start and
VAD detection.
Args:
sample_rate: The sample rate for audio input. If None, will be determined
from the start frame.
**kwargs: Additional arguments passed to the parent STTService.
"""
def __init__(self, *, sample_rate: Optional[int] = None, **kwargs):
@@ -125,10 +189,16 @@ class SegmentedSTTService(STTService):
self._user_speaking = False
async def start(self, frame: StartFrame):
"""Start the segmented STT service and initialize audio buffer.
Args:
frame: The start frame containing initialization parameters.
"""
await super().start(frame)
self._audio_buffer_size_1s = self.sample_rate * 2
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process frames, handling VAD events and audio segmentation."""
await super().process_frame(frame, direction)
if isinstance(frame, UserStartedSpeakingFrame):
@@ -162,6 +232,15 @@ class SegmentedSTTService(STTService):
self._audio_buffer.clear()
async def process_audio_frame(self, frame: AudioRawFrame, direction: FrameDirection):
"""Process audio frames by buffering them for segmented transcription.
Continuously buffers audio, growing the buffer while user is speaking and
maintaining a small buffer when not speaking to account for VAD delay.
Args:
frame: The audio frame to process.
direction: The direction of frame processing.
"""
# If the user is speaking the audio buffer will keep growing.
self._audio_buffer += frame.audio

View File

@@ -27,6 +27,7 @@ from pipecat.frames.frames import (
from pipecat.processors.frame_processor import FrameDirection, FrameProcessorSetup
from pipecat.services.ai_service import AIService
from pipecat.transports.services.tavus import TavusCallbacks, TavusParams, TavusTransportClient
from pipecat.utils.asyncio.watchdog_queue import WatchdogQueue
class TavusVideoService(AIService):
@@ -71,7 +72,6 @@ class TavusVideoService(AIService):
self._resampler = create_default_resampler()
self._audio_buffer = bytearray()
self._queue = asyncio.Queue()
self._send_task: Optional[asyncio.Task] = None
# This is the custom track destination expected by Tavus
self._transport_destination: Optional[str] = "stream"
@@ -188,7 +188,7 @@ class TavusVideoService(AIService):
async def _create_send_task(self):
if not self._send_task:
self._queue = asyncio.Queue()
self._queue = WatchdogQueue(self.task_manager)
self._send_task = self.create_task(self._send_task_handler())
async def _cancel_send_task(self):
@@ -217,5 +217,6 @@ class TavusVideoService(AIService):
async def _send_task_handler(self):
while True:
frame = await self._queue.get()
if isinstance(frame, OutputAudioRawFrame):
if isinstance(frame, OutputAudioRawFrame) and self._client:
await self._client.write_audio_frame(frame)
self._queue.task_done()

View File

@@ -4,6 +4,8 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Together.ai LLM service implementation using OpenAI-compatible interface."""
from loguru import logger
from pipecat.services.openai.llm import OpenAILLMService
@@ -16,10 +18,10 @@ class TogetherLLMService(OpenAILLMService):
maintaining full compatibility with OpenAI's interface and functionality.
Args:
api_key (str): The API key for accessing Together.ai's API
base_url (str, optional): The base URL for Together.ai API. Defaults to "https://api.together.xyz/v1"
model (str, optional): The model identifier to use. Defaults to "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo"
**kwargs: Additional keyword arguments passed to OpenAILLMService
api_key: The API key for accessing Together.ai's API.
base_url: The base URL for Together.ai API. Defaults to "https://api.together.xyz/v1".
model: The model identifier to use. Defaults to "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo".
**kwargs: Additional keyword arguments passed to OpenAILLMService.
"""
def __init__(
@@ -33,6 +35,15 @@ class TogetherLLMService(OpenAILLMService):
super().__init__(api_key=api_key, base_url=base_url, model=model, **kwargs)
def create_client(self, api_key=None, base_url=None, **kwargs):
"""Create OpenAI-compatible client for Together.ai API endpoint."""
"""Create OpenAI-compatible client for Together.ai API endpoint.
Args:
api_key: The API key to use for the client. If None, uses instance api_key.
base_url: The base URL for the API. If None, uses instance base_url.
**kwargs: Additional keyword arguments passed to the parent create_client method.
Returns:
An OpenAI-compatible client configured for Together.ai's API.
"""
logger.debug(f"Creating Together.ai client with api {base_url}")
return super().create_client(api_key, base_url, **kwargs)

View File

@@ -4,6 +4,8 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Base classes for Text-to-speech services."""
import asyncio
from abc import abstractmethod
from typing import Any, AsyncGenerator, Dict, List, Mapping, Optional, Sequence, Tuple
@@ -35,6 +37,7 @@ from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_service import AIService
from pipecat.services.websocket_service import WebsocketService
from pipecat.transcriptions.language import Language
from pipecat.utils.asyncio.watchdog_queue import WatchdogQueue
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
from pipecat.utils.text.base_text_filter import BaseTextFilter
from pipecat.utils.text.simple_text_aggregator import SimpleTextAggregator
@@ -42,6 +45,28 @@ from pipecat.utils.time import seconds_to_nanoseconds
class TTSService(AIService):
"""Base class for text-to-speech services.
Provides common functionality for TTS services including text aggregation,
filtering, audio generation, and frame management. Supports configurable
sentence aggregation, silence insertion, and frame processing control.
Args:
aggregate_sentences: Whether to aggregate text into sentences before synthesis.
push_text_frames: Whether to push TextFrames and LLMFullResponseEndFrames.
push_stop_frames: Whether to automatically push TTSStoppedFrames.
stop_frame_timeout_s: Idle time before pushing TTSStoppedFrame when push_stop_frames is True.
push_silence_after_stop: Whether to push silence audio after TTSStoppedFrame.
silence_time_s: Duration of silence to push when push_silence_after_stop is True.
pause_frame_processing: Whether to pause frame processing during audio generation.
sample_rate: Output sample rate for generated audio.
text_aggregator: Custom text aggregator for processing incoming text.
text_filters: Sequence of text filters to apply after aggregation.
text_filter: Single text filter (deprecated, use text_filters).
transport_destination: Destination for generated audio frames.
**kwargs: Additional arguments passed to the parent AIService.
"""
def __init__(
self,
*,
@@ -104,54 +129,113 @@ class TTSService(AIService):
@property
def sample_rate(self) -> int:
"""Get the current sample rate for audio output.
Returns:
The sample rate in Hz.
"""
return self._sample_rate
@property
def chunk_size(self) -> int:
"""This property indicates how much audio we download (from TTS services
"""Get the recommended chunk size for audio streaming.
This property indicates how much audio we download (from TTS services
that require chunking) before we start pushing the first audio
frame. This will make sure we download the rest of the audio while audio
is being played without causing audio glitches (specially at the
beginning). Of course, this will also depend on how fast the TTS service
generates bytes.
Returns:
The recommended chunk size in bytes.
"""
CHUNK_SECONDS = 0.5
return int(self.sample_rate * CHUNK_SECONDS * 2) # 2 bytes/sample
async def set_model(self, model: str):
"""Set the TTS model to use.
Args:
model: The name of the TTS model.
"""
self.set_model_name(model)
def set_voice(self, voice: str):
"""Set the voice for speech synthesis.
Args:
voice: The voice identifier or name.
"""
self._voice_id = voice
# Converts the text to audio.
@abstractmethod
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
"""Run text-to-speech synthesis on the provided text.
This method must be implemented by subclasses to provide actual TTS functionality.
Args:
text: The text to synthesize into speech.
Yields:
Frame: Audio frames containing the synthesized speech.
"""
pass
def language_to_service_language(self, language: Language) -> Optional[str]:
"""Convert a language to the service-specific language format.
Args:
language: The language to convert.
Returns:
The service-specific language identifier, or None if not supported.
"""
return Language(language)
async def update_setting(self, key: str, value: Any):
"""Update a service-specific setting.
Args:
key: The setting key to update.
value: The new value for the setting.
"""
pass
async def flush_audio(self):
"""Flush any buffered audio data."""
pass
async def start(self, frame: StartFrame):
"""Start the TTS service.
Args:
frame: The start frame containing initialization parameters.
"""
await super().start(frame)
self._sample_rate = self._init_sample_rate or frame.audio_out_sample_rate
if self._push_stop_frames and not self._stop_frame_task:
self._stop_frame_task = self.create_task(self._stop_frame_handler())
async def stop(self, frame: EndFrame):
"""Stop the TTS service.
Args:
frame: The end frame.
"""
await super().stop(frame)
if self._stop_frame_task:
await self.cancel_task(self._stop_frame_task)
self._stop_frame_task = None
async def cancel(self, frame: CancelFrame):
"""Cancel the TTS service.
Args:
frame: The cancel frame.
"""
await super().cancel(frame)
if self._stop_frame_task:
await self.cancel_task(self._stop_frame_task)
@@ -175,9 +259,23 @@ class TTSService(AIService):
logger.warning(f"Unknown setting for TTS service: {key}")
async def say(self, text: str):
"""Immediately speak the provided text.
Args:
text: The text to speak.
"""
await self.queue_frame(TTSSpeakFrame(text))
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process frames for text-to-speech conversion.
Handles TextFrames for synthesis, interruption frames, settings updates,
and various control frames.
Args:
frame: The frame to process.
direction: The direction of frame processing.
"""
await super().process_frame(frame, direction)
if (
@@ -222,6 +320,12 @@ class TTSService(AIService):
await self.push_frame(frame, direction)
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
"""Push a frame downstream with TTS-specific handling.
Args:
frame: The frame to push.
direction: The direction to push the frame.
"""
if self._push_silence_after_stop and isinstance(frame, TTSStoppedFrame):
silence_num_bytes = int(self._silence_time_s * self.sample_rate * 2) # 16-bit
silence_frame = TTSAudioRawFrame(
@@ -315,46 +419,78 @@ class TTSService(AIService):
if has_started:
await self.push_frame(TTSStoppedFrame())
has_started = False
finally:
self.reset_watchdog()
class WordTTSService(TTSService):
"""This is a base class for TTS services that support word timestamps. Word
timestamps are useful to synchronize audio with text of the spoken
"""Base class for TTS services that support word timestamps.
Word timestamps are useful to synchronize audio with text of the spoken
words. This way only the spoken words are added to the conversation context.
Args:
**kwargs: Additional arguments passed to the parent TTSService.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._initial_word_timestamp = -1
self._words_queue = asyncio.Queue()
self._words_task = None
self._llm_response_started: bool = False
def start_word_timestamps(self):
"""Start tracking word timestamps from the current time."""
if self._initial_word_timestamp == -1:
self._initial_word_timestamp = self.get_clock().get_time()
def reset_word_timestamps(self):
"""Reset word timestamp tracking."""
self._initial_word_timestamp = -1
async def add_word_timestamps(self, word_times: List[Tuple[str, float]]):
"""Add word timestamps to the processing queue.
Args:
word_times: List of (word, timestamp) tuples where timestamp is in seconds.
"""
for word, timestamp in word_times:
await self._words_queue.put((word, seconds_to_nanoseconds(timestamp)))
async def start(self, frame: StartFrame):
"""Start the word TTS service.
Args:
frame: The start frame containing initialization parameters.
"""
await super().start(frame)
self._create_words_task()
async def stop(self, frame: EndFrame):
"""Stop the word TTS service.
Args:
frame: The end frame.
"""
await super().stop(frame)
await self._stop_words_task()
async def cancel(self, frame: CancelFrame):
"""Cancel the word TTS service.
Args:
frame: The cancel frame.
"""
await super().cancel(frame)
await self._stop_words_task()
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process frames with word timestamp awareness.
Args:
frame: The frame to process.
direction: The direction of frame processing.
"""
await super().process_frame(frame, direction)
if isinstance(frame, LLMFullResponseStartFrame):
@@ -369,6 +505,7 @@ class WordTTSService(TTSService):
def _create_words_task(self):
if not self._words_task:
self._words_queue = WatchdogQueue(self.task_manager)
self._words_task = self.create_task(self._words_task_handler())
async def _stop_words_task(self):
@@ -400,15 +537,24 @@ class WordTTSService(TTSService):
class WebsocketTTSService(TTSService, WebsocketService):
"""This is a base class for websocket-based TTS services.
"""Base class for websocket-based TTS services.
If an error occurs with the websocket, an "on_connection_error" event will
be triggered:
Combines TTS functionality with websocket connectivity, providing automatic
error handling and reconnection capabilities.
@tts.event_handler("on_connection_error")
async def on_connection_error(tts: TTSService, error: str):
...
Args:
reconnect_on_error: Whether to automatically reconnect on websocket errors.
**kwargs: Additional arguments passed to parent classes.
Event handlers:
on_connection_error: Called when a websocket connection error occurs.
Example:
```python
@tts.event_handler("on_connection_error")
async def on_connection_error(tts: TTSService, error: str):
logger.error(f"TTS connection error: {error}")
```
"""
def __init__(self, *, reconnect_on_error: bool = True, **kwargs):
@@ -422,10 +568,13 @@ class WebsocketTTSService(TTSService, WebsocketService):
class InterruptibleTTSService(WebsocketTTSService):
"""This is a base class for websocket-based TTS services that don't support
word timestamps and that don't offer a way to correlate the generated audio
to the requested text.
"""Websocket-based TTS service that handles interruptions without word timestamps.
Designed for TTS services that don't support word timestamps. Handles interruptions
by reconnecting the websocket when the bot is speaking and gets interrupted.
Args:
**kwargs: Additional arguments passed to the parent WebsocketTTSService.
"""
def __init__(self, **kwargs):
@@ -443,6 +592,12 @@ class InterruptibleTTSService(WebsocketTTSService):
await self._connect()
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process frames with bot speaking state tracking.
Args:
frame: The frame to process.
direction: The direction of frame processing.
"""
await super().process_frame(frame, direction)
if isinstance(frame, BotStartedSpeakingFrame):
@@ -452,16 +607,23 @@ class InterruptibleTTSService(WebsocketTTSService):
class WebsocketWordTTSService(WordTTSService, WebsocketService):
"""This is a base class for websocket-based TTS services that support word
timestamps.
"""Base class for websocket-based TTS services that support word timestamps.
If an error occurs with the websocket a "on_connection_error" event will be
triggered:
Combines word timestamp functionality with websocket connectivity.
@tts.event_handler("on_connection_error")
async def on_connection_error(tts: TTSService, error: str):
...
Args:
reconnect_on_error: Whether to automatically reconnect on websocket errors.
**kwargs: Additional arguments passed to parent classes.
Event handlers:
on_connection_error: Called when a websocket connection error occurs.
Example:
```python
@tts.event_handler("on_connection_error")
async def on_connection_error(tts: TTSService, error: str):
logger.error(f"TTS connection error: {error}")
```
"""
def __init__(self, *, reconnect_on_error: bool = True, **kwargs):
@@ -475,10 +637,13 @@ class WebsocketWordTTSService(WordTTSService, WebsocketService):
class InterruptibleWordTTSService(WebsocketWordTTSService):
"""This is a base class for websocket-based TTS services that support word
timestamps but don't offer a way to correlate the generated audio to the
requested text.
"""Websocket-based TTS service with word timestamps that handles interruptions.
For TTS services that support word timestamps but can't correlate generated
audio with requested text. Handles interruptions by reconnecting when needed.
Args:
**kwargs: Additional arguments passed to the parent WebsocketWordTTSService.
"""
def __init__(self, **kwargs):
@@ -496,6 +661,12 @@ class InterruptibleWordTTSService(WebsocketWordTTSService):
await self._connect()
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process frames with bot speaking state tracking.
Args:
frame: The frame to process.
direction: The direction of frame processing.
"""
await super().process_frame(frame, direction)
if isinstance(frame, BotStartedSpeakingFrame):
@@ -505,7 +676,9 @@ class InterruptibleWordTTSService(WebsocketWordTTSService):
class AudioContextWordTTSService(WebsocketWordTTSService):
"""This is a base class for websocket-based TTS services that support word
"""Websocket-based TTS service with word timestamps and audio context management.
This is a base class for websocket-based TTS services that support word
timestamps and also allow correlating the generated audio with the requested
text.
@@ -517,22 +690,32 @@ class AudioContextWordTTSService(WebsocketWordTTSService):
we requested audio for a context "A" and then audio for context "B", the
audio from context ID "A" will be played first.
Args:
**kwargs: Additional arguments passed to the parent WebsocketWordTTSService.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._contexts_queue = asyncio.Queue()
self._contexts: Dict[str, asyncio.Queue] = {}
self._audio_context_task = None
async def create_audio_context(self, context_id: str):
"""Create a new audio context."""
"""Create a new audio context for grouping related audio.
Args:
context_id: Unique identifier for the audio context.
"""
await self._contexts_queue.put(context_id)
self._contexts[context_id] = asyncio.Queue()
logger.trace(f"{self} created audio context {context_id}")
async def append_to_audio_context(self, context_id: str, frame: TTSAudioRawFrame):
"""Append audio to an existing context."""
"""Append audio to an existing context.
Args:
context_id: The context to append audio to.
frame: The audio frame to append.
"""
if self.audio_context_available(context_id):
logger.trace(f"{self} appending audio {frame} to audio context {context_id}")
await self._contexts[context_id].put(frame)
@@ -540,7 +723,11 @@ class AudioContextWordTTSService(WebsocketWordTTSService):
logger.warning(f"{self} unable to append audio to context {context_id}")
async def remove_audio_context(self, context_id: str):
"""Remove an existing audio context."""
"""Remove an existing audio context.
Args:
context_id: The context to remove.
"""
if self.audio_context_available(context_id):
# We just mark the audio context for deletion by appending
# None. Once we reach None while handling audio we know we can
@@ -551,14 +738,31 @@ class AudioContextWordTTSService(WebsocketWordTTSService):
logger.warning(f"{self} unable to remove context {context_id}")
def audio_context_available(self, context_id: str) -> bool:
"""Checks whether the given audio context is registered."""
"""Check whether the given audio context is registered.
Args:
context_id: The context ID to check.
Returns:
True if the context exists and is available.
"""
return context_id in self._contexts
async def start(self, frame: StartFrame):
"""Start the audio context TTS service.
Args:
frame: The start frame containing initialization parameters.
"""
await super().start(frame)
self._create_audio_context_task()
async def stop(self, frame: EndFrame):
"""Stop the audio context TTS service.
Args:
frame: The end frame.
"""
await super().stop(frame)
if self._audio_context_task:
# Indicate no more audio contexts are available. this will end the
@@ -568,6 +772,11 @@ class AudioContextWordTTSService(WebsocketWordTTSService):
self._audio_context_task = None
async def cancel(self, frame: CancelFrame):
"""Cancel the audio context TTS service.
Args:
frame: The cancel frame.
"""
await super().cancel(frame)
await self._stop_audio_context_task()
@@ -578,7 +787,7 @@ class AudioContextWordTTSService(WebsocketWordTTSService):
def _create_audio_context_task(self):
if not self._audio_context_task:
self._contexts_queue = asyncio.Queue()
self._contexts_queue = WatchdogQueue(self.task_manager)
self._contexts: Dict[str, asyncio.Queue] = {}
self._audio_context_task = self.create_task(self._audio_context_task_handler())
@@ -620,10 +829,12 @@ class AudioContextWordTTSService(WebsocketWordTTSService):
while running:
try:
frame = await asyncio.wait_for(queue.get(), timeout=AUDIO_CONTEXT_TIMEOUT)
self.reset_watchdog()
if frame:
await self.push_frame(frame)
running = frame is not None
except asyncio.TimeoutError:
self.reset_watchdog()
# We didn't get audio, so let's consider this context finished.
logger.trace(f"{self} time out on audio context {context_id}")
break

View File

@@ -4,6 +4,13 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Vision service implementation.
Provides base classes and implementations for computer vision services that can
analyze images and generate textual descriptions or answers to questions about
visual content.
"""
from abc import abstractmethod
from typing import AsyncGenerator
@@ -13,7 +20,15 @@ from pipecat.services.ai_service import AIService
class VisionService(AIService):
"""VisionService is a base class for vision services."""
"""Base class for vision services.
Provides common functionality for vision services that process images and
generate textual responses. Handles image frame processing and integrates
with the AI service infrastructure for metrics and lifecycle management.
Args:
**kwargs: Additional arguments passed to the parent AIService.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
@@ -21,9 +36,31 @@ class VisionService(AIService):
@abstractmethod
async def run_vision(self, frame: VisionImageRawFrame) -> AsyncGenerator[Frame, None]:
"""Process a vision image frame and generate results.
This method must be implemented by subclasses to provide actual computer
vision functionality such as image description, object detection, or
visual question answering.
Args:
frame: The vision image frame to process, containing image data.
Yields:
Frame: Frames containing the vision analysis results, typically TextFrame
objects with descriptions or answers.
"""
pass
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process frames, handling vision image frames for analysis.
Automatically processes VisionImageRawFrame objects by calling run_vision
and handles metrics tracking. Other frames are passed through unchanged.
Args:
frame: The frame to process.
direction: The direction of frame processing.
"""
await super().process_frame(frame, direction)
if isinstance(frame, VisionImageRawFrame):

View File

@@ -4,6 +4,8 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Base websocket service with automatic reconnection and error handling."""
import asyncio
from abc import ABC, abstractmethod
from typing import Awaitable, Callable, Optional
@@ -17,18 +19,26 @@ from pipecat.utils.network import exponential_backoff_time
class WebsocketService(ABC):
"""Base class for websocket-based services with reconnection logic."""
"""Base class for websocket-based services with automatic reconnection.
Provides websocket connection management, automatic reconnection with
exponential backoff, connection verification, and error handling.
Subclasses implement service-specific connection and message handling logic.
Args:
reconnect_on_error: Whether to automatically reconnect on connection errors.
**kwargs: Additional arguments (unused, for compatibility).
"""
def __init__(self, *, reconnect_on_error: bool = True, **kwargs):
"""Initialize websocket attributes."""
self._websocket: Optional[websockets.WebSocketClientProtocol] = None
self._reconnect_on_error = reconnect_on_error
async def _verify_connection(self) -> bool:
"""Verify websocket connection is working.
"""Verify the websocket connection is active and responsive.
Returns:
bool: True if connection is verified working, False otherwise
True if connection is verified working, False otherwise.
"""
try:
if not self._websocket or self._websocket.closed:
@@ -40,13 +50,13 @@ class WebsocketService(ABC):
return False
async def _reconnect_websocket(self, attempt_number: int) -> bool:
"""Reconnect the websocket.
"""Reconnect the websocket with the current attempt number.
Args:
attempt_number: Current retry attempt number
attempt_number: Current retry attempt number for logging.
Returns:
bool: True if reconnection and verification successful, False otherwise
True if reconnection and verification successful, False otherwise.
"""
logger.warning(f"{self} reconnecting (attempt: {attempt_number})")
await self._disconnect_websocket()
@@ -54,10 +64,14 @@ class WebsocketService(ABC):
return await self._verify_connection()
async def _receive_task_handler(self, report_error: Callable[[ErrorFrame], Awaitable[None]]):
"""Handles WebSocket message receiving with automatic retry logic.
"""Handle websocket message receiving with automatic retry logic.
Continuously receives messages with automatic reconnection on errors.
Uses exponential backoff between retry attempts and reports fatal errors
after maximum retries are exhausted.
Args:
report_error: Callback to report errors
report_error: Callback function to report connection errors.
"""
retry_count = 0
MAX_RETRIES = 3
@@ -98,33 +112,45 @@ class WebsocketService(ABC):
@abstractmethod
async def _connect(self):
"""Implement service-specific connection logic. This function will
connect to the websocket via _connect_websocket() among other connection
logic."""
"""Connect to the service.
Implement service-specific connection logic including websocket connection
via _connect_websocket() and any additional setup required.
"""
pass
@abstractmethod
async def _disconnect(self):
"""Implement service-specific disconnection logic. This function will
disconnect to the websocket via _connect_websocket() among other
connection logic.
"""Disconnect from the service.
Implement service-specific disconnection logic including websocket
disconnection via _disconnect_websocket() and any cleanup required.
"""
pass
@abstractmethod
async def _connect_websocket(self):
"""Implement service-specific websocket connection logic. This function
should only connect to the websocket."""
"""Establish the websocket connection.
Implement the low-level websocket connection logic specific to the service.
Should only handle websocket connection, not additional service setup.
"""
pass
@abstractmethod
async def _disconnect_websocket(self):
"""Implement service-specific websocket disconnection logic. This
function should only disconnect from the websocket."""
"""Close the websocket connection.
Implement the low-level websocket disconnection logic specific to the service.
Should only handle websocket disconnection, not additional service cleanup.
"""
pass
@abstractmethod
async def _receive_messages(self):
"""Implement service-specific message receiving logic."""
"""Receive and process websocket messages.
Implement service-specific logic for receiving and handling messages
from the websocket connection. Called continuously by the receive task handler.
"""
pass

View File

@@ -43,6 +43,8 @@ from pipecat.metrics.metrics import MetricsData
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.transports.base_transport import TransportParams
AUDIO_INPUT_TIMEOUT_SECS = 0.5
class BaseInputTransport(FrameProcessor):
def __init__(self, params: TransportParams, **kwargs):
@@ -56,6 +58,9 @@ class BaseInputTransport(FrameProcessor):
# Track bot speaking state for interruption logic
self._bot_speaking = False
# Track user speaking state for interruption logic
self._user_speaking = False
# We read audio from a single queue one at a time and we then run VAD in
# a thread. Therefore, only one thread should be necessary.
self._executor = ThreadPoolExecutor(max_workers=1)
@@ -130,6 +135,7 @@ class BaseInputTransport(FrameProcessor):
async def start(self, frame: StartFrame):
self._paused = False
self._user_speaking = False
self._sample_rate = self._params.audio_in_sample_rate or frame.audio_in_sample_rate
@@ -240,6 +246,7 @@ class BaseInputTransport(FrameProcessor):
async def _handle_user_interruption(self, frame: Frame):
if isinstance(frame, UserStartedSpeakingFrame):
logger.debug("User started speaking")
self._user_speaking = True
await self.push_frame(frame)
# Only push StartInterruptionFrame if:
@@ -263,6 +270,7 @@ class BaseInputTransport(FrameProcessor):
)
elif isinstance(frame, UserStoppedSpeakingFrame):
logger.debug("User stopped speaking")
self._user_speaking = False
await self.push_frame(frame)
if self.interruptions_allowed:
await self._stop_interruption()
@@ -355,26 +363,40 @@ class BaseInputTransport(FrameProcessor):
async def _audio_task_handler(self):
vad_state: VADState = VADState.QUIET
while True:
frame: InputAudioRawFrame = await self._audio_in_queue.get()
try:
frame: InputAudioRawFrame = await asyncio.wait_for(
self._audio_in_queue.get(), timeout=AUDIO_INPUT_TIMEOUT_SECS
)
# If an audio filter is available, run it before VAD.
if self._params.audio_in_filter:
frame.audio = await self._params.audio_in_filter.filter(frame.audio)
# If an audio filter is available, run it before VAD.
if self._params.audio_in_filter:
frame.audio = await self._params.audio_in_filter.filter(frame.audio)
# Check VAD and push event if necessary. We just care about
# changes from QUIET to SPEAKING and vice versa.
previous_vad_state = vad_state
if self._params.vad_analyzer:
vad_state = await self._handle_vad(frame, vad_state)
# Check VAD and push event if necessary. We just care about
# changes from QUIET to SPEAKING and vice versa.
previous_vad_state = vad_state
if self._params.vad_analyzer:
vad_state = await self._handle_vad(frame, vad_state)
if self._params.turn_analyzer:
await self._run_turn_analyzer(frame, vad_state, previous_vad_state)
if self._params.turn_analyzer:
await self._run_turn_analyzer(frame, vad_state, previous_vad_state)
# Push audio downstream if passthrough is set.
if self._params.audio_in_passthrough:
await self.push_frame(frame)
# Push audio downstream if passthrough is set.
if self._params.audio_in_passthrough:
await self.push_frame(frame)
self._audio_in_queue.task_done()
self._audio_in_queue.task_done()
except asyncio.TimeoutError:
if self._user_speaking:
logger.warning(
"Forcing user stopped speaking due to timeout receiving audio frame!"
)
vad_state = VADState.QUIET
if self._params.turn_analyzer:
self._params.turn_analyzer.clear()
await self._handle_user_interruption(UserStoppedSpeakingFrame())
finally:
self.reset_watchdog()
async def _handle_prediction_result(self, result: MetricsData):
"""Handle a prediction result event from the turn analyzer.

Some files were not shown because too many files have changed in this diff Show More