@@ -460,7 +460,9 @@ def reshape(self, shape):
460
460
# TODO: this np.prod(self.shape) enforces a 2**64 limit to array size
461
461
linear_loc = self .linear_loc ()
462
462
463
- coords = np .empty ((len (shape ), self .nnz ), dtype = np .min_scalar_type (max (shape )))
463
+ max_shape = max (shape ) if len (shape ) != 0 else 1
464
+
465
+ coords = np .empty ((len (shape ), self .nnz ), dtype = np .min_scalar_type (max_shape - 1 ))
464
466
strides = 1
465
467
for i , d in enumerate (shape [::- 1 ]):
466
468
coords [- (i + 1 ), :] = (linear_loc // strides ) % d
@@ -580,31 +582,22 @@ def sum_duplicates(self):
580
582
return self
581
583
582
584
def __add__ (self , other ):
583
- if isinstance (other , numbers .Number ) and other == 0 :
584
- return self
585
- if not isinstance (other , COO ):
586
- return self .maybe_densify () + other
587
- else :
588
- return self .elemwise_binary (operator .add , other )
585
+ return self .elemwise (operator .add , other )
589
586
590
- def __radd__ (self , other ):
591
- return self + other
587
+ __radd__ = __add__
592
588
593
589
def __neg__ (self ):
594
590
return COO (self .coords , - self .data , self .shape , self .has_duplicates ,
595
591
self .sorted )
596
592
597
593
def __sub__ (self , other ):
598
- return self + ( - other )
594
+ return self . elemwise ( operator . sub , other )
599
595
600
596
def __rsub__ (self , other ):
601
- return - self + other
597
+ return - ( self - other )
602
598
603
599
def __mul__ (self , other ):
604
- if isinstance (other , COO ):
605
- return self .elemwise_binary (operator .mul , other )
606
- else :
607
- return self .elemwise (operator .mul , other )
600
+ return self .elemwise (operator .mul , other )
608
601
609
602
__rmul__ = __mul__
610
603
@@ -620,32 +613,86 @@ def __pow__(self, other):
620
613
return self .elemwise (operator .pow , other )
621
614
622
615
def __and__ (self , other ):
623
- return self .elemwise_binary (operator .and_ , other )
616
+ return self .elemwise (operator .and_ , other )
624
617
625
618
def __xor__ (self , other ):
626
- return self .elemwise_binary (operator .xor , other )
619
+ return self .elemwise (operator .xor , other )
627
620
628
621
def __or__ (self , other ):
629
- return self .elemwise_binary (operator .or_ , other )
622
+ return self .elemwise (operator .or_ , other )
623
+
624
+ def __gt__ (self , other ):
625
+ return self .elemwise (operator .gt , other )
626
+
627
+ def __ge__ (self , other ):
628
+ return self .elemwise (operator .ge , other )
629
+
630
+ def __lt__ (self , other ):
631
+ return self .elemwise (operator .lt , other )
632
+
633
+ def __le__ (self , other ):
634
+ return self .elemwise (operator .le , other )
635
+
636
+ def __eq__ (self , other ):
637
+ return self .elemwise (operator .eq , other )
638
+
639
+ def __ne__ (self , other ):
640
+ return self .elemwise (operator .ne , other )
630
641
631
642
def elemwise (self , func , * args , ** kwargs ):
643
+ """
644
+ Apply a function to one or two arguments.
645
+
646
+ Parameters
647
+ ----------
648
+ func
649
+ The function to apply to one or two arguments.
650
+ args : tuple, optional
651
+ The extra arguments to pass to the function. If args[0] is a COO object
652
+ or a scipy.sparse.spmatrix, the function will be treated as a binary
653
+ function. Otherwise, it will be treated as a unary function.
654
+ kwargs : dict, optional
655
+ The kwargs to pass to the function.
656
+
657
+ Returns
658
+ -------
659
+ COO
660
+ The result of applying the function.
661
+ """
662
+ if len (args ) == 0 :
663
+ return self ._elemwise_unary (func , * args , ** kwargs )
664
+ else :
665
+ other = args [0 ]
666
+ if isinstance (other , COO ):
667
+ return self ._elemwise_binary (func , * args , ** kwargs )
668
+ elif isinstance (other , scipy .sparse .spmatrix ):
669
+ other = COO .from_scipy_sparse (other )
670
+ return self ._elemwise_binary (func , other , * args [1 :], ** kwargs )
671
+ else :
672
+ return self ._elemwise_unary (func , * args , ** kwargs )
673
+
674
+ def _elemwise_unary (self , func , * args , ** kwargs ):
632
675
check = kwargs .pop ('check' , True )
633
676
data_zero = _zero_of_dtype (self .dtype )
634
677
func_zero = _zero_of_dtype (func (data_zero , * args , ** kwargs ).dtype )
635
678
if check and func (data_zero , * args , ** kwargs ) != func_zero :
636
679
raise ValueError ("Performing this operation would produce "
637
680
"a dense result: %s" % str (func ))
638
- return COO (self .coords , func (self .data , * args , ** kwargs ),
681
+
682
+ data_func = func (self .data , * args , ** kwargs )
683
+ nonzero = data_func != func_zero
684
+
685
+ return COO (self .coords [:, nonzero ], data_func [nonzero ],
639
686
shape = self .shape ,
640
687
has_duplicates = self .has_duplicates ,
641
688
sorted = self .sorted )
642
689
643
- def elemwise_binary (self , func , other , * args , ** kwargs ):
690
+ def _elemwise_binary (self , func , other , * args , ** kwargs ):
644
691
assert isinstance (other , COO )
692
+ check = kwargs .pop ('check' , True )
645
693
self_zero = _zero_of_dtype (self .dtype )
646
694
other_zero = _zero_of_dtype (other .dtype )
647
- check = kwargs .pop ('check' , True )
648
- func_zero = _zero_of_dtype (func (self_zero , other_zero , * args , ** kwargs ).dtype )
695
+ func_zero = _zero_of_dtype (func (self_zero , other_zero , * args , ** kwargs ).dtype )
649
696
if check and func (self_zero , other_zero , * args , ** kwargs ) != func_zero :
650
697
raise ValueError ("Performing this operation would produce "
651
698
"a dense result: %s" % str (func ))
@@ -690,12 +737,6 @@ def elemwise_binary(self, func, other, *args, **kwargs):
690
737
matched_self , matched_other = _match_arrays (self_reduced_linear ,
691
738
other_reduced_linear )
692
739
693
- # Locate coordinates without a match
694
- unmatched_self = np .ones (self .nnz , dtype = np .bool )
695
- unmatched_self [matched_self ] = False
696
- unmatched_other = np .ones (other .nnz , dtype = np .bool )
697
- unmatched_other [matched_other ] = False
698
-
699
740
# Start with an empty list. This may reduce computation in many cases.
700
741
data_list = []
701
742
coords_list = []
@@ -711,11 +752,10 @@ def elemwise_binary(self, func, other, *args, **kwargs):
711
752
coords_list .append (matched_coords )
712
753
713
754
self_func = func (self_data , other_zero , * args , ** kwargs )
714
-
715
755
# Add unmatched parts as necessary.
716
756
if (self_func != func_zero ).any ():
717
757
self_unmatched_coords , self_unmatched_func = \
718
- self ._get_unmatched_coords_data (self_coords , self_data , self_shape ,
758
+ self ._get_unmatched_coords_data (self_coords , self_func , self_shape ,
719
759
result_shape , matched_self ,
720
760
matched_coords )
721
761
@@ -726,7 +766,7 @@ def elemwise_binary(self, func, other, *args, **kwargs):
726
766
727
767
if (other_func != func_zero ).any ():
728
768
other_unmatched_coords , other_unmatched_func = \
729
- self ._get_unmatched_coords_data (other_coords , other_data , other_shape ,
769
+ self ._get_unmatched_coords_data (other_coords , other_func , other_shape ,
730
770
result_shape , matched_other ,
731
771
matched_coords )
732
772
@@ -1067,7 +1107,7 @@ def __abs__(self):
1067
1107
1068
1108
def exp (self , out = None ):
1069
1109
assert out is None
1070
- return np . exp ( self .maybe_densify () )
1110
+ return self .elemwise ( np . exp )
1071
1111
1072
1112
def expm1 (self , out = None ):
1073
1113
assert out is None
@@ -1123,23 +1163,7 @@ def conjugate(self, out=None):
1123
1163
1124
1164
def astype (self , dtype , out = None ):
1125
1165
assert out is None
1126
- return self .elemwise (np .ndarray .astype , dtype , check = False )
1127
-
1128
- def __gt__ (self , other ):
1129
- if not isinstance (other , numbers .Number ):
1130
- raise NotImplementedError ("Only scalars supported" )
1131
- if other < 0 :
1132
- raise ValueError ("Comparison with negative number would produce "
1133
- "dense result" )
1134
- return self .elemwise (operator .gt , other )
1135
-
1136
- def __ge__ (self , other ):
1137
- if not isinstance (other , numbers .Number ):
1138
- raise NotImplementedError ("Only scalars supported" )
1139
- if other <= 0 :
1140
- raise ValueError ("Comparison with negative number would produce "
1141
- "dense result" )
1142
- return self .elemwise (operator .ge , other )
1166
+ return self .elemwise (np .ndarray .astype , dtype )
1143
1167
1144
1168
def maybe_densify (self , allowed_nnz = 1e3 , allowed_fraction = 0.25 ):
1145
1169
""" Convert to a dense numpy array if not too costly. Err othrewise """
0 commit comments