diff --git a/numpy_groupies/utils_numpy.py b/numpy_groupies/utils_numpy.py index f2ef839..0c86065 100644 --- a/numpy_groupies/utils_numpy.py +++ b/numpy_groupies/utils_numpy.py @@ -189,8 +189,53 @@ def check_group_idx(group_idx, a=None, check_min=True): raise ValueError("group_idx contains negative indices") +def _ravel_group_idx(group_idx, a, axis, size, order, method="ravel"): + ndim_a = a.ndim + # Create the broadcast-ready multidimensional indexing. + # Note the user could do this themselves, so this is + # very much just a convenience. + size_in = int(np.max(group_idx)) + 1 if size is None else size + group_idx_in = group_idx + group_idx = [] + size = [] + for ii, s in enumerate(a.shape): + if method == "ravel": + ii_idx = group_idx_in if ii == axis else np.arange(s) + ii_shape = [1] * ndim_a + ii_shape[ii] = s + group_idx.append(ii_idx.reshape(ii_shape)) + size.append(size_in if ii == axis else s) + # Use the indexing, and return. It's a bit simpler than + # using trying to keep all the logic below happy + if method == "ravel": + group_idx = np.ravel_multi_index(group_idx, size, order=order, + mode='raise') + elif method == "offset": + group_idx = offset_labels(group_idx_in, a.shape, axis, order, size_in) + return group_idx, size + +def offset_labels(group_idx, inshape, axis, order, size): + """ + Offset group labels by dimension. This is used when we + reduce over a subset of the dimensions of by. It assumes that the reductions + dimensions have been flattened in the last dimension + Copied from + https://stackoverflow.com/questions/46256279/bin-elements-per-row-vectorized-2d-bincount-for-numpy + """ + if axis not in (-1, len(inshape) - 1): + newshape = (s for idx, s in enumerate(inshape) if idx != axis) + (inshape[axis],) + else: + newshape = inshape + group_idx = np.broadcast_to(group_idx, newshape) + group_idx: np.ndarray = ( + group_idx + + np.arange(np.prod(group_idx.shape[:-1]), dtype=int).reshape((*group_idx.shape[:-1], -1)) + * size + ) + return group_idx.reshape(inshape).ravel() + def input_validation(group_idx, a, size=None, order='C', axis=None, - ravel_group_idx=True, check_bounds=True): + ravel_group_idx=True, check_bounds=True, method="ravel"): """ Do some fairly extensive checking of group_idx and a, trying to give the user as much help as possible with what is wrong. Also, convert ndim-indexing to 1d indexing. @@ -230,23 +275,7 @@ def input_validation(group_idx, a, size=None, order='C', axis=None, raise NotImplementedError("when using axis arg, size must be" "None or scalar.") else: - # Create the broadcast-ready multidimensional indexing. - # Note the user could do this themselves, so this is - # very much just a convenience. - size_in = int(np.max(group_idx)) + 1 if size is None else size - group_idx_in = group_idx - group_idx = [] - size = [] - for ii, s in enumerate(a.shape): - ii_idx = group_idx_in if ii == axis else np.arange(s) - ii_shape = [1] * ndim_a - ii_shape[ii] = s - group_idx.append(ii_idx.reshape(ii_shape)) - size.append(size_in if ii == axis else s) - # Use the indexing, and return. It's a bit simpler than - # using trying to keep all the logic below happy - group_idx = np.ravel_multi_index(group_idx, size, order=order, - mode='raise') + group_idx, size = _ravel_group_idx(group_idx, a, axis, size, order, method=method) flat_size = np.prod(size) ndim_idx = ndim_a return group_idx.ravel(), a.ravel(), flat_size, ndim_idx, size