From 5ecede5af8096b7e5a05aeff544c262e567ef577 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Wed, 6 Dec 2023 11:50:00 -0800 Subject: [PATCH 1/3] BUG: Series.__mul__ for pyarrow strings --- doc/source/whatsnew/v2.2.0.rst | 1 + pandas/core/arrays/arrow/array.py | 22 ++++++++++++---------- pandas/tests/extension/test_arrow.py | 11 +++++++++++ 3 files changed, 24 insertions(+), 10 deletions(-) diff --git a/doc/source/whatsnew/v2.2.0.rst b/doc/source/whatsnew/v2.2.0.rst index 67b4052b386c0..c878fd2664dc4 100644 --- a/doc/source/whatsnew/v2.2.0.rst +++ b/doc/source/whatsnew/v2.2.0.rst @@ -576,6 +576,7 @@ Strings ^^^^^^^ - Bug in :func:`pandas.api.types.is_string_dtype` while checking object array with no elements is of the string dtype (:issue:`54661`) - Bug in :meth:`DataFrame.apply` failing when ``engine="numba"`` and columns or index have ``StringDtype`` (:issue:`56189`) +- Bug in :meth:`Series.__mul__` for :class:`ArrowDtype` with ``pyarrow.string`` dtype and ``string[pyarrow]`` for the pyarrow backend (:issue:`51970`) - Bug in :meth:`Series.str.startswith` and :meth:`Series.str.endswith` with arguments of type ``tuple[str, ...]`` for ``string[pyarrow]`` (:issue:`54942`) Interval diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 1609bf50a834a..3e2826eae9651 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -668,16 +668,18 @@ def _evaluate_op_method(self, other, op, arrow_funcs): pa_type = self._pa_array.type other = self._box_pa(other) - if (pa.types.is_string(pa_type) or pa.types.is_binary(pa_type)) and op in [ - operator.add, - roperator.radd, - ]: - sep = pa.scalar("", type=pa_type) - if op is operator.add: - result = pc.binary_join_element_wise(self._pa_array, other, sep) - else: - result = pc.binary_join_element_wise(other, self._pa_array, sep) - return type(self)(result) + if pa.types.is_string(pa_type) or pa.types.is_binary(pa_type): + if op in [operator.add, roperator.radd, operator.mul, roperator.rmul]: + sep = pa.scalar("", type=pa_type) + if op is operator.add: + result = pc.binary_join_element_wise(self._pa_array, other, sep) + elif op is roperator.radd: + result = pc.binary_join_element_wise(other, self._pa_array, sep) + else: + result = pc.binary_join_element_wise( + *([self._pa_array] * other.as_py()), sep + ) + return type(self)(result) if ( isinstance(other, pa.Scalar) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 9e70a59932701..3dd63bb53a37b 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -2955,6 +2955,17 @@ def test_arrowextensiondtype_dataframe_repr(): assert result == expected +def test_str_multiply(): + # GH 51970 + ser = pd.Series(["a", None], dtype=ArrowDtype(pa.string())) + expected = pd.Series(["aaa", None], dtype=ArrowDtype(pa.string())) + result = 3 * ser + tm.assert_series_equal(result, expected) + + result = ser * 3 + tm.assert_series_equal(result, expected) + + def test_pow_missing_operand(): # GH 55512 k = pd.Series([2, None], dtype="int64[pyarrow]") From a54a97055894feca661bb9977cdaf19c8063fd8c Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Wed, 6 Dec 2023 12:29:14 -0800 Subject: [PATCH 2/3] Fix existing tests --- pandas/core/arrays/arrow/array.py | 4 ++++ pandas/tests/arrays/string_/test_string.py | 7 +----- pandas/tests/extension/test_arrow.py | 25 +++++++++++----------- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 3e2826eae9651..e7a50dbba9935 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -676,6 +676,10 @@ def _evaluate_op_method(self, other, op, arrow_funcs): elif op is roperator.radd: result = pc.binary_join_element_wise(other, self._pa_array, sep) else: + if not ( + isinstance(other, pa.Scalar) and pa.types.is_integer(other.type) + ): + raise TypeError("Can only string multiply by an integer.") result = pc.binary_join_element_wise( *([self._pa_array] * other.as_py()), sep ) diff --git a/pandas/tests/arrays/string_/test_string.py b/pandas/tests/arrays/string_/test_string.py index 524a6632e5544..3e11062b8384e 100644 --- a/pandas/tests/arrays/string_/test_string.py +++ b/pandas/tests/arrays/string_/test_string.py @@ -176,12 +176,7 @@ def test_add_sequence(dtype): tm.assert_extension_array_equal(result, expected) -def test_mul(dtype, request, arrow_string_storage): - if dtype.storage in arrow_string_storage: - reason = "unsupported operand type(s) for *: 'ArrowStringArray' and 'int'" - mark = pytest.mark.xfail(raises=NotImplementedError, reason=reason) - request.applymarker(mark) - +def test_mul(dtype): a = pd.array(["a", "b", None], dtype=dtype) result = a * 2 expected = pd.array(["aa", "bb", None], dtype=dtype) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 3dd63bb53a37b..8532d937c4c88 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -965,8 +965,12 @@ def _get_arith_xfail_marker(self, opname, pa_dtype): def test_arith_series_with_scalar(self, data, all_arithmetic_operators, request): pa_dtype = data.dtype.pyarrow_dtype - if all_arithmetic_operators == "__rmod__" and (pa.types.is_binary(pa_dtype)): + if all_arithmetic_operators == "__rmod__" and pa.types.is_binary(pa_dtype): pytest.skip("Skip testing Python string formatting") + elif all_arithmetic_operators in ("__rmul__", "__mul__") and ( + pa.types.is_binary(pa_dtype) or pa.types.is_string(pa_dtype) + ): + pytest.skip("Cannot multiply strings by a string") mark = self._get_arith_xfail_marker(all_arithmetic_operators, pa_dtype) if mark is not None: @@ -1004,6 +1008,14 @@ def test_arith_series_with_array(self, data, all_arithmetic_operators, request): ), ) ) + elif all_arithmetic_operators in ("__rmul__", "__mul__") and ( + pa.types.is_binary(pa_dtype) or pa.types.is_string(pa_dtype) + ): + request.applymarker( + pytest.mark.xfail( + raises=TypeError, reason="Can only string multiply by an integer." + ) + ) mark = self._get_arith_xfail_marker(all_arithmetic_operators, pa_dtype) if mark is not None: @@ -2955,17 +2967,6 @@ def test_arrowextensiondtype_dataframe_repr(): assert result == expected -def test_str_multiply(): - # GH 51970 - ser = pd.Series(["a", None], dtype=ArrowDtype(pa.string())) - expected = pd.Series(["aaa", None], dtype=ArrowDtype(pa.string())) - result = 3 * ser - tm.assert_series_equal(result, expected) - - result = ser * 3 - tm.assert_series_equal(result, expected) - - def test_pow_missing_operand(): # GH 55512 k = pd.Series([2, None], dtype="int64[pyarrow]") From f04b187e1e360ce266c84dfa802faf21f85e727b Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Wed, 6 Dec 2023 13:18:59 -0800 Subject: [PATCH 3/3] Another test --- pandas/tests/extension/test_arrow.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 8532d937c4c88..3ce3cee9714e4 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -970,7 +970,11 @@ def test_arith_series_with_scalar(self, data, all_arithmetic_operators, request) elif all_arithmetic_operators in ("__rmul__", "__mul__") and ( pa.types.is_binary(pa_dtype) or pa.types.is_string(pa_dtype) ): - pytest.skip("Cannot multiply strings by a string") + request.applymarker( + pytest.mark.xfail( + raises=TypeError, reason="Can only string multiply by an integer." + ) + ) mark = self._get_arith_xfail_marker(all_arithmetic_operators, pa_dtype) if mark is not None: @@ -985,6 +989,14 @@ def test_arith_frame_with_scalar(self, data, all_arithmetic_operators, request): pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype) ): pytest.skip("Skip testing Python string formatting") + elif all_arithmetic_operators in ("__rmul__", "__mul__") and ( + pa.types.is_binary(pa_dtype) or pa.types.is_string(pa_dtype) + ): + request.applymarker( + pytest.mark.xfail( + raises=TypeError, reason="Can only string multiply by an integer." + ) + ) mark = self._get_arith_xfail_marker(all_arithmetic_operators, pa_dtype) if mark is not None: