diff --git a/rand_distr/Cargo.toml b/rand_distr/Cargo.toml index d212f460b3d..586b6ac1e3d 100644 --- a/rand_distr/Cargo.toml +++ b/rand_distr/Cargo.toml @@ -20,9 +20,17 @@ travis-ci = { repository = "rust-random/rand" } appveyor = { repository = "rust-random/rand" } [dependencies] -rand = { path = "..", version = "0.7" } +rand = { path = "..", version = "0.7", default-features = false } +num-traits = { version = "0.2", default-features = false, features = ["libm"] } + +[features] +default = ["std"] +std = ["alloc"] +alloc = [] [dev-dependencies] rand_pcg = { version = "0.2", path = "../rand_pcg" } +# For inline examples +rand = { path = "..", version = "0.7", default-features = false, features = ["std_rng", "std"] } # Histogram implementation for testing uniformity average = "0.10.3" diff --git a/rand_distr/src/binomial.rs b/rand_distr/src/binomial.rs index ea58ddf7ebe..eec3c5b34b3 100644 --- a/rand_distr/src/binomial.rs +++ b/rand_distr/src/binomial.rs @@ -11,7 +11,7 @@ use crate::{Distribution, Uniform}; use rand::Rng; -use std::{error, fmt}; +use core::fmt; /// The binomial distribution `Binomial(n, p)`. /// @@ -53,7 +53,8 @@ impl fmt::Display for Error { } } -impl error::Error for Error {} +#[cfg(feature = "std")] +impl std::error::Error for Error {} impl Binomial { /// Construct a new `Binomial` with the given shape parameters `n` (number @@ -72,7 +73,7 @@ impl Binomial { /// Convert a `f64` to an `i64`, panicing on overflow. // In the future (Rust 1.34), this might be replaced with `TryFrom`. fn f64_to_i64(x: f64) -> i64 { - assert!(x < (::std::i64::MAX as f64)); + assert!(x < (core::i64::MAX as f64)); x as i64 } @@ -106,7 +107,7 @@ impl Distribution for Binomial { // Ranlib uses 30, and GSL uses 14. const BINV_THRESHOLD: f64 = 10.; - if (self.n as f64) * p < BINV_THRESHOLD && self.n <= (::std::i32::MAX as u64) { + if (self.n as f64) * p < BINV_THRESHOLD && self.n <= (core::i32::MAX as u64) { // Use the BINV algorithm. let s = p / q; let a = ((self.n + 1) as f64) * s; @@ -338,22 +339,4 @@ mod test { fn test_binomial_invalid_lambda_neg() { Binomial::new(20, -10.0).unwrap(); } - - #[test] - fn value_stability() { - fn test_samples(n: u64, p: f64, expected: &[u64]) { - let distr = Binomial::new(n, p).unwrap(); - let mut rng = crate::test::rng(353); - let mut buf = [0; 4]; - for x in &mut buf { - *x = rng.sample(&distr); - } - assert_eq!(buf, expected); - } - - // We have multiple code paths: np < 10, p > 0.5 - test_samples(2, 0.7, &[1, 1, 2, 1]); - test_samples(20, 0.3, &[7, 7, 5, 7]); - test_samples(2000, 0.6, &[1194, 1208, 1192, 1210]); - } } diff --git a/rand_distr/src/cauchy.rs b/rand_distr/src/cauchy.rs index 3e1b1b5ec72..ffe86d00a9b 100644 --- a/rand_distr/src/cauchy.rs +++ b/rand_distr/src/cauchy.rs @@ -9,10 +9,10 @@ //! The Cauchy distribution. -use crate::utils::Float; +use num_traits::{Float, FloatConst}; use crate::{Distribution, Standard}; use rand::Rng; -use std::{error, fmt}; +use core::fmt; /// The Cauchy distribution `Cauchy(median, scale)`. /// @@ -32,9 +32,11 @@ use std::{error, fmt}; /// println!("{} is from a Cauchy(2, 5) distribution", v); /// ``` #[derive(Clone, Copy, Debug)] -pub struct Cauchy { - median: N, - scale: N, +pub struct Cauchy +where F: Float + FloatConst, Standard: Distribution +{ + median: F, + scale: F, } /// Error type returned from `Cauchy::new`. @@ -52,30 +54,31 @@ impl fmt::Display for Error { } } -impl error::Error for Error {} +#[cfg(feature = "std")] +impl std::error::Error for Error {} -impl Cauchy -where Standard: Distribution +impl Cauchy +where F: Float + FloatConst, Standard: Distribution { /// Construct a new `Cauchy` with the given shape parameters /// `median` the peak location and `scale` the scale factor. - pub fn new(median: N, scale: N) -> Result, Error> { - if !(scale > N::from(0.0)) { + pub fn new(median: F, scale: F) -> Result, Error> { + if !(scale > F::zero()) { return Err(Error::ScaleTooSmall); } Ok(Cauchy { median, scale }) } } -impl Distribution for Cauchy -where Standard: Distribution +impl Distribution for Cauchy +where F: Float + FloatConst, Standard: Distribution { - fn sample(&self, rng: &mut R) -> N { + fn sample(&self, rng: &mut R) -> F { // sample from [0, 1) let x = Standard.sample(rng); // get standard cauchy random number // note that π/2 is not exactly representable, even if x=0.5 the result is finite - let comp_dev = (N::pi() * x).tan(); + let comp_dev = (F::PI() * x).tan(); // shift and scale according to parameters self.median + self.scale * comp_dev } @@ -108,10 +111,12 @@ mod test { sum += numbers[i]; } let median = median(&mut numbers); - println!("Cauchy median: {}", median); + #[cfg(feature = "std")] + std::println!("Cauchy median: {}", median); assert!((median - 10.0).abs() < 0.4); // not 100% certain, but probable enough let mean = sum / 1000.0; - println!("Cauchy mean: {}", mean); + #[cfg(feature = "std")] + std::println!("Cauchy mean: {}", mean); // for a Cauchy distribution the mean should not converge assert!((mean - 10.0).abs() > 0.4); // not 100% certain, but probable enough } @@ -130,8 +135,8 @@ mod test { #[test] fn value_stability() { - fn gen_samples(m: N, s: N, buf: &mut [N]) - where Standard: Distribution { + fn gen_samples(m: F, s: F, buf: &mut [F]) + where Standard: Distribution { let distr = Cauchy::new(m, s).unwrap(); let mut rng = crate::test::rng(353); for x in buf { diff --git a/rand_distr/src/dirichlet.rs b/rand_distr/src/dirichlet.rs index 84c7c933915..1a470f72922 100644 --- a/rand_distr/src/dirichlet.rs +++ b/rand_distr/src/dirichlet.rs @@ -8,11 +8,12 @@ // except according to those terms. //! The dirichlet distribution. - -use crate::utils::Float; +#![cfg(feature = "alloc")] +use num_traits::Float; use crate::{Distribution, Exp1, Gamma, Open01, StandardNormal}; use rand::Rng; -use std::{error, fmt}; +use core::fmt; +use alloc::{boxed::Box, vec, vec::Vec}; /// The Dirichlet distribution `Dirichlet(alpha)`. /// @@ -26,14 +27,20 @@ use std::{error, fmt}; /// use rand::prelude::*; /// use rand_distr::Dirichlet; /// -/// let dirichlet = Dirichlet::new(vec![1.0, 2.0, 3.0]).unwrap(); +/// let dirichlet = Dirichlet::new(&[1.0, 2.0, 3.0]).unwrap(); /// let samples = dirichlet.sample(&mut rand::thread_rng()); /// println!("{:?} is from a Dirichlet([1.0, 2.0, 3.0]) distribution", samples); /// ``` #[derive(Clone, Debug)] -pub struct Dirichlet { +pub struct Dirichlet +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ /// Concentration parameters (alpha) - alpha: Vec, + alpha: Box<[F]>, } /// Error type returned from `Dirchlet::new`. @@ -58,68 +65,70 @@ impl fmt::Display for Error { } } -impl error::Error for Error {} +#[cfg(feature = "std")] +impl std::error::Error for Error {} -impl Dirichlet +impl Dirichlet where - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, { /// Construct a new `Dirichlet` with the given alpha parameter `alpha`. /// /// Requires `alpha.len() >= 2`. #[inline] - pub fn new>>(alpha: V) -> Result, Error> { - let a = alpha.into(); - if a.len() < 2 { + pub fn new(alpha: &[F]) -> Result, Error> { + if alpha.len() < 2 { return Err(Error::AlphaTooShort); } - for &ai in &a { - if !(ai > N::from(0.0)) { + for &ai in alpha.iter() { + if !(ai > F::zero()) { return Err(Error::AlphaTooSmall); } } - Ok(Dirichlet { alpha: a }) + Ok(Dirichlet { alpha: alpha.to_vec().into_boxed_slice() }) } /// Construct a new `Dirichlet` with the given shape parameter `alpha` and `size`. /// /// Requires `size >= 2`. #[inline] - pub fn new_with_size(alpha: N, size: usize) -> Result, Error> { - if !(alpha > N::from(0.0)) { + pub fn new_with_size(alpha: F, size: usize) -> Result, Error> { + if !(alpha > F::zero()) { return Err(Error::AlphaTooSmall); } if size < 2 { return Err(Error::SizeTooSmall); } Ok(Dirichlet { - alpha: vec![alpha; size], + alpha: vec![alpha; size].into_boxed_slice(), }) } } -impl Distribution> for Dirichlet +impl Distribution> for Dirichlet where - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, { - fn sample(&self, rng: &mut R) -> Vec { + fn sample(&self, rng: &mut R) -> Vec { let n = self.alpha.len(); - let mut samples = vec![N::from(0.0); n]; - let mut sum = N::from(0.0); + let mut samples = vec![F::zero(); n]; + let mut sum = F::zero(); for (s, &a) in samples.iter_mut().zip(self.alpha.iter()) { - let g = Gamma::new(a, N::from(1.0)).unwrap(); + let g = Gamma::new(a, F::one()).unwrap(); *s = g.sample(rng); - sum += *s; + sum = sum + (*s); } - let invacc = N::from(1.0) / sum; + let invacc = F::one() / sum; for s in samples.iter_mut() { - *s *= invacc; + *s = (*s)*invacc; } samples } @@ -131,7 +140,7 @@ mod test { #[test] fn test_dirichlet() { - let d = Dirichlet::new(vec![1.0, 2.0, 3.0]).unwrap(); + let d = Dirichlet::new(&[1.0, 2.0, 3.0]).unwrap(); let mut rng = crate::test::rng(221); let samples = d.sample(&mut rng); let _: Vec = samples @@ -170,20 +179,4 @@ mod test { fn test_dirichlet_invalid_alpha() { Dirichlet::new_with_size(0.0f64, 2).unwrap(); } - - #[test] - fn value_stability() { - let mut rng = crate::test::rng(223); - assert_eq!( - rng.sample(Dirichlet::new(vec![1.0, 2.0, 3.0]).unwrap()), - vec![0.12941567177708177, 0.4702121891675036, 0.4003721390554146] - ); - assert_eq!(rng.sample(Dirichlet::new_with_size(8.0, 5).unwrap()), vec![ - 0.17684200044809556, - 0.29915953935953055, - 0.1832858056608014, - 0.1425623503573967, - 0.19815030417417595 - ]); - } } diff --git a/rand_distr/src/exponential.rs b/rand_distr/src/exponential.rs index 3fe8e22fd09..841096ec6ed 100644 --- a/rand_distr/src/exponential.rs +++ b/rand_distr/src/exponential.rs @@ -9,10 +9,11 @@ //! The exponential distribution. -use crate::utils::{ziggurat, Float}; +use crate::utils::ziggurat; +use num_traits::Float; use crate::{ziggurat_tables, Distribution}; use rand::Rng; -use std::{error, fmt}; +use core::fmt; /// Samples floating-point numbers according to the exponential distribution, /// with rate parameter `λ = 1`. This is equivalent to `Exp::new(1.0)` or @@ -90,9 +91,11 @@ impl Distribution for Exp1 { /// println!("{} is from a Exp(2) distribution", v); /// ``` #[derive(Clone, Copy, Debug)] -pub struct Exp { +pub struct Exp +where F: Float, Exp1: Distribution +{ /// `lambda` stored as `1/lambda`, since this is what we scale by. - lambda_inverse: N, + lambda_inverse: F, } /// Error type returned from `Exp::new`. @@ -110,10 +113,11 @@ impl fmt::Display for Error { } } -impl error::Error for Error {} +#[cfg(feature = "std")] +impl std::error::Error for Error {} -impl Exp -where Exp1: Distribution +impl Exp +where F: Float, Exp1: Distribution { /// Construct a new `Exp` with the given shape parameter /// `lambda`. @@ -125,20 +129,20 @@ where Exp1: Distribution /// to a sample from an `Exp1` multiplied by `1 / 0`. Primitive types /// yield infinity, since `1 / 0 = infinity`. #[inline] - pub fn new(lambda: N) -> Result, Error> { - if !(lambda >= N::from(0.0)) { + pub fn new(lambda: F) -> Result, Error> { + if !(lambda >= F::zero()) { return Err(Error::LambdaTooSmall); } Ok(Exp { - lambda_inverse: N::from(1.0) / lambda, + lambda_inverse: F::one() / lambda, }) } } -impl Distribution for Exp -where Exp1: Distribution +impl Distribution for Exp +where F: Float, Exp1: Distribution { - fn sample(&self, rng: &mut R) -> N { + fn sample(&self, rng: &mut R) -> F { rng.sample(Exp1) * self.lambda_inverse } } @@ -165,41 +169,10 @@ mod test { fn test_exp_invalid_lambda_neg() { Exp::new(-10.0).unwrap(); } + #[test] #[should_panic] fn test_exp_invalid_lambda_nan() { Exp::new(std::f64::NAN).unwrap(); } - - #[test] - fn value_stability() { - fn test_samples>( - distr: D, zero: N, expected: &[N], - ) { - let mut rng = crate::test::rng(223); - let mut buf = [zero; 4]; - for x in &mut buf { - *x = rng.sample(&distr); - } - assert_eq!(buf, expected); - } - - test_samples(Exp1, 0f32, &[1.079617, 1.8325565, 0.04601166, 0.34471703]); - test_samples(Exp1, 0f64, &[ - 1.0796170642388276, - 1.8325565304274, - 0.04601166186842716, - 0.3447170217100157, - ]); - - test_samples(Exp::new(2.0).unwrap(), 0f32, &[ - 0.5398085, 0.91627824, 0.02300583, 0.17235851, - ]); - test_samples(Exp::new(1.0).unwrap(), 0f64, &[ - 1.0796170642388276, - 1.8325565304274, - 0.04601166186842716, - 0.3447170217100157, - ]); - } } diff --git a/rand_distr/src/gamma.rs b/rand_distr/src/gamma.rs index ba8e4e0eb31..34cb45dfb36 100644 --- a/rand_distr/src/gamma.rs +++ b/rand_distr/src/gamma.rs @@ -13,10 +13,10 @@ use self::ChiSquaredRepr::*; use self::GammaRepr::*; use crate::normal::StandardNormal; -use crate::utils::Float; +use num_traits::Float; use crate::{Distribution, Exp, Exp1, Open01}; use rand::Rng; -use std::{error, fmt}; +use core::fmt; /// The Gamma distribution `Gamma(shape, scale)` distribution. /// @@ -49,8 +49,14 @@ use std::{error, fmt}; /// (September 2000), 363-372. /// DOI:[10.1145/358407.358414](https://doi.acm.org/10.1145/358407.358414) #[derive(Clone, Copy, Debug)] -pub struct Gamma { - repr: GammaRepr, +pub struct Gamma +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + repr: GammaRepr, } /// Error type returned from `Gamma::new`. @@ -74,13 +80,20 @@ impl fmt::Display for Error { } } -impl error::Error for Error {} +#[cfg(feature = "std")] +impl std::error::Error for Error {} #[derive(Clone, Copy, Debug)] -enum GammaRepr { - Large(GammaLargeShape), - One(Exp), - Small(GammaSmallShape), +enum GammaRepr +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + Large(GammaLargeShape), + One(Exp), + Small(GammaSmallShape), } // These two helpers could be made public, but saving the @@ -98,9 +111,14 @@ enum GammaRepr { /// See `Gamma` for sampling from a Gamma distribution with general /// shape parameters. #[derive(Clone, Copy, Debug)] -struct GammaSmallShape { - inv_shape: N, - large_shape: GammaLargeShape, +struct GammaSmallShape +where + F: Float, + StandardNormal: Distribution, + Open01: Distribution, +{ + inv_shape: F, + large_shape: GammaLargeShape, } /// Gamma distribution where the shape parameter is larger than 1. @@ -108,32 +126,38 @@ struct GammaSmallShape { /// See `Gamma` for sampling from a Gamma distribution with general /// shape parameters. #[derive(Clone, Copy, Debug)] -struct GammaLargeShape { - scale: N, - c: N, - d: N, +struct GammaLargeShape +where + F: Float, + StandardNormal: Distribution, + Open01: Distribution, +{ + scale: F, + c: F, + d: F, } -impl Gamma +impl Gamma where - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, { /// Construct an object representing the `Gamma(shape, scale)` /// distribution. #[inline] - pub fn new(shape: N, scale: N) -> Result, Error> { - if !(shape > N::from(0.0)) { + pub fn new(shape: F, scale: F) -> Result, Error> { + if !(shape > F::zero()) { return Err(Error::ShapeTooSmall); } - if !(scale > N::from(0.0)) { + if !(scale > F::zero()) { return Err(Error::ScaleTooSmall); } - let repr = if shape == N::from(1.0) { - One(Exp::new(N::from(1.0) / scale).map_err(|_| Error::ScaleTooLarge)?) - } else if shape < N::from(1.0) { + let repr = if shape == F::one() { + One(Exp::new(F::one() / scale).map_err(|_| Error::ScaleTooLarge)?) + } else if shape < F::one() { Small(GammaSmallShape::new_raw(shape, scale)) } else { Large(GammaLargeShape::new_raw(shape, scale)) @@ -142,41 +166,44 @@ where } } -impl GammaSmallShape +impl GammaSmallShape where - StandardNormal: Distribution, - Open01: Distribution, + F: Float, + StandardNormal: Distribution, + Open01: Distribution, { - fn new_raw(shape: N, scale: N) -> GammaSmallShape { + fn new_raw(shape: F, scale: F) -> GammaSmallShape { GammaSmallShape { - inv_shape: N::from(1.0) / shape, - large_shape: GammaLargeShape::new_raw(shape + N::from(1.0), scale), + inv_shape: F::one() / shape, + large_shape: GammaLargeShape::new_raw(shape + F::one(), scale), } } } -impl GammaLargeShape +impl GammaLargeShape where - StandardNormal: Distribution, - Open01: Distribution, + F: Float, + StandardNormal: Distribution, + Open01: Distribution, { - fn new_raw(shape: N, scale: N) -> GammaLargeShape { - let d = shape - N::from(1. / 3.); + fn new_raw(shape: F, scale: F) -> GammaLargeShape { + let d = shape - F::from(1. / 3.).unwrap(); GammaLargeShape { scale, - c: N::from(1.0) / (N::from(9.) * d).sqrt(), + c: F::one() / (F::from(9.).unwrap() * d).sqrt(), d, } } } -impl Distribution for Gamma +impl Distribution for Gamma where - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, { - fn sample(&self, rng: &mut R) -> N { + fn sample(&self, rng: &mut R) -> F { match self.repr { Small(ref g) => g.sample(rng), One(ref g) => g.sample(rng), @@ -184,38 +211,40 @@ where } } } -impl Distribution for GammaSmallShape +impl Distribution for GammaSmallShape where - StandardNormal: Distribution, - Open01: Distribution, + F: Float, + StandardNormal: Distribution, + Open01: Distribution, { - fn sample(&self, rng: &mut R) -> N { - let u: N = rng.sample(Open01); + fn sample(&self, rng: &mut R) -> F { + let u: F = rng.sample(Open01); self.large_shape.sample(rng) * u.powf(self.inv_shape) } } -impl Distribution for GammaLargeShape +impl Distribution for GammaLargeShape where - StandardNormal: Distribution, - Open01: Distribution, + F: Float, + StandardNormal: Distribution, + Open01: Distribution, { - fn sample(&self, rng: &mut R) -> N { + fn sample(&self, rng: &mut R) -> F { // Marsaglia & Tsang method, 2000 loop { - let x: N = rng.sample(StandardNormal); - let v_cbrt = N::from(1.0) + self.c * x; - if v_cbrt <= N::from(0.0) { + let x: F = rng.sample(StandardNormal); + let v_cbrt = F::one() + self.c * x; + if v_cbrt <= F::zero() { // a^3 <= 0 iff a <= 0 continue; } let v = v_cbrt * v_cbrt * v_cbrt; - let u: N = rng.sample(Open01); + let u: F = rng.sample(Open01); let x_sqr = x * x; - if u < N::from(1.0) - N::from(0.0331) * x_sqr * x_sqr - || u.ln() < N::from(0.5) * x_sqr + self.d * (N::from(1.0) - v + v.ln()) + if u < F::one() - F::from(0.0331).unwrap() * x_sqr * x_sqr + || u.ln() < F::from(0.5).unwrap() * x_sqr + self.d * (F::one() - v + v.ln()) { return self.d * v * self.scale; } @@ -241,8 +270,14 @@ where /// println!("{} is from a χ²(11) distribution", v) /// ``` #[derive(Clone, Copy, Debug)] -pub struct ChiSquared { - repr: ChiSquaredRepr, +pub struct ChiSquared +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + repr: ChiSquaredRepr, } /// Error type returned from `ChiSquared::new` and `StudentT::new`. @@ -262,48 +297,57 @@ impl fmt::Display for ChiSquaredError { } } -impl error::Error for ChiSquaredError {} +#[cfg(feature = "std")] +impl std::error::Error for ChiSquaredError {} #[derive(Clone, Copy, Debug)] -enum ChiSquaredRepr { +enum ChiSquaredRepr +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ // k == 1, Gamma(alpha, ..) is particularly slow for alpha < 1, // e.g. when alpha = 1/2 as it would be for this case, so special- // casing and using the definition of N(0,1)^2 is faster. DoFExactlyOne, - DoFAnythingElse(Gamma), + DoFAnythingElse(Gamma), } -impl ChiSquared +impl ChiSquared where - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, { /// Create a new chi-squared distribution with degrees-of-freedom /// `k`. - pub fn new(k: N) -> Result, ChiSquaredError> { - let repr = if k == N::from(1.0) { + pub fn new(k: F) -> Result, ChiSquaredError> { + let repr = if k == F::one() { DoFExactlyOne } else { - if !(N::from(0.5) * k > N::from(0.0)) { + if !(F::from(0.5).unwrap() * k > F::zero()) { return Err(ChiSquaredError::DoFTooSmall); } - DoFAnythingElse(Gamma::new(N::from(0.5) * k, N::from(2.0)).unwrap()) + DoFAnythingElse(Gamma::new(F::from(0.5).unwrap() * k, F::from(2.0).unwrap()).unwrap()) }; Ok(ChiSquared { repr }) } } -impl Distribution for ChiSquared +impl Distribution for ChiSquared where - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, { - fn sample(&self, rng: &mut R) -> N { + fn sample(&self, rng: &mut R) -> F { match self.repr { DoFExactlyOne => { // k == 1 => N(0,1)^2 - let norm: N = rng.sample(StandardNormal); + let norm: F = rng.sample(StandardNormal); norm * norm } DoFAnythingElse(ref g) => g.sample(rng), @@ -327,12 +371,18 @@ where /// println!("{} is from an F(2, 32) distribution", v) /// ``` #[derive(Clone, Copy, Debug)] -pub struct FisherF { - numer: ChiSquared, - denom: ChiSquared, +pub struct FisherF +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + numer: ChiSquared, + denom: ChiSquared, // denom_dof / numer_dof so that this can just be a straight // multiplication, rather than a division. - dof_ratio: N, + dof_ratio: F, } /// Error type returned from `FisherF::new`. @@ -353,20 +403,23 @@ impl fmt::Display for FisherFError { } } -impl error::Error for FisherFError {} +#[cfg(feature = "std")] +impl std::error::Error for FisherFError {} -impl FisherF +impl FisherF where - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, { /// Create a new `FisherF` distribution, with the given parameter. - pub fn new(m: N, n: N) -> Result, FisherFError> { - if !(m > N::from(0.0)) { + pub fn new(m: F, n: F) -> Result, FisherFError> { + let zero = F::zero(); + if !(m > zero) { return Err(FisherFError::MTooSmall); } - if !(n > N::from(0.0)) { + if !(n > zero) { return Err(FisherFError::NTooSmall); } @@ -377,13 +430,14 @@ where }) } } -impl Distribution for FisherF +impl Distribution for FisherF where - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, { - fn sample(&self, rng: &mut R) -> N { + fn sample(&self, rng: &mut R) -> F { self.numer.sample(rng) / self.denom.sample(rng) * self.dof_ratio } } @@ -401,34 +455,42 @@ where /// println!("{} is from a t(11) distribution", v) /// ``` #[derive(Clone, Copy, Debug)] -pub struct StudentT { - chi: ChiSquared, - dof: N, +pub struct StudentT +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + chi: ChiSquared, + dof: F, } -impl StudentT +impl StudentT where - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, { /// Create a new Student t distribution with `n` degrees of /// freedom. - pub fn new(n: N) -> Result, ChiSquaredError> { + pub fn new(n: F) -> Result, ChiSquaredError> { Ok(StudentT { chi: ChiSquared::new(n)?, dof: n, }) } } -impl Distribution for StudentT +impl Distribution for StudentT where - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, { - fn sample(&self, rng: &mut R) -> N { - let norm: N = rng.sample(StandardNormal); + fn sample(&self, rng: &mut R) -> F { + let norm: F = rng.sample(StandardNormal); norm * (self.dof / self.chi.sample(rng)).sqrt() } } @@ -445,9 +507,15 @@ where /// println!("{} is from a Beta(2, 5) distribution", v); /// ``` #[derive(Clone, Copy, Debug)] -pub struct Beta { - gamma_a: Gamma, - gamma_b: Gamma, +pub struct Beta +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + gamma_a: Gamma, + gamma_b: Gamma, } /// Error type returned from `Beta::new`. @@ -468,31 +536,34 @@ impl fmt::Display for BetaError { } } -impl error::Error for BetaError {} +#[cfg(feature = "std")] +impl std::error::Error for BetaError {} -impl Beta +impl Beta where - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, { /// Construct an object representing the `Beta(alpha, beta)` /// distribution. - pub fn new(alpha: N, beta: N) -> Result, BetaError> { + pub fn new(alpha: F, beta: F) -> Result, BetaError> { Ok(Beta { - gamma_a: Gamma::new(alpha, N::from(1.)).map_err(|_| BetaError::AlphaTooSmall)?, - gamma_b: Gamma::new(beta, N::from(1.)).map_err(|_| BetaError::BetaTooSmall)?, + gamma_a: Gamma::new(alpha, F::one()).map_err(|_| BetaError::AlphaTooSmall)?, + gamma_b: Gamma::new(beta, F::one()).map_err(|_| BetaError::BetaTooSmall)?, }) } } -impl Distribution for Beta +impl Distribution for Beta where - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, { - fn sample(&self, rng: &mut R) -> N { + fn sample(&self, rng: &mut R) -> F { let x = self.gamma_a.sample(rng); let y = self.gamma_b.sample(rng); x / (x + y) @@ -565,91 +636,4 @@ mod test { fn test_beta_invalid_dof() { Beta::new(0., 0.).unwrap(); } - - #[test] - fn value_stability() { - fn test_samples>( - distr: D, zero: N, expected: &[N], - ) { - let mut rng = crate::test::rng(223); - let mut buf = [zero; 4]; - for x in &mut buf { - *x = rng.sample(&distr); - } - assert_eq!(buf, expected); - } - - // Gamma has 3 cases: shape == 1, shape < 1, shape > 1 - test_samples(Gamma::new(1.0, 5.0).unwrap(), 0f32, &[ - 5.398085, 9.162783, 0.2300583, 1.7235851, - ]); - test_samples(Gamma::new(0.8, 5.0).unwrap(), 0f32, &[ - 0.5051203, 0.9048302, 3.095812, 1.8566116, - ]); - test_samples(Gamma::new(1.1, 5.0).unwrap(), 0f64, &[ - 7.783878094584059, - 1.4939528171618057, - 8.638017638857592, - 3.0949337228829004, - ]); - - // ChiSquared has 2 cases: k == 1, k != 1 - test_samples(ChiSquared::new(1.0).unwrap(), 0f64, &[ - 0.4893526200348249, - 1.635249736808788, - 0.5013580219361969, - 0.1457735613733489, - ]); - test_samples(ChiSquared::new(0.1).unwrap(), 0f64, &[ - 0.014824404726978617, - 0.021602123937134326, - 0.0000003431429746851693, - 0.00000002291755769542258, - ]); - test_samples(ChiSquared::new(10.0).unwrap(), 0f32, &[ - 12.693656, 6.812016, 11.082001, 12.436167, - ]); - - // FisherF has same special cases as ChiSquared on each param - test_samples(FisherF::new(1.0, 13.5).unwrap(), 0f32, &[ - 0.32283646, - 0.048049655, - 0.0788893, - 1.817178, - ]); - test_samples(FisherF::new(1.0, 1.0).unwrap(), 0f32, &[ - 0.29925257, 3.4392934, 9.567652, 0.020074, - ]); - test_samples(FisherF::new(0.7, 13.5).unwrap(), 0f64, &[ - 3.3196593155045124, - 0.3409169916262829, - 0.03377989856426519, - 0.00004041672861036937, - ]); - - // StudentT has same special cases as ChiSquared - test_samples(StudentT::new(1.0).unwrap(), 0f32, &[ - 0.54703987, - -1.8545331, - 3.093162, - -0.14168274, - ]); - test_samples(StudentT::new(1.1).unwrap(), 0f64, &[ - 0.7729195887949754, - 1.2606210611616204, - -1.7553606501113175, - -2.377641221169782, - ]); - - // Beta has same special cases as Gamma on each param - test_samples(Beta::new(1.0, 0.8).unwrap(), 0f32, &[ - 0.6444564, 0.357635, 0.4110078, 0.7347192, - ]); - test_samples(Beta::new(0.7, 1.2).unwrap(), 0f64, &[ - 0.6433129944095513, - 0.5373371199711573, - 0.10313293199269491, - 0.002472280249144378, - ]); - } } diff --git a/rand_distr/src/inverse_gaussian.rs b/rand_distr/src/inverse_gaussian.rs index 171aa473eee..dee77686008 100644 --- a/rand_distr/src/inverse_gaussian.rs +++ b/rand_distr/src/inverse_gaussian.rs @@ -1,4 +1,5 @@ -use crate::{Distribution, Float, Standard, StandardNormal}; +use crate::{Distribution, Standard, StandardNormal}; +use num_traits::Float; use rand::Rng; /// Error type returned from `InverseGaussian::new` @@ -12,22 +13,31 @@ pub enum Error { /// The [inverse Gaussian distribution](https://en.wikipedia.org/wiki/Inverse_Gaussian_distribution) #[derive(Debug)] -pub struct InverseGaussian { - mean: N, - shape: N, +pub struct InverseGaussian +where + F: Float, + StandardNormal: Distribution, + Standard: Distribution, +{ + mean: F, + shape: F, } -impl InverseGaussian -where StandardNormal: Distribution +impl InverseGaussian +where + F: Float, + StandardNormal: Distribution, + Standard: Distribution, { /// Construct a new `InverseGaussian` distribution with the given mean and /// shape. - pub fn new(mean: N, shape: N) -> Result, Error> { - if !(mean > N::from(0.0)) { + pub fn new(mean: F, shape: F) -> Result, Error> { + let zero = F::zero(); + if !(mean > zero) { return Err(Error::MeanNegativeOrNull); } - if !(shape > N::from(0.0)) { + if !(shape > zero) { return Err(Error::ShapeNegativeOrNull); } @@ -35,24 +45,25 @@ where StandardNormal: Distribution } } -impl Distribution for InverseGaussian +impl Distribution for InverseGaussian where - StandardNormal: Distribution, - Standard: Distribution, + F: Float, + StandardNormal: Distribution, + Standard: Distribution, { - fn sample(&self, rng: &mut R) -> N + fn sample(&self, rng: &mut R) -> F where R: Rng + ?Sized { let mu = self.mean; let l = self.shape; - let v: N = rng.sample(StandardNormal); + let v: F = rng.sample(StandardNormal); let y = mu * v * v; - let mu_2l = mu / (N::from(2.) * l); + let mu_2l = mu / (F::from(2.).unwrap() * l); - let x = mu + mu_2l * (y - (N::from(4.) * l * y + y * y).sqrt()); + let x = mu + mu_2l * (y - (F::from(4.).unwrap() * l * y + y * y).sqrt()); - let u: N = rng.gen(); + let u: F = rng.gen(); if u <= mu / (mu + x) { return x; @@ -82,28 +93,4 @@ mod tests { assert!(InverseGaussian::new(1.0, -1.0).is_err()); assert!(InverseGaussian::new(1.0, 1.0).is_ok()); } - - #[test] - fn value_stability() { - fn test_samples>( - distr: D, zero: N, expected: &[N], - ) { - let mut rng = crate::test::rng(213); - let mut buf = [zero; 4]; - for x in &mut buf { - *x = rng.sample(&distr); - } - assert_eq!(buf, expected); - } - - test_samples(InverseGaussian::new(1.0, 3.0).unwrap(), 0f32, &[ - 0.9339157, 1.108113, 0.50864697, 0.39849377, - ]); - test_samples(InverseGaussian::new(1.0, 3.0).unwrap(), 0f64, &[ - 1.0707604954722476, - 0.9628140605340697, - 0.4069687656468226, - 0.660283852985818, - ]); - } } diff --git a/rand_distr/src/lib.rs b/rand_distr/src/lib.rs index ebc14402771..de6e2274614 100644 --- a/rand_distr/src/lib.rs +++ b/rand_distr/src/lib.rs @@ -19,6 +19,7 @@ clippy::unreadable_literal )] #![allow(clippy::neg_cmp_op_on_partial_ord)] // suggested fix too verbose +#![no_std] //! Generating random samples from probability distributions. //! @@ -70,6 +71,15 @@ //! - [`InverseGaussian`] distribution //! - [`NormalInverseGaussian`] distribution +#[cfg(all(feature = "alloc", not(feature = "std")))] +extern crate alloc; + +#[cfg(feature = "std")] +extern crate std; +// TODO: remove on MSRV bump to 1.36 +#[cfg(feature = "std")] +extern crate std as alloc; + pub use rand::distributions::{ uniform, Alphanumeric, Bernoulli, BernoulliError, DistIter, Distribution, Open01, OpenClosed01, Standard, Uniform, @@ -77,6 +87,7 @@ pub use rand::distributions::{ pub use self::binomial::{Binomial, Error as BinomialError}; pub use self::cauchy::{Cauchy, Error as CauchyError}; +#[cfg(feature = "alloc")] pub use self::dirichlet::{Dirichlet, Error as DirichletError}; pub use self::exponential::{Error as ExpError, Exp, Exp1}; pub use self::gamma::{ @@ -94,10 +105,13 @@ pub use self::unit_ball::UnitBall; pub use self::unit_circle::UnitCircle; pub use self::unit_disc::UnitDisc; pub use self::unit_sphere::UnitSphere; -pub use self::utils::Float; pub use self::weibull::{Error as WeibullError, Weibull}; +#[cfg(feature = "alloc")] pub use self::weighted::{WeightedError, WeightedIndex}; +pub use num_traits; + +#[cfg(feature = "alloc")] pub mod weighted; mod binomial; diff --git a/rand_distr/src/normal.rs b/rand_distr/src/normal.rs index 39f9402ec04..8067c5b9a6b 100644 --- a/rand_distr/src/normal.rs +++ b/rand_distr/src/normal.rs @@ -9,10 +9,11 @@ //! The normal and derived distributions. -use crate::utils::{ziggurat, Float}; +use crate::utils::ziggurat; +use num_traits::Float; use crate::{ziggurat_tables, Distribution, Open01}; use rand::Rng; -use std::{error, fmt}; +use core::fmt; /// Samples floating-point numbers according to the normal distribution /// `N(0, 1)` (a.k.a. a standard normal, or Gaussian). This is equivalent to @@ -111,9 +112,11 @@ impl Distribution for StandardNormal { /// /// [`StandardNormal`]: crate::StandardNormal #[derive(Clone, Copy, Debug)] -pub struct Normal { - mean: N, - std_dev: N, +pub struct Normal +where F: Float, StandardNormal: Distribution +{ + mean: F, + std_dev: F, } /// Error type returned from `Normal::new` and `LogNormal::new`. @@ -131,27 +134,28 @@ impl fmt::Display for Error { } } -impl error::Error for Error {} +#[cfg(feature = "std")] +impl std::error::Error for Error {} -impl Normal -where StandardNormal: Distribution +impl Normal +where F: Float, StandardNormal: Distribution { /// Construct a new `Normal` distribution with the given mean and /// standard deviation. #[inline] - pub fn new(mean: N, std_dev: N) -> Result, Error> { - if !(std_dev >= N::from(0.0)) { + pub fn new(mean: F, std_dev: F) -> Result, Error> { + if !(std_dev >= F::zero()) { return Err(Error::StdDevTooSmall); } Ok(Normal { mean, std_dev }) } } -impl Distribution for Normal -where StandardNormal: Distribution +impl Distribution for Normal +where F: Float, StandardNormal: Distribution { - fn sample(&self, rng: &mut R) -> N { - let n: N = rng.sample(StandardNormal); + fn sample(&self, rng: &mut R) -> F { + let n: F = rng.sample(StandardNormal); self.mean + self.std_dev * n } } @@ -173,18 +177,20 @@ where StandardNormal: Distribution /// println!("{} is from an ln N(2, 9) distribution", v) /// ``` #[derive(Clone, Copy, Debug)] -pub struct LogNormal { - norm: Normal, +pub struct LogNormal +where F: Float, StandardNormal: Distribution +{ + norm: Normal, } -impl LogNormal -where StandardNormal: Distribution +impl LogNormal +where F: Float, StandardNormal: Distribution { /// Construct a new `LogNormal` distribution with the given mean /// and standard deviation of the logarithm of the distribution. #[inline] - pub fn new(mean: N, std_dev: N) -> Result, Error> { - if !(std_dev >= N::from(0.0)) { + pub fn new(mean: F, std_dev: F) -> Result, Error> { + if !(std_dev >= F::zero()) { return Err(Error::StdDevTooSmall); } Ok(LogNormal { @@ -193,10 +199,10 @@ where StandardNormal: Distribution } } -impl Distribution for LogNormal -where StandardNormal: Distribution +impl Distribution for LogNormal +where F: Float, StandardNormal: Distribution { - fn sample(&self, rng: &mut R) -> N { + fn sample(&self, rng: &mut R) -> F { self.norm.sample(rng).exp() } } @@ -233,54 +239,4 @@ mod tests { fn test_log_normal_invalid_sd() { LogNormal::new(10.0, -1.0).unwrap(); } - - #[test] - fn value_stability() { - fn test_samples>( - distr: D, zero: N, expected: &[N], - ) { - let mut rng = crate::test::rng(213); - let mut buf = [zero; 4]; - for x in &mut buf { - *x = rng.sample(&distr); - } - assert_eq!(buf, expected); - } - - test_samples(StandardNormal, 0f32, &[ - -0.11844189, - 0.781378, - 0.06563994, - -1.1932899, - ]); - test_samples(StandardNormal, 0f64, &[ - -0.11844188827977231, - 0.7813779637772346, - 0.06563993969580051, - -1.1932899004186373, - ]); - - test_samples(Normal::new(0.0, 1.0).unwrap(), 0f32, &[ - -0.11844189, - 0.781378, - 0.06563994, - -1.1932899, - ]); - test_samples(Normal::new(2.0, 0.5).unwrap(), 0f64, &[ - 1.940779055860114, - 2.3906889818886174, - 2.0328199698479, - 1.4033550497906813, - ]); - - test_samples(LogNormal::new(0.0, 1.0).unwrap(), 0f32, &[ - 0.88830346, 2.1844804, 1.0678421, 0.30322206, - ]); - test_samples(LogNormal::new(2.0, 0.5).unwrap(), 0f64, &[ - 6.964174338639032, - 10.921015733601452, - 7.6355881556915906, - 4.068828213584092, - ]); - } } diff --git a/rand_distr/src/normal_inverse_gaussian.rs b/rand_distr/src/normal_inverse_gaussian.rs index fc6e9801217..252a319d877 100644 --- a/rand_distr/src/normal_inverse_gaussian.rs +++ b/rand_distr/src/normal_inverse_gaussian.rs @@ -1,4 +1,5 @@ -use crate::{Distribution, Float, InverseGaussian, Standard, StandardNormal}; +use crate::{Distribution, InverseGaussian, Standard, StandardNormal}; +use num_traits::Float; use rand::Rng; /// Error type returned from `NormalInverseGaussian::new` @@ -12,19 +13,27 @@ pub enum Error { /// The [normal-inverse Gaussian distribution](https://en.wikipedia.org/wiki/Normal-inverse_Gaussian_distribution) #[derive(Debug)] -pub struct NormalInverseGaussian { - alpha: N, - beta: N, - inverse_gaussian: InverseGaussian, +pub struct NormalInverseGaussian +where + F: Float, + StandardNormal: Distribution, + Standard: Distribution, +{ + alpha: F, + beta: F, + inverse_gaussian: InverseGaussian, } -impl NormalInverseGaussian -where StandardNormal: Distribution +impl NormalInverseGaussian +where + F: Float, + StandardNormal: Distribution, + Standard: Distribution, { /// Construct a new `NormalInverseGaussian` distribution with the given alpha (tail heaviness) and /// beta (asymmetry) parameters. - pub fn new(alpha: N, beta: N) -> Result, Error> { - if !(alpha > N::from(0.0)) { + pub fn new(alpha: F, beta: F) -> Result, Error> { + if !(alpha > F::zero()) { return Err(Error::AlphaNegativeOrNull); } @@ -34,9 +43,9 @@ where StandardNormal: Distribution let gamma = (alpha * alpha - beta * beta).sqrt(); - let mu = N::from(1.) / gamma; + let mu = F::one() / gamma; - let inverse_gaussian = InverseGaussian::new(mu, N::from(1.)).unwrap(); + let inverse_gaussian = InverseGaussian::new(mu, F::one()).unwrap(); Ok(Self { alpha, @@ -46,12 +55,13 @@ where StandardNormal: Distribution } } -impl Distribution for NormalInverseGaussian +impl Distribution for NormalInverseGaussian where - StandardNormal: Distribution, - Standard: Distribution, + F: Float, + StandardNormal: Distribution, + Standard: Distribution, { - fn sample(&self, rng: &mut R) -> N + fn sample(&self, rng: &mut R) -> F where R: Rng + ?Sized { let inv_gauss = rng.sample(&self.inverse_gaussian); @@ -79,29 +89,4 @@ mod tests { assert!(NormalInverseGaussian::new(1.0, 2.0).is_err()); assert!(NormalInverseGaussian::new(2.0, 1.0).is_ok()); } - - - #[test] - fn value_stability() { - fn test_samples>( - distr: D, zero: N, expected: &[N], - ) { - let mut rng = crate::test::rng(213); - let mut buf = [zero; 4]; - for x in &mut buf { - *x = rng.sample(&distr); - } - assert_eq!(buf, expected); - } - - test_samples(NormalInverseGaussian::new(2.0, 1.0).unwrap(), 0f32, &[ - 0.6568966, 1.3744819, 2.216063, 0.11488572, - ]); - test_samples(NormalInverseGaussian::new(2.0, 1.0).unwrap(), 0f64, &[ - 0.6838707059642927, - 2.4447306460569784, - 0.2361045023235968, - 1.7774534624785319, - ]); - } } diff --git a/rand_distr/src/pareto.rs b/rand_distr/src/pareto.rs index fd427514c4d..217899ed9a7 100644 --- a/rand_distr/src/pareto.rs +++ b/rand_distr/src/pareto.rs @@ -8,10 +8,10 @@ //! The Pareto distribution. -use crate::utils::Float; +use num_traits::Float; use crate::{Distribution, OpenClosed01}; use rand::Rng; -use std::{error, fmt}; +use core::fmt; /// Samples floating-point numbers according to the Pareto distribution /// @@ -24,9 +24,11 @@ use std::{error, fmt}; /// println!("{}", val); /// ``` #[derive(Clone, Copy, Debug)] -pub struct Pareto { - scale: N, - inv_neg_shape: N, +pub struct Pareto +where F: Float, OpenClosed01: Distribution +{ + scale: F, + inv_neg_shape: F, } /// Error type returned from `Pareto::new`. @@ -47,34 +49,37 @@ impl fmt::Display for Error { } } -impl error::Error for Error {} +#[cfg(feature = "std")] +impl std::error::Error for Error {} -impl Pareto -where OpenClosed01: Distribution +impl Pareto +where F: Float, OpenClosed01: Distribution { /// Construct a new Pareto distribution with given `scale` and `shape`. /// /// In the literature, `scale` is commonly written as xm or k and /// `shape` is often written as α. - pub fn new(scale: N, shape: N) -> Result, Error> { - if !(scale > N::from(0.0)) { + pub fn new(scale: F, shape: F) -> Result, Error> { + let zero = F::zero(); + + if !(scale > zero) { return Err(Error::ScaleTooSmall); } - if !(shape > N::from(0.0)) { + if !(shape > zero) { return Err(Error::ShapeTooSmall); } Ok(Pareto { scale, - inv_neg_shape: N::from(-1.0) / shape, + inv_neg_shape: F::from(-1.0).unwrap() / shape, }) } } -impl Distribution for Pareto -where OpenClosed01: Distribution +impl Distribution for Pareto +where F: Float, OpenClosed01: Distribution { - fn sample(&self, rng: &mut R) -> N { - let u: N = OpenClosed01.sample(rng); + fn sample(&self, rng: &mut R) -> F { + let u: F = OpenClosed01.sample(rng); self.scale * u.powf(self.inv_neg_shape) } } @@ -103,8 +108,8 @@ mod tests { #[test] fn value_stability() { - fn test_samples>( - distr: D, zero: N, expected: &[N], + fn test_samples>( + distr: D, zero: F, expected: &[F], ) { let mut rng = crate::test::rng(213); let mut buf = [zero; 4]; diff --git a/rand_distr/src/pert.rs b/rand_distr/src/pert.rs index 2a00a8c27fb..d6905e014bf 100644 --- a/rand_distr/src/pert.rs +++ b/rand_distr/src/pert.rs @@ -7,10 +7,10 @@ // except according to those terms. //! The PERT distribution. -use crate::utils::Float; +use num_traits::Float; use crate::{Beta, Distribution, Exp1, Open01, StandardNormal}; use rand::Rng; -use std::{error, fmt}; +use core::fmt; /// The PERT distribution. /// @@ -31,10 +31,16 @@ use std::{error, fmt}; /// /// [`Triangular`]: crate::Triangular #[derive(Clone, Copy, Debug)] -pub struct Pert { - min: N, - range: N, - beta: Beta, +pub struct Pert +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + min: F, + range: F, + beta: Beta, } /// Error type returned from [`Pert`] constructors. @@ -58,41 +64,43 @@ impl fmt::Display for PertError { } } -impl error::Error for PertError {} +#[cfg(feature = "std")] +impl std::error::Error for PertError {} -impl Pert +impl Pert where - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, { /// Set up the PERT distribution with defined `min`, `max` and `mode`. /// /// This is equivalent to calling `Pert::new_shape` with `shape == 4.0`. #[inline] - pub fn new(min: N, max: N, mode: N) -> Result, PertError> { - Pert::new_with_shape(min, max, mode, N::from(4.)) + pub fn new(min: F, max: F, mode: F) -> Result, PertError> { + Pert::new_with_shape(min, max, mode, F::from(4.).unwrap()) } /// Set up the PERT distribution with defined `min`, `max`, `mode` and /// `shape`. - pub fn new_with_shape(min: N, max: N, mode: N, shape: N) -> Result, PertError> { + pub fn new_with_shape(min: F, max: F, mode: F, shape: F) -> Result, PertError> { if !(max > min) { return Err(PertError::RangeTooSmall); } if !(mode >= min && max >= mode) { return Err(PertError::ModeRange); } - if !(shape >= N::from(0.)) { + if !(shape >= F::from(0.).unwrap()) { return Err(PertError::ShapeTooSmall); } let range = max - min; - let mu = (min + max + shape * mode) / (shape + N::from(2.)); + let mu = (min + max + shape * mode) / (shape + F::from(2.).unwrap()); let v = if mu == mode { - shape * N::from(0.5) + N::from(1.) + shape * F::from(0.5).unwrap() + F::from(1.).unwrap() } else { - (mu - min) * (N::from(2.) * mode - min - max) / ((mode - mu) * (max - min)) + (mu - min) * (F::from(2.).unwrap() * mode - min - max) / ((mode - mu) * (max - min)) }; let w = v * (max - mu) / (mu - min); let beta = Beta::new(v, w).map_err(|_| PertError::RangeTooSmall)?; @@ -100,14 +108,15 @@ where } } -impl Distribution for Pert +impl Distribution for Pert where - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, { #[inline] - fn sample(&self, rng: &mut R) -> N { + fn sample(&self, rng: &mut R) -> F { self.beta.sample(rng) * self.range + self.min } } @@ -115,7 +124,6 @@ where #[cfg(test)] mod test { use super::*; - use std::f64; #[test] fn test_pert() { @@ -136,20 +144,4 @@ mod test { assert!(Pert::new(min, max, mode).is_err()); } } - - #[test] - fn value_stability() { - let rng = crate::test::rng(860); - let distr = Pert::new(2., 10., 3.).unwrap(); // mean = 4, var = 12/7 - let seq = distr.sample_iter(rng).take(5).collect::>(); - println!("seq: {:?}", seq); - let expected = vec![ - 4.631484136029422, - 3.307201472321789, - 3.29995019556348, - 3.66835483991721, - 3.514246139933899, - ]; - assert!(seq == expected); - } } diff --git a/rand_distr/src/poisson.rs b/rand_distr/src/poisson.rs index 28f835ab7b8..a190256e15e 100644 --- a/rand_distr/src/poisson.rs +++ b/rand_distr/src/poisson.rs @@ -9,10 +9,10 @@ //! The Poisson distribution. -use crate::utils::Float; +use num_traits::{Float, FloatConst}; use crate::{Cauchy, Distribution, Standard}; use rand::Rng; -use std::{error, fmt}; +use core::fmt; /// The Poisson distribution `Poisson(lambda)`. /// @@ -25,17 +25,19 @@ use std::{error, fmt}; /// use rand_distr::{Poisson, Distribution}; /// /// let poi = Poisson::new(2.0).unwrap(); -/// let v: u64 = poi.sample(&mut rand::thread_rng()); +/// let v = poi.sample(&mut rand::thread_rng()); /// println!("{} is from a Poisson(2) distribution", v); /// ``` #[derive(Clone, Copy, Debug)] -pub struct Poisson { - lambda: N, +pub struct Poisson +where F: Float + FloatConst, Standard: Distribution +{ + lambda: F, // precalculated values - exp_lambda: N, - log_lambda: N, - sqrt_2lambda: N, - magic_val: N, + exp_lambda: F, + log_lambda: F, + sqrt_2lambda: F, + magic_val: F, } /// Error type returned from `Poisson::new`. @@ -53,15 +55,16 @@ impl fmt::Display for Error { } } -impl error::Error for Error {} +#[cfg(feature = "std")] +impl std::error::Error for Error {} -impl Poisson -where Standard: Distribution +impl Poisson +where F: Float + FloatConst, Standard: Distribution { /// Construct a new `Poisson` with the given shape parameter /// `lambda`. - pub fn new(lambda: N) -> Result, Error> { - if !(lambda > N::from(0.0)) { + pub fn new(lambda: F) -> Result, Error> { + if !(lambda > F::zero()) { return Err(Error::ShapeTooSmall); } let log_lambda = lambda.ln(); @@ -69,34 +72,34 @@ where Standard: Distribution lambda, exp_lambda: (-lambda).exp(), log_lambda, - sqrt_2lambda: (N::from(2.0) * lambda).sqrt(), - magic_val: lambda * log_lambda - (N::from(1.0) + lambda).log_gamma(), + sqrt_2lambda: (F::from(2.0).unwrap() * lambda).sqrt(), + magic_val: lambda * log_lambda - crate::utils::log_gamma(F::one() + lambda), }) } } -impl Distribution for Poisson -where Standard: Distribution +impl Distribution for Poisson +where F: Float + FloatConst, Standard: Distribution { #[inline] - fn sample(&self, rng: &mut R) -> N { + fn sample(&self, rng: &mut R) -> F { // using the algorithm from Numerical Recipes in C // for low expected values use the Knuth method - if self.lambda < N::from(12.0) { - let mut result = N::from(0.); - let mut p = N::from(1.0); + if self.lambda < F::from(12.0).unwrap() { + let mut result = F::zero(); + let mut p = F::one(); while p > self.exp_lambda { - p *= rng.gen::(); - result += N::from(1.); + p = p*rng.gen::(); + result = result + F::one(); } - result - N::from(1.) + result - F::one() } // high expected values - rejection method else { // we use the Cauchy distribution as the comparison distribution // f(x) ~ 1/(1+x^2) - let cauchy = Cauchy::new(N::from(0.0), N::from(1.0)).unwrap(); + let cauchy = Cauchy::new(F::zero(), F::one()).unwrap(); let mut result; loop { @@ -108,7 +111,7 @@ where Standard: Distribution // shift the peak of the comparison ditribution result = self.sqrt_2lambda * comp_dev + self.lambda; // repeat the drawing until we are in the range of possible values - if result >= N::from(0.0) { + if result >= F::zero() { break; } } @@ -120,15 +123,15 @@ where Standard: Distribution // the magic value scales the distribution function to a range of approximately 0-1 // since it is not exact, we multiply the ratio by 0.9 to avoid ratios greater than 1 // this doesn't change the resulting distribution, only increases the rate of failed drawings - let check = N::from(0.9) - * (N::from(1.0) + comp_dev * comp_dev) + let check = F::from(0.9).unwrap() + * (F::one() + comp_dev * comp_dev) * (result * self.log_lambda - - (N::from(1.0) + result).log_gamma() + - crate::utils::log_gamma(F::one() + result) - self.magic_val) .exp(); // check with uniform random value - if below the threshold, we are within the target distribution - if rng.gen::() <= check { + if rng.gen::() <= check { break; } } @@ -137,100 +140,29 @@ where Standard: Distribution } } -impl Distribution for Poisson -where Standard: Distribution -{ - #[inline] - fn sample(&self, rng: &mut R) -> u64 { - let result: N = self.sample(rng); - result.to_u64().unwrap() - } -} - #[cfg(test)] mod test { use super::*; - #[test] - fn test_poisson_10() { - let poisson = Poisson::new(10.0).unwrap(); - let mut rng = crate::test::rng(123); - let mut sum_u64 = 0; - let mut sum_f64 = 0.; - for _ in 0..1000 { - let s_u64: u64 = poisson.sample(&mut rng); - let s_f64: f64 = poisson.sample(&mut rng); - sum_u64 += s_u64; - sum_f64 += s_f64; - } - let avg_u64 = (sum_u64 as f64) / 1000.0; - let avg_f64 = sum_f64 / 1000.0; - println!("Poisson averages: {} (u64) {} (f64)", avg_u64, avg_f64); - for &avg in &[avg_u64, avg_f64] { - assert!((avg - 10.0).abs() < 0.5); // not 100% certain, but probable enough - } - } - - #[test] - fn test_poisson_15() { - // Take the 'high expected values' path - let poisson = Poisson::new(15.0).unwrap(); - let mut rng = crate::test::rng(123); - let mut sum_u64 = 0; - let mut sum_f64 = 0.; - for _ in 0..1000 { - let s_u64: u64 = poisson.sample(&mut rng); - let s_f64: f64 = poisson.sample(&mut rng); - sum_u64 += s_u64; - sum_f64 += s_f64; - } - let avg_u64 = (sum_u64 as f64) / 1000.0; - let avg_f64 = sum_f64 / 1000.0; - println!("Poisson average: {} (u64) {} (f64)", avg_u64, avg_f64); - for &avg in &[avg_u64, avg_f64] { - assert!((avg - 15.0).abs() < 0.5); // not 100% certain, but probable enough - } - } - - #[test] - fn test_poisson_10_f32() { - let poisson = Poisson::new(10.0f32).unwrap(); + fn test_poisson_avg_gen(lambda: F, tol: F) + where Standard: Distribution + { + let poisson = Poisson::new(lambda).unwrap(); let mut rng = crate::test::rng(123); - let mut sum_u64 = 0; - let mut sum_f32 = 0.; + let mut sum = F::zero(); for _ in 0..1000 { - let s_u64: u64 = poisson.sample(&mut rng); - let s_f32: f32 = poisson.sample(&mut rng); - sum_u64 += s_u64; - sum_f32 += s_f32; - } - let avg_u64 = (sum_u64 as f32) / 1000.0; - let avg_f32 = sum_f32 / 1000.0; - println!("Poisson averages: {} (u64) {} (f32)", avg_u64, avg_f32); - for &avg in &[avg_u64, avg_f32] { - assert!((avg - 10.0).abs() < 0.5); // not 100% certain, but probable enough + sum = sum + poisson.sample(&mut rng); } + let avg = sum / F::from(1000.0).unwrap(); + assert!((avg - lambda).abs() < tol); } #[test] - fn test_poisson_15_f32() { - // Take the 'high expected values' path - let poisson = Poisson::new(15.0f32).unwrap(); - let mut rng = crate::test::rng(123); - let mut sum_u64 = 0; - let mut sum_f32 = 0.; - for _ in 0..1000 { - let s_u64: u64 = poisson.sample(&mut rng); - let s_f32: f32 = poisson.sample(&mut rng); - sum_u64 += s_u64; - sum_f32 += s_f32; - } - let avg_u64 = (sum_u64 as f32) / 1000.0; - let avg_f32 = sum_f32 / 1000.0; - println!("Poisson average: {} (u64) {} (f32)", avg_u64, avg_f32); - for &avg in &[avg_u64, avg_f32] { - assert!((avg - 15.0).abs() < 0.5); // not 100% certain, but probable enough - } + fn test_poisson_avg() { + test_poisson_avg_gen::(10.0, 0.5); + test_poisson_avg_gen::(15.0, 0.5); + test_poisson_avg_gen::(10.0, 0.5); + test_poisson_avg_gen::(15.0, 0.5); } #[test] @@ -244,23 +176,4 @@ mod test { fn test_poisson_invalid_lambda_neg() { Poisson::new(-10.0).unwrap(); } - - #[test] - fn value_stability() { - fn test_samples>( - distr: D, zero: N, expected: &[N], - ) { - let mut rng = crate::test::rng(223); - let mut buf = [zero; 4]; - for x in &mut buf { - *x = rng.sample(&distr); - } - assert_eq!(buf, expected); - } - - // Special cases: < 12, >= 12 - test_samples(Poisson::new(7.0).unwrap(), 0f32, &[5.0, 11.0, 6.0, 5.0]); - test_samples(Poisson::new(7.0).unwrap(), 0f64, &[9.0, 5.0, 7.0, 6.0]); - test_samples(Poisson::new(27.0).unwrap(), 0f32, &[28.0, 32.0, 36.0, 36.0]); - } } diff --git a/rand_distr/src/triangular.rs b/rand_distr/src/triangular.rs index f290f030765..6d3d4cfd03f 100644 --- a/rand_distr/src/triangular.rs +++ b/rand_distr/src/triangular.rs @@ -7,10 +7,10 @@ // except according to those terms. //! The triangular distribution. -use crate::utils::Float; +use num_traits::Float; use crate::{Distribution, Standard}; use rand::Rng; -use std::{error, fmt}; +use core::fmt; /// The triangular distribution. /// @@ -32,10 +32,12 @@ use std::{error, fmt}; /// /// [`Pert`]: crate::Pert #[derive(Clone, Copy, Debug)] -pub struct Triangular { - min: N, - max: N, - mode: N, +pub struct Triangular +where F: Float, Standard: Distribution +{ + min: F, + max: F, + mode: F, } /// Error type returned from [`Triangular::new`]. @@ -58,14 +60,15 @@ impl fmt::Display for TriangularError { } } -impl error::Error for TriangularError {} +#[cfg(feature = "std")] +impl std::error::Error for TriangularError {} -impl Triangular -where Standard: Distribution +impl Triangular +where F: Float, Standard: Distribution { /// Set up the Triangular distribution with defined `min`, `max` and `mode`. #[inline] - pub fn new(min: N, max: N, mode: N) -> Result, TriangularError> { + pub fn new(min: F, max: F, mode: F) -> Result, TriangularError> { if !(max >= min) { return Err(TriangularError::RangeTooSmall); } @@ -76,12 +79,12 @@ where Standard: Distribution } } -impl Distribution for Triangular -where Standard: Distribution +impl Distribution for Triangular +where F: Float, Standard: Distribution { #[inline] - fn sample(&self, rng: &mut R) -> N { - let f: N = rng.sample(Standard); + fn sample(&self, rng: &mut R) -> F { + let f: F = rng.sample(Standard); let diff_mode_min = self.mode - self.min; let range = self.max - self.min; let f_range = f * range; @@ -97,7 +100,6 @@ where Standard: Distribution mod test { use super::*; use rand::{rngs::mock, Rng}; - use std::f64; #[test] fn test_triangular() { @@ -111,7 +113,8 @@ mod test { (0., 1., 0.9, 0.45f64.sqrt()), (-4., -0.5, -2., -4.0 + 3.5f64.sqrt()), ] { - println!("{} {} {} {}", min, max, mode, median); + #[cfg(feature = "std")] + std::println!("{} {} {} {}", min, max, mode, median); let distr = Triangular::new(min, max, mode).unwrap(); // Test correct value at median: assert_eq!(distr.sample(&mut half_rng), median); @@ -125,20 +128,4 @@ mod test { assert!(Triangular::new(min, max, mode).is_err()); } } - - #[test] - fn value_stability() { - let rng = crate::test::rng(860); - let distr = Triangular::new(2., 10., 3.).unwrap(); - let seq = distr.sample_iter(rng).take(5).collect::>(); - println!("seq: {:?}", seq); - let expected = vec![ - 5.74373257511361, - 7.890059162791258, - 4.7256280652553455, - 2.9474808121184077, - 3.058301946314053, - ]; - assert!(seq == expected); - } } diff --git a/rand_distr/src/unit_ball.rs b/rand_distr/src/unit_ball.rs index 616a25161ae..e5585a1e677 100644 --- a/rand_distr/src/unit_ball.rs +++ b/rand_distr/src/unit_ball.rs @@ -6,7 +6,7 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use crate::utils::Float; +use num_traits::Float; use crate::{uniform::SampleUniform, Distribution, Uniform}; use rand::Rng; @@ -27,10 +27,10 @@ use rand::Rng; #[derive(Clone, Copy, Debug)] pub struct UnitBall; -impl Distribution<[N; 3]> for UnitBall { +impl Distribution<[F; 3]> for UnitBall { #[inline] - fn sample(&self, rng: &mut R) -> [N; 3] { - let uniform = Uniform::new(N::from(-1.), N::from(1.)); + fn sample(&self, rng: &mut R) -> [F; 3] { + let uniform = Uniform::new(F::from(-1.).unwrap(), F::from(1.).unwrap()); let mut x1; let mut x2; let mut x3; @@ -38,32 +38,10 @@ impl Distribution<[N; 3]> for UnitBall { x1 = uniform.sample(rng); x2 = uniform.sample(rng); x3 = uniform.sample(rng); - if x1 * x1 + x2 * x2 + x3 * x3 <= N::from(1.) { + if x1 * x1 + x2 * x2 + x3 * x3 <= F::from(1.).unwrap() { break; } } [x1, x2, x3] } } - -#[cfg(test)] -mod tests { - use super::UnitBall; - use crate::Distribution; - - #[test] - fn value_stability() { - let mut rng = crate::test::rng(2); - let expected = [ - [0.018035709265959987, -0.4348771383120438, -0.07982762085055706], - [0.10588569388223945, -0.4734350111375454, -0.7392104908825501], - [0.11060237642041049, -0.16065642822852677, -0.8444043930440075] - ]; - let samples: [[f64; 3]; 3] = [ - UnitBall.sample(&mut rng), - UnitBall.sample(&mut rng), - UnitBall.sample(&mut rng), - ]; - assert_eq!(samples, expected); - } -} diff --git a/rand_distr/src/unit_circle.rs b/rand_distr/src/unit_circle.rs index 9f9844a5536..fee9d8ce78f 100644 --- a/rand_distr/src/unit_circle.rs +++ b/rand_distr/src/unit_circle.rs @@ -6,7 +6,7 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use crate::utils::Float; +use num_traits::Float; use crate::{uniform::SampleUniform, Distribution, Uniform}; use rand::Rng; @@ -31,10 +31,10 @@ use rand::Rng; #[derive(Clone, Copy, Debug)] pub struct UnitCircle; -impl Distribution<[N; 2]> for UnitCircle { +impl Distribution<[F; 2]> for UnitCircle { #[inline] - fn sample(&self, rng: &mut R) -> [N; 2] { - let uniform = Uniform::new(N::from(-1.), N::from(1.)); + fn sample(&self, rng: &mut R) -> [F; 2] { + let uniform = Uniform::new(F::from(-1.).unwrap(), F::from(1.).unwrap()); let mut x1; let mut x2; let mut sum; @@ -42,12 +42,12 @@ impl Distribution<[N; 2]> for UnitCircle { x1 = uniform.sample(rng); x2 = uniform.sample(rng); sum = x1 * x1 + x2 * x2; - if sum < N::from(1.) { + if sum < F::from(1.).unwrap() { break; } } let diff = x1 * x1 - x2 * x2; - [diff / sum, N::from(2.) * x1 * x2 / sum] + [diff / sum, F::from(2.).unwrap() * x1 * x2 / sum] } } @@ -64,11 +64,11 @@ mod tests { ($a:expr, $b:expr, $prec:expr) => { let diff = ($a - $b).abs(); if diff > $prec { - panic!(format!( + panic!( "assertion failed: `abs(left - right) = {:.1e} < {:e}`, \ (left: `{}`, right: `{}`)", diff, $prec, $a, $b - )); + ); } }; } @@ -81,20 +81,4 @@ mod tests { assert_almost_eq!(x[0] * x[0] + x[1] * x[1], 1., 1e-15); } } - - #[test] - fn value_stability() { - let mut rng = crate::test::rng(2); - let expected = [ - [-0.9965658683520504, -0.08280380447614634], - [-0.9790853270389644, -0.20345004884984505], - [-0.8449189758898707, 0.5348943112253227], - ]; - let samples: [[f64; 2]; 3] = [ - UnitCircle.sample(&mut rng), - UnitCircle.sample(&mut rng), - UnitCircle.sample(&mut rng), - ]; - assert_eq!(samples, expected); - } } diff --git a/rand_distr/src/unit_disc.rs b/rand_distr/src/unit_disc.rs index dc37c129cb9..ced548b4dc0 100644 --- a/rand_distr/src/unit_disc.rs +++ b/rand_distr/src/unit_disc.rs @@ -6,7 +6,7 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use crate::utils::Float; +use num_traits::Float; use crate::{uniform::SampleUniform, Distribution, Uniform}; use rand::Rng; @@ -26,41 +26,19 @@ use rand::Rng; #[derive(Clone, Copy, Debug)] pub struct UnitDisc; -impl Distribution<[N; 2]> for UnitDisc { +impl Distribution<[F; 2]> for UnitDisc { #[inline] - fn sample(&self, rng: &mut R) -> [N; 2] { - let uniform = Uniform::new(N::from(-1.), N::from(1.)); + fn sample(&self, rng: &mut R) -> [F; 2] { + let uniform = Uniform::new(F::from(-1.).unwrap(), F::from(1.).unwrap()); let mut x1; let mut x2; loop { x1 = uniform.sample(rng); x2 = uniform.sample(rng); - if x1 * x1 + x2 * x2 <= N::from(1.) { + if x1 * x1 + x2 * x2 <= F::from(1.).unwrap() { break; } } [x1, x2] } } - -#[cfg(test)] -mod tests { - use super::UnitDisc; - use crate::Distribution; - - #[test] - fn value_stability() { - let mut rng = crate::test::rng(2); - let expected = [ - [0.018035709265959987, -0.4348771383120438], - [-0.07982762085055706, 0.7765329819820659], - [0.21450745997299503, 0.7398636984333291], - ]; - let samples: [[f64; 2]; 3] = [ - UnitDisc.sample(&mut rng), - UnitDisc.sample(&mut rng), - UnitDisc.sample(&mut rng), - ]; - assert_eq!(samples, expected); - } -} diff --git a/rand_distr/src/unit_sphere.rs b/rand_distr/src/unit_sphere.rs index 54539cc9e41..a5ec0e009ac 100644 --- a/rand_distr/src/unit_sphere.rs +++ b/rand_distr/src/unit_sphere.rs @@ -6,7 +6,7 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use crate::utils::Float; +use num_traits::Float; use crate::{uniform::SampleUniform, Distribution, Uniform}; use rand::Rng; @@ -30,18 +30,18 @@ use rand::Rng; #[derive(Clone, Copy, Debug)] pub struct UnitSphere; -impl Distribution<[N; 3]> for UnitSphere { +impl Distribution<[F; 3]> for UnitSphere { #[inline] - fn sample(&self, rng: &mut R) -> [N; 3] { - let uniform = Uniform::new(N::from(-1.), N::from(1.)); + fn sample(&self, rng: &mut R) -> [F; 3] { + let uniform = Uniform::new(F::from(-1.).unwrap(), F::from(1.).unwrap()); loop { let (x1, x2) = (uniform.sample(rng), uniform.sample(rng)); let sum = x1 * x1 + x2 * x2; - if sum >= N::from(1.) { + if sum >= F::from(1.).unwrap() { continue; } - let factor = N::from(2.) * (N::from(1.0) - sum).sqrt(); - return [x1 * factor, x2 * factor, N::from(1.) - N::from(2.) * sum]; + let factor = F::from(2.).unwrap() * (F::one() - sum).sqrt(); + return [x1 * factor, x2 * factor, F::from(1.).unwrap() - F::from(2.).unwrap() * sum]; } } } @@ -59,11 +59,11 @@ mod tests { ($a:expr, $b:expr, $prec:expr) => { let diff = ($a - $b).abs(); if diff > $prec { - panic!(format!( + panic!( "assertion failed: `abs(left - right) = {:.1e} < {:e}`, \ (left: `{}`, right: `{}`)", diff, $prec, $a, $b - )); + ); } }; } @@ -76,20 +76,4 @@ mod tests { assert_almost_eq!(x[0] * x[0] + x[1] * x[1] + x[2] * x[2], 1., 1e-15); } } - - #[test] - fn value_stability() { - let mut rng = crate::test::rng(2); - let expected = [ - [0.03247542860231647, -0.7830477442152738, 0.6211131755296027], - [-0.09978440840914075, 0.9706650829833128, -0.21875184231323952], - [0.2735582468624679, 0.9435374242279655, -0.1868234852870203], - ]; - let samples: [[f64; 3]; 3] = [ - UnitSphere.sample(&mut rng), - UnitSphere.sample(&mut rng), - UnitSphere.sample(&mut rng), - ]; - assert_eq!(samples, expected); - } } diff --git a/rand_distr/src/utils.rs b/rand_distr/src/utils.rs index 478aacf7625..878faf2072b 100644 --- a/rand_distr/src/utils.rs +++ b/rand_distr/src/utils.rs @@ -9,184 +9,9 @@ //! Math helper functions use crate::ziggurat_tables; -use core::{cmp, ops}; use rand::distributions::hidden_export::IntoFloat; use rand::Rng; -/// Trait for floating-point scalar types -/// -/// This allows many distributions to work with `f32` or `f64` parameters and is -/// potentially extensible. Note however that the `Exp1` and `StandardNormal` -/// distributions are implemented exclusively for `f32` and `f64`. -/// -/// The bounds and methods are based purely on internal -/// requirements, and will change as needed. -pub trait Float: - Copy - + Sized - + cmp::PartialOrd - + ops::Neg - + ops::Add - + ops::Sub - + ops::Mul - + ops::Div - + ops::AddAssign - + ops::SubAssign - + ops::MulAssign - + ops::DivAssign -{ - /// The constant π - fn pi() -> Self; - /// Support approximate representation of a f64 value - fn from(x: f64) -> Self; - /// Support converting to an unsigned integer. - fn to_u64(self) -> Option; - - /// Take the absolute value of self - fn abs(self) -> Self; - /// Take the largest integer less than or equal to self - fn floor(self) -> Self; - - /// Take the exponential of self - fn exp(self) -> Self; - /// Take the natural logarithm of self - fn ln(self) -> Self; - /// Take square root of self - fn sqrt(self) -> Self; - /// Take self to a floating-point power - fn powf(self, power: Self) -> Self; - - /// Take the tangent of self - fn tan(self) -> Self; - /// Take the logarithm of the gamma function of self - fn log_gamma(self) -> Self; -} - -impl Float for f32 { - #[inline] - fn pi() -> Self { - core::f32::consts::PI - } - - #[inline] - fn from(x: f64) -> Self { - x as f32 - } - - #[inline] - fn to_u64(self) -> Option { - if self >= 0. && self <= ::core::u64::MAX as f32 { - Some(self as u64) - } else { - None - } - } - - #[inline] - fn abs(self) -> Self { - self.abs() - } - - #[inline] - fn floor(self) -> Self { - self.floor() - } - - #[inline] - fn exp(self) -> Self { - self.exp() - } - - #[inline] - fn ln(self) -> Self { - self.ln() - } - - #[inline] - fn sqrt(self) -> Self { - self.sqrt() - } - - #[inline] - fn powf(self, power: Self) -> Self { - self.powf(power) - } - - #[inline] - fn tan(self) -> Self { - self.tan() - } - - #[inline] - fn log_gamma(self) -> Self { - let result = log_gamma(self.into()); - assert!(result <= ::core::f32::MAX.into()); - assert!(result >= ::core::f32::MIN.into()); - result as f32 - } -} - -impl Float for f64 { - #[inline] - fn pi() -> Self { - core::f64::consts::PI - } - - #[inline] - fn from(x: f64) -> Self { - x - } - - #[inline] - fn to_u64(self) -> Option { - if self >= 0. && self <= ::core::u64::MAX as f64 { - Some(self as u64) - } else { - None - } - } - - #[inline] - fn abs(self) -> Self { - self.abs() - } - - #[inline] - fn floor(self) -> Self { - self.floor() - } - - #[inline] - fn exp(self) -> Self { - self.exp() - } - - #[inline] - fn ln(self) -> Self { - self.ln() - } - - #[inline] - fn sqrt(self) -> Self { - self.sqrt() - } - - #[inline] - fn powf(self, power: Self) -> Self { - self.powf(power) - } - - #[inline] - fn tan(self) -> Self { - self.tan() - } - - #[inline] - fn log_gamma(self) -> Self { - log_gamma(self) - } -} - /// Calculates ln(gamma(x)) (natural logarithm of the gamma /// function) using the Lanczos approximation. /// @@ -200,33 +25,33 @@ impl Float for f64 { /// `Ag(z)` is an infinite series with coefficients that can be calculated /// ahead of time - we use just the first 6 terms, which is good enough /// for most purposes. -pub(crate) fn log_gamma(x: f64) -> f64 { +pub(crate) fn log_gamma(x: F) -> F { // precalculated 6 coefficients for the first 6 terms of the series - let coefficients: [f64; 6] = [ - 76.18009172947146, - -86.50532032941677, - 24.01409824083091, - -1.231739572450155, - 0.1208650973866179e-2, - -0.5395239384953e-5, + let coefficients: [F; 6] = [ + F::from(76.18009172947146).unwrap(), + F::from(-86.50532032941677).unwrap(), + F::from(24.01409824083091).unwrap(), + F::from(-1.231739572450155).unwrap(), + F::from(0.1208650973866179e-2).unwrap(), + F::from(-0.5395239384953e-5).unwrap(), ]; // (x+0.5)*ln(x+g+0.5)-(x+g+0.5) - let tmp = x + 5.5; - let log = (x + 0.5) * tmp.ln() - tmp; + let tmp = x + F::from(5.5).unwrap(); + let log = (x + F::from(0.5).unwrap()) * tmp.ln() - tmp; // the first few terms of the series for Ag(x) - let mut a = 1.000000000190015; + let mut a = F::from(1.000000000190015).unwrap(); let mut denom = x; for &coeff in &coefficients { - denom += 1.0; - a += coeff / denom; + denom = denom + F::one(); + a = a + (coeff / denom); } // get everything together // a is Ag(x) // 2.5066... is sqrt(2pi) - log + (2.5066282746310005 * a / x).ln() + log + (F::from(2.5066282746310005).unwrap() * a / x).ln() } /// Sample a random number using the Ziggurat method (specifically the @@ -274,7 +99,7 @@ where (bits >> 12).into_float_with_exponent(1) - 3.0 } else { // Convert to a value in the range [1,2) and substract to get (0,1) - (bits >> 12).into_float_with_exponent(0) - (1.0 - std::f64::EPSILON / 2.0) + (bits >> 12).into_float_with_exponent(0) - (1.0 - core::f64::EPSILON / 2.0) }; let x = u * x_tab[i]; diff --git a/rand_distr/src/weibull.rs b/rand_distr/src/weibull.rs index 6ef3e553363..184e5e06b16 100644 --- a/rand_distr/src/weibull.rs +++ b/rand_distr/src/weibull.rs @@ -8,10 +8,10 @@ //! The Weibull distribution. -use crate::utils::Float; +use num_traits::Float; use crate::{Distribution, OpenClosed01}; use rand::Rng; -use std::{error, fmt}; +use core::fmt; /// Samples floating-point numbers according to the Weibull distribution /// @@ -24,9 +24,11 @@ use std::{error, fmt}; /// println!("{}", val); /// ``` #[derive(Clone, Copy, Debug)] -pub struct Weibull { - inv_shape: N, - scale: N, +pub struct Weibull +where F: Float, OpenClosed01: Distribution +{ + inv_shape: F, + scale: F, } /// Error type returned from `Weibull::new`. @@ -47,31 +49,32 @@ impl fmt::Display for Error { } } -impl error::Error for Error {} +#[cfg(feature = "std")] +impl std::error::Error for Error {} -impl Weibull -where OpenClosed01: Distribution +impl Weibull +where F: Float, OpenClosed01: Distribution { /// Construct a new `Weibull` distribution with given `scale` and `shape`. - pub fn new(scale: N, shape: N) -> Result, Error> { - if !(scale > N::from(0.0)) { + pub fn new(scale: F, shape: F) -> Result, Error> { + if !(scale > F::zero()) { return Err(Error::ScaleTooSmall); } - if !(shape > N::from(0.0)) { + if !(shape > F::zero()) { return Err(Error::ShapeTooSmall); } Ok(Weibull { - inv_shape: N::from(1.) / shape, + inv_shape: F::from(1.).unwrap() / shape, scale, }) } } -impl Distribution for Weibull -where OpenClosed01: Distribution +impl Distribution for Weibull +where F: Float, OpenClosed01: Distribution { - fn sample(&self, rng: &mut R) -> N { - let x: N = rng.sample(OpenClosed01); + fn sample(&self, rng: &mut R) -> F { + let x: F = rng.sample(OpenClosed01); self.scale * (-x.ln()).powf(self.inv_shape) } } @@ -100,8 +103,8 @@ mod tests { #[test] fn value_stability() { - fn test_samples>( - distr: D, zero: N, expected: &[N], + fn test_samples>( + distr: D, zero: F, expected: &[F], ) { let mut rng = crate::test::rng(213); let mut buf = [zero; 4]; diff --git a/rand_distr/src/weighted/alias_method.rs b/rand_distr/src/weighted/alias_method.rs index 71c341f83cc..290a32dcfc9 100644 --- a/rand_distr/src/weighted/alias_method.rs +++ b/rand_distr/src/weighted/alias_method.rs @@ -15,6 +15,7 @@ use core::fmt; use core::iter::Sum; use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign}; use rand::Rng; +use alloc::{boxed::Box, vec, vec::Vec}; /// A distribution using weighted sampling to pick a discretely selected item. /// @@ -63,8 +64,8 @@ use rand::Rng; /// [`Uniform::sample`]: Distribution::sample /// [`Uniform::sample`]: Distribution::sample pub struct WeightedIndex { - aliases: Vec, - no_alias_odds: Vec, + aliases: Box<[u32]>, + no_alias_odds: Box<[W]>, uniform_index: Uniform, uniform_within_weight_sum: Uniform, } @@ -112,7 +113,7 @@ impl WeightedIndex { // `weight_sum` would have been zero if `try_from_lossy` causes an error here. let n_converted = W::try_from_u32_lossy(n).unwrap(); - let mut no_alias_odds = weights; + let mut no_alias_odds = weights.into_boxed_slice(); for odds in no_alias_odds.iter_mut() { *odds *= n_converted; // Prevent floating point overflow due to rounding errors. @@ -126,7 +127,7 @@ impl WeightedIndex { /// be ensured that a single index is only ever in one of them at the /// same time. struct Aliases { - aliases: Vec, + aliases: Box<[u32]>, smalls_head: u32, bigs_head: u32, } @@ -134,7 +135,7 @@ impl WeightedIndex { impl Aliases { fn new(size: u32) -> Self { Aliases { - aliases: vec![0; size as usize], + aliases: vec![0; size as usize].into_boxed_slice(), smalls_head: ::core::u32::MAX, bigs_head: ::core::u32::MAX, } diff --git a/rand_distr/tests/value_stability.rs b/rand_distr/tests/value_stability.rs new file mode 100644 index 00000000000..192ba748b7f --- /dev/null +++ b/rand_distr/tests/value_stability.rs @@ -0,0 +1,319 @@ +// Copyright 2018 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use core::{fmt::Debug, cmp::PartialEq}; +use rand::Rng; +use rand_distr::*; + +fn get_rng(seed: u64) -> impl rand::Rng { + // For tests, we want a statistically good, fast, reproducible RNG. + // PCG32 will do fine, and will be easy to embed if we ever need to. + const INC: u64 = 11634580027462260723; + rand_pcg::Pcg32::new(seed, INC) +} + +fn test_samples>( + seed: u64, distr: D, expected: &[F], +) { + let mut rng = get_rng(seed); + for &val in expected { + assert_eq!(val, rng.sample(&distr)); + } +} + +#[test] +fn binominal_stability() { + // We have multiple code paths: np < 10, p > 0.5 + test_samples(353, Binomial::new(2, 0.7).unwrap(), &[1, 1, 2, 1]); + test_samples(353, Binomial::new(20, 0.3).unwrap(), &[7, 7, 5, 7]); + test_samples(353, Binomial::new(2000, 0.6).unwrap(), &[1194, 1208, 1192, 1210]); +} + + +#[test] +fn unit_ball_stability() { + test_samples(2, UnitBall, &[ + [0.018035709265959987f64, -0.4348771383120438, -0.07982762085055706], + [0.10588569388223945, -0.4734350111375454, -0.7392104908825501], + [0.11060237642041049, -0.16065642822852677, -0.8444043930440075] + ]); +} + +#[test] +fn unit_circle_stability() { + test_samples(2, UnitCircle, &[ + [-0.9965658683520504f64, -0.08280380447614634], + [-0.9790853270389644, -0.20345004884984505], + [-0.8449189758898707, 0.5348943112253227], + ]); +} + +#[test] +fn unit_sphere_stability() { + test_samples(2, UnitSphere, &[ + [0.03247542860231647f64, -0.7830477442152738, 0.6211131755296027], + [-0.09978440840914075, 0.9706650829833128, -0.21875184231323952], + [0.2735582468624679, 0.9435374242279655, -0.1868234852870203], + ]); +} + +#[test] +fn unit_disc_stability() { + test_samples(2, UnitDisc, &[ + [0.018035709265959987f64, -0.4348771383120438], + [-0.07982762085055706, 0.7765329819820659], + [0.21450745997299503, 0.7398636984333291], + ]); +} + +#[test] +fn pareto_stability() { + test_samples(213, Pareto::new(1.0, 1.0).unwrap(), &[ + 1.0423688f32, 2.1235929, 4.132709, 1.4679428, + ]); + test_samples(213, Pareto::new(2.0, 0.5).unwrap(), &[ + 9.019295276219136f64, + 4.3097126018270595, + 6.837815045397157, + 105.8826669383772, + ]); +} + +#[test] +fn poisson_stability() { + test_samples(223, Poisson::new(7.0).unwrap(), &[5.0f32, 11.0, 6.0, 5.0]); + test_samples(223, Poisson::new(7.0).unwrap(), &[9.0f64, 5.0, 7.0, 6.0]); + test_samples(223, Poisson::new(27.0).unwrap(), &[28.0f32, 32.0, 36.0, 36.0]); +} + + +#[test] +fn triangular_stability() { + test_samples(860, Triangular::new(2., 10., 3.).unwrap(), &[ + 5.74373257511361f64, + 7.890059162791258f64, + 4.7256280652553455f64, + 2.9474808121184077f64, + 3.058301946314053f64, + ]); +} + + +#[test] +fn normal_inverse_gaussian_stability() { + test_samples(213, NormalInverseGaussian::new(2.0, 1.0).unwrap(), &[ + 0.6568966f32, 1.3744819, 2.216063, 0.11488572, + ]); + test_samples(213, NormalInverseGaussian::new(2.0, 1.0).unwrap(), &[ + 0.6838707059642927f64, + 2.4447306460569784, + 0.2361045023235968, + 1.7774534624785319, + ]); +} + +#[test] +fn pert_stability() { + // mean = 4, var = 12/7 + test_samples(860, Pert::new(2., 10., 3.).unwrap(), &[ + 4.631484136029422f64, + 3.307201472321789f64, + 3.29995019556348f64, + 3.66835483991721f64, + 3.514246139933899f64, + ]); +} + +#[test] +fn inverse_gaussian_stability() { + test_samples(213, InverseGaussian::new(1.0, 3.0).unwrap(),&[ + 0.9339157f32, 1.108113, 0.50864697, 0.39849377, + ]); + test_samples(213, InverseGaussian::new(1.0, 3.0).unwrap(), &[ + 1.0707604954722476f64, + 0.9628140605340697, + 0.4069687656468226, + 0.660283852985818, + ]); +} + +#[test] +fn gamma_stability() { + // Gamma has 3 cases: shape == 1, shape < 1, shape > 1 + test_samples(223, Gamma::new(1.0, 5.0).unwrap(), &[ + 5.398085f32, 9.162783, 0.2300583, 1.7235851, + ]); + test_samples(223, Gamma::new(0.8, 5.0).unwrap(), &[ + 0.5051203f32, 0.9048302, 3.095812, 1.8566116, + ]); + test_samples(223, Gamma::new(1.1, 5.0).unwrap(), &[ + 7.783878094584059f64, + 1.4939528171618057, + 8.638017638857592, + 3.0949337228829004, + ]); + + // ChiSquared has 2 cases: k == 1, k != 1 + test_samples(223, ChiSquared::new(1.0).unwrap(), &[ + 0.4893526200348249f64, + 1.635249736808788, + 0.5013580219361969, + 0.1457735613733489, + ]); + test_samples(223, ChiSquared::new(0.1).unwrap(), &[ + 0.014824404726978617f64, + 0.021602123937134326, + 0.0000003431429746851693, + 0.00000002291755769542258, + ]); + test_samples(223, ChiSquared::new(10.0).unwrap(), &[ + 12.693656f32, 6.812016, 11.082001, 12.436167, + ]); + + // FisherF has same special cases as ChiSquared on each param + test_samples(223, FisherF::new(1.0, 13.5).unwrap(), &[ + 0.32283646f32, 0.048049655, 0.0788893, 1.817178, + ]); + test_samples(223, FisherF::new(1.0, 1.0).unwrap(), &[ + 0.29925257f32, 3.4392934, 9.567652, 0.020074, + ]); + test_samples(223, FisherF::new(0.7, 13.5).unwrap(), &[ + 3.3196593155045124f64, + 0.3409169916262829, + 0.03377989856426519, + 0.00004041672861036937, + ]); + + // StudentT has same special cases as ChiSquared + test_samples(223, StudentT::new(1.0).unwrap(), &[ + 0.54703987f32, -1.8545331, 3.093162, -0.14168274, + ]); + test_samples(223, StudentT::new(1.1).unwrap(), &[ + 0.7729195887949754f64, + 1.2606210611616204, + -1.7553606501113175, + -2.377641221169782, + ]); + + // Beta has same special cases as Gamma on each param + test_samples(223, Beta::new(1.0, 0.8).unwrap(), &[ + 0.6444564f32, 0.357635, 0.4110078, 0.7347192, + ]); + test_samples(223, Beta::new(0.7, 1.2).unwrap(), &[ + 0.6433129944095513f64, + 0.5373371199711573, + 0.10313293199269491, + 0.002472280249144378, + ]); +} + +#[test] +fn exponential_stability() { + test_samples(223, Exp1, &[ + 1.079617f32, 1.8325565, 0.04601166, 0.34471703, + ]); + test_samples(223, Exp1, &[ + 1.0796170642388276f64, + 1.8325565304274, + 0.04601166186842716, + 0.3447170217100157, + ]); + + test_samples(223, Exp::new(2.0).unwrap(), &[ + 0.5398085f32, 0.91627824, 0.02300583, 0.17235851, + ]); + test_samples(223, Exp::new(1.0).unwrap(), &[ + 1.0796170642388276f64, + 1.8325565304274, + 0.04601166186842716, + 0.3447170217100157, + ]); +} + +#[test] +fn normal_stability() { + test_samples(213, StandardNormal, &[ + -0.11844189f32, 0.781378, 0.06563994, -1.1932899, + ]); + test_samples(213, StandardNormal, &[ + -0.11844188827977231f64, + 0.7813779637772346, + 0.06563993969580051, + -1.1932899004186373, + ]); + + test_samples(213, Normal::new(0.0, 1.0).unwrap(), &[ + -0.11844189f32, 0.781378, 0.06563994, -1.1932899, + ]); + test_samples(213, Normal::new(2.0, 0.5).unwrap(), &[ + 1.940779055860114f64, + 2.3906889818886174, + 2.0328199698479, + 1.4033550497906813, + ]); + + test_samples(213, LogNormal::new(0.0, 1.0).unwrap(), &[ + 0.88830346f32, 2.1844804, 1.0678421, 0.30322206, + ]); + test_samples(213, LogNormal::new(2.0, 0.5).unwrap(), &[ + 6.964174338639032f64, + 10.921015733601452, + 7.6355881556915906, + 4.068828213584092, + ]); +} + +#[test] +fn weibull_stability() { + test_samples(213, Weibull::new(1.0, 1.0).unwrap(), &[ + 0.041495778f32, 0.7531094, 1.4189332, 0.38386202, + ]); + test_samples(213, Weibull::new(2.0, 0.5).unwrap(), &[ + 1.1343478702739669f64, + 0.29470010050655226, + 0.7556151370284702, + 7.877212340241561, + ]); +} + +#[cfg(feature = "alloc")] +#[test] +fn dirichlet_stability() { + let mut rng = get_rng(223); + assert_eq!( + rng.sample(Dirichlet::new(&[1.0, 2.0, 3.0]).unwrap()), + vec![0.12941567177708177, 0.4702121891675036, 0.4003721390554146] + ); + assert_eq!(rng.sample(Dirichlet::new_with_size(8.0, 5).unwrap()), vec![ + 0.17684200044809556, + 0.29915953935953055, + 0.1832858056608014, + 0.1425623503573967, + 0.19815030417417595 + ]); +} + +#[test] +fn cauchy_stability() { + test_samples(353, Cauchy::new(100f64, 10.0).unwrap(), &[ + 77.93369152808678f64, + 90.1606912098641, + 125.31516221323625, + 86.10217834773925, + ]); + + // Unfortunately this test is not fully portable due to reliance on the + // system's implementation of tanf (see doc on Cauchy struct). + let distr = Cauchy::new(10f32, 7.0).unwrap(); + let mut rng = get_rng(353); + let expected = [15.023088, -5.446413, 3.7092876, 3.112482]; + for &a in expected.iter() { + let b = rng.sample(&distr); + assert!((a - b).abs() < 1e-6, "expected: {} = {}", a, b); + } +}