diff --git a/lib/SILOptimizer/SILCombiner/SILCombinerMiscVisitors.cpp b/lib/SILOptimizer/SILCombiner/SILCombinerMiscVisitors.cpp index fbf4042cbdddd..90f87874b8060 100644 --- a/lib/SILOptimizer/SILCombiner/SILCombinerMiscVisitors.cpp +++ b/lib/SILOptimizer/SILCombiner/SILCombinerMiscVisitors.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #define DEBUG_TYPE "sil-combine" + #include "SILCombiner.h" #include "swift/Basic/STLExtras.h" #include "swift/SIL/DebugUtils.h" @@ -18,16 +19,17 @@ #include "swift/SIL/InstructionUtils.h" #include "swift/SIL/PatternMatch.h" #include "swift/SIL/Projection.h" +#include "swift/SIL/SILBitfield.h" #include "swift/SIL/SILBuilder.h" +#include "swift/SIL/SILInstruction.h" #include "swift/SIL/SILVisitor.h" -#include "swift/SIL/SILBitfield.h" #include "swift/SILOptimizer/Analysis/ARCAnalysis.h" #include "swift/SILOptimizer/Analysis/AliasAnalysis.h" #include "swift/SILOptimizer/Analysis/ValueTracking.h" +#include "swift/SILOptimizer/Utils/BasicBlockOptUtils.h" #include "swift/SILOptimizer/Utils/CFGOptUtils.h" #include "swift/SILOptimizer/Utils/Devirtualize.h" #include "swift/SILOptimizer/Utils/InstOptUtils.h" -#include "swift/SILOptimizer/Utils/BasicBlockOptUtils.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" @@ -1764,10 +1766,22 @@ SILInstruction *SILCombiner::visitUncheckedTakeEnumDataAddrInst( auto *svi = cast<SingleValueInstruction>(user); SILValue newValue; if (auto *oldLoad = dyn_cast<LoadInst>(svi)) { - auto newLoad = Builder.emitLoadValueOperation( - loc, enumAddr, oldLoad->getOwnershipQualifier()); - newValue = - Builder.createUncheckedEnumData(loc, newLoad, enumElt, payloadType); + // If the old load is trivial and our enum addr is non-trivial, we need to + // use a load_borrow here. We know that the unchecked_enum_data will + // produce a trivial value meaning that we can just do a + // load_borrow/immediately end the lifetime here. + if (oldLoad->getOwnershipQualifier() == LoadOwnershipQualifier::Trivial && + !enumAddr->getType().isTrivial(Builder.getFunction())) { + Builder.emitScopedBorrowOperation(loc, enumAddr, [&](SILValue newLoad) { + newValue = Builder.createUncheckedEnumData(loc, newLoad, enumElt, + payloadType); + }); + } else { + auto newLoad = Builder.emitLoadValueOperation( + loc, enumAddr, oldLoad->getOwnershipQualifier()); + newValue = + Builder.createUncheckedEnumData(loc, newLoad, enumElt, payloadType); + } } else if (auto *lbi = cast<LoadBorrowInst>(svi)) { auto newLoad = Builder.emitLoadBorrowOperation(loc, enumAddr); for (auto ui = lbi->consuming_use_begin(), ue = lbi->consuming_use_end(); diff --git a/test/SILOptimizer/sil_combine_ossa.sil b/test/SILOptimizer/sil_combine_ossa.sil index d73c40c80700a..c7aa651b7cbf9 100644 --- a/test/SILOptimizer/sil_combine_ossa.sil +++ b/test/SILOptimizer/sil_combine_ossa.sil @@ -41,6 +41,12 @@ enum AddressOnlyEnum { case AddressOnly(Any) } +enum NonTrivialLoadableEnum { +case noPayload +case trivialPayload(Builtin.Int32) +case nonTrivialPayload(Builtin.NativeObject) +} + protocol FakeProtocol { func requirement() } @@ -4800,3 +4806,21 @@ bb4: destroy_value %cast : $Builtin.NativeObject return %2 : $UInt } + +// CHECK-LABEL: sil [ossa] @unchecked_take_enum_data_addr_promotion_trivialpayload_nontrivial_enum : $@convention(thin) (@inout NonTrivialLoadableEnum) -> Builtin.Int32 { +// CHECK-NOT: unchecked_take_enum_data_addr +// CHECK: unchecked_enum_data +// CHECK-NOT: unchecked_take_enum_data_addr +// CHECK: } // end sil function 'unchecked_take_enum_data_addr_promotion_trivialpayload_nontrivial_enum' +sil [ossa] @unchecked_take_enum_data_addr_promotion_trivialpayload_nontrivial_enum : $@convention(thin) (@inout NonTrivialLoadableEnum) -> Builtin.Int32 { +bb0(%0 : $*NonTrivialLoadableEnum): + switch_enum_addr %0 : $*NonTrivialLoadableEnum, case #NonTrivialLoadableEnum.trivialPayload!enumelt: bb1, default bb2 + +bb1: + %2 = unchecked_take_enum_data_addr %0 : $*NonTrivialLoadableEnum, #NonTrivialLoadableEnum.trivialPayload!enumelt + %3 = load [trivial] %2 : $*Builtin.Int32 + return %3 : $Builtin.Int32 + +bb2: + unreachable +}