Skip to content

New infer_intervals keyword for pcolormesh #1079

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Nov 10, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions doc/plotting.rst
Original file line number Diff line number Diff line change
Expand Up @@ -522,3 +522,56 @@ the values on the y axis are decreasing with -0.5 on the top. This is because
the pixels are centered over their coordinates, and the
axis labels and ranges correspond to the values of the
coordinates.

Multidimensional coordinates
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

See also: :ref:`examples.multidim`.

You can plot irregular grids defined by multidimensional coordinates with
xarray, but you'll have to tell the plot function to use these coordinates
instead of the default ones:

.. ipython:: python

lon, lat = np.meshgrid(np.linspace(-20, 20, 5), np.linspace(0, 30, 4))
lon += lat/10
lat += lon/10
da = xr.DataArray(np.arange(20).reshape(4, 5), dims=['y', 'x'],
coords = {'lat': (('y', 'x'), lat),
'lon': (('y', 'x'), lon)})

@savefig plotting_example_2d_irreg.png width=4in
da.plot.pcolormesh('lon', 'lat');

Note that in this case, xarray still follows the pixel centered convention.
This might be undesirable in some cases, for example when your data is defined
on a polar projection (:issue:`781`). This is why the default is to not follow
this convention when plotting on a map:

.. ipython:: python

import cartopy.crs as ccrs
ax = plt.subplot(projection=ccrs.PlateCarree());
da.plot.pcolormesh('lon', 'lat', ax=ax);
ax.scatter(lon, lat, transform=ccrs.PlateCarree());
@savefig plotting_example_2d_irreg_map.png width=4in
ax.coastlines(); ax.gridlines(draw_labels=True);

You can however decide to infer the cell boundaries and use the
``infer_intervals`` keyword:

.. ipython:: python

ax = plt.subplot(projection=ccrs.PlateCarree());
da.plot.pcolormesh('lon', 'lat', ax=ax, infer_intervals=True);
ax.scatter(lon, lat, transform=ccrs.PlateCarree());
@savefig plotting_example_2d_irreg_map_infer.png width=4in
ax.coastlines(); ax.gridlines(draw_labels=True);

.. note::
The data model of xarray does not support datasets with `cell boundaries`_
yet. If you want to use these coordinates, you'll have to make the plots
outside the xarray framework.

.. _cell boundaries: http://cfconventions.org/cf-conventions/v1.6.0/cf-conventions.html#cell-boundaries
5 changes: 5 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ Bug fixes
- ``Dataset.concat()`` now preserves variables order (:issue:`1027`).
By `Fabien Maussion <https://github.com/fmaussion>`_.

- Fixed an issue with pcolormesh (:issue:`781`). A new
``infer_intervals`` keyword gives control on whether the cell intervals
should be computed or not.
By `Fabien Maussion <https://github.com/fmaussion>`_.

.. _whats-new.0.8.2:

v0.8.2 (18 August 2016)
Expand Down
59 changes: 46 additions & 13 deletions xarray/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,12 @@ def _plot2d(plotfunc):
provided, extend is inferred from vmin, vmax and the data limits.
levels : int or list-like object, optional
Split the colormap (cmap) into discrete color intervals.
infer_intervals : bool, optional
Only applies to pcolormesh. If True, the coordinate intervals are
passed to pcolormesh. If False, the original coordinates are used
(this can be useful for certain map projections). The default is to
always infer intervals, unless the mesh is irregular and plotted on
a map projection.
subplot_kws : dict, optional
Dictionary of keyword arguments for matplotlib subplots. Only applies
to FacetGrid plotting.
Expand All @@ -341,8 +347,9 @@ def newplotfunc(darray, x=None, y=None, ax=None, row=None, col=None,
col_wrap=None, xincrease=True, yincrease=True,
add_colorbar=None, add_labels=True, vmin=None, vmax=None,
cmap=None, center=None, robust=False, extend=None,
levels=None, colors=None, subplot_kws=None,
cbar_ax=None, cbar_kwargs=None, **kwargs):
levels=None, infer_intervals=None, colors=None,
subplot_kws=None, cbar_ax=None, cbar_kwargs=None,
**kwargs):
# All 2d plots in xarray share this function signature.
# Method signature below should be consistent.

Expand Down Expand Up @@ -416,6 +423,9 @@ def newplotfunc(darray, x=None, y=None, ax=None, row=None, col=None,
kwargs['extend'] = cmap_params['extend']
kwargs['levels'] = cmap_params['levels']

if 'pcolormesh' == plotfunc.__name__:
kwargs['infer_intervals'] = infer_intervals

# This allows the user to pass in a custom norm coming via kwargs
kwargs.setdefault('norm', cmap_params['norm'])

Expand Down Expand Up @@ -456,8 +466,8 @@ def plotmethod(_PlotMethods_obj, x=None, y=None, ax=None, row=None,
col=None, col_wrap=None, xincrease=True, yincrease=True,
add_colorbar=None, add_labels=True, vmin=None, vmax=None,
cmap=None, colors=None, center=None, robust=False,
extend=None, levels=None, subplot_kws=None,
cbar_ax=None, cbar_kwargs=None, **kwargs):
extend=None, levels=None, infer_intervals=None,
subplot_kws=None, cbar_ax=None, cbar_kwargs=None, **kwargs):
"""
The method should have the same signature as the function.

Expand Down Expand Up @@ -542,29 +552,52 @@ def contourf(x, y, z, ax, **kwargs):
return ax, primitive


def _infer_interval_breaks(coord):
def _infer_interval_breaks(coord, axis=0):
"""
>>> _infer_interval_breaks(np.arange(5))
array([-0.5, 0.5, 1.5, 2.5, 3.5, 4.5])
>>> _infer_interval_breaks([[0, 1], [3, 4]], axis=1)
array([[-0.5, 0.5, 1.5],
[ 2.5, 3.5, 4.5]])
"""
coord = np.asarray(coord)
deltas = 0.5 * (coord[1:] - coord[:-1])
first = coord[0] - deltas[0]
last = coord[-1] + deltas[-1]
return np.r_[[first], coord[:-1] + deltas, [last]]
deltas = 0.5 * np.diff(coord, axis=axis)
first = np.take(coord, [0], axis=axis) - np.take(deltas, [0], axis=axis)
last = np.take(coord, [-1], axis=axis) + np.take(deltas, [-1], axis=axis)
trim_last = tuple(slice(None, -1) if n == axis else slice(None)
for n in range(coord.ndim))
return np.concatenate([first, coord[trim_last] + deltas, last], axis=axis)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nifty! Loved the use of axis here!



@_plot2d
def pcolormesh(x, y, z, ax, **kwargs):
def pcolormesh(x, y, z, ax, infer_intervals=None, **kwargs):
"""
Pseudocolor plot of 2d DataArray

Wraps matplotlib.pyplot.pcolormesh
"""

if not hasattr(ax, 'projection'):
x = _infer_interval_breaks(x)
y = _infer_interval_breaks(y)
# decide on a default for infer_intervals (GH781)
x = np.asarray(x)
if infer_intervals is None:
if hasattr(ax, 'projection'):
if len(x.shape) == 1:
infer_intervals = True
else:
infer_intervals = False
else:
infer_intervals = True

if infer_intervals:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have mixed feelings about this change. I'm not sure I like infer_intervals to default to True. This is going to break a bunch of our plotting code. My preference, and I'm open to other ideas at this point, would be to default to infer_intervals=None with something like this logic:

if (infer_intervals is None) and (hasattr(ax, 'projection'))
    infer_intervals = False
    warnings.warn("infer_intervals=None is deprecated, please specify either True or False",
                  DeprecationWarning)
if infer_intervals:
    ...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure to understand... infer_intervals=True was the de-facto default until V0.7.2, and from the example above it is clear that setting it to False is almost never what you want to do. Can you provide an example where this is a good thing?

I agree your concerns about breaking old code though. How about keeping the current behavior for one more version, but issuing a warning that in the next version (probably 0.9.1 or 0.10), it is going to be set to True per default?

For my purposes this is going to imply that I'm going to use imshow() and not pcolormesh anymore for the time being.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting infer_intervals=False is exactly what you want if you have irregularly spaced coordinates (see example here: #781 (comment)).

I think something like the logic I showed above may be a good compromise. I'm happy to transition to always specifying infer_intervals=False we should be more transparent about the change (hence the warning along with a None for a default value. Would you be happier with this logic?

if (infer_intervals is None):
    if (hasattr(ax, 'projection'))
        infer_intervals = False
        warnings.warn("infer_intervals=None is deprecated, please specify either True or False",
                      DeprecationWarning)
    else:
        infer_intervals = True
if infer_intervals:
    ...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, with irregular coordinates none of the solutions is satisfying. infer_intervals=True crops one column, while infer_intervals=False crop one column and one line of data:

import numpy as np
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import xarray as xr

lon, lat = np.linspace(-20, 20, 5), np.linspace(0, 30, 4)
lon, lat = np.meshgrid(lon, lat)

da = xr.DataArray(np.arange(20).reshape(4, 5), dims=['y', 'x'],
                  coords = {'lat': (('y', 'x'), lat),
                            'lon': (('y', 'x'), lon)})

f = plt.figure()

ax = plt.subplot(2, 1, 1, projection=ccrs.PlateCarree())
da.plot.pcolormesh('lon', 'lat', ax=ax, transform=ccrs.PlateCarree(), infer_intervals=True)
ax.scatter(lon, lat, transform=ccrs.PlateCarree())
ax.coastlines()
ax.set_title('pcolormesh')

ax = plt.subplot(2, 1, 2, projection=ccrs.PlateCarree())
da.plot.pcolormesh('lon', 'lat', ax=ax, transform=ccrs.PlateCarree(), infer_intervals=False)
ax.scatter(lon, lat, transform=ccrs.PlateCarree())
ax.coastlines()
ax.set_title('pcolormesh - no intervals')

plt.show()

figure_1-1

I'll try to have a closer look later next week

if len(x.shape) == 1:
x = _infer_interval_breaks(x)
y = _infer_interval_breaks(y)
else:
# we have to infer the intervals on both axes
x = _infer_interval_breaks(x, axis=1)
x = _infer_interval_breaks(x, axis=0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clever!

y = _infer_interval_breaks(y, axis=1)
y = _infer_interval_breaks(y, axis=0)

primitive = ax.pcolormesh(x, y, z, **kwargs)

Expand Down
13 changes: 13 additions & 0 deletions xarray/test/test_plot.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import division
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add this to the top of literally every single Python file in xarray?

That will keep this easier to reason about.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:D sure

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes OK. shouldn't I do this in a separate PR? I'll do it tomorrow though, it's getting late in Innsbruck...


import inspect

import numpy as np
Expand Down Expand Up @@ -112,6 +114,17 @@ def test__infer_interval_breaks(self):
self.assertArrayEqual(pd.date_range('20000101', periods=4) - np.timedelta64(12, 'h'),
_infer_interval_breaks(pd.date_range('20000101', periods=3)))

# make a bounded 2D array that we will center and re-infer
xref, yref = np.meshgrid(np.arange(6), np.arange(5))
cx = (xref[1:, 1:] + xref[:-1, :-1]) / 2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On Python 2, I think this will still have integer dtype. Can you divide by 2.0 instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a from __future__ import division instead

cy = (yref[1:, 1:] + yref[:-1, :-1]) / 2
x = _infer_interval_breaks(cx, axis=1)
x = _infer_interval_breaks(x, axis=0)
y = _infer_interval_breaks(cy, axis=1)
y = _infer_interval_breaks(y, axis=0)
np.testing.assert_allclose(xref, x)
np.testing.assert_allclose(yref, y)

def test_datetime_dimension(self):
nrow = 3
ncol = 4
Expand Down