Skip to content

Commit f75c3be

Browse files
znichollsIllviljan
andauthored
Numpy string coding (#5264)
* Add failing test * Try fix * Lint * Require netCDF4 for test * Move fix to infer dtype * Update thanks to @shoyer * Whats new * Move test and add comment * Update whats-new.rst * Update whats-new.rst Co-authored-by: Illviljan <[email protected]>
1 parent b14e2d8 commit f75c3be

File tree

5 files changed

+38
-2
lines changed

5 files changed

+38
-2
lines changed

doc/whats-new.rst

+3
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ Deprecations
3737

3838
Bug fixes
3939
~~~~~~~~~
40+
- Subclasses of ``byte`` and ``str`` (e.g. ``np.str_`` and ``np.bytes_``) will now serialise to disk rather than raising a ``ValueError: unsupported dtype for netCDF4 variable: object`` as they did previously (:pull:`5264`).
41+
By `Zeb Nicholls <https://github.com/znicholls>`_.
42+
4043
- Fix applying function with non-xarray arguments using :py:func:`xr.map_blocks`.
4144
By `Cindy Chiao <https://github.com/tcchiao>`_.
4245

xarray/coding/strings.py

+2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818

1919
def create_vlen_dtype(element_type):
20+
if element_type not in (str, bytes):
21+
raise TypeError("unsupported type for vlen_dtype: {!r}".format(element_type))
2022
# based on h5py.special_dtype
2123
return np.dtype("O", metadata={"element_type": element_type})
2224

xarray/conventions.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,12 @@ def _infer_dtype(array, name=None):
157157
return np.dtype(float)
158158

159159
element = array[(0,) * array.ndim]
160-
if isinstance(element, (bytes, str)):
161-
return strings.create_vlen_dtype(type(element))
160+
# We use the base types to avoid subclasses of bytes and str (which might
161+
# not play nice with e.g. hdf5 datatypes), such as those from numpy
162+
if isinstance(element, bytes):
163+
return strings.create_vlen_dtype(bytes)
164+
elif isinstance(element, str):
165+
return strings.create_vlen_dtype(str)
162166

163167
dtype = np.array(element).dtype
164168
if dtype.kind != "O":

xarray/tests/test_backends.py

+21
Original file line numberDiff line numberDiff line change
@@ -5393,3 +5393,24 @@ def test_h5netcdf_entrypoint(tmp_path):
53935393
assert entrypoint.guess_can_open("something-local.nc4")
53945394
assert entrypoint.guess_can_open("something-local.cdf")
53955395
assert not entrypoint.guess_can_open("not-found-and-no-extension")
5396+
5397+
5398+
@requires_netCDF4
5399+
@pytest.mark.parametrize("str_type", (str, np.str_))
5400+
def test_write_file_from_np_str(str_type, tmpdir) -> None:
5401+
# https://github.com/pydata/xarray/pull/5264
5402+
scenarios = [str_type(v) for v in ["scenario_a", "scenario_b", "scenario_c"]]
5403+
years = range(2015, 2100 + 1)
5404+
tdf = pd.DataFrame(
5405+
data=np.random.random((len(scenarios), len(years))),
5406+
columns=years,
5407+
index=scenarios,
5408+
)
5409+
tdf.index.name = "scenario"
5410+
tdf.columns.name = "year"
5411+
tdf = tdf.stack()
5412+
tdf.name = "tas"
5413+
5414+
txr = tdf.to_xarray()
5415+
5416+
txr.to_netcdf(tmpdir.join("test.nc"))

xarray/tests/test_coding_strings.py

+6
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@ def test_vlen_dtype() -> None:
2929
assert strings.check_vlen_dtype(np.dtype(object)) is None
3030

3131

32+
@pytest.mark.parametrize("numpy_str_type", (np.str_, np.bytes_))
33+
def test_numpy_subclass_handling(numpy_str_type) -> None:
34+
with pytest.raises(TypeError, match="unsupported type for vlen_dtype"):
35+
strings.create_vlen_dtype(numpy_str_type)
36+
37+
3238
def test_EncodedStringCoder_decode() -> None:
3339
coder = strings.EncodedStringCoder()
3440

0 commit comments

Comments
 (0)