2
2
//! Types representing
3
3
#![ allow( non_camel_case_types) ]
4
4
5
- #[ cfg_attr(
6
- not( all( target_arch = "x86_64" , target_feature = "avx512f" ) ) ,
7
- path = "masks/full_masks.rs"
8
- ) ]
9
- #[ cfg_attr(
10
- all( target_arch = "x86_64" , target_feature = "avx512f" ) ,
11
- path = "masks/bitmask.rs"
12
- ) ]
13
- mod mask_impl;
14
-
15
- use crate :: simd:: { LaneCount , Simd , SimdCast , SimdElement , SupportedLaneCount } ;
5
+ use crate :: simd:: { LaneCount , Select , Simd , SimdCast , SimdElement , SupportedLaneCount } ;
16
6
use core:: cmp:: Ordering ;
17
7
use core:: { fmt, mem} ;
18
8
9
+ pub ( crate ) trait FixEndianness {
10
+ fn fix_endianness ( self ) -> Self ;
11
+ }
12
+
13
+ macro_rules! impl_fix_endianness {
14
+ { $( $int: ty) ,* } => {
15
+ $(
16
+ impl FixEndianness for $int {
17
+ #[ inline( always) ]
18
+ fn fix_endianness( self ) -> Self {
19
+ if cfg!( target_endian = "big" ) {
20
+ <$int>:: reverse_bits( self )
21
+ } else {
22
+ self
23
+ }
24
+ }
25
+ }
26
+ ) *
27
+ }
28
+ }
29
+
30
+ impl_fix_endianness ! { u8 , u16 , u32 , u64 }
31
+
19
32
mod sealed {
20
33
use super :: * ;
21
34
@@ -109,7 +122,7 @@ impl_element! { isize, usize }
109
122
/// and/or Rust versions, and code should not assume that it is equivalent to
110
123
/// `[T; N]`.
111
124
#[ repr( transparent) ]
112
- pub struct Mask < T , const N : usize > ( mask_impl :: Mask < T , N > )
125
+ pub struct Mask < T , const N : usize > ( Simd < T , N > )
113
126
where
114
127
T : MaskElement ,
115
128
LaneCount < N > : SupportedLaneCount ;
@@ -141,7 +154,7 @@ where
141
154
#[ inline]
142
155
#[ rustc_const_unstable( feature = "portable_simd" , issue = "86656" ) ]
143
156
pub const fn splat ( value : bool ) -> Self {
144
- Self ( mask_impl :: Mask :: splat ( value) )
157
+ Self ( Simd :: splat ( if value { T :: TRUE } else { T :: FALSE } ) )
145
158
}
146
159
147
160
/// Converts an array of bools to a SIMD mask.
@@ -192,8 +205,8 @@ where
192
205
// Safety: the caller must confirm this invariant
193
206
unsafe {
194
207
core:: intrinsics:: assume ( <T as Sealed >:: valid ( value) ) ;
195
- Self ( mask_impl:: Mask :: from_simd_unchecked ( value) )
196
208
}
209
+ Self ( value)
197
210
}
198
211
199
212
/// Converts a vector of integers to a mask, where 0 represents `false` and -1
@@ -215,14 +228,15 @@ where
215
228
#[ inline]
216
229
#[ must_use = "method returns a new vector and does not mutate the original value" ]
217
230
pub fn to_simd ( self ) -> Simd < T , N > {
218
- self . 0 . to_simd ( )
231
+ self . 0
219
232
}
220
233
221
234
/// Converts the mask to a mask of any other element size.
222
235
#[ inline]
223
236
#[ must_use = "method returns a new mask and does not mutate the original value" ]
224
237
pub fn cast < U : MaskElement > ( self ) -> Mask < U , N > {
225
- Mask ( self . 0 . convert ( ) )
238
+ // Safety: mask elements are integers
239
+ unsafe { Mask ( core:: intrinsics:: simd:: simd_as ( self . 0 ) ) }
226
240
}
227
241
228
242
/// Tests the value of the specified element.
@@ -233,7 +247,7 @@ where
233
247
#[ must_use = "method returns a new bool and does not mutate the original value" ]
234
248
pub unsafe fn test_unchecked ( & self , index : usize ) -> bool {
235
249
// Safety: the caller must confirm this invariant
236
- unsafe { self . 0 . test_unchecked ( index) }
250
+ unsafe { T :: eq ( * self . 0 . as_array ( ) . get_unchecked ( index) , T :: TRUE ) }
237
251
}
238
252
239
253
/// Tests the value of the specified element.
@@ -244,9 +258,7 @@ where
244
258
#[ must_use = "method returns a new bool and does not mutate the original value" ]
245
259
#[ track_caller]
246
260
pub fn test ( & self , index : usize ) -> bool {
247
- assert ! ( index < N , "element index out of range" ) ;
248
- // Safety: the element index has been checked
249
- unsafe { self . test_unchecked ( index) }
261
+ T :: eq ( self . 0 [ index] , T :: TRUE )
250
262
}
251
263
252
264
/// Sets the value of the specified element.
@@ -257,7 +269,7 @@ where
257
269
pub unsafe fn set_unchecked ( & mut self , index : usize , value : bool ) {
258
270
// Safety: the caller must confirm this invariant
259
271
unsafe {
260
- self . 0 . set_unchecked ( index, value) ;
272
+ * self . 0 . as_mut_array ( ) . get_unchecked_mut ( index) = if value { T :: TRUE } else { T :: FALSE }
261
273
}
262
274
}
263
275
@@ -268,35 +280,67 @@ where
268
280
#[ inline]
269
281
#[ track_caller]
270
282
pub fn set ( & mut self , index : usize , value : bool ) {
271
- assert ! ( index < N , "element index out of range" ) ;
272
- // Safety: the element index has been checked
273
- unsafe {
274
- self . set_unchecked ( index, value) ;
275
- }
283
+ self . 0 [ index] = if value { T :: TRUE } else { T :: FALSE }
276
284
}
277
285
278
286
/// Returns true if any element is set, or false otherwise.
279
287
#[ inline]
280
288
#[ must_use = "method returns a new bool and does not mutate the original value" ]
281
289
pub fn any ( self ) -> bool {
282
- self . 0 . any ( )
290
+ // Safety: `self` is a mask vector
291
+ unsafe { core:: intrinsics:: simd:: simd_reduce_any ( self . 0 ) }
283
292
}
284
293
285
294
/// Returns true if all elements are set, or false otherwise.
286
295
#[ inline]
287
296
#[ must_use = "method returns a new bool and does not mutate the original value" ]
288
297
pub fn all ( self ) -> bool {
289
- self . 0 . all ( )
298
+ // Safety: `self` is a mask vector
299
+ unsafe { core:: intrinsics:: simd:: simd_reduce_all ( self . 0 ) }
290
300
}
291
301
292
302
/// Creates a bitmask from a mask.
293
303
///
294
304
/// Each bit is set if the corresponding element in the mask is `true`.
295
- /// If the mask contains more than 64 elements, the bitmask is truncated to the first 64.
296
305
#[ inline]
297
306
#[ must_use = "method returns a new integer and does not mutate the original value" ]
298
307
pub fn to_bitmask ( self ) -> u64 {
299
- self . 0 . to_bitmask_integer ( )
308
+ const {
309
+ assert ! ( N <= 64 , "number of elements can't be greater than 64" ) ;
310
+ }
311
+
312
+ #[ inline]
313
+ unsafe fn to_bitmask_impl < T , U : FixEndianness , const M : usize , const N : usize > (
314
+ mask : Mask < T , N > ,
315
+ ) -> U
316
+ where
317
+ T : MaskElement ,
318
+ LaneCount < M > : SupportedLaneCount ,
319
+ LaneCount < N > : SupportedLaneCount ,
320
+ {
321
+ let resized = mask. resize :: < M > ( false ) ;
322
+
323
+ // Safety: `resized` is an integer vector with length M, which must match T
324
+ let bitmask: U = unsafe { core:: intrinsics:: simd:: simd_bitmask ( resized. 0 ) } ;
325
+
326
+ // LLVM assumes bit order should match endianness
327
+ bitmask. fix_endianness ( )
328
+ }
329
+
330
+ // TODO modify simd_bitmask to zero-extend output, making this unnecessary
331
+ if N <= 8 {
332
+ // Safety: bitmask matches length
333
+ unsafe { to_bitmask_impl :: < T , u8 , 8 , N > ( self ) as u64 }
334
+ } else if N <= 16 {
335
+ // Safety: bitmask matches length
336
+ unsafe { to_bitmask_impl :: < T , u16 , 16 , N > ( self ) as u64 }
337
+ } else if N <= 32 {
338
+ // Safety: bitmask matches length
339
+ unsafe { to_bitmask_impl :: < T , u32 , 32 , N > ( self ) as u64 }
340
+ } else {
341
+ // Safety: bitmask matches length
342
+ unsafe { to_bitmask_impl :: < T , u64 , 64 , N > ( self ) }
343
+ }
300
344
}
301
345
302
346
/// Creates a mask from a bitmask.
@@ -306,7 +350,7 @@ where
306
350
#[ inline]
307
351
#[ must_use = "method returns a new mask and does not mutate the original value" ]
308
352
pub fn from_bitmask ( bitmask : u64 ) -> Self {
309
- Self ( mask_impl :: Mask :: from_bitmask_integer ( bitmask ) )
353
+ Self ( bitmask . select ( Simd :: splat ( T :: TRUE ) , Simd :: splat ( T :: FALSE ) ) )
310
354
}
311
355
312
356
/// Finds the index of the first set element.
@@ -450,7 +494,8 @@ where
450
494
type Output = Self ;
451
495
#[ inline]
452
496
fn bitand ( self , rhs : Self ) -> Self {
453
- Self ( self . 0 & rhs. 0 )
497
+ // Safety: `self` is an integer vector
498
+ unsafe { Self ( core:: intrinsics:: simd:: simd_and ( self . 0 , rhs. 0 ) ) }
454
499
}
455
500
}
456
501
@@ -486,7 +531,8 @@ where
486
531
type Output = Self ;
487
532
#[ inline]
488
533
fn bitor ( self , rhs : Self ) -> Self {
489
- Self ( self . 0 | rhs. 0 )
534
+ // Safety: `self` is an integer vector
535
+ unsafe { Self ( core:: intrinsics:: simd:: simd_or ( self . 0 , rhs. 0 ) ) }
490
536
}
491
537
}
492
538
@@ -522,7 +568,8 @@ where
522
568
type Output = Self ;
523
569
#[ inline]
524
570
fn bitxor ( self , rhs : Self ) -> Self :: Output {
525
- Self ( self . 0 ^ rhs. 0 )
571
+ // Safety: `self` is an integer vector
572
+ unsafe { Self ( core:: intrinsics:: simd:: simd_xor ( self . 0 , rhs. 0 ) ) }
526
573
}
527
574
}
528
575
@@ -558,7 +605,7 @@ where
558
605
type Output = Mask < T , N > ;
559
606
#[ inline]
560
607
fn not ( self ) -> Self :: Output {
561
- Self ( ! self . 0 )
608
+ Self :: splat ( true ) ^ self
562
609
}
563
610
}
564
611
@@ -569,7 +616,7 @@ where
569
616
{
570
617
#[ inline]
571
618
fn bitand_assign ( & mut self , rhs : Self ) {
572
- self . 0 = self . 0 & rhs. 0 ;
619
+ * self = * self & rhs;
573
620
}
574
621
}
575
622
@@ -591,7 +638,7 @@ where
591
638
{
592
639
#[ inline]
593
640
fn bitor_assign ( & mut self , rhs : Self ) {
594
- self . 0 = self . 0 | rhs. 0 ;
641
+ * self = * self | rhs;
595
642
}
596
643
}
597
644
@@ -613,7 +660,7 @@ where
613
660
{
614
661
#[ inline]
615
662
fn bitxor_assign ( & mut self , rhs : Self ) {
616
- self . 0 = self . 0 ^ rhs. 0 ;
663
+ * self = * self ^ rhs;
617
664
}
618
665
}
619
666
0 commit comments