Skip to content

Commit 28a3cb3

Browse files
authored
Merge pull request numpy#26317 from ngoldbaum/astype-stringdtype-fix
BUG: use PyArray_SafeCast in array_astype
2 parents da95f8e + 9d56834 commit 28a3cb3

File tree

2 files changed

+18
-7
lines changed

2 files changed

+18
-7
lines changed

numpy/_core/src/multiarray/methods.c

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -811,8 +811,8 @@ array_astype(PyArrayObject *self,
811811

812812
/*
813813
* If the memory layout matches and, data types are equivalent,
814-
* and it's not a subtype if subok is False, then we
815-
* can skip the copy.
814+
* it's not a subtype if subok is False, and if the cast says
815+
* view are possible, we can skip the copy.
816816
*/
817817
if (forcecopy != NPY_AS_TYPE_COPY_ALWAYS &&
818818
(order == NPY_KEEPORDER ||
@@ -823,11 +823,15 @@ array_astype(PyArrayObject *self,
823823
PyArray_IS_C_CONTIGUOUS(self)) ||
824824
(order == NPY_FORTRANORDER &&
825825
PyArray_IS_F_CONTIGUOUS(self))) &&
826-
(subok || PyArray_CheckExact(self)) &&
827-
PyArray_EquivTypes(dtype, PyArray_DESCR(self))) {
828-
Py_DECREF(dtype);
829-
Py_INCREF(self);
830-
return (PyObject *)self;
826+
(subok || PyArray_CheckExact(self))) {
827+
npy_intp view_offset;
828+
npy_intp is_safe = PyArray_SafeCast(dtype, PyArray_DESCR(self),
829+
&view_offset, NPY_NO_CASTING, 1);
830+
if (is_safe && (view_offset != NPY_MIN_INTP)) {
831+
Py_DECREF(dtype);
832+
Py_INCREF(self);
833+
return (PyObject *)self;
834+
}
831835
}
832836

833837
if (!PyArray_CanCastArrayTo(self, dtype, casting)) {

numpy/_core/tests/test_stringdtype.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,13 @@ def test_create_with_copy_none(string_list):
493493
assert arr_view is arr
494494

495495

496+
def test_astype_copy_false():
497+
orig_dt = StringDType()
498+
arr = np.array(["hello", "world"], dtype=StringDType())
499+
assert not arr.astype(StringDType(coerce=False), copy=False).dtype.coerce
500+
501+
assert arr.astype(orig_dt, copy=False).dtype is orig_dt
502+
496503
@pytest.mark.parametrize(
497504
"strings",
498505
[

0 commit comments

Comments
 (0)