Files
pipecat/tests/test_run_inference.py
2026-01-30 10:07:34 -08:00

516 lines
21 KiB
Python

#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from anthropic import NOT_GIVEN
from openai import NotGiven
from openai._types import NOT_GIVEN as OPENAI_NOT_GIVEN
from pipecat.adapters.services.anthropic_adapter import AnthropicLLMInvocationParams
from pipecat.adapters.services.bedrock_adapter import AWSBedrockLLMInvocationParams
from pipecat.adapters.services.gemini_adapter import GeminiLLMInvocationParams
from pipecat.adapters.services.open_ai_adapter import OpenAILLMInvocationParams
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.services.anthropic.llm import AnthropicLLMService
from pipecat.services.aws.llm import AWSBedrockLLMService
from pipecat.services.google.llm import GoogleLLMService
from pipecat.services.openai.llm import OpenAILLMService
@pytest.mark.asyncio
async def test_openai_run_inference_with_llm_context():
"""Test run_inference with LLMContext returns expected response."""
# Create service with mocked client and specific parameters
with patch.object(OpenAILLMService, "create_client"):
from pipecat.services.openai.base_llm import BaseOpenAILLMService
params = BaseOpenAILLMService.InputParams(
temperature=0.7, max_tokens=100, frequency_penalty=0.5, seed=42
)
service = OpenAILLMService(model="gpt-4", params=params)
service._client = AsyncMock()
# Setup mocks
mock_context = MagicMock(spec=LLMContext)
mock_adapter = MagicMock()
test_messages = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello, world!"},
]
mock_adapter.get_llm_invocation_params.return_value = OpenAILLMInvocationParams(
messages=test_messages, tools=OPENAI_NOT_GIVEN, tool_choice=OPENAI_NOT_GIVEN
)
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
# Mock response
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Hello! How can I help you today?"
service._client.chat.completions.create.return_value = mock_response
# Execute
result = await service.run_inference(mock_context)
# Verify
assert result == "Hello! How can I help you today?"
service.get_llm_adapter.assert_called_once()
mock_adapter.get_llm_invocation_params.assert_called_once_with(mock_context)
service._client.chat.completions.create.assert_called_once_with(
model="gpt-4",
stream=False,
frequency_penalty=0.5,
presence_penalty=OPENAI_NOT_GIVEN,
seed=42,
temperature=0.7,
top_p=OPENAI_NOT_GIVEN,
max_tokens=100,
max_completion_tokens=OPENAI_NOT_GIVEN,
service_tier=OPENAI_NOT_GIVEN,
messages=test_messages,
tools=OPENAI_NOT_GIVEN,
tool_choice=OPENAI_NOT_GIVEN,
)
@pytest.mark.asyncio
async def test_openai_run_inference_with_openai_llm_context():
"""Test run_inference with OpenAILLMContext returns expected response."""
# Create service with mocked client and specific parameters
with patch.object(OpenAILLMService, "create_client"):
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.openai.base_llm import BaseOpenAILLMService
params = BaseOpenAILLMService.InputParams(
temperature=0.8, max_completion_tokens=150, presence_penalty=0.3, top_p=0.9
)
service = OpenAILLMService(model="gpt-4", params=params)
service._client = AsyncMock()
# Create OpenAILLMContext
context = OpenAILLMContext(
messages=[
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello, world!"},
],
tools=OPENAI_NOT_GIVEN,
tool_choice=OPENAI_NOT_GIVEN,
)
# Mock response
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Hello! How can I help you today?"
service._client.chat.completions.create.return_value = mock_response
# Execute
result = await service.run_inference(context)
# Verify
assert result == "Hello! How can I help you today?"
service._client.chat.completions.create.assert_called_once_with(
model="gpt-4",
stream=False,
frequency_penalty=OPENAI_NOT_GIVEN,
presence_penalty=0.3,
seed=OPENAI_NOT_GIVEN,
temperature=0.8,
top_p=0.9,
max_tokens=OPENAI_NOT_GIVEN,
max_completion_tokens=150,
service_tier=OPENAI_NOT_GIVEN,
messages=[
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello, world!"},
],
tools=OPENAI_NOT_GIVEN,
tool_choice=OPENAI_NOT_GIVEN,
)
@pytest.mark.asyncio
async def test_openai_run_inference_client_exception():
"""Test that exceptions from the client are propagated."""
with patch.object(OpenAILLMService, "create_client"):
service = OpenAILLMService(model="gpt-4")
service._client = AsyncMock()
mock_context = MagicMock(spec=LLMContext)
mock_adapter = MagicMock()
mock_adapter.get_llm_invocation_params.return_value = OpenAILLMInvocationParams(
messages=[], tools=OPENAI_NOT_GIVEN, tool_choice=OPENAI_NOT_GIVEN
)
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
service._client.chat.completions.create.side_effect = Exception("API Error")
with pytest.raises(Exception, match="API Error"):
await service.run_inference(mock_context)
@pytest.mark.asyncio
async def test_anthropic_run_inference_with_llm_context():
"""Test run_inference with LLMContext returns expected response for Anthropic."""
# Create service with mocked client and specific parameters
from pipecat.services.anthropic.llm import AnthropicLLMService
params = AnthropicLLMService.InputParams(max_tokens=2048, temperature=0.6, top_k=50, top_p=0.95)
service = AnthropicLLMService(
api_key="test-key", model="claude-3-sonnet-20240229", params=params
)
service._client = AsyncMock()
# Setup mocks
mock_context = MagicMock(spec=LLMContext)
mock_adapter = MagicMock()
test_messages = [{"role": "user", "content": "Hello, world!"}]
test_system = "You are a helpful assistant"
mock_adapter.get_llm_invocation_params.return_value = AnthropicLLMInvocationParams(
messages=test_messages, system=test_system, tools=[]
)
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
# Mock response
mock_response = MagicMock()
mock_response.content = [MagicMock()]
mock_response.content[0].text = "Hello! How can I help you today?"
service._client.beta.messages.create.return_value = mock_response
# Execute
result = await service.run_inference(mock_context)
# Verify
assert result == "Hello! How can I help you today?"
service.get_llm_adapter.assert_called_once()
mock_adapter.get_llm_invocation_params.assert_called_once_with(
mock_context, enable_prompt_caching=False
)
service._client.beta.messages.create.assert_called_once_with(
model="claude-3-sonnet-20240229",
max_tokens=2048,
stream=False,
temperature=0.6,
top_k=50,
top_p=0.95,
messages=test_messages,
system=test_system,
tools=[],
betas=["interleaved-thinking-2025-05-14"],
)
@pytest.mark.asyncio
async def test_anthropic_run_inference_with_openai_llm_context():
"""Test run_inference with OpenAILLMContext returns expected response for Anthropic."""
# Create service with mocked client and specific parameters
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.anthropic.llm import AnthropicLLMService
params = AnthropicLLMService.InputParams(max_tokens=1024, temperature=0.7, top_k=40, top_p=0.9)
service = AnthropicLLMService(
api_key="test-key", model="claude-3-sonnet-20240229", params=params
)
service._client = AsyncMock()
# Create OpenAILLMContext
context = OpenAILLMContext(
messages=[
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello, world!"},
],
tools=NOT_GIVEN,
tool_choice=NOT_GIVEN,
)
# Mock response
mock_response = MagicMock()
mock_response.content = [MagicMock()]
mock_response.content[0].text = "Hello! How can I help you today?"
service._client.beta.messages.create.return_value = mock_response
# Execute
result = await service.run_inference(context)
# Verify
assert result == "Hello! How can I help you today?"
service._client.beta.messages.create.assert_called_once_with(
model="claude-3-sonnet-20240229",
max_tokens=1024,
stream=False,
temperature=0.7,
top_k=40,
top_p=0.9,
messages=[{"role": "user", "content": "Hello, world!"}],
system="You are a helpful assistant",
tools=[],
betas=["interleaved-thinking-2025-05-14"],
)
@pytest.mark.asyncio
async def test_anthropic_run_inference_client_exception():
"""Test that exceptions from the Anthropic client are propagated."""
service = AnthropicLLMService(api_key="test-key", model="claude-3-sonnet-20240229")
service._client = AsyncMock()
mock_context = MagicMock(spec=LLMContext)
mock_adapter = MagicMock()
mock_adapter.get_llm_invocation_params.return_value = AnthropicLLMInvocationParams(
messages=[], system="Test system", tools=[]
)
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
service._client.beta.messages.create.side_effect = Exception("Anthropic API Error")
with pytest.raises(Exception, match="Anthropic API Error"):
await service.run_inference(mock_context)
@pytest.mark.asyncio
async def test_google_run_inference_with_llm_context():
"""Test run_inference with LLMContext returns expected response for Google."""
# Create service with mocked client
service = GoogleLLMService(api_key="test-key", model="gemini-2.0-flash")
service._client = AsyncMock()
# Setup mocks
mock_context = MagicMock(spec=LLMContext)
mock_adapter = MagicMock()
test_messages = [{"role": "user", "content": "Hello, world!"}]
test_system = "You are a helpful assistant"
mock_adapter.get_llm_invocation_params.return_value = GeminiLLMInvocationParams(
messages=test_messages, system_instruction=test_system, tools=NotGiven()
)
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
# Mock response
mock_response = MagicMock()
mock_response.candidates = [MagicMock()]
mock_response.candidates[0].content = MagicMock()
mock_response.candidates[0].content.parts = [MagicMock()]
mock_response.candidates[0].content.parts[0].text = "Hello! How can I help you today?"
service._client.aio = AsyncMock()
service._client.aio.models = AsyncMock()
service._client.aio.models.generate_content = AsyncMock(return_value=mock_response)
# Execute
result = await service.run_inference(mock_context)
# Verify
assert result == "Hello! How can I help you today?"
service.get_llm_adapter.assert_called_once()
mock_adapter.get_llm_invocation_params.assert_called_once_with(mock_context)
service._client.aio.models.generate_content.assert_called_once()
@pytest.mark.asyncio
async def test_google_run_inference_client_exception():
"""Test that exceptions from the Google client are propagated."""
service = GoogleLLMService(api_key="test-key", model="gemini-2.0-flash")
service._client = AsyncMock()
mock_context = MagicMock(spec=LLMContext)
mock_adapter = MagicMock()
mock_adapter.get_llm_invocation_params.return_value = GeminiLLMInvocationParams(
messages=[], system_instruction="Test system", tools=NotGiven()
)
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
service._client.aio = AsyncMock()
service._client.aio.models = AsyncMock()
service._client.aio.models.generate_content = AsyncMock(
side_effect=Exception("Google API Error")
)
with pytest.raises(Exception, match="Google API Error"):
await service.run_inference(mock_context)
@pytest.mark.asyncio
async def test_google_run_inference_with_openai_llm_context():
"""Test run_inference with OpenAILLMContext returns expected response for Google."""
# Create service with mocked client and specific parameters
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
params = GoogleLLMService.InputParams(max_tokens=256, temperature=0.4, top_k=30, top_p=0.75)
service = GoogleLLMService(api_key="test-key", model="gemini-2.0-flash", params=params)
service._client = AsyncMock()
# Create OpenAILLMContext
context = OpenAILLMContext(
messages=[
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello, world!"},
],
tools=NOT_GIVEN,
tool_choice=NOT_GIVEN,
)
# Mock response
mock_response = MagicMock()
mock_response.candidates = [MagicMock()]
mock_response.candidates[0].content = MagicMock()
mock_response.candidates[0].content.parts = [MagicMock()]
mock_response.candidates[0].content.parts[0].text = "Hello! How can I help you today?"
service._client.aio = AsyncMock()
service._client.aio.models = AsyncMock()
service._client.aio.models.generate_content = AsyncMock(return_value=mock_response)
# Execute
result = await service.run_inference(context)
# Verify
assert result == "Hello! How can I help you today?"
# Verify the call includes configured parameters
call_kwargs = service._client.aio.models.generate_content.call_args.kwargs
assert call_kwargs["model"] == "gemini-2.0-flash"
# Contents is a Google Content object, so check its structure
contents = call_kwargs["contents"]
assert len(contents) == 1
assert contents[0].role == "user"
assert len(contents[0].parts) == 1
assert contents[0].parts[0].text == "Hello, world!"
assert "config" in call_kwargs
config = call_kwargs["config"]
# Config is a GenerateContentConfig object, so access attributes
assert config.system_instruction == "You are a helpful assistant"
assert config.temperature == 0.4
assert config.top_k == 30
assert config.top_p == 0.75
assert config.max_output_tokens == 256
@pytest.mark.asyncio
async def test_aws_bedrock_run_inference_with_llm_context():
"""Test run_inference with LLMContext returns expected response for AWS Bedrock."""
# Create service with specific parameters
from pipecat.services.aws.llm import AWSBedrockLLMService
params = AWSBedrockLLMService.InputParams(max_tokens=1024, temperature=0.5, top_p=0.85)
service = AWSBedrockLLMService(model="anthropic.claude-3-sonnet-20240229-v1:0", params=params)
# Setup mocks
mock_context = MagicMock(spec=LLMContext)
mock_adapter = MagicMock()
test_messages = [{"role": "user", "content": [{"text": "Hello, world!"}]}]
test_system = [{"text": "You are a helpful assistant"}]
mock_adapter.get_llm_invocation_params.return_value = AWSBedrockLLMInvocationParams(
messages=test_messages, system=test_system, tools=[], tool_choice=None
)
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
# Mock the client and response
mock_client = AsyncMock()
mock_response = {
"output": {"message": {"content": [{"text": "Hello! How can I help you today?"}]}}
}
mock_client.converse.return_value = mock_response
# Patch the _aws_session.client method to be an async context manager
mock_context_manager = AsyncMock()
mock_context_manager.__aenter__ = AsyncMock(return_value=mock_client)
mock_context_manager.__aexit__ = AsyncMock(return_value=None)
with patch.object(service._aws_session, "client", return_value=mock_context_manager):
# Execute
result = await service.run_inference(mock_context)
# Verify
assert result == "Hello! How can I help you today?"
service.get_llm_adapter.assert_called_once()
mock_adapter.get_llm_invocation_params.assert_called_once_with(mock_context)
# Verify the call includes configured parameters
call_kwargs = mock_client.converse.call_args.kwargs
assert call_kwargs["modelId"] == "anthropic.claude-3-sonnet-20240229-v1:0"
assert call_kwargs["messages"] == test_messages
assert call_kwargs["system"] == test_system
assert call_kwargs["additionalModelRequestFields"] == {}
assert "inferenceConfig" in call_kwargs
assert call_kwargs["inferenceConfig"]["maxTokens"] == 1024
assert call_kwargs["inferenceConfig"]["temperature"] == 0.5
assert call_kwargs["inferenceConfig"]["topP"] == 0.85
@pytest.mark.asyncio
async def test_aws_bedrock_run_inference_with_openai_llm_context():
"""Test run_inference with OpenAILLMContext returns expected response for AWS Bedrock."""
# Create service with specific parameters
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.aws.llm import AWSBedrockLLMService
params = AWSBedrockLLMService.InputParams(max_tokens=512, temperature=0.8, top_p=0.95)
service = AWSBedrockLLMService(model="anthropic.claude-3-sonnet-20240229-v1:0", params=params)
# Create OpenAILLMContext
context = OpenAILLMContext(
messages=[
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello, world!"},
],
tools=NOT_GIVEN,
tool_choice=NOT_GIVEN,
)
# Mock the client and response
mock_client = AsyncMock()
mock_response = {
"output": {"message": {"content": [{"text": "Hello! How can I help you today?"}]}}
}
mock_client.converse.return_value = mock_response
# Patch the _aws_session.client method to be an async context manager
mock_context_manager = AsyncMock()
mock_context_manager.__aenter__ = AsyncMock(return_value=mock_client)
mock_context_manager.__aexit__ = AsyncMock(return_value=None)
with patch.object(service._aws_session, "client", return_value=mock_context_manager):
# Execute
result = await service.run_inference(context)
# Verify
assert result == "Hello! How can I help you today?"
# Verify the call includes configured parameters
call_kwargs = mock_client.converse.call_args.kwargs
assert call_kwargs["modelId"] == "anthropic.claude-3-sonnet-20240229-v1:0"
assert call_kwargs["messages"] == [{"role": "user", "content": [{"text": "Hello, world!"}]}]
assert call_kwargs["system"] == [{"text": "You are a helpful assistant"}]
assert call_kwargs["additionalModelRequestFields"] == {}
assert "inferenceConfig" in call_kwargs
assert call_kwargs["inferenceConfig"]["maxTokens"] == 512
assert call_kwargs["inferenceConfig"]["temperature"] == 0.8
assert call_kwargs["inferenceConfig"]["topP"] == 0.95
@pytest.mark.asyncio
async def test_aws_bedrock_run_inference_client_exception():
"""Test that exceptions from the AWS Bedrock client are propagated."""
service = AWSBedrockLLMService(model="anthropic.claude-3-sonnet-20240229-v1:0")
mock_context = MagicMock(spec=LLMContext)
mock_adapter = MagicMock()
mock_adapter.get_llm_invocation_params.return_value = AWSBedrockLLMInvocationParams(
messages=[], system=[{"text": "Test system"}], tools=[], tool_choice=None
)
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
# Mock AWS client to raise exception
mock_client = AsyncMock()
mock_client.converse.side_effect = Exception("Bedrock API Error")
# Patch the _aws_session.client method to be an async context manager
mock_context_manager = AsyncMock()
mock_context_manager.__aenter__ = AsyncMock(return_value=mock_client)
mock_context_manager.__aexit__ = AsyncMock(return_value=None)
with patch.object(service._aws_session, "client", return_value=mock_context_manager):
with pytest.raises(Exception, match="Bedrock API Error"):
await service.run_inference(mock_context)
if __name__ == "__main__":
unittest.main()