[WIP] AWS Nova Sonic service - add tool calling
This commit is contained in:
@@ -5,14 +5,16 @@
|
||||
#
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
# import logging
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import LLMMessagesAppendFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
@@ -31,6 +33,39 @@ load_dotenv(override=True)
|
||||
# )
|
||||
|
||||
|
||||
async def fetch_weather_from_api(function_name, tool_call_id, args, llm, context, result_callback):
|
||||
temperature = 75 if args["format"] == "fahrenheit" else 24
|
||||
await result_callback(
|
||||
{
|
||||
"conditions": "nice",
|
||||
"temperature": temperature,
|
||||
"format": args["format"],
|
||||
"timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
weather_function = FunctionSchema(
|
||||
name="get_current_weather",
|
||||
description="Get the current weather",
|
||||
properties={
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location.",
|
||||
},
|
||||
},
|
||||
required=["location", "format"],
|
||||
)
|
||||
|
||||
# Create tools schema
|
||||
tools = ToolsSchema(standard_tools=[weather_function])
|
||||
|
||||
|
||||
async def run_bot(webrtc_connection: SmallWebRTCConnection):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
@@ -62,20 +97,27 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection):
|
||||
access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
|
||||
region=os.getenv("AWS_REGION"),
|
||||
voice_id="tiffany", # matthew, tiffany, amy
|
||||
# instruction=system_instruction # could pass instruction here rather than context, below
|
||||
# instruction=system_instruction # you could pass instruction here rather than in context
|
||||
)
|
||||
|
||||
# Register function for function calls
|
||||
# you can either register a single function for all function calls, or specific functions
|
||||
# llm.register_function(None, fetch_weather_from_api)
|
||||
llm.register_function("get_current_weather", fetch_weather_from_api)
|
||||
|
||||
# Set up context and context management.
|
||||
# AWSNovaSonicService will adapt OpenAI LLM context objects with standard message format to
|
||||
# what's expected by Nova Sonic.
|
||||
# TODO: since we can't trigger a response upon joining, this isn't particularly useful
|
||||
context = OpenAILLMContext(
|
||||
messages=[
|
||||
{"role": "system", "content": f"{system_instruction}"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Tell me hello! Don't wait for me to say anything else first!",
|
||||
"content": "Say hello!",
|
||||
},
|
||||
]
|
||||
],
|
||||
tools=tools,
|
||||
)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
|
||||
40
src/pipecat/adapters/services/aws_nova_sonic_adapter.py
Normal file
40
src/pipecat/adapters/services/aws_nova_sonic_adapter.py
Normal file
@@ -0,0 +1,40 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
import json
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
|
||||
|
||||
class AWSNovaSonicLLMAdapter(BaseLLMAdapter):
|
||||
@staticmethod
|
||||
def _to_aws_nova_sonic_function_format(function: FunctionSchema) -> Dict[str, Any]:
|
||||
return {
|
||||
"toolSpec": {
|
||||
"name": function.name,
|
||||
"description": function.description,
|
||||
"inputSchema": {
|
||||
"json": json.dumps(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": function.properties,
|
||||
"required": function.required,
|
||||
}
|
||||
)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
def to_provider_tools_format(self, tools_schema: ToolsSchema) -> List[Dict[str, Any]]:
|
||||
"""Converts function schemas to Openai Realtime function-calling format.
|
||||
|
||||
:return: Openai Realtime formatted function call definition.
|
||||
"""
|
||||
|
||||
functions_schema = tools_schema.standard_tools
|
||||
return [self._to_aws_nova_sonic_function_format(func) for func in functions_schema]
|
||||
@@ -1,9 +1,15 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import base64
|
||||
import json
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import Any, List
|
||||
|
||||
from aws_sdk_bedrock_runtime.client import (
|
||||
BedrockRuntimeClient,
|
||||
@@ -22,6 +28,7 @@ from smithy_aws_core.credentials_resolvers.static import StaticCredentialsResolv
|
||||
from smithy_aws_core.identity import AWSCredentialsIdentity
|
||||
from smithy_core.aio.eventstream import DuplexEventStream
|
||||
|
||||
from pipecat.adapters.services.aws_nova_sonic_adapter import AWSNovaSonicLLMAdapter
|
||||
from pipecat.frames.frames import (
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
@@ -58,10 +65,15 @@ from pipecat.services.aws_nova_sonic.context import (
|
||||
AWSNovaSonicUserContextAggregator,
|
||||
Role,
|
||||
)
|
||||
from pipecat.services.aws_nova_sonic.frames import AWSNovaSonicFunctionCallResultFrame
|
||||
from pipecat.services.llm_service import LLMService
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
|
||||
class AWSNovaSonicUnhandledFunctionException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ContentType(Enum):
|
||||
AUDIO = "AUDIO"
|
||||
TEXT = "TEXT"
|
||||
@@ -91,6 +103,9 @@ class CurrentContent:
|
||||
|
||||
|
||||
class AWSNovaSonicLLMService(LLMService):
|
||||
# Override the default adapter to use the AWSNovaSonicLLMAdapter one
|
||||
adapter_class = AWSNovaSonicLLMAdapter
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -162,6 +177,8 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
await self._send_user_audio_event(frame)
|
||||
elif isinstance(frame, BotStoppedSpeakingFrame):
|
||||
await self._handle_bot_stopped_speaking()
|
||||
elif isinstance(frame, AWSNovaSonicFunctionCallResultFrame):
|
||||
await self._handle_function_call_result(frame)
|
||||
# TODO: do we need to do anything for the below four frame types?
|
||||
elif isinstance(frame, StartInterruptionFrame):
|
||||
# print("[pk] StartInterruptionFrame")
|
||||
@@ -206,6 +223,10 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
self._assistant_is_responding = False
|
||||
await self._report_assistant_response_ended()
|
||||
|
||||
async def _handle_function_call_result(self, frame: AWSNovaSonicFunctionCallResultFrame):
|
||||
result = frame.result_frame
|
||||
await self._send_tool_result(tool_call_id=result.tool_call_id, result=result.result)
|
||||
|
||||
#
|
||||
# LLM communication: lifecycle
|
||||
#
|
||||
@@ -228,8 +249,8 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
InvokeModelWithBidirectionalStreamOperationInput(model_id=self._model)
|
||||
)
|
||||
|
||||
# Send session start events
|
||||
await self._send_session_start_events()
|
||||
# Send session start event
|
||||
await self._send_session_start_event()
|
||||
|
||||
# Finish connecting
|
||||
self._ready_to_send_context = True
|
||||
@@ -247,6 +268,10 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
# Read context
|
||||
history = self._context.get_messages_for_initializing_history()
|
||||
|
||||
# Send prompt start event, specifying tools
|
||||
tools = self._context.tools
|
||||
await self._send_prompt_start_event(tools)
|
||||
|
||||
# Send system instruction
|
||||
# Instruction from context takes priority
|
||||
instruction = history.instruction if history.instruction else self._instruction
|
||||
@@ -318,7 +343,7 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
#
|
||||
|
||||
# TODO: make params configurable?
|
||||
async def _send_session_start_events(self):
|
||||
async def _send_session_start_event(self):
|
||||
session_start = """
|
||||
{
|
||||
"event": {
|
||||
@@ -334,6 +359,20 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
"""
|
||||
await self._send_client_event(session_start)
|
||||
|
||||
async def _send_prompt_start_event(self, tools: List[Any]):
|
||||
tools_config = (
|
||||
f""",
|
||||
"toolUseOutputConfiguration": {{
|
||||
"mediaType": "application/json"
|
||||
}},
|
||||
"toolConfiguration": {{
|
||||
"tools": {json.dumps(tools)}
|
||||
}}
|
||||
"""
|
||||
if tools
|
||||
else ""
|
||||
)
|
||||
|
||||
prompt_start = f'''
|
||||
{{
|
||||
"event": {{
|
||||
@@ -350,7 +389,7 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
"voiceId": "{self._voice_id}",
|
||||
"encoding": "base64",
|
||||
"audioType": "SPEECH"
|
||||
}}
|
||||
}}{tools_config}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
@@ -382,6 +421,9 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
await self._send_client_event(audio_content_start)
|
||||
|
||||
async def _send_text_event(self, text: str, role: Role):
|
||||
if not self._stream:
|
||||
return
|
||||
|
||||
content_name = str(uuid.uuid4())
|
||||
|
||||
text_content_start = f'''
|
||||
@@ -469,6 +511,61 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
"""
|
||||
await self._send_client_event(session_end)
|
||||
|
||||
async def _send_tool_result(self, tool_call_id, result):
|
||||
if not self._stream:
|
||||
return
|
||||
|
||||
# print(f"[pk] sending tool result. tool call ID: {tool_call_id}, result: {result}")
|
||||
|
||||
content_name = str(uuid.uuid4())
|
||||
|
||||
result_content_start = f'''
|
||||
{{
|
||||
"event": {{
|
||||
"contentStart": {{
|
||||
"promptName": "{self._prompt_name}",
|
||||
"contentName": "{content_name}",
|
||||
"interactive": false,
|
||||
"type": "TOOL",
|
||||
"role": "TOOL",
|
||||
"toolResultInputConfiguration": {{
|
||||
"toolUseId": "{tool_call_id}",
|
||||
"type": "TEXT",
|
||||
"textInputConfiguration": {{
|
||||
"mediaType": "text/plain"
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
'''
|
||||
await self._send_client_event(result_content_start)
|
||||
|
||||
result_content = json.dumps(
|
||||
{
|
||||
"event": {
|
||||
"toolResult": {
|
||||
"promptName": self._prompt_name,
|
||||
"contentName": content_name,
|
||||
"content": json.dumps(result) if isinstance(result, dict) else result,
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
await self._send_client_event(result_content)
|
||||
|
||||
result_content_end = f"""
|
||||
{{
|
||||
"event": {{
|
||||
"contentEnd": {{
|
||||
"promptName": "{self._prompt_name}",
|
||||
"contentName": "{content_name}"
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
await self._send_client_event(result_content_end)
|
||||
|
||||
async def _send_client_event(self, event_json: str):
|
||||
event = InvokeModelWithBidirectionalStreamInputChunk(
|
||||
value=BidirectionalInputPayloadPart(bytes_=event_json.encode("utf-8"))
|
||||
@@ -515,6 +612,9 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
elif "audioOutput" in event_json:
|
||||
# Handle audio output content
|
||||
await self._handle_audio_output_event(event_json)
|
||||
elif "toolUse" in event_json:
|
||||
# Handle tool use
|
||||
await self._handle_tool_use_event(event_json)
|
||||
elif "contentEnd" in event_json:
|
||||
# Handle a piece of content ending
|
||||
await self._handle_content_end_event(event_json)
|
||||
@@ -593,6 +693,42 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def _handle_tool_use_event(self, event_json):
|
||||
# This should never happen
|
||||
if not self._content_being_received:
|
||||
return
|
||||
|
||||
# Get tool use details
|
||||
tool_use = event_json["toolUse"]
|
||||
function_name = tool_use["toolName"]
|
||||
tool_call_id = tool_use["toolUseId"]
|
||||
arguments = json.loads(tool_use["content"])
|
||||
|
||||
# print(
|
||||
# f"[pk] tool use - function_name: {function_name}, tool_call_id: {tool_call_id}, arguments: {arguments}"
|
||||
# )
|
||||
|
||||
# Call tool function
|
||||
if self.has_function(function_name):
|
||||
if function_name in self._functions.keys():
|
||||
await self.call_function(
|
||||
context=self._context,
|
||||
tool_call_id=tool_call_id,
|
||||
function_name=function_name,
|
||||
arguments=arguments,
|
||||
)
|
||||
elif None in self._functions.keys():
|
||||
await self.call_function(
|
||||
context=self._context,
|
||||
tool_call_id=tool_call_id,
|
||||
function_name=function_name,
|
||||
arguments=arguments,
|
||||
)
|
||||
else:
|
||||
raise AWSNovaSonicUnhandledFunctionException(
|
||||
f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function."
|
||||
)
|
||||
|
||||
async def _handle_content_end_event(self, event_json):
|
||||
# This should never happen
|
||||
if not self._content_being_received:
|
||||
@@ -671,6 +807,9 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
|
||||
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
|
||||
) -> AWSNovaSonicContextAggregatorPair:
|
||||
context.set_llm_adapter(self.get_llm_adapter())
|
||||
|
||||
user = AWSNovaSonicUserContextAggregator(context=context, params=user_params)
|
||||
assistant = AWSNovaSonicAssistantContextAggregator(context=context, params=assistant_params)
|
||||
|
||||
return AWSNovaSonicContextAggregatorPair(user, assistant)
|
||||
|
||||
@@ -1,12 +1,25 @@
|
||||
#
|
||||
# Copyright (c) 2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import DataFrame, Frame, LLMMessagesUpdateFrame, LLMSetToolsFrame
|
||||
from pipecat.frames.frames import (
|
||||
DataFrame,
|
||||
Frame,
|
||||
FunctionCallResultFrame,
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMSetToolsFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.aws_nova_sonic.frames import AWSNovaSonicFunctionCallResultFrame
|
||||
from pipecat.services.openai.llm import (
|
||||
OpenAIAssistantContextAggregator,
|
||||
OpenAIUserContextAggregator,
|
||||
@@ -106,7 +119,15 @@ class AWSNovaSonicUserContextAggregator(OpenAIUserContextAggregator):
|
||||
|
||||
|
||||
class AWSNovaSonicAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
pass
|
||||
async def handle_function_call_result(self, frame: FunctionCallResultFrame):
|
||||
await super().handle_function_call_result(frame)
|
||||
|
||||
# The standard function callback code path pushes the FunctionCallResultFrame from the llm itself,
|
||||
# so we didn't have a chance to add the result to the openai realtime api context. Let's push a
|
||||
# special frame to do that.
|
||||
await self.push_frame(
|
||||
AWSNovaSonicFunctionCallResultFrame(result_frame=frame), FrameDirection.UPSTREAM
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
14
src/pipecat/services/aws_nova_sonic/frames.py
Normal file
14
src/pipecat/services/aws_nova_sonic/frames.py
Normal file
@@ -0,0 +1,14 @@
|
||||
#
|
||||
# Copyright (c) 2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from pipecat.frames.frames import DataFrame, FunctionCallResultFrame
|
||||
|
||||
|
||||
@dataclass
|
||||
class AWSNovaSonicFunctionCallResultFrame(DataFrame):
|
||||
result_frame: FunctionCallResultFrame
|
||||
Reference in New Issue
Block a user