From 8a0d8dde44b3b413488aa69caebc9cc3fb0f7edf Mon Sep 17 00:00:00 2001 From: Bastian Kersting Date: Tue, 1 Jul 2025 08:59:52 +0000 Subject: [PATCH] Make the enum check work for negative discriminants The discriminant check was not working correctly for negative numbers. This change fixes that by masking out the relevant bits correctly. --- .../rustc_mir_transform/src/check_enums.rs | 30 +++++++++-- tests/ui/mir/enum/negative_discr_break.rs | 14 +++++ tests/ui/mir/enum/negative_discr_ok.rs | 53 +++++++++++++++++++ 3 files changed, 93 insertions(+), 4 deletions(-) create mode 100644 tests/ui/mir/enum/negative_discr_break.rs create mode 100644 tests/ui/mir/enum/negative_discr_ok.rs diff --git a/compiler/rustc_mir_transform/src/check_enums.rs b/compiler/rustc_mir_transform/src/check_enums.rs index 240da87ab278b..fae984b4936f2 100644 --- a/compiler/rustc_mir_transform/src/check_enums.rs +++ b/compiler/rustc_mir_transform/src/check_enums.rs @@ -120,6 +120,7 @@ enum EnumCheckType<'tcx> { }, } +#[derive(Debug, Copy, Clone)] struct TyAndSize<'tcx> { pub ty: Ty<'tcx>, pub size: Size, @@ -337,7 +338,7 @@ fn insert_direct_enum_check<'tcx>( let invalid_discr_block_data = BasicBlockData::new(None, false); let invalid_discr_block = basic_blocks.push(invalid_discr_block_data); let block_data = &mut basic_blocks[current_block]; - let discr = insert_discr_cast_to_u128( + let discr_place = insert_discr_cast_to_u128( tcx, local_decls, block_data, @@ -348,13 +349,34 @@ fn insert_direct_enum_check<'tcx>( source_info, ); + // Mask out the bits of the discriminant type. + let mask = discr.size.unsigned_int_max(); + let discr_masked = + local_decls.push(LocalDecl::with_source_info(tcx.types.u128, source_info)).into(); + let rvalue = Rvalue::BinaryOp( + BinOp::BitAnd, + Box::new(( + Operand::Copy(discr_place), + Operand::Constant(Box::new(ConstOperand { + span: source_info.span, + user_ty: None, + const_: Const::Val(ConstValue::from_u128(mask), tcx.types.u128), + })), + )), + ); + block_data + .statements + .push(Statement::new(source_info, StatementKind::Assign(Box::new((discr_masked, rvalue))))); + // Branch based on the discriminant value. block_data.terminator = Some(Terminator { source_info, kind: TerminatorKind::SwitchInt { - discr: Operand::Copy(discr), + discr: Operand::Copy(discr_masked), targets: SwitchTargets::new( - discriminants.into_iter().map(|discr| (discr, new_block)), + discriminants + .into_iter() + .map(|discr_val| (discr.size.truncate(discr_val), new_block)), invalid_discr_block, ), }, @@ -371,7 +393,7 @@ fn insert_direct_enum_check<'tcx>( })), expected: true, target: new_block, - msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Copy(discr))), + msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Copy(discr_masked))), // This calls panic_invalid_enum_construction, which is #[rustc_nounwind]. // We never want to insert an unwind into unsafe code, because unwinding could // make a failing UB check turn into much worse UB when we start unwinding. diff --git a/tests/ui/mir/enum/negative_discr_break.rs b/tests/ui/mir/enum/negative_discr_break.rs new file mode 100644 index 0000000000000..fa1284f72a079 --- /dev/null +++ b/tests/ui/mir/enum/negative_discr_break.rs @@ -0,0 +1,14 @@ +//@ run-fail +//@ compile-flags: -C debug-assertions +//@ error-pattern: trying to construct an enum from an invalid value 0xfd + +#[allow(dead_code)] +enum Foo { + A = -2, + B = -1, + C = 1, +} + +fn main() { + let _val: Foo = unsafe { std::mem::transmute::(-3) }; +} diff --git a/tests/ui/mir/enum/negative_discr_ok.rs b/tests/ui/mir/enum/negative_discr_ok.rs new file mode 100644 index 0000000000000..5c15b33fa8423 --- /dev/null +++ b/tests/ui/mir/enum/negative_discr_ok.rs @@ -0,0 +1,53 @@ +//@ run-pass +//@ compile-flags: -C debug-assertions + +#[allow(dead_code)] +#[derive(Debug, PartialEq)] +enum Foo { + A = -12121, + B = -2, + C = -1, + D = 1, + E = 2, + F = 12121, +} + +#[allow(dead_code)] +#[repr(i64)] +#[derive(Debug, PartialEq)] +enum Bar { + A = i64::MIN, + B = -2, + C = -1, + D = 1, + E = 2, + F = i64::MAX, +} + +fn main() { + let val: Foo = unsafe { std::mem::transmute::(-12121) }; + assert_eq!(val, Foo::A); + let val: Foo = unsafe { std::mem::transmute::(-2) }; + assert_eq!(val, Foo::B); + let val: Foo = unsafe { std::mem::transmute::(-1) }; + assert_eq!(val, Foo::C); + let val: Foo = unsafe { std::mem::transmute::(1) }; + assert_eq!(val, Foo::D); + let val: Foo = unsafe { std::mem::transmute::(2) }; + assert_eq!(val, Foo::E); + let val: Foo = unsafe { std::mem::transmute::(12121) }; + assert_eq!(val, Foo::F); + + let val: Bar = unsafe { std::mem::transmute::(i64::MIN) }; + assert_eq!(val, Bar::A); + let val: Bar = unsafe { std::mem::transmute::(-2) }; + assert_eq!(val, Bar::B); + let val: Bar = unsafe { std::mem::transmute::(-1) }; + assert_eq!(val, Bar::C); + let val: Bar = unsafe { std::mem::transmute::(1) }; + assert_eq!(val, Bar::D); + let val: Bar = unsafe { std::mem::transmute::(2) }; + assert_eq!(val, Bar::E); + let val: Bar = unsafe { std::mem::transmute::(i64::MAX) }; + assert_eq!(val, Bar::F); +}