-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
WIP: EA SetArray #22382
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
+657
−5
Closed
WIP: EA SetArray #22382
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,443 @@ | ||
import sys | ||
import warnings | ||
import copy | ||
import numpy as np | ||
|
||
import operator | ||
|
||
from pandas import Series | ||
|
||
# from pandas._libs.lib import infer_dtype | ||
from pandas.util._decorators import cache_readonly | ||
from pandas.compat import u, range | ||
from pandas.compat import set_function_name | ||
|
||
from pandas.core.dtypes.common import ( | ||
is_integer, is_scalar, is_object_dtype, is_list_like) | ||
from pandas.core.arrays import ExtensionArray, ExtensionOpsMixin | ||
from pandas.core.dtypes.base import ExtensionDtype | ||
from pandas.core.dtypes.dtypes import registry | ||
from pandas.core.dtypes.missing import isna | ||
|
||
from pandas.io.formats.printing import ( | ||
format_object_summary, format_object_attrs, default_pprint) | ||
|
||
|
||
class SetDtype(ExtensionDtype): | ||
""" | ||
An ExtensionDtype to hold sets. | ||
""" | ||
name = 'set' | ||
type = object | ||
na_value = np.nan | ||
|
||
def __hash__(self): | ||
# XXX: this needs to be part of the interface. | ||
return hash(str(self)) | ||
|
||
def __eq__(self, other): | ||
# TODO: test | ||
if isinstance(other, type(self)): | ||
return True | ||
else: | ||
return super(SetDtype, self).__eq__(other) | ||
|
||
@property | ||
def _is_numeric(self): | ||
return False | ||
|
||
def __repr__(self): | ||
return self.name | ||
|
||
@classmethod | ||
def construct_array_type(cls): | ||
"""Return the array type associated with this dtype | ||
Returns | ||
------- | ||
type | ||
""" | ||
return SetArray | ||
|
||
@classmethod | ||
def construct_from_string(cls, string): | ||
""" | ||
Construction from a string, raise a TypeError if not | ||
possible | ||
""" | ||
if string == cls.name or string is set: | ||
return cls() | ||
raise TypeError("Cannot construct a '{}' from " | ||
"'{}'".format(cls, string)) | ||
|
||
# @classmethod | ||
# def is_dtype(cls, dtype): | ||
# dtype = getattr(dtype, 'dtype', dtype) | ||
# if (isinstance(dtype, compat.string_types) and | ||
# dtype == 'set'): | ||
# return True | ||
# elif isinstance(dtype, cls): | ||
# return True | ||
# return isinstance(dtype, np.dtype) or dtype == 'set' | ||
|
||
|
||
def to_set_array(values): | ||
""" | ||
Infer and return a set array of the values. | ||
Parameters | ||
---------- | ||
values : 1D list-like of list-likes | ||
Returns | ||
------- | ||
SetArray | ||
Raises | ||
------ | ||
TypeError if incompatible types | ||
""" | ||
return SetArray(values, copy=False) | ||
|
||
|
||
def coerce_to_array(values, mask=None, copy=False): | ||
""" | ||
Coerce the input values array to numpy arrays with a mask | ||
Parameters | ||
---------- | ||
values : 1D list-like | ||
mask : boolean 1D array, optional | ||
copy : boolean, default False | ||
if True, copy the input | ||
Returns | ||
------- | ||
tuple of (values, mask) | ||
""" | ||
|
||
if isinstance(values, SetArray): | ||
values, mask = values._data, values._mask | ||
|
||
if copy: | ||
values = values.copy() | ||
mask = mask.copy() | ||
return values, mask | ||
|
||
values = np.array(values, copy=copy) | ||
if not (is_object_dtype(values) or isna(values).all()): | ||
raise TypeError("{} cannot be converted to a SetDtype".format( | ||
values.dtype)) | ||
|
||
if mask is None: | ||
mask = isna(values) | ||
else: | ||
assert len(mask) == len(values) | ||
|
||
if not values.ndim == 1: | ||
raise TypeError("values must be a 1D list-like") | ||
if not mask.ndim == 1: | ||
raise TypeError("mask must be a 1D list-like") | ||
|
||
if mask.any(): | ||
values = values.copy() | ||
values[mask] = np.nan | ||
|
||
return values, mask | ||
|
||
|
||
class SetArray(ExtensionArray, ExtensionOpsMixin): | ||
""" | ||
We represent a SetArray with 2 numpy arrays | ||
- data: contains a numpy set array of object dtype | ||
- mask: a boolean array holding a mask on the data, False is missing | ||
""" | ||
|
||
@cache_readonly | ||
def dtype(self): | ||
return SetDtype() | ||
|
||
def __init__(self, values, mask=None, dtype=None, copy=False): | ||
""" | ||
Parameters | ||
---------- | ||
values : 1D list-like / SetArray | ||
mask : 1D list-like, optional | ||
copy : bool, default False | ||
Returns | ||
------- | ||
SetArray | ||
""" | ||
self._data, self._mask = coerce_to_array( | ||
values, mask=mask, copy=copy) | ||
|
||
@classmethod | ||
def _from_sequence(cls, scalars, dtype=None, copy=False): | ||
# dtype is ignored | ||
return cls(scalars, copy=copy) | ||
|
||
@classmethod | ||
def _from_factorized(cls, values, original): | ||
return cls(values) | ||
|
||
def __getitem__(self, item): | ||
if is_integer(item): | ||
if self._mask[item]: | ||
return self.dtype.na_value | ||
return self._data[item] | ||
return type(self)(self._data[item], mask=self._mask[item]) | ||
|
||
def _coerce_to_ndarray(self): | ||
""" | ||
coerce to an ndarray of object dtype | ||
""" | ||
data = self._data | ||
data[self._mask] = self._na_value | ||
return data | ||
|
||
def __array__(self): | ||
""" | ||
the array interface, return values | ||
""" | ||
return self._coerce_to_ndarray() | ||
|
||
def __iter__(self): | ||
"""Iterate over elements of the array. | ||
""" | ||
# This needs to be implemented so that pandas recognizes extension | ||
# arrays as list-like. The default implementation makes successive | ||
# calls to ``__getitem__``, which may be slower than necessary. | ||
for i in range(len(self)): | ||
if self._mask[i]: | ||
yield self.dtype.na_value | ||
else: | ||
yield self._data[i] | ||
|
||
def _formatting_values(self): | ||
# type: () -> np.ndarray | ||
return self._coerce_to_ndarray() | ||
|
||
def take(self, indices, allow_fill=False, fill_value=None): | ||
from pandas.core.algorithms import take | ||
|
||
if allow_fill and fill_value is None: | ||
fill_value = self.dtype.na_value | ||
|
||
result = take(self._data, indices, fill_value=fill_value, | ||
allow_fill=allow_fill) | ||
return self._from_sequence(result) | ||
|
||
def copy(self, deep=False): | ||
data, mask = self._data, self._mask | ||
if deep: | ||
data = copy.deepcopy(data) | ||
mask = copy.deepcopy(mask) | ||
else: | ||
data = data.copy() | ||
mask = mask.copy() | ||
return type(self)(data, mask, copy=False) | ||
|
||
def __setitem__(self, key, value): | ||
_is_scalar = is_scalar(value) | ||
if _is_scalar: | ||
value = [value] | ||
value, mask = coerce_to_array(value) | ||
|
||
if _is_scalar: | ||
value = value[0] | ||
mask = mask[0] | ||
|
||
self._data[key] = value | ||
self._mask[key] = mask | ||
|
||
def __len__(self): | ||
return len(self._data) | ||
|
||
def __repr__(self): | ||
""" | ||
Return a string representation for this object. | ||
Invoked by unicode(df) in py2 only. Yields a Unicode String in both | ||
py2/py3. | ||
""" | ||
klass = self.__class__.__name__ | ||
data = format_object_summary(self, default_pprint, False) | ||
attrs = format_object_attrs(self) | ||
space = " " | ||
|
||
prepr = (u(",%s") % | ||
space).join(u("%s=%s") % (k, v) for k, v in attrs) | ||
|
||
res = u("%s(%s%s)") % (klass, data, prepr) | ||
|
||
return res | ||
|
||
@property | ||
def nbytes(self): | ||
return self._data.nbytes + self._mask.nbytes | ||
|
||
def isna(self): | ||
return self._mask | ||
|
||
@property | ||
def _na_value(self): | ||
return np.nan | ||
|
||
@classmethod | ||
def _concat_same_type(cls, to_concat): | ||
data = np.concatenate([x._data for x in to_concat]) | ||
mask = np.concatenate([x._mask for x in to_concat]) | ||
return cls(data, mask=mask) | ||
|
||
def astype(self, dtype, copy=True, errors='raise', fill_value=None): | ||
"""Cast to a NumPy array or SetArray with 'dtype'. | ||
Parameters | ||
---------- | ||
dtype : str or dtype | ||
Typecode or data-type to which the array is cast. | ||
copy : bool, default True | ||
Whether to copy the data, even if not necessary. If False, | ||
a copy is made only if the old dtype does not match the | ||
new dtype. | ||
Returns | ||
------- | ||
array : ndarray or SetArray | ||
NumPy ndarray or SetArray with 'dtype' for its dtype. | ||
Raises | ||
------ | ||
TypeError | ||
if incompatible type with a SetDtype, equivalent of same_kind | ||
casting | ||
""" | ||
|
||
# if we are astyping to an existing IntegerDtype we can fastpath | ||
if isinstance(dtype, SetDtype): | ||
result = self._data.astype(dtype.type, | ||
casting='same_kind', copy=False) | ||
return type(self)(result, mask=self._mask, copy=False) | ||
|
||
# coerce | ||
data = self._coerce_to_ndarray() | ||
return data.astype(dtype, copy=False) | ||
|
||
@property | ||
def _ndarray_values(self): | ||
# type: () -> np.ndarray | ||
"""Internal pandas method for lossy conversion to a NumPy ndarray. | ||
This method is not part of the pandas interface. | ||
The expectation is that this is cheap to compute, and is primarily | ||
used for interacting with our indexers. | ||
""" | ||
return self._data | ||
|
||
def fillna(self, value=None, method=None, limit=None): | ||
# TODO: method/limit | ||
res = self._data.copy() | ||
res[self._mask] = [value] * self._mask.sum() | ||
return type(self)(res, | ||
mask=np.full_like(res, fill_value=False, dtype=bool), | ||
copy=False) | ||
|
||
def dropna(self): | ||
res = self._data[~self._mask] | ||
return type(self)(res, | ||
mask=np.full_like(res, fill_value=False, dtype=bool), | ||
copy=False) | ||
|
||
def unique(self): | ||
raise NotImplementedError | ||
|
||
def factorize(self): | ||
raise NotImplementedError | ||
|
||
def argsort(self): | ||
raise NotImplementedError | ||
|
||
def value_counts(self, dropna=True): | ||
raise NotImplementedError | ||
|
||
def _values_for_argsort(self): | ||
raise NotImplementedError | ||
|
||
@classmethod | ||
def _create_comparison_method(cls, op): | ||
def cmp_method(self, other): | ||
|
||
op_name = op.__name__ | ||
mask = None | ||
if isinstance(other, SetArray): | ||
other, mask = other._data, other._mask | ||
elif (isinstance(other, Series) | ||
and isinstance(other.values, SetArray)): | ||
other, mask = other.values._data, other.values._mask | ||
elif isinstance(other, set) or (is_scalar(other) and isna(other)): | ||
other = np.array([other] * len(self)) | ||
elif is_list_like(other): | ||
other = np.asarray(other) | ||
if other.ndim > 0 and len(self) != len(other): | ||
raise ValueError('Lengths must match to compare') | ||
|
||
mask = self._mask | mask if mask is not None else self._mask | ||
result = np.full_like(self._data, fill_value=np.nan, dtype='O') | ||
|
||
# numpy will show a DeprecationWarning on invalid elementwise | ||
# comparisons, this will raise in the future | ||
with warnings.catch_warnings(record=True): | ||
with np.errstate(all='ignore'): | ||
result[~mask] = op(self._data[~mask], other[~mask]) | ||
|
||
result[mask] = True if op_name == 'ne' else False | ||
return result.astype('bool') | ||
|
||
name = '__{name}__'.format(name=op.__name__) | ||
return set_function_name(cmp_method, name, cls) | ||
|
||
@classmethod | ||
def _create_arithmetic_method(cls, op): | ||
def arithmetic_method(self, other): | ||
|
||
mask = None | ||
if isinstance(other, SetArray): | ||
other, mask = other._data, other._mask | ||
elif isinstance(other, set) or (is_scalar(other) and isna(other)): | ||
other = np.array([other] * len(self)) | ||
elif is_list_like(other): | ||
other = np.asarray(other) | ||
# cannot use isnan due to numpy/numpy#9009 | ||
mask = np.array([x is np.nan for x in other]) | ||
if other.ndim > 0 and len(self) != len(other): | ||
raise ValueError('Lengths must match to compare') | ||
|
||
mask = self._mask | mask if mask is not None else self._mask | ||
result = np.full_like(self._data, fill_value=np.nan, dtype='O') | ||
|
||
with np.errstate(all='ignore'): | ||
result[~mask] = op(self._data[~mask], other[~mask]) | ||
|
||
return type(self)(result, mask=mask, copy=False) | ||
|
||
name = '__{name}__'.format(name=op.__name__) | ||
|
||
def raiser(self, other): | ||
raise NotImplementedError | ||
if name != '__sub__': | ||
return raiser | ||
return set_function_name(arithmetic_method, name, cls) | ||
|
||
|
||
SetArray._add_comparison_ops() | ||
SetArray._add_arithmetic_ops() | ||
SetArray.__or__ = SetArray._create_arithmetic_method(operator.__or__) | ||
SetArray.__xor__ = SetArray._create_arithmetic_method(operator.__xor__) | ||
SetArray.__and__ = SetArray._create_arithmetic_method(operator.__and__) | ||
|
||
module = sys.modules[__name__] | ||
setattr(module, 'SetDtype', SetDtype) | ||
registry.register(SetDtype) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
import numpy as np | ||
import pandas.util.testing as tm | ||
import pytest | ||
|
||
from pandas.tests.extension import base | ||
|
||
from pandas.core.arrays.set import SetDtype, SetArray | ||
|
||
|
||
def make_string_sets(): | ||
s = tm.makeStringSeries() | ||
return s.index.map(set).values | ||
|
||
|
||
def make_int_sets(): | ||
s = tm.makeFloatSeries().astype(str).str.replace(r'\D', '') | ||
return s.map(lambda x: set(map(int, x))).values | ||
|
||
|
||
def make_data(): | ||
return (list(make_string_sets()) + | ||
[np.nan] + | ||
list(make_int_sets()) + | ||
[np.nan, None, set()]) | ||
|
||
|
||
@pytest.fixture | ||
def dtype(): | ||
return SetDtype() | ||
|
||
|
||
@pytest.fixture | ||
def data(): | ||
return SetArray(make_int_sets()) | ||
|
||
|
||
@pytest.fixture | ||
def data_missing(): | ||
return SetArray([np.nan, {1}]) | ||
|
||
|
||
@pytest.fixture | ||
def data_repeated(data): | ||
def gen(count): | ||
for _ in range(count): | ||
yield data | ||
yield gen | ||
|
||
|
||
# @pytest.fixture | ||
# def data_for_sorting(dtype): | ||
# return SetArray(...) | ||
|
||
|
||
# @pytest.fixture | ||
# def data_missing_for_sorting(dtype): | ||
# return SetArray(...) | ||
|
||
|
||
@pytest.fixture | ||
def na_cmp(): | ||
# we are np.nan | ||
return lambda x, y: np.isnan(x) and np.isnan(y) | ||
|
||
|
||
@pytest.fixture | ||
def na_value(): | ||
return np.nan | ||
|
||
# @pytest.fixture | ||
# def data_for_grouping(dtype): | ||
# return SetArray(...) | ||
|
||
|
||
class TestDtype(base.BaseDtypeTests): | ||
|
||
def test_array_type_with_arg(self, data, dtype): | ||
assert dtype.construct_array_type() is SetArray | ||
|
||
|
||
class TestInterface(base.BaseInterfaceTests): | ||
|
||
def test_len(self, data): | ||
assert len(data) == 30 | ||
|
||
|
||
class TestConstructors(base.BaseConstructorsTests): | ||
pass | ||
|
||
|
||
class TestReshaping(base.BaseReshapingTests): | ||
pass | ||
|
||
|
||
class TestGetitem(base.BaseGetitemTests): | ||
|
||
@pytest.mark.skip(reason="Need to think about it.") | ||
def test_take_non_na_fill_value(self, data_missing): | ||
pass | ||
|
||
|
||
class TestSetitem(base.BaseGetitemTests): | ||
pass | ||
|
||
|
||
class TestMissing(base.BaseMissingTests): | ||
|
||
def test_fillna_frame(self, data_missing): | ||
pytest.skip('df.fillna does not dispatch to EA') | ||
|
||
def test_fillna_limit_pad(self): | ||
pytest.skip('TODO') | ||
|
||
def test_fillna_limit_backfill(self): | ||
pytest.skip('TODO') | ||
|
||
def test_fillna_series_method(self): | ||
pytest.skip('TODO') | ||
|
||
def test_fillna_series(self): | ||
pytest.skip('series.fillna does not dispatch to EA') | ||
|
||
|
||
# # most methods (value_counts, unique, factorize) will not be for SetArray | ||
# # rest still buggy | ||
class TestMethods(base.BaseMethodsTests): | ||
pass | ||
|
||
|
||
class TestCasting(base.BaseCastingTests): | ||
pass | ||
|
||
|
||
class TestArithmeticOps(base.BaseArithmeticOpsTests): | ||
|
||
def check_opname(self, s, op_name, other, exc='ignored'): | ||
op = self.get_op_from_name(op_name) | ||
|
||
self._check_op(s, op, other, | ||
None if op_name == '__sub__' else NotImplementedError) | ||
|
||
def test_divmod(self, data): | ||
pytest.skip('Not relevant') | ||
|
||
def test_error(self, data, all_arithmetic_operators): | ||
pytest.skip('TODO') | ||
|
||
|
||
class TestComparisonOps(base.BaseComparisonOpsTests): | ||
|
||
def _compare_other(self, s, data, op_name, other): | ||
op = self.get_op_from_name(op_name) | ||
result = op(s, other) | ||
expected = s.combine(other, op) | ||
self.assert_series_equal(result, expected) | ||
|
||
# # GroupBy won't be implemented for SetArray | ||
# class TestGroupby(base.BaseGroupbyTests): | ||
# pass |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this a repetition of the existing dispatch function?