diff --git a/compiler/rustc_mir/src/transform/early_otherwise_branch.rs b/compiler/rustc_mir/src/transform/early_otherwise_branch.rs index f97dcf4852df4..fb11006274379 100644 --- a/compiler/rustc_mir/src/transform/early_otherwise_branch.rs +++ b/compiler/rustc_mir/src/transform/early_otherwise_branch.rs @@ -34,12 +34,12 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch { let bbs_with_switch = body.basic_blocks().iter_enumerated().filter(|(_, bb)| is_switch(bb.terminator())); + let helper = Helper { body, tcx }; let opts_to_apply: Vec> = bbs_with_switch .flat_map(|(bb_idx, bb)| { let switch = bb.terminator(); - let helper = Helper { body, tcx }; - let infos = helper.go(bb, switch)?; - Some(OptimizationToApply { infos, basic_block_first_switch: bb_idx }) + let info = helper.go(bb, switch)?; + Some(OptimizationToApply { info, basic_block_first_switch: bb_idx }) }) .collect(); @@ -48,6 +48,8 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch { for opt_to_apply in opts_to_apply { trace!("SUCCESS: found optimization possibility to apply: {:?}", &opt_to_apply); + let first_switch_info = opt_to_apply.info.first_switch_info; + let second_switch_info = &opt_to_apply.info.second_switch_infos[0]; let statements_before = body.basic_blocks()[opt_to_apply.basic_block_first_switch].statements.len(); let end_of_block_location = Location { @@ -58,8 +60,8 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch { let mut patch = MirPatch::new(body); // create temp to store second discriminant in - let discr_type = opt_to_apply.infos[0].second_switch_info.discr_ty; - let discr_span = opt_to_apply.infos[0].second_switch_info.discr_source_info.span; + let discr_type = second_switch_info.discr_ty; + let discr_span = second_switch_info.discr_source_info.span; let second_discriminant_temp = patch.new_temp(discr_type, discr_span); patch.add_statement( @@ -68,8 +70,7 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch { ); // create assignment of discriminant - let place_of_adt_to_get_discriminant_of = - opt_to_apply.infos[0].second_switch_info.place_of_adt_discr_read; + let place_of_adt_to_get_discriminant_of = second_switch_info.place_of_adt_discr_read; patch.add_assign( end_of_block_location, Place::from(second_discriminant_temp), @@ -83,8 +84,7 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch { patch.add_statement(end_of_block_location, StatementKind::StorageLive(not_equal_temp)); // create NotEqual comparison between the two discriminants - let first_descriminant_place = - opt_to_apply.infos[0].first_switch_info.discr_used_in_switch; + let first_descriminant_place = first_switch_info.discr_used_in_switch; let not_equal_rvalue = Rvalue::BinaryOp( not_equal, Operand::Copy(Place::from(second_discriminant_temp)), @@ -96,19 +96,17 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch { ); let new_targets = opt_to_apply - .infos + .info + .second_switch_infos .iter() - .flat_map(|x| x.second_switch_info.targets_with_values.iter()) - .cloned(); + .flat_map(|x| x.switch_targets.iter().map(|x| (x.0, x.1))); - let targets = SwitchTargets::new( - new_targets, - opt_to_apply.infos[0].first_switch_info.otherwise_bb, - ); + let targets = + SwitchTargets::new(new_targets, first_switch_info.switch_targets.otherwise()); // new block that jumps to the correct discriminant case. This block is switched to if the discriminants are equal let new_switch_data = BasicBlockData::new(Some(Terminator { - source_info: opt_to_apply.infos[0].second_switch_info.discr_source_info, + source_info: second_switch_info.discr_source_info, kind: TerminatorKind::SwitchInt { // the first and second discriminants are equal, so just pick one discr: Operand::Copy(first_descriminant_place), @@ -121,7 +119,7 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch { // switch on the NotEqual. If true, then jump to the `otherwise` case. // If false, then jump to a basic block that then jumps to the correct disciminant case - let true_case = opt_to_apply.infos[0].first_switch_info.otherwise_bb; + let true_case = first_switch_info.switch_targets.otherwise(); let false_case = new_switch_bb; patch.patch_terminator( opt_to_apply.basic_block_first_switch, @@ -174,10 +172,8 @@ struct Helper<'a, 'tcx> { struct SwitchDiscriminantInfo<'tcx> { /// Type of the discriminant being switched on discr_ty: Ty<'tcx>, - /// The basic block that the otherwise branch points to - otherwise_bb: BasicBlock, - /// Target along with the value being branched from. Otherwise is not included - targets_with_values: Vec<(u128, BasicBlock)>, + /// Targets and values for the switch + switch_targets: SwitchTargets, discr_source_info: SourceInfo, /// The place of the discriminant used in the switch discr_used_in_switch: Place<'tcx>, @@ -189,7 +185,7 @@ struct SwitchDiscriminantInfo<'tcx> { #[derive(Debug)] struct OptimizationToApply<'tcx> { - infos: Vec>, + info: OptimizationInfo<'tcx>, /// Basic block of the original first switch basic_block_first_switch: BasicBlock, } @@ -198,8 +194,8 @@ struct OptimizationToApply<'tcx> { struct OptimizationInfo<'tcx> { /// Info about the first switch and discriminant first_switch_info: SwitchDiscriminantInfo<'tcx>, - /// Info about the second switch and discriminant - second_switch_info: SwitchDiscriminantInfo<'tcx>, + /// Info about all swtiches that are successors of the first switch + second_switch_infos: Vec>, } impl<'a, 'tcx> Helper<'a, 'tcx> { @@ -207,22 +203,19 @@ impl<'a, 'tcx> Helper<'a, 'tcx> { &self, bb: &BasicBlockData<'tcx>, switch: &Terminator<'tcx>, - ) -> Option>> { + ) -> Option> { // try to find the statement that defines the discriminant that is used for the switch let discr = self.find_switch_discriminant_info(bb, switch)?; // go through each target, finding a discriminant read, and a switch - let results = discr.targets_with_values.iter().map(|(value, target)| { - self.find_discriminant_switch_pairing(&discr, target.clone(), value.clone()) - }); - - // if the optimization did not apply for one of the targets, then abort - if results.clone().any(|x| x.is_none()) || results.len() == 0 { - trace!("NO: not all of the targets matched the pattern for optimization"); - return None; + let mut second_switch_infos = vec![]; + + for (value, target) in discr.switch_targets.iter() { + let info = self.find_discriminant_switch_pairing(&discr, target, value)?; + second_switch_infos.push(info); } - Some(results.flatten().collect()) + Some(OptimizationInfo { first_switch_info: discr, second_switch_infos }) } fn find_discriminant_switch_pairing( @@ -230,36 +223,22 @@ impl<'a, 'tcx> Helper<'a, 'tcx> { discr_info: &SwitchDiscriminantInfo<'tcx>, target: BasicBlock, value: u128, - ) -> Option> { + ) -> Option> { let bb = &self.body.basic_blocks()[target]; // find switch let terminator = bb.terminator(); if is_switch(terminator) { let this_bb_discr_info = self.find_switch_discriminant_info(bb, terminator)?; - // the types of the two adts matched on have to be equalfor this optimization to apply - if discr_info.type_adt_matched_on != this_bb_discr_info.type_adt_matched_on { - trace!( - "NO: types do not match. LHS: {:?}, RHS: {:?}", - discr_info.type_adt_matched_on, - this_bb_discr_info.type_adt_matched_on - ); - return None; - } - - // the otherwise branch of the two switches have to point to the same bb - if discr_info.otherwise_bb != this_bb_discr_info.otherwise_bb { + // The otherwise branch of the two switches have to point to the same bb + if discr_info.switch_targets.otherwise() + != this_bb_discr_info.switch_targets.otherwise() + { trace!("NO: otherwise target is not the same"); return None; } - // check that the value being matched on is the same. The - if this_bb_discr_info.targets_with_values.iter().find(|x| x.0 == value).is_none() { - trace!("NO: values being matched on are not the same"); - return None; - } - - // only allow optimization if the left and right of the tuple being matched are the same variants. + // Only allow optimization if the left and right of the tuple being matched are the same variants. // so the following should not optimize // ```rust // let x: Option<()>; @@ -270,8 +249,8 @@ impl<'a, 'tcx> Helper<'a, 'tcx> { // } // ``` // We check this by seeing that the value of the first discriminant is the only other discriminant value being used as a target in the second switch - if !(this_bb_discr_info.targets_with_values.len() == 1 - && this_bb_discr_info.targets_with_values[0].0 == value) + if !(this_bb_discr_info.switch_targets.iter().len() == 1 + && this_bb_discr_info.switch_targets.iter().next()?.0 == value) { trace!( "NO: The second switch did not have only 1 target (besides otherwise) that had the same value as the value from the first switch that got us here" @@ -279,13 +258,20 @@ impl<'a, 'tcx> Helper<'a, 'tcx> { return None; } + // The types of the two adts matched on have to be equal for this optimization to apply + if discr_info.type_adt_matched_on != this_bb_discr_info.type_adt_matched_on { + trace!( + "NO: types do not match. LHS: {:?}, RHS: {:?}", + discr_info.type_adt_matched_on, + this_bb_discr_info.type_adt_matched_on + ); + return None; + } + // if we reach this point, the optimization applies, and we should be able to optimize this case // store the info that is needed to apply the optimization - Some(OptimizationInfo { - first_switch_info: discr_info.clone(), - second_switch_info: this_bb_discr_info, - }) + Some(this_bb_discr_info) } else { None } @@ -302,9 +288,7 @@ impl<'a, 'tcx> Helper<'a, 'tcx> { // the declaration of the discriminant read. Place of this read is being used in the switch let discr_decl = &self.body.local_decls()[discr_local]; let discr_ty = discr_decl.ty; - // the otherwise target lies as the last element - let otherwise_bb = targets.otherwise(); - let targets_with_values = targets.iter().collect(); + let targets_with_values = targets.clone(); // find the place of the adt where the discriminant is being read from // assume this is the last statement of the block @@ -320,8 +304,7 @@ impl<'a, 'tcx> Helper<'a, 'tcx> { Some(SwitchDiscriminantInfo { discr_used_in_switch: discr.place()?, discr_ty, - otherwise_bb, - targets_with_values, + switch_targets: targets_with_values, discr_source_info: discr_decl.source_info, place_of_adt_discr_read, type_adt_matched_on,