diff --git a/sparse/coo.py b/sparse/coo.py index f0b5f0b1..53308906 100644 --- a/sparse/coo.py +++ b/sparse/coo.py @@ -1,6 +1,7 @@ from __future__ import absolute_import, division, print_function from collections import Iterable, defaultdict, deque +from contextlib import contextmanager from functools import reduce, partial import numbers import operator @@ -9,7 +10,8 @@ import scipy.sparse from .slicing import normalize_index -from .utils import _zero_of_dtype +from .utils import _zero_of_dtype, _get_broadcast_shape, SparseArray +from .densification import DensificationConfig # zip_longest with Python 2/3 compat from six.moves import range, zip_longest @@ -19,8 +21,10 @@ except NameError: pass +_DEFAULT_DENSIFICATION_CONFIG = DensificationConfig(densify=False) -class COO(object): + +class COO(SparseArray): """ A sparse multidimensional array. @@ -183,7 +187,7 @@ class COO(object): __array_priority__ = 12 def __init__(self, coords, data=None, shape=None, has_duplicates=True, - sorted=False, cache=False): + sorted=False, cache=False, densification_config=None): self._cache = None if cache: self.enable_caching() @@ -191,11 +195,7 @@ def __init__(self, coords, data=None, shape=None, has_duplicates=True, from .dok import DOK if isinstance(coords, COO): - self.coords = coords.coords - self.data = coords.data - self.has_duplicates = coords.has_duplicates - self.sorted = coords.sorted - self.shape = coords.shape + self._make_shallow_copy(coords) return if isinstance(coords, DOK): @@ -209,20 +209,12 @@ def __init__(self, coords, data=None, shape=None, has_duplicates=True, if isinstance(coords, np.ndarray): result = COO.from_numpy(coords) - self.coords = result.coords - self.data = result.data - self.has_duplicates = result.has_duplicates - self.sorted = result.sorted - self.shape = result.shape + self._make_shallow_copy(result) return if isinstance(coords, scipy.sparse.spmatrix): result = COO.from_scipy_sparse(coords) - self.coords = result.coords - self.data = result.data - self.has_duplicates = result.has_duplicates - self.sorted = result.sorted - self.shape = result.shape + self._make_shallow_copy(result) return # [] @@ -270,6 +262,19 @@ def __init__(self, coords, data=None, shape=None, has_duplicates=True, self.has_duplicates = has_duplicates self.sorted = sorted + if densification_config is None: + self.densification_config = _DEFAULT_DENSIFICATION_CONFIG + else: + self.densification_config = densification_config + + def _make_shallow_copy(self, other): + self.coords = other.coords + self.data = other.data + self.has_duplicates = other.has_duplicates + self.sorted = other.sorted + self.shape = other.shape + self.densification_config = other.densification_config + def enable_caching(self): """ Enable caching of reshape, transpose, and tocsr/csc operations @@ -584,10 +589,15 @@ def __getitem__(self, index): coords = list(self.coords[:, idx[0]]) coords.extend(idx[1:]) + densification_config = DensificationConfig.from_parents([ + self.densification_config + ]) + return COO(coords, data[idx].flatten(), shape=self.shape + self.data.dtype[index].shape, has_duplicates=self.has_duplicates, - sorted=self.sorted) + sorted=self.sorted, + densification_config=densification_config) else: index = (index,) @@ -660,9 +670,14 @@ def __getitem__(self, index): shape = tuple(shape) data = self.data[mask] + densification_config = DensificationConfig.from_parents([ + self.densification_config + ]) + return COO(coords, data, shape=shape, has_duplicates=self.has_duplicates, - sorted=self.sorted) + sorted=self.sorted, + densification_config=densification_config) def __str__(self): return "" % ( @@ -778,8 +793,14 @@ def reduce(self, method, axis=None, keepdims=False, **kwargs): result[missing_counts] = method(result[missing_counts], _zero_of_dtype(self.dtype), **kwargs) coords = a.coords[0:1, inv_idx] + + densification_config = DensificationConfig.from_parents([ + self.densification_config + ]) + a = COO(coords, result, shape=(a.shape[0],), - has_duplicates=False, sorted=True) + has_duplicates=False, sorted=True, + densification_config=densification_config) a = a.reshape([self.shape[d] for d in neg_axis]) result = a @@ -1123,10 +1144,15 @@ def transpose(self, axes=None): if ax == axes: return value + densification_config = DensificationConfig.from_parents([ + self.densification_config + ]) + shape = tuple(self.shape[ax] for ax in axes) result = COO(self.coords[axes, :], self.data, shape, has_duplicates=self.has_duplicates, - cache=self._cache is not None) + cache=self._cache is not None, + densification_config=densification_config) if self._cache is not None: self._cache['transpose'].append((axes, result)) @@ -1229,6 +1255,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): return NotImplemented def __array__(self, dtype=None, **kwargs): + self.densification_config.check('array', self) x = self.todense() if dtype and x.dtype != dtype: x = x.astype(dtype) @@ -1319,9 +1346,14 @@ def reshape(self, shape): coords[-(i + 1), :] = (linear_loc // strides) % d strides *= d + densification_config = DensificationConfig.from_parents([ + self.densification_config + ]) + result = COO(coords, self.data, shape, has_duplicates=self.has_duplicates, - sorted=self.sorted, cache=self._cache is not None) + sorted=self.sorted, cache=self._cache is not None, + densification_config=densification_config) if self._cache is not None: self._cache['reshape'].append((shape, result)) @@ -1709,8 +1741,12 @@ def broadcast_to(self, shape): params = _get_broadcast_parameters(self.shape, result_shape) coords, data = _get_expanded_coords_data(self.coords, self.data, params, result_shape) + densification_config = DensificationConfig.from_parents([ + self.densification_config + ]) + return COO(coords, data, shape=result_shape, has_duplicates=self.has_duplicates, - sorted=self.sorted) + sorted=self.sorted, densification_config=densification_config) def __abs__(self): """ @@ -2006,7 +2042,16 @@ def astype(self, dtype, out=None): assert out is None return self.elemwise(np.ndarray.astype, dtype) - def maybe_densify(self, max_size=1000, min_density=0.25): + @contextmanager + def configure_densification(self, **kwargs): + old_densification_config = self.densification_config + self.densification_config = DensificationConfig(**kwargs) + + yield + + self.densification_config = old_densification_config + + def maybe_densify(self): """ Converts this :obj:`COO` array to a :obj:`numpy.ndarray` if not too costly. @@ -2030,30 +2075,44 @@ def maybe_densify(self, max_size=1000, min_density=0.25): Examples -------- - Convert a small sparse array to a dense array. + Convert a small sparse array to a dense array. Uses the default + densification config. >>> s = COO.from_numpy(np.random.rand(2, 3, 4)) >>> x = s.maybe_densify() - >>> np.allclose(x, s.todense()) - True + Traceback (most recent call last): + ... + ValueError: Performing this operation would produce a dense result: maybe_densify You can also specify the minimum allowed density or the maximum number of output elements. If both conditions are unmet, this method will throw - an error. + an exception. >>> x = np.zeros((5, 5), dtype=np.uint8) >>> x[2, 2] = 1 >>> s = COO.from_numpy(x) - >>> s.maybe_densify(max_size=5, min_density=0.25) + >>> with s.configure_densification(densify=True): + ... s.maybe_densify() + array([[0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]], dtype=uint8) + >>> with s.configure_densification(densify=None, max_size=100): + ... s.maybe_densify() + array([[0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]], dtype=uint8) + >>> with s.configure_densification(densify=None, max_size=1, min_density=0.5): + ... s.maybe_densify() Traceback (most recent call last): ... - ValueError: Operation would require converting large sparse array to dense + ValueError: Performing this operation would produce a dense result: maybe_densify """ - if self.size <= max_size or self.density >= min_density: - return self.todense() - else: - raise ValueError("Operation would require converting " - "large sparse array to dense") + self.densification_config.check('maybe_densify', self) + return self.todense() def tensordot(a, b, axes=2): @@ -2263,8 +2322,13 @@ def concatenate(arrays, axis=0): has_duplicates = any(x.has_duplicates for x in arrays) + densification_config = DensificationConfig.from_parents([ + x.densification_config for x in arrays + ]) + return COO(coords, data, shape=shape, has_duplicates=has_duplicates, - sorted=(axis == 0) and all(a.sorted for a in arrays)) + sorted=(axis == 0) and all(a.sorted for a in arrays), + densification_config=densification_config) def stack(arrays, axis=0): @@ -2311,8 +2375,13 @@ def stack(arrays, axis=0): coords.insert(axis, new) coords = np.stack(coords, axis=0) + densification_config = DensificationConfig.from_parents([ + x.densification_config for x in arrays + ]) + return COO(coords, data, shape=shape, has_duplicates=has_duplicates, - sorted=(axis == 0) and all(a.sorted for a in arrays)) + sorted=(axis == 0) and all(a.sorted for a in arrays), + densification_config=densification_config) def triu(x, k=0): @@ -2344,7 +2413,12 @@ def triu(x, k=0): coords = x.coords[:, mask] data = x.data[mask] - return COO(coords, data, x.shape, x.has_duplicates, x.sorted) + densification_config = DensificationConfig.from_parents([ + x.densification_config + ]) + + return COO(coords, data, x.shape, x.has_duplicates, x.sorted, + densification_config=densification_config) def tril(x, k=0): @@ -2376,7 +2450,12 @@ def tril(x, k=0): coords = x.coords[:, mask] data = x.data[mask] - return COO(coords, data, x.shape, x.has_duplicates, x.sorted) + densification_config = DensificationConfig.from_parents([ + x.densification_config + ]) + + return COO(coords, data, x.shape, x.has_duplicates, x.sorted, + densification_config=densification_config) # (c) Paul Panzer @@ -2460,33 +2539,37 @@ def _grouped_reduce(x, groups, method, **kwargs): def _elemwise_binary(func, self, other, *args, **kwargs): - check = kwargs.pop('check', True) self_zero = _zero_of_dtype(self.dtype) other_zero = _zero_of_dtype(other.dtype) func_zero = _zero_of_dtype(func(self_zero, other_zero, *args, **kwargs).dtype) - if check and func(self_zero, other_zero, *args, **kwargs) != func_zero: - raise ValueError("Performing this operation would produce " - "a dense result: %s" % str(func)) if not isinstance(self, COO): - if not check or np.array_equiv(func(self, other_zero, *args, **kwargs), func_zero): + if np.array_equiv(func(self, other_zero, *args, **kwargs), func_zero): return _elemwise_binary_self_dense(func, self, other, *args, **kwargs) else: - raise ValueError("Performing this operation would produce " - "a dense result: %s" % str(func)) + other.densification_config.check(str(func), self, other) + return func(self, other.todense(), *args, **kwargs) if not isinstance(other, COO): - if not check or np.array_equiv(func(self_zero, other, *args, **kwargs), func_zero): + if np.array_equiv(func(self_zero, other, *args, **kwargs), func_zero): temp_func = _reverse_self_other(func) return _elemwise_binary_self_dense(temp_func, other, self, *args, **kwargs) else: - raise ValueError("Performing this operation would produce " - "a dense result: %s" % str(func)) + self.densification_config.check(str(func), self, other) + return func(self.todense(), other, *args, **kwargs) - self_shape, other_shape = self.shape, other.shape + if func(self_zero, other_zero, *args, **kwargs) != func_zero: + DensificationConfig.from_parents([ + self.densification_config, + other.densification_config + ]).check(str(func), self, other) + + return func(self.todense(), other.todense(), *args, **kwargs) + self_shape, other_shape = self.shape, other.shape result_shape = _get_broadcast_shape(self_shape, other_shape) + self_params = _get_broadcast_parameters(self.shape, result_shape) other_params = _get_broadcast_parameters(other.shape, result_shape) combined_params = [p1 and p2 for p1, p2 in zip(self_params, other_params)] @@ -2570,7 +2653,12 @@ def _elemwise_binary(func, self, other, *args, **kwargs): data = data[nonzero] coords = coords[:, nonzero] - return COO(coords, data, shape=result_shape, has_duplicates=False) + densification_config = DensificationConfig.from_parents([ + self.densification_config, + other.densification_config + ]) + return COO(coords, data, shape=result_shape, has_duplicates=False, + densification_config=densification_config) def _elemwise_binary_self_dense(func, self, other, *args, **kwargs): @@ -2593,9 +2681,14 @@ def _elemwise_binary_self_dense(func, self, other, *args, **kwargs): func_data = func_data[mask] func_coords = other.coords[:, mask] + densification_config = DensificationConfig.from_parents([ + other.densification_config + ]) + return COO(func_coords, func_data, shape=result_shape, has_duplicates=other.has_duplicates, - sorted=other.sorted) + sorted=other.sorted, + densification_config=densification_config) def _reverse_self_other(func): @@ -2670,39 +2763,6 @@ def _get_unmatched_coords_data(coords, data, shape, result_shape, matched_idx, return coords_list, data_list -def _get_broadcast_shape(shape1, shape2, is_result=False): - """ - Get the overall broadcasted shape. - - Parameters - ---------- - shape1, shape2 : tuple[int] - The input shapes to broadcast together. - is_result : bool - Whether or not shape2 is also the result shape. - - Returns - ------- - result_shape : tuple[int] - The overall shape of the result. - - Raises - ------ - ValueError - If the two shapes cannot be broadcast together. - """ - # https://stackoverflow.com/a/47244284/774273 - if not all((l1 == l2) or (l1 == 1) or ((l2 == 1) and not is_result) for l1, l2 in - zip(shape1[::-1], shape2[::-1])): - raise ValueError('operands could not be broadcast together with shapes %s, %s' % - (shape1, shape2)) - - result_shape = tuple(max(l1, l2) for l1, l2 in - zip_longest(shape1[::-1], shape2[::-1], fillvalue=1))[::-1] - - return result_shape - - def _get_broadcast_parameters(shape, broadcast_shape): """ Get the broadcast parameters. @@ -2834,20 +2894,24 @@ def _cartesian_product(*arrays): def _elemwise_unary(func, self, *args, **kwargs): - check = kwargs.pop('check', True) data_zero = _zero_of_dtype(self.dtype) func_zero = _zero_of_dtype(func(data_zero, *args, **kwargs).dtype) - if check and func(data_zero, *args, **kwargs) != func_zero: - raise ValueError("Performing this operation would produce " - "a dense result: %s" % str(func)) + if func(data_zero, *args, **kwargs) != func_zero: + self.densification_config.check(str(func), self) + return func(self.todense(), *args, **kwargs) data_func = func(self.data, *args, **kwargs) nonzero = data_func != func_zero + densification_config = DensificationConfig.from_parents([ + self.densification_config + ]) + return COO(self.coords[:, nonzero], data_func[nonzero], shape=self.shape, has_duplicates=self.has_duplicates, - sorted=self.sorted) + sorted=self.sorted, + densification_config=densification_config) def _get_matching_coords(coords1, coords2, shape1, shape2): diff --git a/sparse/densification.py b/sparse/densification.py new file mode 100644 index 00000000..174ae67f --- /dev/null +++ b/sparse/densification.py @@ -0,0 +1,158 @@ +import sys +from numbers import Integral, Number +from collections import Iterable +import threading + +import numpy as np + +from .utils import _get_broadcast_shape, SparseArray, TriState + +try: # Windows compatibility + int = long +except NameError: + pass + +try: + _max_size = sys.maxsize +except NameError: + _max_size = sys.maxint + + +class DensificationConfig(object): + def __init__(self, densify=None, max_size=10000, min_density=0.25): + if not isinstance(densify, bool) and densify is not None: + raise ValueError('always_densify must be a bool or None.') + + if not isinstance(max_size, Integral) or max_size < 0: + raise ValueError("max_nnz must be a non-negative integer.") + + if not isinstance(min_density, Number) or not (0.0 <= min_density <= 1.0): + raise ValueError('min_density must be a number between 0 and 1.') + + self.densify = TriState(densify) + self.max_size = int(max_size) + self.min_density = float(min_density) + self.children = None + self.parents = None + self._disconnected_parents = 0 + self._parents_with_children = 0 + self._lock = threading.Lock() + + def _should_densify(self, size, density): + if self.densify.value is True: + return True + elif self.densify.value is False: + return False + else: + return self.max_size >= size or self.min_density <= density + + def _raise_if_fails(self, size, density, name): + if not self._should_densify(size, density): + raise ValueError("Performing this operation would produce " + "a dense result: %s" % name) + + def check(self, name, *arrays): + config = self._reduce_from_parents() + result_shape = () + density = 1.0 + + for arr in arrays: + if isinstance(arr, SparseArray): + density = min(density, arr.density) + result_shape = _get_broadcast_shape(result_shape, arr.shape) + + size = np.prod(result_shape, dtype=np.uint64) + + config._raise_if_fails(size, density, name) + + @staticmethod + def validate(configs): + if isinstance(configs, Iterable): + for config in configs: + DensificationConfig._validate_single(config) + + return + + DensificationConfig._validate_single(configs) + + @staticmethod + def _validate_single(config): + if not isinstance(config, DensificationConfig): + raise ValueError('Invalid DensificationConfig.') + + @staticmethod + def from_parents(parents): + root_parents = set() + + for parent in parents: + DensificationConfig._validate_single(parent) + root_parents.update(parent._get_all_parents()) + + result = DensificationConfig() + result.parents = root_parents + + for parent in root_parents: + if isinstance(parent.children, set): + parent.children.add(result) + result._parents_with_children += 1 + + if not result._parents_with_children: + result._reduce_from_parents(in_place=True) + + return result + + def _get_all_parents(self): + if isinstance(self.parents, Iterable): + return self.parents + + parents = set() + parents.add(self) + + return parents + + def _reduce_from_parents(self, in_place=False): + if not isinstance(self.parents, Iterable): + return self + + max_size = _max_size + min_density = 0.0 + densify = TriState(True) + + for parent in self.parents: + max_size = min(max_size, parent.max_size) + min_density = max(min_density, parent.min_density) + densify = min(densify, parent.densify) + + if in_place: + self.parents = None + self._disconnected_parents = 0 + self._parents_with_children = 0 + self.max_size = max_size + self.min_density = min_density + self.densify = densify + return self + + return DensificationConfig(densify, max_size, min_density) + + def __enter__(self): + self.children = set() + + def __exit__(self, exc_type, exc_val, exc_tb): + for child in self.children: + child._lock.acquire() + child._disconnected_parents += 1 + if child._disconnected_parents == child._parents_with_children: + self._reduce_from_parents(in_place=True) + + child._lock.release() + + self.children = None + + def __str__(self): + if isinstance(self.parents, set): + return '' % len(self.parents) + elif isinstance(self.densify.value, bool): + return '' % self.densify.value + else: + return '' % \ + (self.max_size, self.min_density) diff --git a/sparse/dok.py b/sparse/dok.py index 65ede6f7..6db141dc 100644 --- a/sparse/dok.py +++ b/sparse/dok.py @@ -10,7 +10,7 @@ from collections import Iterable from .slicing import normalize_index -from .utils import _zero_of_dtype +from .utils import _zero_of_dtype, SparseArray try: # Windows compatibility int = long @@ -18,7 +18,7 @@ pass -class DOK(object): +class DOK(SparseArray): """ A class for building sparse multidimensional arrays. @@ -341,6 +341,8 @@ def _setitem(self, key_list, value): if value != _zero_of_dtype(self.dtype): self.data[tuple(key_list)] = value[()] + else: + self.data.pop(tuple(key_list), None) def __str__(self): return "" % (self.shape, self.dtype, self.nnz) diff --git a/sparse/tests/test_coo.py b/sparse/tests/test_coo.py index f5883d70..0f98dcc7 100644 --- a/sparse/tests/test_coo.py +++ b/sparse/tests/test_coo.py @@ -1085,8 +1085,8 @@ def test_scalar_shape_construction(): def test_len(): - s = sparse.random((20, 30, 40)) - assert len(s) == 20 + s = sparse.random((2, 3, 4)) + assert len(s) == 2 def test_density(): @@ -1095,12 +1095,119 @@ def test_density(): def test_size(): - s = sparse.random((20, 30, 40)) - assert s.size == 20 * 30 * 40 + s = sparse.random((2, 3, 4)) + assert s.size == 2 * 3 * 4 def test_np_array(): - s = sparse.random((20, 30, 40)) - x = np.array(s) + s = sparse.random((2, 3, 4)) + with s.configure_densification(densify=True): + x = np.array(s) assert isinstance(x, np.ndarray) assert_eq(x, s) + + +@pytest.mark.parametrize('func', [ + lambda s1, s2: np.exp(s1), + lambda s1, s2: np.cos(s1), + lambda s1, s2: s1 + 1, + lambda s1, s2: s1 + s2 + 1, + lambda s1, s2: -(s1 + s2) + 1, +]) +@pytest.mark.parametrize('kwargs', [ + {'densify': True}, + { + 'densify': None, + 'max_size': 0, + 'min_density': 0.3, + }, + { + 'densify': None, + 'max_size': 25, + 'min_density': 0.7, + }, +]) +def test_densification_config(func, kwargs): + s1 = sparse.random((2, 3, 4), density=0.5) + s2 = sparse.random((2, 3, 4), density=0.5) + + x1 = s1.todense() + x2 = s2.todense() + + func_x = func(x1, x2) + + with s1.configure_densification(**kwargs), \ + s2.configure_densification(**kwargs): + func_s = func(s1, s2) + + assert isinstance(func_s, np.ndarray) + assert_eq(func_x, func_s) + + with pytest.raises(ValueError): + func(s1, s2) + + +@pytest.mark.parametrize('func', [ + lambda s1, s2: np.sin(s1), + lambda s1, s2: np.expm1(s1), + lambda s1, s2: s1 + 0, + lambda s1, s2: s1 + s2, + lambda s1, s2: -(s1 - s2), +]) +@pytest.mark.parametrize('kwargs', [ + {'densify': False}, + {'densify': True}, + { + 'densify': None, + 'max_size': 0, + 'min_density': 0.3, + }, + { + 'densify': None, + 'max_size': 25, + 'min_density': 0.7, + }, + { + 'densify': None, + 'max_size': 0, + 'min_density': 0.7, + }, +]) +def test_densification_config_sparse(func, kwargs): + s1 = sparse.random((2, 3, 4), density=0.5) + s2 = sparse.random((2, 3, 4), density=0.5) + + x1 = s1.todense() + x2 = s2.todense() + + func_x = func(x1, x2) + + with s1.configure_densification(**kwargs), \ + s2.configure_densification(**kwargs): + func_s = func(s1, s2) + + assert isinstance(func_s, COO) + assert_eq(func_x, func_s) + + +@pytest.mark.parametrize('func', [ + lambda s1, s2: np.exp(s1), + lambda s1, s2: np.cos(s1), + lambda s1, s2: s1 + 1 +]) +@pytest.mark.parametrize('kwargs', [ + {'densify': False}, + { + 'densify': None, + 'max_size': 0, + 'min_density': 0.7, + }, +]) +def test_densification_config_fails(func, kwargs): + s1 = sparse.random((2, 3, 4), density=0.5) + s2 = sparse.random((2, 3, 4), density=0.5) + + with s1.configure_densification(**kwargs), \ + s2.configure_densification(**kwargs), \ + pytest.raises(ValueError): + func(s1, s2) diff --git a/sparse/utils.py b/sparse/utils.py index 3fd3df62..10dddaa4 100644 --- a/sparse/utils.py +++ b/sparse/utils.py @@ -1,5 +1,13 @@ import numpy as np from numbers import Integral +from functools import total_ordering +from abc import ABCMeta + +from six.moves import range, zip_longest + + +class SparseArray(object): + __metaclass__ = ABCMeta def assert_eq(x, y, **kwargs): @@ -25,6 +33,36 @@ def assert_eq(x, y, **kwargs): assert np.allclose(xx, yy, **kwargs) +# (c) kindall +# Taken from https://stackoverflow.com/a/9504358/774273 +# License: https://creativecommons.org/licenses/by-sa/3.0/ +@total_ordering +class TriState(object): + def __init__(self, value=None): + if any(value is v for v in (True, False, None)): + self.value = value + else: + raise ValueError("Tristate value must be True, False, or None") + + def __eq__(self, other): + return (self.value is other.value if isinstance(other, TriState) + else self.value is other) + + def __le__(self, other): + if self.value is False: + return True + elif self.value is None: + return other.value is not False + else: + return other.value is True + + def __str__(self): + return str(self.value) + + def __repr__(self): + return "Tristate(%s)" % self.value + + def is_lexsorted(x): return not x.shape or (np.diff(x.linear_loc()) > 0).all() @@ -146,3 +184,36 @@ def random( ar = DOK(ar) return ar + + +def _get_broadcast_shape(shape1, shape2, is_result=False): + """ + Get the overall broadcasted shape. + + Parameters + ---------- + shape1, shape2 : tuple[int] + The input shapes to broadcast together. + is_result : bool + Whether or not shape2 is also the result shape. + + Returns + ------- + result_shape : tuple[int] + The overall shape of the result. + + Raises + ------ + ValueError + If the two shapes cannot be broadcast together. + """ + # https://stackoverflow.com/a/47244284/774273 + if not all((l1 == l2) or (l1 == 1) or ((l2 == 1) and not is_result) for l1, l2 in + zip(shape1[::-1], shape2[::-1])): + raise ValueError('operands could not be broadcast together with shapes %s, %s' % + (shape1, shape2)) + + result_shape = tuple(max(l1, l2) for l1, l2 in + zip_longest(shape1[::-1], shape2[::-1], fillvalue=1))[::-1] + + return result_shape