Skip to content

Commit ad424e6

Browse files
committed
Always fall back to PartialEq when a constant in a pattern is not recursively structural-eq
1 parent 8d00f76 commit ad424e6

File tree

4 files changed

+104
-76
lines changed

4 files changed

+104
-76
lines changed

compiler/rustc_mir_build/src/build/matches/test.rs

+31-5
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,9 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
380380
);
381381
}
382382

383-
/// Compare two `&T` values using `<T as std::compare::PartialEq>::eq`
383+
/// Compare two values using `<T as std::compare::PartialEq>::eq`.
384+
/// If the values are already references, just call it directly, otherwise
385+
/// take a reference to the values first and then call it.
384386
fn non_scalar_compare(
385387
&mut self,
386388
block: BasicBlock,
@@ -441,12 +443,36 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
441443
}
442444
}
443445

444-
let ty::Ref(_, deref_ty, _) = *ty.kind() else {
445-
bug!("non_scalar_compare called on non-reference type: {}", ty);
446-
};
446+
match *ty.kind() {
447+
ty::Ref(_, deref_ty, _) => ty = deref_ty,
448+
_ => {
449+
// non_scalar_compare called on non-reference type
450+
let temp = self.temp(ty, source_info.span);
451+
self.cfg.push_assign(block, source_info, temp, Rvalue::Use(expect));
452+
let ref_ty = self.tcx.mk_imm_ref(self.tcx.lifetimes.re_erased, ty);
453+
let ref_temp = self.temp(ref_ty, source_info.span);
454+
455+
self.cfg.push_assign(
456+
block,
457+
source_info,
458+
ref_temp,
459+
Rvalue::Ref(self.tcx.lifetimes.re_erased, BorrowKind::Shared, temp),
460+
);
461+
expect = Operand::Move(ref_temp);
462+
463+
let ref_temp = self.temp(ref_ty, source_info.span);
464+
self.cfg.push_assign(
465+
block,
466+
source_info,
467+
ref_temp,
468+
Rvalue::Ref(self.tcx.lifetimes.re_erased, BorrowKind::Shared, val),
469+
);
470+
val = ref_temp;
471+
}
472+
}
447473

448474
let eq_def_id = self.tcx.require_lang_item(LangItem::PartialEq, Some(source_info.span));
449-
let method = trait_method(self.tcx, eq_def_id, sym::eq, [deref_ty, deref_ty]);
475+
let method = trait_method(self.tcx, eq_def_id, sym::eq, [ty, ty]);
450476

451477
let bool_ty = self.tcx.types.bool;
452478
let eq_result = self.temp(bool_ty, source_info.span);

compiler/rustc_mir_build/src/thir/pattern/const_to_pat.rs

+20-31
Original file line numberDiff line numberDiff line change
@@ -62,21 +62,13 @@ struct ConstToPat<'tcx> {
6262
treat_byte_string_as_slice: bool,
6363
}
6464

65-
mod fallback_to_const_ref {
66-
#[derive(Debug)]
67-
/// This error type signals that we encountered a non-struct-eq situation behind a reference.
68-
/// We bubble this up in order to get back to the reference destructuring and make that emit
69-
/// a const pattern instead of a deref pattern. This allows us to simply call `PartialEq::eq`
70-
/// on such patterns (since that function takes a reference) and not have to jump through any
71-
/// hoops to get a reference to the value.
72-
pub(super) struct FallbackToConstRef(());
73-
74-
pub(super) fn fallback_to_const_ref(c2p: &super::ConstToPat<'_>) -> FallbackToConstRef {
75-
assert!(c2p.behind_reference.get());
76-
FallbackToConstRef(())
77-
}
78-
}
79-
use fallback_to_const_ref::{fallback_to_const_ref, FallbackToConstRef};
65+
/// This error type signals that we encountered a non-struct-eq situation.
66+
/// We bubble this up in order to get back to the reference destructuring and make that emit
67+
/// a const pattern instead of a deref pattern. This allows us to simply call `PartialEq::eq`
68+
/// on such patterns (since that function takes a reference) and not have to jump through any
69+
/// hoops to get a reference to the value.
70+
#[derive(Debug)]
71+
struct FallbackToConstRef;
8072

8173
impl<'tcx> ConstToPat<'tcx> {
8274
fn new(
@@ -236,13 +228,13 @@ impl<'tcx> ConstToPat<'tcx> {
236228

237229
let kind = match cv.ty().kind() {
238230
ty::Float(_) => {
239-
tcx.emit_spanned_lint(
240-
lint::builtin::ILLEGAL_FLOATING_POINT_LITERAL_PATTERN,
241-
id,
242-
span,
243-
FloatPattern,
244-
);
245-
PatKind::Constant { value: cv }
231+
tcx.emit_spanned_lint(
232+
lint::builtin::ILLEGAL_FLOATING_POINT_LITERAL_PATTERN,
233+
id,
234+
span,
235+
FloatPattern,
236+
);
237+
return Err(FallbackToConstRef);
246238
}
247239
ty::Adt(adt_def, _) if adt_def.is_union() => {
248240
// Matching on union fields is unsafe, we can't hide it in constants
@@ -289,7 +281,7 @@ impl<'tcx> ConstToPat<'tcx> {
289281
// Since we are behind a reference, we can just bubble the error up so we get a
290282
// constant at reference type, making it easy to let the fallback call
291283
// `PartialEq::eq` on it.
292-
return Err(fallback_to_const_ref(self));
284+
return Err(FallbackToConstRef);
293285
}
294286
ty::Adt(adt_def, _) if !self.type_marked_structural(cv.ty()) => {
295287
debug!(
@@ -411,7 +403,7 @@ impl<'tcx> ConstToPat<'tcx> {
411403
IndirectStructuralMatch { non_sm_ty: *pointee_ty },
412404
);
413405
}
414-
PatKind::Constant { value: cv }
406+
return Err(FallbackToConstRef);
415407
} else {
416408
if !self.saw_const_match_error.get() {
417409
self.saw_const_match_error.set(true);
@@ -439,20 +431,17 @@ impl<'tcx> ConstToPat<'tcx> {
439431
// we fall back to a const pattern. If we do not do this, we may end up with
440432
// a !structural-match constant that is not of reference type, which makes it
441433
// very hard to invoke `PartialEq::eq` on it as a fallback.
442-
let val = match self.recur(tcx.deref_mir_constant(self.param_env.and(cv)), false) {
443-
Ok(subpattern) => PatKind::Deref { subpattern },
444-
Err(_) => PatKind::Constant { value: cv },
445-
};
434+
let subpattern = self.recur(tcx.deref_mir_constant(self.param_env.and(cv)), false)?;
446435
self.behind_reference.set(old);
447-
val
436+
PatKind::Deref { subpattern }
448437
}
449438
}
450439
},
451440
ty::Bool | ty::Char | ty::Int(_) | ty::Uint(_) | ty::FnDef(..) => {
452441
PatKind::Constant { value: cv }
453442
}
454443
ty::RawPtr(pointee) if pointee.ty.is_sized(tcx, param_env) => {
455-
PatKind::Constant { value: cv }
444+
return Err(FallbackToConstRef);
456445
}
457446
// FIXME: these can have very surprising behaviour where optimization levels or other
458447
// compilation choices change the runtime behaviour of the match.
@@ -469,7 +458,7 @@ impl<'tcx> ConstToPat<'tcx> {
469458
PointerPattern
470459
);
471460
}
472-
PatKind::Constant { value: cv }
461+
return Err(FallbackToConstRef);
473462
}
474463
_ => {
475464
self.saw_const_match_error.set(true);

tests/ui/pattern/usefulness/consts-opaque.rs

+13-5
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@ const BAR: Bar = Bar;
2020
#[derive(PartialEq)]
2121
enum Baz {
2222
Baz1,
23-
Baz2
23+
Baz2,
2424
}
2525
impl Eq for Baz {}
2626
const BAZ: Baz = Baz::Baz1;
2727

28+
#[rustfmt::skip]
2829
fn main() {
2930
match FOO {
3031
FOO => {}
@@ -124,8 +125,16 @@ fn main() {
124125

125126
match WRAPQUUX {
126127
Wrap(_) => {}
127-
WRAPQUUX => {} // detected unreachable because we do inspect the `Wrap` layer
128-
//~^ ERROR unreachable pattern
128+
WRAPQUUX => {}
129+
}
130+
131+
match WRAPQUUX {
132+
Wrap(_) => {}
133+
}
134+
135+
match WRAPQUUX {
136+
//~^ ERROR: non-exhaustive patterns: `Wrap(_)` not covered
137+
WRAPQUUX => {}
129138
}
130139

131140
#[derive(PartialEq, Eq)]
@@ -138,8 +147,7 @@ fn main() {
138147
match WHOKNOWSQUUX {
139148
WHOKNOWSQUUX => {}
140149
WhoKnows::Yay(_) => {}
141-
WHOKNOWSQUUX => {} // detected unreachable because we do inspect the `WhoKnows` layer
142-
//~^ ERROR unreachable pattern
150+
WHOKNOWSQUUX => {}
143151
WhoKnows::Nope => {}
144152
}
145153
}

0 commit comments

Comments
 (0)