diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 47086a687b8..5546444cf57 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -82,6 +82,8 @@ Bug fixes - Allow writing NetCDF files including only dimensionless variables using the distributed or multiprocessing scheduler (:issue:`7013`, :pull:`7040`). By `Francesco Nattino `_. +- Fix bug where subplot_kwargs were not working when plotting with figsize, size or aspect (:issue:`7078`, :pull:`7080`) + By `Michael Niklas `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index ae0adfff00b..ee616b9040e 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -280,7 +280,8 @@ def plot( col_wrap : int, optional Use together with ``col`` to wrap faceted plots. ax : matplotlib axes object, optional - If ``None``, use the current axes. Not applicable when using facets. + Axes on which to plot. By default, use the current axes. + Mutually exclusive with ``size``, ``figsize`` and facets. rtol : float, optional Relative tolerance used to determine if the indexes are uniformly spaced. Usually a small positive number. diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index f106d56689c..11bd66a6945 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -30,6 +30,8 @@ if TYPE_CHECKING: + from matplotlib.axes import Axes + from ..core.dataarray import DataArray @@ -423,7 +425,13 @@ def _assert_valid_xy(darray: DataArray, xy: None | Hashable, name: str) -> None: ) -def get_axis(figsize=None, size=None, aspect=None, ax=None, **kwargs): +def get_axis( + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + **subplot_kws: Any, +) -> Axes: try: import matplotlib as mpl import matplotlib.pyplot as plt @@ -435,28 +443,32 @@ def get_axis(figsize=None, size=None, aspect=None, ax=None, **kwargs): raise ValueError("cannot provide both `figsize` and `ax` arguments") if size is not None: raise ValueError("cannot provide both `figsize` and `size` arguments") - _, ax = plt.subplots(figsize=figsize) - elif size is not None: + _, ax = plt.subplots(figsize=figsize, subplot_kw=subplot_kws) + return ax + + if size is not None: if ax is not None: raise ValueError("cannot provide both `size` and `ax` arguments") if aspect is None: width, height = mpl.rcParams["figure.figsize"] aspect = width / height figsize = (size * aspect, size) - _, ax = plt.subplots(figsize=figsize) - elif aspect is not None: + _, ax = plt.subplots(figsize=figsize, subplot_kw=subplot_kws) + return ax + + if aspect is not None: raise ValueError("cannot provide `aspect` argument without `size`") - if kwargs and ax is not None: + if subplot_kws and ax is not None: raise ValueError("cannot use subplot_kws with existing ax") if ax is None: - ax = _maybe_gca(**kwargs) + ax = _maybe_gca(**subplot_kws) return ax -def _maybe_gca(**kwargs): +def _maybe_gca(**subplot_kws: Any) -> Axes: import matplotlib.pyplot as plt @@ -468,7 +480,7 @@ def _maybe_gca(**kwargs): # can not pass kwargs to active axes return plt.gca() - return plt.axes(**kwargs) + return plt.axes(**subplot_kws) def _get_units_from_attrs(da) -> str: diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index f37c2fd7508..ca530bc9cce 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -2955,9 +2955,8 @@ def test_facetgrid_single_contour(): @requires_matplotlib -def test_get_axis(): - # test get_axis works with different args combinations - # and return the right type +def test_get_axis_raises(): + # test get_axis raises an error if trying to do invalid things # cannot provide both ax and figsize with pytest.raises(ValueError, match="both `figsize` and `ax`"): @@ -2975,18 +2974,68 @@ def test_get_axis(): with pytest.raises(ValueError, match="`aspect` argument without `size`"): get_axis(figsize=None, size=None, aspect=4 / 3, ax=None) + # cannot provide axis and subplot_kws + with pytest.raises(ValueError, match="cannot use subplot_kws with existing ax"): + get_axis(figsize=None, size=None, aspect=None, ax=1, something_else=5) + + +@requires_matplotlib +@pytest.mark.parametrize( + ["figsize", "size", "aspect", "ax", "kwargs"], + [ + pytest.param((3, 2), None, None, False, {}, id="figsize"), + pytest.param( + (3.5, 2.5), None, None, False, {"label": "test"}, id="figsize_kwargs" + ), + pytest.param(None, 5, None, False, {}, id="size"), + pytest.param(None, 5.5, None, False, {"label": "test"}, id="size_kwargs"), + pytest.param(None, 5, 1, False, {}, id="size+aspect"), + pytest.param(None, None, None, True, {}, id="ax"), + pytest.param(None, None, None, False, {}, id="default"), + pytest.param(None, None, None, False, {"label": "test"}, id="default_kwargs"), + ], +) +def test_get_axis( + figsize: tuple[float, float] | None, + size: float | None, + aspect: float | None, + ax: bool, + kwargs: dict[str, Any], +) -> None: with figure_context(): - ax = get_axis() - assert isinstance(ax, mpl.axes.Axes) + inp_ax = plt.axes() if ax else None + out_ax = get_axis( + figsize=figsize, size=size, aspect=aspect, ax=inp_ax, **kwargs + ) + assert isinstance(out_ax, mpl.axes.Axes) +@requires_matplotlib @requires_cartopy -def test_get_axis_cartopy(): - +@pytest.mark.parametrize( + ["figsize", "size", "aspect"], + [ + pytest.param((3, 2), None, None, id="figsize"), + pytest.param(None, 5, None, id="size"), + pytest.param(None, 5, 1, id="size+aspect"), + pytest.param(None, None, None, id="default"), + ], +) +def test_get_axis_cartopy( + figsize: tuple[float, float] | None, size: float | None, aspect: float | None +) -> None: kwargs = {"projection": cartopy.crs.PlateCarree()} with figure_context(): - ax = get_axis(**kwargs) - assert isinstance(ax, cartopy.mpl.geoaxes.GeoAxesSubplot) + out_ax = get_axis(figsize=figsize, size=size, aspect=aspect, **kwargs) + assert isinstance(out_ax, cartopy.mpl.geoaxes.GeoAxesSubplot) + + +@requires_matplotlib +def test_get_axis_current() -> None: + with figure_context(): + _, ax = plt.subplots() + out_ax = get_axis() + assert ax is out_ax @requires_matplotlib