Skip to content

Dataset.plot.quiver #4407

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Feb 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
cb4d48f
First attempt at quiver plot.
dcherian Sep 13, 2019
fbca40e
Support for quiver key
dcherian Sep 5, 2020
fc8ad67
refactor out adjusting subplots for legend
dcherian Sep 5, 2020
190d788
Adjust figure for quiverkey.
dcherian Sep 5, 2020
51cb73e
Update quiverkey
dcherian Sep 5, 2020
9cb6a61
Autoscale quiver facetgrid
dcherian Sep 5, 2020
c2dcc1d
Support for hue
dcherian Sep 5, 2020
0393e66
fix tests.
dcherian Sep 5, 2020
8116ec9
Fix test
dcherian Sep 5, 2020
04f64c5
Adding docs
dcherian Sep 20, 2020
abcc008
small fixes
dcherian Sep 20, 2020
f1da0c0
Start adding tests
dcherian Oct 8, 2020
39545b1
Simple tests.
dcherian Jan 20, 2021
b6c1bc8
[skip-ci] Start adding docs.
dcherian Jan 20, 2021
7b55945
[skip-ci] add whats-new
dcherian Jan 20, 2021
0da3540
Add quiverkey test
dcherian Jan 20, 2021
bc78582
Add test for hue_style
dcherian Jan 20, 2021
d1e028c
small doc updates.
dcherian Jan 20, 2021
94f8396
remove commented out autoscaling code
dcherian Jan 20, 2021
de48d51
Raise if u,v are provided for non-quiver plots.
dcherian Jan 20, 2021
e795672
Fix tests
dcherian Jan 20, 2021
e0f227f
Merge remote-tracking branch 'upstream/master' into dataset/quiver
dcherian Jan 29, 2021
8a3912c
[skip-ci] fix bad merge.
dcherian Jan 29, 2021
3685f14
Merge branch 'master' into dataset/quiver
dcherian Feb 10, 2021
ca00bd4
[skip-ci] update whats-new
dcherian Feb 10, 2021
d2cf58c
[skip-ci] Apply suggestions from code review
dcherian Feb 12, 2021
b9bcada
review comments.
dcherian Feb 12, 2021
2b1bc32
Merge remote-tracking branch 'upstream/master' into dataset/quiver
dcherian Feb 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ Plotting
:template: autosummary/accessor_method.rst

Dataset.plot.scatter
Dataset.plot.quiver

DataArray
=========
Expand Down
24 changes: 24 additions & 0 deletions doc/plotting.rst
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,9 @@ Consider this dataset
ds


Scatter
~~~~~~~

Suppose we want to scatter ``A`` against ``B``

.. ipython:: python
Expand Down Expand Up @@ -762,6 +765,27 @@ Faceting is also possible

For more advanced scatter plots, we recommend converting the relevant data variables to a pandas DataFrame and using the extensive plotting capabilities of ``seaborn``.

Quiver
~~~~~~

Visualizing vector fields is supported with quiver plots:

.. ipython:: python
:okwarning:

@savefig ds_simple_quiver.png
ds.isel(w=1, z=1).plot.quiver(x="x", y="y", u="A", v="B")


where ``u`` and ``v`` denote the x and y direction components of the arrow vectors. Again, faceting is also possible:

.. ipython:: python
:okwarning:

@savefig ds_facet_quiver.png
ds.plot.quiver(x="x", y="y", u="A", v="B", col="w", row="z", scale=4)

``scale`` is required for faceted quiver plots. The scale determines the number of data units per arrow length unit, i.e. a smaller scale parameter makes the arrow longer.

.. _plot-maps:

Expand Down
8 changes: 4 additions & 4 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,11 @@ New Features
contain missing values; 8x faster in our benchmark, and 2x faster than pandas.
(:pull:`4746`);
By `Maximilian Roos <https://github.com/max-sixty>`_.

- Performance improvement when constructing DataArrays. Significantly speeds up repr for Datasets with large number of variables.
By `Deepak Cherian <https://github.com/dcherian>`_
- Add :py:meth:`Dataset.plot.quiver` for quiver plots with :py:class:`Dataset` variables.
By `Deepak Cherian <https://github.com/dcherian>`_.
- add ``"drop_conflicts"`` to the strategies supported by the ``combine_attrs`` kwarg
(:issue:`4749`, :pull:`4827`).
By `Justus Magin <https://github.com/keewis>`_.
By `Deepak Cherian <https://github.com/dcherian>`_.
- :py:meth:`DataArray.swap_dims` & :py:meth:`Dataset.swap_dims` now accept dims
in the form of kwargs as well as a dict, like most similar methods.
By `Maximilian Roos <https://github.com/max-sixty>`_.
Expand Down Expand Up @@ -152,6 +150,8 @@ Internal Changes
all resources. (:pull:`#4809`), By `Alessandro Amici <https://github.com/alexamici>`_.
- Ensure warnings cannot be turned into exceptions in :py:func:`testing.assert_equal` and
the other ``assert_*`` functions (:pull:`4864`). By `Mathias Hauser <https://github.com/mathause>`_.
- Performance improvement when constructing DataArrays. Significantly speeds up repr for Datasets with large number of variables.
By `Deepak Cherian <https://github.com/dcherian>`_

.. _whats-new.0.16.2:

Expand Down
86 changes: 81 additions & 5 deletions xarray/plot/dataset_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .facetgrid import _easy_facetgrid
from .utils import (
_add_colorbar,
_get_nice_quiver_magnitude,
_is_numeric,
_process_cmap_cbar_kwargs,
get_axis,
Expand All @@ -17,7 +18,7 @@
_MARKERSIZE_RANGE = np.array([18.0, 72.0])


def _infer_meta_data(ds, x, y, hue, hue_style, add_guide):
def _infer_meta_data(ds, x, y, hue, hue_style, add_guide, funcname):
dvars = set(ds.variables.keys())
error_msg = " must be one of ({:s})".format(", ".join(dvars))

Expand Down Expand Up @@ -48,11 +49,24 @@ def _infer_meta_data(ds, x, y, hue, hue_style, add_guide):
add_colorbar = False
add_legend = False
else:
if add_guide is True:
if add_guide is True and funcname != "quiver":
raise ValueError("Cannot set add_guide when hue is None.")
add_legend = False
add_colorbar = False

if (add_guide or add_guide is None) and funcname == "quiver":
add_quiverkey = True
if hue:
add_colorbar = True
if not hue_style:
hue_style = "continuous"
elif hue_style != "continuous":
raise ValueError(
"hue_style must be 'continuous' or None for .plot.quiver"
)
else:
add_quiverkey = False

if hue_style is not None and hue_style not in ["discrete", "continuous"]:
raise ValueError("hue_style must be either None, 'discrete' or 'continuous'.")

Expand All @@ -66,6 +80,7 @@ def _infer_meta_data(ds, x, y, hue, hue_style, add_guide):
return {
"add_colorbar": add_colorbar,
"add_legend": add_legend,
"add_quiverkey": add_quiverkey,
"hue_label": hue_label,
"hue_style": hue_style,
"xlabel": label_from_attrs(ds[x]),
Expand Down Expand Up @@ -170,6 +185,8 @@ def _dsplot(plotfunc):
ds : Dataset
x, y : str
Variable names for x, y axis.
u, v : str, optional
Variable names for quiver plots
hue: str, optional
Variable by which to color scattered points
hue_style: str, optional
Expand Down Expand Up @@ -250,6 +267,8 @@ def newplotfunc(
ds,
x=None,
y=None,
u=None,
v=None,
hue=None,
hue_style=None,
col=None,
Expand Down Expand Up @@ -282,7 +301,9 @@ def newplotfunc(
if _is_facetgrid: # facetgrid call
meta_data = kwargs.pop("meta_data")
else:
meta_data = _infer_meta_data(ds, x, y, hue, hue_style, add_guide)
meta_data = _infer_meta_data(
ds, x, y, hue, hue_style, add_guide, funcname=plotfunc.__name__
)

hue_style = meta_data["hue_style"]

Expand Down Expand Up @@ -317,13 +338,18 @@ def newplotfunc(
else:
cmap_params_subset = {}

if (u is not None or v is not None) and plotfunc.__name__ != "quiver":
raise ValueError("u, v are only allowed for quiver plots.")

primitive = plotfunc(
ds=ds,
x=x,
y=y,
ax=ax,
u=u,
v=v,
hue=hue,
hue_style=hue_style,
ax=ax,
cmap_params=cmap_params_subset,
**kwargs,
)
Expand All @@ -344,13 +370,34 @@ def newplotfunc(
cbar_kwargs["label"] = meta_data.get("hue_label", None)
_add_colorbar(primitive, ax, cbar_ax, cbar_kwargs, cmap_params)

if meta_data["add_quiverkey"]:
magnitude = _get_nice_quiver_magnitude(ds[u], ds[v])
units = ds[u].attrs.get("units", "")
ax.quiverkey(
primitive,
X=0.85,
Y=0.9,
U=magnitude,
label=f"{magnitude}\n{units}",
labelpos="E",
coordinates="figure",
)

if plotfunc.__name__ == "quiver":
title = ds[u]._title_for_slice()
else:
title = ds[x]._title_for_slice()
ax.set_title(title)

return primitive

@functools.wraps(newplotfunc)
def plotmethod(
_PlotMethods_obj,
x=None,
y=None,
u=None,
v=None,
hue=None,
hue_style=None,
col=None,
Expand Down Expand Up @@ -398,7 +445,7 @@ def plotmethod(


@_dsplot
def scatter(ds, x, y, ax, **kwargs):
def scatter(ds, x, y, ax, u, v, **kwargs):
"""
Scatter Dataset data variables against each other.
"""
Expand Down Expand Up @@ -450,3 +497,32 @@ def scatter(ds, x, y, ax, **kwargs):
)

return primitive


@_dsplot
def quiver(ds, x, y, ax, u, v, **kwargs):
""" Quiver plot with Dataset variables."""
import matplotlib as mpl

if x is None or y is None or u is None or v is None:
raise ValueError("Must specify x, y, u, v for quiver plots.")

x, y, u, v = broadcast(ds[x], ds[y], ds[u], ds[v])

args = [x.values, y.values, u.values, v.values]
hue = kwargs.pop("hue")
cmap_params = kwargs.pop("cmap_params")

if hue:
args.append(ds[hue].values)

# TODO: Fix this by always returning a norm with vmin, vmax in cmap_params
if not cmap_params["norm"]:
cmap_params["norm"] = mpl.colors.Normalize(
cmap_params.pop("vmin"), cmap_params.pop("vmax")
)

kwargs.pop("hue_style")
kwargs.setdefault("pivot", "middle")
hdl = ax.quiver(*args, **kwargs, **cmap_params)
return hdl
73 changes: 55 additions & 18 deletions xarray/plot/facetgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from ..core.formatting import format_item
from .utils import (
_get_nice_quiver_magnitude,
_infer_xy_labels,
_process_cmap_cbar_kwargs,
import_matplotlib_pyplot,
Expand Down Expand Up @@ -195,7 +196,11 @@ def __init__(
self.axes = axes
self.row_names = row_names
self.col_names = col_names

# guides
self.figlegend = None
self.quiverkey = None
self.cbar = None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

technically backwards incompatible but I'm not sure it affects anyone.


# Next the private variables
self._single_group = single_group
Expand Down Expand Up @@ -327,14 +332,15 @@ def map_dataset(
from .dataset_plot import _infer_meta_data, _parse_size

kwargs["add_guide"] = False
kwargs["_is_facetgrid"] = True

if kwargs.get("markersize", None):
kwargs["size_mapping"] = _parse_size(
self.data[kwargs["markersize"]], kwargs.pop("size_norm", None)
)

meta_data = _infer_meta_data(self.data, x, y, hue, hue_style, add_guide)
meta_data = _infer_meta_data(
self.data, x, y, hue, hue_style, add_guide, funcname=func.__name__
)
kwargs["meta_data"] = meta_data

if hue and meta_data["hue_style"] == "continuous":
Expand All @@ -344,6 +350,12 @@ def map_dataset(
kwargs["meta_data"]["cmap_params"] = cmap_params
kwargs["meta_data"]["cbar_kwargs"] = cbar_kwargs

kwargs["_is_facetgrid"] = True

if func.__name__ == "quiver" and "scale" not in kwargs:
raise ValueError("Please provide scale.")
# TODO: come up with an algorithm for reasonable scale choice

for d, ax in zip(self.name_dicts.flat, self.axes.flat):
# None is the sentinel value
if d is not None:
Expand All @@ -365,6 +377,9 @@ def map_dataset(
elif meta_data["add_colorbar"]:
self.add_colorbar(label=self._hue_label, **cbar_kwargs)

if meta_data["add_quiverkey"]:
self.add_quiverkey(kwargs["u"], kwargs["v"])

return self

def _finalize_grid(self, *axlabels):
Expand All @@ -380,37 +395,39 @@ def _finalize_grid(self, *axlabels):

self._finalized = True

def add_legend(self, **kwargs):
figlegend = self.fig.legend(
handles=self._mappables[-1],
labels=list(self._hue_var.values),
title=self._hue_label,
loc="center right",
**kwargs,
)

self.figlegend = figlegend
def _adjust_fig_for_guide(self, guide):
# Draw the plot to set the bounding boxes correctly
self.fig.draw(self.fig.canvas.get_renderer())
renderer = self.fig.canvas.get_renderer()
self.fig.draw(renderer)

# Calculate and set the new width of the figure so the legend fits
legend_width = figlegend.get_window_extent().width / self.fig.dpi
guide_width = guide.get_window_extent(renderer).width / self.fig.dpi
figure_width = self.fig.get_figwidth()
self.fig.set_figwidth(figure_width + legend_width)
self.fig.set_figwidth(figure_width + guide_width)

# Draw the plot again to get the new transformations
self.fig.draw(self.fig.canvas.get_renderer())
self.fig.draw(renderer)

# Now calculate how much space we need on the right side
legend_width = figlegend.get_window_extent().width / self.fig.dpi
space_needed = legend_width / (figure_width + legend_width) + 0.02
guide_width = guide.get_window_extent(renderer).width / self.fig.dpi
space_needed = guide_width / (figure_width + guide_width) + 0.02
# margin = .01
# _space_needed = margin + space_needed
right = 1 - space_needed

# Place the subplot axes to give space for the legend
self.fig.subplots_adjust(right=right)

def add_legend(self, **kwargs):
self.figlegend = self.fig.legend(
handles=self._mappables[-1],
labels=list(self._hue_var.values),
title=self._hue_label,
loc="center right",
**kwargs,
)
self._adjust_fig_for_guide(self.figlegend)

def add_colorbar(self, **kwargs):
"""Draw a colorbar"""
kwargs = kwargs.copy()
Expand All @@ -426,6 +443,26 @@ def add_colorbar(self, **kwargs):
)
return self

def add_quiverkey(self, u, v, **kwargs):
kwargs = kwargs.copy()

magnitude = _get_nice_quiver_magnitude(self.data[u], self.data[v])
units = self.data[u].attrs.get("units", "")
self.quiverkey = self.axes.flat[-1].quiverkey(
self._mappables[-1],
X=0.8,
Y=0.9,
U=magnitude,
label=f"{magnitude}\n{units}",
labelpos="E",
coordinates="figure",
)

# TODO: does not work because self.quiverkey.get_window_extent(renderer) = 0
# https://github.com/matplotlib/matplotlib/issues/18530
# self._adjust_fig_for_guide(self.quiverkey.text)
return self

def set_axis_labels(self, x_var=None, y_var=None):
"""Set axis labels on the left column and bottom row of the grid."""
if x_var is not None:
Expand Down
Loading