-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
PERF: optimize algos.take for repeated calls #39692
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
4512f9c
36c3ed2
6d52932
ded773a
2ee2543
f489ba5
96305c5
480d2b4
c70ac4d
9fba887
5273cd5
d3dd4e4
288c6f2
06a3901
ca30487
bf598a7
05b6b87
2284813
4861fdb
a41ee6b
b52e1ec
76371cf
2faf70b
1c19732
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
""" | ||
from __future__ import annotations | ||
|
||
import functools | ||
import operator | ||
from textwrap import dedent | ||
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union, cast | ||
|
@@ -1534,40 +1535,73 @@ def _take_nd_object(arr, indexer, out, axis: int, fill_value, mask_info): | |
} | ||
|
||
|
||
@functools.lru_cache(maxsize=128) | ||
def __get_take_nd_function_cached(ndim, arr_dtype, out_dtype, axis): | ||
""" | ||
Part of _get_take_nd_function below that doesn't need the mask | ||
and thus can be cached. | ||
""" | ||
tup = (arr_dtype.name, out_dtype.name) | ||
if ndim == 1: | ||
func = _take_1d_dict.get(tup, None) | ||
elif ndim == 2: | ||
if axis == 0: | ||
func = _take_2d_axis0_dict.get(tup, None) | ||
jbrockmendel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
func = _take_2d_axis1_dict.get(tup, None) | ||
if func is not None: | ||
return func | ||
|
||
tup = (out_dtype.name, out_dtype.name) | ||
if ndim == 1: | ||
func = _take_1d_dict.get(tup, None) | ||
elif ndim == 2: | ||
if axis == 0: | ||
func = _take_2d_axis0_dict.get(tup, None) | ||
else: | ||
func = _take_2d_axis1_dict.get(tup, None) | ||
if func is not None: | ||
func = _convert_wrapper(func, out_dtype) | ||
return func | ||
|
||
return None | ||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is the caching not on this function? having too many levels of indirection is -1 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will clarify the comment above, the |
||
def _get_take_nd_function( | ||
ndim: int, arr_dtype, out_dtype, axis: int = 0, mask_info=None | ||
): | ||
func = None | ||
if ndim <= 2: | ||
jbrockmendel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
tup = (arr_dtype.name, out_dtype.name) | ||
if ndim == 1: | ||
func = _take_1d_dict.get(tup, None) | ||
elif ndim == 2: | ||
if axis == 0: | ||
func = _take_2d_axis0_dict.get(tup, None) | ||
else: | ||
func = _take_2d_axis1_dict.get(tup, None) | ||
if func is not None: | ||
return func | ||
|
||
tup = (out_dtype.name, out_dtype.name) | ||
if ndim == 1: | ||
func = _take_1d_dict.get(tup, None) | ||
elif ndim == 2: | ||
if axis == 0: | ||
func = _take_2d_axis0_dict.get(tup, None) | ||
else: | ||
func = _take_2d_axis1_dict.get(tup, None) | ||
if func is not None: | ||
func = _convert_wrapper(func, out_dtype) | ||
return func | ||
func = __get_take_nd_function_cached(ndim, arr_dtype, out_dtype, axis) | ||
|
||
def func2(arr, indexer, out, fill_value=np.nan): | ||
indexer = ensure_int64(indexer) | ||
_take_nd_object( | ||
arr, indexer, out, axis=axis, fill_value=fill_value, mask_info=mask_info | ||
) | ||
if func is None: | ||
|
||
return func2 | ||
def func(arr, indexer, out, fill_value=np.nan): | ||
indexer = ensure_int64(indexer) | ||
_take_nd_object( | ||
arr, indexer, out, axis=axis, fill_value=fill_value, mask_info=mask_info | ||
) | ||
|
||
return func | ||
|
||
|
||
@functools.lru_cache(maxsize=128) | ||
jbrockmendel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def _maybe_promote_cached(dtype, fill_value, fill_value_type): | ||
# also use fill_value_type as (unused) argument to use this in the cache | ||
# lookup -> differentiate 1 and True | ||
return maybe_promote(dtype, fill_value) | ||
|
||
|
||
def _maybe_promote(dtype, fill_value): | ||
jbrockmendel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
try: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is very strange to do. can you simply change all of the call of putting a try/except inside here is reversing the paradigm and not good. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That is not simply possible. We don't know in advance if the fill_value is going to be hashable or not. So that's the reason the fallback is needed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure but the try/except needs to be in the cached method, NOT here. IOW you are now exposing 2 api's, we need to have exactly one. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But try/except is because of non-hashable fill_values, which thus cannot be inside the cached method, that's the whole reason I added the try/except in the first place. I am not exposing two different APIs. These are internal helper methods, and There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we just simply ban non-hashables from |
||
# error: Argument 3 to "__call__" of "_lru_cache_wrapper" has incompatible type | ||
# "Type[Any]"; expected "Hashable" [arg-type] | ||
return _maybe_promote_cached( | ||
dtype, fill_value, type(fill_value) | ||
) # type: ignore[arg-type] | ||
except TypeError: | ||
# if fill_value is not hashable (required for caching) | ||
return maybe_promote(dtype, fill_value) | ||
|
||
|
||
def take(arr, indices, axis: int = 0, allow_fill: bool = False, fill_value=None): | ||
|
@@ -1661,6 +1695,40 @@ def take(arr, indices, axis: int = 0, allow_fill: bool = False, fill_value=None) | |
return result | ||
|
||
|
||
def _take_preprocess_indexer_and_fill_value( | ||
arr, indexer, axis, out, fill_value, allow_fill | ||
): | ||
mask_info = None | ||
|
||
if indexer is None: | ||
indexer = np.arange(arr.shape[axis], dtype=np.int64) | ||
dtype, fill_value = arr.dtype, arr.dtype.type() | ||
else: | ||
indexer = ensure_int64(indexer, copy=False) | ||
if not allow_fill: | ||
dtype, fill_value = arr.dtype, arr.dtype.type() | ||
mask_info = None, False | ||
else: | ||
# check for promotion based on types only (do this first because | ||
# it's faster than computing a mask) | ||
dtype, fill_value = _maybe_promote(arr.dtype, fill_value) | ||
if dtype != arr.dtype and (out is None or out.dtype != dtype): | ||
# check if promotion is actually required based on indexer | ||
mask = indexer == -1 | ||
needs_masking = mask.any() | ||
mask_info = mask, needs_masking | ||
if needs_masking: | ||
if out is not None and out.dtype != dtype: | ||
raise TypeError("Incompatible type for fill_value") | ||
else: | ||
# if not, then depromote, set fill_value to dummy | ||
# (it won't be used but we don't want the cython code | ||
# to crash when trying to cast it to dtype) | ||
dtype, fill_value = arr.dtype, arr.dtype.type() | ||
|
||
return indexer, dtype, fill_value, mask_info | ||
|
||
|
||
def take_nd( | ||
arr, | ||
indexer, | ||
|
@@ -1700,8 +1768,6 @@ def take_nd( | |
subarray : array-like | ||
May be the same type as the input, or cast to an ndarray. | ||
""" | ||
mask_info = None | ||
|
||
if fill_value is lib.no_default: | ||
fill_value = na_value_for_dtype(arr.dtype, compat=False) | ||
|
||
|
@@ -1712,31 +1778,9 @@ def take_nd( | |
arr = extract_array(arr) | ||
arr = np.asarray(arr) | ||
|
||
if indexer is None: | ||
indexer = np.arange(arr.shape[axis], dtype=np.int64) | ||
dtype, fill_value = arr.dtype, arr.dtype.type() | ||
else: | ||
indexer = ensure_int64(indexer, copy=False) | ||
if not allow_fill: | ||
dtype, fill_value = arr.dtype, arr.dtype.type() | ||
mask_info = None, False | ||
else: | ||
# check for promotion based on types only (do this first because | ||
# it's faster than computing a mask) | ||
dtype, fill_value = maybe_promote(arr.dtype, fill_value) | ||
if dtype != arr.dtype and (out is None or out.dtype != dtype): | ||
# check if promotion is actually required based on indexer | ||
mask = indexer == -1 | ||
needs_masking = mask.any() | ||
mask_info = mask, needs_masking | ||
if needs_masking: | ||
if out is not None and out.dtype != dtype: | ||
raise TypeError("Incompatible type for fill_value") | ||
else: | ||
# if not, then depromote, set fill_value to dummy | ||
# (it won't be used but we don't want the cython code | ||
# to crash when trying to cast it to dtype) | ||
dtype, fill_value = arr.dtype, arr.dtype.type() | ||
indexer, dtype, fill_value, mask_info = _take_preprocess_indexer_and_fill_value( | ||
arr, indexer, axis, out, fill_value, allow_fill | ||
) | ||
|
||
flip_order = False | ||
if arr.ndim == 2 and arr.flags.f_contiguous: | ||
|
@@ -1776,6 +1820,43 @@ def take_nd( | |
take_1d = take_nd | ||
|
||
|
||
def take_1d_array( | ||
jorisvandenbossche marked this conversation as resolved.
Show resolved
Hide resolved
|
||
arr: np.ndarray, | ||
indexer: np.ndarray, | ||
out=None, | ||
fill_value=lib.no_default, | ||
allow_fill: bool = True, | ||
): | ||
""" | ||
Specialized version for 1D arrays. Differences compared to take_nd/take_1d: | ||
|
||
- Assumes input (arr, indexer) has already been converted to numpy arrays | ||
- Only works for 1D arrays | ||
|
||
""" | ||
if fill_value is lib.no_default: | ||
fill_value = na_value_for_dtype(arr.dtype, compat=False) | ||
|
||
if isinstance(arr, ABCExtensionArray): | ||
jbrockmendel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Check for EA to catch DatetimeArray, TimedeltaArray | ||
return arr.take(indexer, fill_value=fill_value, allow_fill=allow_fill) | ||
|
||
indexer, dtype, fill_value, mask_info = _take_preprocess_indexer_and_fill_value( | ||
arr, indexer, 0, out, fill_value, allow_fill | ||
) | ||
|
||
# at this point, it's guaranteed that dtype can hold both the arr values | ||
# and the fill_value | ||
out = np.empty(indexer.shape, dtype=dtype) | ||
|
||
func = _get_take_nd_function( | ||
arr.ndim, arr.dtype, out.dtype, axis=0, mask_info=mask_info | ||
) | ||
func(arr, indexer, out, fill_value) | ||
|
||
return out | ||
|
||
|
||
def take_2d_multi(arr, indexer, fill_value=np.nan): | ||
""" | ||
Specialized Cython take which sets NaN values in one pass. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this coment does't make any sense w/o the PR context. can you put / move a doc-string here. typing a +1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The "and thus can be cached" on the next line is the essential continuation of the sentence.
The mask can be an array, and thus is not hashable and thus cannot be used as argument for a cached function.