diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 73e772609..f7330683c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,8 +1,13 @@ repos: - - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.12.1 + - repo: local hooks: - id: ruff - language_version: python3 - args: [--fix] + name: ruff + entry: uv run ruff check --fix + language: system + types: [python] - id: ruff-format + name: ruff-format + entry: uv run ruff format + language: system + types: [python] diff --git a/changelog/4215.changed.md b/changelog/4215.changed.md new file mode 100644 index 000000000..b84a343dd --- /dev/null +++ b/changelog/4215.changed.md @@ -0,0 +1 @@ +- ⚠️ `BaseOpenAILLMService.get_chat_completions()` now accepts an `LLMContext` instead of `OpenAILLMInvocationParams`. If you override this method, update your signature accordingly. diff --git a/changelog/4215.removed.2.md b/changelog/4215.removed.2.md new file mode 100644 index 000000000..4d5f6c630 --- /dev/null +++ b/changelog/4215.removed.2.md @@ -0,0 +1,22 @@ +- ⚠️ Removed deprecated service-specific context and aggregator machinery, which was superseded by the universal `LLMContext` system. + + Service-specific classes removed: `AnthropicLLMContext`, `AnthropicContextAggregatorPair`, `AWSBedrockLLMContext`, `AWSBedrockContextAggregatorPair`, `OpenAIContextAggregatorPair`, and their user/assistant aggregators. Also removed `create_context_aggregator()` from `LLMService`, `OpenAILLMService`, `AnthropicLLMService`, and `AWSBedrockLLMService`. + + Base aggregator classes removed (from `pipecat.processors.aggregators.llm_response`): `BaseLLMResponseAggregator`, `LLMContextResponseAggregator`, `LLMUserContextAggregator`, `LLMAssistantContextAggregator`, `LLMUserResponseAggregator`, `LLMAssistantResponseAggregator`. + + From the developer's point of view, migrating will usually be a matter of going from this: + + ```python + context = OpenAILLMContext(messages, tools) + context_aggregator = llm.create_context_aggregator(context) + ``` + + To this: + + ```python + from pipecat.processors.aggregators.llm_context import LLMContext + from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair + + context = LLMContext(messages, tools) + context_aggregator = LLMContextAggregatorPair(context) + ``` diff --git a/changelog/4215.removed.3.md b/changelog/4215.removed.3.md new file mode 100644 index 000000000..4ef02fa89 --- /dev/null +++ b/changelog/4215.removed.3.md @@ -0,0 +1 @@ +- ⚠️ Removed deprecated frame types `LLMMessagesFrame` and `OpenAILLMContextAssistantTimestampFrame` from `pipecat.frames.frames`. Instead of `LLMMessagesFrame`, use `LLMContextFrame` with the new messages, or `LLMMessagesUpdateFrame` with `run_llm=True`. diff --git a/changelog/4215.removed.4.md b/changelog/4215.removed.4.md new file mode 100644 index 000000000..20bde1ce9 --- /dev/null +++ b/changelog/4215.removed.4.md @@ -0,0 +1 @@ +- ⚠️ Removed `GatedOpenAILLMContextAggregator` (from `pipecat.processors.aggregators.gated_open_ai_llm_context`). Use `GatedLLMContextAggregator` (from `pipecat.processors.aggregators.gated_llm_context`) instead. diff --git a/changelog/4215.removed.5.md b/changelog/4215.removed.5.md new file mode 100644 index 000000000..cbd84018d --- /dev/null +++ b/changelog/4215.removed.5.md @@ -0,0 +1 @@ +- ⚠️ Removed `VisionImageFrameAggregator` (from `pipecat.processors.aggregators.vision_image_frame`). Vision/image handling is now built into `LLMContext` (from `pipecat.processors.aggregators.llm_context`). See the `12*` examples for the recommended replacement pattern. diff --git a/changelog/4215.removed.6.md b/changelog/4215.removed.6.md new file mode 100644 index 000000000..28062359a --- /dev/null +++ b/changelog/4215.removed.6.md @@ -0,0 +1 @@ +- ⚠️ Removed deprecated compatibility modules: `pipecat.services.openai_realtime_beta` (use `pipecat.services.openai.realtime`), `pipecat.services.openai_realtime.context`, `pipecat.services.openai_realtime.frames`, `pipecat.services.openai.realtime.context`, `pipecat.services.openai.realtime.frames`, `pipecat.services.gemini_multimodal_live` (use `pipecat.services.google.gemini_live`), `pipecat.services.aws_nova_sonic.context` (use `pipecat.services.aws.nova_sonic`), `pipecat.services.google.openai` and `pipecat.services.google.llm_openai` (use `pipecat.services.google.llm`). diff --git a/changelog/4215.removed.md b/changelog/4215.removed.md new file mode 100644 index 000000000..96cb03c71 --- /dev/null +++ b/changelog/4215.removed.md @@ -0,0 +1,18 @@ +- ⚠️ Removed `OpenAILLMContext`, `OpenAILLMContextFrame`, and `OpenAILLMContext.from_messages()`. Use `LLMContext` (from `pipecat.processors.aggregators.llm_context`) and `LLMContextFrame` (from `pipecat.frames.frames`) instead. All services now exclusively use the universal `LLMContext`. + + From the developer's point of view, migrating will usually be a matter of going from this: + + ```python + context = OpenAILLMContext(messages, tools) + context_aggregator = llm.create_context_aggregator(context) + ``` + + To this: + + ```python + from pipecat.processors.aggregators.llm_context import LLMContext + from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair + + context = LLMContext(messages, tools) + context_aggregator = LLMContextAggregatorPair(context) + ``` diff --git a/scripts/pre-commit.sh b/scripts/pre-commit.sh deleted file mode 100755 index 44a17cc19..000000000 --- a/scripts/pre-commit.sh +++ /dev/null @@ -1,27 +0,0 @@ -#!/bin/bash - -# Color codes for output -RED='\033[0;31m' -GREEN='\033[0;32m' -NC='\033[0m' # No Color - -echo "🔍 Running pre-commit checks..." - -# Change to project root (one level up from scripts/) -cd "$(dirname "$0")/.." - -# Format check -echo "📝 Checking code formatting..." -if ! NO_COLOR=1 uv run ruff format --diff --check; then - echo -e "${RED}❌ Code formatting issues found. Run 'uv run ruff format' to fix.${NC}" - exit 1 -fi - -# Lint check -echo "🔍 Running linter..." -if ! uv run ruff check; then - echo -e "${RED}❌ Linting issues found.${NC}" - exit 1 -fi - -echo -e "${GREEN}✅ All pre-commit checks passed!${NC}" \ No newline at end of file diff --git a/src/pipecat/adapters/schemas/tools_schema.py b/src/pipecat/adapters/schemas/tools_schema.py index 298cef56d..e3940f6cb 100644 --- a/src/pipecat/adapters/schemas/tools_schema.py +++ b/src/pipecat/adapters/schemas/tools_schema.py @@ -22,12 +22,9 @@ class AdapterType(Enum): Parameters: GEMINI: Google Gemini adapter - currently the only service supporting custom tools. - SHIM: Backward compatibility shim for creating ToolsSchemas from lists of tools in - any format, used by LLMContext.from_openai_context. """ GEMINI = "gemini" # that is the only service where we are able to add custom tools for now - SHIM = "shim" # for use as backward compatibility shim for creating ToolsSchemas from list of tools in any format class ToolsSchema: diff --git a/src/pipecat/adapters/services/aws_nova_sonic_adapter.py b/src/pipecat/adapters/services/aws_nova_sonic_adapter.py index 492e02db6..e481f9971 100644 --- a/src/pipecat/adapters/services/aws_nova_sonic_adapter.py +++ b/src/pipecat/adapters/services/aws_nova_sonic_adapter.py @@ -222,18 +222,4 @@ class AWSNovaSonicLLMAdapter(BaseLLMAdapter[AWSNovaSonicLLMInvocationParams]): List of dictionaries in AWS Nova Sonic function format. """ functions_schema = tools_schema.standard_tools - standard_tools = [ - self._to_aws_nova_sonic_function_format(func) for func in functions_schema - ] - - # For backward compatibility, AWS Nova Sonic can still be used with - # tools in dict format, even though it always uses `LLMContext` under - # the hood (via `LLMContext.from_openai_context()`). - # To support this behavior, we use "shimmed" custom tools here. - # (We maintain this backward compatibility because users aren't - # *knowingly* opting into the new `LLMContext`.) - shimmed_tools = [] - if tools_schema.custom_tools: - shimmed_tools = tools_schema.custom_tools.get(AdapterType.SHIM, []) - - return standard_tools + shimmed_tools + return [self._to_aws_nova_sonic_function_format(func) for func in functions_schema] diff --git a/src/pipecat/adapters/services/grok_realtime_adapter.py b/src/pipecat/adapters/services/grok_realtime_adapter.py index 4e5e2a8c5..1c1f7336d 100644 --- a/src/pipecat/adapters/services/grok_realtime_adapter.py +++ b/src/pipecat/adapters/services/grok_realtime_adapter.py @@ -256,11 +256,4 @@ class GrokRealtimeLLMAdapter(BaseLLMAdapter): """ # Convert standard function tools functions_schema = tools_schema.standard_tools - standard_tools = [self._to_grok_function_format(func) for func in functions_schema] - - # Support shimmed custom tools for backward compatibility - shimmed_tools = [] - if tools_schema.custom_tools: - shimmed_tools = tools_schema.custom_tools.get(AdapterType.SHIM, []) - - return standard_tools + shimmed_tools + return [self._to_grok_function_format(func) for func in functions_schema] diff --git a/src/pipecat/adapters/services/open_ai_realtime_adapter.py b/src/pipecat/adapters/services/open_ai_realtime_adapter.py index 3c394d99b..4732f70b7 100644 --- a/src/pipecat/adapters/services/open_ai_realtime_adapter.py +++ b/src/pipecat/adapters/services/open_ai_realtime_adapter.py @@ -236,18 +236,4 @@ class OpenAIRealtimeLLMAdapter(BaseLLMAdapter): List of function definitions in OpenAI Realtime format. """ functions_schema = tools_schema.standard_tools - standard_tools = [ - self._to_openai_realtime_function_format(func) for func in functions_schema - ] - - # For backward compatibility, OpenAI Realtime can still be used with - # tools in dict format, even though it always uses `LLMContext` under - # the hood (via `LLMContext.from_openai_context()`). - # To support this behavior, we use "shimmed" custom tools here. - # (We maintain this backward compatibility because users aren't - # *knowingly* opting into the new `LLMContext`.) - shimmed_tools = [] - if tools_schema.custom_tools: - shimmed_tools = tools_schema.custom_tools.get(AdapterType.SHIM, []) - - return standard_tools + shimmed_tools + return [self._to_openai_realtime_function_format(func) for func in functions_schema] diff --git a/src/pipecat/extensions/ivr/ivr_navigator.py b/src/pipecat/extensions/ivr/ivr_navigator.py index 64a6d0942..a2ff0cde5 100644 --- a/src/pipecat/extensions/ivr/ivr_navigator.py +++ b/src/pipecat/extensions/ivr/ivr_navigator.py @@ -31,7 +31,6 @@ from pipecat.frames.frames import ( VADParamsUpdateFrame, ) from pipecat.pipeline.pipeline import Pipeline -from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame from pipecat.processors.frame_processor import FrameDirection, FrameProcessor from pipecat.services.llm_service import LLMService from pipecat.utils.text.pattern_pair_aggregator import ( @@ -444,7 +443,7 @@ Remember: Respond with `NUMBER` (single or multiple for sequences), frame: The frame to process. direction: The direction of frame flow in the pipeline. """ - if isinstance(frame, (OpenAILLMContextFrame, LLMContextFrame)): + if isinstance(frame, LLMContextFrame): # Extract messages and pass to IVR processor all_messages = frame.context.get_messages() diff --git a/src/pipecat/frames/frames.py b/src/pipecat/frames/frames.py index 29369f03a..bf5319405 100644 --- a/src/pipecat/frames/frames.py +++ b/src/pipecat/frames/frames.py @@ -451,36 +451,6 @@ class TranslationFrame(TextFrame): return f"{self.name}(user: {self.user_id}, text: [{self.text}], language: {self.language}, timestamp: {self.timestamp})" -@dataclass -class OpenAILLMContextAssistantTimestampFrame(DataFrame): - """Timestamp information for assistant messages in LLM context. - - .. deprecated:: 0.0.99 - `OpenAILLMContextAssistantTimestampFrame` is deprecated and will be removed in a future version. - Use `LLMContextAssistantTimestampFrame` with the universal `LLMContext` and `LLMContextAggregatorPair` instead. - See `OpenAILLMContext` docstring for migration guide. - - Parameters: - timestamp: Timestamp when the assistant message was created. - """ - - timestamp: str - - def __post_init__(self): - super().__post_init__() - import warnings - - with warnings.catch_warnings(): - warnings.simplefilter("always") - warnings.warn( - "OpenAILLMContextAssistantTimestampFrame is deprecated and will be removed in a future version. " - "Use LLMContextAssistantTimestampFrame with the universal LLMContext and LLMContextAggregatorPair instead. " - "See OpenAILLMContext docstring for migration guide.", - DeprecationWarning, - stacklevel=2, - ) - - @dataclass class LLMContextAssistantTimestampFrame(DataFrame): """Timestamp information for assistant messages in LLM context. @@ -706,44 +676,6 @@ class LLMThoughtEndFrame(ControlFrame): return f"{self.name}(pts: {pts}, signature: {self.signature})" -@dataclass -class LLMMessagesFrame(DataFrame): - """Frame containing LLM messages for chat completion. - - .. deprecated:: 0.0.79 - This class is deprecated and will be removed in a future version. - Instead, use either: - - `LLMMessagesUpdateFrame` with `run_llm=True` - - `OpenAILLMContextFrame` with desired messages in a new context - - A frame containing a list of LLM messages. Used to signal that an LLM - service should run a chat completion and emit an LLMFullResponseStartFrame, - TextFrames and an LLMFullResponseEndFrame. Note that the `messages` - property in this class is mutable, and will be updated by various - aggregators. - - Parameters: - messages: List of message dictionaries in LLM format. - """ - - messages: List[dict] - - def __post_init__(self): - super().__post_init__() - import warnings - - with warnings.catch_warnings(): - warnings.simplefilter("always") - warnings.warn( - "LLMMessagesFrame is deprecated and will be removed in a future version. " - "Instead, use either " - "`LLMMessagesUpdateFrame` with `run_llm=True`, or " - "`OpenAILLMContextFrame` with desired messages in a new context", - DeprecationWarning, - stacklevel=2, - ) - - @dataclass class LLMRunFrame(DataFrame): """Frame to trigger LLM processing with current context. diff --git a/src/pipecat/observers/loggers/llm_log_observer.py b/src/pipecat/observers/loggers/llm_log_observer.py index 07212e4c9..31fd0142e 100644 --- a/src/pipecat/observers/loggers/llm_log_observer.py +++ b/src/pipecat/observers/loggers/llm_log_observer.py @@ -14,11 +14,9 @@ from pipecat.frames.frames import ( LLMContextFrame, LLMFullResponseEndFrame, LLMFullResponseStartFrame, - LLMMessagesFrame, LLMTextFrame, ) from pipecat.observers.base_observer import BaseObserver, FramePushed -from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame from pipecat.processors.frame_processor import FrameDirection from pipecat.services.llm_service import LLMService @@ -32,8 +30,6 @@ class LLMLogObserver(BaseObserver): - LLMFullResponseEndFrame - LLMTextFrame - FunctionCallInProgressFrame - - LLMMessagesFrame - - OpenAILLMContextFrame This allows you to track when the LLM starts responding, what it generates, and when it finishes. @@ -74,18 +70,9 @@ class LLMLogObserver(BaseObserver): logger.debug( f"🧠 {src} {arrow} LLM FUNCTION CALL ({frame.tool_call_id}): {frame.function_name!r}({frame.arguments}) at {time_sec:.2f}s" ) - # Log LLMMessagesFrame (input) - elif isinstance(frame, LLMMessagesFrame): - logger.debug( - f"🧠 {arrow} {dst} LLM MESSAGES FRAME: {frame.messages} at {time_sec:.2f}s" - ) - # Log OpenAILLMContextFrame (input) - elif isinstance(frame, (LLMContextFrame, OpenAILLMContextFrame)): - messages = ( - frame.context.messages - if isinstance(frame, OpenAILLMContextFrame) - else frame.context.get_messages() - ) + # Log LLMContextFrame (input) + elif isinstance(frame, LLMContextFrame): + messages = frame.context.get_messages() logger.debug(f"🧠 {arrow} {dst} LLM CONTEXT FRAME: {messages} at {time_sec:.2f}s") # Log function call result (input) elif isinstance(frame, FunctionCallResultFrame): diff --git a/src/pipecat/pipeline/task.py b/src/pipecat/pipeline/task.py index d33fcbd2c..e5a6ad7ff 100644 --- a/src/pipecat/pipeline/task.py +++ b/src/pipecat/pipeline/task.py @@ -48,7 +48,6 @@ from pipecat.pipeline.base_pipeline import BasePipeline from pipecat.pipeline.base_task import BasePipelineTask, PipelineTaskParams from pipecat.pipeline.pipeline import Pipeline, PipelineSink, PipelineSource from pipecat.pipeline.task_observer import TaskObserver -from pipecat.processors.aggregators.llm_response import LLMUserContextAggregator from pipecat.processors.frame_processor import FrameDirection, FrameProcessor, FrameProcessorSetup from pipecat.processors.frameworks.rtvi import RTVIObserver, RTVIObserverParams, RTVIProcessor from pipecat.utils.asyncio.task_manager import BaseTaskManager, TaskManager, TaskManagerParams @@ -1028,10 +1027,6 @@ class PipelineTask(BasePipelineTask): """Build and return start metadata including user-provided values.""" start_metadata = {} - # NOTE(aleix): Remove when OpenAILLMContext/LLMUserContextAggregator is removed. - if self._find_processor(self._pipeline, LLMUserContextAggregator): - start_metadata["deprecated_openaillmcontext"] = True - # Update with user provided metadata. start_metadata.update(self._params.start_metadata) diff --git a/src/pipecat/processors/aggregators/gated_llm_context.py b/src/pipecat/processors/aggregators/gated_llm_context.py index cd1c587e9..fdcc9a025 100644 --- a/src/pipecat/processors/aggregators/gated_llm_context.py +++ b/src/pipecat/processors/aggregators/gated_llm_context.py @@ -7,7 +7,6 @@ """Gated LLM context aggregator for controlled message flow.""" from pipecat.frames.frames import CancelFrame, EndFrame, Frame, LLMContextFrame, StartFrame -from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame from pipecat.processors.frame_processor import FrameDirection, FrameProcessor from pipecat.utils.sync.base_notifier import BaseNotifier @@ -49,7 +48,7 @@ class GatedLLMContextAggregator(FrameProcessor): if isinstance(frame, (EndFrame, CancelFrame)): await self._stop() await self.push_frame(frame) - elif isinstance(frame, (LLMContextFrame, OpenAILLMContextFrame)): + elif isinstance(frame, LLMContextFrame): if self._start_open: self._start_open = False await self.push_frame(frame, direction) diff --git a/src/pipecat/processors/aggregators/gated_open_ai_llm_context.py b/src/pipecat/processors/aggregators/gated_open_ai_llm_context.py deleted file mode 100644 index 0cdb366d3..000000000 --- a/src/pipecat/processors/aggregators/gated_open_ai_llm_context.py +++ /dev/null @@ -1,12 +0,0 @@ -# -# Copyright (c) 2024-2026, Daily -# -# SPDX-License-Identifier: BSD 2-Clause License -# - -"""Gated OpenAI LLM context aggregator for controlled message flow.""" - -from pipecat.processors.aggregators.gated_llm_context import GatedLLMContextAggregator - -# Alias for backward compatibility with the previous name -GatedOpenAILLMContextAggregator = GatedLLMContextAggregator diff --git a/src/pipecat/processors/aggregators/llm_context.py b/src/pipecat/processors/aggregators/llm_context.py index 1375b8297..eb98571cc 100644 --- a/src/pipecat/processors/aggregators/llm_context.py +++ b/src/pipecat/processors/aggregators/llm_context.py @@ -33,9 +33,6 @@ from PIL import Image from pipecat.adapters.schemas.tools_schema import AdapterType, ToolsSchema from pipecat.frames.frames import AudioRawFrame -if TYPE_CHECKING: - from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext - # "Re-export" types from OpenAI that we're using as universal context types. # NOTE: if universal message types need to someday diverge from OpenAI's, we # should consider managing our own definitions. But we should do so carefully, @@ -70,51 +67,6 @@ class LLMContext: and content formatting. """ - @staticmethod - def from_openai_context(openai_context: "OpenAILLMContext") -> "LLMContext": - """Create a universal LLM context from an OpenAI-specific context. - - NOTE: this should only be used internally, for facilitating migration - from OpenAILLMContext to LLMContext. New user code should use - LLMContext directly. - - .. deprecated:: 0.0.99 - `from_openai_context()` is deprecated and will be removed in a future version. - Directly use the universal `LLMContext` and `LLMContextAggregatorPair` instead. - See `OpenAILLMContext` docstring for migration guide. - - Args: - openai_context: The OpenAI LLM context to convert. - - Returns: - New LLMContext instance with converted messages and settings. - """ - import warnings - - with warnings.catch_warnings(): - warnings.simplefilter("always") - warnings.warn( - "from_openai_context() (likely invoked by create_context_aggregator()) is deprecated and will be removed in a future version. " - "Directly use the universal LLMContext and LLMContextAggregatorPair instead. " - "See OpenAILLMContext docstring for migration guide.", - DeprecationWarning, - stacklevel=2, - ) - - # Convert tools to ToolsSchema if needed. - # If the tools are already a ToolsSchema, this is a no-op. - # Otherwise, we wrap them in a shim ToolsSchema. - converted_tools = openai_context.tools - if isinstance(converted_tools, list): - converted_tools = ToolsSchema( - standard_tools=[], custom_tools={AdapterType.SHIM: converted_tools} - ) - return LLMContext( - messages=openai_context.get_messages(), - tools=converted_tools, - tool_choice=openai_context.tool_choice, - ) - def __init__( self, messages: Optional[List[LLMContextMessage]] = None, @@ -246,33 +198,6 @@ class LLMContext: """ return self.get_messages() - def get_messages_for_persistent_storage(self) -> List[LLMContextMessage]: - """Get messages suitable for persistent storage. - - NOTE: the only reason this method exists is because we're "silently" - switching from OpenAILLMContext to LLMContext under the hood in some - services and don't want to trip up users who may have been relying on - this method, which is part of the public API of OpenAILLMContext but - doesn't need to be for LLMContext. - - .. deprecated:: 0.0.92 - Use `get_messages()` instead. - - Returns: - List of conversation messages. - """ - import warnings - - with warnings.catch_warnings(): - warnings.simplefilter("always") - warnings.warn( - "get_messages_for_persistent_storage() is deprecated, use get_messages() instead.", - DeprecationWarning, - stacklevel=2, - ) - - return self.get_messages() - def get_messages(self, llm_specific_filter: Optional[str] = None) -> List[LLMContextMessage]: """Get the current messages list. diff --git a/src/pipecat/processors/aggregators/llm_response.py b/src/pipecat/processors/aggregators/llm_response.py index 7c246b209..a1d4399cb 100644 --- a/src/pipecat/processors/aggregators/llm_response.py +++ b/src/pipecat/processors/aggregators/llm_response.py @@ -4,106 +4,17 @@ # SPDX-License-Identifier: BSD 2-Clause License # -"""LLM response aggregators for handling conversation context and message aggregation. +"""LLM response aggregator for collecting complete LLM responses.""" -This module provides aggregators that process and accumulate LLM responses, user inputs, -and conversation context. These aggregators handle the flow between speech-to-text, -LLM processing, and text-to-speech components in conversational AI pipelines. -""" - -import asyncio -import warnings -from abc import abstractmethod -from dataclasses import dataclass -from typing import Dict, List, Literal, Optional, Set - -from loguru import logger - -from pipecat.audio.interruptions.base_interruption_strategy import BaseInterruptionStrategy -from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams -from pipecat.audio.vad.vad_analyzer import VADParams from pipecat.frames.frames import ( - BotStartedSpeakingFrame, - BotStoppedSpeakingFrame, - CancelFrame, - EmulateUserStartedSpeakingFrame, - EmulateUserStoppedSpeakingFrame, - EndFrame, Frame, - FunctionCallCancelFrame, - FunctionCallInProgressFrame, - FunctionCallResultFrame, - FunctionCallsStartedFrame, - InputAudioRawFrame, - InterimTranscriptionFrame, InterruptionFrame, LLMFullResponseEndFrame, LLMFullResponseStartFrame, - LLMMessagesAppendFrame, - LLMMessagesFrame, - LLMMessagesUpdateFrame, - LLMRunFrame, - LLMSetToolChoiceFrame, - LLMSetToolsFrame, LLMTextFrame, - OpenAILLMContextAssistantTimestampFrame, - SpeechControlParamsFrame, - StartFrame, TextFrame, - TranscriptionFrame, - UserImageRawFrame, - UserStartedSpeakingFrame, - UserStoppedSpeakingFrame, -) -from pipecat.processors.aggregators.openai_llm_context import ( - OpenAILLMContext, - OpenAILLMContextFrame, ) from pipecat.processors.frame_processor import FrameDirection, FrameProcessor -from pipecat.utils.time import time_now_iso8601 - - -@dataclass -class LLMUserAggregatorParams: - """Parameters for configuring LLM user aggregation behavior. - - .. deprecated:: 0.0.99 - This class is deprecated, use the new universal `LLMContext` and - `LLMContextAggregatorPair`. - - Parameters: - aggregation_timeout: Maximum time in seconds to wait for additional - transcription content before pushing aggregated result. This - timeout is used only when the transcription is slow to arrive. - turn_emulated_vad_timeout: Maximum time in seconds to wait for emulated - VAD when using turn-based analysis. Applied when transcription is - received but VAD didn't detect speech (e.g., whispered utterances). - enable_emulated_vad_interruptions: When True, allows emulated VAD events - to interrupt the bot when it's speaking. When False, emulated speech - is ignored while the bot is speaking. - """ - - aggregation_timeout: float = 0.5 - turn_emulated_vad_timeout: float = 0.8 - enable_emulated_vad_interruptions: bool = False - - -@dataclass -class LLMAssistantAggregatorParams: - """Parameters for configuring LLM assistant aggregation behavior. - - .. deprecated:: 0.0.99 - This class is deprecated, use the new universal `LLMContext` and - `LLMContextAggregatorPair`. - - Parameters: - expect_stripped_words: Whether to expect and handle stripped words - in text frames by adding spaces between tokens. This parameter is - ignored when used with the newer LLMAssistantAggregator, which - handles word spacing automatically. - """ - - expect_stripped_words: bool = True class LLMFullResponseAggregator(FrameProcessor): @@ -173,993 +84,3 @@ class LLMFullResponseAggregator(FrameProcessor): if not self._started: return self._aggregation += frame.text - - -class BaseLLMResponseAggregator(FrameProcessor): - """Base class for all LLM response aggregators. - - These aggregators process incoming frames and aggregate content until they are - ready to push the aggregation downstream. They maintain conversation state - and handle message flow between different components in the pipeline. - - The aggregators keep a store (e.g. message list or LLM context) of the current - conversation, storing messages from both users and the bot. - - .. deprecated:: 0.0.99 - `BaseLLMResponseAggregator` is deprecated and will be removed in a future version. - Use the universal `LLMContext` and `LLMContextAggregatorPair` instead. - See `OpenAILLMContext` docstring for migration guide. - """ - - def __init__(self, **kwargs): - """Initialize the base LLM response aggregator. - - Args: - **kwargs: Additional arguments passed to parent FrameProcessor. - - .. deprecated:: 0.0.99 - `BaseLLMResponseAggregator` is deprecated and will be removed in a future version. - Use the universal `LLMContext` and `LLMContextAggregatorPair` instead. - See `OpenAILLMContext` docstring for migration guide. - """ - with warnings.catch_warnings(): - warnings.simplefilter("always") - warnings.warn( - f"{self.__class__.__name__} (likely created with create_context_aggregator()) is deprecated and will be removed in a future version. " - "Use the universal LLMContext and LLMContextAggregatorPair instead. " - "See OpenAILLMContext docstring for migration guide.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(**kwargs) - - @property - @abstractmethod - def messages(self) -> List[dict]: - """Get the messages from the current conversation. - - Returns: - List of message dictionaries representing the conversation history. - """ - pass - - @property - @abstractmethod - def role(self) -> str: - """Get the role for this aggregator. - - Returns: - The role string (e.g. "user", "assistant") for this aggregator. - """ - pass - - @abstractmethod - def add_messages(self, messages): - """Add the given messages to the conversation. - - Args: - messages: Messages to append to the conversation history. - """ - pass - - @abstractmethod - def set_messages(self, messages): - """Reset the conversation with the given messages. - - Args: - messages: Messages to replace the current conversation history. - """ - pass - - @abstractmethod - def set_tools(self, tools): - """Set LLM tools to be used in the current conversation. - - Args: - tools: List of tool definitions for the LLM to use. - """ - pass - - @abstractmethod - def set_tool_choice(self, tool_choice): - """Set the tool choice for the LLM. - - Args: - tool_choice: Tool choice configuration for the LLM context. - """ - pass - - @abstractmethod - async def reset(self): - """Reset the internal state of this aggregator. - - This should clear aggregation state but not modify the conversation messages. - """ - pass - - @abstractmethod - async def handle_aggregation(self, aggregation: str): - """Add the given aggregation to the conversation store. - - Args: - aggregation: The aggregated text content to add to the conversation. - """ - pass - - @abstractmethod - async def push_aggregation(self): - """Push the current aggregation downstream. - - The specific frame type pushed depends on the aggregator implementation - (e.g. context frame, messages frame). - """ - pass - - -class LLMContextResponseAggregator(BaseLLMResponseAggregator): - """Base LLM aggregator that uses an OpenAI LLM context for conversation storage. - - This aggregator maintains conversation state using an OpenAILLMContext and - pushes OpenAILLMContextFrame objects as aggregation frames. It provides - common functionality for context-based conversation management. - - .. deprecated:: 0.0.99 - `LLMContextResponseAggregator` is deprecated and will be removed in a future version. - Use the universal `LLMContext` and `LLMContextAggregatorPair` instead. - See `OpenAILLMContext` docstring for migration guide. - """ - - def __init__(self, *, context: OpenAILLMContext, role: str, **kwargs): - """Initialize the context response aggregator. - - Args: - context: The OpenAI LLM context to use for conversation storage. - role: The role this aggregator represents (e.g. "user", "assistant"). - **kwargs: Additional arguments passed to parent class. - - .. deprecated:: 0.0.99 - `LLMContextResponseAggregator` is deprecated and will be removed in a future version. - Use the universal `LLMUserAggregator` and `LLMAssistantAggregator` instead. - See `OpenAILLMContext` docstring for migration guide. - """ - # Super handles deprecation warning - super().__init__(**kwargs) - self._context = context - self._role = role - - self._aggregation: str = "" - - @property - def messages(self) -> List[dict]: - """Get messages from the LLM context. - - Returns: - List of message dictionaries from the context. - """ - return self._context.get_messages() - - @property - def role(self) -> str: - """Get the role for this aggregator. - - Returns: - The role string for this aggregator. - """ - return self._role - - @property - def context(self): - """Get the OpenAI LLM context. - - Returns: - The OpenAILLMContext instance used by this aggregator. - """ - return self._context - - def get_context_frame(self) -> OpenAILLMContextFrame: - """Create a context frame with the current context. - - .. deprecated:: 0.0.82 - This method is deprecated and will be removed in a future version. - - Returns: - LLMContextFrame containing the current context. - """ - with warnings.catch_warnings(): - warnings.simplefilter("always") - warnings.warn( - "get_context_frame() is deprecated and will be removed in a future version. To trigger an LLM response, use LLMRunFrame instead.", - DeprecationWarning, - stacklevel=2, - ) - return self._get_context_frame() - - def _get_context_frame(self) -> OpenAILLMContextFrame: - return OpenAILLMContextFrame(context=self._context) - - async def push_context_frame(self, direction: FrameDirection = FrameDirection.DOWNSTREAM): - """Push a context frame in the specified direction. - - Args: - direction: The direction to push the frame (upstream or downstream). - """ - frame = self._get_context_frame() - await self.push_frame(frame, direction) - - def add_messages(self, messages): - """Add messages to the context. - - Args: - messages: Messages to add to the conversation context. - """ - self._context.add_messages(messages) - - def set_messages(self, messages): - """Set the context messages. - - Args: - messages: Messages to replace the current context messages. - """ - self._context.set_messages(messages) - - def set_tools(self, tools: List): - """Set tools in the context. - - Args: - tools: List of tool definitions to set in the context. - """ - self._context.set_tools(tools) - - def set_tool_choice(self, tool_choice: Literal["none", "auto", "required"] | dict): - """Set tool choice in the context. - - Args: - tool_choice: Tool choice configuration for the context. - """ - self._context.set_tool_choice(tool_choice) - - async def reset(self): - """Reset the aggregation state.""" - self._aggregation = "" - - -class LLMUserContextAggregator(LLMContextResponseAggregator): - """User LLM aggregator that processes speech-to-text transcriptions. - - This aggregator handles the complex logic of aggregating user speech transcriptions - from STT services. It manages multiple scenarios including: - - - Transcriptions received between VAD events - - Transcriptions received outside VAD events - - Interim vs final transcriptions - - User interruptions during bot speech - - Emulated VAD for whispered or short utterances - - The aggregator uses timeouts to handle cases where transcriptions arrive - after VAD events or when no VAD is available. - - .. deprecated:: 0.0.99 - `LLMUserContextAggregator` is deprecated and will be removed in a future version. - Use the universal `LLMContext` and `LLMContextAggregatorPair` instead. - See `OpenAILLMContext` docstring for migration guide. - """ - - def __init__( - self, - context: OpenAILLMContext, - *, - params: Optional[LLMUserAggregatorParams] = None, - **kwargs, - ): - """Initialize the user context aggregator. - - Args: - context: The OpenAI LLM context for conversation storage. - params: Configuration parameters for aggregation behavior. - **kwargs: Additional arguments. Supports deprecated 'aggregation_timeout'. - - .. deprecated:: 0.0.99 - `LLMUserContextAggregator` is deprecated and will be removed in a future version. - Use the universal `LLMContext` and `LLMContextAggregatorPair` instead. - See `OpenAILLMContext` docstring for migration guide. - """ - # Super handles deprecation warning - super().__init__(context=context, role="user", **kwargs) - self._params = params or LLMUserAggregatorParams() - self._vad_params: Optional[VADParams] = None - self._turn_params: Optional[SmartTurnParams] = None - - if "aggregation_timeout" in kwargs: - with warnings.catch_warnings(): - warnings.simplefilter("always") - warnings.warn( - "Parameter 'aggregation_timeout' is deprecated, use 'params' instead.", - DeprecationWarning, - ) - - self._params.aggregation_timeout = kwargs["aggregation_timeout"] - - self._user_speaking = False - self._bot_speaking = False - self._was_bot_speaking = False - self._emulating_vad = False - self._seen_interim_results = False - self._waiting_for_aggregation = False - - self._aggregation_event = asyncio.Event() - self._aggregation_task = None - - async def reset(self): - """Reset the aggregation state and interruption strategies.""" - await super().reset() - self._was_bot_speaking = False - self._seen_interim_results = False - self._waiting_for_aggregation = False - [await s.reset() for s in self._interruption_strategies] - - async def handle_aggregation(self, aggregation: str): - """Add the aggregated user text to the context. - - Args: - aggregation: The aggregated user text to add as a user message. - """ - self._context.add_message({"role": self.role, "content": aggregation}) - - async def process_frame(self, frame: Frame, direction: FrameDirection): - """Process frames for user speech aggregation and context management. - - Args: - frame: The frame to process. - direction: The direction of frame flow in the pipeline. - """ - await super().process_frame(frame, direction) - - if isinstance(frame, StartFrame): - # Push StartFrame before start(), because we want StartFrame to be - # processed by every processor before any other frame is processed. - await self.push_frame(frame, direction) - await self._start(frame) - elif isinstance(frame, EndFrame): - # Push EndFrame before stop(), because stop() waits on the task to - # finish and the task finishes when EndFrame is processed. - await self.push_frame(frame, direction) - await self._stop(frame) - elif isinstance(frame, CancelFrame): - await self._cancel(frame) - await self.push_frame(frame, direction) - elif isinstance(frame, InputAudioRawFrame): - await self._handle_input_audio(frame) - await self.push_frame(frame, direction) - elif isinstance(frame, UserStartedSpeakingFrame): - await self._handle_user_started_speaking(frame) - await self.push_frame(frame, direction) - elif isinstance(frame, UserStoppedSpeakingFrame): - await self._handle_user_stopped_speaking(frame) - await self.push_frame(frame, direction) - elif isinstance(frame, BotStartedSpeakingFrame): - await self._handle_bot_started_speaking(frame) - await self.push_frame(frame, direction) - elif isinstance(frame, BotStoppedSpeakingFrame): - await self._handle_bot_stopped_speaking(frame) - await self.push_frame(frame, direction) - elif isinstance(frame, TranscriptionFrame): - await self._handle_transcription(frame) - elif isinstance(frame, InterimTranscriptionFrame): - await self._handle_interim_transcription(frame) - elif isinstance(frame, LLMRunFrame): - await self._handle_llm_run(frame) - elif isinstance(frame, LLMMessagesAppendFrame): - await self._handle_llm_messages_append(frame) - elif isinstance(frame, LLMMessagesUpdateFrame): - await self._handle_llm_messages_update(frame) - elif isinstance(frame, LLMSetToolsFrame): - self.set_tools(frame.tools) - elif isinstance(frame, LLMSetToolChoiceFrame): - self.set_tool_choice(frame.tool_choice) - elif isinstance(frame, SpeechControlParamsFrame): - self._vad_params = frame.vad_params - self._turn_params = frame.turn_params - await self.push_frame(frame, direction) - else: - await self.push_frame(frame, direction) - - async def _process_aggregation(self): - """Process the current aggregation and push it downstream.""" - aggregation = self._aggregation - await self.reset() - await self.handle_aggregation(aggregation) - frame = OpenAILLMContextFrame(self._context) - await self.push_frame(frame) - - async def push_aggregation(self): - """Push the current aggregation based on interruption strategies and conditions.""" - if len(self._aggregation) > 0: - if self.interruption_strategies and self._bot_speaking: - should_interrupt = await self._should_interrupt_based_on_strategies() - - if should_interrupt: - logger.debug( - "Interruption conditions met - pushing interruption and aggregation" - ) - await self.broadcast_interruption() - await self._process_aggregation() - else: - logger.debug("Interruption conditions not met - not pushing aggregation") - # Don't process aggregation, just reset it - await self.reset() - else: - # No interruption config - normal behavior (always push aggregation) - await self._process_aggregation() - # Handles the case where both the user and the bot are not speaking, - # and the bot was previously speaking before the user interruption. - # Normally, when the user stops speaking, new text is expected, - # which triggers the bot to respond. However, if no new text - # is received, this safeguard ensures - # the bot doesn't hang indefinitely while waiting to speak again. - elif not self._seen_interim_results and self._was_bot_speaking and not self._bot_speaking: - logger.warning("User stopped speaking but no new aggregation received.") - # Resetting it so we don't trigger this twice - self._was_bot_speaking = False - # TODO: we are not enabling this for now, due to some STT services which can take as long as 2 seconds two return a transcription - # So we need more tests and probably make this feature configurable, disabled it by default. - # We are just pushing the same previous context to be processed again in this case - # await self.push_frame(OpenAILLMContextFrame(self._context)) - - async def _should_interrupt_based_on_strategies(self) -> bool: - """Check if interruption should occur based on configured strategies. - - Returns: - True if any interruption strategy indicates interruption should occur. - """ - - async def should_interrupt(strategy: BaseInterruptionStrategy): - await strategy.append_text(self._aggregation) - return await strategy.should_interrupt() - - return any([await should_interrupt(s) for s in self._interruption_strategies]) - - async def _start(self, frame: StartFrame): - self._create_aggregation_task() - - async def _stop(self, frame: EndFrame): - await self._cancel_aggregation_task() - - async def _cancel(self, frame: CancelFrame): - await self._cancel_aggregation_task() - - async def _handle_llm_run(self, frame: LLMRunFrame): - await self.push_context_frame() - - async def _handle_llm_messages_append(self, frame: LLMMessagesAppendFrame): - self.add_messages(frame.messages) - if frame.run_llm: - await self.push_context_frame() - - async def _handle_llm_messages_update(self, frame: LLMMessagesUpdateFrame): - self.set_messages(frame.messages) - if frame.run_llm: - await self.push_context_frame() - - async def _handle_input_audio(self, frame: InputAudioRawFrame): - for s in self.interruption_strategies: - await s.append_audio(frame.audio, frame.sample_rate) - - async def _handle_user_started_speaking(self, frame: UserStartedSpeakingFrame): - self._user_speaking = True - self._waiting_for_aggregation = True - self._was_bot_speaking = self._bot_speaking - - # If we get a non-emulated UserStartedSpeakingFrame but we are in the - # middle of emulating VAD, let's stop emulating VAD (i.e. don't send the - # EmulateUserStoppedSpeakingFrame). - if not frame.emulated and self._emulating_vad: - self._emulating_vad = False - - async def _handle_user_stopped_speaking(self, _: UserStoppedSpeakingFrame): - self._user_speaking = False - # We just stopped speaking. Let's see if there's some aggregation to - # push. If the last thing we saw is an interim transcription, let's wait - # pushing the aggregation as we will probably get a final transcription. - if len(self._aggregation) > 0: - if not self._seen_interim_results: - await self.push_aggregation() - # Handles the case where both the user and the bot are not speaking, - # and the bot was previously speaking before the user interruption. - # So in this case we are resetting the aggregation timer - elif not self._seen_interim_results and self._was_bot_speaking and not self._bot_speaking: - # Reset aggregation timer. - self._aggregation_event.set() - - async def _handle_bot_started_speaking(self, _: BotStartedSpeakingFrame): - self._bot_speaking = True - - async def _handle_bot_stopped_speaking(self, _: BotStoppedSpeakingFrame): - self._bot_speaking = False - - async def _handle_transcription(self, frame: TranscriptionFrame): - text = frame.text - - # Make sure we really have some text. - if not text.strip(): - return - - self._aggregation += f" {text}" if self._aggregation else text - # We just got a final result, so let's reset interim results. - self._seen_interim_results = False - # Reset aggregation timer. - self._aggregation_event.set() - - async def _handle_interim_transcription(self, _: InterimTranscriptionFrame): - self._seen_interim_results = True - - def _create_aggregation_task(self): - if not self._aggregation_task: - self._aggregation_task = self.create_task(self._aggregation_task_handler()) - - async def _cancel_aggregation_task(self): - if self._aggregation_task: - await self.cancel_task(self._aggregation_task) - self._aggregation_task = None - - async def _aggregation_task_handler(self): - while True: - try: - # The _aggregation_task_handler handles two distinct timeout scenarios: - # - # 1. When emulating_vad=True: Wait for emulated VAD timeout before - # pushing aggregation (simulating VAD behavior when no actual VAD - # detection occurred). - # - # 2. When emulating_vad=False: Use aggregation_timeout as a buffer - # to wait for potential late-arriving transcription frames after - # a real VAD event. - # - # For emulated VAD scenarios, the timeout strategy depends on whether - # a turn analyzer is configured: - # - # - WITH turn analyzer: Use turn_emulated_vad_timeout parameter because - # the VAD's stop_secs is set very low (e.g. 0.2s) for rapid speech - # chunking to feed the turn analyzer. This low value is too fast - # for emulated VAD scenarios where we need to allow users time to - # finish speaking (e.g. 0.8s). - # - # - WITHOUT turn analyzer: Use VAD's stop_secs directly to maintain - # consistent user experience between real VAD detection and - # emulated VAD scenarios. - if not self._emulating_vad: - timeout = self._params.aggregation_timeout - elif self._turn_params: - timeout = self._params.turn_emulated_vad_timeout - else: - # Use VAD stop_secs when no turn analyzer is present, fallback if no VAD params - timeout = ( - self._vad_params.stop_secs - if self._vad_params - else self._params.turn_emulated_vad_timeout - ) - await asyncio.wait_for(self._aggregation_event.wait(), timeout=timeout) - await self._maybe_emulate_user_speaking() - except asyncio.TimeoutError: - if not self._user_speaking: - await self.push_aggregation() - - # If we are emulating VAD we still need to send the user stopped - # speaking frame. - if self._emulating_vad: - await self.push_frame( - EmulateUserStoppedSpeakingFrame(), FrameDirection.UPSTREAM - ) - self._emulating_vad = False - finally: - self._aggregation_event.clear() - - async def _maybe_emulate_user_speaking(self): - """Maybe emulate user speaking based on transcription. - - Emulate user speaking if we got a transcription but it was not - detected by VAD. Behavior when bot is speaking depends on the - enable_emulated_vad_interruptions parameter. - """ - # Check if we received a transcription but VAD was not able to detect - # voice (e.g. when you whisper a short utterance). In that case, we need - # to emulate VAD (i.e. user start/stopped speaking). - if ( - not self._user_speaking - and not self._waiting_for_aggregation - and len(self._aggregation) > 0 - ): - if self._bot_speaking and not self._params.enable_emulated_vad_interruptions: - # If emulated VAD interruptions are disabled and bot is speaking, ignore - logger.debug("Ignoring user speaking emulation, bot is speaking.") - await self.reset() - else: - # Either bot is not speaking, or emulated VAD interruptions are enabled - # - trigger user speaking emulation. - await self.push_frame(EmulateUserStartedSpeakingFrame(), FrameDirection.UPSTREAM) - self._emulating_vad = True - - -class LLMAssistantContextAggregator(LLMContextResponseAggregator): - """Assistant LLM aggregator that processes bot responses and function calls. - - This aggregator handles the complex logic of processing assistant responses including: - - - Text frame aggregation between response start/end markers - - Function call lifecycle management - - Context updates with timestamps - - Tool execution and result handling - - Interruption handling during responses - - The aggregator manages function calls in progress and coordinates between - text generation and tool execution phases of LLM responses. - - .. deprecated:: 0.0.99 - `LLMAssistantContextAggregator` is deprecated and will be removed in a future version. - Use the universal `LLMContext` and `LLMContextAggregatorPair` instead. - See `OpenAILLMContext` docstring for migration guide. - """ - - def __init__( - self, - context: OpenAILLMContext, - *, - params: Optional[LLMAssistantAggregatorParams] = None, - **kwargs, - ): - """Initialize the assistant context aggregator. - - Args: - context: The OpenAI LLM context for conversation storage. - params: Configuration parameters for aggregation behavior. - **kwargs: Additional arguments. Supports deprecated 'expect_stripped_words'. - - .. deprecated:: 0.0.99 - `LLMAssistantContextAggregator` is deprecated and will be removed in a future version. - Use the universal `LLMContext` and `LLMContextAggregatorPair` instead. - See `OpenAILLMContext` docstring for migration guide. - """ - # Super handles deprecation warning - super().__init__(context=context, role="assistant", **kwargs) - self._params = params or LLMAssistantAggregatorParams() - - if "expect_stripped_words" in kwargs: - with warnings.catch_warnings(): - warnings.simplefilter("always") - warnings.warn( - "Parameter 'expect_stripped_words' is deprecated, use 'params' instead.", - DeprecationWarning, - ) - - self._params.expect_stripped_words = kwargs["expect_stripped_words"] - - self._started = 0 - self._function_calls_in_progress: Dict[str, Optional[FunctionCallInProgressFrame]] = {} - self._context_updated_tasks: Set[asyncio.Task] = set() - - @property - def has_function_calls_in_progress(self) -> bool: - """Check if there are any function calls currently in progress. - - Returns: - True if function calls are in progress, False otherwise. - """ - return bool(self._function_calls_in_progress) - - async def handle_aggregation(self, aggregation: str): - """Add the aggregated assistant text to the context. - - Args: - aggregation: The aggregated assistant text to add as an assistant message. - """ - self._context.add_message({"role": "assistant", "content": aggregation}) - - async def handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame): - """Handle a function call that is in progress. - - Args: - frame: The function call in progress frame to handle. - """ - pass - - async def handle_function_call_result(self, frame: FunctionCallResultFrame): - """Handle the result of a completed function call. - - Args: - frame: The function call result frame to handle. - """ - pass - - async def handle_function_call_cancel(self, frame: FunctionCallCancelFrame): - """Handle cancellation of a function call. - - Args: - frame: The function call cancel frame to handle. - """ - pass - - async def handle_user_image_frame(self, frame: UserImageRawFrame): - """Handle a user image frame associated with a function call. - - Args: - frame: The user image frame to handle. - """ - pass - - async def process_frame(self, frame: Frame, direction: FrameDirection): - """Process frames for assistant response aggregation and function call management. - - Args: - frame: The frame to process. - direction: The direction of frame flow in the pipeline. - """ - await super().process_frame(frame, direction) - - if isinstance(frame, InterruptionFrame): - await self._handle_interruptions(frame) - await self.push_frame(frame, direction) - elif isinstance(frame, LLMFullResponseStartFrame): - await self._handle_llm_start(frame) - elif isinstance(frame, LLMFullResponseEndFrame): - await self._handle_llm_end(frame) - elif isinstance(frame, TextFrame): - await self._handle_text(frame) - elif isinstance(frame, LLMRunFrame): - await self._handle_llm_run(frame) - elif isinstance(frame, LLMMessagesAppendFrame): - await self._handle_llm_messages_append(frame) - elif isinstance(frame, LLMMessagesUpdateFrame): - await self._handle_llm_messages_update(frame) - elif isinstance(frame, LLMSetToolsFrame): - self.set_tools(frame.tools) - elif isinstance(frame, LLMSetToolChoiceFrame): - self.set_tool_choice(frame.tool_choice) - elif isinstance(frame, FunctionCallsStartedFrame): - await self._handle_function_calls_started(frame) - elif isinstance(frame, FunctionCallInProgressFrame): - await self._handle_function_call_in_progress(frame) - elif isinstance(frame, FunctionCallResultFrame): - await self._handle_function_call_result(frame) - elif isinstance(frame, FunctionCallCancelFrame): - await self._handle_function_call_cancel(frame) - elif isinstance(frame, UserImageRawFrame) and frame.request and frame.request.tool_call_id: - await self._handle_user_image_frame(frame) - elif isinstance(frame, BotStoppedSpeakingFrame): - await self.push_aggregation() - await self.push_frame(frame, direction) - else: - await self.push_frame(frame, direction) - - async def push_aggregation(self): - """Push the current assistant aggregation with timestamp.""" - if not self._aggregation: - return - - aggregation = self._aggregation.strip() - await self.reset() - - if aggregation: - await self.handle_aggregation(aggregation) - - # Push context frame - await self.push_context_frame() - - # Push timestamp frame with current time - timestamp_frame = OpenAILLMContextAssistantTimestampFrame(timestamp=time_now_iso8601()) - await self.push_frame(timestamp_frame) - - async def _handle_llm_run(self, frame: LLMRunFrame): - await self.push_context_frame(FrameDirection.UPSTREAM) - - async def _handle_llm_messages_append(self, frame: LLMMessagesAppendFrame): - self.add_messages(frame.messages) - if frame.run_llm: - await self.push_context_frame(FrameDirection.UPSTREAM) - - async def _handle_llm_messages_update(self, frame: LLMMessagesUpdateFrame): - self.set_messages(frame.messages) - if frame.run_llm: - await self.push_context_frame(FrameDirection.UPSTREAM) - - async def _handle_interruptions(self, frame: InterruptionFrame): - await self.push_aggregation() - self._started = 0 - await self.reset() - - async def _handle_function_calls_started(self, frame: FunctionCallsStartedFrame): - function_names = [f"{f.function_name}:{f.tool_call_id}" for f in frame.function_calls] - logger.debug(f"{self} FunctionCallsStartedFrame: {function_names}") - for function_call in frame.function_calls: - self._function_calls_in_progress[function_call.tool_call_id] = None - - async def _handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame): - logger.debug( - f"{self} FunctionCallInProgressFrame: [{frame.function_name}:{frame.tool_call_id}]" - ) - await self.handle_function_call_in_progress(frame) - self._function_calls_in_progress[frame.tool_call_id] = frame - - async def _handle_function_call_result(self, frame: FunctionCallResultFrame): - logger.debug( - f"{self} FunctionCallResultFrame: [{frame.function_name}:{frame.tool_call_id}]" - ) - if frame.tool_call_id not in self._function_calls_in_progress: - logger.warning( - f"FunctionCallResultFrame tool_call_id [{frame.tool_call_id}] is not running" - ) - return - - del self._function_calls_in_progress[frame.tool_call_id] - - properties = frame.properties - - await self.handle_function_call_result(frame) - - run_llm = False - - # Run inference if the function call result requires it. - if frame.result: - if properties and properties.run_llm is not None: - # If the tool call result has a run_llm property, use it. - run_llm = properties.run_llm - elif frame.run_llm is not None: - # If the frame is indicating we should run the LLM, do it. - run_llm = frame.run_llm - else: - # If this is the last function call in progress, run the LLM. - run_llm = not bool(self._function_calls_in_progress) - - if run_llm: - await self.push_context_frame(FrameDirection.UPSTREAM) - - # Call the `on_context_updated` callback once the function call result - # is added to the context. Also, run this in a separate task to make - # sure we don't block the pipeline. - if properties and properties.on_context_updated: - task_name = f"{frame.function_name}:{frame.tool_call_id}:on_context_updated" - task = self.create_task(properties.on_context_updated(), task_name) - self._context_updated_tasks.add(task) - task.add_done_callback(self._context_updated_task_finished) - - async def _handle_function_call_cancel(self, frame: FunctionCallCancelFrame): - logger.debug( - f"{self} FunctionCallCancelFrame: [{frame.function_name}:{frame.tool_call_id}]" - ) - function_call = self._function_calls_in_progress.get(frame.tool_call_id) - if function_call and function_call.cancel_on_interruption: - await self.handle_function_call_cancel(frame) - del self._function_calls_in_progress[frame.tool_call_id] - - async def _handle_user_image_frame(self, frame: UserImageRawFrame): - logger.debug( - f"{self} UserImageRawFrame: [{frame.request.function_name}:{frame.request.tool_call_id}]" - ) - - if frame.request.tool_call_id not in self._function_calls_in_progress: - logger.warning( - f"UserImageRawFrame tool_call_id [{frame.request.tool_call_id}] is not running" - ) - return - - del self._function_calls_in_progress[frame.request.tool_call_id] - - # Call the result_callback if provided. This signals that the image - # has been retrieved and the function call can now complete. - if frame.request and frame.request.result_callback: - await frame.request.result_callback(None) - - await self.handle_user_image_frame(frame) - await self.push_aggregation() - await self.push_context_frame(FrameDirection.UPSTREAM) - - async def _handle_llm_start(self, _: LLMFullResponseStartFrame): - self._started += 1 - - async def _handle_llm_end(self, _: LLMFullResponseEndFrame): - self._started -= 1 - await self.push_aggregation() - - async def _handle_text(self, frame: TextFrame): - if not frame.append_to_context: - return - - if self._params.expect_stripped_words: - self._aggregation += f" {frame.text}" if self._aggregation else frame.text - else: - self._aggregation += frame.text - - def _context_updated_task_finished(self, task: asyncio.Task): - self._context_updated_tasks.discard(task) - - -class LLMUserResponseAggregator(LLMUserContextAggregator): - """User response aggregator that outputs LLMMessagesFrame instead of context frames. - - .. deprecated:: 0.0.79 - This class is deprecated and will be removed in a future version. - Use `LLMUserContextAggregator` or another LLM-specific subclass instead. - - This aggregator extends LLMUserContextAggregator but pushes LLMMessagesFrame - objects downstream instead of OpenAILLMContextFrame objects. This is useful - when you need message-based output rather than context-based output. - """ - - def __init__( - self, - messages: Optional[List[dict]] = None, - *, - params: Optional[LLMUserAggregatorParams] = None, - **kwargs, - ): - """Initialize the user response aggregator. - - Args: - messages: Initial messages for the conversation context. - params: Configuration parameters for aggregation behavior. - **kwargs: Additional arguments passed to parent class. - """ - with warnings.catch_warnings(): - warnings.simplefilter("always") - warnings.warn( - "LLMUserResponseAggregator is deprecated and will be removed in a future version. " - "Use LLMUserContextAggregator or another LLM-specific subclass instead.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(context=OpenAILLMContext(messages), params=params, **kwargs) - - async def _process_aggregation(self): - """Process the current aggregation and push it downstream.""" - aggregation = self._aggregation - await self.reset() - await self.handle_aggregation(aggregation) - frame = LLMMessagesFrame(self._context.messages) - await self.push_frame(frame) - - -class LLMAssistantResponseAggregator(LLMAssistantContextAggregator): - """Assistant response aggregator that outputs LLMMessagesFrame instead of context frames. - - .. deprecated:: 0.0.79 - This class is deprecated and will be removed in a future version. - Use `LLMAssistantContextAggregator` or another LLM-specific subclass instead. - - This aggregator extends LLMAssistantContextAggregator but pushes LLMMessagesFrame - objects downstream instead of OpenAILLMContextFrame objects. This is useful - when you need message-based output rather than context-based output. - """ - - def __init__( - self, - messages: Optional[List[dict]] = None, - *, - params: Optional[LLMAssistantAggregatorParams] = None, - **kwargs, - ): - """Initialize the assistant response aggregator. - - Args: - messages: Initial messages for the conversation context. - params: Configuration parameters for aggregation behavior. - **kwargs: Additional arguments passed to parent class. - """ - with warnings.catch_warnings(): - warnings.simplefilter("always") - warnings.warn( - "LLMAssistantResponseAggregator is deprecated and will be removed in a future version. " - "Use LLMAssistantContextAggregator or another LLM-specific subclass instead.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(context=OpenAILLMContext(messages), params=params, **kwargs) - - async def push_aggregation(self): - """Push the aggregated assistant response as an LLMMessagesFrame.""" - if len(self._aggregation) > 0: - await self.handle_aggregation(self._aggregation) - - # Reset the aggregation. Reset it before pushing it down, otherwise - # if the tasks gets cancelled we won't be able to clear things up. - await self.reset() - - frame = LLMMessagesFrame(self._context.messages) - await self.push_frame(frame) diff --git a/src/pipecat/processors/aggregators/openai_llm_context.py b/src/pipecat/processors/aggregators/openai_llm_context.py deleted file mode 100644 index f75625156..000000000 --- a/src/pipecat/processors/aggregators/openai_llm_context.py +++ /dev/null @@ -1,413 +0,0 @@ -# -# Copyright (c) 2024-2026, Daily -# -# SPDX-License-Identifier: BSD 2-Clause License -# - -"""OpenAI LLM context management for Pipecat. - -This module provides classes for managing OpenAI-specific conversation contexts, -including message handling, tool management, and image/audio processing capabilities. - -.. deprecated:: 0.0.99 - This module is deprecated. - Use the universal `LLMContext` and `LLMContextAggregatorPair` instead. - See `OpenAILLMContext` docstring for migration guide. -""" - -import base64 -import copy -import io -import json -import warnings -from dataclasses import dataclass -from typing import Any, Dict, List, Optional - -from openai._types import NOT_GIVEN, NotGiven -from openai.types.chat import ( - ChatCompletionMessageParam, - ChatCompletionToolChoiceOptionParam, - ChatCompletionToolParam, -) -from PIL import Image - -from pipecat.adapters.base_llm_adapter import BaseLLMAdapter -from pipecat.adapters.schemas.tools_schema import ToolsSchema -from pipecat.frames.frames import AudioRawFrame, Frame - -# JSON custom encoder to handle bytes arrays so that we can log contexts -# with images to the console. - - -class CustomEncoder(json.JSONEncoder): - """Custom JSON encoder for handling special data types in logging. - - Provides specialized encoding for io.BytesIO objects to display - readable representations in log output instead of raw binary data. - """ - - def default(self, obj): - """Encode special objects for JSON serialization. - - Args: - obj: The object to encode. - - Returns: - Encoded representation of the object. - """ - if isinstance(obj, io.BytesIO): - # Convert the first 8 bytes to an ASCII hex string - return f"{obj.getbuffer()[0:8].hex()}..." - return super().default(obj) - - -class OpenAILLMContext: - """Manages conversation context for OpenAI LLM interactions. - - Handles message history, tool definitions, tool choices, and multimedia content - for OpenAI API conversations. Provides methods for message manipulation, - content formatting, and integration with various LLM adapters. - - .. deprecated:: 0.0.99 - `OpenAILLMContext` is deprecated and will be removed in a future version. - Use the universal `LLMContext` and `LLMContextAggregatorPair` instead. - - **Before:** - - context = OpenAILLMContext(messages, tools) - context_aggregator = llm.create_context_aggregator(context) - - **After:** - - context = LLMContext(messages, tools) - context_aggregator = LLMContextAggregatorPair(context) - """ - - def __init__( - self, - messages: Optional[List[ChatCompletionMessageParam]] = None, - tools: List[ChatCompletionToolParam] | NotGiven | ToolsSchema = NOT_GIVEN, - tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN, - ): - """Initialize the OpenAI LLM context. - - Args: - messages: Initial list of conversation messages. - tools: Available tools for the LLM to use. - tool_choice: Tool selection strategy for the LLM. - - .. deprecated:: 0.0.99 - `OpenAILLMContext` is deprecated and will be removed in a future version. - Use the universal `LLMContext` and `LLMContextAggregatorPair` instead. - See `OpenAILLMContext` docstring for migration guide. - """ - with warnings.catch_warnings(): - warnings.simplefilter("always") - warnings.warn( - "OpenAILLMContext is deprecated and will be removed in a future version. " - "Use the universal LLMContext and LLMContextAggregatorPair instead. " - "See OpenAILLMContext docstring for migration guide.", - DeprecationWarning, - stacklevel=2, - ) - self._messages: List[ChatCompletionMessageParam] = messages if messages else [] - self._tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = tool_choice - self._tools: List[ChatCompletionToolParam] | NotGiven | ToolsSchema = tools - self._llm_adapter: Optional[BaseLLMAdapter] = None - - def get_llm_adapter(self) -> Optional[BaseLLMAdapter]: - """Get the current LLM adapter. - - Returns: - The currently set LLM adapter, or None if not set. - """ - return self._llm_adapter - - def set_llm_adapter(self, llm_adapter: BaseLLMAdapter): - """Set the LLM adapter for context processing. - - Args: - llm_adapter: The LLM adapter to use for tool conversion. - """ - self._llm_adapter = llm_adapter - - @staticmethod - def from_messages(messages: List[dict]) -> "OpenAILLMContext": - """Create a context from a list of message dictionaries. - - Args: - messages: List of message dictionaries to convert to context. - - Returns: - New OpenAILLMContext instance with the provided messages. - """ - context = OpenAILLMContext() - - for message in messages: - context.add_message(message) - return context - - @property - def messages(self) -> List[ChatCompletionMessageParam]: - """Get the current messages list. - - Returns: - List of conversation messages. - """ - return self._messages - - @property - def tools(self) -> List[ChatCompletionToolParam] | NotGiven | List[Any]: - """Get the tools list, converting through adapter if available. - - Returns: - Tools list, potentially converted by the LLM adapter. - """ - if self._llm_adapter: - return self._llm_adapter.from_standard_tools(self._tools) - return self._tools - - @property - def tool_choice(self) -> ChatCompletionToolChoiceOptionParam | NotGiven: - """Get the current tool choice setting. - - Returns: - The tool choice configuration. - """ - return self._tool_choice - - def add_message(self, message: ChatCompletionMessageParam): - """Add a single message to the context. - - Args: - message: The message to add to the conversation history. - """ - self._messages.append(message) - - def add_messages(self, messages: List[ChatCompletionMessageParam]): - """Add multiple messages to the context. - - Args: - messages: List of messages to add to the conversation history. - """ - self._messages.extend(messages) - - def set_messages(self, messages: List[ChatCompletionMessageParam]): - """Replace all messages in the context. - - Args: - messages: New list of messages to replace the current history. - """ - self._messages[:] = messages - - def get_messages(self) -> List[ChatCompletionMessageParam]: - """Get a copy of the current messages list. - - Returns: - List of all messages in the conversation history. - """ - return self._messages - - def get_messages_json(self) -> str: - """Get messages as a formatted JSON string. - - Returns: - JSON string representation of all messages with custom encoding. - """ - return json.dumps(self._messages, cls=CustomEncoder, ensure_ascii=False, indent=2) - - def get_messages_for_logging(self) -> List[Dict[str, Any]]: - """Get sanitized messages suitable for logging. - - Removes or truncates sensitive data like image content for safe logging. - - Returns: - List of messages in a format ready for logging. - """ - msgs = [] - for message in self.messages: - msg = copy.deepcopy(message) - if "content" in msg: - if isinstance(msg["content"], list): - for item in msg["content"]: - if item["type"] == "image_url": - if item["image_url"]["url"].startswith("data:image/"): - item["image_url"]["url"] = "data:image/..." - if "mime_type" in msg and msg["mime_type"].startswith("image/"): - msg["data"] = "..." - msgs.append(msg) - return msgs - - def from_standard_message(self, message): - """Convert from OpenAI message format to OpenAI message format (passthrough). - - OpenAI's format allows both simple string content and structured content:: - - Simple: {"role": "user", "content": "Hello"} - Structured: {"role": "user", "content": [{"type": "text", "text": "Hello"}]} - - Since OpenAI is our standard format, this is a passthrough function. - - Args: - message: Message in OpenAI format. - - Returns: - Same message, unchanged. - """ - return message - - def to_standard_messages(self, obj) -> list: - """Convert from OpenAI message format to OpenAI message format (passthrough). - - OpenAI's format is our standard format throughout Pipecat. This function - returns a list containing the original message to maintain consistency with - other LLM services that may need to return multiple messages. - - Args: - obj: Message in OpenAI format with either simple string content - or structured list content. - - Returns: - List containing the original messages, preserving the content format. - """ - return [obj] - - def get_messages_for_initializing_history(self): - """Get messages for initializing conversation history. - - Returns: - List of messages suitable for history initialization. - """ - return self._messages - - def get_messages_for_persistent_storage(self): - """Get messages formatted for persistent storage. - - Returns: - List of messages converted to standard format for storage. - """ - messages = [] - for m in self._messages: - standard_messages = self.to_standard_messages(m) - messages.extend(standard_messages) - return messages - - def set_tool_choice(self, tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven): - """Set the tool choice configuration. - - Args: - tool_choice: Tool selection strategy for the LLM. - """ - self._tool_choice = tool_choice - - def set_tools(self, tools: List[ChatCompletionToolParam] | NotGiven | ToolsSchema = NOT_GIVEN): - """Set the available tools for the LLM. - - Args: - tools: List of tools available to the LLM, or NOT_GIVEN to disable tools. - """ - if tools != NOT_GIVEN and isinstance(tools, list) and len(tools) == 0: - tools = NOT_GIVEN - self._tools = tools - - def add_image_frame_message( - self, *, format: str, size: tuple[int, int], image: bytes, text: str = None - ): - """Add a message containing an image frame. - - Args: - format: Image format (e.g., 'RGB', 'RGBA'). - size: Image dimensions as (width, height) tuple. - image: Raw image bytes. - text: Optional text to include with the image. - """ - buffer = io.BytesIO() - Image.frombytes(format, size, image).save(buffer, format="JPEG") - encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8") - - content = [] - if text: - content.append({"type": "text", "text": text}) - content.append( - {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}}, - ) - self.add_message({"role": "user", "content": content}) - - def add_audio_frames_message(self, *, audio_frames: list[AudioRawFrame], text: str = None): - """Add a message containing audio frames. - - Args: - audio_frames: List of audio frame objects to include. - text: Optional text to include with the audio. - - Note: - This method is currently a placeholder for future implementation. - """ - # todo: implement for OpenAI models and others - pass - - def create_wav_header(self, sample_rate, num_channels, bits_per_sample, data_size): - """Create a WAV file header for audio data. - - Args: - sample_rate: Audio sample rate in Hz. - num_channels: Number of audio channels. - bits_per_sample: Bits per audio sample. - data_size: Size of audio data in bytes. - - Returns: - WAV header as a bytearray. - """ - # RIFF chunk descriptor - header = bytearray() - header.extend(b"RIFF") # ChunkID - header.extend((data_size + 36).to_bytes(4, "little")) # ChunkSize: total size - 8 - header.extend(b"WAVE") # Format - # "fmt " sub-chunk - header.extend(b"fmt ") # Subchunk1ID - header.extend((16).to_bytes(4, "little")) # Subchunk1Size (16 for PCM) - header.extend((1).to_bytes(2, "little")) # AudioFormat (1 for PCM) - header.extend(num_channels.to_bytes(2, "little")) # NumChannels - header.extend(sample_rate.to_bytes(4, "little")) # SampleRate - # Calculate byte rate and block align - byte_rate = sample_rate * num_channels * (bits_per_sample // 8) - block_align = num_channels * (bits_per_sample // 8) - header.extend(byte_rate.to_bytes(4, "little")) # ByteRate - header.extend(block_align.to_bytes(2, "little")) # BlockAlign - header.extend(bits_per_sample.to_bytes(2, "little")) # BitsPerSample - # "data" sub-chunk - header.extend(b"data") # Subchunk2ID - header.extend(data_size.to_bytes(4, "little")) # Subchunk2Size - return header - - -@dataclass -class OpenAILLMContextFrame(Frame): - """Frame containing OpenAI-specific LLM context. - - Like an LLMMessagesFrame, but with extra context specific to the OpenAI - API. The context in this message is also mutable, and will be changed by the - OpenAIContextAggregator frame processor. - - .. deprecated:: 0.0.99 - `OpenAILLMContextFrame` is deprecated and will be removed in a future version. - Use `LLMContextFrame` with the universal `LLMContext` and `LLMContextAggregatorPair` instead. - See `OpenAILLMContext` docstring for migration guide. - - Parameters: - context: The OpenAI LLM context containing messages, tools, and configuration. - """ - - context: OpenAILLMContext - - def __post_init__(self): - super().__post_init__() - with warnings.catch_warnings(): - warnings.simplefilter("always") - warnings.warn( - "OpenAILLMContextFrame is deprecated and will be removed in a future version. " - "Use LLMContextFrame with the universal `LLMContext` and `LLMContextAggregatorPair` instead. " - "See OpenAILLMContext docstring for migration guide.", - DeprecationWarning, - stacklevel=2, - ) diff --git a/src/pipecat/processors/aggregators/vision_image_frame.py b/src/pipecat/processors/aggregators/vision_image_frame.py deleted file mode 100644 index e7360c07e..000000000 --- a/src/pipecat/processors/aggregators/vision_image_frame.py +++ /dev/null @@ -1,81 +0,0 @@ -# -# Copyright (c) 2024-2026, Daily -# -# SPDX-License-Identifier: BSD 2-Clause License -# - -"""Vision image frame aggregation for Pipecat. - -This module provides frame aggregation functionality to combine text and image -frames into vision frames for multimodal processing. -""" - -from pipecat.frames.frames import Frame, InputImageRawFrame, TextFrame -from pipecat.processors.aggregators.openai_llm_context import ( - OpenAILLMContext, - OpenAILLMContextFrame, -) -from pipecat.processors.frame_processor import FrameDirection, FrameProcessor - - -class VisionImageFrameAggregator(FrameProcessor): - """Aggregates consecutive text and image frames into vision frames. - - .. deprecated:: 0.0.85 - VisionImageRawFrame has been removed in favor of context frames - (LLMContextFrame or OpenAILLMContextFrame), so this aggregator is not - needed anymore. See the 12* examples for the new recommended pattern. - - This aggregator waits for a consecutive TextFrame and an InputImageRawFrame. - After the InputImageRawFrame arrives it will output a VisionImageRawFrame - combining both the text and image data for multimodal processing. - """ - - def __init__(self): - """Initialize the vision image frame aggregator. - - The aggregator starts with no cached text, waiting for the first - TextFrame to arrive before it can create vision frames. - """ - import warnings - - warnings.warn( - "VisionImageFrameAggregator is deprecated. " - "VisionImageRawFrame has been removed in favor of context frames " - "(LLMContextFrame or OpenAILLMContextFrame), so this aggregator is " - "not needed anymore. See the 12* examples for the new recommended " - "pattern.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__() - self._describe_text = None - - async def process_frame(self, frame: Frame, direction: FrameDirection): - """Process incoming frames and aggregate text with images. - - Caches TextFrames and combines them with subsequent InputImageRawFrames - to create VisionImageRawFrames. Other frames are passed through unchanged. - - Args: - frame: The incoming frame to process. - direction: The direction of frame flow in the pipeline. - """ - await super().process_frame(frame, direction) - - if isinstance(frame, TextFrame): - self._describe_text = frame.text - elif isinstance(frame, InputImageRawFrame): - if self._describe_text: - context = OpenAILLMContext() - context.add_image_frame_message( - text=self._describe_text, - image=frame.image, - size=frame.size, - format=frame.format, - ) - frame = OpenAILLMContextFrame(context) - await self.push_frame(frame) - self._describe_text = None - else: - await self.push_frame(frame, direction) diff --git a/src/pipecat/processors/frame_processor.py b/src/pipecat/processors/frame_processor.py index 8efb1353d..29b613527 100644 --- a/src/pipecat/processors/frame_processor.py +++ b/src/pipecat/processors/frame_processor.py @@ -196,7 +196,6 @@ class FrameProcessor(BaseObject): # Other properties (deprecated) self._allow_interruptions = False self._interruption_strategies: List[BaseInterruptionStrategy] = [] - self._deprecated_openaillmcontext = False # Indicates whether we have received the StartFrame. self.__started = False @@ -826,9 +825,6 @@ class FrameProcessor(BaseObject): self._interruption_strategies = frame.interruption_strategies self._report_only_initial_ttfb = frame.report_only_initial_ttfb - # NOTE(aleix): Remove when OpenAILLMContext/LLMUserContextAggregator is removed. - self._deprecated_openaillmcontext = "deprecated_openaillmcontext" in frame.metadata - self.__create_process_task() async def __cancel(self, frame: CancelFrame): diff --git a/src/pipecat/processors/frameworks/langchain.py b/src/pipecat/processors/frameworks/langchain.py index e569758ce..165f749ea 100644 --- a/src/pipecat/processors/frameworks/langchain.py +++ b/src/pipecat/processors/frameworks/langchain.py @@ -17,7 +17,6 @@ from pipecat.frames.frames import ( LLMFullResponseStartFrame, TextFrame, ) -from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame from pipecat.processors.frame_processor import FrameDirection, FrameProcessor try: @@ -65,15 +64,11 @@ class LangchainProcessor(FrameProcessor): """ await super().process_frame(frame, direction) - if isinstance(frame, (LLMContextFrame, OpenAILLMContextFrame)): + if isinstance(frame, LLMContextFrame): # Messages are accumulated on the context as a list of messages. # The last one by the human is the one we want to send to the LLM. logger.debug(f"Got transcription frame {frame}") - messages = ( - frame.context.messages - if isinstance(frame, OpenAILLMContextFrame) - else frame.context.get_messages() - ) + messages = frame.context.get_messages() text: str = messages[-1]["content"] await self._ainvoke(text.strip()) diff --git a/src/pipecat/processors/frameworks/rtvi/observer.py b/src/pipecat/processors/frameworks/rtvi/observer.py index 7944dd3eb..958ba8841 100644 --- a/src/pipecat/processors/frameworks/rtvi/observer.py +++ b/src/pipecat/processors/frameworks/rtvi/observer.py @@ -59,7 +59,6 @@ from pipecat.metrics.metrics import ( TTSUsageMetricsData, ) from pipecat.observers.base_observer import BaseObserver, FramePushed -from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame from pipecat.processors.frame_processor import FrameDirection, FrameProcessor from pipecat.processors.frameworks.rtvi.frames import ( RTVIServerMessageFrame, @@ -358,10 +357,7 @@ class RTVIObserver(BaseObserver): and self._params.user_transcription_enabled ): await self._handle_user_transcriptions(frame) - elif ( - isinstance(frame, (OpenAILLMContextFrame, LLMContextFrame)) - and self._params.user_llm_enabled - ): + elif isinstance(frame, LLMContextFrame) and self._params.user_llm_enabled: await self._handle_context(frame) elif isinstance(frame, LLMFullResponseStartFrame) and self._params.bot_llm_enabled: await self.send_rtvi_message(RTVI.BotLLMStartedMessage()) @@ -575,13 +571,10 @@ class RTVIObserver(BaseObserver): if message: await self.send_rtvi_message(message) - async def _handle_context(self, frame: OpenAILLMContextFrame | LLMContextFrame): + async def _handle_context(self, frame: LLMContextFrame): """Process LLM context frames to extract user messages for the RTVI client.""" try: - if isinstance(frame, OpenAILLMContextFrame): - messages = frame.context.messages - else: - messages = frame.context.get_messages() + messages = frame.context.get_messages() if not messages: return diff --git a/src/pipecat/processors/frameworks/strands_agents.py b/src/pipecat/processors/frameworks/strands_agents.py index 8022f5387..eb1edbfdc 100644 --- a/src/pipecat/processors/frameworks/strands_agents.py +++ b/src/pipecat/processors/frameworks/strands_agents.py @@ -16,7 +16,6 @@ from pipecat.frames.frames import ( LLMTextFrame, ) from pipecat.metrics.metrics import LLMTokenUsage -from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame from pipecat.processors.frame_processor import FrameDirection, FrameProcessor try: @@ -72,7 +71,7 @@ class StrandsAgentsProcessor(FrameProcessor): direction: The direction of frame flow in the pipeline. """ await super().process_frame(frame, direction) - if isinstance(frame, (LLMContextFrame, OpenAILLMContextFrame)): + if isinstance(frame, LLMContextFrame): messages = frame.context.get_messages() if messages: last_message = messages[-1] diff --git a/src/pipecat/services/anthropic/llm.py b/src/pipecat/services/anthropic/llm.py index f35d6226c..9c07a8444 100644 --- a/src/pipecat/services/anthropic/llm.py +++ b/src/pipecat/services/anthropic/llm.py @@ -37,7 +37,6 @@ from pipecat.frames.frames import ( LLMEnablePromptCachingFrame, LLMFullResponseEndFrame, LLMFullResponseStartFrame, - LLMMessagesFrame, LLMThoughtEndFrame, LLMThoughtStartFrame, LLMThoughtTextFrame, @@ -45,16 +44,6 @@ from pipecat.frames.frames import ( ) from pipecat.metrics.metrics import LLMTokenUsage from pipecat.processors.aggregators.llm_context import LLMContext -from pipecat.processors.aggregators.llm_response import ( - LLMAssistantAggregatorParams, - LLMAssistantContextAggregator, - LLMUserAggregatorParams, - LLMUserContextAggregator, -) -from pipecat.processors.aggregators.openai_llm_context import ( - OpenAILLMContext, - OpenAILLMContextFrame, -) from pipecat.processors.frame_processor import FrameDirection from pipecat.services.llm_service import FunctionCallFromLLM, LLMService from pipecat.services.settings import NOT_GIVEN as _NOT_GIVEN @@ -115,44 +104,6 @@ class AnthropicLLMSettings(LLMSettings): return instance -@dataclass -class AnthropicContextAggregatorPair: - """Pair of context aggregators for Anthropic conversations. - - Encapsulates both user and assistant context aggregators - to manage conversation flow and message formatting. - - .. deprecated:: 0.0.99 - `AnthropicContextAggregatorPair` is deprecated and will be removed in a future version. - Use the universal `LLMContext` and `LLMContextAggregatorPair` instead. - See `OpenAILLMContext` docstring for migration guide. - - Parameters: - _user: The user context aggregator. - _assistant: The assistant context aggregator. - """ - - # Aggregators handle deprecation warnings - _user: "AnthropicUserContextAggregator" - _assistant: "AnthropicAssistantContextAggregator" - - def user(self) -> "AnthropicUserContextAggregator": - """Get the user context aggregator. - - Returns: - The user context aggregator instance. - """ - return self._user - - def assistant(self) -> "AnthropicAssistantContextAggregator": - """Get the assistant context aggregator. - - Returns: - The assistant context aggregator instance. - """ - return self._assistant - - class AnthropicLLMService(LLMService): """LLM service for Anthropic's Claude models. @@ -351,7 +302,7 @@ class AnthropicLLMService(LLMService): async def run_inference( self, - context: LLMContext | OpenAILLMContext, + context: LLMContext, max_tokens: Optional[int] = None, system_instruction: Optional[str] = None, ) -> Optional[str]: @@ -371,21 +322,15 @@ class AnthropicLLMService(LLMService): system = NOT_GIVEN tools = [] effective_instruction = system_instruction or self._settings.system_instruction - if isinstance(context, LLMContext): - adapter: AnthropicLLMAdapter = self.get_llm_adapter() - invocation_params = adapter.get_llm_invocation_params( - context, - enable_prompt_caching=self._settings.enable_prompt_caching, - system_instruction=effective_instruction, - ) - messages = invocation_params["messages"] - system = invocation_params["system"] - tools = invocation_params["tools"] - else: - context = AnthropicLLMContext.upgrade_to_anthropic(context) - messages = context.messages - system = getattr(context, "system", NOT_GIVEN) - tools = context.tools or [] + adapter: AnthropicLLMAdapter = self.get_llm_adapter() + invocation_params = adapter.get_llm_invocation_params( + context, + enable_prompt_caching=self._settings.enable_prompt_caching, + system_instruction=effective_instruction, + ) + messages = invocation_params["messages"] + system = invocation_params["system"] + tools = invocation_params["tools"] # Build params using the same method as streaming completions params = { @@ -410,70 +355,17 @@ class AnthropicLLMService(LLMService): return next((block.text for block in response.content if hasattr(block, "text")), None) - def create_context_aggregator( - self, - context: OpenAILLMContext, - *, - user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(), - assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(), - ) -> AnthropicContextAggregatorPair: - """Create Anthropic-specific context aggregators. - - Creates a pair of context aggregators optimized for Anthropic's message format, - including support for function calls, tool usage, and image handling. - - Args: - context: The LLM context. - user_params: User aggregator parameters. - assistant_params: Assistant aggregator parameters. - - Returns: - A pair of context aggregators, one for the user and one for the assistant, - encapsulated in an AnthropicContextAggregatorPair. - - .. deprecated:: 0.0.99 - `create_context_aggregator()` is deprecated and will be removed in a future version. - Use the universal `LLMContext` and `LLMContextAggregatorPair` instead. - See `OpenAILLMContext` docstring for migration guide. - """ - context.set_llm_adapter(self.get_llm_adapter()) - - if isinstance(context, OpenAILLMContext): - context = AnthropicLLMContext.from_openai_context(context) - - # Aggregators handle deprecation warnings - user = AnthropicUserContextAggregator(context, params=user_params) - assistant = AnthropicAssistantContextAggregator(context, params=assistant_params) - - return AnthropicContextAggregatorPair(_user=user, _assistant=assistant) - - def _get_llm_invocation_params( - self, context: OpenAILLMContext | LLMContext - ) -> AnthropicLLMInvocationParams: - # Universal LLMContext - if isinstance(context, LLMContext): - adapter: AnthropicLLMAdapter = self.get_llm_adapter() - params: AnthropicLLMInvocationParams = adapter.get_llm_invocation_params( - context, - enable_prompt_caching=self._settings.enable_prompt_caching, - system_instruction=self._settings.system_instruction, - ) - return params - - # Anthropic-specific context - messages = ( - context.get_messages_with_cache_control_markers() - if self._settings.enable_prompt_caching - else context.messages - ) - return AnthropicLLMInvocationParams( - system=context.system, - messages=messages, - tools=context.tools or [], + def _get_llm_invocation_params(self, context: LLMContext) -> AnthropicLLMInvocationParams: + adapter: AnthropicLLMAdapter = self.get_llm_adapter() + params: AnthropicLLMInvocationParams = adapter.get_llm_invocation_params( + context, + enable_prompt_caching=self._settings.enable_prompt_caching, + system_instruction=self._settings.system_instruction, ) + return params @traced_llm - async def _process_context(self, context: OpenAILLMContext | LLMContext): + async def _process_context(self, context: LLMContext): # Usage tracking. We track the usage reported by Anthropic in prompt_tokens and # completion_tokens. We also estimate the completion tokens from output text # and use that estimate if we are interrupted, because we almost certainly won't @@ -491,15 +383,10 @@ class AnthropicLLMService(LLMService): params_from_context = self._get_llm_invocation_params(context) - if isinstance(context, LLMContext): - adapter = self.get_llm_adapter() - context_type_for_logging = "universal" - messages_for_logging = adapter.get_messages_for_logging(context) - else: - context_type_for_logging = "LLM-specific" - messages_for_logging = context.get_messages_for_logging() + adapter = self.get_llm_adapter() + messages_for_logging = adapter.get_messages_for_logging(context) logger.debug( - f"{self}: Generating chat from {context_type_for_logging} context [{params_from_context['system']}] | {messages_for_logging}" + f"{self}: Generating chat from context [{params_from_context['system']}] | {messages_for_logging}" ) await self.start_ttfb_metrics() @@ -665,24 +552,14 @@ class AnthropicLLMService(LLMService): """ await super().process_frame(frame, direction) - context = None - if isinstance(frame, OpenAILLMContextFrame): - context: "AnthropicLLMContext" = AnthropicLLMContext.upgrade_to_anthropic(frame.context) - elif isinstance(frame, LLMContextFrame): - context = frame.context - elif isinstance(frame, LLMMessagesFrame): - # NOTE: LLMMessagesFrame is deprecated, so we don't support the newer universal - # LLMContext with it - context = AnthropicLLMContext.from_messages(frame.messages) + if isinstance(frame, LLMContextFrame): + await self._process_context(frame.context) elif isinstance(frame, LLMEnablePromptCachingFrame): logger.debug(f"Setting enable prompt caching to: [{frame.enable}]") self._settings.enable_prompt_caching = frame.enable else: await self.push_frame(frame, direction) - if context: - await self._process_context(context) - def _estimate_tokens(self, text: str) -> int: return int(len(re.split(r"[^\w]+", text)) * 1.3) @@ -707,581 +584,3 @@ class AnthropicLLMService(LLMService): total_tokens=prompt_tokens + completion_tokens, ) await self.start_llm_usage_metrics(tokens) - - -class AnthropicLLMContext(OpenAILLMContext): - """LLM context specialized for Anthropic's message format and features. - - Extends OpenAILLMContext to handle Anthropic-specific features like - system messages, prompt caching, and message format conversions. - Manages conversation state and message history formatting. - - .. deprecated:: 0.0.99 - `AnthropicLLMContext` is deprecated and will be removed in a future version. - Use the universal `LLMContext` and `LLMContextAggregatorPair` instead. - See `OpenAILLMContext` docstring for migration guide. - """ - - def __init__( - self, - messages: Optional[List[dict]] = None, - tools: Optional[List[dict]] = None, - tool_choice: Optional[dict] = None, - *, - system: Union[str, NotGiven] = NOT_GIVEN, - ): - """Initialize the Anthropic LLM context. - - Args: - messages: Initial list of conversation messages. - tools: Available function calling tools. - tool_choice: Tool selection preference. - system: System message content. - """ - # Super handles deprecation warning - super().__init__(messages=messages, tools=tools, tool_choice=tool_choice) - self.__setup_local() - self.system = system - - def __setup_local(self): - # For beta prompt caching. This is a counter that tracks the number of turns - # we've seen above the cache threshold. We reset this when we reset the - # messages list. We only care about this number being 0, 1, or 2. But - # it's easiest just to treat it as a counter. - self.turns_above_cache_threshold = 0 - return - - @staticmethod - def upgrade_to_anthropic(obj: OpenAILLMContext) -> "AnthropicLLMContext": - """Upgrade an OpenAI context to Anthropic format. - - Converts message format and restructures content for Anthropic compatibility. - - Args: - obj: The OpenAI context to upgrade. - - Returns: - The upgraded Anthropic context. - """ - logger.debug(f"Upgrading to Anthropic: {obj}") - if isinstance(obj, OpenAILLMContext) and not isinstance(obj, AnthropicLLMContext): - obj.__class__ = AnthropicLLMContext - obj.__setup_local() - obj._restructure_from_openai_messages() - return obj - - @classmethod - def from_openai_context(cls, openai_context: OpenAILLMContext): - """Create Anthropic context from OpenAI context. - - Args: - openai_context: The OpenAI context to convert. - - Returns: - New Anthropic context with converted messages. - """ - self = cls( - messages=openai_context.messages, - tools=openai_context.tools, - tool_choice=openai_context.tool_choice, - ) - self.set_llm_adapter(openai_context.get_llm_adapter()) - self._restructure_from_openai_messages() - return self - - @classmethod - def from_messages(cls, messages: List[dict]) -> "AnthropicLLMContext": - """Create context from a list of messages. - - Args: - messages: List of conversation messages. - - Returns: - New Anthropic context with the provided messages. - """ - self = cls(messages=messages) - self._restructure_from_openai_messages() - return self - - def set_messages(self, messages: List): - """Set the messages list and reset cache tracking. - - Args: - messages: New list of messages to set. - """ - self.turns_above_cache_threshold = 0 - self._messages[:] = messages - self._restructure_from_openai_messages() - - def to_standard_messages(self, obj): - """Convert Anthropic message format to standard structured format. - - Handles text content and function calls for both user and assistant messages. - - Args: - obj: Message in Anthropic format. - - Returns: - List of messages in standard format. - - Examples: - Input Anthropic format:: - - { - "role": "assistant", - "content": [ - {"type": "text", "text": "Hello"}, - {"type": "tool_use", "id": "123", "name": "search", "input": {"q": "test"}} - ] - } - - Output standard format:: - - [ - {"role": "assistant", "content": [{"type": "text", "text": "Hello"}]}, - { - "role": "assistant", - "tool_calls": [ - { - "type": "function", - "id": "123", - "function": {"name": "search", "arguments": '{"q": "test"}'} - } - ] - } - ] - """ - # todo: image format (?) - # tool_use - role = obj.get("role") - content = obj.get("content") - if role == "assistant": - if isinstance(content, str): - return [{"role": role, "content": [{"type": "text", "text": content}]}] - elif isinstance(content, list): - text_items = [] - tool_items = [] - for item in content: - if item["type"] == "text": - text_items.append({"type": "text", "text": item["text"]}) - elif item["type"] == "tool_use": - tool_items.append( - { - "type": "function", - "id": item["id"], - "function": { - "name": item["name"], - "arguments": json.dumps(item["input"]), - }, - } - ) - messages = [] - if text_items: - messages.append({"role": role, "content": text_items}) - if tool_items: - messages.append({"role": role, "tool_calls": tool_items}) - return messages - elif role == "user": - if isinstance(content, str): - return [{"role": role, "content": [{"type": "text", "text": content}]}] - elif isinstance(content, list): - text_items = [] - tool_items = [] - for item in content: - if item["type"] == "text": - text_items.append({"type": "text", "text": item["text"]}) - elif item["type"] == "tool_result": - tool_items.append( - { - "role": "tool", - "tool_call_id": item["tool_use_id"], - "content": item["content"], - } - ) - messages = [] - if text_items: - messages.append({"role": role, "content": text_items}) - messages.extend(tool_items) - return messages - - def from_standard_message(self, message): - """Convert standard format message to Anthropic format. - - Handles conversion of text content, tool calls, and tool results. - Empty text content is converted to "(empty)". - - Args: - message: Message in standard format. - - Returns: - Message in Anthropic format. - - Examples: - Input standard format:: - - { - "role": "assistant", - "tool_calls": [ - { - "id": "123", - "function": {"name": "search", "arguments": '{"q": "test"}'} - } - ] - } - - Output Anthropic format:: - - { - "role": "assistant", - "content": [ - { - "type": "tool_use", - "id": "123", - "name": "search", - "input": {"q": "test"} - } - ] - } - """ - # todo: image messages (?) - if message["role"] == "tool": - return { - "role": "user", - "content": [ - { - "type": "tool_result", - "tool_use_id": message["tool_call_id"], - "content": message["content"], - }, - ], - } - if message.get("tool_calls"): - tc = message["tool_calls"] - ret = {"role": "assistant", "content": []} - for tool_call in tc: - function = tool_call["function"] - arguments = json.loads(function["arguments"]) - new_tool_use = { - "type": "tool_use", - "id": tool_call["id"], - "name": function["name"], - "input": arguments, - } - ret["content"].append(new_tool_use) - return ret - # check for empty text strings - content = message.get("content") - if isinstance(content, str): - if content == "": - content = "(empty)" - elif isinstance(content, list): - for item in content: - if item["type"] == "text" and item["text"] == "": - item["text"] = "(empty)" - - return message - - def add_image_frame_message( - self, *, format: str, size: tuple[int, int], image: bytes, text: str = None - ): - """Add an image message to the context. - - Converts the image to base64 JPEG format and adds it as a user message - with optional accompanying text. - - Args: - format: The image format (e.g., 'RGB', 'RGBA'). - size: Image dimensions as (width, height). - image: Raw image bytes. - text: Optional text to accompany the image. - """ - buffer = io.BytesIO() - Image.frombytes(format, size, image).save(buffer, format="JPEG") - encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8") - - # Anthropic docs say that the image should be the first content block in the message. - content = [ - { - "type": "image", - "source": { - "type": "base64", - "media_type": "image/jpeg", - "data": encoded_image, - }, - } - ] - if text: - content.append({"type": "text", "text": text}) - self.add_message({"role": "user", "content": content}) - - def add_message(self, message): - """Add a message to the context, merging with previous message if same role. - - Anthropic requires alternating roles, so consecutive messages from the same - role are merged together. - - Args: - message: The message to add to the context. - """ - try: - if self.messages: - # Anthropic requires that roles alternate. If this message's role is the same as the - # last message, we should add this message's content to the last message. - if self.messages[-1]["role"] == message["role"]: - # if the last message has just a content string, convert it to a list - # in the proper format - if isinstance(self.messages[-1]["content"], str): - self.messages[-1]["content"] = [ - {"type": "text", "text": self.messages[-1]["content"]} - ] - # if this message has just a content string, convert it to a list - # in the proper format - if isinstance(message["content"], str): - message["content"] = [{"type": "text", "text": message["content"]}] - # append the content of this message to the last message - self.messages[-1]["content"].extend(message["content"]) - else: - self.messages.append(message) - else: - self.messages.append(message) - except Exception as e: - logger.error(f"Error adding message: {e}") - - def get_messages_with_cache_control_markers(self) -> List[dict]: - """Get messages with prompt caching markers applied. - - Adds cache control markers to appropriate messages based on the - number of turns above the cache threshold. - - Returns: - List of messages with cache control markers added. - """ - try: - messages = copy.deepcopy(self.messages) - if self.turns_above_cache_threshold >= 1 and messages[-1]["role"] == "user": - if isinstance(messages[-1]["content"], str): - messages[-1]["content"] = [{"type": "text", "text": messages[-1]["content"]}] - messages[-1]["content"][-1]["cache_control"] = {"type": "ephemeral"} - if ( - self.turns_above_cache_threshold >= 2 - and len(messages) > 2 - and messages[-3]["role"] == "user" - ): - if isinstance(messages[-3]["content"], str): - messages[-3]["content"] = [{"type": "text", "text": messages[-3]["content"]}] - messages[-3]["content"][-1]["cache_control"] = {"type": "ephemeral"} - return messages - except Exception as e: - logger.error(f"Error adding cache control marker: {e}") - return self.messages - - def _restructure_from_openai_messages(self): - # first, map across self._messages calling self.from_standard_message(m) to modify messages in place - try: - self._messages[:] = [self.from_standard_message(m) for m in self._messages] - except Exception as e: - logger.error(f"Error mapping messages: {e}") - - # See if we should pull the system message out of our context.messages list. (For - # compatibility with Open AI messages format.) - if self.messages and self.messages[0]["role"] == "system": - if len(self.messages) == 1: - # If we have only have a system message in the list, all we can really do - # without introducing too much magic is change the role to "user". - self.messages[0]["role"] = "user" - else: - # If we have more than one message, we'll pull the system message out of the - # list. - self.system = self.messages[0]["content"] - self.messages.pop(0) - - # Merge consecutive messages with the same role. - i = 0 - while i < len(self.messages) - 1: - current_message = self.messages[i] - next_message = self.messages[i + 1] - if current_message["role"] == next_message["role"]: - # Convert content to list of dictionaries if it's a string - if isinstance(current_message["content"], str): - current_message["content"] = [ - {"type": "text", "text": current_message["content"]} - ] - if isinstance(next_message["content"], str): - next_message["content"] = [{"type": "text", "text": next_message["content"]}] - # Concatenate the content - current_message["content"].extend(next_message["content"]) - # Remove the next message from the list - self.messages.pop(i + 1) - else: - i += 1 - - # Avoid empty content in messages - for message in self.messages: - if isinstance(message["content"], str) and message["content"] == "": - message["content"] = "(empty)" - elif isinstance(message["content"], list) and len(message["content"]) == 0: - message["content"] = [{"type": "text", "text": "(empty)"}] - - def get_messages_for_persistent_storage(self): - """Get messages formatted for persistent storage. - - Includes system message at the beginning if present. - - Returns: - List of messages suitable for storage. - """ - messages = super().get_messages_for_persistent_storage() - if self.system: - messages.insert(0, {"role": "system", "content": self.system}) - return messages - - def get_messages_for_logging(self) -> List[Dict[str, Any]]: - """Get messages formatted for logging with sensitive data redacted. - - Replaces image data with placeholder text for cleaner logs. - - Returns: - List of messages in a format ready for logging. - """ - msgs = [] - for message in self.messages: - msg = copy.deepcopy(message) - if "content" in msg: - if isinstance(msg["content"], list): - for item in msg["content"]: - if item["type"] == "image": - item["source"]["data"] = "..." - msgs.append(msg) - return msgs - - -class AnthropicUserContextAggregator(LLMUserContextAggregator): - """Anthropic-specific user context aggregator. - - Handles aggregation of user messages for Anthropic LLM services. - Inherits all functionality from the base LLMUserContextAggregator. - - .. deprecated:: 0.0.99 - `AnthropicUserContextAggregator` is deprecated and will be removed in a future version. - Use the universal `LLMContext` and `LLMContextAggregatorPair` instead. - See `OpenAILLMContext` docstring for migration guide. - """ - - # Super handles deprecation warning - pass - - -# -# Claude returns a text content block along with a tool use content block. This works quite nicely -# with streaming. We get the text first, so we can start streaming it right away. Then we get the -# tool_use block. While the text is streaming to TTS and the transport, we can run the tool call. -# -# But Claude is verbose. It would be nice to come up with prompt language that suppresses Claude's -# chattiness about it's tool thinking. -# - - -class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator): - """Context aggregator for assistant messages in Anthropic conversations. - - Handles function call lifecycle management including in-progress tracking, - result handling, and cancellation for Anthropic's tool use format. - - .. deprecated:: 0.0.99 - `AnthropicAssistantContextAggregator` is deprecated and will be removed in a future version. - Use the universal `LLMContext` and `LLMContextAggregatorPair` instead. - See `OpenAILLMContext` docstring for migration guide. - """ - - # Super handles deprecation warning - - async def handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame): - """Handle a function call that is starting. - - Creates tool use message and placeholder tool result for tracking. - - Args: - frame: Frame containing function call details. - """ - assistant_message = {"role": "assistant", "content": []} - assistant_message["content"].append( - { - "type": "tool_use", - "id": frame.tool_call_id, - "name": frame.function_name, - "input": frame.arguments, - } - ) - self._context.add_message(assistant_message) - self._context.add_message( - { - "role": "user", - "content": [ - { - "type": "tool_result", - "tool_use_id": frame.tool_call_id, - "content": "IN_PROGRESS", - } - ], - } - ) - - async def handle_function_call_result(self, frame: FunctionCallResultFrame): - """Handle the result of a completed function call. - - Updates the tool result with actual return value or completion status. - - Args: - frame: Frame containing function call result. - """ - if frame.result: - result = json.dumps(frame.result, ensure_ascii=False) - await self._update_function_call_result(frame.function_name, frame.tool_call_id, result) - else: - await self._update_function_call_result( - frame.function_name, frame.tool_call_id, "COMPLETED" - ) - - async def handle_function_call_cancel(self, frame: FunctionCallCancelFrame): - """Handle cancellation of a function call. - - Updates the tool result to indicate cancellation. - - Args: - frame: Frame containing function call cancellation details. - """ - await self._update_function_call_result( - frame.function_name, frame.tool_call_id, "CANCELLED" - ) - - async def _update_function_call_result( - self, function_name: str, tool_call_id: str, result: Any - ): - for message in self._context.messages: - if message["role"] == "user": - for content in message["content"]: - if ( - isinstance(content, dict) - and content["type"] == "tool_result" - and content["tool_use_id"] == tool_call_id - ): - content["content"] = result - - async def handle_user_image_frame(self, frame: UserImageRawFrame): - """Handle a user image frame with function call context. - - Marks the associated function call as completed and adds the image - to the conversation context. - - Args: - frame: User image frame with request context. - """ - await self._update_function_call_result( - frame.request.function_name, frame.request.tool_call_id, "COMPLETED" - ) - self._context.add_image_frame_message( - format=frame.format, - size=frame.size, - image=frame.image, - text=frame.request.context, - ) diff --git a/src/pipecat/services/aws/agent_core.py b/src/pipecat/services/aws/agent_core.py index f21654584..d66af9c5b 100644 --- a/src/pipecat/services/aws/agent_core.py +++ b/src/pipecat/services/aws/agent_core.py @@ -26,15 +26,11 @@ from pipecat.frames.frames import ( LLMTextFrame, ) from pipecat.processors.aggregators.llm_context import LLMContext, LLMSpecificMessage -from pipecat.processors.aggregators.openai_llm_context import ( - OpenAILLMContext, - OpenAILLMContextFrame, -) from pipecat.processors.frame_processor import FrameDirection, FrameProcessor def default_context_to_payload_transformer( - context: LLMContext | OpenAILLMContext, + context: LLMContext, ) -> Optional[str]: """Default transformer to create AgentCore payload from LLM context. @@ -118,9 +114,7 @@ class AWSAgentCoreProcessor(FrameProcessor): aws_secret_key: Optional[str] = None, aws_session_token: Optional[str] = None, aws_region: Optional[str] = None, - context_to_payload_transformer: Optional[ - Callable[[LLMContext | OpenAILLMContext], Optional[str]] - ] = None, + context_to_payload_transformer: Optional[Callable[[LLMContext], Optional[str]]] = None, response_to_output_transformer: Optional[Callable[[str], Optional[str]]] = None, **kwargs, ): @@ -200,7 +194,7 @@ class AWSAgentCoreProcessor(FrameProcessor): direction: The direction of frame flow in the pipeline. """ await super().process_frame(frame, direction) - if isinstance(frame, (LLMContextFrame, OpenAILLMContextFrame)): + if isinstance(frame, LLMContextFrame): # Create payload to invoke AgentCore agent payload = self._context_to_payload_transformer(frame.context) diff --git a/src/pipecat/services/aws/llm.py b/src/pipecat/services/aws/llm.py index cda186167..722072592 100644 --- a/src/pipecat/services/aws/llm.py +++ b/src/pipecat/services/aws/llm.py @@ -38,21 +38,10 @@ from pipecat.frames.frames import ( LLMContextFrame, LLMFullResponseEndFrame, LLMFullResponseStartFrame, - LLMMessagesFrame, UserImageRawFrame, ) from pipecat.metrics.metrics import LLMTokenUsage from pipecat.processors.aggregators.llm_context import LLMContext -from pipecat.processors.aggregators.llm_response import ( - LLMAssistantAggregatorParams, - LLMAssistantContextAggregator, - LLMUserAggregatorParams, - LLMUserContextAggregator, -) -from pipecat.processors.aggregators.openai_llm_context import ( - OpenAILLMContext, - OpenAILLMContextFrame, -) from pipecat.processors.frame_processor import FrameDirection from pipecat.services.llm_service import LLMService from pipecat.services.settings import NOT_GIVEN, LLMSettings, _NotGiven @@ -87,657 +76,6 @@ class AWSBedrockLLMSettings(LLMSettings): ) -@dataclass -class AWSBedrockContextAggregatorPair: - """Container for AWS Bedrock context aggregators. - - Provides convenient access to both user and assistant context aggregators - for AWS Bedrock LLM operations. - - .. deprecated:: 0.0.99 - `AWSBedrockContextAggregatorPair` is deprecated and will be removed in a future version. - Use the universal `LLMContext` and `LLMContextAggregatorPair` instead. - See `OpenAILLMContext` docstring for migration guide. - - Parameters: - _user: The user context aggregator instance. - _assistant: The assistant context aggregator instance. - """ - - # Aggregators handle deprecation warnings - _user: "AWSBedrockUserContextAggregator" - _assistant: "AWSBedrockAssistantContextAggregator" - - def user(self) -> "AWSBedrockUserContextAggregator": - """Get the user context aggregator. - - Returns: - The user context aggregator instance. - """ - return self._user - - def assistant(self) -> "AWSBedrockAssistantContextAggregator": - """Get the assistant context aggregator. - - Returns: - The assistant context aggregator instance. - """ - return self._assistant - - -class AWSBedrockLLMContext(OpenAILLMContext): - """AWS Bedrock-specific LLM context implementation. - - Extends OpenAI LLM context to handle AWS Bedrock's specific message format - and system message handling. Manages conversion between OpenAI and Bedrock - message formats. - - .. deprecated:: 0.0.99 - `AWSBedrockLLMContext` is deprecated and will be removed in a future version. - Use the universal `LLMContext` and `LLMContextAggregatorPair` instead. - See `OpenAILLMContext` docstring for migration guide. - """ - - def __init__( - self, - messages: Optional[List[dict]] = None, - tools: Optional[List[dict]] = None, - tool_choice: Optional[dict] = None, - *, - system: Optional[str] = None, - ): - """Initialize AWS Bedrock LLM context. - - Args: - messages: List of conversation messages in OpenAI format. - tools: List of available function calling tools. - tool_choice: Tool selection strategy or specific tool choice. - system: System message content for AWS Bedrock. - """ - # Super handles deprecation warning - super().__init__(messages=messages, tools=tools, tool_choice=tool_choice) - self.system = system - - @staticmethod - def upgrade_to_bedrock(obj: OpenAILLMContext) -> "AWSBedrockLLMContext": - """Upgrade an OpenAI LLM context to AWS Bedrock format. - - Args: - obj: The OpenAI LLM context to upgrade. - - Returns: - The upgraded AWS Bedrock LLM context. - """ - logger.debug(f"Upgrading to AWS Bedrock: {obj}") - if isinstance(obj, OpenAILLMContext) and not isinstance(obj, AWSBedrockLLMContext): - obj.__class__ = AWSBedrockLLMContext - obj._restructure_from_openai_messages() - else: - obj._restructure_from_bedrock_messages() - return obj - - @classmethod - def from_openai_context(cls, openai_context: OpenAILLMContext): - """Create AWS Bedrock context from OpenAI context. - - Args: - openai_context: The OpenAI LLM context to convert. - - Returns: - New AWS Bedrock LLM context instance. - """ - self = cls( - messages=openai_context.messages, - tools=openai_context.tools, - tool_choice=openai_context.tool_choice, - ) - self.set_llm_adapter(openai_context.get_llm_adapter()) - self._restructure_from_openai_messages() - return self - - @classmethod - def from_messages(cls, messages: List[dict]) -> "AWSBedrockLLMContext": - """Create AWS Bedrock context from message list. - - Args: - messages: List of messages in OpenAI format. - - Returns: - New AWS Bedrock LLM context instance. - """ - self = cls(messages=messages) - self._restructure_from_openai_messages() - return self - - def set_messages(self, messages: List): - """Set the messages list and restructure for Bedrock format. - - Args: - messages: List of messages to set. - """ - self._messages[:] = messages - self._restructure_from_openai_messages() - - def to_standard_messages(self, obj): - """Convert AWS Bedrock message format to standard structured format. - - Handles text content and function calls for both user and assistant messages. - - Args: - obj: Message in AWS Bedrock format. - - Returns: - List of messages in standard format. - - Examples: - AWS Bedrock format input:: - - { - "role": "assistant", - "content": [ - {"text": "Hello"}, - {"toolUse": {"toolUseId": "123", "name": "search", "input": {"q": "test"}}} - ] - } - - Standard format output:: - - [ - {"role": "assistant", "content": [{"type": "text", "text": "Hello"}]}, - { - "role": "assistant", - "tool_calls": [ - { - "type": "function", - "id": "123", - "function": {"name": "search", "arguments": '{"q": "test"}'} - } - ] - } - ] - """ - role = obj.get("role") - content = obj.get("content") - - if role == "assistant": - if isinstance(content, str): - return [{"role": role, "content": [{"type": "text", "text": content}]}] - elif isinstance(content, list): - text_items = [] - tool_items = [] - for item in content: - if "text" in item: - text_items.append({"type": "text", "text": item["text"]}) - elif "toolUse" in item: - tool_use = item["toolUse"] - tool_items.append( - { - "type": "function", - "id": tool_use["toolUseId"], - "function": { - "name": tool_use["name"], - "arguments": json.dumps(tool_use["input"]), - }, - } - ) - messages = [] - if text_items: - messages.append({"role": role, "content": text_items}) - if tool_items: - messages.append({"role": role, "tool_calls": tool_items}) - return messages - elif role == "user": - if isinstance(content, str): - return [{"role": role, "content": [{"type": "text", "text": content}]}] - elif isinstance(content, list): - text_items = [] - tool_items = [] - for item in content: - if "text" in item: - text_items.append({"type": "text", "text": item["text"]}) - elif "toolResult" in item: - tool_result = item["toolResult"] - # Extract content from toolResult - result_content = "" - if isinstance(tool_result["content"], list): - for content_item in tool_result["content"]: - if "text" in content_item: - result_content = content_item["text"] - elif "json" in content_item: - result_content = json.dumps(content_item["json"]) - else: - result_content = tool_result["content"] - - tool_items.append( - { - "role": "tool", - "tool_call_id": tool_result["toolUseId"], - "content": result_content, - } - ) - messages = [] - if text_items: - messages.append({"role": role, "content": text_items}) - messages.extend(tool_items) - return messages - - def from_standard_message(self, message): - """Convert standard format message to AWS Bedrock format. - - Handles conversion of text content, tool calls, and tool results. - Empty text content is converted to "(empty)". - - Args: - message: Message in standard format. - - Returns: - Message in AWS Bedrock format. - - Examples: - Standard format input:: - - { - "role": "assistant", - "tool_calls": [ - { - "id": "123", - "function": {"name": "search", "arguments": '{"q": "test"}'} - } - ] - } - - AWS Bedrock format output:: - - { - "role": "assistant", - "content": [ - { - "toolUse": { - "toolUseId": "123", - "name": "search", - "input": {"q": "test"} - } - } - ] - } - """ - if message["role"] == "tool": - # Try to parse the content as JSON if it looks like JSON - try: - if message["content"].strip().startswith("{") and message[ - "content" - ].strip().endswith("}"): - content_json = json.loads(message["content"]) - tool_result_content = [{"json": content_json}] - else: - tool_result_content = [{"text": message["content"]}] - except (json.JSONDecodeError, ValueError, AttributeError): - tool_result_content = [{"text": message["content"]}] - - return { - "role": "user", - "content": [ - { - "toolResult": { - "toolUseId": message["tool_call_id"], - "content": tool_result_content, - }, - }, - ], - } - - if message.get("tool_calls"): - tc = message["tool_calls"] - ret = {"role": "assistant", "content": []} - for tool_call in tc: - function = tool_call["function"] - arguments = json.loads(function["arguments"]) - new_tool_use = { - "toolUse": { - "toolUseId": tool_call["id"], - "name": function["name"], - "input": arguments, - } - } - ret["content"].append(new_tool_use) - return ret - - # Handle text content - content = message.get("content") - if isinstance(content, str): - if content == "": - return {"role": message["role"], "content": [{"text": "(empty)"}]} - else: - return {"role": message["role"], "content": [{"text": content}]} - elif isinstance(content, list): - new_content = [] - for item in content: - # fix empty text - if item.get("type", "") == "text": - text_content = item["text"] if item["text"] != "" else "(empty)" - new_content.append({"text": text_content}) - # handle image_url -> image conversion - if item["type"] == "image_url": - new_item = { - "image": { - "format": "jpeg", - "source": { - "bytes": base64.b64decode(item["image_url"]["url"].split(",")[1]) - }, - } - } - new_content.append(new_item) - # In the case where there's a single image in the list (like what - # would result from a UserImageRawFrame), ensure that the image - # comes before text - image_indices = [i for i, item in enumerate(new_content) if "image" in item] - text_indices = [i for i, item in enumerate(new_content) if "text" in item] - if len(image_indices) == 1 and text_indices: - img_idx = image_indices[0] - first_txt_idx = text_indices[0] - if img_idx > first_txt_idx: - # Move image before the first text - image_item = new_content.pop(img_idx) - new_content.insert(first_txt_idx, image_item) - return {"role": message["role"], "content": new_content} - - return message - - def add_image_frame_message( - self, *, format: str, size: tuple[int, int], image: bytes, text: str = None - ): - """Add an image message to the context. - - Args: - format: The image format (e.g., 'RGB', 'RGBA'). - size: The image dimensions as (width, height). - image: The raw image data as bytes. - text: Optional text to accompany the image. - """ - buffer = io.BytesIO() - Image.frombytes(format, size, image).save(buffer, format="JPEG") - encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8") - - # Image should be the first content block in the message - content = [{"type": "image", "format": "jpeg", "source": {"bytes": encoded_image}}] - if text: - content.append({"text": text}) - self.add_message({"role": "user", "content": content}) - - def add_message(self, message): - """Add a message to the context, merging with previous message if same role. - - AWS Bedrock requires alternating roles, so consecutive messages from the - same role are merged together. - - Args: - message: The message to add to the context. - """ - try: - if self.messages: - # AWS Bedrock requires that roles alternate. If this message's - # role is the same as the last message, we should add this - # message's content to the last message. - if self.messages[-1]["role"] == message["role"]: - # if the last message has just a content string, convert it to a list - # in the proper format - if isinstance(self.messages[-1]["content"], str): - self.messages[-1]["content"] = [{"text": self.messages[-1]["content"]}] - # if this message has just a content string, convert it to a list - # in the proper format - if isinstance(message["content"], str): - message["content"] = [{"text": message["content"]}] - # append the content of this message to the last message - self.messages[-1]["content"].extend(message["content"]) - else: - self.messages.append(message) - else: - self.messages.append(message) - except Exception as e: - logger.error(f"Error adding message: {e}") - - def _restructure_from_bedrock_messages(self): - """Restructure messages in AWS Bedrock format. - - Handles system messages, merging consecutive messages with the same role, - and ensuring proper content formatting. - """ - # Handle system message if present at the beginning - if self.messages and self.messages[0]["role"] == "system": - if len(self.messages) == 1: - self.messages[0]["role"] = "user" - else: - system_content = self.messages.pop(0)["content"] - if isinstance(system_content, str): - system_content = [{"text": system_content}] - - if self.system: - if isinstance(self.system, str): - self.system = [{"text": self.system}] - self.system.extend(system_content) - else: - self.system = system_content - - # Ensure content is properly formatted - for msg in self.messages: - if isinstance(msg["content"], str): - msg["content"] = [{"text": msg["content"]}] - elif not msg["content"]: - msg["content"] = [{"text": "(empty)"}] - elif isinstance(msg["content"], list): - for idx, item in enumerate(msg["content"]): - if isinstance(item, dict) and "text" in item and item["text"] == "": - item["text"] = "(empty)" - elif isinstance(item, str) and item == "": - msg["content"][idx] = {"text": "(empty)"} - - # Merge consecutive messages with the same role - merged_messages = [] - for msg in self.messages: - if merged_messages and merged_messages[-1]["role"] == msg["role"]: - merged_messages[-1]["content"].extend(msg["content"]) - else: - merged_messages.append(msg) - - self.messages.clear() - self.messages.extend(merged_messages) - - def _restructure_from_openai_messages(self): - # first, map across self._messages calling self.from_standard_message(m) to modify messages in place - try: - self._messages[:] = [self.from_standard_message(m) for m in self._messages] - except Exception as e: - logger.error(f"Error mapping messages: {e}") - - # See if we should pull the system message out of our context.messages list. (For - # compatibility with Open AI messages format.) - if self.messages and self.messages[0]["role"] == "system": - self.system = self.messages[0]["content"] - self.messages.pop(0) - - # Merge consecutive messages with the same role. - i = 0 - while i < len(self.messages) - 1: - current_message = self.messages[i] - next_message = self.messages[i + 1] - if current_message["role"] == next_message["role"]: - # Convert content to list of dictionaries if it's a string - if isinstance(current_message["content"], str): - current_message["content"] = [ - {"type": "text", "text": current_message["content"]} - ] - if isinstance(next_message["content"], str): - next_message["content"] = [{"type": "text", "text": next_message["content"]}] - # Concatenate the content - current_message["content"].extend(next_message["content"]) - # Remove the next message from the list - self.messages.pop(i + 1) - else: - i += 1 - - # Avoid empty content in messages - for message in self.messages: - if isinstance(message["content"], str) and message["content"] == "": - message["content"] = "(empty)" - elif isinstance(message["content"], list) and len(message["content"]) == 0: - message["content"] = [{"type": "text", "text": "(empty)"}] - - def get_messages_for_persistent_storage(self): - """Get messages formatted for persistent storage. - - Returns: - List of messages including system message if present. - """ - messages = super().get_messages_for_persistent_storage() - if self.system: - messages.insert(0, {"role": "system", "content": self.system}) - return messages - - def get_messages_for_logging(self) -> List[Dict[str, Any]]: - """Get messages formatted for logging with sensitive data redacted. - - Returns: - List of messages in a format ready for logging. - """ - msgs = [] - for message in self.messages: - msg = copy.deepcopy(message) - if "content" in msg: - if isinstance(msg["content"], list): - for item in msg["content"]: - if item.get("image"): - item["image"]["source"]["bytes"] = "..." - msgs.append(msg) - return msgs - - -class AWSBedrockUserContextAggregator(LLMUserContextAggregator): - """User context aggregator for AWS Bedrock LLM service. - - Handles aggregation of user messages and frames for AWS Bedrock format. - Inherits all functionality from the base LLM user context aggregator. - - .. deprecated:: 0.0.99 - `AWSBedrockUserContextAggregator` is deprecated and will be removed in a future version. - Use the universal `LLMContext` and `LLMContextAggregatorPair` instead. - See `OpenAILLMContext` docstring for migration guide. - - Args: - context: The LLM context to aggregate messages into. - params: Configuration parameters for the aggregator. - """ - - # Super handles deprecation warning - pass - - -class AWSBedrockAssistantContextAggregator(LLMAssistantContextAggregator): - """Assistant context aggregator for AWS Bedrock LLM service. - - Handles aggregation of assistant responses and function calls for AWS Bedrock - format, including tool use and tool result handling. - - .. deprecated:: 0.0.99 - `AWSBedrockAssistantContextAggregator` is deprecated and will be removed in a future version. - Use the universal `LLMContext` and `LLMContextAggregatorPair` instead. - See `OpenAILLMContext` docstring for migration guide. - - Args: - context: The LLM context to aggregate messages into. - params: Configuration parameters for the aggregator. - """ - - # Super handles deprecation warning - - async def handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame): - """Handle function call in progress frame. - - Args: - frame: The function call in progress frame to handle. - """ - # Format tool use according to AWS Bedrock API - self._context.add_message( - { - "role": "assistant", - "content": [ - { - "toolUse": { - "toolUseId": frame.tool_call_id, - "name": frame.function_name, - "input": frame.arguments if frame.arguments else {}, - } - } - ], - } - ) - self._context.add_message( - { - "role": "user", - "content": [ - { - "toolResult": { - "toolUseId": frame.tool_call_id, - "content": [{"text": "IN_PROGRESS"}], - } - } - ], - } - ) - - async def handle_function_call_result(self, frame: FunctionCallResultFrame): - """Handle function call result frame. - - Args: - frame: The function call result frame to handle. - """ - if frame.result: - result = json.dumps(frame.result, ensure_ascii=False) - await self._update_function_call_result(frame.function_name, frame.tool_call_id, result) - else: - await self._update_function_call_result( - frame.function_name, frame.tool_call_id, "COMPLETED" - ) - - async def handle_function_call_cancel(self, frame: FunctionCallCancelFrame): - """Handle function call cancel frame. - - Args: - frame: The function call cancel frame to handle. - """ - await self._update_function_call_result( - frame.function_name, frame.tool_call_id, "CANCELLED" - ) - - async def _update_function_call_result( - self, function_name: str, tool_call_id: str, result: Any - ): - for message in self._context.messages: - if message["role"] == "user": - for content in message["content"]: - if ( - isinstance(content, dict) - and content.get("toolResult") - and content["toolResult"]["toolUseId"] == tool_call_id - ): - content["toolResult"]["content"] = [{"text": result}] - - async def handle_user_image_frame(self, frame: UserImageRawFrame): - """Handle user image frame. - - Args: - frame: The user image frame to handle. - """ - await self._update_function_call_result( - frame.request.function_name, frame.request.tool_call_id, "COMPLETED" - ) - self._context.add_image_frame_message( - format=frame.format, - size=frame.size, - image=frame.image, - text=frame.request.context, - ) - - class AWSBedrockLLMService(LLMService): """AWS Bedrock Large Language Model service implementation. @@ -924,7 +262,7 @@ class AWSBedrockLLMService(LLMService): async def run_inference( self, - context: LLMContext | OpenAILLMContext, + context: LLMContext, max_tokens: Optional[int] = None, system_instruction: Optional[str] = None, ) -> Optional[str]: @@ -943,17 +281,12 @@ class AWSBedrockLLMService(LLMService): messages = [] system = [] effective_instruction = system_instruction or self._settings.system_instruction - if isinstance(context, LLMContext): - adapter: AWSBedrockLLMAdapter = self.get_llm_adapter() - params: AWSBedrockLLMInvocationParams = adapter.get_llm_invocation_params( - context, system_instruction=effective_instruction - ) - messages = params["messages"] - system = params["system"] # [{"text": "system message"}] or None - else: - context = AWSBedrockLLMContext.upgrade_to_bedrock(context) - messages = context.messages - system = getattr(context, "system", None) # [{"text": "system message"}] + adapter: AWSBedrockLLMAdapter = self.get_llm_adapter() + params: AWSBedrockLLMInvocationParams = adapter.get_llm_invocation_params( + context, system_instruction=effective_instruction + ) + messages = params["messages"] + system = params["system"] # [{"text": "system message"}] or None # Prepare request parameters using the same method as streaming inference_config = self._build_inference_config() @@ -1021,44 +354,6 @@ class AWSBedrockLLMService(LLMService): response = await client.converse_stream(**request_params) return response - def create_context_aggregator( - self, - context: OpenAILLMContext, - *, - user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(), - assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(), - ) -> AWSBedrockContextAggregatorPair: - """Create AWS Bedrock-specific context aggregators. - - Creates a pair of context aggregators optimized for AWS Bedrocks's message - format, including support for function calls, tool usage, and image handling. - - Args: - context: The LLM context to create aggregators for. - user_params: Parameters for user message aggregation. - assistant_params: Parameters for assistant message aggregation. - - Returns: - AWSBedrockContextAggregatorPair: A pair of context aggregators, one for - the user and one for the assistant, encapsulated in an - AWSBedrockContextAggregatorPair. - - .. deprecated:: 0.0.99 - `create_context_aggregator()` is deprecated and will be removed in a future version. - Use the universal `LLMContext` and `LLMContextAggregatorPair` instead. - See `OpenAILLMContext` docstring for migration guide. - """ - context.set_llm_adapter(self.get_llm_adapter()) - - if isinstance(context, OpenAILLMContext): - context = AWSBedrockLLMContext.from_openai_context(context) - - # Aggregators handle deprecation warnings - user = AWSBedrockUserContextAggregator(context, params=user_params) - assistant = AWSBedrockAssistantContextAggregator(context, params=assistant_params) - - return AWSBedrockContextAggregatorPair(_user=user, _assistant=assistant) - def _create_no_op_tool(self): """Create a no-operation tool for AWS Bedrock when tool content exists but no tools are defined. @@ -1074,27 +369,15 @@ class AWSBedrockLLMService(LLMService): } } - def _get_llm_invocation_params( - self, context: OpenAILLMContext | LLMContext - ) -> AWSBedrockLLMInvocationParams: - # Universal LLMContext - if isinstance(context, LLMContext): - adapter: AWSBedrockLLMAdapter = self.get_llm_adapter() - params: AWSBedrockLLMInvocationParams = adapter.get_llm_invocation_params( - context, system_instruction=self._settings.system_instruction - ) - return params - - # AWS Bedrock-specific context - return AWSBedrockLLMInvocationParams( - system=getattr(context, "system", None), - messages=context.messages, - tools=context.tools or [], - tool_choice=context.tool_choice, + def _get_llm_invocation_params(self, context: LLMContext) -> AWSBedrockLLMInvocationParams: + adapter: AWSBedrockLLMAdapter = self.get_llm_adapter() + params: AWSBedrockLLMInvocationParams = adapter.get_llm_invocation_params( + context, system_instruction=self._settings.system_instruction ) + return params @traced_llm - async def _process_context(self, context: AWSBedrockLLMContext | LLMContext): + async def _process_context(self, context: LLMContext): # Usage tracking prompt_tokens = 0 completion_tokens = 0 @@ -1173,15 +456,10 @@ class AWSBedrockLLMService(LLMService): request_params["performanceConfig"] = {"latency": self._settings.latency} # Log request params with messages redacted for logging - if isinstance(context, LLMContext): - adapter = self.get_llm_adapter() - context_type_for_logging = "universal" - messages_for_logging = adapter.get_messages_for_logging(context) - else: - context_type_for_logging = "LLM-specific" - messages_for_logging = context.get_messages_for_logging() + adapter = self.get_llm_adapter() + messages_for_logging = adapter.get_messages_for_logging(context) logger.debug( - f"{self}: Generating chat from {context_type_for_logging} context [{system}] | {messages_for_logging}" + f"{self}: Generating chat from context [{system}] | {messages_for_logging}" ) async with self._aws_session.client( @@ -1286,21 +564,11 @@ class AWSBedrockLLMService(LLMService): """ await super().process_frame(frame, direction) - context = None - if isinstance(frame, OpenAILLMContextFrame): - context = AWSBedrockLLMContext.upgrade_to_bedrock(frame.context) if isinstance(frame, LLMContextFrame): - context = frame.context - elif isinstance(frame, LLMMessagesFrame): - # NOTE: LLMMessagesFrame is deprecated, so we don't support the newer universal - # LLMContext with it - context = AWSBedrockLLMContext.from_messages(frame.messages) + await self._process_context(frame.context) else: await self.push_frame(frame, direction) - if context: - await self._process_context(context) - def _estimate_tokens(self, text: str) -> int: return int(len(re.split(r"[^\w]+", text)) * 1.3) diff --git a/src/pipecat/services/aws/nova_sonic/context.py b/src/pipecat/services/aws/nova_sonic/context.py deleted file mode 100644 index a87564c5b..000000000 --- a/src/pipecat/services/aws/nova_sonic/context.py +++ /dev/null @@ -1,460 +0,0 @@ -# -# Copyright (c) 2024-2026, Daily -# -# SPDX-License-Identifier: BSD 2-Clause License -# - -"""Context management for AWS Nova Sonic LLM service. - -This module provides specialized context aggregators and message handling for AWS Nova Sonic, -including conversation history management and role-specific message processing. - -.. deprecated:: 0.0.91 - AWS Nova Sonic no longer uses types from this module under the hood. - It now uses ``LLMContext`` and ``LLMContextAggregatorPair``. - Using the new patterns should allow you to not need types from this module. - - BEFORE:: - - # Setup - context = OpenAILLMContext(messages, tools) - context_aggregator = llm.create_context_aggregator(context) - - # Context frame type - frame: OpenAILLMContextFrame - - # Context type - context: AWSNovaSonicLLMContext - # or - context: OpenAILLMContext - - AFTER:: - - # Setup - context = LLMContext(messages, tools) - context_aggregator = LLMContextAggregatorPair(context) - - # Context frame type - frame: LLMContextFrame - - # Context type - context: LLMContext -""" - -import warnings - -with warnings.catch_warnings(): - warnings.simplefilter("always") - warnings.warn( - "Types in pipecat.services.aws.nova_sonic.context (or " - "pipecat.services.aws_nova_sonic.context) are deprecated. \n" - "AWS Nova Sonic no longer uses types from this module under the hood. \n" - "It now uses `LLMContext` and `LLMContextAggregatorPair`. \n" - "Using the new patterns should allow you to not need types from this module.\n\n" - "BEFORE:\n" - "```\n" - "# Setup\n" - "context = OpenAILLMContext(messages, tools)\n" - "context_aggregator = llm.create_context_aggregator(context)\n\n" - "# Context frame type\n" - "frame: OpenAILLMContextFrame\n\n" - "# Context type\n" - "context: AWSNovaSonicLLMContext\n" - "# or\n" - "context: OpenAILLMContext\n\n" - "```\n\n" - "AFTER:\n" - "```\n" - "# Setup\n" - "context = LLMContext(messages, tools)\n" - "context_aggregator = LLMContextAggregatorPair(context)\n\n" - "# Context frame type\n" - "frame: LLMContextFrame\n\n" - "# Context type\n" - "context: LLMContext\n\n" - "```", - DeprecationWarning, - stacklevel=2, - ) - -import copy -from dataclasses import dataclass, field -from enum import Enum - -from loguru import logger - -from pipecat.frames.frames import ( - BotStoppedSpeakingFrame, - DataFrame, - Frame, - FunctionCallResultFrame, - InterruptionFrame, - LLMFullResponseEndFrame, - LLMFullResponseStartFrame, - LLMMessagesAppendFrame, - LLMMessagesUpdateFrame, - LLMSetToolChoiceFrame, - LLMSetToolsFrame, - TextFrame, - UserImageRawFrame, -) -from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext -from pipecat.processors.frame_processor import FrameDirection -from pipecat.services.aws.nova_sonic.frames import AWSNovaSonicFunctionCallResultFrame -from pipecat.services.openai.llm import ( - OpenAIAssistantContextAggregator, - OpenAIUserContextAggregator, -) - - -class Role(Enum): - """Roles supported in AWS Nova Sonic conversations. - - Parameters: - SYSTEM: System-level messages (not used in conversation history). - USER: Messages sent by the user. - ASSISTANT: Messages sent by the assistant. - TOOL: Messages sent by tools (not used in conversation history). - """ - - SYSTEM = "SYSTEM" - USER = "USER" - ASSISTANT = "ASSISTANT" - TOOL = "TOOL" - - -@dataclass -class AWSNovaSonicConversationHistoryMessage: - """A single message in AWS Nova Sonic conversation history. - - Parameters: - role: The role of the message sender (USER or ASSISTANT only). - text: The text content of the message. - """ - - role: Role # only USER and ASSISTANT - text: str - - -@dataclass -class AWSNovaSonicConversationHistory: - """Complete conversation history for AWS Nova Sonic initialization. - - Parameters: - system_instruction: System-level instruction for the conversation. - messages: List of conversation messages between user and assistant. - """ - - system_instruction: str = None - messages: list[AWSNovaSonicConversationHistoryMessage] = field(default_factory=list) - - -class AWSNovaSonicLLMContext(OpenAILLMContext): - """Specialized LLM context for AWS Nova Sonic service. - - Extends OpenAI context with Nova Sonic-specific message handling, - conversation history management, and text buffering capabilities. - - .. deprecated:: 0.0.99 - `AWSNovaSonicLLMContext` is deprecated and will be removed in a future version. - Use the universal `LLMContext` and `LLMContextAggregatorPair` instead. - See `OpenAILLMContext` docstring for migration guide. - """ - - def __init__(self, messages=None, tools=None, **kwargs): - """Initialize AWS Nova Sonic LLM context. - - Args: - messages: Initial messages for the context. - tools: Available tools for the context. - **kwargs: Additional arguments passed to parent class. - """ - # Super handles deprecation warning - super().__init__(messages=messages, tools=tools, **kwargs) - self.__setup_local() - - def __setup_local(self, system_instruction: str = ""): - self._assistant_text = "" - self._user_text = "" - self._system_instruction = system_instruction - - @staticmethod - def upgrade_to_nova_sonic( - obj: OpenAILLMContext, system_instruction: str - ) -> "AWSNovaSonicLLMContext": - """Upgrade an OpenAI context to AWS Nova Sonic context. - - Args: - obj: The OpenAI context to upgrade. - system_instruction: System instruction for the context. - - Returns: - The upgraded AWS Nova Sonic context. - """ - if isinstance(obj, OpenAILLMContext) and not isinstance(obj, AWSNovaSonicLLMContext): - obj.__class__ = AWSNovaSonicLLMContext - obj.__setup_local(system_instruction) - return obj - - # NOTE: this method has the side-effect of updating _system_instruction from messages - def get_messages_for_initializing_history(self) -> AWSNovaSonicConversationHistory: - """Get conversation history for initializing AWS Nova Sonic session. - - Processes stored messages and extracts system instruction and conversation - history in the format expected by AWS Nova Sonic. - - Returns: - Formatted conversation history with system instruction and messages. - """ - history = AWSNovaSonicConversationHistory(system_instruction=self._system_instruction) - - # Bail if there are no messages - if not self.messages: - return history - - messages = copy.deepcopy(self.messages) - - # If we have a "system" message as our first message, let's pull that out into "instruction" - if messages[0].get("role") == "system": - system = messages.pop(0) - content = system.get("content") - if isinstance(content, str): - history.system_instruction = content - elif isinstance(content, list): - history.system_instruction = content[0].get("text") - if history.system_instruction: - self._system_instruction = history.system_instruction - - # Process remaining messages to fill out conversation history. - # Nova Sonic supports "user" and "assistant" messages in history. - for message in messages: - history_message = self.from_standard_message(message) - if history_message: - history.messages.append(history_message) - - return history - - def get_messages_for_persistent_storage(self): - """Get messages formatted for persistent storage. - - Returns: - List of messages including system instruction if present. - """ - messages = super().get_messages_for_persistent_storage() - # If we have a system instruction and messages doesn't already contain it, add it - if self._system_instruction and not (messages and messages[0].get("role") == "system"): - messages.insert(0, {"role": "system", "content": self._system_instruction}) - return messages - - def from_standard_message(self, message) -> AWSNovaSonicConversationHistoryMessage: - """Convert standard message format to Nova Sonic format. - - Args: - message: Standard message dictionary to convert. - - Returns: - Nova Sonic conversation history message, or None if not convertible. - """ - role = message.get("role") - if message.get("role") == "user" or message.get("role") == "assistant": - content = message.get("content") - if isinstance(message.get("content"), list): - content = "" - for c in message.get("content"): - if c.get("type") == "text": - content += " " + c.get("text") - else: - logger.error( - f"Unhandled content type in context message: {c.get('type')} - {message}" - ) - # There won't be content if this is an assistant tool call entry. - # We're ignoring those since they can't be loaded into AWS Nova Sonic conversation - # history - if content: - return AWSNovaSonicConversationHistoryMessage(role=Role[role.upper()], text=content) - # NOTE: we're ignoring messages with role "tool" since they can't be loaded into AWS Nova - # Sonic conversation history - - def buffer_user_text(self, text): - """Buffer user text for later flushing to context. - - Args: - text: User text to buffer. - """ - self._user_text += f" {text}" if self._user_text else text - # logger.debug(f"User text buffered: {self._user_text}") - - def flush_aggregated_user_text(self) -> str: - """Flush buffered user text to context as a complete message. - - Returns: - The flushed user text, or empty string if no text was buffered. - """ - if not self._user_text: - return "" - user_text = self._user_text - message = { - "role": "user", - "content": [{"type": "text", "text": user_text}], - } - self._user_text = "" - self.add_message(message) - # logger.debug(f"Context updated (user): {self.get_messages_for_logging()}") - return user_text - - def buffer_assistant_text(self, text): - """Buffer assistant text for later flushing to context. - - Args: - text: Assistant text to buffer. - """ - self._assistant_text += text - # logger.debug(f"Assistant text buffered: {self._assistant_text}") - - def flush_aggregated_assistant_text(self): - """Flush buffered assistant text to context as a complete message.""" - if not self._assistant_text: - return - message = { - "role": "assistant", - "content": [{"type": "text", "text": self._assistant_text}], - } - self._assistant_text = "" - self.add_message(message) - # logger.debug(f"Context updated (assistant): {self.get_messages_for_logging()}") - - -@dataclass -class AWSNovaSonicMessagesUpdateFrame(DataFrame): - """Frame containing updated AWS Nova Sonic context. - - Parameters: - context: The updated AWS Nova Sonic LLM context. - """ - - context: AWSNovaSonicLLMContext - - -class AWSNovaSonicUserContextAggregator(OpenAIUserContextAggregator): - """Context aggregator for user messages in AWS Nova Sonic conversations. - - Extends the OpenAI user context aggregator to emit Nova Sonic-specific - context update frames. - - .. deprecated:: 0.0.99 - `AWSNovaSonicUserContextAggregator` is deprecated and will be removed in a future version. - Use the universal `LLMContext` and `LLMContextAggregatorPair` instead. - See `OpenAILLMContext` docstring for migration guide. - """ - - # Super handles deprecation warning - - async def process_frame( - self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM - ): - """Process frames and emit Nova Sonic-specific context updates. - - Args: - frame: The frame to process. - direction: The direction the frame is traveling. - """ - await super().process_frame(frame, direction) - - # Parent does not push LLMMessagesUpdateFrame - if isinstance(frame, LLMMessagesUpdateFrame): - await self.push_frame(AWSNovaSonicMessagesUpdateFrame(context=self._context)) - - -class AWSNovaSonicAssistantContextAggregator(OpenAIAssistantContextAggregator): - """Context aggregator for assistant messages in AWS Nova Sonic conversations. - - Provides specialized handling for assistant responses and function calls - in AWS Nova Sonic context, with custom frame processing logic. - - .. deprecated:: 0.0.99 - `AWSNovaSonicAssistantContextAggregator` is deprecated and will be removed in a future version. - Use the universal `LLMContext` and `LLMContextAggregatorPair` instead. - See `OpenAILLMContext` docstring for migration guide. - """ - - # Super handles deprecation warning - - async def process_frame(self, frame: Frame, direction: FrameDirection): - """Process frames with Nova Sonic-specific logic. - - Args: - frame: The frame to process. - direction: The direction the frame is traveling. - """ - # HACK: For now, disable the context aggregator by making it just pass through all frames - # that the parent handles (except the function call stuff, which we still need). - # For an explanation of this hack, see - # AWSNovaSonicLLMService._report_assistant_response_text_added. - if isinstance( - frame, - ( - InterruptionFrame, - LLMFullResponseStartFrame, - LLMFullResponseEndFrame, - TextFrame, - LLMMessagesAppendFrame, - LLMMessagesUpdateFrame, - LLMSetToolsFrame, - LLMSetToolChoiceFrame, - UserImageRawFrame, - BotStoppedSpeakingFrame, - ), - ): - await self.push_frame(frame, direction) - else: - await super().process_frame(frame, direction) - - async def handle_function_call_result(self, frame: FunctionCallResultFrame): - """Handle function call results for AWS Nova Sonic. - - Args: - frame: The function call result frame to handle. - """ - await super().handle_function_call_result(frame) - - # The standard function callback code path pushes the FunctionCallResultFrame from the LLM - # itself, so we didn't have a chance to add the result to the AWS Nova Sonic server-side - # context. Let's push a special frame to do that. - await self.push_frame( - AWSNovaSonicFunctionCallResultFrame(result_frame=frame), FrameDirection.UPSTREAM - ) - - -@dataclass -class AWSNovaSonicContextAggregatorPair: - """Pair of user and assistant context aggregators for AWS Nova Sonic. - - .. deprecated:: 0.0.99 - `AWSNovaSonicContextAggregatorPair` is deprecated and will be removed in a future version. - Use the universal `LLMContext` and `LLMContextAggregatorPair` instead. - See `OpenAILLMContext` docstring for migration guide. - - Parameters: - _user: The user context aggregator. - _assistant: The assistant context aggregator. - """ - - # Aggregators handle deprecation warnings - _user: AWSNovaSonicUserContextAggregator - _assistant: AWSNovaSonicAssistantContextAggregator - - def user(self) -> AWSNovaSonicUserContextAggregator: - """Get the user context aggregator. - - Returns: - The user context aggregator instance. - """ - return self._user - - def assistant(self) -> AWSNovaSonicAssistantContextAggregator: - """Get the assistant context aggregator. - - Returns: - The assistant context aggregator instance. - """ - return self._assistant diff --git a/src/pipecat/services/aws/nova_sonic/frames.py b/src/pipecat/services/aws/nova_sonic/frames.py deleted file mode 100644 index f392eba7f..000000000 --- a/src/pipecat/services/aws/nova_sonic/frames.py +++ /dev/null @@ -1,25 +0,0 @@ -# -# Copyright (c) 2024-2026, Daily -# -# SPDX-License-Identifier: BSD 2-Clause License -# - -"""Custom frames for AWS Nova Sonic LLM service.""" - -from dataclasses import dataclass - -from pipecat.frames.frames import DataFrame, FunctionCallResultFrame - - -@dataclass -class AWSNovaSonicFunctionCallResultFrame(DataFrame): - """Frame containing function call result for AWS Nova Sonic processing. - - This frame wraps a standard function call result frame to enable - AWS Nova Sonic-specific handling and context updates. - - Parameters: - result_frame: The underlying function call result frame. - """ - - result_frame: FunctionCallResultFrame diff --git a/src/pipecat/services/aws/nova_sonic/llm.py b/src/pipecat/services/aws/nova_sonic/llm.py index 3541946c7..9aa36c5db 100644 --- a/src/pipecat/services/aws/nova_sonic/llm.py +++ b/src/pipecat/services/aws/nova_sonic/llm.py @@ -49,15 +49,7 @@ from pipecat.frames.frames import ( UserStoppedSpeakingFrame, ) from pipecat.processors.aggregators.llm_context import LLMContext -from pipecat.processors.aggregators.llm_response import ( - LLMAssistantAggregatorParams, - LLMUserAggregatorParams, -) from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair -from pipecat.processors.aggregators.openai_llm_context import ( - OpenAILLMContext, - OpenAILLMContextFrame, -) from pipecat.processors.frame_processor import FrameDirection from pipecat.services.llm_service import LLMService from pipecat.services.settings import NOT_GIVEN, LLMSettings, _NotGiven @@ -531,13 +523,8 @@ class AWSNovaSonicLLMService(LLMService): """ await super().process_frame(frame, direction) - if isinstance(frame, (LLMContextFrame, OpenAILLMContextFrame)): - context = ( - frame.context - if isinstance(frame, LLMContextFrame) - else LLMContext.from_openai_context(frame.context) - ) - await self._handle_context(context) + if isinstance(frame, LLMContextFrame): + await self._handle_context(frame.context) elif isinstance(frame, InputAudioRawFrame): await self._handle_input_audio_frame(frame) elif isinstance(frame, InterruptionFrame): @@ -1353,44 +1340,6 @@ class AWSNovaSonicLLMService(LLMService): # We're no longer waiting for a trigger transcription self._waiting_for_trigger_transcription = False - # - # context - # - - def create_context_aggregator( - self, - context: OpenAILLMContext, - *, - user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(), - assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(), - ) -> LLMContextAggregatorPair: - """Create context aggregator pair for managing conversation context. - - NOTE: this method exists only for backward compatibility. New code - should instead do:: - - context = LLMContext(...) - context_aggregator = LLMContextAggregatorPair(context) - - Args: - context: The OpenAI LLM context. - user_params: Parameters for the user context aggregator. - assistant_params: Parameters for the assistant context aggregator. - - Returns: - A pair of user and assistant context aggregators. - - .. deprecated:: 0.0.99 - `create_context_aggregator()` is deprecated and will be removed in a future version. - Use the universal `LLMContext` and `LLMContextAggregatorPair` instead. - See `OpenAILLMContext` docstring for migration guide. - """ - # from_openai_context handles deprecation warning - context = LLMContext.from_openai_context(context) - return LLMContextAggregatorPair( - context, user_params=user_params, assistant_params=assistant_params - ) - # # assistant response trigger # HACK: only needed for the older Nova Sonic (as opposed to Nova 2 Sonic) model diff --git a/src/pipecat/services/google/gemini_live/llm.py b/src/pipecat/services/google/gemini_live/llm.py index 846a614a2..1bd7174b2 100644 --- a/src/pipecat/services/google/gemini_live/llm.py +++ b/src/pipecat/services/google/gemini_live/llm.py @@ -59,23 +59,11 @@ from pipecat.frames.frames import ( ) from pipecat.metrics.metrics import LLMTokenUsage from pipecat.processors.aggregators.llm_context import LLMContext -from pipecat.processors.aggregators.llm_response import ( - LLMAssistantAggregatorParams, - LLMUserAggregatorParams, -) from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair -from pipecat.processors.aggregators.openai_llm_context import ( - OpenAILLMContext, - OpenAILLMContextFrame, -) from pipecat.processors.frame_processor import FrameDirection from pipecat.services.google.frames import LLMSearchOrigin, LLMSearchResponseFrame, LLMSearchResult from pipecat.services.google.utils import update_google_client_http_options from pipecat.services.llm_service import FunctionCallFromLLM, LLMService -from pipecat.services.openai.llm import ( - OpenAIAssistantContextAggregator, - OpenAIUserContextAggregator, -) from pipecat.services.settings import NOT_GIVEN, LLMSettings, _NotGiven from pipecat.transcriptions.language import Language, resolve_language from pipecat.utils.string import match_endofsentence @@ -224,274 +212,6 @@ def language_to_gemini_language(language: Language) -> Optional[str]: return resolve_language(language, LANGUAGE_MAP, use_base_code=False) -class GeminiLiveContext(OpenAILLMContext): - """Extended OpenAI context for Gemini Live API. - - Provides Gemini-specific context management including system instruction - extraction and message format conversion for the Live API. - - .. deprecated:: 0.0.92 - Gemini Live no longer uses `GeminiLiveContext` under the hood. - It now uses `LLMContext`. - """ - - @staticmethod - def upgrade(obj: OpenAILLMContext) -> "GeminiLiveContext": - """Upgrade an OpenAI context to Gemini context. - - Args: - obj: The OpenAI context to upgrade. - - Returns: - The upgraded Gemini context instance. - """ - # This warning is here rather than `__init__` since `upgrade()` was the - # "main" way that GeminiLiveContext instances were created. - # Almost no users should be seeing this message anyway, as - # GeminiLiveContext instances were typically created under the hood: - # the user would pass an OpenAILLMContext instance, which would be - # upgraded without them necessarily knowing. - with warnings.catch_warnings(): - warnings.simplefilter("always") - warnings.warn( - "GeminiLiveContext is deprecated. " - "Gemini Live no longer uses GeminiLiveContext under the hood. " - "It now uses LLMContext.", - DeprecationWarning, - stacklevel=2, - ) - - if isinstance(obj, OpenAILLMContext) and not isinstance(obj, GeminiLiveContext): - logger.debug(f"Upgrading to Gemini Live Context: {obj}") - obj.__class__ = GeminiLiveContext - obj._restructure_from_openai_messages() - return obj - - def _restructure_from_openai_messages(self): - pass - - def extract_system_instructions(self): - """Extract system instructions from context messages. - - Returns: - Combined system instruction text from all system messages. - """ - system_instruction = "" - for item in self.messages: - if item.get("role") == "system": - content = item.get("content", "") - if content: - if system_instruction and not system_instruction.endswith("\n"): - system_instruction += "\n" - system_instruction += str(content) - return system_instruction - - def add_file_reference(self, file_uri: str, mime_type: str, text: Optional[str] = None): - """Add a file reference to the context. - - This adds a user message with a file reference that will be sent during context initialization. - - Args: - file_uri: URI of the uploaded file - mime_type: MIME type of the file - text: Optional text prompt to accompany the file - """ - # Create parts list with file reference - parts = [] - if text: - parts.append({"type": "text", "text": text}) - - # Add file reference part - parts.append( - {"type": "file_data", "file_data": {"mime_type": mime_type, "file_uri": file_uri}} - ) - - # Add to messages - message = {"role": "user", "content": parts} - self.messages.append(message) - logger.info(f"Added file reference to context: {file_uri}") - - def get_messages_for_initializing_history(self) -> List[Content]: - """Get messages formatted for Gemini history initialization. - - Returns: - List of messages in Gemini format for conversation history. - """ - messages: List[Content] = [] - for item in self.messages: - role = item.get("role") - - if role == "system": - continue - - elif role == "assistant": - role = "model" - - content = item.get("content") - parts: List[Part] = [] - if isinstance(content, str): - parts = [Part(text=content)] - elif isinstance(content, list): - for part in content: - if part.get("type") == "text": - parts.append(Part(text=part.get("text"))) - elif part.get("type") == "file_data": - file_data = part.get("file_data", {}) - parts.append( - Part( - file_data=FileData( - mime_type=file_data.get("mime_type"), - file_uri=file_data.get("file_uri"), - ) - ) - ) - else: - logger.warning(f"Unsupported content type: {str(part)[:80]}") - else: - logger.warning(f"Unsupported content type: {str(content)[:80]}") - messages.append(Content(role=role, parts=parts)) - return messages - - -class GeminiLiveUserContextAggregator(OpenAIUserContextAggregator): - """User context aggregator for Gemini Live. - - Extends OpenAI user aggregator to handle Gemini-specific message passing - while maintaining compatibility with the standard aggregation pipeline. - - .. deprecated:: 0.0.92 - Gemini Live no longer expects a `GeminiLiveUserContextAggregator`. - It now expects a `LLMUserAggregator`. - """ - - def __init__(self, *args, **kwargs): - """Initialize Gemini Live user context aggregator.""" - # Almost no users should be seeing this message, as - # `GeminiLiveUserContextAggregator`` instances were typically created - # under the hood, as part of `llm.create_context_aggregator()`. - with warnings.catch_warnings(): - warnings.simplefilter("always") - warnings.warn( - "GeminiLiveUserContextAggregator is deprecated. " - "Gemini Live no longer expects a GeminiLiveUserContextAggregator. " - "It now expects a LLMUserAggregator.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - async def process_frame(self, frame, direction): - """Process incoming frames for user context aggregation. - - Args: - frame: The frame to process. - direction: The frame processing direction. - """ - await super().process_frame(frame, direction) - # kind of a hack just to pass the LLMMessagesAppendFrame through, but it's fine for now - if isinstance(frame, LLMMessagesAppendFrame): - await self.push_frame(frame, direction) - - -class GeminiLiveAssistantContextAggregator(OpenAIAssistantContextAggregator): - """Assistant context aggregator for Gemini Live. - - Handles assistant response aggregation while filtering out LLMTextFrames - to prevent duplicate context entries, as Gemini Live pushes both - LLMTextFrames and TTSTextFrames. - - .. deprecated:: 0.0.92 - Gemini Live no longer uses `GeminiLiveAssistantContextAggregator` under the hood. - It now uses `LLMAssistantAggregator`. - """ - - def __init__(self, *args, **kwargs): - """Initialize Gemini Live assistant context aggregator.""" - # Almost no users should be seeing this message, as - # `GeminiLiveAssistantContextAggregator` instances were typically - # created under the hood, as part of `llm.create_context_aggregator()`. - with warnings.catch_warnings(): - warnings.simplefilter("always") - warnings.warn( - "GeminiLiveAssistantContextAggregator is deprecated. " - "Gemini Live no longer uses GeminiLiveAssistantContextAggregator under the hood. " - "It now uses LLMAssistantAggregator.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) - - async def process_frame(self, frame: Frame, direction: FrameDirection): - """Process incoming frames for assistant context aggregation. - - Args: - frame: The frame to process. - direction: The frame processing direction. - """ - # The LLMAssistantContextAggregator uses TextFrames to aggregate the LLM output, - # but the GeminiLiveAssistantContextAggregator pushes LLMTextFrames and TTSTextFrames. We - # need to override this proces_frame for LLMTextFrame, so that only the TTSTextFrames - # are process. This ensures that the context gets only one set of messages. - if not isinstance(frame, LLMTextFrame): - await super().process_frame(frame, direction) - - async def handle_user_image_frame(self, frame: UserImageRawFrame): - """Handle user image frames. - - Args: - frame: The user image frame to handle. - """ - # We don't want to store any images in the context. Revisit this later - # when the API evolves. - pass - - -@dataclass -class GeminiLiveContextAggregatorPair: - """Pair of user and assistant context aggregators for Gemini Live. - - .. deprecated:: 0.0.92 - `GeminiLiveContextAggregatorPair` is deprecated. - Use `LLMContextAggregatorPair` instead. - - Parameters: - _user: The user context aggregator instance. - _assistant: The assistant context aggregator instance. - """ - - _user: GeminiLiveUserContextAggregator - _assistant: GeminiLiveAssistantContextAggregator - - def __post_init__(self): - # Almost no users should be seeing this message, as - # `GeminiLiveContextAggregatorPair` instances were typically created - # under the hood, with `llm.create_context_aggregator()`. - with warnings.catch_warnings(): - warnings.simplefilter("always") - warnings.warn( - "GeminiLiveContextAggregatorPair is deprecated. " - "Use LLMContextAggregatorPair instead.", - DeprecationWarning, - stacklevel=2, - ) - - def user(self) -> GeminiLiveUserContextAggregator: - """Get the user context aggregator. - - Returns: - The user context aggregator instance. - """ - return self._user - - def assistant(self) -> GeminiLiveAssistantContextAggregator: - """Get the assistant context aggregator. - - Returns: - The assistant context aggregator instance. - """ - return self._assistant - - class GeminiModalities(Enum): """Supported modalities for Gemini Live. @@ -945,23 +665,6 @@ class GeminiLiveLLMService(LLMService): self._settings.language = self._language_code logger.info(f"Set Gemini language to: {self._language_code}") - async def set_context(self, context: OpenAILLMContext): - """Set the context explicitly from outside the pipeline. - - This is useful when initializing a conversation because in server-side VAD mode we might not have a - way to trigger the pipeline. This sends the history to the server. The `inference_on_context_initialization` - flag controls whether to set the turnComplete flag when we do this. Without that flag, the model will - not respond. This is often what we want when setting the context at the beginning of a conversation. - - Args: - context: The OpenAI LLM context to set. - """ - if self._context: - logger.error("Context already set. Can only set up Gemini Live context once.") - return - self._context = GeminiLiveContext.upgrade(context) - await self._create_initial_response() - # # standard AIService frame handling # @@ -1053,13 +756,8 @@ class GeminiLiveLLMService(LLMService): if isinstance(frame, TranscriptionFrame): await self.push_frame(frame, direction) - elif isinstance(frame, (LLMContextFrame, OpenAILLMContextFrame)): - context = ( - frame.context - if isinstance(frame, LLMContextFrame) - else LLMContext.from_openai_context(frame.context) - ) - await self._handle_context(context) + elif isinstance(frame, LLMContextFrame): + await self._handle_context(frame.context) elif isinstance(frame, InputTextRawFrame): await self._send_user_text(frame.text) await self.push_frame(frame, direction) @@ -2078,40 +1776,3 @@ class GeminiLiveLLMService(LLMService): # cost/stability implications for a service cluster, let's just treat a # send-side error as fatal. await self.push_error(error_msg=f"Send error: {error}") - - def create_context_aggregator( - self, - context: OpenAILLMContext, - *, - user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(), - assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(), - ) -> LLMContextAggregatorPair: - """Create an instance of GeminiLiveContextAggregatorPair from an OpenAILLMContext. - - Constructor keyword arguments for both the user and assistant aggregators can be provided. - - NOTE: this method exists only for backward compatibility. New code - should instead do:: - - context = LLMContext(...) - context_aggregator = LLMContextAggregatorPair(context) - - Args: - context: The LLM context to use. - user_params: User aggregator parameters. Defaults to LLMUserAggregatorParams(). - assistant_params: Assistant aggregator parameters. Defaults to LLMAssistantAggregatorParams(). - - Returns: - A pair of user and assistant context aggregators. - - .. deprecated:: 0.0.99 - `create_context_aggregator()` is deprecated and will be removed in a future version. - Use the universal `LLMContext` and `LLMContextAggregatorPair` instead. - See `OpenAILLMContext` docstring for migration guide. - """ - # from_openai_context handles deprecation warning - context = LLMContext.from_openai_context(context) - assistant_params.expect_stripped_words = False - return LLMContextAggregatorPair( - context, user_params=user_params, assistant_params=assistant_params - ) diff --git a/src/pipecat/services/google/llm.py b/src/pipecat/services/google/llm.py index c8d54830e..b336fcef5 100644 --- a/src/pipecat/services/google/llm.py +++ b/src/pipecat/services/google/llm.py @@ -34,29 +34,16 @@ from pipecat.frames.frames import ( LLMFullResponseEndFrame, LLMFullResponseStartFrame, LLMMessagesAppendFrame, - LLMMessagesFrame, LLMThoughtEndFrame, LLMThoughtStartFrame, LLMThoughtTextFrame, ) from pipecat.metrics.metrics import LLMTokenUsage from pipecat.processors.aggregators.llm_context import LLMContext -from pipecat.processors.aggregators.llm_response import ( - LLMAssistantAggregatorParams, - LLMUserAggregatorParams, -) -from pipecat.processors.aggregators.openai_llm_context import ( - OpenAILLMContext, - OpenAILLMContextFrame, -) from pipecat.processors.frame_processor import FrameDirection from pipecat.services.google.frames import LLMSearchResponseFrame from pipecat.services.google.utils import update_google_client_http_options from pipecat.services.llm_service import FunctionCallFromLLM, LLMService -from pipecat.services.openai.llm import ( - OpenAIAssistantContextAggregator, - OpenAIUserContextAggregator, -) from pipecat.services.settings import ( NOT_GIVEN, LLMSettings, @@ -90,595 +77,6 @@ except ModuleNotFoundError as e: raise Exception(f"Missing module: {e}") -class GoogleUserContextAggregator(OpenAIUserContextAggregator): - """Google-specific user context aggregator. - - Extends OpenAI user context aggregator to handle Google AI's specific - Content and Part message format for user messages. - - .. deprecated:: 0.0.99 - `OpenAIUserContextAggregator` is deprecated and will be removed in a future version. - Use the universal `LLMContext` and `LLMContextAggregatorPair` instead. - See `OpenAILLMContext` docstring for migration guide. - """ - - # Super handles deprecation warning - - async def handle_aggregation(self, aggregation: str): - """Add the aggregated user text to the context as a Google Content message. - - Args: - aggregation: The aggregated user text to add as a user message. - """ - self._context.add_message(Content(role="user", parts=[Part(text=aggregation)])) - - -class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator): - """Google-specific assistant context aggregator. - - Extends OpenAI assistant context aggregator to handle Google AI's specific - Content and Part message format for assistant responses and function calls. - - .. deprecated:: 0.0.99 - `GoogleAssistantContextAggregator` is deprecated and will be removed in a future version. - Use the universal `LLMContext` and `LLMContextAggregatorPair` instead. - See `OpenAILLMContext` docstring for migration guide. - """ - - # Super handles deprecation warning - - async def handle_aggregation(self, aggregation: str): - """Handle aggregated assistant text response. - - Args: - aggregation: The aggregated text response from the assistant. - """ - self._context.add_message(Content(role="model", parts=[Part(text=aggregation)])) - - async def handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame): - """Handle function call in progress frame. - - Args: - frame: Frame containing function call details. - """ - self._context.add_message( - Content( - role="model", - parts=[ - Part( - function_call=FunctionCall( - id=frame.tool_call_id, name=frame.function_name, args=frame.arguments - ) - ) - ], - ) - ) - self._context.add_message( - Content( - role="user", - parts=[ - Part( - function_response=FunctionResponse( - id=frame.tool_call_id, - name=frame.function_name, - response={"response": "IN_PROGRESS"}, - ) - ) - ], - ) - ) - - async def handle_function_call_result(self, frame: FunctionCallResultFrame): - """Handle function call result frame. - - Args: - frame: Frame containing function call result. - """ - if frame.result: - await self._update_function_call_result( - frame.function_name, frame.tool_call_id, frame.result - ) - else: - await self._update_function_call_result( - frame.function_name, frame.tool_call_id, "COMPLETED" - ) - - async def handle_function_call_cancel(self, frame: FunctionCallCancelFrame): - """Handle function call cancellation frame. - - Args: - frame: Frame containing function call cancellation details. - """ - await self._update_function_call_result( - frame.function_name, frame.tool_call_id, "CANCELLED" - ) - - async def _update_function_call_result( - self, function_name: str, tool_call_id: str, result: Any - ): - for message in self._context.messages: - if message.role == "user": - for part in message.parts: - if part.function_response and part.function_response.id == tool_call_id: - part.function_response.response = { - "value": json.dumps(result, ensure_ascii=False) - } - - -@dataclass -class GoogleContextAggregatorPair: - """Pair of Google context aggregators for user and assistant messages. - - .. deprecated:: 0.0.99 - `GoogleContextAggregatorPair` is deprecated and will be removed in a future version. - Use the universal `LLMContext` and `LLMContextAggregatorPair` instead. - See `OpenAILLMContext` docstring for migration guide. - - Parameters: - _user: User context aggregator for handling user messages. - _assistant: Assistant context aggregator for handling assistant responses. - """ - - # Aggregators handle deprecation warnings - _user: GoogleUserContextAggregator - _assistant: GoogleAssistantContextAggregator - - def user(self) -> GoogleUserContextAggregator: - """Get the user context aggregator. - - Returns: - The user context aggregator instance. - """ - return self._user - - def assistant(self) -> GoogleAssistantContextAggregator: - """Get the assistant context aggregator. - - Returns: - The assistant context aggregator instance. - """ - return self._assistant - - -class GoogleLLMContext(OpenAILLMContext): - """Google AI LLM context that extends OpenAI context for Google-specific formatting. - - This class handles conversion between OpenAI-style messages and Google AI's - Content/Part format, including system messages, function calls, and media. - - .. deprecated:: 0.0.99 - `GoogleLLMContext` is deprecated and will be removed in a future version. - Use the universal `LLMContext` and `LLMContextAggregatorPair` instead. - See `OpenAILLMContext` docstring for migration guide. - """ - - def __init__( - self, - messages: Optional[List[dict]] = None, - tools: Optional[List[dict]] = None, - tool_choice: Optional[dict] = None, - ): - """Initialize GoogleLLMContext. - - Args: - messages: Initial messages in OpenAI format. - tools: Available tools/functions for the model. - tool_choice: Tool choice configuration. - """ - # Super handles deprecation warning - super().__init__(messages=messages, tools=tools, tool_choice=tool_choice) - self.system_message = None - - @staticmethod - def upgrade_to_google(obj: OpenAILLMContext) -> "GoogleLLMContext": - """Upgrade an OpenAI context to a Google context. - - Args: - obj: OpenAI LLM context to upgrade. - - Returns: - GoogleLLMContext instance with converted messages. - """ - if isinstance(obj, OpenAILLMContext) and not isinstance(obj, GoogleLLMContext): - logger.debug(f"Upgrading to Google: {obj}") - obj.__class__ = GoogleLLMContext - obj._restructure_from_openai_messages() - return obj - - def set_messages(self, messages: List): - """Set messages and restructure them for Google format. - - Args: - messages: List of messages to set. - """ - self._messages[:] = messages - self._restructure_from_openai_messages() - - def add_messages(self, messages: List): - """Add messages to the context, converting to Google format as needed. - - Args: - messages: List of messages to add (can be mixed formats). - """ - # Convert each message individually - converted_messages = [] - for msg in messages: - if isinstance(msg, Content): - # Already in Gemini format - converted_messages.append(msg) - else: - # Convert from standard format to Gemini format - converted = self.from_standard_message(msg) - if converted is not None: - converted_messages.append(converted) - - # Add the converted messages to our existing messages - self._messages.extend(converted_messages) - - def get_messages_for_logging(self) -> List[Dict[str, Any]]: - """Get messages formatted for logging with sensitive data redacted. - - Returns: - List of messages in a format ready for logging. - """ - msgs = [] - for message in self.messages: - obj = message.to_json_dict() - try: - if "parts" in obj: - for part in obj["parts"]: - if "inline_data" in part: - part["inline_data"]["data"] = "..." - except Exception as e: - logger.debug(f"Error: {e}") - msgs.append(obj) - return msgs - - def add_image_frame_message( - self, *, format: str, size: tuple[int, int], image: bytes, text: str = None - ): - """Add an image message to the context. - - Args: - format: Image format (e.g., 'RGB', 'RGBA'). - size: Image dimensions as (width, height). - image: Raw image bytes. - text: Optional text to accompany the image. - """ - buffer = io.BytesIO() - Image.frombytes(format, size, image).save(buffer, format="JPEG") - - parts = [] - if text: - parts.append(Part(text=text)) - parts.append(Part(inline_data=Blob(mime_type="image/jpeg", data=buffer.getvalue()))) - - self.add_message(Content(role="user", parts=parts)) - - def add_audio_frames_message( - self, *, audio_frames: list[AudioRawFrame], text: str = "Audio follows" - ): - """Add audio frames as a message to the context. - - Args: - audio_frames: List of audio frames to add. - text: Text description of the audio content. - """ - if not audio_frames: - return - - sample_rate = audio_frames[0].sample_rate - num_channels = audio_frames[0].num_channels - - parts = [] - data = b"".join(frame.audio for frame in audio_frames) - # NOTE(aleix): According to the docs only text or inline_data should be needed. - # (see https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference) - parts.append(Part(text=text)) - parts.append( - Part( - inline_data=Blob( - mime_type="audio/wav", - data=( - bytes( - self.create_wav_header(sample_rate, num_channels, 16, len(data)) + data - ) - ), - ) - ), - ) - self.add_message(Content(role="user", parts=parts)) - # message = {"mime_type": "audio/mp3", "data": bytes(data + create_wav_header(sample_rate, num_channels, 16, len(data)))} - # self.add_message(message) - - def from_standard_message(self, message): - """Convert standard format message to Google Content object. - - Handles conversion of text, images, and function calls to Google's format. - System messages are stored separately and return None. - - Args: - message: Message in standard format. - - Returns: - Content object with role and parts, or None for system messages. - - Examples: - Standard text message:: - - { - "role": "user", - "content": "Hello there" - } - - Converts to Google Content with:: - - Content( - role="user", - parts=[Part(text="Hello there")] - ) - - Standard function call message:: - - { - "role": "assistant", - "tool_calls": [ - { - "function": { - "name": "search", - "arguments": '{"query": "test"}' - } - } - ] - } - - Converts to Google Content with:: - - Content( - role="model", - parts=[Part(function_call=FunctionCall(name="search", args={"query": "test"}))] - ) - - System message returns None and stores content in self.system_message. - """ - role = message["role"] - content = message.get("content", []) - if role == "system": - # System instructions are returned as plain text - if isinstance(content, str): - self.system_message = content - elif isinstance(content, list): - # If content is a list, we assume it's a list of text parts, per the standard - self.system_message = " ".join( - part["text"] for part in content if part.get("type") == "text" - ) - return None - elif role == "assistant": - role = "model" - - parts = [] - if message.get("tool_calls"): - for tc in message["tool_calls"]: - parts.append( - Part( - function_call=FunctionCall( - name=tc["function"]["name"], - args=json.loads(tc["function"]["arguments"]), - ) - ) - ) - elif role == "tool": - role = "model" - try: - response = json.loads(message["content"]) - if isinstance(response, dict): - response_dict = response - else: - response_dict = {"value": response} - except Exception as e: - # Response might not be JSON-deserializable (e.g. plain text). - response_dict = {"value": message["content"]} - parts.append( - Part( - function_response=FunctionResponse( - name="tool_call_result", # seems to work to hard-code the same name every time - response=response_dict, - ) - ) - ) - elif isinstance(content, str): - parts.append(Part(text=content)) - elif isinstance(content, list): - for c in content: - if c["type"] == "text": - parts.append(Part(text=c["text"])) - elif c["type"] == "image_url": - # Extract MIME type from data URL (format: "data:image/jpeg;base64,...") - url = c["image_url"]["url"] - mime_type = ( - url.split(":")[1].split(";")[0] if url.startswith("data:") else "image/jpeg" - ) - parts.append( - Part( - inline_data=Blob( - mime_type=mime_type, - data=base64.b64decode(url.split(",")[1]), - ) - ) - ) - - message = Content(role=role, parts=parts) - return message - - def to_standard_messages(self, obj) -> list: - """Convert Google Content object to standard structured format. - - Handles text, images, and function calls from Google's Content/Part objects. - - Args: - obj: Google Content object with role and parts. - - Returns: - List containing a single message in standard format. - - Examples: - Google Content with text:: - - Content( - role="user", - parts=[Part(text="Hello")] - ) - - Converts to:: - - [ - { - "role": "user", - "content": [{"type": "text", "text": "Hello"}] - } - ] - - Google Content with function call:: - - Content( - role="model", - parts=[Part(function_call=FunctionCall(name="search", args={"q": "test"}))] - ) - - Converts to:: - - [ - { - "role": "assistant", - "tool_calls": [ - { - "id": "search", - "type": "function", - "function": { - "name": "search", - "arguments": '{"q": "test"}' - } - } - ] - } - ] - - Google Content with image:: - - Content( - role="user", - parts=[Part(inline_data=Blob(mime_type="image/jpeg", data=bytes_data))] - ) - - Converts to:: - - [ - { - "role": "user", - "content": [ - { - "type": "image_url", - "image_url": {"url": "data:image/jpeg;base64,"} - } - ] - } - ] - """ - msg = {"role": obj.role, "content": []} - if msg["role"] == "model": - msg["role"] = "assistant" - - for part in obj.parts: - if part.text: - msg["content"].append({"type": "text", "text": part.text}) - elif part.inline_data: - encoded = base64.b64encode(part.inline_data.data).decode("utf-8") - msg["content"].append( - { - "type": "image_url", - "image_url": {"url": f"data:{part.inline_data.mime_type};base64,{encoded}"}, - } - ) - elif part.function_call: - args = part.function_call.args if hasattr(part.function_call, "args") else {} - msg["tool_calls"] = [ - { - "id": part.function_call.name, - "type": "function", - "function": { - "name": part.function_call.name, - "arguments": json.dumps(args), - }, - } - ] - - elif part.function_response: - msg["role"] = "tool" - resp = ( - part.function_response.response - if hasattr(part.function_response, "response") - else {} - ) - msg["tool_call_id"] = part.function_response.name - msg["content"] = json.dumps(resp) - - # there might be no content parts for tool_calls messages - if not msg["content"]: - del msg["content"] - return [msg] - - def _restructure_from_openai_messages(self): - """Restructures messages to ensure proper Google format and message ordering. - - This method handles conversion of OpenAI-formatted messages to Google format, - with special handling for function calls, function responses, and system messages. - System messages are added back to the context as user messages when needed. - - The final message order is preserved as: - 1. Function calls (from model) - 2. Function responses (from user) - 3. Text messages (converted from system messages) - - Note: - System messages are only added back when there are no regular text - messages in the context, ensuring proper conversation continuity - after function calls. - """ - self.system_message = None - converted_messages = [] - - # Process each message, preserving Google-formatted messages and converting others - for message in self._messages: - if isinstance(message, Content): - # Keep existing Google-formatted messages (e.g., function calls/responses) - converted_messages.append(message) - continue - - # Convert OpenAI format to Google format, system messages return None - converted = self.from_standard_message(message) - if converted is not None: - converted_messages.append(converted) - - # Update message list - self._messages[:] = converted_messages - - # Check if we only have function-related messages (no regular text) - has_regular_messages = any( - len(msg.parts) == 1 - and getattr(msg.parts[0], "text", None) - and not getattr(msg.parts[0], "function_call", None) - and not getattr(msg.parts[0], "function_response", None) - for msg in self._messages - ) - - # Add system message back as a user message if we only have function messages - if self.system_message and not has_regular_messages: - self._messages.append(Content(role="user", parts=[Part(text=self.system_message)])) - - # Remove any empty messages - self._messages = [m for m in self._messages if m.parts] - - class GoogleThinkingConfig(BaseModel): """Configuration for controlling the model's internal "thinking" process used before generating a response. @@ -741,8 +139,7 @@ class GoogleLLMService(LLMService): """Google AI (Gemini) LLM service implementation. This class implements inference with Google's AI models, translating internally - from an OpenAILLMContext or a universal LLMContext to the messages format - expected by the Google AI model. + from an LLMContext to the messages format expected by the Google AI model. """ Settings = GoogleLLMSettings @@ -885,7 +282,7 @@ class GoogleLLMService(LLMService): async def run_inference( self, - context: LLMContext | OpenAILLMContext, + context: LLMContext, max_tokens: Optional[int] = None, system_instruction: Optional[str] = None, ) -> Optional[str]: @@ -905,19 +302,13 @@ class GoogleLLMService(LLMService): system = [] tools = [] effective_instruction = system_instruction or self._settings.system_instruction - if isinstance(context, LLMContext): - adapter = self.get_llm_adapter() - params: GeminiLLMInvocationParams = adapter.get_llm_invocation_params( - context, system_instruction=effective_instruction - ) - messages = params["messages"] - system = params["system_instruction"] - tools = params["tools"] - else: - context = GoogleLLMContext.upgrade_to_google(context) - messages = context.messages - system = getattr(context, "system_message", None) - tools = context.tools or [] + adapter = self.get_llm_adapter() + params: GeminiLLMInvocationParams = adapter.get_llm_invocation_params( + context, system_instruction=effective_instruction + ) + messages = params["messages"] + system = params["system_instruction"] + tools = params["tools"] # Build generation config using the same method as streaming generation_params = self._build_generation_params( @@ -1004,17 +395,24 @@ class GoogleLLMService(LLMService): except Exception as e: logger.error(f"Failed to unset thinking budget: {e}") - async def _stream_content( - self, params_from_context: GeminiLLMInvocationParams - ) -> AsyncIterator[GenerateContentResponse]: - messages = params_from_context["messages"] + async def _stream_content(self, context: LLMContext) -> AsyncIterator[GenerateContentResponse]: + adapter = self.get_llm_adapter() + params: GeminiLLMInvocationParams = adapter.get_llm_invocation_params( + context, system_instruction=self._settings.system_instruction + ) + + logger.debug( + f"{self}: Generating chat from context [{params['system_instruction']}] | {adapter.get_messages_for_logging(context)}" + ) + + messages = params["messages"] # The adapter already resolved system_instruction vs context system message. - system_instruction = params_from_context["system_instruction"] + system_instruction = params["system_instruction"] tools = [] - if params_from_context["tools"]: - tools = params_from_context["tools"] + if params["tools"]: + tools = params["tools"] elif self._tools: tools = self._tools tool_config = None @@ -1040,37 +438,8 @@ class GoogleLLMService(LLMService): config=generation_config, ) - async def _stream_content_specific_context( - self, context: OpenAILLMContext - ) -> AsyncIterator[GenerateContentResponse]: - logger.debug( - f"{self}: Generating chat from LLM-specific context [{context.system_message}] | {context.get_messages_for_logging()}" - ) - - params = GeminiLLMInvocationParams( - messages=context.messages, - system_instruction=context.system_message, - tools=context.tools, - ) - - return await self._stream_content(params) - - async def _stream_content_universal_context( - self, context: LLMContext - ) -> AsyncIterator[GenerateContentResponse]: - adapter = self.get_llm_adapter() - params: GeminiLLMInvocationParams = adapter.get_llm_invocation_params( - context, system_instruction=self._settings.system_instruction - ) - - logger.debug( - f"{self}: Generating chat from universal context [{params['system_instruction']}] | {adapter.get_messages_for_logging(context)}" - ) - - return await self._stream_content(params) - @traced_llm - async def _process_context(self, context: OpenAILLMContext | LLMContext): + async def _process_context(self, context: LLMContext): await self.push_frame(LLMFullResponseStartFrame()) prompt_tokens = 0 @@ -1083,12 +452,8 @@ class GoogleLLMService(LLMService): accumulated_text = "" try: - # Generate content using either OpenAILLMContext or universal LLMContext - response = await ( - self._stream_content_specific_context(context) - if isinstance(context, OpenAILLMContext) - else self._stream_content_universal_context(context) - ) + # Generate content from LLMContext + response = await self._stream_content(context) function_calls = [] async for chunk in response: @@ -1272,23 +637,11 @@ class GoogleLLMService(LLMService): """ await super().process_frame(frame, direction) - context = None - - if isinstance(frame, OpenAILLMContextFrame): - context = GoogleLLMContext.upgrade_to_google(frame.context) - elif isinstance(frame, LLMContextFrame): - # Handle universal (LLM-agnostic) LLM context frames - context = frame.context - elif isinstance(frame, LLMMessagesFrame): - # NOTE: LLMMessagesFrame is deprecated, so we don't support the newer universal - # LLMContext with it - context = GoogleLLMContext(frame.messages) + if isinstance(frame, LLMContextFrame): + await self._process_context(frame.context) else: await self.push_frame(frame, direction) - if context: - await self._process_context(context) - async def stop(self, frame): """Override stop to gracefully close the client.""" await super().stop(frame) @@ -1305,41 +658,3 @@ class GoogleLLMService(LLMService): except Exception: # Do nothing - we're shutting down anyway pass - - def create_context_aggregator( - self, - context: OpenAILLMContext, - *, - user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(), - assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(), - ) -> GoogleContextAggregatorPair: - """Create Google-specific context aggregators. - - Creates a pair of context aggregators optimized for Google's message format, - including support for function calls, tool usage, and image handling. - - Args: - context: The LLM context to create aggregators for. - user_params: Parameters for user message aggregation. - assistant_params: Parameters for assistant message aggregation. - - Returns: - GoogleContextAggregatorPair: A pair of context aggregators, one for - the user and one for the assistant, encapsulated in an - GoogleContextAggregatorPair. - - .. deprecated:: 0.0.99 - `create_context_aggregator()` is deprecated and will be removed in a future version. - Use the universal `LLMContext` and `LLMContextAggregatorPair` instead. - See `OpenAILLMContext` docstring for migration guide. - """ - context.set_llm_adapter(self.get_llm_adapter()) - - if isinstance(context, OpenAILLMContext): - context = GoogleLLMContext.upgrade_to_google(context) - - # Aggregators handle deprecation warnings - user = GoogleUserContextAggregator(context, params=user_params) - assistant = GoogleAssistantContextAggregator(context, params=assistant_params) - - return GoogleContextAggregatorPair(_user=user, _assistant=assistant) diff --git a/src/pipecat/services/llm_service.py b/src/pipecat/services/llm_service.py index f2e247de9..e4f96c388 100644 --- a/src/pipecat/services/llm_service.py +++ b/src/pipecat/services/llm_service.py @@ -55,11 +55,6 @@ from pipecat.processors.aggregators.llm_context import ( LLMContext, LLMSpecificMessage, ) -from pipecat.processors.aggregators.llm_response import ( - LLMAssistantAggregatorParams, - LLMUserAggregatorParams, -) -from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext from pipecat.processors.frame_processor import FrameDirection from pipecat.services.ai_service import AIService from pipecat.services.settings import LLMSettings @@ -110,7 +105,7 @@ class FunctionCallParams: tool_call_id: str arguments: Mapping[str, Any] llm: "LLMService" - context: OpenAILLMContext | LLMContext + context: LLMContext result_callback: FunctionCallResultCallback @@ -153,7 +148,7 @@ class FunctionCallRunnerItem: function_name: str tool_call_id: str arguments: Mapping[str, Any] - context: OpenAILLMContext | LLMContext + context: LLMContext run_llm: Optional[bool] = None @@ -247,7 +242,7 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService): async def run_inference( self, - context: LLMContext | OpenAILLMContext, + context: LLMContext, max_tokens: Optional[int] = None, system_instruction: Optional[str] = None, ) -> Optional[str]: @@ -267,41 +262,6 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService): """ raise NotImplementedError(f"run_inference() not supported by {self.__class__.__name__}") - def create_context_aggregator( - self, - context: OpenAILLMContext, - *, - user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(), - assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(), - ) -> Any: - """Create a context aggregator for managing LLM conversation context. - - Must be implemented by subclasses. - - Args: - context: The LLM context to create an aggregator for. - user_params: Parameters for user message aggregation. - assistant_params: Parameters for assistant message aggregation. - - Returns: - A context aggregator instance. - - .. deprecated:: 0.0.99 - `create_context_aggregator()` is deprecated and will be removed in a future version. - Use the universal `LLMContext` and `LLMContextAggregatorPair` instead. - See `OpenAILLMContext` docstring for migration guide. - """ - with warnings.catch_warnings(): - warnings.simplefilter("always") - warnings.warn( - "create_context_aggregator() is deprecated and will be removed in a future version. " - "Use the universal LLMContext and LLMContextAggregatorPair directly instead. " - "See OpenAILLMContext docstring for migration guide.", - DeprecationWarning, - stacklevel=2, - ) - pass - async def start(self, frame: StartFrame): """Start the LLM service. diff --git a/src/pipecat/services/mem0/memory.py b/src/pipecat/services/mem0/memory.py index 754e2be0a..91396cab4 100644 --- a/src/pipecat/services/mem0/memory.py +++ b/src/pipecat/services/mem0/memory.py @@ -17,12 +17,8 @@ from typing import Any, Dict, List, Optional from loguru import logger from pydantic import BaseModel, Field -from pipecat.frames.frames import Frame, LLMContextFrame, LLMMessagesFrame +from pipecat.frames.frames import Frame, LLMContextFrame from pipecat.processors.aggregators.llm_context import LLMContext -from pipecat.processors.aggregators.openai_llm_context import ( - OpenAILLMContext, - OpenAILLMContextFrame, -) from pipecat.processors.frame_processor import FrameDirection, FrameProcessor try: @@ -227,9 +223,7 @@ class Mem0MemoryService(FrameProcessor): logger.error(f"Error retrieving memories from Mem0: {e}") return [] - async def _enhance_context_with_memories( - self, context: LLMContext | OpenAILLMContext, query: str - ): + async def _enhance_context_with_memories(self, context: LLMContext, query: str): """Enhance the LLM context with relevant memories. Args: @@ -271,16 +265,8 @@ class Mem0MemoryService(FrameProcessor): """ await super().process_frame(frame, direction) - context = None - messages = None - - if isinstance(frame, (LLMContextFrame, OpenAILLMContextFrame)): + if isinstance(frame, LLMContextFrame): context = frame.context - elif isinstance(frame, LLMMessagesFrame): - messages = frame.messages - context = LLMContext(messages) - - if context: try: # Get the latest user message to use as a query for memory retrieval context_messages = context.get_messages() @@ -302,17 +288,12 @@ class Mem0MemoryService(FrameProcessor): # Store the conversation in Mem0 as a background task self.create_task(self._store_messages(messages_to_store), name="mem0_store") - # If we received an LLMMessagesFrame, create a new one with the enhanced messages - if messages is not None: - await self.push_frame(LLMMessagesFrame(context.get_messages())) - else: - # Otherwise, pass the enhanced context frame downstream - await self.push_frame(frame) + # Pass the enhanced context frame downstream + await self.push_frame(frame) except Exception as e: await self.push_error( error_msg=f"Error processing with Mem0: {str(e)}", exception=e ) await self.push_frame(frame) # Still pass the original frame through else: - # For non-context frames, just pass them through await self.push_frame(frame, direction) diff --git a/src/pipecat/services/nvidia/llm.py b/src/pipecat/services/nvidia/llm.py index 66bbd4402..a06dfd4da 100644 --- a/src/pipecat/services/nvidia/llm.py +++ b/src/pipecat/services/nvidia/llm.py @@ -15,7 +15,6 @@ from typing import Optional from pipecat.metrics.metrics import LLMTokenUsage from pipecat.processors.aggregators.llm_context import LLMContext -from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext from pipecat.services.openai.base_llm import BaseOpenAILLMService from pipecat.services.openai.llm import OpenAILLMService @@ -84,7 +83,7 @@ class NvidiaLLMService(OpenAILLMService): self._has_reported_prompt_tokens = False self._is_processing = False - async def _process_context(self, context: OpenAILLMContext | LLMContext): + async def _process_context(self, context: LLMContext): """Process a context through the LLM and accumulate token usage metrics. This method overrides the parent class implementation to handle NVIDIA's diff --git a/src/pipecat/services/openai/base_llm.py b/src/pipecat/services/openai/base_llm.py index 5b444aea1..76fab8f5f 100644 --- a/src/pipecat/services/openai/base_llm.py +++ b/src/pipecat/services/openai/base_llm.py @@ -31,15 +31,10 @@ from pipecat.frames.frames import ( LLMContextFrame, LLMFullResponseEndFrame, LLMFullResponseStartFrame, - LLMMessagesFrame, LLMTextFrame, ) from pipecat.metrics.metrics import LLMTokenUsage from pipecat.processors.aggregators.llm_context import LLMContext -from pipecat.processors.aggregators.openai_llm_context import ( - OpenAILLMContext, - OpenAILLMContextFrame, -) from pipecat.processors.frame_processor import FrameDirection from pipecat.services.llm_service import FunctionCallFromLLM, LLMService from pipecat.services.settings import NOT_GIVEN as _NOT_GIVEN @@ -61,11 +56,10 @@ class OpenAILLMSettings(LLMSettings): class BaseOpenAILLMService(LLMService): """Base class for all services that use the AsyncOpenAI client. - This service consumes OpenAILLMContextFrame or LLMContextFrame frames, - which contain a reference to an OpenAILLMContext or LLMContext object. The - context defines what is sent to the LLM for completion, including user, - assistant, and system messages, as well as tool choices and function call - configurations. + This service consumes LLMContextFrame frames, which contain a reference to + an LLMContext object. The context defines what is sent to the LLM for + completion, including user, assistant, and system messages, as well as tool + choices and function call configurations. """ Settings = OpenAILLMSettings @@ -274,19 +268,27 @@ class BaseOpenAILLMService(LLMService): """ return self._full_model_name - async def get_chat_completions( - self, params_from_context: OpenAILLMInvocationParams - ) -> AsyncStream[ChatCompletionChunk]: + async def get_chat_completions(self, context: LLMContext) -> AsyncStream[ChatCompletionChunk]: """Get streaming chat completions from OpenAI API with optional timeout and retry. Args: - params_from_context: Parameters, derived from the LLM context, to - use for the chat completion. Contains messages, tools, and tool - choice. + context: Context to use for the chat completion. + Contains messages, tools, and tool choice. Returns: Async stream of chat completion chunks. """ + adapter = self.get_llm_adapter() + logger.debug( + f"{self}: Generating chat from context {adapter.get_messages_for_logging(context)}" + ) + + params_from_context: OpenAILLMInvocationParams = adapter.get_llm_invocation_params( + context, + system_instruction=self._settings.system_instruction, + convert_developer_to_user=not self.supports_developer_role, + ) + params = self.build_chat_completion_params(params_from_context) if self._retry_on_timeout: @@ -340,7 +342,7 @@ class BaseOpenAILLMService(LLMService): async def run_inference( self, - context: LLMContext | OpenAILLMContext, + context: LLMContext, max_tokens: Optional[int] = None, system_instruction: Optional[str] = None, ) -> Optional[str]: @@ -357,17 +359,12 @@ class BaseOpenAILLMService(LLMService): The LLM's response as a string, or None if no response is generated. """ effective_instruction = system_instruction or self._settings.system_instruction - if isinstance(context, LLMContext): - adapter = self.get_llm_adapter() - invocation_params: OpenAILLMInvocationParams = adapter.get_llm_invocation_params( - context, - system_instruction=effective_instruction, - convert_developer_to_user=not self.supports_developer_role, - ) - else: - invocation_params = OpenAILLMInvocationParams( - messages=context.messages, tools=context.tools, tool_choice=context.tool_choice - ) + adapter = self.get_llm_adapter() + invocation_params: OpenAILLMInvocationParams = adapter.get_llm_invocation_params( + context, + system_instruction=effective_instruction, + convert_developer_to_user=not self.supports_developer_role, + ) # Build params using the same method as streaming completions params = self.build_chat_completion_params(invocation_params) @@ -389,59 +386,8 @@ class BaseOpenAILLMService(LLMService): return response.choices[0].message.content - async def _stream_chat_completions_specific_context( - self, context: OpenAILLMContext - ) -> AsyncStream[ChatCompletionChunk]: - logger.debug( - f"{self}: Generating chat from LLM-specific context {context.get_messages_for_logging()}" - ) - - messages: List[ChatCompletionMessageParam] = context.get_messages() - - # base64 encode any images - for message in messages: - if message.get("mime_type") == "image/jpeg": - # Avoid .getvalue() which makes a full copy of BytesIO - raw_bytes = message["data"].read() - encoded_image = base64.b64encode(raw_bytes).decode("utf-8") - text = message.get("content", "") - message["content"] = [ - {"type": "text", "text": text}, - { - "type": "image_url", - "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}, - }, - ] - # Explicit cleanup - del message["data"] - del message["mime_type"] - - params = OpenAILLMInvocationParams( - messages=messages, tools=context.tools, tool_choice=context.tool_choice - ) - chunks = await self.get_chat_completions(params) - - return chunks - - async def _stream_chat_completions_universal_context( - self, context: LLMContext - ) -> AsyncStream[ChatCompletionChunk]: - adapter = self.get_llm_adapter() - logger.debug( - f"{self}: Generating chat from universal context {adapter.get_messages_for_logging(context)}" - ) - - params: OpenAILLMInvocationParams = adapter.get_llm_invocation_params( - context, - system_instruction=self._settings.system_instruction, - convert_developer_to_user=not self.supports_developer_role, - ) - chunks = await self.get_chat_completions(params) - - return chunks - @traced_llm - async def _process_context(self, context: OpenAILLMContext | LLMContext): + async def _process_context(self, context: LLMContext): functions_list = [] arguments_list = [] tool_id_list = [] @@ -452,12 +398,8 @@ class BaseOpenAILLMService(LLMService): await self.start_ttfb_metrics() - # Generate chat completions using either OpenAILLMContext or universal LLMContext - chunk_stream = await ( - self._stream_chat_completions_specific_context(context) - if isinstance(context, OpenAILLMContext) - else self._stream_chat_completions_universal_context(context) - ) + # Generate chat completions from LLMContext + chunk_stream = await self.get_chat_completions(context) # Ensure stream and its async iterator are closed on cancellation/exception # to prevent socket leaks and uvloop crashes. Closing the iterator first @@ -586,9 +528,7 @@ class BaseOpenAILLMService(LLMService): async def process_frame(self, frame: Frame, direction: FrameDirection): """Process frames for LLM completion requests. - Handles OpenAILLMContextFrame, LLMContextFrame, LLMMessagesFrame, - and LLMUpdateSettingsFrame to trigger LLM completions and manage - settings. + Handles LLMContextFrame to trigger LLM completions. Args: frame: The frame to process. @@ -596,25 +536,11 @@ class BaseOpenAILLMService(LLMService): """ await super().process_frame(frame, direction) - context = None - if isinstance(frame, OpenAILLMContextFrame): - # Handle OpenAI-specific context frames - context = frame.context - elif isinstance(frame, LLMContextFrame): - # Handle universal (LLM-agnostic) LLM context frames - context = frame.context - elif isinstance(frame, LLMMessagesFrame): - # NOTE: LLMMessagesFrame is deprecated, so we don't support the newer universal - # LLMContext with it - context = OpenAILLMContext.from_messages(frame.messages) - else: - await self.push_frame(frame, direction) - - if context: + if isinstance(frame, LLMContextFrame): try: await self.push_frame(LLMFullResponseStartFrame()) await self.start_processing_metrics() - await self._process_context(context) + await self._process_context(frame.context) except httpx.TimeoutException as e: await self._call_event_handler("on_completion_timeout") await self.push_error(error_msg="LLM completion timeout", exception=e) @@ -623,3 +549,5 @@ class BaseOpenAILLMService(LLMService): finally: await self.stop_processing_metrics() await self.push_frame(LLMFullResponseEndFrame()) + else: + await self.push_frame(frame, direction) diff --git a/src/pipecat/services/openai/llm.py b/src/pipecat/services/openai/llm.py index 553733922..24c2134fc 100644 --- a/src/pipecat/services/openai/llm.py +++ b/src/pipecat/services/openai/llm.py @@ -18,51 +18,9 @@ from pipecat.frames.frames import ( FunctionCallResultFrame, UserImageRawFrame, ) -from pipecat.processors.aggregators.llm_response import ( - LLMAssistantAggregatorParams, - LLMAssistantContextAggregator, - LLMUserAggregatorParams, - LLMUserContextAggregator, -) -from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext from pipecat.services.openai.base_llm import BaseOpenAILLMService -@dataclass -class OpenAIContextAggregatorPair: - """Pair of OpenAI context aggregators for user and assistant messages. - - .. deprecated:: 0.0.99 - `OpenAIContextAggregatorPair` is deprecated and will be removed in a future version. - Use the universal `LLMContext` and `LLMContextAggregatorPair` instead. - See `OpenAILLMContext` docstring for migration guide. - - Parameters: - _user: User context aggregator for processing user messages. - _assistant: Assistant context aggregator for processing assistant messages. - """ - - # Aggregators handle deprecation warnings - _user: "OpenAIUserContextAggregator" - _assistant: "OpenAIAssistantContextAggregator" - - def user(self) -> "OpenAIUserContextAggregator": - """Get the user context aggregator. - - Returns: - The user context aggregator instance. - """ - return self._user - - def assistant(self) -> "OpenAIAssistantContextAggregator": - """Get the assistant context aggregator. - - Returns: - The assistant context aggregator instance. - """ - return self._assistant - - class OpenAILLMService(BaseOpenAILLMService): """OpenAI LLM service implementation. @@ -145,161 +103,3 @@ class OpenAILLMService(BaseOpenAILLMService): default_settings.apply_update(settings) super().__init__(service_tier=service_tier, settings=default_settings, **kwargs) - - def create_context_aggregator( - self, - context: OpenAILLMContext, - *, - user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(), - assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(), - ) -> OpenAIContextAggregatorPair: - """Create OpenAI-specific context aggregators. - - Creates a pair of context aggregators optimized for OpenAI's message format, - including support for function calls, tool usage, and image handling. - - Args: - context: The LLM context to create aggregators for. - user_params: Parameters for user message aggregation. - assistant_params: Parameters for assistant message aggregation. - - Returns: - OpenAIContextAggregatorPair: A pair of context aggregators, one for - the user and one for the assistant, encapsulated in an - OpenAIContextAggregatorPair. - - .. deprecated:: 0.0.99 - `create_context_aggregator()` is deprecated and will be removed in a future version. - Use the universal `LLMContext` and `LLMContextAggregatorPair` instead. - See `OpenAILLMContext` docstring for migration guide. - """ - context.set_llm_adapter(self.get_llm_adapter()) - - # Aggregators handle deprecation warnings - user = OpenAIUserContextAggregator(context, params=user_params) - assistant = OpenAIAssistantContextAggregator(context, params=assistant_params) - - return OpenAIContextAggregatorPair(_user=user, _assistant=assistant) - - -class OpenAIUserContextAggregator(LLMUserContextAggregator): - """OpenAI-specific user context aggregator. - - Handles aggregation of user messages for OpenAI LLM services. - Inherits all functionality from the base LLMUserContextAggregator. - - .. deprecated:: 0.0.99 - `OpenAIUserContextAggregator` is deprecated and will be removed in a future version. - Use the universal `LLMContext` and `LLMContextAggregatorPair` instead. - See `OpenAILLMContext` docstring for migration guide. - """ - - # Super handles deprecation warning - pass - - -class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator): - """OpenAI-specific assistant context aggregator. - - Handles aggregation of assistant messages for OpenAI LLM services, - with specialized support for OpenAI's function calling format, - tool usage tracking, and image message handling. - - .. deprecated:: 0.0.99 - `OpenAIAssistantContextAggregator` is deprecated and will be removed in a future version. - Use the universal `LLMContext` and `LLMContextAggregatorPair` instead. - See `OpenAILLMContext` docstring for migration guide. - """ - - # Super handles deprecation warning - - async def handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame): - """Handle a function call in progress. - - Adds the function call to the context with an IN_PROGRESS status - to track ongoing function execution. - - Args: - frame: Frame containing function call progress information. - """ - self._context.add_message( - { - "role": "assistant", - "tool_calls": [ - { - "id": frame.tool_call_id, - "function": { - "name": frame.function_name, - "arguments": json.dumps(frame.arguments), - }, - "type": "function", - } - ], - } - ) - self._context.add_message( - { - "role": "tool", - "content": "IN_PROGRESS", - "tool_call_id": frame.tool_call_id, - } - ) - - async def handle_function_call_result(self, frame: FunctionCallResultFrame): - """Handle the result of a function call. - - Updates the context with the function call result, replacing any - previous IN_PROGRESS status. - - Args: - frame: Frame containing the function call result. - """ - if frame.result: - result = json.dumps(frame.result, ensure_ascii=False) - await self._update_function_call_result(frame.function_name, frame.tool_call_id, result) - else: - await self._update_function_call_result( - frame.function_name, frame.tool_call_id, "COMPLETED" - ) - - async def handle_function_call_cancel(self, frame: FunctionCallCancelFrame): - """Handle a cancelled function call. - - Updates the context to mark the function call as cancelled. - - Args: - frame: Frame containing the function call cancellation information. - """ - await self._update_function_call_result( - frame.function_name, frame.tool_call_id, "CANCELLED" - ) - - async def _update_function_call_result( - self, function_name: str, tool_call_id: str, result: Any - ): - for message in self._context.messages: - if ( - message["role"] == "tool" - and message["tool_call_id"] - and message["tool_call_id"] == tool_call_id - ): - message["content"] = result - - async def handle_user_image_frame(self, frame: UserImageRawFrame): - """Handle a user image frame from a function call request. - - Marks the associated function call as completed and adds the image - to the context for processing. - - Args: - frame: Frame containing the user image and request context. - """ - await self._update_function_call_result( - frame.request.function_name, frame.request.tool_call_id, "COMPLETED" - ) - self._context.add_image_frame_message( - format=frame.format, - size=frame.size, - image=frame.image, - text=frame.request.context, - ) diff --git a/src/pipecat/services/openai/realtime/context.py b/src/pipecat/services/openai/realtime/context.py deleted file mode 100644 index 7870cc519..000000000 --- a/src/pipecat/services/openai/realtime/context.py +++ /dev/null @@ -1,368 +0,0 @@ -# -# Copyright (c) 2024-2026, Daily -# -# SPDX-License-Identifier: BSD 2-Clause License -# - -"""OpenAI Realtime LLM context and aggregator implementations. - -.. deprecated:: 0.0.92 - OpenAI Realtime no longer uses types from this module under the hood. - It now uses ``LLMContext`` and ``LLMContextAggregatorPair``. - Using the new patterns should allow you to not need types from this module. - - BEFORE:: - - # Setup - context = OpenAILLMContext(messages, tools) - context_aggregator = llm.create_context_aggregator(context) - - # Context aggregator type - context_aggregator: OpenAIContextAggregatorPair - - # Context frame type - frame: OpenAILLMContextFrame - - # Context type - context: OpenAIRealtimeLLMContext - # or - context: OpenAILLMContext - - AFTER:: - - # Setup - context = LLMContext(messages, tools) - context_aggregator = LLMContextAggregatorPair(context) - - # Context aggregator type - context_aggregator: LLMContextAggregatorPair - - # Context frame type - frame: LLMContextFrame - - # Context type - context: LLMContext -""" - -import warnings - -with warnings.catch_warnings(): - warnings.simplefilter("always") - warnings.warn( - "Types in pipecat.services.openai.realtime.llm (or " - "pipecat.services.openai_realtime.llm) are deprecated. \n" - "OpenAI Realtime no longer uses types from this module under the hood. \n" - "It now uses `LLMContext` and `LLMContextAggregatorPair`. \n" - "Using the new patterns should allow you to not need types from this module.\n\n" - "BEFORE:\n" - "```\n" - "# Setup\n" - "context = OpenAILLMContext(messages, tools)\n" - "context_aggregator = llm.create_context_aggregator(context)\n\n" - "# Context aggregator type\n" - "context_aggregator: OpenAIContextAggregatorPair\n\n" - "# Context frame type\n" - "frame: OpenAILLMContextFrame\n\n" - "# Context type\n" - "context: OpenAIRealtimeLLMContext\n" - "# or\n" - "context: OpenAILLMContext\n\n" - "```\n\n" - "AFTER:\n" - "```\n" - "# Setup\n" - "context = LLMContext(messages, tools)\n" - "context_aggregator = LLMContextAggregatorPair(context)\n\n" - "# Context aggregator type\n" - "context_aggregator: LLMContextAggregatorPair\n\n" - "# Context frame type\n" - "frame: LLMContextFrame\n\n" - "# Context type\n" - "context: LLMContext\n\n" - "```\n", - ) - -import copy -import json - -from loguru import logger - -from pipecat.frames.frames import ( - Frame, - FunctionCallResultFrame, - InterimTranscriptionFrame, - LLMMessagesUpdateFrame, - LLMSetToolsFrame, - LLMTextFrame, - TranscriptionFrame, -) -from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext -from pipecat.processors.frame_processor import FrameDirection -from pipecat.services.openai.llm import ( - OpenAIAssistantContextAggregator, - OpenAIUserContextAggregator, -) - -from . import events -from .frames import RealtimeFunctionCallResultFrame, RealtimeMessagesUpdateFrame - - -class OpenAIRealtimeLLMContext(OpenAILLMContext): - """OpenAI Realtime LLM context with session management and message conversion. - - Extends the standard OpenAI LLM context to support real-time session properties, - instruction management, and conversion between standard message formats and - realtime conversation items. - - .. deprecated:: 0.0.99 - `OpenAIRealtimeLLMContext` is deprecated and will be removed in a future version. - Use the universal `LLMContext` and `LLMContextAggregatorPair` instead. - See `OpenAILLMContext` docstring for migration guide. - """ - - def __init__(self, messages=None, tools=None, **kwargs): - """Initialize the OpenAIRealtimeLLMContext. - - Args: - messages: Initial conversation messages. Defaults to None. - tools: Available function tools. Defaults to None. - **kwargs: Additional arguments passed to parent OpenAILLMContext. - """ - # Super handles deprecation warning - super().__init__(messages=messages, tools=tools, **kwargs) - self.__setup_local() - - def __setup_local(self): - self.llm_needs_settings_update = True - self.llm_needs_initial_messages = True - self._session_instructions = "" - - return - - @staticmethod - def upgrade_to_realtime(obj: OpenAILLMContext) -> "OpenAIRealtimeLLMContext": - """Upgrade a standard OpenAI LLM context to a realtime context. - - Args: - obj: The OpenAILLMContext instance to upgrade. - - Returns: - The upgraded OpenAIRealtimeLLMContext instance. - """ - if isinstance(obj, OpenAILLMContext) and not isinstance(obj, OpenAIRealtimeLLMContext): - obj.__class__ = OpenAIRealtimeLLMContext - obj.__setup_local() - return obj - - # todo - # - finish implementing all frames - - def from_standard_message(self, message): - """Convert a standard message format to a realtime conversation item. - - Args: - message: The standard message dictionary to convert. - - Returns: - A ConversationItem instance for the realtime API. - """ - if message.get("role") == "user": - content = message.get("content") - if isinstance(message.get("content"), list): - content = "" - for c in message.get("content"): - if c.get("type") == "text": - content += " " + c.get("text") - else: - logger.error( - f"Unhandled content type in context message: {c.get('type')} - {message}" - ) - return events.ConversationItem( - role="user", - type="message", - content=[events.ItemContent(type="input_text", text=content)], - ) - if message.get("role") == "assistant" and message.get("tool_calls"): - tc = message.get("tool_calls")[0] - return events.ConversationItem( - type="function_call", - call_id=tc["id"], - name=tc["function"]["name"], - arguments=tc["function"]["arguments"], - ) - logger.error(f"Unhandled message type in from_standard_message: {message}") - - def get_messages_for_initializing_history(self): - """Get conversation items for initializing the realtime session history. - - Converts the context's messages to a format suitable for the realtime API, - handling system instructions and conversation history packaging. - - Returns: - List of conversation items for session initialization. - """ - # We can't load a long conversation history into the openai realtime api yet. (The API/model - # forgets that it can do audio, if you do a series of `conversation.item.create` calls.) So - # our general strategy until this is fixed is just to put everything into a first "user" - # message as a single input. - if not self.messages: - return [] - - messages = copy.deepcopy(self.messages) - - # If we have a "system" message as our first message, let's pull that out into session - # "instructions" - if messages[0].get("role") == "system": - self.llm_needs_settings_update = True - system = messages.pop(0) - content = system.get("content") - if isinstance(content, str): - self._session_instructions = content - elif isinstance(content, list): - self._session_instructions = content[0].get("text") - if not messages: - return [] - - # If we have just a single "user" item, we can just send it normally - if len(messages) == 1 and messages[0].get("role") == "user": - return [self.from_standard_message(messages[0])] - - # Otherwise, let's pack everything into a single "user" message with a bit of - # explanation for the LLM - intro_text = """ - This is a previously saved conversation. Please treat this conversation history as a - starting point for the current conversation.""" - - trailing_text = """ - This is the end of the previously saved conversation. Please continue the conversation - from here. If the last message is a user instruction or question, act on that instruction - or answer the question. If the last message is an assistant response, simple say that you - are ready to continue the conversation.""" - - return [ - { - "role": "user", - "type": "message", - "content": [ - { - "type": "input_text", - "text": "\n\n".join( - [intro_text, json.dumps(messages, indent=2), trailing_text] - ), - } - ], - } - ] - - def add_user_content_item_as_message(self, item): - """Add a user content item as a standard message to the context. - - Args: - item: The conversation item to add as a user message. - """ - message = { - "role": "user", - "content": [{"type": "text", "text": item.content[0].transcript}], - } - self.add_message(message) - - -class OpenAIRealtimeUserContextAggregator(OpenAIUserContextAggregator): - """User context aggregator for OpenAI Realtime API. - - Handles user input frames and generates appropriate context updates - for the realtime conversation, including message updates and tool settings. - - .. deprecated:: 0.0.99 - `OpenAIRealtimeUserContextAggregator` is deprecated and will be removed in a future version. - Use the universal `LLMContext` and `LLMContextAggregatorPair` instead. - See `OpenAILLMContext` docstring for migration guide. - - Args: - context: The OpenAI realtime LLM context. - **kwargs: Additional arguments passed to parent aggregator. - """ - - # Super handles deprecation warning - - async def process_frame( - self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM - ): - """Process incoming frames and handle realtime-specific frame types. - - Args: - frame: The frame to process. - direction: The direction of frame flow in the pipeline. - """ - await super().process_frame(frame, direction) - # Parent does not push LLMMessagesUpdateFrame. This ensures that in a typical pipeline, - # messages are only processed by the user context aggregator, which is generally what we want. But - # we also need to send new messages over the websocket, so the openai realtime API has them - # in its context. - if isinstance(frame, LLMMessagesUpdateFrame): - await self.push_frame(RealtimeMessagesUpdateFrame(context=self._context)) - - # Parent also doesn't push the LLMSetToolsFrame. - if isinstance(frame, LLMSetToolsFrame): - await self.push_frame(frame, direction) - - async def push_aggregation(self): - """Push user input aggregation. - - Currently ignores all user input coming into the pipeline as realtime - audio input is handled directly by the service. - """ - # for the moment, ignore all user input coming into the pipeline. - # todo: think about whether/how to fix this to allow for text input from - # upstream (transport/transcription, or other sources) - pass - - -class OpenAIRealtimeAssistantContextAggregator(OpenAIAssistantContextAggregator): - """Assistant context aggregator for OpenAI Realtime API. - - Handles assistant output frames from the realtime service, filtering - out duplicate text frames and managing function call results. - - .. deprecated:: 0.0.99 - `OpenAIRealtimeAssistantContextAggregator` is deprecated and will be removed in a future version. - Use the universal `LLMContext` and `LLMContextAggregatorPair` instead. - See `OpenAILLMContext` docstring for migration guide. - - Args: - context: The OpenAI realtime LLM context. - **kwargs: Additional arguments passed to parent aggregator. - """ - - # Super handles deprecation warning - - # The LLMAssistantContextAggregator uses TextFrames to aggregate the LLM output, - # but the OpenAIRealtimeLLMService pushes LLMTextFrames and TTSTextFrames. We - # need to override this proces_frame for LLMTextFrame, so that only the TTSTextFrames - # are process. This ensures that the context gets only one set of messages. - # OpenAIRealtimeLLMService also pushes TranscriptionFrames and InterimTranscriptionFrames, - # so we need to ignore pushing those as well, as they're also TextFrames. - async def process_frame(self, frame: Frame, direction: FrameDirection): - """Process assistant frames, filtering out duplicate text content. - - Args: - frame: The frame to process. - direction: The direction of frame flow in the pipeline. - """ - if not isinstance(frame, (LLMTextFrame, TranscriptionFrame, InterimTranscriptionFrame)): - await super().process_frame(frame, direction) - - async def handle_function_call_result(self, frame: FunctionCallResultFrame): - """Handle function call result and notify the realtime service. - - Args: - frame: The function call result frame to handle. - """ - await super().handle_function_call_result(frame) - - # The standard function callback code path pushes the FunctionCallResultFrame from the llm itself, - # so we didn't have a chance to add the result to the openai realtime api context. Let's push a - # special frame to do that. - await self.push_frame( - RealtimeFunctionCallResultFrame(result_frame=frame), FrameDirection.UPSTREAM - ) diff --git a/src/pipecat/services/openai/realtime/frames.py b/src/pipecat/services/openai/realtime/frames.py deleted file mode 100644 index 0d3eef974..000000000 --- a/src/pipecat/services/openai/realtime/frames.py +++ /dev/null @@ -1,58 +0,0 @@ -# -# Copyright (c) 2024-2026, Daily -# -# SPDX-License-Identifier: BSD 2-Clause License -# - -"""Custom frame types for OpenAI Realtime API integration. - -.. deprecated:: 0.0.92 - OpenAI Realtime no longer uses types from this module under the hood. - - It now works more like most LLM services in Pipecat, relying on updates to - its context, pushed by context aggregators, to update its internal state. - - Listen for ``LLMContextFrame`` s for context updates. -""" - -import warnings - -with warnings.catch_warnings(): - warnings.simplefilter("always") - warnings.warn( - "Types in pipecat.services.openai.realtime.frames are deprecated. \n" - "OpenAI Realtime no longer uses types from this module under the hood. \n\n" - "It now works more like other LLM services in Pipecat, relying on updates to \n" - "its context, pushed by context aggregators, to update its internal state.\n\n" - "Listen for `LLMContextFrame`s for context updates.\n" - ) - -from dataclasses import dataclass -from typing import TYPE_CHECKING - -from pipecat.frames.frames import DataFrame, FunctionCallResultFrame - -if TYPE_CHECKING: - from pipecat.services.openai.realtime.context import OpenAIRealtimeLLMContext - - -@dataclass -class RealtimeMessagesUpdateFrame(DataFrame): - """Frame indicating that the realtime context messages have been updated. - - Parameters: - context: The updated OpenAI realtime LLM context. - """ - - context: "OpenAIRealtimeLLMContext" - - -@dataclass -class RealtimeFunctionCallResultFrame(DataFrame): - """Frame containing function call results for the realtime service. - - Parameters: - result_frame: The function call result frame to send to the realtime API. - """ - - result_frame: FunctionCallResultFrame diff --git a/src/pipecat/services/openai/realtime/llm.py b/src/pipecat/services/openai/realtime/llm.py index 293ee0d87..a1e593ecd 100644 --- a/src/pipecat/services/openai/realtime/llm.py +++ b/src/pipecat/services/openai/realtime/llm.py @@ -48,15 +48,7 @@ from pipecat.frames.frames import ( ) from pipecat.metrics.metrics import LLMTokenUsage from pipecat.processors.aggregators.llm_context import LLMContext -from pipecat.processors.aggregators.llm_response import ( - LLMAssistantAggregatorParams, - LLMUserAggregatorParams, -) from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair -from pipecat.processors.aggregators.openai_llm_context import ( - OpenAILLMContext, - OpenAILLMContextFrame, -) from pipecat.processors.frame_processor import FrameDirection from pipecat.services.llm_service import FunctionCallFromLLM, LLMService from pipecat.services.settings import ( @@ -564,13 +556,8 @@ class OpenAIRealtimeLLMService(LLMService): if isinstance(frame, TranscriptionFrame): pass - elif isinstance(frame, (LLMContextFrame, OpenAILLMContextFrame)): - context = ( - frame.context - if isinstance(frame, LLMContextFrame) - else LLMContext.from_openai_context(frame.context) - ) - await self._handle_context(context) + elif isinstance(frame, LLMContextFrame): + await self._handle_context(frame.context) elif isinstance(frame, InputAudioRawFrame): if not self._audio_input_paused: await self._send_user_audio(frame) @@ -1133,74 +1120,3 @@ class OpenAIRealtimeLLMService(LLMService): output=json.dumps(result, ensure_ascii=False), ) await self.send_client_event(events.ConversationItemCreateEvent(item=item)) - - def create_context_aggregator( - self, - context: OpenAILLMContext, - *, - user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(), - assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(), - ) -> LLMContextAggregatorPair: - """Create an instance of OpenAIContextAggregatorPair from an OpenAILLMContext. - - NOTE: this method exists only for backward compatibility. New code - should instead do:: - - context = LLMContext(...) - context_aggregator = LLMContextAggregatorPair(context) - - Constructor keyword arguments for both the user and assistant aggregators can be provided. - - Args: - context: The LLM context. - user_params: User aggregator parameters. - assistant_params: Assistant aggregator parameters. - - Returns: - OpenAIContextAggregatorPair: A pair of context aggregators, one for - the user and one for the assistant, encapsulated in an - OpenAIContextAggregatorPair. - - .. deprecated:: 0.0.99 - `create_context_aggregator()` is deprecated and will be removed in a future version. - Use the universal `LLMContext` and `LLMContextAggregatorPair` instead. - See `OpenAILLMContext` docstring for migration guide. - """ - # Log warning about transcription frame direction change in 0.0.92. - # We're putting this warning here rather than in the constructor so - # that it shows up for folks who haven't updated their code at all - # since 0.0.92, gives them a way to acknowledge and dismiss the - # warning, and encourages adoption of a new preferred pattern. - logger.warning( - "As of version 0.0.92, TranscriptionFrames and InterimTranscriptionFrames " - "now go upstream from OpenAIRealtimeLLMService, so if you're using " - "TranscriptProcessor, say, you'll want to adjust accordingly:\n\n" - "pipeline = Pipeline(\n" - " [\n" - " transport.input(),\n" - " context_aggregator.user(),\n\n" - " # BEFORE\n" - " llm,\n" - " transcript.user(),\n\n" - " # AFTER\n" - " transcript.user(),\n" - " llm,\n\n" - " transport.output(),\n" - " transcript.assistant(),\n" - " context_aggregator.assistant(),\n" - " ]\n" - ")\n\n" - "Also, LLMTextFrames are no longer pushed from " - "OpenAIRealtimeLLMService when it's configured with " - "output_modalities=['audio']. Listen for TTSTextFrames instead.\n\n" - "Once you've made the appropriate changes (if needed), you can " - "dismiss this warning by updating to the new context-setup pattern:\n\n" - " context = LLMContext(messages, tools)\n" - " context_aggregator = LLMContextAggregatorPair(context)\n" - ) - # from_openai_context handles deprecation warning already - context = LLMContext.from_openai_context(context) - assistant_params.expect_stripped_words = False - return LLMContextAggregatorPair( - context, user_params=user_params, assistant_params=assistant_params - ) diff --git a/src/pipecat/services/openai/responses/llm.py b/src/pipecat/services/openai/responses/llm.py index fce6b46d8..e1b4ace78 100644 --- a/src/pipecat/services/openai/responses/llm.py +++ b/src/pipecat/services/openai/responses/llm.py @@ -674,17 +674,11 @@ class OpenAIResponsesLLMService(_BaseOpenAIResponsesLLMService, WebsocketLLMServ """ await super().process_frame(frame, direction) - context = None if isinstance(frame, LLMContextFrame): - context = frame.context - else: - await self.push_frame(frame, direction) - - if context: try: await self.push_frame(LLMFullResponseStartFrame()) await self.start_processing_metrics() - await self._process_context(context) + await self._process_context(frame.context) except asyncio.CancelledError: # The pipeline cancelled us (e.g. due to an interruption). # Ask the server to stop generating and flag that we need @@ -717,6 +711,8 @@ class OpenAIResponsesLLMService(_BaseOpenAIResponsesLLMService, WebsocketLLMServ finally: await self.stop_processing_metrics() await self.push_frame(LLMFullResponseEndFrame()) + else: + await self.push_frame(frame, direction) # -- core inference ------------------------------------------------------- @@ -960,17 +956,11 @@ class OpenAIResponsesHttpLLMService(_BaseOpenAIResponsesLLMService): """ await super().process_frame(frame, direction) - context = None if isinstance(frame, LLMContextFrame): - context = frame.context - else: - await self.push_frame(frame, direction) - - if context: try: await self.push_frame(LLMFullResponseStartFrame()) await self.start_processing_metrics() - await self._process_context(context) + await self._process_context(frame.context) except httpx.TimeoutException as e: await self._call_event_handler("on_completion_timeout") await self.push_error(error_msg="LLM completion timeout", exception=e) @@ -979,6 +969,8 @@ class OpenAIResponsesHttpLLMService(_BaseOpenAIResponsesLLMService): finally: await self.stop_processing_metrics() await self.push_frame(LLMFullResponseEndFrame()) + else: + await self.push_frame(frame, direction) @traced_llm async def _process_context(self, context: LLMContext): diff --git a/src/pipecat/services/perplexity/llm.py b/src/pipecat/services/perplexity/llm.py index 50953691a..95ff23889 100644 --- a/src/pipecat/services/perplexity/llm.py +++ b/src/pipecat/services/perplexity/llm.py @@ -20,7 +20,6 @@ from pipecat.adapters.services.open_ai_adapter import OpenAILLMInvocationParams from pipecat.adapters.services.perplexity_adapter import PerplexityLLMAdapter from pipecat.metrics.metrics import LLMTokenUsage from pipecat.processors.aggregators.llm_context import LLMContext -from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext from pipecat.services.openai.base_llm import BaseOpenAILLMService from pipecat.services.openai.llm import OpenAILLMService @@ -126,7 +125,7 @@ class PerplexityLLMService(OpenAILLMService): return params - async def _process_context(self, context: OpenAILLMContext | LLMContext): + async def _process_context(self, context: LLMContext): """Process a context through the LLM and accumulate token usage metrics. This method overrides the parent class implementation to handle diff --git a/src/pipecat/services/sambanova/llm.py b/src/pipecat/services/sambanova/llm.py index 63f37cb43..66f3d52bf 100644 --- a/src/pipecat/services/sambanova/llm.py +++ b/src/pipecat/services/sambanova/llm.py @@ -20,7 +20,6 @@ from pipecat.frames.frames import ( ) from pipecat.metrics.metrics import LLMTokenUsage from pipecat.processors.aggregators.llm_context import LLMContext -from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext from pipecat.services.llm_service import FunctionCallFromLLM from pipecat.services.openai.base_llm import BaseOpenAILLMService from pipecat.services.openai.llm import OpenAILLMService @@ -138,9 +137,7 @@ class SambaNovaLLMService(OpenAILLMService): # type: ignore return params @traced_llm # type: ignore - async def _process_context( - self, context: OpenAILLMContext | LLMContext - ) -> AsyncStream[ChatCompletionChunk]: + async def _process_context(self, context: LLMContext) -> AsyncStream[ChatCompletionChunk]: """Process OpenAI LLM context and stream chat completion chunks. This method handles the streaming response from SambaNova API, including @@ -163,11 +160,7 @@ class SambaNovaLLMService(OpenAILLMService): # type: ignore await self.start_ttfb_metrics() - chunk_stream = await ( - self._stream_chat_completions_specific_context(context) - if isinstance(context, OpenAILLMContext) - else self._stream_chat_completions_universal_context(context) - ) + chunk_stream = await self.get_chat_completions(context) # Use context manager to ensure stream is closed on cancellation/exception. # Without this, CancelledError during iteration leaves the underlying socket open. diff --git a/src/pipecat/services/ultravox/llm.py b/src/pipecat/services/ultravox/llm.py index fe8a97549..525da40c4 100644 --- a/src/pipecat/services/ultravox/llm.py +++ b/src/pipecat/services/ultravox/llm.py @@ -46,15 +46,7 @@ from pipecat.frames.frames import ( VADUserStoppedSpeakingFrame, ) from pipecat.processors.aggregators.llm_context import LLMContext -from pipecat.processors.aggregators.llm_response import ( - LLMAssistantAggregatorParams, - LLMUserAggregatorParams, -) from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair -from pipecat.processors.aggregators.openai_llm_context import ( - OpenAILLMContext, - OpenAILLMContextFrame, -) from pipecat.processors.frame_processor import FrameDirection from pipecat.services.llm_service import FunctionCallFromLLM, LLMService from pipecat.services.settings import NOT_GIVEN, LLMSettings, _NotGiven @@ -404,13 +396,8 @@ class UltravoxRealtimeLLMService(LLMService): """ await super().process_frame(frame, direction) - if isinstance(frame, (LLMContextFrame, OpenAILLMContextFrame)): - context = ( - frame.context - if isinstance(frame, LLMContextFrame) - else LLMContext.from_openai_context(frame.context) - ) - await self._handle_context(context) + if isinstance(frame, LLMContextFrame): + await self._handle_context(frame.context) elif isinstance(frame, InterruptionFrame): await self.stop_all_metrics() await self.push_frame(frame, direction) @@ -629,40 +616,3 @@ class UltravoxRealtimeLLMService(LLMService): await self.push_frame(LLMFullResponseStartFrame()) self._bot_responding = "text" await self.push_frame(LLMTextFrame(text=text or delta)) - - def create_context_aggregator( - self, - context: OpenAILLMContext, - *, - user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(), - assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(), - ) -> LLMContextAggregatorPair: - """Create an instance of LLMContextAggregatorPair from an OpenAILLMContext. - - Constructor keyword arguments for both the user and assistant aggregators can be provided. - - NOTE: this method exists only for backward compatibility. New code - should instead do:: - - context = LLMContext(...) - context_aggregator = LLMContextAggregatorPair(context) - - Args: - context: The LLM context to use. - user_params: User aggregator parameters. Defaults to LLMUserAggregatorParams(). - assistant_params: Assistant aggregator parameters. Defaults to LLMAssistantAggregatorParams(). - - Returns: - A pair of user and assistant context aggregators. - - .. deprecated:: 0.0.99 - `create_context_aggregator()` is deprecated and will be removed in a future version. - Use the universal `LLMContext` and `LLMContextAggregatorPair` instead. - See `OpenAILLMContext` docstring for migration guide. - """ - # from_openai_context handles deprecation warning - context = LLMContext.from_openai_context(context) - assistant_params.expect_stripped_words = False - return LLMContextAggregatorPair( - context, user_params=user_params, assistant_params=assistant_params - ) diff --git a/src/pipecat/services/xai/llm.py b/src/pipecat/services/xai/llm.py index 6fc274a6d..0bbfb62b3 100644 --- a/src/pipecat/services/xai/llm.py +++ b/src/pipecat/services/xai/llm.py @@ -18,57 +18,12 @@ from loguru import logger from pipecat.metrics.metrics import LLMTokenUsage from pipecat.processors.aggregators.llm_context import LLMContext -from pipecat.processors.aggregators.llm_response import ( - LLMAssistantAggregatorParams, - LLMUserAggregatorParams, -) -from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext from pipecat.services.openai.base_llm import BaseOpenAILLMService from pipecat.services.openai.llm import ( - OpenAIAssistantContextAggregator, OpenAILLMService, - OpenAIUserContextAggregator, ) -@dataclass -class GrokContextAggregatorPair: - """Pair of context aggregators for user and assistant interactions. - - Provides a convenient container for managing both user and assistant - context aggregators together for Grok LLM interactions. - - .. deprecated:: 0.0.99 - `GrokContextAggregatorPair` is deprecated and will be removed in a future version. - Use the universal `LLMContext` and `LLMContextAggregatorPair` instead. - See `OpenAILLMContext` docstring for migration guide. - - Parameters: - _user: The user context aggregator instance. - _assistant: The assistant context aggregator instance. - """ - - # Aggregators handle deprecation warnings - _user: OpenAIUserContextAggregator - _assistant: OpenAIAssistantContextAggregator - - def user(self) -> OpenAIUserContextAggregator: - """Get the user context aggregator. - - Returns: - The user context aggregator instance. - """ - return self._user - - def assistant(self) -> OpenAIAssistantContextAggregator: - """Get the assistant context aggregator. - - Returns: - The assistant context aggregator instance. - """ - return self._assistant - - @dataclass class GrokLLMSettings(BaseOpenAILLMService.Settings): """Settings for GrokLLMService.""" @@ -149,7 +104,7 @@ class GrokLLMService(OpenAILLMService): logger.debug(f"Creating Grok client with api {base_url}") return super().create_client(api_key, base_url, **kwargs) - async def _process_context(self, context: OpenAILLMContext | LLMContext): + async def _process_context(self, context: LLMContext): """Process a context through the LLM and accumulate token usage metrics. This method overrides the parent class implementation to handle Grok's @@ -215,38 +170,3 @@ class GrokLLMService(OpenAILLMService): if tokens.reasoning_tokens is not None: self._reasoning_tokens = tokens.reasoning_tokens - - def create_context_aggregator( - self, - context: OpenAILLMContext, - *, - user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(), - assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(), - ) -> GrokContextAggregatorPair: - """Create an instance of GrokContextAggregatorPair from an OpenAILLMContext. - - Constructor keyword arguments for both the user and assistant aggregators - can be provided. - - Args: - context: The LLM context to create aggregators for. - user_params: Parameters for configuring the user aggregator. - assistant_params: Parameters for configuring the assistant aggregator. - - Returns: - GrokContextAggregatorPair: A pair of context aggregators, one for - the user and one for the assistant, encapsulated in an - GrokContextAggregatorPair. - - .. deprecated:: 0.0.99 - `create_context_aggregator()` is deprecated and will be removed in a future version. - Use the universal `LLMContext` and `LLMContextAggregatorPair` instead. - See `OpenAILLMContext` docstring for migration guide. - """ - context.set_llm_adapter(self.get_llm_adapter()) - - # Aggregators handle deprecation warnings - user = OpenAIUserContextAggregator(context, params=user_params) - assistant = OpenAIAssistantContextAggregator(context, params=assistant_params) - - return GrokContextAggregatorPair(_user=user, _assistant=assistant) diff --git a/src/pipecat/services/xai/realtime/llm.py b/src/pipecat/services/xai/realtime/llm.py index 1317e7269..33dbe2f63 100644 --- a/src/pipecat/services/xai/realtime/llm.py +++ b/src/pipecat/services/xai/realtime/llm.py @@ -46,14 +46,9 @@ from pipecat.frames.frames import ( ) from pipecat.metrics.metrics import LLMTokenUsage from pipecat.processors.aggregators.llm_context import LLMContext -from pipecat.processors.aggregators.llm_response import ( - LLMAssistantAggregatorParams, - LLMUserAggregatorParams, -) from pipecat.processors.aggregators.llm_response_universal import ( LLMContextAggregatorPair, ) -from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext from pipecat.processors.frame_processor import FrameDirection from pipecat.services.llm_service import FunctionCallFromLLM, LLMService from pipecat.services.settings import ( @@ -946,26 +941,3 @@ class GrokRealtimeLLMService(LLMService): output=json.dumps(result, ensure_ascii=False), ) await self.send_client_event(events.ConversationItemCreateEvent(item=item)) - - def create_context_aggregator( - self, - context: OpenAILLMContext, - *, - user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(), - assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(), - ) -> LLMContextAggregatorPair: - """Create context aggregators for the Grok Realtime service. - - Args: - context: The LLM context. - user_params: User aggregator parameters. - assistant_params: Assistant aggregator parameters. - - Returns: - LLMContextAggregatorPair for user and assistant context aggregation. - """ - context = LLMContext.from_openai_context(context) - assistant_params.expect_stripped_words = False - return LLMContextAggregatorPair( - context, user_params=user_params, assistant_params=assistant_params - ) diff --git a/src/pipecat/transports/base_input.py b/src/pipecat/transports/base_input.py index d0cd5212c..d4a1241f3 100644 --- a/src/pipecat/transports/base_input.py +++ b/src/pipecat/transports/base_input.py @@ -480,7 +480,7 @@ class BaseInputTransport(FrameProcessor): self, audio_frame: InputAudioRawFrame, vad_state: VADState ) -> VADState: """Handle Voice Activity Detection results and generate appropriate frames.""" - if self._params.turn_analyzer or self._deprecated_openaillmcontext: + if self._params.turn_analyzer: return await self._deprecated_old_handle_vad(audio_frame, vad_state) else: return await self._deprecated_new_handle_vad(audio_frame, vad_state) diff --git a/src/pipecat/utils/tracing/service_decorators.py b/src/pipecat/utils/tracing/service_decorators.py index cc353e1a3..d87955755 100644 --- a/src/pipecat/utils/tracing/service_decorators.py +++ b/src/pipecat/utils/tracing/service_decorators.py @@ -24,7 +24,6 @@ if TYPE_CHECKING: from opentelemetry import trace from pipecat.processors.aggregators.llm_context import NOT_GIVEN, LLMContext -from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext from pipecat.utils.tracing.service_attributes import ( add_gemini_live_span_attributes, add_llm_span_attributes, @@ -459,40 +458,30 @@ def traced_llm(func: Optional[Callable] = None, *, name: Optional[str] = None) - self.push_frame = traced_push_frame # Get messages for logging - # For OpenAILLMContext: use context's own get_messages_for_logging() method - # For LLMContext: use adapter's get_messages_for_logging() which returns + # Use adapter's get_messages_for_logging() which returns # messages in provider's native format with sensitive data sanitized messages = None serialized_messages = None - if isinstance(context, OpenAILLMContext): - # OpenAILLMContext and subclasses have their own method - messages = context.get_messages_for_logging() - elif isinstance(context, LLMContext): - # Universal LLMContext - use adapter for provider-native format - if hasattr(self, "get_llm_adapter"): - adapter = self.get_llm_adapter() - messages = adapter.get_messages_for_logging(context) + # Use adapter for provider-native format + if hasattr(self, "get_llm_adapter"): + adapter = self.get_llm_adapter() + messages = adapter.get_messages_for_logging(context) # Serialize messages if available if messages: serialized_messages = json.dumps(messages) # Get tools - # For OpenAILLMContext: tools may need adapter conversion if set - # For LLMContext: use adapter's from_standard_tools() to convert ToolsSchema + # Use adapter's from_standard_tools() to convert ToolsSchema tools = None serialized_tools = None tool_count = 0 - if isinstance(context, OpenAILLMContext): - # OpenAILLMContext: tools property handles adapter conversion internally - tools = context.tools - elif isinstance(context, LLMContext): - # Universal LLMContext - use adapter to convert ToolsSchema - if hasattr(self, "get_llm_adapter") and hasattr(context, "tools"): - adapter = self.get_llm_adapter() - tools = adapter.from_standard_tools(context.tools) + # Use adapter to convert ToolsSchema + if hasattr(self, "get_llm_adapter") and hasattr(context, "tools"): + adapter = self.get_llm_adapter() + tools = adapter.from_standard_tools(context.tools) # Serialize and count tools if available # Check if tools is not None and not NOT_GIVEN @@ -501,36 +490,28 @@ def traced_llm(func: Optional[Callable] = None, *, name: Optional[str] = None) - tool_count = len(tools) if isinstance(tools, list) else 1 # Handle system message for different services + # settings.system_instruction takes priority (matches service behavior) system_message = None - if isinstance(context, LLMContext): - # settings.system_instruction takes priority (matches service behavior) - if hasattr(self, "_settings") and getattr( - self._settings, "system_instruction", None - ): - system_message = self._settings.system_instruction - else: - # Fall back to extracting from context messages - ctx_messages = context.get_messages() - if ctx_messages: - first = ctx_messages[0] - if ( - isinstance(first, dict) - and first.get("role") == "system" - ): - content = first.get("content") - if isinstance(content, str): - system_message = content - elif isinstance(content, list): - system_message = " ".join( - part.get("text", "") - for part in content - if isinstance(part, dict) - and part.get("type") == "text" - ) - elif hasattr(context, "system"): - system_message = context.system - elif hasattr(context, "system_message"): - system_message = context.system_message + if hasattr(self, "_settings") and getattr( + self._settings, "system_instruction", None + ): + system_message = self._settings.system_instruction + else: + # Fall back to extracting from context messages + ctx_messages = context.get_messages() + if ctx_messages: + first = ctx_messages[0] + if isinstance(first, dict) and first.get("role") == "system": + content = first.get("content") + if isinstance(content, str): + system_message = content + elif isinstance(content, list): + system_message = " ".join( + part.get("text", "") + for part in content + if isinstance(part, dict) + and part.get("type") == "text" + ) # Use given_fields() defensively in case a service doesn't # initialize all settings. diff --git a/tests/test_context_aggregators.py b/tests/test_context_aggregators.py deleted file mode 100644 index 37d36bfef..000000000 --- a/tests/test_context_aggregators.py +++ /dev/null @@ -1,1060 +0,0 @@ -# -# Copyright (c) 2024-2026, Daily -# -# SPDX-License-Identifier: BSD 2-Clause License -# - -import json -import unittest -from typing import Any, Optional - -from pipecat.audio.interruptions.min_words_interruption_strategy import MinWordsInterruptionStrategy -from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams -from pipecat.audio.vad.vad_analyzer import VADParams -from pipecat.frames.frames import ( - BotStartedSpeakingFrame, - EmulateUserStartedSpeakingFrame, - EmulateUserStoppedSpeakingFrame, - Frame, - FunctionCallInProgressFrame, - FunctionCallResultFrame, - FunctionCallResultProperties, - InterimTranscriptionFrame, - InterruptionFrame, - LLMContextAssistantTimestampFrame, - LLMContextFrame, - LLMFullResponseEndFrame, - LLMFullResponseStartFrame, - OpenAILLMContextAssistantTimestampFrame, - SpeechControlParamsFrame, - TextFrame, - TranscriptionFrame, - UserStartedSpeakingFrame, - UserStoppedSpeakingFrame, -) -from pipecat.pipeline.pipeline import Pipeline -from pipecat.pipeline.task import PipelineParams -from pipecat.processors.aggregators.llm_context import LLMContext -from pipecat.processors.aggregators.llm_response import ( - LLMAssistantAggregatorParams, - LLMUserAggregatorParams, - LLMUserContextAggregator, -) -from pipecat.processors.aggregators.llm_response_universal import LLMAssistantAggregator -from pipecat.processors.aggregators.openai_llm_context import ( - OpenAILLMContext, - OpenAILLMContextFrame, -) -from pipecat.processors.frame_processor import FrameDirection, FrameProcessor -from pipecat.services.anthropic.llm import ( - AnthropicAssistantContextAggregator, - AnthropicLLMContext, - AnthropicUserContextAggregator, -) -from pipecat.services.aws.llm import ( - AWSBedrockAssistantContextAggregator, - AWSBedrockLLMContext, - AWSBedrockUserContextAggregator, -) -from pipecat.services.google.llm import ( - GoogleAssistantContextAggregator, - GoogleLLMContext, - GoogleUserContextAggregator, -) -from pipecat.services.openai.llm import ( - OpenAIAssistantContextAggregator, - OpenAIUserContextAggregator, -) -from pipecat.tests.utils import SleepFrame, run_test - -AGGREGATION_TIMEOUT = 0.1 -AGGREGATION_SLEEP = 0.15 - - -class BaseTestUserContextAggregator: - CONTEXT_CLASS = None # To be set in subclasses - AGGREGATOR_CLASS = None # To be set in subclasses - EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame] - - def check_message_content(self, context: OpenAILLMContext, index: int, content: str): - assert context.messages[index]["content"] == content - - def check_message_multi_content( - self, context: OpenAILLMContext, content_index: int, index: int, content: str - ): - assert context.messages[index]["content"] == content - - async def test_se(self): - assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" - - context = self.CONTEXT_CLASS() - aggregator = self.AGGREGATOR_CLASS(context) - frames_to_send = [UserStartedSpeakingFrame(), UserStoppedSpeakingFrame()] - expected_down_frames = [UserStartedSpeakingFrame, UserStoppedSpeakingFrame] - await run_test( - aggregator, - frames_to_send=frames_to_send, - expected_down_frames=expected_down_frames, - ) - - async def test_ste(self): - assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" - - context = self.CONTEXT_CLASS() - aggregator = self.AGGREGATOR_CLASS(context) - frames_to_send = [ - UserStartedSpeakingFrame(), - TranscriptionFrame(text="Hello!", user_id="cat", timestamp=""), - SleepFrame(), - UserStoppedSpeakingFrame(), - ] - expected_down_frames = [ - UserStartedSpeakingFrame, - *self.EXPECTED_CONTEXT_FRAMES, - UserStoppedSpeakingFrame, - ] - await run_test( - aggregator, - frames_to_send=frames_to_send, - expected_down_frames=expected_down_frames, - ) - self.check_message_content(context, 0, "Hello!") - - async def test_site(self): - assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" - - context = self.CONTEXT_CLASS() - aggregator = self.AGGREGATOR_CLASS(context) - frames_to_send = [ - UserStartedSpeakingFrame(), - InterimTranscriptionFrame(text="Hello", user_id="cat", timestamp=""), - TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""), - SleepFrame(), - UserStoppedSpeakingFrame(), - ] - expected_down_frames = [ - UserStartedSpeakingFrame, - *self.EXPECTED_CONTEXT_FRAMES, - UserStoppedSpeakingFrame, - ] - await run_test( - aggregator, - frames_to_send=frames_to_send, - expected_down_frames=expected_down_frames, - ) - self.check_message_content(context, 0, "Hello Pipecat!") - - async def test_st1iest2e(self): - assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" - - context = self.CONTEXT_CLASS() - aggregator = self.AGGREGATOR_CLASS(context) - frames_to_send = [ - UserStartedSpeakingFrame(), - TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""), - InterimTranscriptionFrame(text="How ", user_id="cat", timestamp=""), - SleepFrame(), - UserStoppedSpeakingFrame(), - UserStartedSpeakingFrame(), - TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""), - SleepFrame(), - UserStoppedSpeakingFrame(), - ] - expected_down_frames = [ - UserStartedSpeakingFrame, - UserStoppedSpeakingFrame, - UserStartedSpeakingFrame, - *self.EXPECTED_CONTEXT_FRAMES, - UserStoppedSpeakingFrame, - ] - await run_test( - aggregator, - frames_to_send=frames_to_send, - expected_down_frames=expected_down_frames, - ) - self.check_message_content(context, 0, "Hello Pipecat! How are you?") - - async def test_siet(self): - assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" - - context = self.CONTEXT_CLASS() - aggregator = self.AGGREGATOR_CLASS( - context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT) - ) - frames_to_send = [ - UserStartedSpeakingFrame(), - InterimTranscriptionFrame(text="How ", user_id="cat", timestamp=""), - SleepFrame(), - UserStoppedSpeakingFrame(), - TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""), - SleepFrame(sleep=AGGREGATION_SLEEP), - ] - expected_down_frames = [ - UserStartedSpeakingFrame, - UserStoppedSpeakingFrame, - *self.EXPECTED_CONTEXT_FRAMES, - ] - await run_test( - aggregator, - frames_to_send=frames_to_send, - expected_down_frames=expected_down_frames, - ) - self.check_message_content(context, 0, "How are you?") - - async def test_sieit(self): - assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" - - context = self.CONTEXT_CLASS() - aggregator = self.AGGREGATOR_CLASS( - context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT) - ) - frames_to_send = [ - UserStartedSpeakingFrame(), - InterimTranscriptionFrame(text="How ", user_id="cat", timestamp=""), - SleepFrame(), - UserStoppedSpeakingFrame(), - InterimTranscriptionFrame(text="are you?", user_id="cat", timestamp=""), - TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""), - SleepFrame(sleep=AGGREGATION_SLEEP), - ] - expected_down_frames = [ - UserStartedSpeakingFrame, - UserStoppedSpeakingFrame, - *self.EXPECTED_CONTEXT_FRAMES, - ] - await run_test( - aggregator, - frames_to_send=frames_to_send, - expected_down_frames=expected_down_frames, - ) - self.check_message_content(context, 0, "How are you?") - - async def test_set(self): - assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" - - context = self.CONTEXT_CLASS() - aggregator = self.AGGREGATOR_CLASS( - context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT) - ) - frames_to_send = [ - UserStartedSpeakingFrame(), - UserStoppedSpeakingFrame(), - TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""), - SleepFrame(sleep=AGGREGATION_SLEEP), - ] - expected_down_frames = [ - UserStartedSpeakingFrame, - UserStoppedSpeakingFrame, - *self.EXPECTED_CONTEXT_FRAMES, - ] - await run_test( - aggregator, - frames_to_send=frames_to_send, - expected_down_frames=expected_down_frames, - ) - self.check_message_content(context, 0, "How are you?") - - async def test_seit(self): - assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" - - context = self.CONTEXT_CLASS() - aggregator = self.AGGREGATOR_CLASS( - context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT) - ) - frames_to_send = [ - UserStartedSpeakingFrame(), - UserStoppedSpeakingFrame(), - InterimTranscriptionFrame(text="How ", user_id="cat", timestamp=""), - TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""), - SleepFrame(sleep=AGGREGATION_SLEEP), - ] - expected_down_frames = [ - UserStartedSpeakingFrame, - UserStoppedSpeakingFrame, - *self.EXPECTED_CONTEXT_FRAMES, - ] - await run_test( - aggregator, - frames_to_send=frames_to_send, - expected_down_frames=expected_down_frames, - ) - self.check_message_content(context, 0, "How are you?") - - async def test_st1et2(self): - assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" - - context = self.CONTEXT_CLASS() - aggregator = self.AGGREGATOR_CLASS( - context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT) - ) - frames_to_send = [ - SpeechControlParamsFrame(vad_params=VADParams(stop_secs=AGGREGATION_TIMEOUT)), - UserStartedSpeakingFrame(), - TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""), - SleepFrame(), - UserStoppedSpeakingFrame(), - TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""), - SleepFrame(sleep=AGGREGATION_SLEEP), - ] - expected_down_frames = [ - SpeechControlParamsFrame, - UserStartedSpeakingFrame, - *self.EXPECTED_CONTEXT_FRAMES, - UserStoppedSpeakingFrame, - *self.EXPECTED_CONTEXT_FRAMES, - ] - await run_test( - aggregator, - frames_to_send=frames_to_send, - expected_down_frames=expected_down_frames, - ) - self.check_message_multi_content(context, 0, 0, "Hello Pipecat!") - self.check_message_multi_content(context, 0, 1, "How are you?") - - async def test_set1t2(self): - assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" - - context = self.CONTEXT_CLASS() - aggregator = self.AGGREGATOR_CLASS( - context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT) - ) - frames_to_send = [ - UserStartedSpeakingFrame(), - UserStoppedSpeakingFrame(), - TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""), - TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""), - SleepFrame(sleep=AGGREGATION_SLEEP), - ] - expected_down_frames = [ - UserStartedSpeakingFrame, - UserStoppedSpeakingFrame, - *self.EXPECTED_CONTEXT_FRAMES, - ] - await run_test( - aggregator, - frames_to_send=frames_to_send, - expected_down_frames=expected_down_frames, - ) - self.check_message_content(context, 0, "Hello Pipecat! How are you?") - - async def test_siet1it2(self): - assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" - - context = self.CONTEXT_CLASS() - aggregator = self.AGGREGATOR_CLASS( - context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT) - ) - frames_to_send = [ - UserStartedSpeakingFrame(), - InterimTranscriptionFrame(text="Hello ", user_id="cat", timestamp=""), - SleepFrame(), - UserStoppedSpeakingFrame(), - TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""), - InterimTranscriptionFrame(text="How ", user_id="cat", timestamp=""), - TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""), - SleepFrame(sleep=AGGREGATION_SLEEP), - ] - expected_down_frames = [ - UserStartedSpeakingFrame, - UserStoppedSpeakingFrame, - *self.EXPECTED_CONTEXT_FRAMES, - ] - await run_test( - aggregator, - frames_to_send=frames_to_send, - expected_down_frames=expected_down_frames, - ) - self.check_message_content(context, 0, "Hello Pipecat! How are you?") - - async def test_t(self): - assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" - - context = self.CONTEXT_CLASS() - aggregator = self.AGGREGATOR_CLASS( - context - ) # No aggregation timeout; this tests VAD emulation - - frames_to_send = [ - SpeechControlParamsFrame(vad_params=VADParams(stop_secs=AGGREGATION_TIMEOUT)), - TranscriptionFrame(text="Hello!", user_id="cat", timestamp=""), - SleepFrame(sleep=AGGREGATION_SLEEP), - ] - expected_down_frames = [ - SpeechControlParamsFrame, - *self.EXPECTED_CONTEXT_FRAMES, - ] - expected_up_frames = [EmulateUserStartedSpeakingFrame, EmulateUserStoppedSpeakingFrame] - - await run_test( - aggregator, - frames_to_send=frames_to_send, - expected_down_frames=expected_down_frames, - expected_up_frames=expected_up_frames, - ) - self.check_message_content(context, 0, "Hello!") - - async def test_t_with_turn_analyzer(self): - assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" - - context = self.CONTEXT_CLASS() - aggregator = self.AGGREGATOR_CLASS( - context, params=LLMUserAggregatorParams(turn_emulated_vad_timeout=AGGREGATION_TIMEOUT) - ) - - frames_to_send = [ - SpeechControlParamsFrame( - vad_params=VADParams(stop_secs=0.2), - turn_params=SmartTurnParams(stop_secs=3.0), # Turn analyzer present - ), - TranscriptionFrame(text="Hello!", user_id="cat", timestamp=""), - SleepFrame(sleep=AGGREGATION_SLEEP), - ] - expected_down_frames = [ - SpeechControlParamsFrame, - *self.EXPECTED_CONTEXT_FRAMES, - ] - expected_up_frames = [EmulateUserStartedSpeakingFrame, EmulateUserStoppedSpeakingFrame] - - await run_test( - aggregator, - frames_to_send=frames_to_send, - expected_down_frames=expected_down_frames, - expected_up_frames=expected_up_frames, - ) - self.check_message_content(context, 0, "Hello!") - - async def test_it(self): - assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" - - context = self.CONTEXT_CLASS() - aggregator = self.AGGREGATOR_CLASS( - context - ) # No aggregation timeout; this tests VAD emulation - frames_to_send = [ - SpeechControlParamsFrame(vad_params=VADParams(stop_secs=AGGREGATION_TIMEOUT)), - InterimTranscriptionFrame(text="Hello ", user_id="cat", timestamp=""), - SleepFrame(), - TranscriptionFrame(text="Hello Pipecat!", user_id="cat", timestamp=""), - SleepFrame(sleep=AGGREGATION_SLEEP), - ] - expected_down_frames = [SpeechControlParamsFrame, *self.EXPECTED_CONTEXT_FRAMES] - expected_up_frames = [EmulateUserStartedSpeakingFrame, EmulateUserStoppedSpeakingFrame] - await run_test( - aggregator, - frames_to_send=frames_to_send, - expected_down_frames=expected_down_frames, - expected_up_frames=expected_up_frames, - ) - self.check_message_content(context, 0, "Hello Pipecat!") - - async def test_sie_delay_it(self): - assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" - - context = self.CONTEXT_CLASS() - aggregator = self.AGGREGATOR_CLASS( - context, params=LLMUserAggregatorParams(aggregation_timeout=AGGREGATION_TIMEOUT) - ) - frames_to_send = [ - UserStartedSpeakingFrame(), - InterimTranscriptionFrame(text="How ", user_id="cat", timestamp=""), - SleepFrame(), - UserStoppedSpeakingFrame(), - SleepFrame(AGGREGATION_SLEEP), - InterimTranscriptionFrame(text="are you?", user_id="cat", timestamp=""), - TranscriptionFrame(text="How are you?", user_id="cat", timestamp=""), - SleepFrame(sleep=AGGREGATION_SLEEP), - ] - expected_down_frames = [ - UserStartedSpeakingFrame, - UserStoppedSpeakingFrame, - *self.EXPECTED_CONTEXT_FRAMES, - ] - await run_test( - aggregator, - frames_to_send=frames_to_send, - expected_down_frames=expected_down_frames, - ) - self.check_message_content(context, 0, "How are you?") - - async def test_min_words_interruption_strategy_one_word(self): - assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" - - class ContextProcessor(FrameProcessor): - def __init__(self): - super().__init__() - self.context_received = False - - async def process_frame(self, frame: Frame, direction: FrameDirection): - await super().process_frame(frame, direction) - - if isinstance(frame, OpenAILLMContextFrame): - self.context_received = True - - await self.push_frame(frame, direction) - - context = self.CONTEXT_CLASS() - aggregator = self.AGGREGATOR_CLASS(context) - context_processor = ContextProcessor() - pipeline = Pipeline([aggregator, context_processor]) - - frames_to_send = [ - BotStartedSpeakingFrame(), - UserStartedSpeakingFrame(), - TranscriptionFrame(text="Can", user_id="cat", timestamp=""), - SleepFrame(), - UserStoppedSpeakingFrame(), - ] - expected_down_frames = [ - BotStartedSpeakingFrame, - UserStartedSpeakingFrame, - UserStoppedSpeakingFrame, - ] - await run_test( - pipeline, - frames_to_send=frames_to_send, - expected_down_frames=expected_down_frames, - pipeline_params=PipelineParams( - interruption_strategies=[MinWordsInterruptionStrategy(min_words=2)] - ), - ) - assert not context_processor.context_received - - async def test_min_words_interruption_strategy_two_words(self): - assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" - - class ContextProcessor(FrameProcessor): - def __init__(self): - super().__init__() - self.context_received = False - - async def process_frame(self, frame: Frame, direction: FrameDirection): - await super().process_frame(frame, direction) - - if isinstance(frame, OpenAILLMContextFrame): - self.context_received = True - elif isinstance(frame, InterruptionFrame): - self.context_received = False - - await self.push_frame(frame, direction) - - context = self.CONTEXT_CLASS() - aggregator = self.AGGREGATOR_CLASS(context) - context_processor = ContextProcessor() - pipeline = Pipeline([aggregator, context_processor]) - - frames_to_send = [ - BotStartedSpeakingFrame(), - UserStartedSpeakingFrame(), - TranscriptionFrame(text="Can you", user_id="cat", timestamp=""), - SleepFrame(), - UserStoppedSpeakingFrame(), - ] - expected_up_frames = [InterruptionFrame] - expected_down_frames = [ - BotStartedSpeakingFrame, - UserStartedSpeakingFrame, - InterruptionFrame, - UserStoppedSpeakingFrame, - *self.EXPECTED_CONTEXT_FRAMES, - ] - await run_test( - pipeline, - frames_to_send=frames_to_send, - expected_up_frames=expected_up_frames, - expected_down_frames=expected_down_frames, - pipeline_params=PipelineParams( - interruption_strategies=[MinWordsInterruptionStrategy(min_words=2)] - ), - ) - self.check_message_content(context, 0, "Can you") - # If the context is not received or it has been cleared by the - # interruption then we have an issue. - assert context_processor.context_received - - -class BaseTestAssistantContextAggregator: - CONTEXT_CLASS = None # To be set in subclasses - AGGREGATOR_CLASS = None # To be set in subclasses - EXPECTED_CONTEXT_FRAMES = None # To be set in subclasses - - def create_assistant_aggregator_params( - self, **kwargs - ) -> Optional[LLMAssistantAggregatorParams]: - return LLMAssistantAggregatorParams(**kwargs) - - def check_message_content(self, context: OpenAILLMContext, index: int, content: str): - assert context.messages[index]["content"] == content - - def check_message_multi_content( - self, context: OpenAILLMContext, content_index: int, index: int, content: str - ): - assert context.messages[index]["content"] == content - - def check_function_call_result(self, context: OpenAILLMContext, index: int, content: str): - assert json.loads(context.messages[index]["content"]) == content - - async def test_empty(self): - assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" - - context = self.CONTEXT_CLASS() - aggregator = self.AGGREGATOR_CLASS(context) - frames_to_send = [LLMFullResponseStartFrame(), LLMFullResponseEndFrame()] - expected_down_frames = [] - await run_test( - aggregator, - frames_to_send=frames_to_send, - expected_down_frames=expected_down_frames, - ) - - async def test_single_text(self): - assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" - - context = self.CONTEXT_CLASS() - aggregator = self.AGGREGATOR_CLASS(context) - frames_to_send = [ - LLMFullResponseStartFrame(), - TextFrame(text="Hello Pipecat!"), - LLMFullResponseEndFrame(), - ] - expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES] - await run_test( - aggregator, - frames_to_send=frames_to_send, - expected_down_frames=expected_down_frames, - ) - self.check_message_content(context, 0, "Hello Pipecat!") - - async def test_multiple_text(self): - assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" - - context = self.CONTEXT_CLASS() - aggregator = self.AGGREGATOR_CLASS( - context, params=self.create_assistant_aggregator_params(expect_stripped_words=False) - ) - - # The newer LLMAssistantAggregator expects TextFrames to declare - # when they include inter-frame spaces. - def make_text_frame(text: str) -> TextFrame: - frame = TextFrame(text=text) - frame.includes_inter_frame_spaces = True - return frame - - frames_to_send = [ - LLMFullResponseStartFrame(), - make_text_frame("Hello "), - make_text_frame("Pipecat. "), - make_text_frame("How are "), - make_text_frame("you?"), - LLMFullResponseEndFrame(), - ] - expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES] - await run_test( - aggregator, - frames_to_send=frames_to_send, - expected_down_frames=expected_down_frames, - ) - self.check_message_content(context, 0, "Hello Pipecat. How are you?") - - async def test_multiple_text_stripped(self): - assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" - - context = self.CONTEXT_CLASS() - aggregator = self.AGGREGATOR_CLASS(context) - frames_to_send = [ - LLMFullResponseStartFrame(), - TextFrame(text="Hello"), - TextFrame(text="Pipecat."), - TextFrame(text="How are"), - TextFrame(text="you?"), - LLMFullResponseEndFrame(), - ] - expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES] - await run_test( - aggregator, - frames_to_send=frames_to_send, - expected_down_frames=expected_down_frames, - ) - self.check_message_content(context, 0, "Hello Pipecat. How are you?") - - async def test_multiple_llm_responses(self): - assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" - - context = self.CONTEXT_CLASS() - aggregator = self.AGGREGATOR_CLASS( - context, params=self.create_assistant_aggregator_params(expect_stripped_words=False) - ) - - # The newer LLMAssistantAggregator expects TextFrames to declare - # when they include inter-frame spaces. - def make_text_frame(text: str) -> TextFrame: - frame = TextFrame(text=text) - frame.includes_inter_frame_spaces = True - return frame - - frames_to_send = [ - LLMFullResponseStartFrame(), - make_text_frame("Hello "), - make_text_frame("Pipecat."), - LLMFullResponseEndFrame(), - LLMFullResponseStartFrame(), - make_text_frame(text="How are "), - make_text_frame(text="you?"), - LLMFullResponseEndFrame(), - ] - expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES, *self.EXPECTED_CONTEXT_FRAMES] - await run_test( - aggregator, - frames_to_send=frames_to_send, - expected_down_frames=expected_down_frames, - ) - self.check_message_multi_content(context, 0, 0, "Hello Pipecat.") - self.check_message_multi_content(context, 0, 1, "How are you?") - - async def test_multiple_llm_responses_interruption(self): - assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" - - context = self.CONTEXT_CLASS() - aggregator = self.AGGREGATOR_CLASS( - context, params=self.create_assistant_aggregator_params(expect_stripped_words=False) - ) - - # The newer LLMAssistantAggregator expects TextFrames to declare - # when they include inter-frame spaces. - def make_text_frame(text: str) -> TextFrame: - frame = TextFrame(text=text) - frame.includes_inter_frame_spaces = True - return frame - - frames_to_send = [ - LLMFullResponseStartFrame(), - make_text_frame("Hello "), - make_text_frame("Pipecat."), - LLMFullResponseEndFrame(), - SleepFrame(AGGREGATION_SLEEP), - InterruptionFrame(), - LLMFullResponseStartFrame(), - make_text_frame("How are "), - make_text_frame("you?"), - LLMFullResponseEndFrame(), - ] - expected_down_frames = [ - *self.EXPECTED_CONTEXT_FRAMES, - InterruptionFrame, - *self.EXPECTED_CONTEXT_FRAMES, - ] - await run_test( - aggregator, - frames_to_send=frames_to_send, - expected_down_frames=expected_down_frames, - ) - self.check_message_multi_content(context, 0, 0, "Hello Pipecat.") - self.check_message_multi_content(context, 0, 1, "How are you?") - - async def test_function_call(self): - assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" - - context = self.CONTEXT_CLASS() - aggregator = self.AGGREGATOR_CLASS(context) - frames_to_send = [ - FunctionCallInProgressFrame( - function_name="get_weather", - tool_call_id="1", - arguments={"location": "Los Angeles"}, - cancel_on_interruption=False, - ), - SleepFrame(), - FunctionCallResultFrame( - function_name="get_weather", - tool_call_id="1", - arguments={"location": "Los Angeles"}, - result={"conditions": "Sunny"}, - ), - ] - expected_down_frames = [] - await run_test( - aggregator, - frames_to_send=frames_to_send, - expected_down_frames=expected_down_frames, - ) - self.check_function_call_result(context, -1, {"conditions": "Sunny"}) - - async def test_function_call_on_context_updated(self): - assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" - - context_updated = False - - async def on_context_updated(): - nonlocal context_updated - context_updated = True - - context = self.CONTEXT_CLASS() - aggregator = self.AGGREGATOR_CLASS(context) - frames_to_send = [ - FunctionCallInProgressFrame( - function_name="get_weather", - tool_call_id="1", - arguments={"location": "Los Angeles"}, - cancel_on_interruption=False, - ), - SleepFrame(), - FunctionCallResultFrame( - function_name="get_weather", - tool_call_id="1", - arguments={"location": "Los Angeles"}, - result={"conditions": "Sunny"}, - properties=FunctionCallResultProperties(on_context_updated=on_context_updated), - ), - SleepFrame(), - ] - expected_down_frames = [] - await run_test( - aggregator, - frames_to_send=frames_to_send, - expected_down_frames=expected_down_frames, - ) - self.check_function_call_result(context, -1, {"conditions": "Sunny"}) - assert context_updated - - -# -# LLMUserContextAggregator -# - - -class TestLLMUserContextAggregator(BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase): - CONTEXT_CLASS = OpenAILLMContext - AGGREGATOR_CLASS = LLMUserContextAggregator - - -# -# Anthropic -# - - -class TestAnthropicUserContextAggregator( - BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase -): - CONTEXT_CLASS = AnthropicLLMContext - AGGREGATOR_CLASS = AnthropicUserContextAggregator - - def check_message_multi_content( - self, context: OpenAILLMContext, content_index: int, index: int, content: str - ): - messages = context.messages[content_index] - assert messages["content"][index]["text"] == content - - -class TestAnthropicAssistantContextAggregator( - BaseTestAssistantContextAggregator, unittest.IsolatedAsyncioTestCase -): - CONTEXT_CLASS = AnthropicLLMContext - AGGREGATOR_CLASS = AnthropicAssistantContextAggregator - EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame] - - def check_message_multi_content( - self, context: OpenAILLMContext, content_index: int, index: int, content: str - ): - messages = context.messages[content_index] - assert messages["content"][index]["text"] == content - - def check_function_call_result(self, context: OpenAILLMContext, index: int, content: Any): - assert context.messages[index]["content"][0]["content"] == json.dumps(content) - - -# -# AWS (Bedrock) -# - - -class TestAWSBedrockUserContextAggregator( - BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase -): - CONTEXT_CLASS = AWSBedrockLLMContext - AGGREGATOR_CLASS = AWSBedrockUserContextAggregator - - def check_message_multi_content( - self, context: OpenAILLMContext, content_index: int, index: int, content: str - ): - messages = context.messages[content_index] - assert messages["content"][index]["text"] == content - - -class TestAWSBedrockAssistantContextAggregator( - BaseTestAssistantContextAggregator, unittest.IsolatedAsyncioTestCase -): - CONTEXT_CLASS = AWSBedrockLLMContext - AGGREGATOR_CLASS = AWSBedrockAssistantContextAggregator - EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame] - - def check_message_multi_content( - self, context: OpenAILLMContext, content_index: int, index: int, content: str - ): - messages = context.messages[content_index] - assert messages["content"][index]["text"] == content - - def check_function_call_result(self, context: OpenAILLMContext, index: int, content: Any): - assert context.messages[index]["content"][0]["toolResult"]["content"][0][ - "text" - ] == json.dumps(content) - - -# -# Google -# - - -class TestGoogleUserContextAggregator( - BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase -): - CONTEXT_CLASS = GoogleLLMContext - AGGREGATOR_CLASS = GoogleUserContextAggregator - - def check_message_content(self, context: OpenAILLMContext, index: int, content: str): - obj = context.messages[index].to_json_dict() - assert obj["parts"][0]["text"] == content - - def check_message_multi_content( - self, context: OpenAILLMContext, content_index: int, index: int, content: str - ): - obj = context.messages[index].to_json_dict() - assert obj["parts"][0]["text"] == content - - -class TestGoogleAssistantContextAggregator( - BaseTestAssistantContextAggregator, unittest.IsolatedAsyncioTestCase -): - CONTEXT_CLASS = GoogleLLMContext - AGGREGATOR_CLASS = GoogleAssistantContextAggregator - EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame] - - def check_message_content(self, context: OpenAILLMContext, index: int, content: str): - obj = context.messages[index].to_json_dict() - assert obj["parts"][0]["text"] == content - - def check_message_multi_content( - self, context: OpenAILLMContext, content_index: int, index: int, content: str - ): - obj = context.messages[index].to_json_dict() - assert obj["parts"][0]["text"] == content - - def check_function_call_result(self, context: OpenAILLMContext, index: int, content: Any): - obj = context.messages[index].to_json_dict() - assert obj["parts"][0]["function_response"]["response"]["value"] == json.dumps(content) - - -# -# OpenAI -# - - -class TestOpenAIUserContextAggregator( - BaseTestUserContextAggregator, unittest.IsolatedAsyncioTestCase -): - CONTEXT_CLASS = OpenAILLMContext - AGGREGATOR_CLASS = OpenAIUserContextAggregator - - -class TestOpenAIAssistantContextAggregator( - BaseTestAssistantContextAggregator, unittest.IsolatedAsyncioTestCase -): - CONTEXT_CLASS = OpenAILLMContext - AGGREGATOR_CLASS = OpenAIAssistantContextAggregator - EXPECTED_CONTEXT_FRAMES = [OpenAILLMContextFrame, OpenAILLMContextAssistantTimestampFrame] - - -# -# Universal -# -class TestLLMAssistantAggregator( - BaseTestAssistantContextAggregator, unittest.IsolatedAsyncioTestCase -): - CONTEXT_CLASS = LLMContext - AGGREGATOR_CLASS = LLMAssistantAggregator - EXPECTED_CONTEXT_FRAMES = [LLMContextFrame, LLMContextAssistantTimestampFrame] - - # Override to remove 'expect_stripped_words' parameter, which is deprecated - # for LLMAssistantAggregator - def create_assistant_aggregator_params( - self, **kwargs - ) -> Optional[LLMAssistantAggregatorParams]: - kwargs.pop("expect_stripped_words", None) - return LLMAssistantAggregatorParams(**kwargs) if kwargs else None - - async def test_multiple_text_mixed(self): - assert self.CONTEXT_CLASS is not None, "CONTEXT_CLASS must be set in a subclass" - assert self.AGGREGATOR_CLASS is not None, "AGGREGATOR_CLASS must be set in a subclass" - - context = self.CONTEXT_CLASS() - aggregator = self.AGGREGATOR_CLASS( - context, params=self.create_assistant_aggregator_params(expect_stripped_words=False) - ) - - # The newer LLMAssistantAggregator expects TextFrames to declare - # when they include inter-frame spaces. - def make_text_frame(text: str, includes_spaces: bool) -> TextFrame: - frame = TextFrame(text=text) - frame.includes_inter_frame_spaces = includes_spaces - return frame - - frames_to_send = [ - LLMFullResponseStartFrame(), - make_text_frame("Hello ", includes_spaces=True), - make_text_frame("Pipecat. ", includes_spaces=True), - make_text_frame("Here's some", includes_spaces=True), - make_text_frame( - " code:", includes_spaces=True - ), # Validates ending includes_inter_frame_spaces run with no space - make_text_frame("```python\nprint('Hello, World!')\n```", includes_spaces=False), - make_text_frame( - "```javascript\nconsole.log('Hello, World!');\n```", includes_spaces=False - ), - make_text_frame( - " And some more: ", includes_spaces=True - ), # Validates starting includes_inter_frame_spaces run with a space and ending it with no space - make_text_frame("```html\n
Hello, World!
\n```", includes_spaces=False), - make_text_frame( - "Hope that ", includes_spaces=True - ), # Validates starting includes_inter_frame_spaces run with no space - make_text_frame("helps!", includes_spaces=True), - LLMFullResponseEndFrame(), - ] - expected_down_frames = [*self.EXPECTED_CONTEXT_FRAMES] - await run_test( - aggregator, - frames_to_send=frames_to_send, - expected_down_frames=expected_down_frames, - ) - self.check_message_content( - context, - 0, - "Hello Pipecat. Here's some code: ```python\nprint('Hello, World!')\n``` ```javascript\nconsole.log('Hello, World!');\n``` And some more: ```html\n
Hello, World!
\n``` Hope that helps!", - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_context_aggregators_universal.py b/tests/test_context_aggregators_universal.py index 7d6f73f2d..59aab4745 100644 --- a/tests/test_context_aggregators_universal.py +++ b/tests/test_context_aggregators_universal.py @@ -4,13 +4,16 @@ # SPDX-License-Identifier: BSD 2-Clause License # +import json import unittest from pipecat.frames.frames import ( BotStartedSpeakingFrame, BotStoppedSpeakingFrame, FunctionCallFromLLM, + FunctionCallInProgressFrame, FunctionCallResultFrame, + FunctionCallResultProperties, FunctionCallsStartedFrame, InterimTranscriptionFrame, InterruptionFrame, @@ -26,6 +29,7 @@ from pipecat.frames.frames import ( LLMThoughtStartFrame, LLMThoughtTextFrame, StartFrame, + TextFrame, TranscriptionFrame, TranslationFrame, UserMuteStartedFrame, @@ -588,6 +592,165 @@ class TestLLMAssistantAggregator(unittest.IsolatedAsyncioTestCase): self.assertTrue(should_stop) self.assertEqual(stop_message.content, "Hello from Pipecat!") + async def test_multiple_text_with_spaces(self): + context = LLMContext() + aggregator = LLMAssistantAggregator(context) + + def make_text_frame(text: str) -> TextFrame: + frame = TextFrame(text=text) + frame.includes_inter_frame_spaces = True + return frame + + frames_to_send = [ + LLMFullResponseStartFrame(), + make_text_frame("Hello "), + make_text_frame("Pipecat. "), + make_text_frame("How are "), + make_text_frame("you?"), + LLMFullResponseEndFrame(), + ] + expected_down_frames = [LLMContextFrame, LLMContextAssistantTimestampFrame] + await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + assert context.messages[0]["content"] == "Hello Pipecat. How are you?" + + async def test_multiple_text_stripped(self): + context = LLMContext() + aggregator = LLMAssistantAggregator(context) + frames_to_send = [ + LLMFullResponseStartFrame(), + TextFrame(text="Hello"), + TextFrame(text="Pipecat."), + TextFrame(text="How are"), + TextFrame(text="you?"), + LLMFullResponseEndFrame(), + ] + expected_down_frames = [LLMContextFrame, LLMContextAssistantTimestampFrame] + await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + assert context.messages[0]["content"] == "Hello Pipecat. How are you?" + + async def test_multiple_text_mixed_spaces(self): + context = LLMContext() + aggregator = LLMAssistantAggregator(context) + + def make_text_frame(text: str, includes_spaces: bool) -> TextFrame: + frame = TextFrame(text=text) + frame.includes_inter_frame_spaces = includes_spaces + return frame + + frames_to_send = [ + LLMFullResponseStartFrame(), + make_text_frame("Hello ", includes_spaces=True), + make_text_frame("Pipecat. ", includes_spaces=True), + make_text_frame("Here's some", includes_spaces=True), + make_text_frame( + " code:", includes_spaces=True + ), # Validates ending includes_inter_frame_spaces run with no space + make_text_frame("```python\nprint('Hello, World!')\n```", includes_spaces=False), + make_text_frame( + "```javascript\nconsole.log('Hello, World!');\n```", includes_spaces=False + ), + make_text_frame( + " And some more: ", includes_spaces=True + ), # Validates starting includes_inter_frame_spaces run with a space and ending it with no space + make_text_frame("```html\n
Hello, World!
\n```", includes_spaces=False), + make_text_frame( + "Hope that ", includes_spaces=True + ), # Validates starting includes_inter_frame_spaces run with no space + make_text_frame("helps!", includes_spaces=True), + LLMFullResponseEndFrame(), + ] + expected_down_frames = [LLMContextFrame, LLMContextAssistantTimestampFrame] + await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + assert context.messages[0]["content"] == ( + "Hello Pipecat. Here's some code: " + "```python\nprint('Hello, World!')\n``` " + "```javascript\nconsole.log('Hello, World!');\n``` " + "And some more: " + "```html\n
Hello, World!
\n``` " + "Hope that helps!" + ) + + async def test_multiple_responses(self): + context = LLMContext() + aggregator = LLMAssistantAggregator(context) + + def make_text_frame(text: str) -> TextFrame: + frame = TextFrame(text=text) + frame.includes_inter_frame_spaces = True + return frame + + frames_to_send = [ + LLMFullResponseStartFrame(), + make_text_frame("Hello "), + make_text_frame("Pipecat."), + LLMFullResponseEndFrame(), + LLMFullResponseStartFrame(), + make_text_frame(text="How are "), + make_text_frame(text="you?"), + LLMFullResponseEndFrame(), + ] + expected_down_frames = [ + LLMContextFrame, + LLMContextAssistantTimestampFrame, + LLMContextFrame, + LLMContextAssistantTimestampFrame, + ] + await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + assert context.messages[0]["content"] == "Hello Pipecat." + assert context.messages[1]["content"] == "How are you?" + + async def test_multiple_responses_interruption(self): + context = LLMContext() + aggregator = LLMAssistantAggregator(context) + + def make_text_frame(text: str) -> TextFrame: + frame = TextFrame(text=text) + frame.includes_inter_frame_spaces = True + return frame + + frames_to_send = [ + LLMFullResponseStartFrame(), + make_text_frame("Hello "), + make_text_frame("Pipecat."), + LLMFullResponseEndFrame(), + SleepFrame(0.15), + InterruptionFrame(), + LLMFullResponseStartFrame(), + make_text_frame("How are "), + make_text_frame("you?"), + LLMFullResponseEndFrame(), + ] + expected_down_frames = [ + LLMContextFrame, + LLMContextAssistantTimestampFrame, + InterruptionFrame, + LLMContextFrame, + LLMContextAssistantTimestampFrame, + ] + await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + assert context.messages[0]["content"] == "Hello Pipecat." + assert context.messages[1]["content"] == "How are you?" + async def test_interruption(self): context = LLMContext() @@ -635,6 +798,67 @@ class TestLLMAssistantAggregator(unittest.IsolatedAsyncioTestCase): self.assertEqual(stop_messages[0].content, "Hello") self.assertEqual(stop_messages[1].content, "Hello there!") + async def test_function_call(self): + context = LLMContext() + aggregator = LLMAssistantAggregator(context) + frames_to_send = [ + FunctionCallInProgressFrame( + function_name="get_weather", + tool_call_id="1", + arguments={"location": "Los Angeles"}, + cancel_on_interruption=False, + ), + SleepFrame(), + FunctionCallResultFrame( + function_name="get_weather", + tool_call_id="1", + arguments={"location": "Los Angeles"}, + result={"conditions": "Sunny"}, + ), + ] + expected_down_frames = [] + await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + assert json.loads(context.messages[-1]["content"]) == {"conditions": "Sunny"} + + async def test_function_call_on_context_updated(self): + context_updated = False + + async def on_context_updated(): + nonlocal context_updated + context_updated = True + + context = LLMContext() + aggregator = LLMAssistantAggregator(context) + frames_to_send = [ + FunctionCallInProgressFrame( + function_name="get_weather", + tool_call_id="1", + arguments={"location": "Los Angeles"}, + cancel_on_interruption=False, + ), + SleepFrame(), + FunctionCallResultFrame( + function_name="get_weather", + tool_call_id="1", + arguments={"location": "Los Angeles"}, + result={"conditions": "Sunny"}, + properties=FunctionCallResultProperties(on_context_updated=on_context_updated), + ), + SleepFrame(), + ] + expected_down_frames = [] + await run_test( + aggregator, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + assert json.loads(context.messages[-1]["content"]) == {"conditions": "Sunny"} + assert context_updated + async def test_thought(self): context = LLMContext() diff --git a/tests/test_novita_llm.py b/tests/test_novita_llm.py index e0f2c71b4..9aab54cab 100644 --- a/tests/test_novita_llm.py +++ b/tests/test_novita_llm.py @@ -51,8 +51,7 @@ async def test_novita_llm_stream_closed_on_cancellation(): mock_stream = MockAsyncStream() - service._stream_chat_completions_specific_context = AsyncMock(return_value=mock_stream) - service._stream_chat_completions_universal_context = AsyncMock(return_value=mock_stream) + service.get_chat_completions = AsyncMock(return_value=mock_stream) service.start_ttfb_metrics = AsyncMock() service.stop_ttfb_metrics = AsyncMock() service.start_llm_usage_metrics = AsyncMock() diff --git a/tests/test_openai_llm_timeout.py b/tests/test_openai_llm_timeout.py index 8ee776a9b..61264cbf5 100644 --- a/tests/test_openai_llm_timeout.py +++ b/tests/test_openai_llm_timeout.py @@ -171,8 +171,7 @@ async def test_openai_llm_stream_closed_on_cancellation(): mock_stream = MockAsyncStream() # Mock the stream creation methods - service._stream_chat_completions_specific_context = AsyncMock(return_value=mock_stream) - service._stream_chat_completions_universal_context = AsyncMock(return_value=mock_stream) + service.get_chat_completions = AsyncMock(return_value=mock_stream) service.start_ttfb_metrics = AsyncMock() service.stop_ttfb_metrics = AsyncMock() service.start_llm_usage_metrics = AsyncMock() @@ -281,8 +280,7 @@ async def test_openai_llm_async_iterator_closed_on_stream_end(): mock_iterator = MockAsyncIterator() mock_stream = MockAsyncStream(mock_iterator) - service._stream_chat_completions_specific_context = AsyncMock(return_value=mock_stream) - service._stream_chat_completions_universal_context = AsyncMock(return_value=mock_stream) + service.get_chat_completions = AsyncMock(return_value=mock_stream) service.start_ttfb_metrics = AsyncMock() service.stop_ttfb_metrics = AsyncMock() service.start_llm_usage_metrics = AsyncMock() diff --git a/tests/test_run_inference.py b/tests/test_run_inference.py index ae9d30f8f..2a7f7695e 100644 --- a/tests/test_run_inference.py +++ b/tests/test_run_inference.py @@ -84,61 +84,6 @@ async def test_openai_run_inference_with_llm_context(): ) -@pytest.mark.asyncio -async def test_openai_run_inference_with_openai_llm_context(): - """Test run_inference with OpenAILLMContext returns expected response.""" - # Create service with mocked client and specific parameters - with patch.object(OpenAILLMService, "create_client"): - from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext - from pipecat.services.openai.base_llm import BaseOpenAILLMService - - params = BaseOpenAILLMService.InputParams( - temperature=0.8, max_completion_tokens=150, presence_penalty=0.3, top_p=0.9 - ) - service = OpenAILLMService(model="gpt-4", params=params) - service._client = AsyncMock() - - # Create OpenAILLMContext - context = OpenAILLMContext( - messages=[ - {"role": "system", "content": "You are a helpful assistant"}, - {"role": "user", "content": "Hello, world!"}, - ], - tools=OPENAI_NOT_GIVEN, - tool_choice=OPENAI_NOT_GIVEN, - ) - - # Mock response - mock_response = MagicMock() - mock_response.choices = [MagicMock()] - mock_response.choices[0].message.content = "Hello! How can I help you today?" - service._client.chat.completions.create.return_value = mock_response - - # Execute - result = await service.run_inference(context) - - # Verify - assert result == "Hello! How can I help you today?" - service._client.chat.completions.create.assert_called_once_with( - model="gpt-4", - stream=False, - frequency_penalty=OPENAI_NOT_GIVEN, - presence_penalty=0.3, - seed=OPENAI_NOT_GIVEN, - temperature=0.8, - top_p=0.9, - max_tokens=OPENAI_NOT_GIVEN, - max_completion_tokens=150, - service_tier=OPENAI_NOT_GIVEN, - messages=[ - {"role": "system", "content": "You are a helpful assistant"}, - {"role": "user", "content": "Hello, world!"}, - ], - tools=OPENAI_NOT_GIVEN, - tool_choice=OPENAI_NOT_GIVEN, - ) - - @pytest.mark.asyncio async def test_openai_run_inference_client_exception(): """Test that exceptions from the client are propagated.""" @@ -209,54 +154,6 @@ async def test_anthropic_run_inference_with_llm_context(): ) -@pytest.mark.asyncio -async def test_anthropic_run_inference_with_openai_llm_context(): - """Test run_inference with OpenAILLMContext returns expected response for Anthropic.""" - # Create service with mocked client and specific parameters - from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext - from pipecat.services.anthropic.llm import AnthropicLLMService - - params = AnthropicLLMService.InputParams(max_tokens=1024, temperature=0.7, top_k=40, top_p=0.9) - service = AnthropicLLMService( - api_key="test-key", model="claude-3-sonnet-20240229", params=params - ) - service._client = AsyncMock() - - # Create OpenAILLMContext - context = OpenAILLMContext( - messages=[ - {"role": "system", "content": "You are a helpful assistant"}, - {"role": "user", "content": "Hello, world!"}, - ], - tools=NOT_GIVEN, - tool_choice=NOT_GIVEN, - ) - - # Mock response - mock_response = MagicMock() - mock_response.content = [MagicMock()] - mock_response.content[0].text = "Hello! How can I help you today?" - service._client.beta.messages.create.return_value = mock_response - - # Execute - result = await service.run_inference(context) - - # Verify - assert result == "Hello! How can I help you today?" - service._client.beta.messages.create.assert_called_once_with( - model="claude-3-sonnet-20240229", - max_tokens=1024, - stream=False, - temperature=0.7, - top_k=40, - top_p=0.9, - messages=[{"role": "user", "content": "Hello, world!"}], - system="You are a helpful assistant", - tools=[], - betas=["interleaved-thinking-2025-05-14"], - ) - - @pytest.mark.asyncio async def test_anthropic_run_inference_client_exception(): """Test that exceptions from the Anthropic client are propagated.""" @@ -336,61 +233,6 @@ async def test_google_run_inference_client_exception(): await service.run_inference(mock_context) -@pytest.mark.asyncio -async def test_google_run_inference_with_openai_llm_context(): - """Test run_inference with OpenAILLMContext returns expected response for Google.""" - # Create service with mocked client and specific parameters - from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext - - params = GoogleLLMService.InputParams(max_tokens=256, temperature=0.4, top_k=30, top_p=0.75) - service = GoogleLLMService(api_key="test-key", model="gemini-2.0-flash", params=params) - service._client = AsyncMock() - - # Create OpenAILLMContext - context = OpenAILLMContext( - messages=[ - {"role": "system", "content": "You are a helpful assistant"}, - {"role": "user", "content": "Hello, world!"}, - ], - tools=NOT_GIVEN, - tool_choice=NOT_GIVEN, - ) - - # Mock response - mock_response = MagicMock() - mock_response.candidates = [MagicMock()] - mock_response.candidates[0].content = MagicMock() - mock_response.candidates[0].content.parts = [MagicMock()] - mock_response.candidates[0].content.parts[0].text = "Hello! How can I help you today?" - service._client.aio = AsyncMock() - service._client.aio.models = AsyncMock() - service._client.aio.models.generate_content = AsyncMock(return_value=mock_response) - - # Execute - result = await service.run_inference(context) - - # Verify - assert result == "Hello! How can I help you today?" - - # Verify the call includes configured parameters - call_kwargs = service._client.aio.models.generate_content.call_args.kwargs - assert call_kwargs["model"] == "gemini-2.0-flash" - # Contents is a Google Content object, so check its structure - contents = call_kwargs["contents"] - assert len(contents) == 1 - assert contents[0].role == "user" - assert len(contents[0].parts) == 1 - assert contents[0].parts[0].text == "Hello, world!" - assert "config" in call_kwargs - config = call_kwargs["config"] - # Config is a GenerateContentConfig object, so access attributes - assert config.system_instruction == "You are a helpful assistant" - assert config.temperature == 0.4 - assert config.top_k == 30 - assert config.top_p == 0.75 - assert config.max_output_tokens == 256 - - @pytest.mark.asyncio async def test_aws_bedrock_run_inference_with_llm_context(): """Test run_inference with LLMContext returns expected response for AWS Bedrock.""" @@ -445,57 +287,6 @@ async def test_aws_bedrock_run_inference_with_llm_context(): assert call_kwargs["inferenceConfig"]["topP"] == 0.85 -@pytest.mark.asyncio -async def test_aws_bedrock_run_inference_with_openai_llm_context(): - """Test run_inference with OpenAILLMContext returns expected response for AWS Bedrock.""" - # Create service with specific parameters - from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext - from pipecat.services.aws.llm import AWSBedrockLLMService - - params = AWSBedrockLLMService.InputParams(max_tokens=512, temperature=0.8, top_p=0.95) - service = AWSBedrockLLMService(model="anthropic.claude-3-sonnet-20240229-v1:0", params=params) - - # Create OpenAILLMContext - context = OpenAILLMContext( - messages=[ - {"role": "system", "content": "You are a helpful assistant"}, - {"role": "user", "content": "Hello, world!"}, - ], - tools=NOT_GIVEN, - tool_choice=NOT_GIVEN, - ) - - # Mock the client and response - mock_client = AsyncMock() - mock_response = { - "output": {"message": {"content": [{"text": "Hello! How can I help you today?"}]}} - } - mock_client.converse.return_value = mock_response - - # Patch the _aws_session.client method to be an async context manager - mock_context_manager = AsyncMock() - mock_context_manager.__aenter__ = AsyncMock(return_value=mock_client) - mock_context_manager.__aexit__ = AsyncMock(return_value=None) - - with patch.object(service._aws_session, "client", return_value=mock_context_manager): - # Execute - result = await service.run_inference(context) - - # Verify - assert result == "Hello! How can I help you today?" - - # Verify the call includes configured parameters - call_kwargs = mock_client.converse.call_args.kwargs - assert call_kwargs["modelId"] == "anthropic.claude-3-sonnet-20240229-v1:0" - assert call_kwargs["messages"] == [{"role": "user", "content": [{"text": "Hello, world!"}]}] - assert call_kwargs["system"] == [{"text": "You are a helpful assistant"}] - assert call_kwargs["additionalModelRequestFields"] == {} - assert "inferenceConfig" in call_kwargs - assert call_kwargs["inferenceConfig"]["maxTokens"] == 512 - assert call_kwargs["inferenceConfig"]["temperature"] == 0.8 - assert call_kwargs["inferenceConfig"]["topP"] == 0.95 - - @pytest.mark.asyncio async def test_aws_bedrock_run_inference_client_exception(): """Test that exceptions from the AWS Bedrock client are propagated.""" diff --git a/tests/test_sambanova_llm.py b/tests/test_sambanova_llm.py index 6632951fc..57e2e8c50 100644 --- a/tests/test_sambanova_llm.py +++ b/tests/test_sambanova_llm.py @@ -56,8 +56,7 @@ async def test_sambanova_llm_stream_closed_on_cancellation(): mock_stream = MockAsyncStream() - service._stream_chat_completions_specific_context = AsyncMock(return_value=mock_stream) - service._stream_chat_completions_universal_context = AsyncMock(return_value=mock_stream) + service.get_chat_completions = AsyncMock(return_value=mock_stream) service.start_ttfb_metrics = AsyncMock() service.stop_ttfb_metrics = AsyncMock() service.start_llm_usage_metrics = AsyncMock()