diff --git a/numpy_groupies/utils_numpy.py b/numpy_groupies/utils_numpy.py index 41ec2c9..545829a 100644 --- a/numpy_groupies/utils_numpy.py +++ b/numpy_groupies/utils_numpy.py @@ -222,21 +222,25 @@ def offset_labels(group_idx, inshape, axis, order, size): Copied from https://stackoverflow.com/questions/46256279/bin-elements-per-row-vectorized-2d-bincount-for-numpy """ + + newaxes = tuple(ax for ax in range(len(inshape)) if ax != axis) + group_idx = np.broadcast_to(np.expand_dims(group_idx, newaxes), inshape) + if axis not in (-1, len(inshape) - 1): + group_idx = np.moveaxis(group_idx, axis, -1) + newshape = group_idx.shape + + group_idx = (group_idx + + np.arange(np.prod(newshape[:-1]), dtype=int).reshape((*newshape[:-1], -1)) + * size + ) if axis not in (-1, len(inshape) - 1): - newshape = (s for idx, s in enumerate(inshape) if idx != axis) + (inshape[axis],) + return np.moveaxis(group_idx, -1, axis) else: - newshape = inshape - group_idx = np.broadcast_to(group_idx, newshape) - group_idx: np.ndarray = ( - group_idx - + np.arange(np.prod(group_idx.shape[:-1]), dtype=int).reshape((*group_idx.shape[:-1], -1)) - * size - ) - return group_idx.reshape(inshape).ravel() + return group_idx def input_validation(group_idx, a, size=None, order='C', axis=None, - ravel_group_idx=True, check_bounds=True, method="ravel", func=None): + ravel_group_idx=True, check_bounds=True, method="offset", func=None): """ Do some fairly extensive checking of group_idx and a, trying to give the user as much help as possible with what is wrong. Also, convert ndim-indexing to 1d indexing.