fix: restore cancel_on_interruption=False support in AWS Nova Sonic and OpenAI Realtime
Before the new async-tool mechanism landed, AWSNovaSonicLLMService and OpenAIRealtimeLLMService honored cancel_on_interruption=False by simply not cancelling in-flight function calls on interruption — the eventual result then flowed through the same channel as any synchronous tool result. The new mechanism (which appends started/intermediate/final messages to the LLM context as the underlying task progresses) broke that path: the realtime services didn't know how to interpret those messages, and the eventual result was never delivered to the provider. Restore the flag's behavior by teaching both services to detect async-tool messages in the context and route them appropriately: - started → skipped silently. The provider already issued the tool call and natively awaits a result; nothing to send for the started marker. - final → delivered via the formal tool-result channel. Same path as a synchronous tool result, just delayed. Streamed intermediate results (FunctionCallResultProperties(is_final= False)) are not supported on these realtime services. An intermediate result is logged as an error and surfaced via push_error, then dropped. Use a non-realtime LLM service if a tool needs to stream intermediate results. (Docstrings on register_function, register_direct_function, and FunctionCallResultProperties.is_final updated to call this out.) A new shared module pipecat.processors.aggregators.async_tool_messages is the single source of truth for the on-the-wire payload shape: the aggregator uses its build_*_message functions when injecting messages, and the realtime services use parse_message when scanning the context. Adds two example files exercising a network-delayed weather tool with each service. The plain realtime-aws-nova-sonic.py example is also reverted to a synchronous tool call now that the async variant lives in its own file. Similar fixes for other realtime services are forthcoming.
This commit is contained in:
1
changelog/4441.fixed.md
Normal file
1
changelog/4441.fixed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Restored `cancel_on_interruption=False` support for `AWSNovaSonicLLMService` and `OpenAIRealtimeLLMService`. These services previously honored the flag by simply not cancelling in-flight function calls on interruption; the introduction of the new async-tool mechanism (which threads started/intermediate/final messages through the LLM context) broke that path because the realtime services didn't know how to interpret those messages. Note that new-style streamed intermediate results (`FunctionCallResultProperties(is_final=False)`) are not supported on these realtime services. Similar fixes for other impacted realtime services are forthcoming.
|
||||
184
examples/realtime/realtime-aws-nova-sonic-async-tool.py
Normal file
184
examples/realtime/realtime-aws-nova-sonic-async-tool.py
Normal file
@@ -0,0 +1,184 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Example: async function call with the AWS Nova Sonic LLM service.
|
||||
|
||||
The ``get_current_weather`` tool is registered with
|
||||
``cancel_on_interruption=False`` and simulates a slow API call (10s sleep).
|
||||
While the call is in flight the conversation continues; the result arrives
|
||||
later via the async-tool mechanism and is forwarded to Nova Sonic via the
|
||||
formal toolResult channel so the model can integrate it naturally into its
|
||||
next turn.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import random
|
||||
from datetime import datetime
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.aws.nova_sonic.llm import AWSNovaSonicLLMService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
async def fetch_weather_from_api(params: FunctionCallParams):
|
||||
# Simulate a long-running API call so we can demonstrate that the
|
||||
# conversation continues while the tool is in flight.
|
||||
await asyncio.sleep(10)
|
||||
temperature = (
|
||||
random.randint(60, 85)
|
||||
if params.arguments["format"] == "fahrenheit"
|
||||
else random.randint(15, 30)
|
||||
)
|
||||
await params.result_callback(
|
||||
{
|
||||
"conditions": "nice",
|
||||
"temperature": temperature,
|
||||
"location": params.arguments["location"],
|
||||
"format": params.arguments["format"],
|
||||
"timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
weather_function = FunctionSchema(
|
||||
name="get_current_weather",
|
||||
description="Get the current weather",
|
||||
properties={
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location.",
|
||||
},
|
||||
},
|
||||
required=["location", "format"],
|
||||
)
|
||||
|
||||
tools = ToolsSchema(standard_tools=[weather_function])
|
||||
|
||||
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
system_instruction = (
|
||||
"You are a friendly assistant. The user and you will engage in a spoken "
|
||||
"dialog exchanging the transcripts of a natural real-time conversation. "
|
||||
"Keep your responses short, generally two or three sentences for chatty "
|
||||
"scenarios. When the user asks for the weather, call get_current_weather. "
|
||||
"While you wait for the result, keep chatting with the user. When the "
|
||||
"result arrives, share it with the user naturally."
|
||||
)
|
||||
|
||||
llm = AWSNovaSonicLLMService(
|
||||
secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"],
|
||||
access_key_id=os.environ["AWS_ACCESS_KEY_ID"],
|
||||
region=os.environ["AWS_REGION"],
|
||||
session_token=os.getenv("AWS_SESSION_TOKEN"),
|
||||
settings=AWSNovaSonicLLMService.Settings(
|
||||
voice="tiffany",
|
||||
system_instruction=system_instruction,
|
||||
),
|
||||
)
|
||||
|
||||
llm.register_function(
|
||||
"get_current_weather",
|
||||
fetch_weather_from_api,
|
||||
cancel_on_interruption=False,
|
||||
)
|
||||
|
||||
context = LLMContext(tools=tools)
|
||||
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(vad_analyzer=SileroVADAnalyzer()),
|
||||
)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
user_aggregator,
|
||||
llm,
|
||||
transport.output(),
|
||||
assistant_aggregator,
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
context.add_message(
|
||||
{"role": "developer", "content": "Please introduce yourself to the user."}
|
||||
)
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -46,11 +46,6 @@ async def fetch_weather_from_api(params: FunctionCallParams):
|
||||
if params.arguments["format"] == "fahrenheit"
|
||||
else random.randint(15, 30)
|
||||
)
|
||||
# Simulate a long network delay.
|
||||
# You can continue chatting while waiting for this to complete.
|
||||
# With Nova 2 Sonic (the default model), the assistant will respond
|
||||
# appropriately once the function call is complete.
|
||||
await asyncio.sleep(5)
|
||||
await params.result_callback(
|
||||
{
|
||||
"conditions": "nice",
|
||||
@@ -150,9 +145,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
# Register function for function calls
|
||||
# you can either register a single function for all function calls, or specific functions
|
||||
# llm.register_function(None, fetch_weather_from_api)
|
||||
llm.register_function(
|
||||
"get_current_weather", fetch_weather_from_api, cancel_on_interruption=False
|
||||
)
|
||||
llm.register_function("get_current_weather", fetch_weather_from_api)
|
||||
|
||||
# Set up context and context management.
|
||||
context = LLMContext(tools=tools)
|
||||
|
||||
198
examples/realtime/realtime-openai-async-tool.py
Normal file
198
examples/realtime/realtime-openai-async-tool.py
Normal file
@@ -0,0 +1,198 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Example: async function call with the OpenAI Realtime LLM service.
|
||||
|
||||
The ``get_current_weather`` tool is registered with
|
||||
``cancel_on_interruption=False`` and simulates a slow API call (10s sleep).
|
||||
While the call is in flight the conversation continues; the result arrives
|
||||
later via the async-tool mechanism and is forwarded to OpenAI Realtime as a
|
||||
``function_call_output`` so the model can integrate it naturally into its
|
||||
next turn.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import random
|
||||
from datetime import datetime
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.services.openai.realtime.events import (
|
||||
AudioConfiguration,
|
||||
AudioInput,
|
||||
InputAudioNoiseReduction,
|
||||
InputAudioTranscription,
|
||||
SemanticTurnDetection,
|
||||
SessionProperties,
|
||||
)
|
||||
from pipecat.services.openai.realtime.llm import OpenAIRealtimeLLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
async def fetch_weather_from_api(params: FunctionCallParams):
|
||||
# Simulate a long-running API call so we can demonstrate that the
|
||||
# conversation continues while the tool is in flight.
|
||||
await asyncio.sleep(10)
|
||||
temperature = (
|
||||
random.randint(60, 85)
|
||||
if params.arguments["format"] == "fahrenheit"
|
||||
else random.randint(15, 30)
|
||||
)
|
||||
await params.result_callback(
|
||||
{
|
||||
"conditions": "nice",
|
||||
"temperature": temperature,
|
||||
"location": params.arguments["location"],
|
||||
"format": params.arguments["format"],
|
||||
"timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
weather_function = FunctionSchema(
|
||||
name="get_current_weather",
|
||||
description="Get the current weather",
|
||||
properties={
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location.",
|
||||
},
|
||||
},
|
||||
required=["location", "format"],
|
||||
)
|
||||
|
||||
tools = ToolsSchema(standard_tools=[weather_function])
|
||||
|
||||
|
||||
system_instruction = (
|
||||
"You are a friendly assistant. The user and you will engage in a spoken "
|
||||
"dialog exchanging the transcripts of a natural real-time conversation. "
|
||||
"Keep your responses short, generally two or three sentences for chatty "
|
||||
"scenarios. When the user asks for the weather, call get_current_weather. "
|
||||
"While you wait for the result, keep chatting with the user. When the "
|
||||
"result arrives, share it with the user naturally."
|
||||
)
|
||||
|
||||
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
llm = OpenAIRealtimeLLMService(
|
||||
api_key=os.environ["OPENAI_API_KEY"],
|
||||
settings=OpenAIRealtimeLLMService.Settings(
|
||||
system_instruction=system_instruction,
|
||||
session_properties=SessionProperties(
|
||||
audio=AudioConfiguration(
|
||||
input=AudioInput(
|
||||
transcription=InputAudioTranscription(),
|
||||
turn_detection=SemanticTurnDetection(),
|
||||
noise_reduction=InputAudioNoiseReduction(type="near_field"),
|
||||
)
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
llm.register_function(
|
||||
"get_current_weather",
|
||||
fetch_weather_from_api,
|
||||
cancel_on_interruption=False,
|
||||
)
|
||||
|
||||
context = LLMContext(tools=tools)
|
||||
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(vad_analyzer=SileroVADAnalyzer()),
|
||||
)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
user_aggregator,
|
||||
llm,
|
||||
transport.output(),
|
||||
assistant_aggregator,
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
context.add_message(
|
||||
{"role": "developer", "content": "Please introduce yourself to the user."}
|
||||
)
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -695,6 +695,11 @@ class FunctionCallResultProperties:
|
||||
is_final: Whether this is the final result for the function call. When
|
||||
``False`` the result is treated as an intermediate update. Defaults to ``True``.
|
||||
Only meaningful for async function calls (``cancel_on_interruption=False``).
|
||||
Note: realtime LLM services do not support streamed intermediate
|
||||
results; they deliver only the final result to the provider. An
|
||||
intermediate result reported to a realtime service is dropped
|
||||
and an error is raised. Use a non-realtime LLM service if your
|
||||
tool needs to stream intermediate results.
|
||||
"""
|
||||
|
||||
run_llm: bool | None = None
|
||||
|
||||
286
src/pipecat/processors/aggregators/async_tool_messages.py
Normal file
286
src/pipecat/processors/aggregators/async_tool_messages.py
Normal file
@@ -0,0 +1,286 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Helpers for the async-tool message protocol used in LLM contexts.
|
||||
|
||||
When a function is registered with ``cancel_on_interruption=False``, the
|
||||
``LLMUserContextAggregator`` / ``LLMAssistantContextAggregator`` pair appends
|
||||
async-tool messages to the conversation context as the underlying task
|
||||
progresses:
|
||||
|
||||
- A ``started`` message (``role="tool"``) is appended immediately when the
|
||||
tool starts running.
|
||||
- An ``intermediate`` message (``role="developer"``) is appended each time an
|
||||
intermediate result is reported via
|
||||
``result_callback(..., FunctionCallResultProperties(is_final=False))``.
|
||||
- A ``final`` message (``role="developer"``) is appended when the task
|
||||
finishes.
|
||||
|
||||
This module is the single source of truth for the on-the-wire payload shape:
|
||||
|
||||
- The aggregator uses the ``build_*_message`` functions when injecting messages.
|
||||
- Realtime LLM services use ``parse_message`` to detect async-tool messages
|
||||
while iterating the context, then read ``payload.result`` and deliver it via
|
||||
their formal tool-result channel.
|
||||
|
||||
Internally, ``AsyncToolMessagePayload`` is the canonical structured form;
|
||||
the on-the-wire JSON string is always derived from it (never stored) so the
|
||||
two representations can't drift.
|
||||
|
||||
Consumers are expected to import the module rather than its individual
|
||||
functions, e.g.::
|
||||
|
||||
from pipecat.processors.aggregators import async_tool_messages
|
||||
...
|
||||
async_tool_messages.build_started_message(tool_call_id)
|
||||
async_tool_messages.parse_message(msg)
|
||||
"""
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal
|
||||
|
||||
from pipecat.processors.aggregators.llm_context import LLMStandardMessage
|
||||
|
||||
AsyncToolMessageKind = Literal["started", "intermediate", "final"]
|
||||
|
||||
# --- Payload shape (private; canonical source of truth) ---------------------
|
||||
|
||||
# The ``type`` field that identifies an async-tool message payload. Both the
|
||||
# builders and the parser use this constant; do not duplicate the literal.
|
||||
_PAYLOAD_TYPE = "async_tool"
|
||||
|
||||
# Status value for started / intermediate messages (task still running).
|
||||
_STATUS_RUNNING = "running"
|
||||
|
||||
# Status value for the final message (task complete).
|
||||
_STATUS_FINISHED = "finished"
|
||||
|
||||
# Description shipped on the started message. The text is intentionally
|
||||
# self-explanatory so a model reading the context can tell what's about to
|
||||
# happen even without out-of-band knowledge of the protocol.
|
||||
_STARTED_DESCRIPTION = (
|
||||
"An asynchronous task associated with this tool_call_id has started "
|
||||
"running. Expect results to arrive later as developer messages that look "
|
||||
"roughly like this one (with 'type=async_tool' and a matching tool_call_id) "
|
||||
"but with a 'result' field. Note that there *may* be more than one result "
|
||||
"(i.e., a stream of results), but there doesn't have to be (there may be "
|
||||
"only one). The last result will come in a message with 'status=finished'."
|
||||
)
|
||||
|
||||
# Description shipped on each intermediate-result message.
|
||||
_INTERMEDIATE_DESCRIPTION = (
|
||||
"This is an intermediate result for the asynchronous task associated with "
|
||||
"this tool_call_id. The task is still running. More intermediate results "
|
||||
"may follow, or the next result may be the final one with "
|
||||
"'status=finished'."
|
||||
)
|
||||
|
||||
# Description shipped on the final-result message.
|
||||
_FINAL_DESCRIPTION = (
|
||||
"This is the final result for the asynchronous task associated with this "
|
||||
"tool_call_id. The task has completed. No further results will arrive for "
|
||||
"this tool_call_id."
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AsyncToolMessagePayload:
|
||||
"""The structured contents of an async-tool message in an LLM context.
|
||||
|
||||
Parameters:
|
||||
kind: Which of the three async-tool message stages this is.
|
||||
tool_call_id: The id of the tool invocation this payload relates to.
|
||||
status: ``"running"`` for started/intermediate, ``"finished"`` for
|
||||
the final message.
|
||||
description: Human-readable description from the payload. May be empty.
|
||||
result: For ``intermediate`` and ``final`` messages, the JSON-encoded
|
||||
result string (or the literal ``"COMPLETED"`` if the function
|
||||
returned no value). ``None`` for ``started`` messages.
|
||||
"""
|
||||
|
||||
kind: AsyncToolMessageKind
|
||||
tool_call_id: str
|
||||
status: Literal["running", "finished"]
|
||||
description: str
|
||||
result: str | None
|
||||
|
||||
|
||||
# --- Internal: payload ↔ on-the-wire forms -----------------------------------
|
||||
|
||||
|
||||
def _payload_to_json(payload: AsyncToolMessagePayload) -> str:
|
||||
"""Serialize a payload to its on-the-wire JSON string form.
|
||||
|
||||
Fields that don't apply to the payload's kind are omitted (notably
|
||||
``result`` is left out of ``started`` payloads, since the task hasn't
|
||||
produced a result yet).
|
||||
"""
|
||||
obj: dict[str, Any] = {
|
||||
"type": _PAYLOAD_TYPE,
|
||||
"status": payload.status,
|
||||
"tool_call_id": payload.tool_call_id,
|
||||
"description": payload.description,
|
||||
}
|
||||
if payload.result is not None:
|
||||
obj["result"] = payload.result
|
||||
return json.dumps(obj)
|
||||
|
||||
|
||||
def _payload_to_message(payload: AsyncToolMessagePayload) -> LLMStandardMessage:
|
||||
"""Wrap a payload in the LLM context message shape that matches its kind.
|
||||
|
||||
- ``started``: ``role="tool"`` plus ``tool_call_id`` at the top level
|
||||
(so the message can sit alongside other regular tool-result messages).
|
||||
- ``intermediate`` / ``final``: ``role="developer"``; ``tool_call_id``
|
||||
lives only inside the JSON payload.
|
||||
"""
|
||||
content = _payload_to_json(payload)
|
||||
if payload.kind == "started":
|
||||
return {
|
||||
"role": "tool",
|
||||
"content": content,
|
||||
"tool_call_id": payload.tool_call_id,
|
||||
}
|
||||
return {
|
||||
"role": "developer",
|
||||
"content": content,
|
||||
}
|
||||
|
||||
|
||||
# --- Builders ----------------------------------------------------------------
|
||||
|
||||
|
||||
def build_started_message(tool_call_id: str) -> LLMStandardMessage:
|
||||
"""Build a ``started`` async-tool message for an LLM context.
|
||||
|
||||
Append the returned message to the LLM context immediately when an async
|
||||
function call (registered with ``cancel_on_interruption=False``) starts
|
||||
running. The message lets the model know a task is in flight and that its
|
||||
results will arrive later in subsequent ``developer``-role messages.
|
||||
|
||||
Args:
|
||||
tool_call_id: The id of the tool invocation this message is for.
|
||||
|
||||
Returns:
|
||||
A message ready to pass to ``LLMContext.add_message``.
|
||||
"""
|
||||
return _payload_to_message(
|
||||
AsyncToolMessagePayload(
|
||||
kind="started",
|
||||
tool_call_id=tool_call_id,
|
||||
status=_STATUS_RUNNING,
|
||||
description=_STARTED_DESCRIPTION,
|
||||
result=None,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def build_intermediate_result_message(tool_call_id: str, result: str) -> LLMStandardMessage:
|
||||
"""Build an intermediate-result async-tool message for an LLM context.
|
||||
|
||||
Append the returned message to the LLM context each time the running async
|
||||
function reports a non-final result via
|
||||
``result_callback(..., FunctionCallResultProperties(is_final=False))``.
|
||||
|
||||
Args:
|
||||
tool_call_id: The id of the tool invocation the result is for.
|
||||
result: The JSON-encoded result string (caller is responsible for
|
||||
encoding the function's actual return value, typically via
|
||||
``json.dumps``).
|
||||
|
||||
Returns:
|
||||
A message ready to pass to ``LLMContext.add_message``.
|
||||
"""
|
||||
return _payload_to_message(
|
||||
AsyncToolMessagePayload(
|
||||
kind="intermediate",
|
||||
tool_call_id=tool_call_id,
|
||||
status=_STATUS_RUNNING,
|
||||
description=_INTERMEDIATE_DESCRIPTION,
|
||||
result=result,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def build_final_result_message(tool_call_id: str, result: str) -> LLMStandardMessage:
|
||||
"""Build a final-result async-tool message for an LLM context.
|
||||
|
||||
Append the returned message to the LLM context when the async function
|
||||
finishes. After this message no further async-tool messages will arrive
|
||||
for this ``tool_call_id``.
|
||||
|
||||
Args:
|
||||
tool_call_id: The id of the tool invocation the result is for.
|
||||
result: The JSON-encoded result string, or the literal ``"COMPLETED"``
|
||||
sentinel when the function returned ``None`` (matching the same
|
||||
convention used for synchronous tool calls).
|
||||
|
||||
Returns:
|
||||
A message ready to pass to ``LLMContext.add_message``.
|
||||
"""
|
||||
return _payload_to_message(
|
||||
AsyncToolMessagePayload(
|
||||
kind="final",
|
||||
tool_call_id=tool_call_id,
|
||||
status=_STATUS_FINISHED,
|
||||
description=_FINAL_DESCRIPTION,
|
||||
result=result,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# --- Parsing -----------------------------------------------------------------
|
||||
|
||||
|
||||
def parse_message(message: LLMStandardMessage) -> AsyncToolMessagePayload | None:
|
||||
"""Decode an async-tool message payload, or return None if not async-tool.
|
||||
|
||||
Args:
|
||||
message: A standard message from the LLM context. Callers iterating
|
||||
over ``LLMContext.get_messages()`` should filter out
|
||||
``LLMSpecificMessage`` entries first; only ``LLMStandardMessage``
|
||||
values can carry async-tool payloads.
|
||||
|
||||
Returns:
|
||||
An ``AsyncToolMessagePayload`` if the message is a recognized
|
||||
async-tool payload, otherwise ``None``.
|
||||
"""
|
||||
role = message.get("role")
|
||||
if role not in ("tool", "developer"):
|
||||
return None
|
||||
content = message.get("content")
|
||||
if not isinstance(content, str):
|
||||
return None
|
||||
try:
|
||||
payload = json.loads(content)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
return None
|
||||
if not isinstance(payload, dict) or payload.get("type") != _PAYLOAD_TYPE:
|
||||
return None
|
||||
tool_call_id = payload.get("tool_call_id")
|
||||
status = payload.get("status")
|
||||
if not isinstance(tool_call_id, str) or status not in (_STATUS_RUNNING, _STATUS_FINISHED):
|
||||
return None
|
||||
description = payload.get("description", "")
|
||||
if not isinstance(description, str):
|
||||
description = ""
|
||||
result = payload.get("result")
|
||||
if result is not None and not isinstance(result, str):
|
||||
result = None
|
||||
if result is None:
|
||||
kind: AsyncToolMessageKind = "started"
|
||||
elif status == _STATUS_FINISHED:
|
||||
kind = "final"
|
||||
else:
|
||||
kind = "intermediate"
|
||||
return AsyncToolMessagePayload(
|
||||
kind=kind,
|
||||
tool_call_id=tool_call_id,
|
||||
status=status,
|
||||
description=description,
|
||||
result=result,
|
||||
)
|
||||
@@ -67,6 +67,7 @@ from pipecat.frames.frames import (
|
||||
VADUserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators import async_tool_messages
|
||||
from pipecat.processors.aggregators.llm_context import (
|
||||
LLMContext,
|
||||
LLMContextMessage,
|
||||
@@ -1278,23 +1279,7 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
|
||||
is_async = not frame.cancel_on_interruption
|
||||
if is_async:
|
||||
self._context.add_message(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": json.dumps(
|
||||
{
|
||||
"type": "async_tool",
|
||||
"status": "running",
|
||||
"tool_call_id": frame.tool_call_id,
|
||||
"description": "An asynchronous task associated with this tool_call_id has started running. "
|
||||
+ "Expect results to arrive later as developer messages that look roughly like this one (with 'type=async_tool' and a matching tool_call_id) but with a 'result' field. "
|
||||
+ "Note that there *may* be more than one result (i.e., a stream of results), but there doesn't have to be (there may be only one). "
|
||||
+ "The last result will come in a message with 'status=finished'.",
|
||||
}
|
||||
),
|
||||
"tool_call_id": frame.tool_call_id,
|
||||
}
|
||||
)
|
||||
self._context.add_message(async_tool_messages.build_started_message(frame.tool_call_id))
|
||||
else:
|
||||
self._context.add_message(
|
||||
{
|
||||
@@ -1407,19 +1392,7 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
|
||||
result = json.dumps(frame.result, ensure_ascii=False)
|
||||
self._context.add_message(
|
||||
{
|
||||
"role": "developer",
|
||||
"content": json.dumps(
|
||||
{
|
||||
"type": "async_tool",
|
||||
"tool_call_id": frame.tool_call_id,
|
||||
"status": "running",
|
||||
"description": "This is an intermediate result for the asynchronous task associated with this tool_call_id. "
|
||||
+ "The task is still running. More intermediate results may follow, or the next result may be the final one with 'status=finished'.",
|
||||
"result": result,
|
||||
}
|
||||
),
|
||||
}
|
||||
async_tool_messages.build_intermediate_result_message(frame.tool_call_id, result)
|
||||
)
|
||||
|
||||
async def _handle_function_call_finished(
|
||||
@@ -1440,19 +1413,7 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
# notified of the completed result instead of updating the IN_PROGRESS
|
||||
# tool message.
|
||||
self._context.add_message(
|
||||
{
|
||||
"role": "developer",
|
||||
"content": json.dumps(
|
||||
{
|
||||
"type": "async_tool",
|
||||
"tool_call_id": frame.tool_call_id,
|
||||
"status": "finished",
|
||||
"description": "This is the final result for the asynchronous task associated with this tool_call_id. "
|
||||
+ "The task has completed. No further results will arrive for this tool_call_id.",
|
||||
"result": result,
|
||||
}
|
||||
),
|
||||
}
|
||||
async_tool_messages.build_final_result_message(frame.tool_call_id, result)
|
||||
)
|
||||
else:
|
||||
self._update_function_call_result(frame.function_name, frame.tool_call_id, result)
|
||||
|
||||
@@ -49,6 +49,7 @@ from pipecat.frames.frames import (
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators import async_tool_messages
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext, LLMSpecificMessage
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.aws.nova_sonic.session_continuation import (
|
||||
@@ -620,6 +621,38 @@ class AWSNovaSonicLLMService(LLMService[AWSNovaSonicLLMAdapter]):
|
||||
# standard tool-result messages — skip them.
|
||||
if isinstance(message, LLMSpecificMessage):
|
||||
continue
|
||||
|
||||
# Async-tool messages live alongside regular tool messages in the
|
||||
# context; detect and route them before the regular logic so we
|
||||
# don't try to send the async-tool envelope JSON as a tool result.
|
||||
async_payload = async_tool_messages.parse_message(message)
|
||||
if async_payload is not None:
|
||||
if async_payload.tool_call_id in self._completed_tool_calls:
|
||||
continue
|
||||
if async_payload.kind == "started":
|
||||
# The provider already issued the tool call and natively
|
||||
# awaits a result; nothing to send for the started marker.
|
||||
continue
|
||||
if async_payload.kind == "intermediate":
|
||||
logger.error(
|
||||
f"{self}: Nova Sonic does not support streamed async "
|
||||
f"tool results; dropping intermediate result for "
|
||||
f"tool_call_id={async_payload.tool_call_id}. Use a "
|
||||
f"non-realtime LLM service if your tool needs to "
|
||||
f"stream intermediate results."
|
||||
)
|
||||
await self.push_error(
|
||||
error_msg="Nova Sonic does not support streamed async tool results.",
|
||||
)
|
||||
continue
|
||||
# kind == "final": deliver via the formal toolResult channel
|
||||
# — same path as a synchronous tool result, just delayed.
|
||||
if send_new_results:
|
||||
await self._send_tool_result(async_payload.tool_call_id, async_payload.result)
|
||||
self._completed_tool_calls.add(async_payload.tool_call_id)
|
||||
continue
|
||||
|
||||
# Look for newly-completed "regular" (as opposed to async-tool) results
|
||||
if message.get("role") == "tool" and message.get("content") not in [
|
||||
"IN_PROGRESS",
|
||||
"CANCELLED",
|
||||
@@ -875,6 +908,8 @@ class AWSNovaSonicLLMService(LLMService[AWSNovaSonicLLMAdapter]):
|
||||
if not self._stream or not self._prompt_name:
|
||||
return
|
||||
|
||||
logger.debug(f"Sending tool result to Nova Sonic for tool_call_id={tool_call_id}")
|
||||
|
||||
content_name = str(uuid.uuid4())
|
||||
|
||||
result_content_start = f'''
|
||||
|
||||
@@ -649,7 +649,12 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService, Generic[TAdapter]
|
||||
interruption occurs. When ``False`` the call is treated as
|
||||
asynchronous: the LLM continues the conversation immediately
|
||||
without waiting for the result, and the result is injected later
|
||||
via a developer message. Defaults to True.
|
||||
via a developer message. Defaults to True. Note: realtime
|
||||
LLM services deliver only the final result to the provider;
|
||||
intermediate streamed results (reported via
|
||||
``FunctionCallResultProperties(is_final=False)``) are
|
||||
dropped and an error is raised. Use a non-realtime LLM
|
||||
service if your tool needs to stream intermediate results.
|
||||
timeout_secs: Optional per-tool timeout in seconds. Overrides the global
|
||||
``function_call_timeout_secs`` for this specific function. Defaults to
|
||||
None, which uses the global timeout.
|
||||
@@ -687,7 +692,12 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService, Generic[TAdapter]
|
||||
interruption occurs. When ``False`` the call is treated as
|
||||
asynchronous: the LLM continues the conversation immediately
|
||||
without waiting for the result, and the result is injected later
|
||||
via a developer message. Defaults to True.
|
||||
via a developer message. Defaults to True. Note: realtime
|
||||
LLM services deliver only the final result to the provider;
|
||||
intermediate streamed results (reported via
|
||||
``FunctionCallResultProperties(is_final=False)``) are
|
||||
dropped and an error is raised. Use a non-realtime LLM
|
||||
service if your tool needs to stream intermediate results.
|
||||
timeout_secs: Optional per-tool timeout in seconds. Overrides the global
|
||||
``function_call_timeout_secs`` for this specific function. Defaults to
|
||||
None, which uses the global timeout.
|
||||
|
||||
@@ -48,7 +48,8 @@ from pipecat.frames.frames import (
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators import async_tool_messages
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext, LLMSpecificMessage
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.llm_service import FunctionCallFromLLM, LLMService
|
||||
from pipecat.services.settings import (
|
||||
@@ -1039,6 +1040,43 @@ class OpenAIRealtimeLLMService(LLMService[OpenAIRealtimeLLMAdapter]):
|
||||
# Check for set of completed function calls in the context
|
||||
sent_new_result = False
|
||||
for message in self._context.get_messages():
|
||||
# LLMSpecificMessages are opaque provider-specific payloads, not
|
||||
# standard tool-result messages — skip them.
|
||||
if isinstance(message, LLMSpecificMessage):
|
||||
continue
|
||||
|
||||
# Async-tool messages live alongside regular tool messages in the
|
||||
# context; detect and route them before the regular logic so we
|
||||
# don't try to send the async-tool envelope JSON as a tool result.
|
||||
async_payload = async_tool_messages.parse_message(message)
|
||||
if async_payload is not None:
|
||||
if async_payload.tool_call_id in self._completed_tool_calls:
|
||||
continue
|
||||
if async_payload.kind == "started":
|
||||
# The provider already issued the tool call and natively
|
||||
# awaits a result; nothing to send for the started marker.
|
||||
continue
|
||||
if async_payload.kind == "intermediate":
|
||||
logger.error(
|
||||
f"{self}: OpenAI Realtime does not support streamed async "
|
||||
f"tool results; dropping intermediate result for "
|
||||
f"tool_call_id={async_payload.tool_call_id}. Use a "
|
||||
f"non-realtime LLM service if your tool needs to "
|
||||
f"stream intermediate results."
|
||||
)
|
||||
await self.push_error(
|
||||
error_msg="OpenAI Realtime does not support streamed async tool results.",
|
||||
)
|
||||
continue
|
||||
# kind == "final": deliver via the formal tool-result channel
|
||||
# — same path as a synchronous tool result, just delayed.
|
||||
if send_new_results:
|
||||
sent_new_result = True
|
||||
await self._send_tool_result(async_payload.tool_call_id, async_payload.result)
|
||||
self._completed_tool_calls.add(async_payload.tool_call_id)
|
||||
continue
|
||||
|
||||
# Look for newly-completed "regular" (as opposed to async-tool) results
|
||||
if message.get("role") and message.get("content") != "IN_PROGRESS":
|
||||
tool_call_id = message.get("tool_call_id")
|
||||
if tool_call_id and tool_call_id not in self._completed_tool_calls:
|
||||
@@ -1101,6 +1139,7 @@ class OpenAIRealtimeLLMService(LLMService[OpenAIRealtimeLLMAdapter]):
|
||||
await self.push_error(error_msg=f"Send error: {e}")
|
||||
|
||||
async def _send_tool_result(self, tool_call_id: str, result: str):
|
||||
logger.debug(f"Sending tool result to OpenAI Realtime for tool_call_id={tool_call_id}")
|
||||
item = events.ConversationItem(
|
||||
type="function_call_output",
|
||||
call_id=tool_call_id,
|
||||
|
||||
236
tests/test_async_tool_messages.py
Normal file
236
tests/test_async_tool_messages.py
Normal file
@@ -0,0 +1,236 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import json
|
||||
import unittest
|
||||
|
||||
from pipecat.processors.aggregators import async_tool_messages
|
||||
|
||||
# The parser tests intentionally exercise the parser via the canonical
|
||||
# builders, so a drift between the two sides will surface as a parse failure
|
||||
# in CI rather than as a silent contract break in production.
|
||||
|
||||
|
||||
def _started_message(tool_call_id: str = "call_123") -> dict:
|
||||
return async_tool_messages.build_started_message(tool_call_id)
|
||||
|
||||
|
||||
def _intermediate_message(
|
||||
tool_call_id: str = "call_123",
|
||||
result: str = '"intermediate-1"',
|
||||
) -> dict:
|
||||
return async_tool_messages.build_intermediate_result_message(tool_call_id, result)
|
||||
|
||||
|
||||
def _final_message(
|
||||
tool_call_id: str = "call_123",
|
||||
result: str = '"final-result"',
|
||||
) -> dict:
|
||||
return async_tool_messages.build_final_result_message(tool_call_id, result)
|
||||
|
||||
|
||||
class TestParseMessage(unittest.TestCase):
|
||||
def test_parses_started(self):
|
||||
info = async_tool_messages.parse_message(_started_message("abc"))
|
||||
assert info is not None
|
||||
assert info.kind == "started"
|
||||
assert info.tool_call_id == "abc"
|
||||
assert info.status == "running"
|
||||
assert info.result is None
|
||||
assert "asynchronous task" in info.description
|
||||
|
||||
def test_parses_intermediate(self):
|
||||
info = async_tool_messages.parse_message(_intermediate_message("abc", '"hello"'))
|
||||
assert info is not None
|
||||
assert info.kind == "intermediate"
|
||||
assert info.tool_call_id == "abc"
|
||||
assert info.status == "running"
|
||||
assert info.result == '"hello"'
|
||||
|
||||
def test_parses_final(self):
|
||||
info = async_tool_messages.parse_message(_final_message("abc", '"done"'))
|
||||
assert info is not None
|
||||
assert info.kind == "final"
|
||||
assert info.tool_call_id == "abc"
|
||||
assert info.status == "finished"
|
||||
assert info.result == '"done"'
|
||||
|
||||
def test_parses_completed_sentinel_result(self):
|
||||
# When a function returns no value, the aggregator sets the result to
|
||||
# the literal "COMPLETED" — same convention used for synchronous tool
|
||||
# calls. The parser doesn't treat it specially; it's just a string.
|
||||
info = async_tool_messages.parse_message(_final_message("abc", "COMPLETED"))
|
||||
assert info is not None
|
||||
assert info.kind == "final"
|
||||
assert info.result == "COMPLETED"
|
||||
|
||||
def test_returns_none_for_regular_user_message(self):
|
||||
assert async_tool_messages.parse_message({"role": "user", "content": "hello"}) is None
|
||||
|
||||
def test_returns_none_for_regular_assistant_message(self):
|
||||
assert async_tool_messages.parse_message({"role": "assistant", "content": "hi"}) is None
|
||||
|
||||
def test_returns_none_for_regular_tool_message(self):
|
||||
# IN_PROGRESS / regular tool result string content.
|
||||
assert (
|
||||
async_tool_messages.parse_message(
|
||||
{"role": "tool", "tool_call_id": "x", "content": "IN_PROGRESS"}
|
||||
)
|
||||
is None
|
||||
)
|
||||
assert (
|
||||
async_tool_messages.parse_message(
|
||||
{"role": "tool", "tool_call_id": "x", "content": "weather: sunny"}
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
def test_returns_none_for_developer_message_without_payload(self):
|
||||
# role=developer is also used for non-async-tool things (potentially).
|
||||
assert (
|
||||
async_tool_messages.parse_message(
|
||||
{"role": "developer", "content": "some other developer note"}
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
def test_returns_none_for_invalid_json_content(self):
|
||||
assert async_tool_messages.parse_message({"role": "tool", "content": "{not json"}) is None
|
||||
|
||||
def test_returns_none_for_non_dict_json(self):
|
||||
assert async_tool_messages.parse_message({"role": "tool", "content": "[1, 2, 3]"}) is None
|
||||
|
||||
def test_returns_none_for_wrong_payload_type(self):
|
||||
assert (
|
||||
async_tool_messages.parse_message(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": json.dumps({"type": "something_else", "tool_call_id": "x"}),
|
||||
}
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
def test_returns_none_when_tool_call_id_missing(self):
|
||||
assert (
|
||||
async_tool_messages.parse_message(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": json.dumps({"type": "async_tool", "status": "running"}),
|
||||
}
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
def test_returns_none_when_status_invalid(self):
|
||||
assert (
|
||||
async_tool_messages.parse_message(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": json.dumps(
|
||||
{"type": "async_tool", "tool_call_id": "x", "status": "weird"}
|
||||
),
|
||||
}
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
def test_returns_none_for_non_string_content(self):
|
||||
# A multimodal message with content as a list would not be an async-tool message.
|
||||
assert (
|
||||
async_tool_messages.parse_message(
|
||||
{"role": "tool", "content": [{"type": "text", "text": "hi"}]}
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
def test_returns_none_for_missing_role(self):
|
||||
assert async_tool_messages.parse_message({"content": "{}"}) is None
|
||||
|
||||
|
||||
class TestBuilders(unittest.TestCase):
|
||||
"""Verify the builders produce the canonical payload shape and round-trip cleanly."""
|
||||
|
||||
def test_started_message_shape(self):
|
||||
msg = async_tool_messages.build_started_message("call_42")
|
||||
# Top-level: role=tool plus the tool_call_id (so the message can sit
|
||||
# alongside other regular tool messages in the context).
|
||||
assert msg["role"] == "tool"
|
||||
assert msg["tool_call_id"] == "call_42"
|
||||
payload = json.loads(msg["content"])
|
||||
assert payload["type"] == "async_tool"
|
||||
assert payload["status"] == "running"
|
||||
assert payload["tool_call_id"] == "call_42"
|
||||
assert "result" not in payload
|
||||
assert isinstance(payload["description"], str) and payload["description"]
|
||||
|
||||
def test_intermediate_message_shape(self):
|
||||
msg = async_tool_messages.build_intermediate_result_message("call_99", '"step-1"')
|
||||
# Intermediate/final use role=developer and don't carry tool_call_id at
|
||||
# the top level (that's only inside the payload).
|
||||
assert msg["role"] == "developer"
|
||||
assert "tool_call_id" not in msg
|
||||
payload = json.loads(msg["content"])
|
||||
assert payload["type"] == "async_tool"
|
||||
assert payload["status"] == "running"
|
||||
assert payload["tool_call_id"] == "call_99"
|
||||
assert payload["result"] == '"step-1"'
|
||||
assert isinstance(payload["description"], str) and payload["description"]
|
||||
|
||||
def test_final_message_shape(self):
|
||||
msg = async_tool_messages.build_final_result_message("call_7", '"all-done"')
|
||||
assert msg["role"] == "developer"
|
||||
assert "tool_call_id" not in msg
|
||||
payload = json.loads(msg["content"])
|
||||
assert payload["type"] == "async_tool"
|
||||
assert payload["status"] == "finished"
|
||||
assert payload["tool_call_id"] == "call_7"
|
||||
assert payload["result"] == '"all-done"'
|
||||
assert isinstance(payload["description"], str) and payload["description"]
|
||||
|
||||
def test_final_message_with_completed_sentinel(self):
|
||||
# The aggregator passes the literal "COMPLETED" string when the
|
||||
# function returned no value (same convention as for synchronous
|
||||
# tool calls). The builder doesn't treat it specially; it just
|
||||
# round-trips as the result.
|
||||
msg = async_tool_messages.build_final_result_message("call_1", "COMPLETED")
|
||||
payload = json.loads(msg["content"])
|
||||
assert payload["result"] == "COMPLETED"
|
||||
info = async_tool_messages.parse_message(msg)
|
||||
assert info is not None
|
||||
assert info.kind == "final"
|
||||
assert info.result == "COMPLETED"
|
||||
|
||||
def test_started_round_trip(self):
|
||||
msg = async_tool_messages.build_started_message("call_x")
|
||||
info = async_tool_messages.parse_message(msg)
|
||||
assert info is not None
|
||||
assert info.kind == "started"
|
||||
assert info.tool_call_id == "call_x"
|
||||
assert info.status == "running"
|
||||
assert info.result is None
|
||||
|
||||
def test_intermediate_round_trip(self):
|
||||
msg = async_tool_messages.build_intermediate_result_message("call_x", '{"step": 1}')
|
||||
info = async_tool_messages.parse_message(msg)
|
||||
assert info is not None
|
||||
assert info.kind == "intermediate"
|
||||
assert info.tool_call_id == "call_x"
|
||||
assert info.status == "running"
|
||||
assert info.result == '{"step": 1}'
|
||||
|
||||
def test_final_round_trip(self):
|
||||
msg = async_tool_messages.build_final_result_message("call_x", '{"answer": 42}')
|
||||
info = async_tool_messages.parse_message(msg)
|
||||
assert info is not None
|
||||
assert info.kind == "final"
|
||||
assert info.tool_call_id == "call_x"
|
||||
assert info.status == "finished"
|
||||
assert info.result == '{"answer": 42}'
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user