Merge pull request #2644 from pipecat-ai/pk/run-inference-unit-tests

`run_inference` unit tests
This commit is contained in:
kompfner
2025-09-15 16:26:10 -04:00
committed by GitHub
2 changed files with 304 additions and 48 deletions

View File

@@ -811,60 +811,55 @@ class AWSBedrockLLMService(LLMService):
Returns:
The LLM's response as a string, or None if no response is generated.
"""
try:
messages = []
system = []
if isinstance(context, LLMContext):
adapter: AWSBedrockLLMAdapter = self.get_llm_adapter()
params: AWSBedrockLLMInvocationParams = adapter.get_llm_invocation_params(context)
messages = params["messages"]
system = params["system"] # [{"text": "system message"}]
else:
context = AWSBedrockLLMContext.upgrade_to_bedrock(context)
messages = context.messages
system = getattr(context, "system", None) # [{"text": "system message"}]
messages = []
system = []
if isinstance(context, LLMContext):
adapter: AWSBedrockLLMAdapter = self.get_llm_adapter()
params: AWSBedrockLLMInvocationParams = adapter.get_llm_invocation_params(context)
messages = params["messages"]
system = params["system"] # [{"text": "system message"}]
else:
context = AWSBedrockLLMContext.upgrade_to_bedrock(context)
messages = context.messages
system = getattr(context, "system", None) # [{"text": "system message"}]
# Determine if we're using Claude or Nova based on model ID
model_id = self.model_name
# Determine if we're using Claude or Nova based on model ID
model_id = self.model_name
# Prepare request parameters
request_params = {
"modelId": model_id,
"messages": messages,
"inferenceConfig": {
"maxTokens": 8192,
"temperature": 0.7,
"topP": 0.9,
},
}
# Prepare request parameters
request_params = {
"modelId": model_id,
"messages": messages,
"inferenceConfig": {
"maxTokens": 8192,
"temperature": 0.7,
"topP": 0.9,
},
}
if system:
request_params["system"] = system
if system:
request_params["system"] = system
async with self._aws_session.client(
service_name="bedrock-runtime", **self._aws_params
) as client:
# Call Bedrock without streaming
response = await client.converse(**request_params)
async with self._aws_session.client(
service_name="bedrock-runtime", **self._aws_params
) as client:
# Call Bedrock without streaming
response = await client.converse(**request_params)
# Extract the response text
if (
"output" in response
and "message" in response["output"]
and "content" in response["output"]["message"]
):
content = response["output"]["message"]["content"]
if isinstance(content, list):
for item in content:
if item.get("text"):
return item["text"]
elif isinstance(content, str):
return content
# Extract the response text
if (
"output" in response
and "message" in response["output"]
and "content" in response["output"]["message"]
):
content = response["output"]["message"]["content"]
if isinstance(content, list):
for item in content:
if item.get("text"):
return item["text"]
elif isinstance(content, str):
return content
return None
except Exception as e:
logger.error(f"Bedrock summary generation failed: {e}", exc_info=True)
return None
async def _create_converse_stream(self, client, request_params):

261
tests/test_run_inference.py Normal file
View File

@@ -0,0 +1,261 @@
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from anthropic import NOT_GIVEN
from openai import NotGiven
from openai._types import NOT_GIVEN as OPENAI_NOT_GIVEN
from pipecat.adapters.services.anthropic_adapter import AnthropicLLMInvocationParams
from pipecat.adapters.services.bedrock_adapter import AWSBedrockLLMInvocationParams
from pipecat.adapters.services.gemini_adapter import GeminiLLMInvocationParams
from pipecat.adapters.services.open_ai_adapter import OpenAILLMInvocationParams
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.services.anthropic.llm import AnthropicLLMService
from pipecat.services.aws.llm import AWSBedrockLLMService
from pipecat.services.google.llm import GoogleLLMService
from pipecat.services.openai.llm import OpenAILLMService
@pytest.mark.asyncio
async def test_openai_run_inference_with_llm_context():
"""Test run_inference with LLMContext returns expected response."""
# Create service with mocked client
with patch.object(OpenAILLMService, "create_client"):
service = OpenAILLMService(model="gpt-4")
service._client = AsyncMock()
# Setup mocks
mock_context = MagicMock(spec=LLMContext)
mock_adapter = MagicMock()
test_messages = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello, world!"},
]
mock_adapter.get_llm_invocation_params.return_value = OpenAILLMInvocationParams(
messages=test_messages, tools=OPENAI_NOT_GIVEN, tool_choice=OPENAI_NOT_GIVEN
)
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
# Mock response
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Hello! How can I help you today?"
service._client.chat.completions.create.return_value = mock_response
# Execute
result = await service.run_inference(mock_context)
# Verify
assert result == "Hello! How can I help you today?"
service.get_llm_adapter.assert_called_once()
mock_adapter.get_llm_invocation_params.assert_called_once_with(mock_context)
service._client.chat.completions.create.assert_called_once_with(
model="gpt-4",
messages=test_messages,
stream=False,
)
@pytest.mark.asyncio
async def test_openai_run_inference_client_exception():
"""Test that exceptions from the client are propagated."""
with patch.object(OpenAILLMService, "create_client"):
service = OpenAILLMService(model="gpt-4")
service._client = AsyncMock()
mock_context = MagicMock(spec=LLMContext)
mock_adapter = MagicMock()
mock_adapter.get_llm_invocation_params.return_value = OpenAILLMInvocationParams(
messages=[], tools=OPENAI_NOT_GIVEN, tool_choice=OPENAI_NOT_GIVEN
)
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
service._client.chat.completions.create.side_effect = Exception("API Error")
with pytest.raises(Exception, match="API Error"):
await service.run_inference(mock_context)
@pytest.mark.asyncio
async def test_anthropic_run_inference_with_llm_context():
"""Test run_inference with LLMContext returns expected response for Anthropic."""
# Create service with mocked client
service = AnthropicLLMService(api_key="test-key", model="claude-3-sonnet-20240229")
service._client = AsyncMock()
# Setup mocks
mock_context = MagicMock(spec=LLMContext)
mock_adapter = MagicMock()
test_messages = [{"role": "user", "content": "Hello, world!"}]
test_system = "You are a helpful assistant"
mock_adapter.get_llm_invocation_params.return_value = AnthropicLLMInvocationParams(
messages=test_messages, system=test_system, tools=[]
)
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
# Mock response
mock_response = MagicMock()
mock_response.content = [MagicMock()]
mock_response.content[0].text = "Hello! How can I help you today?"
service._client.messages.create.return_value = mock_response
# Execute
result = await service.run_inference(mock_context)
# Verify
assert result == "Hello! How can I help you today?"
service.get_llm_adapter.assert_called_once()
mock_adapter.get_llm_invocation_params.assert_called_once_with(
mock_context, enable_prompt_caching=False
)
service._client.messages.create.assert_called_once_with(
model="claude-3-sonnet-20240229",
messages=test_messages,
system=test_system,
max_tokens=8192,
stream=False,
)
@pytest.mark.asyncio
async def test_anthropic_run_inference_client_exception():
"""Test that exceptions from the Anthropic client are propagated."""
service = AnthropicLLMService(api_key="test-key", model="claude-3-sonnet-20240229")
service._client = AsyncMock()
mock_context = MagicMock(spec=LLMContext)
mock_adapter = MagicMock()
mock_adapter.get_llm_invocation_params.return_value = AnthropicLLMInvocationParams(
messages=[], system="Test system", tools=[]
)
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
service._client.messages.create.side_effect = Exception("Anthropic API Error")
with pytest.raises(Exception, match="Anthropic API Error"):
await service.run_inference(mock_context)
@pytest.mark.asyncio
async def test_google_run_inference_with_llm_context():
"""Test run_inference with LLMContext returns expected response for Google."""
# Create service with mocked client
service = GoogleLLMService(api_key="test-key", model="gemini-2.0-flash")
service._client = AsyncMock()
# Setup mocks
mock_context = MagicMock(spec=LLMContext)
mock_adapter = MagicMock()
test_messages = [{"role": "user", "content": "Hello, world!"}]
test_system = "You are a helpful assistant"
mock_adapter.get_llm_invocation_params.return_value = GeminiLLMInvocationParams(
messages=test_messages, system_instruction=test_system, tools=NotGiven()
)
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
# Mock response
mock_response = MagicMock()
mock_response.candidates = [MagicMock()]
mock_response.candidates[0].content = MagicMock()
mock_response.candidates[0].content.parts = [MagicMock()]
mock_response.candidates[0].content.parts[0].text = "Hello! How can I help you today?"
service._client.aio = AsyncMock()
service._client.aio.models = AsyncMock()
service._client.aio.models.generate_content = AsyncMock(return_value=mock_response)
# Execute
result = await service.run_inference(mock_context)
# Verify
assert result == "Hello! How can I help you today?"
service.get_llm_adapter.assert_called_once()
mock_adapter.get_llm_invocation_params.assert_called_once_with(mock_context)
service._client.aio.models.generate_content.assert_called_once()
@pytest.mark.asyncio
async def test_google_run_inference_client_exception():
"""Test that exceptions from the Google client are propagated."""
service = GoogleLLMService(api_key="test-key", model="gemini-2.0-flash")
service._client = AsyncMock()
mock_context = MagicMock(spec=LLMContext)
mock_adapter = MagicMock()
mock_adapter.get_llm_invocation_params.return_value = GeminiLLMInvocationParams(
messages=[], system_instruction="Test system", tools=NotGiven()
)
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
service._client.aio = AsyncMock()
service._client.aio.models = AsyncMock()
service._client.aio.models.generate_content = AsyncMock(
side_effect=Exception("Google API Error")
)
with pytest.raises(Exception, match="Google API Error"):
await service.run_inference(mock_context)
@pytest.mark.asyncio
async def test_aws_bedrock_run_inference_with_llm_context():
"""Test run_inference with LLMContext returns expected response for AWS Bedrock."""
# Create service and patch the session client method
service = AWSBedrockLLMService(model="anthropic.claude-3-sonnet-20240229-v1:0")
# Setup mocks
mock_context = MagicMock(spec=LLMContext)
mock_adapter = MagicMock()
test_messages = [{"role": "user", "content": [{"text": "Hello, world!"}]}]
test_system = [{"text": "You are a helpful assistant"}]
mock_adapter.get_llm_invocation_params.return_value = AWSBedrockLLMInvocationParams(
messages=test_messages, system=test_system, tools=[], tool_choice=None
)
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
# Mock the client and response
mock_client = AsyncMock()
mock_response = {
"output": {"message": {"content": [{"text": "Hello! How can I help you today?"}]}}
}
mock_client.converse.return_value = mock_response
# Patch the _aws_session.client method to be an async context manager
async def mock_client_cm(*args, **kwargs):
return mock_client
mock_context_manager = AsyncMock()
mock_context_manager.__aenter__ = AsyncMock(return_value=mock_client)
mock_context_manager.__aexit__ = AsyncMock(return_value=None)
with patch.object(service._aws_session, "client", return_value=mock_context_manager):
# Execute
result = await service.run_inference(mock_context)
# Verify
assert result == "Hello! How can I help you today?"
service.get_llm_adapter.assert_called_once()
mock_adapter.get_llm_invocation_params.assert_called_once_with(mock_context)
mock_client.converse.assert_called_once()
@pytest.mark.asyncio
async def test_aws_bedrock_run_inference_client_exception():
"""Test that exceptions from the AWS Bedrock client are propagated."""
service = AWSBedrockLLMService(model="anthropic.claude-3-sonnet-20240229-v1:0")
mock_context = MagicMock(spec=LLMContext)
mock_adapter = MagicMock()
mock_adapter.get_llm_invocation_params.return_value = AWSBedrockLLMInvocationParams(
messages=[], system=[{"text": "Test system"}], tools=[], tool_choice=None
)
service.get_llm_adapter = MagicMock(return_value=mock_adapter)
# Mock AWS client to raise exception
mock_client = AsyncMock()
mock_client.converse.side_effect = Exception("Bedrock API Error")
# Patch the _aws_session.client method to be an async context manager
mock_context_manager = AsyncMock()
mock_context_manager.__aenter__ = AsyncMock(return_value=mock_client)
mock_context_manager.__aexit__ = AsyncMock(return_value=None)
with patch.object(service._aws_session, "client", return_value=mock_context_manager):
with pytest.raises(Exception, match="Bedrock API Error"):
await service.run_inference(mock_context)