diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 641eb7b01f0b6..d7faf00113609 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -657,12 +657,11 @@ def factorize( pa_type = self._data.type if pa.types.is_duration(pa_type): # https://github.com/apache/arrow/issues/15226#issuecomment-1376578323 - arr = cast(ArrowExtensionArray, self.astype("int64[pyarrow]")) - indices, uniques = arr.factorize(use_na_sentinel=use_na_sentinel) - uniques = uniques.astype(self.dtype) - return indices, uniques + data = self._data.cast(pa.int64()) + else: + data = self._data - encoded = self._data.dictionary_encode(null_encoding=null_encoding) + encoded = data.dictionary_encode(null_encoding=null_encoding) if encoded.length() == 0: indices = np.array([], dtype=np.intp) uniques = type(self)(pa.chunked_array([], type=encoded.type.value_type)) @@ -674,6 +673,9 @@ def factorize( np.intp, copy=False ) uniques = type(self)(encoded.chunk(0).dictionary) + + if pa.types.is_duration(pa_type): + uniques = cast(ArrowExtensionArray, uniques.astype(self.dtype)) return indices, uniques def reshape(self, *args, **kwargs): @@ -858,13 +860,20 @@ def unique(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT: ------- ArrowExtensionArray """ - if pa.types.is_duration(self._data.type): + pa_type = self._data.type + + if pa.types.is_duration(pa_type): # https://github.com/apache/arrow/issues/15226#issuecomment-1376578323 - arr = cast(ArrowExtensionArrayT, self.astype("int64[pyarrow]")) - result = arr.unique() - return cast(ArrowExtensionArrayT, result.astype(self.dtype)) + data = self._data.cast(pa.int64()) + else: + data = self._data + + pa_result = pc.unique(data) - return type(self)(pc.unique(self._data)) + if pa.types.is_duration(pa_type): + pa_result = pa_result.cast(pa_type) + + return type(self)(pa_result) def value_counts(self, dropna: bool = True) -> Series: """ @@ -883,27 +892,30 @@ def value_counts(self, dropna: bool = True) -> Series: -------- Series.value_counts """ - if pa.types.is_duration(self._data.type): + pa_type = self._data.type + if pa.types.is_duration(pa_type): # https://github.com/apache/arrow/issues/15226#issuecomment-1376578323 - arr = cast(ArrowExtensionArray, self.astype("int64[pyarrow]")) - result = arr.value_counts() - result.index = result.index.astype(self.dtype) - return result + data = self._data.cast(pa.int64()) + else: + data = self._data from pandas import ( Index, Series, ) - vc = self._data.value_counts() + vc = data.value_counts() values = vc.field(0) counts = vc.field(1) - if dropna and self._data.null_count > 0: + if dropna and data.null_count > 0: mask = values.is_valid() values = values.filter(mask) counts = counts.filter(mask) + if pa.types.is_duration(pa_type): + values = values.cast(pa_type) + # No missing values so we can adhere to the interface and return a numpy array. counts = np.array(counts) @@ -1231,12 +1243,29 @@ def _mode(self: ArrowExtensionArrayT, dropna: bool = True) -> ArrowExtensionArra """ if pa_version_under6p0: raise NotImplementedError("mode only supported for pyarrow version >= 6.0") - modes = pc.mode(self._data, pc.count_distinct(self._data).as_py()) + + pa_type = self._data.type + if pa.types.is_temporal(pa_type): + nbits = pa_type.bit_width + if nbits == 32: + data = self._data.cast(pa.int32()) + elif nbits == 64: + data = self._data.cast(pa.int64()) + else: + raise NotImplementedError(pa_type) + else: + data = self._data + + modes = pc.mode(data, pc.count_distinct(data).as_py()) values = modes.field(0) counts = modes.field(1) # counts sorted descending i.e counts[0] = max mask = pc.equal(counts, counts[0]) most_common = values.filter(mask) + + if pa.types.is_temporal(pa_type): + most_common = most_common.cast(pa_type) + return type(self)(most_common) def _maybe_convert_setitem_value(self, value): diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 889df3fee15c7..d643582e98095 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -336,15 +336,39 @@ def test_from_sequence_of_strings_pa_array(self, data, request): class TestGetitemTests(base.BaseGetitemTests): - @pytest.mark.xfail( - reason=( - "data.dtype.type return pyarrow.DataType " - "but this (intentionally) returns " - "Python scalars or pd.NA" - ) - ) def test_getitem_scalar(self, data): - super().test_getitem_scalar(data) + # In the base class we expect data.dtype.type; but this (intentionally) + # returns Python scalars or pd.NA + pa_type = data._data.type + if pa.types.is_integer(pa_type): + exp_type = int + elif pa.types.is_floating(pa_type): + exp_type = float + elif pa.types.is_string(pa_type): + exp_type = str + elif pa.types.is_binary(pa_type): + exp_type = bytes + elif pa.types.is_boolean(pa_type): + exp_type = bool + elif pa.types.is_duration(pa_type): + exp_type = timedelta + elif pa.types.is_timestamp(pa_type): + if pa_type.unit == "ns": + exp_type = pd.Timestamp + else: + exp_type = datetime + elif pa.types.is_date(pa_type): + exp_type = date + elif pa.types.is_time(pa_type): + exp_type = time + else: + raise NotImplementedError(data.dtype) + + result = data[0] + assert isinstance(result, exp_type), type(result) + + result = pd.Series(data)[0] + assert isinstance(result, exp_type), type(result) class TestBaseAccumulateTests(base.BaseAccumulateTests): @@ -1011,70 +1035,83 @@ def _patch_combine(self, obj, other, op): expected = pd.Series(pd_array) return expected - def test_arith_series_with_scalar( - self, data, all_arithmetic_operators, request, monkeypatch - ): - pa_dtype = data.dtype.pyarrow_dtype - - arrow_temporal_supported = not pa_version_under8p0 and ( - all_arithmetic_operators in ("__add__", "__radd__") + def _is_temporal_supported(self, opname, pa_dtype): + return not pa_version_under8p0 and ( + opname in ("__add__", "__radd__") and pa.types.is_duration(pa_dtype) - or all_arithmetic_operators in ("__sub__", "__rsub__") + or opname in ("__sub__", "__rsub__") and pa.types.is_temporal(pa_dtype) ) - if all_arithmetic_operators == "__rmod__" and ( - pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype) - ): - pytest.skip("Skip testing Python string formatting") - elif all_arithmetic_operators in { + + def _get_scalar_exception(self, opname, pa_dtype): + arrow_temporal_supported = self._is_temporal_supported(opname, pa_dtype) + if opname in { "__mod__", "__rmod__", }: - self.series_scalar_exc = NotImplementedError + exc = NotImplementedError elif arrow_temporal_supported: - self.series_scalar_exc = None - elif not ( - pa.types.is_floating(pa_dtype) - or pa.types.is_integer(pa_dtype) - or arrow_temporal_supported - ): - self.series_scalar_exc = pa.ArrowNotImplementedError + exc = None + elif not (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)): + exc = pa.ArrowNotImplementedError else: - self.series_scalar_exc = None + exc = None + return exc + + def _get_arith_xfail_marker(self, opname, pa_dtype): + mark = None + + arrow_temporal_supported = self._is_temporal_supported(opname, pa_dtype) + if ( - all_arithmetic_operators == "__rpow__" + opname == "__rpow__" and (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)) and not pa_version_under6p0 ): - request.node.add_marker( - pytest.mark.xfail( - reason=( - f"GH 29997: 1**pandas.NA == 1 while 1**pyarrow.NA == NULL " - f"for {pa_dtype}" - ) + mark = pytest.mark.xfail( + reason=( + f"GH#29997: 1**pandas.NA == 1 while 1**pyarrow.NA == NULL " + f"for {pa_dtype}" ) ) elif arrow_temporal_supported: - request.node.add_marker( - pytest.mark.xfail( - raises=TypeError, - reason=( - f"{all_arithmetic_operators} not supported between" - f"pd.NA and {pa_dtype} Python scalar" - ), - ) + mark = pytest.mark.xfail( + raises=TypeError, + reason=( + f"{opname} not supported between" + f"pd.NA and {pa_dtype} Python scalar" + ), ) elif ( - all_arithmetic_operators in {"__rtruediv__", "__rfloordiv__"} + opname in {"__rtruediv__", "__rfloordiv__"} and (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)) and not pa_version_under6p0 ): - request.node.add_marker( - pytest.mark.xfail( - raises=pa.ArrowInvalid, - reason="divide by 0", - ) + mark = pytest.mark.xfail( + raises=pa.ArrowInvalid, + reason="divide by 0", ) + + return mark + + def test_arith_series_with_scalar( + self, data, all_arithmetic_operators, request, monkeypatch + ): + pa_dtype = data.dtype.pyarrow_dtype + + if all_arithmetic_operators == "__rmod__" and ( + pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype) + ): + pytest.skip("Skip testing Python string formatting") + + self.series_scalar_exc = self._get_scalar_exception( + all_arithmetic_operators, pa_dtype + ) + + mark = self._get_arith_xfail_marker(all_arithmetic_operators, pa_dtype) + if mark is not None: + request.node.add_marker(mark) + if all_arithmetic_operators == "__floordiv__" and pa.types.is_integer(pa_dtype): # BaseOpsUtil._combine always returns int64, while ArrowExtensionArray does # not upcast @@ -1086,61 +1123,19 @@ def test_arith_frame_with_scalar( ): pa_dtype = data.dtype.pyarrow_dtype - arrow_temporal_supported = not pa_version_under8p0 and ( - all_arithmetic_operators in ("__add__", "__radd__") - and pa.types.is_duration(pa_dtype) - or all_arithmetic_operators in ("__sub__", "__rsub__") - and pa.types.is_temporal(pa_dtype) - ) if all_arithmetic_operators == "__rmod__" and ( pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype) ): pytest.skip("Skip testing Python string formatting") - elif all_arithmetic_operators in { - "__mod__", - "__rmod__", - }: - self.frame_scalar_exc = NotImplementedError - elif arrow_temporal_supported: - self.frame_scalar_exc = None - elif not (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)): - self.frame_scalar_exc = pa.ArrowNotImplementedError - else: - self.frame_scalar_exc = None - if ( - all_arithmetic_operators == "__rpow__" - and (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)) - and not pa_version_under6p0 - ): - request.node.add_marker( - pytest.mark.xfail( - reason=( - f"GH 29997: 1**pandas.NA == 1 while 1**pyarrow.NA == NULL " - f"for {pa_dtype}" - ) - ) - ) - elif arrow_temporal_supported: - request.node.add_marker( - pytest.mark.xfail( - raises=TypeError, - reason=( - f"{all_arithmetic_operators} not supported between" - f"pd.NA and {pa_dtype} Python scalar" - ), - ) - ) - elif ( - all_arithmetic_operators in {"__rtruediv__", "__rfloordiv__"} - and (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)) - and not pa_version_under6p0 - ): - request.node.add_marker( - pytest.mark.xfail( - raises=pa.ArrowInvalid, - reason="divide by 0", - ) - ) + + self.frame_scalar_exc = self._get_scalar_exception( + all_arithmetic_operators, pa_dtype + ) + + mark = self._get_arith_xfail_marker(all_arithmetic_operators, pa_dtype) + if mark is not None: + request.node.add_marker(mark) + if all_arithmetic_operators == "__floordiv__" and pa.types.is_integer(pa_dtype): # BaseOpsUtil._combine always returns int64, while ArrowExtensionArray does # not upcast @@ -1152,37 +1147,11 @@ def test_arith_series_with_array( ): pa_dtype = data.dtype.pyarrow_dtype - arrow_temporal_supported = not pa_version_under8p0 and ( - all_arithmetic_operators in ("__add__", "__radd__") - and pa.types.is_duration(pa_dtype) - or all_arithmetic_operators in ("__sub__", "__rsub__") - and pa.types.is_temporal(pa_dtype) + self.series_array_exc = self._get_scalar_exception( + all_arithmetic_operators, pa_dtype ) - if all_arithmetic_operators in { - "__mod__", - "__rmod__", - }: - self.series_array_exc = NotImplementedError - elif arrow_temporal_supported: - self.series_array_exc = None - elif not (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)): - self.series_array_exc = pa.ArrowNotImplementedError - else: - self.series_array_exc = None + if ( - all_arithmetic_operators == "__rpow__" - and (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)) - and not pa_version_under6p0 - ): - request.node.add_marker( - pytest.mark.xfail( - reason=( - f"GH 29997: 1**pandas.NA == 1 while 1**pyarrow.NA == NULL " - f"for {pa_dtype}" - ) - ) - ) - elif ( all_arithmetic_operators in ( "__sub__", @@ -1200,32 +1169,17 @@ def test_arith_series_with_array( ), ) ) - elif arrow_temporal_supported: - request.node.add_marker( - pytest.mark.xfail( - raises=TypeError, - reason=( - f"{all_arithmetic_operators} not supported between" - f"pd.NA and {pa_dtype} Python scalar" - ), - ) - ) - elif ( - all_arithmetic_operators in {"__rtruediv__", "__rfloordiv__"} - and (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)) - and not pa_version_under6p0 - ): - request.node.add_marker( - pytest.mark.xfail( - raises=pa.ArrowInvalid, - reason="divide by 0", - ) - ) + + mark = self._get_arith_xfail_marker(all_arithmetic_operators, pa_dtype) + if mark is not None: + request.node.add_marker(mark) + op_name = all_arithmetic_operators ser = pd.Series(data) # pd.Series([ser.iloc[0]] * len(ser)) may not return ArrowExtensionArray # since ser.iloc[0] is a python scalar other = pd.Series(pd.array([ser.iloc[0]] * len(ser), dtype=data.dtype)) + if pa.types.is_floating(pa_dtype) or ( pa.types.is_integer(pa_dtype) and all_arithmetic_operators != "__truediv__" ): @@ -1321,6 +1275,25 @@ def test_arrowdtype_construct_from_string_type_with_unsupported_parameters(): @pytest.mark.parametrize("quantile", [0.5, [0.5, 0.5]]) def test_quantile(data, interpolation, quantile, request): pa_dtype = data.dtype.pyarrow_dtype + + data = data.take([0, 0, 0]) + ser = pd.Series(data) + + if ( + pa.types.is_string(pa_dtype) + or pa.types.is_binary(pa_dtype) + or pa.types.is_boolean(pa_dtype) + ): + # For string, bytes, and bool, we don't *expect* to have quantile work + # Note this matches the non-pyarrow behavior + if pa_version_under7p0: + msg = r"Function quantile has no kernel matching input types \(.*\)" + else: + msg = r"Function 'quantile' has no kernel matching input types \(.*\)" + with pytest.raises(pa.ArrowNotImplementedError, match=msg): + ser.quantile(q=quantile, interpolation=interpolation) + return + if not (pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype)): request.node.add_marker( pytest.mark.xfail( @@ -1355,11 +1328,7 @@ def test_quantile(data, interpolation, quantile, request): ) def test_mode(data_for_grouping, dropna, take_idx, exp_idx, request): pa_dtype = data_for_grouping.dtype.pyarrow_dtype - if ( - pa.types.is_temporal(pa_dtype) - or pa.types.is_string(pa_dtype) - or pa.types.is_binary(pa_dtype) - ): + if pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype): request.node.add_marker( pytest.mark.xfail( raises=pa.ArrowNotImplementedError,