Skip to content

coords: retain str dtype #4759

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jan 13, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
@@ -66,6 +66,9 @@ Bug fixes
By `Anderson Banihirwe <https://github.com/andersy005>`_
- Fix a crash in orthogonal indexing on geographic coordinates with ``engine='cfgrib'`` (:issue:`4733` :pull:`4737`).
By `Alessandro Amici <https://github.com/alexamici>`_
- Coordinates with dtype ``str`` or ``bytes`` now retain their dtype on many operations,
e.g. ``reindex``, ``align``, ``concat``, ``assign``, previously they were cast to an object dtype
(:issue:`2658` and :issue:`4543`) by `Mathias Hauser <https://github.com/mathause>`_.
- Limit number of data rows when printing large datasets. (:issue:`4736`, :pull:`4750`). By `Jimmy Westling <https://github.com/illviljan>`_.
- Add ``missing_dims`` parameter to transpose (:issue:`4647`, :pull:`4767`). By `Daniel Mesejo <https://github.com/mesejo>`_.
- Resolve intervals before appending other metadata to labels when plotting (:issue:`4322`, :pull:`4794`).
12 changes: 8 additions & 4 deletions xarray/core/alignment.py
Original file line number Diff line number Diff line change
@@ -19,7 +19,7 @@

from . import dtypes, utils
from .indexing import get_indexer_nd
from .utils import is_dict_like, is_full_slice
from .utils import is_dict_like, is_full_slice, maybe_coerce_to_str
from .variable import IndexVariable, Variable

if TYPE_CHECKING:
@@ -278,10 +278,12 @@ def align(
return (obj.copy(deep=copy),)

all_indexes = defaultdict(list)
all_coords = defaultdict(list)
unlabeled_dim_sizes = defaultdict(set)
for obj in objects:
for dim in obj.dims:
if dim not in exclude:
all_coords[dim].append(obj.coords[dim])
try:
index = obj.indexes[dim]
except KeyError:
@@ -306,7 +308,7 @@ def align(
any(not index.equals(other) for other in matching_indexes)
or dim in unlabeled_dim_sizes
):
joined_indexes[dim] = index
joined_indexes[dim] = indexes[dim]
else:
if (
any(
@@ -318,9 +320,11 @@ def align(
if join == "exact":
raise ValueError(f"indexes along dimension {dim!r} are not equal")
index = joiner(matching_indexes)
# make sure str coords are not cast to object
index = maybe_coerce_to_str(index, all_coords[dim])
joined_indexes[dim] = index
else:
index = matching_indexes[0]
index = all_coords[dim][0]

if dim in unlabeled_dim_sizes:
unlabeled_sizes = unlabeled_dim_sizes[dim]
@@ -583,7 +587,7 @@ def reindex_variables(
args: tuple = (var.attrs, var.encoding)
else:
args = ()
reindexed[dim] = IndexVariable((dim,), target, *args)
reindexed[dim] = IndexVariable((dim,), indexers[dim], *args)

for dim in sizes:
if dim not in indexes and dim in indexers:
4 changes: 2 additions & 2 deletions xarray/core/concat.py
Original file line number Diff line number Diff line change
@@ -187,7 +187,7 @@ def concat(
array([[0, 1, 2],
[3, 4, 5]])
Coordinates:
* x (x) object 'a' 'b'
* x (x) <U1 'a' 'b'
* y (y) int64 10 20 30

>>> xr.concat([da.isel(x=0), da.isel(x=1)], "new_dim")
@@ -503,7 +503,7 @@ def ensure_common_dims(vars):
for k in datasets[0].variables:
if k in concat_over:
try:
vars = ensure_common_dims([ds.variables[k] for ds in datasets])
vars = ensure_common_dims([ds[k].variable for ds in datasets])
except KeyError:
raise ValueError("%r is not present in all datasets." % k)
combined = concat_vars(vars, dim, positions)
4 changes: 2 additions & 2 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
@@ -1325,8 +1325,8 @@ def broadcast_like(
[ 2.2408932 , 1.86755799, -0.97727788],
[ nan, nan, nan]])
Coordinates:
* x (x) object 'a' 'b' 'c'
* y (y) object 'a' 'b' 'c'
* x (x) <U1 'a' 'b' 'c'
* y (y) <U1 'a' 'b' 'c'
"""
if exclude is None:
exclude = set()
6 changes: 3 additions & 3 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
@@ -2565,7 +2565,7 @@ def reindex(
<xarray.Dataset>
Dimensions: (station: 4)
Coordinates:
* station (station) object 'boston' 'austin' 'seattle' 'lincoln'
* station (station) <U7 'boston' 'austin' 'seattle' 'lincoln'
Data variables:
temperature (station) float64 10.98 nan 12.06 nan
pressure (station) float64 211.8 nan 218.8 nan
@@ -2576,7 +2576,7 @@ def reindex(
<xarray.Dataset>
Dimensions: (station: 4)
Coordinates:
* station (station) object 'boston' 'austin' 'seattle' 'lincoln'
* station (station) <U7 'boston' 'austin' 'seattle' 'lincoln'
Data variables:
temperature (station) float64 10.98 0.0 12.06 0.0
pressure (station) float64 211.8 0.0 218.8 0.0
@@ -2589,7 +2589,7 @@ def reindex(
<xarray.Dataset>
Dimensions: (station: 4)
Coordinates:
* station (station) object 'boston' 'austin' 'seattle' 'lincoln'
* station (station) <U7 'boston' 'austin' 'seattle' 'lincoln'
Data variables:
temperature (station) float64 10.98 0.0 12.06 0.0
pressure (station) float64 211.8 100.0 218.8 100.0
4 changes: 3 additions & 1 deletion xarray/core/merge.py
Original file line number Diff line number Diff line change
@@ -930,9 +930,11 @@ def dataset_update_method(
if coord_names:
other[key] = value.drop_vars(coord_names)

# use ds.coords and not ds.indexes, else str coords are cast to object
indexes = {key: dataset.coords[key] for key in dataset.indexes.keys()}
return merge_core(
[dataset, other],
priority_arg=1,
indexes=dataset.indexes,
indexes=indexes,
combine_attrs="override",
)
19 changes: 19 additions & 0 deletions xarray/core/utils.py
Original file line number Diff line number Diff line change
@@ -31,6 +31,8 @@
import numpy as np
import pandas as pd

from . import dtypes

K = TypeVar("K")
V = TypeVar("V")
T = TypeVar("T")
@@ -76,6 +78,23 @@ def maybe_cast_to_coords_dtype(label, coords_dtype):
return label


def maybe_coerce_to_str(index, original_coords):
"""maybe coerce a pandas Index back to a nunpy array of type str

pd.Index uses object-dtype to store str - try to avoid this for coords
"""

try:
result_type = dtypes.result_type(*original_coords)
except TypeError:
pass
else:
if result_type.kind in "SU":
index = np.asarray(index, dtype=result_type.type)

return index


def safe_cast_to_index(array: Any) -> pd.Index:
"""Given an array, safely cast it to a pandas.Index.

4 changes: 4 additions & 0 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
@@ -48,6 +48,7 @@
ensure_us_time_resolution,
infix_dims,
is_duck_array,
maybe_coerce_to_str,
)

NON_NUMPY_SUPPORTED_ARRAY_TYPES = (
@@ -2523,6 +2524,9 @@ def concat(cls, variables, dim="concat_dim", positions=None, shortcut=False):
indices = nputils.inverse_permutation(np.concatenate(positions))
data = data.take(indices)

# keep as str if possible as pandas.Index uses object (converts to numpy array)
data = maybe_coerce_to_str(data, variables)

attrs = dict(first_var.attrs)
if not shortcut:
for var in variables:
44 changes: 44 additions & 0 deletions xarray/tests/test_concat.py
Original file line number Diff line number Diff line change
@@ -376,6 +376,30 @@ def test_concat_fill_value(self, fill_value):
actual = concat(datasets, dim="t", fill_value=fill_value)
assert_identical(actual, expected)

@pytest.mark.parametrize("dtype", [str, bytes])
@pytest.mark.parametrize("dim", ["x1", "x2"])
def test_concat_str_dtype(self, dtype, dim):

data = np.arange(4).reshape([2, 2])

da1 = Dataset(
{
"data": (["x1", "x2"], data),
"x1": [0, 1],
"x2": np.array(["a", "b"], dtype=dtype),
}
)
da2 = Dataset(
{
"data": (["x1", "x2"], data),
"x1": np.array([1, 2]),
"x2": np.array(["c", "d"], dtype=dtype),
}
)
actual = concat([da1, da2], dim=dim)

assert np.issubdtype(actual.x2.dtype, dtype)


class TestConcatDataArray:
def test_concat(self):
@@ -525,6 +549,26 @@ def test_concat_combine_attrs_kwarg(self):
actual = concat([da1, da2], dim="x", combine_attrs=combine_attrs)
assert_identical(actual, expected[combine_attrs])

@pytest.mark.parametrize("dtype", [str, bytes])
@pytest.mark.parametrize("dim", ["x1", "x2"])
def test_concat_str_dtype(self, dtype, dim):

data = np.arange(4).reshape([2, 2])

da1 = DataArray(
data=data,
dims=["x1", "x2"],
coords={"x1": [0, 1], "x2": np.array(["a", "b"], dtype=dtype)},
)
da2 = DataArray(
data=data,
dims=["x1", "x2"],
coords={"x1": np.array([1, 2]), "x2": np.array(["c", "d"], dtype=dtype)},
)
actual = concat([da1, da2], dim=dim)

assert np.issubdtype(actual.x2.dtype, dtype)


@pytest.mark.parametrize("attr1", ({"a": {"meta": [10, 20, 30]}}, {"a": [1, 2, 3]}, {}))
@pytest.mark.parametrize("attr2", ({"a": [1, 2, 3]}, {}))
33 changes: 33 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
@@ -1568,6 +1568,19 @@ def test_reindex_fill_value(self, fill_value):
)
assert_identical(expected, actual)

@pytest.mark.parametrize("dtype", [str, bytes])
def test_reindex_str_dtype(self, dtype):

data = DataArray(
[1, 2], dims="x", coords={"x": np.array(["a", "b"], dtype=dtype)}
)

actual = data.reindex(x=data.x)
expected = data

assert_identical(expected, actual)
assert actual.dtype == expected.dtype

def test_rename(self):
renamed = self.dv.rename("bar")
assert_identical(renamed.to_dataset(), self.ds.rename({"foo": "bar"}))
@@ -3435,6 +3448,26 @@ def test_align_without_indexes_errors(self):
DataArray([1, 2], coords=[("x", [0, 1])]),
)

def test_align_str_dtype(self):

a = DataArray([0, 1], dims=["x"], coords={"x": ["a", "b"]})
b = DataArray([1, 2], dims=["x"], coords={"x": ["b", "c"]})

expected_a = DataArray(
[0, 1, np.NaN], dims=["x"], coords={"x": ["a", "b", "c"]}
)
expected_b = DataArray(
[np.NaN, 1, 2], dims=["x"], coords={"x": ["a", "b", "c"]}
)

actual_a, actual_b = xr.align(a, b, join="outer")

assert_identical(expected_a, actual_a)
assert expected_a.x.dtype == actual_a.x.dtype

assert_identical(expected_b, actual_b)
assert expected_b.x.dtype == actual_b.x.dtype

def test_broadcast_arrays(self):
x = DataArray([1, 2], coords=[("a", [-1, -2])], name="x")
y = DataArray([1, 2], coords=[("b", [3, 4])], name="y")
34 changes: 34 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1949,6 +1949,16 @@ def test_reindex_like_fill_value(self, fill_value):
)
assert_identical(expected, actual)

@pytest.mark.parametrize("dtype", [str, bytes])
def test_reindex_str_dtype(self, dtype):
data = Dataset({"data": ("x", [1, 2]), "x": np.array(["a", "b"], dtype=dtype)})

actual = data.reindex(x=data.x)
expected = data

assert_identical(expected, actual)
assert actual.x.dtype == expected.x.dtype

@pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0, {"foo": 2, "bar": 1}])
def test_align_fill_value(self, fill_value):
x = Dataset({"foo": DataArray([1, 2], dims=["x"], coords={"x": [1, 2]})})
@@ -2134,6 +2144,22 @@ def test_align_non_unique(self):
with raises_regex(ValueError, "cannot reindex or align"):
align(x, y)

def test_align_str_dtype(self):

a = Dataset({"foo": ("x", [0, 1]), "x": ["a", "b"]})
b = Dataset({"foo": ("x", [1, 2]), "x": ["b", "c"]})

expected_a = Dataset({"foo": ("x", [0, 1, np.NaN]), "x": ["a", "b", "c"]})
expected_b = Dataset({"foo": ("x", [np.NaN, 1, 2]), "x": ["a", "b", "c"]})

actual_a, actual_b = xr.align(a, b, join="outer")

assert_identical(expected_a, actual_a)
assert expected_a.x.dtype == actual_a.x.dtype

assert_identical(expected_b, actual_b)
assert expected_b.x.dtype == actual_b.x.dtype

def test_broadcast(self):
ds = Dataset(
{"foo": 0, "bar": ("x", [1]), "baz": ("y", [2, 3])}, {"c": ("x", [4])}
@@ -3420,6 +3446,14 @@ def test_setitem_align_new_indexes(self):
)
assert_identical(ds, expected)

@pytest.mark.parametrize("dtype", [str, bytes])
def test_setitem_str_dtype(self, dtype):

ds = xr.Dataset(coords={"x": np.array(["x", "y"], dtype=dtype)})
ds["foo"] = xr.DataArray(np.array([0, 0]), dims=["x"])

assert np.issubdtype(ds.x.dtype, dtype)

def test_assign(self):
ds = Dataset()
actual = ds.assign(x=[0, 1, 2], y=2)
27 changes: 27 additions & 0 deletions xarray/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -39,6 +39,33 @@ def test_safe_cast_to_index():
assert expected.dtype == actual.dtype


@pytest.mark.parametrize(
"a, b, expected", [["a", "b", np.array(["a", "b"])], [1, 2, pd.Index([1, 2])]]
)
def test_maybe_coerce_to_str(a, b, expected):

a = np.array([a])
b = np.array([b])
index = pd.Index(a).append(pd.Index(b))

actual = utils.maybe_coerce_to_str(index, [a, b])

assert_array_equal(expected, actual)
assert expected.dtype == actual.dtype


def test_maybe_coerce_to_str_minimal_str_dtype():

a = np.array(["a", "a_long_string"])
index = pd.Index(["a"])

actual = utils.maybe_coerce_to_str(index, [a])
expected = np.array("a")

assert_array_equal(expected, actual)
assert expected.dtype == actual.dtype


@requires_cftime
def test_safe_cast_to_index_cftimeindex():
date_types = _all_cftime_date_types()
11 changes: 11 additions & 0 deletions xarray/tests/test_variable.py
Original file line number Diff line number Diff line change
@@ -2094,6 +2094,17 @@ def test_concat_multiindex(self):
assert_identical(actual, expected)
assert isinstance(actual.to_index(), pd.MultiIndex)

@pytest.mark.parametrize("dtype", [str, bytes])
def test_concat_str_dtype(self, dtype):

a = IndexVariable("x", np.array(["a"], dtype=dtype))
b = IndexVariable("x", np.array(["b"], dtype=dtype))
expected = IndexVariable("x", np.array(["a", "b"], dtype=dtype))

actual = IndexVariable.concat([a, b])
assert actual.identical(expected)
assert np.issubdtype(actual.dtype, dtype)

def test_coordinate_alias(self):
with pytest.warns(Warning, match="deprecated"):
x = Coordinate("x", [1, 2, 3])