|
13 | 13 |
|
14 | 14 | import numpy as np
|
15 | 15 | import pandas as pd
|
| 16 | +from datetime import datetime |
16 | 17 |
|
17 | 18 | from .utils import _determine_cmap_params, _infer_xy_labels, get_axis
|
18 | 19 | from .facetgrid import FacetGrid
|
19 | 20 | from xarray.core.pycompat import basestring
|
20 | 21 |
|
21 | 22 |
|
22 |
| -# Maybe more appropriate to keep this in .utils |
23 |
| -def _right_dtype(arr, types): |
| 23 | +def _valid_numpy_subdtype(x, numpy_types): |
24 | 24 | """
|
25 |
| - Is the numpy array a sub dtype of anything in types? |
| 25 | + Is any dtype from numpy_types superior to the dtype of x? |
26 | 26 | """
|
27 |
| - return any(np.issubdtype(arr.dtype, t) for t in types) |
| 27 | + # If any of the types given in numpy_types is understood as numpy.generic, |
| 28 | + # all possible x will be considered valid. This is probably unwanted. |
| 29 | + for t in numpy_types: |
| 30 | + assert not np.issubdtype(np.generic, t) |
28 | 31 |
|
| 32 | + return any(np.issubdtype(x.dtype, t) for t in numpy_types) |
29 | 33 |
|
30 |
| -def _ensure_plottable(*args): |
| 34 | + |
| 35 | +def _valid_other_type(x, types): |
31 | 36 | """
|
32 |
| - Raise exception if there is anything in args that can't be plotted on |
33 |
| - an axis. |
| 37 | + Do all elements of x have a type from types? |
34 | 38 | """
|
35 |
| - plottypes = [np.floating, np.integer, np.timedelta64, np.datetime64] |
| 39 | + if not np.issubdtype(np.generic, x.dtype): |
| 40 | + return False |
| 41 | + else: |
| 42 | + return all(any(isinstance(el, t) for t in types) for el in np.ravel(x)) |
36 | 43 |
|
37 |
| - # Lists need to be converted to np.arrays here. |
38 |
| - if not any(_right_dtype(np.array(x), plottypes) for x in args): |
39 |
| - raise TypeError('Plotting requires coordinates to be numeric ' |
40 |
| - 'or dates.') |
41 | 44 |
|
| 45 | +def _ensure_plottable(*args): |
| 46 | + """ |
| 47 | + Raise exception if there is anything in args that can't be plotted on an |
| 48 | + axis. |
| 49 | + """ |
| 50 | + numpy_types = [np.floating, np.integer, np.timedelta64, np.datetime64] |
| 51 | + other_types = [datetime] |
| 52 | + |
| 53 | + for x in args: |
| 54 | + if not (_valid_numpy_subdtype(np.array(x), numpy_types) |
| 55 | + or _valid_other_type(np.array(x), other_types)): |
| 56 | + raise TypeError('Plotting requires coordinates to be numeric ' |
| 57 | + 'or dates.') |
42 | 58 |
|
43 | 59 | def _easy_facetgrid(darray, plotfunc, x, y, row=None, col=None,
|
44 | 60 | col_wrap=None, sharex=True, sharey=True, aspect=None,
|
|
0 commit comments