diff --git a/doc/plotting.rst b/doc/plotting.rst index 9710cb5d9c8..e1b466a6e7a 100644 --- a/doc/plotting.rst +++ b/doc/plotting.rst @@ -141,8 +141,7 @@ Simple Example ~~~~~~~~~~~~~~ The default method :py:meth:`xray.DataArray.plot` sees that the data is -2 dimensional. If the coordinates are uniformly spaced then it -calls :py:func:`xray.plot.imshow`. +2 dimensional and calls :py:func:`xray.plot.pcolormesh`. .. ipython:: python @@ -159,6 +158,14 @@ and ``xincrease``. @savefig 2d_simple_yincrease.png width=4in air2d.plot(yincrease=False) +.. note:: + + We use :py:func:`xray.plot.pcolormesh` as the default two-dimensional plot + method because it is more flexible than :py:func:`xray.plot.imshow`. + However, for large arrays, ``imshow`` can be much faster than ``pcolormesh``. + If speed is important to you and you are plotting a regular mesh, consider + using ``imshow``. + Missing Values ~~~~~~~~~~~~~~ @@ -176,9 +183,9 @@ Xray plots data with :ref:`missing_values`. Nonuniform Coordinates ~~~~~~~~~~~~~~~~~~~~~~ -It's not necessary for the coordinates to be evenly spaced. If not, then -:py:meth:`xray.DataArray.plot` produces a filled contour plot by calling -:py:func:`xray.plot.contourf`. +It's not necessary for the coordinates to be evenly spaced. Both +:py:func:`xray.plot.pcolormesh` (default) and :py:func:`xray.plot.contourf` can +produce plots with nonuniform coordinates. .. ipython:: python @@ -201,6 +208,7 @@ matplotlib is available. plt.title('These colors prove North America\nhas fallen in the ocean') plt.ylabel('latitude') plt.xlabel('longitude') + plt.tight_layout() @savefig plotting_2d_call_matplotlib.png width=4in plt.show() @@ -376,8 +384,8 @@ Faceted plotting supports other arguments common to xray 2d plots. hasoutliers[-1, -1, -1] = 400 @savefig plot_facet_robust.png height=12in - g = hasoutliers.plot.imshow('lon', 'lat', col='time', col_wrap=3, - robust=True, cmap='viridis') + g = hasoutliers.plot.pcolormesh('lon', 'lat', col='time', col_wrap=3, + robust=True, cmap='viridis') FacetGrid Objects ~~~~~~~~~~~~~~~~~ @@ -473,14 +481,13 @@ plotting function based on the dimensions of the ``DataArray`` and whether the coordinates are sorted and uniformly spaced. This table describes what gets plotted: -=============== =========== =========================== -Dimensions Coordinates Plotting function ---------------- ----------- --------------------------- -1 :py:func:`xray.plot.line` -2 Uniform :py:func:`xray.plot.imshow` -2 Irregular :py:func:`xray.plot.contourf` -Anything else :py:func:`xray.plot.hist` -=============== =========== =========================== +=============== =========================== +Dimensions Plotting function +--------------- --------------------------- +1 :py:func:`xray.plot.line` +2 :py:func:`xray.plot.pcolormesh` +Anything else :py:func:`xray.plot.hist` +=============== =========================== Coordinates ~~~~~~~~~~~ diff --git a/xray/plot/facetgrid.py b/xray/plot/facetgrid.py index 93d139d7607..bc2ffe7b353 100644 --- a/xray/plot/facetgrid.py +++ b/xray/plot/facetgrid.py @@ -7,7 +7,7 @@ import numpy as np from ..core.formatting import format_item -from .utils import _determine_cmap_params +from .utils import _determine_cmap_params, _infer_xy_labels # Overrides axes.labelsize, xtick.major.size, ytick.major.size @@ -242,6 +242,10 @@ def map_dataarray(self, func, x, y, **kwargs): defaults.update(cmap_params) defaults.update(kwargs) + # Get x, y labels for the first subplot + x, y = _infer_xy_labels(darray=self.data.loc[self.name_dicts.flat[0]], + x=x, y=y) + for d, ax in zip(self.name_dicts.flat, self.axes.flat): # None is the sentinel value if d is not None: @@ -270,7 +274,7 @@ def map_dataarray(self, func, x, y, **kwargs): extend=cmap_params['extend']) if self.data.name: - cbar.set_label(self.data.name, rotation=270, + cbar.set_label(self.data.name, rotation=90, verticalalignment='bottom') self._x_var = x diff --git a/xray/plot/plot.py b/xray/plot/plot.py index 1317cc54eeb..1312e4ef172 100644 --- a/xray/plot/plot.py +++ b/xray/plot/plot.py @@ -13,9 +13,8 @@ import numpy as np import pandas as pd -from .utils import _determine_cmap_params +from .utils import _determine_cmap_params, _infer_xy_labels from .facetgrid import FacetGrid -from ..core.utils import is_uniform_spaced # Maybe more appropriate to keep this in .utils @@ -39,47 +38,6 @@ def _ensure_plottable(*args): 'or dates.') -def _infer_xy_labels(plotfunc, darray, x, y): - """ - Determine x and y labels when some are missing. For use in _plot2d - - darray is a 2 dimensional data array. - """ - dims = list(darray.dims) - - if len(dims) != 2: - raise ValueError('{type} plots are for 2 dimensional DataArrays. ' - 'Passed DataArray has {ndim} dimensions' - .format(type=plotfunc.__name__, ndim=len(dims))) - - if x and x not in dims: - raise KeyError('{0} is not a dimension of this DataArray. Use ' - '{1} or {2} for x' - .format(x, *dims)) - - if y and y not in dims: - raise KeyError('{0} is not a dimension of this DataArray. Use ' - '{1} or {2} for y' - .format(y, *dims)) - - # Get label names - if x and y: - xlab = x - ylab = y - elif x and not y: - xlab = x - del dims[dims.index(x)] - ylab = dims.pop() - elif y and not x: - ylab = y - del dims[dims.index(y)] - xlab = dims.pop() - else: - ylab, xlab = dims - - return xlab, ylab - - def _easy_facetgrid(darray, plotfunc, x, y, row=None, col=None, col_wrap=None, aspect=1, size=3, subplot_kws=None, **kwargs): """ @@ -99,19 +57,18 @@ def _easy_facetgrid(darray, plotfunc, x, y, row=None, col=None, col_wrap=None, def plot(darray, row=None, col=None, col_wrap=None, ax=None, rtol=0.01, subplot_kws=None, **kwargs): """ - Default plot of DataArray using matplotlib / pylab. + Default plot of DataArray using matplotlib.pyplot. Calls xray plotting function based on the dimensions of darray.squeeze() - =============== =========== =========================== - Dimensions Coordinates Plotting function - --------------- ----------- --------------------------- - 1 :py:func:`xray.plot.line` - 2 Uniform :py:func:`xray.plot.imshow` - 2 Irregular :py:func:`xray.plot.contourf` - Anything else :py:func:`xray.plot.hist` - =============== =========== =========================== + =============== =========================== + Dimensions Plotting function + --------------- --------------------------- + 1 :py:func:`xray.plot.line` + 2 :py:func:`xray.plot.pcolormesh` + Anything else :py:func:`xray.plot.hist` + =============== =========================== Parameters ---------- @@ -156,9 +113,7 @@ def plot(darray, row=None, col=None, col_wrap=None, ax=None, rtol=0.01, kwargs['col_wrap'] = col_wrap kwargs['subplot_kws'] = subplot_kws - indexes = (darray.indexes[dim].values for dim in plot_dims) - uniform = all(is_uniform_spaced(i, rtol=rtol) for i in indexes) - plotfunc = imshow if uniform else contourf + plotfunc = pcolormesh else: if row or col: raise ValueError(error_msg) @@ -376,7 +331,7 @@ def _plot2d(plotfunc): @functools.wraps(plotfunc) def newplotfunc(darray, x=None, y=None, ax=None, row=None, col=None, - col_wrap=None, xincrease=None, yincrease=None, + col_wrap=None, xincrease=True, yincrease=True, add_colorbar=True, add_labels=True, vmin=None, vmax=None, cmap=None, center=None, robust=False, extend=None, levels=None, colors=None, subplot_kws=None, **kwargs): @@ -416,8 +371,7 @@ def newplotfunc(darray, x=None, y=None, ax=None, row=None, col=None, if ax is None: ax = plt.gca() - xlab, ylab = _infer_xy_labels(plotfunc=plotfunc, darray=darray, - x=x, y=y) + xlab, ylab = _infer_xy_labels(darray=darray, x=x, y=y) # better to pass the ndarrays directly to plotting functions xval = darray[xlab].values @@ -471,7 +425,7 @@ def newplotfunc(darray, x=None, y=None, ax=None, row=None, col=None, if add_colorbar: cbar = plt.colorbar(primitive, ax=ax, extend=cmap_params['extend']) if darray.name and add_labels: - cbar.set_label(darray.name) + cbar.set_label(darray.name, rotation=90) _update_axes_limits(ax, xincrease, yincrease) @@ -480,7 +434,7 @@ def newplotfunc(darray, x=None, y=None, ax=None, row=None, col=None, # For use as DataArray.plot.plotmethod @functools.wraps(newplotfunc) def plotmethod(_PlotMethods_obj, x=None, y=None, ax=None, row=None, - col=None, col_wrap=None, xincrease=None, yincrease=None, + col=None, col_wrap=None, xincrease=True, yincrease=True, add_colorbar=True, add_labels=True, vmin=None, vmax=None, cmap=None, colors=None, center=None, robust=False, extend=None, levels=None, subplot_kws=None, **kwargs): @@ -506,7 +460,7 @@ def plotmethod(_PlotMethods_obj, x=None, y=None, ax=None, row=None, @_plot2d def imshow(x, y, z, ax, **kwargs): """ - Image plot of 2d DataArray using matplotlib / pylab + Image plot of 2d DataArray using matplotlib.pyplot Wraps matplotlib.pyplot.imshow @@ -518,6 +472,11 @@ def imshow(x, y, z, ax, **kwargs): The pixels are centered on the coordinates values. Ie, if the coordinate value is 3.2 then the pixels for those coordinates will be centered on 3.2. """ + + if x.ndim != 1 or y.ndim != 1: + raise ValueError('imshow requires 1D coordinates, try using ' + 'pcolormesh or contour(f)') + # Centering the pixels- Assumes uniform spacing xstep = (x[1] - x[0]) / 2.0 ystep = (y[1] - y[0]) / 2.0 @@ -589,7 +548,7 @@ def pcolormesh(x, y, z, ax, **kwargs): # by default, pcolormesh picks "round" values for bounds # this results in ugly looking plots with lots of surrounding whitespace - if not hasattr(ax, 'projection'): + if not hasattr(ax, 'projection') and x.ndim == 1 and y.ndim == 1: # not a cartopy geoaxis ax.set_xlim(x[0], x[-1]) ax.set_ylim(y[0], y[-1]) diff --git a/xray/plot/utils.py b/xray/plot/utils.py index 87f1c995dfe..60055f0bd52 100644 --- a/xray/plot/utils.py +++ b/xray/plot/utils.py @@ -177,3 +177,21 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None, return dict(vmin=vmin, vmax=vmax, cmap=cmap, extend=extend, levels=levels, cnorm=cnorm) + + +def _infer_xy_labels(darray, x, y): + """ + Determine x and y labels. For use in _plot2d + + darray must be a 2 dimensional data array. + """ + + if x is None and y is None: + if darray.ndim != 2: + raise ValueError('DataArray must be 2d') + y, x = darray.dims + elif x is None or y is None: + raise ValueError('cannot supply only one of x and y') + elif any(k not in darray.coords for k in (x, y)): + raise ValueError('x and y must be coordinate variables') + return x, y diff --git a/xray/test/test_plot.py b/xray/test/test_plot.py index 64e559f4600..c24a6f146d7 100644 --- a/xray/test/test_plot.py +++ b/xray/test/test_plot.py @@ -89,12 +89,12 @@ def test_2d_before_squeeze(self): a.plot() def test2d_uniform_calls_imshow(self): - self.assertTrue(self.imshow_called(self.darray[:, :, 0].plot)) + self.assertTrue(self.imshow_called(self.darray[:, :, 0].plot.imshow)) def test2d_nonuniform_calls_contourf(self): a = self.darray[:, :, 0] a.coords['dim_1'] = [2, 1, 89] - self.assertTrue(self.contourf_called(a.plot)) + self.assertTrue(self.contourf_called(a.plot.contourf)) def test3d(self): self.darray.plot() @@ -399,8 +399,16 @@ class Common2dMixin: """ def setUp(self): - self.darray = DataArray(easy_array( + da = DataArray(easy_array( (10, 15), start=-1), dims=['y', 'x']) + # add 2d coords + ds = da.to_dataset(name='testvar') + x, y = np.meshgrid(da.x.values, da.y.values) + ds['x2d'] = DataArray(x, dims=['y', 'x']) + ds['y2d'] = DataArray(y, dims=['y', 'x']) + ds.set_coords(['x2d', 'y2d'], inplace=True) + # set darray and plot method + self.darray = ds.testvar self.plotmethod = getattr(self.darray.plot, self.plotfunc.__name__) def test_label_names(self): @@ -409,12 +417,12 @@ def test_label_names(self): self.assertEqual('y', plt.gca().get_ylabel()) def test_1d_raises_valueerror(self): - with self.assertRaisesRegexp(ValueError, r'[Dd]im'): + with self.assertRaisesRegexp(ValueError, r'DataArray must be 2d'): self.plotfunc(self.darray[0, :]) def test_3d_raises_valueerror(self): a = DataArray(easy_array((2, 3, 4))) - with self.assertRaisesRegexp(ValueError, r'[Dd]im'): + with self.assertRaisesRegexp(ValueError, r'DataArray must be 2d'): self.plotfunc(a) def test_nonnumeric_index_raises_typeerror(self): @@ -484,26 +492,25 @@ def test_xy_strings(self): self.assertEqual('y', ax.get_xlabel()) self.assertEqual('x', ax.get_ylabel()) - def test_positional_x_string(self): - self.plotmethod('y') - ax = plt.gca() - self.assertEqual('y', ax.get_xlabel()) - self.assertEqual('x', ax.get_ylabel()) - - def test_y_string(self): - self.plotmethod(y='x') - ax = plt.gca() - self.assertEqual('y', ax.get_xlabel()) - self.assertEqual('x', ax.get_ylabel()) + def test_positional_coord_string(self): + with self.assertRaisesRegexp(ValueError, 'cannot supply only one'): + self.plotmethod('y') + with self.assertRaisesRegexp(ValueError, 'cannot supply only one'): + self.plotmethod(y='x') def test_bad_x_string_exception(self): - with self.assertRaisesRegexp(KeyError, r'y'): - self.plotmethod('not_a_real_dim') - + with self.assertRaisesRegexp(ValueError, 'x and y must be coordinate'): + self.plotmethod('not_a_real_dim', 'y') self.darray.coords['z'] = 100 - with self.assertRaisesRegexp(KeyError, r'y'): + with self.assertRaisesRegexp(ValueError, 'cannot supply only one'): self.plotmethod('z') + def test_coord_strings(self): + # 1d coords (same as dims) + self.assertIn('x', self.darray.coords) + self.assertIn('y', self.darray.coords) + self.plotmethod(y='y', x='x') + def test_default_title(self): a = DataArray(easy_array((4, 3, 2)), dims=['a', 'b', 'c']) a.coords['d'] = u'foo' @@ -545,8 +552,30 @@ def test_convenient_facetgrid(self): g = self.plotfunc(d, x='x', y='y', col='z', col_wrap=2) self.assertArrayEqual(g.axes.shape, [2, 2]) - for ax in g.axes.flat: + for (y, x), ax in np.ndenumerate(g.axes): + self.assertTrue(ax.has_data()) + if x == 0: + self.assertEqual('y', ax.get_ylabel()) + else: + self.assertEqual('', ax.get_ylabel()) + if y == 1: + self.assertEqual('x', ax.get_xlabel()) + else: + self.assertEqual('', ax.get_xlabel()) + + # Infering labels + g = self.plotfunc(d, col='z', col_wrap=2) + self.assertArrayEqual(g.axes.shape, [2, 2]) + for (y, x), ax in np.ndenumerate(g.axes): self.assertTrue(ax.has_data()) + if x == 0: + self.assertEqual('y', ax.get_ylabel()) + else: + self.assertEqual('', ax.get_ylabel()) + if y == 1: + self.assertEqual('x', ax.get_xlabel()) + else: + self.assertEqual('', ax.get_xlabel()) def test_convenient_facetgrid_4d(self): a = easy_array((10, 15, 2, 3)) @@ -598,6 +627,13 @@ def test_extend(self): artist = self.plotmethod(vmin=-10, vmax=0) self.assertEqual(artist.extend, 'max') + def test_2d_coord_names(self): + self.plotmethod(x='x2d', y='y2d') + # make sure labels came out ok + ax = plt.gca() + self.assertEqual('x2d', ax.get_xlabel()) + self.assertEqual('y2d', ax.get_ylabel()) + def test_levels(self): artist = self.plotmethod(levels=[-0.5, -0.4, 0.1]) self.assertEqual(artist.extend, 'both') @@ -633,6 +669,13 @@ def list_of_colors_in_cmap_deprecated(self): with self.assertRaises(Exception): self.plotmethod(cmap=['k', 'b']) + def test_2d_coord_names(self): + self.plotmethod(x='x2d', y='y2d') + # make sure labels came out ok + ax = plt.gca() + self.assertEqual('x2d', ax.get_xlabel()) + self.assertEqual('y2d', ax.get_ylabel()) + class TestPcolormesh(Common2dMixin, PlotTestCase): @@ -646,6 +689,13 @@ def test_everything_plotted(self): artist = self.plotmethod() self.assertEqual(artist.get_array().size, self.darray.size) + def test_2d_coord_names(self): + self.plotmethod(x='x2d', y='y2d') + # make sure labels came out ok + ax = plt.gca() + self.assertEqual('x2d', ax.get_xlabel()) + self.assertEqual('y2d', ax.get_ylabel()) + class TestImshow(Common2dMixin, PlotTestCase): @@ -657,7 +707,7 @@ def test_imshow_called(self): self.assertTrue(self.imshow_called(self.darray.plot.imshow)) def test_xy_pixel_centered(self): - self.darray.plot.imshow() + self.darray.plot.imshow(yincrease=False) self.assertTrue(np.allclose([-0.5, 14.5], plt.gca().get_xlim())) self.assertTrue(np.allclose([9.5, -0.5], plt.gca().get_ylim())) @@ -681,6 +731,9 @@ def test_seaborn_palette_needs_levels(self): except ImportError: pass + def test_2d_coord_names(self): + with self.assertRaisesRegexp(ValueError, 'requires 1D coordinates'): + self.plotmethod(x='x2d', y='y2d') class TestFacetGrid(PlotTestCase):