Skip to content

allow passing coordinate names as x and y to plot methods #608

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Oct 9, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 22 additions & 15 deletions doc/plotting.rst
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,7 @@ Simple Example
~~~~~~~~~~~~~~

The default method :py:meth:`xray.DataArray.plot` sees that the data is
2 dimensional. If the coordinates are uniformly spaced then it
calls :py:func:`xray.plot.imshow`.
2 dimensional and calls :py:func:`xray.plot.pcolormesh`.

.. ipython:: python

Expand All @@ -159,6 +158,14 @@ and ``xincrease``.
@savefig 2d_simple_yincrease.png width=4in
air2d.plot(yincrease=False)

.. note::

We use :py:func:`xray.plot.pcolormesh` as the default two-dimensional plot
method because it is more flexible than :py:func:`xray.plot.imshow`.
However, for large arrays, ``imshow`` can be much faster than ``pcolormesh``.
If speed is important to you and you are plotting a regular mesh, consider
using ``imshow``.

Missing Values
~~~~~~~~~~~~~~

Expand All @@ -176,9 +183,9 @@ Xray plots data with :ref:`missing_values`.
Nonuniform Coordinates
~~~~~~~~~~~~~~~~~~~~~~

It's not necessary for the coordinates to be evenly spaced. If not, then
:py:meth:`xray.DataArray.plot` produces a filled contour plot by calling
:py:func:`xray.plot.contourf`.
It's not necessary for the coordinates to be evenly spaced. Both
:py:func:`xray.plot.pcolormesh` (default) and :py:func:`xray.plot.contourf` can
produce plots with nonuniform coordinates.

.. ipython:: python

Expand All @@ -201,6 +208,7 @@ matplotlib is available.
plt.title('These colors prove North America\nhas fallen in the ocean')
plt.ylabel('latitude')
plt.xlabel('longitude')
plt.tight_layout()

@savefig plotting_2d_call_matplotlib.png width=4in
plt.show()
Expand Down Expand Up @@ -376,8 +384,8 @@ Faceted plotting supports other arguments common to xray 2d plots.
hasoutliers[-1, -1, -1] = 400

@savefig plot_facet_robust.png height=12in
g = hasoutliers.plot.imshow('lon', 'lat', col='time', col_wrap=3,
robust=True, cmap='viridis')
g = hasoutliers.plot.pcolormesh('lon', 'lat', col='time', col_wrap=3,
robust=True, cmap='viridis')

FacetGrid Objects
~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -473,14 +481,13 @@ plotting function based on the dimensions of the ``DataArray`` and whether
the coordinates are sorted and uniformly spaced. This table
describes what gets plotted:

=============== =========== ===========================
Dimensions Coordinates Plotting function
--------------- ----------- ---------------------------
1 :py:func:`xray.plot.line`
2 Uniform :py:func:`xray.plot.imshow`
2 Irregular :py:func:`xray.plot.contourf`
Anything else :py:func:`xray.plot.hist`
=============== =========== ===========================
=============== ===========================
Dimensions Plotting function
--------------- ---------------------------
1 :py:func:`xray.plot.line`
2 :py:func:`xray.plot.pcolormesh`
Anything else :py:func:`xray.plot.hist`
=============== ===========================

Coordinates
~~~~~~~~~~~
Expand Down
8 changes: 6 additions & 2 deletions xray/plot/facetgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np

from ..core.formatting import format_item
from .utils import _determine_cmap_params
from .utils import _determine_cmap_params, _infer_xy_labels


# Overrides axes.labelsize, xtick.major.size, ytick.major.size
Expand Down Expand Up @@ -242,6 +242,10 @@ def map_dataarray(self, func, x, y, **kwargs):
defaults.update(cmap_params)
defaults.update(kwargs)

# Get x, y labels for the first subplot
x, y = _infer_xy_labels(darray=self.data.loc[self.name_dicts.flat[0]],
x=x, y=y)

for d, ax in zip(self.name_dicts.flat, self.axes.flat):
# None is the sentinel value
if d is not None:
Expand Down Expand Up @@ -270,7 +274,7 @@ def map_dataarray(self, func, x, y, **kwargs):
extend=cmap_params['extend'])

if self.data.name:
cbar.set_label(self.data.name, rotation=270,
cbar.set_label(self.data.name, rotation=90,
verticalalignment='bottom')

self._x_var = x
Expand Down
83 changes: 21 additions & 62 deletions xray/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@
import numpy as np
import pandas as pd

from .utils import _determine_cmap_params
from .utils import _determine_cmap_params, _infer_xy_labels
from .facetgrid import FacetGrid
from ..core.utils import is_uniform_spaced


# Maybe more appropriate to keep this in .utils
Expand All @@ -39,47 +38,6 @@ def _ensure_plottable(*args):
'or dates.')


def _infer_xy_labels(plotfunc, darray, x, y):
"""
Determine x and y labels when some are missing. For use in _plot2d

darray is a 2 dimensional data array.
"""
dims = list(darray.dims)

if len(dims) != 2:
raise ValueError('{type} plots are for 2 dimensional DataArrays. '
'Passed DataArray has {ndim} dimensions'
.format(type=plotfunc.__name__, ndim=len(dims)))

if x and x not in dims:
raise KeyError('{0} is not a dimension of this DataArray. Use '
'{1} or {2} for x'
.format(x, *dims))

if y and y not in dims:
raise KeyError('{0} is not a dimension of this DataArray. Use '
'{1} or {2} for y'
.format(y, *dims))

# Get label names
if x and y:
xlab = x
ylab = y
elif x and not y:
xlab = x
del dims[dims.index(x)]
ylab = dims.pop()
elif y and not x:
ylab = y
del dims[dims.index(y)]
xlab = dims.pop()
else:
ylab, xlab = dims

return xlab, ylab


def _easy_facetgrid(darray, plotfunc, x, y, row=None, col=None, col_wrap=None,
aspect=1, size=3, subplot_kws=None, **kwargs):
"""
Expand All @@ -99,19 +57,18 @@ def _easy_facetgrid(darray, plotfunc, x, y, row=None, col=None, col_wrap=None,
def plot(darray, row=None, col=None, col_wrap=None, ax=None, rtol=0.01,
subplot_kws=None, **kwargs):
"""
Default plot of DataArray using matplotlib / pylab.
Default plot of DataArray using matplotlib.pyplot.

Calls xray plotting function based on the dimensions of
darray.squeeze()

=============== =========== ===========================
Dimensions Coordinates Plotting function
--------------- ----------- ---------------------------
1 :py:func:`xray.plot.line`
2 Uniform :py:func:`xray.plot.imshow`
2 Irregular :py:func:`xray.plot.contourf`
Anything else :py:func:`xray.plot.hist`
=============== =========== ===========================
=============== ===========================
Dimensions Plotting function
--------------- ---------------------------
1 :py:func:`xray.plot.line`
2 :py:func:`xray.plot.pcolormesh`
Anything else :py:func:`xray.plot.hist`
=============== ===========================

Parameters
----------
Expand Down Expand Up @@ -156,9 +113,7 @@ def plot(darray, row=None, col=None, col_wrap=None, ax=None, rtol=0.01,
kwargs['col_wrap'] = col_wrap
kwargs['subplot_kws'] = subplot_kws

indexes = (darray.indexes[dim].values for dim in plot_dims)
uniform = all(is_uniform_spaced(i, rtol=rtol) for i in indexes)
plotfunc = imshow if uniform else contourf
plotfunc = pcolormesh
else:
if row or col:
raise ValueError(error_msg)
Expand Down Expand Up @@ -376,7 +331,7 @@ def _plot2d(plotfunc):

@functools.wraps(plotfunc)
def newplotfunc(darray, x=None, y=None, ax=None, row=None, col=None,
col_wrap=None, xincrease=None, yincrease=None,
col_wrap=None, xincrease=True, yincrease=True,
add_colorbar=True, add_labels=True, vmin=None, vmax=None,
cmap=None, center=None, robust=False, extend=None,
levels=None, colors=None, subplot_kws=None, **kwargs):
Expand Down Expand Up @@ -416,8 +371,7 @@ def newplotfunc(darray, x=None, y=None, ax=None, row=None, col=None,
if ax is None:
ax = plt.gca()

xlab, ylab = _infer_xy_labels(plotfunc=plotfunc, darray=darray,
x=x, y=y)
xlab, ylab = _infer_xy_labels(darray=darray, x=x, y=y)

# better to pass the ndarrays directly to plotting functions
xval = darray[xlab].values
Expand Down Expand Up @@ -471,7 +425,7 @@ def newplotfunc(darray, x=None, y=None, ax=None, row=None, col=None,
if add_colorbar:
cbar = plt.colorbar(primitive, ax=ax, extend=cmap_params['extend'])
if darray.name and add_labels:
cbar.set_label(darray.name)
cbar.set_label(darray.name, rotation=90)

_update_axes_limits(ax, xincrease, yincrease)

Expand All @@ -480,7 +434,7 @@ def newplotfunc(darray, x=None, y=None, ax=None, row=None, col=None,
# For use as DataArray.plot.plotmethod
@functools.wraps(newplotfunc)
def plotmethod(_PlotMethods_obj, x=None, y=None, ax=None, row=None,
col=None, col_wrap=None, xincrease=None, yincrease=None,
col=None, col_wrap=None, xincrease=True, yincrease=True,
add_colorbar=True, add_labels=True, vmin=None, vmax=None,
cmap=None, colors=None, center=None, robust=False,
extend=None, levels=None, subplot_kws=None, **kwargs):
Expand All @@ -506,7 +460,7 @@ def plotmethod(_PlotMethods_obj, x=None, y=None, ax=None, row=None,
@_plot2d
def imshow(x, y, z, ax, **kwargs):
"""
Image plot of 2d DataArray using matplotlib / pylab
Image plot of 2d DataArray using matplotlib.pyplot

Wraps matplotlib.pyplot.imshow

Expand All @@ -518,6 +472,11 @@ def imshow(x, y, z, ax, **kwargs):
The pixels are centered on the coordinates values. Ie, if the coordinate
value is 3.2 then the pixels for those coordinates will be centered on 3.2.
"""

if x.ndim != 1 or y.ndim != 1:
raise ValueError('imshow requires 1D coordinates, try using '
'pcolormesh or contour(f)')

# Centering the pixels- Assumes uniform spacing
xstep = (x[1] - x[0]) / 2.0
ystep = (y[1] - y[0]) / 2.0
Expand Down Expand Up @@ -589,7 +548,7 @@ def pcolormesh(x, y, z, ax, **kwargs):

# by default, pcolormesh picks "round" values for bounds
# this results in ugly looking plots with lots of surrounding whitespace
if not hasattr(ax, 'projection'):
if not hasattr(ax, 'projection') and x.ndim == 1 and y.ndim == 1:
# not a cartopy geoaxis
ax.set_xlim(x[0], x[-1])
ax.set_ylim(y[0], y[-1])
Expand Down
18 changes: 18 additions & 0 deletions xray/plot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,21 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None,

return dict(vmin=vmin, vmax=vmax, cmap=cmap, extend=extend,
levels=levels, cnorm=cnorm)


def _infer_xy_labels(darray, x, y):
"""
Determine x and y labels. For use in _plot2d

darray must be a 2 dimensional data array.
"""

if x is None and y is None:
if darray.ndim != 2:
raise ValueError('DataArray must be 2d')
y, x = darray.dims
elif x is None or y is None:
raise ValueError('cannot supply only one of x and y')
elif any(k not in darray.coords for k in (x, y)):
raise ValueError('x and y must be coordinate variables')
return x, y
Loading