diff --git a/benches/numeric.rs b/benches/numeric.rs index 76d07f1e4..5a7c111a5 100644 --- a/benches/numeric.rs +++ b/benches/numeric.rs @@ -1,7 +1,7 @@ #![feature(test)] extern crate test; -use test::Bencher; +use test::{black_box, Bencher}; extern crate ndarray; use ndarray::prelude::*; @@ -65,6 +65,38 @@ fn contiguous_sum_1e2(bench: &mut Bencher) }); } +#[bench] +fn contiguous_sum_ix3_1e2(bench: &mut Bencher) +{ + let n = 1e2 as usize; + let a = Array::linspace(-1e6, 1e6, n * n * n) + .into_shape([n, n, n]) + .unwrap(); + bench.iter(|| black_box(&a).sum()); +} + +#[bench] +fn inner_discontiguous_sum_ix3_1e2(bench: &mut Bencher) +{ + let n = 1e2 as usize; + let a = Array::linspace(-1e6, 1e6, n * n * 2*n) + .into_shape([n, n, 2*n]) + .unwrap(); + let v = a.slice(s![.., .., ..;2]); + bench.iter(|| black_box(&v).sum()); +} + +#[bench] +fn middle_discontiguous_sum_ix3_1e2(bench: &mut Bencher) +{ + let n = 1e2 as usize; + let a = Array::linspace(-1e6, 1e6, n * 2*n * n) + .into_shape([n, 2*n, n]) + .unwrap(); + let v = a.slice(s![.., ..;2, ..]); + bench.iter(|| black_box(&v).sum()); +} + #[bench] fn sum_by_row_1e4(bench: &mut Bencher) { @@ -88,3 +120,15 @@ fn sum_by_col_1e4(bench: &mut Bencher) a.sum_axis(Axis(1)) }); } + +#[bench] +fn sum_by_middle_1e2(bench: &mut Bencher) +{ + let n = 1e2 as usize; + let a = Array::linspace(-1e6, 1e6, n * n * n) + .into_shape([n, n, n]) + .unwrap(); + bench.iter(|| { + a.sum_axis(Axis(1)) + }); +} diff --git a/src/numeric/impl_numeric.rs b/src/numeric/impl_numeric.rs index 15ddd789e..0be4a0ce0 100644 --- a/src/numeric/impl_numeric.rs +++ b/src/numeric/impl_numeric.rs @@ -8,7 +8,6 @@ use std::ops::{Add, Div, Mul}; use num_traits::{self, Zero, Float, FromPrimitive}; -use itertools::free::enumerate; use crate::imp_prelude::*; use crate::numeric_util; @@ -33,17 +32,10 @@ impl ArrayBase where A: Clone + Add + num_traits::Zero, { if let Some(slc) = self.as_slice_memory_order() { - return numeric_util::pairwise_sum(&slc) - } - let mut sum = A::zero(); - for row in self.inner_rows() { - if let Some(slc) = row.as_slice() { - sum = sum + numeric_util::pairwise_sum(&slc); - } else { - sum = sum + numeric_util::iterator_pairwise_sum(row.iter()); - } + numeric_util::pairwise_sum(&slc) + } else { + numeric_util::iterator_pairwise_sum(self.iter()) } - sum } /// Return the sum of all elements in the array. @@ -104,16 +96,14 @@ impl ArrayBase D: RemoveAxis, { let n = self.len_of(axis); - let stride = self.strides()[axis.index()]; - if self.ndim() == 2 && stride == 1 { + if self.stride_of(axis) == 1 { // contiguous along the axis we are summing let mut res = Array::zeros(self.raw_dim().remove_axis(axis)); - let ax = axis.index(); - for (i, elt) in enumerate(&mut res) { - *elt = self.index_axis(Axis(1 - ax), i).sum(); - } + Zip::from(&mut res) + .and(self.lanes(axis)) + .apply(|sum, lane| *sum = lane.sum()); res - } else if self.len_of(axis) <= numeric_util::NAIVE_SUM_THRESHOLD { + } else if n <= numeric_util::NAIVE_SUM_THRESHOLD { self.fold_axis(axis, A::zero(), |acc, x| acc.clone() + x.clone()) } else { let (v1, v2) = self.view().split_at(axis, n / 2); diff --git a/src/numeric_util.rs b/src/numeric_util.rs index bcd1080e8..23966b9a9 100644 --- a/src/numeric_util.rs +++ b/src/numeric_util.rs @@ -56,16 +56,18 @@ where I: Iterator, A: Clone + Add + Zero, { - let mut partial_sums = vec![]; - let mut partial_sum = A::zero(); - for (i, x) in iter.enumerate() { - partial_sum = partial_sum + x.clone(); - if i % NAIVE_SUM_THRESHOLD == NAIVE_SUM_THRESHOLD - 1 { + let (len, _) = iter.size_hint(); + let cap = len.saturating_sub(1) / NAIVE_SUM_THRESHOLD + 1; // ceiling of division + let mut partial_sums = Vec::with_capacity(cap); + let (_, last_sum) = iter.fold((0, A::zero()), |(count, partial_sum), x| { + if count < NAIVE_SUM_THRESHOLD { + (count + 1, partial_sum + x.clone()) + } else { partial_sums.push(partial_sum); - partial_sum = A::zero(); + (1, x.clone()) } - } - partial_sums.push(partial_sum); + }); + partial_sums.push(last_sum); pure_pairwise_sum(&partial_sums) } @@ -205,3 +207,17 @@ pub fn unrolled_eq(xs: &[A], ys: &[A]) -> bool true } + +#[cfg(test)] +mod tests { + use quickcheck::quickcheck; + use std::num::Wrapping; + use super::iterator_pairwise_sum; + + quickcheck! { + fn iterator_pairwise_sum_is_correct(xs: Vec) -> bool { + let xs: Vec<_> = xs.into_iter().map(|x| Wrapping(x)).collect(); + iterator_pairwise_sum(xs.iter()) == xs.iter().sum() + } + } +}