Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ Bug fixes
- Allow writing NetCDF files including only dimensionless variables using the distributed or multiprocessing scheduler
(:issue:`7013`, :pull:`7040`).
By `Francesco Nattino <https://github.com/fnattino>`_.
- Fix bug where subplot_kwargs were not working when plotting with figsize, size or aspect (:issue:`7078`, :pull:`7080`)
By `Michael Niklas <https://github.com/headtr1ck>`_.

Documentation
~~~~~~~~~~~~~
Expand Down
3 changes: 2 additions & 1 deletion xarray/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,8 @@ def plot(
col_wrap : int, optional
Use together with ``col`` to wrap faceted plots.
ax : matplotlib axes object, optional
If ``None``, use the current axes. Not applicable when using facets.
Axes on which to plot. By default, use the current axes.
Mutually exclusive with ``size``, ``figsize`` and facets.
rtol : float, optional
Relative tolerance used to determine if the indexes
are uniformly spaced. Usually a small positive number.
Expand Down
30 changes: 21 additions & 9 deletions xarray/plot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@


if TYPE_CHECKING:
from matplotlib.axes import Axes

from ..core.dataarray import DataArray


Expand Down Expand Up @@ -423,7 +425,13 @@ def _assert_valid_xy(darray: DataArray, xy: None | Hashable, name: str) -> None:
)


def get_axis(figsize=None, size=None, aspect=None, ax=None, **kwargs):
def get_axis(
figsize: Iterable[float] | None = None,
size: float | None = None,
aspect: float | None = None,
ax: Axes | None = None,
**subplot_kws: Any,
) -> Axes:
try:
import matplotlib as mpl
import matplotlib.pyplot as plt
Expand All @@ -435,28 +443,32 @@ def get_axis(figsize=None, size=None, aspect=None, ax=None, **kwargs):
raise ValueError("cannot provide both `figsize` and `ax` arguments")
if size is not None:
raise ValueError("cannot provide both `figsize` and `size` arguments")
_, ax = plt.subplots(figsize=figsize)
elif size is not None:
_, ax = plt.subplots(figsize=figsize, subplot_kw=subplot_kws)
return ax

if size is not None:
if ax is not None:
raise ValueError("cannot provide both `size` and `ax` arguments")
if aspect is None:
width, height = mpl.rcParams["figure.figsize"]
aspect = width / height
figsize = (size * aspect, size)
_, ax = plt.subplots(figsize=figsize)
elif aspect is not None:
_, ax = plt.subplots(figsize=figsize, subplot_kw=subplot_kws)
return ax

if aspect is not None:
raise ValueError("cannot provide `aspect` argument without `size`")

if kwargs and ax is not None:
if subplot_kws and ax is not None:
raise ValueError("cannot use subplot_kws with existing ax")

if ax is None:
ax = _maybe_gca(**kwargs)
ax = _maybe_gca(**subplot_kws)

return ax


def _maybe_gca(**kwargs):
def _maybe_gca(**subplot_kws: Any) -> Axes:

import matplotlib.pyplot as plt

Expand All @@ -468,7 +480,7 @@ def _maybe_gca(**kwargs):
# can not pass kwargs to active axes
return plt.gca()

return plt.axes(**kwargs)
return plt.axes(**subplot_kws)


def _get_units_from_attrs(da) -> str:
Expand Down
67 changes: 58 additions & 9 deletions xarray/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2955,9 +2955,8 @@ def test_facetgrid_single_contour():


@requires_matplotlib
def test_get_axis():
# test get_axis works with different args combinations
# and return the right type
def test_get_axis_raises():
# test get_axis raises an error if trying to do invalid things

# cannot provide both ax and figsize
with pytest.raises(ValueError, match="both `figsize` and `ax`"):
Expand All @@ -2975,18 +2974,68 @@ def test_get_axis():
with pytest.raises(ValueError, match="`aspect` argument without `size`"):
get_axis(figsize=None, size=None, aspect=4 / 3, ax=None)

# cannot provide axis and subplot_kws
with pytest.raises(ValueError, match="cannot use subplot_kws with existing ax"):
get_axis(figsize=None, size=None, aspect=None, ax=1, something_else=5)


@requires_matplotlib
@pytest.mark.parametrize(
["figsize", "size", "aspect", "ax", "kwargs"],
[
pytest.param((3, 2), None, None, False, {}, id="figsize"),
pytest.param(
(3.5, 2.5), None, None, False, {"label": "test"}, id="figsize_kwargs"
),
pytest.param(None, 5, None, False, {}, id="size"),
pytest.param(None, 5.5, None, False, {"label": "test"}, id="size_kwargs"),
pytest.param(None, 5, 1, False, {}, id="size+aspect"),
pytest.param(None, None, None, True, {}, id="ax"),
pytest.param(None, None, None, False, {}, id="default"),
pytest.param(None, None, None, False, {"label": "test"}, id="default_kwargs"),
],
)
def test_get_axis(
figsize: tuple[float, float] | None,
size: float | None,
aspect: float | None,
ax: bool,
kwargs: dict[str, Any],
) -> None:
with figure_context():
ax = get_axis()
assert isinstance(ax, mpl.axes.Axes)
inp_ax = plt.axes() if ax else None
out_ax = get_axis(
figsize=figsize, size=size, aspect=aspect, ax=inp_ax, **kwargs
)
assert isinstance(out_ax, mpl.axes.Axes)


@requires_matplotlib
@requires_cartopy
def test_get_axis_cartopy():

@pytest.mark.parametrize(
["figsize", "size", "aspect"],
[
pytest.param((3, 2), None, None, id="figsize"),
pytest.param(None, 5, None, id="size"),
pytest.param(None, 5, 1, id="size+aspect"),
pytest.param(None, None, None, id="default"),
],
)
def test_get_axis_cartopy(
figsize: tuple[float, float] | None, size: float | None, aspect: float | None
) -> None:
kwargs = {"projection": cartopy.crs.PlateCarree()}
with figure_context():
ax = get_axis(**kwargs)
assert isinstance(ax, cartopy.mpl.geoaxes.GeoAxesSubplot)
out_ax = get_axis(figsize=figsize, size=size, aspect=aspect, **kwargs)
assert isinstance(out_ax, cartopy.mpl.geoaxes.GeoAxesSubplot)


@requires_matplotlib
def test_get_axis_current() -> None:
with figure_context():
_, ax = plt.subplots()
out_ax = get_axis()
assert ax is out_ax


@requires_matplotlib
Expand Down