diff --git a/rand_distr/src/lib.rs b/rand_distr/src/lib.rs index 957e0424c94..eea8bed3fea 100644 --- a/rand_distr/src/lib.rs +++ b/rand_distr/src/lib.rs @@ -80,6 +80,10 @@ extern crate alloc; #[cfg(feature = "std")] extern crate std; +// This is used for doc links: +#[allow(unused)] +use rand::Rng; + pub use rand::distributions::{ uniform, Alphanumeric, Bernoulli, BernoulliError, DistIter, Distribution, Open01, OpenClosed01, Standard, Uniform, diff --git a/src/distributions/uniform.rs b/src/distributions/uniform.rs index 819a0bccd99..bbd96948f8c 100644 --- a/src/distributions/uniform.rs +++ b/src/distributions/uniform.rs @@ -706,6 +706,78 @@ uniform_simd_int_impl! { u8 } +impl SampleUniform for char { + type Sampler = UniformChar; +} + +/// The back-end implementing [`UniformSampler`] for `char`. +/// +/// Unless you are implementing [`UniformSampler`] for your own type, this type +/// should not be used directly, use [`Uniform`] instead. +/// +/// This differs from integer range sampling since the range `0xD800..=0xDFFF` +/// are used for surrogate pairs in UCS and UTF-16, and consequently are not +/// valid Unicode code points. We must therefore avoid sampling values in this +/// range. +#[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +pub struct UniformChar { + sampler: UniformInt, +} + +/// UTF-16 surrogate range start +const CHAR_SURROGATE_START: u32 = 0xD800; +/// UTF-16 surrogate range size +const CHAR_SURROGATE_LEN: u32 = 0xE000 - CHAR_SURROGATE_START; + +/// Convert `char` to compressed `u32` +fn char_to_comp_u32(c: char) -> u32 { + match c as u32 { + c if c >= CHAR_SURROGATE_START => c - CHAR_SURROGATE_LEN, + c => c, + } +} + +impl UniformSampler for UniformChar { + type X = char; + + #[inline] // if the range is constant, this helps LLVM to do the + // calculations at compile-time. + fn new(low_b: B1, high_b: B2) -> Self + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = char_to_comp_u32(*low_b.borrow()); + let high = char_to_comp_u32(*high_b.borrow()); + let sampler = UniformInt::::new(low, high); + UniformChar { sampler } + } + + #[inline] // if the range is constant, this helps LLVM to do the + // calculations at compile-time. + fn new_inclusive(low_b: B1, high_b: B2) -> Self + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = char_to_comp_u32(*low_b.borrow()); + let high = char_to_comp_u32(*high_b.borrow()); + let sampler = UniformInt::::new_inclusive(low, high); + UniformChar { sampler } + } + + fn sample(&self, rng: &mut R) -> Self::X { + let mut x = self.sampler.sample(rng); + if x >= CHAR_SURROGATE_START { + x += CHAR_SURROGATE_LEN; + } + // SAFETY: x must not be in surrogate range or greater than char::MAX. + // This relies on range constructors which accept char arguments. + // Validity of input char values is assumed. + unsafe { core::char::from_u32_unchecked(x) } + } +} /// The back-end implementing [`UniformSampler`] for floating-point types. /// @@ -1207,6 +1279,27 @@ mod tests { } } + #[test] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_char() { + let mut rng = crate::test::rng(891); + let mut max = core::char::from_u32(0).unwrap(); + for _ in 0..100 { + let c = rng.gen_range('A'..='Z'); + assert!('A' <= c && c <= 'Z'); + max = max.max(c); + } + assert_eq!(max, 'Z'); + let d = Uniform::new( + core::char::from_u32(0xD7F0).unwrap(), + core::char::from_u32(0xE010).unwrap(), + ); + for _ in 0..100 { + let c = d.sample(&mut rng); + assert!((c as u32) < 0xD800 || (c as u32) > 0xDFFF); + } + } + #[test] #[cfg_attr(miri, ignore)] // Miri is too slow fn test_floats() { diff --git a/src/seq/index.rs b/src/seq/index.rs index 0a5619c5aae..c09e5804229 100644 --- a/src/seq/index.rs +++ b/src/seq/index.rs @@ -359,11 +359,11 @@ where } // Partially sort the array to find the `amount` elements with the greatest - // keys. Do this by using `partition_at_index` to put the elements with + // keys. Do this by using `select_nth_unstable` to put the elements with // the *smallest* keys at the beginning of the list in `O(n)` time, which // provides equivalent information about the elements with the *greatest* keys. let (_, mid, greater) - = candidates.partition_at_index(length.as_usize() - amount.as_usize()); + = candidates.select_nth_unstable(length.as_usize() - amount.as_usize()); let mut result: Vec = Vec::with_capacity(amount.as_usize()); result.push(mid.index); diff --git a/src/seq/mod.rs b/src/seq/mod.rs index f8d21c9a444..9e6ffaf1242 100644 --- a/src/seq/mod.rs +++ b/src/seq/mod.rs @@ -296,7 +296,7 @@ pub trait IteratorRandom: Iterator + Sized { /// depends on size hints. In particular, `Iterator` combinators that don't /// change the values yielded but change the size hints may result in /// `choose` returning different elements. If you want consistent results - /// and RNG usage consider using [`choose_stable`]. + /// and RNG usage consider using [`IteratorRandom::choose_stable`]. fn choose(mut self, rng: &mut R) -> Option where R: Rng + ?Sized { let (mut lower, mut upper) = self.size_hint(); @@ -364,6 +364,8 @@ pub trait IteratorRandom: Iterator + Sized { /// constructing elements where possible, however the selection and `rng` /// calls are the same in the face of this optimization. If you want to /// force every element to be created regardless call `.inspect(|e| ())`. + /// + /// [`choose`]: IteratorRandom::choose fn choose_stable(mut self, rng: &mut R) -> Option where R: Rng + ?Sized { let mut consumed = 0;