Skip to content

Commit 3a142f0

Browse files
committed
handle the array case consistently for simd_select_bitmask and simd_bitmask
also move the two next to each other
1 parent 8e3dbb2 commit 3a142f0

File tree

1 file changed

+45
-43
lines changed

1 file changed

+45
-43
lines changed

src/shims/intrinsics/simd.rs

Lines changed: 45 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -405,37 +405,35 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
405405
this.write_immediate(*val, &dest)?;
406406
}
407407
}
408+
// Variant of `select` that takes a bitmask rather than a "vector of bool".
408409
"select_bitmask" => {
409410
let [mask, yes, no] = check_arg_count(args)?;
410411
let (yes, yes_len) = this.operand_to_simd(yes)?;
411412
let (no, no_len) = this.operand_to_simd(no)?;
412413
let (dest, dest_len) = this.place_to_simd(dest)?;
413414
let bitmask_len = dest_len.max(8);
414415

416+
// The mask must be an integer or an array.
417+
assert!(
418+
mask.layout.ty.is_integral()
419+
|| matches!(mask.layout.ty.kind(), ty::Array(elemty, _) if elemty == &this.tcx.types.u8)
420+
);
415421
assert!(bitmask_len <= 64);
416422
assert_eq!(bitmask_len, mask.layout.size.bits());
417423
assert_eq!(dest_len, yes_len);
418424
assert_eq!(dest_len, no_len);
419425
let dest_len = u32::try_from(dest_len).unwrap();
420426
let bitmask_len = u32::try_from(bitmask_len).unwrap();
421427

422-
// The mask can be a single integer or an array.
423-
let mask: u64 = match mask.layout.ty.kind() {
424-
ty::Int(..) | ty::Uint(..) =>
425-
this.read_scalar(mask)?.to_bits(mask.layout.size)?.try_into().unwrap(),
426-
ty::Array(elem, _) if matches!(elem.kind(), ty::Uint(ty::UintTy::U8)) => {
427-
let mask_ty = this.machine.layouts.uint(mask.layout.size).unwrap();
428-
let mask = mask.transmute(mask_ty, this)?;
429-
this.read_scalar(&mask)?.to_bits(mask_ty.size)?.try_into().unwrap()
430-
}
431-
_ => bug!("simd_select_bitmask: invalid mask type {}", mask.layout.ty),
432-
};
428+
// To read the mask, we transmute it to an integer.
429+
// That does the right thing wrt endianess.
430+
let mask_ty = this.machine.layouts.uint(mask.layout.size).unwrap();
431+
let mask = mask.transmute(mask_ty, this)?;
432+
let mask: u64 = this.read_scalar(&mask)?.to_bits(mask_ty.size)?.try_into().unwrap();
433433

434434
for i in 0..dest_len {
435-
let mask = mask
436-
& 1u64
437-
.checked_shl(simd_bitmask_index(i, dest_len, this.data_layout().endian))
438-
.unwrap();
435+
let bit_i = simd_bitmask_index(i, dest_len, this.data_layout().endian);
436+
let mask = mask & 1u64.checked_shl(bit_i).unwrap();
439437
let yes = this.read_immediate(&this.project_index(&yes, i.into())?)?;
440438
let no = this.read_immediate(&this.project_index(&no, i.into())?)?;
441439
let dest = this.project_index(&dest, i.into())?;
@@ -445,6 +443,8 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
445443
}
446444
for i in dest_len..bitmask_len {
447445
// If the mask is "padded", ensure that padding is all-zero.
446+
// This deliberately does not use `simd_bitmask_index`; these bits are outside
447+
// the bitmask. It does not matter in which order we check them.
448448
let mask = mask & 1u64.checked_shl(i).unwrap();
449449
if mask != 0 {
450450
throw_ub_format!(
@@ -453,6 +453,36 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
453453
}
454454
}
455455
}
456+
// Converts a "vector of bool" into a bitmask.
457+
"bitmask" => {
458+
let [op] = check_arg_count(args)?;
459+
let (op, op_len) = this.operand_to_simd(op)?;
460+
let bitmask_len = op_len.max(8);
461+
462+
// Returns either an unsigned integer or array of `u8`.
463+
assert!(
464+
dest.layout.ty.is_integral()
465+
|| matches!(dest.layout.ty.kind(), ty::Array(elemty, _) if elemty == &this.tcx.types.u8)
466+
);
467+
assert!(bitmask_len <= 64);
468+
assert_eq!(bitmask_len, dest.layout.size.bits());
469+
let op_len = u32::try_from(op_len).unwrap();
470+
471+
let mut res = 0u64;
472+
for i in 0..op_len {
473+
let op = this.read_immediate(&this.project_index(&op, i.into())?)?;
474+
if simd_element_to_bool(op)? {
475+
res |= 1u64
476+
.checked_shl(simd_bitmask_index(i, op_len, this.data_layout().endian))
477+
.unwrap();
478+
}
479+
}
480+
// We have to change the type of the place to be able to write `res` into it. This
481+
// transmutes the integer to an array, which does the right thing wrt endianess.
482+
let dest =
483+
dest.transmute(this.machine.layouts.uint(dest.layout.size).unwrap(), this)?;
484+
this.write_int(res, &dest)?;
485+
}
456486
"cast" | "as" | "cast_ptr" | "expose_addr" | "from_exposed_addr" => {
457487
let [op] = check_arg_count(args)?;
458488
let (op, op_len) = this.operand_to_simd(op)?;
@@ -635,34 +665,6 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
635665
}
636666
}
637667
}
638-
"bitmask" => {
639-
let [op] = check_arg_count(args)?;
640-
let (op, op_len) = this.operand_to_simd(op)?;
641-
let bitmask_len = op_len.max(8);
642-
643-
// Returns either an unsigned integer or array of `u8`.
644-
assert!(
645-
dest.layout.ty.is_integral()
646-
|| matches!(dest.layout.ty.kind(), ty::Array(elemty, _) if elemty == &this.tcx.types.u8)
647-
);
648-
assert!(bitmask_len <= 64);
649-
assert_eq!(bitmask_len, dest.layout.size.bits());
650-
let op_len = u32::try_from(op_len).unwrap();
651-
652-
let mut res = 0u64;
653-
for i in 0..op_len {
654-
let op = this.read_immediate(&this.project_index(&op, i.into())?)?;
655-
if simd_element_to_bool(op)? {
656-
res |= 1u64
657-
.checked_shl(simd_bitmask_index(i, op_len, this.data_layout().endian))
658-
.unwrap();
659-
}
660-
}
661-
// We have to force the place type to be an int so that we can write `res` into it.
662-
let mut dest = this.force_allocation(dest)?;
663-
dest.layout = this.machine.layouts.uint(dest.layout.size).unwrap();
664-
this.write_int(res, &dest)?;
665-
}
666668

667669
name => throw_unsup_format!("unimplemented intrinsic: `simd_{name}`"),
668670
}

0 commit comments

Comments
 (0)