Skip to content

Commit fb8c3be

Browse files
committed
Implement WeightedIndex, SliceRandom::choose_weighted and SliceRandom::choose_weighted_mut
1 parent af1303c commit fb8c3be

File tree

4 files changed

+338
-28
lines changed

4 files changed

+338
-28
lines changed

src/distributions/mod.rs

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,8 @@ use Rng;
173173
#[doc(inline)] pub use self::other::Alphanumeric;
174174
#[doc(inline)] pub use self::uniform::Uniform;
175175
#[doc(inline)] pub use self::float::{OpenClosed01, Open01};
176+
#[cfg(feature="alloc")]
177+
#[doc(inline)] pub use self::weighted::WeightedIndex;
176178
#[cfg(feature="std")]
177179
#[doc(inline)] pub use self::gamma::{Gamma, ChiSquared, FisherF, StudentT};
178180
#[cfg(feature="std")]
@@ -192,6 +194,8 @@ use Rng;
192194
#[doc(inline)] pub use self::dirichlet::Dirichlet;
193195

194196
pub mod uniform;
197+
#[cfg(feature="alloc")]
198+
#[doc(hidden)] pub mod weighted;
195199
#[cfg(feature="std")]
196200
#[doc(hidden)] pub mod gamma;
197201
#[cfg(feature="std")]
@@ -372,6 +376,8 @@ pub struct Standard;
372376

373377

374378
/// A value with a particular weight for use with `WeightedChoice`.
379+
#[deprecated(since="0.6.0", note="use WeightedIndex instead")]
380+
#[allow(deprecated)]
375381
#[derive(Copy, Clone, Debug)]
376382
pub struct Weighted<T> {
377383
/// The numerical weight of this item
@@ -382,34 +388,18 @@ pub struct Weighted<T> {
382388

383389
/// A distribution that selects from a finite collection of weighted items.
384390
///
385-
/// Each item has an associated weight that influences how likely it
386-
/// is to be chosen: higher weight is more likely.
387-
///
388-
/// The `Clone` restriction is a limitation of the `Distribution` trait.
389-
/// Note that `&T` is (cheaply) `Clone` for all `T`, as is `u32`, so one can
390-
/// store references or indices into another vector.
391-
///
392-
/// # Example
393-
///
394-
/// ```
395-
/// use rand::distributions::{Weighted, WeightedChoice, Distribution};
396-
///
397-
/// let mut items = vec!(Weighted { weight: 2, item: 'a' },
398-
/// Weighted { weight: 4, item: 'b' },
399-
/// Weighted { weight: 1, item: 'c' });
400-
/// let wc = WeightedChoice::new(&mut items);
401-
/// let mut rng = rand::thread_rng();
402-
/// for _ in 0..16 {
403-
/// // on average prints 'a' 4 times, 'b' 8 and 'c' twice.
404-
/// println!("{}", wc.sample(&mut rng));
405-
/// }
406-
/// ```
391+
/// Deprecated: use [`WeightedIndex`] instead.
392+
/// [`WeightedIndex`]: distributions/struct.WeightedIndex.html
393+
#[deprecated(since="0.6.0", note="use WeightedIndex instead")]
394+
#[allow(deprecated)]
407395
#[derive(Debug)]
408396
pub struct WeightedChoice<'a, T:'a> {
409397
items: &'a mut [Weighted<T>],
410398
weight_range: Uniform<u32>,
411399
}
412400

401+
#[deprecated(since="0.6.0", note="use WeightedIndex instead")]
402+
#[allow(deprecated)]
413403
impl<'a, T: Clone> WeightedChoice<'a, T> {
414404
/// Create a new `WeightedChoice`.
415405
///
@@ -447,6 +437,8 @@ impl<'a, T: Clone> WeightedChoice<'a, T> {
447437
}
448438
}
449439

440+
#[deprecated(since="0.6.0", note="use WeightedIndex instead")]
441+
#[allow(deprecated)]
450442
impl<'a, T: Clone> Distribution<T> for WeightedChoice<'a, T> {
451443
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> T {
452444
// we want to find the first element that has cumulative
@@ -556,9 +548,11 @@ fn ziggurat<R: Rng + ?Sized, P, Z>(
556548
#[cfg(test)]
557549
mod tests {
558550
use rngs::mock::StepRng;
551+
#[allow(deprecated)]
559552
use super::{WeightedChoice, Weighted, Distribution};
560553

561554
#[test]
555+
#[allow(deprecated)]
562556
fn test_weighted_choice() {
563557
// this makes assumptions about the internal implementation of
564558
// WeightedChoice. It may fail when the implementation in
@@ -618,6 +612,7 @@ mod tests {
618612
}
619613

620614
#[test]
615+
#[allow(deprecated)]
621616
fn test_weighted_clone_initialization() {
622617
let initial : Weighted<u32> = Weighted {weight: 1, item: 1};
623618
let clone = initial.clone();
@@ -626,6 +621,7 @@ mod tests {
626621
}
627622

628623
#[test] #[should_panic]
624+
#[allow(deprecated)]
629625
fn test_weighted_clone_change_weight() {
630626
let initial : Weighted<u32> = Weighted {weight: 1, item: 1};
631627
let mut clone = initial.clone();
@@ -634,6 +630,7 @@ mod tests {
634630
}
635631

636632
#[test] #[should_panic]
633+
#[allow(deprecated)]
637634
fn test_weighted_clone_change_item() {
638635
let initial : Weighted<u32> = Weighted {weight: 1, item: 1};
639636
let mut clone = initial.clone();
@@ -643,15 +640,18 @@ mod tests {
643640
}
644641

645642
#[test] #[should_panic]
643+
#[allow(deprecated)]
646644
fn test_weighted_choice_no_items() {
647645
WeightedChoice::<isize>::new(&mut []);
648646
}
649647
#[test] #[should_panic]
648+
#[allow(deprecated)]
650649
fn test_weighted_choice_zero_weight() {
651650
WeightedChoice::new(&mut [Weighted { weight: 0, item: 0},
652651
Weighted { weight: 0, item: 1}]);
653652
}
654653
#[test] #[should_panic]
654+
#[allow(deprecated)]
655655
fn test_weighted_choice_weight_overflows() {
656656
let x = ::core::u32::MAX / 2; // x + x + 2 is the overflow
657657
WeightedChoice::new(&mut [Weighted { weight: x, item: 0 },

src/distributions/weighted.rs

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
// Copyright 2017 The Rust Project Developers. See the COPYRIGHT
2+
// file at the top-level directory of this distribution and at
3+
// https://rust-lang.org/COPYRIGHT.
4+
//
5+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6+
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7+
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
8+
// option. This file may not be copied, modified, or distributed
9+
// except according to those terms.
10+
11+
//! A distribution using weighted sampling to pick an discretely selected item.
12+
//!
13+
//! When a `WeightedIndex` is sampled from, it returns the index
14+
//! of a random element from the iterator used when the `WeightedIndex` was
15+
//! created. The chance of a given element being picked is proportional to the
16+
//! value of the element. The weights can use any type `X` for which an
17+
//! implementaiton of [`Uniform<X>`] exists.
18+
//!
19+
//! # Example
20+
//!
21+
//! ```
22+
//! use rand::prelude::*;
23+
//! use rand::distributions::WeightedIndex;
24+
//!
25+
//! let choices = ['a', 'b', 'c'];
26+
//! let weights = [2, 1, 1];
27+
//! let dist = WeightedIndex::new(&weights).unwrap();
28+
//! let mut rng = thread_rng();
29+
//! for _ in 0..100 {
30+
//! // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c'
31+
//! println!("{}", choices[dist.sample(&mut rng)]);
32+
//! }
33+
//!
34+
//! let items = [('a', 0), ('b', 3), ('c', 7)];
35+
//! let dist2 = WeightedIndex::new(items.iter().map(|item| item.1)).unwrap();
36+
//! for _ in 0..100 {
37+
//! // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c'
38+
//! println!("{}", items[dist2.sample(&mut rng)].0);
39+
//! }
40+
//! ```
41+
42+
use Rng;
43+
use distributions::Distribution;
44+
use distributions::uniform::{UniformSampler, SampleUniform, SampleBorrow};
45+
use ::core::cmp::PartialOrd;
46+
use ::{Error, ErrorKind};
47+
48+
#[cfg(feature = "alloc")]
49+
#[derive(Debug, Clone)]
50+
pub struct WeightedIndex<X: SampleUniform + PartialOrd> {
51+
cumulative_weights: Vec<X>,
52+
weight_distribution: X::Sampler,
53+
}
54+
55+
impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
56+
/// Creates a new a [`WeightedIndex`] [`Distribution`] using the values
57+
/// in `weights`. The weights can use any type `X` for which an
58+
/// implementaiton of [`Uniform<X>`] exists.
59+
///
60+
/// Returns an error if the iterator is empty, or its total value is 0.
61+
///
62+
/// # Panics
63+
///
64+
/// If a value in the iterator is `< 0`.
65+
///
66+
/// [`WeightedIndex`]: struct.WeightedIndex.html
67+
/// [`Distribution`]: trait.Distribution.html
68+
/// [`Uniform<X>`]: struct.Uniform.html
69+
pub fn new<I>(weights: I) -> Result<WeightedIndex<X>, Error>
70+
where I: IntoIterator,
71+
I::Item: SampleBorrow<X>,
72+
X: for<'a> ::core::ops::AddAssign<&'a X> +
73+
Clone +
74+
Default {
75+
let mut iter = weights.into_iter();
76+
let mut total_weight: X = iter.next()
77+
.ok_or(Error::new(ErrorKind::Unexpected, "Empty iterator in WeightedIndex::new"))?
78+
.borrow()
79+
.clone();
80+
81+
let zero = <X as Default>::default();
82+
let weights = iter.map(|w| {
83+
assert!(*w.borrow() >= zero, "Negative weight in WeightedIndex::new");
84+
let prev_weight = total_weight.clone();
85+
total_weight += w.borrow();
86+
prev_weight
87+
}).collect::<Vec<X>>();
88+
89+
if total_weight == zero {
90+
return Err(Error::new(ErrorKind::Unexpected, "Total weight is zero in WeightedIndex::new"));
91+
}
92+
let distr = X::Sampler::new(zero, total_weight);
93+
94+
Ok(WeightedIndex { cumulative_weights: weights, weight_distribution: distr })
95+
}
96+
}
97+
98+
impl<X> Distribution<usize> for WeightedIndex<X> where
99+
X: SampleUniform + PartialOrd {
100+
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
101+
let chosen_weight = self.weight_distribution.sample(rng);
102+
// Invariants: indexes in range [start, end] (inclusive) are candidate indexes
103+
// cumulative_weights[start-1] <= chosen_weight
104+
// chosen_weight < cumulative_weights[end]
105+
// The returned index is the first one whose value is >= chosen_weight
106+
let mut start = 0usize;
107+
let mut end = self.cumulative_weights.len();
108+
while start < end {
109+
let mid = (start + end) / 2;
110+
if chosen_weight >= * unsafe { self.cumulative_weights.get_unchecked(mid) } {
111+
start = mid + 1;
112+
} else {
113+
end = mid;
114+
}
115+
}
116+
debug_assert_eq!(start, end);
117+
start
118+
}
119+
}
120+
121+
#[cfg(test)]
122+
mod test {
123+
use super::*;
124+
#[cfg(feature="std")]
125+
use core::panic::catch_unwind;
126+
127+
#[test]
128+
fn test_weightedindex() {
129+
let mut r = ::test::rng(700);
130+
const N_REPS: u32 = 5000;
131+
let weights = vec![1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7];
132+
let total_weight = weights.iter().sum::<u32>() as f32;
133+
134+
let verify = |result: [i32; 14]| {
135+
for (i, count) in result.iter().enumerate() {
136+
let exp = (weights[i] * N_REPS) as f32 / total_weight;
137+
let mut err = (*count as f32 - exp).abs();
138+
if err != 0.0 {
139+
err /= exp;
140+
}
141+
if err > 0.25 {
142+
println!("err: {}, i: {}\ncounts: {:?}", err, i, &result);
143+
}
144+
assert!(err <= 0.25);
145+
}
146+
};
147+
148+
// WeightedIndex from array
149+
let mut chosen = [0i32; 14];
150+
let distr = WeightedIndex::new(weights.clone()).unwrap();
151+
for _ in 0..N_REPS {
152+
chosen[distr.sample(&mut r)] += 1;
153+
}
154+
verify(chosen);
155+
156+
// WeightedIndex from slice
157+
chosen = [0i32; 14];
158+
let distr = WeightedIndex::new(&weights[..]).unwrap();
159+
for _ in 0..N_REPS {
160+
chosen[distr.sample(&mut r)] += 1;
161+
}
162+
verify(chosen);
163+
164+
// WeightedIndex from iterator
165+
chosen = [0i32; 14];
166+
let distr = WeightedIndex::new(weights.iter()).unwrap();
167+
for _ in 0..N_REPS {
168+
chosen[distr.sample(&mut r)] += 1;
169+
}
170+
verify(chosen);
171+
}
172+
173+
#[test]
174+
#[cfg(all(feature="std",
175+
not(target_arch = "wasm32"),
176+
not(target_arch = "asmjs")))]
177+
fn test_weighted_assertions() {
178+
assert!(catch_unwind(|| WeightedIndex::new(&[1, 2, 3])).is_ok());
179+
assert!(catch_unwind(|| WeightedIndex::new(&[10, -1, 10])).is_err());
180+
assert!(catch_unwind(|| WeightedIndex::new(&[1, -1])).is_err());
181+
}
182+
}

src/lib.rs

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,6 @@
134134
//!
135135
//! For more slice/sequence related functionality, look in the [`seq` module].
136136
//!
137-
//! There is also [`distributions::WeightedChoice`], which can be used to pick
138-
//! elements at random with some probability. But it does not work well at the
139-
//! moment and is going through a redesign.
140-
//!
141137
//!
142138
//! # Error handling
143139
//!
@@ -187,7 +183,6 @@
187183
//!
188184
//!
189185
//! [`distributions` module]: distributions/index.html
190-
//! [`distributions::WeightedChoice`]: distributions/struct.WeightedChoice.html
191186
//! [`EntropyRng`]: rngs/struct.EntropyRng.html
192187
//! [`Error`]: struct.Error.html
193188
//! [`gen_range`]: trait.Rng.html#method.gen_range

0 commit comments

Comments
 (0)