diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index f438f75707265..43331639b3744 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -417,7 +417,7 @@ def _str_isupper(self): def _str_len(self): result = pc.utf8_length(self._pa_array) - return Int64Dtype().__from_arrow__(result) + return self._convert_int_dtype(result) def _str_lower(self): return type(self)(pc.utf8_lower(self._pa_array)) @@ -446,6 +446,29 @@ def _str_rstrip(self, to_strip=None): result = pc.utf8_rtrim(self._pa_array, characters=to_strip) return type(self)(result) + def _str_count(self, pat: str, flags: int = 0): + if flags: + return super()._str_count(pat, flags) + result = pc.count_substring_regex(self._pa_array, pat) + return self._convert_int_dtype(result) + + def _str_find(self, sub: str, start: int = 0, end: int | None = None): + if start != 0 and end is not None: + slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end) + result = pc.find_substring(slices, sub) + not_found = pc.equal(result, -1) + offset_result = pc.add(result, end - start) + result = pc.if_else(not_found, result, offset_result) + elif start == 0 and end is None: + slices = self._pa_array + result = pc.find_substring(slices, sub) + else: + return super()._str_find(sub, start, end) + return self._convert_int_dtype(result) + + def _convert_int_dtype(self, result): + return Int64Dtype().__from_arrow__(result) + class ArrowStringArrayNumpySemantics(ArrowStringArray): _storage = "pyarrow_numpy" @@ -517,34 +540,11 @@ def _str_map( return lib.map_infer_mask(arr, f, mask.view("uint8")) def _convert_int_dtype(self, result): + result = result.to_numpy() if result.dtype == np.int32: result = result.astype(np.int64) return result - def _str_count(self, pat: str, flags: int = 0): - if flags: - return super()._str_count(pat, flags) - result = pc.count_substring_regex(self._pa_array, pat).to_numpy() - return self._convert_int_dtype(result) - - def _str_len(self): - result = pc.utf8_length(self._pa_array).to_numpy() - return self._convert_int_dtype(result) - - def _str_find(self, sub: str, start: int = 0, end: int | None = None): - if start != 0 and end is not None: - slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end) - result = pc.find_substring(slices, sub) - not_found = pc.equal(result, -1) - offset_result = pc.add(result, end - start) - result = pc.if_else(not_found, result, offset_result) - elif start == 0 and end is None: - slices = self._pa_array - result = pc.find_substring(slices, sub) - else: - return super()._str_find(sub, start, end) - return self._convert_int_dtype(result.to_numpy()) - def _cmp_method(self, other, op): result = super()._cmp_method(other, op) return result.to_numpy(np.bool_, na_value=False)