Skip to content

Commit 736d7d6

Browse files
authored
Remove auto densification and unify operator code. (#46)
* Get rid of auto densification and unify ops and elemwise code. * Add more operators. * Add tests for all operators. * Move scalar logic to elemwise_binary. * Unify elemwise and elemwise_binary. * Add computed function to test instead of re-computing it. * Added newline to docstring for Sphinx. * Added test for operation with scipy sparse matrix. * Fix spontaneous test failure.
1 parent 13bb9a2 commit 736d7d6

File tree

2 files changed

+157
-67
lines changed

2 files changed

+157
-67
lines changed

sparse/core.py

Lines changed: 73 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,9 @@ def reshape(self, shape):
460460
# TODO: this np.prod(self.shape) enforces a 2**64 limit to array size
461461
linear_loc = self.linear_loc()
462462

463-
coords = np.empty((len(shape), self.nnz), dtype=np.min_scalar_type(max(shape)))
463+
max_shape = max(shape) if len(shape) != 0 else 1
464+
465+
coords = np.empty((len(shape), self.nnz), dtype=np.min_scalar_type(max_shape - 1))
464466
strides = 1
465467
for i, d in enumerate(shape[::-1]):
466468
coords[-(i + 1), :] = (linear_loc // strides) % d
@@ -580,31 +582,22 @@ def sum_duplicates(self):
580582
return self
581583

582584
def __add__(self, other):
583-
if isinstance(other, numbers.Number) and other == 0:
584-
return self
585-
if not isinstance(other, COO):
586-
return self.maybe_densify() + other
587-
else:
588-
return self.elemwise_binary(operator.add, other)
585+
return self.elemwise(operator.add, other)
589586

590-
def __radd__(self, other):
591-
return self + other
587+
__radd__ = __add__
592588

593589
def __neg__(self):
594590
return COO(self.coords, -self.data, self.shape, self.has_duplicates,
595591
self.sorted)
596592

597593
def __sub__(self, other):
598-
return self + (-other)
594+
return self.elemwise(operator.sub, other)
599595

600596
def __rsub__(self, other):
601-
return -self + other
597+
return -(self - other)
602598

603599
def __mul__(self, other):
604-
if isinstance(other, COO):
605-
return self.elemwise_binary(operator.mul, other)
606-
else:
607-
return self.elemwise(operator.mul, other)
600+
return self.elemwise(operator.mul, other)
608601

609602
__rmul__ = __mul__
610603

@@ -620,32 +613,86 @@ def __pow__(self, other):
620613
return self.elemwise(operator.pow, other)
621614

622615
def __and__(self, other):
623-
return self.elemwise_binary(operator.and_, other)
616+
return self.elemwise(operator.and_, other)
624617

625618
def __xor__(self, other):
626-
return self.elemwise_binary(operator.xor, other)
619+
return self.elemwise(operator.xor, other)
627620

628621
def __or__(self, other):
629-
return self.elemwise_binary(operator.or_, other)
622+
return self.elemwise(operator.or_, other)
623+
624+
def __gt__(self, other):
625+
return self.elemwise(operator.gt, other)
626+
627+
def __ge__(self, other):
628+
return self.elemwise(operator.ge, other)
629+
630+
def __lt__(self, other):
631+
return self.elemwise(operator.lt, other)
632+
633+
def __le__(self, other):
634+
return self.elemwise(operator.le, other)
635+
636+
def __eq__(self, other):
637+
return self.elemwise(operator.eq, other)
638+
639+
def __ne__(self, other):
640+
return self.elemwise(operator.ne, other)
630641

631642
def elemwise(self, func, *args, **kwargs):
643+
"""
644+
Apply a function to one or two arguments.
645+
646+
Parameters
647+
----------
648+
func
649+
The function to apply to one or two arguments.
650+
args : tuple, optional
651+
The extra arguments to pass to the function. If args[0] is a COO object
652+
or a scipy.sparse.spmatrix, the function will be treated as a binary
653+
function. Otherwise, it will be treated as a unary function.
654+
kwargs : dict, optional
655+
The kwargs to pass to the function.
656+
657+
Returns
658+
-------
659+
COO
660+
The result of applying the function.
661+
"""
662+
if len(args) == 0:
663+
return self._elemwise_unary(func, *args, **kwargs)
664+
else:
665+
other = args[0]
666+
if isinstance(other, COO):
667+
return self._elemwise_binary(func, *args, **kwargs)
668+
elif isinstance(other, scipy.sparse.spmatrix):
669+
other = COO.from_scipy_sparse(other)
670+
return self._elemwise_binary(func, other, *args[1:], **kwargs)
671+
else:
672+
return self._elemwise_unary(func, *args, **kwargs)
673+
674+
def _elemwise_unary(self, func, *args, **kwargs):
632675
check = kwargs.pop('check', True)
633676
data_zero = _zero_of_dtype(self.dtype)
634677
func_zero = _zero_of_dtype(func(data_zero, *args, **kwargs).dtype)
635678
if check and func(data_zero, *args, **kwargs) != func_zero:
636679
raise ValueError("Performing this operation would produce "
637680
"a dense result: %s" % str(func))
638-
return COO(self.coords, func(self.data, *args, **kwargs),
681+
682+
data_func = func(self.data, *args, **kwargs)
683+
nonzero = data_func != func_zero
684+
685+
return COO(self.coords[:, nonzero], data_func[nonzero],
639686
shape=self.shape,
640687
has_duplicates=self.has_duplicates,
641688
sorted=self.sorted)
642689

643-
def elemwise_binary(self, func, other, *args, **kwargs):
690+
def _elemwise_binary(self, func, other, *args, **kwargs):
644691
assert isinstance(other, COO)
692+
check = kwargs.pop('check', True)
645693
self_zero = _zero_of_dtype(self.dtype)
646694
other_zero = _zero_of_dtype(other.dtype)
647-
check = kwargs.pop('check', True)
648-
func_zero = _zero_of_dtype(func(self_zero, other_zero, * args, **kwargs).dtype)
695+
func_zero = _zero_of_dtype(func(self_zero, other_zero, *args, **kwargs).dtype)
649696
if check and func(self_zero, other_zero, *args, **kwargs) != func_zero:
650697
raise ValueError("Performing this operation would produce "
651698
"a dense result: %s" % str(func))
@@ -690,12 +737,6 @@ def elemwise_binary(self, func, other, *args, **kwargs):
690737
matched_self, matched_other = _match_arrays(self_reduced_linear,
691738
other_reduced_linear)
692739

693-
# Locate coordinates without a match
694-
unmatched_self = np.ones(self.nnz, dtype=np.bool)
695-
unmatched_self[matched_self] = False
696-
unmatched_other = np.ones(other.nnz, dtype=np.bool)
697-
unmatched_other[matched_other] = False
698-
699740
# Start with an empty list. This may reduce computation in many cases.
700741
data_list = []
701742
coords_list = []
@@ -711,11 +752,10 @@ def elemwise_binary(self, func, other, *args, **kwargs):
711752
coords_list.append(matched_coords)
712753

713754
self_func = func(self_data, other_zero, *args, **kwargs)
714-
715755
# Add unmatched parts as necessary.
716756
if (self_func != func_zero).any():
717757
self_unmatched_coords, self_unmatched_func = \
718-
self._get_unmatched_coords_data(self_coords, self_data, self_shape,
758+
self._get_unmatched_coords_data(self_coords, self_func, self_shape,
719759
result_shape, matched_self,
720760
matched_coords)
721761

@@ -726,7 +766,7 @@ def elemwise_binary(self, func, other, *args, **kwargs):
726766

727767
if (other_func != func_zero).any():
728768
other_unmatched_coords, other_unmatched_func = \
729-
self._get_unmatched_coords_data(other_coords, other_data, other_shape,
769+
self._get_unmatched_coords_data(other_coords, other_func, other_shape,
730770
result_shape, matched_other,
731771
matched_coords)
732772

@@ -1067,7 +1107,7 @@ def __abs__(self):
10671107

10681108
def exp(self, out=None):
10691109
assert out is None
1070-
return np.exp(self.maybe_densify())
1110+
return self.elemwise(np.exp)
10711111

10721112
def expm1(self, out=None):
10731113
assert out is None
@@ -1123,23 +1163,7 @@ def conjugate(self, out=None):
11231163

11241164
def astype(self, dtype, out=None):
11251165
assert out is None
1126-
return self.elemwise(np.ndarray.astype, dtype, check=False)
1127-
1128-
def __gt__(self, other):
1129-
if not isinstance(other, numbers.Number):
1130-
raise NotImplementedError("Only scalars supported")
1131-
if other < 0:
1132-
raise ValueError("Comparison with negative number would produce "
1133-
"dense result")
1134-
return self.elemwise(operator.gt, other)
1135-
1136-
def __ge__(self, other):
1137-
if not isinstance(other, numbers.Number):
1138-
raise NotImplementedError("Only scalars supported")
1139-
if other <= 0:
1140-
raise ValueError("Comparison with negative number would produce "
1141-
"dense result")
1142-
return self.elemwise(operator.ge, other)
1166+
return self.elemwise(np.ndarray.astype, dtype)
11431167

11441168
def maybe_densify(self, allowed_nnz=1e3, allowed_fraction=0.25):
11451169
""" Convert to a dense numpy array if not too costly. Err othrewise """

sparse/tests/test_core.py

Lines changed: 84 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,10 @@ def test_elemwise(func):
185185
assert_eq(func(x), func(s))
186186

187187

188-
@pytest.mark.parametrize('func', [operator.mul, operator.add])
188+
@pytest.mark.parametrize('func', [
189+
operator.mul, operator.add, operator.sub, operator.gt,
190+
operator.lt, operator.ne
191+
])
189192
@pytest.mark.parametrize('shape', [(2,), (2, 3), (2, 3, 4), (2, 3, 4, 5)])
190193
def test_elemwise_binary(func, shape):
191194
x = random_x(shape)
@@ -197,6 +200,80 @@ def test_elemwise_binary(func, shape):
197200
assert_eq(func(xs, ys), func(x, y))
198201

199202

203+
@pytest.mark.parametrize('func', [
204+
operator.pow, operator.truediv, operator.floordiv,
205+
operator.ge, operator.le, operator.eq
206+
])
207+
@pytest.mark.filterwarnings('ignore:divide by zero')
208+
@pytest.mark.filterwarnings('ignore:invalid value')
209+
def test_auto_densification_fails(func):
210+
xs = COO.from_numpy(random_x((2, 3, 4)))
211+
ys = COO.from_numpy(random_x((2, 3, 4)))
212+
213+
with pytest.raises(ValueError):
214+
func(xs, ys)
215+
216+
217+
def test_op_scipy_sparse():
218+
x = random_x((3, 4))
219+
y = random_x((3, 4))
220+
221+
xs = COO.from_numpy(x)
222+
ys = scipy.sparse.csr_matrix(y)
223+
224+
assert_eq(x + y, xs + ys)
225+
226+
227+
@pytest.mark.parametrize('func, scalar', [
228+
(operator.mul, 5),
229+
(operator.add, 0),
230+
(operator.sub, 0),
231+
(operator.pow, 5),
232+
(operator.truediv, 3),
233+
(operator.floordiv, 4),
234+
(operator.gt, 5),
235+
(operator.lt, -5),
236+
(operator.ne, 0),
237+
(operator.ge, 5),
238+
(operator.le, -3),
239+
(operator.eq, 1)
240+
])
241+
def test_elemwise_scalar(func, scalar):
242+
x = random_x((2, 3, 4))
243+
y = scalar
244+
245+
xs = COO.from_numpy(x)
246+
fs = func(xs, y)
247+
248+
assert isinstance(fs, COO)
249+
assert xs.nnz >= fs.nnz
250+
251+
assert_eq(fs, func(x, y))
252+
253+
254+
@pytest.mark.parametrize('func, scalar', [
255+
(operator.add, 5),
256+
(operator.sub, -5),
257+
(operator.pow, -3),
258+
(operator.truediv, 0),
259+
(operator.floordiv, 0),
260+
(operator.gt, -5),
261+
(operator.lt, 5),
262+
(operator.ne, 1),
263+
(operator.ge, -3),
264+
(operator.le, 3),
265+
(operator.eq, 0)
266+
])
267+
@pytest.mark.filterwarnings('ignore:divide by zero')
268+
@pytest.mark.filterwarnings('ignore:invalid value')
269+
def test_scalar_densification_fails(func, scalar):
270+
xs = COO.from_numpy(random_x((2, 3, 4)))
271+
y = scalar
272+
273+
with pytest.raises(ValueError):
274+
func(xs, y)
275+
276+
200277
@pytest.mark.parametrize('func', [operator.and_, operator.or_, operator.xor])
201278
@pytest.mark.parametrize('shape', [(2,), (2, 3), (2, 3, 4), (2, 3, 4, 5)])
202279
def test_bitwise_binary(func, shape):
@@ -399,30 +476,19 @@ def test_addition():
399476

400477
assert_eq(x + y, a + b)
401478
assert_eq(x - y, a - b)
402-
assert_eq(-x, -a)
403-
404-
405-
def test_addition_ok_when_mostly_dense():
406-
x = np.arange(10)
407-
y = COO.from_numpy(x)
408-
409-
assert_eq(x + 1, y + 1)
410-
assert_eq(x - 1, y - 1)
411-
assert_eq(1 - x, 1 - y)
412-
assert_eq(np.exp(x), np.exp(y))
413479

414480

415481
def test_addition_not_ok_when_large_and_sparse():
416482
x = COO({(0, 0): 1}, shape=(1000000, 1000000))
417-
with pytest.raises(Exception):
483+
with pytest.raises(ValueError):
418484
x + 1
419-
with pytest.raises(Exception):
485+
with pytest.raises(ValueError):
420486
1 + x
421-
with pytest.raises(Exception):
487+
with pytest.raises(ValueError):
422488
1 - x
423-
with pytest.raises(Exception):
489+
with pytest.raises(ValueError):
424490
x - 1
425-
with pytest.raises(Exception):
491+
with pytest.raises(ValueError):
426492
np.exp(x)
427493

428494

@@ -537,7 +603,7 @@ def test_cache_csr():
537603

538604

539605
def test_empty_shape():
540-
x = COO([], [1.0])
606+
x = COO(np.empty((0, 1), dtype=np.int8), [1.0])
541607
assert x.shape == ()
542608
assert ((2 * x).todense() == np.array(2.0)).all()
543609

0 commit comments

Comments
 (0)