Skip to content

Commit 32ba8ed

Browse files
Remove poor-performing bitmasks, add Select trait, and enable select on integer bitmasks (#482)
1 parent 936d58b commit 32ba8ed

File tree

10 files changed

+246
-617
lines changed

10 files changed

+246
-617
lines changed

crates/core_simd/src/masks.rs

Lines changed: 85 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,33 @@
22
//! Types representing
33
#![allow(non_camel_case_types)]
44

5-
#[cfg_attr(
6-
not(all(target_arch = "x86_64", target_feature = "avx512f")),
7-
path = "masks/full_masks.rs"
8-
)]
9-
#[cfg_attr(
10-
all(target_arch = "x86_64", target_feature = "avx512f"),
11-
path = "masks/bitmask.rs"
12-
)]
13-
mod mask_impl;
14-
15-
use crate::simd::{LaneCount, Simd, SimdCast, SimdElement, SupportedLaneCount};
5+
use crate::simd::{LaneCount, Select, Simd, SimdCast, SimdElement, SupportedLaneCount};
166
use core::cmp::Ordering;
177
use core::{fmt, mem};
188

9+
pub(crate) trait FixEndianness {
10+
fn fix_endianness(self) -> Self;
11+
}
12+
13+
macro_rules! impl_fix_endianness {
14+
{ $($int:ty),* } => {
15+
$(
16+
impl FixEndianness for $int {
17+
#[inline(always)]
18+
fn fix_endianness(self) -> Self {
19+
if cfg!(target_endian = "big") {
20+
<$int>::reverse_bits(self)
21+
} else {
22+
self
23+
}
24+
}
25+
}
26+
)*
27+
}
28+
}
29+
30+
impl_fix_endianness! { u8, u16, u32, u64 }
31+
1932
mod sealed {
2033
use super::*;
2134

@@ -109,7 +122,7 @@ impl_element! { isize, usize }
109122
/// and/or Rust versions, and code should not assume that it is equivalent to
110123
/// `[T; N]`.
111124
#[repr(transparent)]
112-
pub struct Mask<T, const N: usize>(mask_impl::Mask<T, N>)
125+
pub struct Mask<T, const N: usize>(Simd<T, N>)
113126
where
114127
T: MaskElement,
115128
LaneCount<N>: SupportedLaneCount;
@@ -141,7 +154,7 @@ where
141154
#[inline]
142155
#[rustc_const_unstable(feature = "portable_simd", issue = "86656")]
143156
pub const fn splat(value: bool) -> Self {
144-
Self(mask_impl::Mask::splat(value))
157+
Self(Simd::splat(if value { T::TRUE } else { T::FALSE }))
145158
}
146159

147160
/// Converts an array of bools to a SIMD mask.
@@ -192,8 +205,8 @@ where
192205
// Safety: the caller must confirm this invariant
193206
unsafe {
194207
core::intrinsics::assume(<T as Sealed>::valid(value));
195-
Self(mask_impl::Mask::from_simd_unchecked(value))
196208
}
209+
Self(value)
197210
}
198211

199212
/// Converts a vector of integers to a mask, where 0 represents `false` and -1
@@ -215,14 +228,15 @@ where
215228
#[inline]
216229
#[must_use = "method returns a new vector and does not mutate the original value"]
217230
pub fn to_simd(self) -> Simd<T, N> {
218-
self.0.to_simd()
231+
self.0
219232
}
220233

221234
/// Converts the mask to a mask of any other element size.
222235
#[inline]
223236
#[must_use = "method returns a new mask and does not mutate the original value"]
224237
pub fn cast<U: MaskElement>(self) -> Mask<U, N> {
225-
Mask(self.0.convert())
238+
// Safety: mask elements are integers
239+
unsafe { Mask(core::intrinsics::simd::simd_as(self.0)) }
226240
}
227241

228242
/// Tests the value of the specified element.
@@ -233,7 +247,7 @@ where
233247
#[must_use = "method returns a new bool and does not mutate the original value"]
234248
pub unsafe fn test_unchecked(&self, index: usize) -> bool {
235249
// Safety: the caller must confirm this invariant
236-
unsafe { self.0.test_unchecked(index) }
250+
unsafe { T::eq(*self.0.as_array().get_unchecked(index), T::TRUE) }
237251
}
238252

239253
/// Tests the value of the specified element.
@@ -244,9 +258,7 @@ where
244258
#[must_use = "method returns a new bool and does not mutate the original value"]
245259
#[track_caller]
246260
pub fn test(&self, index: usize) -> bool {
247-
assert!(index < N, "element index out of range");
248-
// Safety: the element index has been checked
249-
unsafe { self.test_unchecked(index) }
261+
T::eq(self.0[index], T::TRUE)
250262
}
251263

252264
/// Sets the value of the specified element.
@@ -257,7 +269,7 @@ where
257269
pub unsafe fn set_unchecked(&mut self, index: usize, value: bool) {
258270
// Safety: the caller must confirm this invariant
259271
unsafe {
260-
self.0.set_unchecked(index, value);
272+
*self.0.as_mut_array().get_unchecked_mut(index) = if value { T::TRUE } else { T::FALSE }
261273
}
262274
}
263275

@@ -268,35 +280,67 @@ where
268280
#[inline]
269281
#[track_caller]
270282
pub fn set(&mut self, index: usize, value: bool) {
271-
assert!(index < N, "element index out of range");
272-
// Safety: the element index has been checked
273-
unsafe {
274-
self.set_unchecked(index, value);
275-
}
283+
self.0[index] = if value { T::TRUE } else { T::FALSE }
276284
}
277285

278286
/// Returns true if any element is set, or false otherwise.
279287
#[inline]
280288
#[must_use = "method returns a new bool and does not mutate the original value"]
281289
pub fn any(self) -> bool {
282-
self.0.any()
290+
// Safety: `self` is a mask vector
291+
unsafe { core::intrinsics::simd::simd_reduce_any(self.0) }
283292
}
284293

285294
/// Returns true if all elements are set, or false otherwise.
286295
#[inline]
287296
#[must_use = "method returns a new bool and does not mutate the original value"]
288297
pub fn all(self) -> bool {
289-
self.0.all()
298+
// Safety: `self` is a mask vector
299+
unsafe { core::intrinsics::simd::simd_reduce_all(self.0) }
290300
}
291301

292302
/// Creates a bitmask from a mask.
293303
///
294304
/// Each bit is set if the corresponding element in the mask is `true`.
295-
/// If the mask contains more than 64 elements, the bitmask is truncated to the first 64.
296305
#[inline]
297306
#[must_use = "method returns a new integer and does not mutate the original value"]
298307
pub fn to_bitmask(self) -> u64 {
299-
self.0.to_bitmask_integer()
308+
const {
309+
assert!(N <= 64, "number of elements can't be greater than 64");
310+
}
311+
312+
#[inline]
313+
unsafe fn to_bitmask_impl<T, U: FixEndianness, const M: usize, const N: usize>(
314+
mask: Mask<T, N>,
315+
) -> U
316+
where
317+
T: MaskElement,
318+
LaneCount<M>: SupportedLaneCount,
319+
LaneCount<N>: SupportedLaneCount,
320+
{
321+
let resized = mask.resize::<M>(false);
322+
323+
// Safety: `resized` is an integer vector with length M, which must match T
324+
let bitmask: U = unsafe { core::intrinsics::simd::simd_bitmask(resized.0) };
325+
326+
// LLVM assumes bit order should match endianness
327+
bitmask.fix_endianness()
328+
}
329+
330+
// TODO modify simd_bitmask to zero-extend output, making this unnecessary
331+
if N <= 8 {
332+
// Safety: bitmask matches length
333+
unsafe { to_bitmask_impl::<T, u8, 8, N>(self) as u64 }
334+
} else if N <= 16 {
335+
// Safety: bitmask matches length
336+
unsafe { to_bitmask_impl::<T, u16, 16, N>(self) as u64 }
337+
} else if N <= 32 {
338+
// Safety: bitmask matches length
339+
unsafe { to_bitmask_impl::<T, u32, 32, N>(self) as u64 }
340+
} else {
341+
// Safety: bitmask matches length
342+
unsafe { to_bitmask_impl::<T, u64, 64, N>(self) }
343+
}
300344
}
301345

302346
/// Creates a mask from a bitmask.
@@ -306,7 +350,7 @@ where
306350
#[inline]
307351
#[must_use = "method returns a new mask and does not mutate the original value"]
308352
pub fn from_bitmask(bitmask: u64) -> Self {
309-
Self(mask_impl::Mask::from_bitmask_integer(bitmask))
353+
Self(bitmask.select(Simd::splat(T::TRUE), Simd::splat(T::FALSE)))
310354
}
311355

312356
/// Finds the index of the first set element.
@@ -450,7 +494,8 @@ where
450494
type Output = Self;
451495
#[inline]
452496
fn bitand(self, rhs: Self) -> Self {
453-
Self(self.0 & rhs.0)
497+
// Safety: `self` is an integer vector
498+
unsafe { Self(core::intrinsics::simd::simd_and(self.0, rhs.0)) }
454499
}
455500
}
456501

@@ -486,7 +531,8 @@ where
486531
type Output = Self;
487532
#[inline]
488533
fn bitor(self, rhs: Self) -> Self {
489-
Self(self.0 | rhs.0)
534+
// Safety: `self` is an integer vector
535+
unsafe { Self(core::intrinsics::simd::simd_or(self.0, rhs.0)) }
490536
}
491537
}
492538

@@ -522,7 +568,8 @@ where
522568
type Output = Self;
523569
#[inline]
524570
fn bitxor(self, rhs: Self) -> Self::Output {
525-
Self(self.0 ^ rhs.0)
571+
// Safety: `self` is an integer vector
572+
unsafe { Self(core::intrinsics::simd::simd_xor(self.0, rhs.0)) }
526573
}
527574
}
528575

@@ -558,7 +605,7 @@ where
558605
type Output = Mask<T, N>;
559606
#[inline]
560607
fn not(self) -> Self::Output {
561-
Self(!self.0)
608+
Self::splat(true) ^ self
562609
}
563610
}
564611

@@ -569,7 +616,7 @@ where
569616
{
570617
#[inline]
571618
fn bitand_assign(&mut self, rhs: Self) {
572-
self.0 = self.0 & rhs.0;
619+
*self = *self & rhs;
573620
}
574621
}
575622

@@ -591,7 +638,7 @@ where
591638
{
592639
#[inline]
593640
fn bitor_assign(&mut self, rhs: Self) {
594-
self.0 = self.0 | rhs.0;
641+
*self = *self | rhs;
595642
}
596643
}
597644

@@ -613,7 +660,7 @@ where
613660
{
614661
#[inline]
615662
fn bitxor_assign(&mut self, rhs: Self) {
616-
self.0 = self.0 ^ rhs.0;
663+
*self = *self ^ rhs;
617664
}
618665
}
619666

0 commit comments

Comments
 (0)