Skip to content

Commit a9e30c5

Browse files
String dtype: map builtin str alias to StringDtype (#59685)
* String dtype: map builtin str alias to StringDtype * fix tests * fix datetimelike astype and more tests * remove xfails * try fix typing * fix copy_view tests * fix remaining tests with infer_string enabled * ignore typing issue for now * move to common.py * simplify Categorical._str_get_dummies * small cleanup * fix ensure_string_array to not modify extension arrays inplace * fix ensure_string_array once more + fix is_extension_array_dtype for str * still xfail TestArrowArray::test_astype_str when not using infer_string * ensure maybe_convert_objects copies object dtype input array when inferring StringDtype * update test_1d_object_array_does_not_copy test * update constructor copy test + do not copy in maybe_convert_objects? * skip str.get_dummies test for now * use pandas_dtype() instead of registry.find * fix corner cases for calling pandas_dtype * add TODO comment in ensure_string_array
1 parent 5b6997c commit a9e30c5

File tree

32 files changed

+185
-111
lines changed

32 files changed

+185
-111
lines changed

pandas/_libs/lib.pyx

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -754,7 +754,14 @@ cpdef ndarray[object] ensure_string_array(
754754

755755
if hasattr(arr, "to_numpy"):
756756

757-
if hasattr(arr, "dtype") and arr.dtype.kind in "mM":
757+
if (
758+
hasattr(arr, "dtype")
759+
and arr.dtype.kind in "mM"
760+
# TODO: we should add a custom ArrowExtensionArray.astype implementation
761+
# that handles astype(str) specifically, avoiding ending up here and
762+
# then we can remove the below check for `_pa_array` (for ArrowEA)
763+
and not hasattr(arr, "_pa_array")
764+
):
758765
# dtype check to exclude DataFrame
759766
# GH#41409 TODO: not a great place for this
760767
out = arr.astype(str).astype(object)

pandas/_testing/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@
108108

109109
COMPLEX_DTYPES: list[Dtype] = [complex, "complex64", "complex128"]
110110
if using_string_dtype():
111-
STRING_DTYPES: list[Dtype] = [str, "U"]
111+
STRING_DTYPES: list[Dtype] = ["U"]
112112
else:
113113
STRING_DTYPES: list[Dtype] = [str, "str", "U"] # type: ignore[no-redef]
114114
COMPLEX_FLOAT_DTYPES: list[Dtype] = [*COMPLEX_DTYPES, *FLOAT_NUMPY_DTYPES]

pandas/core/arrays/categorical.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2685,7 +2685,9 @@ def _str_get_dummies(self, sep: str = "|", dtype: NpDtype | None = None):
26852685
# sep may not be in categories. Just bail on this.
26862686
from pandas.core.arrays import NumpyExtensionArray
26872687

2688-
return NumpyExtensionArray(self.astype(str))._str_get_dummies(sep, dtype)
2688+
return NumpyExtensionArray(self.to_numpy(str, na_value="NaN"))._str_get_dummies(
2689+
sep, dtype
2690+
)
26892691

26902692
# ------------------------------------------------------------------------
26912693
# GroupBy Methods

pandas/core/arrays/datetimelike.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -471,10 +471,16 @@ def astype(self, dtype, copy: bool = True):
471471

472472
return self._box_values(self.asi8.ravel()).reshape(self.shape)
473473

474+
elif is_string_dtype(dtype):
475+
if isinstance(dtype, ExtensionDtype):
476+
arr_object = self._format_native_types(na_rep=dtype.na_value) # type: ignore[arg-type]
477+
cls = dtype.construct_array_type()
478+
return cls._from_sequence(arr_object, dtype=dtype, copy=False)
479+
else:
480+
return self._format_native_types()
481+
474482
elif isinstance(dtype, ExtensionDtype):
475483
return super().astype(dtype, copy=copy)
476-
elif is_string_dtype(dtype):
477-
return self._format_native_types()
478484
elif dtype.kind in "iu":
479485
# we deliberately ignore int32 vs. int64 here.
480486
# See https://github.com/pandas-dev/pandas/issues/24381 for more.

pandas/core/dtypes/common.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
import numpy as np
1414

15+
from pandas._config import using_string_dtype
16+
1517
from pandas._libs import (
1618
Interval,
1719
Period,
@@ -1470,7 +1472,15 @@ def is_extension_array_dtype(arr_or_dtype) -> bool:
14701472
elif isinstance(dtype, np.dtype):
14711473
return False
14721474
else:
1473-
return registry.find(dtype) is not None
1475+
try:
1476+
with warnings.catch_warnings():
1477+
# pandas_dtype(..) can raise UserWarning for class input
1478+
warnings.simplefilter("ignore", UserWarning)
1479+
dtype = pandas_dtype(dtype)
1480+
except (TypeError, ValueError):
1481+
# np.dtype(..) can raise ValueError
1482+
return False
1483+
return isinstance(dtype, ExtensionDtype)
14741484

14751485

14761486
def is_ea_or_datetimelike_dtype(dtype: DtypeObj | None) -> bool:
@@ -1773,6 +1783,12 @@ def pandas_dtype(dtype) -> DtypeObj:
17731783
elif isinstance(dtype, (np.dtype, ExtensionDtype)):
17741784
return dtype
17751785

1786+
# builtin aliases
1787+
if dtype is str and using_string_dtype():
1788+
from pandas.core.arrays.string_ import StringDtype
1789+
1790+
return StringDtype(na_value=np.nan)
1791+
17761792
# registered extension types
17771793
result = registry.find(dtype)
17781794
if result is not None:

pandas/core/indexes/base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6262,7 +6262,11 @@ def _should_compare(self, other: Index) -> bool:
62626262
return False
62636263

62646264
dtype = _unpack_nested_dtype(other)
6265-
return self._is_comparable_dtype(dtype) or is_object_dtype(dtype)
6265+
return (
6266+
self._is_comparable_dtype(dtype)
6267+
or is_object_dtype(dtype)
6268+
or is_string_dtype(dtype)
6269+
)
62666270

62676271
def _is_comparable_dtype(self, dtype: DtypeObj) -> bool:
62686272
"""

pandas/core/indexes/interval.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
is_number,
5252
is_object_dtype,
5353
is_scalar,
54+
is_string_dtype,
5455
pandas_dtype,
5556
)
5657
from pandas.core.dtypes.dtypes import (
@@ -712,7 +713,7 @@ def _get_indexer(
712713
# left/right get_indexer, compare elementwise, equality -> match
713714
indexer = self._get_indexer_unique_sides(target)
714715

715-
elif not is_object_dtype(target.dtype):
716+
elif not (is_object_dtype(target.dtype) or is_string_dtype(target.dtype)):
716717
# homogeneous scalar index: use IntervalTree
717718
# we should always have self._should_partial_index(target) here
718719
target = self._maybe_convert_i8(target)

pandas/tests/arrays/floating/test_astype.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,9 @@ def test_astype_str(using_infer_string):
6868

6969
if using_infer_string:
7070
expected = pd.array(["0.1", "0.2", None], dtype=pd.StringDtype(na_value=np.nan))
71-
tm.assert_extension_array_equal(a.astype("str"), expected)
7271

73-
# TODO(infer_string) this should also be a string array like above
74-
expected = np.array(["0.1", "0.2", "<NA>"], dtype="U32")
75-
tm.assert_numpy_array_equal(a.astype(str), expected)
72+
tm.assert_extension_array_equal(a.astype(str), expected)
73+
tm.assert_extension_array_equal(a.astype("str"), expected)
7674
else:
7775
expected = np.array(["0.1", "0.2", "<NA>"], dtype="U32")
7876

pandas/tests/arrays/integer/test_dtypes.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -281,11 +281,9 @@ def test_astype_str(using_infer_string):
281281

282282
if using_infer_string:
283283
expected = pd.array(["1", "2", None], dtype=pd.StringDtype(na_value=np.nan))
284-
tm.assert_extension_array_equal(a.astype("str"), expected)
285284

286-
# TODO(infer_string) this should also be a string array like above
287-
expected = np.array(["1", "2", "<NA>"], dtype=f"{tm.ENDIAN}U21")
288-
tm.assert_numpy_array_equal(a.astype(str), expected)
285+
tm.assert_extension_array_equal(a.astype(str), expected)
286+
tm.assert_extension_array_equal(a.astype("str"), expected)
289287
else:
290288
expected = np.array(["1", "2", "<NA>"], dtype=f"{tm.ENDIAN}U21")
291289

pandas/tests/arrays/sparse/test_astype.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ def test_astype_all(self, any_real_numpy_dtype):
8181
),
8282
(
8383
SparseArray([0, 1, 10]),
84-
str,
85-
SparseArray(["0", "1", "10"], dtype=SparseDtype(str, "0")),
84+
np.str_,
85+
SparseArray(["0", "1", "10"], dtype=SparseDtype(np.str_, "0")),
8686
),
8787
(SparseArray(["10", "20"]), float, SparseArray([10.0, 20.0])),
8888
(

0 commit comments

Comments
 (0)