Compare commits
134 Commits
v0.0.76
...
jpt/runner
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2b1f056aa7 | ||
|
|
2be615066c | ||
|
|
1bb821a07d | ||
|
|
d8bcb81f35 | ||
|
|
3ce0ab8c6d | ||
|
|
097d786431 | ||
|
|
662f04879c | ||
|
|
7a69f57e11 | ||
|
|
5b7b4efdc9 | ||
|
|
cfa26524ca | ||
|
|
3d4ab7158d | ||
|
|
26d1ca3c98 | ||
|
|
083b32887e | ||
|
|
3391929127 | ||
|
|
ebf9bc2741 | ||
|
|
f5edde42f6 | ||
|
|
37bb7ef926 | ||
|
|
a63d1530a4 | ||
|
|
960bc9df5b | ||
|
|
e2a153ee01 | ||
|
|
300f19ad23 | ||
|
|
7955080da2 | ||
|
|
994e82c1ef | ||
|
|
b07b947352 | ||
|
|
a6527c3856 | ||
|
|
0e6874b605 | ||
|
|
9ba172c49f | ||
|
|
f710c94b6e | ||
|
|
6e3a0a2d5d | ||
|
|
9530b8b842 | ||
|
|
26c937af87 | ||
|
|
976f6168f0 | ||
|
|
0be64e0fd9 | ||
|
|
7d527c3a6b | ||
|
|
c6f6930c27 | ||
|
|
c33dfe8309 | ||
|
|
769cd1ef06 | ||
|
|
6d72f60571 | ||
|
|
e8d0712ac1 | ||
|
|
88b2c817ac | ||
|
|
f8f6c9918d | ||
|
|
8ee608bbfe | ||
|
|
fad2ba4570 | ||
|
|
f609f7eb53 | ||
|
|
ea09813a2b | ||
|
|
53abfc27a7 | ||
|
|
9c72e96a2c | ||
|
|
f66c67c4ab | ||
|
|
b623face03 | ||
|
|
698d60f3ae | ||
|
|
c9717a23a5 | ||
|
|
d981ce6e56 | ||
|
|
1bbd3bd8ab | ||
|
|
a20915caa7 | ||
|
|
28cab5a606 | ||
|
|
cfea56064d | ||
|
|
8467d87cfc | ||
|
|
b20d020bea | ||
|
|
948257c66e | ||
|
|
b54d1fb7fd | ||
|
|
ec361df0d1 | ||
|
|
b1a5cddde4 | ||
|
|
e165d38277 | ||
|
|
8ba340a8a5 | ||
|
|
d4e33663b2 | ||
|
|
d7d1b16dad | ||
|
|
0bc2ea13f2 | ||
|
|
b5d1301221 | ||
|
|
ed8f30ec71 | ||
|
|
a74a935ca0 | ||
|
|
7cfd56699b | ||
|
|
cb984237a7 | ||
|
|
c969fdddb9 | ||
|
|
9931ad2ce1 | ||
|
|
fd73feb645 | ||
|
|
ee78428a2a | ||
|
|
ae02249255 | ||
|
|
727af2e6fb | ||
|
|
8fd5576879 | ||
|
|
1f85dcee7c | ||
|
|
138890bc5c | ||
|
|
a094efc9e6 | ||
|
|
1f9e2fdecc | ||
|
|
4a2b4660bc | ||
|
|
b3ac90015a | ||
|
|
2fe06f0a4e | ||
|
|
fe8573322f | ||
|
|
5c3fb73cef | ||
|
|
2e84c91748 | ||
|
|
650d45c1f4 | ||
|
|
61ac77be72 | ||
|
|
c093eb5b63 | ||
|
|
98e24131bd | ||
|
|
7becce9e8c | ||
|
|
3cdaeb719a | ||
|
|
8daaea5969 | ||
|
|
dc47516e14 | ||
|
|
0f727248d2 | ||
|
|
7ed4fe50d4 | ||
|
|
6f66ec1727 | ||
|
|
c7e758fc36 | ||
|
|
14c22234bb | ||
|
|
d565e9ae53 | ||
|
|
4951c97eab | ||
|
|
9b38f3e2fa | ||
|
|
a297e4208e | ||
|
|
1cf0b35ac1 | ||
|
|
c54084b7a4 | ||
|
|
e3fe040017 | ||
|
|
ae5e3e2dc4 | ||
|
|
77378d2779 | ||
|
|
4106f0dabe | ||
|
|
2ed1ed6821 | ||
|
|
6d3a38842d | ||
|
|
7360f79413 | ||
|
|
8d55e13750 | ||
|
|
737e8e79c9 | ||
|
|
4d977fede0 | ||
|
|
8070e156d8 | ||
|
|
43c6f1f5cd | ||
|
|
f53f5445ba | ||
|
|
7263d11ee4 | ||
|
|
f2d5b9ad69 | ||
|
|
40c7e3c52c | ||
|
|
ee5fea4221 | ||
|
|
db7b60cfe9 | ||
|
|
51b79bd6a1 | ||
|
|
95fe762776 | ||
|
|
2968c846ce | ||
|
|
e27da96cdc | ||
|
|
d86502e79a | ||
|
|
59c7744590 | ||
|
|
949971dea9 | ||
|
|
cd4a893c65 |
86
CHANGELOG.md
86
CHANGELOG.md
@@ -5,6 +5,92 @@ All notable changes to **Pipecat** will be documented in this file.
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
### Added
|
||||
|
||||
- Added a new field `handle_sigterm` to `PipelineRunner`. It defaults to `False`.
|
||||
This field handles SIGTERM signals. The `handle_sigint` field still defaults
|
||||
to `True`, but now it handles only SIGINT signals.
|
||||
|
||||
- Added foundational example `14u-function-calling-ollama.py` for Ollama
|
||||
function calling.
|
||||
|
||||
- Added `LocalSmartTurnAnalyzerV2`, which supports local on-device inference
|
||||
with the new `smart-turn-v2` turn detection model.
|
||||
|
||||
- Added `set_log_level` to `DailyTransport`, allowing setting the logging level
|
||||
for Daily's internal logging system.
|
||||
|
||||
### Changed
|
||||
|
||||
- Play delayed messages from `ElevenLabsTTSService` if they still belong to the
|
||||
current context.
|
||||
|
||||
- Dependency compatibility improvements: Relaxed version constraints for core
|
||||
dependencies to support broader version ranges while maintaining stability:
|
||||
|
||||
- `aiohttp`, `Markdown`, `nltk`, `numpy`, `Pillow`, `pydantic`, `openai`,
|
||||
`numba`: Now support up to the next major version (e.g. `numpy>=1.26.4,<3`)
|
||||
- `pyht`: Relaxed to `>=0.1.6` to resolve `grpcio` conflicts with
|
||||
`nvidia-riva-client`
|
||||
- `fastapi`: Updated to support versions `>=0.115.6,<0.117.0`
|
||||
- `torch`/`torchaudio`: Changed from exact pinning (`==2.5.0`) to compatible
|
||||
range (`~=2.5.0`)
|
||||
- `aws_sdk_bedrock_runtime`: Added Python 3.12+ constraint via environment
|
||||
marker
|
||||
- `numba`: Reduced minimum version to `0.60.0` for better compatibility
|
||||
|
||||
- Changed `NeuphonicHttpTTSService` to use a POST based request instead of the
|
||||
`pyneuphonic` package. This removes a package requirement, allowing Neuphonic
|
||||
to work with more services.
|
||||
|
||||
- Updated the `deepgram` optional dependency to 4.7.0, which downgrades the
|
||||
`tasks cancelled error` to a debug log. This removes the log from appearing
|
||||
in Pipecat logs upon leaving.
|
||||
|
||||
- Upgraded the `websockets` implementation to the new asyncio implementation.
|
||||
Along with this change, we're updating support for versions >=13.1.0 and
|
||||
<15.0.0. All services have been update to use the asyncio implementation.
|
||||
|
||||
- Updated `MiniMaxHttpTTSService` with a `base_url` arg where you can specify
|
||||
the Global endpoint (default) or Mainland China.
|
||||
|
||||
- Replaced regex-based sentence detection in `match_endofsentence` with NLTK's
|
||||
punkt_tab tokenizer for more reliable sentence boundary detection.
|
||||
|
||||
- Changed the `livekit` optional dependency for `tenacity` to
|
||||
`tenacity>=8.2.3,<10.0.0` in order to support the `google-genai` package.
|
||||
|
||||
- For `LmntTTSService`, changed the default `model` to `blizzard`, LMNT's
|
||||
recommended model.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed a dependency issue for uv users where an `llvmlite` version required python 3.9.
|
||||
|
||||
- Fixed an issue in `MiniMaxHttpTTSService` where the `pitch` param was the
|
||||
incorrect type.
|
||||
|
||||
- Fixed an issue with OpenTelemetry tracing where the `enable_tracing` flag did
|
||||
not disable the internal tracing decorator functions.
|
||||
|
||||
- Fixed an issue in `OLLamaLLMService` where kwargs were not passed correctly
|
||||
to the parent class.
|
||||
|
||||
- Fixed an issue in `ElevenLabsTTSService` where the word/timestamp pairs were
|
||||
calculating word boundaries incorrectly.
|
||||
|
||||
- Fixed an issue where, in some edge cases, the `EmulateUserStartedSpeakingFrame`
|
||||
could be created even if we didn't have a transcription.
|
||||
|
||||
- Fixed an issue in `GoogleLLMContext` where it would inject the
|
||||
`system_message` as a "user" message into cases where it was not meant to;
|
||||
it was only meant to do that when there were no "regular" (non-function-call)
|
||||
messages in the context, to ensure that inference would run properly.
|
||||
|
||||
- Fixed an issue in `LiveKitTransport` where the `on_audio_track_subscribed` was never emitted.
|
||||
|
||||
## [0.0.76] - 2025-07-11
|
||||
|
||||
### Added
|
||||
|
||||
@@ -53,7 +53,7 @@ You can connect to Pipecat from any platform using our official SDKs:
|
||||
|
||||
| Category | Services |
|
||||
| ------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| Speech-to-Text | [AssemblyAI](https://docs.pipecat.ai/server/services/stt/assemblyai), [AWS](https://docs.pipecat.ai/server/services/stt/aws), [Azure](https://docs.pipecat.ai/server/services/stt/azure), [Cartesia](https://docs.pipecat.ai/server/services/stt/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/stt/deepgram), [Fal Wizper](https://docs.pipecat.ai/server/services/stt/fal), [Gladia](https://docs.pipecat.ai/server/services/stt/gladia), [Google](https://docs.pipecat.ai/server/services/stt/google), [Groq (Whisper)](https://docs.pipecat.ai/server/services/stt/groq), [OpenAI (Whisper)](https://docs.pipecat.ai/server/services/stt/openai), [Parakeet (NVIDIA)](https://docs.pipecat.ai/server/services/stt/parakeet), [SambaNova (Whisper)](https://docs.pipecat.ai/server/services/stt/sambanova), [Speechmatics](https://docs.pipecat.ai/server/services/stt/speechmatics), [Ultravox](https://docs.pipecat.ai/server/services/stt/ultravox), [Whisper](https://docs.pipecat.ai/server/services/stt/whisper) |
|
||||
| Speech-to-Text | [AssemblyAI](https://docs.pipecat.ai/server/services/stt/assemblyai), [AWS](https://docs.pipecat.ai/server/services/stt/aws), [Azure](https://docs.pipecat.ai/server/services/stt/azure), [Cartesia](https://docs.pipecat.ai/server/services/stt/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/stt/deepgram), [Fal Wizper](https://docs.pipecat.ai/server/services/stt/fal), [Gladia](https://docs.pipecat.ai/server/services/stt/gladia), [Google](https://docs.pipecat.ai/server/services/stt/google), [Groq (Whisper)](https://docs.pipecat.ai/server/services/stt/groq), [OpenAI (Whisper)](https://docs.pipecat.ai/server/services/stt/openai), [Parakeet (NVIDIA)](https://docs.pipecat.ai/server/services/stt/parakeet), [SambaNova (Whisper)](https://docs.pipecat.ai/server/services/stt/sambanova), [Soniox](https://docs.pipecat.ai/server/services/stt/soniox), [Speechmatics](https://docs.pipecat.ai/server/services/stt/speechmatics), [Ultravox](https://docs.pipecat.ai/server/services/stt/ultravox), [Whisper](https://docs.pipecat.ai/server/services/stt/whisper) |
|
||||
| LLMs | [Anthropic](https://docs.pipecat.ai/server/services/llm/anthropic), [AWS](https://docs.pipecat.ai/server/services/llm/aws), [Azure](https://docs.pipecat.ai/server/services/llm/azure), [Cerebras](https://docs.pipecat.ai/server/services/llm/cerebras), [DeepSeek](https://docs.pipecat.ai/server/services/llm/deepseek), [Fireworks AI](https://docs.pipecat.ai/server/services/llm/fireworks), [Gemini](https://docs.pipecat.ai/server/services/llm/gemini), [Grok](https://docs.pipecat.ai/server/services/llm/grok), [Groq](https://docs.pipecat.ai/server/services/llm/groq), [NVIDIA NIM](https://docs.pipecat.ai/server/services/llm/nim), [Ollama](https://docs.pipecat.ai/server/services/llm/ollama), [OpenAI](https://docs.pipecat.ai/server/services/llm/openai), [OpenRouter](https://docs.pipecat.ai/server/services/llm/openrouter), [Perplexity](https://docs.pipecat.ai/server/services/llm/perplexity), [Qwen](https://docs.pipecat.ai/server/services/llm/qwen), [SambaNova](https://docs.pipecat.ai/server/services/llm/sambanova) [Together AI](https://docs.pipecat.ai/server/services/llm/together) |
|
||||
| Text-to-Speech | [AWS](https://docs.pipecat.ai/server/services/tts/aws), [Azure](https://docs.pipecat.ai/server/services/tts/azure), [Cartesia](https://docs.pipecat.ai/server/services/tts/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/tts/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/tts/elevenlabs), [FastPitch (NVIDIA)](https://docs.pipecat.ai/server/services/tts/fastpitch), [Fish](https://docs.pipecat.ai/server/services/tts/fish), [Google](https://docs.pipecat.ai/server/services/tts/google), [LMNT](https://docs.pipecat.ai/server/services/tts/lmnt), [MiniMax](https://docs.pipecat.ai/server/services/tts/minimax), [Neuphonic](https://docs.pipecat.ai/server/services/tts/neuphonic), [OpenAI](https://docs.pipecat.ai/server/services/tts/openai), [Piper](https://docs.pipecat.ai/server/services/tts/piper), [PlayHT](https://docs.pipecat.ai/server/services/tts/playht), [Rime](https://docs.pipecat.ai/server/services/tts/rime), [Sarvam](https://docs.pipecat.ai/server/services/tts/sarvam), [XTTS](https://docs.pipecat.ai/server/services/tts/xtts) |
|
||||
| Speech-to-Speech | [AWS Nova Sonic](https://docs.pipecat.ai/server/services/s2s/aws), [Gemini Multimodal Live](https://docs.pipecat.ai/server/services/s2s/gemini), [OpenAI Realtime](https://docs.pipecat.ai/server/services/s2s/openai) |
|
||||
|
||||
@@ -11,3 +11,10 @@ ruff~=0.12.1
|
||||
setuptools~=78.1.1
|
||||
setuptools_scm~=8.3.1
|
||||
python-dotenv~=1.1.1
|
||||
|
||||
# For running examples
|
||||
uvicorn
|
||||
python-dotenv
|
||||
fastapi
|
||||
aiohttp
|
||||
aiortc
|
||||
@@ -77,6 +77,7 @@ autodoc_mock_imports = [
|
||||
"openpipe",
|
||||
"simli",
|
||||
"soundfile",
|
||||
"soniox",
|
||||
"pipecat_ai_krisp",
|
||||
"pyaudio",
|
||||
"_tkinter",
|
||||
|
||||
@@ -46,6 +46,7 @@ pipecat-ai[sambanova]
|
||||
pipecat-ai[silero]
|
||||
pipecat-ai[simli]
|
||||
pipecat-ai[soundfile]
|
||||
pipecat-ai[soniox]
|
||||
pipecat-ai[speechmatics]
|
||||
pipecat-ai[tavus]
|
||||
pipecat-ai[together]
|
||||
|
||||
@@ -109,6 +109,9 @@ MINIMAX_GROUP_ID=...
|
||||
# Sarvam AI
|
||||
SARVAM_API_KEY=...
|
||||
|
||||
# Soniox
|
||||
SONIOX_API_KEY=
|
||||
|
||||
# Speechmatics
|
||||
SPEECHMATICS_API_KEY=...
|
||||
|
||||
|
||||
60
examples/aws-strands/README.md
Normal file
60
examples/aws-strands/README.md
Normal file
@@ -0,0 +1,60 @@
|
||||
# AWS Strands Examples
|
||||
|
||||
This folder contains two Python examples demonstrating how to use Pipecat with the AWS Strands agent.
|
||||
|
||||
## Overview
|
||||
|
||||
These examples show how to delegate complex, multi-step tasks to a Strands agent, which can reason step-by-step and call tools to accomplish user requests.
|
||||
|
||||
These examples are intentionally simplified for demonstration, using mock API calls. They work best if you ask it:
|
||||
|
||||
> What's the weather where the Golden Gate Bridge is?
|
||||
|
||||
## Example Scripts
|
||||
|
||||
### `black-box.py`
|
||||
|
||||
A minimal example that demonstrates how to use the Strands agent with Pipecat. The agent can handle multi-step queries by calling tools, but does not explain its reasoning out loud.
|
||||
|
||||
### `explain-thinking.py`
|
||||
|
||||
An enhanced example where the Strands agent explains each step of its reasoning in clear, simple language as it works through a multi-step task.
|
||||
|
||||
## Quick Start
|
||||
|
||||
1. **Clone the repository and navigate to this example:**
|
||||
|
||||
```bash
|
||||
git clone https://github.com/pipecat-ai/pipecat.git
|
||||
cd pipecat/examples/aws-strands
|
||||
```
|
||||
|
||||
2. **Set up a virtual environment:**
|
||||
|
||||
```bash
|
||||
python -m venv venv
|
||||
source venv/bin/activate # On Windows: venv\Scripts\activate
|
||||
```
|
||||
|
||||
3. **Install dependencies:**
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
4. **Configure environment variables:**
|
||||
|
||||
Copy the provided `env.example` file to `.env` and fill in the necessary credentials:
|
||||
|
||||
```bash
|
||||
cp env.example .env
|
||||
# Then edit .env with your preferred editor
|
||||
```
|
||||
|
||||
5. **Run an example:**
|
||||
|
||||
```bash
|
||||
python black-box.py
|
||||
# or
|
||||
python explain-thinking.py
|
||||
```
|
||||
206
examples/aws-strands/black-box.py
Normal file
206
examples/aws-strands/black-box.py
Normal file
@@ -0,0 +1,206 @@
|
||||
#
|
||||
# Copyright (c) 2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from strands import Agent, tool
|
||||
from strands.models import BedrockModel
|
||||
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import TTSSpeakFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.network.fastapi_websocket import FastAPIWebsocketParams
|
||||
from pipecat.transports.services.daily import DailyParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
"""This example demonstrates how to use the Strands agent with Pipecat.
|
||||
|
||||
You can delegate complex, multi-step tasks to the Strands agent, which can cycle through LLM-based reasoning and tool calls to accomplish the task.
|
||||
|
||||
Try asking: "What's the weather where the Golden Gate Bridge is?"
|
||||
"""
|
||||
|
||||
# Strands agent tools
|
||||
|
||||
|
||||
@tool
|
||||
def get_location_name_from_landmark(landmark: str) -> str:
|
||||
"""
|
||||
Get the location name from a landmark.
|
||||
|
||||
Args:
|
||||
landmark (str): The name of the landmark, e.g. "Golden Gate Bridge".
|
||||
"""
|
||||
# Simulate fetching location
|
||||
return "San Francisco, CA"
|
||||
|
||||
|
||||
@tool
|
||||
def get_lat_long_from_location_name(location: str) -> dict:
|
||||
"""
|
||||
Get the latitude and longitude for a location name.
|
||||
|
||||
Args:
|
||||
location (str): The city and state, e.g. "San Francisco, CA".
|
||||
"""
|
||||
# Simulate fetching lat/long from a geocoding service
|
||||
return {"lat": 37.7749, "long": -122.4194}
|
||||
|
||||
|
||||
@tool
|
||||
def get_current_weather_from_lat_long(lat: float, long: float) -> dict:
|
||||
"""
|
||||
Get the current weather for a specific latitude and longitude.
|
||||
|
||||
Args:
|
||||
lat (float): The latitude of the location.
|
||||
long (float): The longitude of the location.
|
||||
"""
|
||||
# Simulate fetching weather data from a weather service
|
||||
return {"conditions": "nice", "temperature": "75"}
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_example(transport: BaseTransport, _: argparse.Namespace, handle_sigint: bool):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
strands_agent = Agent(
|
||||
model=BedrockModel(
|
||||
model_id="us.anthropic.claude-3-7-sonnet-20250219-v1:0", max_tokens=64000
|
||||
),
|
||||
tools=[
|
||||
get_location_name_from_landmark,
|
||||
get_lat_long_from_location_name,
|
||||
get_current_weather_from_lat_long,
|
||||
],
|
||||
system_prompt="""
|
||||
You are a helpful personal assistant who can look up information about places and weather.
|
||||
|
||||
Your key capabilities:
|
||||
1. Look up where landmarks are located.
|
||||
2. Find latitude and longitude for a location.
|
||||
3. Look up the current weather for a specific latitude and longitude.
|
||||
|
||||
Explain each step of your reasoning in clear, simple, and concise language. Your responses will be converted to audio, so avoid special characters and numbered lists.
|
||||
""",
|
||||
)
|
||||
|
||||
async def handle_location_or_weather_related_queries(params: FunctionCallParams, query: str):
|
||||
"""
|
||||
Handle location or weather related queries.
|
||||
|
||||
Args:
|
||||
query (str): The user's query, e.g. "What's the weather where the Golden Gate Bridge is?".
|
||||
"""
|
||||
# Run in a background thread
|
||||
# (Otherwise the agent blocks the event loop; one effect of that is that we don't hear
|
||||
# "let me check on that" until the agent finishes)
|
||||
loop = asyncio.get_running_loop()
|
||||
result = await loop.run_in_executor(None, strands_agent, query)
|
||||
await params.result_callback(result.message)
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
llm.register_direct_function(handle_location_or_weather_related_queries)
|
||||
|
||||
@llm.event_handler("on_function_calls_started")
|
||||
async def on_function_calls_started(service, function_calls):
|
||||
await tts.queue_frame(TTSSpeakFrame("Let me check on that."))
|
||||
|
||||
tools = ToolsSchema(standard_tools=[handle_location_or_weather_related_queries])
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way. Start by suggesting that the user ask about the weather where the Golden Gate Bridge is.",
|
||||
},
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages, tools)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
stt,
|
||||
context_aggregator.user(),
|
||||
llm,
|
||||
tts,
|
||||
transport.output(),
|
||||
context_aggregator.assistant(),
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.examples.run import main
|
||||
|
||||
main(run_example, transport_params=transport_params)
|
||||
8
examples/aws-strands/env.example
Normal file
8
examples/aws-strands/env.example
Normal file
@@ -0,0 +1,8 @@
|
||||
OPENAI_API_KEY=
|
||||
CARTESIA_API_KEY=
|
||||
DEEPGRAM_API_KEY=
|
||||
DAILY_API_KEY=
|
||||
DAILY_SAMPLE_ROOM_URL=
|
||||
AWS_SECRET_ACCESS_KEY=
|
||||
AWS_ACCESS_KEY_ID=
|
||||
AWS_REGION=
|
||||
249
examples/aws-strands/explain-thinking.py
Normal file
249
examples/aws-strands/explain-thinking.py
Normal file
@@ -0,0 +1,249 @@
|
||||
#
|
||||
# Copyright (c) 2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from strands import Agent, tool
|
||||
from strands.models import BedrockModel
|
||||
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import TTSSpeakFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.network.fastapi_websocket import FastAPIWebsocketParams
|
||||
from pipecat.transports.services.daily import DailyParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
"""This example demonstrates how to use the Strands agent with Pipecat in a way where the agent explains its reasoning step-by-step.
|
||||
|
||||
You can delegate complex, multi-step tasks to the Strands agent, which can cycle through LLM-based reasoning and tool calls to accomplish the task.
|
||||
|
||||
Try asking: "What's the weather where the Golden Gate Bridge is?"
|
||||
"""
|
||||
|
||||
|
||||
# Strands agent tools
|
||||
|
||||
|
||||
@tool
|
||||
def get_location_name_from_landmark(landmark: str) -> str:
|
||||
"""
|
||||
Get the location name from a landmark.
|
||||
|
||||
Args:
|
||||
landmark (str): The name of the landmark, e.g. "Golden Gate Bridge".
|
||||
"""
|
||||
# Simulate fetching location (slowly)
|
||||
time.sleep(3)
|
||||
return "San Francisco, CA"
|
||||
|
||||
|
||||
@tool
|
||||
def get_lat_long_from_location_name(location: str) -> dict:
|
||||
"""
|
||||
Get the latitude and longitude for a location name.
|
||||
|
||||
Args:
|
||||
location (str): The city and state, e.g. "San Francisco, CA".
|
||||
"""
|
||||
# Simulate fetching lat/long from a geocoding service (slowly)
|
||||
time.sleep(3)
|
||||
return {"lat": 37.7749, "long": -122.4194}
|
||||
|
||||
|
||||
@tool
|
||||
def get_current_weather_from_lat_long(lat: float, long: float) -> dict:
|
||||
"""
|
||||
Get the current weather for a specific latitude and longitude.
|
||||
|
||||
Args:
|
||||
lat (float): The latitude of the location.
|
||||
long (float): The longitude of the location.
|
||||
"""
|
||||
# Simulate fetching weather data from a weather service (slowly)
|
||||
time.sleep(3)
|
||||
return {"conditions": "nice", "temperature": "75"}
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_example(transport: BaseTransport, _: argparse.Namespace, handle_sigint: bool):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
)
|
||||
|
||||
next_strands_message_is_last = False
|
||||
strands_messages_queue = asyncio.Queue()
|
||||
|
||||
def strands_callback_handler(**kwargs):
|
||||
"""
|
||||
Handle events from the Strands agent.
|
||||
"""
|
||||
nonlocal next_strands_message_is_last
|
||||
if "event" in kwargs:
|
||||
event_obj = kwargs["event"]
|
||||
if event_obj and "messageStop" in event_obj:
|
||||
message_stop = event_obj["messageStop"]
|
||||
if message_stop and "stopReason" in message_stop:
|
||||
stop_reason = message_stop["stopReason"]
|
||||
if stop_reason == "end_turn":
|
||||
next_strands_message_is_last = True
|
||||
elif "message" in kwargs:
|
||||
message_obj = kwargs["message"]
|
||||
if message_obj and "content" in message_obj and "role" in message_obj:
|
||||
role = message_obj["role"]
|
||||
content = message_obj["content"]
|
||||
if role == "assistant" and isinstance(content, list):
|
||||
for content_obj in content:
|
||||
if isinstance(content_obj, dict) and "text" in content_obj:
|
||||
message = content_obj["text"]
|
||||
if not next_strands_message_is_last:
|
||||
strands_messages_queue.put_nowait(message)
|
||||
|
||||
async def process_strands_messages():
|
||||
while True:
|
||||
message = await strands_messages_queue.get()
|
||||
await tts.queue_frame(TTSSpeakFrame(message))
|
||||
strands_messages_queue.task_done()
|
||||
|
||||
asyncio.create_task(process_strands_messages())
|
||||
|
||||
strands_agent = Agent(
|
||||
model=BedrockModel(
|
||||
model_id="us.anthropic.claude-3-7-sonnet-20250219-v1:0", max_tokens=64000
|
||||
),
|
||||
tools=[
|
||||
get_location_name_from_landmark,
|
||||
get_lat_long_from_location_name,
|
||||
get_current_weather_from_lat_long,
|
||||
],
|
||||
system_prompt="""
|
||||
You are a helpful personal assistant who can look up information about places and weather.
|
||||
|
||||
Your key capabilities:
|
||||
1. Look up where landmarks are located.
|
||||
2. Find latitude and longitude for a location.
|
||||
3. Look up the current weather for a specific latitude and longitude.
|
||||
|
||||
Explain each step of your reasoning in clear, simple, and concise language. Your responses will be converted to audio, so avoid special characters and numbered lists.
|
||||
""",
|
||||
callback_handler=strands_callback_handler,
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
async def handle_location_or_weather_related_queries(params: FunctionCallParams, query: str):
|
||||
"""
|
||||
Handle location or weather related queries.
|
||||
|
||||
Args:
|
||||
query (str): The user's query, e.g. "What's the weather where the Golden Gate Bridge is?".
|
||||
"""
|
||||
# Run in a background thread
|
||||
# (Otherwise the agent blocks the event loop; one effect of that is that we don't hear
|
||||
# the agent's "thinking" messages until the agent finishes)
|
||||
loop = asyncio.get_running_loop()
|
||||
result = await loop.run_in_executor(None, strands_agent, query)
|
||||
await params.result_callback(result.message)
|
||||
|
||||
llm.register_direct_function(handle_location_or_weather_related_queries)
|
||||
|
||||
@llm.event_handler("on_function_calls_started")
|
||||
async def on_function_calls_started(service, function_calls):
|
||||
await tts.queue_frame(TTSSpeakFrame("Let me check on that."))
|
||||
|
||||
tools = ToolsSchema(standard_tools=[handle_location_or_weather_related_queries])
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way. Start by suggesting that the user ask about the weather where the Golden Gate Bridge is.",
|
||||
},
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages, tools)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
stt,
|
||||
context_aggregator.user(),
|
||||
llm,
|
||||
tts,
|
||||
transport.output(),
|
||||
context_aggregator.assistant(),
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.examples.run import main
|
||||
|
||||
main(run_example, transport_params=transport_params)
|
||||
@@ -2,4 +2,5 @@ fastapi
|
||||
uvicorn
|
||||
python-dotenv
|
||||
pipecat-ai[webrtc,daily,deepgram,cartesia]
|
||||
pipecat-ai-small-webrtc-prebuilt
|
||||
pipecat-ai-small-webrtc-prebuilt
|
||||
strands-agents
|
||||
@@ -301,7 +301,7 @@ def fastapi_app():
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Include the endpoints from endpoints.py
|
||||
# Include the endpoints from this file
|
||||
web_app.include_router(router)
|
||||
|
||||
return web_app
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
"name": "my-daily-app",
|
||||
"version": "0.1.0",
|
||||
"dependencies": {
|
||||
"axios": "^1.6.0",
|
||||
"axios": "^1.11.0",
|
||||
"next": "^14.0.0",
|
||||
"pino": "^8.15.0",
|
||||
"react": "^18.2.0",
|
||||
@@ -1165,13 +1165,13 @@
|
||||
}
|
||||
},
|
||||
"node_modules/axios": {
|
||||
"version": "1.8.4",
|
||||
"resolved": "https://registry.npmjs.org/axios/-/axios-1.8.4.tgz",
|
||||
"integrity": "sha512-eBSYY4Y68NNlHbHBMdeDmKNtDgXWhQsJcGqzO3iLUM0GraQFSS9cVgPX5I9b3lbdFKyYoAEGAZF1DwhTaljNAw==",
|
||||
"version": "1.11.0",
|
||||
"resolved": "https://registry.npmjs.org/axios/-/axios-1.11.0.tgz",
|
||||
"integrity": "sha512-1Lx3WLFQWm3ooKDYZD1eXmoGO9fxYQjrycfHFC8P0sCfQVXyROp0p9PFWBehewBOdCwHc+f/b8I0fMto5eSfwA==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"follow-redirects": "^1.15.6",
|
||||
"form-data": "^4.0.0",
|
||||
"form-data": "^4.0.4",
|
||||
"proxy-from-env": "^1.1.0"
|
||||
}
|
||||
},
|
||||
@@ -2436,14 +2436,15 @@
|
||||
}
|
||||
},
|
||||
"node_modules/form-data": {
|
||||
"version": "4.0.2",
|
||||
"resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.2.tgz",
|
||||
"integrity": "sha512-hGfm/slu0ZabnNt4oaRZ6uREyfCj6P4fT/n6A1rGV+Z0VdGXjfOhVUpkn6qVQONHGIFwmveGXyDs75+nr6FM8w==",
|
||||
"version": "4.0.4",
|
||||
"resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.4.tgz",
|
||||
"integrity": "sha512-KrGhL9Q4zjj0kiUt5OO4Mr/A/jlI2jDYs5eHBpYHPcBEVSiipAvn2Ko2HnPe20rmcuuvMHNdZFp+4IlGTMF0Ow==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"asynckit": "^0.4.0",
|
||||
"combined-stream": "^1.0.8",
|
||||
"es-set-tostringtag": "^2.1.0",
|
||||
"hasown": "^2.0.2",
|
||||
"mime-types": "^2.1.12"
|
||||
},
|
||||
"engines": {
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
"lint": "next lint"
|
||||
},
|
||||
"dependencies": {
|
||||
"axios": "^1.6.0",
|
||||
"axios": "^1.11.0",
|
||||
"next": "^14.0.0",
|
||||
"pino": "^8.15.0",
|
||||
"react": "^18.2.0",
|
||||
|
||||
@@ -90,7 +90,7 @@ async def main(transport: DailyTransport):
|
||||
logger.info("Participant left: {}", participant)
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner()
|
||||
runner = PipelineRunner(handle_sigint=False, force_gc=True)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
from pipecat.transports.services.daily import DailyLogLevel, DailyParams, DailyTransport
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
@@ -43,6 +43,7 @@ async def main():
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
)
|
||||
transport.set_log_level(DailyLogLevel.Info)
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
|
||||
109
examples/foundational/07aa-interruptible-soniox.py
Normal file
109
examples/foundational/07aa-interruptible-soniox.py
Normal file
@@ -0,0 +1,109 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.services.soniox.stt import SonioxSTTService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.network.fastapi_websocket import FastAPIWebsocketParams
|
||||
from pipecat.transports.services.daily import DailyParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_example(transport: BaseTransport, _: argparse.Namespace, handle_sigint: bool):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = SonioxSTTService(
|
||||
api_key=os.getenv("SONIOX_API_KEY"),
|
||||
)
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt,
|
||||
context_aggregator.user(), # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.examples.run import main
|
||||
|
||||
main(run_example, transport_params=transport_params)
|
||||
@@ -7,6 +7,7 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
@@ -50,60 +51,63 @@ transport_params = {
|
||||
async def run_example(transport: BaseTransport, _: argparse.Namespace, handle_sigint: bool):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
# Create an HTTP session
|
||||
async with aiohttp.ClientSession() as session:
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
tts = NeuphonicHttpTTSService(
|
||||
api_key=os.getenv("NEUPHONIC_API_KEY"),
|
||||
voice_id="fc854436-2dac-4d21-aa69-ae17b54e98eb", # Emily
|
||||
)
|
||||
tts = NeuphonicHttpTTSService(
|
||||
api_key=os.getenv("NEUPHONIC_API_KEY"),
|
||||
voice_id="fc854436-2dac-4d21-aa69-ae17b54e98eb", # Emily
|
||||
aiohttp_session=session,
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt,
|
||||
context_aggregator.user(), # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
)
|
||||
context = OpenAILLMContext(messages)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt,
|
||||
context_aggregator.user(), # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
)
|
||||
|
||||
runner = PipelineRunner(handle_sigint=handle_sigint)
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||
|
||||
await runner.run(task)
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
81
examples/foundational/13i-soniox-transcription.py
Normal file
81
examples/foundational/13i-soniox-transcription.py
Normal file
@@ -0,0 +1,81 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import Frame, TranscriptionFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.services.soniox.stt import SonioxSTTService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.network.fastapi_websocket import FastAPIWebsocketParams
|
||||
from pipecat.transports.services.daily import DailyParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
class TranscriptionLogger(FrameProcessor):
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, TranscriptionFrame):
|
||||
print(f"Transcription: {frame.text}")
|
||||
|
||||
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_example(transport: BaseTransport, _: argparse.Namespace, handle_sigint: bool):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = SonioxSTTService(
|
||||
api_key=os.getenv("SONIOX_API_KEY"),
|
||||
)
|
||||
|
||||
tl = TranscriptionLogger()
|
||||
|
||||
pipeline = Pipeline([transport.input(), stt, tl])
|
||||
|
||||
task = PipelineTask(pipeline)
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
|
||||
@transport.event_handler("on_client_closed")
|
||||
async def on_client_closed(transport, client):
|
||||
logger.info(f"Client closed connection")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=False)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.examples.run import main
|
||||
|
||||
main(run_example, transport_params=transport_params)
|
||||
162
examples/foundational/14u-function-calling-ollama.py
Normal file
162
examples/foundational/14u-function-calling-ollama.py
Normal file
@@ -0,0 +1,162 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import TTSSpeakFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.services.ollama.llm import OLLamaLLMService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.network.fastapi_websocket import FastAPIWebsocketParams
|
||||
from pipecat.transports.services.daily import DailyParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
async def fetch_weather_from_api(params: FunctionCallParams):
|
||||
await params.result_callback({"conditions": "nice", "temperature": "75"})
|
||||
|
||||
|
||||
async def fetch_restaurant_recommendation(params: FunctionCallParams):
|
||||
await params.result_callback({"name": "The Golden Dragon"})
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_example(transport: BaseTransport, _: argparse.Namespace, handle_sigint: bool):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
)
|
||||
|
||||
llm = OLLamaLLMService(model="llama3.2") # Update to the model you're running locally
|
||||
|
||||
# You can also register a function_name of None to get all functions
|
||||
# sent to the same callback with an additional function_name parameter.
|
||||
llm.register_function("get_current_weather", fetch_weather_from_api)
|
||||
llm.register_function("get_restaurant_recommendation", fetch_restaurant_recommendation)
|
||||
|
||||
@llm.event_handler("on_function_calls_started")
|
||||
async def on_function_calls_started(service, function_calls):
|
||||
await tts.queue_frame(TTSSpeakFrame("Let me check on that."))
|
||||
|
||||
weather_function = FunctionSchema(
|
||||
name="get_current_weather",
|
||||
description="Get the current weather",
|
||||
properties={
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the user's location.",
|
||||
},
|
||||
},
|
||||
required=["location", "format"],
|
||||
)
|
||||
restaurant_function = FunctionSchema(
|
||||
name="get_restaurant_recommendation",
|
||||
description="Get a restaurant recommendation",
|
||||
properties={
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
},
|
||||
required=["location"],
|
||||
)
|
||||
tools = ToolsSchema(standard_tools=[weather_function, restaurant_function])
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages, tools)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
stt,
|
||||
context_aggregator.user(),
|
||||
llm,
|
||||
tts,
|
||||
transport.output(),
|
||||
context_aggregator.assistant(),
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.examples.run import main
|
||||
|
||||
main(run_example, transport_params=transport_params)
|
||||
@@ -0,0 +1,165 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.tools_schema import AdapterType, ToolsSchema
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import Frame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.services.gemini_multimodal_live.gemini import GeminiMultimodalLiveLLMService
|
||||
from pipecat.services.google.frames import LLMSearchResponseFrame
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.network.fastapi_websocket import FastAPIWebsocketParams
|
||||
from pipecat.transports.services.daily import DailyParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
video_in_enabled=False,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.5)),
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
video_in_enabled=False,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.5)),
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
video_in_enabled=False,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.5)),
|
||||
),
|
||||
}
|
||||
|
||||
SYSTEM_INSTRUCTION = """
|
||||
You are a helpful AI assistant that actively uses Google Search to provide up-to-date, accurate information.
|
||||
|
||||
IMPORTANT: For ANY question about current events, news, recent developments, real-time information, or anything that might have changed recently, you MUST use the google_search tool to get the latest information.
|
||||
|
||||
You should use Google Search for:
|
||||
- Current news and events
|
||||
- Recent developments in any field
|
||||
- Today's weather, stock prices, or other real-time data
|
||||
- Any question that starts with "what's happening", "latest", "recent", "current", "today", etc.
|
||||
- When you're not certain about recent information
|
||||
|
||||
Always be proactive about using search when the user asks about anything that could benefit from real-time information.
|
||||
|
||||
Your output will be converted to audio so don't include special characters in your answers.
|
||||
|
||||
Respond to what the user said in a creative and helpful way, always using search for current information.
|
||||
"""
|
||||
|
||||
|
||||
class GroundingMetadataProcessor(FrameProcessor):
|
||||
"""Processor to capture and display grounding metadata from Gemini Live API."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._grounding_count = 0
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, LLMSearchResponseFrame):
|
||||
self._grounding_count += 1
|
||||
logger.info(f"\n\n🔍 GROUNDING METADATA RECEIVED #{self._grounding_count}\n")
|
||||
logger.info(f"📝 Search Result Text: {frame.search_result[:200]}...")
|
||||
|
||||
if frame.rendered_content:
|
||||
logger.info(f"🔗 Rendered Content: {frame.rendered_content}")
|
||||
|
||||
if frame.origins:
|
||||
logger.info(f"📍 Number of Origins: {len(frame.origins)}")
|
||||
for i, origin in enumerate(frame.origins):
|
||||
logger.info(f" Origin {i + 1}: {origin.site_title} - {origin.site_uri}")
|
||||
if origin.results:
|
||||
logger.info(f" Results: {len(origin.results)} items")
|
||||
|
||||
# Always push the frame downstream
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
async def run_example(transport: BaseTransport, _: argparse.Namespace, handle_sigint: bool):
|
||||
logger.info(f"Starting Gemini Live Grounding Metadata Test Bot")
|
||||
|
||||
# Create tools using ToolsSchema with custom tools for Gemini
|
||||
tools = ToolsSchema(
|
||||
standard_tools=[], # No standard function declarations needed
|
||||
custom_tools={AdapterType.GEMINI: [{"google_search": {}}, {"code_execution": {}}]},
|
||||
)
|
||||
|
||||
llm = GeminiMultimodalLiveLLMService(
|
||||
api_key=os.getenv("GOOGLE_API_KEY"),
|
||||
system_instruction=SYSTEM_INSTRUCTION,
|
||||
voice_id="Charon", # Aoede, Charon, Fenrir, Kore, Puck
|
||||
transcribe_user_audio=True,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
# Create a processor to capture grounding metadata
|
||||
grounding_processor = GroundingMetadataProcessor()
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Please introduce yourself and let me know that you can help with current information by searching the web. Ask me what current information I'd like to know about.",
|
||||
},
|
||||
]
|
||||
|
||||
# Set up conversation context and management
|
||||
context = OpenAILLMContext(messages)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
context_aggregator.user(),
|
||||
llm,
|
||||
grounding_processor, # Add our grounding processor here
|
||||
transport.output(),
|
||||
context_aggregator.assistant(),
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(pipeline)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
|
||||
@transport.event_handler("on_client_closed")
|
||||
async def on_client_closed(transport, client):
|
||||
logger.info(f"Client closed connection")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=False)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.examples.run import main
|
||||
|
||||
main(run_example, transport_params=transport_params)
|
||||
@@ -11,7 +11,7 @@ from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn import LocalSmartTurnAnalyzer
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v2 import LocalSmartTurnAnalyzerV2
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
@@ -37,7 +37,7 @@ load_dotenv(override=True)
|
||||
# # Hugging Face uses LFS to store large model files, including .mlpackage
|
||||
# git lfs install
|
||||
# # Clone the repo with the smart_turn_classifier.mlpackage
|
||||
# git clone https://huggingface.co/pipecat-ai/smart-turn
|
||||
# git clone https://huggingface.co/pipecat-ai/smart-turn-v2
|
||||
#
|
||||
# Then set the env variable:
|
||||
# export LOCAL_SMART_TURN_MODEL_PATH=./smart-turn
|
||||
@@ -52,7 +52,7 @@ transport_params = {
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzer(
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV2(
|
||||
smart_turn_model_path=smart_turn_model_path, params=SmartTurnParams()
|
||||
),
|
||||
),
|
||||
@@ -60,7 +60,7 @@ transport_params = {
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzer(
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV2(
|
||||
smart_turn_model_path=smart_turn_model_path, params=SmartTurnParams()
|
||||
),
|
||||
),
|
||||
@@ -68,7 +68,7 @@ transport_params = {
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
turn_analyzer=LocalSmartTurnAnalyzer(
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV2(
|
||||
smart_turn_model_path=smart_turn_model_path, params=SmartTurnParams()
|
||||
),
|
||||
),
|
||||
|
||||
@@ -295,6 +295,22 @@ This project uses TypeScript, React, and Next.js, making it a perfect fit for [V
|
||||
|
||||
Again, we'll use Pipecat Cloud. Follow the steps from above. The only difference will be the secrets required; in addition to a GOOGLE_API_KEY, you'll need `GOOGLE_APPLICATION_CREDENTIALS` in the format of a .json file with your [Google Cloud service account](https://console.cloud.google.com/iam-admin/serviceaccounts) information.
|
||||
|
||||
You'll need to modify the Dockerfile so that the credentials.json and word_list.py are accessible. This Dockerfile will work:
|
||||
|
||||
```Dockerfile
|
||||
FROM dailyco/pipecat-base:latest
|
||||
|
||||
COPY ./requirements.txt requirements.txt
|
||||
|
||||
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
||||
|
||||
COPY ./word_list.py word_list.py
|
||||
COPY ./credentials.json credentials.json
|
||||
COPY ./bot_phone_twilio.py bot.py
|
||||
```
|
||||
|
||||
Note: Your `credentials.json` file should have your Google service account credentials.
|
||||
|
||||
#### Buy and Configure a Twilio Number
|
||||
|
||||
Check out the [Twilio Websocket Telephony guide](https://docs.pipecat.daily.co/pipecat-in-production/telephony/twilio-mediastreams) for a step-by-step walkthrough on how to purchase a phone number, configure your TwiML, and make or receive calls.
|
||||
|
||||
@@ -20,19 +20,22 @@ classifiers = [
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence"
|
||||
]
|
||||
dependencies = [
|
||||
"aiohttp~=3.11.12",
|
||||
"aiohttp>=3.11.12,<4",
|
||||
"audioop-lts~=0.2.1; python_version>='3.13'",
|
||||
"docstring_parser~=0.16",
|
||||
"loguru~=0.7.3",
|
||||
"Markdown~=3.7",
|
||||
"numpy>=1.26.4",
|
||||
"Pillow~=11.1.0",
|
||||
"Markdown>=3.7,<4",
|
||||
"nltk>=3.9.1,<4",
|
||||
"numpy>=1.26.4,<3",
|
||||
"Pillow>=11.1.0,<12",
|
||||
"protobuf~=5.29.3",
|
||||
"pydantic~=2.10.6",
|
||||
"pydantic>=2.10.6,<3",
|
||||
"pyloudnorm~=0.1.1",
|
||||
"resampy~=0.4.3",
|
||||
"soxr~=0.5.0",
|
||||
"openai~=1.74.0",
|
||||
"openai>=1.74.0,<2",
|
||||
# Explicit dependency pins for Python 3.11+ compatibility
|
||||
"numba>=0.60.0,<1",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
@@ -41,59 +44,60 @@ Website = "https://pipecat.ai"
|
||||
|
||||
[project.optional-dependencies]
|
||||
anthropic = [ "anthropic~=0.49.0" ]
|
||||
assemblyai = [ "websockets~=13.1" ]
|
||||
aws = [ "aioboto3~=15.0.0", "websockets~=13.1" ]
|
||||
aws-nova-sonic = [ "aws_sdk_bedrock_runtime~=0.0.2" ]
|
||||
assemblyai = [ "websockets>=13.1,<15.0" ]
|
||||
aws = [ "aioboto3~=15.0.0", "websockets>=13.1,<15.0" ]
|
||||
aws-nova-sonic = [ "aws_sdk_bedrock_runtime~=0.0.2; python_version>='3.12'" ]
|
||||
azure = [ "azure-cognitiveservices-speech~=1.42.0"]
|
||||
cartesia = [ "cartesia~=2.0.3", "websockets~=13.1" ]
|
||||
cartesia = [ "cartesia~=2.0.3", "websockets>=13.1,<15.0" ]
|
||||
cerebras = []
|
||||
deepseek = []
|
||||
daily = [ "daily-python~=0.19.4" ]
|
||||
deepgram = [ "deepgram-sdk~=4.1.0" ]
|
||||
elevenlabs = [ "websockets~=13.1" ]
|
||||
deepgram = [ "deepgram-sdk~=4.7.0" ]
|
||||
elevenlabs = [ "websockets>=13.1,<15.0" ]
|
||||
fal = [ "fal-client~=0.5.9" ]
|
||||
fireworks = []
|
||||
fish = [ "ormsgpack~=1.7.0", "websockets~=13.1" ]
|
||||
gladia = [ "websockets~=13.1" ]
|
||||
google = [ "google-cloud-speech~=2.32.0", "google-cloud-texttospeech~=2.26.0", "google-genai~=1.24.0", "websockets~=13.1" ]
|
||||
fish = [ "ormsgpack~=1.7.0", "websockets>=13.1,<15.0" ]
|
||||
gladia = [ "websockets>=13.1,<15.0" ]
|
||||
google = [ "google-cloud-speech~=2.32.0", "google-cloud-texttospeech~=2.26.0", "google-genai~=1.24.0", "websockets>=13.1,<15.0" ]
|
||||
grok = []
|
||||
groq = [ "groq~=0.23.0" ]
|
||||
gstreamer = [ "pygobject~=3.50.0" ]
|
||||
krisp = [ "pipecat-ai-krisp~=0.4.0" ]
|
||||
koala = [ "pvkoala~=2.0.3" ]
|
||||
langchain = [ "langchain~=0.3.20", "langchain-community~=0.3.20", "langchain-openai~=0.3.9" ]
|
||||
livekit = [ "livekit~=0.22.0", "livekit-api~=0.8.2", "tenacity~=9.0.0" ]
|
||||
lmnt = [ "websockets~=13.1" ]
|
||||
livekit = [ "livekit~=0.22.0", "livekit-api~=0.8.2", "tenacity>=8.2.3,<10.0.0" ]
|
||||
lmnt = [ "websockets>=13.1,<15.0" ]
|
||||
local = [ "pyaudio~=0.2.14" ]
|
||||
mcp = [ "mcp[cli]~=1.9.4" ]
|
||||
mem0 = [ "mem0ai~=0.1.94" ]
|
||||
mlx-whisper = [ "mlx-whisper~=0.4.2" ]
|
||||
moondream = [ "einops~=0.8.0", "timm~=1.0.13", "transformers>=4.48.0" ]
|
||||
nim = []
|
||||
neuphonic = [ "pyneuphonic~=1.5.13", "websockets~=13.1" ]
|
||||
neuphonic = [ "websockets>=13.1,<15.0" ]
|
||||
noisereduce = [ "noisereduce~=3.0.3" ]
|
||||
openai = [ "websockets~=13.1" ]
|
||||
openai = [ "websockets>=13.1,<15.0" ]
|
||||
openpipe = [ "openpipe~=4.50.0" ]
|
||||
openrouter = []
|
||||
perplexity = []
|
||||
playht = [ "pyht~=0.1.12", "websockets~=13.1" ]
|
||||
playht = [ "pyht>=0.1.6", "websockets>=13.1,<15.0" ]
|
||||
qwen = []
|
||||
rime = [ "websockets~=13.1" ]
|
||||
rime = [ "websockets>=13.1,<15.0" ]
|
||||
riva = [ "nvidia-riva-client~=2.21.1" ]
|
||||
sambanova = []
|
||||
sentry = [ "sentry-sdk~=2.23.1" ]
|
||||
local-smart-turn = [ "coremltools>=8.0", "transformers", "torch==2.5.0", "torchaudio==2.5.0" ]
|
||||
local-smart-turn = [ "coremltools>=8.0", "transformers", "torch~=2.5.0", "torchaudio~=2.5.0" ]
|
||||
remote-smart-turn = []
|
||||
silero = [ "onnxruntime~=1.20.1" ]
|
||||
simli = [ "simli-ai~=0.1.10"]
|
||||
soniox = [ "websockets>=13.1,<15.0" ]
|
||||
soundfile = [ "soundfile~=0.13.0" ]
|
||||
speechmatics = [ "speechmatics-rt>=0.3.1" ]
|
||||
tavus=[]
|
||||
together = []
|
||||
tracing = [ "opentelemetry-sdk>=1.33.0", "opentelemetry-api>=1.33.0", "opentelemetry-instrumentation>=0.54b0" ]
|
||||
ultravox = [ "transformers~=4.48.0", "vllm~=0.7.3" ]
|
||||
ultravox = [ "transformers>=4.48.0", "vllm~=0.7.3" ]
|
||||
webrtc = [ "aiortc~=1.11.0", "opencv-python~=4.11.0.86" ]
|
||||
websocket = [ "websockets~=13.1", "fastapi~=0.115.6" ]
|
||||
websocket = [ "websockets>=13.1,<15.0", "fastapi>=0.115.6,<0.117.0" ]
|
||||
whisper = [ "faster-whisper~=1.1.1" ]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
@@ -148,3 +152,6 @@ convention = "google"
|
||||
command_line = "--module pytest"
|
||||
source = ["src"]
|
||||
omit = ["*/tests/*"]
|
||||
|
||||
[project.scripts]
|
||||
pipecat = "pipecat.__main__:main"
|
||||
101
src/pipecat/__main__.py
Normal file
101
src/pipecat/__main__.py
Normal file
@@ -0,0 +1,101 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import argparse
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
def load_bot_module(file_path: str, function_name: str = "run_example"):
|
||||
"""Load a bot module from a Python file and return the specified function.
|
||||
|
||||
Args:
|
||||
file_path: Path to the Python file containing the bot
|
||||
function_name: Name of the function to load (default: run_example)
|
||||
|
||||
Returns:
|
||||
The callable function from the module
|
||||
|
||||
Raises:
|
||||
SystemExit: If the file doesn't exist, isn't a Python file, or the function isn't found
|
||||
"""
|
||||
logger.info(f"Loading bot module from: {file_path}")
|
||||
logger.info(f"Looking for function: {function_name}")
|
||||
|
||||
file_path_obj = Path(file_path)
|
||||
if not file_path_obj.exists():
|
||||
print(f"Error: File '{file_path}' not found", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
if not file_path_obj.suffix == ".py":
|
||||
print(f"Error: File '{file_path}' is not a Python file", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Import the module
|
||||
try:
|
||||
logger.info(f"Importing module from: {file_path}")
|
||||
spec = importlib.util.spec_from_file_location("bot_module", file_path_obj)
|
||||
if spec is None or spec.loader is None:
|
||||
print(f"Error: Could not load module from '{file_path}'", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
logger.info(f"Successfully imported module: {module.__name__}")
|
||||
except Exception as e:
|
||||
print(f"Error importing module from '{file_path}': {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Find the function to run
|
||||
if not hasattr(module, function_name):
|
||||
print(f"Error: Function '{function_name}' not found in '{file_path}'", file=sys.stderr)
|
||||
print(
|
||||
f"Available functions: {[name for name in dir(module) if not name.startswith('_')]}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
run_example = getattr(module, function_name)
|
||||
if not callable(run_example):
|
||||
print(f"Error: '{function_name}' is not a callable function", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
logger.info(f"Successfully loaded function: {function_name}")
|
||||
return run_example
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point for the pipecat command line tool.
|
||||
|
||||
This function is called by the entry point script and handles argument parsing
|
||||
and module loading before calling the actual main execution logic.
|
||||
"""
|
||||
# Set up argument parser for our specific arguments
|
||||
parser = argparse.ArgumentParser(description="Run a Pipecat bot from a Python file")
|
||||
parser.add_argument("file", help="Python file containing the bot to run")
|
||||
parser.add_argument("--function", "-f", default="run_example",
|
||||
help="Function name to run (default: run_example)")
|
||||
|
||||
# Parse our arguments first
|
||||
args, remaining_args = parser.parse_known_args()
|
||||
|
||||
# Load the bot module and get the function
|
||||
run_example = load_bot_module(args.file, args.function)
|
||||
|
||||
# Set sys.argv to the remaining arguments for the run_main function
|
||||
sys.argv = [sys.argv[0]] + remaining_args
|
||||
|
||||
# Import run_main only when we need it
|
||||
from pipecat.examples.run import main as run_main
|
||||
|
||||
# Call the main function from pipecat.examples.run
|
||||
run_main(run_example)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
196
src/pipecat/audio/turn/smart_turn/local_smart_turn_v2.py
Normal file
196
src/pipecat/audio/turn/smart_turn/local_smart_turn_v2.py
Normal file
@@ -0,0 +1,196 @@
|
||||
#
|
||||
# Copyright (c) 2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Local PyTorch turn analyzer for on-device ML inference using the smart-turn-v2 model.
|
||||
|
||||
This module provides a smart turn analyzer that uses PyTorch models for
|
||||
local end-of-turn detection without requiring network connectivity.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import BaseSmartTurn
|
||||
|
||||
try:
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import (
|
||||
Wav2Vec2Config,
|
||||
Wav2Vec2Model,
|
||||
Wav2Vec2PreTrainedModel,
|
||||
Wav2Vec2Processor,
|
||||
)
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use LocalSmartTurnAnalyzerV2, you need to `pip install pipecat-ai[local-smart-turn]`."
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class LocalSmartTurnAnalyzerV2(BaseSmartTurn):
|
||||
"""Local turn analyzer using the smart-turn-v2 PyTorch model.
|
||||
|
||||
Provides end-of-turn detection using locally-stored PyTorch models,
|
||||
enabling offline operation without network dependencies. Uses
|
||||
Wav2Vec2 architecture for audio sequence classification.
|
||||
"""
|
||||
|
||||
def __init__(self, *, smart_turn_model_path: str, **kwargs):
|
||||
"""Initialize the local PyTorch smart-turn-v2 analyzer.
|
||||
|
||||
Args:
|
||||
smart_turn_model_path: Path to directory containing the PyTorch model
|
||||
and feature extractor files. If empty, uses default HuggingFace model.
|
||||
**kwargs: Additional arguments passed to BaseSmartTurn.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if not smart_turn_model_path:
|
||||
# Define the path to the pretrained model on Hugging Face
|
||||
smart_turn_model_path = "pipecat-ai/smart-turn-v2"
|
||||
|
||||
logger.debug("Loading Local Smart Turn v2 model...")
|
||||
# Load the pretrained model for sequence classification
|
||||
self._turn_model = _Wav2Vec2ForEndpointing.from_pretrained(smart_turn_model_path)
|
||||
# Load the corresponding feature extractor for preprocessing audio
|
||||
self._turn_processor = Wav2Vec2Processor.from_pretrained(smart_turn_model_path)
|
||||
# Use platform-optimized backend if available (MPS for Apple silicon, CUDA for NVIDIA)
|
||||
self._device = "cpu"
|
||||
if torch.backends.mps.is_available():
|
||||
self._device = "mps"
|
||||
elif torch.cuda.is_available():
|
||||
self._device = "cuda"
|
||||
# Move model to selected device and set it to evaluation mode
|
||||
self._turn_model = self._turn_model.to(self._device)
|
||||
self._turn_model.eval()
|
||||
logger.debug("Loaded Local Smart Turn v2")
|
||||
|
||||
async def _predict_endpoint(self, audio_array: np.ndarray) -> Dict[str, Any]:
|
||||
"""Predict end-of-turn using local PyTorch model."""
|
||||
inputs = self._turn_processor(
|
||||
audio_array,
|
||||
sampling_rate=16000,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=16000 * 16, # 16 seconds at 16kHz
|
||||
return_attention_mask=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
# Move inputs to device
|
||||
inputs = {k: v.to(self._device) for k, v in inputs.items()}
|
||||
|
||||
# Run inference
|
||||
with torch.no_grad():
|
||||
outputs = self._turn_model(**inputs)
|
||||
|
||||
# The model returns sigmoid probabilities directly in the logits field
|
||||
probability = outputs["logits"][0].item()
|
||||
|
||||
# Make prediction (1 for Complete, 0 for Incomplete)
|
||||
prediction = 1 if probability > 0.5 else 0
|
||||
|
||||
return {
|
||||
"prediction": prediction,
|
||||
"probability": probability,
|
||||
}
|
||||
|
||||
|
||||
class _Wav2Vec2ForEndpointing(Wav2Vec2PreTrainedModel):
|
||||
def __init__(self, config: Wav2Vec2Config):
|
||||
super().__init__(config)
|
||||
self.wav2vec2 = Wav2Vec2Model(config)
|
||||
|
||||
self.pool_attention = nn.Sequential(
|
||||
nn.Linear(config.hidden_size, 256), nn.Tanh(), nn.Linear(256, 1)
|
||||
)
|
||||
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(config.hidden_size, 256),
|
||||
nn.LayerNorm(256),
|
||||
nn.GELU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(256, 64),
|
||||
nn.GELU(),
|
||||
nn.Linear(64, 1),
|
||||
)
|
||||
|
||||
for module in self.classifier:
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=0.1)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
for module in self.pool_attention:
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=0.1)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def attention_pool(self, hidden_states, attention_mask):
|
||||
# Calculate attention weights
|
||||
attention_weights = self.pool_attention(hidden_states)
|
||||
|
||||
if attention_mask is None:
|
||||
raise ValueError("attention_mask must be provided for attention pooling")
|
||||
|
||||
attention_weights = attention_weights + (
|
||||
(1.0 - attention_mask.unsqueeze(-1).to(attention_weights.dtype)) * -1e9
|
||||
)
|
||||
|
||||
attention_weights = F.softmax(attention_weights, dim=1)
|
||||
|
||||
# Apply attention to hidden states
|
||||
weighted_sum = torch.sum(hidden_states * attention_weights, dim=1)
|
||||
|
||||
return weighted_sum
|
||||
|
||||
def forward(self, input_values, attention_mask=None, labels=None):
|
||||
outputs = self.wav2vec2(input_values, attention_mask=attention_mask)
|
||||
hidden_states = outputs[0]
|
||||
|
||||
# Create transformer padding mask
|
||||
if attention_mask is not None:
|
||||
input_length = attention_mask.size(1)
|
||||
hidden_length = hidden_states.size(1)
|
||||
ratio = input_length / hidden_length
|
||||
indices = (torch.arange(hidden_length, device=attention_mask.device) * ratio).long()
|
||||
attention_mask = attention_mask[:, indices]
|
||||
attention_mask = attention_mask.bool()
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
pooled = self.attention_pool(hidden_states, attention_mask)
|
||||
|
||||
logits = self.classifier(pooled)
|
||||
|
||||
if torch.isnan(logits).any():
|
||||
raise ValueError("NaN values detected in logits")
|
||||
|
||||
if labels is not None:
|
||||
# Calculate positive sample weight based on batch statistics
|
||||
pos_weight = ((labels == 0).sum() / (labels == 1).sum()).clamp(min=0.1, max=10.0)
|
||||
loss_fct = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
|
||||
labels = labels.float()
|
||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
||||
|
||||
# Add L2 regularization for classifier layers
|
||||
l2_lambda = 0.01
|
||||
l2_reg = torch.tensor(0.0, device=logits.device)
|
||||
for param in self.classifier.parameters():
|
||||
l2_reg += torch.norm(param)
|
||||
loss += l2_lambda * l2_reg
|
||||
|
||||
probs = torch.sigmoid(logits.detach())
|
||||
return {"loss": loss, "logits": probs}
|
||||
|
||||
probs = torch.sigmoid(logits)
|
||||
return {"logits": probs}
|
||||
@@ -614,6 +614,7 @@ class StartFrame(SystemFrame):
|
||||
audio_out_sample_rate: Output audio sample rate in Hz.
|
||||
allow_interruptions: Whether to allow user interruptions.
|
||||
enable_metrics: Whether to enable performance metrics collection.
|
||||
enable_tracing: Whether to enable OpenTelemetry tracing.
|
||||
enable_usage_metrics: Whether to enable usage metrics collection.
|
||||
interruption_strategies: List of interruption handling strategies.
|
||||
report_only_initial_ttfb: Whether to report only initial time-to-first-byte.
|
||||
@@ -623,6 +624,7 @@ class StartFrame(SystemFrame):
|
||||
audio_out_sample_rate: int = 24000
|
||||
allow_interruptions: bool = False
|
||||
enable_metrics: bool = False
|
||||
enable_tracing: bool = False
|
||||
enable_usage_metrics: bool = False
|
||||
interruption_strategies: List[BaseInterruptionStrategy] = field(default_factory=list)
|
||||
report_only_initial_ttfb: bool = False
|
||||
|
||||
@@ -38,14 +38,16 @@ class PipelineRunner(BaseObject):
|
||||
handle_sigint: bool = True,
|
||||
force_gc: bool = False,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||
handle_sigterm: bool = False,
|
||||
):
|
||||
"""Initialize the pipeline runner.
|
||||
|
||||
Args:
|
||||
name: Optional name for the runner instance.
|
||||
handle_sigint: Whether to automatically handle SIGINT/SIGTERM signals.
|
||||
handle_sigint: Whether to automatically handle SIGINT signals.
|
||||
force_gc: Whether to force garbage collection after task completion.
|
||||
loop: Event loop to use. If None, uses the current running loop.
|
||||
handle_sigterm: Whether to automatically handle SIGTERM signals.
|
||||
"""
|
||||
super().__init__(name=name)
|
||||
|
||||
@@ -57,6 +59,9 @@ class PipelineRunner(BaseObject):
|
||||
if handle_sigint:
|
||||
self._setup_sigint()
|
||||
|
||||
if handle_sigterm:
|
||||
self._setup_sigterm()
|
||||
|
||||
async def run(self, task: PipelineTask):
|
||||
"""Run a pipeline task to completion.
|
||||
|
||||
@@ -96,6 +101,10 @@ class PipelineRunner(BaseObject):
|
||||
"""Set up signal handlers for graceful shutdown."""
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.add_signal_handler(signal.SIGINT, lambda *args: self._sig_handler())
|
||||
|
||||
def _setup_sigterm(self):
|
||||
"""Set up signal handlers for graceful shutdown."""
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.add_signal_handler(signal.SIGTERM, lambda *args: self._sig_handler())
|
||||
|
||||
def _sig_handler(self):
|
||||
|
||||
@@ -638,6 +638,7 @@ class PipelineTask(BasePipelineTask):
|
||||
audio_in_sample_rate=self._params.audio_in_sample_rate,
|
||||
audio_out_sample_rate=self._params.audio_out_sample_rate,
|
||||
enable_metrics=self._params.enable_metrics,
|
||||
enable_tracing=self._enable_tracing,
|
||||
enable_usage_metrics=self._params.enable_usage_metrics,
|
||||
report_only_initial_ttfb=self._params.report_only_initial_ttfb,
|
||||
interruption_strategies=self._params.interruption_strategies,
|
||||
|
||||
@@ -693,7 +693,11 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
|
||||
# to emulate VAD (i.e. user start/stopped speaking), but we do it only
|
||||
# if the bot is not speaking. If the bot is speaking and we really have
|
||||
# a short utterance we don't really want to interrupt the bot.
|
||||
if not self._user_speaking and not self._waiting_for_aggregation:
|
||||
if (
|
||||
not self._user_speaking
|
||||
and not self._waiting_for_aggregation
|
||||
and len(self._aggregation) > 0
|
||||
):
|
||||
if self._bot_speaking:
|
||||
# If we reached this case and the bot is speaking, let's ignore
|
||||
# what the user said.
|
||||
|
||||
@@ -44,6 +44,7 @@ from .models import (
|
||||
|
||||
try:
|
||||
import websockets
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error('In order to use AssemblyAI, you need to `pip install "pipecat-ai[assemblyai]"`.')
|
||||
@@ -190,9 +191,9 @@ class AssemblyAISTTService(STTService):
|
||||
"Authorization": self._api_key,
|
||||
"User-Agent": f"AssemblyAI/1.0 (integration=Pipecat/{pipecat_version})",
|
||||
}
|
||||
self._websocket = await websockets.connect(
|
||||
self._websocket = await websocket_connect(
|
||||
ws_url,
|
||||
extra_headers=headers,
|
||||
additional_headers=headers,
|
||||
)
|
||||
self._connected = True
|
||||
self._receive_task = self.create_task(self._receive_task_handler())
|
||||
|
||||
@@ -36,6 +36,8 @@ from pipecat.utils.tracing.service_decorators import traced_stt
|
||||
|
||||
try:
|
||||
import websockets
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
from websockets.protocol import State
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use AWS services, you need to `pip install pipecat-ai[aws]`.")
|
||||
@@ -133,7 +135,7 @@ class AWSTranscribeSTTService(STTService):
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
await self._connect()
|
||||
if self._ws_client and self._ws_client.open:
|
||||
if self._ws_client and self._ws_client.state is State.OPEN:
|
||||
logger.info("Successfully established WebSocket connection")
|
||||
return
|
||||
logger.warning("WebSocket connection not established after connect")
|
||||
@@ -174,7 +176,7 @@ class AWSTranscribeSTTService(STTService):
|
||||
"""
|
||||
try:
|
||||
# Ensure WebSocket is connected
|
||||
if not self._ws_client or not self._ws_client.open:
|
||||
if not self._ws_client or self._ws_client.state is State.CLOSED:
|
||||
logger.debug("WebSocket not connected, attempting to reconnect...")
|
||||
try:
|
||||
await self._connect()
|
||||
@@ -208,7 +210,7 @@ class AWSTranscribeSTTService(STTService):
|
||||
|
||||
async def _connect(self):
|
||||
"""Connect to AWS Transcribe with connection state management."""
|
||||
if self._ws_client and self._ws_client.open and self._receive_task:
|
||||
if self._ws_client and self._ws_client.state is State.OPEN and self._receive_task:
|
||||
logger.debug(f"{self} Already connected")
|
||||
return
|
||||
|
||||
@@ -238,7 +240,7 @@ class AWSTranscribeSTTService(STTService):
|
||||
)
|
||||
|
||||
# Add required headers
|
||||
extra_headers = {
|
||||
additional_headers = {
|
||||
"Origin": "https://localhost",
|
||||
"Sec-WebSocket-Key": websocket_key,
|
||||
"Sec-WebSocket-Version": "13",
|
||||
@@ -268,9 +270,9 @@ class AWSTranscribeSTTService(STTService):
|
||||
logger.debug(f"{self} Connecting to WebSocket with URL: {presigned_url[:100]}...")
|
||||
|
||||
# Connect with the required headers and settings
|
||||
self._ws_client = await websockets.connect(
|
||||
self._ws_client = await websocket_connect(
|
||||
presigned_url,
|
||||
extra_headers=extra_headers,
|
||||
additional_headers=additional_headers,
|
||||
subprotocols=["mqtt"],
|
||||
ping_interval=None,
|
||||
ping_timeout=None,
|
||||
@@ -299,7 +301,7 @@ class AWSTranscribeSTTService(STTService):
|
||||
self._receive_task = None
|
||||
|
||||
try:
|
||||
if self._ws_client and self._ws_client.open:
|
||||
if self._ws_client and self._ws_client.state is State.OPEN:
|
||||
# Send end-stream message
|
||||
end_stream = {"message-type": "event", "event": "end"}
|
||||
await self._ws_client.send(json.dumps(end_stream))
|
||||
@@ -341,7 +343,7 @@ class AWSTranscribeSTTService(STTService):
|
||||
async def _receive_loop(self):
|
||||
"""Background task to receive and process messages from AWS Transcribe."""
|
||||
while True:
|
||||
if not self._ws_client or not self._ws_client.open:
|
||||
if not self._ws_client or self._ws_client.state is State.CLOSED:
|
||||
logger.warning(f"{self} WebSocket closed in receive loop")
|
||||
break
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@ import json
|
||||
import urllib.parse
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
import websockets
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
@@ -34,6 +33,15 @@ from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
from pipecat.utils.tracing.service_decorators import traced_stt
|
||||
|
||||
try:
|
||||
import websockets
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
from websockets.protocol import State
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use Cartesia, you need to `pip install pipecat-ai[cartesia]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class CartesiaLiveOptions:
|
||||
"""Configuration options for Cartesia Live STT service.
|
||||
@@ -216,7 +224,7 @@ class CartesiaSTTService(STTService):
|
||||
None - transcription results are handled via WebSocket responses.
|
||||
"""
|
||||
# If the connection is closed, due to timeout, we need to reconnect when the user starts speaking again
|
||||
if not self._connection or self._connection.closed:
|
||||
if not self._connection or self._connection.state is State.CLOSED:
|
||||
await self._connect()
|
||||
|
||||
await self._connection.send(audio)
|
||||
@@ -229,7 +237,7 @@ class CartesiaSTTService(STTService):
|
||||
headers = {"Cartesia-Version": "2025-04-16", "X-API-Key": self._api_key}
|
||||
|
||||
try:
|
||||
self._connection = await websockets.connect(ws_url, extra_headers=headers)
|
||||
self._connection = await websocket_connect(ws_url, additional_headers=headers)
|
||||
# Setup the receiver task to handle the incoming messages from the Cartesia server
|
||||
if self._receiver_task is None or self._receiver_task.done():
|
||||
self._receiver_task = asyncio.create_task(self._receive_messages())
|
||||
@@ -240,7 +248,7 @@ class CartesiaSTTService(STTService):
|
||||
async def _receive_messages(self):
|
||||
try:
|
||||
while True:
|
||||
if not self._connection or self._connection.closed:
|
||||
if not self._connection or self._connection.state is State.CLOSED:
|
||||
break
|
||||
|
||||
message = await self._connection.recv()
|
||||
@@ -320,7 +328,7 @@ class CartesiaSTTService(STTService):
|
||||
logger.exception(f"Unexpected exception while cancelling task: {e}")
|
||||
self._receiver_task = None
|
||||
|
||||
if self._connection and self._connection.open:
|
||||
if self._connection and self._connection.state is State.OPEN:
|
||||
logger.debug("Disconnecting from Cartesia")
|
||||
|
||||
await self._connection.close()
|
||||
@@ -344,5 +352,5 @@ class CartesiaSTTService(STTService):
|
||||
await self.start_metrics()
|
||||
elif isinstance(frame, UserStoppedSpeakingFrame):
|
||||
# Send finalize command to flush the transcription session
|
||||
if self._connection and self._connection.open:
|
||||
if self._connection and self._connection.state is State.OPEN:
|
||||
await self._connection.send("finalize")
|
||||
|
||||
@@ -36,8 +36,9 @@ from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
# See .env.example for Cartesia configuration needed
|
||||
try:
|
||||
import websockets
|
||||
from cartesia import AsyncCartesia
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
from websockets.protocol import State
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use Cartesia, you need to `pip install pipecat-ai[cartesia]`.")
|
||||
@@ -288,10 +289,10 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
||||
|
||||
async def _connect_websocket(self):
|
||||
try:
|
||||
if self._websocket and self._websocket.open:
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
return
|
||||
logger.debug("Connecting to Cartesia")
|
||||
self._websocket = await websockets.connect(
|
||||
self._websocket = await websocket_connect(
|
||||
f"{self._url}?api_key={self._api_key}&cartesia_version={self._cartesia_version}"
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -380,7 +381,7 @@ class CartesiaTTSService(AudioContextWordTTSService):
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
try:
|
||||
if not self._websocket or self._websocket.closed:
|
||||
if not self._websocket or self._websocket.state is State.CLOSED:
|
||||
await self._connect()
|
||||
|
||||
if not self._context_id:
|
||||
|
||||
@@ -44,6 +44,8 @@ from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
# See .env.example for ElevenLabs configuration needed
|
||||
try:
|
||||
import websockets
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
from websockets.protocol import State
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use ElevenLabs, you need to `pip install pipecat-ai[elevenlabs]`.")
|
||||
@@ -178,20 +180,46 @@ def calculate_word_times(
|
||||
Returns:
|
||||
List of (word, timestamp) tuples.
|
||||
"""
|
||||
zipped_times = list(zip(alignment_info["chars"], alignment_info["charStartTimesMs"]))
|
||||
chars = alignment_info["chars"]
|
||||
char_start_times_ms = alignment_info["charStartTimesMs"]
|
||||
|
||||
words = "".join(alignment_info["chars"]).split(" ")
|
||||
if len(chars) != len(char_start_times_ms):
|
||||
logger.error(
|
||||
f"calculate_word_times: length mismatch - chars={len(chars)}, times={len(char_start_times_ms)}"
|
||||
)
|
||||
return []
|
||||
|
||||
# Calculate start time for each word. We do this by finding a space character
|
||||
# and using the previous word time, also taking into account there might not
|
||||
# be a space at the end.
|
||||
times = []
|
||||
for i, (a, b) in enumerate(zipped_times):
|
||||
if a == " " or i == len(zipped_times) - 1:
|
||||
t = cumulative_time + (zipped_times[i - 1][1] / 1000.0)
|
||||
times.append(t)
|
||||
# Build words and track their start positions
|
||||
words = []
|
||||
word_start_indices = []
|
||||
current_word = ""
|
||||
word_start_index = None
|
||||
|
||||
word_times = list(zip(words, times))
|
||||
for i, char in enumerate(chars):
|
||||
if char == " ":
|
||||
# End of current word
|
||||
if current_word: # Only add non-empty words
|
||||
words.append(current_word)
|
||||
word_start_indices.append(word_start_index)
|
||||
current_word = ""
|
||||
word_start_index = None
|
||||
else:
|
||||
# Building a word
|
||||
if word_start_index is None: # First character of new word
|
||||
word_start_index = i
|
||||
current_word += char
|
||||
|
||||
# Handle the last word if there's no trailing space
|
||||
if current_word and word_start_index is not None:
|
||||
words.append(current_word)
|
||||
word_start_indices.append(word_start_index)
|
||||
|
||||
# Calculate timestamps for each word
|
||||
word_times = []
|
||||
for word, start_idx in zip(words, word_start_indices):
|
||||
# Convert from milliseconds to seconds and add cumulative offset
|
||||
start_time_seconds = cumulative_time + (char_start_times_ms[start_idx] / 1000.0)
|
||||
word_times.append((word, start_time_seconds))
|
||||
|
||||
return word_times
|
||||
|
||||
@@ -213,7 +241,7 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
similarity_boost: Similarity boost control (0.0 to 1.0).
|
||||
style: Style control for voice expression (0.0 to 1.0).
|
||||
use_speaker_boost: Whether to use speaker boost enhancement.
|
||||
speed: Voice speed control (0.25 to 4.0).
|
||||
speed: Voice speed control (0.7 to 1.2).
|
||||
auto_mode: Whether to enable automatic mode optimization.
|
||||
enable_ssml_parsing: Whether to parse SSML tags in text.
|
||||
enable_logging: Whether to enable ElevenLabs logging.
|
||||
@@ -421,7 +449,7 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
|
||||
async def _connect_websocket(self):
|
||||
try:
|
||||
if self._websocket and self._websocket.open:
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
return
|
||||
|
||||
logger.debug("Connecting to ElevenLabs")
|
||||
@@ -448,8 +476,8 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
)
|
||||
|
||||
# Set max websocket message size to 16MB for large audio responses
|
||||
self._websocket = await websockets.connect(
|
||||
url, max_size=16 * 1024 * 1024, extra_headers={"xi-api-key": self._api_key}
|
||||
self._websocket = await websocket_connect(
|
||||
url, max_size=16 * 1024 * 1024, additional_headers={"xi-api-key": self._api_key}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -520,8 +548,14 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
# Check if this message belongs to the current context.
|
||||
# This should never happen, so warn about it.
|
||||
if not self.audio_context_available(received_ctx_id):
|
||||
logger.warning(f"Ignoring message from unavailable context: {received_ctx_id}")
|
||||
continue
|
||||
if self._context_id == received_ctx_id:
|
||||
logger.debug(
|
||||
f"Received a delayed message, recreating the context: {self._context_id}"
|
||||
)
|
||||
await self.create_audio_context(self._context_id)
|
||||
else:
|
||||
logger.warning(f"Ignoring message from unavailable context: {received_ctx_id}")
|
||||
continue
|
||||
|
||||
if msg.get("audio"):
|
||||
await self.stop_ttfb_metrics()
|
||||
@@ -530,10 +564,29 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
audio = base64.b64decode(msg["audio"])
|
||||
frame = TTSAudioRawFrame(audio, self.sample_rate, 1)
|
||||
await self.append_to_audio_context(received_ctx_id, frame)
|
||||
|
||||
if msg.get("alignment"):
|
||||
word_times = calculate_word_times(msg["alignment"], self._cumulative_time)
|
||||
await self.add_word_timestamps(word_times)
|
||||
self._cumulative_time = word_times[-1][1]
|
||||
alignment = msg["alignment"]
|
||||
word_times = calculate_word_times(alignment, self._cumulative_time)
|
||||
|
||||
if word_times:
|
||||
await self.add_word_timestamps(word_times)
|
||||
|
||||
# Calculate the actual end time of this audio chunk
|
||||
char_start_times_ms = alignment.get("charStartTimesMs", [])
|
||||
char_durations_ms = alignment.get("charDurationsMs", [])
|
||||
|
||||
if char_start_times_ms and char_durations_ms:
|
||||
# End time = start time of last character + duration of last character
|
||||
chunk_end_time_ms = char_start_times_ms[-1] + char_durations_ms[-1]
|
||||
chunk_end_time_seconds = chunk_end_time_ms / 1000.0
|
||||
self._cumulative_time += chunk_end_time_seconds
|
||||
else:
|
||||
# Fallback: use the last word's start time (current behavior)
|
||||
self._cumulative_time = word_times[-1][1]
|
||||
logger.warning(
|
||||
"_receive_messages: using fallback timing method - consider investigating alignment data structure"
|
||||
)
|
||||
|
||||
async def _keepalive_task_handler(self):
|
||||
"""Send periodic keepalive messages to maintain WebSocket connection."""
|
||||
@@ -542,7 +595,7 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
self.reset_watchdog()
|
||||
await asyncio.sleep(KEEPALIVE_SLEEP)
|
||||
try:
|
||||
if self._websocket and self._websocket.open:
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
if self._context_id:
|
||||
# Send keepalive with context ID to keep the connection alive
|
||||
keepalive_message = {
|
||||
@@ -580,7 +633,7 @@ class ElevenLabsTTSService(AudioContextWordTTSService):
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
try:
|
||||
if not self._websocket or self._websocket.closed:
|
||||
if not self._websocket or self._websocket.state is State.CLOSED:
|
||||
await self._connect()
|
||||
|
||||
try:
|
||||
|
||||
@@ -34,7 +34,8 @@ from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
try:
|
||||
import ormsgpack
|
||||
import websockets
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
from websockets.protocol import State
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use Fish Audio, you need to `pip install pipecat-ai[fish]`.")
|
||||
@@ -210,13 +211,13 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
|
||||
async def _connect_websocket(self):
|
||||
try:
|
||||
if self._websocket and self._websocket.open:
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
return
|
||||
|
||||
logger.debug("Connecting to Fish Audio")
|
||||
headers = {"Authorization": f"Bearer {self._api_key}"}
|
||||
headers["model"] = self.model_name
|
||||
self._websocket = await websockets.connect(self._base_url, extra_headers=headers)
|
||||
self._websocket = await websocket_connect(self._base_url, additional_headers=headers)
|
||||
|
||||
# Send initial start message with ormsgpack
|
||||
start_message = {"event": "start", "request": {"text": "", **self._settings}}
|
||||
@@ -246,7 +247,7 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
async def flush_audio(self):
|
||||
"""Flush any buffered audio by sending a flush event to Fish Audio."""
|
||||
logger.trace(f"{self}: Flushing audio buffers")
|
||||
if not self._websocket or self._websocket.closed:
|
||||
if not self._websocket or self._websocket.state is State.CLOSED:
|
||||
return
|
||||
flush_message = {"event": "flush"}
|
||||
await self._get_websocket().send(ormsgpack.packb(flush_message))
|
||||
@@ -292,7 +293,7 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
"""
|
||||
logger.debug(f"{self}: Generating Fish TTS: [{text}]")
|
||||
try:
|
||||
if not self._websocket or self._websocket.closed:
|
||||
if not self._websocket or self._websocket.state is State.CLOSED:
|
||||
await self._connect()
|
||||
|
||||
if not self._request_id:
|
||||
|
||||
@@ -248,6 +248,55 @@ class Config(BaseModel):
|
||||
setup: Setup
|
||||
|
||||
|
||||
#
|
||||
# Grounding metadata models
|
||||
#
|
||||
|
||||
|
||||
class SearchEntryPoint(BaseModel):
|
||||
"""Represents the search entry point with rendered content for search suggestions."""
|
||||
|
||||
renderedContent: Optional[str] = None
|
||||
|
||||
|
||||
class WebSource(BaseModel):
|
||||
"""Represents a web source from grounding chunks."""
|
||||
|
||||
uri: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
|
||||
|
||||
class GroundingChunk(BaseModel):
|
||||
"""Represents a grounding chunk containing web source information."""
|
||||
|
||||
web: Optional[WebSource] = None
|
||||
|
||||
|
||||
class GroundingSegment(BaseModel):
|
||||
"""Represents a segment of text that is grounded."""
|
||||
|
||||
startIndex: Optional[int] = None
|
||||
endIndex: Optional[int] = None
|
||||
text: Optional[str] = None
|
||||
|
||||
|
||||
class GroundingSupport(BaseModel):
|
||||
"""Represents support information for grounded text segments."""
|
||||
|
||||
segment: Optional[GroundingSegment] = None
|
||||
groundingChunkIndices: Optional[List[int]] = None
|
||||
confidenceScores: Optional[List[float]] = None
|
||||
|
||||
|
||||
class GroundingMetadata(BaseModel):
|
||||
"""Represents grounding metadata from Google Search."""
|
||||
|
||||
searchEntryPoint: Optional[SearchEntryPoint] = None
|
||||
groundingChunks: Optional[List[GroundingChunk]] = None
|
||||
groundingSupports: Optional[List[GroundingSupport]] = None
|
||||
webSearchQueries: Optional[List[str]] = None
|
||||
|
||||
|
||||
#
|
||||
# Server events
|
||||
#
|
||||
@@ -339,6 +388,7 @@ class ServerContent(BaseModel):
|
||||
turnComplete: Optional[bool] = None
|
||||
inputTranscription: Optional[BidiGenerateContentTranscription] = None
|
||||
outputTranscription: Optional[BidiGenerateContentTranscription] = None
|
||||
groundingMetadata: Optional[GroundingMetadata] = None
|
||||
|
||||
|
||||
class FunctionCall(BaseModel):
|
||||
|
||||
@@ -75,7 +75,7 @@ from . import events
|
||||
from .file_api import GeminiFileAPI
|
||||
|
||||
try:
|
||||
import websockets
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use Google AI, you need to `pip install pipecat-ai[google]`.")
|
||||
@@ -271,6 +271,7 @@ class GeminiMultimodalLiveContext(OpenAILLMContext):
|
||||
parts.append({"text": part.get("text")})
|
||||
elif part.get("type") == "file_data":
|
||||
file_data = part.get("file_data", {})
|
||||
|
||||
parts.append(
|
||||
{
|
||||
"fileData": {
|
||||
@@ -572,6 +573,10 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
# Initialize the File API client
|
||||
self.file_api = GeminiFileAPI(api_key=api_key, base_url=file_api_base_url)
|
||||
|
||||
# Grounding metadata tracking
|
||||
self._search_result_buffer = ""
|
||||
self._accumulated_grounding_metadata = None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if the service can generate usage metrics.
|
||||
|
||||
@@ -786,7 +791,7 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
try:
|
||||
logger.info(f"Connecting to wss://{self._base_url}")
|
||||
uri = f"wss://{self._base_url}?key={self._api_key}"
|
||||
self._websocket = await websockets.connect(uri=uri)
|
||||
self._websocket = await websocket_connect(uri=uri)
|
||||
self._receive_task = self.create_task(self._receive_task_handler())
|
||||
|
||||
# Create the basic configuration
|
||||
@@ -936,6 +941,8 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
await self._handle_evt_input_transcription(evt)
|
||||
elif evt.serverContent and evt.serverContent.outputTranscription:
|
||||
await self._handle_evt_output_transcription(evt)
|
||||
elif evt.serverContent and evt.serverContent.groundingMetadata:
|
||||
await self._handle_evt_grounding_metadata(evt)
|
||||
elif evt.toolCall:
|
||||
await self._handle_evt_tool_call(evt)
|
||||
elif False: # !!! todo: error events?
|
||||
@@ -1027,6 +1034,7 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
parts.append({"text": part.get("text")})
|
||||
elif part.get("type") == "file_data":
|
||||
file_data = part.get("file_data", {})
|
||||
|
||||
parts.append(
|
||||
{
|
||||
"fileData": {
|
||||
@@ -1107,8 +1115,13 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
|
||||
self._bot_text_buffer += text
|
||||
self._search_result_buffer += text # Also accumulate for grounding
|
||||
await self.push_frame(LLMTextFrame(text=text))
|
||||
|
||||
# Check for grounding metadata in server content
|
||||
if evt.serverContent and evt.serverContent.groundingMetadata:
|
||||
self._accumulated_grounding_metadata = evt.serverContent.groundingMetadata
|
||||
|
||||
inline_data = part.inlineData
|
||||
if not inline_data:
|
||||
return
|
||||
@@ -1176,6 +1189,16 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
self._bot_text_buffer = ""
|
||||
self._llm_output_buffer = ""
|
||||
|
||||
# Process grounding metadata if we have accumulated any
|
||||
if self._accumulated_grounding_metadata:
|
||||
await self._process_grounding_metadata(
|
||||
self._accumulated_grounding_metadata, self._search_result_buffer
|
||||
)
|
||||
|
||||
# Reset grounding tracking for next response
|
||||
self._search_result_buffer = ""
|
||||
self._accumulated_grounding_metadata = None
|
||||
|
||||
# Only push the TTSStoppedFrame if the bot is outputting audio
|
||||
# when text is found, modalities is set to TEXT and no audio
|
||||
# is produced.
|
||||
@@ -1252,12 +1275,74 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
if not text:
|
||||
return
|
||||
|
||||
# Accumulate text for grounding as well
|
||||
self._search_result_buffer += text
|
||||
|
||||
# Check for grounding metadata in server content
|
||||
if evt.serverContent and evt.serverContent.groundingMetadata:
|
||||
self._accumulated_grounding_metadata = evt.serverContent.groundingMetadata
|
||||
# Collect text for tracing
|
||||
self._llm_output_buffer += text
|
||||
|
||||
await self.push_frame(LLMTextFrame(text=text))
|
||||
await self.push_frame(TTSTextFrame(text=text))
|
||||
|
||||
async def _handle_evt_grounding_metadata(self, evt):
|
||||
"""Handle dedicated grounding metadata events."""
|
||||
if evt.serverContent and evt.serverContent.groundingMetadata:
|
||||
grounding_metadata = evt.serverContent.groundingMetadata
|
||||
# Process the grounding metadata immediately
|
||||
await self._process_grounding_metadata(grounding_metadata, self._search_result_buffer)
|
||||
|
||||
async def _process_grounding_metadata(
|
||||
self, grounding_metadata: events.GroundingMetadata, search_result: str = ""
|
||||
):
|
||||
"""Process grounding metadata and emit LLMSearchResponseFrame."""
|
||||
if not grounding_metadata:
|
||||
return
|
||||
|
||||
# Extract rendered content for search suggestions
|
||||
rendered_content = None
|
||||
if (
|
||||
grounding_metadata.searchEntryPoint
|
||||
and grounding_metadata.searchEntryPoint.renderedContent
|
||||
):
|
||||
rendered_content = grounding_metadata.searchEntryPoint.renderedContent
|
||||
|
||||
# Convert grounding chunks and supports to LLMSearchOrigin format
|
||||
origins = []
|
||||
|
||||
if grounding_metadata.groundingChunks and grounding_metadata.groundingSupports:
|
||||
# Create a mapping of chunk indices to origins
|
||||
chunk_to_origin = {}
|
||||
|
||||
for index, chunk in enumerate(grounding_metadata.groundingChunks):
|
||||
if chunk.web:
|
||||
origin = LLMSearchOrigin(
|
||||
site_uri=chunk.web.uri, site_title=chunk.web.title, results=[]
|
||||
)
|
||||
chunk_to_origin[index] = origin
|
||||
origins.append(origin)
|
||||
|
||||
# Add grounding support results to the appropriate origins
|
||||
for support in grounding_metadata.groundingSupports:
|
||||
if support.segment and support.groundingChunkIndices:
|
||||
text = support.segment.text or ""
|
||||
confidence_scores = support.confidenceScores or []
|
||||
|
||||
# Add this result to all origins referenced by this support
|
||||
for chunk_index in support.groundingChunkIndices:
|
||||
if chunk_index in chunk_to_origin:
|
||||
result = LLMSearchResult(text=text, confidence=confidence_scores)
|
||||
chunk_to_origin[chunk_index].results.append(result)
|
||||
|
||||
# Create and push the search response frame
|
||||
search_frame = LLMSearchResponseFrame(
|
||||
search_result=search_result, origins=origins, rendered_content=rendered_content
|
||||
)
|
||||
|
||||
await self.push_frame(search_frame)
|
||||
|
||||
async def _handle_evt_usage_metadata(self, evt):
|
||||
"""Handle the usage metadata event."""
|
||||
if not evt.usageMetadata:
|
||||
|
||||
@@ -37,6 +37,8 @@ from pipecat.utils.tracing.service_decorators import traced_stt
|
||||
|
||||
try:
|
||||
import websockets
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
from websockets.protocol import State
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use Gladia, you need to `pip install pipecat-ai[gladia]`.")
|
||||
@@ -402,7 +404,7 @@ class GladiaSTTService(STTService):
|
||||
logger.warning(f"Audio buffer exceeded max size, trimmed {trim_size} bytes")
|
||||
|
||||
# Send audio if connected
|
||||
if self._connection_active and self._websocket and not self._websocket.closed:
|
||||
if self._connection_active and self._websocket and self._websocket.state is State.OPEN:
|
||||
try:
|
||||
await self._send_audio(audio)
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
@@ -423,7 +425,7 @@ class GladiaSTTService(STTService):
|
||||
self._reconnection_attempts = 0
|
||||
|
||||
# Connect with automatic reconnection
|
||||
async with websockets.connect(self._session_url) as websocket:
|
||||
async with websocket_connect(self._session_url) as websocket:
|
||||
try:
|
||||
self._websocket = websocket
|
||||
self._connection_active = True
|
||||
@@ -507,7 +509,7 @@ class GladiaSTTService(STTService):
|
||||
|
||||
async def _send_audio(self, audio: bytes):
|
||||
"""Send audio chunk with proper message format."""
|
||||
if self._websocket and not self._websocket.closed:
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
data = base64.b64encode(audio).decode("utf-8")
|
||||
message = {"type": "audio_chunk", "data": {"chunk": data}}
|
||||
await self._websocket.send(json.dumps(message))
|
||||
@@ -520,7 +522,7 @@ class GladiaSTTService(STTService):
|
||||
await self._send_audio(bytes(self._audio_buffer))
|
||||
|
||||
async def _send_stop_recording(self):
|
||||
if self._websocket and not self._websocket.closed:
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
await self._websocket.send(json.dumps({"type": "stop_recording"}))
|
||||
|
||||
async def _keepalive_task_handler(self):
|
||||
@@ -531,7 +533,7 @@ class GladiaSTTService(STTService):
|
||||
self.reset_watchdog()
|
||||
# Send keepalive (Gladia times out after 30 seconds)
|
||||
await asyncio.sleep(KEEPALIVE_SLEEP)
|
||||
if self._websocket and not self._websocket.closed:
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
# Send an empty audio chunk as keepalive
|
||||
empty_audio = b""
|
||||
await self._send_audio(empty_audio)
|
||||
|
||||
@@ -627,9 +627,9 @@ class GoogleLLMContext(OpenAILLMContext):
|
||||
# Check if we only have function-related messages (no regular text)
|
||||
has_regular_messages = any(
|
||||
len(msg.parts) == 1
|
||||
and not getattr(msg.parts[0], "text", None)
|
||||
and getattr(msg.parts[0], "function_call", None)
|
||||
and getattr(msg.parts[0], "function_response", None)
|
||||
and getattr(msg.parts[0], "text", None)
|
||||
and not getattr(msg.parts[0], "function_call", None)
|
||||
and not getattr(msg.parts[0], "function_response", None)
|
||||
for msg in self._messages
|
||||
)
|
||||
|
||||
|
||||
@@ -176,6 +176,7 @@ class LLMService(AIService):
|
||||
self._functions: Dict[Optional[str], FunctionCallRegistryItem] = {}
|
||||
self._function_call_tasks: Dict[asyncio.Task, FunctionCallRunnerItem] = {}
|
||||
self._sequential_runner_task: Optional[asyncio.Task] = None
|
||||
self._tracing_enabled: bool = False
|
||||
|
||||
self._register_event_handler("on_function_calls_started")
|
||||
self._register_event_handler("on_completion_timeout")
|
||||
@@ -218,6 +219,7 @@ class LLMService(AIService):
|
||||
await super().start(frame)
|
||||
if not self._run_in_parallel:
|
||||
await self._create_sequential_runner_task()
|
||||
self._tracing_enabled = frame.enable_tracing
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the LLM service.
|
||||
|
||||
@@ -29,7 +29,8 @@ from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
# See .env.example for LMNT configuration needed
|
||||
try:
|
||||
import websockets
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
from websockets.protocol import State
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use LMNT, you need to `pip install pipecat-ai[lmnt]`.")
|
||||
@@ -95,7 +96,7 @@ class LmntTTSService(InterruptibleTTSService):
|
||||
voice_id: str,
|
||||
sample_rate: Optional[int] = None,
|
||||
language: Language = Language.EN,
|
||||
model: str = "aurora",
|
||||
model: str = "blizzard",
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the LMNT TTS service.
|
||||
@@ -105,7 +106,7 @@ class LmntTTSService(InterruptibleTTSService):
|
||||
voice_id: ID of the voice to use for synthesis.
|
||||
sample_rate: Audio sample rate. If None, uses default.
|
||||
language: Language for synthesis. Defaults to English.
|
||||
model: TTS model to use. Defaults to "aurora".
|
||||
model: TTS model to use. Defaults to "blizzard".
|
||||
**kwargs: Additional arguments passed to parent InterruptibleTTSService.
|
||||
"""
|
||||
super().__init__(
|
||||
@@ -200,7 +201,7 @@ class LmntTTSService(InterruptibleTTSService):
|
||||
async def _connect_websocket(self):
|
||||
"""Connect to LMNT websocket."""
|
||||
try:
|
||||
if self._websocket and self._websocket.open:
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
return
|
||||
|
||||
logger.debug("Connecting to LMNT")
|
||||
@@ -216,7 +217,7 @@ class LmntTTSService(InterruptibleTTSService):
|
||||
}
|
||||
|
||||
# Connect to LMNT's websocket directly
|
||||
self._websocket = await websockets.connect("wss://api.lmnt.com/v1/ai/speech/stream")
|
||||
self._websocket = await websocket_connect("wss://api.lmnt.com/v1/ai/speech/stream")
|
||||
|
||||
# Send initialization message
|
||||
await self._websocket.send(json.dumps(init_msg))
|
||||
@@ -251,7 +252,7 @@ class LmntTTSService(InterruptibleTTSService):
|
||||
|
||||
async def flush_audio(self):
|
||||
"""Flush any pending audio synthesis."""
|
||||
if not self._websocket or self._websocket.closed:
|
||||
if not self._websocket or self._websocket.state is State.CLOSED:
|
||||
return
|
||||
await self._get_websocket().send(json.dumps({"flush": True}))
|
||||
|
||||
@@ -292,7 +293,7 @@ class LmntTTSService(InterruptibleTTSService):
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
try:
|
||||
if not self._websocket or self._websocket.closed:
|
||||
if not self._websocket or self._websocket.state is State.CLOSED:
|
||||
await self._connect()
|
||||
|
||||
try:
|
||||
|
||||
@@ -13,6 +13,7 @@ from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.utils.base_object import BaseObject
|
||||
|
||||
try:
|
||||
@@ -165,27 +166,24 @@ class MCPClient(BaseObject):
|
||||
A ToolsSchema containing all registered tools
|
||||
"""
|
||||
|
||||
async def mcp_tool_wrapper(
|
||||
function_name: str,
|
||||
tool_call_id: str,
|
||||
arguments: Dict[str, Any],
|
||||
llm: any,
|
||||
context: any,
|
||||
result_callback: any,
|
||||
) -> None:
|
||||
async def mcp_tool_wrapper(params: FunctionCallParams) -> None:
|
||||
"""Wrapper for mcp tool calls to match Pipecat's function call interface."""
|
||||
logger.debug(f"Executing tool '{function_name}' with call ID: {tool_call_id}")
|
||||
logger.trace(f"Tool arguments: {json.dumps(arguments, indent=2)}")
|
||||
logger.debug(
|
||||
f"Executing tool '{params.function_name}' with call ID: {params.tool_call_id}"
|
||||
)
|
||||
logger.trace(f"Tool arguments: {json.dumps(params.arguments, indent=2)}")
|
||||
try:
|
||||
async with self._client(**self._server_params.model_dump()) as (read, write):
|
||||
async with self._session(read, write) as session:
|
||||
await session.initialize()
|
||||
await self._call_tool(session, function_name, arguments, result_callback)
|
||||
await self._call_tool(
|
||||
session, params.function_name, params.arguments, params.result_callback
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = f"Error calling mcp tool {function_name}: {str(e)}"
|
||||
error_msg = f"Error calling mcp tool {params.function_name}: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
logger.exception("Full exception details:")
|
||||
await result_callback(error_msg)
|
||||
await params.result_callback(error_msg)
|
||||
|
||||
logger.debug(f"SSE server parameters: {self._server_params}")
|
||||
logger.debug("Starting registration of mcp tools")
|
||||
@@ -205,27 +203,24 @@ class MCPClient(BaseObject):
|
||||
A ToolsSchema containing all registered tools
|
||||
"""
|
||||
|
||||
async def mcp_tool_wrapper(
|
||||
function_name: str,
|
||||
tool_call_id: str,
|
||||
arguments: Dict[str, Any],
|
||||
llm: any,
|
||||
context: any,
|
||||
result_callback: any,
|
||||
) -> None:
|
||||
async def mcp_tool_wrapper(params: FunctionCallParams) -> None:
|
||||
"""Wrapper for mcp tool calls to match Pipecat's function call interface."""
|
||||
logger.debug(f"Executing tool '{function_name}' with call ID: {tool_call_id}")
|
||||
logger.trace(f"Tool arguments: {json.dumps(arguments, indent=2)}")
|
||||
logger.debug(
|
||||
f"Executing tool '{params.function_name}' with call ID: {params.tool_call_id}"
|
||||
)
|
||||
logger.trace(f"Tool arguments: {json.dumps(params.arguments, indent=2)}")
|
||||
try:
|
||||
async with self._client(self._server_params) as streams:
|
||||
async with self._session(streams[0], streams[1]) as session:
|
||||
await session.initialize()
|
||||
await self._call_tool(session, function_name, arguments, result_callback)
|
||||
await self._call_tool(
|
||||
session, params.function_name, params.arguments, params.result_callback
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = f"Error calling mcp tool {function_name}: {str(e)}"
|
||||
error_msg = f"Error calling mcp tool {params.function_name}: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
logger.exception("Full exception details:")
|
||||
await result_callback(error_msg)
|
||||
await params.result_callback(error_msg)
|
||||
|
||||
logger.debug("Starting registration of mcp tools")
|
||||
|
||||
@@ -244,17 +239,12 @@ class MCPClient(BaseObject):
|
||||
A ToolsSchema containing all registered tools
|
||||
"""
|
||||
|
||||
async def mcp_tool_wrapper(
|
||||
function_name: str,
|
||||
tool_call_id: str,
|
||||
arguments: Dict[str, Any],
|
||||
llm: any,
|
||||
context: any,
|
||||
result_callback: any,
|
||||
) -> None:
|
||||
async def mcp_tool_wrapper(params: FunctionCallParams) -> None:
|
||||
"""Wrapper for mcp tool calls to match Pipecat's function call interface."""
|
||||
logger.debug(f"Executing tool '{function_name}' with call ID: {tool_call_id}")
|
||||
logger.trace(f"Tool arguments: {json.dumps(arguments, indent=2)}")
|
||||
logger.debug(
|
||||
f"Executing tool '{params.function_name}' with call ID: {params.tool_call_id}"
|
||||
)
|
||||
logger.trace(f"Tool arguments: {json.dumps(params.arguments, indent=2)}")
|
||||
try:
|
||||
async with self._client(**self._server_params.model_dump()) as (
|
||||
read_stream,
|
||||
@@ -263,12 +253,14 @@ class MCPClient(BaseObject):
|
||||
):
|
||||
async with self._session(read_stream, write_stream) as session:
|
||||
await session.initialize()
|
||||
await self._call_tool(session, function_name, arguments, result_callback)
|
||||
await self._call_tool(
|
||||
session, params.function_name, params.arguments, params.result_callback
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = f"Error calling mcp tool {function_name}: {str(e)}"
|
||||
error_msg = f"Error calling mcp tool {params.function_name}: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
logger.exception("Full exception details:")
|
||||
await result_callback(error_msg)
|
||||
await params.result_callback(error_msg)
|
||||
|
||||
logger.debug("Starting registration of mcp tools using streamable HTTP")
|
||||
|
||||
|
||||
@@ -69,6 +69,7 @@ class Mem0MemoryService(FrameProcessor):
|
||||
agent_id: Optional[str] = None,
|
||||
run_id: Optional[str] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
host: Optional[str] = None,
|
||||
):
|
||||
"""Initialize the Mem0 memory service.
|
||||
|
||||
@@ -79,6 +80,7 @@ class Mem0MemoryService(FrameProcessor):
|
||||
agent_id: The agent ID to associate with memories in Mem0.
|
||||
run_id: The run ID to associate with memories in Mem0.
|
||||
params: Configuration parameters for memory retrieval and storage.
|
||||
host: The host of the Mem0 server.
|
||||
|
||||
Raises:
|
||||
ValueError: If none of user_id, agent_id, or run_id are provided.
|
||||
@@ -92,7 +94,7 @@ class Mem0MemoryService(FrameProcessor):
|
||||
if local_config:
|
||||
self.memory_client = Memory.from_config(local_config)
|
||||
else:
|
||||
self.memory_client = MemoryClient(api_key=api_key)
|
||||
self.memory_client = MemoryClient(api_key=api_key, host=host)
|
||||
# At least one of user_id, agent_id, or run_id must be provided
|
||||
if not any([user_id, agent_id, run_id]):
|
||||
raise ValueError("At least one of user_id, agent_id, or run_id must be provided")
|
||||
|
||||
@@ -109,7 +109,7 @@ class MiniMaxHttpTTSService(TTSService):
|
||||
language: Optional[Language] = Language.EN
|
||||
speed: Optional[float] = 1.0
|
||||
volume: Optional[float] = 1.0
|
||||
pitch: Optional[float] = 0
|
||||
pitch: Optional[int] = 0
|
||||
emotion: Optional[str] = None
|
||||
english_normalization: Optional[bool] = None
|
||||
|
||||
@@ -117,6 +117,7 @@ class MiniMaxHttpTTSService(TTSService):
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
base_url: str = "https://api.minimax.io/v1/t2a_v2",
|
||||
group_id: str,
|
||||
model: str = "speech-02-turbo",
|
||||
voice_id: str = "Calm_Woman",
|
||||
@@ -129,6 +130,9 @@ class MiniMaxHttpTTSService(TTSService):
|
||||
|
||||
Args:
|
||||
api_key: MiniMax API key for authentication.
|
||||
base_url: API base URL, defaults to MiniMax's T2A endpoint.
|
||||
Global: https://api.minimax.io/v1/t2a_v2
|
||||
Mainland China: https://api.minimaxi.chat/v1/t2a_v2
|
||||
group_id: MiniMax Group ID to identify project.
|
||||
model: TTS model name. Defaults to "speech-02-turbo". Options include
|
||||
"speech-02-hd", "speech-02-turbo", "speech-01-hd", "speech-01-turbo".
|
||||
@@ -144,7 +148,7 @@ class MiniMaxHttpTTSService(TTSService):
|
||||
|
||||
self._api_key = api_key
|
||||
self._group_id = group_id
|
||||
self._base_url = f"https://api.minimaxi.chat/v1/t2a_v2?GroupId={group_id}"
|
||||
self._base_url = f"{base_url}?GroupId={group_id}"
|
||||
self._session = aiohttp_session
|
||||
self._model_name = model
|
||||
self._voice_id = voice_id
|
||||
|
||||
@@ -15,6 +15,7 @@ import base64
|
||||
import json
|
||||
from typing import Any, AsyncGenerator, Mapping, Optional
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -39,8 +40,8 @@ from pipecat.utils.asyncio.watchdog_async_iterator import WatchdogAsyncIterator
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
try:
|
||||
import websockets
|
||||
from pyneuphonic import Neuphonic, TTSConfig
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
from websockets.protocol import State
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use Neuphonic, you need to `pip install pipecat-ai[neuphonic]`.")
|
||||
@@ -271,7 +272,7 @@ class NeuphonicTTSService(InterruptibleTTSService):
|
||||
async def _connect_websocket(self):
|
||||
"""Establish WebSocket connection to Neuphonic API."""
|
||||
try:
|
||||
if self._websocket and self._websocket.open:
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
return
|
||||
|
||||
logger.debug("Connecting to Neuphonic")
|
||||
@@ -292,7 +293,7 @@ class NeuphonicTTSService(InterruptibleTTSService):
|
||||
|
||||
headers = {"x-api-key": self._api_key}
|
||||
|
||||
self._websocket = await websockets.connect(url, extra_headers=headers)
|
||||
self._websocket = await websocket_connect(url, additional_headers=headers)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._websocket = None
|
||||
@@ -359,7 +360,7 @@ class NeuphonicTTSService(InterruptibleTTSService):
|
||||
logger.debug(f"Generating TTS: [{text}]")
|
||||
|
||||
try:
|
||||
if not self._websocket or self._websocket.closed:
|
||||
if not self._websocket or self._websocket.state is State.CLOSED:
|
||||
await self._connect()
|
||||
|
||||
try:
|
||||
@@ -406,9 +407,10 @@ class NeuphonicHttpTTSService(TTSService):
|
||||
*,
|
||||
api_key: str,
|
||||
voice_id: Optional[str] = None,
|
||||
aiohttp_session: aiohttp.ClientSession,
|
||||
url: str = "https://api.neuphonic.com",
|
||||
sample_rate: Optional[int] = 22050,
|
||||
encoding: str = "pcm_linear",
|
||||
encoding: Optional[str] = "pcm_linear",
|
||||
params: Optional[InputParams] = None,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -417,6 +419,7 @@ class NeuphonicHttpTTSService(TTSService):
|
||||
Args:
|
||||
api_key: Neuphonic API key for authentication.
|
||||
voice_id: ID of the voice to use for synthesis.
|
||||
aiohttp_session: Shared aiohttp session for HTTP requests.
|
||||
url: Base URL for the Neuphonic HTTP API.
|
||||
sample_rate: Audio sample rate in Hz. Defaults to 22050.
|
||||
encoding: Audio encoding format. Defaults to "pcm_linear".
|
||||
@@ -428,13 +431,11 @@ class NeuphonicHttpTTSService(TTSService):
|
||||
params = params or NeuphonicHttpTTSService.InputParams()
|
||||
|
||||
self._api_key = api_key
|
||||
self._url = url
|
||||
self._settings = {
|
||||
"lang_code": self.language_to_service_language(params.language),
|
||||
"speed": params.speed,
|
||||
"encoding": encoding,
|
||||
"sampling_rate": sample_rate,
|
||||
}
|
||||
self._session = aiohttp_session
|
||||
self._base_url = url.rstrip("/")
|
||||
self._lang_code = self.language_to_service_language(params.language) or "en"
|
||||
self._speed = params.speed
|
||||
self._encoding = encoding
|
||||
self.set_voice(voice_id)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
@@ -472,6 +473,40 @@ class NeuphonicHttpTTSService(TTSService):
|
||||
"""
|
||||
pass
|
||||
|
||||
def _parse_sse_message(self, message: str) -> dict | None:
|
||||
"""Parse a Server-Sent Event message.
|
||||
|
||||
Args:
|
||||
message: The SSE message to parse.
|
||||
|
||||
Returns:
|
||||
Parsed message dictionary or None if not a data message.
|
||||
"""
|
||||
message = message.strip()
|
||||
|
||||
if not message or "data" not in message:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Split on ": " and take the part after "data: "
|
||||
_, data_content = message.split(": ", 1)
|
||||
|
||||
if not data_content or data_content == "[DONE]":
|
||||
return None
|
||||
|
||||
message_dict = json.loads(data_content)
|
||||
|
||||
# Check for errors in the response
|
||||
if message_dict.get("errors") is not None:
|
||||
raise Exception(
|
||||
f"Neuphonic API error {message_dict.get('status_code', 'unknown')}: {message_dict['errors']}"
|
||||
)
|
||||
|
||||
return message_dict
|
||||
except (ValueError, json.JSONDecodeError) as e:
|
||||
logger.warning(f"Failed to parse SSE message: {e}")
|
||||
return None
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate speech from text using Neuphonic streaming API.
|
||||
@@ -484,26 +519,71 @@ class NeuphonicHttpTTSService(TTSService):
|
||||
"""
|
||||
logger.debug(f"Generating TTS: [{text}]")
|
||||
|
||||
client = Neuphonic(api_key=self._api_key, base_url=self._url.replace("https://", ""))
|
||||
url = f"{self._base_url}/sse/speak/{self._lang_code}"
|
||||
|
||||
sse = client.tts.AsyncSSEClient()
|
||||
headers = {
|
||||
"X-API-KEY": self._api_key,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
payload = {
|
||||
"text": text,
|
||||
"lang_code": self._lang_code,
|
||||
"encoding": self._encoding,
|
||||
"sampling_rate": self.sample_rate,
|
||||
"speed": self._speed,
|
||||
}
|
||||
|
||||
if self._voice_id:
|
||||
payload["voice_id"] = self._voice_id
|
||||
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
response = sse.send(text, TTSConfig(**self._settings, voice_id=self._voice_id))
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
yield TTSStartedFrame()
|
||||
async with self._session.post(url, json=payload, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
error_message = f"Neuphonic API error: HTTP {response.status} - {error_text}"
|
||||
logger.error(error_message)
|
||||
yield ErrorFrame(error=error_message)
|
||||
return
|
||||
|
||||
async for message in response:
|
||||
if message.status_code != 200:
|
||||
logger.error(f"{self} error: {message.errors}")
|
||||
yield ErrorFrame(error=f"Neuphonic API error: {message.errors}")
|
||||
await self.start_tts_usage_metrics(text)
|
||||
yield TTSStartedFrame()
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSAudioRawFrame(message.data.audio, self.sample_rate, 1)
|
||||
# Process SSE stream line by line
|
||||
async for line in response.content:
|
||||
if not line:
|
||||
continue
|
||||
|
||||
message = line.decode("utf-8", errors="ignore")
|
||||
if not message.strip():
|
||||
continue
|
||||
|
||||
try:
|
||||
parsed_message = self._parse_sse_message(message)
|
||||
|
||||
if (
|
||||
parsed_message is not None
|
||||
and parsed_message.get("data", {}).get("audio") is not None
|
||||
):
|
||||
audio_b64 = parsed_message["data"]["audio"]
|
||||
audio_bytes = base64.b64decode(audio_b64)
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSAudioRawFrame(audio_bytes, self.sample_rate, 1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing SSE message: {e}")
|
||||
# Don't yield error frame for individual message failures
|
||||
continue
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("TTS generation cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in run_tts: {e}")
|
||||
yield ErrorFrame(error=str(e))
|
||||
logger.exception(f"Error in run_tts: {e}")
|
||||
yield ErrorFrame(error=f"Neuphonic TTS error: {str(e)}")
|
||||
finally:
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -42,4 +42,4 @@ class OLLamaLLMService(OpenAILLMService):
|
||||
An OpenAI-compatible client configured for Ollama.
|
||||
"""
|
||||
logger.debug(f"Creating Ollama client with api {base_url}")
|
||||
return super().create_client(base_url, **kwargs)
|
||||
return super().create_client(base_url=base_url, **kwargs)
|
||||
|
||||
@@ -11,7 +11,7 @@ from loguru import logger
|
||||
from .openai import OpenAIRealtimeBetaLLMService
|
||||
|
||||
try:
|
||||
import websockets
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
@@ -55,9 +55,9 @@ class AzureRealtimeBetaLLMService(OpenAIRealtimeBetaLLMService):
|
||||
return
|
||||
|
||||
logger.info(f"Connecting to {self.base_url}, api key: {self.api_key}")
|
||||
self._websocket = await websockets.connect(
|
||||
self._websocket = await websocket_connect(
|
||||
uri=self.base_url,
|
||||
extra_headers={
|
||||
additional_headers={
|
||||
"api-key": self.api_key,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -66,7 +66,7 @@ from .context import (
|
||||
from .frames import RealtimeFunctionCallResultFrame, RealtimeMessagesUpdateFrame
|
||||
|
||||
try:
|
||||
import websockets
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use OpenAI, you need to `pip install pipecat-ai[openai]`.")
|
||||
@@ -387,9 +387,9 @@ class OpenAIRealtimeBetaLLMService(LLMService):
|
||||
# Here we assume that if we have a websocket, we are connected. We
|
||||
# handle disconnections in the send/recv code paths.
|
||||
return
|
||||
self._websocket = await websockets.connect(
|
||||
self._websocket = await websocket_connect(
|
||||
uri=self.base_url,
|
||||
extra_headers={
|
||||
additional_headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"OpenAI-Beta": "realtime=v1",
|
||||
},
|
||||
|
||||
@@ -17,7 +17,6 @@ import uuid
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
import aiohttp
|
||||
import websockets
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -41,6 +40,8 @@ try:
|
||||
from pyht.async_client import AsyncClient
|
||||
from pyht.client import Format, TTSOptions
|
||||
from pyht.client import Language as PlayHTLanguage
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
from websockets.protocol import State
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use PlayHT, you need to `pip install pipecat-ai[playht]`.")
|
||||
@@ -244,7 +245,7 @@ class PlayHTTTSService(InterruptibleTTSService):
|
||||
async def _connect_websocket(self):
|
||||
"""Connect to PlayHT websocket."""
|
||||
try:
|
||||
if self._websocket and self._websocket.open:
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
return
|
||||
|
||||
logger.debug("Connecting to PlayHT")
|
||||
@@ -255,7 +256,7 @@ class PlayHTTTSService(InterruptibleTTSService):
|
||||
if not isinstance(self._websocket_url, str):
|
||||
raise ValueError("WebSocket URL is not a string")
|
||||
|
||||
self._websocket = await websockets.connect(self._websocket_url)
|
||||
self._websocket = await websocket_connect(self._websocket_url)
|
||||
except ValueError as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._websocket = None
|
||||
@@ -362,7 +363,7 @@ class PlayHTTTSService(InterruptibleTTSService):
|
||||
|
||||
try:
|
||||
# Reconnect if the websocket is closed
|
||||
if not self._websocket or self._websocket.closed:
|
||||
if not self._websocket or self._websocket.state is State.CLOSED:
|
||||
await self._connect()
|
||||
|
||||
if not self._request_id:
|
||||
|
||||
@@ -39,7 +39,8 @@ from pipecat.utils.text.skip_tags_aggregator import SkipTagsAggregator
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
try:
|
||||
import websockets
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
from websockets.protocol import State
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use Rime, you need to `pip install pipecat-ai[rime]`.")
|
||||
@@ -238,13 +239,13 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
async def _connect_websocket(self):
|
||||
"""Connect to Rime websocket API with configured settings."""
|
||||
try:
|
||||
if self._websocket and self._websocket.open:
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
return
|
||||
|
||||
params = "&".join(f"{k}={v}" for k, v in self._settings.items())
|
||||
url = f"{self._url}?{params}"
|
||||
headers = {"Authorization": f"Bearer {self._api_key}"}
|
||||
self._websocket = await websockets.connect(url, extra_headers=headers)
|
||||
self._websocket = await websocket_connect(url, additional_headers=headers)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._websocket = None
|
||||
@@ -380,7 +381,7 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
"""
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
try:
|
||||
if not self._websocket or self._websocket.closed:
|
||||
if not self._websocket or self._websocket.state is State.CLOSED:
|
||||
await self._connect()
|
||||
|
||||
try:
|
||||
|
||||
0
src/pipecat/services/soniox/__init__.py
Normal file
0
src/pipecat/services/soniox/__init__.py
Normal file
398
src/pipecat/services/soniox/stt.py
Normal file
398
src/pipecat/services/soniox/stt.py
Normal file
@@ -0,0 +1,398 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Soniox speech-to-text service implementation."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.stt_service import STTService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
from pipecat.utils.tracing.service_decorators import traced_stt
|
||||
|
||||
try:
|
||||
import websockets
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
from websockets.protocol import State
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use Soniox, you need to `pip install pipecat-ai[soniox]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
KEEPALIVE_MESSAGE = '{"type": "keepalive"}'
|
||||
|
||||
FINALIZE_MESSAGE = '{"type": "finalize"}'
|
||||
|
||||
END_TOKEN = "<end>"
|
||||
|
||||
FINALIZED_TOKEN = "<fin>"
|
||||
|
||||
|
||||
class SonioxInputParams(BaseModel):
|
||||
"""Real-time transcription settings.
|
||||
|
||||
See Soniox WebSocket API documentation for more details:
|
||||
https://soniox.com/docs/speech-to-text/api-reference/websocket-api#configuration-parameters
|
||||
|
||||
Parameters:
|
||||
model: Model to use for transcription.
|
||||
audio_format: Audio format to use for transcription.
|
||||
num_channels: Number of channels to use for transcription.
|
||||
language_hints: List of language hints to use for transcription.
|
||||
context: Customization for transcription.
|
||||
enable_non_final_tokens: Whether to enable non-final tokens. If false, only final tokens will be returned.
|
||||
max_non_final_tokens_duration_ms: Maximum duration of non-final tokens.
|
||||
client_reference_id: Client reference ID to use for transcription.
|
||||
"""
|
||||
|
||||
model: str = "stt-rt-preview"
|
||||
|
||||
audio_format: Optional[str] = "pcm_s16le"
|
||||
num_channels: Optional[int] = 1
|
||||
|
||||
language_hints: Optional[List[Language]] = None
|
||||
context: Optional[str] = None
|
||||
|
||||
enable_non_final_tokens: Optional[bool] = True
|
||||
max_non_final_tokens_duration_ms: Optional[int] = None
|
||||
|
||||
client_reference_id: Optional[str] = None
|
||||
|
||||
|
||||
def is_end_token(token: dict) -> bool:
|
||||
"""Determine if a token is an end token."""
|
||||
return token["text"] == END_TOKEN or token["text"] == FINALIZED_TOKEN
|
||||
|
||||
|
||||
def language_to_soniox_language(language: Language) -> str:
|
||||
"""Pipecat Language enum uses same ISO 2-letter codes as Soniox, except with added regional variants.
|
||||
|
||||
For a list of all supported languages, see: https://soniox.com/docs/speech-to-text/core-concepts/supported-languages
|
||||
"""
|
||||
lang_str = str(language.value).lower()
|
||||
if "-" in lang_str:
|
||||
return lang_str.split("-")[0]
|
||||
return lang_str
|
||||
|
||||
|
||||
def _prepare_language_hints(
|
||||
language_hints: Optional[List[Language]],
|
||||
) -> Optional[List[str]]:
|
||||
if language_hints is None:
|
||||
return None
|
||||
|
||||
prepared_languages = [language_to_soniox_language(lang) for lang in language_hints]
|
||||
# Remove duplicates (in case of language_hints with multiple regions).
|
||||
return list(set(prepared_languages))
|
||||
|
||||
|
||||
class SonioxSTTService(STTService):
|
||||
"""Speech-to-Text service using Soniox's WebSocket API.
|
||||
|
||||
This service connects to Soniox's WebSocket API for real-time transcription
|
||||
with support for multiple languages, custom context, speaker diarization,
|
||||
and more.
|
||||
|
||||
For complete API documentation, see: https://soniox.com/docs/speech-to-text/api-reference/websocket-api
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
url: str = "wss://stt-rt.soniox.com/transcribe-websocket",
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[SonioxInputParams] = None,
|
||||
vad_force_turn_endpoint: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Soniox STT service.
|
||||
|
||||
Args:
|
||||
api_key: Soniox API key.
|
||||
url: Soniox WebSocket API URL.
|
||||
sample_rate: Audio sample rate.
|
||||
params: Additional configuration parameters, such as language hints, context and
|
||||
speaker diarization.
|
||||
vad_force_turn_endpoint: Listen to `UserStoppedSpeakingFrame` to send finalize message to Soniox. If disabled, Soniox will detect the end of the speech.
|
||||
**kwargs: Additional arguments passed to the STTService.
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
params = params or SonioxInputParams()
|
||||
|
||||
self._api_key = api_key
|
||||
self._url = url
|
||||
self.set_model_name(params.model)
|
||||
self._params = params
|
||||
self._vad_force_turn_endpoint = vad_force_turn_endpoint
|
||||
self._websocket = None
|
||||
|
||||
self._final_transcription_buffer = []
|
||||
self._last_tokens_received: Optional[float] = None
|
||||
|
||||
self._receive_task = None
|
||||
self._keepalive_task = None
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the Soniox STT websocket connection.
|
||||
|
||||
Args:
|
||||
frame: The start frame containing initialization parameters.
|
||||
"""
|
||||
await super().start(frame)
|
||||
if self._websocket:
|
||||
return
|
||||
|
||||
self._websocket = await websocket_connect(self._url)
|
||||
|
||||
if not self._websocket:
|
||||
logger.error(f"Unable to connect to Soniox API at {self._url}")
|
||||
|
||||
# If vad_force_turn_endpoint is not enabled, we need to enable endpoint detection.
|
||||
# Either one or the other is required.
|
||||
enable_endpoint_detection = not self._vad_force_turn_endpoint
|
||||
|
||||
# Send the initial configuration message.
|
||||
config = {
|
||||
"api_key": self._api_key,
|
||||
"model": self._model_name,
|
||||
"audio_format": self._params.audio_format,
|
||||
"num_channels": self._params.num_channels or 1,
|
||||
"enable_endpoint_detection": enable_endpoint_detection,
|
||||
"sample_rate": self.sample_rate,
|
||||
"language_hints": _prepare_language_hints(self._params.language_hints),
|
||||
"context": self._params.context,
|
||||
"enable_non_final_tokens": self._params.enable_non_final_tokens,
|
||||
"max_non_final_tokens_duration_ms": self._params.max_non_final_tokens_duration_ms,
|
||||
"client_reference_id": self._params.client_reference_id,
|
||||
}
|
||||
|
||||
# Send the configuration message.
|
||||
await self._websocket.send(json.dumps(config))
|
||||
|
||||
if self._websocket and not self._receive_task:
|
||||
self._receive_task = self.create_task(self._receive_task_handler())
|
||||
if self._websocket and not self._keepalive_task:
|
||||
self._keepalive_task = self.create_task(self._keepalive_task_handler())
|
||||
|
||||
async def _cleanup(self):
|
||||
if self._keepalive_task:
|
||||
await self.cancel_task(self._keepalive_task)
|
||||
self._keepalive_task = None
|
||||
|
||||
if self._websocket:
|
||||
await self._websocket.close()
|
||||
self._websocket = None
|
||||
|
||||
if self._receive_task:
|
||||
# Task cannot cancel itself. If task called _cleanup() we expect it to cancel itself.
|
||||
if self._receive_task != asyncio.current_task():
|
||||
await self.wait_for_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the Soniox STT websocket connection.
|
||||
|
||||
Stopping waits for the server to close the connection as we might receive
|
||||
additional final tokens after sending the stop recording message.
|
||||
|
||||
Args:
|
||||
frame: The end frame.
|
||||
"""
|
||||
await super().stop(frame)
|
||||
await self._send_stop_recording()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the Soniox STT websocket connection.
|
||||
|
||||
Compared to stop, this method closes the connection immediately without waiting
|
||||
for the server to close it. This is useful when we want to stop the connection
|
||||
immediately without waiting for the server to send any final tokens.
|
||||
|
||||
Args:
|
||||
frame: The cancel frame.
|
||||
"""
|
||||
await super().cancel(frame)
|
||||
await self._cleanup()
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
"""Send audio data to Soniox STT Service.
|
||||
|
||||
Args:
|
||||
audio: Raw audio bytes to transcribe.
|
||||
|
||||
Yields:
|
||||
Frame: None (transcription results come via WebSocket callbacks).
|
||||
"""
|
||||
await self.start_processing_metrics()
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
await self._websocket.send(audio)
|
||||
await self.stop_processing_metrics()
|
||||
|
||||
yield None
|
||||
|
||||
@traced_stt
|
||||
async def _handle_transcription(
|
||||
self, transcript: str, is_final: bool, language: Optional[Language] = None
|
||||
):
|
||||
"""Handle a transcription result with tracing."""
|
||||
pass
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Processes a frame of audio data, either buffering or transcribing it.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame processing.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, UserStoppedSpeakingFrame) and self._vad_force_turn_endpoint:
|
||||
# Send finalize message to Soniox so we get the final tokens asap.
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
await self._websocket.send(FINALIZE_MESSAGE)
|
||||
logger.debug(f"Triggered finalize event on: {frame.name=}, {direction=}")
|
||||
|
||||
async def _send_stop_recording(self):
|
||||
"""Send stop recording message to Soniox."""
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
# Send stop recording message
|
||||
await self._websocket.send("")
|
||||
|
||||
async def _keepalive_task_handler(self):
|
||||
"""Connection has to be open all the time."""
|
||||
try:
|
||||
while True:
|
||||
logger.debug("Sending keepalive message")
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
await self._websocket.send(KEEPALIVE_MESSAGE)
|
||||
else:
|
||||
logger.debug("WebSocket connection closed.")
|
||||
break
|
||||
await asyncio.sleep(5)
|
||||
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
# Expected when closing the connection
|
||||
logger.debug("WebSocket connection closed, keepalive task stopped.")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error (_keepalive_task_handler): {e}")
|
||||
await self.push_error(ErrorFrame(f"{self} error (_keepalive_task_handler): {e}"))
|
||||
|
||||
async def _receive_task_handler(self):
|
||||
if not self._websocket:
|
||||
return
|
||||
|
||||
# Transcription frame will be only sent after we get the "endpoint" event.
|
||||
self._final_transcription_buffer = []
|
||||
|
||||
async def send_endpoint_transcript():
|
||||
if self._final_transcription_buffer:
|
||||
text = "".join(map(lambda token: token["text"], self._final_transcription_buffer))
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
text=text,
|
||||
user_id=self._user_id,
|
||||
timestamp=time_now_iso8601(),
|
||||
result=self._final_transcription_buffer,
|
||||
)
|
||||
)
|
||||
await self._handle_transcription(text, is_final=True)
|
||||
await self.stop_processing_metrics()
|
||||
self._final_transcription_buffer = []
|
||||
|
||||
try:
|
||||
async for message in self._websocket:
|
||||
content = json.loads(message)
|
||||
|
||||
tokens = content["tokens"]
|
||||
|
||||
if tokens:
|
||||
if len(tokens) == 1 and tokens[0]["text"] == FINALIZED_TOKEN:
|
||||
# Ignore finalized token, prevent auto-finalize cycling.
|
||||
pass
|
||||
else:
|
||||
# Got at least one token, so we can reset the auto finalize delay.
|
||||
self._last_tokens_received = time.time()
|
||||
|
||||
# We will only send the final tokens after we get the "endpoint" event.
|
||||
non_final_transcription = []
|
||||
|
||||
for token in tokens:
|
||||
if token["is_final"]:
|
||||
if is_end_token(token):
|
||||
# Found an endpoint, tokens until here will be sent as transcript,
|
||||
# the rest will be sent as interim tokens (even final tokens).
|
||||
await send_endpoint_transcript()
|
||||
else:
|
||||
self._final_transcription_buffer.append(token)
|
||||
else:
|
||||
non_final_transcription.append(token)
|
||||
|
||||
if self._final_transcription_buffer or non_final_transcription:
|
||||
final_text = "".join(
|
||||
map(lambda token: token["text"], self._final_transcription_buffer)
|
||||
)
|
||||
non_final_text = "".join(
|
||||
map(lambda token: token["text"], non_final_transcription)
|
||||
)
|
||||
|
||||
await self.push_frame(
|
||||
InterimTranscriptionFrame(
|
||||
# Even final tokens are sent as interim tokens as we want to send
|
||||
# nicely formatted messages - therefore waiting for the endpoint.
|
||||
text=final_text + non_final_text,
|
||||
user_id=self._user_id,
|
||||
timestamp=time_now_iso8601(),
|
||||
result=self._final_transcription_buffer + non_final_transcription,
|
||||
)
|
||||
)
|
||||
|
||||
error_code = content.get("error_code")
|
||||
error_message = content.get("error_message")
|
||||
if error_code or error_message:
|
||||
# In case of error, still send the final transcript (if any remaining in the buffer).
|
||||
await send_endpoint_transcript()
|
||||
logger.error(
|
||||
f"{self} error: {error_code} (_receive_task_handler) - {error_message}"
|
||||
)
|
||||
await self.push_error(
|
||||
ErrorFrame(
|
||||
f"{self} error: {error_code} (_receive_task_handler) - {error_message}"
|
||||
)
|
||||
)
|
||||
|
||||
finished = content.get("finished")
|
||||
if finished:
|
||||
# When finished, still send the final transcript (if any remaining in the buffer).
|
||||
await send_endpoint_transcript()
|
||||
logger.debug("Transcription finished.")
|
||||
await self._cleanup()
|
||||
return
|
||||
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
# Expected when closing the connection.
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error: {e}")
|
||||
await self.push_error(ErrorFrame(f"{self} error: {e}"))
|
||||
@@ -56,6 +56,7 @@ class STTService(AIService):
|
||||
self._init_sample_rate = sample_rate
|
||||
self._sample_rate = 0
|
||||
self._settings: Dict[str, Any] = {}
|
||||
self._tracing_enabled: bool = False
|
||||
self._muted: bool = False
|
||||
self._user_id: str = ""
|
||||
|
||||
@@ -116,6 +117,7 @@ class STTService(AIService):
|
||||
"""
|
||||
await super().start(frame)
|
||||
self._sample_rate = self._init_sample_rate or frame.audio_in_sample_rate
|
||||
self._tracing_enabled = frame.enable_tracing
|
||||
|
||||
async def _update_settings(self, settings: Mapping[str, Any]):
|
||||
logger.info(f"Updating STT settings: {self._settings}")
|
||||
|
||||
@@ -116,6 +116,7 @@ class TTSService(AIService):
|
||||
self._text_aggregator: BaseTextAggregator = text_aggregator or SimpleTextAggregator()
|
||||
self._text_filters: Sequence[BaseTextFilter] = text_filters or []
|
||||
self._transport_destination: Optional[str] = transport_destination
|
||||
self._tracing_enabled: bool = False
|
||||
|
||||
if text_filter:
|
||||
import warnings
|
||||
@@ -224,6 +225,7 @@ class TTSService(AIService):
|
||||
self._sample_rate = self._init_sample_rate or frame.audio_out_sample_rate
|
||||
if self._push_stop_frames and not self._stop_frame_task:
|
||||
self._stop_frame_task = self.create_task(self._stop_frame_handler())
|
||||
self._tracing_enabled = frame.enable_tracing
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the TTS service.
|
||||
|
||||
@@ -43,7 +43,7 @@ class WebsocketService(ABC):
|
||||
True if connection is verified working, False otherwise.
|
||||
"""
|
||||
try:
|
||||
if not self._websocket or self._websocket.closed:
|
||||
if not self._websocket or self._websocket.state is State.CLOSED:
|
||||
return False
|
||||
await self._websocket.ping()
|
||||
return True
|
||||
@@ -82,7 +82,7 @@ class WebsocketService(ABC):
|
||||
try:
|
||||
await self._receive_messages()
|
||||
retry_count = 0 # Reset counter on successful message receive
|
||||
if self._websocket and self._websocket.state == State.CLOSED:
|
||||
if self._websocket and self._websocket.state is State.CLOSED:
|
||||
raise websockets.ConnectionClosedOK(
|
||||
self._websocket.close_rcvd,
|
||||
self._websocket.close_sent,
|
||||
|
||||
@@ -20,6 +20,7 @@ from typing import Awaitable, Callable, Optional
|
||||
import websockets
|
||||
from loguru import logger
|
||||
from pydantic.main import BaseModel
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
@@ -129,7 +130,7 @@ class WebsocketClientSession:
|
||||
return
|
||||
|
||||
try:
|
||||
self._websocket = await websockets.connect(uri=self._uri, open_timeout=10)
|
||||
self._websocket = await websocket_connect(uri=self._uri, open_timeout=10)
|
||||
self._client_task = self.task_manager.create_task(
|
||||
self._client_task_handler(),
|
||||
f"{self._transport_name}::WebsocketClientSession::_client_task_handler",
|
||||
|
||||
@@ -39,6 +39,8 @@ from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
|
||||
try:
|
||||
import websockets
|
||||
from websockets.asyncio.server import serve as websocket_serve
|
||||
from websockets.protocol import State
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use websockets, you need to `pip install pipecat-ai[websocket]`.")
|
||||
@@ -177,11 +179,11 @@ class WebsocketServerInputTransport(BaseInputTransport):
|
||||
async def _server_task_handler(self):
|
||||
"""Handle WebSocket server startup and client connections."""
|
||||
logger.info(f"Starting websocket server on {self._host}:{self._port}")
|
||||
async with websockets.serve(self._client_handler, self._host, self._port) as server:
|
||||
async with websocket_serve(self._client_handler, self._host, self._port) as server:
|
||||
await self._callbacks.on_websocket_ready()
|
||||
await self._stop_server_event.wait()
|
||||
|
||||
async def _client_handler(self, websocket: websockets.WebSocketServerProtocol, path):
|
||||
async def _client_handler(self, websocket: websockets.WebSocketServerProtocol):
|
||||
"""Handle individual client connections and message processing."""
|
||||
logger.info(f"New client connection from {websocket.remote_address}")
|
||||
if self._websocket:
|
||||
@@ -231,7 +233,7 @@ class WebsocketServerInputTransport(BaseInputTransport):
|
||||
"""Monitor WebSocket connection for session timeout."""
|
||||
try:
|
||||
await asyncio.sleep(session_timeout)
|
||||
if not websocket.closed:
|
||||
if websocket.state is not State.CLOSED:
|
||||
await self._callbacks.on_session_timeout(websocket)
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Monitoring task cancelled for: {websocket.remote_address}")
|
||||
|
||||
@@ -62,6 +62,9 @@ try:
|
||||
VirtualCameraDevice,
|
||||
VirtualSpeakerDevice,
|
||||
)
|
||||
from daily import (
|
||||
LogLevel as DailyLogLevel,
|
||||
)
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
@@ -1924,6 +1927,18 @@ class DailyTransport(BaseTransport):
|
||||
"""
|
||||
return self._client.participant_id
|
||||
|
||||
def set_log_level(self, level: DailyLogLevel):
|
||||
"""Set the logging level for Daily's internal logging system.
|
||||
|
||||
Args:
|
||||
level: The log level to set. Should be a member of the DailyLogLevel enum,
|
||||
such as DailyLogLevel.Info, DailyLogLevel.Debug, etc.
|
||||
|
||||
Example:
|
||||
transport.set_log_level(DailyLogLevel.Info)
|
||||
"""
|
||||
Daily.set_log_level(level)
|
||||
|
||||
async def send_image(self, frame: OutputImageRawFrame | SpriteFrame):
|
||||
"""Send an image frame to the Daily call.
|
||||
|
||||
|
||||
@@ -439,6 +439,7 @@ class LiveKitTransportClient:
|
||||
self._process_audio_stream(audio_stream, participant.sid),
|
||||
f"{self}::_process_audio_stream",
|
||||
)
|
||||
await self._callbacks.on_audio_track_subscribed(participant.sid)
|
||||
|
||||
async def _async_on_track_unsubscribed(
|
||||
self,
|
||||
|
||||
@@ -9,29 +9,72 @@
|
||||
This module provides utilities for natural language text processing including
|
||||
sentence boundary detection, email and number pattern handling, and XML-style
|
||||
tag parsing for structured text content.
|
||||
|
||||
Dependencies:
|
||||
This module uses NLTK (Natural Language Toolkit) for robust sentence
|
||||
tokenization. NLTK is licensed under the Apache License 2.0.
|
||||
See: https://www.nltk.org/
|
||||
Source: https://www.nltk.org/api/nltk.tokenize.punkt.html
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Optional, Sequence, Tuple
|
||||
from typing import FrozenSet, Optional, Sequence, Tuple
|
||||
|
||||
ENDOFSENTENCE_PATTERN_STR = r"""
|
||||
(?<![A-Z]) # Negative lookbehind: not preceded by an uppercase letter (e.g., "U.S.A.")
|
||||
(?<!\d\.\d) # Not preceded by a decimal number (e.g., "3.14159")
|
||||
(?<!^\d\.) # Not preceded by a numbered list item (e.g., "1. Let's start")
|
||||
(?<!\d\s[ap]) # Negative lookbehind: not preceded by time (e.g., "3:00 a.m.")
|
||||
(?<!Mr|Ms|Dr) # Negative lookbehind: not preceded by Mr, Ms, Dr (combined bc. length is the same)
|
||||
(?<!Mrs) # Negative lookbehind: not preceded by "Mrs"
|
||||
(?<!Prof) # Negative lookbehind: not preceded by "Prof"
|
||||
(\.\s*\.\s*\.|[\.\?\!;])| # Match a period, question mark, exclamation point, or semicolon
|
||||
(\。\s*\。\s*\。|[。?!;।]) # the full-width version (mainly used in East Asian languages such as Chinese, Hindi)
|
||||
$ # End of string
|
||||
"""
|
||||
import nltk
|
||||
from nltk.tokenize import sent_tokenize
|
||||
|
||||
ENDOFSENTENCE_PATTERN = re.compile(ENDOFSENTENCE_PATTERN_STR, re.VERBOSE)
|
||||
# Ensure punkt_tab tokenizer data is available
|
||||
try:
|
||||
nltk.data.find("tokenizers/punkt_tab")
|
||||
except LookupError:
|
||||
nltk.download("punkt_tab", quiet=True)
|
||||
|
||||
EMAIL_PATTERN = re.compile(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}")
|
||||
|
||||
NUMBER_PATTERN = re.compile(r"[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)?")
|
||||
SENTENCE_ENDING_PUNCTUATION: FrozenSet[str] = frozenset(
|
||||
{
|
||||
# Latin script punctuation (most European languages, Filipino, etc.)
|
||||
".",
|
||||
"!",
|
||||
"?",
|
||||
";",
|
||||
# East Asian punctuation (Chinese (Traditional & Simplified), Japanese, Korean)
|
||||
"。", # Ideographic full stop
|
||||
"?", # Full-width question mark
|
||||
"!", # Full-width exclamation mark
|
||||
";", # Full-width semicolon
|
||||
".", # Full-width period
|
||||
"。", # Halfwidth ideographic period
|
||||
# Indic scripts punctuation (Hindi, Sanskrit, Marathi, Nepali, Bengali, Tamil, Telugu, Kannada, Malayalam, Gujarati, Punjabi, Oriya, Assamese)
|
||||
"।", # Devanagari danda (single vertical bar)
|
||||
"॥", # Devanagari double danda (double vertical bar)
|
||||
# Arabic script punctuation (Arabic, Persian, Urdu, Pashto)
|
||||
"؟", # Arabic question mark
|
||||
"؛", # Arabic semicolon
|
||||
"۔", # Urdu full stop
|
||||
"؏", # Arabic sign misra (classical texts)
|
||||
# Thai
|
||||
"।", # Thai uses Devanagari-style punctuation in some contexts
|
||||
# Myanmar/Burmese
|
||||
"၊", # Myanmar sign little section
|
||||
"။", # Myanmar sign section
|
||||
# Khmer
|
||||
"។", # Khmer sign khan
|
||||
"៕", # Khmer sign bariyoosan
|
||||
# Lao
|
||||
"໌", # Lao cancellation mark (used as period)
|
||||
"༎", # Tibetan mark delimiter tsheg bstar (also used in Lao contexts)
|
||||
# Tibetan
|
||||
"།", # Tibetan mark intersyllabic tsheg
|
||||
"༎", # Tibetan mark delimiter tsheg bstar
|
||||
# Armenian
|
||||
"։", # Armenian full stop
|
||||
"՜", # Armenian exclamation mark
|
||||
"՞", # Armenian question mark
|
||||
# Ethiopic script (Amharic)
|
||||
"።", # Ethiopic full stop
|
||||
"፧", # Ethiopic question mark
|
||||
"፨", # Ethiopic paragraph separator
|
||||
}
|
||||
)
|
||||
|
||||
StartEndTags = Tuple[str, str]
|
||||
|
||||
@@ -58,10 +101,9 @@ def replace_match(text: str, match: re.Match, old: str, new: str) -> str:
|
||||
def match_endofsentence(text: str) -> int:
|
||||
"""Find the position of the end of a sentence in the provided text.
|
||||
|
||||
This function processes the input text by replacing periods in email
|
||||
addresses and numbers with ampersands to prevent them from being
|
||||
misidentified as sentence terminals. It then searches for the end of a
|
||||
sentence using a specified regex pattern.
|
||||
This function uses NLTK's sentence tokenizer to detect sentence boundaries
|
||||
in the input text, combined with punctuation verification to ensure that
|
||||
single tokens without proper sentence endings aren't considered complete sentences.
|
||||
|
||||
Args:
|
||||
text: The input text in which to find the end of the sentence.
|
||||
@@ -71,21 +113,33 @@ def match_endofsentence(text: str) -> int:
|
||||
"""
|
||||
text = text.rstrip()
|
||||
|
||||
# Replace email dots by ampersands so we can find the end of sentence. For
|
||||
# example, first.last@email.com becomes first&last@email&com.
|
||||
emails = list(EMAIL_PATTERN.finditer(text))
|
||||
for email_match in emails:
|
||||
text = replace_match(text, email_match, ".", "&")
|
||||
if not text:
|
||||
return 0
|
||||
|
||||
# Replace number dots by ampersands so we can find the end of sentence.
|
||||
numbers = list(NUMBER_PATTERN.finditer(text))
|
||||
for number_match in numbers:
|
||||
text = replace_match(text, number_match, ".", "&")
|
||||
# Use NLTK's sentence tokenizer to find sentence boundaries
|
||||
sentences = sent_tokenize(text)
|
||||
|
||||
# Match against the new text.
|
||||
match = ENDOFSENTENCE_PATTERN.search(text)
|
||||
if not sentences:
|
||||
return 0
|
||||
|
||||
return match.end() if match else 0
|
||||
first_sentence = sentences[0]
|
||||
|
||||
# If there's only one sentence that equals the entire text,
|
||||
# verify it actually ends with sentence-ending punctuation.
|
||||
# This is required as NLTK may return a single sentence for
|
||||
# text that's a single word. In the case of LLM tokens, it's
|
||||
# common for text to be single words, so we need to ensure
|
||||
# sentence-ending punctuation is present.
|
||||
if len(sentences) == 1 and first_sentence == text:
|
||||
return len(text) if text and text[-1] in SENTENCE_ENDING_PUNCTUATION else 0
|
||||
|
||||
# If there are multiple sentences, the first one is complete by definition
|
||||
# (NLTK found a boundary, so there must be proper punctuation)
|
||||
if len(sentences) > 1:
|
||||
return len(first_sentence)
|
||||
|
||||
# Single sentence that doesn't equal the full text means incomplete
|
||||
return 0
|
||||
|
||||
|
||||
def parse_start_end_tags(
|
||||
|
||||
@@ -134,7 +134,8 @@ def traced_tts(func: Optional[Callable] = None, *, name: Optional[str] = None) -
|
||||
Yields:
|
||||
The active span for the TTS operation.
|
||||
"""
|
||||
if not is_tracing_available():
|
||||
# Check if tracing is enabled for this service instance
|
||||
if not getattr(self, "_tracing_enabled", False):
|
||||
yield None
|
||||
return
|
||||
|
||||
@@ -178,7 +179,8 @@ def traced_tts(func: Optional[Callable] = None, *, name: Optional[str] = None) -
|
||||
@functools.wraps(f)
|
||||
async def gen_wrapper(self, text, *args, **kwargs):
|
||||
try:
|
||||
if not is_tracing_available():
|
||||
# Check if tracing is enabled for this service instance
|
||||
if not getattr(self, "_tracing_enabled", False):
|
||||
async for item in f(self, text, *args, **kwargs):
|
||||
yield item
|
||||
return
|
||||
@@ -198,7 +200,8 @@ def traced_tts(func: Optional[Callable] = None, *, name: Optional[str] = None) -
|
||||
@functools.wraps(f)
|
||||
async def wrapper(self, text, *args, **kwargs):
|
||||
try:
|
||||
if not is_tracing_available():
|
||||
# Check if tracing is enabled for this service instance
|
||||
if not getattr(self, "_tracing_enabled", False):
|
||||
return await f(self, text, *args, **kwargs)
|
||||
|
||||
async with tracing_context(self, text):
|
||||
@@ -239,7 +242,8 @@ def traced_stt(func: Optional[Callable] = None, *, name: Optional[str] = None) -
|
||||
@functools.wraps(f)
|
||||
async def wrapper(self, transcript, is_final, language=None):
|
||||
try:
|
||||
if not is_tracing_available():
|
||||
# Check if tracing is enabled for this service instance
|
||||
if not getattr(self, "_tracing_enabled", False):
|
||||
return await f(self, transcript, is_final, language)
|
||||
|
||||
service_class_name = self.__class__.__name__
|
||||
@@ -320,7 +324,8 @@ def traced_llm(func: Optional[Callable] = None, *, name: Optional[str] = None) -
|
||||
@functools.wraps(f)
|
||||
async def wrapper(self, context, *args, **kwargs):
|
||||
try:
|
||||
if not is_tracing_available():
|
||||
# Check if tracing is enabled for this service instance
|
||||
if not getattr(self, "_tracing_enabled", False):
|
||||
return await f(self, context, *args, **kwargs)
|
||||
|
||||
service_class_name = self.__class__.__name__
|
||||
@@ -522,7 +527,8 @@ def traced_gemini_live(operation: str) -> Callable:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
try:
|
||||
if not is_tracing_available():
|
||||
# Check if tracing is enabled for this service instance
|
||||
if not getattr(self, "_tracing_enabled", False):
|
||||
return await func(self, *args, **kwargs)
|
||||
|
||||
service_class_name = self.__class__.__name__
|
||||
@@ -826,7 +832,8 @@ def traced_openai_realtime(operation: str) -> Callable:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
try:
|
||||
if not is_tracing_available():
|
||||
# Check if tracing is enabled for this service instance
|
||||
if not getattr(self, "_tracing_enabled", False):
|
||||
return await func(self, *args, **kwargs)
|
||||
|
||||
service_class_name = self.__class__.__name__
|
||||
|
||||
@@ -16,10 +16,13 @@ class TestUtilsString(unittest.IsolatedAsyncioTestCase):
|
||||
assert match_endofsentence("This is a sentence?") == 19
|
||||
assert match_endofsentence("This is a sentence;") == 19
|
||||
assert match_endofsentence("This is a sentence...") == 21
|
||||
assert match_endofsentence("This is a sentence . . .") == 24
|
||||
assert match_endofsentence("This is a sentence. ..") == 22
|
||||
assert match_endofsentence("This is a sentence. This is another one") == 19
|
||||
assert match_endofsentence("This is for Mr. and Mrs. Jones.") == 31
|
||||
assert match_endofsentence("U.S.A and U.S.A..") == 17
|
||||
assert match_endofsentence("Meet the new Mr. and Mrs.") == 25
|
||||
assert match_endofsentence("U.S.A. and N.A.S.A.") == 19
|
||||
assert match_endofsentence("USA and NASA.") == 13
|
||||
assert match_endofsentence("My number is 123-456-7890.") == 26
|
||||
assert match_endofsentence("For information, call 411.") == 26
|
||||
assert match_endofsentence("My emails are foo@pipecat.ai and bar@pipecat.ai.") == 48
|
||||
assert match_endofsentence("My email is foo.bar@pipecat.ai.") == 31
|
||||
assert match_endofsentence("My email is spell(foo.bar@pipecat.ai).") == 38
|
||||
@@ -27,41 +30,162 @@ class TestUtilsString(unittest.IsolatedAsyncioTestCase):
|
||||
assert match_endofsentence("The number pi is 3.14159.") == 25
|
||||
assert match_endofsentence("Valid scientific notation 1.23e4.") == 33
|
||||
assert match_endofsentence("Valid scientific notation 0.e4.") == 31
|
||||
assert match_endofsentence("It still early, it's 3:00 a.m.") == 30
|
||||
assert not match_endofsentence("This is not a sentence")
|
||||
assert not match_endofsentence("This is not a sentence,")
|
||||
assert not match_endofsentence("This is not a sentence, ")
|
||||
assert not match_endofsentence("Ok, Mr. Smith let's ")
|
||||
assert not match_endofsentence("Dr. Walker, I presume ")
|
||||
assert not match_endofsentence("Prof. Walker, I presume ")
|
||||
assert not match_endofsentence("zweitens, und 3.")
|
||||
assert not match_endofsentence("Heute ist Dienstag, der 3.") # 3. Juli 2024
|
||||
assert not match_endofsentence("America, or the U.") # U.S.A.
|
||||
assert not match_endofsentence("It still early, it's 3:00 a.") # 3:00 a.m.
|
||||
assert not match_endofsentence("zweitens, und 3")
|
||||
assert not match_endofsentence("Heute ist Dienstag, der 3") # 3. Juli 2024
|
||||
assert not match_endofsentence("America, or the U.S") # U.S.A.
|
||||
assert not match_endofsentence("My emails are foo@pipecat.ai and bar@pipecat.ai")
|
||||
assert not match_endofsentence("The number pi is 3.14159")
|
||||
|
||||
async def test_endofsentence_zh(self):
|
||||
async def test_endofsentence_multilingual(self):
|
||||
"""Test sentence detection across various language families and scripts."""
|
||||
|
||||
# Arabic script (Arabic, Urdu, Persian)
|
||||
arabic_sentences = [
|
||||
"مرحبا؟", # Arabic question mark
|
||||
"السلام عليكم؛", # Arabic semicolon
|
||||
"یہ اردو ہے۔", # Urdu full stop
|
||||
]
|
||||
for sentence in arabic_sentences:
|
||||
assert match_endofsentence(sentence), f"Failed for Arabic/Urdu: {sentence}"
|
||||
|
||||
# Should not match incomplete Arabic
|
||||
assert not match_endofsentence("مرحبا،"), "Arabic comma should not end sentence"
|
||||
|
||||
chinese_sentences = [
|
||||
"你好。",
|
||||
"你好!",
|
||||
"吃了吗?",
|
||||
"安全第一;",
|
||||
]
|
||||
for i in chinese_sentences:
|
||||
assert match_endofsentence(i)
|
||||
for sentence in chinese_sentences:
|
||||
assert match_endofsentence(sentence), f"Failed for Chinese: {sentence}"
|
||||
assert not match_endofsentence("你好,")
|
||||
|
||||
async def test_endofsentence_hi(self):
|
||||
hindi_sentences = [
|
||||
"हैलो।",
|
||||
"हैलो!",
|
||||
"आप खाये हैं?",
|
||||
"सुरक्षा पहले।",
|
||||
]
|
||||
for i in hindi_sentences:
|
||||
assert match_endofsentence(i)
|
||||
for sentence in hindi_sentences:
|
||||
assert match_endofsentence(sentence), f"Failed for Hindi: {sentence}"
|
||||
assert not match_endofsentence("हैलो,")
|
||||
|
||||
# East Asian (Japanese, Korean)
|
||||
japanese_sentences = [
|
||||
"こんにちは。", # Japanese
|
||||
"元気ですか?", # Japanese question
|
||||
"ありがとう!", # Japanese exclamation
|
||||
]
|
||||
for sentence in japanese_sentences:
|
||||
assert match_endofsentence(sentence), f"Failed for Japanese: {sentence}"
|
||||
|
||||
korean_sentences = [
|
||||
"안녕하세요。", # Korean with ideographic period
|
||||
"어떻게 지내세요?", # Korean question
|
||||
]
|
||||
for sentence in korean_sentences:
|
||||
assert match_endofsentence(sentence), f"Failed for Korean: {sentence}"
|
||||
|
||||
# Southeast Asian scripts
|
||||
thai_sentences = [
|
||||
"สวัสดี।", # Thai with Devanagari-style punctuation
|
||||
]
|
||||
for sentence in thai_sentences:
|
||||
assert match_endofsentence(sentence), f"Failed for Thai: {sentence}"
|
||||
|
||||
myanmar_sentences = [
|
||||
"မင်္ဂလာပါ၊", # Myanmar little section
|
||||
"ကျေးဇူးတင်ပါတယ်။", # Myanmar section
|
||||
]
|
||||
for sentence in myanmar_sentences:
|
||||
assert match_endofsentence(sentence), f"Failed for Myanmar: {sentence}"
|
||||
|
||||
# Other Indic scripts (same punctuation as Hindi but different scripts)
|
||||
bengali_sentences = [
|
||||
"নমস্কার।", # Bengali
|
||||
"আপনি কেমন আছেন?", # Bengali question (uses Latin ?)
|
||||
]
|
||||
for sentence in bengali_sentences:
|
||||
assert match_endofsentence(sentence), f"Failed for Bengali: {sentence}"
|
||||
|
||||
tamil_sentences = [
|
||||
"வணக்கம்।", # Tamil
|
||||
"நீங்கள் எப்படி இருக்கிறீர்கள்?", # Tamil question
|
||||
]
|
||||
for sentence in tamil_sentences:
|
||||
assert match_endofsentence(sentence), f"Failed for Tamil: {sentence}"
|
||||
|
||||
# Armenian
|
||||
armenian_sentences = [
|
||||
"Բարև։", # Armenian full stop
|
||||
"Ինչպես եք՞", # Armenian question mark
|
||||
"Շնորհակալություն՜", # Armenian exclamation
|
||||
]
|
||||
for sentence in armenian_sentences:
|
||||
assert match_endofsentence(sentence), f"Failed for Armenian: {sentence}"
|
||||
|
||||
# Ethiopic (Amharic)
|
||||
amharic_sentences = [
|
||||
"ሰላም።", # Ethiopic full stop
|
||||
"እንዴት ነዎት፧", # Ethiopic question mark
|
||||
]
|
||||
for sentence in amharic_sentences:
|
||||
assert match_endofsentence(sentence), f"Failed for Amharic: {sentence}"
|
||||
|
||||
# Languages using Latin punctuation (should still work)
|
||||
latin_script_sentences = [
|
||||
"Hola.", # Spanish
|
||||
"Bonjour!", # French
|
||||
"Guten Tag?", # German
|
||||
"Привет.", # Russian (Cyrillic but uses Latin punctuation)
|
||||
"Γεια σας.", # Greek
|
||||
"שלום.", # Hebrew
|
||||
"გამარჯობა.", # Georgian
|
||||
]
|
||||
for sentence in latin_script_sentences:
|
||||
assert match_endofsentence(sentence), f"Failed for Latin script: {sentence}"
|
||||
|
||||
async def test_endofsentence_streaming_tokens(self):
|
||||
"""Test the specific use case of streaming LLM tokens."""
|
||||
|
||||
# These are the scenarios that were problematic with the original regex
|
||||
# Single tokens should not be considered complete sentences
|
||||
assert not match_endofsentence("Hello"), "Single token should not be sentence"
|
||||
assert not match_endofsentence("world"), "Single token should not be sentence"
|
||||
assert not match_endofsentence("The"), "Single token should not be sentence"
|
||||
assert not match_endofsentence("quick"), "Single token should not be sentence"
|
||||
|
||||
# But accumulating tokens should eventually form sentences
|
||||
assert not match_endofsentence("Hello world"), "No punctuation = incomplete"
|
||||
assert match_endofsentence("Hello world.") == 12, "With punctuation = complete"
|
||||
|
||||
# Test progressive building (simulating token streaming)
|
||||
tokens = ["The", " quick", " brown", " fox", " jumps", "."]
|
||||
accumulated = ""
|
||||
for i, token in enumerate(tokens):
|
||||
accumulated += token
|
||||
if i < len(tokens) - 1: # All but the last token
|
||||
assert not match_endofsentence(accumulated), (
|
||||
f"Should be incomplete at token {i}: '{accumulated}'"
|
||||
)
|
||||
else: # Last token adds the period
|
||||
assert match_endofsentence(accumulated) == len(accumulated), (
|
||||
f"Should be complete: '{accumulated}'"
|
||||
)
|
||||
|
||||
# Test with multiple sentences
|
||||
assert match_endofsentence("First sentence. Second incomplete") == 15, (
|
||||
"Should return end of first sentence"
|
||||
)
|
||||
|
||||
|
||||
class TestStartEndTags(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_empty(self):
|
||||
|
||||
Reference in New Issue
Block a user