diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index 3aaeef3b63760..c6b0732b04c09 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -50,7 +50,7 @@ from pandas.core.sorting import ( compress_group_index, decons_obs_group_ids, - get_flattened_iterator, + get_flattened_list, get_group_index, get_group_index_sorter, get_indexer_dict, @@ -153,7 +153,7 @@ def _get_group_keys(self): comp_ids, _, ngroups = self.group_info # provide "flattened" iterator for multi-group setting - return get_flattened_iterator(comp_ids, ngroups, self.levels, self.codes) + return get_flattened_list(comp_ids, ngroups, self.levels, self.codes) def apply(self, f: F, data: FrameOrSeries, axis: int = 0): mutated = self.mutated diff --git a/pandas/core/sorting.py b/pandas/core/sorting.py index ee73aa42701b0..8bdd466ae6f33 100644 --- a/pandas/core/sorting.py +++ b/pandas/core/sorting.py @@ -1,5 +1,6 @@ """ miscellaneous sorting / groupby utilities """ -from typing import Callable, Optional +from collections import defaultdict +from typing import TYPE_CHECKING, Callable, DefaultDict, Iterable, List, Optional, Tuple import numpy as np @@ -18,6 +19,9 @@ import pandas.core.algorithms as algorithms from pandas.core.construction import extract_array +if TYPE_CHECKING: + from pandas.core.indexes.base import Index # noqa:F401 + _INT64_MAX = np.iinfo(np.int64).max @@ -409,7 +413,7 @@ def ensure_key_mapped(values, key: Optional[Callable], levels=None): levels : Optional[List], if values is a MultiIndex, list of levels to apply the key to. """ - from pandas.core.indexes.api import Index + from pandas.core.indexes.api import Index # noqa:F811 if not key: return values @@ -440,36 +444,21 @@ def ensure_key_mapped(values, key: Optional[Callable], levels=None): return result -class _KeyMapper: - """ - Map compressed group id -> key tuple. - """ - - def __init__(self, comp_ids, ngroups: int, levels, labels): - self.levels = levels - self.labels = labels - self.comp_ids = comp_ids.astype(np.int64) - - self.k = len(labels) - self.tables = [hashtable.Int64HashTable(ngroups) for _ in range(self.k)] - - self._populate_tables() - - def _populate_tables(self): - for labs, table in zip(self.labels, self.tables): - table.map(self.comp_ids, labs.astype(np.int64)) - - def get_key(self, comp_id): - return tuple( - level[table.get_item(comp_id)] - for table, level in zip(self.tables, self.levels) - ) - - -def get_flattened_iterator(comp_ids, ngroups, levels, labels): - # provide "flattened" iterator for multi-group setting - mapper = _KeyMapper(comp_ids, ngroups, levels, labels) - return [mapper.get_key(i) for i in range(ngroups)] +def get_flattened_list( + comp_ids: np.ndarray, + ngroups: int, + levels: Iterable["Index"], + labels: Iterable[np.ndarray], +) -> List[Tuple]: + """Map compressed group id -> key tuple.""" + comp_ids = comp_ids.astype(np.int64, copy=False) + arrays: DefaultDict[int, List[int]] = defaultdict(list) + for labs, level in zip(labels, levels): + table = hashtable.Int64HashTable(ngroups) + table.map(comp_ids, labs.astype(np.int64, copy=False)) + for i in range(ngroups): + arrays[i].append(level[table.get_item(i)]) + return [tuple(array) for array in arrays.values()] def get_indexer_dict(label_list, keys):