diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h index e4b1241151e9d..ca35b91e6300a 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -508,6 +508,30 @@ class OpenMPIRBuilder { return allocaInst; } }; + + struct ScanInformation { + /// Dominates the body of the loop before scan directive + llvm::BasicBlock *OMPBeforeScanBlock = nullptr; + /// Dominates the body of the loop before scan directive + llvm::BasicBlock *OMPAfterScanBlock = nullptr; + /// Controls the flow to before or after scan blocks + llvm::BasicBlock *OMPScanDispatch = nullptr; + /// Exit block of loop body + llvm::BasicBlock *OMPScanLoopExit = nullptr; + /// Block before loop body where scan initializations are done + llvm::BasicBlock *OMPScanInit = nullptr; + /// Block after loop body where scan finalizations are done + llvm::BasicBlock *OMPScanFinish = nullptr; + /// If true, it indicates Input phase is lowered; else it indicates + /// ScanPhase is lowered + bool OMPFirstScanLoop = false; + // Maps the private reduction variable to the pointer of the temporary + // buffer + llvm::SmallDenseMap ScanBuffPtrs; + llvm::Value *IV; + llvm::Value *Span; + } ScanInfo; + /// Initialize the internal state, this will put structures types and /// potentially other helpers into the underlying module. Must be called /// before any other method and only once! This internal state includes types @@ -743,6 +767,35 @@ class OpenMPIRBuilder { LoopBodyGenCallbackTy BodyGenCB, Value *TripCount, const Twine &Name = "loop"); + /// Generator for the control flow structure of an OpenMP canonical loops if + /// the parent directive has an `inscan` modifier specified. + /// If the `inscan` modifier is specified, the region of the parent is + /// expected to have a `scan` directive. Based on the clauses in + /// scan directive, the body of the loop is split into two loops: Input loop + /// and Scan Loop. Input loop contains the code generated for input phase of + /// scan and Scan loop contains the code generated for scan phase of scan. + /// + /// \param Loc The insert and source location description. + /// \param BodyGenCB Callback that will generate the loop body code. + /// \param Start Value of the loop counter for the first iterations. + /// \param Stop Loop counter values past this will stop the loop. + /// \param Step Loop counter increment after each iteration; negative + /// means counting down. + /// \param IsSigned Whether Start, Stop and Step are signed integers. + /// \param InclusiveStop Whether \p Stop itself is a valid value for the loop + /// counter. + /// \param ComputeIP Insertion point for instructions computing the trip + /// count. Can be used to ensure the trip count is available + /// at the outermost loop of a loop nest. If not set, + /// defaults to the preheader of the generated loop. + /// \param Name Base name used to derive BB and instruction names. + /// + /// \returns A vector containing Loop Info of Input Loop and Scan Loop. + Expected> createCanonicalScanLoops( + const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB, + Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop, + InsertPointTy ComputeIP, const Twine &Name); + /// Calculate the trip count of a canonical loop. /// /// This allows specifying user-defined loop counter values using increment, @@ -811,13 +864,16 @@ class OpenMPIRBuilder { /// at the outermost loop of a loop nest. If not set, /// defaults to the preheader of the generated loop. /// \param Name Base name used to derive BB and instruction names. + /// \param InScan Whether loop has a scan reduction specified. /// /// \returns An object representing the created control flow structure which /// can be used for loop-associated directives. - LLVM_ABI Expected createCanonicalLoop( - const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB, - Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop, - InsertPointTy ComputeIP = {}, const Twine &Name = "loop"); + LLVM_ABI Expected + createCanonicalLoop(const LocationDescription &Loc, + LoopBodyGenCallbackTy BodyGenCB, Value *Start, + Value *Stop, Value *Step, bool IsSigned, + bool InclusiveStop, InsertPointTy ComputeIP = {}, + const Twine &Name = "loop", bool InScan = false); /// Collapse a loop nest into a single loop. /// @@ -1548,6 +1604,35 @@ class OpenMPIRBuilder { ArrayRef ReductionInfos, Function *ReduceFn, AttributeList FuncAttrs); + /// Helper function for CreateCanonicalScanLoops to create InputLoop + /// in the firstGen and Scan Loop in the SecondGen + /// \param InputLoopGen Callback for generating the loop for input phase + /// \param ScanLoopGen Callback for generating the loop for scan phase + /// + /// \return error if any produced, else return success. + Error emitScanBasedDirectiveIR( + llvm::function_ref InputLoopGen, + llvm::function_ref ScanLoopGen); + + /// Creates the basic blocks required for scan reduction. + void createScanBBs(); + + /// Dynamically allocates the buffer needed for scan reduction. + /// \param AllocaIP The IP where possibly-shared pointer of buffer needs to be + /// declared. \param ScanVars Scan Variables. + /// + /// \return error if any produced, else return success. + Error emitScanBasedDirectiveDeclsIR(InsertPointTy AllocaIP, + ArrayRef ScanVars, + ArrayRef ScanVarsType); + + /// Copies the result back to the reduction variable. + /// \param ReductionInfos Array type containing the ReductionOps. + /// + /// \return error if any produced, else return success. + Error emitScanBasedDirectiveFinalsIR( + SmallVector ReductionInfos); + /// This function emits a helper that gathers Reduce lists from the first /// lane of every active warp to lanes in the first warp. /// @@ -2631,6 +2716,41 @@ class OpenMPIRBuilder { FinalizeCallbackTy FiniCB, Value *Filter); + /// This function performs the scan reduction of the values updated in + /// the input phase. The reduction logic needs to be emitted between input + /// and scan loop returned by `CreateCanonicalScanLoops`. The following + /// is the code that is generated, `buffer` and `span` are expected to be + /// populated before executing the generated code. + /// + /// for (int k = 0; k != ceil(log2(span)); ++k) { + /// i=pow(2,k) + /// for (size cnt = last_iter; cnt >= i; --cnt) + /// buffer[cnt] op= buffer[cnt-i]; + /// } + /// \param Loc The insert and source location description. + /// \param ReductionInfos Array type containing the ReductionOps. + /// + /// \returns The insertion position *after* the masked. + InsertPointOrErrorTy emitScanReduction( + const LocationDescription &Loc, + SmallVector ReductionInfos); + + /// This directive split and directs the control flow to input phase + /// blocks or scan phase blocks based on 1. whether input loop or scan loop + /// is executed, 2. whether exclusive or inclusive scan is used. + /// + /// \param Loc The insert and source location description. + /// \param AllocaIP The IP where the temporary buffer for scan reduction + // needs to be allocated. + /// \param ScanVars Scan Variables. + /// \param IsInclusive Whether it is an inclusive or exclusive scan. + /// + /// \returns The insertion position *after* the scan. + InsertPointOrErrorTy createScan(const LocationDescription &Loc, + InsertPointTy AllocaIP, + ArrayRef ScanVars, + ArrayRef ScanVarsType, + bool IsInclusive); /// Generator for '#omp critical' /// /// \param Loc The insert and source location description. diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index ca3d8438654dc..5cdb0406e2b6d 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -59,6 +59,8 @@ #include "llvm/Transforms/Utils/LoopPeel.h" #include "llvm/Transforms/Utils/UnrollLoop.h" +#include +#include #include #include @@ -4011,6 +4013,334 @@ OpenMPIRBuilder::createMasked(const LocationDescription &Loc, /*Conditional*/ true, /*hasFinalize*/ true); } +llvm::CallInst *emitNoUnwindRuntimeCall(IRBuilder<> &Builder, + llvm::FunctionCallee Callee, + ArrayRef Args, + const llvm::Twine &Name) { + llvm::CallInst *Call = Builder.CreateCall( + Callee, Args, SmallVector(), Name); + Call->setDoesNotThrow(); + return Call; +} + +// Expects input basic block is dominated by BeforeScanBB. +// Once Scan directive is encountered, the code after scan directive should be +// dominated by AfterScanBB. Scan directive splits the code sequence to +// scan and input phase. Based on whether inclusive or exclusive +// clause is used in the scan directive and whether input loop or scan loop +// is lowered, it adds jumps to input and scan phase. First Scan loop is the +// input loop and second is the scan loop. The code generated handles only +// inclusive scans now. +OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createScan( + const LocationDescription &Loc, InsertPointTy AllocaIP, + ArrayRef ScanVars, ArrayRef ScanVarsType, + bool IsInclusive) { + if (ScanInfo.OMPFirstScanLoop) { + llvm::Error Err = + emitScanBasedDirectiveDeclsIR(AllocaIP, ScanVars, ScanVarsType); + if (Err) + return Err; + } + if (!updateToLocation(Loc)) + return Loc.IP; + + llvm::Value *IV = ScanInfo.IV; + + if (ScanInfo.OMPFirstScanLoop) { + // Emit buffer[i] = red; at the end of the input phase. + for (size_t i = 0; i < ScanVars.size(); i++) { + Value *BuffPtr = ScanInfo.ScanBuffPtrs[ScanVars[i]]; + Value *Buff = Builder.CreateLoad(Builder.getPtrTy(), BuffPtr); + Type *DestTy = ScanVarsType[i]; + Value *Val = Builder.CreateInBoundsGEP(DestTy, Buff, IV, "arrayOffset"); + Value *Src = Builder.CreateLoad(DestTy, ScanVars[i]); + + Builder.CreateStore(Src, Val); + } + } + Builder.CreateBr(ScanInfo.OMPScanLoopExit); + emitBlock(ScanInfo.OMPScanDispatch, Builder.GetInsertBlock()->getParent()); + + if (!ScanInfo.OMPFirstScanLoop) { + IV = ScanInfo.IV; + // Emit red = buffer[i]; at the entrance to the scan phase. + // TODO: if exclusive scan, the red = buffer[i-1] needs to be updated. + for (size_t i = 0; i < ScanVars.size(); i++) { + Value *BuffPtr = ScanInfo.ScanBuffPtrs[ScanVars[i]]; + Value *Buff = Builder.CreateLoad(Builder.getPtrTy(), BuffPtr); + Type *DestTy = ScanVarsType[i]; + Value *SrcPtr = + Builder.CreateInBoundsGEP(DestTy, Buff, IV, "arrayOffset"); + Value *Src = Builder.CreateLoad(DestTy, SrcPtr); + Builder.CreateStore(Src, ScanVars[i]); + } + } + + // TODO: Update it to CreateBr and remove dead blocks + llvm::Value *CmpI = Builder.getInt1(true); + if (ScanInfo.OMPFirstScanLoop == IsInclusive) { + Builder.CreateCondBr(CmpI, ScanInfo.OMPBeforeScanBlock, + ScanInfo.OMPAfterScanBlock); + } else { + Builder.CreateCondBr(CmpI, ScanInfo.OMPAfterScanBlock, + ScanInfo.OMPBeforeScanBlock); + } + emitBlock(ScanInfo.OMPAfterScanBlock, Builder.GetInsertBlock()->getParent()); + Builder.SetInsertPoint(ScanInfo.OMPAfterScanBlock); + return Builder.saveIP(); +} + +Error OpenMPIRBuilder::emitScanBasedDirectiveDeclsIR( + InsertPointTy AllocaIP, ArrayRef ScanVars, + ArrayRef ScanVarsType) { + + Builder.restoreIP(AllocaIP); + // Create the shared pointer at alloca IP. + for (size_t i = 0; i < ScanVars.size(); i++) { + llvm::Value *BuffPtr = + Builder.CreateAlloca(Builder.getPtrTy(), nullptr, "vla"); + ScanInfo.ScanBuffPtrs[ScanVars[i]] = BuffPtr; + } + + // Allocate temporary buffer by master thread + auto BodyGenCB = [&](InsertPointTy AllocaIP, + InsertPointTy CodeGenIP) -> Error { + Builder.restoreIP(CodeGenIP); + Value *AllocSpan = Builder.CreateAdd(ScanInfo.Span, Builder.getInt32(1)); + for (size_t i = 0; i < ScanVars.size(); i++) { + Type *IntPtrTy = Builder.getInt32Ty(); + Constant *Allocsize = ConstantExpr::getSizeOf(ScanVarsType[i]); + Allocsize = ConstantExpr::getTruncOrBitCast(Allocsize, IntPtrTy); + Value *Buff = Builder.CreateMalloc(IntPtrTy, ScanVarsType[i], Allocsize, + AllocSpan, nullptr, "arr"); + Builder.CreateStore(Buff, ScanInfo.ScanBuffPtrs[ScanVars[i]]); + } + return Error::success(); + }; + // TODO: Perform finalization actions for variables. This has to be + // called for variables which have destructors/finalizers. + auto FiniCB = [&](InsertPointTy CodeGenIP) { return llvm::Error::success(); }; + + Builder.SetInsertPoint(ScanInfo.OMPScanInit->getTerminator()); + llvm::Value *FilterVal = Builder.getInt32(0); + llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = + createMasked(Builder.saveIP(), BodyGenCB, FiniCB, FilterVal); + + if (!AfterIP) + return AfterIP.takeError(); + Builder.restoreIP(*AfterIP); + BasicBlock *InputBB = Builder.GetInsertBlock(); + if (InputBB->getTerminator()) + Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator()); + AfterIP = createBarrier(Builder.saveIP(), llvm::omp::OMPD_barrier); + if (!AfterIP) + return AfterIP.takeError(); + Builder.restoreIP(*AfterIP); + + return Error::success(); +} + +Error OpenMPIRBuilder::emitScanBasedDirectiveFinalsIR( + SmallVector ReductionInfos) { + auto BodyGenCB = [&](InsertPointTy AllocaIP, + InsertPointTy CodeGenIP) -> Error { + Builder.restoreIP(CodeGenIP); + for (ReductionInfo RedInfo : ReductionInfos) { + Value *PrivateVar = RedInfo.PrivateVariable; + Value *OrigVar = RedInfo.Variable; + Value *BuffPtr = ScanInfo.ScanBuffPtrs[PrivateVar]; + Value *Buff = Builder.CreateLoad(Builder.getPtrTy(), BuffPtr); + + Type *SrcTy = RedInfo.ElementType; + Value *Val = + Builder.CreateInBoundsGEP(SrcTy, Buff, ScanInfo.Span, "arrayOffset"); + Value *Src = Builder.CreateLoad(SrcTy, Val); + + Builder.CreateStore(Src, OrigVar); + Builder.CreateFree(Buff); + } + return Error::success(); + }; + // TODO: Perform finalization actions for variables. This has to be + // called for variables which have destructors/finalizers. + auto FiniCB = [&](InsertPointTy CodeGenIP) { return llvm::Error::success(); }; + + if (ScanInfo.OMPScanFinish->getTerminator()) + Builder.SetInsertPoint(ScanInfo.OMPScanFinish->getTerminator()); + else + Builder.SetInsertPoint(ScanInfo.OMPScanFinish); + + llvm::Value *FilterVal = Builder.getInt32(0); + llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = + createMasked(Builder.saveIP(), BodyGenCB, FiniCB, FilterVal); + + if (!AfterIP) + return AfterIP.takeError(); + Builder.restoreIP(*AfterIP); + BasicBlock *InputBB = Builder.GetInsertBlock(); + if (InputBB->getTerminator()) + Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator()); + AfterIP = createBarrier(Builder.saveIP(), llvm::omp::OMPD_barrier); + if (!AfterIP) + return AfterIP.takeError(); + Builder.restoreIP(*AfterIP); + return Error::success(); +} + +OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitScanReduction( + const LocationDescription &Loc, + SmallVector ReductionInfos) { + + if (!updateToLocation(Loc)) + return Loc.IP; + auto BodyGenCB = [&](InsertPointTy AllocaIP, + InsertPointTy CodeGenIP) -> Error { + Builder.restoreIP(CodeGenIP); + Function *CurFn = Builder.GetInsertBlock()->getParent(); + // for (int k = 0; k <= ceil(log2(n)); ++k) + llvm::BasicBlock *LoopBB = + BasicBlock::Create(CurFn->getContext(), "omp.outer.log.scan.body"); + llvm::BasicBlock *ExitBB = + splitBB(Builder, false, "omp.outer.log.scan.exit"); + llvm::Function *F = llvm::Intrinsic::getOrInsertDeclaration( + Builder.GetInsertBlock()->getModule(), + (llvm::Intrinsic::ID)llvm::Intrinsic::log2, Builder.getDoubleTy()); + llvm::BasicBlock *InputBB = Builder.GetInsertBlock(); + llvm::Value *Arg = + Builder.CreateUIToFP(ScanInfo.Span, Builder.getDoubleTy()); + llvm::Value *LogVal = emitNoUnwindRuntimeCall(Builder, F, Arg, ""); + F = llvm::Intrinsic::getOrInsertDeclaration( + Builder.GetInsertBlock()->getModule(), + (llvm::Intrinsic::ID)llvm::Intrinsic::ceil, Builder.getDoubleTy()); + LogVal = emitNoUnwindRuntimeCall(Builder, F, LogVal, ""); + LogVal = Builder.CreateFPToUI(LogVal, Builder.getInt32Ty()); + llvm::Value *NMin1 = Builder.CreateNUWSub( + ScanInfo.Span, llvm::ConstantInt::get(ScanInfo.Span->getType(), 1)); + Builder.SetInsertPoint(InputBB); + Builder.CreateBr(LoopBB); + emitBlock(LoopBB, CurFn); + Builder.SetInsertPoint(LoopBB); + + PHINode *Counter = Builder.CreatePHI(Builder.getInt32Ty(), 2); + //// size pow2k = 1; + PHINode *Pow2K = Builder.CreatePHI(Builder.getInt32Ty(), 2); + Counter->addIncoming(llvm::ConstantInt::get(Builder.getInt32Ty(), 0), + InputBB); + Pow2K->addIncoming(llvm::ConstantInt::get(Builder.getInt32Ty(), 1), + InputBB); + //// for (size i = n - 1; i >= 2 ^ k; --i) + //// tmp[i] op= tmp[i-pow2k]; + llvm::BasicBlock *InnerLoopBB = + BasicBlock::Create(CurFn->getContext(), "omp.inner.log.scan.body"); + llvm::BasicBlock *InnerExitBB = + BasicBlock::Create(CurFn->getContext(), "omp.inner.log.scan.exit"); + llvm::Value *CmpI = Builder.CreateICmpUGE(NMin1, Pow2K); + Builder.CreateCondBr(CmpI, InnerLoopBB, InnerExitBB); + emitBlock(InnerLoopBB, CurFn); + Builder.SetInsertPoint(InnerLoopBB); + auto *IVal = Builder.CreatePHI(Builder.getInt32Ty(), 2); + IVal->addIncoming(NMin1, LoopBB); + for (ReductionInfo RedInfo : ReductionInfos) { + Value *ReductionVal = RedInfo.PrivateVariable; + Value *BuffPtr = ScanInfo.ScanBuffPtrs[ReductionVal]; + Value *Buff = Builder.CreateLoad(Builder.getPtrTy(), BuffPtr); + Type *DestTy = RedInfo.ElementType; + Value *IV = Builder.CreateAdd(IVal, Builder.getInt32(1)); + Value *LHSPtr = + Builder.CreateInBoundsGEP(DestTy, Buff, IV, "arrayOffset"); + Value *OffsetIval = Builder.CreateNUWSub(IV, Pow2K); + Value *RHSPtr = + Builder.CreateInBoundsGEP(DestTy, Buff, OffsetIval, "arrayOffset"); + Value *LHS = Builder.CreateLoad(DestTy, LHSPtr); + Value *RHS = Builder.CreateLoad(DestTy, RHSPtr); + llvm::Value *Result; + InsertPointOrErrorTy AfterIP = + RedInfo.ReductionGen(Builder.saveIP(), LHS, RHS, Result); + if (!AfterIP) + return AfterIP.takeError(); + Builder.CreateStore(Result, LHSPtr); + } + llvm::Value *NextIVal = Builder.CreateNUWSub( + IVal, llvm::ConstantInt::get(Builder.getInt32Ty(), 1)); + IVal->addIncoming(NextIVal, Builder.GetInsertBlock()); + CmpI = Builder.CreateICmpUGE(NextIVal, Pow2K); + Builder.CreateCondBr(CmpI, InnerLoopBB, InnerExitBB); + emitBlock(InnerExitBB, CurFn); + llvm::Value *Next = Builder.CreateNUWAdd( + Counter, llvm::ConstantInt::get(Counter->getType(), 1)); + Counter->addIncoming(Next, Builder.GetInsertBlock()); + // pow2k <<= 1; + llvm::Value *NextPow2K = Builder.CreateShl(Pow2K, 1, "", /*HasNUW=*/true); + Pow2K->addIncoming(NextPow2K, Builder.GetInsertBlock()); + llvm::Value *Cmp = Builder.CreateICmpNE(Next, LogVal); + Builder.CreateCondBr(Cmp, LoopBB, ExitBB); + Builder.SetInsertPoint(ExitBB->getFirstInsertionPt()); + return Error::success(); + }; + + // TODO: Perform finalization actions for variables. This has to be + // called for variables which have destructors/finalizers. + auto FiniCB = [&](InsertPointTy CodeGenIP) { return llvm::Error::success(); }; + + llvm::Value *FilterVal = Builder.getInt32(0); + llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = + createMasked(Builder.saveIP(), BodyGenCB, FiniCB, FilterVal); + + if (!AfterIP) + return AfterIP.takeError(); + Builder.restoreIP(*AfterIP); + AfterIP = createBarrier(Builder.saveIP(), llvm::omp::OMPD_barrier); + + if (!AfterIP) + return AfterIP.takeError(); + Builder.restoreIP(*AfterIP); + Error Err = emitScanBasedDirectiveFinalsIR(ReductionInfos); + if (Err) + return Err; + + return AfterIP; +} + +Error OpenMPIRBuilder::emitScanBasedDirectiveIR( + llvm::function_ref InputLoopGen, + llvm::function_ref ScanLoopGen) { + + { + // Emit loop with input phase: + // for (i: 0..) { + // ; + // buffer[i] = red; + // } + ScanInfo.OMPFirstScanLoop = true; + auto Result = InputLoopGen(); + if (Result) + return Result; + } + { + // Emit loop with scan phase: + // for (i: 0..) { + // red = buffer[i]; + // ; + // } + ScanInfo.OMPFirstScanLoop = false; + auto Result = ScanLoopGen(Builder.saveIP()); + if (Result) + return Result; + } + return Error::success(); +} + +void OpenMPIRBuilder::createScanBBs() { + Function *Fun = Builder.GetInsertBlock()->getParent(); + ScanInfo.OMPScanDispatch = + BasicBlock::Create(Fun->getContext(), "omp.inscan.dispatch"); + ScanInfo.OMPAfterScanBlock = + BasicBlock::Create(Fun->getContext(), "omp.after.scan.bb"); + ScanInfo.OMPBeforeScanBlock = + BasicBlock::Create(Fun->getContext(), "omp.before.scan.bb"); + ScanInfo.OMPScanLoopExit = + BasicBlock::Create(Fun->getContext(), "omp.scan.loop.exit"); +} CanonicalLoopInfo *OpenMPIRBuilder::createLoopSkeleton( DebugLoc DL, Value *TripCount, Function *F, BasicBlock *PreInsertBefore, BasicBlock *PostInsertBefore, const Twine &Name) { @@ -4108,6 +4438,92 @@ OpenMPIRBuilder::createCanonicalLoop(const LocationDescription &Loc, return CL; } +Expected> +OpenMPIRBuilder::createCanonicalScanLoops( + const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB, + Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop, + InsertPointTy ComputeIP, const Twine &Name) { + LocationDescription ComputeLoc = + ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc; + updateToLocation(ComputeLoc); + + Value *TripCount = calculateCanonicalLoopTripCount( + ComputeLoc, Start, Stop, Step, IsSigned, InclusiveStop, Name); + ScanInfo.Span = TripCount; + ScanInfo.OMPScanInit = splitBB(Builder, true, "scan.init"); + Builder.SetInsertPoint(ScanInfo.OMPScanInit); + + auto BodyGen = [=](InsertPointTy CodeGenIP, Value *IV) { + /// The control of the loopbody of following structure: + /// + /// InputBlock + /// | + /// ContinueBlock + /// + /// is transformed to: + /// + /// InputBlock + /// | + /// OMPScanDispatch + /// + /// OMPBeforeScanBlock + /// | + /// OMPScanLoopExit + /// | + /// ContinueBlock + /// + /// OMPBeforeScanBlock dominates the control flow of code generated until + /// scan directive is encountered and OMPAfterScanBlock dominates the + /// control flow of code generated after scan is encountered. The successor + /// of OMPScanDispatch can be OMPBeforeScanBlock or OMPAfterScanBlock based + /// on 1.whether it is in Input phase or Scan Phase , 2. whether it is an + /// exclusive or inclusive scan. + Builder.restoreIP(CodeGenIP); + ScanInfo.IV = IV; + createScanBBs(); + BasicBlock *InputBlock = Builder.GetInsertBlock(); + Instruction *Terminator = InputBlock->getTerminator(); + assert(Terminator->getNumSuccessors() == 1); + BasicBlock *ContinueBlock = Terminator->getSuccessor(0); + Terminator->setSuccessor(0, ScanInfo.OMPScanDispatch); + emitBlock(ScanInfo.OMPBeforeScanBlock, + Builder.GetInsertBlock()->getParent()); + Builder.CreateBr(ScanInfo.OMPScanLoopExit); + emitBlock(ScanInfo.OMPScanLoopExit, Builder.GetInsertBlock()->getParent()); + Builder.CreateBr(ContinueBlock); + Builder.SetInsertPoint(ScanInfo.OMPBeforeScanBlock->getFirstInsertionPt()); + return BodyGenCB(Builder.saveIP(), IV); + }; + + SmallVector Result; + const auto &&InputLoopGen = [&]() -> Error { + auto LoopInfo = + createCanonicalLoop(Builder.saveIP(), BodyGen, Start, Stop, Step, + IsSigned, InclusiveStop, ComputeIP, Name, true); + if (!LoopInfo) + return LoopInfo.takeError(); + Result.push_back(*LoopInfo); + Builder.restoreIP((*LoopInfo)->getAfterIP()); + return Error::success(); + }; + const auto &&ScanLoopGen = [&](LocationDescription Loc) -> Error { + auto LoopInfo = + createCanonicalLoop(Loc, BodyGen, Start, Stop, Step, IsSigned, + InclusiveStop, ComputeIP, Name, true); + if (!LoopInfo) + return LoopInfo.takeError(); + Result.push_back(*LoopInfo); + Builder.restoreIP((*LoopInfo)->getAfterIP()); + ScanInfo.OMPScanFinish = Builder.GetInsertBlock(); + return Error::success(); + }; + Error Err = emitScanBasedDirectiveIR(InputLoopGen, ScanLoopGen); + if (Err) { + return Err; + } + return Result; +} + Value *OpenMPIRBuilder::calculateCanonicalLoopTripCount( const LocationDescription &Loc, Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop, const Twine &Name) { @@ -4171,7 +4587,7 @@ Value *OpenMPIRBuilder::calculateCanonicalLoopTripCount( Expected OpenMPIRBuilder::createCanonicalLoop( const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB, Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop, - InsertPointTy ComputeIP, const Twine &Name) { + InsertPointTy ComputeIP, const Twine &Name, bool InScan) { LocationDescription ComputeLoc = ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc; @@ -4182,6 +4598,8 @@ Expected OpenMPIRBuilder::createCanonicalLoop( Builder.restoreIP(CodeGenIP); Value *Span = Builder.CreateMul(IV, Step); Value *IndVar = Builder.CreateAdd(Span, Start); + if (InScan) + ScanInfo.IV = IndVar; return BodyGenCB(Builder.saveIP(), IndVar); }; LocationDescription LoopLoc = ComputeIP.isSet() ? Loc.IP : Builder.saveIP(); diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp index be98be260c9dc..eba3b4db88afb 100644 --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -23,6 +23,7 @@ #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "gmock/gmock.h" #include "gtest/gtest.h" +#include #include using namespace llvm; @@ -5349,6 +5350,100 @@ TEST_F(OpenMPIRBuilderTest, CreateReductions) { EXPECT_TRUE(findGEPZeroOne(ReductionFn->getArg(1), FirstRHS, SecondRHS)); } +void createScan(llvm::Value *scanVar, llvm::Type *scanType, + OpenMPIRBuilder &OMPBuilder, IRBuilder<> &Builder, + OpenMPIRBuilder::LocationDescription Loc, + OpenMPIRBuilder::InsertPointTy &allocaIP) { + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + ASSERT_EXPECTED_INIT( + InsertPointTy, retIp, + OMPBuilder.createScan(Loc, allocaIP, {scanVar}, {scanType}, true)); + Builder.restoreIP(retIp); +} + +TEST_F(OpenMPIRBuilderTest, ScanReduction) { + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.initialize(); + IRBuilder<> Builder(BB); + OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL}); + Value *TripCount = F->getArg(0); + Type *LCTy = TripCount->getType(); + Value *StartVal = ConstantInt::get(LCTy, 1); + Value *StopVal = ConstantInt::get(LCTy, 100); + Value *Step = ConstantInt::get(LCTy, 1); + auto AllocaIP = Builder.saveIP(); + + llvm::Value *ScanVar = Builder.CreateAlloca(Builder.getFloatTy()); + llvm::Value *OrigVar = Builder.CreateAlloca(Builder.getFloatTy()); + unsigned NumBodiesGenerated = 0; + auto LoopBodyGenCB = [&](InsertPointTy CodeGenIP, llvm::Value *LC) { + NumBodiesGenerated += 1; + Builder.restoreIP(CodeGenIP); + createScan(ScanVar, Builder.getFloatTy(), OMPBuilder, Builder, Loc, + AllocaIP); + return Error::success(); + }; + SmallVector Loops; + ASSERT_EXPECTED_INIT(SmallVector, loopsVec, + OMPBuilder.createCanonicalScanLoops( + Loc, LoopBodyGenCB, StartVal, StopVal, Step, false, + false, Builder.saveIP(), "scan")); + Loops = loopsVec; + EXPECT_EQ(Loops.size(), 2U); + CanonicalLoopInfo *InputLoop = Loops.front(); + CanonicalLoopInfo *ScanLoop = Loops.back(); + Builder.restoreIP(ScanLoop->getAfterIP()); + InputLoop->assertOK(); + ScanLoop->assertOK(); + + EXPECT_EQ(ScanLoop->getAfter(), Builder.GetInsertBlock()); + EXPECT_EQ(NumBodiesGenerated, 2U); + SmallVector ReductionInfos = { + {Builder.getFloatTy(), OrigVar, ScanVar, + /*EvaluationKind=*/OpenMPIRBuilder::EvalKind::Scalar, sumReduction, + /*ReductionGenClang=*/nullptr, sumAtomicReduction}}; + OpenMPIRBuilder::LocationDescription RedLoc({InputLoop->getAfterIP(), DL}); + llvm::BasicBlock *Cont = splitBB(Builder, false, "omp.scan.loop.cont"); + ASSERT_EXPECTED_INIT(InsertPointTy, retIp, + OMPBuilder.emitScanReduction(RedLoc, ReductionInfos)); + Builder.restoreIP(retIp); + Builder.CreateBr(Cont); + Builder.SetInsertPoint(Cont); + unsigned NumMallocs = 0; + unsigned NumFrees = 0; + unsigned NumMasked = 0; + unsigned NumEndMasked = 0; + unsigned NumLog = 0; + unsigned NumCeil = 0; + for (Instruction &I : instructions(F)) { + if (isa(I)) { + CallInst *Call = dyn_cast(&I); + auto Name = Call->getCalledFunction()->getName(); + if (Name.equals_insensitive("malloc")) { + NumMallocs += 1; + } else if (Name.equals_insensitive("free")) { + NumFrees += 1; + } else if (Name.equals_insensitive("__kmpc_masked")) { + NumMasked += 1; + } else if (Name.equals_insensitive("__kmpc_end_masked")) { + NumEndMasked += 1; + } else if (Name.equals_insensitive("llvm.log2.f64")) { + NumLog += 1; + } else if (Name.equals_insensitive("llvm.ceil.f64")) { + NumCeil += 1; + } + } + } + EXPECT_EQ(NumBodiesGenerated, 2U); + EXPECT_EQ(NumMasked, 3U); + EXPECT_EQ(NumEndMasked, 3U); + EXPECT_EQ(NumMallocs, 1U); + EXPECT_EQ(NumFrees, 1U); + EXPECT_EQ(NumLog, 1U); + EXPECT_EQ(NumCeil, 1U); +} + TEST_F(OpenMPIRBuilderTest, CreateTwoReductions) { using InsertPointTy = OpenMPIRBuilder::InsertPointTy; OpenMPIRBuilder OMPBuilder(*M);