diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index d4f299fa777..919b381dc26 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -328,9 +328,8 @@ Here's an example: @check_figures_equal() def test_my_plotting_case(): "Test that my plotting function works" - fig_ref = Figure() + fig_ref, fig_test = Figure(), Figure() fig_ref.grdimage("@earth_relief_01d_g", projection="W120/15c", cmap="geo") - fig_test = Figure() fig_test.grdimage(grid, projection="W120/15c", cmap="geo") return fig_ref, fig_test ``` diff --git a/pygmt/helpers/testing.py b/pygmt/helpers/testing.py index 34db14ae06c..f762051d8a1 100644 --- a/pygmt/helpers/testing.py +++ b/pygmt/helpers/testing.py @@ -1,22 +1,20 @@ """ Helper functions for testing. """ - import inspect import os +import string from matplotlib.testing.compare import compare_images - from ..exceptions import GMTImageComparisonFailure -def check_figures_equal(*, tol=0.0, result_dir="result_images"): +def check_figures_equal(*, extensions=("png",), tol=0.0, result_dir="result_images"): """ Decorator for test cases that generate and compare two figures. - The decorated function must take two arguments, *fig_ref* and *fig_test*, - and draw the reference and test images on them. After the function - returns, the figures are saved and compared. + The decorated function must return two arguments, *fig_ref* and *fig_test*, + these two figures will then be saved and compared against each other. This decorator is practically identical to matplotlib's check_figures_equal function, but adapted for PyGMT figures. See also the original code at @@ -25,6 +23,8 @@ def check_figures_equal(*, tol=0.0, result_dir="result_images"): Parameters ---------- + extensions : list + The extensions to test. Default is ["png"]. tol : float The RMS threshold above which the test is considered failed. result_dir : str @@ -66,19 +66,30 @@ def check_figures_equal(*, tol=0.0, result_dir="result_images"): ... ) >>> shutil.rmtree(path="tmp_result_images") # cleanup folder if tests pass """ + # pylint: disable=invalid-name + ALLOWED_CHARS = set(string.digits + string.ascii_letters + "_-[]()") + KEYWORD_ONLY = inspect.Parameter.KEYWORD_ONLY def decorator(func): + import pytest os.makedirs(result_dir, exist_ok=True) old_sig = inspect.signature(func) - def wrapper(*args, **kwargs): + @pytest.mark.parametrize("ext", extensions) + def wrapper(*args, ext="png", request=None, **kwargs): + if "ext" in old_sig.parameters: + kwargs["ext"] = ext + if "request" in old_sig.parameters: + kwargs["request"] = request + try: + file_name = "".join(c for c in request.node.name if c in ALLOWED_CHARS) + except AttributeError: # 'NoneType' object has no attribute 'node' + file_name = func.__name__ try: fig_ref, fig_test = func(*args, **kwargs) - ref_image_path = os.path.join( - result_dir, func.__name__ + "-expected.png" - ) - test_image_path = os.path.join(result_dir, func.__name__ + ".png") + ref_image_path = os.path.join(result_dir, f"{file_name}-expected.{ext}") + test_image_path = os.path.join(result_dir, f"{file_name}.{ext}") fig_ref.savefig(ref_image_path) fig_test.savefig(test_image_path) @@ -109,9 +120,18 @@ def wrapper(*args, **kwargs): for param in old_sig.parameters.values() if param.name not in {"fig_test", "fig_ref"} ] + if "ext" not in old_sig.parameters: + parameters += [inspect.Parameter("ext", KEYWORD_ONLY)] + if "request" not in old_sig.parameters: + parameters += [inspect.Parameter("request", KEYWORD_ONLY)] new_sig = old_sig.replace(parameters=parameters) wrapper.__signature__ = new_sig + # reach a bit into pytest internals to hoist the marks from + # our wrapped function + new_marks = getattr(func, "pytestmark", []) + wrapper.pytestmark + wrapper.pytestmark = new_marks + return wrapper return decorator