Skip to content

Commit ec56dd2

Browse files
PERF: optimize algos.take for repeated calls (#39692)
1 parent 6409754 commit ec56dd2

File tree

4 files changed

+95
-32
lines changed

4 files changed

+95
-32
lines changed

pandas/core/array_algos/take.py

+50-30
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import functools
34
from typing import Optional
45

56
import numpy as np
@@ -177,41 +178,60 @@ def take_2d_multi(
177178
return out
178179

179180

181+
@functools.lru_cache(maxsize=128)
182+
def _get_take_nd_function_cached(ndim, arr_dtype, out_dtype, axis):
183+
"""
184+
Part of _get_take_nd_function below that doesn't need `mask_info` and thus
185+
can be cached (mask_info potentially contains a numpy ndarray which is not
186+
hashable and thus cannot be used as argument for cached function).
187+
"""
188+
tup = (arr_dtype.name, out_dtype.name)
189+
if ndim == 1:
190+
func = _take_1d_dict.get(tup, None)
191+
elif ndim == 2:
192+
if axis == 0:
193+
func = _take_2d_axis0_dict.get(tup, None)
194+
else:
195+
func = _take_2d_axis1_dict.get(tup, None)
196+
if func is not None:
197+
return func
198+
199+
tup = (out_dtype.name, out_dtype.name)
200+
if ndim == 1:
201+
func = _take_1d_dict.get(tup, None)
202+
elif ndim == 2:
203+
if axis == 0:
204+
func = _take_2d_axis0_dict.get(tup, None)
205+
else:
206+
func = _take_2d_axis1_dict.get(tup, None)
207+
if func is not None:
208+
func = _convert_wrapper(func, out_dtype)
209+
return func
210+
211+
return None
212+
213+
180214
def _get_take_nd_function(
181-
ndim: int, arr_dtype: np.dtype, out_dtype: np.dtype, axis: int = 0, mask_info=None
215+
ndim: int, arr_dtype, out_dtype, axis: int = 0, mask_info=None
182216
):
183-
217+
"""
218+
Get the appropriate "take" implementation for the given dimension, axis
219+
and dtypes.
220+
"""
221+
func = None
184222
if ndim <= 2:
185-
tup = (arr_dtype.name, out_dtype.name)
186-
if ndim == 1:
187-
func = _take_1d_dict.get(tup, None)
188-
elif ndim == 2:
189-
if axis == 0:
190-
func = _take_2d_axis0_dict.get(tup, None)
191-
else:
192-
func = _take_2d_axis1_dict.get(tup, None)
193-
if func is not None:
194-
return func
195-
196-
tup = (out_dtype.name, out_dtype.name)
197-
if ndim == 1:
198-
func = _take_1d_dict.get(tup, None)
199-
elif ndim == 2:
200-
if axis == 0:
201-
func = _take_2d_axis0_dict.get(tup, None)
202-
else:
203-
func = _take_2d_axis1_dict.get(tup, None)
204-
if func is not None:
205-
func = _convert_wrapper(func, out_dtype)
206-
return func
223+
# for this part we don't need `mask_info` -> use the cached algo lookup
224+
func = _get_take_nd_function_cached(ndim, arr_dtype, out_dtype, axis)
207225

208-
def func2(arr, indexer, out, fill_value=np.nan):
209-
indexer = ensure_int64(indexer)
210-
_take_nd_object(
211-
arr, indexer, out, axis=axis, fill_value=fill_value, mask_info=mask_info
212-
)
226+
if func is None:
227+
228+
def func(arr, indexer, out, fill_value=np.nan):
229+
indexer = ensure_int64(indexer)
230+
_take_nd_object(
231+
arr, indexer, out, axis=axis, fill_value=fill_value, mask_info=mask_info
232+
)
213233

214-
return func2
234+
return func
215235

216236

217237
def _view_wrapper(f, arr_dtype=None, out_dtype=None, fill_wrap=None):

pandas/core/dtypes/cast.py

+31-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
datetime,
1111
timedelta,
1212
)
13+
import functools
1314
import inspect
1415
from typing import (
1516
TYPE_CHECKING,
@@ -573,6 +574,35 @@ def maybe_promote(dtype: np.dtype, fill_value=np.nan):
573574
ValueError
574575
If fill_value is a non-scalar and dtype is not object.
575576
"""
577+
# TODO(2.0): need to directly use the non-cached version as long as we
578+
# possibly raise a deprecation warning for datetime dtype
579+
if dtype.kind == "M":
580+
return _maybe_promote(dtype, fill_value)
581+
# for performance, we are using a cached version of the actual implementation
582+
# of the function in _maybe_promote. However, this doesn't always work (in case
583+
# of non-hashable arguments), so we fallback to the actual implementation if needed
584+
try:
585+
# error: Argument 3 to "__call__" of "_lru_cache_wrapper" has incompatible type
586+
# "Type[Any]"; expected "Hashable" [arg-type]
587+
return _maybe_promote_cached(
588+
dtype, fill_value, type(fill_value) # type: ignore[arg-type]
589+
)
590+
except TypeError:
591+
# if fill_value is not hashable (required for caching)
592+
return _maybe_promote(dtype, fill_value)
593+
594+
595+
@functools.lru_cache(maxsize=128)
596+
def _maybe_promote_cached(dtype, fill_value, fill_value_type):
597+
# The cached version of _maybe_promote below
598+
# This also use fill_value_type as (unused) argument to use this in the
599+
# cache lookup -> to differentiate 1 and True
600+
return _maybe_promote(dtype, fill_value)
601+
602+
603+
def _maybe_promote(dtype: np.dtype, fill_value=np.nan):
604+
# The actual implementation of the function, use `maybe_promote` above for
605+
# a cached version.
576606
if not is_scalar(fill_value):
577607
# with object dtype there is nothing to promote, and the user can
578608
# pass pretty much any weird fill_value they like
@@ -623,7 +653,7 @@ def maybe_promote(dtype: np.dtype, fill_value=np.nan):
623653
"dtype is deprecated. In a future version, this will be cast "
624654
"to object dtype. Pass `fill_value=Timestamp(date_obj)` instead.",
625655
FutureWarning,
626-
stacklevel=7,
656+
stacklevel=8,
627657
)
628658
return dtype, fv
629659
elif isinstance(fill_value, str):

pandas/core/internals/array_manager.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
)
5757

5858
import pandas.core.algorithms as algos
59+
from pandas.core.array_algos.take import take_nd
5960
from pandas.core.arrays import (
6061
DatetimeArray,
6162
ExtensionArray,
@@ -1020,7 +1021,7 @@ def unstack(self, unstacker, fill_value) -> ArrayManager:
10201021
new_arrays = []
10211022
for arr in self.arrays:
10221023
for i in range(unstacker.full_shape[1]):
1023-
new_arr = algos.take(
1024+
new_arr = take_nd(
10241025
arr, new_indexer2D[:, i], allow_fill=True, fill_value=fill_value
10251026
)
10261027
new_arrays.append(new_arr)

pandas/tests/test_take.py

+12
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,18 @@ def test_take_axis_1(self):
417417
with pytest.raises(IndexError, match="indices are out-of-bounds"):
418418
algos.take(arr, [0, 3], axis=1, allow_fill=True, fill_value=0)
419419

420+
def test_take_non_hashable_fill_value(self):
421+
arr = np.array([1, 2, 3])
422+
indexer = np.array([1, -1])
423+
with pytest.raises(ValueError, match="fill_value must be a scalar"):
424+
algos.take(arr, indexer, allow_fill=True, fill_value=[1])
425+
426+
# with object dtype it is allowed
427+
arr = np.array([1, 2, 3], dtype=object)
428+
result = algos.take(arr, indexer, allow_fill=True, fill_value=[1])
429+
expected = np.array([2, [1]], dtype=object)
430+
tm.assert_numpy_array_equal(result, expected)
431+
420432

421433
class TestExtensionTake:
422434
# The take method found in pd.api.extensions

0 commit comments

Comments
 (0)