From 1b5c4cfa2a1545ba09261c5c179e435165a08040 Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Thu, 30 Apr 2026 16:16:17 -0400 Subject: [PATCH] feat: broaden tool_resources to app_resources Broaden `tool_resources` to `app_resources` for easy access not just in tool handlers but in other places like custom `FrameProcessor`s. Involves 3 changes: - A rename: `tool_resources` -> `app_resources` - A new property on `PipelineTask`: `app_resources` - A new property on `FrameProcessor`: `pipeline_task` Usage in tool handler: async def get_weather(params: FunctionCallParams): resources = cast(MyAppResources, params.app_resources) ... Usage in custom `FrameProcessor`: class MyProcessor(FrameProcessor): async def process_frame(self, frame, direction): await super().process_frame(frame, direction) if self.pipeline_task is not None: resources = cast(MyAppResources, self.pipeline_task.app_resources) ... The previous `tool_resources` aliases (on `PipelineTask`, `FunctionCallParams`, and `FrameProcessorSetup`) keep working but are deprecated as of 1.2.0 and emit `DeprecationWarning`s. --- .../features-app-resources.py} | 117 +++++-- src/pipecat/pipeline/task.py | 56 ++- src/pipecat/processors/frame_processor.py | 59 +++- src/pipecat/services/llm_service.py | 46 ++- tests/test_app_resources.py | 326 ++++++++++++++++++ tests/test_tool_resources.py | 140 -------- 6 files changed, 553 insertions(+), 191 deletions(-) rename examples/{function-calling/function-calling-tool-resources.py => features/features-app-resources.py} (65%) create mode 100644 tests/test_app_resources.py delete mode 100644 tests/test_tool_resources.py diff --git a/examples/function-calling/function-calling-tool-resources.py b/examples/features/features-app-resources.py similarity index 65% rename from examples/function-calling/function-calling-tool-resources.py rename to examples/features/features-app-resources.py index ac1a8df8e..dc07f30e5 100644 --- a/examples/function-calling/function-calling-tool-resources.py +++ b/examples/features/features-app-resources.py @@ -4,23 +4,33 @@ # SPDX-License-Identifier: BSD 2-Clause License # -"""Example demonstrating ``PipelineTask(tool_resources=...)``. +"""Example demonstrating ``PipelineTask(app_resources=...)``. -``tool_resources`` is an application-defined bag of anything you want every -tool handler in a session to share by reference: database handles, HTTP -clients, feature flags, per-user state, observability clients, in-memory -caches — whatever fits your app. Pipecat passes it through untouched as -``FunctionCallParams.tool_resources``. +``app_resources`` is an application-defined bag of anything your +application code may want to share across a session: database handles, +HTTP clients, feature flags, per-user state, observability clients, +in-memory caches — whatever fits your app. Pipecat passes it through +untouched and exposes it as ``task.app_resources``, so any code with a +handle on the task can read or mutate it. -This example uses a small ``ToolCallLogger`` as a stand-in for that "shared -thing". A real app might just as easily pass a Postgres pool, a Redis -client, a Stripe SDK instance, or any combination thereof. The mechanics -shown here — construct once, hand to the task, read it from each handler, -inspect it after the session — are the same regardless of what you put in. +Two of the convenience aliases exercised below: -We bundle resources in a typed ``SessionResources`` dataclass and cast back -to it at the top of each handler. Pipecat doesn't care what type you pass -(a plain dict works too), but a typed container gives you autocomplete and +- Tool handlers read it from ``FunctionCallParams.app_resources``. +- Custom ``FrameProcessor`` subclasses read it from + ``self.pipeline_task.app_resources``. + +This example uses two small loggers as stand-ins for that "shared thing": +``ToolCallLogger`` (written from tool handlers) and +``TranscriptionLogger`` (written from a custom ``FrameProcessor`` that +sits in the pipeline). A real app might just as easily pass a Postgres +pool, a Redis client, a Stripe SDK instance, or any combination thereof. +The mechanics shown here — construct once, hand to the task, read it +from each site, inspect it after the session — are the same regardless +of what you put in. + +We bundle resources in a typed ``AppResources`` dataclass and cast back +to it at each read site. Pipecat doesn't care what type you pass (a +plain dict works too), but a typed container gives you autocomplete and refactor safety instead of dict-by-string-key lookups. """ @@ -28,7 +38,7 @@ import json import os from collections.abc import Mapping from dataclasses import dataclass -from datetime import UTC, datetime, timezone +from datetime import UTC, datetime from typing import Any, cast from dotenv import load_dotenv @@ -37,7 +47,7 @@ from loguru import logger from pipecat.adapters.schemas.function_schema import FunctionSchema from pipecat.adapters.schemas.tools_schema import ToolsSchema from pipecat.audio.vad.silero import SileroVADAnalyzer -from pipecat.frames.frames import LLMRunFrame, TTSSpeakFrame +from pipecat.frames.frames import Frame, LLMRunFrame, TranscriptionFrame, TTSSpeakFrame from pipecat.pipeline.pipeline import Pipeline from pipecat.pipeline.runner import PipelineRunner from pipecat.pipeline.task import PipelineParams, PipelineTask @@ -46,6 +56,7 @@ from pipecat.processors.aggregators.llm_response_universal import ( LLMContextAggregatorPair, LLMUserAggregatorParams, ) +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor from pipecat.runner.types import RunnerArguments from pipecat.runner.utils import create_transport from pipecat.services.cartesia.tts import CartesiaTTSService @@ -86,30 +97,80 @@ class ToolCallLogger: return json.dumps(self._calls, indent=2) +class TranscriptionLogger: + """Records final user transcriptions — written from a custom FrameProcessor.""" + + def __init__(self): + """Initialize the logger with an empty list of recorded transcriptions.""" + self._entries: list[dict[str, Any]] = [] + + def log_transcription(self, text: str) -> None: + """Record a transcription. + + Args: + text: The transcribed user utterance. + """ + entry = { + "timestamp": datetime.now(UTC).isoformat(), + "text": text, + } + self._entries.append(entry) + logger.info(f"[TranscriptionLogger] {text!r}") + + def dump(self) -> str: + """Return all recorded transcriptions as a JSON string.""" + return json.dumps(self._entries, indent=2) + + @dataclass -class SessionResources: - """Typed container for everything the tool handlers in this session share. +class AppResources: + """Typed container for everything the app shares across this session. Add fields here as the app grows (e.g. ``db: AsyncConnection``, - ``http: httpx.AsyncClient``). Handlers ``cast()`` ``params.tool_resources`` - to this type to get autocomplete and refactor safety. + ``http: httpx.AsyncClient``). Read sites ``cast()`` to this type to + get autocomplete and refactor safety: + + - In tools: ``cast(AppResources, params.app_resources)``. + - In custom processors: ``cast(AppResources, self.pipeline_task.app_resources)``. """ tool_call_logger: ToolCallLogger + transcription_logger: TranscriptionLogger async def fetch_weather_from_api(params: FunctionCallParams): - resources = cast(SessionResources, params.tool_resources) + resources = cast(AppResources, params.app_resources) resources.tool_call_logger.log_tool_call(params.function_name, params.arguments) await params.result_callback({"conditions": "nice", "temperature": "75"}) async def fetch_restaurant_recommendation(params: FunctionCallParams): - resources = cast(SessionResources, params.tool_resources) + resources = cast(AppResources, params.app_resources) resources.tool_call_logger.log_tool_call(params.function_name, params.arguments) await params.result_callback({"name": "The Golden Dragon"}) +class TranscriptionLoggingProcessor(FrameProcessor): + """Logs each final user transcription into the shared app resources. + + Demonstrates the second read site for ``app_resources``: any custom + ``FrameProcessor`` can reach the same bag every tool handler sees by + going through ``self.pipeline_task.app_resources``. ``pipeline_task`` + is ``None`` until the task sets the processor up, so we guard against + that case. + """ + + async def process_frame(self, frame: Frame, direction: FrameDirection): + """Forward all frames; log final user transcriptions on the way through.""" + await super().process_frame(frame, direction) + + if isinstance(frame, TranscriptionFrame) and self.pipeline_task is not None: + resources = cast(AppResources, self.pipeline_task.app_resources) + resources.transcription_logger.log_transcription(frame.text) + + await self.push_frame(frame, direction) + + # We use lambdas to defer transport parameter creation until the transport # type is selected at runtime. transport_params = { @@ -203,6 +264,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments): [ transport.input(), stt, + TranscriptionLoggingProcessor(), user_aggregator, llm, tts, @@ -211,10 +273,14 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments): ] ) - # Keep a local handle so we can read collected state after the session + # Keep local handles so we can read collected state after the session # ends; Pipecat never copies or clears the object. tool_call_logger = ToolCallLogger() - resources = SessionResources(tool_call_logger=tool_call_logger) + transcription_logger = TranscriptionLogger() + resources = AppResources( + tool_call_logger=tool_call_logger, + transcription_logger=transcription_logger, + ) task = PipelineTask( pipeline, @@ -223,7 +289,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments): enable_usage_metrics=True, ), idle_timeout_secs=runner_args.pipeline_idle_timeout_secs, - tool_resources=resources, + app_resources=resources, ) @transport.event_handler("on_client_connected") @@ -246,6 +312,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments): # The session has ended; read whatever state the handlers built up. logger.info(f"Tool calls logged during session:\n{tool_call_logger.dump()}") + logger.info(f"Transcriptions logged during session:\n{transcription_logger.dump()}") async def bot(runner_args: RunnerArguments): diff --git a/src/pipecat/pipeline/task.py b/src/pipecat/pipeline/task.py index 8b1b7dfb6..2783da6ba 100644 --- a/src/pipecat/pipeline/task.py +++ b/src/pipecat/pipeline/task.py @@ -14,6 +14,7 @@ including heartbeats, idle detection, and observer integration. import asyncio import importlib.util import os +import warnings from collections.abc import AsyncIterable, Iterable from pathlib import Path from typing import Any, TypeVar @@ -193,6 +194,7 @@ class PipelineTask(BasePipelineTask): *, params: PipelineParams | None = None, additional_span_attributes: dict | None = None, + app_resources: Any = None, cancel_on_idle_timeout: bool = True, cancel_timeout_secs: float = CANCEL_TIMEOUT_SECS, check_dangling_tasks: bool = True, @@ -216,6 +218,14 @@ class PipelineTask(BasePipelineTask): params: Configuration parameters for the pipeline. additional_span_attributes: Optional dictionary of attributes to propagate as OpenTelemetry conversation span attributes. + app_resources: Optional application-defined bag of anything your + application code may want to share across this session (DB + handles, HTTP clients, etc.), passed by reference. Pipecat + passes it through untouched and exposes it on the task itself + as ``task.app_resources`` and passes it to tool handlers as + ``FunctionCallParams.app_resources``. The framework never + copies or clears this object; the caller retains their handle + and can read any mutations after the task finishes. cancel_on_idle_timeout: Whether the pipeline task should be cancelled if the idle timeout is reached. cancel_timeout_secs: Timeout (in seconds) to wait for cancellation to happen @@ -235,13 +245,24 @@ class PipelineTask(BasePipelineTask): rtvi_observer_params: The RTVI observer parameter to use if RTVI is enabled. rtvi_processor: The RTVI processor to add if RTVI is enabled. task_manager: Optional task manager for handling asyncio tasks. - tool_resources: Optional application-defined bag of resources (DB handles, - clients, state, etc.) passed by reference to every tool handler via - ``FunctionCallParams.tool_resources``. The framework never copies or - clears this object; the caller retains their handle and can read any - mutations after the task finishes. + tool_resources: Deprecated alias for ``app_resources``. + + .. deprecated:: 1.2.0 + Use ``app_resources`` instead. ``tool_resources`` will be + removed in a future version. """ super().__init__() + if tool_resources is not None: + with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "`PipelineTask(tool_resources=...)` is deprecated since 1.2.0, " + "use `app_resources` instead.", + DeprecationWarning, + stacklevel=2, + ) + if app_resources is None: + app_resources = tool_resources self._params = params or PipelineParams() self._additional_span_attributes = additional_span_attributes or {} self._cancel_on_idle_timeout = cancel_on_idle_timeout @@ -252,7 +273,7 @@ class PipelineTask(BasePipelineTask): self._enable_tracing = enable_tracing and is_tracing_available() self._enable_turn_tracking = enable_turn_tracking self._idle_timeout_secs = idle_timeout_secs - self._tool_resources = tool_resources + self._app_resources = app_resources observers = observers or [] self._turn_tracking_observer: TurnTrackingObserver | None = None self._user_bot_latency_observer: UserBotLatencyObserver | None = None @@ -391,6 +412,21 @@ class PipelineTask(BasePipelineTask): """ return self._params + @property + def app_resources(self) -> Any: + """Get the application-defined resources passed to this task. + + This is the same object passed to the constructor as + ``app_resources``. Tool handlers can also access it via + ``FunctionCallParams.app_resources``. The framework returns the + original reference; mutations are visible to all callers. + + Returns: + The application-defined resources, or ``None`` if none were + passed. + """ + return self._app_resources + @property def pipeline(self) -> BasePipeline: """Get the full pipeline managed by this pipeline task. @@ -730,7 +766,13 @@ class PipelineTask(BasePipelineTask): clock=self._clock, task_manager=self._task_manager, observer=self._observer, - tool_resources=self._tool_resources, + pipeline_task=self, + # Populate the deprecated `tool_resources` field for backwards + # compatibility with custom FrameProcessor subclasses whose + # ``setup()`` overrides still read it. Reading the field emits a + # DeprecationWarning; new code should read + # ``setup.pipeline_task.app_resources`` instead. + tool_resources=self._app_resources, ) await self._pipeline.setup(setup) diff --git a/src/pipecat/processors/frame_processor.py b/src/pipecat/processors/frame_processor.py index fe5f49b10..0ee22a642 100644 --- a/src/pipecat/processors/frame_processor.py +++ b/src/pipecat/processors/frame_processor.py @@ -16,10 +16,12 @@ from __future__ import annotations import asyncio import dataclasses import traceback +import warnings from collections.abc import Awaitable, Callable, Coroutine from dataclasses import dataclass from enum import Enum from typing import ( + TYPE_CHECKING, Any, Optional, ) @@ -47,6 +49,9 @@ from pipecat.utils.asyncio.task_manager import BaseTaskManager from pipecat.utils.base_object import BaseObject from pipecat.utils.frame_queue import FrameQueue +if TYPE_CHECKING: + from pipecat.pipeline.task import PipelineTask + class FrameDirection(Enum): """Direction of frame flow in the processing pipeline. @@ -71,15 +76,45 @@ class FrameProcessorSetup: clock: The clock instance for timing operations. task_manager: The task manager for handling async operations. observer: Optional observer for monitoring frame processing events. - tool_resources: Application-defined resources shared with processors - for this pipeline run. + pipeline_task: The :class:`PipelineTask` running this pipeline. Stored + on each processor as ``self.pipeline_task`` so processors can + reach task-scoped state (e.g. ``self.pipeline_task.app_resources``). + tool_resources: Deprecated. :class:`PipelineTask` continues to populate + this with ``app_resources`` so that custom :class:`FrameProcessor` + subclasses whose ``setup()`` overrides read ``setup.tool_resources`` + keep working. New code should read + ``setup.pipeline_task.app_resources`` instead. + + .. deprecated:: 1.2.0 + Reading this attribute emits a ``DeprecationWarning``. Read + ``setup.pipeline_task.app_resources`` instead. + ``tool_resources`` will be removed in a future version. """ clock: BaseClock task_manager: BaseTaskManager observer: BaseObserver | None = None + pipeline_task: PipelineTask | None = None tool_resources: Any = None + def __getattribute__(self, name: str) -> Any: + # Warn when user code reads the deprecated ``tool_resources`` field. + # Set is unaffected (goes through ``__setattr__``), so PipelineTask can + # populate it for backwards compat without tripping the warning. + if name == "tool_resources": + value = object.__getattribute__(self, "tool_resources") + if value is not None: + with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "`FrameProcessorSetup.tool_resources` is deprecated since 1.2.0; " + "read `setup.pipeline_task.app_resources` instead.", + DeprecationWarning, + stacklevel=2, + ) + return value + return object.__getattribute__(self, name) + class FrameProcessorQueue(asyncio.PriorityQueue): """A priority queue for systems frames and other frames. @@ -188,6 +223,9 @@ class FrameProcessor(BaseObject): # Observer self._observer: BaseObserver | None = None + # Pipeline Task + self._pipeline_task: PipelineTask | None = None + # Other properties self._enable_metrics = False self._enable_usage_metrics = False @@ -344,6 +382,22 @@ class FrameProcessor(BaseObject): raise Exception(f"{self} TaskManager is still not initialized.") return self._task_manager + @property + def pipeline_task(self) -> PipelineTask | None: + """Get the :class:`PipelineTask` this processor is running in. + + Provides access to task-scoped state from inside a processor — most + notably ``self.pipeline_task.app_resources`` for the application's + shared bag of resources (DB handles, clients, feature flags, etc.). + + Returns: + The :class:`PipelineTask` instance that set up this processor, + or ``None`` if the processor has not yet been set up by one + (for example, before the task has started, or when the processor + was instantiated in isolation). + """ + return self._pipeline_task + def processors_with_metrics(self): """Return processors that can generate metrics. @@ -495,6 +549,7 @@ class FrameProcessor(BaseObject): self._clock = setup.clock self._task_manager = setup.task_manager self._observer = setup.observer + self._pipeline_task = setup.pipeline_task # Create processing tasks. self.__create_input_task() diff --git a/src/pipecat/services/llm_service.py b/src/pipecat/services/llm_service.py index 59689ba6f..6d8caacab 100644 --- a/src/pipecat/services/llm_service.py +++ b/src/pipecat/services/llm_service.py @@ -51,7 +51,7 @@ from pipecat.processors.aggregators.llm_context import ( LLMContext, LLMSpecificMessage, ) -from pipecat.processors.frame_processor import FrameDirection, FrameProcessorSetup +from pipecat.processors.frame_processor import FrameDirection from pipecat.services.ai_service import AIService from pipecat.services.settings import LLMSettings, assert_given from pipecat.services.websocket_service import WebsocketService @@ -107,9 +107,10 @@ class FunctionCallParams: For async function calls (``cancel_on_interruption=False``), call it with ``properties=FunctionCallResultProperties(is_final=False)`` to push intermediate updates before the final result. - tool_resources: Application-defined bag of resources (DB handles, clients, - state, etc.) shared across tool calls for the pipeline session. Set - via ``PipelineTask(..., tool_resources=...)`` and passed by reference. + app_resources: The application-defined resources passed to + ``PipelineTask(..., app_resources=...)``. Same object — passed by + reference, not a copy. Use it to share DB handles, clients, state, + feature flags, etc. across all of a session's tool handlers. """ function_name: str @@ -118,7 +119,25 @@ class FunctionCallParams: llm: LLMService context: LLMContext result_callback: FunctionCallResultCallback - tool_resources: Any = None + app_resources: Any = None + + @property + def tool_resources(self) -> Any: + """Deprecated alias for :attr:`app_resources`. + + .. deprecated:: 1.2.0 + Use :attr:`app_resources` instead. ``tool_resources`` will be + removed in a future version. + """ + with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "`FunctionCallParams.tool_resources` is deprecated since 1.2.0, " + "use `app_resources` instead.", + DeprecationWarning, + stacklevel=2, + ) + return self.app_resources @dataclass @@ -256,7 +275,6 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService): self._sequential_runner_task: asyncio.Task | None = None self._skip_tts: bool | None = None self._summary_task: asyncio.Task | None = None - self._tool_resources: Any = None self._register_event_handler("on_function_calls_started") self._register_event_handler("on_function_calls_cancelled") @@ -303,15 +321,6 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService): """ raise NotImplementedError(f"run_inference() not supported by {self.__class__.__name__}") - async def setup(self, setup: FrameProcessorSetup): - """Set up the LLM service. - - Args: - setup: The frame processor setup data. - """ - await super().setup(setup) - self._tool_resources = setup.tool_resources - async def start(self, frame: StartFrame): """Start the LLM service. @@ -882,6 +891,9 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService): # it starts would leave the coroutine in a "never awaited" state. await asyncio.sleep(0) + # _pipeline_task may be unset when the service is driven without a PipelineTask. + app_resources = self._pipeline_task.app_resources if self._pipeline_task else None + try: if isinstance(item.handler, DirectFunctionWrapper): # Handler is a DirectFunctionWrapper @@ -894,7 +906,7 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService): llm=self, context=runner_item.context, result_callback=function_call_result_callback, - tool_resources=self._tool_resources, + app_resources=app_resources, ), ) else: @@ -906,7 +918,7 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService): llm=self, context=runner_item.context, result_callback=function_call_result_callback, - tool_resources=self._tool_resources, + app_resources=app_resources, ) await item.handler(params) except Exception as e: diff --git a/tests/test_app_resources.py b/tests/test_app_resources.py new file mode 100644 index 000000000..646b8d397 --- /dev/null +++ b/tests/test_app_resources.py @@ -0,0 +1,326 @@ +# +# Copyright (c) 2024-2026, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +import unittest +from dataclasses import dataclass, field +from types import SimpleNamespace +from typing import Any +from unittest.mock import AsyncMock + +from pipecat.adapters.schemas.direct_function import DirectFunctionWrapper +from pipecat.clocks.system_clock import SystemClock +from pipecat.frames.frames import EndFrame, Frame, StartFrame +from pipecat.pipeline.base_task import PipelineTaskParams +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.task import PipelineTask +from pipecat.processors.aggregators.llm_context import LLMContext +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor, FrameProcessorSetup +from pipecat.services.llm_service import ( + FunctionCallParams, + FunctionCallRegistryItem, + FunctionCallRunnerItem, + LLMService, +) +from pipecat.services.settings import LLMSettings +from pipecat.utils.asyncio.task_manager import TaskManager, TaskManagerParams + + +@dataclass +class _Resources: + user_name: str + db: dict[str, Any] = field(default_factory=dict) + + +def _complete_llm_settings() -> LLMSettings: + """Return an LLMSettings with every field set so test_service_init's + auto-discovered ``_MockLLMService`` doesn't fail its NOT_GIVEN check.""" + return LLMSettings( + model=None, + system_instruction=None, + temperature=None, + max_tokens=None, + top_p=None, + top_k=None, + frequency_penalty=None, + presence_penalty=None, + seed=None, + filter_incomplete_user_turns=None, + user_turn_completion_config=None, + ) + + +class _MockLLMService(LLMService): + def __init__(self, **kwargs): + super().__init__(settings=_complete_llm_settings(), **kwargs) + + +class TestFunctionCallParamsAppResources(unittest.TestCase): + def test_default_is_none(self): + params = FunctionCallParams( + function_name="f", + tool_call_id="1", + arguments={}, + llm=None, # type: ignore[arg-type] + context=LLMContext(), + result_callback=AsyncMock(), + ) + self.assertIsNone(params.app_resources) + + def test_holds_reference(self): + resources = _Resources(user_name="John") + params = FunctionCallParams( + function_name="f", + tool_call_id="1", + arguments={}, + llm=None, # type: ignore[arg-type] + context=LLMContext(), + result_callback=AsyncMock(), + app_resources=resources, + ) + self.assertIs(params.app_resources, resources) + + def test_tool_resources_property_warns_and_aliases_app_resources(self): + resources = _Resources(user_name="John") + params = FunctionCallParams( + function_name="f", + tool_call_id="1", + arguments={}, + llm=None, # type: ignore[arg-type] + context=LLMContext(), + result_callback=AsyncMock(), + app_resources=resources, + ) + with self.assertWarns(DeprecationWarning): + value = params.tool_resources + self.assertIs(value, resources) + + +class TestLLMServiceFunctionCallReadsAppResources(unittest.IsolatedAsyncioTestCase): + async def test_function_call_params_receives_app_resources(self): + service = _MockLLMService() + resources = _Resources(user_name="John") + # Stub the pipeline task with just the bit LLMService reads. + service._pipeline_task = SimpleNamespace(app_resources=resources) # type: ignore[assignment] + + captured: dict[str, Any] = {} + + async def handler(params: FunctionCallParams): + captured["params"] = params + params.app_resources.db["hit"] = True + await params.result_callback({"ok": True}) + + service._functions["lookup"] = FunctionCallRegistryItem( + function_name="lookup", + handler=handler, + cancel_on_interruption=True, + ) + service.broadcast_frame = AsyncMock() # type: ignore[method-assign] + + runner_item = FunctionCallRunnerItem( + registry_item=service._functions["lookup"], + function_name="lookup", + tool_call_id="call-1", + arguments={}, + context=LLMContext(), + ) + await service._run_function_call(runner_item) + + self.assertIs(captured["params"].app_resources, resources) + self.assertTrue(resources.db["hit"]) + + async def test_direct_function_params_receives_app_resources(self): + service = _MockLLMService() + resources = _Resources(user_name="John") + service._pipeline_task = SimpleNamespace(app_resources=resources) # type: ignore[assignment] + captured: dict[str, Any] = {} + + async def lookup(params: FunctionCallParams): + captured["params"] = params + + wrapper = DirectFunctionWrapper(lookup) + service._functions[wrapper.name] = FunctionCallRegistryItem( + function_name=wrapper.name, + handler=wrapper, + cancel_on_interruption=True, + ) + service.broadcast_frame = AsyncMock() # type: ignore[method-assign] + + runner_item = FunctionCallRunnerItem( + registry_item=service._functions[wrapper.name], + function_name=wrapper.name, + tool_call_id="call-1", + arguments={}, + context=LLMContext(), + ) + await service._run_function_call(runner_item) + + self.assertIs(captured["params"].app_resources, resources) + + async def test_app_resources_none_when_pipeline_task_unset(self): + service = _MockLLMService() + captured: dict[str, Any] = {} + + async def handler(params: FunctionCallParams): + captured["params"] = params + await params.result_callback({"ok": True}) + + service._functions["lookup"] = FunctionCallRegistryItem( + function_name="lookup", + handler=handler, + cancel_on_interruption=True, + ) + service.broadcast_frame = AsyncMock() # type: ignore[method-assign] + + runner_item = FunctionCallRunnerItem( + registry_item=service._functions["lookup"], + function_name="lookup", + tool_call_id="call-1", + arguments={}, + context=LLMContext(), + ) + await service._run_function_call(runner_item) + + self.assertIsNone(captured["params"].app_resources) + + async def test_frame_processor_setup_tool_resources_warns_on_read(self): + # ``FrameProcessorSetup.tool_resources`` is retained for backwards + # compatibility with custom FrameProcessors whose ``setup()`` overrides + # still read it. The field is populated, but reading it warns. + task_manager = TaskManager() + task_manager.setup(TaskManagerParams(loop=asyncio.get_running_loop())) + resources = _Resources(user_name="John") + + # Construction itself does not warn — only reads do. + setup = FrameProcessorSetup( + clock=SystemClock(), + task_manager=task_manager, + tool_resources=resources, + ) + + with self.assertWarns(DeprecationWarning): + value = setup.tool_resources + self.assertIs(value, resources) + + +class TestPipelineTaskAppResources(unittest.TestCase): + def test_getter_returns_constructor_value(self): + resources = _Resources(user_name="John") + task = PipelineTask(Pipeline([]), app_resources=resources) + self.assertIs(task.app_resources, resources) + + def test_default_app_resources_is_none(self): + task = PipelineTask(Pipeline([])) + self.assertIsNone(task.app_resources) + + def test_tool_resources_kwarg_warns_and_aliases_app_resources(self): + resources = _Resources(user_name="John") + with self.assertWarns(DeprecationWarning): + task = PipelineTask(Pipeline([]), tool_resources=resources) + self.assertIs(task.app_resources, resources) + + def test_app_resources_takes_precedence_over_tool_resources(self): + new = _Resources(user_name="new") + old = _Resources(user_name="old") + with self.assertWarns(DeprecationWarning): + task = PipelineTask(Pipeline([]), app_resources=new, tool_resources=old) + self.assertIs(task.app_resources, new) + + +class _RecordingProcessor(FrameProcessor): + """Records the pipeline_task it sees once StartFrame reaches it.""" + + def __init__(self): + super().__init__() + self.observed_task: Any = None + self.observed_app_resources: Any = None + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + if isinstance(frame, StartFrame): + # setup() runs before any frame reaches us, so pipeline_task is wired up. + assert self.pipeline_task is not None + self.observed_task = self.pipeline_task + self.observed_app_resources = self.pipeline_task.app_resources + await self.push_frame(frame, direction) + + +class _LegacyToolResourcesReader(FrameProcessor): + """Custom processor that reads the deprecated ``setup.tool_resources``. + + Models a previously-written user FrameProcessor whose ``setup()`` + override hasn't been migrated yet. The field is populated by + ``PipelineTask`` for backwards compatibility; reading it emits a + DeprecationWarning. + """ + + def __init__(self): + super().__init__() + self.captured_tool_resources: Any = None + + async def setup(self, setup): + await super().setup(setup) + self.captured_tool_resources = setup.tool_resources + + async def process_frame(self, frame: Frame, direction: FrameDirection): + # Forward all frames so the EndFrame reaches the pipeline sink and + # ``task.run()`` can return cleanly. + await super().process_frame(frame, direction) + await self.push_frame(frame, direction) + + +class TestFrameProcessorSetupToolResourcesBackwardsCompat(unittest.IsolatedAsyncioTestCase): + async def test_legacy_processor_receives_value_via_app_resources(self): + resources = _Resources(user_name="John") + legacy = _LegacyToolResourcesReader() + pipeline = Pipeline([legacy]) + task = PipelineTask(pipeline, app_resources=resources) + + await task.queue_frame(EndFrame()) + with self.assertWarns(DeprecationWarning): + await task.run(PipelineTaskParams(loop=asyncio.get_event_loop())) + + self.assertIs(legacy.captured_tool_resources, resources) + + async def test_legacy_processor_receives_value_via_deprecated_tool_resources_kwarg( + self, + ): + # If the user is still constructing PipelineTask with the deprecated + # ``tool_resources`` kwarg (and hasn't migrated to ``app_resources``), + # legacy processors must still see the value too. + resources = _Resources(user_name="John") + legacy = _LegacyToolResourcesReader() + pipeline = Pipeline([legacy]) + with self.assertWarns(DeprecationWarning): + task = PipelineTask(pipeline, tool_resources=resources) + + await task.queue_frame(EndFrame()) + with self.assertWarns(DeprecationWarning): + await task.run(PipelineTaskParams(loop=asyncio.get_event_loop())) + + self.assertIs(legacy.captured_tool_resources, resources) + + +class TestFrameProcessorPipelineTaskAccess(unittest.IsolatedAsyncioTestCase): + async def test_processor_can_reach_pipeline_task_and_app_resources(self): + resources = _Resources(user_name="John") + recorder = _RecordingProcessor() + pipeline = Pipeline([recorder]) + task = PipelineTask(pipeline, app_resources=resources) + + await task.queue_frame(EndFrame()) + await task.run(PipelineTaskParams(loop=asyncio.get_event_loop())) + + self.assertIs(recorder.observed_task, task) + self.assertIs(recorder.observed_app_resources, resources) + + def test_pipeline_task_returns_none_when_not_set_up(self): + recorder = _RecordingProcessor() + self.assertIsNone(recorder.pipeline_task) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_tool_resources.py b/tests/test_tool_resources.py deleted file mode 100644 index 3f5026ddc..000000000 --- a/tests/test_tool_resources.py +++ /dev/null @@ -1,140 +0,0 @@ -# -# Copyright (c) 2024-2026, Daily -# -# SPDX-License-Identifier: BSD 2-Clause License -# - -import asyncio -import unittest -from dataclasses import dataclass, field -from typing import Any -from unittest.mock import AsyncMock - -from pipecat.adapters.schemas.direct_function import DirectFunctionWrapper -from pipecat.clocks.system_clock import SystemClock -from pipecat.processors.aggregators.llm_context import LLMContext -from pipecat.processors.frame_processor import FrameProcessorSetup -from pipecat.services.llm_service import ( - FunctionCallParams, - FunctionCallRegistryItem, - FunctionCallRunnerItem, - LLMService, -) -from pipecat.services.settings import LLMSettings -from pipecat.utils.asyncio.task_manager import TaskManager, TaskManagerParams - - -@dataclass -class _Resources: - user_name: str - db: dict[str, Any] = field(default_factory=dict) - - -class _MockLLMService(LLMService): - def __init__(self, **kwargs): - super().__init__(settings=LLMSettings(), **kwargs) - - -class TestFunctionCallParamsToolResources(unittest.TestCase): - def test_default_is_none(self): - params = FunctionCallParams( - function_name="f", - tool_call_id="1", - arguments={}, - llm=None, # type: ignore[arg-type] - context=LLMContext(), - result_callback=AsyncMock(), - ) - self.assertIsNone(params.tool_resources) - - def test_holds_reference(self): - resources = _Resources(user_name="John") - params = FunctionCallParams( - function_name="f", - tool_call_id="1", - arguments={}, - llm=None, # type: ignore[arg-type] - context=LLMContext(), - result_callback=AsyncMock(), - tool_resources=resources, - ) - self.assertIs(params.tool_resources, resources) - - -class TestLLMServiceCachesToolResources(unittest.IsolatedAsyncioTestCase): - async def test_setup_caches_tool_resources(self): - service = _MockLLMService() - resources = _Resources(user_name="John") - task_manager = TaskManager() - task_manager.setup(TaskManagerParams(loop=asyncio.get_running_loop())) - - await service.setup( - FrameProcessorSetup( - clock=SystemClock(), - task_manager=task_manager, - tool_resources=resources, - ) - ) - await asyncio.sleep(0) - await service.cleanup() - - self.assertIs(service._tool_resources, resources) - - async def test_function_call_params_receives_tool_resources(self): - service = _MockLLMService() - resources = _Resources(user_name="John") - service._tool_resources = resources - - captured: dict[str, Any] = {} - - async def handler(params: FunctionCallParams): - captured["params"] = params - params.tool_resources.db["hit"] = True - await params.result_callback({"ok": True}) - - service._functions["lookup"] = FunctionCallRegistryItem( - function_name="lookup", - handler=handler, - cancel_on_interruption=True, - ) - service.broadcast_frame = AsyncMock() # type: ignore[method-assign] - - runner_item = FunctionCallRunnerItem( - registry_item=service._functions["lookup"], - function_name="lookup", - tool_call_id="call-1", - arguments={}, - context=LLMContext(), - ) - await service._run_function_call(runner_item) - - self.assertIs(captured["params"].tool_resources, resources) - self.assertTrue(resources.db["hit"]) - - async def test_direct_function_params_receives_tool_resources(self): - service = _MockLLMService() - resources = _Resources(user_name="John") - service._tool_resources = resources - captured: dict[str, Any] = {} - - async def lookup(params: FunctionCallParams): - captured["params"] = params - - wrapper = DirectFunctionWrapper(lookup) - service._functions[wrapper.name] = FunctionCallRegistryItem( - function_name=wrapper.name, - handler=wrapper, - cancel_on_interruption=True, - ) - service.broadcast_frame = AsyncMock() # type: ignore[method-assign] - - runner_item = FunctionCallRunnerItem( - registry_item=service._functions[wrapper.name], - function_name=wrapper.name, - tool_call_id="call-1", - arguments={}, - context=LLMContext(), - ) - await service._run_function_call(runner_item) - - self.assertIs(captured["params"].tool_resources, resources)