From cb4d48f53b1ce4aa0949623cc86b349cb300f534 Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 13 Sep 2019 15:10:01 -0600 Subject: [PATCH 01/25] First attempt at quiver plot. --- xarray/plot/dataset_plot.py | 39 +++++++++++++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index 6d942e1b0fa..8a4bba912d5 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -170,6 +170,8 @@ def _dsplot(plotfunc): ds : Dataset x, y : str Variable names for x, y axis. + u, v : string + Variable names for quiver plots hue: str, optional Variable by which to color scattered points hue_style: str, optional @@ -250,6 +252,8 @@ def newplotfunc( ds, x=None, y=None, + u=None, + v=None, hue=None, hue_style=None, col=None, @@ -321,9 +325,11 @@ def newplotfunc( ds=ds, x=x, y=y, + ax=ax, + u=u, + v=v, hue=hue, hue_style=hue_style, - ax=ax, cmap_params=cmap_params_subset, **kwargs, ) @@ -351,6 +357,8 @@ def plotmethod( _PlotMethods_obj, x=None, y=None, + u=None, + v=None, hue=None, hue_style=None, col=None, @@ -398,7 +406,34 @@ def plotmethod( @_dsplot -def scatter(ds, x, y, ax, **kwargs): +def quiver(ds, x, y, ax, u, v, **kwargs): + if x is None or y is None or u is None or v is None: + raise ValueError("Must specify x, y, u, v for quiver plots.") + + x, y, u, v = broadcast(ds[x], ds[y], ds[u], ds[v]) + + kwargs.pop("cmap_params") + kwargs.pop("hue") + kwargs.pop("hue_style") + hdl = ax.quiver(x.values, y.values, u.values, v.values, **kwargs) + + # if plotfunc.__name__ == "quiver": + # trans = ax._set_transform() + # span, _y = trans.inverted().transform_point( + # (ax.bbox.width, ax.bbox.height)) + # # matplotlib autoscaling algorithm + # scale = kwargs.pop("scale") + # if scale is None: + # npts = ds.dims[x] * ds.dims[y] + # # crude auto-scaling + # # scale is typical arrow length as a multiple of the arrow width + # scale = 1.8 * ds.mean() * np.max(10, np.sqrt(npts)) / span + + return hdl + + +@_dsplot +def scatter(ds, x, y, ax, u, v, **kwargs): """ Scatter Dataset data variables against each other. """ From fbca40ea2ca936a9211e5243398ec013864dbfd6 Mon Sep 17 00:00:00 2001 From: dcherian Date: Sat, 5 Sep 2020 13:25:21 -0600 Subject: [PATCH 02/25] Support for quiver key --- xarray/plot/dataset_plot.py | 32 +++++++++++++++++++++++++++++--- xarray/plot/facetgrid.py | 27 ++++++++++++++++++++++++++- 2 files changed, 55 insertions(+), 4 deletions(-) diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index 8a4bba912d5..5f4af6f9c2e 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -1,5 +1,6 @@ import functools +import matplotlib as mpl import numpy as np import pandas as pd @@ -17,7 +18,7 @@ _MARKERSIZE_RANGE = np.array([18.0, 72.0]) -def _infer_meta_data(ds, x, y, hue, hue_style, add_guide): +def _infer_meta_data(ds, x, y, hue, hue_style, add_guide, funcname): dvars = set(ds.variables.keys()) error_msg = " must be one of ({:s})".format(", ".join(dvars)) @@ -48,11 +49,16 @@ def _infer_meta_data(ds, x, y, hue, hue_style, add_guide): add_colorbar = False add_legend = False else: - if add_guide is True: + if add_guide is True and funcname != "quiver": raise ValueError("Cannot set add_guide when hue is None.") add_legend = False add_colorbar = False + if (add_guide or add_guide is None) and funcname == "quiver": + add_quiverkey = True + else: + add_quiverkey = False + if hue_style is not None and hue_style not in ["discrete", "continuous"]: raise ValueError("hue_style must be either None, 'discrete' or 'continuous'.") @@ -66,6 +72,7 @@ def _infer_meta_data(ds, x, y, hue, hue_style, add_guide): return { "add_colorbar": add_colorbar, "add_legend": add_legend, + "add_quiverkey": add_quiverkey, "hue_label": hue_label, "hue_style": hue_style, "xlabel": label_from_attrs(ds[x]), @@ -286,7 +293,9 @@ def newplotfunc( if _is_facetgrid: # facetgrid call meta_data = kwargs.pop("meta_data") else: - meta_data = _infer_meta_data(ds, x, y, hue, hue_style, add_guide) + meta_data = _infer_meta_data( + ds, x, y, hue, hue_style, add_guide, funcname=plotfunc.__name__ + ) hue_style = meta_data["hue_style"] @@ -350,6 +359,23 @@ def newplotfunc( cbar_kwargs["label"] = meta_data.get("hue_label", None) _add_colorbar(primitive, ax, cbar_ax, cbar_kwargs, cmap_params) + if meta_data["add_quiverkey"]: + ticker = mpl.ticker.MaxNLocator(3) + median = np.median(np.hypot(ds[u].values, ds[v].values)) + magnitude = ticker.tick_values(0, median)[-2] + units = ds[u].attrs.get("units", "") + ax.quiverkey( + primitive, + X=0.85, + Y=0.9, + U=magnitude, + label=f"{magnitude}\n{units}", + labelpos="E", + coordinates="figure", + ) + + ax.set_title(ds[u]._title_for_slice()) + return primitive @functools.wraps(newplotfunc) diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 58b38251352..d251df2cb5b 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -2,6 +2,7 @@ import itertools import warnings +import matplotlib as mpl import numpy as np from ..core.formatting import format_item @@ -334,7 +335,9 @@ def map_dataset( self.data[kwargs["markersize"]], kwargs.pop("size_norm", None) ) - meta_data = _infer_meta_data(self.data, x, y, hue, hue_style, add_guide) + meta_data = _infer_meta_data( + self.data, x, y, hue, hue_style, add_guide, funcname=func.__name__ + ) kwargs["meta_data"] = meta_data if hue and meta_data["hue_style"] == "continuous": @@ -365,6 +368,9 @@ def map_dataset( elif meta_data["add_colorbar"]: self.add_colorbar(label=self._hue_label, **cbar_kwargs) + if meta_data["add_quiverkey"]: + self.add_quiverkey(kwargs["u"], kwargs["v"]) + return self def _finalize_grid(self, *axlabels): @@ -426,6 +432,25 @@ def add_colorbar(self, **kwargs): ) return self + def add_quiverkey(self, u, v, **kwargs): + kwargs = kwargs.copy() + + ticker = mpl.ticker.MaxNLocator(3) + median = np.median(np.hypot(self.data[u].values, self.data[v].values)) + magnitude = ticker.tick_values(0, median)[-2] + units = self.data[u].attrs.get("units", "") + self.axes.flat[-1].quiverkey( + self._mappables[-1], + X=0.95, + Y=0.9, + U=magnitude, + label=f"{magnitude}\n{units}", + labelpos="E", + coordinates="figure", + ) + + return self + def set_axis_labels(self, x_var=None, y_var=None): """Set axis labels on the left column and bottom row of the grid.""" if x_var is not None: From fc8ad67cf4ee5381de0b9e58d607b190ca63798c Mon Sep 17 00:00:00 2001 From: dcherian Date: Sat, 5 Sep 2020 14:12:03 -0600 Subject: [PATCH 03/25] refactor out adjusting subplots for legend --- xarray/plot/facetgrid.py | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index d251df2cb5b..fc123a9eba0 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -386,30 +386,22 @@ def _finalize_grid(self, *axlabels): self._finalized = True - def add_legend(self, **kwargs): - figlegend = self.fig.legend( - handles=self._mappables[-1], - labels=list(self._hue_var.values), - title=self._hue_label, - loc="center right", - **kwargs, - ) - - self.figlegend = figlegend + def _adjust_fig_for_guide(self, guide): # Draw the plot to set the bounding boxes correctly - self.fig.draw(self.fig.canvas.get_renderer()) + renderer = self.fig.canvas.get_renderer() + self.fig.draw(renderer) # Calculate and set the new width of the figure so the legend fits - legend_width = figlegend.get_window_extent().width / self.fig.dpi + guide_width = guide.get_window_extent(renderer).width / self.fig.dpi figure_width = self.fig.get_figwidth() - self.fig.set_figwidth(figure_width + legend_width) + self.fig.set_figwidth(figure_width + guide_width) # Draw the plot again to get the new transformations - self.fig.draw(self.fig.canvas.get_renderer()) + self.fig.draw(renderer) # Now calculate how much space we need on the right side - legend_width = figlegend.get_window_extent().width / self.fig.dpi - space_needed = legend_width / (figure_width + legend_width) + 0.02 + guide_width = guide.get_window_extent(renderer).width / self.fig.dpi + space_needed = guide_width / (figure_width + guide_width) + 0.02 # margin = .01 # _space_needed = margin + space_needed right = 1 - space_needed @@ -417,6 +409,16 @@ def add_legend(self, **kwargs): # Place the subplot axes to give space for the legend self.fig.subplots_adjust(right=right) + def add_legend(self, **kwargs): + self.figlegend = self.fig.legend( + handles=self._mappables[-1], + labels=list(self._hue_var.values), + title=self._hue_label, + loc="center right", + **kwargs, + ) + self._adjust_fig_for_guide(self.figlegend) + def add_colorbar(self, **kwargs): """Draw a colorbar""" kwargs = kwargs.copy() From 190d788d89ee59780f13912d063b57087be2f92c Mon Sep 17 00:00:00 2001 From: dcherian Date: Sat, 5 Sep 2020 14:12:17 -0600 Subject: [PATCH 04/25] Adjust figure for quiverkey. This does not work because quiverkey.get_window_extent(...) return 0... --- xarray/plot/facetgrid.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index fc123a9eba0..ce54c59d31d 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -441,16 +441,18 @@ def add_quiverkey(self, u, v, **kwargs): median = np.median(np.hypot(self.data[u].values, self.data[v].values)) magnitude = ticker.tick_values(0, median)[-2] units = self.data[u].attrs.get("units", "") - self.axes.flat[-1].quiverkey( + self.quiverkey = self.axes.flat[-1].quiverkey( self._mappables[-1], X=0.95, - Y=0.9, + Y=0.5, U=magnitude, label=f"{magnitude}\n{units}", labelpos="E", coordinates="figure", ) + self._adjust_fig_for_guide(self.quiverkey) + return self def set_axis_labels(self, x_var=None, y_var=None): From 51cb73e93cc4f2dbc7565bec7c47375de207d06a Mon Sep 17 00:00:00 2001 From: dcherian Date: Sat, 5 Sep 2020 15:01:15 -0600 Subject: [PATCH 05/25] Update quiverkey --- xarray/plot/dataset_plot.py | 6 ++---- xarray/plot/facetgrid.py | 18 +++++++++--------- xarray/plot/utils.py | 9 +++++++++ 3 files changed, 20 insertions(+), 13 deletions(-) diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index 5f4af6f9c2e..57b544b69cf 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -1,6 +1,5 @@ import functools -import matplotlib as mpl import numpy as np import pandas as pd @@ -8,6 +7,7 @@ from .facetgrid import _easy_facetgrid from .utils import ( _add_colorbar, + _get_nice_quiver_magnitude, _is_numeric, _process_cmap_cbar_kwargs, get_axis, @@ -360,9 +360,7 @@ def newplotfunc( _add_colorbar(primitive, ax, cbar_ax, cbar_kwargs, cmap_params) if meta_data["add_quiverkey"]: - ticker = mpl.ticker.MaxNLocator(3) - median = np.median(np.hypot(ds[u].values, ds[v].values)) - magnitude = ticker.tick_values(0, median)[-2] + magnitude = _get_nice_quiver_magnitude(ds[u], ds[v]) units = ds[u].attrs.get("units", "") ax.quiverkey( primitive, diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index ce54c59d31d..42044c30281 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -2,11 +2,11 @@ import itertools import warnings -import matplotlib as mpl import numpy as np from ..core.formatting import format_item from .utils import ( + _get_nice_quiver_magnitude, _infer_xy_labels, _process_cmap_cbar_kwargs, import_matplotlib_pyplot, @@ -328,7 +328,6 @@ def map_dataset( from .dataset_plot import _infer_meta_data, _parse_size kwargs["add_guide"] = False - kwargs["_is_facetgrid"] = True if kwargs.get("markersize", None): kwargs["size_mapping"] = _parse_size( @@ -347,6 +346,8 @@ def map_dataset( kwargs["meta_data"]["cmap_params"] = cmap_params kwargs["meta_data"]["cbar_kwargs"] = cbar_kwargs + kwargs["_is_facetgrid"] = True + for d, ax in zip(self.name_dicts.flat, self.axes.flat): # None is the sentinel value if d is not None: @@ -437,21 +438,20 @@ def add_colorbar(self, **kwargs): def add_quiverkey(self, u, v, **kwargs): kwargs = kwargs.copy() - ticker = mpl.ticker.MaxNLocator(3) - median = np.median(np.hypot(self.data[u].values, self.data[v].values)) - magnitude = ticker.tick_values(0, median)[-2] + magnitude = _get_nice_quiver_magnitude(self.data[u], self.data[v]) units = self.data[u].attrs.get("units", "") self.quiverkey = self.axes.flat[-1].quiverkey( self._mappables[-1], - X=0.95, - Y=0.5, + X=0.85, + Y=1.03, U=magnitude, label=f"{magnitude}\n{units}", labelpos="E", - coordinates="figure", + coordinates="axes", ) - self._adjust_fig_for_guide(self.quiverkey) + # TODO: does not work because self.quiverkey.get_window_extent(renderer) = 0 + # self._adjust_fig_for_guide(self.quiverkey) return self diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 16c67e154fc..278ed964073 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -840,3 +840,12 @@ def _process_cmap_cbar_kwargs( } return cmap_params, cbar_kwargs + + +def _get_nice_quiver_magnitude(u, v): + import matplotlib as mpl + + ticker = mpl.ticker.MaxNLocator(3) + median = np.median(np.hypot(u.values, v.values)) + magnitude = ticker.tick_values(0, median)[-2] + return magnitude From 9cb6a61afbced13b4c0123573d70e69ca295af6c Mon Sep 17 00:00:00 2001 From: dcherian Date: Sat, 5 Sep 2020 15:01:31 -0600 Subject: [PATCH 06/25] Autoscale quiver facetgrid --- xarray/plot/facetgrid.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 42044c30281..20d1b214e7c 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -348,6 +348,21 @@ def map_dataset( kwargs["_is_facetgrid"] = True + if func.__name__ == "quiver" and "scale" not in kwargs: + if "scale_units" in kwargs and kwargs["scale_units"] is not None: + raise NotImplementedError("Can't pass only scale_units.") + # autoscaling + ax = self.axes[0, 0] + magnitude = _get_nice_quiver_magnitude( + self.data[kwargs["u"]], self.data[kwargs["v"]] + ) + # matplotlib autoscaling algorithm + span = ax.get_transform().inverted().transform_bbox(ax.bbox).width + npts = self.data.sizes[x] * self.data.sizes[y] + # scale is typical arrow length as a multiple of the arrow width + scale = 1.8 * magnitude * max(10, np.sqrt(npts)) / span + kwargs["scale"] = 1 / scale # TODO: why? + for d, ax in zip(self.name_dicts.flat, self.axes.flat): # None is the sentinel value if d is not None: From c2dcc1d760dc4617057f7dcd62979363363e4ede Mon Sep 17 00:00:00 2001 From: dcherian Date: Sat, 5 Sep 2020 15:01:43 -0600 Subject: [PATCH 07/25] Support for hue --- xarray/plot/dataset_plot.py | 39 ++++++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index 57b544b69cf..7a48d7e7870 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -56,6 +56,14 @@ def _infer_meta_data(ds, x, y, hue, hue_style, add_guide, funcname): if (add_guide or add_guide is None) and funcname == "quiver": add_quiverkey = True + if hue: + add_colorbar = True + if not hue_style: + hue_style = "continuous" + elif hue_style != "continuous": + raise ValueError( + "hue_style must be 'continuous' or None for .plot.quiver" + ) else: add_quiverkey = False @@ -431,28 +439,27 @@ def plotmethod( @_dsplot def quiver(ds, x, y, ax, u, v, **kwargs): + import matplotlib as mpl + if x is None or y is None or u is None or v is None: raise ValueError("Must specify x, y, u, v for quiver plots.") x, y, u, v = broadcast(ds[x], ds[y], ds[u], ds[v]) - kwargs.pop("cmap_params") - kwargs.pop("hue") - kwargs.pop("hue_style") - hdl = ax.quiver(x.values, y.values, u.values, v.values, **kwargs) - - # if plotfunc.__name__ == "quiver": - # trans = ax._set_transform() - # span, _y = trans.inverted().transform_point( - # (ax.bbox.width, ax.bbox.height)) - # # matplotlib autoscaling algorithm - # scale = kwargs.pop("scale") - # if scale is None: - # npts = ds.dims[x] * ds.dims[y] - # # crude auto-scaling - # # scale is typical arrow length as a multiple of the arrow width - # scale = 1.8 * ds.mean() * np.max(10, np.sqrt(npts)) / span + args = [x.values, y.values, u.values, v.values] + hue = kwargs.pop("hue") + if hue: + args.append(ds[hue].values) + # TODO: Fix this by always returning a norm with vmin, vmax in cmap_params + cmap_params = kwargs.pop("cmap_params") + if not cmap_params["norm"]: + cmap_params["norm"] = mpl.colors.Normalize( + cmap_params.pop("vmin"), cmap_params.pop("vmax") + ) + + kwargs.pop("hue_style") + hdl = ax.quiver(*args, **kwargs, **cmap_params) return hdl From 0393e66602d84391d0b5746c89077e9c94ed1a86 Mon Sep 17 00:00:00 2001 From: dcherian Date: Sat, 5 Sep 2020 16:23:55 -0600 Subject: [PATCH 08/25] fix tests. --- xarray/tests/test_plot.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 47b15446f1d..549a3c49c1f 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -2194,7 +2194,13 @@ def test_accessor(self): def test_add_guide(self, add_guide, hue_style, legend, colorbar): meta_data = _infer_meta_data( - self.ds, x="A", y="B", hue="hue", hue_style=hue_style, add_guide=add_guide + self.ds, + x="A", + y="B", + hue="hue", + hue_style=hue_style, + add_guide=add_guide, + funcname="scatter", ) assert meta_data["add_legend"] is legend assert meta_data["add_colorbar"] is colorbar From 8116ec95c98505eaf79ee87ba18226dd8b1ea2d3 Mon Sep 17 00:00:00 2001 From: dcherian Date: Sat, 5 Sep 2020 16:28:52 -0600 Subject: [PATCH 09/25] Fix test --- xarray/plot/dataset_plot.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index 7a48d7e7870..1aabb218df7 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -380,7 +380,11 @@ def newplotfunc( coordinates="figure", ) - ax.set_title(ds[u]._title_for_slice()) + if plotfunc.__name__ == "quiver": + title = ds[u]._title_for_slice() + else: + title = ds[x]._title_for_slice() + ax.set_title(title) return primitive From 04f64c562be7f7ffd25c34f1991da5f1d6cf3ba9 Mon Sep 17 00:00:00 2001 From: dcherian Date: Sat, 19 Sep 2020 18:03:36 -0600 Subject: [PATCH 10/25] Adding docs --- doc/plotting.rst | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/doc/plotting.rst b/doc/plotting.rst index 3699f794ae8..767a4ca58cd 100644 --- a/doc/plotting.rst +++ b/doc/plotting.rst @@ -715,6 +715,9 @@ Consider this dataset ds +Scatter +~~~~~~~ + Suppose we want to scatter ``A`` against ``B`` .. ipython:: python @@ -762,6 +765,24 @@ Faceting is also possible For more advanced scatter plots, we recommend converting the relevant data variables to a pandas DataFrame and using the extensive plotting capabilities of ``seaborn``. +Quiver +~~~~~~ + +.. ipython:: python + :okwarning: + + @savefig ds_simple_quiver.png + ds.isel(w=1, z=1).plot.quiver(x="x", y="y", u="A", v="B") + + +Again, faceting is also possible + +.. ipython:: python + :okwarning: + + @savefig ds_facet_quiver.png + ds.plot.quiver(x="x", y="y", u="A", v="B", col="w", row="z", scale=4) + .. _plot-maps: From abcc0081a4c37b60dfb35653301f47cebce0029b Mon Sep 17 00:00:00 2001 From: dcherian Date: Sat, 19 Sep 2020 18:03:59 -0600 Subject: [PATCH 11/25] small fixes --- xarray/plot/dataset_plot.py | 54 +++++++++++++++++++------------------ xarray/plot/facetgrid.py | 39 ++++++++++++++------------- xarray/plot/utils.py | 2 +- 3 files changed, 50 insertions(+), 45 deletions(-) diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index 1aabb218df7..a9a4cd54d3b 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -441,32 +441,6 @@ def plotmethod( return newplotfunc -@_dsplot -def quiver(ds, x, y, ax, u, v, **kwargs): - import matplotlib as mpl - - if x is None or y is None or u is None or v is None: - raise ValueError("Must specify x, y, u, v for quiver plots.") - - x, y, u, v = broadcast(ds[x], ds[y], ds[u], ds[v]) - - args = [x.values, y.values, u.values, v.values] - hue = kwargs.pop("hue") - if hue: - args.append(ds[hue].values) - - # TODO: Fix this by always returning a norm with vmin, vmax in cmap_params - cmap_params = kwargs.pop("cmap_params") - if not cmap_params["norm"]: - cmap_params["norm"] = mpl.colors.Normalize( - cmap_params.pop("vmin"), cmap_params.pop("vmax") - ) - - kwargs.pop("hue_style") - hdl = ax.quiver(*args, **kwargs, **cmap_params) - return hdl - - @_dsplot def scatter(ds, x, y, ax, u, v, **kwargs): """ @@ -520,3 +494,31 @@ def scatter(ds, x, y, ax, u, v, **kwargs): ) return primitive + + +@_dsplot +def quiver(ds, x, y, ax, u, v, **kwargs): + import matplotlib as mpl + + if x is None or y is None or u is None or v is None: + raise ValueError("Must specify x, y, u, v for quiver plots.") + + x, y, u, v = broadcast(ds[x], ds[y], ds[u], ds[v]) + + args = [x.values, y.values, u.values, v.values] + hue = kwargs.pop("hue") + cmap_params = kwargs.pop("cmap_params") + + if hue: + args.append(ds[hue].values) + + # TODO: Fix this by always returning a norm with vmin, vmax in cmap_params + if not cmap_params["norm"]: + cmap_params["norm"] = mpl.colors.Normalize( + cmap_params.pop("vmin"), cmap_params.pop("vmax") + ) + + kwargs.pop("hue_style") + kwargs.setdefault("pivot", "middle") + hdl = ax.quiver(*args, **kwargs, **cmap_params) + return hdl diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 20d1b214e7c..0e03f56670c 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -349,19 +349,22 @@ def map_dataset( kwargs["_is_facetgrid"] = True if func.__name__ == "quiver" and "scale" not in kwargs: - if "scale_units" in kwargs and kwargs["scale_units"] is not None: - raise NotImplementedError("Can't pass only scale_units.") - # autoscaling - ax = self.axes[0, 0] - magnitude = _get_nice_quiver_magnitude( - self.data[kwargs["u"]], self.data[kwargs["v"]] - ) - # matplotlib autoscaling algorithm - span = ax.get_transform().inverted().transform_bbox(ax.bbox).width - npts = self.data.sizes[x] * self.data.sizes[y] - # scale is typical arrow length as a multiple of the arrow width - scale = 1.8 * magnitude * max(10, np.sqrt(npts)) / span - kwargs["scale"] = 1 / scale # TODO: why? + raise ValueError("Please provide scale.") + # TODO: come up with an algorithm for reasonable scale choice + # if "scale_units" in kwargs and kwargs["scale_units"] is not None: + # raise NotImplementedError("Can't pass only scale_units.") + # # autoscaling + # ax = self.axes[0, 0] + # magnitude = _get_nice_quiver_magnitude( + # self.data[kwargs["u"]], self.data[kwargs["v"]] + # ) + # # matplotlib autoscaling algorithm + # span = ax.get_transform().inverted().transform_bbox(ax.bbox).width + # npts = self.data.sizes[x] * self.data.sizes[y] + # # scale is typical arrow length as a multiple of the arrow width + # print(magnitude, np.sqrt(npts), span) + # kwargs["scale"] = 1.8 * magnitude * min(10, np.sqrt(npts)) / span + # print(kwargs["scale"]) for d, ax in zip(self.name_dicts.flat, self.axes.flat): # None is the sentinel value @@ -457,17 +460,17 @@ def add_quiverkey(self, u, v, **kwargs): units = self.data[u].attrs.get("units", "") self.quiverkey = self.axes.flat[-1].quiverkey( self._mappables[-1], - X=0.85, - Y=1.03, + X=0.8, + Y=0.9, U=magnitude, label=f"{magnitude}\n{units}", labelpos="E", - coordinates="axes", + coordinates="figure", ) # TODO: does not work because self.quiverkey.get_window_extent(renderer) = 0 - # self._adjust_fig_for_guide(self.quiverkey) - + # https://github.com/matplotlib/matplotlib/issues/18530 + # self._adjust_fig_for_guide(self.quiverkey.text) return self def set_axis_labels(self, x_var=None, y_var=None): diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 278ed964073..adb60c628d7 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -846,6 +846,6 @@ def _get_nice_quiver_magnitude(u, v): import matplotlib as mpl ticker = mpl.ticker.MaxNLocator(3) - median = np.median(np.hypot(u.values, v.values)) + median = np.mean(np.hypot(u.values, v.values)) magnitude = ticker.tick_values(0, median)[-2] return magnitude From f1da0c0c28fa6b4599f880f102613aa92934eb46 Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 8 Oct 2020 14:03:42 -0600 Subject: [PATCH 12/25] Start adding tests --- xarray/tests/test_plot.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 549a3c49c1f..a83ac5635af 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -2152,6 +2152,29 @@ def test_wrong_num_of_dimensions(self): self.darray.plot.line(row="row", hue="hue") +@requires_matplotlib +class TestDatasetQuiverPlots(PlotTestCase): + @pytest.fixture(autouse=True) + def setUp(self): + das = [ + DataArray( + np.random.randn(3, 3, 4, 4), + dims=["x", "row", "col", "hue"], + coords=[range(k) for k in [3, 3, 4, 4]], + ) + for _ in [1, 2] + ] + ds = Dataset({"A": das[0], "B": das[1]}) + ds.hue.name = "huename" + ds.hue.attrs["units"] = "hunits" + ds.x.attrs["units"] = "xunits" + ds.col.attrs["units"] = "colunits" + ds.row.attrs["units"] = "rowunits" + ds.A.attrs["units"] = "Aunits" + ds.B.attrs["units"] = "Bunits" + self.ds = ds + + @requires_matplotlib class TestDatasetScatterPlots(PlotTestCase): @pytest.fixture(autouse=True) From 39545b1aa0f7d901b093961c1193de0bfb871ebf Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 19 Jan 2021 17:14:14 -0700 Subject: [PATCH 13/25] Simple tests. --- xarray/tests/test_plot.py | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index a83ac5635af..d3aade4d0bf 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -2159,21 +2159,38 @@ def setUp(self): das = [ DataArray( np.random.randn(3, 3, 4, 4), - dims=["x", "row", "col", "hue"], + dims=["x", "y", "row", "col"], coords=[range(k) for k in [3, 3, 4, 4]], ) for _ in [1, 2] ] - ds = Dataset({"A": das[0], "B": das[1]}) - ds.hue.name = "huename" - ds.hue.attrs["units"] = "hunits" + ds = Dataset({"u": das[0], "v": das[1]}) ds.x.attrs["units"] = "xunits" + ds.y.attrs["units"] = "yunits" ds.col.attrs["units"] = "colunits" ds.row.attrs["units"] = "rowunits" - ds.A.attrs["units"] = "Aunits" - ds.B.attrs["units"] = "Bunits" + ds.u.attrs["units"] = "uunits" + ds.v.attrs["units"] = "vunits" + ds["mag"] = np.hypot(ds.u, ds.v) self.ds = ds + def test_quiver(self): + with figure_context(): + hdl = self.ds.isel(row=0, col=0).plot.quiver(x="x", y="y", u="u", v="v") + assert isinstance(hdl, mpl.quiver.Quiver) + with raises_regex(ValueError, "specify x, y, u, v"): + self.ds.isel(row=0, col=0).plot.quiver(x="x", y="y", u="u") + + def test_facetgrid(self): + with figure_context(): + fg = self.ds.plot.quiver( + x="x", y="y", u="u", v="v", row="row", col="col", scale=1, hue="mag" + ) + for handle in fg._mappables: + assert isinstance(handle, mpl.quiver.Quiver) + with raises_regex(ValueError, "Please provide scale"): + self.ds.plot.quiver(x="x", y="y", u="u", v="v", row="row", col="col") + @requires_matplotlib class TestDatasetScatterPlots(PlotTestCase): From b6c1bc81a17a27c360211dac23f8086c3fe22183 Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 19 Jan 2021 17:21:36 -0700 Subject: [PATCH 14/25] [skip-ci] Start adding docs. --- doc/api.rst | 1 + xarray/plot/dataset_plot.py | 1 + 2 files changed, 2 insertions(+) diff --git a/doc/api.rst b/doc/api.rst index ceab7dcc976..ce866093db8 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -240,6 +240,7 @@ Plotting :template: autosummary/accessor_method.rst Dataset.plot.scatter + Dataset.plot.quiver DataArray ========= diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index a9a4cd54d3b..fb850be0641 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -498,6 +498,7 @@ def scatter(ds, x, y, ax, u, v, **kwargs): @_dsplot def quiver(ds, x, y, ax, u, v, **kwargs): + """ Quiver plot with Dataset variables.""" import matplotlib as mpl if x is None or y is None or u is None or v is None: From 7b559453a88eb1fc2b425b3dd3d54e295674006c Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 19 Jan 2021 17:25:26 -0700 Subject: [PATCH 15/25] [skip-ci] add whats-new --- doc/whats-new.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 88994a5bfc0..502181487b9 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -47,6 +47,8 @@ New Features ~~~~~~~~~~~~ - Performance improvement when constructing DataArrays. Significantly speeds up repr for Datasets with large number of variables. By `Deepak Cherian `_ +- Add :py:meth:`Dataset.plot.quiver` for quiver plots with :py:class:`Dataset` variables. + By `Deepak Cherian `_ Bug fixes ~~~~~~~~~ From 0da3540161ec6a1cac17caa4fa40ea2442636d45 Mon Sep 17 00:00:00 2001 From: dcherian Date: Wed, 20 Jan 2021 11:44:54 -0700 Subject: [PATCH 16/25] Add quiverkey test --- xarray/plot/facetgrid.py | 4 ++++ xarray/tests/test_plot.py | 15 +++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 0e03f56670c..95be05a4a25 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -196,7 +196,11 @@ def __init__( self.axes = axes self.row_names = row_names self.col_names = col_names + + # guides self.figlegend = None + self.quiverkey = None + self.cbar = None # Next the private variables self._single_group = single_group diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index d3aade4d0bf..c1201782791 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -2188,6 +2188,21 @@ def test_facetgrid(self): ) for handle in fg._mappables: assert isinstance(handle, mpl.quiver.Quiver) + assert "uunits" in fg.quiverkey.text.get_text() + + with figure_context(): + fg = self.ds.plot.quiver( + x="x", + y="y", + u="u", + v="v", + row="row", + col="col", + scale=1, + hue="mag", + add_guide=False, + ) + assert fg.quiverkey is None with raises_regex(ValueError, "Please provide scale"): self.ds.plot.quiver(x="x", y="y", u="u", v="v", row="row", col="col") From bc785826af6ba0b6ae728acc02f5cbac2382012d Mon Sep 17 00:00:00 2001 From: dcherian Date: Wed, 20 Jan 2021 11:48:05 -0700 Subject: [PATCH 17/25] Add test for hue_style --- xarray/tests/test_plot.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index c1201782791..d8bf6632914 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -2181,6 +2181,11 @@ def test_quiver(self): with raises_regex(ValueError, "specify x, y, u, v"): self.ds.isel(row=0, col=0).plot.quiver(x="x", y="y", u="u") + with raises_regex(ValueError, "hue_style"): + self.ds.isel(row=0, col=0).plot.quiver( + x="x", y="y", u="u", hue_style="discrete" + ) + def test_facetgrid(self): with figure_context(): fg = self.ds.plot.quiver( From d1e028ccbb3adbad20bce423bf7d95b07cf5905e Mon Sep 17 00:00:00 2001 From: dcherian Date: Wed, 20 Jan 2021 11:50:20 -0700 Subject: [PATCH 18/25] small doc updates. --- doc/plotting.rst | 2 ++ xarray/plot/dataset_plot.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/doc/plotting.rst b/doc/plotting.rst index 767a4ca58cd..a6866f98b11 100644 --- a/doc/plotting.rst +++ b/doc/plotting.rst @@ -768,6 +768,8 @@ For more advanced scatter plots, we recommend converting the relevant data varia Quiver ~~~~~~ +Visualizing vector fields using a quiver plot is easy. + .. ipython:: python :okwarning: diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index fb850be0641..e71091eb9cb 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -185,7 +185,7 @@ def _dsplot(plotfunc): ds : Dataset x, y : str Variable names for x, y axis. - u, v : string + u, v : str, optional Variable names for quiver plots hue: str, optional Variable by which to color scattered points From 94f83969379ce19981c5365a2198183b05dee2a4 Mon Sep 17 00:00:00 2001 From: dcherian Date: Wed, 20 Jan 2021 11:50:55 -0700 Subject: [PATCH 19/25] remove commented out autoscaling code --- xarray/plot/facetgrid.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 95be05a4a25..cf76000e806 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -355,20 +355,6 @@ def map_dataset( if func.__name__ == "quiver" and "scale" not in kwargs: raise ValueError("Please provide scale.") # TODO: come up with an algorithm for reasonable scale choice - # if "scale_units" in kwargs and kwargs["scale_units"] is not None: - # raise NotImplementedError("Can't pass only scale_units.") - # # autoscaling - # ax = self.axes[0, 0] - # magnitude = _get_nice_quiver_magnitude( - # self.data[kwargs["u"]], self.data[kwargs["v"]] - # ) - # # matplotlib autoscaling algorithm - # span = ax.get_transform().inverted().transform_bbox(ax.bbox).width - # npts = self.data.sizes[x] * self.data.sizes[y] - # # scale is typical arrow length as a multiple of the arrow width - # print(magnitude, np.sqrt(npts), span) - # kwargs["scale"] = 1.8 * magnitude * min(10, np.sqrt(npts)) / span - # print(kwargs["scale"]) for d, ax in zip(self.name_dicts.flat, self.axes.flat): # None is the sentinel value From de48d51fe5a13e190404d08ed9cd6b2fda474716 Mon Sep 17 00:00:00 2001 From: dcherian Date: Wed, 20 Jan 2021 11:55:51 -0700 Subject: [PATCH 20/25] Raise if u,v are provided for non-quiver plots. --- xarray/plot/dataset_plot.py | 3 +++ xarray/tests/test_plot.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index e71091eb9cb..59d3ca98f23 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -338,6 +338,9 @@ def newplotfunc( else: cmap_params_subset = {} + if (u is not None or v is not None) and plotfunc.__name__ != "quiver": + raise ValueError("u, v are only allowed for quiver plots.") + primitive = plotfunc( ds=ds, x=x, diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index d8bf6632914..5f735a6769e 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -2339,6 +2339,9 @@ def test_facetgrid_hue_style(self): def test_scatter(self, x, y, hue, markersize): self.ds.plot.scatter(x, y, hue=hue, markersize=markersize) + with raises_regex(ValueError, "u, v"): + self.ds.plot.scatter(x, y, u="col", v="row") + def test_non_numeric_legend(self): ds2 = self.ds.copy() ds2["hue"] = ["a", "b", "c", "d"] From e795672b6dc01725ff0e7d7157fc1ec5d5cddf43 Mon Sep 17 00:00:00 2001 From: dcherian Date: Wed, 20 Jan 2021 14:12:35 -0700 Subject: [PATCH 21/25] Fix tests --- xarray/tests/test_plot.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 5f735a6769e..705b2d5e2e7 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -1475,7 +1475,7 @@ def test_facetgrid_cbar_kwargs(self): ) # catch contour case - if hasattr(g, "cbar"): + if g.cbar is not None: assert get_colorbar_label(g.cbar) == "test_label" def test_facetgrid_no_cbar_ax(self): @@ -2183,7 +2183,7 @@ def test_quiver(self): with raises_regex(ValueError, "hue_style"): self.ds.isel(row=0, col=0).plot.quiver( - x="x", y="y", u="u", hue_style="discrete" + x="x", y="y", u="u", v="v", hue="mag", hue_style="discrete" ) def test_facetgrid(self): From 8a3912c7218e894ca2663d7fffceb146cdf93d9b Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 29 Jan 2021 16:28:20 -0700 Subject: [PATCH 22/25] [skip-ci] fix bad merge. --- doc/whats-new.rst | 1 - 1 file changed, 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index bb81426e4dd..3ef8009ba86 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -60,7 +60,6 @@ New Features - Performance improvement when constructing DataArrays. Significantly speeds up repr for Datasets with large number of variables. By `Deepak Cherian `_ - Add :py:meth:`Dataset.plot.quiver` for quiver plots with :py:class:`Dataset` variables. - By `Deepak Cherian `_ By `Deepak Cherian `_. - :py:meth:`DataArray.swap_dims` & :py:meth:`Dataset.swap_dims` now accept dims in the form of kwargs as well as a dict, like most similar methods. From ca00bd4d39d0fae8af35f446cfe0575eaa1d3f7f Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 10 Feb 2021 15:19:42 -0700 Subject: [PATCH 23/25] [skip-ci] update whats-new --- doc/whats-new.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 47c35458acb..9cdfe1517c4 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -145,7 +145,7 @@ Internal Changes the other ``assert_*`` functions (:pull:`4864`). By `Mathias Hauser `_. - Performance improvement when constructing DataArrays. Significantly speeds up repr for Datasets with large number of variables. By `Deepak Cherian `_ - + .. _whats-new.0.16.2: v0.16.2 (30 Nov 2020) From d2cf58c4c71cde14d2b86febdbe3f96428a06eec Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 12 Feb 2021 07:04:40 -0700 Subject: [PATCH 24/25] [skip-ci] Apply suggestions from code review Co-authored-by: Mathias Hauser --- doc/plotting.rst | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/doc/plotting.rst b/doc/plotting.rst index a6866f98b11..2ada3e25431 100644 --- a/doc/plotting.rst +++ b/doc/plotting.rst @@ -768,7 +768,7 @@ For more advanced scatter plots, we recommend converting the relevant data varia Quiver ~~~~~~ -Visualizing vector fields using a quiver plot is easy. +Visualizing vector fields is supported with quiver plots: .. ipython:: python :okwarning: @@ -777,7 +777,7 @@ Visualizing vector fields using a quiver plot is easy. ds.isel(w=1, z=1).plot.quiver(x="x", y="y", u="A", v="B") -Again, faceting is also possible +where ``u`` and ``v`` denote the x and y direction components of the arrow vectors. Again, faceting is also possible: .. ipython:: python :okwarning: @@ -785,6 +785,7 @@ Again, faceting is also possible @savefig ds_facet_quiver.png ds.plot.quiver(x="x", y="y", u="A", v="B", col="w", row="z", scale=4) +``scale`` is required for faceted quiver plots. The scale determines the number of data units per arrow length unit, i.e. a smaller scale parameter makes the arrow longer. .. _plot-maps: From b9bcadaa41bec701666debdfda6b5aac9930d41b Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 12 Feb 2021 07:05:15 -0700 Subject: [PATCH 25/25] review comments. --- xarray/plot/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 534a00bbd10..5510cf7f219 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -847,6 +847,6 @@ def _get_nice_quiver_magnitude(u, v): import matplotlib as mpl ticker = mpl.ticker.MaxNLocator(3) - median = np.mean(np.hypot(u.values, v.values)) - magnitude = ticker.tick_values(0, median)[-2] + mean = np.mean(np.hypot(u.values, v.values)) + magnitude = ticker.tick_values(0, mean)[-2] return magnitude