Compare commits
11 Commits
mb/runner-
...
cb/test-se
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
33d813ed8f | ||
|
|
20f4b0e8ff | ||
|
|
6feaf91789 | ||
|
|
91d3ae07b3 | ||
|
|
71841f71ef | ||
|
|
949b807023 | ||
|
|
4ad15f9a01 | ||
|
|
99d94fc625 | ||
|
|
a3d630c0d1 | ||
|
|
04b482c445 | ||
|
|
60e9817f16 |
@@ -17,6 +17,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed an issue in the runner where starting a DailyTransport room via
|
||||
`/start` didn't support using the `DAILY_SAMPLE_ROOM_URL` env var.
|
||||
|
||||
- Fixed an issue in `ServiceSwitcher` where the `STTService`s would result in
|
||||
all STT services producing `TranscriptionFrame`s.
|
||||
|
||||
|
||||
@@ -110,7 +110,7 @@ class AnthropicLLMAdapter(BaseLLMAdapter[AnthropicLLMInvocationParams]):
|
||||
system = NOT_GIVEN
|
||||
messages = []
|
||||
|
||||
# first, map messages using self._from_universal_context_message(m)
|
||||
# First, map messages using self._from_universal_context_message(m)
|
||||
try:
|
||||
messages = [self._from_universal_context_message(m) for m in universal_context_messages]
|
||||
except Exception as e:
|
||||
|
||||
@@ -107,7 +107,7 @@ class AWSBedrockLLMAdapter(BaseLLMAdapter[AWSBedrockLLMInvocationParams]):
|
||||
system = None
|
||||
messages = []
|
||||
|
||||
# first, map messages using self._from_universal_context_message(m)
|
||||
# First, map messages using self._from_universal_context_message(m)
|
||||
try:
|
||||
messages = [self._from_universal_context_message(m) for m in universal_context_messages]
|
||||
except Exception as e:
|
||||
|
||||
@@ -8,8 +8,8 @@
|
||||
|
||||
import base64
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, TypedDict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple, TypedDict
|
||||
|
||||
from loguru import logger
|
||||
from openai import NotGiven
|
||||
@@ -133,6 +133,28 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
|
||||
messages: List[Content]
|
||||
system_instruction: Optional[str] = None
|
||||
|
||||
@dataclass
|
||||
class MessageConversionResult:
|
||||
"""Result of converting a single universal context message to Google format.
|
||||
|
||||
Either content (a Google Content object) or a system instruction string
|
||||
is guaranteed to be set.
|
||||
|
||||
Also returns a tool call ID to name mapping for any tool calls
|
||||
discovered in the message.
|
||||
"""
|
||||
|
||||
content: Optional[Content] = None
|
||||
system_instruction: Optional[str] = None
|
||||
tool_call_id_to_name_mapping: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
@dataclass
|
||||
class MessageConversionParams:
|
||||
"""Parameters for converting a single universal context message to Google format."""
|
||||
|
||||
already_have_system_instruction: bool
|
||||
tool_call_id_to_name_mapping: Dict[str, str]
|
||||
|
||||
def _from_universal_context_messages(
|
||||
self, universal_context_messages: List[LLMContextMessage]
|
||||
) -> ConvertedMessages:
|
||||
@@ -156,24 +178,26 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
|
||||
"""
|
||||
system_instruction = None
|
||||
messages = []
|
||||
tool_call_id_to_name_mapping = {}
|
||||
|
||||
# Process each message, preserving Google-formatted messages and converting others
|
||||
for message in universal_context_messages:
|
||||
if isinstance(message, LLMSpecificMessage):
|
||||
# Assume that LLMSpecificMessage wraps a message in Google format
|
||||
messages.append(message.message)
|
||||
continue
|
||||
|
||||
# Convert standard format to Google format
|
||||
converted = self._from_standard_message(
|
||||
message, already_have_system_instruction=bool(system_instruction)
|
||||
result = self._from_universal_context_message(
|
||||
message,
|
||||
params=self.MessageConversionParams(
|
||||
already_have_system_instruction=bool(system_instruction),
|
||||
tool_call_id_to_name_mapping=tool_call_id_to_name_mapping,
|
||||
),
|
||||
)
|
||||
if isinstance(converted, Content):
|
||||
# Regular (non-system) message
|
||||
messages.append(converted)
|
||||
else:
|
||||
# System instruction
|
||||
system_instruction = converted
|
||||
# Each result is either a Content or a system instruction
|
||||
if result.content:
|
||||
messages.append(result.content)
|
||||
elif result.system_instruction:
|
||||
system_instruction = result.system_instruction
|
||||
|
||||
# Merge tool call ID to name mapping
|
||||
if result.tool_call_id_to_name_mapping:
|
||||
tool_call_id_to_name_mapping.update(result.tool_call_id_to_name_mapping)
|
||||
|
||||
# Check if we only have function-related messages (no regular text)
|
||||
has_regular_messages = any(
|
||||
@@ -193,9 +217,16 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
|
||||
|
||||
return self.ConvertedMessages(messages=messages, system_instruction=system_instruction)
|
||||
|
||||
def _from_universal_context_message(
|
||||
self, message: LLMContextMessage, *, params: MessageConversionParams
|
||||
) -> MessageConversionResult:
|
||||
if isinstance(message, LLMSpecificMessage):
|
||||
return self.MessageConversionResult(content=message.message)
|
||||
return self._from_standard_message(message, params=params)
|
||||
|
||||
def _from_standard_message(
|
||||
self, message: LLMStandardMessage, already_have_system_instruction: bool
|
||||
) -> Content | str:
|
||||
self, message: LLMStandardMessage, *, params: MessageConversionParams
|
||||
) -> MessageConversionResult:
|
||||
"""Convert standard universal context message to Google Content object.
|
||||
|
||||
Handles conversion of text, images, and function calls to Google's
|
||||
@@ -205,10 +236,11 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
|
||||
Args:
|
||||
message: Message in standard universal context format.
|
||||
already_have_system_instruction: Whether we already have a system instruction
|
||||
params: Parameters for conversion.
|
||||
|
||||
Returns:
|
||||
Content object with role and parts, or a plain string for system
|
||||
messages.
|
||||
MessageConversionResult containing either a Content object or a
|
||||
system instruction string.
|
||||
|
||||
Examples:
|
||||
Standard text message::
|
||||
@@ -242,38 +274,48 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
|
||||
Converts to Google Content with::
|
||||
|
||||
Content(
|
||||
role="model",
|
||||
role="user",
|
||||
parts=[Part(function_call=FunctionCall(name="search", args={"query": "test"}))]
|
||||
)
|
||||
"""
|
||||
role = message["role"]
|
||||
content = message.get("content", [])
|
||||
|
||||
if role == "system":
|
||||
if already_have_system_instruction:
|
||||
if params.already_have_system_instruction:
|
||||
role = "user" # Convert system message to user role if we already have a system instruction
|
||||
else:
|
||||
# System instructions are returned as plain text
|
||||
system_instruction: str = None
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
system_instruction = content
|
||||
elif isinstance(content, list):
|
||||
# If content is a list, we assume it's a list of text parts, per the standard
|
||||
return " ".join(part["text"] for part in content if part.get("type") == "text")
|
||||
system_instruction = " ".join(
|
||||
part["text"] for part in content if part.get("type") == "text"
|
||||
)
|
||||
if system_instruction:
|
||||
return self.MessageConversionResult(system_instruction=system_instruction)
|
||||
elif role == "assistant":
|
||||
role = "model"
|
||||
|
||||
parts = []
|
||||
tool_call_id_to_name_mapping = {}
|
||||
|
||||
if message.get("tool_calls"):
|
||||
for tc in message["tool_calls"]:
|
||||
id = tc["id"]
|
||||
name = tc["function"]["name"]
|
||||
tool_call_id_to_name_mapping[id] = name
|
||||
parts.append(
|
||||
Part(
|
||||
function_call=FunctionCall(
|
||||
name=tc["function"]["name"],
|
||||
name=name,
|
||||
args=json.loads(tc["function"]["arguments"]),
|
||||
)
|
||||
)
|
||||
)
|
||||
elif role == "tool":
|
||||
role = "model"
|
||||
role = "user"
|
||||
try:
|
||||
response = json.loads(message["content"])
|
||||
if isinstance(response, dict):
|
||||
@@ -284,12 +326,17 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
|
||||
# Response might not be JSON-deserializable.
|
||||
# This occurs with a UserImageFrame, for example, where we get a plain "COMPLETED" string.
|
||||
response_dict = {"value": message["content"]}
|
||||
|
||||
# Get function name from mapping using tool_call_id, or fallback
|
||||
tool_call_id = message.get("tool_call_id")
|
||||
function_name = "tool_call_result" # Default fallback
|
||||
if tool_call_id and tool_call_id in params.tool_call_id_to_name_mapping:
|
||||
function_name = params.tool_call_id_to_name_mapping[tool_call_id]
|
||||
|
||||
parts.append(
|
||||
Part(
|
||||
function_response=FunctionResponse(
|
||||
name="tool_call_result", # seems to work to hard-code the same name every time
|
||||
response=response_dict,
|
||||
)
|
||||
Part.from_function_response(
|
||||
name=function_name,
|
||||
response=response_dict,
|
||||
)
|
||||
)
|
||||
elif isinstance(content, str):
|
||||
@@ -312,4 +359,7 @@ class GeminiLLMAdapter(BaseLLMAdapter[GeminiLLMInvocationParams]):
|
||||
audio_bytes = base64.b64decode(input_audio["data"])
|
||||
parts.append(Part(inline_data=Blob(mime_type="audio/wav", data=audio_bytes)))
|
||||
|
||||
return Content(role=role, parts=parts)
|
||||
return self.MessageConversionResult(
|
||||
content=Content(role=role, parts=parts),
|
||||
tool_call_id_to_name_mapping=tool_call_id_to_name_mapping,
|
||||
)
|
||||
|
||||
@@ -189,7 +189,7 @@ class TaskObserver(BaseObserver):
|
||||
if isinstance(data, FramePushed):
|
||||
if on_push_frame_deprecated:
|
||||
await observer.on_push_frame(
|
||||
data.src, data.dst, data.frame, data.direction, data.timestamp
|
||||
data.source, data.destination, data.frame, data.direction, data.timestamp
|
||||
)
|
||||
else:
|
||||
await observer.on_push_frame(data)
|
||||
|
||||
@@ -573,8 +573,14 @@ def _setup_daily_routes(app: FastAPI):
|
||||
|
||||
bot_module = _get_bot_module()
|
||||
|
||||
existing_room_url = os.getenv("DAILY_SAMPLE_ROOM_URL")
|
||||
|
||||
result = None
|
||||
if create_daily_room:
|
||||
|
||||
# Configure room if:
|
||||
# 1. Explicitly requested via createDailyRoom in payload
|
||||
# 2. Using pre-configured room from DAILY_SAMPLE_ROOM_URL env var
|
||||
if create_daily_room or existing_room_url:
|
||||
import aiohttp
|
||||
|
||||
from pipecat.runner.daily import configure
|
||||
|
||||
@@ -1034,6 +1034,23 @@ class GoogleLLMService(LLMService):
|
||||
if context:
|
||||
await self._process_context(context)
|
||||
|
||||
async def stop(self, frame):
|
||||
"""Override stop to gracefully close the client."""
|
||||
await super().stop(frame)
|
||||
await self._close_client()
|
||||
|
||||
async def cancel(self, frame):
|
||||
"""Override cancel to gracefully close the client."""
|
||||
await super().cancel(frame)
|
||||
await self._close_client()
|
||||
|
||||
async def _close_client(self):
|
||||
try:
|
||||
await self._client.aio.aclose()
|
||||
except Exception:
|
||||
# Do nothing - we're shutting down anyway
|
||||
pass
|
||||
|
||||
def create_context_aggregator(
|
||||
self,
|
||||
context: OpenAILLMContext,
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
#
|
||||
# Copyright (c) 2024-2025 Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Public testing API for Pipecat frame processors."""
|
||||
|
||||
from .serialization import dict_to_frame, frame_to_dict, load_frames_from_json
|
||||
from .test_runner import run_test_from_file
|
||||
|
||||
__all__ = ["dict_to_frame", "frame_to_dict", "load_frames_from_json", "run_test_from_file"]
|
||||
|
||||
150
src/pipecat/tests/serialization.py
Normal file
150
src/pipecat/tests/serialization.py
Normal file
@@ -0,0 +1,150 @@
|
||||
#
|
||||
# Copyright (c) 2024-2025 Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Frame serialization and deserialization for testing."""
|
||||
|
||||
import base64
|
||||
import inspect
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pipecat.frames import frames
|
||||
|
||||
|
||||
def _get_frame_class(frame_type: str):
|
||||
"""Get a frame class by name from the frames module.
|
||||
|
||||
Args:
|
||||
frame_type: The name of the frame class (e.g., "TextFrame")
|
||||
|
||||
Returns:
|
||||
The frame class object
|
||||
|
||||
Raises:
|
||||
ValueError: If the frame type is not found
|
||||
"""
|
||||
if not hasattr(frames, frame_type):
|
||||
raise ValueError(f"Unknown frame type: {frame_type}")
|
||||
|
||||
cls = getattr(frames, frame_type)
|
||||
if not inspect.isclass(cls) or not issubclass(cls, frames.Frame):
|
||||
raise ValueError(f"{frame_type} is not a valid Frame class")
|
||||
|
||||
return cls
|
||||
|
||||
|
||||
def dict_to_frame(data: Dict[str, Any]) -> frames.Frame:
|
||||
"""Convert a dictionary to a Frame object.
|
||||
|
||||
Args:
|
||||
data: Dictionary containing frame data with a "type" key
|
||||
|
||||
Returns:
|
||||
A Frame instance
|
||||
|
||||
Raises:
|
||||
ValueError: If frame type is missing or invalid
|
||||
|
||||
Example:
|
||||
>>> dict_to_frame({"type": "TextFrame", "text": "hello"})
|
||||
TextFrame(text="hello")
|
||||
"""
|
||||
if "type" not in data:
|
||||
raise ValueError("Frame dictionary must contain a 'type' field")
|
||||
|
||||
frame_type = data["type"]
|
||||
frame_cls = _get_frame_class(frame_type)
|
||||
|
||||
# Build kwargs from data, excluding 'type'
|
||||
kwargs = {k: v for k, v in data.items() if k != "type"}
|
||||
|
||||
# Special handling for audio frames with base64 encoded audio
|
||||
if "audio" in kwargs and isinstance(kwargs["audio"], str):
|
||||
kwargs["audio"] = base64.b64decode(kwargs["audio"])
|
||||
|
||||
# Special handling for image frames with base64 encoded images
|
||||
if "image" in kwargs and isinstance(kwargs["image"], str):
|
||||
kwargs["image"] = base64.b64decode(kwargs["image"])
|
||||
|
||||
try:
|
||||
return frame_cls(**kwargs)
|
||||
except TypeError as e:
|
||||
raise ValueError(f"Failed to create {frame_type}: {e}")
|
||||
|
||||
|
||||
def load_frames_from_json(filepath: str) -> List[frames.Frame]:
|
||||
"""Load frames from a JSON file.
|
||||
|
||||
Args:
|
||||
filepath: Path to JSON file containing frame data
|
||||
|
||||
Returns:
|
||||
List of Frame objects
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the file doesn't exist
|
||||
ValueError: If JSON is invalid or frames cannot be deserialized
|
||||
|
||||
Example JSON format:
|
||||
{
|
||||
"input_frames": [
|
||||
{"type": "TextFrame", "text": "hello"},
|
||||
{"type": "EndFrame"}
|
||||
]
|
||||
}
|
||||
"""
|
||||
path = Path(filepath)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Frame file not found: {filepath}")
|
||||
|
||||
with open(path, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError("JSON must contain a dictionary")
|
||||
|
||||
if "input_frames" not in data:
|
||||
raise ValueError("JSON must contain an 'input_frames' key")
|
||||
|
||||
frame_dicts = data["input_frames"]
|
||||
if not isinstance(frame_dicts, list):
|
||||
raise ValueError("'input_frames' must be a list")
|
||||
|
||||
return [dict_to_frame(frame_dict) for frame_dict in frame_dicts]
|
||||
|
||||
|
||||
def frame_to_dict(frame: frames.Frame) -> Dict[str, Any]:
|
||||
"""Convert a Frame object to a dictionary.
|
||||
|
||||
Args:
|
||||
frame: Frame object to serialize
|
||||
|
||||
Returns:
|
||||
Dictionary representation of the frame
|
||||
|
||||
Example:
|
||||
>>> frame_to_dict(TextFrame(text="hello"))
|
||||
{"type": "TextFrame", "text": "hello"}
|
||||
"""
|
||||
result = {"type": frame.__class__.__name__}
|
||||
|
||||
# Get all fields from the dataclass
|
||||
if hasattr(frame, "__dataclass_fields__"):
|
||||
for field_name in frame.__dataclass_fields__:
|
||||
# Skip internal fields from base Frame class
|
||||
if field_name in ("id", "name", "pts", "metadata", "transport_source", "transport_destination"):
|
||||
continue
|
||||
|
||||
value = getattr(frame, field_name, None)
|
||||
if value is not None:
|
||||
# Special handling for bytes (audio/image data)
|
||||
if isinstance(value, bytes):
|
||||
result[field_name] = base64.b64encode(value).decode("utf-8")
|
||||
else:
|
||||
result[field_name] = value
|
||||
|
||||
return result
|
||||
169
src/pipecat/tests/test_runner.py
Normal file
169
src/pipecat/tests/test_runner.py
Normal file
@@ -0,0 +1,169 @@
|
||||
#
|
||||
# Copyright (c) 2024-2025 Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Test runner for frame processors from JSON test files."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from pipecat.frames.frames import Frame
|
||||
from pipecat.processors.frame_processor import FrameProcessor
|
||||
|
||||
from .serialization import dict_to_frame, frame_to_dict, load_frames_from_json
|
||||
|
||||
|
||||
async def run_test_from_file(
|
||||
processor: FrameProcessor,
|
||||
test_file: str,
|
||||
) -> Tuple[List[Frame], Optional[List[Dict[str, Any]]], bool]:
|
||||
"""Run a processor test from a JSON test file.
|
||||
|
||||
Args:
|
||||
processor: The frame processor to test
|
||||
test_file: Path to JSON test file
|
||||
|
||||
Returns:
|
||||
Tuple of (output_frames, expected_output, passed)
|
||||
- output_frames: List of Frame objects that were output
|
||||
- expected_output: List of expected frame dicts (None if not specified)
|
||||
- passed: True if test passed, False if failed, None if no validation
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If test file doesn't exist
|
||||
ValueError: If test file is invalid
|
||||
|
||||
Example test file format:
|
||||
{
|
||||
"input_frames": [
|
||||
{"type": "TextFrame", "text": "hello"}
|
||||
],
|
||||
"expected_output": [
|
||||
{"type": "TextFrame"},
|
||||
{"type": "EndFrame"}
|
||||
]
|
||||
}
|
||||
"""
|
||||
path = Path(test_file)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Test file not found: {test_file}")
|
||||
|
||||
with open(path, "r") as f:
|
||||
test_data = json.load(f)
|
||||
|
||||
# Load input frames
|
||||
if "input_frames" not in test_data:
|
||||
raise ValueError("Test file must contain 'input_frames'")
|
||||
|
||||
input_frames = [dict_to_frame(frame_dict) for frame_dict in test_data["input_frames"]]
|
||||
|
||||
# Load expected output (optional)
|
||||
expected_output = test_data.get("expected_output", None)
|
||||
|
||||
# Run the test
|
||||
# Note: run_test() only collects frames if expected_down_frames is provided,
|
||||
# so we need to manually collect from the pipeline ourselves
|
||||
import asyncio
|
||||
from pipecat.frames.frames import EndFrame
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.tests.utils import QueuedFrameProcessor
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.task import PipelineTask, PipelineParams
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
|
||||
# Set up the test pipeline manually
|
||||
received_down = asyncio.Queue()
|
||||
received_up = asyncio.Queue()
|
||||
source = QueuedFrameProcessor(
|
||||
queue=received_up,
|
||||
queue_direction=FrameDirection.UPSTREAM,
|
||||
ignore_start=True,
|
||||
)
|
||||
sink = QueuedFrameProcessor(
|
||||
queue=received_down,
|
||||
queue_direction=FrameDirection.DOWNSTREAM,
|
||||
ignore_start=True,
|
||||
)
|
||||
|
||||
pipeline = Pipeline([source, processor, sink])
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(),
|
||||
observers=[],
|
||||
cancel_on_idle_timeout=False,
|
||||
)
|
||||
|
||||
async def push_frames():
|
||||
await asyncio.sleep(0.01)
|
||||
for frame in input_frames:
|
||||
await task.queue_frame(frame)
|
||||
await task.queue_frame(EndFrame())
|
||||
|
||||
runner = PipelineRunner()
|
||||
await asyncio.gather(runner.run(task), push_frames())
|
||||
|
||||
# Collect all frames from the downstream queue
|
||||
downstream_frames = []
|
||||
while not received_down.empty():
|
||||
frame = await received_down.get()
|
||||
if not isinstance(frame, EndFrame):
|
||||
downstream_frames.append(frame)
|
||||
|
||||
# Validate if expected_output is provided
|
||||
passed = None
|
||||
if expected_output is not None:
|
||||
passed = _validate_output(downstream_frames, expected_output)
|
||||
|
||||
return downstream_frames, expected_output, passed
|
||||
|
||||
|
||||
def _validate_output(actual_frames: List[Frame], expected_output: List[Dict[str, Any]]) -> bool:
|
||||
"""Validate actual output frames against expected output.
|
||||
|
||||
Args:
|
||||
actual_frames: List of frames that were actually output
|
||||
expected_output: List of expected frame specifications
|
||||
|
||||
Returns:
|
||||
True if validation passed, False otherwise
|
||||
"""
|
||||
if len(actual_frames) != len(expected_output):
|
||||
return False
|
||||
|
||||
for actual, expected in zip(actual_frames, expected_output):
|
||||
# Check frame type
|
||||
if "type" not in expected:
|
||||
return False
|
||||
|
||||
expected_type = expected["type"]
|
||||
if actual.__class__.__name__ != expected_type:
|
||||
return False
|
||||
|
||||
# Check specific fields if provided
|
||||
for field_name, expected_value in expected.items():
|
||||
if field_name == "type":
|
||||
continue
|
||||
|
||||
if not hasattr(actual, field_name):
|
||||
return False
|
||||
|
||||
actual_value = getattr(actual, field_name)
|
||||
|
||||
# Special handling for different types
|
||||
if isinstance(expected_value, str) and isinstance(actual_value, str):
|
||||
# For string fields, support partial matching with "contains"
|
||||
if field_name.endswith("_contains"):
|
||||
base_field = field_name.replace("_contains", "")
|
||||
if hasattr(actual, base_field):
|
||||
actual_text = getattr(actual, base_field)
|
||||
if expected_value not in actual_text:
|
||||
return False
|
||||
elif actual_value != expected_value:
|
||||
return False
|
||||
elif actual_value != expected_value:
|
||||
return False
|
||||
|
||||
return True
|
||||
Reference in New Issue
Block a user