Compare commits
3 Commits
filipi/sma
...
pk/nova-so
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9a8cd5cee5 | ||
|
|
1bb0dc1d4f | ||
|
|
7d3726a74b |
1
changelog/4425.added.md
Normal file
1
changelog/4425.added.md
Normal file
@@ -0,0 +1 @@
|
||||
- Added support to `AWSNovaSonicLLMService` for the new "async tool call" mechanism activated by `cancel_on_interruption=False`, which includes delivering results asynchronously, delivering result streams, and cancelling running async tools. Support for the other major realtime services (`GeminiLiveLLMService`, `OpenAIRealtimeLLMService`) will be added in a follow-up PR.
|
||||
1
changelog/4425.fixed.md
Normal file
1
changelog/4425.fixed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Fixed a regression in `AWSNovaSonicLLMService` where `cancel_on_interruption=False` (which previously worked under the old async-tool-call mechanism, by simply avoiding discarding tool calls on interruptions) stopped working after the introduction of the new "async tool call" mechanism.
|
||||
@@ -74,6 +74,7 @@ async def track_current_location(params: FunctionCallParams):
|
||||
|
||||
# Second update: revised city estimate.
|
||||
await asyncio.sleep(10)
|
||||
# await asyncio.sleep(20)
|
||||
gps = {"lat": 33.96003, "lng": -118.40639}
|
||||
await params.result_callback(
|
||||
{"gps": gps, "city": "Los Angeles"},
|
||||
@@ -82,6 +83,7 @@ async def track_current_location(params: FunctionCallParams):
|
||||
|
||||
# Final result: confirmed city.
|
||||
await asyncio.sleep(10)
|
||||
# await asyncio.sleep(20)
|
||||
gps = {"lat": 32.743569, "lng": -117.20466}
|
||||
await params.result_callback({"gps": gps, "city": "San Diego"})
|
||||
|
||||
|
||||
182
examples/realtime/realtime-aws-nova-sonic-async-stream-tool.py
Normal file
182
examples/realtime/realtime-aws-nova-sonic-async-stream-tool.py
Normal file
@@ -0,0 +1,182 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Example: streaming async function call with the AWS Nova Sonic LLM service.
|
||||
|
||||
The ``track_current_location`` tool simulates a GPS tracker reporting the
|
||||
device's position during a road trip from San Francisco to San Diego. It
|
||||
sends two intermediate updates (via ``params.result_callback`` with
|
||||
``is_final=False``) as the vehicle passes through cities along the way, then
|
||||
delivers the final destination.
|
||||
|
||||
The placeholder is sent as a formal Nova Sonic ``toolResult``; each
|
||||
intermediate result is forwarded as a cross-modal user-role text input event
|
||||
so the model can fold each update into its next turn.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import FunctionCallResultProperties, LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.aws.nova_sonic.llm import AWSNovaSonicLLMService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
async def track_current_location(params: FunctionCallParams):
|
||||
"""Simulate a GPS tracker reporting position during a road trip."""
|
||||
gps = {"lat": 37.7310, "lng": -122.4527}
|
||||
await params.result_callback(
|
||||
{"gps": gps, "city": "San Francisco"},
|
||||
properties=FunctionCallResultProperties(is_final=False),
|
||||
)
|
||||
|
||||
await asyncio.sleep(10)
|
||||
gps = {"lat": 33.96003, "lng": -118.40639}
|
||||
await params.result_callback(
|
||||
{"gps": gps, "city": "Los Angeles"},
|
||||
properties=FunctionCallResultProperties(is_final=False),
|
||||
)
|
||||
|
||||
await asyncio.sleep(10)
|
||||
gps = {"lat": 32.743569, "lng": -117.20466}
|
||||
await params.result_callback({"gps": gps, "city": "San Diego"})
|
||||
|
||||
|
||||
location_function = FunctionSchema(
|
||||
name="track_current_location",
|
||||
description=(
|
||||
"Start tracking the user's current GPS location, reporting position "
|
||||
"updates until the user reaches their destination. "
|
||||
"Once this tracker is started, it doesn't need to be started again for subsequent updates; "
|
||||
"just call this function once to kick it off and the updates will come in automatically."
|
||||
),
|
||||
properties={},
|
||||
required=[],
|
||||
)
|
||||
|
||||
tools = ToolsSchema(standard_tools=[location_function])
|
||||
|
||||
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
system_instruction = (
|
||||
"You are a friendly assistant. The user and you will engage in a spoken "
|
||||
"dialog exchanging the transcripts of a natural real-time conversation. "
|
||||
"Keep your responses short, generally two or three sentences for chatty "
|
||||
"scenarios. You have access to a function that starts tracking the user's "
|
||||
"location and provides regular updates on it. Narrate each position "
|
||||
"update to the user as it arrives (city only, no coordinates). "
|
||||
"When you receive the final location, tell the user the destination has "
|
||||
"been reached."
|
||||
)
|
||||
|
||||
llm = AWSNovaSonicLLMService(
|
||||
secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"],
|
||||
access_key_id=os.environ["AWS_ACCESS_KEY_ID"],
|
||||
region=os.environ["AWS_REGION"],
|
||||
session_token=os.getenv("AWS_SESSION_TOKEN"),
|
||||
settings=AWSNovaSonicLLMService.Settings(
|
||||
voice="tiffany",
|
||||
system_instruction=system_instruction,
|
||||
),
|
||||
)
|
||||
|
||||
llm.register_function(
|
||||
"track_current_location",
|
||||
track_current_location,
|
||||
cancel_on_interruption=False,
|
||||
)
|
||||
|
||||
context = LLMContext(tools=tools)
|
||||
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(vad_analyzer=SileroVADAnalyzer()),
|
||||
)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
user_aggregator,
|
||||
llm,
|
||||
transport.output(),
|
||||
assistant_aggregator,
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
context.add_message(
|
||||
{"role": "developer", "content": "Please introduce yourself to the user."}
|
||||
)
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
183
examples/realtime/realtime-aws-nova-sonic-async-tool.py
Normal file
183
examples/realtime/realtime-aws-nova-sonic-async-tool.py
Normal file
@@ -0,0 +1,183 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Example: async function call with the AWS Nova Sonic LLM service.
|
||||
|
||||
The ``get_current_weather`` tool is registered with
|
||||
``cancel_on_interruption=False`` and simulates a slow API call (20s sleep).
|
||||
While the call is in flight the conversation continues; the result arrives
|
||||
later via the async-tool mechanism and is forwarded to Nova Sonic so the
|
||||
model can integrate it naturally into its next turn.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import random
|
||||
from datetime import datetime
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.aws.nova_sonic.llm import AWSNovaSonicLLMService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
async def fetch_weather_from_api(params: FunctionCallParams):
|
||||
# Simulate a long-running API call so we can demonstrate that the
|
||||
# conversation continues while the tool is in flight.
|
||||
await asyncio.sleep(20)
|
||||
temperature = (
|
||||
random.randint(60, 85)
|
||||
if params.arguments["format"] == "fahrenheit"
|
||||
else random.randint(15, 30)
|
||||
)
|
||||
await params.result_callback(
|
||||
{
|
||||
"conditions": "nice",
|
||||
"temperature": temperature,
|
||||
"location": params.arguments["location"],
|
||||
"format": params.arguments["format"],
|
||||
"timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
weather_function = FunctionSchema(
|
||||
name="get_current_weather",
|
||||
description="Get the current weather",
|
||||
properties={
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location.",
|
||||
},
|
||||
},
|
||||
required=["location", "format"],
|
||||
)
|
||||
|
||||
tools = ToolsSchema(standard_tools=[weather_function])
|
||||
|
||||
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
system_instruction = (
|
||||
"You are a friendly assistant. The user and you will engage in a spoken "
|
||||
"dialog exchanging the transcripts of a natural real-time conversation. "
|
||||
"Keep your responses short, generally two or three sentences for chatty "
|
||||
"scenarios. When the user asks for the weather, call get_current_weather. "
|
||||
"While you wait for the result, keep chatting with the user. When the "
|
||||
"result arrives, share it with the user naturally."
|
||||
)
|
||||
|
||||
llm = AWSNovaSonicLLMService(
|
||||
secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"],
|
||||
access_key_id=os.environ["AWS_ACCESS_KEY_ID"],
|
||||
region=os.environ["AWS_REGION"],
|
||||
session_token=os.getenv("AWS_SESSION_TOKEN"),
|
||||
settings=AWSNovaSonicLLMService.Settings(
|
||||
voice="tiffany",
|
||||
system_instruction=system_instruction,
|
||||
),
|
||||
)
|
||||
|
||||
llm.register_function(
|
||||
"get_current_weather",
|
||||
fetch_weather_from_api,
|
||||
cancel_on_interruption=False,
|
||||
)
|
||||
|
||||
context = LLMContext(tools=tools)
|
||||
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(vad_analyzer=SileroVADAnalyzer()),
|
||||
)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
user_aggregator,
|
||||
llm,
|
||||
transport.output(),
|
||||
assistant_aggregator,
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
context.add_message(
|
||||
{"role": "developer", "content": "Please introduce yourself to the user."}
|
||||
)
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -5,7 +5,6 @@
|
||||
#
|
||||
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import random
|
||||
from datetime import datetime
|
||||
@@ -46,11 +45,6 @@ async def fetch_weather_from_api(params: FunctionCallParams):
|
||||
if params.arguments["format"] == "fahrenheit"
|
||||
else random.randint(15, 30)
|
||||
)
|
||||
# Simulate a long network delay.
|
||||
# You can continue chatting while waiting for this to complete.
|
||||
# With Nova 2 Sonic (the default model), the assistant will respond
|
||||
# appropriately once the function call is complete.
|
||||
await asyncio.sleep(5)
|
||||
await params.result_callback(
|
||||
{
|
||||
"conditions": "nice",
|
||||
@@ -150,9 +144,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
# Register function for function calls
|
||||
# you can either register a single function for all function calls, or specific functions
|
||||
# llm.register_function(None, fetch_weather_from_api)
|
||||
llm.register_function(
|
||||
"get_current_weather", fetch_weather_from_api, cancel_on_interruption=False
|
||||
)
|
||||
llm.register_function("get_current_weather", fetch_weather_from_api)
|
||||
|
||||
# Set up context and context management.
|
||||
context = LLMContext(tools=tools)
|
||||
|
||||
370
src/pipecat/processors/aggregators/async_tool_messages.py
Normal file
370
src/pipecat/processors/aggregators/async_tool_messages.py
Normal file
@@ -0,0 +1,370 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Helpers for the async-tool message protocol used in LLM contexts.
|
||||
|
||||
When a function is registered with ``cancel_on_interruption=False``, the
|
||||
``LLMUserContextAggregator`` / ``LLMAssistantContextAggregator`` pair appends
|
||||
async-tool messages to the conversation context as the underlying task
|
||||
progresses:
|
||||
|
||||
- A ``started`` message (``role="tool"``) is appended immediately when the
|
||||
tool starts running.
|
||||
- An ``intermediate`` message (``role="developer"``) is appended each time an
|
||||
intermediate result is reported via
|
||||
``result_callback(..., FunctionCallResultProperties(is_final=False))``.
|
||||
- A ``final`` message (``role="developer"``) is appended when the task
|
||||
finishes.
|
||||
|
||||
This module is the single source of truth for the on-the-wire payload shape:
|
||||
|
||||
- The aggregator uses the ``build_*_message`` functions when injecting messages.
|
||||
- Realtime LLM services use ``parse_message`` when scanning the context for
|
||||
async-tool messages to forward to their providers, then
|
||||
``prepare_message_payload_for_realtime`` to produce a wire-ready string.
|
||||
|
||||
Internally, ``AsyncToolMessagePayload`` is the canonical structured form;
|
||||
the on-the-wire JSON string is always derived from it (never stored) so the
|
||||
two representations can't drift.
|
||||
|
||||
Consumers are expected to import the module rather than its individual
|
||||
functions, e.g.::
|
||||
|
||||
from pipecat.processors.aggregators import async_tool_messages
|
||||
...
|
||||
async_tool_messages.build_started_message(tool_call_id)
|
||||
async_tool_messages.parse_message(msg)
|
||||
"""
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal
|
||||
|
||||
from pipecat.processors.aggregators.llm_context import LLMStandardMessage
|
||||
|
||||
AsyncToolMessageKind = Literal["started", "intermediate", "final"]
|
||||
|
||||
# --- Payload shape (private; canonical source of truth) ---------------------
|
||||
|
||||
# The ``type`` field that identifies an async-tool message payload. Both the
|
||||
# builders and the parser use this constant; do not duplicate the literal.
|
||||
_PAYLOAD_TYPE = "async_tool"
|
||||
|
||||
# Status value for started / intermediate messages (task still running).
|
||||
_STATUS_RUNNING = "running"
|
||||
|
||||
# Status value for the final message (task complete).
|
||||
_STATUS_FINISHED = "finished"
|
||||
|
||||
# Description shipped on the started message. The text is intentionally
|
||||
# self-explanatory so a model reading the context can tell what's about to
|
||||
# happen even without out-of-band knowledge of the protocol.
|
||||
_STARTED_DESCRIPTION = (
|
||||
"An asynchronous task associated with this tool_call_id has started "
|
||||
"running. Expect results to arrive later as developer messages that look "
|
||||
"roughly like this one (with 'type=async_tool' and a matching tool_call_id) "
|
||||
"but with a 'result' field. Note that there *may* be more than one result "
|
||||
"(i.e., a stream of results), but there doesn't have to be (there may be "
|
||||
"only one). The last result will come in a message with 'status=finished'."
|
||||
)
|
||||
|
||||
# Description shipped on each intermediate-result message.
|
||||
_INTERMEDIATE_DESCRIPTION = (
|
||||
"This is an intermediate result for the asynchronous task associated with "
|
||||
"this tool_call_id. The task is still running. More intermediate results "
|
||||
"may follow, or the next result may be the final one with "
|
||||
"'status=finished'."
|
||||
)
|
||||
|
||||
# Description shipped on the final-result message.
|
||||
_FINAL_DESCRIPTION = (
|
||||
"This is the final result for the asynchronous task associated with this "
|
||||
"tool_call_id. The task has completed. No further results will arrive for "
|
||||
"this tool_call_id."
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AsyncToolMessagePayload:
|
||||
"""The structured contents of an async-tool message in an LLM context.
|
||||
|
||||
Parameters:
|
||||
kind: Which of the three async-tool message stages this is.
|
||||
tool_call_id: The id of the tool invocation this payload relates to.
|
||||
status: ``"running"`` for started/intermediate, ``"finished"`` for
|
||||
the final message.
|
||||
description: Human-readable description from the payload. May be empty.
|
||||
result: For ``intermediate`` and ``final`` messages, the JSON-encoded
|
||||
result string (or the literal ``"COMPLETED"`` if the function
|
||||
returned no value). ``None`` for ``started`` messages.
|
||||
"""
|
||||
|
||||
kind: AsyncToolMessageKind
|
||||
tool_call_id: str
|
||||
status: Literal["running", "finished"]
|
||||
description: str
|
||||
result: str | None
|
||||
|
||||
|
||||
# --- Internal: payload ↔ on-the-wire forms -----------------------------------
|
||||
|
||||
|
||||
def _payload_to_json(payload: AsyncToolMessagePayload) -> str:
|
||||
"""Serialize a payload to its on-the-wire JSON string form.
|
||||
|
||||
Fields that don't apply to the payload's kind are omitted (notably
|
||||
``result`` is left out of ``started`` payloads, since the task hasn't
|
||||
produced a result yet).
|
||||
"""
|
||||
obj: dict[str, Any] = {
|
||||
"type": _PAYLOAD_TYPE,
|
||||
"status": payload.status,
|
||||
"tool_call_id": payload.tool_call_id,
|
||||
"description": payload.description,
|
||||
}
|
||||
if payload.result is not None:
|
||||
obj["result"] = payload.result
|
||||
return json.dumps(obj)
|
||||
|
||||
|
||||
def _payload_to_message(payload: AsyncToolMessagePayload) -> LLMStandardMessage:
|
||||
"""Wrap a payload in the LLM context message shape that matches its kind.
|
||||
|
||||
- ``started``: ``role="tool"`` plus ``tool_call_id`` at the top level
|
||||
(so the message can sit alongside other regular tool-result messages).
|
||||
- ``intermediate`` / ``final``: ``role="developer"``; ``tool_call_id``
|
||||
lives only inside the JSON payload.
|
||||
"""
|
||||
content = _payload_to_json(payload)
|
||||
if payload.kind == "started":
|
||||
return {
|
||||
"role": "tool",
|
||||
"content": content,
|
||||
"tool_call_id": payload.tool_call_id,
|
||||
}
|
||||
return {
|
||||
"role": "developer",
|
||||
"content": content,
|
||||
}
|
||||
|
||||
|
||||
# --- Builders ----------------------------------------------------------------
|
||||
|
||||
|
||||
def build_started_message(tool_call_id: str) -> LLMStandardMessage:
|
||||
"""Build a ``started`` async-tool message for an LLM context.
|
||||
|
||||
Append the returned message to the LLM context immediately when an async
|
||||
function call (registered with ``cancel_on_interruption=False``) starts
|
||||
running. The message lets the model know a task is in flight and that its
|
||||
results will arrive later in subsequent ``developer``-role messages.
|
||||
|
||||
Args:
|
||||
tool_call_id: The id of the tool invocation this message is for.
|
||||
|
||||
Returns:
|
||||
A message ready to pass to ``LLMContext.add_message``.
|
||||
"""
|
||||
return _payload_to_message(
|
||||
AsyncToolMessagePayload(
|
||||
kind="started",
|
||||
tool_call_id=tool_call_id,
|
||||
status=_STATUS_RUNNING,
|
||||
description=_STARTED_DESCRIPTION,
|
||||
result=None,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def build_intermediate_result_message(tool_call_id: str, result: str) -> LLMStandardMessage:
|
||||
"""Build an intermediate-result async-tool message for an LLM context.
|
||||
|
||||
Append the returned message to the LLM context each time the running async
|
||||
function reports a non-final result via
|
||||
``result_callback(..., FunctionCallResultProperties(is_final=False))``.
|
||||
|
||||
Args:
|
||||
tool_call_id: The id of the tool invocation the result is for.
|
||||
result: The JSON-encoded result string (caller is responsible for
|
||||
encoding the function's actual return value, typically via
|
||||
``json.dumps``).
|
||||
|
||||
Returns:
|
||||
A message ready to pass to ``LLMContext.add_message``.
|
||||
"""
|
||||
return _payload_to_message(
|
||||
AsyncToolMessagePayload(
|
||||
kind="intermediate",
|
||||
tool_call_id=tool_call_id,
|
||||
status=_STATUS_RUNNING,
|
||||
description=_INTERMEDIATE_DESCRIPTION,
|
||||
result=result,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def build_final_result_message(tool_call_id: str, result: str) -> LLMStandardMessage:
|
||||
"""Build a final-result async-tool message for an LLM context.
|
||||
|
||||
Append the returned message to the LLM context when the async function
|
||||
finishes. After this message no further async-tool messages will arrive
|
||||
for this ``tool_call_id``.
|
||||
|
||||
Args:
|
||||
tool_call_id: The id of the tool invocation the result is for.
|
||||
result: The JSON-encoded result string, or the literal ``"COMPLETED"``
|
||||
sentinel when the function returned ``None`` (matching the same
|
||||
convention used for synchronous tool calls).
|
||||
|
||||
Returns:
|
||||
A message ready to pass to ``LLMContext.add_message``.
|
||||
"""
|
||||
return _payload_to_message(
|
||||
AsyncToolMessagePayload(
|
||||
kind="final",
|
||||
tool_call_id=tool_call_id,
|
||||
status=_STATUS_FINISHED,
|
||||
description=_FINAL_DESCRIPTION,
|
||||
result=result,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# --- Parsing -----------------------------------------------------------------
|
||||
|
||||
|
||||
def parse_message(message: LLMStandardMessage) -> AsyncToolMessagePayload | None:
|
||||
"""Decode an async-tool message payload, or return None if not async-tool.
|
||||
|
||||
Args:
|
||||
message: A standard message from the LLM context. Callers iterating
|
||||
over ``LLMContext.get_messages()`` should filter out
|
||||
``LLMSpecificMessage`` entries first; only ``LLMStandardMessage``
|
||||
values can carry async-tool payloads.
|
||||
|
||||
Returns:
|
||||
An ``AsyncToolMessagePayload`` if the message is a recognized
|
||||
async-tool payload, otherwise ``None``.
|
||||
"""
|
||||
role = message.get("role")
|
||||
if role not in ("tool", "developer"):
|
||||
return None
|
||||
content = message.get("content")
|
||||
if not isinstance(content, str):
|
||||
return None
|
||||
try:
|
||||
payload = json.loads(content)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
return None
|
||||
if not isinstance(payload, dict) or payload.get("type") != _PAYLOAD_TYPE:
|
||||
return None
|
||||
tool_call_id = payload.get("tool_call_id")
|
||||
status = payload.get("status")
|
||||
if not isinstance(tool_call_id, str) or status not in (_STATUS_RUNNING, _STATUS_FINISHED):
|
||||
return None
|
||||
description = payload.get("description", "")
|
||||
if not isinstance(description, str):
|
||||
description = ""
|
||||
result = payload.get("result")
|
||||
if result is not None and not isinstance(result, str):
|
||||
result = None
|
||||
if result is None:
|
||||
kind: AsyncToolMessageKind = "started"
|
||||
elif status == _STATUS_FINISHED:
|
||||
kind = "final"
|
||||
else:
|
||||
kind = "intermediate"
|
||||
return AsyncToolMessagePayload(
|
||||
kind=kind,
|
||||
tool_call_id=tool_call_id,
|
||||
status=status,
|
||||
description=description,
|
||||
result=result,
|
||||
)
|
||||
|
||||
|
||||
# --- Realtime preparation ----------------------------------------------------
|
||||
|
||||
|
||||
def prepare_message_payload_for_realtime(
|
||||
payload: AsyncToolMessagePayload,
|
||||
*,
|
||||
template: str | None = None,
|
||||
) -> str:
|
||||
"""Prepare an async-tool message payload for sending to a realtime LLM service.
|
||||
|
||||
Realtime services that fully honor the async-tool mechanism send the
|
||||
``started`` payload via the formal tool-result channel and the subsequent
|
||||
``intermediate`` / ``final`` payloads as text injected mid-conversation;
|
||||
this function returns the string to send in either case, and callers
|
||||
route it to the appropriate channel.
|
||||
|
||||
Args:
|
||||
payload: The parsed async-tool message payload.
|
||||
template: Optional format string. If provided, the rendered output is
|
||||
``template.format(tool_call_id=…, status=…, result=…, description=…)``.
|
||||
If ``None``, the payload is serialized to its canonical JSON
|
||||
form. Per-kind helpers ultimately decide what to do with the
|
||||
template, so future per-kind tweaks (e.g. raising for a kind
|
||||
that shouldn't accept templates) can be added without changing
|
||||
this signature.
|
||||
|
||||
Returns:
|
||||
The prepared string, ready to be sent to the realtime service.
|
||||
"""
|
||||
if payload.kind == "started":
|
||||
return _prepare_started_message_payload_for_realtime(payload, template=template)
|
||||
if payload.kind == "intermediate":
|
||||
return _prepare_intermediate_result_message_payload_for_realtime(payload, template=template)
|
||||
if payload.kind == "final":
|
||||
return _prepare_final_result_message_payload_for_realtime(payload, template=template)
|
||||
raise ValueError(f"Unknown async-tool message payload kind: {payload.kind!r}")
|
||||
|
||||
|
||||
def _prepare_started_message_payload_for_realtime(
|
||||
payload: AsyncToolMessagePayload,
|
||||
*,
|
||||
template: str | None = None,
|
||||
) -> str:
|
||||
if template is None:
|
||||
return _payload_to_json(payload)
|
||||
return _format_with_template(payload, template)
|
||||
|
||||
|
||||
def _prepare_intermediate_result_message_payload_for_realtime(
|
||||
payload: AsyncToolMessagePayload,
|
||||
*,
|
||||
template: str | None = None,
|
||||
) -> str:
|
||||
if template is None:
|
||||
return _payload_to_json(payload)
|
||||
return _format_with_template(payload, template)
|
||||
|
||||
|
||||
def _prepare_final_result_message_payload_for_realtime(
|
||||
payload: AsyncToolMessagePayload,
|
||||
*,
|
||||
template: str | None = None,
|
||||
) -> str:
|
||||
if template is None:
|
||||
return _payload_to_json(payload)
|
||||
return _format_with_template(payload, template)
|
||||
|
||||
|
||||
def _format_with_template(payload: AsyncToolMessagePayload, template: str) -> str:
|
||||
"""Render a payload via a caller-supplied template.
|
||||
|
||||
Available substitution keys: ``tool_call_id``, ``status``, ``result``,
|
||||
``description``. Note that ``result`` is empty for ``started`` payloads
|
||||
(no result has been produced yet); callers building templates intended
|
||||
for ``started`` should not rely on it.
|
||||
"""
|
||||
return template.format(
|
||||
tool_call_id=payload.tool_call_id,
|
||||
status=payload.status,
|
||||
result=payload.result or "",
|
||||
description=payload.description,
|
||||
)
|
||||
@@ -67,6 +67,7 @@ from pipecat.frames.frames import (
|
||||
VADUserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators import async_tool_messages
|
||||
from pipecat.processors.aggregators.llm_context import (
|
||||
LLMContext,
|
||||
LLMContextMessage,
|
||||
@@ -1075,23 +1076,7 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
|
||||
is_async = not frame.cancel_on_interruption
|
||||
if is_async:
|
||||
self._context.add_message(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": json.dumps(
|
||||
{
|
||||
"type": "async_tool",
|
||||
"status": "running",
|
||||
"tool_call_id": frame.tool_call_id,
|
||||
"description": "An asynchronous task associated with this tool_call_id has started running. "
|
||||
+ "Expect results to arrive later as developer messages that look roughly like this one (with 'type=async_tool' and a matching tool_call_id) but with a 'result' field. "
|
||||
+ "Note that there *may* be more than one result (i.e., a stream of results), but there doesn't have to be (there may be only one). "
|
||||
+ "The last result will come in a message with 'status=finished'.",
|
||||
}
|
||||
),
|
||||
"tool_call_id": frame.tool_call_id,
|
||||
}
|
||||
)
|
||||
self._context.add_message(async_tool_messages.build_started_message(frame.tool_call_id))
|
||||
else:
|
||||
self._context.add_message(
|
||||
{
|
||||
@@ -1204,19 +1189,7 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
|
||||
result = json.dumps(frame.result, ensure_ascii=False)
|
||||
self._context.add_message(
|
||||
{
|
||||
"role": "developer",
|
||||
"content": json.dumps(
|
||||
{
|
||||
"type": "async_tool",
|
||||
"tool_call_id": frame.tool_call_id,
|
||||
"status": "running",
|
||||
"description": "This is an intermediate result for the asynchronous task associated with this tool_call_id. "
|
||||
+ "The task is still running. More intermediate results may follow, or the next result may be the final one with 'status=finished'.",
|
||||
"result": result,
|
||||
}
|
||||
),
|
||||
}
|
||||
async_tool_messages.build_intermediate_result_message(frame.tool_call_id, result)
|
||||
)
|
||||
|
||||
async def _handle_function_call_finished(
|
||||
@@ -1237,19 +1210,7 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
# notified of the completed result instead of updating the IN_PROGRESS
|
||||
# tool message.
|
||||
self._context.add_message(
|
||||
{
|
||||
"role": "developer",
|
||||
"content": json.dumps(
|
||||
{
|
||||
"type": "async_tool",
|
||||
"tool_call_id": frame.tool_call_id,
|
||||
"status": "finished",
|
||||
"description": "This is the final result for the asynchronous task associated with this tool_call_id. "
|
||||
+ "The task has completed. No further results will arrive for this tool_call_id.",
|
||||
"result": result,
|
||||
}
|
||||
),
|
||||
}
|
||||
async_tool_messages.build_final_result_message(frame.tool_call_id, result)
|
||||
)
|
||||
else:
|
||||
self._update_function_call_result(frame.function_name, frame.tool_call_id, result)
|
||||
|
||||
@@ -49,6 +49,8 @@ from pipecat.frames.frames import (
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators import async_tool_messages
|
||||
from pipecat.processors.aggregators.async_tool_messages import AsyncToolMessagePayload
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext, LLMSpecificMessage
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.aws.nova_sonic.session_continuation import (
|
||||
@@ -235,6 +237,26 @@ class AWSNovaSonicLLMSettings(LLMSettings):
|
||||
endpointing_sensitivity: str | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
|
||||
|
||||
|
||||
# Bracketed plain-text template Nova Sonic uses when injecting async-tool
|
||||
# result updates onto the cross-modal user-text channel.
|
||||
#
|
||||
# Note that this template intentionally drops the payload's ``description``
|
||||
# field (the protocol-level explanation of what async-tool messages are and
|
||||
# how they work) and only carries ``tool_call_id``, ``status``, and
|
||||
# ``result``. Counterintuitively, this short framing — minus the verbose
|
||||
# protocol description, minus a JSON envelope altogether — empirically
|
||||
# yields much better Nova Sonic behavior: noticeably fewer spurious
|
||||
# re-invocations of the same tool than when the full JSON envelope (with
|
||||
# its description) was injected as text. We don't fully understand why; one
|
||||
# plausible explanation is that the model treats long, instruction-shaped
|
||||
# description text as content demanding a response, where a terse
|
||||
# bracketed status update reads more like ambient state. Worth revisiting
|
||||
# if Nova Sonic's text-channel handling changes.
|
||||
_ASYNC_TOOL_RESULT_TEXT_TEMPLATE = (
|
||||
"[Async tool update for tool_call_id={tool_call_id}, status={status}] {result}"
|
||||
)
|
||||
|
||||
|
||||
class AWSNovaSonicLLMService(LLMService[AWSNovaSonicLLMAdapter]):
|
||||
"""AWS Nova Sonic speech-to-speech LLM service.
|
||||
|
||||
@@ -414,6 +436,7 @@ class AWSNovaSonicLLMService(LLMService[AWSNovaSonicLLMAdapter]):
|
||||
self._wants_connection = False
|
||||
self._user_text_buffer = ""
|
||||
self._completed_tool_calls = set()
|
||||
self._async_tool_text_dispatched: set[str] = set()
|
||||
self._audio_input_started = False
|
||||
|
||||
# Session continuation helper. The service itself implements the
|
||||
@@ -620,6 +643,13 @@ class AWSNovaSonicLLMService(LLMService[AWSNovaSonicLLMAdapter]):
|
||||
# standard tool-result messages — skip them.
|
||||
if isinstance(message, LLMSpecificMessage):
|
||||
continue
|
||||
|
||||
async_info = async_tool_messages.parse_message(message)
|
||||
if async_info is not None:
|
||||
# Async-tool message — dispatch per the configured support tier.
|
||||
await self._dispatch_async_tool_message(async_info, send_new_results)
|
||||
continue
|
||||
|
||||
if message.get("role") == "tool" and message.get("content") not in [
|
||||
"IN_PROGRESS",
|
||||
"CANCELLED",
|
||||
@@ -631,6 +661,82 @@ class AWSNovaSonicLLMService(LLMService[AWSNovaSonicLLMAdapter]):
|
||||
await self._send_tool_result(tool_call_id, message.get("content"))
|
||||
self._completed_tool_calls.add(tool_call_id)
|
||||
|
||||
async def _dispatch_async_tool_message(
|
||||
self, info: AsyncToolMessagePayload, send_new_results: bool
|
||||
):
|
||||
"""Dispatch an async-tool message to Nova Sonic.
|
||||
|
||||
The ``started`` message is sent as a formal ``toolResult``; subsequent
|
||||
intermediate/final results are injected as cross-modal user-role text
|
||||
input events (supporting streaming async results).
|
||||
"""
|
||||
logger.trace(
|
||||
f"{self}: async_tool dispatch: kind={info.kind} "
|
||||
f"tool_call_id={info.tool_call_id} status={info.status} "
|
||||
f"send_new_results={send_new_results}"
|
||||
)
|
||||
|
||||
if info.kind == "started":
|
||||
if info.tool_call_id in self._completed_tool_calls:
|
||||
logger.trace(
|
||||
f"{self}: async_tool started already sent: tool_call_id={info.tool_call_id}"
|
||||
)
|
||||
return
|
||||
if send_new_results:
|
||||
payload = async_tool_messages.prepare_message_payload_for_realtime(info)
|
||||
logger.debug(
|
||||
f"{self}: async_tool send started as tool result: "
|
||||
f"tool_call_id={info.tool_call_id} payload={payload!r}"
|
||||
)
|
||||
await self._send_tool_result(info.tool_call_id, payload)
|
||||
else:
|
||||
logger.trace(
|
||||
f"{self}: async_tool started mark-handled (no send): "
|
||||
f"tool_call_id={info.tool_call_id}"
|
||||
)
|
||||
self._completed_tool_calls.add(info.tool_call_id)
|
||||
return
|
||||
|
||||
# info.kind in ("intermediate", "final")
|
||||
signature = self._async_tool_message_signature(info)
|
||||
if signature in self._async_tool_text_dispatched:
|
||||
logger.trace(
|
||||
f"{self}: async_tool {info.kind} already dispatched: "
|
||||
f"tool_call_id={info.tool_call_id}"
|
||||
)
|
||||
return
|
||||
if send_new_results:
|
||||
text = async_tool_messages.prepare_message_payload_for_realtime(
|
||||
info, template=_ASYNC_TOOL_RESULT_TEXT_TEMPLATE
|
||||
)
|
||||
logger.debug(
|
||||
f"{self}: async_tool send {info.kind} as text input: "
|
||||
f"tool_call_id={info.tool_call_id} text={text!r}"
|
||||
)
|
||||
await self._send_async_tool_text(text)
|
||||
else:
|
||||
logger.trace(
|
||||
f"{self}: async_tool {info.kind} mark-handled (no send): "
|
||||
f"tool_call_id={info.tool_call_id}"
|
||||
)
|
||||
self._async_tool_text_dispatched.add(signature)
|
||||
|
||||
@staticmethod
|
||||
def _async_tool_message_signature(info: AsyncToolMessagePayload) -> str:
|
||||
return f"{info.tool_call_id}|{info.status}|{info.result or ''}"
|
||||
|
||||
async def _send_async_tool_text(self, text: str):
|
||||
"""Inject mid-conversation text via Nova Sonic's cross-modal user text input.
|
||||
|
||||
Used to forward intermediate/final async-tool results to the provider.
|
||||
Sends a USER-role text content block (contentStart/textInput/contentEnd)
|
||||
with ``interactive=True``, the documented cross-modal pattern for mid-
|
||||
conversation text injection.
|
||||
"""
|
||||
if not self._stream or not self._prompt_name or not text:
|
||||
return
|
||||
await self._send_text_event(text=text, role=Role.USER, interactive=True)
|
||||
|
||||
async def _finish_connecting_if_context_available(self):
|
||||
# We can only finish connecting once we've gotten our initial context and we're ready to
|
||||
# send it
|
||||
@@ -758,6 +864,7 @@ class AWSNovaSonicLLMService(LLMService[AWSNovaSonicLLMAdapter]):
|
||||
self._connected_time = None
|
||||
self._user_text_buffer = ""
|
||||
self._completed_tool_calls = set()
|
||||
self._async_tool_text_dispatched = set()
|
||||
self._audio_input_started = False
|
||||
self._pending_speculative_text = None
|
||||
|
||||
|
||||
321
tests/test_async_tool_messages.py
Normal file
321
tests/test_async_tool_messages.py
Normal file
@@ -0,0 +1,321 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import json
|
||||
import unittest
|
||||
|
||||
from pipecat.processors.aggregators import async_tool_messages
|
||||
|
||||
# The parser tests intentionally exercise the parser via the canonical
|
||||
# builders, so a drift between the two sides will surface as a parse failure
|
||||
# in CI rather than as a silent contract break in production.
|
||||
|
||||
|
||||
def _started_message(tool_call_id: str = "call_123") -> dict:
|
||||
return async_tool_messages.build_started_message(tool_call_id)
|
||||
|
||||
|
||||
def _intermediate_message(
|
||||
tool_call_id: str = "call_123",
|
||||
result: str = '"intermediate-1"',
|
||||
) -> dict:
|
||||
return async_tool_messages.build_intermediate_result_message(tool_call_id, result)
|
||||
|
||||
|
||||
def _final_message(
|
||||
tool_call_id: str = "call_123",
|
||||
result: str = '"final-result"',
|
||||
) -> dict:
|
||||
return async_tool_messages.build_final_result_message(tool_call_id, result)
|
||||
|
||||
|
||||
class TestParseMessage(unittest.TestCase):
|
||||
def test_parses_started(self):
|
||||
info = async_tool_messages.parse_message(_started_message("abc"))
|
||||
assert info is not None
|
||||
assert info.kind == "started"
|
||||
assert info.tool_call_id == "abc"
|
||||
assert info.status == "running"
|
||||
assert info.result is None
|
||||
assert "asynchronous task" in info.description
|
||||
|
||||
def test_parses_intermediate(self):
|
||||
info = async_tool_messages.parse_message(_intermediate_message("abc", '"hello"'))
|
||||
assert info is not None
|
||||
assert info.kind == "intermediate"
|
||||
assert info.tool_call_id == "abc"
|
||||
assert info.status == "running"
|
||||
assert info.result == '"hello"'
|
||||
|
||||
def test_parses_final(self):
|
||||
info = async_tool_messages.parse_message(_final_message("abc", '"done"'))
|
||||
assert info is not None
|
||||
assert info.kind == "final"
|
||||
assert info.tool_call_id == "abc"
|
||||
assert info.status == "finished"
|
||||
assert info.result == '"done"'
|
||||
|
||||
def test_parses_completed_sentinel_result(self):
|
||||
# When a function returns no value, the aggregator sets the result to
|
||||
# the literal "COMPLETED" — same convention used for synchronous tool
|
||||
# calls. The parser doesn't treat it specially; it's just a string.
|
||||
info = async_tool_messages.parse_message(_final_message("abc", "COMPLETED"))
|
||||
assert info is not None
|
||||
assert info.kind == "final"
|
||||
assert info.result == "COMPLETED"
|
||||
|
||||
def test_returns_none_for_regular_user_message(self):
|
||||
assert async_tool_messages.parse_message({"role": "user", "content": "hello"}) is None
|
||||
|
||||
def test_returns_none_for_regular_assistant_message(self):
|
||||
assert async_tool_messages.parse_message({"role": "assistant", "content": "hi"}) is None
|
||||
|
||||
def test_returns_none_for_regular_tool_message(self):
|
||||
# IN_PROGRESS / regular tool result string content.
|
||||
assert (
|
||||
async_tool_messages.parse_message(
|
||||
{"role": "tool", "tool_call_id": "x", "content": "IN_PROGRESS"}
|
||||
)
|
||||
is None
|
||||
)
|
||||
assert (
|
||||
async_tool_messages.parse_message(
|
||||
{"role": "tool", "tool_call_id": "x", "content": "weather: sunny"}
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
def test_returns_none_for_developer_message_without_payload(self):
|
||||
# role=developer is also used for non-async-tool things (potentially).
|
||||
assert (
|
||||
async_tool_messages.parse_message(
|
||||
{"role": "developer", "content": "some other developer note"}
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
def test_returns_none_for_invalid_json_content(self):
|
||||
assert async_tool_messages.parse_message({"role": "tool", "content": "{not json"}) is None
|
||||
|
||||
def test_returns_none_for_non_dict_json(self):
|
||||
assert async_tool_messages.parse_message({"role": "tool", "content": "[1, 2, 3]"}) is None
|
||||
|
||||
def test_returns_none_for_wrong_payload_type(self):
|
||||
assert (
|
||||
async_tool_messages.parse_message(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": json.dumps({"type": "something_else", "tool_call_id": "x"}),
|
||||
}
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
def test_returns_none_when_tool_call_id_missing(self):
|
||||
assert (
|
||||
async_tool_messages.parse_message(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": json.dumps({"type": "async_tool", "status": "running"}),
|
||||
}
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
def test_returns_none_when_status_invalid(self):
|
||||
assert (
|
||||
async_tool_messages.parse_message(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": json.dumps(
|
||||
{"type": "async_tool", "tool_call_id": "x", "status": "weird"}
|
||||
),
|
||||
}
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
def test_returns_none_for_non_string_content(self):
|
||||
# A multimodal message with content as a list would not be an async-tool message.
|
||||
assert (
|
||||
async_tool_messages.parse_message(
|
||||
{"role": "tool", "content": [{"type": "text", "text": "hi"}]}
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
def test_returns_none_for_missing_role(self):
|
||||
assert async_tool_messages.parse_message({"content": "{}"}) is None
|
||||
|
||||
|
||||
class TestBuilders(unittest.TestCase):
|
||||
"""Verify the builders produce the canonical payload shape and round-trip cleanly."""
|
||||
|
||||
def test_started_message_shape(self):
|
||||
msg = async_tool_messages.build_started_message("call_42")
|
||||
# Top-level: role=tool plus the tool_call_id (so the message can sit
|
||||
# alongside other regular tool messages in the context).
|
||||
assert msg["role"] == "tool"
|
||||
assert msg["tool_call_id"] == "call_42"
|
||||
payload = json.loads(msg["content"])
|
||||
assert payload["type"] == "async_tool"
|
||||
assert payload["status"] == "running"
|
||||
assert payload["tool_call_id"] == "call_42"
|
||||
assert "result" not in payload
|
||||
assert isinstance(payload["description"], str) and payload["description"]
|
||||
|
||||
def test_intermediate_message_shape(self):
|
||||
msg = async_tool_messages.build_intermediate_result_message("call_99", '"step-1"')
|
||||
# Intermediate/final use role=developer and don't carry tool_call_id at
|
||||
# the top level (that's only inside the payload).
|
||||
assert msg["role"] == "developer"
|
||||
assert "tool_call_id" not in msg
|
||||
payload = json.loads(msg["content"])
|
||||
assert payload["type"] == "async_tool"
|
||||
assert payload["status"] == "running"
|
||||
assert payload["tool_call_id"] == "call_99"
|
||||
assert payload["result"] == '"step-1"'
|
||||
assert isinstance(payload["description"], str) and payload["description"]
|
||||
|
||||
def test_final_message_shape(self):
|
||||
msg = async_tool_messages.build_final_result_message("call_7", '"all-done"')
|
||||
assert msg["role"] == "developer"
|
||||
assert "tool_call_id" not in msg
|
||||
payload = json.loads(msg["content"])
|
||||
assert payload["type"] == "async_tool"
|
||||
assert payload["status"] == "finished"
|
||||
assert payload["tool_call_id"] == "call_7"
|
||||
assert payload["result"] == '"all-done"'
|
||||
assert isinstance(payload["description"], str) and payload["description"]
|
||||
|
||||
def test_final_message_with_completed_sentinel(self):
|
||||
# The aggregator passes the literal "COMPLETED" string when the
|
||||
# function returned no value (same convention as for synchronous
|
||||
# tool calls). The builder doesn't treat it specially; it just
|
||||
# round-trips as the result.
|
||||
msg = async_tool_messages.build_final_result_message("call_1", "COMPLETED")
|
||||
payload = json.loads(msg["content"])
|
||||
assert payload["result"] == "COMPLETED"
|
||||
info = async_tool_messages.parse_message(msg)
|
||||
assert info is not None
|
||||
assert info.kind == "final"
|
||||
assert info.result == "COMPLETED"
|
||||
|
||||
def test_started_round_trip(self):
|
||||
msg = async_tool_messages.build_started_message("call_x")
|
||||
info = async_tool_messages.parse_message(msg)
|
||||
assert info is not None
|
||||
assert info.kind == "started"
|
||||
assert info.tool_call_id == "call_x"
|
||||
assert info.status == "running"
|
||||
assert info.result is None
|
||||
|
||||
def test_intermediate_round_trip(self):
|
||||
msg = async_tool_messages.build_intermediate_result_message("call_x", '{"step": 1}')
|
||||
info = async_tool_messages.parse_message(msg)
|
||||
assert info is not None
|
||||
assert info.kind == "intermediate"
|
||||
assert info.tool_call_id == "call_x"
|
||||
assert info.status == "running"
|
||||
assert info.result == '{"step": 1}'
|
||||
|
||||
def test_final_round_trip(self):
|
||||
msg = async_tool_messages.build_final_result_message("call_x", '{"answer": 42}')
|
||||
info = async_tool_messages.parse_message(msg)
|
||||
assert info is not None
|
||||
assert info.kind == "final"
|
||||
assert info.tool_call_id == "call_x"
|
||||
assert info.status == "finished"
|
||||
assert info.result == '{"answer": 42}'
|
||||
|
||||
|
||||
class TestPrepareMessagePayloadForRealtime(unittest.TestCase):
|
||||
"""Verify the realtime preparation behavior across kinds and template usage."""
|
||||
|
||||
# --- Default (no template) → raw JSON pass-through -----------------------
|
||||
|
||||
def test_started_default_is_raw_json(self):
|
||||
msg = async_tool_messages.build_started_message("call_42")
|
||||
info = async_tool_messages.parse_message(msg)
|
||||
assert info is not None
|
||||
text = async_tool_messages.prepare_message_payload_for_realtime(info)
|
||||
decoded = json.loads(text)
|
||||
assert decoded["type"] == "async_tool"
|
||||
assert decoded["tool_call_id"] == "call_42"
|
||||
assert decoded["status"] == "running"
|
||||
# Started payloads have no result field.
|
||||
assert "result" not in decoded
|
||||
|
||||
def test_intermediate_default_is_raw_json(self):
|
||||
msg = async_tool_messages.build_intermediate_result_message("call_42", '"step-1"')
|
||||
info = async_tool_messages.parse_message(msg)
|
||||
assert info is not None
|
||||
text = async_tool_messages.prepare_message_payload_for_realtime(info)
|
||||
decoded = json.loads(text)
|
||||
assert decoded["type"] == "async_tool"
|
||||
assert decoded["tool_call_id"] == "call_42"
|
||||
assert decoded["status"] == "running"
|
||||
assert decoded["result"] == '"step-1"'
|
||||
|
||||
def test_final_default_is_raw_json(self):
|
||||
msg = async_tool_messages.build_final_result_message("call_42", '"the answer"')
|
||||
info = async_tool_messages.parse_message(msg)
|
||||
assert info is not None
|
||||
text = async_tool_messages.prepare_message_payload_for_realtime(info)
|
||||
decoded = json.loads(text)
|
||||
assert decoded["type"] == "async_tool"
|
||||
assert decoded["tool_call_id"] == "call_42"
|
||||
assert decoded["status"] == "finished"
|
||||
assert decoded["result"] == '"the answer"'
|
||||
|
||||
# --- Caller-supplied template applied across kinds -----------------------
|
||||
|
||||
def test_template_applied_to_started(self):
|
||||
msg = async_tool_messages.build_started_message("call_42")
|
||||
info = async_tool_messages.parse_message(msg)
|
||||
assert info is not None
|
||||
text = async_tool_messages.prepare_message_payload_for_realtime(
|
||||
info,
|
||||
template="[{tool_call_id} {status}] {result}",
|
||||
)
|
||||
# Started has no result; substitution yields empty string after the bracket.
|
||||
assert text == "[call_42 running] "
|
||||
|
||||
def test_template_applied_to_intermediate(self):
|
||||
msg = async_tool_messages.build_intermediate_result_message("call_42", '"step-1"')
|
||||
info = async_tool_messages.parse_message(msg)
|
||||
assert info is not None
|
||||
text = async_tool_messages.prepare_message_payload_for_realtime(
|
||||
info,
|
||||
template="[{tool_call_id} {status}] {result}",
|
||||
)
|
||||
assert text == '[call_42 running] "step-1"'
|
||||
|
||||
def test_template_applied_to_final(self):
|
||||
msg = async_tool_messages.build_final_result_message("call_42", '"the answer"')
|
||||
info = async_tool_messages.parse_message(msg)
|
||||
assert info is not None
|
||||
text = async_tool_messages.prepare_message_payload_for_realtime(
|
||||
info,
|
||||
template="[{tool_call_id} {status}] {result}",
|
||||
)
|
||||
assert text == '[call_42 finished] "the answer"'
|
||||
|
||||
def test_template_can_use_description_field(self):
|
||||
msg = async_tool_messages.build_intermediate_result_message("call_42", '"step-1"')
|
||||
info = async_tool_messages.parse_message(msg)
|
||||
assert info is not None
|
||||
text = async_tool_messages.prepare_message_payload_for_realtime(
|
||||
info,
|
||||
template="{description} >> {result}",
|
||||
)
|
||||
# The intermediate description text is preserved verbatim.
|
||||
assert "intermediate result" in text
|
||||
assert text.endswith('>> "step-1"')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user