Skip to content

Commit 603911d

Browse files
authored
Support for additional axis kwargs (#2294)
1. Support xscale, yscale, xticks, yticks, xlim, ylim kwargs. 2. Small fixes: 1. Forgot to replace autofmt_xdate for 2D plots. 2. Use matplotlib's axis inverting methods. 3. Don't automatically set histogram ylabel to be 'count'.
1 parent 5d8670f commit 603911d

File tree

4 files changed

+149
-28
lines changed

4 files changed

+149
-28
lines changed

doc/plotting.rst

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,6 @@ If required, the automatic legend can be turned off using ``add_legend=False``.
212212
``hue`` can be passed directly to :py:func:`xarray.plot` as `air.isel(lon=10, lat=[19,21,22]).plot(hue='lat')`.
213213

214214

215-
216-
217215
Dimension along y-axis
218216
~~~~~~~~~~~~~~~~~~~~~~
219217

@@ -224,8 +222,8 @@ It is also possible to make line plots such that the data are on the x-axis and
224222
@savefig plotting_example_xy_kwarg.png
225223
air.isel(time=10, lon=[10, 11]).plot(y='lat', hue='lon')
226224
227-
Changing Axes Direction
228-
-----------------------
225+
Other axes kwargs
226+
-----------------
229227

230228
The keyword arguments ``xincrease`` and ``yincrease`` let you control the axes direction.
231229

@@ -234,6 +232,9 @@ The keyword arguments ``xincrease`` and ``yincrease`` let you control the axes d
234232
@savefig plotting_example_xincrease_yincrease_kwarg.png
235233
air.isel(time=10, lon=[10, 11]).plot.line(y='lat', hue='lon', xincrease=False, yincrease=False)
236234
235+
In addition, one can use ``xscale, yscale`` to set axes scaling; ``xticks, yticks`` to set axes ticks and ``xlim, ylim`` to set axes limits. These accept the same values as the matplotlib methods ``Axes.set_(x,y)scale()``, ``Axes.set_(x,y)ticks()``, ``Axes.set_(x,y)lim()`` respectively.
236+
237+
237238
Two Dimensions
238239
--------------
239240

doc/whats-new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ Documentation
3636
Enhancements
3737
~~~~~~~~~~~~
3838

39+
- :py:meth:`plot()` now accepts the kwargs ``xscale, yscale, xlim, ylim, xticks, yticks`` just like Pandas. Also ``xincrease=False, yincrease=False`` now use matplotlib's axis inverting methods instead of setting limits.
40+
By `Deepak Cherian <https://github.com/dcherian>`_. (:issue:`2224`)
41+
3942
- DataArray coordinates and Dataset coordinates and data variables are
4043
now displayed as `a b ... y z` rather than `a b c d ...`.
4144
(:issue:`1186`)

xarray/plot/plot.py

Lines changed: 76 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,10 @@ def line(darray, *args, **kwargs):
270270
Coordinates for x, y axis. Only one of these may be specified.
271271
The other coordinate plots values from the DataArray on which this
272272
plot method is called.
273+
xscale, yscale : 'linear', 'symlog', 'log', 'logit', optional
274+
Specifies scaling for the x- and y-axes respectively
275+
xticks, yticks : Specify tick locations for x- and y-axes
276+
xlim, ylim : Specify x- and y-axes limits
273277
xincrease : None, True, or False, optional
274278
Should the values on the x axes be increasing from left to right?
275279
if None, use the default for the matplotlib function.
@@ -305,8 +309,14 @@ def line(darray, *args, **kwargs):
305309
hue = kwargs.pop('hue', None)
306310
x = kwargs.pop('x', None)
307311
y = kwargs.pop('y', None)
308-
xincrease = kwargs.pop('xincrease', True)
309-
yincrease = kwargs.pop('yincrease', True)
312+
xincrease = kwargs.pop('xincrease', None) # default needs to be None
313+
yincrease = kwargs.pop('yincrease', None)
314+
xscale = kwargs.pop('xscale', None) # default needs to be None
315+
yscale = kwargs.pop('yscale', None)
316+
xticks = kwargs.pop('xticks', None)
317+
yticks = kwargs.pop('yticks', None)
318+
xlim = kwargs.pop('xlim', None)
319+
ylim = kwargs.pop('ylim', None)
310320
add_legend = kwargs.pop('add_legend', True)
311321
_labels = kwargs.pop('_labels', True)
312322
if args is ():
@@ -343,7 +353,8 @@ def line(darray, *args, **kwargs):
343353
xlabels.set_rotation(30)
344354
xlabels.set_ha('right')
345355

346-
_update_axes_limits(ax, xincrease, yincrease)
356+
_update_axes(ax, xincrease, yincrease, xscale, yscale,
357+
xticks, yticks, xlim, ylim)
347358

348359
return primitive
349360

@@ -378,37 +389,69 @@ def hist(darray, figsize=None, size=None, aspect=None, ax=None, **kwargs):
378389
"""
379390
ax = get_axis(figsize, size, aspect, ax)
380391

392+
xincrease = kwargs.pop('xincrease', None) # default needs to be None
393+
yincrease = kwargs.pop('yincrease', None)
394+
xscale = kwargs.pop('xscale', None) # default needs to be None
395+
yscale = kwargs.pop('yscale', None)
396+
xticks = kwargs.pop('xticks', None)
397+
yticks = kwargs.pop('yticks', None)
398+
xlim = kwargs.pop('xlim', None)
399+
ylim = kwargs.pop('ylim', None)
400+
381401
no_nan = np.ravel(darray.values)
382402
no_nan = no_nan[pd.notnull(no_nan)]
383403

384404
primitive = ax.hist(no_nan, **kwargs)
385405

386-
ax.set_ylabel('Count')
387-
388406
ax.set_title('Histogram')
389407
ax.set_xlabel(label_from_attrs(darray))
390408

409+
_update_axes(ax, xincrease, yincrease, xscale, yscale,
410+
xticks, yticks, xlim, ylim)
411+
391412
return primitive
392413

393414

394-
def _update_axes_limits(ax, xincrease, yincrease):
415+
def _update_axes(ax, xincrease, yincrease,
416+
xscale=None, yscale=None,
417+
xticks=None, yticks=None,
418+
xlim=None, ylim=None):
395419
"""
396-
Update axes in place to increase or decrease
397-
For use in _plot2d
420+
Update axes with provided parameters
398421
"""
399422
if xincrease is None:
400423
pass
401-
elif xincrease:
402-
ax.set_xlim(sorted(ax.get_xlim()))
403-
elif not xincrease:
404-
ax.set_xlim(sorted(ax.get_xlim(), reverse=True))
424+
elif xincrease and ax.xaxis_inverted():
425+
ax.invert_xaxis()
426+
elif not xincrease and not ax.xaxis_inverted():
427+
ax.invert_xaxis()
405428

406429
if yincrease is None:
407430
pass
408-
elif yincrease:
409-
ax.set_ylim(sorted(ax.get_ylim()))
410-
elif not yincrease:
411-
ax.set_ylim(sorted(ax.get_ylim(), reverse=True))
431+
elif yincrease and ax.yaxis_inverted():
432+
ax.invert_yaxis()
433+
elif not yincrease and not ax.yaxis_inverted():
434+
ax.invert_yaxis()
435+
436+
# The default xscale, yscale needs to be None.
437+
# If we set a scale it resets the axes formatters,
438+
# This means that set_xscale('linear') on a datetime axis
439+
# will remove the date labels. So only set the scale when explicitly
440+
# asked to. https://github.com/matplotlib/matplotlib/issues/8740
441+
if xscale is not None:
442+
ax.set_xscale(xscale)
443+
if yscale is not None:
444+
ax.set_yscale(yscale)
445+
446+
if xticks is not None:
447+
ax.set_xticks(xticks)
448+
if yticks is not None:
449+
ax.set_yticks(yticks)
450+
451+
if xlim is not None:
452+
ax.set_xlim(xlim)
453+
if ylim is not None:
454+
ax.set_ylim(ylim)
412455

413456

414457
# MUST run before any 2d plotting functions are defined since
@@ -500,6 +543,10 @@ def _plot2d(plotfunc):
500543
If passed, make column faceted plots on this dimension name
501544
col_wrap : integer, optional
502545
Use together with ``col`` to wrap faceted plots
546+
xscale, yscale : 'linear', 'symlog', 'log', 'logit', optional
547+
Specifies scaling for the x- and y-axes respectively
548+
xticks, yticks : Specify tick locations for x- and y-axes
549+
xlim, ylim : Specify x- and y-axes limits
503550
xincrease : None, True, or False, optional
504551
Should the values on the x axes be increasing from left to right?
505552
if None, use the default for the matplotlib function.
@@ -577,7 +624,8 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
577624
cmap=None, center=None, robust=False, extend=None,
578625
levels=None, infer_intervals=None, colors=None,
579626
subplot_kws=None, cbar_ax=None, cbar_kwargs=None,
580-
**kwargs):
627+
xscale=None, yscale=None, xticks=None, yticks=None,
628+
xlim=None, ylim=None, **kwargs):
581629
# All 2d plots in xarray share this function signature.
582630
# Method signature below should be consistent.
583631

@@ -723,11 +771,17 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
723771
raise ValueError("cbar_ax and cbar_kwargs can't be used with "
724772
"add_colorbar=False.")
725773

726-
_update_axes_limits(ax, xincrease, yincrease)
774+
_update_axes(ax, xincrease, yincrease, xscale, yscale,
775+
xticks, yticks, xlim, ylim)
727776

728777
# Rotate dates on xlabels
778+
# Do this without calling autofmt_xdate so that x-axes ticks
779+
# on other subplots (if any) are not deleted.
780+
# https://stackoverflow.com/questions/17430105/autofmt-xdate-deletes-x-axis-labels-of-all-subplots
729781
if np.issubdtype(xval.dtype, np.datetime64):
730-
ax.get_figure().autofmt_xdate()
782+
for xlabels in ax.get_xticklabels():
783+
xlabels.set_rotation(30)
784+
xlabels.set_ha('right')
731785

732786
return primitive
733787

@@ -739,7 +793,9 @@ def plotmethod(_PlotMethods_obj, x=None, y=None, figsize=None, size=None,
739793
add_labels=True, vmin=None, vmax=None, cmap=None,
740794
colors=None, center=None, robust=False, extend=None,
741795
levels=None, infer_intervals=None, subplot_kws=None,
742-
cbar_ax=None, cbar_kwargs=None, **kwargs):
796+
cbar_ax=None, cbar_kwargs=None,
797+
xscale=None, yscale=None, xticks=None, yticks=None,
798+
xlim=None, ylim=None, **kwargs):
743799
"""
744800
The method should have the same signature as the function.
745801

xarray/tests/test_plot.py

Lines changed: 65 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -426,10 +426,6 @@ def test_xlabel_uses_name(self):
426426
self.darray.plot.hist()
427427
assert 'testpoints [testunits]' == plt.gca().get_xlabel()
428428

429-
def test_ylabel_is_count(self):
430-
self.darray.plot.hist()
431-
assert 'Count' == plt.gca().get_ylabel()
432-
433429
def test_title_is_histogram(self):
434430
self.darray.plot.hist()
435431
assert 'Histogram' == plt.gca().get_title()
@@ -1675,3 +1671,68 @@ def test_plot_cftime_data_error():
16751671
data = DataArray(data, coords=[np.arange(5)], dims=['x'])
16761672
with raises_regex(NotImplementedError, 'cftime.datetime'):
16771673
data.plot()
1674+
1675+
1676+
test_da_list = [DataArray(easy_array((10, ))),
1677+
DataArray(easy_array((10, 3))),
1678+
DataArray(easy_array((10, 3, 2)))]
1679+
1680+
1681+
@requires_matplotlib
1682+
class TestAxesKwargs(object):
1683+
@pytest.mark.parametrize('da', test_da_list)
1684+
@pytest.mark.parametrize('xincrease', [True, False])
1685+
def test_xincrease_kwarg(self, da, xincrease):
1686+
plt.clf()
1687+
da.plot(xincrease=xincrease)
1688+
assert plt.gca().xaxis_inverted() == (not xincrease)
1689+
1690+
@pytest.mark.parametrize('da', test_da_list)
1691+
@pytest.mark.parametrize('yincrease', [True, False])
1692+
def test_yincrease_kwarg(self, da, yincrease):
1693+
plt.clf()
1694+
da.plot(yincrease=yincrease)
1695+
assert plt.gca().yaxis_inverted() == (not yincrease)
1696+
1697+
@pytest.mark.parametrize('da', test_da_list)
1698+
@pytest.mark.parametrize('xscale', ['linear', 'log', 'logit', 'symlog'])
1699+
def test_xscale_kwarg(self, da, xscale):
1700+
plt.clf()
1701+
da.plot(xscale=xscale)
1702+
assert plt.gca().get_xscale() == xscale
1703+
1704+
@pytest.mark.parametrize('da', [DataArray(easy_array((10, ))),
1705+
DataArray(easy_array((10, 3)))])
1706+
@pytest.mark.parametrize('yscale', ['linear', 'log', 'logit', 'symlog'])
1707+
def test_yscale_kwarg(self, da, yscale):
1708+
plt.clf()
1709+
da.plot(yscale=yscale)
1710+
assert plt.gca().get_yscale() == yscale
1711+
1712+
@pytest.mark.parametrize('da', test_da_list)
1713+
def test_xlim_kwarg(self, da):
1714+
plt.clf()
1715+
expected = (0.0, 1000.0)
1716+
da.plot(xlim=[0, 1000])
1717+
assert plt.gca().get_xlim() == expected
1718+
1719+
@pytest.mark.parametrize('da', test_da_list)
1720+
def test_ylim_kwarg(self, da):
1721+
plt.clf()
1722+
da.plot(ylim=[0, 1000])
1723+
expected = (0.0, 1000.0)
1724+
assert plt.gca().get_ylim() == expected
1725+
1726+
@pytest.mark.parametrize('da', test_da_list)
1727+
def test_xticks_kwarg(self, da):
1728+
plt.clf()
1729+
da.plot(xticks=np.arange(5))
1730+
expected = np.arange(5).tolist()
1731+
assert np.all(plt.gca().get_xticks() == expected)
1732+
1733+
@pytest.mark.parametrize('da', test_da_list)
1734+
def test_yticks_kwarg(self, da):
1735+
plt.clf()
1736+
da.plot(yticks=np.arange(5))
1737+
expected = np.arange(5)
1738+
assert np.all(plt.gca().get_yticks() == expected)

0 commit comments

Comments
 (0)