diff --git a/xarray/conventions.py b/xarray/conventions.py index 75f816e6cb4..8c7d6be2309 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -52,16 +52,32 @@ def _var_as_tuple(var: Variable) -> T_VarTuple: return var.dims, var.data, var.attrs.copy(), var.encoding.copy() -def _infer_dtype(array, name: T_Name = None) -> np.dtype: - """Given an object array with no missing values, infer its dtype from its - first element - """ +def _infer_dtype(array, name=None): + """Given an object array with no missing values, infer its dtype from all elements.""" if array.dtype.kind != "O": raise TypeError("infer_type must be called on a dtype=object array") if array.size == 0: return np.dtype(float) + native_dtypes = set(np.vectorize(type, otypes=[object])(array.ravel())) + if len(native_dtypes) > 1 and native_dtypes != {bytes, str}: + raise ValueError( + "unable to infer dtype on variable {!r}; object array " + "contains mixed native types: {}".format( + name, ", ".join(x.__name__ for x in native_dtypes) + ) + ) + + native_dtypes = set(np.vectorize(type, otypes=[object])(array.ravel())) + if len(native_dtypes) > 1 and native_dtypes != {bytes, str}: + raise ValueError( + "unable to infer dtype on variable {!r}; object array " + "contains mixed native types: {}".format( + name, ", ".join(x.__name__ for x in native_dtypes) + ) + ) + element = array[(0,) * array.ndim] # We use the base types to avoid subclasses of bytes and str (which might # not play nice with e.g. hdf5 datatypes), such as those from numpy diff --git a/xarray/tests/test_conventions.py b/xarray/tests/test_conventions.py index d6d1303a696..be6e949edf8 100644 --- a/xarray/tests/test_conventions.py +++ b/xarray/tests/test_conventions.py @@ -495,6 +495,18 @@ def test_encoding_kwarg_fixed_width_string(self) -> None: pass +@pytest.mark.parametrize( + "data", + [ + np.array([["ab", "cdef", b"X"], [1, 2, "c"]], dtype=object), + np.array([["x", 1], ["y", 2]], dtype="object"), + ], +) +def test_infer_dtype_error_on_mixed_types(data): + with pytest.raises(ValueError, match="unable to infer dtype on variable"): + conventions._infer_dtype(data, "test") + + class TestDecodeCFVariableWithArrayUnits: def test_decode_cf_variable_with_array_units(self) -> None: v = Variable(["t"], [1, 2, 3], {"units": np.array(["foobar"], dtype=object)})