diff --git a/benches/misc.rs b/benches/misc.rs index 4eb910c9d2d..1d17cc50193 100644 --- a/benches/misc.rs +++ b/benches/misc.rs @@ -11,14 +11,15 @@ use rand::prelude::*; use rand::seq::*; #[bench] -fn misc_gen_bool(b: &mut Bencher) { +fn misc_gen_bool_const(b: &mut Bencher) { let mut rng = SmallRng::from_rng(&mut thread_rng()).unwrap(); b.iter(|| { + // Can be evaluated at compile time. let mut accum = true; for _ in 0..::RAND_BENCH_N { accum ^= rng.gen_bool(0.18); } - black_box(accum); + accum }) } @@ -27,12 +28,37 @@ fn misc_gen_bool_var(b: &mut Bencher) { let mut rng = SmallRng::from_rng(&mut thread_rng()).unwrap(); b.iter(|| { let mut p = 0.18; + black_box(&mut p); // Avoid constant folding. + for _ in 0..::RAND_BENCH_N { + black_box(rng.gen_bool(p)); + } + }) +} + +#[bench] +fn misc_bernoulli_const(b: &mut Bencher) { + let mut rng = SmallRng::from_rng(&mut thread_rng()).unwrap(); + let d = rand::distributions::Bernoulli::new(0.18); + b.iter(|| { + // Can be evaluated at compile time. let mut accum = true; for _ in 0..::RAND_BENCH_N { - accum ^= rng.gen_bool(p); - p += 0.0001; + accum ^= rng.sample(d); + } + accum + }) +} + +#[bench] +fn misc_bernoulli_var(b: &mut Bencher) { + let mut rng = SmallRng::from_rng(&mut thread_rng()).unwrap(); + b.iter(|| { + let mut p = 0.18; + black_box(&mut p); // Avoid constant folding. + let d = rand::distributions::Bernoulli::new(p); + for _ in 0..::RAND_BENCH_N { + black_box(rng.sample(d)); } - black_box(accum); }) } diff --git a/src/distributions/bernoulli.rs b/src/distributions/bernoulli.rs new file mode 100644 index 00000000000..2361fac0c21 --- /dev/null +++ b/src/distributions/bernoulli.rs @@ -0,0 +1,120 @@ +// Copyright 2018 The Rust Project Developers. See the COPYRIGHT +// file at the top-level directory of this distribution and at +// https://rust-lang.org/COPYRIGHT. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. +//! The Bernoulli distribution. + +use Rng; +use distributions::Distribution; + +/// The Bernoulli distribution. +/// +/// This is a special case of the Binomial distribution where `n = 1`. +/// +/// # Example +/// +/// ```rust +/// use rand::distributions::{Bernoulli, Distribution}; +/// +/// let d = Bernoulli::new(0.3); +/// let v = d.sample(&mut rand::thread_rng()); +/// println!("{} is from a Bernoulli distribution", v); +/// ``` +/// +/// # Precision +/// +/// This `Bernoulli` distribution uses 64 bits from the RNG (a `u64`), +/// so only probabilities that are multiples of 2-64 can be +/// represented. +#[derive(Clone, Copy, Debug)] +pub struct Bernoulli { + /// Probability of success, relative to the maximal integer. + p_int: u64, +} + +impl Bernoulli { + /// Construct a new `Bernoulli` with the given probability of success `p`. + /// + /// # Panics + /// + /// If `p < 0` or `p > 1`. + /// + /// # Precision + /// + /// For `p = 1.0`, the resulting distribution will always generate true. + /// For `p = 0.0`, the resulting distribution will always generate false. + /// + /// This method is accurate for any input `p` in the range `[0, 1]` which is + /// a multiple of 2-64. (Note that not all multiples of + /// 2-64 in `[0, 1]` can be represented as a `f64`.) + #[inline] + pub fn new(p: f64) -> Bernoulli { + assert!((p >= 0.0) & (p <= 1.0), "Bernoulli::new not called with 0 <= p <= 0"); + // Technically, this should be 2^64 or `u64::MAX + 1` because we compare + // using `<` when sampling. However, `u64::MAX` rounds to an `f64` + // larger than `u64::MAX` anyway. + const MAX_P_INT: f64 = ::core::u64::MAX as f64; + let p_int = if p < 1.0 { + (p * MAX_P_INT) as u64 + } else { + // Avoid overflow: `MAX_P_INT` cannot be represented as u64. + ::core::u64::MAX + }; + Bernoulli { p_int } + } +} + +impl Distribution for Bernoulli { + #[inline] + fn sample(&self, rng: &mut R) -> bool { + // Make sure to always return true for p = 1.0. + if self.p_int == ::core::u64::MAX { + return true; + } + let r: u64 = rng.gen(); + r < self.p_int + } +} + +#[cfg(test)] +mod test { + use Rng; + use distributions::Distribution; + use super::Bernoulli; + + #[test] + fn test_trivial() { + let mut r = ::test::rng(1); + let always_false = Bernoulli::new(0.0); + let always_true = Bernoulli::new(1.0); + for _ in 0..5 { + assert_eq!(r.sample::(&always_false), false); + assert_eq!(r.sample::(&always_true), true); + assert_eq!(Distribution::::sample(&always_false, &mut r), false); + assert_eq!(Distribution::::sample(&always_true, &mut r), true); + } + } + + #[test] + fn test_average() { + const P: f64 = 0.3; + let d = Bernoulli::new(P); + const N: u32 = 10_000_000; + + let mut sum: u32 = 0; + let mut rng = ::test::rng(2); + for _ in 0..N { + if d.sample(&mut rng) { + sum += 1; + } + } + let avg = (sum as f64) / (N as f64); + + assert!((avg - P).abs() < 1e-3); + } +} diff --git a/src/distributions/binomial.rs b/src/distributions/binomial.rs index 8a03e1d5841..eb716f444b7 100644 --- a/src/distributions/binomial.rs +++ b/src/distributions/binomial.rs @@ -31,13 +31,17 @@ use std::f64::consts::PI; /// ``` #[derive(Clone, Copy, Debug)] pub struct Binomial { - n: u64, // number of trials - p: f64, // probability of success + /// Number of trials. + n: u64, + /// Probability of success. + p: f64, } impl Binomial { - /// Construct a new `Binomial` with the given shape parameters - /// `n`, `p`. Panics if `p <= 0` or `p >= 1`. + /// Construct a new `Binomial` with the given shape parameters `n` (number + /// of trials) and `p` (probability of success). + /// + /// Panics if `p <= 0` or `p >= 1`. pub fn new(n: u64, p: f64) -> Binomial { assert!(p > 0.0, "Binomial::new called with p <= 0"); assert!(p < 1.0, "Binomial::new called with p >= 1"); diff --git a/src/distributions/mod.rs b/src/distributions/mod.rs index aaf330420c8..a7761e388bf 100644 --- a/src/distributions/mod.rs +++ b/src/distributions/mod.rs @@ -178,6 +178,7 @@ pub use self::uniform::Uniform as Range; #[doc(inline)] pub use self::poisson::Poisson; #[cfg(feature = "std")] #[doc(inline)] pub use self::binomial::Binomial; +#[doc(inline)] pub use self::bernoulli::Bernoulli; pub mod uniform; #[cfg(feature="std")] @@ -190,6 +191,7 @@ pub mod uniform; #[doc(hidden)] pub mod poisson; #[cfg(feature = "std")] #[doc(hidden)] pub mod binomial; +#[doc(hidden)] pub mod bernoulli; mod float; mod integer; diff --git a/src/lib.rs b/src/lib.rs index 01cc15dd0bf..0599fc1ea02 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -318,7 +318,6 @@ pub trait Rng: RngCore { /// println!("{}", x); /// println!("{:?}", rng.gen::<(f64, bool)>()); /// ``` - #[inline(always)] fn gen(&mut self) -> T where Standard: Distribution { Standard.sample(self) } @@ -474,6 +473,8 @@ pub trait Rng: RngCore { /// Return a bool with a probability `p` of being true. /// + /// This is a wrapper around [`distributions::Bernoulli`]. + /// /// # Example /// /// ```rust @@ -483,20 +484,15 @@ pub trait Rng: RngCore { /// println!("{}", rng.gen_bool(1.0 / 3.0)); /// ``` /// - /// # Accuracy note + /// # Panics + /// + /// If `p` < 0 or `p` > 1. /// - /// `gen_bool` uses 32 bits of the RNG, so if you use it to generate close - /// to or more than `2^32` results, a tiny bias may become noticable. - /// A notable consequence of the method used here is that the worst case is - /// `rng.gen_bool(0.0)`: it has a chance of 1 in `2^32` of being true, while - /// it should always be false. But using `gen_bool` to consume *many* values - /// from an RNG just to consistently generate `false` does not match with - /// the intent of this method. + /// [`distributions::Bernoulli`]: distributions/bernoulli/struct.Bernoulli.html + #[inline] fn gen_bool(&mut self, p: f64) -> bool { - assert!(p >= 0.0 && p <= 1.0); - // If `p` is constant, this will be evaluated at compile-time. - let p_int = (p * f64::from(core::u32::MAX)) as u32; - self.gen::() <= p_int + let d = distributions::Bernoulli::new(p); + self.sample(d) } /// Return a random element from `values`. @@ -897,7 +893,6 @@ pub fn weak_rng() -> XorShiftRng { /// [`thread_rng`]: fn.thread_rng.html /// [`Standard`]: distributions/struct.Standard.html #[cfg(feature="std")] -#[inline] pub fn random() -> T where Standard: Distribution { thread_rng().gen() } @@ -918,7 +913,6 @@ pub fn random() -> T where Standard: Distribution { /// println!("{:?}", sample); /// ``` #[cfg(feature="std")] -#[inline(always)] #[deprecated(since="0.4.0", note="renamed to seq::sample_iter")] pub fn sample(rng: &mut R, iterable: I, amount: usize) -> Vec where I: IntoIterator, diff --git a/tests/bool.rs b/tests/bool.rs new file mode 100644 index 00000000000..c4208a009a3 --- /dev/null +++ b/tests/bool.rs @@ -0,0 +1,23 @@ +#![no_std] + +extern crate rand; + +use rand::SeedableRng; +use rand::rngs::SmallRng; +use rand::distributions::{Distribution, Bernoulli}; + +/// This test should make sure that we don't accidentally have undefined +/// behavior for large propabilties due to +/// https://github.com/rust-lang/rust/issues/10184. +/// Expressions like `1.0*(u64::MAX as f64) as u64` have to be avoided. +#[test] +fn large_probability() { + let p = 1. - ::core::f64::EPSILON / 2.; + assert!(p < 1.); + let d = Bernoulli::new(p); + let mut rng = SmallRng::from_seed( + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]); + for _ in 0..10 { + assert!(d.sample(&mut rng), "extremely unlikely to fail by accident"); + } +}