Skip to content

Commit 3265483

Browse files
committed
fix: Fix stddev indeterministically producing NAN
In the VarianceGroupAccumulator we were missing a `count == 0` check that is present in the normal Accumulator. This mostly does not matter except for the case where the first state to be merged has `count == 0` then the `merge` function will incorrectly calculate a new m2 of NAN which will propagate to the final result. This fixes the bug bu adding the missing `count == 0` check.
1 parent 2482ff4 commit 3265483

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

datafusion/functions-aggregate/src/variance.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,7 @@ fn merge(
316316
mean2: f64,
317317
m22: f64,
318318
) -> (u64, f64, f64) {
319+
debug_assert!(count != 0 || count2 != 0, "Cannot merge two empty states");
319320
let new_count = count + count2;
320321
let new_mean =
321322
mean * count as f64 / new_count as f64 + mean2 * count2 as f64 / new_count as f64;
@@ -573,6 +574,9 @@ impl GroupsAccumulator for VarianceGroupsAccumulator {
573574
partial_m2s,
574575
opt_filter,
575576
|group_index, partial_count, partial_mean, partial_m2| {
577+
if partial_count == 0 {
578+
return;
579+
}
576580
let (new_count, new_mean, new_m2) = merge(
577581
self.counts[group_index],
578582
self.means[group_index],
@@ -612,3 +616,32 @@ impl GroupsAccumulator for VarianceGroupsAccumulator {
612616
+ self.counts.capacity() * size_of::<u64>()
613617
}
614618
}
619+
620+
#[cfg(test)]
621+
mod tests {
622+
use datafusion_expr::EmitTo;
623+
624+
use super::*;
625+
626+
#[test]
627+
fn test_groups_accumulator_merge_empty_states() -> Result<()> {
628+
let state_1 = vec![
629+
Arc::new(UInt64Array::from(vec![0])) as ArrayRef,
630+
Arc::new(Float64Array::from(vec![0.0])),
631+
Arc::new(Float64Array::from(vec![0.0])),
632+
];
633+
let state_2 = vec![
634+
Arc::new(UInt64Array::from(vec![2])) as ArrayRef,
635+
Arc::new(Float64Array::from(vec![1.0])),
636+
Arc::new(Float64Array::from(vec![1.0])),
637+
];
638+
let mut acc = VarianceGroupsAccumulator::new(StatsType::Sample);
639+
acc.merge_batch(&state_1, &[0], None, 1)?;
640+
acc.merge_batch(&state_2, &[0], None, 1)?;
641+
let result = acc.evaluate(EmitTo::All)?;
642+
let result = result.as_any().downcast_ref::<Float64Array>().unwrap();
643+
assert_eq!(result.len(), 1);
644+
assert_eq!(result.value(0), 1.0);
645+
Ok(())
646+
}
647+
}

0 commit comments

Comments
 (0)