Compare commits
59 Commits
filipi/asy
...
vp-fix/mcp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a4b66aedc1 | ||
|
|
dc909e2713 | ||
|
|
e22f9f84bb | ||
|
|
57068f1b38 | ||
|
|
a29be38f48 | ||
|
|
976c644f90 | ||
|
|
34aa37f395 | ||
|
|
380867a87a | ||
|
|
cc3af59db4 | ||
|
|
f93d13efff | ||
|
|
c28b7e8f26 | ||
|
|
d1a2dee7a1 | ||
|
|
da1a1a59a4 | ||
|
|
134790b17c | ||
|
|
e5aa3bbc20 | ||
|
|
3be0ea05ef | ||
|
|
0c59819682 | ||
|
|
5b67dcd9e7 | ||
|
|
d503383c23 | ||
|
|
fa30268b84 | ||
|
|
2a118084bd | ||
|
|
87e8ed109a | ||
|
|
a5e1bbf4a3 | ||
|
|
f8267f1ea6 | ||
|
|
74acb0b7d0 | ||
|
|
41e3afbc2f | ||
|
|
d4824ffe8a | ||
|
|
2426f80789 | ||
|
|
5ce46df599 | ||
|
|
a6013ba437 | ||
|
|
279ca5a87b | ||
|
|
c6f79592d8 | ||
|
|
e74e497b8d | ||
|
|
d245b79bba | ||
|
|
8a794424dd | ||
|
|
f4743a6c91 | ||
|
|
ba32a48510 | ||
|
|
a9cafa2a3b | ||
|
|
58b1b7249e | ||
|
|
db8e73e5ca | ||
|
|
170f6dfe8b | ||
|
|
c763abc4ae | ||
|
|
197d96fc49 | ||
|
|
c8e9bf77fd | ||
|
|
48b25962e2 | ||
|
|
5d093c9ad7 | ||
|
|
d93f63deb5 | ||
|
|
09a57972f5 | ||
|
|
f83d062df9 | ||
|
|
a2a42b8703 | ||
|
|
e60a72e2d4 | ||
|
|
83f4989a78 | ||
|
|
5d2b288274 | ||
|
|
52ece87ac9 | ||
|
|
bc4bbb1895 | ||
|
|
eb014fffc4 | ||
|
|
e74930b954 | ||
|
|
6ed4109da9 | ||
|
|
53f809b7d5 |
2
.github/workflows/python-compatibility.yaml
vendored
2
.github/workflows/python-compatibility.yaml
vendored
@@ -14,7 +14,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ['3.10.19', '3.11.14', '3.12.12', '3.13.12']
|
||||
python-version: ['3.11.15', '3.12.13', '3.13.12', '3.14.3']
|
||||
|
||||
name: Python ${{ matrix.python-version }}
|
||||
steps:
|
||||
|
||||
@@ -149,8 +149,8 @@ You can get started with Pipecat running on your local machine, then move your a
|
||||
|
||||
### Prerequisites
|
||||
|
||||
**Minimum Python Version:** 3.10
|
||||
**Recommended Python Version:** 3.12
|
||||
**Minimum Python Version:** 3.11
|
||||
**Recommended Python Version:** >= 3.12
|
||||
|
||||
### Setup Steps
|
||||
|
||||
|
||||
1
changelog/3984.changed.md
Normal file
1
changelog/3984.changed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Updated `onnxruntime` from 1.23.2 to 1.24.3, adding support for Python 3.14.
|
||||
1
changelog/4034.changed.md
Normal file
1
changelog/4034.changed.md
Normal file
@@ -0,0 +1 @@
|
||||
- MCPClient now requires async with MCPClient(...) as mcp: or explicit start()/close() calls to manage the connection lifecycle.
|
||||
1
changelog/4034.fixed.md
Normal file
1
changelog/4034.fixed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Fixed MCPClient opening a new connection for every tool call instead of reusing the session.
|
||||
1
changelog/4219.added.md
Normal file
1
changelog/4219.added.md
Normal file
@@ -0,0 +1 @@
|
||||
- Added `enable_prompt_caching` setting to `AWSBedrockLLMService` for Bedrock ConverseStream prompt caching.
|
||||
1
changelog/4220.fixed.md
Normal file
1
changelog/4220.fixed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Fixed `CartesiaTTSService` failing with "Context has closed" errors when switching voice, model, or language via `TTSUpdateSettingsFrame`. The service now automatically flushes the current audio context and opens a fresh one when these settings change.
|
||||
13
changelog/4220.removed.md
Normal file
13
changelog/4220.removed.md
Normal file
@@ -0,0 +1,13 @@
|
||||
- ⚠️ Removed deprecated service parameters and shims that have been replaced by the `settings=Service.Settings(...)` pattern or direct `__init__` parameters:
|
||||
- `PollyTTSService` alias (use `AWSTTSService`)
|
||||
- `TTSService`: `text_aggregator`, `text_filter` init params
|
||||
- `AWSNovaSonicLLMService`: `send_transcription_frames` init param
|
||||
- `DeepgramSTTService`: `url` init param (use `base_url`)
|
||||
- `FishAudioTTSService`: `model` init param (use `reference_id` or `settings`)
|
||||
- `GladiaSTTService`: `language` and `confidence` from `GladiaInputParams`, `InputParams` class alias
|
||||
- `GeminiTTSService`: `api_key` init param
|
||||
- `GeminiLiveLLMService`: `base_url` init param (use `http_options`)
|
||||
- `GoogleVertexLLMService`: `InputParams` class with `location`/`project_id` fields (use direct init params); `project_id` is now required, `location` defaults to `"us-east4"`
|
||||
- `MiniMaxHttpTTSService`: `english_normalization` from `InputParams` (use `text_normalization`)
|
||||
- `SimliVideoService`: `simli_config` init param (use `api_key`/`face_id`), `use_turn_server` init param; `api_key` and `face_id` are now required
|
||||
- `AnthropicLLMService`: `enable_prompt_caching_beta` from `InputParams` (use `enable_prompt_caching`)
|
||||
1
changelog/4224.changed.md
Normal file
1
changelog/4224.changed.md
Normal file
@@ -0,0 +1 @@
|
||||
- ⚠️ `LLMService.function_call_timeout_secs` now defaults to `None` instead of `10.0`. Deferred function calls will run indefinitely unless a timeout is explicitly set at the service level or per-call. If you relied on the previous 10-second default, pass `function_call_timeout_secs=10.0` explicitly.
|
||||
1
changelog/4225.removed.2.md
Normal file
1
changelog/4225.removed.2.md
Normal file
@@ -0,0 +1 @@
|
||||
- ⚠️ Removed deprecated `pipecat.sync` package. Use `pipecat.utils.sync` instead.
|
||||
1
changelog/4225.removed.md
Normal file
1
changelog/4225.removed.md
Normal file
@@ -0,0 +1 @@
|
||||
- ⚠️ Removed deprecated `pipecat.transports.services` and `pipecat.transports.network` module aliases. Update imports to use `pipecat.transports.daily.transport`, `pipecat.transports.livekit.transport`, `pipecat.transports.websocket.*`, `pipecat.transports.webrtc.*`, and `pipecat.transports.daily.utils` respectively.
|
||||
1
changelog/4228.removed.10.md
Normal file
1
changelog/4228.removed.10.md
Normal file
@@ -0,0 +1 @@
|
||||
- ⚠️ Removed deprecated `add_pattern_pair` method from `PatternPairAggregator`. Use `add_pattern` instead.
|
||||
1
changelog/4228.removed.2.md
Normal file
1
changelog/4228.removed.2.md
Normal file
@@ -0,0 +1 @@
|
||||
- ⚠️ Removed deprecated `interruption_strategies` parameter from `PipelineParams`, `StartFrame`, and `FrameProcessor`. Use `LLMUserAggregator`'s `user_turn_strategies` parameter instead.
|
||||
1
changelog/4228.removed.3.md
Normal file
1
changelog/4228.removed.3.md
Normal file
@@ -0,0 +1 @@
|
||||
- ⚠️ Removed deprecated `EmulateUserStartedSpeakingFrame` and `EmulateUserStoppedSpeakingFrame` frames, and the `emulated` field from `UserStartedSpeakingFrame` / `UserStoppedSpeakingFrame`.
|
||||
1
changelog/4228.removed.4.md
Normal file
1
changelog/4228.removed.4.md
Normal file
@@ -0,0 +1 @@
|
||||
- ⚠️ Removed deprecated `pipecat.audio.interruptions` module (`BaseInterruptionStrategy`, `MinWordsInterruptionStrategy`). Use `pipecat.turns.user_start.MinWordsUserTurnStartStrategy` with `LLMUserAggregator`'s `user_turn_strategies` parameter instead.
|
||||
1
changelog/4228.removed.5.md
Normal file
1
changelog/4228.removed.5.md
Normal file
@@ -0,0 +1 @@
|
||||
- ⚠️ Removed deprecated `pipecat.processors.transcript_processor` module (`TranscriptProcessor`, `TranscriptProcessorConfig`). Use pipeline observers instead.
|
||||
1
changelog/4228.removed.6.md
Normal file
1
changelog/4228.removed.6.md
Normal file
@@ -0,0 +1 @@
|
||||
- ⚠️ Removed deprecated `TranscriptionMessage`, `ThoughtTranscriptionMessage`, and `TranscriptionUpdateFrame` from `pipecat.frames.frames`.
|
||||
1
changelog/4228.removed.7.md
Normal file
1
changelog/4228.removed.7.md
Normal file
@@ -0,0 +1 @@
|
||||
- ⚠️ Removed deprecated `STTMuteFilter`, `STTMuteConfig`, and `STTMuteStrategy` from `pipecat.processors.filters.stt_mute_filter`. Use `pipecat.turns.user_mute` strategies with `LLMUserAggregator`'s `user_mute_strategies` parameter instead.
|
||||
1
changelog/4228.removed.8.md
Normal file
1
changelog/4228.removed.8.md
Normal file
@@ -0,0 +1 @@
|
||||
- ⚠️ Removed deprecated `UserResponseAggregator` class from `pipecat.processors.aggregators.user_response`. Use `LLMUserAggregator` instead.
|
||||
1
changelog/4228.removed.9.md
Normal file
1
changelog/4228.removed.9.md
Normal file
@@ -0,0 +1 @@
|
||||
- ⚠️ Removed deprecated `pipecat.utils.tracing.class_decorators` module. Use `pipecat.utils.tracing.service_decorators` instead.
|
||||
1
changelog/4228.removed.md
Normal file
1
changelog/4228.removed.md
Normal file
@@ -0,0 +1 @@
|
||||
- ⚠️ Removed deprecated `allow_interruptions` parameter from `PipelineParams`, `StartFrame`, and `FrameProcessor`. Interruptions are now always allowed by default. Use `LLMUserAggregator`'s `user_turn_strategies` / `user_mute_strategies` parameters to control interruption behavior.
|
||||
1
changelog/4229.removed.2.md
Normal file
1
changelog/4229.removed.2.md
Normal file
@@ -0,0 +1 @@
|
||||
- ⚠️ Removed `ExternalUserTurnStrategies` and the automatic fallback to it in `LLMUserAggregator` when a `SpeechControlParamsFrame` was received from the transport.
|
||||
1
changelog/4229.removed.md
Normal file
1
changelog/4229.removed.md
Normal file
@@ -0,0 +1 @@
|
||||
- ⚠️ Removed `vad_analyzer` and `turn_analyzer` parameters from `TransportParams` and all transport input classes, along with all deprecated VAD/turn analysis logic in `BaseInputTransport`. VAD and turn detection are now handled entirely by `LLMUserAggregator`.
|
||||
@@ -45,7 +45,7 @@ from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.frames.frames import LLMRunFrame, TTSUpdateSettingsFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
@@ -54,6 +54,7 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_text_processor import LLMTextProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
@@ -100,39 +101,43 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
# Create pattern pair aggregator for voice switching
|
||||
pattern_aggregator = PatternPairAggregator()
|
||||
llm_text_aggregator = PatternPairAggregator()
|
||||
|
||||
# Add pattern for voice switching
|
||||
pattern_aggregator.add_pattern(
|
||||
llm_text_aggregator.add_pattern(
|
||||
type="voice",
|
||||
start_pattern="<voice>",
|
||||
end_pattern="</voice>",
|
||||
action=MatchAction.REMOVE, # Remove tags from final text
|
||||
action=MatchAction.AGGREGATE,
|
||||
)
|
||||
|
||||
# Register handler for voice switching
|
||||
async def on_voice_tag(match: PatternMatch):
|
||||
voice_name = match.text.strip().lower()
|
||||
if voice_name in VOICE_IDS:
|
||||
# First flush any existing audio to finish the current context
|
||||
await tts.flush_audio()
|
||||
# Then set the new voice
|
||||
await tts.set_voice(VOICE_IDS[voice_name])
|
||||
await llm_text_processor.push_frame(
|
||||
TTSUpdateSettingsFrame(
|
||||
delta=CartesiaTTSService.Settings(voice=VOICE_IDS[voice_name])
|
||||
)
|
||||
)
|
||||
logger.info(f"Switched to {voice_name} voice")
|
||||
else:
|
||||
logger.warning(f"Unknown voice: {voice_name}")
|
||||
|
||||
pattern_aggregator.on_pattern_match("voice", on_voice_tag)
|
||||
llm_text_aggregator.on_pattern_match("voice", on_voice_tag)
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
# Process LLM text through the pattern aggregator before TTS
|
||||
llm_text_processor = LLMTextProcessor(text_aggregator=llm_text_aggregator)
|
||||
|
||||
# Initialize TTS with narrator voice as default
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
settings=CartesiaTTSService.Settings(
|
||||
voice=VOICE_IDS["narrator"],
|
||||
),
|
||||
text_aggregator=pattern_aggregator,
|
||||
skip_aggregator_types=["voice"], # Skip voice tags in TTS speech
|
||||
)
|
||||
|
||||
# System prompt for storytelling with voice switching
|
||||
@@ -204,7 +209,8 @@ Remember: Use narrator voice for EVERYTHING except the actual quoted dialogue.""
|
||||
stt,
|
||||
user_aggregator,
|
||||
llm,
|
||||
tts, # TTS with pattern aggregator
|
||||
llm_text_processor,
|
||||
tts,
|
||||
transport.output(),
|
||||
assistant_aggregator,
|
||||
]
|
||||
|
||||
@@ -5,27 +5,17 @@
|
||||
#
|
||||
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from mcp import StdioServerParameters
|
||||
from mcp.client.session_group import StreamableHttpParameters
|
||||
from PIL import Image
|
||||
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
FunctionCallResultFrame,
|
||||
LLMRunFrame,
|
||||
URLImageRawFrame,
|
||||
)
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
@@ -34,7 +24,6 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.anthropic.llm import AnthropicLLMService
|
||||
@@ -47,66 +36,16 @@ from pipecat.transports.daily.transport import DailyParams
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
class UrlToImageProcessor(FrameProcessor):
|
||||
def __init__(self, aiohttp_session: aiohttp.ClientSession, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._aiohttp_session = aiohttp_session
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, FunctionCallResultFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
image_url = self.extract_url(frame.result)
|
||||
if image_url:
|
||||
await self.run_image_process(image_url)
|
||||
# sometimes we get multiple image urls- process 1 at a time
|
||||
await asyncio.sleep(1)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
def extract_url(self, text: str):
|
||||
try:
|
||||
data = json.loads(text)
|
||||
if "artObject" in data:
|
||||
return data["artObject"]["webImage"]["url"]
|
||||
if "artworks" in data and len(data["artworks"]):
|
||||
return data["artworks"][0]["webImage"]["url"]
|
||||
except (json.JSONDecodeError, KeyError, TypeError):
|
||||
pass
|
||||
|
||||
async def run_image_process(self, image_url: str):
|
||||
try:
|
||||
logger.debug(f"handling image from url: '{image_url}'")
|
||||
async with self._aiohttp_session.get(image_url) as response:
|
||||
image_stream = io.BytesIO(await response.content.read())
|
||||
image = Image.open(image_stream)
|
||||
image = image.convert("RGB")
|
||||
frame = URLImageRawFrame(
|
||||
url=image_url, image=image.tobytes(), size=image.size, format="RGB"
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
except Exception as e:
|
||||
error_msg = f"Error handling image url {image_url}: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
|
||||
|
||||
# We use lambdas to defer transport parameter creation until the transport
|
||||
# type is selected at runtime.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
video_out_enabled=True,
|
||||
video_out_width=1024,
|
||||
video_out_height=1024,
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
video_out_enabled=True,
|
||||
video_out_width=1024,
|
||||
video_out_height=1024,
|
||||
),
|
||||
}
|
||||
|
||||
@@ -114,85 +53,72 @@ transport_params = {
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
# Create an HTTP session for API calls
|
||||
async with aiohttp.ClientSession() as session:
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
settings=CartesiaTTSService.Settings(
|
||||
voice="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
settings=CartesiaTTSService.Settings(
|
||||
voice="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
),
|
||||
)
|
||||
|
||||
system_prompt = f"""
|
||||
You are a helpful LLM in a voice call.
|
||||
Your goal is to demonstrate your capabilities in a succinct way.
|
||||
You have access to memory tools that let you store and recall information,
|
||||
and tools to answer questions about the user's GitHub repositories and account.
|
||||
Offer to remember things for the user, like their name, preferences, or anything they'd like.
|
||||
You can also recall things you've previously stored.
|
||||
You can also offer to answer users questions about their GitHub repositories and account.
|
||||
Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points.
|
||||
Respond to what the user said in a creative and helpful way.
|
||||
Don't overexplain what you are doing.
|
||||
Just respond with short sentences when you are carrying out tool calls.
|
||||
"""
|
||||
|
||||
llm = AnthropicLLMService(
|
||||
api_key=os.getenv("ANTHROPIC_API_KEY"),
|
||||
settings=AnthropicLLMService.Settings(
|
||||
system_instruction=system_prompt,
|
||||
),
|
||||
)
|
||||
|
||||
async with (
|
||||
# https://github.com/modelcontextprotocol/servers/tree/main/src/memory
|
||||
MCPClient(
|
||||
server_params=StdioServerParameters(
|
||||
command=shutil.which("npx"),
|
||||
args=["-y", "@modelcontextprotocol/server-memory"],
|
||||
# env={"MEMORY_FILE_PATH": "/tmp/pipecat_memory.jsonl"}, # Optional: specify MEMORY_FILE_PATH
|
||||
),
|
||||
)
|
||||
|
||||
system_prompt = f"""
|
||||
You are a helpful LLM in a voice call.
|
||||
Your goal is to demonstrate your capabilities in a succinct way.
|
||||
You have access to tools to search the Rijksmuseum collection and the user's GitHub repositories and account.
|
||||
Offer, for example, to show a floral still life, use the `search_artwork` tool.
|
||||
The tool may respond with a JSON object with an `artworks` array. Choose the art from that array.
|
||||
Once the tool has responded, tell the user the title and use the `open_image_in_browser` tool.
|
||||
You can also offer to answer users questions about their GitHub repositories and account.
|
||||
Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points.
|
||||
Respond to what the user said in a creative and helpful way.
|
||||
Don't overexplain what you are doing.
|
||||
Just respond with short sentences when you are carrying out tool calls.
|
||||
"""
|
||||
|
||||
llm = AnthropicLLMService(
|
||||
api_key=os.getenv("ANTHROPIC_API_KEY"),
|
||||
settings=AnthropicLLMService.Settings(
|
||||
system_instruction=system_prompt,
|
||||
) as memory_mcp,
|
||||
# Github MCP docs: https://github.com/github/github-mcp-server
|
||||
# Enable Github Copilot on your GitHub account. Free tier is ok. (https://github.com/settings/copilot)
|
||||
# Generate a personal access token. It must be a Fine-grained token, classic tokens are not supported. (https://github.com/settings/personal-access-tokens)
|
||||
# Set permissions you want to use (eg. "all repositories", "profile: read/write", etc)
|
||||
MCPClient(
|
||||
server_params=StreamableHttpParameters(
|
||||
url="https://api.githubcopilot.com/mcp/",
|
||||
headers={
|
||||
"Authorization": f"Bearer {os.getenv('GITHUB_PERSONAL_ACCESS_TOKEN')}"
|
||||
},
|
||||
),
|
||||
)
|
||||
) as github_mcp,
|
||||
):
|
||||
memory_tools = await memory_mcp.register_tools(llm)
|
||||
github_tools = await github_mcp.register_tools(llm)
|
||||
|
||||
try:
|
||||
rijksmuseum_mcp = MCPClient(
|
||||
server_params=StdioServerParameters(
|
||||
command=shutil.which("npx"),
|
||||
# https://github.com/r-huijts/rijksmuseum-mcp
|
||||
args=["-y", "mcp-server-rijksmuseum"],
|
||||
env={"RIJKSMUSEUM_API_KEY": os.getenv("RIJKSMUSEUM_API_KEY")},
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"error setting up rijksmuseum mcp")
|
||||
logger.exception("error trace:")
|
||||
try:
|
||||
# Github MCP docs: https://github.com/github/github-mcp-server
|
||||
# Enable Github Copilot on your GitHub account. Free tier is ok. (https://github.com/settings/copilot)
|
||||
# Generate a personal access token. It must be a Fine-grained token, classic tokens are not supported. (https://github.com/settings/personal-access-tokens)
|
||||
# Set permissions you want to use (eg. "all repositories", "profile: read/write", etc)
|
||||
github_mcp = MCPClient(
|
||||
server_params=StreamableHttpParameters(
|
||||
url="https://api.githubcopilot.com/mcp/",
|
||||
headers={
|
||||
"Authorization": f"Bearer {os.getenv('GITHUB_PERSONAL_ACCESS_TOKEN')}"
|
||||
},
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"error setting up mcp.run")
|
||||
logger.exception("error trace:")
|
||||
|
||||
rijksmuseum_tools = {}
|
||||
github_tools = {}
|
||||
try:
|
||||
rijksmuseum_tools = await rijksmuseum_mcp.register_tools(llm)
|
||||
github_tools = await github_mcp.register_tools(llm)
|
||||
except Exception as e:
|
||||
logger.error(f"error registering tools")
|
||||
logger.exception("error trace:")
|
||||
|
||||
all_standard_tools = rijksmuseum_tools.standard_tools + github_tools.standard_tools
|
||||
all_standard_tools = memory_tools.standard_tools + github_tools.standard_tools
|
||||
all_tools = ToolsSchema(standard_tools=all_standard_tools)
|
||||
|
||||
context = LLMContext(tools=all_tools)
|
||||
context = LLMContext(
|
||||
messages=[{"role": "user", "content": "Please introduce yourself."}],
|
||||
tools=all_tools,
|
||||
)
|
||||
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(vad_analyzer=SileroVADAnalyzer()),
|
||||
)
|
||||
mcp_image_processor = UrlToImageProcessor(aiohttp_session=session)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
@@ -201,7 +127,6 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
user_aggregator, # User spoken responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
mcp_image_processor, # URL image -> output
|
||||
transport.output(), # Transport bot output
|
||||
assistant_aggregator, # Assistant spoken responses and tool context
|
||||
]
|
||||
@@ -239,9 +164,9 @@ async def bot(runner_args: RunnerArguments):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if not os.getenv("RIJKSMUSEUM_API_KEY") or not os.getenv("GITHUB_PERSONAL_ACCESS_TOKEN"):
|
||||
if not os.getenv("GITHUB_PERSONAL_ACCESS_TOKEN"):
|
||||
logger.error(
|
||||
f"Please set `RIJKSMUSEUM_API_KEY` and `GITHUB_PERSONAL_ACCESS_TOKEN` environment variables. See https://github.com/r-huijts/rijksmuseum-mcp."
|
||||
f"Please set `GITHUB_PERSONAL_ACCESS_TOKEN` environment variable."
|
||||
)
|
||||
import sys
|
||||
|
||||
|
||||
@@ -4,26 +4,15 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from mcp import StdioServerParameters
|
||||
from PIL import Image
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
FunctionCallResultFrame,
|
||||
LLMRunFrame,
|
||||
URLImageRawFrame,
|
||||
)
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
@@ -32,7 +21,6 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.anthropic.llm import AnthropicLLMService
|
||||
@@ -44,86 +32,16 @@ from pipecat.transports.daily.transport import DailyParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
class UrlToImageProcessor(FrameProcessor):
|
||||
def __init__(self, aiohttp_session: aiohttp.ClientSession, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._aiohttp_session = aiohttp_session
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, FunctionCallResultFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
image_url = self.extract_url(frame.result)
|
||||
if image_url:
|
||||
await self.run_image_process(image_url)
|
||||
# sometimes we get multiple image urls- process 1 at a time
|
||||
await asyncio.sleep(1)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
def extract_url(self, text: str):
|
||||
try:
|
||||
data = json.loads(text)
|
||||
if "artObject" in data:
|
||||
return data["artObject"]["webImage"]["url"]
|
||||
if "artworks" in data and len(data["artworks"]):
|
||||
return data["artworks"][0]["webImage"]["url"]
|
||||
except (json.JSONDecodeError, KeyError, TypeError):
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
async def run_image_process(self, image_url: str):
|
||||
try:
|
||||
logger.debug(f"handling image from url: '{image_url}'")
|
||||
async with self._aiohttp_session.get(image_url) as response:
|
||||
image_stream = io.BytesIO(await response.content.read())
|
||||
image = Image.open(image_stream)
|
||||
image = image.convert("RGB")
|
||||
frame = URLImageRawFrame(
|
||||
url=image_url, image=image.tobytes(), size=image.size, format="RGB"
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
except Exception as e:
|
||||
error_msg = f"Error handling image url {image_url}: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
|
||||
|
||||
# full list of tools available from rijksmuseum MCP:
|
||||
# - get_artwork_details
|
||||
# - get_artwork_image
|
||||
# - get_user_sets
|
||||
# - get_user_set_details
|
||||
# - open_image_in_browser
|
||||
# - get_artist_timeline
|
||||
|
||||
mcp_tools_filter = ["get_artwork_details", "get_artwork_image", "open_image_in_browser"]
|
||||
|
||||
|
||||
def open_image_output_filter(output: str):
|
||||
pattern = r"Successfully opened image in browser: "
|
||||
text_to_print = re.sub(pattern, "", output)
|
||||
print(f"🖼️ link to high resolution artwork: {text_to_print}")
|
||||
|
||||
|
||||
# We use lambdas to defer transport parameter creation until the transport
|
||||
# type is selected at runtime.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
video_out_enabled=True,
|
||||
video_out_width=1024,
|
||||
video_out_height=1024,
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
video_out_enabled=True,
|
||||
video_out_width=1024,
|
||||
video_out_height=1024,
|
||||
),
|
||||
}
|
||||
|
||||
@@ -131,63 +49,48 @@ transport_params = {
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
# Create an HTTP session for API calls
|
||||
async with aiohttp.ClientSession() as session:
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
settings=CartesiaTTSService.Settings(
|
||||
voice="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
),
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
settings=CartesiaTTSService.Settings(
|
||||
voice="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
),
|
||||
)
|
||||
|
||||
system_prompt = f"""
|
||||
You are a helpful LLM in a voice call.
|
||||
Your goal is to demonstrate your capabilities in a succinct way.
|
||||
You have access to memory tools that let you store and recall information.
|
||||
Offer to remember things for the user, like their name, preferences, or anything they'd like.
|
||||
You can also recall things you've previously stored.
|
||||
Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points.
|
||||
Respond to what the user said in a creative and helpful way.
|
||||
Don't overexplain what you are doing.
|
||||
Just respond with short sentences when you are carrying out tool calls.
|
||||
"""
|
||||
|
||||
llm = AnthropicLLMService(
|
||||
api_key=os.getenv("ANTHROPIC_API_KEY"),
|
||||
settings=AnthropicLLMService.Settings(
|
||||
system_instruction=system_prompt,
|
||||
),
|
||||
)
|
||||
|
||||
# https://github.com/modelcontextprotocol/servers/tree/main/src/memory
|
||||
async with MCPClient(
|
||||
server_params=StdioServerParameters(
|
||||
command=shutil.which("npx"),
|
||||
args=["-y", "@modelcontextprotocol/server-memory"],
|
||||
# env={"MEMORY_FILE_PATH": "/tmp/pipecat_memory.jsonl"}, # Optional: specify MEMORY_FILE_PATH
|
||||
),
|
||||
) as mcp:
|
||||
tools = await mcp.register_tools(llm)
|
||||
|
||||
context = LLMContext(
|
||||
messages=[{"role": "user", "content": "Please introduce yourself."}],
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
system_prompt = f"""
|
||||
You are a helpful LLM in a voice call.
|
||||
Your goal is to demonstrate your capabilities in a succinct way.
|
||||
You have access to tools to search the Rijksmuseum collection.
|
||||
Offer, for example, to show a floral still life, use the `search_artwork` tool.
|
||||
The tool may respond with a JSON object with an `artworks` array. Choose the art from that array.
|
||||
Once the tool has responded, tell the user the title and use the `open_image_in_browser` tool.
|
||||
Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points.
|
||||
Respond to what the user said in a creative and helpful way.
|
||||
Don't overexplain what you are doing.
|
||||
Just respond with short sentences when you are carrying out tool calls.
|
||||
"""
|
||||
|
||||
llm = AnthropicLLMService(
|
||||
api_key=os.getenv("ANTHROPIC_API_KEY"),
|
||||
settings=AnthropicLLMService.Settings(
|
||||
system_instruction=system_prompt,
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
mcp = MCPClient(
|
||||
server_params=StdioServerParameters(
|
||||
command=shutil.which("npx"),
|
||||
# https://github.com/r-huijts/rijksmuseum-mcp
|
||||
args=["-y", "mcp-server-rijksmuseum"],
|
||||
env={"RIJKSMUSEUM_API_KEY": os.getenv("RIJKSMUSEUM_API_KEY")},
|
||||
),
|
||||
# Optional
|
||||
tools_filter=mcp_tools_filter, # Optional
|
||||
tools_output_filters={"open_image_in_browser": open_image_output_filter},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"error setting up mcp")
|
||||
logger.exception("error trace:")
|
||||
|
||||
mcp_image = UrlToImageProcessor(aiohttp_session=session)
|
||||
|
||||
tools = {}
|
||||
try:
|
||||
tools = await mcp.register_tools(llm)
|
||||
except Exception as e:
|
||||
logger.error(f"error registering tools")
|
||||
logger.exception("error trace:")
|
||||
|
||||
context = LLMContext(tools=tools)
|
||||
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(vad_analyzer=SileroVADAnalyzer()),
|
||||
@@ -200,7 +103,6 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
user_aggregator, # User spoken responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
mcp_image, # URL image -> output
|
||||
transport.output(), # Transport bot output
|
||||
assistant_aggregator, # Assistant spoken responses and tool context
|
||||
]
|
||||
@@ -238,13 +140,6 @@ async def bot(runner_args: RunnerArguments):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if not os.getenv("RIJKSMUSEUM_API_KEY"):
|
||||
logger.error(
|
||||
f"Please set RIJKSMUSEUM_API_KEY environment variable for this example. See https://github.com/r-huijts/rijksmuseum-mcp and https://www.rijksmuseum.nl/en/register?redirectUrl=https://www.https://www.rijksmuseum.nl/en/rijksstudio/my/profile"
|
||||
)
|
||||
import sys
|
||||
|
||||
sys.exit(1)
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
|
||||
@@ -63,28 +63,6 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
# Github MCP docs: https://github.com/github/github-mcp-server
|
||||
# Enable Github Copilot on your GitHub account. Free tier is ok. (https://github.com/settings/copilot)
|
||||
# Generate a personal access token. It must be a Fine-grained token, classic tokens are not supported. (https://github.com/settings/personal-access-tokens)
|
||||
# Set permissions you want to use (eg. "all repositories", "profile: read/write", etc)
|
||||
mcp = MCPClient(
|
||||
server_params=StreamableHttpParameters(
|
||||
url="https://api.githubcopilot.com/mcp/",
|
||||
headers={"Authorization": f"Bearer {os.getenv('GITHUB_PERSONAL_ACCESS_TOKEN')}"},
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"error setting up mcp")
|
||||
logger.exception("error trace:")
|
||||
|
||||
tools = {}
|
||||
try:
|
||||
tools = await mcp.get_tools_schema()
|
||||
except Exception as e:
|
||||
logger.error(f"error registering tools")
|
||||
logger.exception("error trace:")
|
||||
|
||||
system = f"""
|
||||
You are a helpful LLM in a voice call.
|
||||
Your goal is to answer questions about the user's GitHub repositories and account.
|
||||
@@ -94,53 +72,65 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
Just respond with short sentences when you are carrying out tool calls.
|
||||
"""
|
||||
|
||||
llm = GeminiLiveLLMService(
|
||||
api_key=os.getenv("GOOGLE_API_KEY"),
|
||||
system_instruction=system,
|
||||
tools=tools,
|
||||
)
|
||||
# Github MCP docs: https://github.com/github/github-mcp-server
|
||||
# Enable Github Copilot on your GitHub account. Free tier is ok. (https://github.com/settings/copilot)
|
||||
# Generate a personal access token. It must be a Fine-grained token, classic tokens are not supported. (https://github.com/settings/personal-access-tokens)
|
||||
# Set permissions you want to use (eg. "all repositories", "profile: read/write", etc)
|
||||
async with MCPClient(
|
||||
server_params=StreamableHttpParameters(
|
||||
url="https://api.githubcopilot.com/mcp/",
|
||||
headers={"Authorization": f"Bearer {os.getenv('GITHUB_PERSONAL_ACCESS_TOKEN')}"},
|
||||
)
|
||||
) as mcp:
|
||||
tools = await mcp.get_tools_schema()
|
||||
|
||||
await mcp.register_tools_schema(tools, llm)
|
||||
llm = GeminiLiveLLMService(
|
||||
api_key=os.getenv("GOOGLE_API_KEY"),
|
||||
system_instruction=system,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
context = LLMContext([{"role": "developer", "content": "Please introduce yourself."}])
|
||||
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(vad_analyzer=SileroVADAnalyzer()),
|
||||
)
|
||||
await mcp.register_tools_schema(tools, llm)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
user_aggregator, # User spoken responses
|
||||
llm, # LLM
|
||||
transport.output(), # Transport bot output
|
||||
assistant_aggregator, # Assistant spoken responses and tool context
|
||||
]
|
||||
)
|
||||
context = LLMContext([{"role": "user", "content": "Please introduce yourself."}])
|
||||
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(vad_analyzer=SileroVADAnalyzer()),
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
user_aggregator, # User spoken responses
|
||||
llm, # LLM
|
||||
transport.output(), # Transport bot output
|
||||
assistant_aggregator, # Assistant spoken responses and tool context
|
||||
]
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected: {client}")
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected: {client}")
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
await runner.run(task)
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
|
||||
@@ -63,83 +63,78 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
),
|
||||
)
|
||||
|
||||
system_prompt = f"""
|
||||
You are a helpful LLM in a voice call.
|
||||
Your goal is to answer questions about the user's GitHub repositories and account.
|
||||
You have access to a number of tools provided by Github. Use any and all tools to help users.
|
||||
Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points.
|
||||
Don't overexplain what you are doing.
|
||||
Just respond with short sentences when you are carrying out tool calls.
|
||||
"""
|
||||
system_prompt = """\
|
||||
You are a helpful LLM in a voice call.
|
||||
Your goal is to answer questions about the user's GitHub repositories and account.
|
||||
You have access to a number of tools provided by Github. Use any and all tools to help users.
|
||||
Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points.
|
||||
Don't overexplain what you are doing.
|
||||
Just respond with short sentences when you are carrying out tool calls.
|
||||
"""
|
||||
|
||||
llm = GoogleLLMService(
|
||||
api_key=os.getenv("GOOGLE_API_KEY"),
|
||||
system_instruction=system_prompt,
|
||||
)
|
||||
|
||||
try:
|
||||
# Github MCP docs: https://github.com/github/github-mcp-server
|
||||
# Enable Github Copilot on your GitHub account. Free tier is ok. (https://github.com/settings/copilot)
|
||||
# Generate a personal access token. It must be a Fine-grained token, classic tokens are not supported. (https://github.com/settings/personal-access-tokens)
|
||||
# Set permissions you want to use (eg. "all repositories", "profile: read/write", etc)
|
||||
mcp = MCPClient(
|
||||
server_params=StreamableHttpParameters(
|
||||
url="https://api.githubcopilot.com/mcp/",
|
||||
headers={"Authorization": f"Bearer {os.getenv('GITHUB_PERSONAL_ACCESS_TOKEN')}"},
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"error setting up mcp")
|
||||
logger.exception("error trace:")
|
||||
|
||||
tools = {}
|
||||
try:
|
||||
tools = await mcp.register_tools(llm)
|
||||
except Exception as e:
|
||||
logger.error(f"error registering tools")
|
||||
logger.exception("error trace:")
|
||||
|
||||
context = LLMContext(tools=tools)
|
||||
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(vad_analyzer=SileroVADAnalyzer()),
|
||||
)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt,
|
||||
user_aggregator, # User spoken responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
assistant_aggregator, # Assistant spoken responses and tool context
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
settings=GoogleLLMService.Settings(
|
||||
system_instruction=system_prompt,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected: {client}")
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
# Github MCP docs: https://github.com/github/github-mcp-server
|
||||
# Enable Github Copilot on your GitHub account. Free tier is ok. (https://github.com/settings/copilot)
|
||||
# Generate a personal access token. It must be a Fine-grained token, classic tokens are not supported. (https://github.com/settings/personal-access-tokens)
|
||||
# Set permissions you want to use (eg. "all repositories", "profile: read/write", etc)
|
||||
async with MCPClient(
|
||||
server_params=StreamableHttpParameters(
|
||||
url="https://api.githubcopilot.com/mcp/",
|
||||
headers={"Authorization": f"Bearer {os.getenv('GITHUB_PERSONAL_ACCESS_TOKEN')}"},
|
||||
)
|
||||
) as mcp:
|
||||
tools = await mcp.register_tools(llm)
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
context = LLMContext(
|
||||
messages=[{"role": "user", "content": "Please introduce yourself."}],
|
||||
tools=tools,
|
||||
)
|
||||
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(vad_analyzer=SileroVADAnalyzer()),
|
||||
)
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt,
|
||||
user_aggregator, # User spoken responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
assistant_aggregator, # Assistant spoken responses and tool context
|
||||
]
|
||||
)
|
||||
|
||||
await runner.run(task)
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected: {client}")
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
|
||||
@@ -96,7 +96,6 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
allow_interruptions=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ description = "An open source framework for voice (and multimodal) assistants"
|
||||
license = "BSD-2-Clause"
|
||||
license-files = ["LICENSE"]
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
requires-python = ">=3.11"
|
||||
keywords = ["webrtc", "audio", "video", "ai"]
|
||||
classifiers = [
|
||||
"Development Status :: 5 - Production/Stable",
|
||||
@@ -41,7 +41,7 @@ dependencies = [
|
||||
# Required by LocalSmartTurnAnalyzerV3
|
||||
# Inlined here instead of using a self-referential extra for Poetry compatibility.
|
||||
"transformers>=4.48.0,<6",
|
||||
"onnxruntime~=1.23.2",
|
||||
"onnxruntime~=1.24.3",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Base interruption strategy for determining when users can interrupt bot speech."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseInterruptionStrategy(ABC):
|
||||
"""Base class for interruption strategies.
|
||||
|
||||
This is a base class for interruption strategies. Interruption strategies
|
||||
decide when the user can interrupt the bot while the bot is speaking. For
|
||||
example, there could be strategies based on audio volume or strategies based
|
||||
on the number of words the user spoke.
|
||||
"""
|
||||
|
||||
async def append_audio(self, audio: bytes, sample_rate: int):
|
||||
"""Append audio data to the strategy for analysis.
|
||||
|
||||
Not all strategies handle audio. Default implementation does nothing.
|
||||
|
||||
Args:
|
||||
audio: Raw audio bytes to append.
|
||||
sample_rate: Sample rate of the audio data in Hz.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def append_text(self, text: str):
|
||||
"""Append text data to the strategy for analysis.
|
||||
|
||||
Not all strategies handle text. Default implementation does nothing.
|
||||
|
||||
Args:
|
||||
text: Text string to append for analysis.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def should_interrupt(self) -> bool:
|
||||
"""Determine if the user should interrupt the bot.
|
||||
|
||||
This is called when the user stops speaking and it's time to decide
|
||||
whether the user should interrupt the bot. The decision will be based on
|
||||
the aggregated audio and/or text.
|
||||
|
||||
Returns:
|
||||
True if the user should interrupt the bot, False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def reset(self):
|
||||
"""Reset the current accumulated text and/or audio."""
|
||||
pass
|
||||
@@ -1,75 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Minimum words interruption strategy for word count-based interruptions."""
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.interruptions.base_interruption_strategy import BaseInterruptionStrategy
|
||||
|
||||
|
||||
class MinWordsInterruptionStrategy(BaseInterruptionStrategy):
|
||||
"""Interruption strategy based on minimum number of words spoken.
|
||||
|
||||
This is an interruption strategy based on a minimum number of words said
|
||||
by the user. That is, the strategy will be true if the user has said at
|
||||
least that amount of words.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
|
||||
This class is deprecated, use
|
||||
`pipecat.turns.user_start.MinWordsUserTurnStartStrategy` with `PipelineTask`'s
|
||||
new `user_turn_strategies` parameter instead.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, *, min_words: int):
|
||||
"""Initialize the minimum words interruption strategy.
|
||||
|
||||
Args:
|
||||
min_words: Minimum number of words required to trigger an interruption.
|
||||
"""
|
||||
super().__init__()
|
||||
self._min_words = min_words
|
||||
self._text = ""
|
||||
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"'pipecat.audio.interruptions' is deprecated. "
|
||||
"Use `pipecat.turns.user_start.MinWordsUserTurnStartStrategy` with `PipelineTask`'s "
|
||||
"new `user_turn_strategies` parameter instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
async def append_text(self, text: str):
|
||||
"""Append text for word count analysis.
|
||||
|
||||
Args:
|
||||
text: Text string to append to the accumulated text.
|
||||
|
||||
Note: Not all strategies need to handle text.
|
||||
"""
|
||||
self._text += text
|
||||
|
||||
async def should_interrupt(self) -> bool:
|
||||
"""Check if the minimum word count has been reached.
|
||||
|
||||
Returns:
|
||||
True if the user has spoken at least the minimum number of words.
|
||||
"""
|
||||
word_count = len(self._text.split())
|
||||
interrupt = word_count >= self._min_words
|
||||
logger.debug(
|
||||
f"should_interrupt={interrupt} num_spoken_words={word_count} min_words={self._min_words}"
|
||||
)
|
||||
return interrupt
|
||||
|
||||
async def reset(self):
|
||||
"""Reset the accumulated text for the next analysis cycle."""
|
||||
self._text = ""
|
||||
@@ -29,7 +29,6 @@ from typing import (
|
||||
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.dtmf.types import KeypadEntry
|
||||
from pipecat.audio.interruptions.base_interruption_strategy import BaseInterruptionStrategy
|
||||
from pipecat.audio.turn.base_turn_analyzer import BaseTurnParams
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.metrics.metrics import MetricsData
|
||||
@@ -462,137 +461,6 @@ class LLMContextAssistantTimestampFrame(DataFrame):
|
||||
timestamp: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptionMessage:
|
||||
"""A message in a conversation transcript.
|
||||
|
||||
A message in a conversation transcript containing the role and content.
|
||||
Messages are in standard format with roles normalized to user/assistant.
|
||||
|
||||
Parameters:
|
||||
role: The role of the message sender (user or assistant).
|
||||
content: The message content/text.
|
||||
user_id: Optional identifier for the user.
|
||||
timestamp: Optional timestamp when the message was created.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
`TranscriptionMessage` is deprecated and will be removed in a future version.
|
||||
Use `LLMUserAggregator`'s and `LLMAssistantAggregator`'s new events instead.
|
||||
"""
|
||||
|
||||
role: Literal["user", "assistant"]
|
||||
content: str
|
||||
user_id: Optional[str] = None
|
||||
timestamp: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"TranscriptionMessage is deprecated and will be removed in a future version. "
|
||||
"Use `LLMUserAggregator`'s and `LLMAssistantAggregator`'s new events instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ThoughtTranscriptionMessage:
|
||||
"""An LLM thought message in a conversation transcript.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
`ThoughtTranscriptionMessage` is deprecated and will be removed in a future version.
|
||||
Use `LLMAssistantAggregator`'s new events instead.
|
||||
"""
|
||||
|
||||
role: Literal["assistant"] = field(default="assistant", init=False)
|
||||
content: str
|
||||
timestamp: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"ThoughtTranscriptionMessage is deprecated and will be removed in a future version. "
|
||||
"Use `LLMAssistantAggregator`'s new events instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptionUpdateFrame(DataFrame):
|
||||
"""Frame containing new messages added to conversation transcript.
|
||||
|
||||
A frame containing new messages added to the conversation transcript.
|
||||
This frame is emitted when new messages are added to the conversation history,
|
||||
containing only the newly added messages rather than the full transcript.
|
||||
Messages have normalized roles (user/assistant) regardless of the LLM service used.
|
||||
Messages are always in the OpenAI standard message format, which supports both:
|
||||
|
||||
Examples:
|
||||
Simple format::
|
||||
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hi, how are you?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Great! And you?"
|
||||
}
|
||||
]
|
||||
|
||||
Content list format::
|
||||
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "Hi, how are you?"}]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "Great! And you?"}]
|
||||
}
|
||||
]
|
||||
|
||||
OpenAI supports both formats. Anthropic and Google messages are converted to the
|
||||
content list format.
|
||||
|
||||
Parameters:
|
||||
messages: List of new transcript messages that were added.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
`TranscriptionUpdateFrame` is deprecated and will be removed in a future version.
|
||||
Use `LLMUserAggregator`'s and `LLMAssistantAggregator`'s new events instead.
|
||||
"""
|
||||
|
||||
messages: List[TranscriptionMessage | ThoughtTranscriptionMessage]
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"TranscriptionUpdateFrame is deprecated and will be removed in a future version. "
|
||||
"Use `LLMUserAggregator`'s and `LLMAssistantAggregator`'s new events instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
pts = format_pts(self.pts)
|
||||
return f"{self.name}(pts: {pts}, messages: {len(self.messages)})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMContextFrame(Frame):
|
||||
"""Frame containing a universal LLM context.
|
||||
@@ -878,30 +746,18 @@ class StartFrame(SystemFrame):
|
||||
Parameters:
|
||||
audio_in_sample_rate: Input audio sample rate in Hz.
|
||||
audio_out_sample_rate: Output audio sample rate in Hz.
|
||||
allow_interruptions: Whether to allow user interruptions.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
Use `LLMUserAggregator`'s new `user_mute_strategies` parameter instead.
|
||||
|
||||
enable_metrics: Whether to enable performance metrics collection.
|
||||
enable_tracing: Whether to enable OpenTelemetry tracing.
|
||||
enable_usage_metrics: Whether to enable usage metrics collection.
|
||||
interruption_strategies: List of interruption handling strategies.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
Use `LLMUserAggregator`'s new `user_turn_strategies` parameter instead.
|
||||
|
||||
report_only_initial_ttfb: Whether to report only initial time-to-first-byte.
|
||||
tracing_context: Pipeline-scoped tracing context for span hierarchy.
|
||||
"""
|
||||
|
||||
audio_in_sample_rate: int = 16000
|
||||
audio_out_sample_rate: int = 24000
|
||||
allow_interruptions: bool = False
|
||||
enable_metrics: bool = False
|
||||
enable_tracing: bool = False
|
||||
enable_usage_metrics: bool = False
|
||||
interruption_strategies: List[BaseInterruptionStrategy] = field(default_factory=list)
|
||||
report_only_initial_ttfb: bool = False
|
||||
tracing_context: Optional["TracingContext"] = None
|
||||
|
||||
@@ -1010,16 +866,9 @@ class UserStartedSpeakingFrame(SystemFrame):
|
||||
|
||||
Emitted when the user turn starts, which usually means that some
|
||||
transcriptions are already available.
|
||||
|
||||
Parameters:
|
||||
emulated: Whether this event was emulated rather than detected by VAD.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
This field is deprecated and will be removed in a future version.
|
||||
|
||||
"""
|
||||
|
||||
emulated: bool = False
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -1028,16 +877,9 @@ class UserStoppedSpeakingFrame(SystemFrame):
|
||||
|
||||
Emitted when the user turn ends. This usually coincides with the start of
|
||||
the bot turn.
|
||||
|
||||
Parameters:
|
||||
emulated: Whether this event was emulated rather than detected by VAD.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
This field is deprecated and will be removed in a future version.
|
||||
|
||||
"""
|
||||
|
||||
emulated: bool = False
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -1072,56 +914,6 @@ class UserSpeakingFrame(SystemFrame):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmulateUserStartedSpeakingFrame(SystemFrame):
|
||||
"""Frame to emulate user started speaking behavior.
|
||||
|
||||
Emitted by internal processors upstream to emulate VAD behavior when a
|
||||
user starts speaking.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
This frame is deprecated and will be removed in a future version.
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"EmulateUserStartedSpeakingFrame is deprecated and will be removed in a future version.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmulateUserStoppedSpeakingFrame(SystemFrame):
|
||||
"""Frame to emulate user stopped speaking behavior.
|
||||
|
||||
Emitted by internal processors upstream to emulate VAD behavior when a
|
||||
user stops speaking.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
This frame is deprecated and will be removed in a future version.
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"EmulateUserStoppedSpeakingFrame is deprecated and will be removed in a future version.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VADUserStartedSpeakingFrame(SystemFrame):
|
||||
"""Frame emitted when VAD definitively detects user started speaking.
|
||||
|
||||
@@ -20,7 +20,6 @@ from typing import Any, AsyncIterable, Dict, Iterable, List, Optional, Set, Tupl
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from pipecat.audio.interruptions.base_interruption_strategy import BaseInterruptionStrategy
|
||||
from pipecat.clocks.base_clock import BaseClock
|
||||
from pipecat.clocks.system_clock import SystemClock
|
||||
from pipecat.frames.frames import (
|
||||
@@ -111,11 +110,6 @@ class PipelineParams(BaseModel):
|
||||
constructor arguments instead.
|
||||
|
||||
Parameters:
|
||||
allow_interruptions: Whether to allow pipeline interruptions.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
Use `LLMUserAggregator`'s new `user_turn_strategies` parameter instead.
|
||||
|
||||
audio_in_sample_rate: Input audio sample rate in Hz.
|
||||
audio_out_sample_rate: Output audio sample rate in Hz.
|
||||
enable_heartbeats: Whether to enable heartbeat monitoring.
|
||||
@@ -124,11 +118,6 @@ class PipelineParams(BaseModel):
|
||||
heartbeats_period_secs: Period between heartbeats in seconds.
|
||||
heartbeats_monitor_secs: Timeout (in seconds) before warning about
|
||||
missed heartbeats. Defaults to 10 seconds.
|
||||
interruption_strategies: [deprecated] Strategies for bot interruption behavior.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
Use `LLMUserAggregator`'s new `user_turn_strategies` parameter instead.
|
||||
|
||||
report_only_initial_ttfb: Whether to report only initial time to first byte.
|
||||
send_initial_empty_metrics: Whether to send initial empty metrics.
|
||||
start_metadata: Additional metadata for pipeline start.
|
||||
@@ -136,7 +125,6 @@ class PipelineParams(BaseModel):
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
allow_interruptions: bool = True
|
||||
audio_in_sample_rate: int = 16000
|
||||
audio_out_sample_rate: int = 24000
|
||||
enable_heartbeats: bool = False
|
||||
@@ -144,7 +132,6 @@ class PipelineParams(BaseModel):
|
||||
enable_usage_metrics: bool = False
|
||||
heartbeats_period_secs: float = HEARTBEAT_SECS
|
||||
heartbeats_monitor_secs: float = HEARTBEAT_MONITOR_SECS
|
||||
interruption_strategies: List[BaseInterruptionStrategy] = Field(default_factory=list)
|
||||
report_only_initial_ttfb: bool = False
|
||||
send_initial_empty_metrics: bool = True
|
||||
start_metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
@@ -778,14 +765,12 @@ class PipelineTask(BasePipelineTask):
|
||||
self._maybe_start_idle_task()
|
||||
|
||||
start_frame = StartFrame(
|
||||
allow_interruptions=self._params.allow_interruptions,
|
||||
audio_in_sample_rate=self._params.audio_in_sample_rate,
|
||||
audio_out_sample_rate=self._params.audio_out_sample_rate,
|
||||
enable_metrics=self._params.enable_metrics,
|
||||
enable_tracing=self._enable_tracing,
|
||||
enable_usage_metrics=self._params.enable_usage_metrics,
|
||||
report_only_initial_ttfb=self._params.report_only_initial_ttfb,
|
||||
interruption_strategies=self._params.interruption_strategies,
|
||||
tracing_context=self._tracing_context,
|
||||
)
|
||||
start_frame.metadata = self._create_start_metadata()
|
||||
|
||||
@@ -50,7 +50,6 @@ from pipecat.frames.frames import (
|
||||
LLMThoughtStartFrame,
|
||||
LLMThoughtTextFrame,
|
||||
LLMUpdateSettingsFrame,
|
||||
SpeechControlParamsFrame,
|
||||
StartFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
@@ -74,7 +73,7 @@ from pipecat.processors.aggregators.llm_context_summarizer import (
|
||||
LLMContextSummarizer,
|
||||
SummaryAppliedEvent,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameCallback, FrameDirection, FrameProcessor
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.services.settings import LLMSettings
|
||||
from pipecat.turns.user_idle_controller import UserIdleController
|
||||
from pipecat.turns.user_mute import BaseUserMuteStrategy
|
||||
@@ -82,7 +81,7 @@ from pipecat.turns.user_start import BaseUserTurnStartStrategy, UserTurnStartedP
|
||||
from pipecat.turns.user_stop import BaseUserTurnStopStrategy, UserTurnStoppedParams
|
||||
from pipecat.turns.user_turn_completion_mixin import UserTurnCompletionConfig
|
||||
from pipecat.turns.user_turn_controller import UserTurnController
|
||||
from pipecat.turns.user_turn_strategies import ExternalUserTurnStrategies, UserTurnStrategies
|
||||
from pipecat.turns.user_turn_strategies import UserTurnStrategies
|
||||
from pipecat.utils.context.llm_context_summarization import (
|
||||
LLMAutoContextSummarizationConfig,
|
||||
LLMContextSummarizationConfig,
|
||||
@@ -468,11 +467,6 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
self._vad_controller.add_event_handler("on_push_frame", self._on_push_frame)
|
||||
self._vad_controller.add_event_handler("on_broadcast_frame", self._on_broadcast_frame)
|
||||
|
||||
# NOTE(aleix): Probably just needed temporarily. This was added to
|
||||
# prevent processing self-queued frames (SpeechControlParamsFrame)
|
||||
# pushed by strategies.
|
||||
self._self_queued_frames = set()
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up processor resources."""
|
||||
await super().cleanup()
|
||||
@@ -528,8 +522,6 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, LLMSetToolChoiceFrame):
|
||||
self.set_tool_choice(frame.tool_choice)
|
||||
elif isinstance(frame, SpeechControlParamsFrame):
|
||||
await self._handle_speech_control_params(frame)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -643,17 +635,6 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
if frame.run_llm:
|
||||
await self.push_context_frame()
|
||||
|
||||
async def _handle_speech_control_params(self, frame: SpeechControlParamsFrame):
|
||||
if frame.id in self._self_queued_frames:
|
||||
return
|
||||
|
||||
if not frame.turn_params:
|
||||
return
|
||||
|
||||
logger.warning(f"{self}: `turn_analyzer` in base input transport is deprecated.")
|
||||
|
||||
await self._user_turn_controller.update_strategies(ExternalUserTurnStrategies())
|
||||
|
||||
async def _handle_transcription(self, frame: TranscriptionFrame):
|
||||
text = frame.text
|
||||
|
||||
@@ -668,16 +649,6 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
)
|
||||
)
|
||||
|
||||
async def _internal_queue_frame(
|
||||
self,
|
||||
frame: Frame,
|
||||
direction: FrameDirection = FrameDirection.DOWNSTREAM,
|
||||
callback: Optional[FrameCallback] = None,
|
||||
):
|
||||
"""Queues the given frame to ourselves."""
|
||||
self._self_queued_frames.add(frame.id)
|
||||
await self.queue_frame(frame, direction, callback)
|
||||
|
||||
async def _queued_broadcast_frame(self, frame_cls: Type[Frame], **kwargs):
|
||||
"""Broadcasts a frame upstream and queues it for internal processing.
|
||||
|
||||
@@ -690,13 +661,13 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
**kwargs: Keyword arguments to be passed to the frame's constructor.
|
||||
|
||||
"""
|
||||
await self._internal_queue_frame(frame_cls(**kwargs))
|
||||
await self.queue_frame(frame_cls(**kwargs))
|
||||
await self.push_frame(frame_cls(**kwargs), FrameDirection.UPSTREAM)
|
||||
|
||||
async def _on_push_frame(
|
||||
self, controller, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM
|
||||
):
|
||||
await self._internal_queue_frame(frame, direction)
|
||||
await self.queue_frame(frame, direction)
|
||||
|
||||
async def _on_broadcast_frame(self, controller, frame_cls: Type[Frame], **kwargs):
|
||||
await self._queued_broadcast_frame(frame_cls, **kwargs)
|
||||
@@ -731,7 +702,7 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
|
||||
await self._user_idle_controller.process_frame(UserStartedSpeakingFrame())
|
||||
|
||||
if params.enable_interruptions and self._allow_interruptions:
|
||||
if params.enable_interruptions:
|
||||
await self.broadcast_interruption()
|
||||
|
||||
await self._call_event_handler("on_user_turn_started", strategy)
|
||||
|
||||
@@ -1,64 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""User response aggregation for text frames.
|
||||
|
||||
This module provides an aggregator that collects user responses and outputs
|
||||
them as TextFrame objects, useful for capturing and processing user input
|
||||
in conversational pipelines.
|
||||
"""
|
||||
|
||||
from pipecat.frames.frames import TextFrame
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMUserAggregator
|
||||
|
||||
|
||||
class UserResponseAggregator(LLMUserAggregator):
|
||||
"""Aggregates user responses into TextFrame objects.
|
||||
|
||||
This aggregator extends LLMUserAggregator to specifically handle
|
||||
user input by collecting text responses and outputting them as TextFrame
|
||||
objects when the aggregation is complete.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize the user response aggregator.
|
||||
|
||||
.. deprecated:: 0.0.92
|
||||
`UserResponseAggregator` is deprecated and will be removed in a future version.
|
||||
|
||||
Args:
|
||||
**kwargs: Additional arguments passed to parent LLMUserAggregator.
|
||||
"""
|
||||
super().__init__(context=LLMContext(), **kwargs)
|
||||
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"`UserResponseAggregator` is deprecated and will be removed in a future version.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
async def push_aggregation(self):
|
||||
"""Push the aggregated user response as a TextFrame.
|
||||
|
||||
Creates a TextFrame from the current aggregation if it contains content,
|
||||
resets the aggregation state, and pushes the frame downstream.
|
||||
"""
|
||||
if len(self._aggregation) > 0:
|
||||
frame = TextFrame(self._aggregation.strip())
|
||||
|
||||
# Reset the aggregation. Reset it before pushing it down, otherwise
|
||||
# if the tasks gets cancelled we won't be able to clear things up.
|
||||
self._aggregation = ""
|
||||
|
||||
await self.push_frame(frame)
|
||||
|
||||
# Reset our accumulator state.
|
||||
await self.reset()
|
||||
@@ -1,243 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Speech-to-text (STT) muting control module.
|
||||
|
||||
This module provides functionality to control STT muting based on different strategies,
|
||||
such as during function calls, bot speech, or custom conditions. It helps manage when
|
||||
the STT service should be active or inactive during a conversation.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Awaitable, Callable, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
Frame,
|
||||
FunctionCallCancelFrame,
|
||||
FunctionCallResultFrame,
|
||||
FunctionCallsStartedFrame,
|
||||
InputAudioRawFrame,
|
||||
InterimTranscriptionFrame,
|
||||
InterruptionFrame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
|
||||
class STTMuteStrategy(Enum):
|
||||
"""Strategies determining when STT should be muted.
|
||||
|
||||
Each strategy defines different conditions under which speech-to-text
|
||||
processing should be temporarily disabled to prevent unwanted audio
|
||||
processing during specific conversation states.
|
||||
|
||||
Parameters:
|
||||
FIRST_SPEECH: Mute STT until the first bot speech is detected.
|
||||
MUTE_UNTIL_FIRST_BOT_COMPLETE: Mute STT until the first bot completes speaking,
|
||||
regardless of whether it is the first speech.
|
||||
FUNCTION_CALL: Mute STT during function calls to prevent interruptions.
|
||||
ALWAYS: Always mute STT when the bot is speaking.
|
||||
CUSTOM: Use a custom callback to determine muting logic dynamically.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
`STTMuteStrategy` is deprecated and will be removed in a future version.
|
||||
Use `LLMUserAggregator`'s new `user_mute_strategies` instead.
|
||||
"""
|
||||
|
||||
FIRST_SPEECH = "first_speech"
|
||||
MUTE_UNTIL_FIRST_BOT_COMPLETE = "mute_until_first_bot_complete"
|
||||
FUNCTION_CALL = "function_call"
|
||||
ALWAYS = "always"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
@dataclass
|
||||
class STTMuteConfig:
|
||||
"""Configuration for STT muting behavior.
|
||||
|
||||
Defines which muting strategies to apply and provides optional custom
|
||||
callback for advanced muting logic. Multiple strategies can be combined
|
||||
to create sophisticated muting behavior.
|
||||
|
||||
Parameters:
|
||||
strategies: Set of muting strategies to apply simultaneously.
|
||||
should_mute_callback: Optional callback for custom muting logic.
|
||||
Only required when using STTMuteStrategy.CUSTOM. Called with
|
||||
the STTMuteFilter instance to determine muting state.
|
||||
|
||||
Note:
|
||||
MUTE_UNTIL_FIRST_BOT_COMPLETE and FIRST_SPEECH strategies should not be used together
|
||||
as they handle the first bot speech differently.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
`STTMuteConfig` is deprecated and will be removed in a future version.
|
||||
Use `LLMUserAggregator`'s new `user_mute_strategies` instead.
|
||||
"""
|
||||
|
||||
strategies: set[STTMuteStrategy]
|
||||
should_mute_callback: Optional[Callable[["STTMuteFilter"], Awaitable[bool]]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate configuration after initialization.
|
||||
|
||||
Raises:
|
||||
ValueError: If incompatible strategies are used together.
|
||||
"""
|
||||
if (
|
||||
STTMuteStrategy.MUTE_UNTIL_FIRST_BOT_COMPLETE in self.strategies
|
||||
and STTMuteStrategy.FIRST_SPEECH in self.strategies
|
||||
):
|
||||
raise ValueError(
|
||||
"MUTE_UNTIL_FIRST_BOT_COMPLETE and FIRST_SPEECH strategies should not be used together"
|
||||
)
|
||||
|
||||
|
||||
class STTMuteFilter(FrameProcessor):
|
||||
"""A processor that handles STT muting and interruption control.
|
||||
|
||||
This processor combines STT muting and interruption control as a coordinated
|
||||
feature. When STT is muted, interruptions are automatically disabled by
|
||||
suppressing VAD-related frames. This prevents unwanted speech detection
|
||||
during bot speech, function calls, or other specified conditions.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
`STTMuteFilter` is deprecated and will be removed in a future version.
|
||||
Use `LLMUserAggregator`'s new `user_mute_strategies` instead.
|
||||
"""
|
||||
|
||||
def __init__(self, *, config: STTMuteConfig, **kwargs):
|
||||
"""Initialize the STT mute filter.
|
||||
|
||||
Args:
|
||||
config: Configuration specifying muting strategies and behavior.
|
||||
**kwargs: Additional arguments passed to parent class.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._config = config
|
||||
self._first_speech_handled = False
|
||||
self._bot_is_speaking = False
|
||||
self._function_call_in_progress = set()
|
||||
self._is_muted = False
|
||||
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"`STTMuteFilter` is deprecated and will be removed in a future version. "
|
||||
"Use `LLMUserAggregator`'s new `user_mute_strategies` instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
async def _handle_mute_state(self, should_mute: bool):
|
||||
"""Handle STT muting and interruption control state changes."""
|
||||
if should_mute != self._is_muted:
|
||||
logger.debug(f"STTMuteFilter {'muting' if should_mute else 'unmuting'}")
|
||||
self._is_muted = should_mute
|
||||
# Note: We don't send STTMuteFrame to the STT service itself.
|
||||
# The filter blocks frames locally, but the STT service continues
|
||||
# processing audio to keep streaming connections alive (e.g., Google STT).
|
||||
|
||||
async def _should_mute(self) -> bool:
|
||||
"""Determine if STT should be muted based on current state and strategies."""
|
||||
for strategy in self._config.strategies:
|
||||
match strategy:
|
||||
case STTMuteStrategy.FUNCTION_CALL:
|
||||
if self._function_call_in_progress:
|
||||
return True
|
||||
|
||||
case STTMuteStrategy.ALWAYS:
|
||||
if self._bot_is_speaking:
|
||||
return True
|
||||
|
||||
case STTMuteStrategy.FIRST_SPEECH:
|
||||
if self._bot_is_speaking and not self._first_speech_handled:
|
||||
self._first_speech_handled = True
|
||||
return True
|
||||
|
||||
case STTMuteStrategy.MUTE_UNTIL_FIRST_BOT_COMPLETE:
|
||||
if not self._first_speech_handled:
|
||||
return True
|
||||
|
||||
case STTMuteStrategy.CUSTOM:
|
||||
if self._bot_is_speaking and self._config.should_mute_callback:
|
||||
should_mute = await self._config.should_mute_callback(self)
|
||||
if should_mute:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process incoming frames and manage muting state.
|
||||
|
||||
Monitors conversation state through frame types and applies muting
|
||||
strategies accordingly. Suppresses VAD-related frames when muted
|
||||
while allowing other frames to pass through.
|
||||
|
||||
Args:
|
||||
frame: The incoming frame to process.
|
||||
direction: The direction of frame flow in the pipeline.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# Determine if we need to change mute state based on frame type
|
||||
should_mute = None
|
||||
|
||||
# Process frames to determine mute state
|
||||
if isinstance(frame, StartFrame):
|
||||
should_mute = await self._should_mute()
|
||||
elif isinstance(frame, FunctionCallsStartedFrame):
|
||||
for f in frame.function_calls:
|
||||
self._function_call_in_progress.add(f.tool_call_id)
|
||||
should_mute = await self._should_mute()
|
||||
elif isinstance(frame, (FunctionCallCancelFrame, FunctionCallResultFrame)):
|
||||
self._function_call_in_progress.remove(frame.tool_call_id)
|
||||
should_mute = await self._should_mute()
|
||||
elif isinstance(frame, BotStartedSpeakingFrame):
|
||||
self._bot_is_speaking = True
|
||||
should_mute = await self._should_mute()
|
||||
elif isinstance(frame, BotStoppedSpeakingFrame):
|
||||
self._bot_is_speaking = False
|
||||
if not self._first_speech_handled:
|
||||
self._first_speech_handled = True
|
||||
should_mute = await self._should_mute()
|
||||
|
||||
# Then push the original frame
|
||||
if isinstance(
|
||||
frame,
|
||||
(
|
||||
InterruptionFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
InputAudioRawFrame,
|
||||
InterimTranscriptionFrame,
|
||||
TranscriptionFrame,
|
||||
),
|
||||
):
|
||||
# Only pass VAD-related frames when not muted
|
||||
if not self._is_muted:
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
logger.trace(f"{frame.__class__.__name__} suppressed - STT currently muted")
|
||||
else:
|
||||
# Pass all other frames through
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
# Finally handle mute state change if needed
|
||||
if should_mute is not None and should_mute != self._is_muted:
|
||||
await self._handle_mute_state(should_mute)
|
||||
@@ -23,14 +23,12 @@ from typing import (
|
||||
Coroutine,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
)
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.interruptions.base_interruption_strategy import BaseInterruptionStrategy
|
||||
from pipecat.clocks.base_clock import BaseClock
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
@@ -193,9 +191,6 @@ class FrameProcessor(BaseObject):
|
||||
self._enable_metrics = False
|
||||
self._enable_usage_metrics = False
|
||||
self._report_only_initial_ttfb = False
|
||||
# Other properties (deprecated)
|
||||
self._allow_interruptions = False
|
||||
self._interruption_strategies: List[BaseInterruptionStrategy] = []
|
||||
|
||||
# Indicates whether we have received the StartFrame.
|
||||
self.__started = False
|
||||
@@ -307,29 +302,6 @@ class FrameProcessor(BaseObject):
|
||||
"""
|
||||
return self._prev
|
||||
|
||||
@property
|
||||
def interruptions_allowed(self):
|
||||
"""Check if interruptions are allowed for this processor.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
Use `LLMUserAggregator`'s new `user_mute_strategies` parameter instead.
|
||||
|
||||
Returns:
|
||||
True if interruptions are allowed.
|
||||
"""
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"`FrameProcessor.interruptions_allowed` is deprecated. "
|
||||
"Use `LLMUserAggregator`'s new `user_mute_strategies` parameter instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return self._allow_interruptions
|
||||
|
||||
@property
|
||||
def metrics_enabled(self):
|
||||
"""Check if metrics collection is enabled.
|
||||
@@ -357,19 +329,6 @@ class FrameProcessor(BaseObject):
|
||||
"""
|
||||
return self._report_only_initial_ttfb
|
||||
|
||||
@property
|
||||
def interruption_strategies(self) -> Sequence[BaseInterruptionStrategy]:
|
||||
"""Get the interruption strategies for this processor.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
This function is deprecated, use the new user and bot turn start
|
||||
strategies insted.
|
||||
|
||||
Returns:
|
||||
Sequence of interruption strategies.
|
||||
"""
|
||||
return self._interruption_strategies
|
||||
|
||||
@property
|
||||
def task_manager(self) -> BaseTaskManager:
|
||||
"""Get the task manager for this processor.
|
||||
@@ -819,10 +778,8 @@ class FrameProcessor(BaseObject):
|
||||
frame: The start frame containing initialization parameters.
|
||||
"""
|
||||
self.__started = True
|
||||
self._allow_interruptions = frame.allow_interruptions
|
||||
self._enable_metrics = frame.enable_metrics
|
||||
self._enable_usage_metrics = frame.enable_usage_metrics
|
||||
self._interruption_strategies = frame.interruption_strategies
|
||||
self._report_only_initial_ttfb = frame.report_only_initial_ttfb
|
||||
|
||||
self.__create_process_task()
|
||||
|
||||
@@ -1,370 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Transcript processing utilities for conversation recording and analysis.
|
||||
|
||||
This module provides processors that convert speech and text frames into structured
|
||||
transcript messages with timestamps, enabling conversation history tracking and analysis.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
LLMThoughtEndFrame,
|
||||
LLMThoughtStartFrame,
|
||||
LLMThoughtTextFrame,
|
||||
ThoughtTranscriptionMessage,
|
||||
TranscriptionFrame,
|
||||
TranscriptionMessage,
|
||||
TranscriptionUpdateFrame,
|
||||
TTSTextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.utils.string import TextPartForConcatenation, concatenate_aggregated_text
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
|
||||
class BaseTranscriptProcessor(FrameProcessor):
|
||||
"""Base class for processing conversation transcripts.
|
||||
|
||||
Provides common functionality for handling transcript messages and updates.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize processor with empty message store.
|
||||
|
||||
Args:
|
||||
**kwargs: Additional arguments passed to parent class.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._processed_messages: List[TranscriptionMessage] = []
|
||||
self._register_event_handler("on_transcript_update")
|
||||
|
||||
async def _emit_update(self, messages: List[TranscriptionMessage]):
|
||||
"""Emit transcript updates for new messages.
|
||||
|
||||
Args:
|
||||
messages: New messages to emit in update.
|
||||
"""
|
||||
if messages:
|
||||
self._processed_messages.extend(messages)
|
||||
update_frame = TranscriptionUpdateFrame(messages=messages)
|
||||
await self._call_event_handler("on_transcript_update", update_frame)
|
||||
await self.push_frame(update_frame)
|
||||
|
||||
|
||||
class UserTranscriptProcessor(BaseTranscriptProcessor):
|
||||
"""Processes user transcription frames into timestamped conversation messages."""
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process TranscriptionFrames into user conversation messages.
|
||||
|
||||
Args:
|
||||
frame: Input frame to process.
|
||||
direction: Frame processing direction.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, TranscriptionFrame):
|
||||
message = TranscriptionMessage(
|
||||
role="user", user_id=frame.user_id, content=frame.text, timestamp=frame.timestamp
|
||||
)
|
||||
await self._emit_update([message])
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
class AssistantTranscriptProcessor(BaseTranscriptProcessor):
|
||||
"""Processes assistant TTS text frames and LLM thought frames into timestamped messages.
|
||||
|
||||
This processor aggregates both TTS text frames and LLM thought frames into
|
||||
complete utterances and thoughts, emitting them as transcript messages.
|
||||
|
||||
An assistant utterance is completed when:
|
||||
- The bot stops speaking (BotStoppedSpeakingFrame)
|
||||
- The bot is interrupted (InterruptionFrame)
|
||||
- The pipeline ends (EndFrame, CancelFrame)
|
||||
|
||||
A thought is completed when:
|
||||
- The thought ends (LLMThoughtEndFrame)
|
||||
- The bot is interrupted (InterruptionFrame)
|
||||
- The pipeline ends (EndFrame, CancelFrame)
|
||||
"""
|
||||
|
||||
def __init__(self, *, process_thoughts: bool = False, **kwargs):
|
||||
"""Initialize processor with aggregation state.
|
||||
|
||||
Args:
|
||||
process_thoughts: Whether to process LLM thought frames. Defaults to False.
|
||||
**kwargs: Additional arguments passed to parent class.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._process_thoughts = process_thoughts
|
||||
self._current_assistant_text_parts: List[TextPartForConcatenation] = []
|
||||
self._assistant_text_start_time: Optional[str] = None
|
||||
|
||||
self._current_thought_parts: List[TextPartForConcatenation] = []
|
||||
self._thought_start_time: Optional[str] = None
|
||||
self._thought_active = False
|
||||
|
||||
async def _emit_aggregated_assistant_text(self):
|
||||
"""Aggregates and emits text fragments as a transcript message.
|
||||
|
||||
This method aggregates text fragments that may arrive in multiple
|
||||
TTSTextFrame instances and emits them as a single TranscriptionMessage.
|
||||
"""
|
||||
if self._current_assistant_text_parts and self._assistant_text_start_time:
|
||||
content = concatenate_aggregated_text(self._current_assistant_text_parts)
|
||||
if content:
|
||||
logger.trace(f"Emitting aggregated assistant message: {content}")
|
||||
message = TranscriptionMessage(
|
||||
role="assistant",
|
||||
content=content,
|
||||
timestamp=self._assistant_text_start_time,
|
||||
)
|
||||
await self._emit_update([message])
|
||||
else:
|
||||
logger.trace("No content to emit after stripping whitespace")
|
||||
|
||||
# Reset aggregation state
|
||||
self._current_assistant_text_parts = []
|
||||
self._assistant_text_start_time = None
|
||||
|
||||
async def _emit_aggregated_thought(self):
|
||||
"""Aggregates and emits thought text fragments as a thought transcript message.
|
||||
|
||||
This method aggregates thought fragments that may arrive in multiple
|
||||
LLMThoughtTextFrame instances and emits them as a single ThoughtTranscriptionMessage.
|
||||
"""
|
||||
if self._current_thought_parts and self._thought_start_time:
|
||||
content = concatenate_aggregated_text(self._current_thought_parts)
|
||||
if content:
|
||||
logger.trace(f"Emitting aggregated thought message: {content}")
|
||||
message = ThoughtTranscriptionMessage(
|
||||
content=content,
|
||||
timestamp=self._thought_start_time,
|
||||
)
|
||||
await self._emit_update([message])
|
||||
else:
|
||||
logger.trace("No thought content to emit after stripping whitespace")
|
||||
|
||||
# Reset aggregation state
|
||||
self._current_thought_parts = []
|
||||
self._thought_start_time = None
|
||||
self._thought_active = False
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process frames into assistant conversation messages and thought messages.
|
||||
|
||||
Handles different frame types:
|
||||
|
||||
- TTSTextFrame: Aggregates text for current utterance
|
||||
- LLMThoughtStartFrame: Begins aggregating a new thought
|
||||
- LLMThoughtTextFrame: Aggregates text for current thought
|
||||
- LLMThoughtEndFrame: Completes current thought
|
||||
- BotStoppedSpeakingFrame: Completes current utterance
|
||||
- InterruptionFrame: Completes current utterance and thought due to interruption
|
||||
- EndFrame: Completes current utterance and thought at pipeline end
|
||||
- CancelFrame: Completes current utterance and thought due to cancellation
|
||||
|
||||
Args:
|
||||
frame: Input frame to process.
|
||||
direction: Frame processing direction.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, (InterruptionFrame, CancelFrame)):
|
||||
# Push frame first otherwise our emitted transcription update frame
|
||||
# might get cleaned up.
|
||||
await self.push_frame(frame, direction)
|
||||
# Emit accumulated text and thought with interruptions
|
||||
await self._emit_aggregated_assistant_text()
|
||||
if self._process_thoughts and self._thought_active:
|
||||
await self._emit_aggregated_thought()
|
||||
elif isinstance(frame, LLMThoughtStartFrame):
|
||||
# Start a new thought
|
||||
if self._process_thoughts:
|
||||
self._thought_active = True
|
||||
self._thought_start_time = time_now_iso8601()
|
||||
self._current_thought_parts = []
|
||||
# Push frame.
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, LLMThoughtTextFrame):
|
||||
# Aggregate thought text if we have an active thought
|
||||
if self._process_thoughts and self._thought_active:
|
||||
self._current_thought_parts.append(
|
||||
TextPartForConcatenation(
|
||||
frame.text, includes_inter_part_spaces=frame.includes_inter_frame_spaces
|
||||
)
|
||||
)
|
||||
# Push frame.
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, LLMThoughtEndFrame):
|
||||
# Emit accumulated thought when thought ends
|
||||
if self._process_thoughts and self._thought_active:
|
||||
await self._emit_aggregated_thought()
|
||||
# Push frame.
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, TTSTextFrame):
|
||||
# Start timestamp on first text part
|
||||
if not self._assistant_text_start_time:
|
||||
self._assistant_text_start_time = time_now_iso8601()
|
||||
|
||||
self._current_assistant_text_parts.append(
|
||||
TextPartForConcatenation(
|
||||
frame.text, includes_inter_part_spaces=frame.includes_inter_frame_spaces
|
||||
)
|
||||
)
|
||||
|
||||
# Push frame.
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, (BotStoppedSpeakingFrame, EndFrame)):
|
||||
# Emit accumulated text when bot finishes speaking or pipeline ends.
|
||||
await self._emit_aggregated_assistant_text()
|
||||
# Emit accumulated thought at pipeline end if still active
|
||||
if isinstance(frame, EndFrame) and self._process_thoughts and self._thought_active:
|
||||
await self._emit_aggregated_thought()
|
||||
# Push frame.
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
class TranscriptProcessor:
|
||||
"""Factory for creating and managing transcript processors.
|
||||
|
||||
Provides unified access to user and assistant transcript processors
|
||||
with shared event handling. The assistant processor handles both TTS text
|
||||
and LLM thought frames.
|
||||
|
||||
Example::
|
||||
|
||||
transcript = TranscriptProcessor()
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
stt,
|
||||
transcript.user(), # User transcripts
|
||||
context_aggregator.user(),
|
||||
llm,
|
||||
tts,
|
||||
transport.output(),
|
||||
transcript.assistant(), # Assistant transcripts (including thoughts)
|
||||
context_aggregator.assistant(),
|
||||
]
|
||||
)
|
||||
|
||||
@transcript.event_handler("on_transcript_update")
|
||||
async def handle_update(processor, frame):
|
||||
print(f"New messages: {frame.messages}")
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
`TranscriptProcessor` is deprecated and will be removed in a future version.
|
||||
Use `LLMUserAggregator`'s and `LLMAssistantAggregator`'s new events instead.
|
||||
"""
|
||||
|
||||
def __init__(self, *, process_thoughts: bool = False):
|
||||
"""Initialize factory.
|
||||
|
||||
Args:
|
||||
process_thoughts: Whether the assistant processor should handle LLM thought
|
||||
frames. Defaults to False.
|
||||
"""
|
||||
self._process_thoughts = process_thoughts
|
||||
self._user_processor = None
|
||||
self._assistant_processor = None
|
||||
self._event_handlers = {}
|
||||
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"`TranscriptProcessor` is deprecated and will be removed in a future version. "
|
||||
"Use `LLMUserAggregator`'s and `LLMAssistantAggregator`'s new events instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
def user(self, **kwargs) -> UserTranscriptProcessor:
|
||||
"""Get the user transcript processor.
|
||||
|
||||
Args:
|
||||
**kwargs: Arguments specific to UserTranscriptProcessor.
|
||||
|
||||
Returns:
|
||||
The user transcript processor instance.
|
||||
"""
|
||||
if self._user_processor is None:
|
||||
self._user_processor = UserTranscriptProcessor(**kwargs)
|
||||
# Apply any registered event handlers
|
||||
for event_name, handler in self._event_handlers.items():
|
||||
|
||||
@self._user_processor.event_handler(event_name)
|
||||
async def user_handler(processor, frame):
|
||||
return await handler(processor, frame)
|
||||
|
||||
return self._user_processor
|
||||
|
||||
def assistant(self, **kwargs) -> AssistantTranscriptProcessor:
|
||||
"""Get the assistant transcript processor.
|
||||
|
||||
Args:
|
||||
**kwargs: Arguments specific to AssistantTranscriptProcessor.
|
||||
|
||||
Returns:
|
||||
The assistant transcript processor instance.
|
||||
"""
|
||||
if self._assistant_processor is None:
|
||||
self._assistant_processor = AssistantTranscriptProcessor(
|
||||
process_thoughts=self._process_thoughts, **kwargs
|
||||
)
|
||||
# Apply any registered event handlers
|
||||
for event_name, handler in self._event_handlers.items():
|
||||
|
||||
@self._assistant_processor.event_handler(event_name)
|
||||
async def assistant_handler(processor, frame):
|
||||
return await handler(processor, frame)
|
||||
|
||||
return self._assistant_processor
|
||||
|
||||
def event_handler(self, event_name: str):
|
||||
"""Register event handler for both processors.
|
||||
|
||||
Args:
|
||||
event_name: Name of event to handle.
|
||||
|
||||
Returns:
|
||||
Decorator function that registers handler with both processors.
|
||||
"""
|
||||
|
||||
def decorator(handler):
|
||||
self._event_handlers[event_name] = handler
|
||||
|
||||
# Apply handler to existing processors if they exist
|
||||
if self._user_processor:
|
||||
|
||||
@self._user_processor.event_handler(event_name)
|
||||
async def user_handler(processor, frame):
|
||||
return await handler(processor, frame)
|
||||
|
||||
if self._assistant_processor:
|
||||
|
||||
@self._assistant_processor.event_handler(event_name)
|
||||
async def assistant_handler(processor, frame):
|
||||
return await handler(processor, frame)
|
||||
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
@@ -130,11 +130,6 @@ class AnthropicLLMService(LLMService):
|
||||
|
||||
Parameters:
|
||||
enable_prompt_caching: Whether to enable the prompt caching feature.
|
||||
enable_prompt_caching_beta (deprecated): Whether to enable the beta prompt caching feature.
|
||||
|
||||
.. deprecated:: 0.0.84
|
||||
Use the `enable_prompt_caching` parameter instead.
|
||||
|
||||
max_tokens: Maximum tokens to generate. Must be at least 1.
|
||||
temperature: Sampling temperature between 0.0 and 1.0.
|
||||
top_k: Top-k sampling parameter.
|
||||
@@ -147,7 +142,6 @@ class AnthropicLLMService(LLMService):
|
||||
"""
|
||||
|
||||
enable_prompt_caching: Optional[bool] = None
|
||||
enable_prompt_caching_beta: Optional[bool] = None
|
||||
max_tokens: Optional[int] = Field(default_factory=lambda: 4096, ge=1)
|
||||
temperature: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=1.0)
|
||||
top_k: Optional[int] = Field(default_factory=lambda: NOT_GIVEN, ge=0)
|
||||
@@ -157,18 +151,6 @@ class AnthropicLLMService(LLMService):
|
||||
)
|
||||
extra: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
def model_post_init(self, __context):
|
||||
"""Post-initialization to handle deprecated parameters."""
|
||||
if self.enable_prompt_caching_beta is not None:
|
||||
import warnings
|
||||
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"enable_prompt_caching_beta is deprecated. Use enable_prompt_caching instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -237,22 +219,8 @@ class AnthropicLLMService(LLMService):
|
||||
default_settings.thinking = params.thinking
|
||||
if isinstance(params.extra, dict):
|
||||
default_settings.extra = params.extra
|
||||
# Handle enable_prompt_caching / enable_prompt_caching_beta
|
||||
enable_prompt_caching = params.enable_prompt_caching
|
||||
if params.enable_prompt_caching_beta is not None:
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"enable_prompt_caching_beta is deprecated. "
|
||||
"Use enable_prompt_caching instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if enable_prompt_caching is None:
|
||||
enable_prompt_caching = params.enable_prompt_caching_beta
|
||||
default_settings.enable_prompt_caching = enable_prompt_caching or False
|
||||
if params.enable_prompt_caching is not None:
|
||||
default_settings.enable_prompt_caching = params.enable_prompt_caching
|
||||
|
||||
# 4. Apply settings delta (canonical API, always wins)
|
||||
if settings is not None:
|
||||
|
||||
@@ -36,6 +36,7 @@ from pipecat.frames.frames import (
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
LLMContextFrame,
|
||||
LLMEnablePromptCachingFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
UserImageRawFrame,
|
||||
@@ -66,11 +67,16 @@ class AWSBedrockLLMSettings(LLMSettings):
|
||||
Parameters:
|
||||
stop_sequences: List of strings that stop generation.
|
||||
latency: Performance mode - "standard" or "optimized".
|
||||
enable_prompt_caching: Whether to enable prompt caching by adding cachePoint
|
||||
markers to system prompts and tool definitions. Can reduce TTFT by up to
|
||||
85% for multi-turn conversations. See:
|
||||
https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html
|
||||
additional_model_request_fields: Additional model-specific parameters.
|
||||
"""
|
||||
|
||||
stop_sequences: List[str] | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
|
||||
latency: str | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
|
||||
enable_prompt_caching: bool | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
|
||||
additional_model_request_fields: Dict[str, Any] | _NotGiven = field(
|
||||
default_factory=lambda: NOT_GIVEN
|
||||
)
|
||||
@@ -174,6 +180,7 @@ class AWSBedrockLLMService(LLMService):
|
||||
user_turn_completion_config=None,
|
||||
stop_sequences=None,
|
||||
latency=None,
|
||||
enable_prompt_caching=False,
|
||||
additional_model_request_fields={},
|
||||
)
|
||||
|
||||
@@ -455,6 +462,24 @@ class AWSBedrockLLMService(LLMService):
|
||||
if self._settings.latency in ["standard", "optimized"]:
|
||||
request_params["performanceConfig"] = {"latency": self._settings.latency}
|
||||
|
||||
# Add cache checkpoints to system prompts and tool definitions.
|
||||
# This enables prompt caching for providers that support it (e.g.
|
||||
# Anthropic Claude on Bedrock), reducing TTFT by up to 85% on
|
||||
# multi-turn conversations where the system prompt stays constant.
|
||||
if self._settings.enable_prompt_caching:
|
||||
if "system" in request_params and request_params["system"]:
|
||||
system_list = request_params["system"]
|
||||
if not any("cachePoint" in item for item in system_list):
|
||||
system_list.append({"cachePoint": {"type": "default"}})
|
||||
if (
|
||||
"toolConfig" in request_params
|
||||
and "tools" in request_params["toolConfig"]
|
||||
and request_params["toolConfig"]["tools"]
|
||||
):
|
||||
tools_list = request_params["toolConfig"]["tools"]
|
||||
if not any("cachePoint" in t for t in tools_list):
|
||||
tools_list.append({"cachePoint": {"type": "default"}})
|
||||
|
||||
# Log request params with messages redacted for logging
|
||||
adapter = self.get_llm_adapter()
|
||||
messages_for_logging = adapter.get_messages_for_logging(context)
|
||||
@@ -566,6 +591,9 @@ class AWSBedrockLLMService(LLMService):
|
||||
|
||||
if isinstance(frame, LLMContextFrame):
|
||||
await self._process_context(frame.context)
|
||||
elif isinstance(frame, LLMEnablePromptCachingFrame):
|
||||
logger.debug(f"Setting enable prompt caching to: [{frame.enable}]")
|
||||
self._settings.enable_prompt_caching = frame.enable
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
@@ -258,7 +258,6 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
settings: Optional[Settings] = None,
|
||||
system_instruction: Optional[str] = None,
|
||||
tools: Optional[ToolsSchema] = None,
|
||||
send_transcription_frames: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initializes the AWS Nova Sonic LLM service.
|
||||
@@ -302,12 +301,6 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
.. deprecated:: 0.0.105
|
||||
Use ``settings=AWSNovaSonicLLMService.Settings(system_instruction=...)`` instead.
|
||||
tools: Available tools/functions for the model to use.
|
||||
send_transcription_frames: Whether to emit transcription frames.
|
||||
|
||||
.. deprecated:: 0.0.91
|
||||
This parameter is deprecated and will be removed in a future version.
|
||||
Transcription frames are always sent.
|
||||
|
||||
**kwargs: Additional arguments passed to the parent LLMService.
|
||||
"""
|
||||
# 1. Initialize default_settings with hardcoded defaults
|
||||
@@ -391,18 +384,6 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
)
|
||||
self._settings.endpointing_sensitivity = None
|
||||
|
||||
if not send_transcription_frames:
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"`send_transcription_frames` is deprecated and will be removed in a future version. "
|
||||
"Transcription frames are always sent.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
self._context: Optional[LLMContext] = None
|
||||
self._stream: Optional[
|
||||
DuplexEventStream[
|
||||
@@ -1300,18 +1281,8 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
# HACK: Check if this transcription was triggered by our own
|
||||
# assistant response trigger. If so, we need to wrap it with
|
||||
# UserStarted/StoppedSpeakingFrames; otherwise the user aggregator
|
||||
# would fire an EmulatedUserStartedSpeakingFrame, which would
|
||||
# trigger an interruption, which would prevent us from writing the
|
||||
# assistant response to context.
|
||||
#
|
||||
# Sending an EmulateUserStartedSpeakingFrame ourselves doesn't
|
||||
# work: it just causes the interruption we're trying to avoid.
|
||||
#
|
||||
# Setting enable_emulated_vad_interruptions also doesn't work: at
|
||||
# the time the user aggregator receives the TranscriptionFrame, it
|
||||
# doesn't yet know the assistant has started responding, so it
|
||||
# doesn't know that emulating the user starting to speak would
|
||||
# cause an interruption.
|
||||
# would trigger an interruption, which would prevent us from
|
||||
# writing the assistant response to context.
|
||||
should_wrap_in_user_started_stopped_speaking_frames = (
|
||||
self._waiting_for_trigger_transcription
|
||||
and self._user_text_buffer.strip().lower() == "ready"
|
||||
|
||||
@@ -369,31 +369,3 @@ class AWSPollyTTSService(TTSService):
|
||||
except (BotoCoreError, ClientError) as error:
|
||||
error_message = f"AWS Polly TTS error: {str(error)}"
|
||||
yield ErrorFrame(error=error_message)
|
||||
|
||||
|
||||
class PollyTTSService(AWSPollyTTSService):
|
||||
"""Deprecated alias for AWSPollyTTSService.
|
||||
|
||||
.. deprecated:: 0.0.67
|
||||
`PollyTTSService` is deprecated, use `AWSPollyTTSService` instead.
|
||||
|
||||
"""
|
||||
|
||||
Settings = AWSPollyTTSSettings
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize the deprecated PollyTTSService.
|
||||
|
||||
Args:
|
||||
**kwargs: All arguments passed to AWSPollyTTSService.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"'PollyTTSService' is deprecated, use 'AWSPollyTTSService' instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
@@ -10,7 +10,7 @@ import base64
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
from typing import Any, AsyncGenerator, List, Optional
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
@@ -28,7 +28,6 @@ from pipecat.frames.frames import (
|
||||
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven
|
||||
from pipecat.services.tts_service import TextAggregationMode, TTSService, WebsocketTTSService
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
|
||||
from pipecat.utils.text.skip_tags_aggregator import SkipTagsAggregator
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
@@ -240,7 +239,6 @@ class CartesiaTTSService(WebsocketTTSService):
|
||||
container: str = "raw",
|
||||
params: Optional[InputParams] = None,
|
||||
settings: Optional[Settings] = None,
|
||||
text_aggregator: Optional[BaseTextAggregator] = None,
|
||||
text_aggregation_mode: Optional[TextAggregationMode] = None,
|
||||
aggregate_sentences: Optional[bool] = None,
|
||||
**kwargs,
|
||||
@@ -271,11 +269,6 @@ class CartesiaTTSService(WebsocketTTSService):
|
||||
|
||||
settings: Runtime-updatable settings. When provided alongside deprecated
|
||||
parameters, ``settings`` values take precedence.
|
||||
text_aggregator: Custom text aggregator for processing input text.
|
||||
|
||||
.. deprecated:: 0.0.95
|
||||
Use an LLMTextProcessor before the TTSService for custom text aggregation.
|
||||
|
||||
text_aggregation_mode: How to aggregate incoming text before synthesis.
|
||||
aggregate_sentences: Whether to aggregate sentences within the TTSService.
|
||||
|
||||
@@ -337,20 +330,18 @@ class CartesiaTTSService(WebsocketTTSService):
|
||||
pause_frame_processing=False,
|
||||
sample_rate=sample_rate,
|
||||
push_start_frame=True,
|
||||
text_aggregator=text_aggregator,
|
||||
settings=default_settings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not text_aggregator:
|
||||
# Always skip tags added for spelled-out text
|
||||
# Note: This is primarily to support backwards compatibility.
|
||||
# The preferred way of taking advantage of Cartesia SSML Tags is
|
||||
# to use an LLMTextProcessor and/or a text_transformer to identify
|
||||
# and insert these tags for the purpose of the TTS service alone.
|
||||
self._text_aggregator = SkipTagsAggregator(
|
||||
[("<spell>", "</spell>")], aggregation_type=self._text_aggregation_mode
|
||||
)
|
||||
# Always skip tags added for spelled-out text
|
||||
# Note: This is primarily to support backwards compatibility.
|
||||
# The preferred way of taking advantage of Cartesia SSML Tags is
|
||||
# to use an LLMTextProcessor and/or a text_transformer to identify
|
||||
# and insert these tags for the purpose of the TTS service alone.
|
||||
self._text_aggregator = SkipTagsAggregator(
|
||||
[("<spell>", "</spell>")], aggregation_type=self._text_aggregation_mode
|
||||
)
|
||||
|
||||
self._api_key = api_key
|
||||
self._cartesia_version = cartesia_version
|
||||
@@ -599,6 +590,34 @@ class CartesiaTTSService(WebsocketTTSService):
|
||||
msg = self._build_msg(text="", continue_transcript=False, context_id=flush_id)
|
||||
await self._websocket.send(msg)
|
||||
|
||||
async def _update_settings(self, delta: CartesiaTTSSettings) -> dict[str, Any]:
|
||||
"""Apply a TTS settings delta, flushing the context if needed.
|
||||
|
||||
Voice, model, and language are locked per Cartesia context. If any of
|
||||
these change, the current context is flushed so the next sentence opens
|
||||
a fresh one with the updated settings.
|
||||
|
||||
Args:
|
||||
delta: A TTS settings delta.
|
||||
|
||||
Returns:
|
||||
Dict mapping changed field names to their previous values.
|
||||
"""
|
||||
changed = await super()._update_settings(delta)
|
||||
if not changed:
|
||||
return changed
|
||||
|
||||
if changed.keys() & {"voice", "model", "language"}:
|
||||
if self._turn_context_id and self.audio_context_available(self._turn_context_id):
|
||||
await self.flush_audio(context_id=self._turn_context_id)
|
||||
# Assign a new turn context ID so subsequent sentences in this
|
||||
# turn open a new Cartesia context with the updated settings.
|
||||
if self._turn_context_id:
|
||||
self._turn_context_id = None
|
||||
self._turn_context_id = self.create_context_id()
|
||||
|
||||
return changed
|
||||
|
||||
async def _process_messages(self):
|
||||
async for message in self._get_websocket():
|
||||
msg = json.loads(message)
|
||||
|
||||
@@ -314,7 +314,6 @@ class DeepgramSTTService(STTService):
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
url: str = "",
|
||||
base_url: str = "",
|
||||
encoding: str = "linear16",
|
||||
channels: int = 1,
|
||||
@@ -335,11 +334,6 @@ class DeepgramSTTService(STTService):
|
||||
|
||||
Args:
|
||||
api_key: Deepgram API key for authentication.
|
||||
url: Custom Deepgram API base URL.
|
||||
|
||||
.. deprecated:: 0.0.64
|
||||
Parameter `url` is deprecated, use `base_url` instead.
|
||||
|
||||
base_url: Custom Deepgram API base URL.
|
||||
encoding: Audio encoding format. Defaults to "linear16".
|
||||
channels: Number of audio channels. Defaults to 1.
|
||||
@@ -374,17 +368,6 @@ class DeepgramSTTService(STTService):
|
||||
Note:
|
||||
The `vad_events` option in LiveOptions is deprecated as of version 0.0.99 and will be removed in a future version. Please use the Silero VAD instead.
|
||||
"""
|
||||
if url:
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Parameter 'url' is deprecated, use 'base_url' instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
base_url = url
|
||||
|
||||
# 1. Initialize default_settings with hardcoded defaults
|
||||
default_settings = self.Settings(
|
||||
model="nova-3-general",
|
||||
|
||||
@@ -110,7 +110,6 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
*,
|
||||
api_key: str,
|
||||
reference_id: Optional[str] = None, # This is the voice ID
|
||||
model: Optional[str] = None, # Deprecated
|
||||
model_id: Optional[str] = None,
|
||||
output_format: FishAudioOutputFormat = "pcm",
|
||||
sample_rate: Optional[int] = None,
|
||||
@@ -127,12 +126,6 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
.. deprecated:: 0.0.105
|
||||
Use ``settings=FishAudioTTSService.Settings(voice=...)`` instead.
|
||||
|
||||
model: Deprecated. Reference ID of the voice model to use for synthesis.
|
||||
|
||||
.. deprecated:: 0.0.74
|
||||
The ``model`` parameter is deprecated and will be removed in version 0.1.0.
|
||||
Use ``reference_id`` instead to specify the voice model.
|
||||
|
||||
model_id: Specify which Fish Audio TTS model to use (e.g. "s1").
|
||||
|
||||
.. deprecated:: 0.0.105
|
||||
@@ -149,25 +142,6 @@ class FishAudioTTSService(InterruptibleTTSService):
|
||||
parameters, ``settings`` values take precedence.
|
||||
**kwargs: Additional arguments passed to the parent service.
|
||||
"""
|
||||
# Validation for model and reference_id parameters
|
||||
if model and reference_id:
|
||||
raise ValueError(
|
||||
"Cannot specify both 'model' and 'reference_id'. Use 'reference_id' only."
|
||||
)
|
||||
|
||||
if model:
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Parameter 'model' is deprecated and will be removed in a future version. "
|
||||
"Use 'reference_id' instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
reference_id = model
|
||||
|
||||
# 1. Initialize default_settings with hardcoded defaults
|
||||
default_settings = self.Settings(
|
||||
model="s2-pro",
|
||||
|
||||
@@ -10,8 +10,6 @@ from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
|
||||
class LanguageConfig(BaseModel):
|
||||
"""Configuration for language detection and handling.
|
||||
@@ -163,12 +161,6 @@ class GladiaInputParams(BaseModel):
|
||||
custom_metadata: Additional metadata to include with requests
|
||||
endpointing: Silence duration in seconds to mark end of speech
|
||||
maximum_duration_without_endpointing: Maximum utterance duration without silence
|
||||
language: Language code for transcription
|
||||
|
||||
.. deprecated:: 0.0.62
|
||||
The 'language' parameter is deprecated and will be removed in a future version.
|
||||
Use 'language_config' instead.
|
||||
|
||||
language_config: Detailed language configuration
|
||||
pre_processing: Audio pre-processing options
|
||||
realtime_processing: Real-time processing features
|
||||
@@ -184,7 +176,6 @@ class GladiaInputParams(BaseModel):
|
||||
custom_metadata: Optional[Dict[str, Any]] = None
|
||||
endpointing: Optional[float] = None
|
||||
maximum_duration_without_endpointing: Optional[int] = 5
|
||||
language: Optional[Language] = None # Deprecated
|
||||
language_config: Optional[LanguageConfig] = None
|
||||
pre_processing: Optional[PreProcessingConfig] = None
|
||||
realtime_processing: Optional[RealtimeProcessingConfig] = None
|
||||
|
||||
@@ -13,7 +13,6 @@ supporting multiple languages, custom vocabulary, and various audio processing o
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, AsyncGenerator, Literal, Optional
|
||||
|
||||
@@ -171,21 +170,6 @@ def language_to_gladia_language(language: Language) -> Optional[str]:
|
||||
|
||||
|
||||
# Deprecation warning for nested InputParams
|
||||
class _InputParamsDescriptor:
|
||||
"""Descriptor for backward compatibility with deprecation warning."""
|
||||
|
||||
def __get__(self, obj, objtype=None):
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"GladiaSTTService.InputParams is deprecated and will be removed in a future version. "
|
||||
"Import and use GladiaInputParams directly instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return GladiaInputParams
|
||||
|
||||
|
||||
@dataclass
|
||||
class GladiaSTTSettings(STTSettings):
|
||||
"""Settings for GladiaSTTService.
|
||||
@@ -225,17 +209,11 @@ class GladiaSTTService(WebsocketSTTService):
|
||||
Provides automatic reconnection, audio buffering, and comprehensive error handling.
|
||||
|
||||
For complete API documentation, see: https://docs.gladia.io/api-reference/v2/live/init
|
||||
|
||||
.. deprecated:: 0.0.62
|
||||
Use :class:`~pipecat.services.gladia.config.GladiaInputParams` directly instead.
|
||||
"""
|
||||
|
||||
Settings = GladiaSTTSettings
|
||||
_settings: Settings
|
||||
|
||||
# Maintain backward compatibility
|
||||
InputParams = _InputParamsDescriptor()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -245,7 +223,6 @@ class GladiaSTTService(WebsocketSTTService):
|
||||
encoding: str = "wav/pcm",
|
||||
bit_depth: int = 16,
|
||||
channels: int = 1,
|
||||
confidence: Optional[float] = None,
|
||||
sample_rate: Optional[int] = None,
|
||||
model: Optional[str] = None,
|
||||
params: Optional[GladiaInputParams] = None,
|
||||
@@ -264,12 +241,6 @@ class GladiaSTTService(WebsocketSTTService):
|
||||
encoding: Audio encoding format. Defaults to ``"wav/pcm"``.
|
||||
bit_depth: Audio bit depth. Defaults to 16.
|
||||
channels: Number of audio channels. Defaults to 1.
|
||||
confidence: Minimum confidence threshold for transcriptions (0.0-1.0).
|
||||
|
||||
.. deprecated:: 0.0.86
|
||||
The 'confidence' parameter is deprecated and will be removed in a future version.
|
||||
No confidence threshold is applied.
|
||||
|
||||
sample_rate: Audio sample rate in Hz. If None, uses service default.
|
||||
model: Model to use for transcription.
|
||||
|
||||
@@ -291,16 +262,6 @@ class GladiaSTTService(WebsocketSTTService):
|
||||
Override for your deployment. See https://github.com/pipecat-ai/stt-benchmark
|
||||
**kwargs: Additional arguments passed to the STTService parent class.
|
||||
"""
|
||||
if confidence:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"The 'confidence' parameter is deprecated and will be removed in a future version. "
|
||||
"No confidence threshold is applied.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# 1. Initialize default_settings with hardcoded defaults
|
||||
default_settings = self.Settings(
|
||||
model="solaria-1",
|
||||
@@ -323,15 +284,6 @@ class GladiaSTTService(WebsocketSTTService):
|
||||
# 3. Apply params overrides — only if settings not provided
|
||||
if params is not None:
|
||||
self._warn_init_param_moved_to_settings("params")
|
||||
if params.language is not None:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"The 'language' parameter is deprecated and will be removed in a future "
|
||||
"version. Use 'language_config' instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if not settings:
|
||||
# Extract init-only fields from params
|
||||
if params.encoding is not None:
|
||||
@@ -349,15 +301,8 @@ class GladiaSTTService(WebsocketSTTService):
|
||||
default_settings.realtime_processing = params.realtime_processing
|
||||
default_settings.messages_config = params.messages_config
|
||||
default_settings.enable_vad = params.enable_vad
|
||||
# Resolve deprecated language → language_config at init time
|
||||
if params.language_config:
|
||||
default_settings.language_config = params.language_config
|
||||
elif params.language:
|
||||
language_code = self.language_to_service_language(params.language)
|
||||
if language_code:
|
||||
default_settings.language_config = LanguageConfig(
|
||||
languages=[language_code], code_switching=False
|
||||
)
|
||||
|
||||
# 4. Apply settings delta (canonical API, always wins)
|
||||
if settings is not None:
|
||||
|
||||
@@ -380,7 +380,6 @@ class GeminiLiveLLMService(LLMService):
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
base_url: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
voice_id: str = "Charon",
|
||||
start_audio_paused: bool = False,
|
||||
@@ -398,13 +397,6 @@ class GeminiLiveLLMService(LLMService):
|
||||
|
||||
Args:
|
||||
api_key: Google AI API key for authentication.
|
||||
base_url: API endpoint base URL. Defaults to the official Gemini Live endpoint.
|
||||
|
||||
.. deprecated:: 0.0.90
|
||||
This parameter is deprecated and no longer has any effect.
|
||||
Please use `http_options` to customize requests made by the
|
||||
API client.
|
||||
|
||||
model: Model identifier to use.
|
||||
|
||||
.. deprecated:: 0.0.105
|
||||
@@ -431,18 +423,6 @@ class GeminiLiveLLMService(LLMService):
|
||||
http_options: HTTP options for the client.
|
||||
**kwargs: Additional arguments passed to parent LLMService.
|
||||
"""
|
||||
# Check for deprecated parameter usage
|
||||
if base_url is not None:
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Parameter 'base_url' is deprecated and no longer has any effect. Please use 'http_options' to customize requests made by the API client.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# 1. Initialize default_settings with hardcoded defaults
|
||||
default_settings = self.Settings(
|
||||
model="models/gemini-2.5-flash-native-audio-preview-12-2025",
|
||||
@@ -515,13 +495,11 @@ class GeminiLiveLLMService(LLMService):
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
base_url=base_url,
|
||||
settings=default_settings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self._last_sent_time = 0
|
||||
self._base_url = base_url
|
||||
|
||||
self._system_instruction_from_init = self._settings.system_instruction
|
||||
self._tools_from_init = tools
|
||||
|
||||
@@ -755,7 +755,7 @@ class GoogleSTTService(STTService):
|
||||
) -> None:
|
||||
"""Update service options dynamically.
|
||||
|
||||
.. deprecated::
|
||||
.. deprecated:: 0.0.104
|
||||
Use ``STTUpdateSettingsFrame`` with ``GoogleSTTService.Settings(...)``
|
||||
instead.
|
||||
|
||||
@@ -1004,7 +1004,7 @@ class GoogleSTTService(STTService):
|
||||
except Aborted as e:
|
||||
# Handle stream abort due to inactivity (409 error).
|
||||
# This occurs when no audio is sent to the stream for 10+ seconds,
|
||||
# which can happen when InputAudioRawFrames are blocked (e.g., by STTMuteFilter).
|
||||
# which can happen when InputAudioRawFrames are blocked.
|
||||
# Google's STT service automatically closes the stream in this case.
|
||||
# We log at DEBUG level (not ERROR) since this is recoverable, then re-raise
|
||||
# to trigger automatic reconnection in _stream_audio.
|
||||
|
||||
@@ -16,7 +16,6 @@ for natural voice control and multi-speaker conversations.
|
||||
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
@@ -1259,7 +1258,6 @@ class GeminiTTSService(GoogleBaseTTSService):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
credentials: Optional[str] = None,
|
||||
credentials_path: Optional[str] = None,
|
||||
@@ -1273,12 +1271,6 @@ class GeminiTTSService(GoogleBaseTTSService):
|
||||
"""Initializes the Gemini TTS service.
|
||||
|
||||
Args:
|
||||
api_key:
|
||||
|
||||
.. deprecated:: 0.0.95
|
||||
The `api_key` parameter is deprecated. Use `credentials` or
|
||||
`credentials_path` instead for Google Cloud authentication.
|
||||
|
||||
model: Gemini TTS model to use. Must be a TTS model like
|
||||
"gemini-2.5-flash-tts" or "gemini-2.5-pro-tts".
|
||||
|
||||
@@ -1303,15 +1295,6 @@ class GeminiTTSService(GoogleBaseTTSService):
|
||||
parameters, ``settings`` values take precedence.
|
||||
**kwargs: Additional arguments passed to parent TTSService.
|
||||
"""
|
||||
# Handle deprecated api_key parameter
|
||||
if api_key is not None:
|
||||
warnings.warn(
|
||||
"The 'api_key' parameter is deprecated and will be removed in a future version. "
|
||||
"Use 'credentials' or 'credentials_path' instead for Google Cloud authentication.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if sample_rate and sample_rate != self.GOOGLE_SAMPLE_RATE:
|
||||
logger.warning(
|
||||
f"Google TTS only supports {self.GOOGLE_SAMPLE_RATE}Hz sample rate. "
|
||||
|
||||
@@ -61,58 +61,14 @@ class GoogleVertexLLMService(GoogleLLMService):
|
||||
Settings = GoogleVertexLLMSettings
|
||||
_settings: Settings
|
||||
|
||||
class InputParams(GoogleLLMService.InputParams):
|
||||
"""Input parameters specific to Vertex AI.
|
||||
|
||||
Parameters:
|
||||
location: GCP region for Vertex AI endpoint (e.g., "us-east4").
|
||||
|
||||
.. deprecated:: 0.0.90
|
||||
Use `location` as a direct argument to
|
||||
`GoogleVertexLLMService.__init__()` instead.
|
||||
|
||||
project_id: Google Cloud project ID.
|
||||
|
||||
.. deprecated:: 0.0.90
|
||||
Use `project_id` as a direct argument to
|
||||
`GoogleVertexLLMService.__init__()` instead.
|
||||
"""
|
||||
|
||||
# https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations
|
||||
location: Optional[str] = None
|
||||
project_id: Optional[str] = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initializes the InputParams."""
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
if "location" in kwargs and kwargs["location"] is not None:
|
||||
warnings.warn(
|
||||
"GoogleVertexLLMService.InputParams.location is deprecated. "
|
||||
"Please provide 'location' as a direct argument to GoogleVertexLLMService.__init__() instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if "project_id" in kwargs and kwargs["project_id"] is not None:
|
||||
warnings.warn(
|
||||
"GoogleVertexLLMService.InputParams.project_id is deprecated. "
|
||||
"Please provide 'project_id' as a direct argument to GoogleVertexLLMService.__init__() instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
credentials: Optional[str] = None,
|
||||
credentials_path: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
project_id: Optional[str] = None,
|
||||
location: str = "us-east4",
|
||||
project_id: str,
|
||||
params: Optional[GoogleLLMService.InputParams] = None,
|
||||
settings: Optional[Settings] = None,
|
||||
system_instruction: Optional[str] = None,
|
||||
@@ -131,7 +87,7 @@ class GoogleVertexLLMService(GoogleLLMService):
|
||||
.. deprecated:: 0.0.105
|
||||
Use ``settings=GoogleVertexLLMService.Settings(model=...)`` instead.
|
||||
|
||||
location: GCP region for Vertex AI endpoint (e.g., "us-east4").
|
||||
location: GCP region for Vertex AI endpoint. Defaults to "us-east4".
|
||||
project_id: Google Cloud project ID.
|
||||
params: Input parameters for the model.
|
||||
|
||||
@@ -161,34 +117,6 @@ class GoogleVertexLLMService(GoogleLLMService):
|
||||
"Invalid parameter 'api_key'. Use 'credentials' or 'credentials_path' for Vertex AI authentication."
|
||||
)
|
||||
|
||||
# Handle deprecated InputParams fields (location/project_id extraction
|
||||
# must happen before validation, regardless of settings)
|
||||
if params and isinstance(params, GoogleVertexLLMService.InputParams):
|
||||
if project_id is None:
|
||||
project_id = params.project_id
|
||||
if location is None:
|
||||
location = params.location
|
||||
# Convert to base InputParams
|
||||
params = GoogleLLMService.InputParams(
|
||||
**params.model_dump(exclude={"location", "project_id"}, exclude_unset=True)
|
||||
)
|
||||
|
||||
# Validate project_id and location parameters
|
||||
# NOTE: once we remove Vertex-specific InputParams class, we can update
|
||||
# __init__() signature as follows:
|
||||
# - location: str = "us-east4",
|
||||
# - project_id: str,
|
||||
# But for now, we need them as-is to maintain proper backward
|
||||
# compatibility.
|
||||
if project_id is None:
|
||||
raise ValueError("project_id is required")
|
||||
if location is None:
|
||||
# If location is not provided, default to "us-east4".
|
||||
# Note: this is legacy behavior; ideally location would be
|
||||
# required.
|
||||
logger.warning("location is not provided. Defaulting to 'us-east4'.")
|
||||
location = "us-east4" # Default location if not provided
|
||||
|
||||
# These need to be set before calling super().__init__() because
|
||||
# super().__init__() invokes _create_client(), which needs these.
|
||||
self._credentials = self._get_credentials(credentials, credentials_path)
|
||||
|
||||
@@ -185,7 +185,7 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
|
||||
def __init__(
|
||||
self,
|
||||
run_in_parallel: bool = True,
|
||||
function_call_timeout_secs: float = 10.0,
|
||||
function_call_timeout_secs: Optional[float] = None,
|
||||
settings: Optional[LLMSettings] = None,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -194,8 +194,8 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
|
||||
Args:
|
||||
run_in_parallel: Whether to run function calls in parallel or sequentially.
|
||||
Defaults to True.
|
||||
function_call_timeout_secs: Timeout in seconds for deferred function calls.
|
||||
Defaults to 10.0 seconds.
|
||||
function_call_timeout_secs: Optional timeout in seconds for deferred function
|
||||
calls.
|
||||
settings: The runtime-updatable settings for the LLM service.
|
||||
**kwargs: Additional arguments passed to the parent AIService.
|
||||
|
||||
@@ -753,11 +753,7 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
|
||||
# Start a timeout task for deferred function calls
|
||||
async def timeout_handler():
|
||||
try:
|
||||
effective_timeout = (
|
||||
item.timeout_secs
|
||||
if item.timeout_secs is not None
|
||||
else self._function_call_timeout_secs
|
||||
)
|
||||
effective_timeout = item.timeout_secs or self._function_call_timeout_secs
|
||||
await asyncio.sleep(effective_timeout)
|
||||
logger.warning(
|
||||
f"{self} Function call [{runner_item.function_name}:{runner_item.tool_call_id}] timed out after {effective_timeout} seconds."
|
||||
@@ -768,13 +764,15 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
|
||||
timeout_task = self.create_task(timeout_handler())
|
||||
if item.timeout_secs or self._function_call_timeout_secs:
|
||||
timeout_task = self.create_task(timeout_handler())
|
||||
|
||||
# Yield to the event loop so the timeout task coroutine gets entered
|
||||
# before it could be cancelled. Without this, cancelling the task before
|
||||
# it starts would leave the coroutine in a "never awaited" state.
|
||||
await asyncio.sleep(0)
|
||||
|
||||
try:
|
||||
# Yield to the event loop so the timeout task coroutine gets entered
|
||||
# before it could be cancelled. Without this, cancelling the task before
|
||||
# it starts would leave the coroutine in a "never awaited" state.
|
||||
await asyncio.sleep(0)
|
||||
if isinstance(item.handler, DirectFunctionWrapper):
|
||||
# Handler is a DirectFunctionWrapper
|
||||
await item.handler.invoke(
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
"""MCP (Model Context Protocol) client for integrating external tools with LLMs."""
|
||||
|
||||
import json
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import Any, Callable, Dict, List, Optional, TypeAlias
|
||||
|
||||
from loguru import logger
|
||||
@@ -36,8 +37,14 @@ class MCPClient(BaseObject):
|
||||
"""Client for Model Context Protocol (MCP) servers.
|
||||
|
||||
Enables integration with MCP servers to provide external tools and resources
|
||||
to LLMs. Supports both stdio and SSE server connections with automatic tool
|
||||
registration and schema conversion.
|
||||
to LLMs. Supports stdio, SSE, and streamable HTTP server connections with
|
||||
automatic tool registration and schema conversion.
|
||||
|
||||
The client maintains a persistent connection to the MCP server. It must
|
||||
be used as an async context manager or explicitly started and closed::
|
||||
|
||||
async with MCPClient(server_params=...) as mcp:
|
||||
tools = await mcp.register_tools(llm)
|
||||
|
||||
Raises:
|
||||
TypeError: If server_params is not a supported parameter type.
|
||||
@@ -53,7 +60,7 @@ class MCPClient(BaseObject):
|
||||
"""Initialize the MCP client with server parameters.
|
||||
|
||||
Args:
|
||||
server_params: Server connection parameters (stdio or SSE).
|
||||
server_params: Server connection parameters (stdio, SSE, or streamable HTTP).
|
||||
tools_filter: Optional list of tool names to register. If None, all tools are registered.
|
||||
tools_output_filters: Optional dict mapping tool names to filter functions that process tool outputs.
|
||||
Each filter function receives the raw tool output (any type) and returns the processed output (any type).
|
||||
@@ -61,31 +68,84 @@ class MCPClient(BaseObject):
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._server_params = server_params
|
||||
self._session = ClientSession
|
||||
self._tools_filter = tools_filter
|
||||
self._tools_output_filters = tools_output_filters or {}
|
||||
self._exit_stack: Optional[AsyncExitStack] = None
|
||||
self._active_session: Optional[ClientSession] = None
|
||||
|
||||
if isinstance(server_params, StdioServerParameters):
|
||||
self._client = stdio_client
|
||||
self._list_tools = self._stdio_list_tools
|
||||
self._tool_wrapper = self._stdio_tool_wrapper
|
||||
elif isinstance(server_params, SseServerParameters):
|
||||
self._client = sse_client
|
||||
self._list_tools = self._sse_list_tools
|
||||
self._tool_wrapper = self._sse_tool_wrapper
|
||||
elif isinstance(server_params, StreamableHttpParameters):
|
||||
self._client = streamablehttp_client
|
||||
self._list_tools = self._streamable_http_list_tools
|
||||
self._tool_wrapper = self._streamable_http_tool_wrapper
|
||||
else:
|
||||
if not isinstance(
|
||||
server_params,
|
||||
(StdioServerParameters, SseServerParameters, StreamableHttpParameters),
|
||||
):
|
||||
raise TypeError(
|
||||
f"{self} invalid argument type: `server_params` must be either StdioServerParameters, SseServerParameters, or StreamableHttpParameters."
|
||||
f"{self} invalid argument type: `server_params` must be either "
|
||||
"StdioServerParameters, SseServerParameters, or StreamableHttpParameters."
|
||||
)
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start a persistent connection to the MCP server.
|
||||
|
||||
Opens the transport and initializes the MCP session. The session
|
||||
is reused for all subsequent tool calls and schema requests until
|
||||
close() is called.
|
||||
|
||||
Can also be used via async context manager::
|
||||
|
||||
async with MCPClient(server_params=...) as mcp:
|
||||
...
|
||||
"""
|
||||
if self._active_session:
|
||||
return
|
||||
|
||||
# We manage the exit stack manually (not via `async with`) so we can
|
||||
# clean up partial resources on failure before assigning to self.
|
||||
exit_stack = AsyncExitStack()
|
||||
await exit_stack.__aenter__()
|
||||
|
||||
try:
|
||||
if isinstance(self._server_params, StdioServerParameters):
|
||||
streams = await exit_stack.enter_async_context(stdio_client(self._server_params))
|
||||
read_stream, write_stream = streams[0], streams[1]
|
||||
elif isinstance(self._server_params, SseServerParameters):
|
||||
read_stream, write_stream = await exit_stack.enter_async_context(
|
||||
sse_client(**self._server_params.model_dump())
|
||||
)
|
||||
else: # StreamableHttpParameters (validated in __init__)
|
||||
read_stream, write_stream, _ = await exit_stack.enter_async_context(
|
||||
streamablehttp_client(**self._server_params.model_dump())
|
||||
)
|
||||
|
||||
session = await exit_stack.enter_async_context(ClientSession(read_stream, write_stream))
|
||||
await session.initialize()
|
||||
|
||||
self._exit_stack = exit_stack
|
||||
self._active_session = session
|
||||
|
||||
except Exception:
|
||||
await exit_stack.aclose()
|
||||
raise
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the persistent MCP connection.
|
||||
|
||||
Safe to call multiple times or without having called start().
|
||||
"""
|
||||
self._active_session = None
|
||||
if self._exit_stack:
|
||||
await self._exit_stack.aclose()
|
||||
self._exit_stack = None
|
||||
|
||||
async def __aenter__(self):
|
||||
await self.start()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.close()
|
||||
|
||||
async def register_tools(self, llm: LLMService | LLMSwitcher) -> ToolsSchema:
|
||||
"""Register all available MCP tools with an LLM service.
|
||||
|
||||
Connects to the MCP server, discovers available tools, converts their
|
||||
Discovers available tools from the active session, converts their
|
||||
schemas to Pipecat format, and registers them with the LLM service.
|
||||
|
||||
This is the equivalent of calling get_tools_schema() followed by
|
||||
@@ -101,18 +161,26 @@ class MCPClient(BaseObject):
|
||||
await self.register_tools_schema(tools_schema, llm)
|
||||
return tools_schema
|
||||
|
||||
def _ensure_connected(self) -> ClientSession:
|
||||
"""Return the active session or raise if not connected."""
|
||||
if not self._active_session:
|
||||
raise RuntimeError(
|
||||
"MCPClient is not connected. Use 'async with MCPClient(...) as mcp:' "
|
||||
"or call 'await mcp.start()' before using MCPClient."
|
||||
)
|
||||
return self._active_session
|
||||
|
||||
async def get_tools_schema(self) -> ToolsSchema:
|
||||
"""Get the schema of all available MCP tools without registering them.
|
||||
|
||||
Connects to the MCP server, discovers available tools, and converts their
|
||||
schemas to Pipecat format.
|
||||
Requires the client to be started via start() or async with.
|
||||
|
||||
Returns:
|
||||
A ToolsSchema containing all available tools. This can be used for
|
||||
subsequent registration using register_tools_schema().
|
||||
"""
|
||||
tools_schema = await self._list_tools()
|
||||
return tools_schema
|
||||
session = self._ensure_connected()
|
||||
return await self._list_tools_helper(session)
|
||||
|
||||
async def register_tools_schema(
|
||||
self, tools_schema: ToolsSchema, llm: LLMService | LLMSwitcher
|
||||
@@ -154,107 +222,21 @@ class MCPClient(BaseObject):
|
||||
|
||||
return schema
|
||||
|
||||
async def _sse_list_tools(self) -> ToolsSchema:
|
||||
"""List all available mcp tools with the LLM service.
|
||||
|
||||
Returns:
|
||||
A ToolsSchema containing all registered tools
|
||||
"""
|
||||
logger.debug(f"SSE server parameters: {self._server_params}")
|
||||
logger.debug(f"Starting reading mcp tools")
|
||||
|
||||
async with self._client(**self._server_params.model_dump()) as (read, write):
|
||||
async with self._session(read, write) as session:
|
||||
await session.initialize()
|
||||
tools_schema = await self._list_tools_helper(session)
|
||||
return tools_schema
|
||||
|
||||
async def _sse_tool_wrapper(self, params: FunctionCallParams) -> None:
|
||||
"""Wrapper for mcp tool calls to match Pipecat's function call interface."""
|
||||
async def _tool_wrapper(self, params: FunctionCallParams) -> None:
|
||||
"""Execute an MCP tool call using the persistent session."""
|
||||
session = self._ensure_connected()
|
||||
logger.debug(f"Executing tool '{params.function_name}' with call ID: {params.tool_call_id}")
|
||||
logger.trace(f"Tool arguments: {json.dumps(params.arguments, indent=2)}")
|
||||
try:
|
||||
async with self._client(**self._server_params.model_dump()) as (read, write):
|
||||
async with self._session(read, write) as session:
|
||||
await session.initialize()
|
||||
await self._call_tool(
|
||||
session, params.function_name, params.arguments, params.result_callback
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = f"Error calling mcp tool {params.function_name}: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
await params.result_callback(error_msg)
|
||||
|
||||
async def _stdio_list_tools(self) -> ToolsSchema:
|
||||
"""List all available mcp tools with the LLM service.
|
||||
|
||||
Returns:
|
||||
A ToolsSchema containing all available tools.
|
||||
"""
|
||||
logger.debug(f"Starting reading mcp tools")
|
||||
|
||||
async with self._client(self._server_params) as streams:
|
||||
async with self._session(streams[0], streams[1]) as session:
|
||||
await session.initialize()
|
||||
tools_schema = await self._list_tools_helper(session)
|
||||
return tools_schema
|
||||
|
||||
async def _stdio_tool_wrapper(self, params: FunctionCallParams) -> None:
|
||||
"""Wrapper for mcp tool calls to match Pipecat's function call interface."""
|
||||
logger.debug(f"Executing tool '{params.function_name}' with call ID: {params.tool_call_id}")
|
||||
logger.trace(f"Tool arguments: {json.dumps(params.arguments, indent=2)}")
|
||||
try:
|
||||
async with self._client(self._server_params) as streams:
|
||||
async with self._session(streams[0], streams[1]) as session:
|
||||
await session.initialize()
|
||||
await self._call_tool(
|
||||
session, params.function_name, params.arguments, params.result_callback
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = f"Error calling mcp tool {params.function_name}: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
await params.result_callback(error_msg)
|
||||
|
||||
async def _streamable_http_list_tools(self) -> ToolsSchema:
|
||||
"""List all available mcp tools with the LLM service using streamable HTTP.
|
||||
|
||||
Returns:
|
||||
A ToolsSchema containing all available tools.
|
||||
"""
|
||||
logger.debug(f"Starting reading mcp tools using streamable HTTP")
|
||||
|
||||
async with self._client(**self._server_params.model_dump()) as (
|
||||
read_stream,
|
||||
write_stream,
|
||||
_,
|
||||
):
|
||||
async with self._session(read_stream, write_stream) as session:
|
||||
await session.initialize()
|
||||
tools_schema = await self._list_tools_helper(session)
|
||||
return tools_schema
|
||||
|
||||
async def _streamable_http_tool_wrapper(self, params: FunctionCallParams) -> None:
|
||||
"""Wrapper for mcp tool calls to match Pipecat's function call interface."""
|
||||
logger.debug(f"Executing tool '{params.function_name}' with call ID: {params.tool_call_id}")
|
||||
logger.trace(f"Tool arguments: {json.dumps(params.arguments, indent=2)}")
|
||||
try:
|
||||
async with self._client(**self._server_params.model_dump()) as (
|
||||
read_stream,
|
||||
write_stream,
|
||||
_,
|
||||
):
|
||||
async with self._session(read_stream, write_stream) as session:
|
||||
await session.initialize()
|
||||
await self._call_tool(
|
||||
session, params.function_name, params.arguments, params.result_callback
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = f"Error calling mcp tool {params.function_name}: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
await params.result_callback(error_msg)
|
||||
await self._call_tool(
|
||||
session,
|
||||
params.function_name,
|
||||
params.arguments,
|
||||
params.result_callback,
|
||||
)
|
||||
|
||||
async def _call_tool(self, session, function_name, arguments, result_callback):
|
||||
logger.debug(f"Calling mcp tool '{function_name}'")
|
||||
results = None
|
||||
try:
|
||||
results = await session.call_tool(function_name, arguments=arguments)
|
||||
except Exception as e:
|
||||
|
||||
@@ -157,12 +157,6 @@ class MiniMaxHttpTTSService(TTSService):
|
||||
pitch: Pitch adjustment (range: -12 to 12).
|
||||
emotion: Emotional tone (options: "happy", "sad", "angry", "fearful",
|
||||
"disgusted", "surprised", "calm", "fluent").
|
||||
english_normalization: Deprecated; use `text_normalization` instead
|
||||
|
||||
.. deprecated:: 0.0.96
|
||||
The `english_normalization` parameter is deprecated and will be removed in a future version.
|
||||
Use the `text_normalization` parameter instead.
|
||||
|
||||
text_normalization: Enable text normalization (Chinese/English).
|
||||
latex_read: Enable LaTeX formula reading.
|
||||
exclude_aggregated_audio: Whether to exclude aggregated audio in final chunk.
|
||||
@@ -173,7 +167,6 @@ class MiniMaxHttpTTSService(TTSService):
|
||||
volume: Optional[float] = 1.0
|
||||
pitch: Optional[int] = 0
|
||||
emotion: Optional[str] = None
|
||||
english_normalization: Optional[bool] = None # Deprecated
|
||||
text_normalization: Optional[bool] = None
|
||||
latex_read: Optional[bool] = None
|
||||
exclude_aggregated_audio: Optional[bool] = None
|
||||
@@ -284,16 +277,6 @@ class MiniMaxHttpTTSService(TTSService):
|
||||
)
|
||||
|
||||
# Resolve text_normalization
|
||||
if params.english_normalization is not None:
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Parameter `english_normalization` is deprecated and will be removed in a future version. Use `text_normalization` instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
default_settings.text_normalization = params.english_normalization
|
||||
if params.text_normalization is not None:
|
||||
default_settings.text_normalization = params.text_normalization
|
||||
|
||||
|
||||
@@ -37,7 +37,6 @@ from pipecat.services.tts_service import (
|
||||
WebsocketTTSService,
|
||||
)
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
|
||||
from pipecat.utils.text.skip_tags_aggregator import SkipTagsAggregator
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
@@ -176,7 +175,6 @@ class RimeTTSService(WebsocketTTSService):
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
settings: Optional[Settings] = None,
|
||||
text_aggregator: Optional[BaseTextAggregator] = None,
|
||||
text_aggregation_mode: Optional[TextAggregationMode] = None,
|
||||
aggregate_sentences: Optional[bool] = None,
|
||||
**kwargs,
|
||||
@@ -204,11 +202,6 @@ class RimeTTSService(WebsocketTTSService):
|
||||
|
||||
settings: Runtime-updatable settings. When provided alongside deprecated
|
||||
parameters, ``settings`` values take precedence.
|
||||
text_aggregator: Custom text aggregator for processing input text.
|
||||
|
||||
.. deprecated:: 0.0.95
|
||||
Use an LLMTextProcessor before the TTSService for custom text aggregation.
|
||||
|
||||
text_aggregation_mode: How to aggregate incoming text before synthesis.
|
||||
aggregate_sentences: Deprecated. Use text_aggregation_mode instead.
|
||||
|
||||
@@ -282,15 +275,14 @@ class RimeTTSService(WebsocketTTSService):
|
||||
self._audio_format = "pcm"
|
||||
self._sampling_rate = 0 # updated in start()
|
||||
|
||||
if not text_aggregator:
|
||||
# Always skip tags added for spelled-out text
|
||||
# Note: This is primarily to support backwards compatibility.
|
||||
# The preferred way of taking advantage of Rime spelling is
|
||||
# to use an LLMTextProcessor and/or a text_transformer to identify
|
||||
# and insert these tags for the purpose of the TTS service alone.
|
||||
self._text_aggregator = SkipTagsAggregator(
|
||||
[("spell(", ")")], aggregation_type=self._text_aggregation_mode
|
||||
)
|
||||
# Always skip tags added for spelled-out text
|
||||
# Note: This is primarily to support backwards compatibility.
|
||||
# The preferred way of taking advantage of Rime spelling is
|
||||
# to use an LLMTextProcessor and/or a text_transformer to identify
|
||||
# and insert these tags for the purpose of the TTS service alone.
|
||||
self._text_aggregator = SkipTagsAggregator(
|
||||
[("spell(", ")")], aggregation_type=self._text_aggregation_mode
|
||||
)
|
||||
|
||||
# Store service configuration
|
||||
self._api_key = api_key
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
"""Simli video service for real-time avatar generation."""
|
||||
|
||||
import asyncio
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
@@ -79,10 +78,8 @@ class SimliVideoService(AIService):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: Optional[str] = None,
|
||||
face_id: Optional[str] = None,
|
||||
simli_config: Optional[SimliConfig] = None,
|
||||
use_turn_server: bool = False,
|
||||
api_key: str,
|
||||
face_id: str,
|
||||
simli_url: str = "https://api.simli.ai",
|
||||
is_trinity_avatar: bool = False,
|
||||
params: Optional[InputParams] = None,
|
||||
@@ -98,18 +95,6 @@ class SimliVideoService(AIService):
|
||||
api_key: Simli API key for authentication.
|
||||
face_id: Simli Face ID. For Trinity avatars, specify "faceId/emotionId"
|
||||
to use a different emotion than the default.
|
||||
simli_config: Configuration object for Simli client settings.
|
||||
Use api_key and face_id instead.
|
||||
|
||||
.. deprecated:: 0.0.92
|
||||
The 'simli_config' parameter is deprecated and will be removed in a future version.
|
||||
Please use 'api_key' and 'face_id' parameters instead.
|
||||
|
||||
use_turn_server: Whether to use TURN server for connection. Defaults to False.
|
||||
|
||||
.. deprecated:: 0.0.95
|
||||
The 'use_turn_server' parameter is deprecated and will be removed in a future version.
|
||||
|
||||
simli_url: URL of the simli servers. Can be changed for custom deployments
|
||||
of enterprise users.
|
||||
is_trinity_avatar: Boolean to tell simli client that this is a Trinity avatar
|
||||
@@ -147,49 +132,16 @@ class SimliVideoService(AIService):
|
||||
# 4. Call super
|
||||
super().__init__(settings=default_settings, **kwargs)
|
||||
|
||||
# Handle deprecated simli_config parameter
|
||||
if simli_config is not None:
|
||||
if api_key is not None or face_id is not None:
|
||||
raise ValueError(
|
||||
"Cannot specify both simli_config and api_key/face_id. "
|
||||
"Please use api_key and face_id (simli_config is deprecated)."
|
||||
)
|
||||
# Build SimliConfig from parameters
|
||||
config_kwargs = {
|
||||
"faceId": face_id,
|
||||
}
|
||||
if max_session_length is not None:
|
||||
config_kwargs["maxSessionLength"] = max_session_length
|
||||
if max_idle_time is not None:
|
||||
config_kwargs["maxIdleTime"] = max_idle_time
|
||||
|
||||
warnings.warn(
|
||||
"The 'simli_config' parameter is deprecated and will be removed in a future version. "
|
||||
"Please use 'api_key' and 'face_id' parameters instead, with optional 'params' for "
|
||||
"max_session_length and max_idle_time configuration.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# Use the provided simli_config
|
||||
config = simli_config
|
||||
else:
|
||||
# Validate new parameters
|
||||
if api_key is None:
|
||||
raise ValueError("api_key is required")
|
||||
if face_id is None:
|
||||
raise ValueError("face_id is required")
|
||||
|
||||
# Build SimliConfig from new parameters
|
||||
# Only pass optional parameters if explicitly provided to use SimliConfig defaults
|
||||
config_kwargs = {
|
||||
"faceId": face_id,
|
||||
}
|
||||
if max_session_length is not None:
|
||||
config_kwargs["maxSessionLength"] = max_session_length
|
||||
if max_idle_time is not None:
|
||||
config_kwargs["maxIdleTime"] = max_idle_time
|
||||
|
||||
config = SimliConfig(**config_kwargs)
|
||||
|
||||
if use_turn_server:
|
||||
warnings.warn(
|
||||
"The 'use_turn_server' parameter is deprecated and will be removed in a future version.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
config = SimliConfig(**config_kwargs)
|
||||
|
||||
self._initialized = False
|
||||
# Add buffer time to session limits
|
||||
|
||||
@@ -58,7 +58,6 @@ from pipecat.services.ai_service import AIService
|
||||
from pipecat.services.settings import TTSSettings, is_given
|
||||
from pipecat.services.websocket_service import WebsocketService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.text.base_text_aggregator import BaseTextAggregator
|
||||
from pipecat.utils.text.base_text_filter import BaseTextFilter
|
||||
from pipecat.utils.text.simple_text_aggregator import SimpleTextAggregator
|
||||
from pipecat.utils.time import seconds_to_nanoseconds
|
||||
@@ -168,8 +167,6 @@ class TTSService(AIService):
|
||||
append_trailing_space: bool = False,
|
||||
# TTS output sample rate
|
||||
sample_rate: Optional[int] = None,
|
||||
# Text aggregator to aggregate incoming tokens and decide when to push to the TTS.
|
||||
text_aggregator: Optional[BaseTextAggregator] = None,
|
||||
# Types of text aggregations that should not be spoken.
|
||||
skip_aggregator_types: Optional[List[str]] = [],
|
||||
# A list of callables to transform text before just before sending it to TTS.
|
||||
@@ -182,7 +179,6 @@ class TTSService(AIService):
|
||||
] = None,
|
||||
# Text filter executed after text has been aggregated.
|
||||
text_filters: Optional[Sequence[BaseTextFilter]] = None,
|
||||
text_filter: Optional[BaseTextFilter] = None,
|
||||
# Audio transport destination of the generated frames.
|
||||
transport_destination: Optional[str] = None,
|
||||
settings: Optional[TTSSettings] = None,
|
||||
@@ -215,11 +211,6 @@ class TTSService(AIService):
|
||||
append_trailing_space: Whether to append a trailing space to text before sending to TTS.
|
||||
This helps prevent some TTS services from vocalizing trailing punctuation (e.g., "dot").
|
||||
sample_rate: Output sample rate for generated audio.
|
||||
text_aggregator: Custom text aggregator for processing incoming text.
|
||||
|
||||
.. deprecated:: 0.0.95
|
||||
Use an LLMTextProcessor before the TTSService for custom text aggregation.
|
||||
|
||||
skip_aggregator_types: List of aggregation types that should not be spoken.
|
||||
text_transforms: A list of callables to transform text before just before sending it
|
||||
to TTS. Each callable takes the aggregated text and its type, and returns the
|
||||
@@ -227,11 +218,6 @@ class TTSService(AIService):
|
||||
(aggregation_type | '*', transform_function).
|
||||
|
||||
text_filters: Sequence of text filters to apply after aggregation.
|
||||
text_filter: Single text filter (deprecated, use text_filters).
|
||||
|
||||
.. deprecated:: 0.0.59
|
||||
Use `text_filters` instead, which allows multiple filters.
|
||||
|
||||
transport_destination: Destination for generated audio frames.
|
||||
settings: The runtime-updatable settings for the TTS service.
|
||||
reuse_context_id_within_turn: Whether the service should reuse context IDs within the
|
||||
@@ -300,18 +286,7 @@ class TTSService(AIService):
|
||||
self._append_trailing_space: bool = append_trailing_space
|
||||
self._init_sample_rate = sample_rate
|
||||
self._sample_rate = 0
|
||||
self._text_aggregator: BaseTextAggregator = text_aggregator or SimpleTextAggregator(
|
||||
aggregation_type=self._text_aggregation_mode
|
||||
)
|
||||
if text_aggregator:
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Parameter 'text_aggregator' is deprecated. Use an LLMTextProcessor before the TTSService for custom text aggregation.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self._text_aggregator = SimpleTextAggregator(aggregation_type=self._text_aggregation_mode)
|
||||
|
||||
self._skip_aggregator_types: List[str] = skip_aggregator_types or []
|
||||
self._text_transforms: List[
|
||||
@@ -320,16 +295,6 @@ class TTSService(AIService):
|
||||
# TODO: Deprecate _text_filters when added to LLMTextProcessor
|
||||
self._text_filters: Sequence[BaseTextFilter] = text_filters or []
|
||||
self._transport_destination: Optional[str] = transport_destination
|
||||
if text_filter:
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Parameter 'text_filter' is deprecated, use 'text_filters' instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self._text_filters = [text_filter]
|
||||
|
||||
self._resampler = create_stream_resampler()
|
||||
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Base notifier interface for Pipecat."""
|
||||
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Package pipecat.sync is deprecated, use pipecat.utils.sync instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
@@ -1,17 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Event-based notifier implementation using asyncio Event primitives."""
|
||||
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Package pipecat.sync is deprecated, use pipecat.utils.sync instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
@@ -7,44 +7,24 @@
|
||||
"""Base input transport implementation for Pipecat.
|
||||
|
||||
This module provides the BaseInputTransport class which handles audio and video
|
||||
input processing, including VAD, turn analysis, and interruption management.
|
||||
input processing.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.turn.base_turn_analyzer import (
|
||||
BaseTurnAnalyzer,
|
||||
EndOfTurnState,
|
||||
)
|
||||
from pipecat.audio.vad.vad_analyzer import VADAnalyzer, VADState
|
||||
from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
EmulateUserStartedSpeakingFrame,
|
||||
EmulateUserStoppedSpeakingFrame,
|
||||
EndFrame,
|
||||
FilterUpdateSettingsFrame,
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
InputImageRawFrame,
|
||||
MetricsFrame,
|
||||
SpeechControlParamsFrame,
|
||||
StartFrame,
|
||||
StopFrame,
|
||||
SystemFrame,
|
||||
UserSpeakingFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
VADParamsUpdateFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import MetricsData
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.transports.base_transport import TransportParams
|
||||
|
||||
@@ -93,29 +73,6 @@ class BaseInputTransport(FrameProcessor):
|
||||
# them downstream until we get another `StartFrame`.
|
||||
self._paused = False
|
||||
|
||||
if self._params.turn_analyzer:
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Parameter 'turn_analyzer' is deprecated, use `LLMUserAggregator`'s new "
|
||||
"`user_turn_strategies` parameter instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
if self._params.vad_analyzer:
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Parameter 'vad_analyzer' is deprecated. Use `LLMUserAggregator`'s "
|
||||
"`vad_analyzer` parameter, or `VADProcessor` if no `LLMUserAggregator` "
|
||||
"is needed.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
def enable_audio_in_stream_on_start(self, enabled: bool) -> None:
|
||||
"""Enable or disable audio streaming on transport start.
|
||||
|
||||
@@ -141,52 +98,6 @@ class BaseInputTransport(FrameProcessor):
|
||||
"""
|
||||
return self._sample_rate
|
||||
|
||||
@property
|
||||
def vad_analyzer(self) -> Optional[VADAnalyzer]:
|
||||
"""Get the Voice Activity Detection analyzer.
|
||||
|
||||
.. deprecated:: 0.0.101
|
||||
This method is deprecated and will be removed in a future version.
|
||||
Use `LLMUserAggregator`'s new `vad_analyzer` parameter instead.
|
||||
|
||||
Returns:
|
||||
The VAD analyzer instance if configured, None otherwise.
|
||||
"""
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Method 'vad_analyzer' is deprecated. Use `LLMUserAggregator`'s new "
|
||||
"`vad_analyzer` parameter instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
return self._params.vad_analyzer
|
||||
|
||||
@property
|
||||
def turn_analyzer(self) -> Optional[BaseTurnAnalyzer]:
|
||||
"""Get the turn-taking analyzer.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
This method is deprecated and will be removed in a future version.
|
||||
Use `LLMUserAggregator`'s new `user_turn_strategies` parameter instead.
|
||||
|
||||
Returns:
|
||||
The turn analyzer instance if configured, None otherwise.
|
||||
"""
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Method 'turn_analyzer' is deprecated. Use `LLMUserAggregator`'s new "
|
||||
"`user_turn_strategies` parameter instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
return self._params.turn_analyzer
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the input transport and initialize components.
|
||||
|
||||
@@ -202,26 +113,6 @@ class BaseInputTransport(FrameProcessor):
|
||||
if self._params.audio_in_filter:
|
||||
await self._params.audio_in_filter.start(self._sample_rate)
|
||||
|
||||
###################################################################
|
||||
# DEPRECATED.
|
||||
|
||||
# Configure VAD analyzer.
|
||||
if self._params.vad_analyzer:
|
||||
self._params.vad_analyzer.set_sample_rate(self._sample_rate)
|
||||
|
||||
# Configure End of turn analyzer.
|
||||
if self._params.turn_analyzer:
|
||||
self._params.turn_analyzer.set_sample_rate(self._sample_rate)
|
||||
|
||||
if self._params.vad_analyzer or self._params.turn_analyzer:
|
||||
vad_params = self._params.vad_analyzer.params if self._params.vad_analyzer else None
|
||||
turn_params = self._params.turn_analyzer.params if self._params.turn_analyzer else None
|
||||
|
||||
await self.broadcast_frame(
|
||||
SpeechControlParamsFrame, vad_params=vad_params, turn_params=turn_params
|
||||
)
|
||||
###################################################################
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the input transport and cleanup resources.
|
||||
|
||||
@@ -307,18 +198,6 @@ class BaseInputTransport(FrameProcessor):
|
||||
elif isinstance(frame, CancelFrame):
|
||||
await self.cancel(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, BotStartedSpeakingFrame):
|
||||
await self._deprecated_handle_bot_started_speaking(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, BotStoppedSpeakingFrame):
|
||||
await self._deprecated_handle_bot_stopped_speaking(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, EmulateUserStartedSpeakingFrame):
|
||||
logger.debug("Emulating user started speaking")
|
||||
await self._deprecated_handle_user_interruption(VADState.SPEAKING, emulated=True)
|
||||
elif isinstance(frame, EmulateUserStoppedSpeakingFrame):
|
||||
logger.debug("Emulating user stopped speaking")
|
||||
await self._deprecated_handle_user_interruption(VADState.QUIET, emulated=True)
|
||||
# All other system frames
|
||||
elif isinstance(frame, SystemFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
@@ -331,19 +210,6 @@ class BaseInputTransport(FrameProcessor):
|
||||
elif isinstance(frame, StopFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
await self.pause(frame)
|
||||
###################################################################
|
||||
# DEPRECATED.
|
||||
elif isinstance(frame, VADParamsUpdateFrame):
|
||||
if self._params.vad_analyzer:
|
||||
self._params.vad_analyzer.set_params(frame.params)
|
||||
await self.broadcast_frame(
|
||||
SpeechControlParamsFrame,
|
||||
vad_params=frame.params,
|
||||
turn_params=self._params.turn_analyzer.params
|
||||
if self._params.turn_analyzer
|
||||
else None,
|
||||
)
|
||||
###################################################################
|
||||
elif isinstance(frame, FilterUpdateSettingsFrame) and self._params.audio_in_filter:
|
||||
await self._params.audio_in_filter.process_frame(frame)
|
||||
# Other frames
|
||||
@@ -367,8 +233,7 @@ class BaseInputTransport(FrameProcessor):
|
||||
self._audio_task = None
|
||||
|
||||
async def _audio_task_handler(self):
|
||||
"""Main audio processing task handler for VAD and turn analysis."""
|
||||
vad_state: VADState = VADState.QUIET
|
||||
"""Main audio processing task handler."""
|
||||
# Skip timeout handling until the first audio frame arrives (e.g. client
|
||||
# not yet connected).
|
||||
audio_received = False
|
||||
@@ -381,7 +246,7 @@ class BaseInputTransport(FrameProcessor):
|
||||
# From now on, timeout should warn if there's no audio.
|
||||
audio_received = True
|
||||
|
||||
# If an audio filter is available, run it before VAD.
|
||||
# Filter audio, if an audio filter is available.
|
||||
if self._params.audio_in_filter:
|
||||
frame.audio = await self._params.audio_in_filter.filter(frame.audio)
|
||||
|
||||
@@ -390,22 +255,6 @@ class BaseInputTransport(FrameProcessor):
|
||||
self._audio_in_queue.task_done()
|
||||
continue
|
||||
|
||||
###################################################################
|
||||
# DEPRECATED.
|
||||
#
|
||||
# Check VAD and push event if necessary. We just care about
|
||||
# changes from QUIET to SPEAKING and vice versa.
|
||||
previous_vad_state = vad_state
|
||||
if self._params.vad_analyzer:
|
||||
vad_state = await self._deprecated_handle_vad(frame, vad_state)
|
||||
|
||||
if self._params.turn_analyzer:
|
||||
await self._deprecated_run_turn_analyzer(frame, vad_state, previous_vad_state)
|
||||
|
||||
if self._params.vad_analyzer and vad_state == VADState.SPEAKING:
|
||||
await self._deprecated_user_currently_speaking()
|
||||
###################################################################
|
||||
|
||||
# Push audio downstream if passthrough is set.
|
||||
if self._params.audio_in_passthrough:
|
||||
await self.push_frame(frame)
|
||||
@@ -414,190 +263,3 @@ class BaseInputTransport(FrameProcessor):
|
||||
except asyncio.TimeoutError:
|
||||
if not audio_received:
|
||||
continue
|
||||
|
||||
###################################################################
|
||||
# DEPRECATED.
|
||||
if self._user_speaking:
|
||||
logger.warning(
|
||||
"Forcing VAD user stopped speaking due to timeout receiving audio frame!"
|
||||
)
|
||||
vad_state = VADState.QUIET
|
||||
if self._params.turn_analyzer:
|
||||
self._params.turn_analyzer.clear()
|
||||
|
||||
if self._params.turn_analyzer:
|
||||
await self._deprecated_handle_user_interruption(VADState.QUIET)
|
||||
else:
|
||||
stop_secs = (
|
||||
self._params.vad_analyzer.params.stop_secs
|
||||
if self._params.vad_analyzer
|
||||
else 0.0
|
||||
)
|
||||
await self.push_frame(VADUserStoppedSpeakingFrame(stop_secs=stop_secs))
|
||||
###################################################################
|
||||
|
||||
#
|
||||
# DEPRECATED.
|
||||
#
|
||||
# The functions below are deprecated and should be removed once the old
|
||||
# interruption strategies and turn analyzer are removed.
|
||||
#
|
||||
|
||||
async def _deprecated_vad_analyze(self, audio_frame: InputAudioRawFrame) -> VADState:
|
||||
"""Analyze audio frame for voice activity."""
|
||||
state = VADState.QUIET
|
||||
if self._params.vad_analyzer:
|
||||
state = await self._params.vad_analyzer.analyze_audio(audio_frame.audio)
|
||||
return state
|
||||
|
||||
async def _deprecated_new_handle_vad(
|
||||
self, audio_frame: InputAudioRawFrame, vad_state: VADState
|
||||
) -> VADState:
|
||||
"""Handle Voice Activity Detection results and generate appropriate frames."""
|
||||
new_vad_state = await self._deprecated_vad_analyze(audio_frame)
|
||||
if (
|
||||
new_vad_state != vad_state
|
||||
and new_vad_state != VADState.STARTING
|
||||
and new_vad_state != VADState.STOPPING
|
||||
):
|
||||
if new_vad_state == VADState.SPEAKING:
|
||||
start_secs = (
|
||||
self._params.vad_analyzer.params.start_secs
|
||||
if self._params.vad_analyzer
|
||||
else 0.0
|
||||
)
|
||||
await self.push_frame(VADUserStartedSpeakingFrame(start_secs=start_secs))
|
||||
elif new_vad_state == VADState.QUIET:
|
||||
stop_secs = (
|
||||
self._params.vad_analyzer.params.stop_secs if self._params.vad_analyzer else 0.0
|
||||
)
|
||||
await self.push_frame(VADUserStoppedSpeakingFrame(stop_secs=stop_secs))
|
||||
|
||||
vad_state = new_vad_state
|
||||
return vad_state
|
||||
|
||||
async def _deprecated_handle_vad(
|
||||
self, audio_frame: InputAudioRawFrame, vad_state: VADState
|
||||
) -> VADState:
|
||||
"""Handle Voice Activity Detection results and generate appropriate frames."""
|
||||
if self._params.turn_analyzer:
|
||||
return await self._deprecated_old_handle_vad(audio_frame, vad_state)
|
||||
else:
|
||||
return await self._deprecated_new_handle_vad(audio_frame, vad_state)
|
||||
|
||||
async def _deprecated_user_currently_speaking(self):
|
||||
"""Handle user speaking frame."""
|
||||
diff_time = time.time() - self._user_speaking_frame_time
|
||||
if diff_time >= self._user_speaking_frame_period:
|
||||
await self.broadcast_frame(UserSpeakingFrame)
|
||||
self._user_speaking_frame_time = time.time()
|
||||
|
||||
async def _deprecated_handle_bot_started_speaking(self, frame: BotStartedSpeakingFrame):
|
||||
"""Update bot speaking state when bot starts speaking."""
|
||||
self._bot_speaking = True
|
||||
|
||||
async def _deprecated_handle_bot_stopped_speaking(self, frame: BotStoppedSpeakingFrame):
|
||||
"""Update bot speaking state when bot stops speaking."""
|
||||
self._bot_speaking = False
|
||||
|
||||
async def _deprecated_handle_user_interruption(
|
||||
self, vad_state: VADState, emulated: bool = False
|
||||
):
|
||||
"""Handle user interruption events based on speaking state."""
|
||||
if vad_state == VADState.SPEAKING:
|
||||
logger.debug("User started speaking")
|
||||
self._user_speaking = True
|
||||
|
||||
await self.broadcast_frame(UserStartedSpeakingFrame, emulated=emulated)
|
||||
|
||||
# Only push InterruptionFrame if:
|
||||
# 1. No interruption config is set, OR
|
||||
# 2. Interruption config is set but bot is not speaking
|
||||
should_push_immediate_interruption = (
|
||||
not self.interruption_strategies or not self._bot_speaking
|
||||
)
|
||||
|
||||
# Make sure we notify about interruptions quickly out-of-band.
|
||||
if should_push_immediate_interruption and self._allow_interruptions:
|
||||
await self.broadcast_interruption()
|
||||
elif self.interruption_strategies and self._bot_speaking:
|
||||
logger.debug(
|
||||
"User started speaking while bot is speaking with interruption config - "
|
||||
"deferring interruption to aggregator"
|
||||
)
|
||||
elif vad_state == VADState.QUIET:
|
||||
logger.debug("User stopped speaking")
|
||||
self._user_speaking = False
|
||||
|
||||
await self.broadcast_frame(UserStoppedSpeakingFrame, emulated=emulated)
|
||||
|
||||
async def _deprecated_old_handle_vad(
|
||||
self, audio_frame: InputAudioRawFrame, vad_state: VADState
|
||||
) -> VADState:
|
||||
"""Handle Voice Activity Detection results and generate appropriate frames."""
|
||||
new_vad_state = await self._deprecated_vad_analyze(audio_frame)
|
||||
if (
|
||||
new_vad_state != vad_state
|
||||
and new_vad_state != VADState.STARTING
|
||||
and new_vad_state != VADState.STOPPING
|
||||
):
|
||||
interruption_state = None
|
||||
|
||||
# If the turn analyser is enabled, this will prevent:
|
||||
# - Creating the UserStoppedSpeakingFrame
|
||||
# - Creating the UserStartedSpeakingFrame multiple times
|
||||
can_create_user_frames = (
|
||||
self._params.turn_analyzer is None
|
||||
or not self._params.turn_analyzer.speech_triggered
|
||||
)
|
||||
if new_vad_state == VADState.SPEAKING:
|
||||
start_secs = (
|
||||
self._params.vad_analyzer.params.start_secs
|
||||
if self._params.vad_analyzer
|
||||
else 0.0
|
||||
)
|
||||
await self.push_frame(VADUserStartedSpeakingFrame(start_secs=start_secs))
|
||||
if can_create_user_frames:
|
||||
interruption_state = VADState.SPEAKING
|
||||
elif new_vad_state == VADState.QUIET:
|
||||
stop_secs = (
|
||||
self._params.vad_analyzer.params.stop_secs if self._params.vad_analyzer else 0.0
|
||||
)
|
||||
await self.push_frame(VADUserStoppedSpeakingFrame(stop_secs=stop_secs))
|
||||
if can_create_user_frames:
|
||||
interruption_state = VADState.QUIET
|
||||
|
||||
if interruption_state:
|
||||
await self._deprecated_handle_user_interruption(interruption_state)
|
||||
|
||||
vad_state = new_vad_state
|
||||
return vad_state
|
||||
|
||||
async def _deprecated_handle_end_of_turn(self):
|
||||
"""Handle end-of-turn analysis and generate prediction results."""
|
||||
if self._params.turn_analyzer:
|
||||
state, prediction = await self._params.turn_analyzer.analyze_end_of_turn()
|
||||
await self._deprecated_handle_prediction_result(prediction)
|
||||
await self._deprecated_handle_end_of_turn_complete(state)
|
||||
|
||||
async def _deprecated_handle_end_of_turn_complete(self, state: EndOfTurnState):
|
||||
"""Handle completion of end-of-turn analysis."""
|
||||
if state == EndOfTurnState.COMPLETE:
|
||||
await self._deprecated_handle_user_interruption(VADState.QUIET)
|
||||
|
||||
async def _deprecated_handle_prediction_result(self, result: MetricsData):
|
||||
"""Handle a prediction result event from the turn analyzer."""
|
||||
await self.push_frame(MetricsFrame(data=[result]))
|
||||
|
||||
async def _deprecated_run_turn_analyzer(
|
||||
self, frame: InputAudioRawFrame, vad_state: VADState, previous_vad_state: VADState
|
||||
):
|
||||
"""Run turn analysis on audio frame and handle results."""
|
||||
is_speech = vad_state == VADState.SPEAKING or vad_state == VADState.STARTING
|
||||
# If silence exceeds threshold, we are going to receive EndOfTurnState.COMPLETE
|
||||
end_of_turn_state = self._params.turn_analyzer.append_audio(frame.audio, is_speech)
|
||||
if end_of_turn_state == EndOfTurnState.COMPLETE:
|
||||
await self._deprecated_handle_end_of_turn_complete(end_of_turn_state)
|
||||
# Otherwise we are going to trigger to check if the turn is completed based on the VAD
|
||||
elif vad_state == VADState.QUIET and vad_state != previous_vad_state:
|
||||
await self._deprecated_handle_end_of_turn()
|
||||
|
||||
@@ -517,9 +517,6 @@ class BaseOutputTransport(FrameProcessor):
|
||||
Args:
|
||||
_: The start interruption frame (unused).
|
||||
"""
|
||||
if not self._transport._allow_interruptions:
|
||||
return
|
||||
|
||||
# Cancel tasks.
|
||||
await self._cancel_audio_task()
|
||||
await self._cancel_clock_task()
|
||||
|
||||
@@ -54,13 +54,6 @@ class TransportParams(BaseModel):
|
||||
video_out_color_format: Video output color format string.
|
||||
video_out_codec: Preferred video codec for output (e.g., 'VP8', 'H264', 'H265').
|
||||
video_out_destinations: List of video output destination identifiers.
|
||||
vad_analyzer: Voice Activity Detection analyzer instance.
|
||||
|
||||
.. deprecated:: 0.0.101
|
||||
The `vad_analyzer` parameter is deprecated. Use `LLMUserAggregator`'s
|
||||
`vad_analyzer` parameter, or `VADProcessor` if no `LLMUserAggregator`
|
||||
is needed.
|
||||
|
||||
turn_analyzer: Turn-taking analyzer instance for conversation management.
|
||||
|
||||
.. deprecated:: 0.0.99
|
||||
@@ -95,8 +88,6 @@ class TransportParams(BaseModel):
|
||||
video_out_color_format: str = "RGB"
|
||||
video_out_codec: Optional[str] = None
|
||||
video_out_destinations: List[str] = Field(default_factory=list)
|
||||
vad_analyzer: Optional[VADAnalyzer] = None
|
||||
turn_analyzer: Optional[BaseTurnAnalyzer] = None
|
||||
|
||||
|
||||
class BaseTransport(BaseObject):
|
||||
|
||||
@@ -1711,17 +1711,6 @@ class DailyInputTransport(BaseInputTransport):
|
||||
# Audio task when using a virtual speaker (i.e. no user tracks).
|
||||
self._audio_in_task: Optional[asyncio.Task] = None
|
||||
|
||||
self._vad_analyzer: Optional[VADAnalyzer] = params.vad_analyzer
|
||||
|
||||
@property
|
||||
def vad_analyzer(self) -> Optional[VADAnalyzer]:
|
||||
"""Get the Voice Activity Detection analyzer.
|
||||
|
||||
Returns:
|
||||
The VAD analyzer instance if configured.
|
||||
"""
|
||||
return self._vad_analyzer
|
||||
|
||||
async def start_audio_in_streaming(self):
|
||||
"""Start receiving audio from participants."""
|
||||
if not self._params.audio_in_enabled:
|
||||
|
||||
@@ -652,21 +652,11 @@ class LiveKitInputTransport(BaseInputTransport):
|
||||
|
||||
self._audio_in_task = None
|
||||
self._video_in_task = None
|
||||
self._vad_analyzer: Optional[VADAnalyzer] = params.vad_analyzer
|
||||
self._resampler = create_stream_resampler()
|
||||
|
||||
# Whether we have seen a StartFrame already.
|
||||
self._initialized = False
|
||||
|
||||
@property
|
||||
def vad_analyzer(self) -> Optional[VADAnalyzer]:
|
||||
"""Get the Voice Activity Detection analyzer.
|
||||
|
||||
Returns:
|
||||
The VAD analyzer instance if configured.
|
||||
"""
|
||||
return self._vad_analyzer
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the input transport and connect to LiveKit room.
|
||||
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""FastAPI WebSocket transport implementation for Pipecat.
|
||||
|
||||
This module provides WebSocket-based transport for real-time audio/video streaming
|
||||
using FastAPI and WebSocket connections. Supports binary and text serialization
|
||||
with configurable session timeouts and WAV header generation.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
from pipecat.transports.websocket.fastapi import *
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Module `pipecat.transports.network.fastapi_websocket` is deprecated, "
|
||||
"use `pipecat.transports.websocket.fastapi` instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
@@ -1,25 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Small WebRTC transport implementation for Pipecat.
|
||||
|
||||
This module provides a WebRTC transport implementation using aiortc for
|
||||
real-time audio and video communication. It supports bidirectional media
|
||||
streaming, application messaging, and client connection management.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
from pipecat.transports.smallwebrtc.transport import *
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Module `pipecat.transports.network.small_webrtc` is deprecated, "
|
||||
"use `pipecat.transports.smallwebrtc.transport` instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
@@ -1,25 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Small WebRTC connection implementation for Pipecat.
|
||||
|
||||
This module provides a WebRTC connection implementation using aiortc,
|
||||
with support for audio/video tracks, data channels, and signaling
|
||||
for real-time communication applications.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
from pipecat.transports.smallwebrtc.connection import *
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Module `pipecat.transports.network.webrtc_connection` is deprecated, "
|
||||
"use `pipecat.transports.smallwebrtc.connection` instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
@@ -1,25 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""WebSocket client transport implementation for Pipecat.
|
||||
|
||||
This module provides a WebSocket client transport that enables bidirectional
|
||||
communication over WebSocket connections, with support for audio streaming,
|
||||
frame serialization, and connection management.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
from pipecat.transports.websocket.client import *
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Module `pipecat.transports.network.websocket_client` is deprecated, "
|
||||
"use `pipecat.transports.websocket.client` instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
@@ -1,25 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""WebSocket server transport implementation for Pipecat.
|
||||
|
||||
This module provides WebSocket server transport functionality for real-time
|
||||
audio and data streaming, including client connection management, session
|
||||
handling, and frame serialization.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
from pipecat.transports.websocket.server import *
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Module `pipecat.transports.network.websocket_server` is deprecated, "
|
||||
"use `pipecat.transports.websocket.server` instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
@@ -1,25 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Daily transport implementation for Pipecat.
|
||||
|
||||
This module provides comprehensive Daily video conferencing integration including
|
||||
audio/video streaming, transcription, recording, dial-in/out functionality, and
|
||||
real-time communication features.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
from pipecat.transports.daily.transport import *
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Module `pipecat.transports.services.daily` is deprecated, "
|
||||
"use `pipecat.transports.daily.transport` instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
@@ -1,23 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Daily REST Helpers.
|
||||
|
||||
Methods that wrap the Daily API to create rooms, check room URLs, and get meeting tokens.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
from pipecat.transports.daily.utils import *
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Module `pipecat.transports.services.helpers.daily_rest` is deprecated, "
|
||||
"use `pipecat.transports.daily.utils` instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
@@ -1,25 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""LiveKit transport implementation for Pipecat.
|
||||
|
||||
This module provides comprehensive LiveKit real-time communication integration
|
||||
including audio streaming, data messaging, participant management, and room
|
||||
event handling for conversational AI applications.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
from pipecat.transports.livekit.transport import *
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Module `pipecat.transports.services.livekit` is deprecated, "
|
||||
"use `pipecat.transports.livekit.transport` instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
@@ -1,25 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Tavus transport implementation for Pipecat.
|
||||
|
||||
This module provides integration with the Tavus platform for creating conversational
|
||||
AI applications with avatars. It manages conversation sessions and provides real-time
|
||||
audio/video streaming capabilities through the Tavus API.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
from pipecat.transports.tavus.transport import *
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Module `pipecat.transports.services.tavus` is deprecated, "
|
||||
"use `pipecat.transports.tavus.transport` instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
@@ -181,7 +181,7 @@ class UserTurnProcessor(FrameProcessor):
|
||||
|
||||
await self._user_idle_controller.process_frame(UserStartedSpeakingFrame())
|
||||
|
||||
if params.enable_interruptions and self._allow_interruptions:
|
||||
if params.enable_interruptions:
|
||||
await self.broadcast_interruption()
|
||||
|
||||
await self._call_event_handler("on_user_turn_started", strategy)
|
||||
|
||||
@@ -161,46 +161,6 @@ class PatternPairAggregator(SimpleTextAggregator):
|
||||
}
|
||||
return self
|
||||
|
||||
def add_pattern_pair(
|
||||
self, pattern_id: str, start_pattern: str, end_pattern: str, remove_match: bool = True
|
||||
):
|
||||
"""Add a pattern pair to detect in the text.
|
||||
|
||||
.. deprecated:: 0.0.95
|
||||
This function is deprecated and will be removed in a future version.
|
||||
Use `add_pattern` with a type and MatchAction instead.
|
||||
|
||||
This method calls `add_pattern` setting type with the provided pattern_id and action
|
||||
to either MatchAction.REMOVE or MatchAction.KEEP based on `remove_match`.
|
||||
|
||||
Args:
|
||||
pattern_id: Identifier for this pattern pair. Should be unique and ideally descriptive.
|
||||
(e.g., 'code', 'speaker', 'custom'). pattern_id can not be 'sentence' or 'word'
|
||||
as those arereserved for the default behavior.
|
||||
start_pattern: Pattern that marks the beginning of content.
|
||||
end_pattern: Pattern that marks the end of content.
|
||||
remove_match: If True, the matched pattern will be removed from the text. (Same as MatchAction.REMOVE)
|
||||
If False, it will be kept and treated as normal text. (Same as MatchAction.KEEP)
|
||||
"""
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("once")
|
||||
warnings.warn(
|
||||
"add_pattern_pair with a pattern_id or remove_match is deprecated and will be"
|
||||
" removed in a future version. Use add_pattern with a type and MatchAction instead",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
action = MatchAction.REMOVE if remove_match else MatchAction.KEEP
|
||||
return self.add_pattern(
|
||||
type=pattern_id,
|
||||
start_pattern=start_pattern,
|
||||
end_pattern=end_pattern,
|
||||
action=action,
|
||||
)
|
||||
|
||||
def on_pattern_match(
|
||||
self, type: str, handler: Callable[[PatternMatch], Awaitable[None]]
|
||||
) -> "PatternPairAggregator":
|
||||
|
||||
@@ -1,257 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
# Portions Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Base OpenTelemetry tracing decorators and utilities for Pipecat.
|
||||
|
||||
.. deprecated:: 0.0.103
|
||||
This module is unused and will be removed in a future release.
|
||||
Service tracing is handled by the decorators in
|
||||
:mod:`pipecat.utils.tracing.service_decorators`.
|
||||
|
||||
This module provides class and method level tracing capabilities
|
||||
similar to the original NVIDIA implementation.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import enum
|
||||
import functools
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import Callable, Optional, TypeVar
|
||||
|
||||
warnings.warn(
|
||||
"pipecat.utils.tracing.class_decorators is deprecated and will be removed in a future "
|
||||
"release. Use pipecat.utils.tracing.service_decorators instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
from pipecat.utils.tracing.setup import is_tracing_available
|
||||
|
||||
# Import OpenTelemetry if available
|
||||
if is_tracing_available():
|
||||
import opentelemetry.trace
|
||||
from opentelemetry import metrics, trace
|
||||
|
||||
# Type variables for better typing support
|
||||
T = TypeVar("T")
|
||||
C = TypeVar("C", bound=type)
|
||||
|
||||
|
||||
class AttachmentStrategy(enum.Enum):
|
||||
"""Controls how spans are attached to the trace hierarchy.
|
||||
|
||||
Parameters:
|
||||
CHILD: Attached to class span if no parent, otherwise to parent.
|
||||
LINK: Attached to class span with link to parent.
|
||||
NONE: Always attached to class span regardless of context.
|
||||
"""
|
||||
|
||||
CHILD = enum.auto()
|
||||
LINK = enum.auto()
|
||||
NONE = enum.auto()
|
||||
|
||||
|
||||
class Traceable:
|
||||
"""Base class for objects that can be traced with OpenTelemetry.
|
||||
|
||||
Provides the foundational tracing capabilities used by @traced methods.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, **kwargs):
|
||||
"""Initialize a traceable object.
|
||||
|
||||
Args:
|
||||
name: Name of the traceable object for the span.
|
||||
**kwargs: Additional arguments passed to parent class.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if not is_tracing_available():
|
||||
self._tracer = self._meter = self._parent_span_id = self._span = None
|
||||
return
|
||||
|
||||
self._tracer = trace.get_tracer("pipecat")
|
||||
self._meter = metrics.get_meter("pipecat")
|
||||
self._parent_span_id = trace.get_current_span().get_span_context().span_id
|
||||
self._span = self._tracer.start_span(name)
|
||||
self._span.end()
|
||||
|
||||
@property
|
||||
def meter(self):
|
||||
"""Get the OpenTelemetry meter instance.
|
||||
|
||||
Returns:
|
||||
The OpenTelemetry meter instance for this object.
|
||||
"""
|
||||
return self._meter
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def __traced_context_manager(
|
||||
self: Traceable, func: Callable, name: str | None, attachment_strategy: AttachmentStrategy
|
||||
):
|
||||
"""Internal context manager for the traced decorator.
|
||||
|
||||
Args:
|
||||
self: The Traceable instance.
|
||||
func: The function being traced.
|
||||
name: Custom span name or None to use function name.
|
||||
attachment_strategy: How to attach this span to the trace hierarchy.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If used in a class not inheriting from Traceable.
|
||||
"""
|
||||
if not isinstance(self, Traceable):
|
||||
raise RuntimeError(
|
||||
"@traced annotation can only be used in classes inheriting from Traceable"
|
||||
)
|
||||
|
||||
stack = contextlib.ExitStack()
|
||||
try:
|
||||
current_span = trace.get_current_span()
|
||||
is_span_class_parent_span = current_span.get_span_context().span_id == self._parent_span_id
|
||||
match attachment_strategy:
|
||||
case AttachmentStrategy.CHILD if not is_span_class_parent_span:
|
||||
stack.enter_context(
|
||||
self._tracer.start_as_current_span(func.__name__ if name is None else name) # type: ignore
|
||||
)
|
||||
case AttachmentStrategy.LINK:
|
||||
if is_span_class_parent_span:
|
||||
link = trace.Link(self._span.get_span_context()) # type: ignore
|
||||
else:
|
||||
link = trace.Link(current_span.get_span_context())
|
||||
stack.enter_context(
|
||||
opentelemetry.trace.use_span(span=self._span, end_on_exit=False) # type: ignore
|
||||
)
|
||||
stack.enter_context(
|
||||
self._tracer.start_as_current_span( # type: ignore
|
||||
func.__name__ if name is None else name, links=[link]
|
||||
)
|
||||
)
|
||||
case AttachmentStrategy.NONE | AttachmentStrategy.CHILD:
|
||||
stack.enter_context(
|
||||
opentelemetry.trace.use_span(span=self._span, end_on_exit=False) # type: ignore
|
||||
)
|
||||
stack.enter_context(
|
||||
self._tracer.start_as_current_span(func.__name__ if name is None else name) # type: ignore
|
||||
)
|
||||
yield
|
||||
finally:
|
||||
stack.close()
|
||||
|
||||
|
||||
def __traced_decorator(func, name, attachment_strategy: AttachmentStrategy):
|
||||
"""Implementation of the traced decorator.
|
||||
|
||||
Args:
|
||||
func: The function to trace.
|
||||
name: Custom span name.
|
||||
attachment_strategy: How to attach this span.
|
||||
|
||||
Returns:
|
||||
The wrapped function with tracing capabilities.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
async def coroutine_wrapper(self: Traceable, *args, **kwargs):
|
||||
exception = None
|
||||
with __traced_context_manager(self, func, name, attachment_strategy):
|
||||
try:
|
||||
return await func(self, *args, **kwargs)
|
||||
except asyncio.CancelledError as e:
|
||||
exception = e
|
||||
if exception:
|
||||
raise exception
|
||||
|
||||
@functools.wraps(func)
|
||||
async def generator_wrapper(self: Traceable, *args, **kwargs):
|
||||
exception = None
|
||||
with __traced_context_manager(self, func, name, attachment_strategy):
|
||||
try:
|
||||
async for v in func(self, *args, **kwargs):
|
||||
yield v
|
||||
except asyncio.CancelledError as e:
|
||||
exception = e
|
||||
if exception:
|
||||
raise exception
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return coroutine_wrapper
|
||||
if inspect.isasyncgenfunction(func):
|
||||
return generator_wrapper
|
||||
|
||||
raise ValueError("@traced annotation can only be used on async or async generator functions")
|
||||
|
||||
|
||||
def traced(
|
||||
func: Optional[Callable] = None,
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
attachment_strategy: AttachmentStrategy = AttachmentStrategy.CHILD,
|
||||
) -> Callable:
|
||||
"""Add tracing to an async function in a Traceable class.
|
||||
|
||||
Args:
|
||||
func: The async function to trace.
|
||||
name: Custom span name. Defaults to function name.
|
||||
attachment_strategy: How to attach this span (CHILD, LINK, NONE).
|
||||
|
||||
Returns:
|
||||
Wrapped async function with tracing.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If used in a class not inheriting from Traceable.
|
||||
ValueError: If used on a non-async function.
|
||||
"""
|
||||
if not is_tracing_available():
|
||||
# Just return the original function or a simple decorator
|
||||
def decorator(f):
|
||||
return f
|
||||
|
||||
return decorator if func is None else func
|
||||
|
||||
if func is not None:
|
||||
return __traced_decorator(func, name=name, attachment_strategy=attachment_strategy)
|
||||
else:
|
||||
return functools.partial(
|
||||
__traced_decorator, name=name, attachment_strategy=attachment_strategy
|
||||
)
|
||||
|
||||
|
||||
def traceable(cls: C) -> C:
|
||||
"""Make a class traceable for OpenTelemetry.
|
||||
|
||||
Creates a new class that inherits from both the original class
|
||||
and Traceable, enabling tracing for class methods.
|
||||
|
||||
Args:
|
||||
cls: The class to make traceable.
|
||||
|
||||
Returns:
|
||||
A new class with tracing capabilities.
|
||||
"""
|
||||
if not is_tracing_available():
|
||||
return cls
|
||||
|
||||
@functools.wraps(cls, updated=())
|
||||
class TracedClass(cls, Traceable):
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Initialize the traced class instance.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments passed to parent classes.
|
||||
**kwargs: Keyword arguments passed to parent classes.
|
||||
"""
|
||||
cls.__init__(self, *args, **kwargs)
|
||||
if hasattr(self, "name"):
|
||||
Traceable.__init__(self, self.name)
|
||||
else:
|
||||
Traceable.__init__(self, cls.__name__)
|
||||
|
||||
return TracedClass
|
||||
@@ -100,11 +100,6 @@ def _get_parent_service_context(self):
|
||||
if not is_tracing_available():
|
||||
return None
|
||||
|
||||
# TODO: Remove this block and delete class_decorators.py once Traceable is removed.
|
||||
# Legacy: support for classes inheriting from Traceable (currently unused, deprecated).
|
||||
if hasattr(self, "_span") and self._span:
|
||||
return trace.set_span_in_context(self._span)
|
||||
|
||||
# Use the conversation context set by TurnTraceObserver via TracingContext.
|
||||
tracing_ctx = getattr(self, "_tracing_context", None)
|
||||
conversation_context = tracing_ctx.get_conversation_context() if tracing_ctx else None
|
||||
|
||||
@@ -28,6 +28,7 @@ from pipecat.frames.frames import (
|
||||
LLMThoughtEndFrame,
|
||||
LLMThoughtStartFrame,
|
||||
LLMThoughtTextFrame,
|
||||
SpeechControlParamsFrame,
|
||||
StartFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
@@ -67,7 +68,7 @@ class TestLLMUserAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
pipeline = Pipeline([LLMUserAggregator(context)])
|
||||
|
||||
frames_to_send = [LLMRunFrame()]
|
||||
expected_down_frames = [LLMContextFrame]
|
||||
expected_down_frames = [SpeechControlParamsFrame, LLMContextFrame]
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=frames_to_send,
|
||||
@@ -110,7 +111,7 @@ class TestLLMUserAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
run_llm=True,
|
||||
)
|
||||
]
|
||||
expected_down_frames = [LLMContextFrame]
|
||||
expected_down_frames = [SpeechControlParamsFrame, LLMContextFrame]
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=frames_to_send,
|
||||
@@ -450,7 +451,7 @@ class TestLLMUserAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
(down_frames, _) = await run_test(
|
||||
pipeline,
|
||||
frames_to_send=[],
|
||||
expected_down_frames=[StartFrame, UserMuteStartedFrame],
|
||||
expected_down_frames=[StartFrame, UserMuteStartedFrame, SpeechControlParamsFrame],
|
||||
ignore_start=False,
|
||||
)
|
||||
|
||||
@@ -467,6 +468,7 @@ class TestLLMUserAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
# TranscriptionUserTurnStartStrategy, so we expect turn-related frames
|
||||
# but NOT the InterimTranscriptionFrame itself.
|
||||
expected_down_frames = [
|
||||
SpeechControlParamsFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
InterruptionFrame,
|
||||
]
|
||||
@@ -485,11 +487,12 @@ class TestLLMUserAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
frames_to_send = [
|
||||
TranslationFrame(text="Hola!", user_id="", timestamp="now", language="es"),
|
||||
]
|
||||
# No downstream frames expected — translations are consumed.
|
||||
# Only the SpeechControlParamsFrame from the default turn strategy on
|
||||
# start is expected — the translation itself is consumed.
|
||||
await run_test(
|
||||
pipeline,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=[],
|
||||
expected_down_frames=[SpeechControlParamsFrame],
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import unittest
|
||||
|
||||
from pipecat.audio.interruptions.min_words_interruption_strategy import MinWordsInterruptionStrategy
|
||||
|
||||
|
||||
class TestMinWordsInterruptionStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_min_words(self):
|
||||
strategy = MinWordsInterruptionStrategy(min_words=2)
|
||||
await strategy.append_text("Hello")
|
||||
self.assertEqual(await strategy.should_interrupt(), False)
|
||||
await strategy.append_text(" there!")
|
||||
self.assertEqual(await strategy.should_interrupt(), True)
|
||||
# Reset and check again
|
||||
await strategy.reset()
|
||||
await strategy.append_text("Hello!")
|
||||
self.assertEqual(await strategy.should_interrupt(), False)
|
||||
await strategy.append_text(" How are you?")
|
||||
self.assertEqual(await strategy.should_interrupt(), True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -21,8 +21,8 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
self.code_handler = AsyncMock()
|
||||
|
||||
# Add a test pattern
|
||||
self.aggregator.add_pattern_pair(
|
||||
pattern_id="test_pattern",
|
||||
self.aggregator.add_pattern(
|
||||
type="test_pattern",
|
||||
start_pattern="<test>",
|
||||
end_pattern="</test>",
|
||||
)
|
||||
|
||||
@@ -1,354 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import unittest
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
FunctionCallFromLLM,
|
||||
FunctionCallResultFrame,
|
||||
FunctionCallsStartedFrame,
|
||||
InputAudioRawFrame,
|
||||
InterimTranscriptionFrame,
|
||||
InterruptionFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.filters.stt_mute_filter import STTMuteConfig, STTMuteFilter, STTMuteStrategy
|
||||
from pipecat.tests.utils import SleepFrame, run_test
|
||||
|
||||
|
||||
class TestSTTMuteFilter(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_first_speech_strategy(self):
|
||||
filter = STTMuteFilter(config=STTMuteConfig(strategies={STTMuteStrategy.FIRST_SPEECH}))
|
||||
|
||||
frames_to_send = [
|
||||
BotStartedSpeakingFrame(), # First bot speech starts
|
||||
VADUserStartedSpeakingFrame(), # Should be suppressed
|
||||
UserStartedSpeakingFrame(), # Should be suppressed
|
||||
InputAudioRawFrame(
|
||||
audio=b"", sample_rate=16000, num_channels=1
|
||||
), # Should be suppressed
|
||||
VADUserStoppedSpeakingFrame(), # Should be suppressed
|
||||
UserStoppedSpeakingFrame(), # Should be suppressed
|
||||
BotStoppedSpeakingFrame(), # First bot speech ends
|
||||
BotStartedSpeakingFrame(), # Second bot speech
|
||||
VADUserStartedSpeakingFrame(), # Should pass through
|
||||
UserStartedSpeakingFrame(), # Should pass through
|
||||
InputAudioRawFrame(audio=b"", sample_rate=16000, num_channels=1), # Should pass through
|
||||
VADUserStoppedSpeakingFrame(), # Should pass through
|
||||
UserStoppedSpeakingFrame(), # Should pass through
|
||||
BotStoppedSpeakingFrame(),
|
||||
]
|
||||
|
||||
expected_returned_frames = [
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
VADUserStartedSpeakingFrame, # Now passes through
|
||||
UserStartedSpeakingFrame, # Now passes through
|
||||
InputAudioRawFrame, # Now passes through
|
||||
VADUserStoppedSpeakingFrame, # Now passes through
|
||||
UserStoppedSpeakingFrame, # Now passes through
|
||||
BotStoppedSpeakingFrame,
|
||||
]
|
||||
|
||||
await run_test(
|
||||
filter,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_returned_frames,
|
||||
)
|
||||
|
||||
async def test_always_strategy(self):
|
||||
filter = STTMuteFilter(config=STTMuteConfig(strategies={STTMuteStrategy.ALWAYS}))
|
||||
|
||||
frames_to_send = [
|
||||
BotStartedSpeakingFrame(), # First speech starts
|
||||
VADUserStartedSpeakingFrame(), # Should be suppressed
|
||||
UserStartedSpeakingFrame(), # Should be suppressed
|
||||
InputAudioRawFrame(
|
||||
audio=b"", sample_rate=16000, num_channels=1
|
||||
), # Should be suppressed
|
||||
VADUserStoppedSpeakingFrame(), # Should be suppressed
|
||||
UserStoppedSpeakingFrame(), # Should be suppressed
|
||||
BotStoppedSpeakingFrame(), # First speech ends
|
||||
VADUserStartedSpeakingFrame(), # Should pass through
|
||||
UserStartedSpeakingFrame(), # Should pass through
|
||||
InputAudioRawFrame(audio=b"", sample_rate=16000, num_channels=1), # Should pass through
|
||||
VADUserStoppedSpeakingFrame(), # Should pass through
|
||||
UserStoppedSpeakingFrame(), # Should pass through
|
||||
BotStartedSpeakingFrame(), # Second speech starts
|
||||
VADUserStartedSpeakingFrame(), # Should be suppressed again
|
||||
UserStartedSpeakingFrame(), # Should be suppressed again
|
||||
InputAudioRawFrame(
|
||||
audio=b"", sample_rate=16000, num_channels=1
|
||||
), # Should be suppressed again
|
||||
VADUserStoppedSpeakingFrame(), # Should be suppressed again
|
||||
UserStoppedSpeakingFrame(), # Should be suppressed again
|
||||
BotStoppedSpeakingFrame(), # Second speech ends
|
||||
]
|
||||
|
||||
expected_returned_frames = [
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
InputAudioRawFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
]
|
||||
|
||||
await run_test(
|
||||
filter,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_returned_frames,
|
||||
)
|
||||
|
||||
async def test_transcription_frames_with_always_strategy(self):
|
||||
filter = STTMuteFilter(config=STTMuteConfig(strategies={STTMuteStrategy.ALWAYS}))
|
||||
|
||||
frames_to_send = [
|
||||
# Bot speaking - should mute
|
||||
BotStartedSpeakingFrame(),
|
||||
SleepFrame(), # Wait for StartedSpeaking to process
|
||||
InterimTranscriptionFrame(
|
||||
user_id="user1", text="This should be suppressed", timestamp="1234567890"
|
||||
),
|
||||
TranscriptionFrame(
|
||||
user_id="user1", text="This should be suppressed", timestamp="1234567890"
|
||||
),
|
||||
SleepFrame(), # Wait for transcription frames to queue
|
||||
BotStoppedSpeakingFrame(),
|
||||
# Bot not speaking - should pass through
|
||||
InterimTranscriptionFrame(
|
||||
user_id="user1", text="This should pass", timestamp="1234567891"
|
||||
),
|
||||
TranscriptionFrame(
|
||||
user_id="user1", text="This should pass through", timestamp="1234567891"
|
||||
),
|
||||
]
|
||||
|
||||
expected_returned_frames = [
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
InterimTranscriptionFrame, # Only passes through after bot stops speaking
|
||||
TranscriptionFrame, # Only passes through after bot stops speaking
|
||||
]
|
||||
|
||||
await run_test(
|
||||
filter,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_returned_frames,
|
||||
)
|
||||
|
||||
async def test_function_call_strategy(self):
|
||||
filter = STTMuteFilter(config=STTMuteConfig(strategies={STTMuteStrategy.FUNCTION_CALL}))
|
||||
|
||||
frames_to_send = [
|
||||
VADUserStartedSpeakingFrame(), # Should pass through initially
|
||||
UserStartedSpeakingFrame(), # Should pass through initially
|
||||
VADUserStoppedSpeakingFrame(), # Should pass through initially
|
||||
UserStoppedSpeakingFrame(), # Should pass through initially
|
||||
FunctionCallsStartedFrame(
|
||||
function_calls=[
|
||||
FunctionCallFromLLM(
|
||||
function_name="get_weather",
|
||||
tool_call_id="call_123",
|
||||
arguments='{"location": "San Francisco"}',
|
||||
context=None,
|
||||
)
|
||||
]
|
||||
), # Start function call
|
||||
VADUserStartedSpeakingFrame(), # Should be suppressed
|
||||
UserStartedSpeakingFrame(), # Should be suppressed
|
||||
VADUserStoppedSpeakingFrame(), # Should be suppressed
|
||||
UserStoppedSpeakingFrame(), # Should be suppressed
|
||||
FunctionCallResultFrame(
|
||||
function_name="get_weather",
|
||||
tool_call_id="call_123",
|
||||
arguments='{"location": "San Francisco"}',
|
||||
result={"temperature": 22},
|
||||
), # End function call
|
||||
SleepFrame(),
|
||||
VADUserStartedSpeakingFrame(), # Should pass through again
|
||||
UserStartedSpeakingFrame(), # Should pass through again
|
||||
VADUserStoppedSpeakingFrame(),
|
||||
UserStoppedSpeakingFrame(),
|
||||
]
|
||||
|
||||
expected_returned_frames = [
|
||||
VADUserStartedSpeakingFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
FunctionCallsStartedFrame,
|
||||
FunctionCallResultFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
]
|
||||
|
||||
await run_test(
|
||||
filter,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_returned_frames,
|
||||
)
|
||||
|
||||
async def test_mute_until_first_bot_complete_strategy(self):
|
||||
filter = STTMuteFilter(
|
||||
config=STTMuteConfig(strategies={STTMuteStrategy.MUTE_UNTIL_FIRST_BOT_COMPLETE})
|
||||
)
|
||||
|
||||
frames_to_send = [
|
||||
VADUserStartedSpeakingFrame(), # Should be suppressed (starts muted)
|
||||
UserStartedSpeakingFrame(), # Should be suppressed (starts muted)
|
||||
InputAudioRawFrame(
|
||||
audio=b"", sample_rate=16000, num_channels=1
|
||||
), # Should be suppressed
|
||||
VADUserStoppedSpeakingFrame(), # Should be suppressed
|
||||
UserStoppedSpeakingFrame(), # Should be suppressed
|
||||
BotStartedSpeakingFrame(), # First bot speech
|
||||
VADUserStartedSpeakingFrame(), # Should be suppressed
|
||||
UserStartedSpeakingFrame(), # Should be suppressed
|
||||
InputAudioRawFrame(
|
||||
audio=b"", sample_rate=16000, num_channels=1
|
||||
), # Should be suppressed
|
||||
VADUserStoppedSpeakingFrame(), # Should be suppressed
|
||||
UserStoppedSpeakingFrame(), # Should be suppressed
|
||||
BotStoppedSpeakingFrame(), # First speech ends, unmutes
|
||||
VADUserStartedSpeakingFrame(), # Should pass through
|
||||
UserStartedSpeakingFrame(), # Should pass through
|
||||
InputAudioRawFrame(audio=b"", sample_rate=16000, num_channels=1), # Should pass through
|
||||
VADUserStoppedSpeakingFrame(), # Should pass through
|
||||
UserStoppedSpeakingFrame(), # Should pass through
|
||||
BotStartedSpeakingFrame(), # Second speech
|
||||
VADUserStartedSpeakingFrame(), # Should pass through
|
||||
UserStartedSpeakingFrame(), # Should pass through
|
||||
InputAudioRawFrame(audio=b"", sample_rate=16000, num_channels=1), # Should pass through
|
||||
VADUserStoppedSpeakingFrame(), # Should pass through
|
||||
UserStoppedSpeakingFrame(), # Should pass through
|
||||
BotStoppedSpeakingFrame(),
|
||||
]
|
||||
|
||||
expected_returned_frames = [
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
InputAudioRawFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
InputAudioRawFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
]
|
||||
|
||||
await run_test(
|
||||
filter,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_returned_frames,
|
||||
)
|
||||
|
||||
async def test_incompatible_strategies(self):
|
||||
with self.assertRaises(ValueError):
|
||||
STTMuteFilter(
|
||||
config=STTMuteConfig(
|
||||
strategies={
|
||||
STTMuteStrategy.FIRST_SPEECH,
|
||||
STTMuteStrategy.MUTE_UNTIL_FIRST_BOT_COMPLETE,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
async def test_custom_strategy(self):
|
||||
async def custom_mute_logic(processor: STTMuteFilter) -> bool:
|
||||
return processor._bot_is_speaking
|
||||
|
||||
filter = STTMuteFilter(
|
||||
config=STTMuteConfig(
|
||||
strategies={STTMuteStrategy.CUSTOM},
|
||||
should_mute_callback=custom_mute_logic,
|
||||
)
|
||||
)
|
||||
|
||||
frames_to_send = [
|
||||
VADUserStartedSpeakingFrame(), # Should pass through
|
||||
UserStartedSpeakingFrame(), # Should pass through
|
||||
InputAudioRawFrame(audio=b"", sample_rate=16000, num_channels=1), # Should pass through
|
||||
VADUserStoppedSpeakingFrame(), # Should pass through
|
||||
UserStoppedSpeakingFrame(), # Should pass through
|
||||
BotStartedSpeakingFrame(), # Bot starts speaking
|
||||
VADUserStartedSpeakingFrame(), # Should be suppressed
|
||||
UserStartedSpeakingFrame(), # Should be suppressed
|
||||
InputAudioRawFrame(
|
||||
audio=b"", sample_rate=16000, num_channels=1
|
||||
), # Should be suppressed
|
||||
VADUserStoppedSpeakingFrame(), # Should be suppressed
|
||||
UserStoppedSpeakingFrame(), # Should be suppressed
|
||||
BotStoppedSpeakingFrame(), # Bot stops speaking
|
||||
VADUserStartedSpeakingFrame(), # Should pass through
|
||||
UserStartedSpeakingFrame(), # Should pass through
|
||||
InputAudioRawFrame(audio=b"", sample_rate=16000, num_channels=1), # Should pass through
|
||||
VADUserStoppedSpeakingFrame(), # Should pass through
|
||||
UserStoppedSpeakingFrame(), # Should pass through
|
||||
]
|
||||
|
||||
expected_returned_frames = [
|
||||
VADUserStartedSpeakingFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
InputAudioRawFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
InputAudioRawFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
]
|
||||
|
||||
await run_test(
|
||||
filter,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_returned_frames,
|
||||
)
|
||||
|
||||
async def test_interruption_frame_suppressed_when_muted(self):
|
||||
"""Test that InterruptionFrame is suppressed when the filter is muted."""
|
||||
filter = STTMuteFilter(config=STTMuteConfig(strategies={STTMuteStrategy.ALWAYS}))
|
||||
|
||||
frames_to_send = [
|
||||
BotStartedSpeakingFrame(),
|
||||
InterruptionFrame(),
|
||||
BotStoppedSpeakingFrame(),
|
||||
]
|
||||
|
||||
expected_returned_frames = [
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
]
|
||||
|
||||
await run_test(
|
||||
filter,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_returned_frames,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,798 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
|
||||
import asyncio
|
||||
import unittest
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Tuple, cast
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
AggregationType,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
InterruptionFrame,
|
||||
LLMThoughtEndFrame,
|
||||
LLMThoughtStartFrame,
|
||||
LLMThoughtTextFrame,
|
||||
ThoughtTranscriptionMessage,
|
||||
TranscriptionFrame,
|
||||
TranscriptionMessage,
|
||||
TranscriptionUpdateFrame,
|
||||
TTSTextFrame,
|
||||
)
|
||||
from pipecat.processors.transcript_processor import (
|
||||
AssistantTranscriptProcessor,
|
||||
UserTranscriptProcessor,
|
||||
)
|
||||
from pipecat.tests.utils import SleepFrame, run_test
|
||||
|
||||
|
||||
class TestUserTranscriptProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
"""Tests for UserTranscriptProcessor"""
|
||||
|
||||
async def test_basic_transcription(self):
|
||||
"""Test basic transcription frame processing"""
|
||||
# Create processor
|
||||
processor = UserTranscriptProcessor()
|
||||
|
||||
# Create test timestamp
|
||||
timestamp = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
# Create frames to send
|
||||
frames_to_send = [
|
||||
TranscriptionFrame(text="Hello, world!", user_id="test_user", timestamp=timestamp)
|
||||
]
|
||||
|
||||
# Expected frames downstream - note the order:
|
||||
# 1. TranscriptionUpdateFrame (processor emits the update first)
|
||||
# 2. TranscriptionFrame (original frame is passed through)
|
||||
expected_down_frames = [TranscriptionUpdateFrame, TranscriptionFrame]
|
||||
|
||||
# Run test
|
||||
received_frames, _ = await run_test(
|
||||
processor,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
|
||||
# Verify the content of the TranscriptionUpdateFrame
|
||||
update_frame = cast(
|
||||
TranscriptionUpdateFrame, received_frames[0]
|
||||
) # Note: now checking first frame
|
||||
self.assertIsInstance(update_frame, TranscriptionUpdateFrame)
|
||||
self.assertEqual(len(update_frame.messages), 1)
|
||||
message = update_frame.messages[0]
|
||||
self.assertEqual(message.role, "user")
|
||||
self.assertEqual(message.content, "Hello, world!")
|
||||
self.assertEqual(message.user_id, "test_user")
|
||||
self.assertEqual(message.timestamp, timestamp)
|
||||
|
||||
async def test_event_handler(self):
|
||||
"""Test that event handlers are called with transcript updates"""
|
||||
# Create processor
|
||||
processor = UserTranscriptProcessor()
|
||||
|
||||
# Track received updates
|
||||
received_updates: List[TranscriptionMessage] = []
|
||||
|
||||
# Register event handler
|
||||
@processor.event_handler("on_transcript_update")
|
||||
async def handle_update(proc, frame: TranscriptionUpdateFrame):
|
||||
received_updates.extend(frame.messages)
|
||||
|
||||
# Create test data
|
||||
timestamp = datetime.now(timezone.utc).isoformat()
|
||||
frames_to_send = [
|
||||
TranscriptionFrame(text="First message", user_id="test_user", timestamp=timestamp),
|
||||
TranscriptionFrame(text="Second message", user_id="test_user", timestamp=timestamp),
|
||||
]
|
||||
|
||||
expected_down_frames = [
|
||||
TranscriptionUpdateFrame,
|
||||
TranscriptionFrame, # First message
|
||||
TranscriptionUpdateFrame,
|
||||
TranscriptionFrame, # Second message
|
||||
]
|
||||
|
||||
# Run test
|
||||
await run_test(
|
||||
processor,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
|
||||
# Verify event handler received updates
|
||||
self.assertEqual(len(received_updates), 2)
|
||||
|
||||
# Check first message
|
||||
self.assertEqual(received_updates[0].role, "user")
|
||||
self.assertEqual(received_updates[0].content, "First message")
|
||||
self.assertEqual(received_updates[0].timestamp, timestamp)
|
||||
|
||||
# Check second message
|
||||
self.assertEqual(received_updates[1].role, "user")
|
||||
self.assertEqual(received_updates[1].content, "Second message")
|
||||
self.assertEqual(received_updates[1].timestamp, timestamp)
|
||||
|
||||
async def test_text_aggregation(self):
|
||||
"""Test that TTSTextFrames are properly aggregated into a single message"""
|
||||
# Create processor
|
||||
processor = AssistantTranscriptProcessor()
|
||||
|
||||
# Track received updates
|
||||
received_updates: List[TranscriptionUpdateFrame] = []
|
||||
|
||||
@processor.event_handler("on_transcript_update")
|
||||
async def handle_update(proc, frame: TranscriptionUpdateFrame):
|
||||
received_updates.append(frame)
|
||||
|
||||
# Create test frames simulating bot speaking multiple text chunks
|
||||
frames_to_send = [
|
||||
BotStartedSpeakingFrame(),
|
||||
SleepFrame(), # Wait for StartedSpeaking to process
|
||||
TTSTextFrame(text="Hello", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text="world!", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text="How", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text="are", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text="you?", aggregated_by=AggregationType.WORD),
|
||||
SleepFrame(), # Wait for text frames to queue
|
||||
BotStoppedSpeakingFrame(),
|
||||
]
|
||||
|
||||
# Expected order:
|
||||
# 1. BotStartedSpeakingFrame (system frame, immediate)
|
||||
# 2. All queued TTSTextFrames
|
||||
# 3. BotStoppedSpeakingFrame (system frame, immediate)
|
||||
# 4. TranscriptionUpdateFrame (after aggregation)
|
||||
expected_down_frames = [
|
||||
BotStartedSpeakingFrame,
|
||||
TTSTextFrame,
|
||||
TTSTextFrame,
|
||||
TTSTextFrame,
|
||||
TTSTextFrame,
|
||||
TTSTextFrame,
|
||||
TranscriptionUpdateFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
]
|
||||
|
||||
# Run test
|
||||
received_frames, _ = await run_test(
|
||||
processor,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
|
||||
# Verify update was received
|
||||
self.assertEqual(len(received_updates), 1)
|
||||
|
||||
# Get the update frame
|
||||
update_frame = received_updates[0]
|
||||
|
||||
# Should have one aggregated message
|
||||
self.assertEqual(len(update_frame.messages), 1)
|
||||
|
||||
message = update_frame.messages[0]
|
||||
self.assertEqual(message.role, "assistant")
|
||||
self.assertEqual(message.content, "Hello world! How are you?")
|
||||
|
||||
# Verify timestamp exists
|
||||
self.assertIsNotNone(message.timestamp)
|
||||
|
||||
# All frames should be passed through in order, with update at end
|
||||
downstream_update = cast(TranscriptionUpdateFrame, received_frames[-2])
|
||||
self.assertEqual(downstream_update.messages[0].content, "Hello world! How are you?")
|
||||
|
||||
async def test_empty_text_handling(self):
|
||||
"""Test that empty messages are not emitted"""
|
||||
processor = AssistantTranscriptProcessor()
|
||||
|
||||
received_updates: List[TranscriptionUpdateFrame] = []
|
||||
|
||||
@processor.event_handler("on_transcript_update")
|
||||
async def handle_update(proc, frame: TranscriptionUpdateFrame):
|
||||
received_updates.append(frame)
|
||||
|
||||
frames_to_send = [
|
||||
BotStartedSpeakingFrame(),
|
||||
SleepFrame(),
|
||||
TTSTextFrame(text="", aggregated_by=AggregationType.WORD), # Empty text
|
||||
TTSTextFrame(text=" ", aggregated_by=AggregationType.WORD), # Just whitespace
|
||||
TTSTextFrame(text="\n", aggregated_by=AggregationType.WORD), # Just newline
|
||||
BotStoppedSpeakingFrame(),
|
||||
# Pipeline ends here; run_test will automatically send EndFrame
|
||||
]
|
||||
|
||||
# From our earlier tests, we know BotStoppedSpeakingFrame comes before TTSTextFrames
|
||||
expected_down_frames = [
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
TTSTextFrame, # empty
|
||||
TTSTextFrame, # whitespace
|
||||
TTSTextFrame, # newline
|
||||
# No TranscriptionUpdateFrame since content is empty after stripping
|
||||
]
|
||||
|
||||
await run_test(
|
||||
processor,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
|
||||
self.assertEqual(len(received_updates), 0, "No updates should be emitted for empty content")
|
||||
|
||||
async def test_interruption_handling(self):
|
||||
"""Test that messages are properly captured when bot is interrupted"""
|
||||
processor = AssistantTranscriptProcessor()
|
||||
|
||||
# Track received updates
|
||||
received_updates: List[TranscriptionUpdateFrame] = []
|
||||
|
||||
@processor.event_handler("on_transcript_update")
|
||||
async def handle_update(proc, frame: TranscriptionUpdateFrame):
|
||||
received_updates.append(frame)
|
||||
|
||||
# Simulate bot being interrupted mid-sentence
|
||||
frames_to_send = [
|
||||
BotStartedSpeakingFrame(),
|
||||
SleepFrame(),
|
||||
TTSTextFrame(text="Hello", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text="world!", aggregated_by=AggregationType.WORD),
|
||||
SleepFrame(),
|
||||
InterruptionFrame(), # User interrupts here
|
||||
SleepFrame(),
|
||||
BotStartedSpeakingFrame(),
|
||||
TTSTextFrame(text="New", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text="response", aggregated_by=AggregationType.WORD),
|
||||
SleepFrame(),
|
||||
BotStoppedSpeakingFrame(),
|
||||
]
|
||||
|
||||
# Actual order of frames:
|
||||
expected_down_frames = [
|
||||
BotStartedSpeakingFrame,
|
||||
TTSTextFrame, # "Hello"
|
||||
TTSTextFrame, # "world!"
|
||||
InterruptionFrame,
|
||||
TranscriptionUpdateFrame, # First message (emitted due to interruption)
|
||||
BotStartedSpeakingFrame,
|
||||
TTSTextFrame, # "New"
|
||||
TTSTextFrame, # "response"
|
||||
TranscriptionUpdateFrame, # Second message
|
||||
BotStoppedSpeakingFrame,
|
||||
]
|
||||
|
||||
# Run test
|
||||
received_frames, _ = await run_test(
|
||||
processor,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
|
||||
# Should have received two updates
|
||||
self.assertEqual(len(received_updates), 2)
|
||||
|
||||
# First update should be interrupted message
|
||||
first_message = received_updates[0].messages[0]
|
||||
self.assertEqual(first_message.role, "assistant")
|
||||
self.assertEqual(first_message.content, "Hello world!")
|
||||
self.assertIsNotNone(first_message.timestamp)
|
||||
|
||||
# Second update should be new response
|
||||
second_message = received_updates[1].messages[0]
|
||||
self.assertEqual(second_message.role, "assistant")
|
||||
self.assertEqual(second_message.content, "New response")
|
||||
self.assertIsNotNone(second_message.timestamp)
|
||||
|
||||
# Verify timestamps are different
|
||||
self.assertNotEqual(first_message.timestamp, second_message.timestamp)
|
||||
|
||||
async def test_end_frame_handling(self):
|
||||
"""Test that final messages are captured when pipeline ends normally"""
|
||||
processor = AssistantTranscriptProcessor()
|
||||
|
||||
received_updates: List[TranscriptionUpdateFrame] = []
|
||||
|
||||
@processor.event_handler("on_transcript_update")
|
||||
async def handle_update(proc, frame: TranscriptionUpdateFrame):
|
||||
received_updates.append(frame)
|
||||
|
||||
frames_to_send = [
|
||||
BotStartedSpeakingFrame(),
|
||||
SleepFrame(),
|
||||
TTSTextFrame(text="Hello", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text="world", aggregated_by=AggregationType.WORD),
|
||||
# Pipeline ends here; run_test will automatically send EndFrame
|
||||
]
|
||||
|
||||
expected_down_frames = [
|
||||
BotStartedSpeakingFrame,
|
||||
TTSTextFrame,
|
||||
TTSTextFrame,
|
||||
TranscriptionUpdateFrame, # Final message emitted due to EndFrame
|
||||
]
|
||||
|
||||
# Run test - EndFrame will be sent automatically
|
||||
received_frames, _ = await run_test(
|
||||
processor,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
|
||||
self.assertEqual(len(received_updates), 1)
|
||||
message = received_updates[0].messages[0]
|
||||
self.assertEqual(message.role, "assistant")
|
||||
self.assertEqual(message.content, "Hello world")
|
||||
|
||||
async def test_cancel_frame_handling(self):
|
||||
"""Test that messages are properly captured when pipeline is cancelled"""
|
||||
processor = AssistantTranscriptProcessor()
|
||||
|
||||
# Track updates with timestamps to verify order
|
||||
received_updates: List[Tuple[str, float]] = []
|
||||
|
||||
@processor.event_handler("on_transcript_update")
|
||||
async def handle_update(proc, frame: TranscriptionUpdateFrame):
|
||||
# Record message content and time received
|
||||
received_updates.append((frame.messages[0].content, asyncio.get_event_loop().time()))
|
||||
|
||||
frames_to_send = [
|
||||
BotStartedSpeakingFrame(),
|
||||
SleepFrame(),
|
||||
TTSTextFrame(text="Hello", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text="world", aggregated_by=AggregationType.WORD),
|
||||
SleepFrame(), # Ensure messages are processed
|
||||
CancelFrame(),
|
||||
]
|
||||
|
||||
# We don't need to verify frame order, just that CancelFrame triggers message emission
|
||||
expected_down_frames = [
|
||||
BotStartedSpeakingFrame,
|
||||
TTSTextFrame,
|
||||
TTSTextFrame,
|
||||
CancelFrame,
|
||||
]
|
||||
|
||||
await run_test(
|
||||
processor,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
send_end_frame=False,
|
||||
)
|
||||
|
||||
# Verify that we received an update
|
||||
self.assertEqual(len(received_updates), 1, "Should receive one update before cancellation")
|
||||
content, _ = received_updates[0]
|
||||
self.assertEqual(content, "Hello world")
|
||||
|
||||
async def test_transcript_processor_factory(self):
|
||||
"""Test that factory properly manages processors and event handlers"""
|
||||
from pipecat.processors.transcript_processor import TranscriptProcessor
|
||||
|
||||
factory = TranscriptProcessor()
|
||||
received_updates: List[TranscriptionMessage] = []
|
||||
|
||||
# Register handler with factory
|
||||
@factory.event_handler("on_transcript_update")
|
||||
async def handle_update(proc, frame: TranscriptionUpdateFrame):
|
||||
received_updates.extend(frame.messages)
|
||||
|
||||
# Get processors and verify they're reused
|
||||
user_proc1 = factory.user()
|
||||
user_proc2 = factory.user()
|
||||
self.assertIs(user_proc1, user_proc2, "User processor should be reused")
|
||||
|
||||
asst_proc1 = factory.assistant()
|
||||
asst_proc2 = factory.assistant()
|
||||
self.assertIs(asst_proc1, asst_proc2, "Assistant processor should be reused")
|
||||
|
||||
# Test user processor
|
||||
timestamp = datetime.now(timezone.utc).isoformat()
|
||||
frames_to_send = [
|
||||
TranscriptionFrame(text="User message", user_id="user1", timestamp=timestamp)
|
||||
]
|
||||
|
||||
await run_test(
|
||||
user_proc1,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=[TranscriptionUpdateFrame, TranscriptionFrame],
|
||||
)
|
||||
|
||||
# Test assistant processor
|
||||
frames_to_send = [
|
||||
BotStartedSpeakingFrame(),
|
||||
SleepFrame(),
|
||||
TTSTextFrame(text="Assistant", aggregated_by=AggregationType.WORD),
|
||||
TTSTextFrame(text="message", aggregated_by=AggregationType.WORD),
|
||||
BotStoppedSpeakingFrame(),
|
||||
]
|
||||
|
||||
# The actual order we see in the output:
|
||||
await run_test(
|
||||
asst_proc1,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=[
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
TTSTextFrame,
|
||||
TTSTextFrame,
|
||||
TranscriptionUpdateFrame,
|
||||
],
|
||||
)
|
||||
|
||||
# Verify both processors triggered the same handler
|
||||
self.assertEqual(len(received_updates), 2)
|
||||
self.assertEqual(received_updates[0].role, "user")
|
||||
self.assertEqual(received_updates[0].content, "User message")
|
||||
self.assertEqual(received_updates[1].role, "assistant")
|
||||
self.assertEqual(received_updates[1].content, "Assistant message")
|
||||
|
||||
async def test_text_fragments_with_spaces(self):
|
||||
"""Test aggregating text fragments with various spacing patterns"""
|
||||
processor = AssistantTranscriptProcessor()
|
||||
|
||||
# Track received updates
|
||||
received_updates = []
|
||||
|
||||
@processor.event_handler("on_transcript_update")
|
||||
async def handle_update(proc, frame: TranscriptionUpdateFrame):
|
||||
received_updates.append(frame)
|
||||
|
||||
# Test the specific pattern shared
|
||||
def make_tts_text_frame(text: str) -> TTSTextFrame:
|
||||
frame = TTSTextFrame(text=text, aggregated_by=AggregationType.WORD)
|
||||
frame.includes_inter_frame_spaces = True
|
||||
return frame
|
||||
|
||||
frames_to_send = [
|
||||
BotStartedSpeakingFrame(),
|
||||
SleepFrame(),
|
||||
make_tts_text_frame("Hello"),
|
||||
make_tts_text_frame(" there"),
|
||||
make_tts_text_frame("!"),
|
||||
make_tts_text_frame(" How"),
|
||||
make_tts_text_frame("'s"),
|
||||
make_tts_text_frame(" it"),
|
||||
make_tts_text_frame(" going"),
|
||||
make_tts_text_frame("?"),
|
||||
BotStoppedSpeakingFrame(),
|
||||
]
|
||||
|
||||
expected_down_frames = [
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
TTSTextFrame,
|
||||
TTSTextFrame,
|
||||
TTSTextFrame,
|
||||
TTSTextFrame,
|
||||
TTSTextFrame,
|
||||
TTSTextFrame,
|
||||
TTSTextFrame,
|
||||
TTSTextFrame,
|
||||
TranscriptionUpdateFrame,
|
||||
]
|
||||
|
||||
# Run test
|
||||
received_frames, _ = await run_test(
|
||||
processor,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
|
||||
# Verify result
|
||||
self.assertEqual(len(received_updates), 1)
|
||||
message = received_updates[0].messages[0]
|
||||
self.assertEqual(message.role, "assistant")
|
||||
# Should be properly joined without extra spaces
|
||||
self.assertEqual(message.content, "Hello there! How's it going?")
|
||||
|
||||
|
||||
class TestThoughtTranscription(unittest.IsolatedAsyncioTestCase):
|
||||
"""Tests for thought transcription in AssistantTranscriptProcessor"""
|
||||
|
||||
async def test_basic_thought_transcription(self):
|
||||
"""Test basic thought frame processing"""
|
||||
processor = AssistantTranscriptProcessor(process_thoughts=True)
|
||||
|
||||
received_updates: List[TranscriptionUpdateFrame] = []
|
||||
|
||||
@processor.event_handler("on_transcript_update")
|
||||
async def handle_update(proc, frame: TranscriptionUpdateFrame):
|
||||
received_updates.append(frame)
|
||||
|
||||
# Create frames for a simple thought
|
||||
frames_to_send = [
|
||||
LLMThoughtStartFrame(),
|
||||
LLMThoughtTextFrame(text="Let me think about this..."),
|
||||
LLMThoughtEndFrame(),
|
||||
]
|
||||
|
||||
expected_down_frames = [
|
||||
LLMThoughtStartFrame,
|
||||
LLMThoughtTextFrame,
|
||||
TranscriptionUpdateFrame,
|
||||
LLMThoughtEndFrame,
|
||||
]
|
||||
|
||||
await run_test(
|
||||
processor,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
|
||||
# Verify update was received
|
||||
self.assertEqual(len(received_updates), 1)
|
||||
message = received_updates[0].messages[0]
|
||||
self.assertIsInstance(message, ThoughtTranscriptionMessage)
|
||||
self.assertEqual(message.content, "Let me think about this...")
|
||||
self.assertIsNotNone(message.timestamp)
|
||||
|
||||
async def test_thought_aggregation(self):
|
||||
"""Test that thought text frames are properly aggregated"""
|
||||
processor = AssistantTranscriptProcessor(process_thoughts=True)
|
||||
|
||||
received_updates: List[TranscriptionUpdateFrame] = []
|
||||
|
||||
@processor.event_handler("on_transcript_update")
|
||||
async def handle_update(proc, frame: TranscriptionUpdateFrame):
|
||||
received_updates.append(frame)
|
||||
|
||||
# Create frames simulating chunked thought text
|
||||
frames_to_send = [
|
||||
LLMThoughtStartFrame(),
|
||||
LLMThoughtTextFrame(text="The user "),
|
||||
LLMThoughtTextFrame(text="is asking "),
|
||||
LLMThoughtTextFrame(text="about electric "),
|
||||
LLMThoughtTextFrame(text="cars."),
|
||||
LLMThoughtEndFrame(),
|
||||
]
|
||||
|
||||
expected_down_frames = [
|
||||
LLMThoughtStartFrame,
|
||||
LLMThoughtTextFrame,
|
||||
LLMThoughtTextFrame,
|
||||
LLMThoughtTextFrame,
|
||||
LLMThoughtTextFrame,
|
||||
TranscriptionUpdateFrame,
|
||||
LLMThoughtEndFrame,
|
||||
]
|
||||
|
||||
await run_test(
|
||||
processor,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
|
||||
# Verify aggregation
|
||||
self.assertEqual(len(received_updates), 1)
|
||||
message = received_updates[0].messages[0]
|
||||
self.assertIsInstance(message, ThoughtTranscriptionMessage)
|
||||
self.assertEqual(message.content, "The user is asking about electric cars.")
|
||||
|
||||
async def test_thought_with_interruption(self):
|
||||
"""Test that thoughts are properly captured when interrupted"""
|
||||
processor = AssistantTranscriptProcessor(process_thoughts=True)
|
||||
|
||||
received_updates: List[TranscriptionUpdateFrame] = []
|
||||
|
||||
@processor.event_handler("on_transcript_update")
|
||||
async def handle_update(proc, frame: TranscriptionUpdateFrame):
|
||||
received_updates.append(frame)
|
||||
|
||||
frames_to_send = [
|
||||
LLMThoughtStartFrame(),
|
||||
LLMThoughtTextFrame(text="I need to consider "),
|
||||
LLMThoughtTextFrame(text="multiple factors"),
|
||||
SleepFrame(),
|
||||
InterruptionFrame(), # User interrupts
|
||||
]
|
||||
|
||||
expected_down_frames = [
|
||||
LLMThoughtStartFrame,
|
||||
LLMThoughtTextFrame,
|
||||
LLMThoughtTextFrame,
|
||||
InterruptionFrame,
|
||||
TranscriptionUpdateFrame,
|
||||
]
|
||||
|
||||
await run_test(
|
||||
processor,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
|
||||
# Verify thought was captured on interruption
|
||||
self.assertEqual(len(received_updates), 1)
|
||||
message = received_updates[0].messages[0]
|
||||
self.assertIsInstance(message, ThoughtTranscriptionMessage)
|
||||
self.assertEqual(message.content, "I need to consider multiple factors")
|
||||
|
||||
async def test_thought_with_cancel(self):
|
||||
"""Test that thoughts are properly captured when cancelled"""
|
||||
processor = AssistantTranscriptProcessor(process_thoughts=True)
|
||||
|
||||
received_updates: List[TranscriptionUpdateFrame] = []
|
||||
|
||||
@processor.event_handler("on_transcript_update")
|
||||
async def handle_update(proc, frame: TranscriptionUpdateFrame):
|
||||
received_updates.append(frame)
|
||||
|
||||
frames_to_send = [
|
||||
LLMThoughtStartFrame(),
|
||||
LLMThoughtTextFrame(text="Starting analysis"),
|
||||
SleepFrame(),
|
||||
CancelFrame(),
|
||||
]
|
||||
|
||||
expected_down_frames = [
|
||||
LLMThoughtStartFrame,
|
||||
LLMThoughtTextFrame,
|
||||
CancelFrame,
|
||||
]
|
||||
|
||||
await run_test(
|
||||
processor,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
send_end_frame=False,
|
||||
)
|
||||
|
||||
# Verify thought was captured on cancellation
|
||||
self.assertEqual(len(received_updates), 1)
|
||||
message = received_updates[0].messages[0]
|
||||
self.assertIsInstance(message, ThoughtTranscriptionMessage)
|
||||
self.assertEqual(message.content, "Starting analysis")
|
||||
|
||||
async def test_thought_with_end_frame(self):
|
||||
"""Test that thoughts are captured when pipeline ends normally"""
|
||||
processor = AssistantTranscriptProcessor(process_thoughts=True)
|
||||
|
||||
received_updates: List[TranscriptionUpdateFrame] = []
|
||||
|
||||
@processor.event_handler("on_transcript_update")
|
||||
async def handle_update(proc, frame: TranscriptionUpdateFrame):
|
||||
received_updates.append(frame)
|
||||
|
||||
frames_to_send = [
|
||||
LLMThoughtStartFrame(),
|
||||
LLMThoughtTextFrame(text="Final thought"),
|
||||
# Pipeline ends here; run_test will automatically send EndFrame
|
||||
]
|
||||
|
||||
expected_down_frames = [
|
||||
LLMThoughtStartFrame,
|
||||
LLMThoughtTextFrame,
|
||||
TranscriptionUpdateFrame,
|
||||
]
|
||||
|
||||
await run_test(
|
||||
processor,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
|
||||
# Verify thought was captured on EndFrame
|
||||
self.assertEqual(len(received_updates), 1)
|
||||
message = received_updates[0].messages[0]
|
||||
self.assertIsInstance(message, ThoughtTranscriptionMessage)
|
||||
self.assertEqual(message.content, "Final thought")
|
||||
|
||||
async def test_multiple_thoughts(self):
|
||||
"""Test multiple separate thoughts in sequence"""
|
||||
processor = AssistantTranscriptProcessor(process_thoughts=True)
|
||||
|
||||
received_updates: List[TranscriptionUpdateFrame] = []
|
||||
|
||||
@processor.event_handler("on_transcript_update")
|
||||
async def handle_update(proc, frame: TranscriptionUpdateFrame):
|
||||
received_updates.append(frame)
|
||||
|
||||
frames_to_send = [
|
||||
# First thought
|
||||
LLMThoughtStartFrame(),
|
||||
LLMThoughtTextFrame(text="First consideration"),
|
||||
LLMThoughtEndFrame(),
|
||||
# Second thought
|
||||
LLMThoughtStartFrame(),
|
||||
LLMThoughtTextFrame(text="Second consideration"),
|
||||
LLMThoughtEndFrame(),
|
||||
]
|
||||
|
||||
expected_down_frames = [
|
||||
LLMThoughtStartFrame,
|
||||
LLMThoughtTextFrame,
|
||||
TranscriptionUpdateFrame,
|
||||
LLMThoughtEndFrame,
|
||||
LLMThoughtStartFrame,
|
||||
LLMThoughtTextFrame,
|
||||
TranscriptionUpdateFrame,
|
||||
LLMThoughtEndFrame,
|
||||
]
|
||||
|
||||
await run_test(
|
||||
processor,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
|
||||
# Verify both thoughts were captured
|
||||
self.assertEqual(len(received_updates), 2)
|
||||
|
||||
first_message = received_updates[0].messages[0]
|
||||
self.assertIsInstance(first_message, ThoughtTranscriptionMessage)
|
||||
self.assertEqual(first_message.content, "First consideration")
|
||||
|
||||
second_message = received_updates[1].messages[0]
|
||||
self.assertIsInstance(second_message, ThoughtTranscriptionMessage)
|
||||
self.assertEqual(second_message.content, "Second consideration")
|
||||
|
||||
async def test_empty_thought_handling(self):
|
||||
"""Test that empty thoughts are not emitted"""
|
||||
processor = AssistantTranscriptProcessor(process_thoughts=True)
|
||||
|
||||
received_updates: List[TranscriptionUpdateFrame] = []
|
||||
|
||||
@processor.event_handler("on_transcript_update")
|
||||
async def handle_update(proc, frame: TranscriptionUpdateFrame):
|
||||
received_updates.append(frame)
|
||||
|
||||
frames_to_send = [
|
||||
LLMThoughtStartFrame(),
|
||||
LLMThoughtTextFrame(text=""), # Empty
|
||||
LLMThoughtTextFrame(text=" "), # Just whitespace
|
||||
LLMThoughtEndFrame(),
|
||||
]
|
||||
|
||||
expected_down_frames = [
|
||||
LLMThoughtStartFrame,
|
||||
LLMThoughtTextFrame,
|
||||
LLMThoughtTextFrame,
|
||||
LLMThoughtEndFrame,
|
||||
]
|
||||
|
||||
await run_test(
|
||||
processor,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
|
||||
# Verify no updates emitted for empty content
|
||||
self.assertEqual(len(received_updates), 0)
|
||||
|
||||
async def test_thought_without_start_frame(self):
|
||||
"""Test that thought text without start frame is ignored"""
|
||||
processor = AssistantTranscriptProcessor(process_thoughts=True)
|
||||
|
||||
received_updates: List[TranscriptionUpdateFrame] = []
|
||||
|
||||
@processor.event_handler("on_transcript_update")
|
||||
async def handle_update(proc, frame: TranscriptionUpdateFrame):
|
||||
received_updates.append(frame)
|
||||
|
||||
# Send thought text without start frame
|
||||
frames_to_send = [
|
||||
LLMThoughtTextFrame(text="This should be ignored"),
|
||||
LLMThoughtEndFrame(),
|
||||
]
|
||||
|
||||
expected_down_frames = [
|
||||
LLMThoughtTextFrame,
|
||||
LLMThoughtEndFrame,
|
||||
]
|
||||
|
||||
await run_test(
|
||||
processor,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
|
||||
# Verify no updates since thought wasn't properly started
|
||||
self.assertEqual(len(received_updates), 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user