diff --git a/compiler/rustc_mir_build/src/build/matches/simplify.rs b/compiler/rustc_mir_build/src/build/matches/simplify.rs index b4a0c965d6b73..c6298904140c3 100644 --- a/compiler/rustc_mir_build/src/build/matches/simplify.rs +++ b/compiler/rustc_mir_build/src/build/matches/simplify.rs @@ -227,15 +227,18 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { _ => (None, 0), }; if let Some((min, max, sz)) = range { - if let (Some(lo), Some(hi)) = (lo.try_to_bits(sz), hi.try_to_bits(sz)) { - // We want to compare ranges numerically, but the order of the bitwise - // representation of signed integers does not match their numeric order. - // Thus, to correct the ordering, we need to shift the range of signed - // integers to correct the comparison. This is achieved by XORing with a - // bias (see pattern/_match.rs for another pertinent example of this - // pattern). - let (lo, hi) = (lo ^ bias, hi ^ bias); - if lo <= min && (hi > max || hi == max && end == RangeEnd::Included) { + // We want to compare ranges numerically, but the order of the bitwise + // representation of signed integers does not match their numeric order. Thus, + // to correct the ordering, we need to shift the range of signed integers to + // correct the comparison. This is achieved by XORing with a bias (see + // pattern/_match.rs for another pertinent example of this pattern). + // + // Also, for performance, it's important to only do the second `try_to_bits` if + // necessary. + let lo = lo.try_to_bits(sz).unwrap() ^ bias; + if lo <= min { + let hi = hi.try_to_bits(sz).unwrap() ^ bias; + if hi > max || hi == max && end == RangeEnd::Included { // Irrefutable pattern match. return Ok(()); } diff --git a/compiler/rustc_mir_build/src/build/matches/test.rs b/compiler/rustc_mir_build/src/build/matches/test.rs index 3774a39503521..598da80c574af 100644 --- a/compiler/rustc_mir_build/src/build/matches/test.rs +++ b/compiler/rustc_mir_build/src/build/matches/test.rs @@ -632,39 +632,30 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { } (&TestKind::Range(test), &PatKind::Range(pat)) => { + use std::cmp::Ordering::*; + if test == pat { self.candidate_without_match_pair(match_pair_index, candidate); return Some(0); } - let no_overlap = (|| { - use rustc_hir::RangeEnd::*; - use std::cmp::Ordering::*; - - let tcx = self.tcx; - - let test_ty = test.lo.ty(); - let lo = compare_const_vals(tcx, test.lo, pat.hi, self.param_env, test_ty)?; - let hi = compare_const_vals(tcx, test.hi, pat.lo, self.param_env, test_ty)?; - - match (test.end, pat.end, lo, hi) { - // pat < test - (_, _, Greater, _) | - (_, Excluded, Equal, _) | - // pat > test - (_, _, _, Less) | - (Excluded, _, _, Equal) => Some(true), - _ => Some(false), - } - })(); - - if let Some(true) = no_overlap { - // Testing range does not overlap with pattern range, - // so the pattern can be matched only if this test fails. + // For performance, it's important to only do the second + // `compare_const_vals` if necessary. + let no_overlap = if matches!( + (compare_const_vals(self.tcx, test.hi, pat.lo, self.param_env)?, test.end), + (Less, _) | (Equal, RangeEnd::Excluded) // test < pat + ) || matches!( + (compare_const_vals(self.tcx, test.lo, pat.hi, self.param_env)?, pat.end), + (Greater, _) | (Equal, RangeEnd::Excluded) // test > pat + ) { Some(1) } else { None - } + }; + + // If the testing range does not overlap with pattern range, + // the pattern can be matched only if this test fails. + no_overlap } (&TestKind::Range(range), &PatKind::Constant { value }) => { @@ -768,15 +759,15 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { ) -> Option { use std::cmp::Ordering::*; - let tcx = self.tcx; - - let a = compare_const_vals(tcx, range.lo, value, self.param_env, range.lo.ty())?; - let b = compare_const_vals(tcx, value, range.hi, self.param_env, range.lo.ty())?; - - match (b, range.end) { - (Less, _) | (Equal, RangeEnd::Included) if a != Greater => Some(true), - _ => Some(false), - } + // For performance, it's important to only do the second + // `compare_const_vals` if necessary. + Some( + matches!(compare_const_vals(self.tcx, range.lo, value, self.param_env)?, Less | Equal) + && matches!( + (compare_const_vals(self.tcx, value, range.hi, self.param_env)?, range.end), + (Less, _) | (Equal, RangeEnd::Included) + ), + ) } fn values_not_contained_in_range( diff --git a/compiler/rustc_mir_build/src/thir/pattern/deconstruct_pat.rs b/compiler/rustc_mir_build/src/thir/pattern/deconstruct_pat.rs index 26532ae33d0cf..60db98073a3b9 100644 --- a/compiler/rustc_mir_build/src/thir/pattern/deconstruct_pat.rs +++ b/compiler/rustc_mir_build/src/thir/pattern/deconstruct_pat.rs @@ -828,14 +828,8 @@ impl<'tcx> Constructor<'tcx> { FloatRange(other_from, other_to, other_end), ) => { match ( - compare_const_vals(pcx.cx.tcx, *self_to, *other_to, pcx.cx.param_env, pcx.ty), - compare_const_vals( - pcx.cx.tcx, - *self_from, - *other_from, - pcx.cx.param_env, - pcx.ty, - ), + compare_const_vals(pcx.cx.tcx, *self_to, *other_to, pcx.cx.param_env), + compare_const_vals(pcx.cx.tcx, *self_from, *other_from, pcx.cx.param_env), ) { (Some(to), Some(from)) => { (from == Ordering::Greater || from == Ordering::Equal) @@ -848,16 +842,7 @@ impl<'tcx> Constructor<'tcx> { (Str(self_val), Str(other_val)) => { // FIXME Once valtrees are available we can directly use the bytes // in the `Str` variant of the valtree for the comparison here. - match compare_const_vals( - pcx.cx.tcx, - *self_val, - *other_val, - pcx.cx.param_env, - pcx.ty, - ) { - Some(comparison) => comparison == Ordering::Equal, - None => false, - } + self_val == other_val } (Slice(self_slice), Slice(other_slice)) => self_slice.is_covered_by(*other_slice), diff --git a/compiler/rustc_mir_build/src/thir/pattern/mod.rs b/compiler/rustc_mir_build/src/thir/pattern/mod.rs index f5d957e30ff09..a13748a2d474a 100644 --- a/compiler/rustc_mir_build/src/thir/pattern/mod.rs +++ b/compiler/rustc_mir_build/src/thir/pattern/mod.rs @@ -15,8 +15,9 @@ use rustc_hir::def::{CtorOf, DefKind, Res}; use rustc_hir::pat_util::EnumerateAndAdjustIterator; use rustc_hir::RangeEnd; use rustc_index::vec::Idx; -use rustc_middle::mir::interpret::{get_slice_bytes, ConstValue}; -use rustc_middle::mir::interpret::{ErrorHandled, LitToConstError, LitToConstInput}; +use rustc_middle::mir::interpret::{ + ConstValue, ErrorHandled, LitToConstError, LitToConstInput, Scalar, +}; use rustc_middle::mir::{self, UserTypeProjection}; use rustc_middle::mir::{BorrowKind, Field, Mutability}; use rustc_middle::thir::{Ascription, BindingMode, FieldPat, LocalVarId, Pat, PatKind, PatRange}; @@ -129,7 +130,7 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> { ) -> PatKind<'tcx> { assert_eq!(lo.ty(), ty); assert_eq!(hi.ty(), ty); - let cmp = compare_const_vals(self.tcx, lo, hi, self.param_env, ty); + let cmp = compare_const_vals(self.tcx, lo, hi, self.param_env); match (end, cmp) { // `x..y` where `x < y`. // Non-empty because the range includes at least `x`. @@ -753,57 +754,49 @@ pub(crate) fn compare_const_vals<'tcx>( a: mir::ConstantKind<'tcx>, b: mir::ConstantKind<'tcx>, param_env: ty::ParamEnv<'tcx>, - ty: Ty<'tcx>, ) -> Option { - let from_bool = |v: bool| v.then_some(Ordering::Equal); - - let fallback = || from_bool(a == b); - - // Use the fallback if any type differs - if a.ty() != b.ty() || a.ty() != ty { - return fallback(); - } - - if a == b { - return from_bool(true); - } - - let a_bits = a.try_eval_bits(tcx, param_env, ty); - let b_bits = b.try_eval_bits(tcx, param_env, ty); - - if let (Some(a), Some(b)) = (a_bits, b_bits) { - use rustc_apfloat::Float; - return match *ty.kind() { - ty::Float(ty::FloatTy::F32) => { - let l = rustc_apfloat::ieee::Single::from_bits(a); - let r = rustc_apfloat::ieee::Single::from_bits(b); - l.partial_cmp(&r) - } - ty::Float(ty::FloatTy::F64) => { - let l = rustc_apfloat::ieee::Double::from_bits(a); - let r = rustc_apfloat::ieee::Double::from_bits(b); - l.partial_cmp(&r) - } - ty::Int(ity) => { - use rustc_middle::ty::layout::IntegerExt; - let size = rustc_target::abi::Integer::from_int_ty(&tcx, ity).size(); - let a = size.sign_extend(a); - let b = size.sign_extend(b); - Some((a as i128).cmp(&(b as i128))) - } - _ => Some(a.cmp(&b)), - }; - } - - if let ty::Str = ty.kind() && let ( - Some(a_val @ ConstValue::Slice { .. }), - Some(b_val @ ConstValue::Slice { .. }), - ) = (a.try_to_value(tcx), b.try_to_value(tcx)) - { - let a_bytes = get_slice_bytes(&tcx, a_val); - let b_bytes = get_slice_bytes(&tcx, b_val); - return from_bool(a_bytes == b_bytes); + assert_eq!(a.ty(), b.ty()); + + let ty = a.ty(); + + // This code is hot when compiling matches with many ranges. So we + // special-case extraction of evaluated scalars for speed, for types where + // raw data comparisons are appropriate. E.g. `unicode-normalization` has + // many ranges such as '\u{037A}'..='\u{037F}', and chars can be compared + // in this way. + match ty.kind() { + ty::Float(_) | ty::Int(_) => {} // require special handling, see below + _ => match (a, b) { + ( + mir::ConstantKind::Val(ConstValue::Scalar(Scalar::Int(a)), _a_ty), + mir::ConstantKind::Val(ConstValue::Scalar(Scalar::Int(b)), _b_ty), + ) => return Some(a.cmp(&b)), + _ => {} + }, + } + + let a = a.eval_bits(tcx, param_env, ty); + let b = b.eval_bits(tcx, param_env, ty); + + use rustc_apfloat::Float; + match *ty.kind() { + ty::Float(ty::FloatTy::F32) => { + let a = rustc_apfloat::ieee::Single::from_bits(a); + let b = rustc_apfloat::ieee::Single::from_bits(b); + a.partial_cmp(&b) + } + ty::Float(ty::FloatTy::F64) => { + let a = rustc_apfloat::ieee::Double::from_bits(a); + let b = rustc_apfloat::ieee::Double::from_bits(b); + a.partial_cmp(&b) + } + ty::Int(ity) => { + use rustc_middle::ty::layout::IntegerExt; + let size = rustc_target::abi::Integer::from_int_ty(&tcx, ity).size(); + let a = size.sign_extend(a); + let b = size.sign_extend(b); + Some((a as i128).cmp(&(b as i128))) + } + _ => Some(a.cmp(&b)), } - - fallback() }