Compare commits
19 Commits
jh/aws-aut
...
vs/deepgra
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8d4dbd4ac0 | ||
|
|
5ceed1e615 | ||
|
|
0623c6c79b | ||
|
|
6d66bbceeb | ||
|
|
a27d9fc30b | ||
|
|
2a8f4734e0 | ||
|
|
48ac68e3c8 | ||
|
|
c3ef199efa | ||
|
|
1b5c4cfa2a | ||
|
|
6e9dd1dbcc | ||
|
|
6487f895b3 | ||
|
|
351105a975 | ||
|
|
8ea963852d | ||
|
|
6f4458f21d | ||
|
|
fb42a7dcf3 | ||
|
|
21547c8680 | ||
|
|
3e5aabc5f2 | ||
|
|
e508642b0a | ||
|
|
e546541e20 |
1
changelog/4390.added.md
Normal file
1
changelog/4390.added.md
Normal file
@@ -0,0 +1 @@
|
||||
- Added a `max_buffer_delay_ms` constructor argument to `CartesiaTTSService` for controlling Cartesia's server-side text buffering. When unset, Pipecat picks a sensible default based on `text_aggregation_mode`: `0` in `SENTENCE` mode (custom buffering — avoids stacking client-side aggregation on top of Cartesia's default 3000ms server buffer) and unset in `TOKEN` mode (Cartesia's managed buffering applies). Pass an explicit value (0–5000ms) to override.
|
||||
1
changelog/4390.changed.2.md
Normal file
1
changelog/4390.changed.2.md
Normal file
@@ -0,0 +1 @@
|
||||
- Default `cartesia_version` for `CartesiaTTSService` bumped from `2025-04-16` to `2026-03-01`, matching `CartesiaHttpTTSService` and unlocking the `use_normalized_timestamps` and `max_buffer_delay_ms` fields.
|
||||
1
changelog/4390.changed.md
Normal file
1
changelog/4390.changed.md
Normal file
@@ -0,0 +1 @@
|
||||
- ⚠️ `CartesiaTTSService` now sends `use_normalized_timestamps: true` instead of the deprecated `use_original_timestamps` field. Word timestamps now reflect what was actually spoken (post text-normalization and pronunciation-dictionary substitution), matching the convention Pipecat uses for ElevenLabs. This is a behavior change for `sonic-3` users, who were previously receiving timestamps tied to the input transcript.
|
||||
1
changelog/4390.fixed.2.md
Normal file
1
changelog/4390.fixed.2.md
Normal file
@@ -0,0 +1 @@
|
||||
- Fixed `CartesiaHttpTTSService` pushing two `ErrorFrame`s on a non-200 response — one with the API's error text and a second, less informative "Unknown error" frame from the outer exception handler. It now pushes a single frame that includes the HTTP status code and returns cleanly.
|
||||
1
changelog/4390.fixed.3.md
Normal file
1
changelog/4390.fixed.3.md
Normal file
@@ -0,0 +1 @@
|
||||
- Fixed Cartesia tag helpers (`SPELL`, `EMOTION_TAG`, `PAUSE_TAG`, `VOLUME_TAG`, `SPEED_TAG`) raising `TypeError` when called on an instance (e.g. `tts.SPELL("hi")`). They're now `@staticmethod` and callable from both the class and an instance.
|
||||
1
changelog/4390.fixed.md
Normal file
1
changelog/4390.fixed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Fixed `CartesiaTTSService` surfacing `flush_done` messages from Cartesia as `ErrorFrame`s. The latest API emits a `flush_done` per transcript when server-side buffering is disabled; Pipecat now consumes them silently since each turn already has its own `context_id`.
|
||||
1
changelog/4393.fixed.md
Normal file
1
changelog/4393.fixed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Fixed an issue where `LocalSmartTurnAnalyzerV3` was imported unconditionally for user turn stop strategies. It is now only imported when `default_user_turn_stop_strategies()` is called. This improves startup time and removes the `transformers` "PyTorch/TensorFlow/Flax not found" warning when the default stop strategies are not used.
|
||||
1
changelog/4395.changed.md
Normal file
1
changelog/4395.changed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Broadened `tool_resources` to `app_resources` for easy access not just in tool handlers but in other places like custom `FrameProcessor`s. Three changes: a rename (`tool_resources` → `app_resources`), a new `app_resources` property on `PipelineTask`, and a new `pipeline_task` property on `FrameProcessor`. Tool handlers now read `params.app_resources`; custom processors read `self.pipeline_task.app_resources`. The previous `tool_resources` aliases (on `PipelineTask`, `FunctionCallParams`, and `FrameProcessorSetup`) keep working but are deprecated as of 1.2.0 and emit `DeprecationWarning`s.
|
||||
1
changelog/4399.added.md
Normal file
1
changelog/4399.added.md
Normal file
@@ -0,0 +1 @@
|
||||
- Added `mip_opt_out` to `DeepgramTTSSettings` (used by both `DeepgramTTSService` and `DeepgramHttpTTSService`) for opting out of Deepgram's Model Improvement Program. Pass it via `settings=DeepgramTTSService.Settings(mip_opt_out=True)` to mirror the existing flag on `DeepgramSTTService`.
|
||||
@@ -4,23 +4,33 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Example demonstrating ``PipelineTask(tool_resources=...)``.
|
||||
"""Example demonstrating ``PipelineTask(app_resources=...)``.
|
||||
|
||||
``tool_resources`` is an application-defined bag of anything you want every
|
||||
tool handler in a session to share by reference: database handles, HTTP
|
||||
clients, feature flags, per-user state, observability clients, in-memory
|
||||
caches — whatever fits your app. Pipecat passes it through untouched as
|
||||
``FunctionCallParams.tool_resources``.
|
||||
``app_resources`` is an application-defined bag of anything your
|
||||
application code may want to share across a session: database handles,
|
||||
HTTP clients, feature flags, per-user state, observability clients,
|
||||
in-memory caches — whatever fits your app. Pipecat passes it through
|
||||
untouched and exposes it as ``task.app_resources``, so any code with a
|
||||
handle on the task can read or mutate it.
|
||||
|
||||
This example uses a small ``ToolCallLogger`` as a stand-in for that "shared
|
||||
thing". A real app might just as easily pass a Postgres pool, a Redis
|
||||
client, a Stripe SDK instance, or any combination thereof. The mechanics
|
||||
shown here — construct once, hand to the task, read it from each handler,
|
||||
inspect it after the session — are the same regardless of what you put in.
|
||||
Two of the convenience aliases exercised below:
|
||||
|
||||
We bundle resources in a typed ``SessionResources`` dataclass and cast back
|
||||
to it at the top of each handler. Pipecat doesn't care what type you pass
|
||||
(a plain dict works too), but a typed container gives you autocomplete and
|
||||
- Tool handlers read it from ``FunctionCallParams.app_resources``.
|
||||
- Custom ``FrameProcessor`` subclasses read it from
|
||||
``self.pipeline_task.app_resources``.
|
||||
|
||||
This example uses two small loggers as stand-ins for that "shared thing":
|
||||
``ToolCallLogger`` (written from tool handlers) and
|
||||
``TranscriptionLogger`` (written from a custom ``FrameProcessor`` that
|
||||
sits in the pipeline). A real app might just as easily pass a Postgres
|
||||
pool, a Redis client, a Stripe SDK instance, or any combination thereof.
|
||||
The mechanics shown here — construct once, hand to the task, read it
|
||||
from each site, inspect it after the session — are the same regardless
|
||||
of what you put in.
|
||||
|
||||
We bundle resources in a typed ``AppResources`` dataclass and cast back
|
||||
to it at each read site. Pipecat doesn't care what type you pass (a
|
||||
plain dict works too), but a typed container gives you autocomplete and
|
||||
refactor safety instead of dict-by-string-key lookups.
|
||||
"""
|
||||
|
||||
@@ -28,7 +38,7 @@ import json
|
||||
import os
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, cast
|
||||
|
||||
from dotenv import load_dotenv
|
||||
@@ -37,7 +47,7 @@ from loguru import logger
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import LLMRunFrame, TTSSpeakFrame
|
||||
from pipecat.frames.frames import Frame, LLMRunFrame, TranscriptionFrame, TTSSpeakFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
@@ -46,6 +56,7 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
@@ -86,30 +97,80 @@ class ToolCallLogger:
|
||||
return json.dumps(self._calls, indent=2)
|
||||
|
||||
|
||||
class TranscriptionLogger:
|
||||
"""Records final user transcriptions — written from a custom FrameProcessor."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the logger with an empty list of recorded transcriptions."""
|
||||
self._entries: list[dict[str, Any]] = []
|
||||
|
||||
def log_transcription(self, text: str) -> None:
|
||||
"""Record a transcription.
|
||||
|
||||
Args:
|
||||
text: The transcribed user utterance.
|
||||
"""
|
||||
entry = {
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
"text": text,
|
||||
}
|
||||
self._entries.append(entry)
|
||||
logger.info(f"[TranscriptionLogger] {text!r}")
|
||||
|
||||
def dump(self) -> str:
|
||||
"""Return all recorded transcriptions as a JSON string."""
|
||||
return json.dumps(self._entries, indent=2)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionResources:
|
||||
"""Typed container for everything the tool handlers in this session share.
|
||||
class AppResources:
|
||||
"""Typed container for everything the app shares across this session.
|
||||
|
||||
Add fields here as the app grows (e.g. ``db: AsyncConnection``,
|
||||
``http: httpx.AsyncClient``). Handlers ``cast()`` ``params.tool_resources``
|
||||
to this type to get autocomplete and refactor safety.
|
||||
``http: httpx.AsyncClient``). Read sites ``cast()`` to this type to
|
||||
get autocomplete and refactor safety:
|
||||
|
||||
- In tools: ``cast(AppResources, params.app_resources)``.
|
||||
- In custom processors: ``cast(AppResources, self.pipeline_task.app_resources)``.
|
||||
"""
|
||||
|
||||
tool_call_logger: ToolCallLogger
|
||||
transcription_logger: TranscriptionLogger
|
||||
|
||||
|
||||
async def fetch_weather_from_api(params: FunctionCallParams):
|
||||
resources = cast(SessionResources, params.tool_resources)
|
||||
resources = cast(AppResources, params.app_resources)
|
||||
resources.tool_call_logger.log_tool_call(params.function_name, params.arguments)
|
||||
await params.result_callback({"conditions": "nice", "temperature": "75"})
|
||||
|
||||
|
||||
async def fetch_restaurant_recommendation(params: FunctionCallParams):
|
||||
resources = cast(SessionResources, params.tool_resources)
|
||||
resources = cast(AppResources, params.app_resources)
|
||||
resources.tool_call_logger.log_tool_call(params.function_name, params.arguments)
|
||||
await params.result_callback({"name": "The Golden Dragon"})
|
||||
|
||||
|
||||
class TranscriptionLoggingProcessor(FrameProcessor):
|
||||
"""Logs each final user transcription into the shared app resources.
|
||||
|
||||
Demonstrates the second read site for ``app_resources``: any custom
|
||||
``FrameProcessor`` can reach the same bag every tool handler sees by
|
||||
going through ``self.pipeline_task.app_resources``. ``pipeline_task``
|
||||
is ``None`` until the task sets the processor up, so we guard against
|
||||
that case.
|
||||
"""
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Forward all frames; log final user transcriptions on the way through."""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, TranscriptionFrame) and self.pipeline_task is not None:
|
||||
resources = cast(AppResources, self.pipeline_task.app_resources)
|
||||
resources.transcription_logger.log_transcription(frame.text)
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
# We use lambdas to defer transport parameter creation until the transport
|
||||
# type is selected at runtime.
|
||||
transport_params = {
|
||||
@@ -203,6 +264,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
[
|
||||
transport.input(),
|
||||
stt,
|
||||
TranscriptionLoggingProcessor(),
|
||||
user_aggregator,
|
||||
llm,
|
||||
tts,
|
||||
@@ -211,10 +273,14 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
]
|
||||
)
|
||||
|
||||
# Keep a local handle so we can read collected state after the session
|
||||
# Keep local handles so we can read collected state after the session
|
||||
# ends; Pipecat never copies or clears the object.
|
||||
tool_call_logger = ToolCallLogger()
|
||||
resources = SessionResources(tool_call_logger=tool_call_logger)
|
||||
transcription_logger = TranscriptionLogger()
|
||||
resources = AppResources(
|
||||
tool_call_logger=tool_call_logger,
|
||||
transcription_logger=transcription_logger,
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
@@ -223,7 +289,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
tool_resources=resources,
|
||||
app_resources=resources,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
@@ -246,6 +312,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
|
||||
# The session has ended; read whatever state the handlers built up.
|
||||
logger.info(f"Tool calls logged during session:\n{tool_call_logger.dump()}")
|
||||
logger.info(f"Transcriptions logged during session:\n{transcription_logger.dump()}")
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
@@ -42,7 +42,6 @@
|
||||
"src/pipecat/services/azure/stt.py",
|
||||
"src/pipecat/services/azure/tts.py",
|
||||
"src/pipecat/services/cartesia/stt.py",
|
||||
"src/pipecat/services/cartesia/tts.py",
|
||||
"src/pipecat/services/deepgram/flux/base.py",
|
||||
"src/pipecat/services/deepgram/flux/sagemaker/stt.py",
|
||||
"src/pipecat/services/deepgram/flux/stt.py",
|
||||
|
||||
@@ -14,6 +14,7 @@ including heartbeats, idle detection, and observer integration.
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import os
|
||||
import warnings
|
||||
from collections.abc import AsyncIterable, Iterable
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeVar
|
||||
@@ -193,6 +194,7 @@ class PipelineTask(BasePipelineTask):
|
||||
*,
|
||||
params: PipelineParams | None = None,
|
||||
additional_span_attributes: dict | None = None,
|
||||
app_resources: Any = None,
|
||||
cancel_on_idle_timeout: bool = True,
|
||||
cancel_timeout_secs: float = CANCEL_TIMEOUT_SECS,
|
||||
check_dangling_tasks: bool = True,
|
||||
@@ -216,6 +218,14 @@ class PipelineTask(BasePipelineTask):
|
||||
params: Configuration parameters for the pipeline.
|
||||
additional_span_attributes: Optional dictionary of attributes to propagate as
|
||||
OpenTelemetry conversation span attributes.
|
||||
app_resources: Optional application-defined bag of anything your
|
||||
application code may want to share across this session (DB
|
||||
handles, HTTP clients, etc.), passed by reference. Pipecat
|
||||
passes it through untouched and exposes it on the task itself
|
||||
as ``task.app_resources`` and passes it to tool handlers as
|
||||
``FunctionCallParams.app_resources``. The framework never
|
||||
copies or clears this object; the caller retains their handle
|
||||
and can read any mutations after the task finishes.
|
||||
cancel_on_idle_timeout: Whether the pipeline task should be cancelled if
|
||||
the idle timeout is reached.
|
||||
cancel_timeout_secs: Timeout (in seconds) to wait for cancellation to happen
|
||||
@@ -235,13 +245,24 @@ class PipelineTask(BasePipelineTask):
|
||||
rtvi_observer_params: The RTVI observer parameter to use if RTVI is enabled.
|
||||
rtvi_processor: The RTVI processor to add if RTVI is enabled.
|
||||
task_manager: Optional task manager for handling asyncio tasks.
|
||||
tool_resources: Optional application-defined bag of resources (DB handles,
|
||||
clients, state, etc.) passed by reference to every tool handler via
|
||||
``FunctionCallParams.tool_resources``. The framework never copies or
|
||||
clears this object; the caller retains their handle and can read any
|
||||
mutations after the task finishes.
|
||||
tool_resources: Deprecated alias for ``app_resources``.
|
||||
|
||||
.. deprecated:: 1.2.0
|
||||
Use ``app_resources`` instead. ``tool_resources`` will be
|
||||
removed in a future version.
|
||||
"""
|
||||
super().__init__()
|
||||
if tool_resources is not None:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"`PipelineTask(tool_resources=...)` is deprecated since 1.2.0, "
|
||||
"use `app_resources` instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if app_resources is None:
|
||||
app_resources = tool_resources
|
||||
self._params = params or PipelineParams()
|
||||
self._additional_span_attributes = additional_span_attributes or {}
|
||||
self._cancel_on_idle_timeout = cancel_on_idle_timeout
|
||||
@@ -252,7 +273,7 @@ class PipelineTask(BasePipelineTask):
|
||||
self._enable_tracing = enable_tracing and is_tracing_available()
|
||||
self._enable_turn_tracking = enable_turn_tracking
|
||||
self._idle_timeout_secs = idle_timeout_secs
|
||||
self._tool_resources = tool_resources
|
||||
self._app_resources = app_resources
|
||||
observers = observers or []
|
||||
self._turn_tracking_observer: TurnTrackingObserver | None = None
|
||||
self._user_bot_latency_observer: UserBotLatencyObserver | None = None
|
||||
@@ -391,6 +412,21 @@ class PipelineTask(BasePipelineTask):
|
||||
"""
|
||||
return self._params
|
||||
|
||||
@property
|
||||
def app_resources(self) -> Any:
|
||||
"""Get the application-defined resources passed to this task.
|
||||
|
||||
This is the same object passed to the constructor as
|
||||
``app_resources``. Tool handlers can also access it via
|
||||
``FunctionCallParams.app_resources``. The framework returns the
|
||||
original reference; mutations are visible to all callers.
|
||||
|
||||
Returns:
|
||||
The application-defined resources, or ``None`` if none were
|
||||
passed.
|
||||
"""
|
||||
return self._app_resources
|
||||
|
||||
@property
|
||||
def pipeline(self) -> BasePipeline:
|
||||
"""Get the full pipeline managed by this pipeline task.
|
||||
@@ -730,7 +766,13 @@ class PipelineTask(BasePipelineTask):
|
||||
clock=self._clock,
|
||||
task_manager=self._task_manager,
|
||||
observer=self._observer,
|
||||
tool_resources=self._tool_resources,
|
||||
pipeline_task=self,
|
||||
# Populate the deprecated `tool_resources` field for backwards
|
||||
# compatibility with custom FrameProcessor subclasses whose
|
||||
# ``setup()`` overrides still read it. Reading the field emits a
|
||||
# DeprecationWarning; new code should read
|
||||
# ``setup.pipeline_task.app_resources`` instead.
|
||||
tool_resources=self._app_resources,
|
||||
)
|
||||
await self._pipeline.setup(setup)
|
||||
|
||||
|
||||
@@ -16,10 +16,12 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import traceback
|
||||
import warnings
|
||||
from collections.abc import Awaitable, Callable, Coroutine
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Optional,
|
||||
)
|
||||
@@ -47,6 +49,9 @@ from pipecat.utils.asyncio.task_manager import BaseTaskManager
|
||||
from pipecat.utils.base_object import BaseObject
|
||||
from pipecat.utils.frame_queue import FrameQueue
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
|
||||
|
||||
class FrameDirection(Enum):
|
||||
"""Direction of frame flow in the processing pipeline.
|
||||
@@ -71,15 +76,45 @@ class FrameProcessorSetup:
|
||||
clock: The clock instance for timing operations.
|
||||
task_manager: The task manager for handling async operations.
|
||||
observer: Optional observer for monitoring frame processing events.
|
||||
tool_resources: Application-defined resources shared with processors
|
||||
for this pipeline run.
|
||||
pipeline_task: The :class:`PipelineTask` running this pipeline. Stored
|
||||
on each processor as ``self.pipeline_task`` so processors can
|
||||
reach task-scoped state (e.g. ``self.pipeline_task.app_resources``).
|
||||
tool_resources: Deprecated. :class:`PipelineTask` continues to populate
|
||||
this with ``app_resources`` so that custom :class:`FrameProcessor`
|
||||
subclasses whose ``setup()`` overrides read ``setup.tool_resources``
|
||||
keep working. New code should read
|
||||
``setup.pipeline_task.app_resources`` instead.
|
||||
|
||||
.. deprecated:: 1.2.0
|
||||
Reading this attribute emits a ``DeprecationWarning``. Read
|
||||
``setup.pipeline_task.app_resources`` instead.
|
||||
``tool_resources`` will be removed in a future version.
|
||||
"""
|
||||
|
||||
clock: BaseClock
|
||||
task_manager: BaseTaskManager
|
||||
observer: BaseObserver | None = None
|
||||
pipeline_task: PipelineTask | None = None
|
||||
tool_resources: Any = None
|
||||
|
||||
def __getattribute__(self, name: str) -> Any:
|
||||
# Warn when user code reads the deprecated ``tool_resources`` field.
|
||||
# Set is unaffected (goes through ``__setattr__``), so PipelineTask can
|
||||
# populate it for backwards compat without tripping the warning.
|
||||
if name == "tool_resources":
|
||||
value = object.__getattribute__(self, "tool_resources")
|
||||
if value is not None:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"`FrameProcessorSetup.tool_resources` is deprecated since 1.2.0; "
|
||||
"read `setup.pipeline_task.app_resources` instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return value
|
||||
return object.__getattribute__(self, name)
|
||||
|
||||
|
||||
class FrameProcessorQueue(asyncio.PriorityQueue):
|
||||
"""A priority queue for systems frames and other frames.
|
||||
@@ -188,6 +223,9 @@ class FrameProcessor(BaseObject):
|
||||
# Observer
|
||||
self._observer: BaseObserver | None = None
|
||||
|
||||
# Pipeline Task
|
||||
self._pipeline_task: PipelineTask | None = None
|
||||
|
||||
# Other properties
|
||||
self._enable_metrics = False
|
||||
self._enable_usage_metrics = False
|
||||
@@ -344,6 +382,22 @@ class FrameProcessor(BaseObject):
|
||||
raise Exception(f"{self} TaskManager is still not initialized.")
|
||||
return self._task_manager
|
||||
|
||||
@property
|
||||
def pipeline_task(self) -> PipelineTask | None:
|
||||
"""Get the :class:`PipelineTask` this processor is running in.
|
||||
|
||||
Provides access to task-scoped state from inside a processor — most
|
||||
notably ``self.pipeline_task.app_resources`` for the application's
|
||||
shared bag of resources (DB handles, clients, feature flags, etc.).
|
||||
|
||||
Returns:
|
||||
The :class:`PipelineTask` instance that set up this processor,
|
||||
or ``None`` if the processor has not yet been set up by one
|
||||
(for example, before the task has started, or when the processor
|
||||
was instantiated in isolation).
|
||||
"""
|
||||
return self._pipeline_task
|
||||
|
||||
def processors_with_metrics(self):
|
||||
"""Return processors that can generate metrics.
|
||||
|
||||
@@ -495,6 +549,7 @@ class FrameProcessor(BaseObject):
|
||||
self._clock = setup.clock
|
||||
self._task_manager = setup.task_manager
|
||||
self._observer = setup.observer
|
||||
self._pipeline_task = setup.pipeline_task
|
||||
|
||||
# Create processing tasks.
|
||||
self.__create_input_task()
|
||||
|
||||
@@ -232,12 +232,13 @@ class CartesiaTTSService(WebsocketTTSService):
|
||||
*,
|
||||
api_key: str,
|
||||
voice_id: str | None = None,
|
||||
cartesia_version: str = "2025-04-16",
|
||||
cartesia_version: str = "2026-03-01",
|
||||
url: str = "wss://api.cartesia.ai/tts/websocket",
|
||||
model: str | None = None,
|
||||
sample_rate: int | None = None,
|
||||
encoding: str = "pcm_s16le",
|
||||
container: str = "raw",
|
||||
max_buffer_delay_ms: int | None = None,
|
||||
params: InputParams | None = None,
|
||||
settings: Settings | None = None,
|
||||
text_aggregation_mode: TextAggregationMode | None = None,
|
||||
@@ -263,6 +264,12 @@ class CartesiaTTSService(WebsocketTTSService):
|
||||
sample_rate: Audio sample rate. If None, uses default.
|
||||
encoding: Audio encoding format.
|
||||
container: Audio container format.
|
||||
max_buffer_delay_ms: Server-side buffering window before generation
|
||||
starts. ``0`` disables server buffering (custom buffering); any
|
||||
value in (0, 5000] enables managed buffering. If ``None``,
|
||||
derived from ``text_aggregation_mode``: ``0`` for ``SENTENCE``
|
||||
(avoids stacking client and server buffering), unset for
|
||||
``TOKEN`` (uses Cartesia's 3000ms default).
|
||||
params: Additional input parameters for voice customization.
|
||||
|
||||
.. deprecated:: 0.0.105
|
||||
@@ -353,6 +360,15 @@ class CartesiaTTSService(WebsocketTTSService):
|
||||
self._output_encoding = encoding
|
||||
self._output_sample_rate = 0 # Set in start() from self.sample_rate
|
||||
|
||||
# Cartesia warns against the "middle ground" of client-side sentence
|
||||
# aggregation plus the server's default 3000ms buffer. When the user
|
||||
# doesn't pick a value, send 0 in SENTENCE mode (custom buffering) and
|
||||
# leave it unset in TOKEN mode so the server default applies (managed
|
||||
# buffering).
|
||||
if max_buffer_delay_ms is None and not self._is_streaming_tokens:
|
||||
max_buffer_delay_ms = 0
|
||||
self._max_buffer_delay_ms = max_buffer_delay_ms
|
||||
|
||||
self._receive_task = None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
@@ -375,22 +391,27 @@ class CartesiaTTSService(WebsocketTTSService):
|
||||
return language_to_cartesia_language(language)
|
||||
|
||||
# A set of Cartesia-specific helpers for text transformations
|
||||
@staticmethod
|
||||
def SPELL(text: str) -> str:
|
||||
"""Wrap text in Cartesia spell tag."""
|
||||
return f"<spell>{text}</spell>"
|
||||
|
||||
@staticmethod
|
||||
def EMOTION_TAG(emotion: CartesiaEmotion) -> str:
|
||||
"""Convenience method to create an emotion tag."""
|
||||
return f'<emotion value="{emotion}" />'
|
||||
|
||||
@staticmethod
|
||||
def PAUSE_TAG(seconds: float) -> str:
|
||||
"""Convenience method to create a pause tag."""
|
||||
return f'<break time="{seconds}s" />'
|
||||
|
||||
@staticmethod
|
||||
def VOLUME_TAG(volume: float) -> str:
|
||||
"""Convenience method to create a volume tag."""
|
||||
return f'<volume ratio="{volume}" />'
|
||||
|
||||
@staticmethod
|
||||
def SPEED_TAG(speed: float) -> str:
|
||||
"""Convenience method to create a speed tag."""
|
||||
return f'<speed ratio="{speed}" />'
|
||||
@@ -466,9 +487,12 @@ class CartesiaTTSService(WebsocketTTSService):
|
||||
"sample_rate": self._output_sample_rate,
|
||||
},
|
||||
"add_timestamps": add_timestamps,
|
||||
"use_original_timestamps": False if self._settings.model == "sonic" else True,
|
||||
"use_normalized_timestamps": False,
|
||||
}
|
||||
|
||||
if self._max_buffer_delay_ms is not None:
|
||||
msg["max_buffer_delay_ms"] = self._max_buffer_delay_ms
|
||||
|
||||
if self._settings.language:
|
||||
msg["language"] = self._settings.language
|
||||
|
||||
@@ -647,6 +671,13 @@ class CartesiaTTSService(WebsocketTTSService):
|
||||
await self.stop_all_metrics()
|
||||
await self.push_error(error_msg=f"Error: {msg}")
|
||||
self.reset_active_audio_context()
|
||||
elif msg["type"] == "flush_done":
|
||||
# Cartesia emits flush_done as a per-transcript boundary marker
|
||||
# within a context (e.g. when max_buffer_delay_ms=0 causes the
|
||||
# server to flush each submission). We don't need it: each turn
|
||||
# already has its own context_id and audio chunks are tagged
|
||||
# with it. Acknowledge silently.
|
||||
pass
|
||||
else:
|
||||
await self.push_error(error_msg=f"Error, unknown message type: {msg}")
|
||||
|
||||
@@ -885,6 +916,9 @@ class CartesiaHttpTTSService(TTSService):
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
try:
|
||||
if self._session is None:
|
||||
raise RuntimeError("HTTP session is not initialized; call start() before run_tts()")
|
||||
|
||||
voice_config = {"mode": "id", "id": self._settings.voice}
|
||||
|
||||
output_format = {
|
||||
@@ -921,8 +955,10 @@ class CartesiaHttpTTSService(TTSService):
|
||||
async with self._session.post(url, json=payload, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
yield ErrorFrame(error=f"Cartesia API error: {error_text}")
|
||||
raise Exception(f"Cartesia API returned status {response.status}: {error_text}")
|
||||
yield ErrorFrame(
|
||||
error=f"Cartesia API error (status {response.status}): {error_text}"
|
||||
)
|
||||
return
|
||||
|
||||
audio_data = await response.read()
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ for generating speech from text using various voice models.
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
@@ -27,7 +27,7 @@ from pipecat.frames.frames import (
|
||||
TTSAudioRawFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.settings import TTSSettings
|
||||
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, is_given
|
||||
from pipecat.services.tts_service import TTSService, WebsocketTTSService
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
@@ -44,9 +44,13 @@ except ModuleNotFoundError as e:
|
||||
|
||||
@dataclass
|
||||
class DeepgramTTSSettings(TTSSettings):
|
||||
"""Settings for DeepgramTTSService and DeepgramHttpTTSService."""
|
||||
"""Settings for DeepgramTTSService and DeepgramHttpTTSService.
|
||||
|
||||
pass
|
||||
Parameters:
|
||||
mip_opt_out: Opt out of Deepgram's Model Improvement Program.
|
||||
"""
|
||||
|
||||
mip_opt_out: bool | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
|
||||
|
||||
|
||||
class DeepgramTTSService(WebsocketTTSService):
|
||||
@@ -102,6 +106,7 @@ class DeepgramTTSService(WebsocketTTSService):
|
||||
model=None,
|
||||
voice="aura-2-helena-en",
|
||||
language=None,
|
||||
mip_opt_out=None,
|
||||
)
|
||||
|
||||
# 2. Apply direct init arg overrides (deprecated)
|
||||
@@ -221,6 +226,8 @@ class DeepgramTTSService(WebsocketTTSService):
|
||||
params.append(f"model={self._settings.voice}")
|
||||
params.append(f"encoding={self._encoding}")
|
||||
params.append(f"sample_rate={self.sample_rate}")
|
||||
if is_given(self._settings.mip_opt_out) and self._settings.mip_opt_out is not None:
|
||||
params.append(f"mip_opt_out={'true' if self._settings.mip_opt_out else 'false'}")
|
||||
|
||||
url = f"{self._base_url}/v1/speak?{'&'.join(params)}"
|
||||
|
||||
@@ -405,6 +412,7 @@ class DeepgramHttpTTSService(TTSService):
|
||||
model=None,
|
||||
voice="aura-2-helena-en",
|
||||
language=None,
|
||||
mip_opt_out=None,
|
||||
)
|
||||
|
||||
# 2. Apply direct init arg overrides (deprecated)
|
||||
@@ -464,6 +472,8 @@ class DeepgramHttpTTSService(TTSService):
|
||||
"sample_rate": self.sample_rate,
|
||||
"container": "none",
|
||||
}
|
||||
if is_given(self._settings.mip_opt_out) and self._settings.mip_opt_out is not None:
|
||||
params["mip_opt_out"] = "true" if self._settings.mip_opt_out else "false"
|
||||
|
||||
payload = {
|
||||
"text": text,
|
||||
|
||||
@@ -51,7 +51,7 @@ from pipecat.processors.aggregators.llm_context import (
|
||||
LLMContext,
|
||||
LLMSpecificMessage,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessorSetup
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_service import AIService
|
||||
from pipecat.services.settings import LLMSettings, assert_given
|
||||
from pipecat.services.websocket_service import WebsocketService
|
||||
@@ -107,9 +107,10 @@ class FunctionCallParams:
|
||||
For async function calls (``cancel_on_interruption=False``), call
|
||||
it with ``properties=FunctionCallResultProperties(is_final=False)``
|
||||
to push intermediate updates before the final result.
|
||||
tool_resources: Application-defined bag of resources (DB handles, clients,
|
||||
state, etc.) shared across tool calls for the pipeline session. Set
|
||||
via ``PipelineTask(..., tool_resources=...)`` and passed by reference.
|
||||
app_resources: The application-defined resources passed to
|
||||
``PipelineTask(..., app_resources=...)``. Same object — passed by
|
||||
reference, not a copy. Use it to share DB handles, clients, state,
|
||||
feature flags, etc. across all of a session's tool handlers.
|
||||
"""
|
||||
|
||||
function_name: str
|
||||
@@ -118,7 +119,25 @@ class FunctionCallParams:
|
||||
llm: LLMService
|
||||
context: LLMContext
|
||||
result_callback: FunctionCallResultCallback
|
||||
tool_resources: Any = None
|
||||
app_resources: Any = None
|
||||
|
||||
@property
|
||||
def tool_resources(self) -> Any:
|
||||
"""Deprecated alias for :attr:`app_resources`.
|
||||
|
||||
.. deprecated:: 1.2.0
|
||||
Use :attr:`app_resources` instead. ``tool_resources`` will be
|
||||
removed in a future version.
|
||||
"""
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"`FunctionCallParams.tool_resources` is deprecated since 1.2.0, "
|
||||
"use `app_resources` instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self.app_resources
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -256,7 +275,6 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
|
||||
self._sequential_runner_task: asyncio.Task | None = None
|
||||
self._skip_tts: bool | None = None
|
||||
self._summary_task: asyncio.Task | None = None
|
||||
self._tool_resources: Any = None
|
||||
|
||||
self._register_event_handler("on_function_calls_started")
|
||||
self._register_event_handler("on_function_calls_cancelled")
|
||||
@@ -303,15 +321,6 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
|
||||
"""
|
||||
raise NotImplementedError(f"run_inference() not supported by {self.__class__.__name__}")
|
||||
|
||||
async def setup(self, setup: FrameProcessorSetup):
|
||||
"""Set up the LLM service.
|
||||
|
||||
Args:
|
||||
setup: The frame processor setup data.
|
||||
"""
|
||||
await super().setup(setup)
|
||||
self._tool_resources = setup.tool_resources
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the LLM service.
|
||||
|
||||
@@ -882,6 +891,9 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
|
||||
# it starts would leave the coroutine in a "never awaited" state.
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# _pipeline_task may be unset when the service is driven without a PipelineTask.
|
||||
app_resources = self._pipeline_task.app_resources if self._pipeline_task else None
|
||||
|
||||
try:
|
||||
if isinstance(item.handler, DirectFunctionWrapper):
|
||||
# Handler is a DirectFunctionWrapper
|
||||
@@ -894,7 +906,7 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
|
||||
llm=self,
|
||||
context=runner_item.context,
|
||||
result_callback=function_call_result_callback,
|
||||
tool_resources=self._tool_resources,
|
||||
app_resources=app_resources,
|
||||
),
|
||||
)
|
||||
else:
|
||||
@@ -906,7 +918,7 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
|
||||
llm=self,
|
||||
context=runner_item.context,
|
||||
result_callback=function_call_result_callback,
|
||||
tool_resources=self._tool_resources,
|
||||
app_resources=app_resources,
|
||||
)
|
||||
await item.handler(params)
|
||||
except Exception as e:
|
||||
|
||||
@@ -18,10 +18,6 @@ class AlwaysUserMuteStrategy(BaseUserMuteStrategy):
|
||||
super().__init__()
|
||||
self._bot_speaking = False
|
||||
|
||||
async def reset(self):
|
||||
"""Reset the strategy to its initial state."""
|
||||
self._bot_speaking = False
|
||||
|
||||
async def process_frame(self, frame: Frame) -> bool:
|
||||
"""Process an incoming frame.
|
||||
|
||||
|
||||
@@ -51,10 +51,6 @@ class BaseUserMuteStrategy(BaseObject):
|
||||
"""Cleanup the strategy."""
|
||||
pass
|
||||
|
||||
async def reset(self):
|
||||
"""Reset the strategy to its initial state."""
|
||||
pass
|
||||
|
||||
async def process_frame(self, frame: Frame) -> bool:
|
||||
"""Process an incoming frame.
|
||||
|
||||
|
||||
@@ -29,11 +29,6 @@ class FirstSpeechUserMuteStrategy(BaseUserMuteStrategy):
|
||||
self._bot_speaking = False
|
||||
self._first_speech_handled = False
|
||||
|
||||
async def reset(self):
|
||||
"""Reset the strategy to its initial state."""
|
||||
self._bot_speaking = False
|
||||
self._first_speech_handled = False
|
||||
|
||||
async def process_frame(self, frame: Frame) -> bool:
|
||||
"""Process an incoming frame.
|
||||
|
||||
|
||||
@@ -30,10 +30,6 @@ class FunctionCallUserMuteStrategy(BaseUserMuteStrategy):
|
||||
super().__init__()
|
||||
self._function_call_in_progress: set[str] = set()
|
||||
|
||||
async def reset(self):
|
||||
"""Reset the strategy to its initial state."""
|
||||
self._function_call_in_progress = set()
|
||||
|
||||
async def process_frame(self, frame: Frame) -> bool:
|
||||
"""Process an incoming frame.
|
||||
|
||||
|
||||
@@ -30,10 +30,6 @@ class MuteUntilFirstBotCompleteUserMuteStrategy(BaseUserMuteStrategy):
|
||||
super().__init__()
|
||||
self._first_speech_handled = False
|
||||
|
||||
async def reset(self):
|
||||
"""Reset the strategy to its initial state."""
|
||||
self._first_speech_handled = False
|
||||
|
||||
async def process_frame(self, frame: Frame) -> bool:
|
||||
"""Process an incoming frame.
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
from pipecat.turns.user_start import (
|
||||
BaseUserTurnStartStrategy,
|
||||
ExternalUserTurnStartStrategy,
|
||||
@@ -44,6 +43,8 @@ def default_user_turn_stop_strategies() -> list[BaseUserTurnStopStrategy]:
|
||||
Returns ``[TurnAnalyzerUserTurnStopStrategy(LocalSmartTurnAnalyzerV3)]``.
|
||||
Useful when building a custom strategy list that extends the defaults.
|
||||
"""
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
|
||||
return [TurnAnalyzerUserTurnStopStrategy(turn_analyzer=LocalSmartTurnAnalyzerV3())]
|
||||
|
||||
|
||||
|
||||
326
tests/test_app_resources.py
Normal file
326
tests/test_app_resources.py
Normal file
@@ -0,0 +1,326 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import unittest
|
||||
from dataclasses import dataclass, field
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from pipecat.adapters.schemas.direct_function import DirectFunctionWrapper
|
||||
from pipecat.clocks.system_clock import SystemClock
|
||||
from pipecat.frames.frames import EndFrame, Frame, StartFrame
|
||||
from pipecat.pipeline.base_task import PipelineTaskParams
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor, FrameProcessorSetup
|
||||
from pipecat.services.llm_service import (
|
||||
FunctionCallParams,
|
||||
FunctionCallRegistryItem,
|
||||
FunctionCallRunnerItem,
|
||||
LLMService,
|
||||
)
|
||||
from pipecat.services.settings import LLMSettings
|
||||
from pipecat.utils.asyncio.task_manager import TaskManager, TaskManagerParams
|
||||
|
||||
|
||||
@dataclass
|
||||
class _Resources:
|
||||
user_name: str
|
||||
db: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
def _complete_llm_settings() -> LLMSettings:
|
||||
"""Return an LLMSettings with every field set so test_service_init's
|
||||
auto-discovered ``_MockLLMService`` doesn't fail its NOT_GIVEN check."""
|
||||
return LLMSettings(
|
||||
model=None,
|
||||
system_instruction=None,
|
||||
temperature=None,
|
||||
max_tokens=None,
|
||||
top_p=None,
|
||||
top_k=None,
|
||||
frequency_penalty=None,
|
||||
presence_penalty=None,
|
||||
seed=None,
|
||||
filter_incomplete_user_turns=None,
|
||||
user_turn_completion_config=None,
|
||||
)
|
||||
|
||||
|
||||
class _MockLLMService(LLMService):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(settings=_complete_llm_settings(), **kwargs)
|
||||
|
||||
|
||||
class TestFunctionCallParamsAppResources(unittest.TestCase):
|
||||
def test_default_is_none(self):
|
||||
params = FunctionCallParams(
|
||||
function_name="f",
|
||||
tool_call_id="1",
|
||||
arguments={},
|
||||
llm=None, # type: ignore[arg-type]
|
||||
context=LLMContext(),
|
||||
result_callback=AsyncMock(),
|
||||
)
|
||||
self.assertIsNone(params.app_resources)
|
||||
|
||||
def test_holds_reference(self):
|
||||
resources = _Resources(user_name="John")
|
||||
params = FunctionCallParams(
|
||||
function_name="f",
|
||||
tool_call_id="1",
|
||||
arguments={},
|
||||
llm=None, # type: ignore[arg-type]
|
||||
context=LLMContext(),
|
||||
result_callback=AsyncMock(),
|
||||
app_resources=resources,
|
||||
)
|
||||
self.assertIs(params.app_resources, resources)
|
||||
|
||||
def test_tool_resources_property_warns_and_aliases_app_resources(self):
|
||||
resources = _Resources(user_name="John")
|
||||
params = FunctionCallParams(
|
||||
function_name="f",
|
||||
tool_call_id="1",
|
||||
arguments={},
|
||||
llm=None, # type: ignore[arg-type]
|
||||
context=LLMContext(),
|
||||
result_callback=AsyncMock(),
|
||||
app_resources=resources,
|
||||
)
|
||||
with self.assertWarns(DeprecationWarning):
|
||||
value = params.tool_resources
|
||||
self.assertIs(value, resources)
|
||||
|
||||
|
||||
class TestLLMServiceFunctionCallReadsAppResources(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_function_call_params_receives_app_resources(self):
|
||||
service = _MockLLMService()
|
||||
resources = _Resources(user_name="John")
|
||||
# Stub the pipeline task with just the bit LLMService reads.
|
||||
service._pipeline_task = SimpleNamespace(app_resources=resources) # type: ignore[assignment]
|
||||
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
async def handler(params: FunctionCallParams):
|
||||
captured["params"] = params
|
||||
params.app_resources.db["hit"] = True
|
||||
await params.result_callback({"ok": True})
|
||||
|
||||
service._functions["lookup"] = FunctionCallRegistryItem(
|
||||
function_name="lookup",
|
||||
handler=handler,
|
||||
cancel_on_interruption=True,
|
||||
)
|
||||
service.broadcast_frame = AsyncMock() # type: ignore[method-assign]
|
||||
|
||||
runner_item = FunctionCallRunnerItem(
|
||||
registry_item=service._functions["lookup"],
|
||||
function_name="lookup",
|
||||
tool_call_id="call-1",
|
||||
arguments={},
|
||||
context=LLMContext(),
|
||||
)
|
||||
await service._run_function_call(runner_item)
|
||||
|
||||
self.assertIs(captured["params"].app_resources, resources)
|
||||
self.assertTrue(resources.db["hit"])
|
||||
|
||||
async def test_direct_function_params_receives_app_resources(self):
|
||||
service = _MockLLMService()
|
||||
resources = _Resources(user_name="John")
|
||||
service._pipeline_task = SimpleNamespace(app_resources=resources) # type: ignore[assignment]
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
async def lookup(params: FunctionCallParams):
|
||||
captured["params"] = params
|
||||
|
||||
wrapper = DirectFunctionWrapper(lookup)
|
||||
service._functions[wrapper.name] = FunctionCallRegistryItem(
|
||||
function_name=wrapper.name,
|
||||
handler=wrapper,
|
||||
cancel_on_interruption=True,
|
||||
)
|
||||
service.broadcast_frame = AsyncMock() # type: ignore[method-assign]
|
||||
|
||||
runner_item = FunctionCallRunnerItem(
|
||||
registry_item=service._functions[wrapper.name],
|
||||
function_name=wrapper.name,
|
||||
tool_call_id="call-1",
|
||||
arguments={},
|
||||
context=LLMContext(),
|
||||
)
|
||||
await service._run_function_call(runner_item)
|
||||
|
||||
self.assertIs(captured["params"].app_resources, resources)
|
||||
|
||||
async def test_app_resources_none_when_pipeline_task_unset(self):
|
||||
service = _MockLLMService()
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
async def handler(params: FunctionCallParams):
|
||||
captured["params"] = params
|
||||
await params.result_callback({"ok": True})
|
||||
|
||||
service._functions["lookup"] = FunctionCallRegistryItem(
|
||||
function_name="lookup",
|
||||
handler=handler,
|
||||
cancel_on_interruption=True,
|
||||
)
|
||||
service.broadcast_frame = AsyncMock() # type: ignore[method-assign]
|
||||
|
||||
runner_item = FunctionCallRunnerItem(
|
||||
registry_item=service._functions["lookup"],
|
||||
function_name="lookup",
|
||||
tool_call_id="call-1",
|
||||
arguments={},
|
||||
context=LLMContext(),
|
||||
)
|
||||
await service._run_function_call(runner_item)
|
||||
|
||||
self.assertIsNone(captured["params"].app_resources)
|
||||
|
||||
async def test_frame_processor_setup_tool_resources_warns_on_read(self):
|
||||
# ``FrameProcessorSetup.tool_resources`` is retained for backwards
|
||||
# compatibility with custom FrameProcessors whose ``setup()`` overrides
|
||||
# still read it. The field is populated, but reading it warns.
|
||||
task_manager = TaskManager()
|
||||
task_manager.setup(TaskManagerParams(loop=asyncio.get_running_loop()))
|
||||
resources = _Resources(user_name="John")
|
||||
|
||||
# Construction itself does not warn — only reads do.
|
||||
setup = FrameProcessorSetup(
|
||||
clock=SystemClock(),
|
||||
task_manager=task_manager,
|
||||
tool_resources=resources,
|
||||
)
|
||||
|
||||
with self.assertWarns(DeprecationWarning):
|
||||
value = setup.tool_resources
|
||||
self.assertIs(value, resources)
|
||||
|
||||
|
||||
class TestPipelineTaskAppResources(unittest.TestCase):
|
||||
def test_getter_returns_constructor_value(self):
|
||||
resources = _Resources(user_name="John")
|
||||
task = PipelineTask(Pipeline([]), app_resources=resources)
|
||||
self.assertIs(task.app_resources, resources)
|
||||
|
||||
def test_default_app_resources_is_none(self):
|
||||
task = PipelineTask(Pipeline([]))
|
||||
self.assertIsNone(task.app_resources)
|
||||
|
||||
def test_tool_resources_kwarg_warns_and_aliases_app_resources(self):
|
||||
resources = _Resources(user_name="John")
|
||||
with self.assertWarns(DeprecationWarning):
|
||||
task = PipelineTask(Pipeline([]), tool_resources=resources)
|
||||
self.assertIs(task.app_resources, resources)
|
||||
|
||||
def test_app_resources_takes_precedence_over_tool_resources(self):
|
||||
new = _Resources(user_name="new")
|
||||
old = _Resources(user_name="old")
|
||||
with self.assertWarns(DeprecationWarning):
|
||||
task = PipelineTask(Pipeline([]), app_resources=new, tool_resources=old)
|
||||
self.assertIs(task.app_resources, new)
|
||||
|
||||
|
||||
class _RecordingProcessor(FrameProcessor):
|
||||
"""Records the pipeline_task it sees once StartFrame reaches it."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.observed_task: Any = None
|
||||
self.observed_app_resources: Any = None
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
if isinstance(frame, StartFrame):
|
||||
# setup() runs before any frame reaches us, so pipeline_task is wired up.
|
||||
assert self.pipeline_task is not None
|
||||
self.observed_task = self.pipeline_task
|
||||
self.observed_app_resources = self.pipeline_task.app_resources
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
class _LegacyToolResourcesReader(FrameProcessor):
|
||||
"""Custom processor that reads the deprecated ``setup.tool_resources``.
|
||||
|
||||
Models a previously-written user FrameProcessor whose ``setup()``
|
||||
override hasn't been migrated yet. The field is populated by
|
||||
``PipelineTask`` for backwards compatibility; reading it emits a
|
||||
DeprecationWarning.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.captured_tool_resources: Any = None
|
||||
|
||||
async def setup(self, setup):
|
||||
await super().setup(setup)
|
||||
self.captured_tool_resources = setup.tool_resources
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
# Forward all frames so the EndFrame reaches the pipeline sink and
|
||||
# ``task.run()`` can return cleanly.
|
||||
await super().process_frame(frame, direction)
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
class TestFrameProcessorSetupToolResourcesBackwardsCompat(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_legacy_processor_receives_value_via_app_resources(self):
|
||||
resources = _Resources(user_name="John")
|
||||
legacy = _LegacyToolResourcesReader()
|
||||
pipeline = Pipeline([legacy])
|
||||
task = PipelineTask(pipeline, app_resources=resources)
|
||||
|
||||
await task.queue_frame(EndFrame())
|
||||
with self.assertWarns(DeprecationWarning):
|
||||
await task.run(PipelineTaskParams(loop=asyncio.get_event_loop()))
|
||||
|
||||
self.assertIs(legacy.captured_tool_resources, resources)
|
||||
|
||||
async def test_legacy_processor_receives_value_via_deprecated_tool_resources_kwarg(
|
||||
self,
|
||||
):
|
||||
# If the user is still constructing PipelineTask with the deprecated
|
||||
# ``tool_resources`` kwarg (and hasn't migrated to ``app_resources``),
|
||||
# legacy processors must still see the value too.
|
||||
resources = _Resources(user_name="John")
|
||||
legacy = _LegacyToolResourcesReader()
|
||||
pipeline = Pipeline([legacy])
|
||||
with self.assertWarns(DeprecationWarning):
|
||||
task = PipelineTask(pipeline, tool_resources=resources)
|
||||
|
||||
await task.queue_frame(EndFrame())
|
||||
with self.assertWarns(DeprecationWarning):
|
||||
await task.run(PipelineTaskParams(loop=asyncio.get_event_loop()))
|
||||
|
||||
self.assertIs(legacy.captured_tool_resources, resources)
|
||||
|
||||
|
||||
class TestFrameProcessorPipelineTaskAccess(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_processor_can_reach_pipeline_task_and_app_resources(self):
|
||||
resources = _Resources(user_name="John")
|
||||
recorder = _RecordingProcessor()
|
||||
pipeline = Pipeline([recorder])
|
||||
task = PipelineTask(pipeline, app_resources=resources)
|
||||
|
||||
await task.queue_frame(EndFrame())
|
||||
await task.run(PipelineTaskParams(loop=asyncio.get_event_loop()))
|
||||
|
||||
self.assertIs(recorder.observed_task, task)
|
||||
self.assertIs(recorder.observed_app_resources, resources)
|
||||
|
||||
def test_pipeline_task_returns_none_when_not_set_up(self):
|
||||
recorder = _RecordingProcessor()
|
||||
self.assertIsNone(recorder.pipeline_task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
173
tests/test_deepgram_tts.py
Normal file
173
tests/test_deepgram_tts.py
Normal file
@@ -0,0 +1,173 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
|
||||
from pipecat.services.deepgram.tts import DeepgramHttpTTSService, DeepgramTTSService
|
||||
|
||||
|
||||
def _make_ws_service(**settings_kwargs) -> DeepgramTTSService:
|
||||
settings = DeepgramTTSService.Settings(**settings_kwargs) if settings_kwargs else None
|
||||
service = DeepgramTTSService(api_key="test-key", settings=settings)
|
||||
# Bypass start() lifecycle: sample_rate is the only field _connect_websocket reads.
|
||||
service._sample_rate = 16000
|
||||
return service
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ws_mip_opt_out_true_in_url():
|
||||
service = _make_ws_service(mip_opt_out=True)
|
||||
|
||||
fake_ws = MagicMock()
|
||||
fake_ws.response.headers = {}
|
||||
|
||||
with patch(
|
||||
"pipecat.services.deepgram.tts.websocket_connect",
|
||||
new=AsyncMock(return_value=fake_ws),
|
||||
) as mock_connect:
|
||||
await service._connect_websocket()
|
||||
|
||||
url = mock_connect.call_args.args[0]
|
||||
assert "mip_opt_out=true" in url
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ws_mip_opt_out_false_in_url():
|
||||
service = _make_ws_service(mip_opt_out=False)
|
||||
|
||||
fake_ws = MagicMock()
|
||||
fake_ws.response.headers = {}
|
||||
|
||||
with patch(
|
||||
"pipecat.services.deepgram.tts.websocket_connect",
|
||||
new=AsyncMock(return_value=fake_ws),
|
||||
) as mock_connect:
|
||||
await service._connect_websocket()
|
||||
|
||||
url = mock_connect.call_args.args[0]
|
||||
assert "mip_opt_out=false" in url
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ws_mip_opt_out_default_absent():
|
||||
service = _make_ws_service()
|
||||
|
||||
fake_ws = MagicMock()
|
||||
fake_ws.response.headers = {}
|
||||
|
||||
with patch(
|
||||
"pipecat.services.deepgram.tts.websocket_connect",
|
||||
new=AsyncMock(return_value=fake_ws),
|
||||
) as mock_connect:
|
||||
await service._connect_websocket()
|
||||
|
||||
url = mock_connect.call_args.args[0]
|
||||
assert "mip_opt_out" not in url
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ws_explicit_empty_settings_omits_mip_opt_out():
|
||||
"""Explicit Settings() with no kwargs must not leak the NOT_GIVEN sentinel."""
|
||||
service = DeepgramTTSService(api_key="test-key", settings=DeepgramTTSService.Settings())
|
||||
# Bypass start() lifecycle: sample_rate is the only field _connect_websocket reads.
|
||||
service._sample_rate = 16000
|
||||
|
||||
fake_ws = MagicMock()
|
||||
fake_ws.response.headers = {}
|
||||
|
||||
with patch(
|
||||
"pipecat.services.deepgram.tts.websocket_connect",
|
||||
new=AsyncMock(return_value=fake_ws),
|
||||
) as mock_connect:
|
||||
await service._connect_websocket()
|
||||
|
||||
url = mock_connect.call_args.args[0]
|
||||
assert "mip_opt_out" not in url
|
||||
|
||||
|
||||
class _FakeResponse:
|
||||
def __init__(self):
|
||||
self.status = 200
|
||||
self.content = MagicMock()
|
||||
|
||||
async def _empty_iter(_chunk_size):
|
||||
return
|
||||
yield # unreachable; makes this an async generator
|
||||
|
||||
self.content.iter_chunked = _empty_iter
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
|
||||
def _make_http_service(**settings_kwargs) -> DeepgramHttpTTSService:
|
||||
settings = DeepgramHttpTTSService.Settings(**settings_kwargs) if settings_kwargs else None
|
||||
session = MagicMock(spec=aiohttp.ClientSession)
|
||||
service = DeepgramHttpTTSService(api_key="test-key", aiohttp_session=session, settings=settings)
|
||||
# Bypass start() lifecycle: sample_rate is the only field run_tts reads.
|
||||
service._sample_rate = 16000
|
||||
service._session.post = MagicMock(return_value=_FakeResponse())
|
||||
return service
|
||||
|
||||
|
||||
async def _drain(gen):
|
||||
async for _ in gen:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_mip_opt_out_true_in_params():
|
||||
service = _make_http_service(mip_opt_out=True)
|
||||
|
||||
await _drain(service.run_tts("hello", "ctx"))
|
||||
|
||||
params = service._session.post.call_args.kwargs["params"]
|
||||
assert params["mip_opt_out"] == "true"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_mip_opt_out_false_in_params():
|
||||
service = _make_http_service(mip_opt_out=False)
|
||||
|
||||
await _drain(service.run_tts("hello", "ctx"))
|
||||
|
||||
params = service._session.post.call_args.kwargs["params"]
|
||||
assert params["mip_opt_out"] == "false"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_mip_opt_out_default_absent():
|
||||
service = _make_http_service()
|
||||
|
||||
await _drain(service.run_tts("hello", "ctx"))
|
||||
|
||||
params = service._session.post.call_args.kwargs["params"]
|
||||
assert "mip_opt_out" not in params
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_explicit_empty_settings_omits_mip_opt_out():
|
||||
"""Explicit Settings() with no kwargs must not leak the NOT_GIVEN sentinel."""
|
||||
session = MagicMock(spec=aiohttp.ClientSession)
|
||||
service = DeepgramHttpTTSService(
|
||||
api_key="test-key",
|
||||
aiohttp_session=session,
|
||||
settings=DeepgramHttpTTSService.Settings(),
|
||||
)
|
||||
# Bypass start() lifecycle: sample_rate is the only field run_tts reads.
|
||||
service._sample_rate = 16000
|
||||
service._session.post = MagicMock(return_value=_FakeResponse())
|
||||
|
||||
await _drain(service.run_tts("hello", "ctx"))
|
||||
|
||||
params = service._session.post.call_args.kwargs["params"]
|
||||
assert "mip_opt_out" not in params
|
||||
@@ -12,10 +12,6 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Mock package version check before importing pipecat (development mode)
|
||||
_version_patcher = patch("importlib.metadata.version", return_value="0.0.0-dev")
|
||||
_version_patcher.start()
|
||||
|
||||
# Mock krisp_audio before any pipecat import that loads krisp_instance / VIVA IP strategy
|
||||
mock_krisp_audio = MagicMock()
|
||||
mock_krisp_audio.SamplingRate.Sr8000Hz = 8000
|
||||
@@ -37,18 +33,22 @@ sys.modules["pipecat_ai_krisp"] = mock_pipecat_krisp
|
||||
sys.modules["pipecat_ai_krisp.audio"] = MagicMock()
|
||||
sys.modules["pipecat_ai_krisp.audio.krisp_processor"] = MagicMock()
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
InputAudioRawFrame,
|
||||
TranscriptionFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.turns.types import ProcessFrameResult
|
||||
from pipecat.turns.user_start.krisp_viva_ip_user_turn_start_strategy import (
|
||||
KrispVivaIPUserTurnStartStrategy,
|
||||
)
|
||||
# The version patch is scoped to just the import so it doesn't leak across the
|
||||
# test session and corrupt importlib.metadata.version for other tests
|
||||
# (e.g. transformers' import-time dependency checks).
|
||||
with patch("importlib.metadata.version", return_value="0.0.0-dev"):
|
||||
from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
InputAudioRawFrame,
|
||||
TranscriptionFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.turns.types import ProcessFrameResult
|
||||
from pipecat.turns.user_start.krisp_viva_ip_user_turn_start_strategy import (
|
||||
KrispVivaIPUserTurnStartStrategy,
|
||||
)
|
||||
|
||||
STRATEGY_MODULE = "pipecat.turns.user_start.krisp_viva_ip_user_turn_start_strategy"
|
||||
|
||||
|
||||
@@ -11,11 +11,6 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# Mock package version check before importing pipecat
|
||||
# This allows tests to run in development mode without installed package
|
||||
_version_patcher = patch("importlib.metadata.version", return_value="0.0.0-dev")
|
||||
_version_patcher.start()
|
||||
|
||||
# Mock krisp_audio module BEFORE any pipecat imports
|
||||
# This allows tests to run without krisp_audio installed
|
||||
mock_krisp_audio = MagicMock()
|
||||
@@ -48,12 +43,15 @@ sys.modules["pipecat_ai_krisp"] = mock_pipecat_krisp
|
||||
sys.modules["pipecat_ai_krisp.audio"] = MagicMock()
|
||||
sys.modules["pipecat_ai_krisp.audio.krisp_processor"] = MagicMock()
|
||||
|
||||
# Now we can safely import
|
||||
from pipecat.audio.krisp_instance import (
|
||||
KRISP_SAMPLE_RATES,
|
||||
KrispVivaSDKManager,
|
||||
int_to_krisp_sample_rate,
|
||||
)
|
||||
# Now we can safely import. The version patch is scoped to just the import so
|
||||
# it doesn't leak across the test session and corrupt importlib.metadata.version
|
||||
# for other tests (e.g. transformers' import-time dependency checks).
|
||||
with patch("importlib.metadata.version", return_value="0.0.0-dev"):
|
||||
from pipecat.audio.krisp_instance import (
|
||||
KRISP_SAMPLE_RATES,
|
||||
KrispVivaSDKManager,
|
||||
int_to_krisp_sample_rate,
|
||||
)
|
||||
|
||||
|
||||
class TestKrispVivaSDKManager:
|
||||
|
||||
@@ -13,11 +13,6 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Mock package version check before importing pipecat
|
||||
# This allows tests to run in development mode without installed package
|
||||
_version_patcher = patch("importlib.metadata.version", return_value="0.0.0-dev")
|
||||
_version_patcher.start()
|
||||
|
||||
# Mock krisp_audio module BEFORE any pipecat imports
|
||||
# This allows tests to run without krisp_audio installed
|
||||
mock_krisp_audio = MagicMock()
|
||||
@@ -42,9 +37,12 @@ sys.modules["pipecat_ai_krisp"] = mock_pipecat_krisp
|
||||
sys.modules["pipecat_ai_krisp.audio"] = MagicMock()
|
||||
sys.modules["pipecat_ai_krisp.audio.krisp_processor"] = MagicMock()
|
||||
|
||||
# Now we can safely import
|
||||
from pipecat.audio.filters.krisp_viva_filter import KrispVivaFilter
|
||||
from pipecat.frames.frames import FilterEnableFrame
|
||||
# Now we can safely import. The version patch is scoped to just the import so
|
||||
# it doesn't leak across the test session and corrupt importlib.metadata.version
|
||||
# for other tests (e.g. transformers' import-time dependency checks).
|
||||
with patch("importlib.metadata.version", return_value="0.0.0-dev"):
|
||||
from pipecat.audio.filters.krisp_viva_filter import KrispVivaFilter
|
||||
from pipecat.frames.frames import FilterEnableFrame
|
||||
|
||||
|
||||
class TestKrispVivaFilter(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
@@ -15,11 +15,6 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Mock package version check before importing pipecat
|
||||
# This allows tests to run in development mode without installed package
|
||||
_version_patcher = patch("importlib.metadata.version", return_value="0.0.0-dev")
|
||||
_version_patcher.start()
|
||||
|
||||
# Mock krisp_audio module BEFORE any pipecat imports
|
||||
# This allows tests to run without krisp_audio installed
|
||||
mock_krisp_audio = MagicMock()
|
||||
@@ -44,9 +39,12 @@ sys.modules["pipecat_ai_krisp"] = mock_pipecat_krisp
|
||||
sys.modules["pipecat_ai_krisp.audio"] = MagicMock()
|
||||
sys.modules["pipecat_ai_krisp.audio.krisp_processor"] = MagicMock()
|
||||
|
||||
# Now we can safely import
|
||||
from pipecat.audio.vad.krisp_viva_vad import KrispVivaVadAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
# Now we can safely import. The version patch is scoped to just the import so
|
||||
# it doesn't leak across the test session and corrupt importlib.metadata.version
|
||||
# for other tests (e.g. transformers' import-time dependency checks).
|
||||
with patch("importlib.metadata.version", return_value="0.0.0-dev"):
|
||||
from pipecat.audio.vad.krisp_viva_vad import KrispVivaVadAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
|
||||
|
||||
class TestKrispVivaVadAnalyzer(unittest.TestCase):
|
||||
|
||||
@@ -1,140 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import unittest
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from pipecat.adapters.schemas.direct_function import DirectFunctionWrapper
|
||||
from pipecat.clocks.system_clock import SystemClock
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.frame_processor import FrameProcessorSetup
|
||||
from pipecat.services.llm_service import (
|
||||
FunctionCallParams,
|
||||
FunctionCallRegistryItem,
|
||||
FunctionCallRunnerItem,
|
||||
LLMService,
|
||||
)
|
||||
from pipecat.services.settings import LLMSettings
|
||||
from pipecat.utils.asyncio.task_manager import TaskManager, TaskManagerParams
|
||||
|
||||
|
||||
@dataclass
|
||||
class _Resources:
|
||||
user_name: str
|
||||
db: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class _MockLLMService(LLMService):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(settings=LLMSettings(), **kwargs)
|
||||
|
||||
|
||||
class TestFunctionCallParamsToolResources(unittest.TestCase):
|
||||
def test_default_is_none(self):
|
||||
params = FunctionCallParams(
|
||||
function_name="f",
|
||||
tool_call_id="1",
|
||||
arguments={},
|
||||
llm=None, # type: ignore[arg-type]
|
||||
context=LLMContext(),
|
||||
result_callback=AsyncMock(),
|
||||
)
|
||||
self.assertIsNone(params.tool_resources)
|
||||
|
||||
def test_holds_reference(self):
|
||||
resources = _Resources(user_name="John")
|
||||
params = FunctionCallParams(
|
||||
function_name="f",
|
||||
tool_call_id="1",
|
||||
arguments={},
|
||||
llm=None, # type: ignore[arg-type]
|
||||
context=LLMContext(),
|
||||
result_callback=AsyncMock(),
|
||||
tool_resources=resources,
|
||||
)
|
||||
self.assertIs(params.tool_resources, resources)
|
||||
|
||||
|
||||
class TestLLMServiceCachesToolResources(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_setup_caches_tool_resources(self):
|
||||
service = _MockLLMService()
|
||||
resources = _Resources(user_name="John")
|
||||
task_manager = TaskManager()
|
||||
task_manager.setup(TaskManagerParams(loop=asyncio.get_running_loop()))
|
||||
|
||||
await service.setup(
|
||||
FrameProcessorSetup(
|
||||
clock=SystemClock(),
|
||||
task_manager=task_manager,
|
||||
tool_resources=resources,
|
||||
)
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
await service.cleanup()
|
||||
|
||||
self.assertIs(service._tool_resources, resources)
|
||||
|
||||
async def test_function_call_params_receives_tool_resources(self):
|
||||
service = _MockLLMService()
|
||||
resources = _Resources(user_name="John")
|
||||
service._tool_resources = resources
|
||||
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
async def handler(params: FunctionCallParams):
|
||||
captured["params"] = params
|
||||
params.tool_resources.db["hit"] = True
|
||||
await params.result_callback({"ok": True})
|
||||
|
||||
service._functions["lookup"] = FunctionCallRegistryItem(
|
||||
function_name="lookup",
|
||||
handler=handler,
|
||||
cancel_on_interruption=True,
|
||||
)
|
||||
service.broadcast_frame = AsyncMock() # type: ignore[method-assign]
|
||||
|
||||
runner_item = FunctionCallRunnerItem(
|
||||
registry_item=service._functions["lookup"],
|
||||
function_name="lookup",
|
||||
tool_call_id="call-1",
|
||||
arguments={},
|
||||
context=LLMContext(),
|
||||
)
|
||||
await service._run_function_call(runner_item)
|
||||
|
||||
self.assertIs(captured["params"].tool_resources, resources)
|
||||
self.assertTrue(resources.db["hit"])
|
||||
|
||||
async def test_direct_function_params_receives_tool_resources(self):
|
||||
service = _MockLLMService()
|
||||
resources = _Resources(user_name="John")
|
||||
service._tool_resources = resources
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
async def lookup(params: FunctionCallParams):
|
||||
captured["params"] = params
|
||||
|
||||
wrapper = DirectFunctionWrapper(lookup)
|
||||
service._functions[wrapper.name] = FunctionCallRegistryItem(
|
||||
function_name=wrapper.name,
|
||||
handler=wrapper,
|
||||
cancel_on_interruption=True,
|
||||
)
|
||||
service.broadcast_frame = AsyncMock() # type: ignore[method-assign]
|
||||
|
||||
runner_item = FunctionCallRunnerItem(
|
||||
registry_item=service._functions[wrapper.name],
|
||||
function_name=wrapper.name,
|
||||
tool_call_id="call-1",
|
||||
arguments={},
|
||||
context=LLMContext(),
|
||||
)
|
||||
await service._run_function_call(runner_item)
|
||||
|
||||
self.assertIs(captured["params"].tool_resources, resources)
|
||||
Reference in New Issue
Block a user