Skip to content

TST: simplify pyarrow tests, make mode work with temporal dtypes #50688

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jan 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 47 additions & 18 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,12 +657,11 @@ def factorize(
pa_type = self._data.type
if pa.types.is_duration(pa_type):
# https://github.com/apache/arrow/issues/15226#issuecomment-1376578323
arr = cast(ArrowExtensionArray, self.astype("int64[pyarrow]"))
indices, uniques = arr.factorize(use_na_sentinel=use_na_sentinel)
uniques = uniques.astype(self.dtype)
return indices, uniques
data = self._data.cast(pa.int64())
else:
data = self._data

encoded = self._data.dictionary_encode(null_encoding=null_encoding)
encoded = data.dictionary_encode(null_encoding=null_encoding)
if encoded.length() == 0:
indices = np.array([], dtype=np.intp)
uniques = type(self)(pa.chunked_array([], type=encoded.type.value_type))
Expand All @@ -674,6 +673,9 @@ def factorize(
np.intp, copy=False
)
uniques = type(self)(encoded.chunk(0).dictionary)

if pa.types.is_duration(pa_type):
uniques = cast(ArrowExtensionArray, uniques.astype(self.dtype))
return indices, uniques

def reshape(self, *args, **kwargs):
Expand Down Expand Up @@ -858,13 +860,20 @@ def unique(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
-------
ArrowExtensionArray
"""
if pa.types.is_duration(self._data.type):
pa_type = self._data.type

if pa.types.is_duration(pa_type):
# https://github.com/apache/arrow/issues/15226#issuecomment-1376578323
arr = cast(ArrowExtensionArrayT, self.astype("int64[pyarrow]"))
result = arr.unique()
return cast(ArrowExtensionArrayT, result.astype(self.dtype))
data = self._data.cast(pa.int64())
else:
data = self._data

pa_result = pc.unique(data)

return type(self)(pc.unique(self._data))
if pa.types.is_duration(pa_type):
pa_result = pa_result.cast(pa_type)

return type(self)(pa_result)

def value_counts(self, dropna: bool = True) -> Series:
"""
Expand All @@ -883,27 +892,30 @@ def value_counts(self, dropna: bool = True) -> Series:
--------
Series.value_counts
"""
if pa.types.is_duration(self._data.type):
pa_type = self._data.type
if pa.types.is_duration(pa_type):
# https://github.com/apache/arrow/issues/15226#issuecomment-1376578323
arr = cast(ArrowExtensionArray, self.astype("int64[pyarrow]"))
result = arr.value_counts()
result.index = result.index.astype(self.dtype)
return result
data = self._data.cast(pa.int64())
else:
data = self._data

from pandas import (
Index,
Series,
)

vc = self._data.value_counts()
vc = data.value_counts()

values = vc.field(0)
counts = vc.field(1)
if dropna and self._data.null_count > 0:
if dropna and data.null_count > 0:
mask = values.is_valid()
values = values.filter(mask)
counts = counts.filter(mask)

if pa.types.is_duration(pa_type):
values = values.cast(pa_type)

# No missing values so we can adhere to the interface and return a numpy array.
counts = np.array(counts)

Expand Down Expand Up @@ -1231,12 +1243,29 @@ def _mode(self: ArrowExtensionArrayT, dropna: bool = True) -> ArrowExtensionArra
"""
if pa_version_under6p0:
raise NotImplementedError("mode only supported for pyarrow version >= 6.0")
modes = pc.mode(self._data, pc.count_distinct(self._data).as_py())

pa_type = self._data.type
if pa.types.is_temporal(pa_type):
nbits = pa_type.bit_width
if nbits == 32:
data = self._data.cast(pa.int32())
elif nbits == 64:
data = self._data.cast(pa.int64())
else:
raise NotImplementedError(pa_type)
else:
data = self._data

modes = pc.mode(data, pc.count_distinct(data).as_py())
values = modes.field(0)
counts = modes.field(1)
# counts sorted descending i.e counts[0] = max
mask = pc.equal(counts, counts[0])
most_common = values.filter(mask)

if pa.types.is_temporal(pa_type):
most_common = most_common.cast(pa_type)

return type(self)(most_common)

def _maybe_convert_setitem_value(self, value):
Expand Down
Loading