Update AWSNovaSonicLLMService to work with LLMContext and LLMContextAggregatorPair
This commit is contained in:
12
CHANGELOG.md
12
CHANGELOG.md
@@ -9,6 +9,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Added
|
||||
|
||||
- Expanded support for universal `LLMContext` to `AWSNovaSonicLLMService`.
|
||||
As a reminder, the context-setup pattern when using `LLMContext` is:
|
||||
|
||||
```python
|
||||
context = LLMContext(messages, tools)
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
```
|
||||
|
||||
(Note that even though `AWSNovaSonicLLMService` now supports the universal
|
||||
`LLMContext`, it is not meant to be swapped out for another LLM service at
|
||||
runtime.)
|
||||
|
||||
- Include OpenAI-based LLM services cached tokens to `MetricsFrame`.
|
||||
|
||||
## Fixed
|
||||
|
||||
@@ -67,11 +67,11 @@ async def save_conversation(params: FunctionCallParams):
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
|
||||
filename = f"{BASE_FILENAME}{timestamp}.json"
|
||||
logger.debug(
|
||||
f"writing conversation to {filename}\n{json.dumps(params.context.get_messages(), indent=4)}"
|
||||
f"writing conversation to {filename}\n{json.dumps(params.context.get_messages_for_persistent_storage(), indent=4)}"
|
||||
)
|
||||
try:
|
||||
with open(filename, "w") as file:
|
||||
messages = params.context.get_messages()
|
||||
messages = params.context.get_messages_for_persistent_storage()
|
||||
# remove the last message, which is the instruction we just gave to save the conversation
|
||||
messages.pop()
|
||||
json.dump(messages, file, indent=2)
|
||||
|
||||
@@ -68,12 +68,12 @@ async def save_conversation(params: FunctionCallParams):
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
|
||||
filename = f"{BASE_FILENAME}{timestamp}.json"
|
||||
logger.debug(
|
||||
f"writing conversation to {filename}\n{json.dumps(params.context.get_messages(), indent=4)}"
|
||||
f"writing conversation to {filename}\n{json.dumps(params.context.get_messages_for_persistent_storage(), indent=4)}"
|
||||
)
|
||||
try:
|
||||
with open(filename, "w") as file:
|
||||
# todo: extract 'system' into the first message in the list
|
||||
messages = params.context.get_messages()
|
||||
messages = params.context.get_messages_for_persistent_storage()
|
||||
# remove the last message, which is the instruction we just gave to save the conversation
|
||||
messages.pop()
|
||||
json.dump(messages, file, indent=2)
|
||||
|
||||
@@ -86,12 +86,11 @@ async def save_conversation(params: FunctionCallParams):
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
|
||||
filename = f"{BASE_FILENAME}{timestamp}.json"
|
||||
logger.debug(
|
||||
f"writing conversation to {filename}\n{json.dumps(params.context.get_messages(), indent=4)}"
|
||||
f"writing conversation to {filename}\n{json.dumps(params.context.get_messages_for_persistent_storage(), indent=4)}"
|
||||
)
|
||||
try:
|
||||
with open(filename, "w") as file:
|
||||
# todo: extract 'system' into the first message in the list
|
||||
messages = params.context.get_messages()
|
||||
messages = params.context.get_messages_for_persistent_storage()
|
||||
# remove the last message (the instruction to save the context)
|
||||
messages.pop()
|
||||
json.dump(messages, file, indent=2)
|
||||
|
||||
@@ -20,6 +20,8 @@ from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
@@ -223,13 +225,13 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
llm.register_function("get_saved_conversation_filenames", get_saved_conversation_filenames)
|
||||
llm.register_function("load_conversation", load_conversation)
|
||||
|
||||
context = OpenAILLMContext(
|
||||
context = LLMContext(
|
||||
messages=[
|
||||
{"role": "system", "content": f"{system_instruction}"},
|
||||
],
|
||||
tools=tools,
|
||||
)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
|
||||
@@ -18,7 +18,8 @@ from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.aws_nova_sonic import AWSNovaSonicLLMService
|
||||
@@ -119,9 +120,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
llm.register_function("get_current_weather", fetch_weather_from_api)
|
||||
|
||||
# Set up context and context management.
|
||||
# AWSNovaSonicService will adapt OpenAI LLM context objects with standard message format to
|
||||
# what's expected by Nova Sonic.
|
||||
context = OpenAILLMContext(
|
||||
context = LLMContext(
|
||||
messages=[
|
||||
{"role": "system", "content": f"{system_instruction}"},
|
||||
{
|
||||
@@ -131,7 +130,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
],
|
||||
tools=tools,
|
||||
)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
|
||||
# Build the pipeline
|
||||
pipeline = Pipeline(
|
||||
|
||||
@@ -6,13 +6,47 @@
|
||||
|
||||
"""AWS Nova Sonic LLM adapter for Pipecat."""
|
||||
|
||||
import copy
|
||||
import json
|
||||
from typing import Any, Dict, List, TypedDict
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, TypedDict
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext, LLMContextMessage
|
||||
|
||||
|
||||
class Role(Enum):
|
||||
"""Roles supported in AWS Nova Sonic conversations.
|
||||
|
||||
Parameters:
|
||||
SYSTEM: System-level messages (not used in conversation history).
|
||||
USER: Messages sent by the user.
|
||||
ASSISTANT: Messages sent by the assistant.
|
||||
TOOL: Messages sent by tools (not used in conversation history).
|
||||
"""
|
||||
|
||||
SYSTEM = "SYSTEM"
|
||||
USER = "USER"
|
||||
ASSISTANT = "ASSISTANT"
|
||||
TOOL = "TOOL"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AWSNovaSonicConversationHistoryMessage:
|
||||
"""A single message in AWS Nova Sonic conversation history.
|
||||
|
||||
Parameters:
|
||||
role: The role of the message sender (USER or ASSISTANT only).
|
||||
text: The text content of the message.
|
||||
"""
|
||||
|
||||
role: Role # only USER and ASSISTANT
|
||||
text: str
|
||||
|
||||
|
||||
class AWSNovaSonicLLMInvocationParams(TypedDict):
|
||||
@@ -21,7 +55,9 @@ class AWSNovaSonicLLMInvocationParams(TypedDict):
|
||||
This is a placeholder until support for universal LLMContext machinery is added for AWS Nova Sonic.
|
||||
"""
|
||||
|
||||
pass
|
||||
system_instruction: Optional[str]
|
||||
messages: List[AWSNovaSonicConversationHistoryMessage]
|
||||
tools: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class AWSNovaSonicLLMAdapter(BaseLLMAdapter[AWSNovaSonicLLMInvocationParams]):
|
||||
@@ -34,7 +70,7 @@ class AWSNovaSonicLLMAdapter(BaseLLMAdapter[AWSNovaSonicLLMInvocationParams]):
|
||||
@property
|
||||
def id_for_llm_specific_messages(self) -> str:
|
||||
"""Get the identifier used in LLMSpecificMessage instances for AWS Nova Sonic."""
|
||||
raise NotImplementedError("Universal LLMContext is not yet supported for AWS Nova Sonic.")
|
||||
return "aws-nova-sonic"
|
||||
|
||||
def get_llm_invocation_params(self, context: LLMContext) -> AWSNovaSonicLLMInvocationParams:
|
||||
"""Get AWS Nova Sonic-specific LLM invocation parameters from a universal LLM context.
|
||||
@@ -47,7 +83,13 @@ class AWSNovaSonicLLMAdapter(BaseLLMAdapter[AWSNovaSonicLLMInvocationParams]):
|
||||
Returns:
|
||||
Dictionary of parameters for invoking AWS Nova Sonic's LLM API.
|
||||
"""
|
||||
raise NotImplementedError("Universal LLMContext is not yet supported for AWS Nova Sonic.")
|
||||
messages = self._from_universal_context_messages(self.get_messages(context))
|
||||
return {
|
||||
"system_instruction": messages.system_instruction,
|
||||
"messages": messages.messages,
|
||||
# NOTE: LLMContext's tools are guaranteed to be a ToolsSchema (or NOT_GIVEN)
|
||||
"tools": self.from_standard_tools(context.tools) or [],
|
||||
}
|
||||
|
||||
def get_messages_for_logging(self, context) -> List[Dict[str, Any]]:
|
||||
"""Get messages from a universal LLM context in a format ready for logging about AWS Nova Sonic.
|
||||
@@ -62,7 +104,75 @@ class AWSNovaSonicLLMAdapter(BaseLLMAdapter[AWSNovaSonicLLMInvocationParams]):
|
||||
Returns:
|
||||
List of messages in a format ready for logging about AWS Nova Sonic.
|
||||
"""
|
||||
raise NotImplementedError("Universal LLMContext is not yet supported for AWS Nova Sonic.")
|
||||
return self._from_universal_context_messages(self.get_messages(context)).messages
|
||||
|
||||
@dataclass
|
||||
class ConvertedMessages:
|
||||
"""Container for Google-formatted messages converted from universal context."""
|
||||
|
||||
messages: List[AWSNovaSonicConversationHistoryMessage]
|
||||
system_instruction: Optional[str] = None
|
||||
|
||||
def _from_universal_context_messages(
|
||||
self, universal_context_messages: List[LLMContextMessage]
|
||||
) -> ConvertedMessages:
|
||||
system_instruction = None
|
||||
messages = []
|
||||
|
||||
# Bail if there are no messages
|
||||
if not universal_context_messages:
|
||||
return self.ConvertedMessages()
|
||||
|
||||
universal_context_messages = copy.deepcopy(universal_context_messages)
|
||||
|
||||
# If we have a "system" message as our first message, let's pull that out into "instruction"
|
||||
if universal_context_messages[0].get("role") == "system":
|
||||
system = universal_context_messages.pop(0)
|
||||
content = system.get("content")
|
||||
if isinstance(content, str):
|
||||
system_instruction = content
|
||||
elif isinstance(content, list):
|
||||
system_instruction = content[0].get("text")
|
||||
if system_instruction:
|
||||
self._system_instruction = system_instruction
|
||||
|
||||
# Process remaining messages to fill out conversation history.
|
||||
# Nova Sonic supports "user" and "assistant" messages in history.
|
||||
for universal_context_message in universal_context_messages:
|
||||
message = self._from_universal_context_message(universal_context_message)
|
||||
if message:
|
||||
messages.append(message)
|
||||
|
||||
return self.ConvertedMessages(messages=messages, system_instruction=system_instruction)
|
||||
|
||||
def _from_universal_context_message(self, message) -> AWSNovaSonicConversationHistoryMessage:
|
||||
"""Convert standard message format to Nova Sonic format.
|
||||
|
||||
Args:
|
||||
message: Standard message dictionary to convert.
|
||||
|
||||
Returns:
|
||||
Nova Sonic conversation history message, or None if not convertible.
|
||||
"""
|
||||
role = message.get("role")
|
||||
if message.get("role") == "user" or message.get("role") == "assistant":
|
||||
content = message.get("content")
|
||||
if isinstance(message.get("content"), list):
|
||||
content = ""
|
||||
for c in message.get("content"):
|
||||
if c.get("type") == "text":
|
||||
content += " " + c.get("text")
|
||||
else:
|
||||
logger.error(
|
||||
f"Unhandled content type in context message: {c.get('type')} - {message}"
|
||||
)
|
||||
# There won't be content if this is an assistant tool call entry.
|
||||
# We're ignoring those since they can't be loaded into AWS Nova Sonic conversation
|
||||
# history
|
||||
if content:
|
||||
return AWSNovaSonicConversationHistoryMessage(role=Role[role.upper()], text=content)
|
||||
# NOTE: we're ignoring messages with role "tool" since they can't be loaded into AWS Nova
|
||||
# Sonic conversation history
|
||||
|
||||
@staticmethod
|
||||
def _to_aws_nova_sonic_function_format(function: FunctionSchema) -> Dict[str, Any]:
|
||||
|
||||
@@ -15,9 +15,10 @@ service-specific adapter.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import copy
|
||||
import io
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Optional, TypeAlias, Union
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, TypeAlias, Union
|
||||
|
||||
from loguru import logger
|
||||
from openai._types import NOT_GIVEN as OPEN_AI_NOT_GIVEN
|
||||
@@ -31,6 +32,9 @@ from PIL import Image
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.frames.frames import AudioRawFrame
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
|
||||
# "Re-export" types from OpenAI that we're using as universal context types.
|
||||
# NOTE: if universal message types need to someday diverge from OpenAI's, we
|
||||
# should consider managing our own definitions. But we should do so carefully,
|
||||
@@ -65,6 +69,26 @@ class LLMContext:
|
||||
and content formatting.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def from_openai_context(openai_context: "OpenAILLMContext") -> "LLMContext":
|
||||
"""Create a universal LLM context from an OpenAI-specific context.
|
||||
|
||||
NOTE: this should only be used internally, for facilitating migration
|
||||
from OpenAILLMContext to LLMContext. New user code should use
|
||||
LLMContext directly.
|
||||
|
||||
Args:
|
||||
openai_context: The OpenAI LLM context to convert.
|
||||
|
||||
Returns:
|
||||
New LLMContext instance with converted messages and settings.
|
||||
"""
|
||||
return LLMContext(
|
||||
messages=openai_context.get_messages(),
|
||||
tools=openai_context.tools,
|
||||
tool_choice=openai_context.tool_choice,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
messages: Optional[List[LLMContextMessage]] = None,
|
||||
@@ -107,6 +131,21 @@ class LLMContext:
|
||||
)
|
||||
return filtered_messages
|
||||
|
||||
def get_messages_for_persistent_storage(
|
||||
self, system_instruction: Optional[str] = None
|
||||
) -> List[LLMContextMessage]:
|
||||
"""Get messages formatted for persistent storage.
|
||||
|
||||
Args:
|
||||
system_instruction: Optional system instruction to ensure is
|
||||
included as the first message in the returned list, if not
|
||||
already present.
|
||||
"""
|
||||
messages = copy.deepcopy(self.get_messages())
|
||||
if system_instruction and (not messages or messages[0].get("role") != "system"):
|
||||
messages.insert(0, {"role": "system", "content": system_instruction})
|
||||
return messages
|
||||
|
||||
@property
|
||||
def tools(self) -> ToolsSchema | NotGiven:
|
||||
"""Get the tools list.
|
||||
|
||||
@@ -25,7 +25,7 @@ from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.adapters.services.aws_nova_sonic_adapter import AWSNovaSonicLLMAdapter
|
||||
from pipecat.adapters.services.aws_nova_sonic_adapter import AWSNovaSonicLLMAdapter, Role
|
||||
from pipecat.frames.frames import (
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
@@ -33,36 +33,36 @@ from pipecat.frames.frames import (
|
||||
Frame,
|
||||
FunctionCallFromLLM,
|
||||
InputAudioRawFrame,
|
||||
InterimTranscriptionFrame,
|
||||
InterruptionFrame,
|
||||
LLMContextFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMTextFrame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
TTSTextFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.aws_nova_sonic.context import (
|
||||
AWSNovaSonicAssistantContextAggregator,
|
||||
AWSNovaSonicContextAggregatorPair,
|
||||
AWSNovaSonicLLMContext,
|
||||
AWSNovaSonicUserContextAggregator,
|
||||
Role,
|
||||
)
|
||||
from pipecat.services.aws_nova_sonic.frames import AWSNovaSonicFunctionCallResultFrame
|
||||
from pipecat.services.llm_service import LLMService
|
||||
from pipecat.services.openai.llm import (
|
||||
OpenAIAssistantContextAggregator,
|
||||
OpenAIContextAggregatorPair,
|
||||
OpenAIUserContextAggregator,
|
||||
)
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
try:
|
||||
@@ -231,7 +231,7 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
self._system_instruction = system_instruction
|
||||
self._tools = tools
|
||||
self._send_transcription_frames = send_transcription_frames
|
||||
self._context: Optional[AWSNovaSonicLLMContext] = None
|
||||
self._context: Optional[LLMContext] = None
|
||||
self._stream: Optional[
|
||||
DuplexEventStream[
|
||||
InvokeModelWithBidirectionalStreamInput,
|
||||
@@ -244,12 +244,17 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
self._input_audio_content_name: Optional[str] = None
|
||||
self._content_being_received: Optional[CurrentContent] = None
|
||||
self._assistant_is_responding = False
|
||||
self._needs_re_push_assistant_text = False
|
||||
self._ready_to_send_context = False
|
||||
self._handling_bot_stopped_speaking = False
|
||||
self._triggering_assistant_response = False
|
||||
self._waiting_for_trigger_transcription = False
|
||||
self._disconnecting = False
|
||||
self._connected_time: Optional[float] = None
|
||||
self._wants_connection = False
|
||||
self._user_text_buffer = ""
|
||||
self._assistant_text_buffer = ""
|
||||
self._completed_tool_calls = set()
|
||||
|
||||
file_path = files("pipecat.services.aws_nova_sonic").joinpath("ready.wav")
|
||||
with wave.open(file_path.open("rb"), "rb") as wav_file:
|
||||
@@ -302,12 +307,12 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
logger.debug("Resetting conversation")
|
||||
await self._handle_bot_stopped_speaking(delay_to_catch_trailing_assistant_text=False)
|
||||
|
||||
# Carry over previous context through disconnect
|
||||
# Grab context to carry through disconnect/reconnect
|
||||
context = self._context
|
||||
await self._disconnect()
|
||||
self._context = context
|
||||
|
||||
await self._disconnect()
|
||||
await self._start_connecting()
|
||||
await self._handle_context(context)
|
||||
|
||||
#
|
||||
# frame processing
|
||||
@@ -322,28 +327,35 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, OpenAILLMContextFrame):
|
||||
await self._handle_context(frame.context)
|
||||
elif isinstance(frame, LLMContextFrame):
|
||||
raise NotImplementedError(
|
||||
"Universal LLMContext is not yet supported for AWS Nova Sonic."
|
||||
if isinstance(frame, (LLMContextFrame, OpenAILLMContextFrame)):
|
||||
context = (
|
||||
frame.context
|
||||
if isinstance(frame, LLMContextFrame)
|
||||
else LLMContext.from_openai_context(frame.context)
|
||||
)
|
||||
await self._handle_context(context)
|
||||
elif isinstance(frame, InputAudioRawFrame):
|
||||
await self._handle_input_audio_frame(frame)
|
||||
elif isinstance(frame, BotStoppedSpeakingFrame):
|
||||
await self._handle_bot_stopped_speaking(delay_to_catch_trailing_assistant_text=True)
|
||||
elif isinstance(frame, AWSNovaSonicFunctionCallResultFrame):
|
||||
await self._handle_function_call_result(frame)
|
||||
elif isinstance(frame, InterruptionFrame):
|
||||
await self._handle_interruption_frame()
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _handle_context(self, context: OpenAILLMContext):
|
||||
async def _handle_context(self, context: LLMContext):
|
||||
if self._disconnecting:
|
||||
return
|
||||
|
||||
if not self._context:
|
||||
# We got our initial context - try to finish connecting
|
||||
self._context = AWSNovaSonicLLMContext.upgrade_to_nova_sonic(
|
||||
context, self._system_instruction
|
||||
)
|
||||
# We got our initial context
|
||||
# Try to finish connecting
|
||||
self._context = context
|
||||
await self._finish_connecting_if_context_available()
|
||||
else:
|
||||
# We got an updated context
|
||||
# Send results for any newly-completed function calls
|
||||
await self._process_completed_function_calls(send_new_results=True)
|
||||
|
||||
async def _handle_input_audio_frame(self, frame: InputAudioRawFrame):
|
||||
# Wait until we're done sending the assistant response trigger audio before sending audio
|
||||
@@ -393,9 +405,9 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
else:
|
||||
await finalize_assistant_response()
|
||||
|
||||
async def _handle_function_call_result(self, frame: AWSNovaSonicFunctionCallResultFrame):
|
||||
result = frame.result_frame
|
||||
await self._send_tool_result(tool_call_id=result.tool_call_id, result=result.result)
|
||||
async def _handle_interruption_frame(self):
|
||||
if self._assistant_is_responding:
|
||||
self._needs_re_push_assistant_text = True
|
||||
|
||||
#
|
||||
# LLM communication: lifecycle
|
||||
@@ -431,6 +443,17 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
logger.error(f"{self} initialization error: {e}")
|
||||
self._disconnect()
|
||||
|
||||
async def _process_completed_function_calls(self, send_new_results: bool):
|
||||
# Check for set of completed function calls in the context
|
||||
for message in self._context.get_messages():
|
||||
if message.get("role") and message.get("content") != "IN_PROGRESS":
|
||||
tool_call_id = message.get("tool_call_id")
|
||||
if tool_call_id and tool_call_id not in self._completed_tool_calls:
|
||||
# Found a newly-completed function call - send the result to the service
|
||||
if send_new_results:
|
||||
await self._send_tool_result(tool_call_id, message.get("content"))
|
||||
self._completed_tool_calls.add(tool_call_id)
|
||||
|
||||
async def _finish_connecting_if_context_available(self):
|
||||
# We can only finish connecting once we've gotten our initial context and we're ready to
|
||||
# send it
|
||||
@@ -439,30 +462,38 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
|
||||
logger.info("Finishing connecting (setting up session)...")
|
||||
|
||||
# Initialize our bookkeeping of already-completed tool calls in the
|
||||
# context
|
||||
await self._process_completed_function_calls(send_new_results=False)
|
||||
|
||||
# Read context
|
||||
history = self._context.get_messages_for_initializing_history()
|
||||
adapter: AWSNovaSonicLLMAdapter = self.get_llm_adapter()
|
||||
llm_connection_params = adapter.get_llm_invocation_params(self._context)
|
||||
|
||||
# Send prompt start event, specifying tools.
|
||||
# Tools from context take priority over self._tools.
|
||||
tools = (
|
||||
self._context.tools
|
||||
if self._context.tools
|
||||
else self.get_llm_adapter().from_standard_tools(self._tools)
|
||||
llm_connection_params["tools"]
|
||||
if llm_connection_params["tools"]
|
||||
else adapter.from_standard_tools(self._tools)
|
||||
)
|
||||
logger.debug(f"Using tools: {tools}")
|
||||
await self._send_prompt_start_event(tools)
|
||||
|
||||
# Send system instruction.
|
||||
# Instruction from context takes priority over self._system_instruction.
|
||||
# (NOTE: this prioritizing occurred automatically behind the scenes: the context was
|
||||
# initialized with self._system_instruction and then updated itself from its messages when
|
||||
# get_messages_for_initializing_history() was called).
|
||||
logger.debug(f"Using system instruction: {history.system_instruction}")
|
||||
if history.system_instruction:
|
||||
await self._send_text_event(text=history.system_instruction, role=Role.SYSTEM)
|
||||
system_instruction = (
|
||||
llm_connection_params["system_instruction"]
|
||||
if llm_connection_params["system_instruction"]
|
||||
else self._system_instruction
|
||||
)
|
||||
logger.debug(f"Using system instruction: {system_instruction}")
|
||||
if system_instruction:
|
||||
await self._send_text_event(text=system_instruction, role=Role.SYSTEM)
|
||||
|
||||
# Send conversation history
|
||||
for message in history.messages:
|
||||
for message in llm_connection_params["messages"]:
|
||||
# logger.debug(f"Seeding conversation history with message: {message}")
|
||||
await self._send_text_event(text=message.text, role=message.role)
|
||||
|
||||
# Start audio input
|
||||
@@ -492,9 +523,12 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
await self._send_session_end_events()
|
||||
self._client = None
|
||||
|
||||
# Clean up context
|
||||
self._context = None
|
||||
|
||||
# Clean up stream
|
||||
if self._stream:
|
||||
await self._stream.input_stream.close()
|
||||
await self._stream.close()
|
||||
self._stream = None
|
||||
|
||||
# NOTE: see explanation of HACK, below
|
||||
@@ -510,15 +544,23 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
self._receive_task = None
|
||||
|
||||
# Reset remaining connection-specific state
|
||||
# Should be all private state except:
|
||||
# - _wants_connection
|
||||
# - _assistant_response_trigger_audio
|
||||
self._prompt_name = None
|
||||
self._input_audio_content_name = None
|
||||
self._content_being_received = None
|
||||
self._assistant_is_responding = False
|
||||
self._needs_re_push_assistant_text = False
|
||||
self._ready_to_send_context = False
|
||||
self._handling_bot_stopped_speaking = False
|
||||
self._triggering_assistant_response = False
|
||||
self._waiting_for_trigger_transcription = False
|
||||
self._disconnecting = False
|
||||
self._connected_time = None
|
||||
self._user_text_buffer = ""
|
||||
self._assistant_text_buffer = ""
|
||||
self._completed_tool_calls = set()
|
||||
|
||||
logger.info("Finished disconnecting")
|
||||
except Exception as e:
|
||||
@@ -830,6 +872,10 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
# Handle the LLM completion ending
|
||||
await self._handle_completion_end_event(event_json)
|
||||
except Exception as e:
|
||||
if self._disconnecting:
|
||||
# Errors are kind of expected while disconnecting, so just
|
||||
# ignore them and do nothing
|
||||
return
|
||||
logger.error(f"{self} error processing responses: {e}")
|
||||
if self._wants_connection:
|
||||
await self.reset_conversation()
|
||||
@@ -960,7 +1006,7 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
async def _report_assistant_response_started(self):
|
||||
logger.debug("Assistant response started")
|
||||
|
||||
# Report that the assistant has started their response.
|
||||
# Report the start of the assistant response.
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
|
||||
# Report that equivalent of TTS (this is a speech-to-speech model) started
|
||||
@@ -972,23 +1018,16 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
|
||||
logger.debug(f"Assistant response text added: {text}")
|
||||
|
||||
# Report some text added to the ongoing assistant response
|
||||
await self.push_frame(LLMTextFrame(text))
|
||||
|
||||
# Report some text added to the *equivalent* of TTS (this is a speech-to-speech model)
|
||||
# Report the text of the assistant response.
|
||||
await self.push_frame(TTSTextFrame(text))
|
||||
|
||||
# TODO: this is a (hopefully temporary) HACK. Here we directly manipulate the context rather
|
||||
# than relying on the frames pushed to the assistant context aggregator. The pattern of
|
||||
# receiving full-sentence text after the assistant has spoken does not easily fit with the
|
||||
# Pipecat expectation of chunks of text streaming in while the assistant is speaking.
|
||||
# Interruption handling was especially challenging. Rather than spend days trying to fit a
|
||||
# square peg in a round hole, I decided on this hack for the time being. We can most cleanly
|
||||
# abandon this hack if/when AWS Nova Sonic implements streaming smaller text chunks
|
||||
# interspersed with audio. Note that when we move away from this hack, we need to make sure
|
||||
# that on an interruption we avoid sending LLMFullResponseEndFrame, which gets the
|
||||
# LLMAssistantContextAggregator into a bad state.
|
||||
self._context.buffer_assistant_text(text)
|
||||
# HACK: here we're also buffering the assistant text ourselves as a
|
||||
# backup rather than relying solely on the assistant context aggregator
|
||||
# to do it, because the text arrives from Nova Sonic only after all the
|
||||
# assistant audio frames have been pushed, meaning that if an
|
||||
# interruption frame were to arrive we would lose all of it (the text
|
||||
# frames sitting in the queue would be wiped).
|
||||
self._assistant_text_buffer += text
|
||||
|
||||
async def _report_assistant_response_ended(self):
|
||||
if not self._context: # should never happen
|
||||
@@ -996,14 +1035,25 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
|
||||
logger.debug("Assistant response ended")
|
||||
|
||||
# Report that the assistant has finished their response.
|
||||
# If an interruption frame arrived while the assistant was responding
|
||||
# we probably lost all of the assistant text (see HACK, above), so
|
||||
# re-push it downstream to the aggregator now.
|
||||
if self._needs_re_push_assistant_text:
|
||||
# We also need to re-push the LLMFullResponseStartFrame since the
|
||||
# TTSTextFrame would be ignored otherwise (the interruption frame
|
||||
# would have cleared the assistant aggregator state).
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
await self.push_frame(TTSTextFrame(self._assistant_text_buffer))
|
||||
self._needs_re_push_assistant_text = False
|
||||
|
||||
# Report the end of the assistant response.
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
|
||||
# Report that equivalent of TTS (this is a speech-to-speech model) stopped.
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
|
||||
# For an explanation of this hack, see _report_assistant_response_text_added.
|
||||
self._context.flush_aggregated_assistant_text()
|
||||
# Clear out the buffered assistant text
|
||||
self._assistant_text_buffer = ""
|
||||
|
||||
#
|
||||
# user transcription reporting
|
||||
@@ -1020,33 +1070,71 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
|
||||
logger.debug(f"User transcription text added: {text}")
|
||||
|
||||
# Manually add new user transcription text to context.
|
||||
# We can't rely on the user context aggregator to do this since it's upstream from the LLM.
|
||||
self._context.buffer_user_text(text)
|
||||
|
||||
# Report that some new user transcription text is available.
|
||||
if self._send_transcription_frames:
|
||||
await self.push_frame(
|
||||
InterimTranscriptionFrame(text=text, user_id="", timestamp=time_now_iso8601())
|
||||
)
|
||||
# HACK: here we're buffering the user text ourselves rather than
|
||||
# relying on the upstream user context aggregator to do it, because the
|
||||
# text arrives in fairly large chunks spaced fairly far apart in time.
|
||||
# That means the user text would be split between different messages in
|
||||
# context. Even if we sent placeholder InterimTranscriptionFrames in
|
||||
# between each TranscriptionFrame to tell the aggregator to hold off on
|
||||
# finalizing the user message, the aggregator would likely get the last
|
||||
# chunk too late.
|
||||
self._user_text_buffer += f" {text}" if self._user_text_buffer else text
|
||||
|
||||
async def _report_user_transcription_ended(self):
|
||||
if not self._context: # should never happen
|
||||
return
|
||||
|
||||
# Manually add user transcription to context (if any has been buffered).
|
||||
# We can't rely on the user context aggregator to do this since it's upstream from the LLM.
|
||||
transcription = self._context.flush_aggregated_user_text()
|
||||
|
||||
if not transcription:
|
||||
return
|
||||
|
||||
logger.debug(f"User transcription ended")
|
||||
|
||||
if self._send_transcription_frames:
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(text=transcription, user_id="", timestamp=time_now_iso8601())
|
||||
# Report to the upstream user context aggregator that some new user
|
||||
# transcription text is available.
|
||||
|
||||
# HACK: Check if this transcription was triggered by our own
|
||||
# assistant response trigger. If so, we need to wrap it with
|
||||
# UserStarted/StoppedSpeakingFrames; otherwise the user aggregator
|
||||
# would fire an EmulatedUserStartedSpeakingFrame, which would
|
||||
# trigger an interruption, which would prevent us from writing the
|
||||
# assistant response to context.
|
||||
#
|
||||
# Sending an EmulateUserStartedSpeakingFrame ourselves doesn't
|
||||
# work: it just causes the interruption we're trying to avoid.
|
||||
#
|
||||
# Setting enable_emulated_vad_interruptions also doesn't work: at
|
||||
# the time the user aggregator receives the TranscriptionFrame, it
|
||||
# doesn't yet know the assistant has started responding, so it
|
||||
# doesn't know that emulating the user starting to speak would
|
||||
# cause an interruption.
|
||||
should_wrap_in_user_started_stopped_speaking_frames = (
|
||||
self._waiting_for_trigger_transcription
|
||||
and self._user_text_buffer.strip().lower() == "ready"
|
||||
)
|
||||
|
||||
# Start wrapping the upstream transcription in UserStarted/StoppedSpeakingFrames if needed
|
||||
if should_wrap_in_user_started_stopped_speaking_frames:
|
||||
logger.debug(
|
||||
"Wrapping assistant response trigger transcription with upstream UserStarted/StoppedSpeakingFrames"
|
||||
)
|
||||
await self.push_frame(UserStartedSpeakingFrame(), direction=FrameDirection.UPSTREAM)
|
||||
|
||||
# Send the transcription upstream for the user context aggregator
|
||||
frame = TranscriptionFrame(
|
||||
text=self._user_text_buffer, user_id="", timestamp=time_now_iso8601()
|
||||
)
|
||||
await self.push_frame(frame, direction=FrameDirection.UPSTREAM)
|
||||
|
||||
# Finish wrapping the upstream transcription in UserStarted/StoppedSpeakingFrames if needed
|
||||
if should_wrap_in_user_started_stopped_speaking_frames:
|
||||
await self.push_frame(UserStoppedSpeakingFrame(), direction=FrameDirection.UPSTREAM)
|
||||
|
||||
# Also send the transcription downstream if requested
|
||||
if self._send_transcription_frames:
|
||||
await self.push_frame(frame, direction=FrameDirection.DOWNSTREAM)
|
||||
|
||||
# Clear out the buffered user text
|
||||
self._user_text_buffer = ""
|
||||
|
||||
# We're no longer waiting for a trigger transcription
|
||||
self._waiting_for_trigger_transcription = False
|
||||
|
||||
#
|
||||
# context
|
||||
@@ -1058,23 +1146,26 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
*,
|
||||
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
|
||||
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
|
||||
) -> AWSNovaSonicContextAggregatorPair:
|
||||
) -> LLMContextAggregatorPair:
|
||||
"""Create context aggregator pair for managing conversation context.
|
||||
|
||||
NOTE: this method exists only for backward compatibility. New code
|
||||
should instead do:
|
||||
context = LLMContext(...)
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
|
||||
Args:
|
||||
context: The OpenAI LLM context to upgrade.
|
||||
context: The OpenAI LLM context.
|
||||
user_params: Parameters for the user context aggregator.
|
||||
assistant_params: Parameters for the assistant context aggregator.
|
||||
|
||||
Returns:
|
||||
A pair of user and assistant context aggregators.
|
||||
"""
|
||||
context.set_llm_adapter(self.get_llm_adapter())
|
||||
|
||||
user = AWSNovaSonicUserContextAggregator(context=context, params=user_params)
|
||||
assistant = AWSNovaSonicAssistantContextAggregator(context=context, params=assistant_params)
|
||||
|
||||
return AWSNovaSonicContextAggregatorPair(user, assistant)
|
||||
context = LLMContext.from_openai_context(context)
|
||||
return LLMContextAggregatorPair(
|
||||
context, user_params=user_params, assistant_params=assistant_params
|
||||
)
|
||||
|
||||
#
|
||||
# assistant response trigger (HACK)
|
||||
@@ -1112,6 +1203,8 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
try:
|
||||
logger.debug("Sending assistant response trigger...")
|
||||
|
||||
self._waiting_for_trigger_transcription = True
|
||||
|
||||
chunk_duration = 0.02 # what we might get from InputAudioRawFrame
|
||||
chunk_size = int(
|
||||
chunk_duration
|
||||
|
||||
@@ -1,367 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Context management for AWS Nova Sonic LLM service.
|
||||
|
||||
This module provides specialized context aggregators and message handling for AWS Nova Sonic,
|
||||
including conversation history management and role-specific message processing.
|
||||
"""
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotStoppedSpeakingFrame,
|
||||
DataFrame,
|
||||
Frame,
|
||||
FunctionCallResultFrame,
|
||||
InterruptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesAppendFrame,
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMSetToolChoiceFrame,
|
||||
LLMSetToolsFrame,
|
||||
TextFrame,
|
||||
UserImageRawFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.aws_nova_sonic.frames import AWSNovaSonicFunctionCallResultFrame
|
||||
from pipecat.services.openai.llm import (
|
||||
OpenAIAssistantContextAggregator,
|
||||
OpenAIUserContextAggregator,
|
||||
)
|
||||
|
||||
|
||||
class Role(Enum):
|
||||
"""Roles supported in AWS Nova Sonic conversations.
|
||||
|
||||
Parameters:
|
||||
SYSTEM: System-level messages (not used in conversation history).
|
||||
USER: Messages sent by the user.
|
||||
ASSISTANT: Messages sent by the assistant.
|
||||
TOOL: Messages sent by tools (not used in conversation history).
|
||||
"""
|
||||
|
||||
SYSTEM = "SYSTEM"
|
||||
USER = "USER"
|
||||
ASSISTANT = "ASSISTANT"
|
||||
TOOL = "TOOL"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AWSNovaSonicConversationHistoryMessage:
|
||||
"""A single message in AWS Nova Sonic conversation history.
|
||||
|
||||
Parameters:
|
||||
role: The role of the message sender (USER or ASSISTANT only).
|
||||
text: The text content of the message.
|
||||
"""
|
||||
|
||||
role: Role # only USER and ASSISTANT
|
||||
text: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class AWSNovaSonicConversationHistory:
|
||||
"""Complete conversation history for AWS Nova Sonic initialization.
|
||||
|
||||
Parameters:
|
||||
system_instruction: System-level instruction for the conversation.
|
||||
messages: List of conversation messages between user and assistant.
|
||||
"""
|
||||
|
||||
system_instruction: str = None
|
||||
messages: list[AWSNovaSonicConversationHistoryMessage] = field(default_factory=list)
|
||||
|
||||
|
||||
class AWSNovaSonicLLMContext(OpenAILLMContext):
|
||||
"""Specialized LLM context for AWS Nova Sonic service.
|
||||
|
||||
Extends OpenAI context with Nova Sonic-specific message handling,
|
||||
conversation history management, and text buffering capabilities.
|
||||
"""
|
||||
|
||||
def __init__(self, messages=None, tools=None, **kwargs):
|
||||
"""Initialize AWS Nova Sonic LLM context.
|
||||
|
||||
Args:
|
||||
messages: Initial messages for the context.
|
||||
tools: Available tools for the context.
|
||||
**kwargs: Additional arguments passed to parent class.
|
||||
"""
|
||||
super().__init__(messages=messages, tools=tools, **kwargs)
|
||||
self.__setup_local()
|
||||
|
||||
def __setup_local(self, system_instruction: str = ""):
|
||||
self._assistant_text = ""
|
||||
self._user_text = ""
|
||||
self._system_instruction = system_instruction
|
||||
|
||||
@staticmethod
|
||||
def upgrade_to_nova_sonic(
|
||||
obj: OpenAILLMContext, system_instruction: str
|
||||
) -> "AWSNovaSonicLLMContext":
|
||||
"""Upgrade an OpenAI context to AWS Nova Sonic context.
|
||||
|
||||
Args:
|
||||
obj: The OpenAI context to upgrade.
|
||||
system_instruction: System instruction for the context.
|
||||
|
||||
Returns:
|
||||
The upgraded AWS Nova Sonic context.
|
||||
"""
|
||||
if isinstance(obj, OpenAILLMContext) and not isinstance(obj, AWSNovaSonicLLMContext):
|
||||
obj.__class__ = AWSNovaSonicLLMContext
|
||||
obj.__setup_local(system_instruction)
|
||||
return obj
|
||||
|
||||
# NOTE: this method has the side-effect of updating _system_instruction from messages
|
||||
def get_messages_for_initializing_history(self) -> AWSNovaSonicConversationHistory:
|
||||
"""Get conversation history for initializing AWS Nova Sonic session.
|
||||
|
||||
Processes stored messages and extracts system instruction and conversation
|
||||
history in the format expected by AWS Nova Sonic.
|
||||
|
||||
Returns:
|
||||
Formatted conversation history with system instruction and messages.
|
||||
"""
|
||||
history = AWSNovaSonicConversationHistory(system_instruction=self._system_instruction)
|
||||
|
||||
# Bail if there are no messages
|
||||
if not self.messages:
|
||||
return history
|
||||
|
||||
messages = copy.deepcopy(self.messages)
|
||||
|
||||
# If we have a "system" message as our first message, let's pull that out into "instruction"
|
||||
if messages[0].get("role") == "system":
|
||||
system = messages.pop(0)
|
||||
content = system.get("content")
|
||||
if isinstance(content, str):
|
||||
history.system_instruction = content
|
||||
elif isinstance(content, list):
|
||||
history.system_instruction = content[0].get("text")
|
||||
if history.system_instruction:
|
||||
self._system_instruction = history.system_instruction
|
||||
|
||||
# Process remaining messages to fill out conversation history.
|
||||
# Nova Sonic supports "user" and "assistant" messages in history.
|
||||
for message in messages:
|
||||
history_message = self.from_standard_message(message)
|
||||
if history_message:
|
||||
history.messages.append(history_message)
|
||||
|
||||
return history
|
||||
|
||||
def get_messages_for_persistent_storage(self):
|
||||
"""Get messages formatted for persistent storage.
|
||||
|
||||
Returns:
|
||||
List of messages including system instruction if present.
|
||||
"""
|
||||
messages = super().get_messages_for_persistent_storage()
|
||||
# If we have a system instruction and messages doesn't already contain it, add it
|
||||
if self._system_instruction and not (messages and messages[0].get("role") == "system"):
|
||||
messages.insert(0, {"role": "system", "content": self._system_instruction})
|
||||
return messages
|
||||
|
||||
def from_standard_message(self, message) -> AWSNovaSonicConversationHistoryMessage:
|
||||
"""Convert standard message format to Nova Sonic format.
|
||||
|
||||
Args:
|
||||
message: Standard message dictionary to convert.
|
||||
|
||||
Returns:
|
||||
Nova Sonic conversation history message, or None if not convertible.
|
||||
"""
|
||||
role = message.get("role")
|
||||
if message.get("role") == "user" or message.get("role") == "assistant":
|
||||
content = message.get("content")
|
||||
if isinstance(message.get("content"), list):
|
||||
content = ""
|
||||
for c in message.get("content"):
|
||||
if c.get("type") == "text":
|
||||
content += " " + c.get("text")
|
||||
else:
|
||||
logger.error(
|
||||
f"Unhandled content type in context message: {c.get('type')} - {message}"
|
||||
)
|
||||
# There won't be content if this is an assistant tool call entry.
|
||||
# We're ignoring those since they can't be loaded into AWS Nova Sonic conversation
|
||||
# history
|
||||
if content:
|
||||
return AWSNovaSonicConversationHistoryMessage(role=Role[role.upper()], text=content)
|
||||
# NOTE: we're ignoring messages with role "tool" since they can't be loaded into AWS Nova
|
||||
# Sonic conversation history
|
||||
|
||||
def buffer_user_text(self, text):
|
||||
"""Buffer user text for later flushing to context.
|
||||
|
||||
Args:
|
||||
text: User text to buffer.
|
||||
"""
|
||||
self._user_text += f" {text}" if self._user_text else text
|
||||
# logger.debug(f"User text buffered: {self._user_text}")
|
||||
|
||||
def flush_aggregated_user_text(self) -> str:
|
||||
"""Flush buffered user text to context as a complete message.
|
||||
|
||||
Returns:
|
||||
The flushed user text, or empty string if no text was buffered.
|
||||
"""
|
||||
if not self._user_text:
|
||||
return ""
|
||||
user_text = self._user_text
|
||||
message = {
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": user_text}],
|
||||
}
|
||||
self._user_text = ""
|
||||
self.add_message(message)
|
||||
# logger.debug(f"Context updated (user): {self.get_messages_for_logging()}")
|
||||
return user_text
|
||||
|
||||
def buffer_assistant_text(self, text):
|
||||
"""Buffer assistant text for later flushing to context.
|
||||
|
||||
Args:
|
||||
text: Assistant text to buffer.
|
||||
"""
|
||||
self._assistant_text += text
|
||||
# logger.debug(f"Assistant text buffered: {self._assistant_text}")
|
||||
|
||||
def flush_aggregated_assistant_text(self):
|
||||
"""Flush buffered assistant text to context as a complete message."""
|
||||
if not self._assistant_text:
|
||||
return
|
||||
message = {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": self._assistant_text}],
|
||||
}
|
||||
self._assistant_text = ""
|
||||
self.add_message(message)
|
||||
# logger.debug(f"Context updated (assistant): {self.get_messages_for_logging()}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class AWSNovaSonicMessagesUpdateFrame(DataFrame):
|
||||
"""Frame containing updated AWS Nova Sonic context.
|
||||
|
||||
Parameters:
|
||||
context: The updated AWS Nova Sonic LLM context.
|
||||
"""
|
||||
|
||||
context: AWSNovaSonicLLMContext
|
||||
|
||||
|
||||
class AWSNovaSonicUserContextAggregator(OpenAIUserContextAggregator):
|
||||
"""Context aggregator for user messages in AWS Nova Sonic conversations.
|
||||
|
||||
Extends the OpenAI user context aggregator to emit Nova Sonic-specific
|
||||
context update frames.
|
||||
"""
|
||||
|
||||
async def process_frame(
|
||||
self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM
|
||||
):
|
||||
"""Process frames and emit Nova Sonic-specific context updates.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction the frame is traveling.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# Parent does not push LLMMessagesUpdateFrame
|
||||
if isinstance(frame, LLMMessagesUpdateFrame):
|
||||
await self.push_frame(AWSNovaSonicMessagesUpdateFrame(context=self._context))
|
||||
|
||||
|
||||
class AWSNovaSonicAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
"""Context aggregator for assistant messages in AWS Nova Sonic conversations.
|
||||
|
||||
Provides specialized handling for assistant responses and function calls
|
||||
in AWS Nova Sonic context, with custom frame processing logic.
|
||||
"""
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process frames with Nova Sonic-specific logic.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction the frame is traveling.
|
||||
"""
|
||||
# HACK: For now, disable the context aggregator by making it just pass through all frames
|
||||
# that the parent handles (except the function call stuff, which we still need).
|
||||
# For an explanation of this hack, see
|
||||
# AWSNovaSonicLLMService._report_assistant_response_text_added.
|
||||
if isinstance(
|
||||
frame,
|
||||
(
|
||||
InterruptionFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
TextFrame,
|
||||
LLMMessagesAppendFrame,
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMSetToolsFrame,
|
||||
LLMSetToolChoiceFrame,
|
||||
UserImageRawFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
),
|
||||
):
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
async def handle_function_call_result(self, frame: FunctionCallResultFrame):
|
||||
"""Handle function call results for AWS Nova Sonic.
|
||||
|
||||
Args:
|
||||
frame: The function call result frame to handle.
|
||||
"""
|
||||
await super().handle_function_call_result(frame)
|
||||
|
||||
# The standard function callback code path pushes the FunctionCallResultFrame from the LLM
|
||||
# itself, so we didn't have a chance to add the result to the AWS Nova Sonic server-side
|
||||
# context. Let's push a special frame to do that.
|
||||
await self.push_frame(
|
||||
AWSNovaSonicFunctionCallResultFrame(result_frame=frame), FrameDirection.UPSTREAM
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AWSNovaSonicContextAggregatorPair:
|
||||
"""Pair of user and assistant context aggregators for AWS Nova Sonic.
|
||||
|
||||
Parameters:
|
||||
_user: The user context aggregator.
|
||||
_assistant: The assistant context aggregator.
|
||||
"""
|
||||
|
||||
_user: AWSNovaSonicUserContextAggregator
|
||||
_assistant: AWSNovaSonicAssistantContextAggregator
|
||||
|
||||
def user(self) -> AWSNovaSonicUserContextAggregator:
|
||||
"""Get the user context aggregator.
|
||||
|
||||
Returns:
|
||||
The user context aggregator instance.
|
||||
"""
|
||||
return self._user
|
||||
|
||||
def assistant(self) -> AWSNovaSonicAssistantContextAggregator:
|
||||
"""Get the assistant context aggregator.
|
||||
|
||||
Returns:
|
||||
The assistant context aggregator instance.
|
||||
"""
|
||||
return self._assistant
|
||||
208
tests/test_llm_context.py
Normal file
208
tests/test_llm_context.py
Normal file
@@ -0,0 +1,208 @@
|
||||
#
|
||||
# Copyright (c) 2024-2025 Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import unittest
|
||||
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext, LLMSpecificMessage
|
||||
|
||||
|
||||
class TestGetMessagesForPersistentStorage(unittest.TestCase):
|
||||
"""Test suite for LLMContext.get_messages_for_persistent_storage method."""
|
||||
|
||||
def test_no_system_instruction_returns_messages_as_is(self):
|
||||
"""Test that without system instruction, messages are returned unchanged."""
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
result = context.get_messages_for_persistent_storage()
|
||||
|
||||
self.assertEqual(result, messages)
|
||||
self.assertEqual(len(result), 2)
|
||||
|
||||
def test_empty_messages_with_system_instruction_adds_system_message(self):
|
||||
"""Test that system instruction is added when messages list is empty."""
|
||||
context = LLMContext()
|
||||
system_instruction = "You are a helpful assistant."
|
||||
|
||||
result = context.get_messages_for_persistent_storage(system_instruction)
|
||||
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result[0]["role"], "system")
|
||||
self.assertEqual(result[0]["content"], system_instruction)
|
||||
|
||||
def test_non_system_first_message_prepends_system_instruction(self):
|
||||
"""Test that system instruction is prepended when first message is not system."""
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
context = LLMContext(messages=messages)
|
||||
system_instruction = "You are a helpful assistant."
|
||||
|
||||
result = context.get_messages_for_persistent_storage(system_instruction)
|
||||
|
||||
self.assertEqual(len(result), 3)
|
||||
self.assertEqual(result[0]["role"], "system")
|
||||
self.assertEqual(result[0]["content"], system_instruction)
|
||||
self.assertEqual(result[1], messages[0])
|
||||
self.assertEqual(result[2], messages[1])
|
||||
|
||||
def test_existing_system_message_not_duplicated(self):
|
||||
"""Test that system instruction is not added when first message is already system."""
|
||||
messages = [
|
||||
{"role": "system", "content": "Existing system message"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
context = LLMContext(messages=messages)
|
||||
system_instruction = "You are a helpful assistant."
|
||||
|
||||
result = context.get_messages_for_persistent_storage(system_instruction)
|
||||
|
||||
self.assertEqual(len(result), 3)
|
||||
self.assertEqual(result, messages)
|
||||
self.assertEqual(result[0]["role"], "system")
|
||||
self.assertEqual(result[0]["content"], "Existing system message")
|
||||
|
||||
def test_empty_system_instruction_does_not_add_message(self):
|
||||
"""Test that empty system instruction does not add a system message."""
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
result = context.get_messages_for_persistent_storage("")
|
||||
|
||||
self.assertEqual(result, messages)
|
||||
self.assertEqual(len(result), 1)
|
||||
|
||||
def test_none_system_instruction_does_not_add_message(self):
|
||||
"""Test that None system instruction does not add a system message."""
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
context = LLMContext(messages=messages)
|
||||
|
||||
result = context.get_messages_for_persistent_storage(None)
|
||||
|
||||
self.assertEqual(result, messages)
|
||||
self.assertEqual(len(result), 1)
|
||||
|
||||
def test_whitespace_only_system_instruction_adds_message(self):
|
||||
"""Test that whitespace-only system instruction still adds a system message."""
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
context = LLMContext(messages=messages)
|
||||
system_instruction = " "
|
||||
|
||||
result = context.get_messages_for_persistent_storage(system_instruction)
|
||||
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertEqual(result[0]["role"], "system")
|
||||
self.assertEqual(result[0]["content"], system_instruction)
|
||||
|
||||
def test_with_llm_specific_messages(self):
|
||||
"""Test that method works correctly with LLMSpecificMessage objects."""
|
||||
llm_specific = LLMSpecificMessage(
|
||||
llm="test-llm", message={"role": "user", "content": "Specific"}
|
||||
)
|
||||
messages = [{"role": "user", "content": "Standard message"}, llm_specific]
|
||||
context = LLMContext(messages=messages)
|
||||
system_instruction = "You are a helpful assistant."
|
||||
|
||||
result = context.get_messages_for_persistent_storage(system_instruction)
|
||||
|
||||
self.assertEqual(len(result), 3)
|
||||
self.assertEqual(result[0]["role"], "system")
|
||||
self.assertEqual(result[0]["content"], system_instruction)
|
||||
self.assertEqual(result[1], messages[0])
|
||||
self.assertEqual(result[2], llm_specific)
|
||||
|
||||
def test_system_message_detection_case_sensitivity(self):
|
||||
"""Test that system message detection is case sensitive."""
|
||||
messages = [
|
||||
{"role": "System", "content": "Mixed case system"}, # Capital S
|
||||
{"role": "user", "content": "Hello"},
|
||||
]
|
||||
context = LLMContext(messages=messages)
|
||||
system_instruction = "You are a helpful assistant."
|
||||
|
||||
result = context.get_messages_for_persistent_storage(system_instruction)
|
||||
|
||||
# Should prepend because "System" != "system"
|
||||
self.assertEqual(len(result), 3)
|
||||
self.assertEqual(result[0]["role"], "system")
|
||||
self.assertEqual(result[0]["content"], system_instruction)
|
||||
self.assertEqual(result[1], messages[0])
|
||||
|
||||
def test_message_without_role_key_does_not_crash(self):
|
||||
"""Test that messages without 'role' key are handled gracefully."""
|
||||
messages = [{"content": "Message without role"}, {"role": "user", "content": "Hello"}]
|
||||
context = LLMContext(messages=messages)
|
||||
system_instruction = "You are a helpful assistant."
|
||||
|
||||
result = context.get_messages_for_persistent_storage(system_instruction)
|
||||
|
||||
# Should prepend system instruction since first message doesn't have role="system"
|
||||
self.assertEqual(len(result), 3)
|
||||
self.assertEqual(result[0]["role"], "system")
|
||||
self.assertEqual(result[0]["content"], system_instruction)
|
||||
|
||||
def test_original_messages_not_modified(self):
|
||||
"""Test that the original messages list is not modified."""
|
||||
original_messages = [{"role": "user", "content": "Hello"}]
|
||||
context = LLMContext(messages=original_messages)
|
||||
system_instruction = "You are a helpful assistant."
|
||||
|
||||
result = context.get_messages_for_persistent_storage(system_instruction)
|
||||
|
||||
# Original messages should remain unchanged
|
||||
self.assertEqual(len(original_messages), 1)
|
||||
self.assertEqual(original_messages[0]["role"], "user")
|
||||
|
||||
# Result should have system message prepended
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertEqual(result[0]["role"], "system")
|
||||
self.assertEqual(result[1], original_messages[0])
|
||||
|
||||
def test_complex_message_structure_preserved(self):
|
||||
"""Test that complex message structures are preserved."""
|
||||
complex_message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Complex message"},
|
||||
{"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,..."}},
|
||||
],
|
||||
}
|
||||
messages = [complex_message]
|
||||
context = LLMContext(messages=messages)
|
||||
system_instruction = "You are a helpful assistant."
|
||||
|
||||
result = context.get_messages_for_persistent_storage(system_instruction)
|
||||
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertEqual(result[0]["role"], "system")
|
||||
self.assertEqual(result[1], complex_message)
|
||||
self.assertEqual(result[1]["content"], complex_message["content"])
|
||||
|
||||
def test_deep_copy_prevents_nested_mutation(self):
|
||||
"""Test that deep copy prevents mutation of nested message content."""
|
||||
nested_content = {"nested": {"data": "original"}}
|
||||
complex_message = {"role": "user", "content": nested_content}
|
||||
messages = [complex_message]
|
||||
context = LLMContext(messages=messages)
|
||||
system_instruction = "You are a helpful assistant."
|
||||
|
||||
result = context.get_messages_for_persistent_storage(system_instruction)
|
||||
|
||||
# Modify the nested content in the result
|
||||
result[1]["content"]["nested"]["data"] = "modified"
|
||||
|
||||
# Original message should remain unchanged
|
||||
self.assertEqual(complex_message["content"]["nested"]["data"], "original")
|
||||
self.assertEqual(context.get_messages()[0]["content"]["nested"]["data"], "original")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user