diff --git a/src/liballoc/benches/btree/set.rs b/src/liballoc/benches/btree/set.rs index 08e1db5fbb74d..0875b3a5fedfe 100644 --- a/src/liballoc/benches/btree/set.rs +++ b/src/liballoc/benches/btree/set.rs @@ -3,34 +3,32 @@ use std::collections::BTreeSet; use rand::{thread_rng, Rng}; use test::{black_box, Bencher}; -fn random(n1: u32, n2: u32) -> [BTreeSet; 2] { +fn random(n1: usize, n2: usize) -> [BTreeSet; 2] { let mut rng = thread_rng(); - let mut set1 = BTreeSet::new(); - let mut set2 = BTreeSet::new(); - for _ in 0..n1 { - let i = rng.gen::(); - set1.insert(i); - } - for _ in 0..n2 { - let i = rng.gen::(); - set2.insert(i); + let mut sets = [BTreeSet::new(), BTreeSet::new()]; + for i in 0..2 { + while sets[i].len() < [n1, n2][i] { + sets[i].insert(rng.gen()); + } } - [set1, set2] + assert_eq!(sets[0].len(), n1); + assert_eq!(sets[1].len(), n2); + sets } -fn staggered(n1: u32, n2: u32) -> [BTreeSet; 2] { - let mut even = BTreeSet::new(); - let mut odd = BTreeSet::new(); - for i in 0..n1 { - even.insert(i * 2); - } - for i in 0..n2 { - odd.insert(i * 2 + 1); +fn stagger(n1: usize, factor: usize) -> [BTreeSet; 2] { + let n2 = n1 * factor; + let mut sets = [BTreeSet::new(), BTreeSet::new()]; + for i in 0..(n1 + n2) { + let b = i % (factor + 1) != 0; + sets[b as usize].insert(i as u32); } - [even, odd] + assert_eq!(sets[0].len(), n1); + assert_eq!(sets[1].len(), n2); + sets } -fn neg_vs_pos(n1: u32, n2: u32) -> [BTreeSet; 2] { +fn neg_vs_pos(n1: usize, n2: usize) -> [BTreeSet; 2] { let mut neg = BTreeSet::new(); let mut pos = BTreeSet::new(); for i in -(n1 as i32)..=-1 { @@ -39,22 +37,20 @@ fn neg_vs_pos(n1: u32, n2: u32) -> [BTreeSet; 2] { for i in 1..=(n2 as i32) { pos.insert(i); } + assert_eq!(neg.len(), n1); + assert_eq!(pos.len(), n2); [neg, pos] } -fn pos_vs_neg(n1: u32, n2: u32) -> [BTreeSet; 2] { - let mut neg = BTreeSet::new(); - let mut pos = BTreeSet::new(); - for i in -(n1 as i32)..=-1 { - neg.insert(i); - } - for i in 1..=(n2 as i32) { - pos.insert(i); - } - [pos, neg] +fn pos_vs_neg(n1: usize, n2: usize) -> [BTreeSet; 2] { + let mut sets = neg_vs_pos(n2, n1); + sets.reverse(); + assert_eq!(sets[0].len(), n1); + assert_eq!(sets[1].len(), n2); + sets } -macro_rules! set_intersection_bench { +macro_rules! intersection_bench { ($name: ident, $sets: expr) => { #[bench] pub fn $name(b: &mut Bencher) { @@ -68,21 +64,64 @@ macro_rules! set_intersection_bench { }) } }; + ($name: ident, $sets: expr, $intersection_kind: ident) => { + #[bench] + pub fn $name(b: &mut Bencher) { + // setup + let sets = $sets; + assert!(sets[0].len() >= 1); + assert!(sets[1].len() >= sets[0].len()); + + // measure + b.iter(|| { + let x = BTreeSet::$intersection_kind(&sets[0], &sets[1]).count(); + black_box(x); + }) + } + }; } -set_intersection_bench! {intersect_random_100, random(100, 100)} -set_intersection_bench! {intersect_random_10k, random(10_000, 10_000)} -set_intersection_bench! {intersect_random_10_vs_10k, random(10, 10_000)} -set_intersection_bench! {intersect_random_10k_vs_10, random(10_000, 10)} -set_intersection_bench! {intersect_staggered_100, staggered(100, 100)} -set_intersection_bench! {intersect_staggered_10k, staggered(10_000, 10_000)} -set_intersection_bench! {intersect_staggered_10_vs_10k, staggered(10, 10_000)} -set_intersection_bench! {intersect_staggered_10k_vs_10, staggered(10_000, 10)} -set_intersection_bench! {intersect_neg_vs_pos_100, neg_vs_pos(100, 100)} -set_intersection_bench! {intersect_neg_vs_pos_10k, neg_vs_pos(10_000, 10_000)} -set_intersection_bench! {intersect_neg_vs_pos_10_vs_10k,neg_vs_pos(10, 10_000)} -set_intersection_bench! {intersect_neg_vs_pos_10k_vs_10,neg_vs_pos(10_000, 10)} -set_intersection_bench! {intersect_pos_vs_neg_100, pos_vs_neg(100, 100)} -set_intersection_bench! {intersect_pos_vs_neg_10k, pos_vs_neg(10_000, 10_000)} -set_intersection_bench! {intersect_pos_vs_neg_10_vs_10k,pos_vs_neg(10, 10_000)} -set_intersection_bench! {intersect_pos_vs_neg_10k_vs_10,pos_vs_neg(10_000, 10)} +intersection_bench! {intersect_100_neg_vs_100_pos, neg_vs_pos(100, 100)} +intersection_bench! {intersect_100_neg_vs_10k_pos, neg_vs_pos(100, 10_000)} +intersection_bench! {intersect_100_pos_vs_100_neg, pos_vs_neg(100, 100)} +intersection_bench! {intersect_100_pos_vs_10k_neg, pos_vs_neg(100, 10_000)} +intersection_bench! {intersect_10k_neg_vs_100_pos, neg_vs_pos(10_000, 100)} +intersection_bench! {intersect_10k_neg_vs_10k_pos, neg_vs_pos(10_000, 10_000)} +intersection_bench! {intersect_10k_pos_vs_100_neg, pos_vs_neg(10_000, 100)} +intersection_bench! {intersect_10k_pos_vs_10k_neg, pos_vs_neg(10_000, 10_000)} +intersection_bench! {intersect_random_100_vs_100_actual,random(100, 100)} +intersection_bench! {intersect_random_100_vs_100_search,random(100, 100), intersection_search} +intersection_bench! {intersect_random_100_vs_100_stitch,random(100, 100), intersection_stitch} +intersection_bench! {intersect_random_100_vs_10k_actual,random(100, 10_000)} +intersection_bench! {intersect_random_100_vs_10k_search,random(100, 10_000), intersection_search} +intersection_bench! {intersect_random_100_vs_10k_stitch,random(100, 10_000), intersection_stitch} +intersection_bench! {intersect_random_10k_vs_10k_actual,random(10_000, 10_000)} +intersection_bench! {intersect_random_10k_vs_10k_search,random(10_000, 10_000), intersection_search} +intersection_bench! {intersect_random_10k_vs_10k_stitch,random(10_000, 10_000), intersection_stitch} +intersection_bench! {intersect_stagger_100_actual, stagger(100, 1)} +intersection_bench! {intersect_stagger_100_search, stagger(100, 1), intersection_search} +intersection_bench! {intersect_stagger_100_stitch, stagger(100, 1), intersection_stitch} +intersection_bench! {intersect_stagger_10k_actual, stagger(10_000, 1)} +intersection_bench! {intersect_stagger_10k_search, stagger(10_000, 1), intersection_search} +intersection_bench! {intersect_stagger_10k_stitch, stagger(10_000, 1), intersection_stitch} +intersection_bench! {intersect_stagger_1_actual, stagger(1, 1)} +intersection_bench! {intersect_stagger_1_search, stagger(1, 1), intersection_search} +intersection_bench! {intersect_stagger_1_stitch, stagger(1, 1), intersection_stitch} +intersection_bench! {intersect_stagger_diff1_actual, stagger(100, 1 << 1)} +intersection_bench! {intersect_stagger_diff1_search, stagger(100, 1 << 1), intersection_search} +intersection_bench! {intersect_stagger_diff1_stitch, stagger(100, 1 << 1), intersection_stitch} +intersection_bench! {intersect_stagger_diff2_actual, stagger(100, 1 << 2)} +intersection_bench! {intersect_stagger_diff2_search, stagger(100, 1 << 2), intersection_search} +intersection_bench! {intersect_stagger_diff2_stitch, stagger(100, 1 << 2), intersection_stitch} +intersection_bench! {intersect_stagger_diff3_actual, stagger(100, 1 << 3)} +intersection_bench! {intersect_stagger_diff3_search, stagger(100, 1 << 3), intersection_search} +intersection_bench! {intersect_stagger_diff3_stitch, stagger(100, 1 << 3), intersection_stitch} +intersection_bench! {intersect_stagger_diff4_actual, stagger(100, 1 << 4)} +intersection_bench! {intersect_stagger_diff4_search, stagger(100, 1 << 4), intersection_search} +intersection_bench! {intersect_stagger_diff4_stitch, stagger(100, 1 << 4), intersection_stitch} +intersection_bench! {intersect_stagger_diff5_actual, stagger(100, 1 << 5)} +intersection_bench! {intersect_stagger_diff5_search, stagger(100, 1 << 5), intersection_search} +intersection_bench! {intersect_stagger_diff5_stitch, stagger(100, 1 << 5), intersection_stitch} +intersection_bench! {intersect_stagger_diff6_actual, stagger(100, 1 << 6)} +intersection_bench! {intersect_stagger_diff6_search, stagger(100, 1 << 6), intersection_search} +intersection_bench! {intersect_stagger_diff6_stitch, stagger(100, 1 << 6), intersection_stitch} diff --git a/src/liballoc/benches/lib.rs b/src/liballoc/benches/lib.rs index 4bf5ec10c41e7..c9cf318cc07df 100644 --- a/src/liballoc/benches/lib.rs +++ b/src/liballoc/benches/lib.rs @@ -1,5 +1,6 @@ #![feature(repr_simd)] #![feature(test)] +#![feature(benches_btree_set)] extern crate test; diff --git a/src/liballoc/collections/btree/set.rs b/src/liballoc/collections/btree/set.rs index 2be6455ad5903..3f1f185129904 100644 --- a/src/liballoc/collections/btree/set.rs +++ b/src/liballoc/collections/btree/set.rs @@ -3,7 +3,7 @@ use core::borrow::Borrow; use core::cmp::Ordering::{self, Less, Greater, Equal}; -use core::cmp::{min, max}; +use core::cmp::max; use core::fmt::{self, Debug}; use core::iter::{Peekable, FromIterator, FusedIterator}; use core::ops::{BitOr, BitAnd, BitXor, Sub, RangeBounds}; @@ -164,8 +164,15 @@ impl fmt::Debug for SymmetricDifference<'_, T> { /// [`intersection`]: struct.BTreeSet.html#method.intersection #[stable(feature = "rust1", since = "1.0.0")] pub struct Intersection<'a, T: 'a> { - a: Peekable>, - b: Peekable>, + a: Range<'a, T>, + b: IntersectionOther<'a, T>, + max_size: usize, +} + +#[derive(Debug)] +enum IntersectionOther<'a, T> { + Stitch(Range<'a, T>), + Search(&'a BTreeSet), } #[stable(feature = "collection_debug", since = "1.17.0")] @@ -174,6 +181,7 @@ impl fmt::Debug for Intersection<'_, T> { f.debug_tuple("Intersection") .field(&self.a) .field(&self.b) + .field(&self.max_size) .finish() } } @@ -326,9 +334,55 @@ impl BTreeSet { /// ``` #[stable(feature = "rust1", since = "1.0.0")] pub fn intersection<'a>(&'a self, other: &'a BTreeSet) -> Intersection<'a, T> { + let (a_set, b_set) = if self.len() <= other.len() { + (self, other) + } else { + (other, self) + }; + if a_set.len() <= 1 { + // At least one set is empty or a singleton, so determining + // a common range is either impossible or wasteful. + Intersection { + a: a_set.range(..), + b: IntersectionOther::Search(b_set), + max_size: a_set.len(), + } + } else if a_set.len() >= b_set.len() / 16 { + // Both sets are roughly of similar size, iterate both. + Self::intersection_stitch(a_set, b_set) + } else { + // Iterate small set only and find matches in large set. + Self::intersection_search(a_set, b_set) + } + } + #[doc(hidden)] + #[unstable(feature = "benches_btree_set", reason = "benchmarks for pull #58577", issue = "0")] + pub fn intersection_stitch<'a>( + a_set: &'a BTreeSet, + b_set: &'a BTreeSet, + ) -> Intersection<'a, T> { + let a_min = a_set.iter().next().unwrap(); + let b_min = b_set.iter().next().unwrap(); + let a_range = a_set.range(b_min..); + let b_range = b_set.range(a_min..); Intersection { - a: self.iter().peekable(), - b: other.iter().peekable(), + a: a_range, + b: IntersectionOther::Stitch(b_range), + max_size: a_set.len(), + } + } + #[doc(hidden)] + #[unstable(feature = "benches_btree_set", reason = "benchmarks for pull #58577", issue = "0")] + pub fn intersection_search<'a>( + a_set: &'a BTreeSet, + b_set: &'a BTreeSet, + ) -> Intersection<'a, T> { + let b_min = b_set.iter().next().unwrap(); + let a_range = a_set.range(b_min..); + Intersection { + a: a_range, + b: IntersectionOther::Search(b_set), + max_size: a_set.len(), } } @@ -1069,12 +1123,21 @@ impl<'a, T: Ord> Iterator for SymmetricDifference<'a, T> { #[stable(feature = "fused", since = "1.26.0")] impl FusedIterator for SymmetricDifference<'_, T> {} +impl Clone for IntersectionOther<'_, T> { + fn clone(&self) -> Self { + match &self { + IntersectionOther::Stitch(range) => IntersectionOther::Stitch(range.clone()), + IntersectionOther::Search(set) => IntersectionOther::Search(set), + } + } +} #[stable(feature = "rust1", since = "1.0.0")] impl Clone for Intersection<'_, T> { fn clone(&self) -> Self { Intersection { a: self.a.clone(), b: self.b.clone(), + max_size: self.max_size, } } } @@ -1083,24 +1146,29 @@ impl<'a, T: Ord> Iterator for Intersection<'a, T> { type Item = &'a T; fn next(&mut self) -> Option<&'a T> { - loop { - match Ord::cmp(self.a.peek()?, self.b.peek()?) { - Less => { - self.a.next(); - } - Equal => { - self.b.next(); - return self.a.next(); - } - Greater => { - self.b.next(); + match &mut self.b { + IntersectionOther::Stitch(self_b) => { + let mut a_elt = self.a.next()?; + let mut b_elt = self_b.next()?; + loop { + match Ord::cmp(a_elt, b_elt) { + Less => a_elt = self.a.next()?, + Equal => return Some(a_elt), + Greater => b_elt = self_b.next()?, + } } } + IntersectionOther::Search(b_set) => loop { + let a_elt = self.a.next()?; + if b_set.contains(&a_elt) { + return Some(a_elt); + } + }, } } fn size_hint(&self) -> (usize, Option) { - (0, Some(min(self.a.len(), self.b.len()))) + (0, Some(self.max_size)) } } diff --git a/src/liballoc/tests/btree/set.rs b/src/liballoc/tests/btree/set.rs index 4f5168f1ce572..a98e08e0d7ebb 100644 --- a/src/liballoc/tests/btree/set.rs +++ b/src/liballoc/tests/btree/set.rs @@ -69,6 +69,19 @@ fn test_intersection() { check_intersection(&[11, 1, 3, 77, 103, 5, -5], &[2, 11, 77, -9, -42, 5, 3], &[3, 5, 11, 77]); + + let mut large = [0i32; 512]; + for i in 0..512 { + large[i] = i as i32 + } + check_intersection(&large[..], &[], &[]); + check_intersection(&large[..], &[-1], &[]); + check_intersection(&large[..], &[42], &[42]); + check_intersection(&large[..], &[4, 2], &[2, 4]); + check_intersection(&[], &large[..], &[]); + check_intersection(&[-1], &large[..], &[]); + check_intersection(&[42], &large[..], &[42]); + check_intersection(&[4, 2], &large[..], &[2, 4]); } #[test]