[WIP] AWS Nova Sonic service - add tool calling

This commit is contained in:
Paul Kompfner
2025-04-29 16:38:02 -04:00
parent f182eafb40
commit 2b7e1cb5b1
5 changed files with 267 additions and 11 deletions

View File

@@ -5,14 +5,16 @@
#
import os
from datetime import datetime
from dotenv import load_dotenv
from loguru import logger
# import logging
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.audio.vad.vad_analyzer import VADParams
from pipecat.frames.frames import LLMMessagesAppendFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
@@ -31,6 +33,39 @@ load_dotenv(override=True)
# )
async def fetch_weather_from_api(function_name, tool_call_id, args, llm, context, result_callback):
temperature = 75 if args["format"] == "fahrenheit" else 24
await result_callback(
{
"conditions": "nice",
"temperature": temperature,
"format": args["format"],
"timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
}
)
weather_function = FunctionSchema(
name="get_current_weather",
description="Get the current weather",
properties={
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this from the users location.",
},
},
required=["location", "format"],
)
# Create tools schema
tools = ToolsSchema(standard_tools=[weather_function])
async def run_bot(webrtc_connection: SmallWebRTCConnection):
logger.info(f"Starting bot")
@@ -62,20 +97,27 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection):
access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
region=os.getenv("AWS_REGION"),
voice_id="tiffany", # matthew, tiffany, amy
# instruction=system_instruction # could pass instruction here rather than context, below
# instruction=system_instruction # you could pass instruction here rather than in context
)
# Register function for function calls
# you can either register a single function for all function calls, or specific functions
# llm.register_function(None, fetch_weather_from_api)
llm.register_function("get_current_weather", fetch_weather_from_api)
# Set up context and context management.
# AWSNovaSonicService will adapt OpenAI LLM context objects with standard message format to
# what's expected by Nova Sonic.
# TODO: since we can't trigger a response upon joining, this isn't particularly useful
context = OpenAILLMContext(
messages=[
{"role": "system", "content": f"{system_instruction}"},
{
"role": "user",
"content": "Tell me hello! Don't wait for me to say anything else first!",
"content": "Say hello!",
},
]
],
tools=tools,
)
context_aggregator = llm.create_context_aggregator(context)

View File

@@ -0,0 +1,40 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import json
from typing import Any, Dict, List, Union
from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import ToolsSchema
class AWSNovaSonicLLMAdapter(BaseLLMAdapter):
@staticmethod
def _to_aws_nova_sonic_function_format(function: FunctionSchema) -> Dict[str, Any]:
return {
"toolSpec": {
"name": function.name,
"description": function.description,
"inputSchema": {
"json": json.dumps(
{
"type": "object",
"properties": function.properties,
"required": function.required,
}
)
},
}
}
def to_provider_tools_format(self, tools_schema: ToolsSchema) -> List[Dict[str, Any]]:
"""Converts function schemas to Openai Realtime function-calling format.
:return: Openai Realtime formatted function call definition.
"""
functions_schema = tools_schema.standard_tools
return [self._to_aws_nova_sonic_function_format(func) for func in functions_schema]

View File

@@ -1,9 +1,15 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import base64
import json
import uuid
from dataclasses import dataclass
from enum import Enum
from typing import Any
from typing import Any, List
from aws_sdk_bedrock_runtime.client import (
BedrockRuntimeClient,
@@ -22,6 +28,7 @@ from smithy_aws_core.credentials_resolvers.static import StaticCredentialsResolv
from smithy_aws_core.identity import AWSCredentialsIdentity
from smithy_core.aio.eventstream import DuplexEventStream
from pipecat.adapters.services.aws_nova_sonic_adapter import AWSNovaSonicLLMAdapter
from pipecat.frames.frames import (
BotStoppedSpeakingFrame,
CancelFrame,
@@ -58,10 +65,15 @@ from pipecat.services.aws_nova_sonic.context import (
AWSNovaSonicUserContextAggregator,
Role,
)
from pipecat.services.aws_nova_sonic.frames import AWSNovaSonicFunctionCallResultFrame
from pipecat.services.llm_service import LLMService
from pipecat.utils.time import time_now_iso8601
class AWSNovaSonicUnhandledFunctionException(Exception):
pass
class ContentType(Enum):
AUDIO = "AUDIO"
TEXT = "TEXT"
@@ -91,6 +103,9 @@ class CurrentContent:
class AWSNovaSonicLLMService(LLMService):
# Override the default adapter to use the AWSNovaSonicLLMAdapter one
adapter_class = AWSNovaSonicLLMAdapter
def __init__(
self,
*,
@@ -162,6 +177,8 @@ class AWSNovaSonicLLMService(LLMService):
await self._send_user_audio_event(frame)
elif isinstance(frame, BotStoppedSpeakingFrame):
await self._handle_bot_stopped_speaking()
elif isinstance(frame, AWSNovaSonicFunctionCallResultFrame):
await self._handle_function_call_result(frame)
# TODO: do we need to do anything for the below four frame types?
elif isinstance(frame, StartInterruptionFrame):
# print("[pk] StartInterruptionFrame")
@@ -206,6 +223,10 @@ class AWSNovaSonicLLMService(LLMService):
self._assistant_is_responding = False
await self._report_assistant_response_ended()
async def _handle_function_call_result(self, frame: AWSNovaSonicFunctionCallResultFrame):
result = frame.result_frame
await self._send_tool_result(tool_call_id=result.tool_call_id, result=result.result)
#
# LLM communication: lifecycle
#
@@ -228,8 +249,8 @@ class AWSNovaSonicLLMService(LLMService):
InvokeModelWithBidirectionalStreamOperationInput(model_id=self._model)
)
# Send session start events
await self._send_session_start_events()
# Send session start event
await self._send_session_start_event()
# Finish connecting
self._ready_to_send_context = True
@@ -247,6 +268,10 @@ class AWSNovaSonicLLMService(LLMService):
# Read context
history = self._context.get_messages_for_initializing_history()
# Send prompt start event, specifying tools
tools = self._context.tools
await self._send_prompt_start_event(tools)
# Send system instruction
# Instruction from context takes priority
instruction = history.instruction if history.instruction else self._instruction
@@ -318,7 +343,7 @@ class AWSNovaSonicLLMService(LLMService):
#
# TODO: make params configurable?
async def _send_session_start_events(self):
async def _send_session_start_event(self):
session_start = """
{
"event": {
@@ -334,6 +359,20 @@ class AWSNovaSonicLLMService(LLMService):
"""
await self._send_client_event(session_start)
async def _send_prompt_start_event(self, tools: List[Any]):
tools_config = (
f""",
"toolUseOutputConfiguration": {{
"mediaType": "application/json"
}},
"toolConfiguration": {{
"tools": {json.dumps(tools)}
}}
"""
if tools
else ""
)
prompt_start = f'''
{{
"event": {{
@@ -350,7 +389,7 @@ class AWSNovaSonicLLMService(LLMService):
"voiceId": "{self._voice_id}",
"encoding": "base64",
"audioType": "SPEECH"
}}
}}{tools_config}
}}
}}
}}
@@ -382,6 +421,9 @@ class AWSNovaSonicLLMService(LLMService):
await self._send_client_event(audio_content_start)
async def _send_text_event(self, text: str, role: Role):
if not self._stream:
return
content_name = str(uuid.uuid4())
text_content_start = f'''
@@ -469,6 +511,61 @@ class AWSNovaSonicLLMService(LLMService):
"""
await self._send_client_event(session_end)
async def _send_tool_result(self, tool_call_id, result):
if not self._stream:
return
# print(f"[pk] sending tool result. tool call ID: {tool_call_id}, result: {result}")
content_name = str(uuid.uuid4())
result_content_start = f'''
{{
"event": {{
"contentStart": {{
"promptName": "{self._prompt_name}",
"contentName": "{content_name}",
"interactive": false,
"type": "TOOL",
"role": "TOOL",
"toolResultInputConfiguration": {{
"toolUseId": "{tool_call_id}",
"type": "TEXT",
"textInputConfiguration": {{
"mediaType": "text/plain"
}}
}}
}}
}}
}}
'''
await self._send_client_event(result_content_start)
result_content = json.dumps(
{
"event": {
"toolResult": {
"promptName": self._prompt_name,
"contentName": content_name,
"content": json.dumps(result) if isinstance(result, dict) else result,
}
}
}
)
await self._send_client_event(result_content)
result_content_end = f"""
{{
"event": {{
"contentEnd": {{
"promptName": "{self._prompt_name}",
"contentName": "{content_name}"
}}
}}
}}
"""
await self._send_client_event(result_content_end)
async def _send_client_event(self, event_json: str):
event = InvokeModelWithBidirectionalStreamInputChunk(
value=BidirectionalInputPayloadPart(bytes_=event_json.encode("utf-8"))
@@ -515,6 +612,9 @@ class AWSNovaSonicLLMService(LLMService):
elif "audioOutput" in event_json:
# Handle audio output content
await self._handle_audio_output_event(event_json)
elif "toolUse" in event_json:
# Handle tool use
await self._handle_tool_use_event(event_json)
elif "contentEnd" in event_json:
# Handle a piece of content ending
await self._handle_content_end_event(event_json)
@@ -593,6 +693,42 @@ class AWSNovaSonicLLMService(LLMService):
)
await self.push_frame(frame)
async def _handle_tool_use_event(self, event_json):
# This should never happen
if not self._content_being_received:
return
# Get tool use details
tool_use = event_json["toolUse"]
function_name = tool_use["toolName"]
tool_call_id = tool_use["toolUseId"]
arguments = json.loads(tool_use["content"])
# print(
# f"[pk] tool use - function_name: {function_name}, tool_call_id: {tool_call_id}, arguments: {arguments}"
# )
# Call tool function
if self.has_function(function_name):
if function_name in self._functions.keys():
await self.call_function(
context=self._context,
tool_call_id=tool_call_id,
function_name=function_name,
arguments=arguments,
)
elif None in self._functions.keys():
await self.call_function(
context=self._context,
tool_call_id=tool_call_id,
function_name=function_name,
arguments=arguments,
)
else:
raise AWSNovaSonicUnhandledFunctionException(
f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function."
)
async def _handle_content_end_event(self, event_json):
# This should never happen
if not self._content_being_received:
@@ -671,6 +807,9 @@ class AWSNovaSonicLLMService(LLMService):
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
) -> AWSNovaSonicContextAggregatorPair:
context.set_llm_adapter(self.get_llm_adapter())
user = AWSNovaSonicUserContextAggregator(context=context, params=user_params)
assistant = AWSNovaSonicAssistantContextAggregator(context=context, params=assistant_params)
return AWSNovaSonicContextAggregatorPair(user, assistant)

View File

@@ -1,12 +1,25 @@
#
# Copyright (c) 2025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import copy
from dataclasses import dataclass, field
from enum import Enum
from loguru import logger
from pipecat.frames.frames import DataFrame, Frame, LLMMessagesUpdateFrame, LLMSetToolsFrame
from pipecat.frames.frames import (
DataFrame,
Frame,
FunctionCallResultFrame,
LLMMessagesUpdateFrame,
LLMSetToolsFrame,
)
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.aws_nova_sonic.frames import AWSNovaSonicFunctionCallResultFrame
from pipecat.services.openai.llm import (
OpenAIAssistantContextAggregator,
OpenAIUserContextAggregator,
@@ -106,7 +119,15 @@ class AWSNovaSonicUserContextAggregator(OpenAIUserContextAggregator):
class AWSNovaSonicAssistantContextAggregator(OpenAIAssistantContextAggregator):
pass
async def handle_function_call_result(self, frame: FunctionCallResultFrame):
await super().handle_function_call_result(frame)
# The standard function callback code path pushes the FunctionCallResultFrame from the llm itself,
# so we didn't have a chance to add the result to the openai realtime api context. Let's push a
# special frame to do that.
await self.push_frame(
AWSNovaSonicFunctionCallResultFrame(result_frame=frame), FrameDirection.UPSTREAM
)
@dataclass

View File

@@ -0,0 +1,14 @@
#
# Copyright (c) 2025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
from dataclasses import dataclass
from pipecat.frames.frames import DataFrame, FunctionCallResultFrame
@dataclass
class AWSNovaSonicFunctionCallResultFrame(DataFrame):
result_frame: FunctionCallResultFrame