Skip to content

Commit 69086b3

Browse files
NotSqrtshoyer
authored andcommitted
Fix maybe_promote (#1953)
* Fix maybe_promote With tests for every possible dtype: (numpy docs say `biufcmMOSUV` only) ``` for letter in string.ascii_letters: try: print(letter, np.dtype(letter)) except TypeError as exc: pass ``` * Check issubdtype of floating before timedelta64 In order to hit this branch more often * Improve maybe_promote test
1 parent 8378d3a commit 69086b3

File tree

2 files changed

+41
-2
lines changed

2 files changed

+41
-2
lines changed

xarray/core/dtypes.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ def maybe_promote(dtype):
8080
# N.B. these casting rules should match pandas
8181
if np.issubdtype(dtype, np.floating):
8282
fill_value = np.nan
83+
elif np.issubdtype(dtype, np.timedelta64):
84+
# See https://github.com/numpy/numpy/issues/10685
85+
# np.timedelta64 is a subclass of np.integer
86+
# Check np.timedelta64 before np.integer
87+
fill_value = np.timedelta64('NaT')
8388
elif np.issubdtype(dtype, np.integer):
8489
if dtype.itemsize <= 2:
8590
dtype = np.float32
@@ -90,8 +95,6 @@ def maybe_promote(dtype):
9095
fill_value = np.nan + np.nan * 1j
9196
elif np.issubdtype(dtype, np.datetime64):
9297
fill_value = np.datetime64('NaT')
93-
elif np.issubdtype(dtype, np.timedelta64):
94-
fill_value = np.timedelta64('NaT')
9598
else:
9699
dtype = object
97100
fill_value = np.nan

xarray/tests/test_dtypes.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,39 @@ def error():
5050
def test_inf(obj):
5151
assert dtypes.INF > obj
5252
assert dtypes.NINF < obj
53+
54+
55+
@pytest.mark.parametrize("kind, expected", [
56+
('a', (np.dtype('O'), 'nan')), # dtype('S')
57+
('b', (np.float32, 'nan')), # dtype('int8')
58+
('B', (np.float32, 'nan')), # dtype('uint8')
59+
('c', (np.dtype('O'), 'nan')), # dtype('S1')
60+
('D', (np.complex128, '(nan+nanj)')), # dtype('complex128')
61+
('d', (np.float64, 'nan')), # dtype('float64')
62+
('e', (np.float16, 'nan')), # dtype('float16')
63+
('F', (np.complex64, '(nan+nanj)')), # dtype('complex64')
64+
('f', (np.float32, 'nan')), # dtype('float32')
65+
('h', (np.float32, 'nan')), # dtype('int16')
66+
('H', (np.float32, 'nan')), # dtype('uint16')
67+
('i', (np.float64, 'nan')), # dtype('int32')
68+
('I', (np.float64, 'nan')), # dtype('uint32')
69+
('l', (np.float64, 'nan')), # dtype('int64')
70+
('L', (np.float64, 'nan')), # dtype('uint64')
71+
('m', (np.timedelta64, 'NaT')), # dtype('<m8')
72+
('M', (np.datetime64, 'NaT')), # dtype('<M8')
73+
('O', (np.dtype('O'), 'nan')), # dtype('O')
74+
('p', (np.float64, 'nan')), # dtype('int64')
75+
('P', (np.float64, 'nan')), # dtype('uint64')
76+
('q', (np.float64, 'nan')), # dtype('int64')
77+
('Q', (np.float64, 'nan')), # dtype('uint64')
78+
('S', (np.dtype('O'), 'nan')), # dtype('S')
79+
('U', (np.dtype('O'), 'nan')), # dtype('<U')
80+
('V', (np.dtype('O'), 'nan')), # dtype('V')
81+
])
82+
def test_maybe_promote(kind, expected):
83+
# 'g': np.float128 is not tested : not available on all platforms
84+
# 'G': np.complex256 is not tested : not available on all platforms
85+
86+
actual = dtypes.maybe_promote(np.dtype(kind))
87+
assert actual[0] == expected[0]
88+
assert str(actual[1]) == expected[1]

0 commit comments

Comments
 (0)