Skip to content

Commit 83d2c2d

Browse files
willirathshoyer
authored andcommitted
Allow for plotting dummy netCDF4.datetime objects. (#1261)
* Allow for plotting dummy netCDF4.datetime objects. * Add decorator to skip if netCDF4 is missing. * Test line plot with netCDF4.datetime coordinate. * Rewrite checks for plottable data. Allow for numeric and date-like numpy data types and for datetime objects. * Address review by @shoyer - Use datetime instead of netCDF4.datetime. - Do not warn but assert in private function. * Line continuation after parenthesis. * Check for generic dtype to avoid expensive loop
1 parent 1cad803 commit 83d2c2d

File tree

2 files changed

+49
-13
lines changed

2 files changed

+49
-13
lines changed

xarray/plot/plot.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,32 +13,48 @@
1313

1414
import numpy as np
1515
import pandas as pd
16+
from datetime import datetime
1617

1718
from .utils import _determine_cmap_params, _infer_xy_labels, get_axis
1819
from .facetgrid import FacetGrid
1920
from xarray.core.pycompat import basestring
2021

2122

22-
# Maybe more appropriate to keep this in .utils
23-
def _right_dtype(arr, types):
23+
def _valid_numpy_subdtype(x, numpy_types):
2424
"""
25-
Is the numpy array a sub dtype of anything in types?
25+
Is any dtype from numpy_types superior to the dtype of x?
2626
"""
27-
return any(np.issubdtype(arr.dtype, t) for t in types)
27+
# If any of the types given in numpy_types is understood as numpy.generic,
28+
# all possible x will be considered valid. This is probably unwanted.
29+
for t in numpy_types:
30+
assert not np.issubdtype(np.generic, t)
2831

32+
return any(np.issubdtype(x.dtype, t) for t in numpy_types)
2933

30-
def _ensure_plottable(*args):
34+
35+
def _valid_other_type(x, types):
3136
"""
32-
Raise exception if there is anything in args that can't be plotted on
33-
an axis.
37+
Do all elements of x have a type from types?
3438
"""
35-
plottypes = [np.floating, np.integer, np.timedelta64, np.datetime64]
39+
if not np.issubdtype(np.generic, x.dtype):
40+
return False
41+
else:
42+
return all(any(isinstance(el, t) for t in types) for el in np.ravel(x))
3643

37-
# Lists need to be converted to np.arrays here.
38-
if not any(_right_dtype(np.array(x), plottypes) for x in args):
39-
raise TypeError('Plotting requires coordinates to be numeric '
40-
'or dates.')
4144

45+
def _ensure_plottable(*args):
46+
"""
47+
Raise exception if there is anything in args that can't be plotted on an
48+
axis.
49+
"""
50+
numpy_types = [np.floating, np.integer, np.timedelta64, np.datetime64]
51+
other_types = [datetime]
52+
53+
for x in args:
54+
if not (_valid_numpy_subdtype(np.array(x), numpy_types)
55+
or _valid_other_type(np.array(x), other_types)):
56+
raise TypeError('Plotting requires coordinates to be numeric '
57+
'or dates.')
4258

4359
def _easy_facetgrid(darray, plotfunc, x, y, row=None, col=None,
4460
col_wrap=None, sharex=True, sharey=True, aspect=None,

xarray/tests/test_plot.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import numpy as np
1818
import pandas as pd
19+
from datetime import datetime
1920

2021
from xarray import DataArray
2122

@@ -1192,7 +1193,7 @@ def test_facetgrid_polar(self):
11921193
subplot_kws=dict(projection='polar'),
11931194
sharex=False, sharey=False)
11941195

1195-
1196+
11961197
class TestFacetGrid4d(PlotTestCase):
11971198

11981199
def setUp(self):
@@ -1218,3 +1219,22 @@ def test_default_labels(self):
12181219
# Top row should be labeled
12191220
for label, ax in zip(self.darray.coords['col'].values, g.axes[0, :]):
12201221
self.assertTrue(substring_in_axes(label, ax))
1222+
1223+
1224+
class TestDatetimePlot(PlotTestCase):
1225+
1226+
def setUp(self):
1227+
'''
1228+
Create a DataArray with a time-axis that contains datetime objects.
1229+
'''
1230+
month = np.arange(1, 13, 1)
1231+
data = np.sin(2 * np.pi * month / 12.0)
1232+
1233+
darray = DataArray(data, dims=['time'])
1234+
darray.coords['time'] = np.array([datetime(2017, m, 1) for m in month])
1235+
1236+
self.darray = darray
1237+
1238+
def test_datetime_line_plot(self):
1239+
# test if line plot raises no Exception
1240+
self.darray.plot.line()

0 commit comments

Comments
 (0)