diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h index 6b104708bdb0d..386de1fa51f18 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -503,6 +503,37 @@ class OpenMPIRBuilder { return allocaInst; } }; + + /// Type used throughout for insertion points. + using InsertPointTy = IRBuilder<>::InsertPoint; + + /// Type used to represent an insertion point or an error value. + using InsertPointOrErrorTy = Expected; + + struct ScanInformation { + public: + /// Dominates the body of the loop before scan directive + llvm::BasicBlock *OMPBeforeScanBlock = nullptr; + /// Dominates the body of the loop before scan directive + llvm::BasicBlock *OMPAfterScanBlock = nullptr; + /// Controls the flow to before or after scan blocks + llvm::BasicBlock *OMPScanDispatch = nullptr; + /// Exit block of loop body + llvm::BasicBlock *OMPScanLoopExit = nullptr; + /// Block before loop body where scan initializations are done + llvm::BasicBlock *OMPScanInit = nullptr; + /// Block after loop body where scan finalizations are done + llvm::BasicBlock *OMPScanFinish = nullptr; + /// If true, it indicates Input phase is lowered; else it indicates + /// ScanPhase is lowered + bool OMPFirstScanLoop = false; + // Maps the private reduction variable to the pointer of the temporary + // buffer + llvm::SmallDenseMap ScanBuffPtrs; + llvm::Value *IV; + llvm::Value *Span; + } ScanInfo; + /// Initialize the internal state, this will put structures types and /// potentially other helpers into the underlying module. Must be called /// before any other method and only once! This internal state includes types @@ -519,12 +550,6 @@ class OpenMPIRBuilder { /// Add attributes known for \p FnID to \p Fn. void addAttributes(omp::RuntimeFunction FnID, Function &Fn); - /// Type used throughout for insertion points. - using InsertPointTy = IRBuilder<>::InsertPoint; - - /// Type used to represent an insertion point or an error value. - using InsertPointOrErrorTy = Expected; - /// Get the create a name using the platform specific separators. /// \param Parts parts of the final name that needs separation /// The created name has a first separator between the first and second part @@ -729,6 +754,35 @@ class OpenMPIRBuilder { LoopBodyGenCallbackTy BodyGenCB, Value *TripCount, const Twine &Name = "loop"); + /// Generator for the control flow structure of an OpenMP canonical loops if + /// the parent directive has an `inscan` modifier specified. + /// If the `inscan` modifier is specified, the region of the parent is + /// expected to have a `scan` directive. Based on the clauses in + /// scan directive, the body of the loop is split into two loops: Input loop + /// and Scan Loop. Input loop contains the code generated for input phase of + /// scan and Scan loop contains the code generated for scan phase of scan. + /// + /// \param Loc The insert and source location description. + /// \param BodyGenCB Callback that will generate the loop body code. + /// \param Start Value of the loop counter for the first iterations. + /// \param Stop Loop counter values past this will stop the loop. + /// \param Step Loop counter increment after each iteration; negative + /// means counting down. + /// \param IsSigned Whether Start, Stop and Step are signed integers. + /// \param InclusiveStop Whether \p Stop itself is a valid value for the loop + /// counter. + /// \param ComputeIP Insertion point for instructions computing the trip + /// count. Can be used to ensure the trip count is available + /// at the outermost loop of a loop nest. If not set, + /// defaults to the preheader of the generated loop. + /// \param Name Base name used to derive BB and instruction names. + /// + /// \returns A vector containing Loop Info of Input Loop and Scan Loop. + Expected> createCanonicalScanLoops( + const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB, + Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop, + InsertPointTy ComputeIP, const Twine &Name); + /// Calculate the trip count of a canonical loop. /// /// This allows specifying user-defined loop counter values using increment, @@ -798,13 +852,16 @@ class OpenMPIRBuilder { /// at the outermost loop of a loop nest. If not set, /// defaults to the preheader of the generated loop. /// \param Name Base name used to derive BB and instruction names. + /// \param InScan Whether loop has a scan reduction specified. /// /// \returns An object representing the created control flow structure which /// can be used for loop-associated directives. - Expected createCanonicalLoop( - const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB, - Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop, - InsertPointTy ComputeIP = {}, const Twine &Name = "loop"); + Expected + createCanonicalLoop(const LocationDescription &Loc, + LoopBodyGenCallbackTy BodyGenCB, Value *Start, + Value *Stop, Value *Step, bool IsSigned, + bool InclusiveStop, InsertPointTy ComputeIP = {}, + const Twine &Name = "loop", bool InScan = false); /// Collapse a loop nest into a single loop. /// @@ -1532,6 +1589,45 @@ class OpenMPIRBuilder { ArrayRef ReductionInfos, Function *ReduceFn, AttributeList FuncAttrs); + /// Creates the runtime call specified + /// \param Callee Function Declaration Value + /// \param Args Arguments passed to the call + /// \param Name Optional param to specify the name of the call Instruction. + /// + /// \return The Runtime call instruction created. + llvm::CallInst *emitNoUnwindRuntimeCall(llvm::FunctionCallee Callee, + ArrayRef Args, + const llvm::Twine &Name); + + /// Helper function for CreateCanonicalScanLoops to create InputLoop + /// in the firstGen and Scan Loop in the SecondGen + /// \param InputLoopGen Callback for generating the loop for input phase + /// \param ScanLoopGen Callback for generating the loop for scan phase + /// + /// \return error if any produced, else return success. + Error emitScanBasedDirectiveIR( + llvm::function_ref InputLoopGen, + llvm::function_ref ScanLoopGen); + + /// Creates the basic blocks required for scan reduction. + void createScanBBs(); + + /// Dynamically allocates the buffer needed for scan reduction. + /// \param AllocaIP The IP where possibly-shared pointer of buffer needs to be + /// declared. \param ScanVars Scan Variables. + /// + /// \return error if any produced, else return success. + Error emitScanBasedDirectiveDeclsIR(InsertPointTy AllocaIP, + ArrayRef ScanVars, + ArrayRef ScanVarsType); + + /// Copies the result back to the reduction variable. + /// \param ReductionInfos Array type containing the ReductionOps. + /// + /// \return error if any produced, else return success. + Error emitScanBasedDirectiveFinalsIR( + SmallVector ReductionInfos); + /// This function emits a helper that gathers Reduce lists from the first /// lane of every active warp to lanes in the first warp. /// @@ -2179,7 +2275,6 @@ class OpenMPIRBuilder { // block, if possible, or else at the end of the function. Also add a branch // from current block to BB if current block does not have a terminator. void emitBlock(BasicBlock *BB, Function *CurFn, bool IsFinished = false); - /// Emits code for OpenMP 'if' clause using specified \a BodyGenCallbackTy /// Here is the logic: /// if (Cond) { @@ -2607,6 +2702,41 @@ class OpenMPIRBuilder { BodyGenCallbackTy BodyGenCB, FinalizeCallbackTy FiniCB, Value *Filter); + /// This function performs the scan reduction of the values updated in + /// the input phase. The reduction logic needs to be emitted between input + /// and scan loop returned by `CreateCanonicalScanLoops`. The following + /// is the code that is generated, `buffer` and `span` are expected to be + /// populated before executing the generated code. + /// + /// for (int k = 0; k != ceil(log2(span)); ++k) { + /// i=pow(2,k) + /// for (size cnt = last_iter; cnt >= i; --cnt) + /// buffer[cnt] op= buffer[cnt-i]; + /// } + /// \param Loc The insert and source location description. + /// \param ReductionInfos Array type containing the ReductionOps. + /// + /// \returns The insertion position *after* the masked. + InsertPointOrErrorTy emitScanReduction( + const LocationDescription &Loc, + SmallVector ReductionInfos); + + /// This directive split and directs the control flow to input phase + /// blocks or scan phase blocks based on 1. whether input loop or scan loop + /// is executed, 2. whether exclusive or inclusive scan is used. + /// + /// \param Loc The insert and source location description. + /// \param AllocaIP The IP where the temporary buffer for scan reduction + // needs to be allocated. + /// \param ScanVars Scan Variables. + /// \param IsInclusive Whether it is an inclusive or exclusive scan. + /// + /// \returns The insertion position *after* the scan. + InsertPointOrErrorTy createScan(const LocationDescription &Loc, + InsertPointTy AllocaIP, + ArrayRef ScanVars, + ArrayRef ScanVarsType, + bool IsInclusive); /// Generator for '#omp critical' /// /// \param Loc The insert and source location description. diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index 28662efc02882..23a548a1a60d0 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -59,6 +59,8 @@ #include "llvm/Transforms/Utils/LoopPeel.h" #include "llvm/Transforms/Utils/UnrollLoop.h" +#include +#include #include #include @@ -3981,6 +3983,333 @@ OpenMPIRBuilder::createMasked(const LocationDescription &Loc, /*Conditional*/ true, /*hasFinalize*/ true); } +llvm::CallInst * +OpenMPIRBuilder::emitNoUnwindRuntimeCall(llvm::FunctionCallee Callee, + ArrayRef Args, + const llvm::Twine &Name) { + llvm::CallInst *Call = Builder.CreateCall( + Callee, Args, SmallVector(), Name); + Call->setDoesNotThrow(); + return Call; +} + +// Expects input basic block is dominated by BeforeScanBB. +// Once Scan directive is encountered, the code after scan directive should be +// dominated by AfterScanBB. Scan directive splits the code sequence to +// scan and input phase. Based on whether inclusive or exclusive +// clause is used in the scan directive and whether input loop or scan loop +// is lowered, it adds jumps to input and scan phase. First Scan loop is the +// input loop and second is the scan loop. The code generated handles only +// inclusive scans now. +OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createScan( + const LocationDescription &Loc, InsertPointTy AllocaIP, + ArrayRef ScanVars, ArrayRef ScanVarsType, + bool IsInclusive) { + if (ScanInfo.OMPFirstScanLoop) { + llvm::Error Err = + emitScanBasedDirectiveDeclsIR(AllocaIP, ScanVars, ScanVarsType); + if (Err) { + return Err; + } + } + if (!updateToLocation(Loc)) + return Loc.IP; + + llvm::Value *IV = ScanInfo.IV; + + if (ScanInfo.OMPFirstScanLoop) { + // Emit buffer[i] = red; at the end of the input phase. + for (size_t i = 0; i < ScanVars.size(); i++) { + Value *BuffPtr = ScanInfo.ScanBuffPtrs[ScanVars[i]]; + Value *Buff = Builder.CreateLoad(Builder.getPtrTy(), BuffPtr); + Type *DestTy = ScanVarsType[i]; + Value *Val = Builder.CreateInBoundsGEP(DestTy, Buff, IV, "arrayOffset"); + Value *Src = Builder.CreateLoad(DestTy, ScanVars[i]); + + Builder.CreateStore(Src, Val); + } + } + Builder.CreateBr(ScanInfo.OMPScanLoopExit); + emitBlock(ScanInfo.OMPScanDispatch, Builder.GetInsertBlock()->getParent()); + + if (!ScanInfo.OMPFirstScanLoop) { + IV = ScanInfo.IV; + // Emit red = buffer[i]; at the entrance to the scan phase. + // TODO: if exclusive scan, the red = buffer[i-1] needs to be updated. + for (size_t i = 0; i < ScanVars.size(); i++) { + Value *BuffPtr = ScanInfo.ScanBuffPtrs[ScanVars[i]]; + Value *Buff = Builder.CreateLoad(Builder.getPtrTy(), BuffPtr); + Type *DestTy = ScanVarsType[i]; + Value *SrcPtr = + Builder.CreateInBoundsGEP(DestTy, Buff, IV, "arrayOffset"); + Value *Src = Builder.CreateLoad(DestTy, SrcPtr); + Builder.CreateStore(Src, ScanVars[i]); + } + } + + // TODO: Update it to CreateBr and remove dead blocks + llvm::Value *CmpI = Builder.getInt1(true); + if (ScanInfo.OMPFirstScanLoop == IsInclusive) { + Builder.CreateCondBr(CmpI, ScanInfo.OMPBeforeScanBlock, + ScanInfo.OMPAfterScanBlock); + } else { + Builder.CreateCondBr(CmpI, ScanInfo.OMPAfterScanBlock, + ScanInfo.OMPBeforeScanBlock); + } + emitBlock(ScanInfo.OMPAfterScanBlock, Builder.GetInsertBlock()->getParent()); + Builder.SetInsertPoint(ScanInfo.OMPAfterScanBlock); + return Builder.saveIP(); +} + +Error OpenMPIRBuilder::emitScanBasedDirectiveDeclsIR( + InsertPointTy AllocaIP, ArrayRef ScanVars, + ArrayRef ScanVarsType) { + + Builder.restoreIP(AllocaIP); + // Create the shared pointer at alloca IP. + for (size_t i = 0; i < ScanVars.size(); i++) { + llvm::Value *BuffPtr = + Builder.CreateAlloca(Builder.getPtrTy(), nullptr, "vla"); + ScanInfo.ScanBuffPtrs[ScanVars[i]] = BuffPtr; + } + + // Allocate temporary buffer by master thread + auto BodyGenCB = [&](InsertPointTy AllocaIP, + InsertPointTy CodeGenIP) -> Error { + Builder.restoreIP(CodeGenIP); + Value *AllocSpan = Builder.CreateAdd(ScanInfo.Span, Builder.getInt32(1)); + for (size_t i = 0; i < ScanVars.size(); i++) { + Type *IntPtrTy = Builder.getInt32Ty(); + Constant *Allocsize = ConstantExpr::getSizeOf(ScanVarsType[i]); + Allocsize = ConstantExpr::getTruncOrBitCast(Allocsize, IntPtrTy); + Value *Buff = Builder.CreateMalloc(IntPtrTy, ScanVarsType[i], Allocsize, + AllocSpan, nullptr, "arr"); + Builder.CreateStore(Buff, ScanInfo.ScanBuffPtrs[ScanVars[i]]); + } + return Error::success(); + }; + // TODO: Perform finalization actions for variables. This has to be + // called for variables which have destructors/finalizers. + auto FiniCB = [&](InsertPointTy CodeGenIP) { return llvm::Error::success(); }; + + Builder.SetInsertPoint(ScanInfo.OMPScanInit->getTerminator()); + llvm::Value *FilterVal = Builder.getInt32(0); + llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = + createMasked(Builder.saveIP(), BodyGenCB, FiniCB, FilterVal); + + if (!AfterIP) + return AfterIP.takeError(); + Builder.restoreIP(*AfterIP); + BasicBlock *InputBB = Builder.GetInsertBlock(); + if (InputBB->getTerminator()) + Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator()); + AfterIP = createBarrier(Builder.saveIP(), llvm::omp::OMPD_barrier); + if (!AfterIP) + return AfterIP.takeError(); + Builder.restoreIP(*AfterIP); + + return Error::success(); +} + +Error OpenMPIRBuilder::emitScanBasedDirectiveFinalsIR( + SmallVector ReductionInfos) { + auto BodyGenCB = [&](InsertPointTy AllocaIP, + InsertPointTy CodeGenIP) -> Error { + Builder.restoreIP(CodeGenIP); + for (ReductionInfo RedInfo : ReductionInfos) { + Value *PrivateVar = RedInfo.PrivateVariable; + Value *OrigVar = RedInfo.Variable; + Value *BuffPtr = ScanInfo.ScanBuffPtrs[PrivateVar]; + Value *Buff = Builder.CreateLoad(Builder.getPtrTy(), BuffPtr); + + Type *SrcTy = RedInfo.ElementType; + Value *Val = + Builder.CreateInBoundsGEP(SrcTy, Buff, ScanInfo.Span, "arrayOffset"); + Value *Src = Builder.CreateLoad(SrcTy, Val); + + Builder.CreateStore(Src, OrigVar); + Builder.CreateFree(Buff); + } + return Error::success(); + }; + // TODO: Perform finalization actions for variables. This has to be + // called for variables which have destructors/finalizers. + auto FiniCB = [&](InsertPointTy CodeGenIP) { return llvm::Error::success(); }; + + Builder.SetInsertPoint(ScanInfo.OMPScanFinish->getTerminator()); + llvm::Value *FilterVal = Builder.getInt32(0); + llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = + createMasked(Builder.saveIP(), BodyGenCB, FiniCB, FilterVal); + + if (!AfterIP) + return AfterIP.takeError(); + Builder.restoreIP(*AfterIP); + BasicBlock *InputBB = Builder.GetInsertBlock(); + if (InputBB->getTerminator()) + Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator()); + AfterIP = createBarrier(Builder.saveIP(), llvm::omp::OMPD_barrier); + if (!AfterIP) + return AfterIP.takeError(); + Builder.restoreIP(*AfterIP); + return Error::success(); +} + +OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitScanReduction( + const LocationDescription &Loc, + SmallVector ReductionInfos) { + + if (!updateToLocation(Loc)) + return Loc.IP; + auto BodyGenCB = [&](InsertPointTy AllocaIP, + InsertPointTy CodeGenIP) -> Error { + Builder.restoreIP(CodeGenIP); + Function *CurFn = Builder.GetInsertBlock()->getParent(); + // for (int k = 0; k <= ceil(log2(n)); ++k) + llvm::BasicBlock *LoopBB = + BasicBlock::Create(CurFn->getContext(), "omp.outer.log.scan.body"); + llvm::BasicBlock *ExitBB = + splitBB(Builder, false, "omp.outer.log.scan.exit"); + llvm::Function *F = llvm::Intrinsic::getOrInsertDeclaration( + Builder.GetInsertBlock()->getModule(), + (llvm::Intrinsic::ID)llvm::Intrinsic::log2, Builder.getDoubleTy()); + llvm::BasicBlock *InputBB = Builder.GetInsertBlock(); + llvm::Value *Arg = + Builder.CreateUIToFP(ScanInfo.Span, Builder.getDoubleTy()); + llvm::Value *LogVal = emitNoUnwindRuntimeCall(F, Arg, ""); + F = llvm::Intrinsic::getOrInsertDeclaration( + Builder.GetInsertBlock()->getModule(), + (llvm::Intrinsic::ID)llvm::Intrinsic::ceil, Builder.getDoubleTy()); + LogVal = emitNoUnwindRuntimeCall(F, LogVal, ""); + LogVal = Builder.CreateFPToUI(LogVal, Builder.getInt32Ty()); + llvm::Value *NMin1 = Builder.CreateNUWSub( + ScanInfo.Span, llvm::ConstantInt::get(ScanInfo.Span->getType(), 1)); + Builder.SetInsertPoint(InputBB); + Builder.CreateBr(LoopBB); + emitBlock(LoopBB, Builder.GetInsertBlock()->getParent()); + Builder.SetInsertPoint(LoopBB); + + PHINode *Counter = Builder.CreatePHI(Builder.getInt32Ty(), 2); + //// size pow2k = 1; + PHINode *Pow2K = Builder.CreatePHI(Builder.getInt32Ty(), 2); + Counter->addIncoming(llvm::ConstantInt::get(Builder.getInt32Ty(), 0), + InputBB); + Pow2K->addIncoming(llvm::ConstantInt::get(Builder.getInt32Ty(), 1), + InputBB); + //// for (size i = n - 1; i >= 2 ^ k; --i) + //// tmp[i] op= tmp[i-pow2k]; + llvm::BasicBlock *InnerLoopBB = + BasicBlock::Create(CurFn->getContext(), "omp.inner.log.scan.body"); + llvm::BasicBlock *InnerExitBB = + BasicBlock::Create(CurFn->getContext(), "omp.inner.log.scan.exit"); + llvm::Value *CmpI = Builder.CreateICmpUGE(NMin1, Pow2K); + Builder.CreateCondBr(CmpI, InnerLoopBB, InnerExitBB); + emitBlock(InnerLoopBB, Builder.GetInsertBlock()->getParent()); + Builder.SetInsertPoint(InnerLoopBB); + auto *IVal = Builder.CreatePHI(Builder.getInt32Ty(), 2); + IVal->addIncoming(NMin1, LoopBB); + for (ReductionInfo RedInfo : ReductionInfos) { + Value *ReductionVal = RedInfo.PrivateVariable; + Value *BuffPtr = ScanInfo.ScanBuffPtrs[ReductionVal]; + Value *Buff = Builder.CreateLoad(Builder.getPtrTy(), BuffPtr); + Type *DestTy = RedInfo.ElementType; + Value *IV = Builder.CreateAdd(IVal, Builder.getInt32(1)); + Value *LHSPtr = + Builder.CreateInBoundsGEP(DestTy, Buff, IV, "arrayOffset"); + Value *OffsetIval = Builder.CreateNUWSub(IV, Pow2K); + Value *RHSPtr = + Builder.CreateInBoundsGEP(DestTy, Buff, OffsetIval, "arrayOffset"); + Value *LHS = Builder.CreateLoad(DestTy, LHSPtr); + Value *RHS = Builder.CreateLoad(DestTy, RHSPtr); + llvm::Value *Result; + InsertPointOrErrorTy AfterIP = + RedInfo.ReductionGen(Builder.saveIP(), LHS, RHS, Result); + if (!AfterIP) + return AfterIP.takeError(); + Builder.CreateStore(Result, LHSPtr); + } + llvm::Value *NextIVal = Builder.CreateNUWSub( + IVal, llvm::ConstantInt::get(Builder.getInt32Ty(), 1)); + IVal->addIncoming(NextIVal, Builder.GetInsertBlock()); + CmpI = Builder.CreateICmpUGE(NextIVal, Pow2K); + Builder.CreateCondBr(CmpI, InnerLoopBB, InnerExitBB); + emitBlock(InnerExitBB, Builder.GetInsertBlock()->getParent()); + llvm::Value *Next = Builder.CreateNUWAdd( + Counter, llvm::ConstantInt::get(Counter->getType(), 1)); + Counter->addIncoming(Next, Builder.GetInsertBlock()); + // pow2k <<= 1; + llvm::Value *NextPow2K = Builder.CreateShl(Pow2K, 1, "", /*HasNUW=*/true); + Pow2K->addIncoming(NextPow2K, Builder.GetInsertBlock()); + llvm::Value *Cmp = Builder.CreateICmpNE(Next, LogVal); + Builder.CreateCondBr(Cmp, LoopBB, ExitBB); + Builder.SetInsertPoint(ExitBB->getFirstInsertionPt()); + return Error::success(); + }; + + // TODO: Perform finalization actions for variables. This has to be + // called for variables which have destructors/finalizers. + auto FiniCB = [&](InsertPointTy CodeGenIP) { return llvm::Error::success(); }; + + llvm::Value *FilterVal = Builder.getInt32(0); + llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = + createMasked(Builder.saveIP(), BodyGenCB, FiniCB, FilterVal); + + if (!AfterIP) + return AfterIP.takeError(); + Builder.restoreIP(*AfterIP); + AfterIP = createBarrier(Builder.saveIP(), llvm::omp::OMPD_barrier); + + if (!AfterIP) + return AfterIP.takeError(); + Builder.restoreIP(*AfterIP); + Error Err = emitScanBasedDirectiveFinalsIR(ReductionInfos); + if (Err) { + return Err; + } + + return AfterIP; +} + +Error OpenMPIRBuilder::emitScanBasedDirectiveIR( + llvm::function_ref InputLoopGen, + llvm::function_ref ScanLoopGen) { + + { + // Emit loop with input phase: + // for (i: 0..) { + // ; + // buffer[i] = red; + // } + ScanInfo.OMPFirstScanLoop = true; + auto Result = InputLoopGen(); + if (Result) + return Result; + } + { + // Emit loop with scan phase: + // for (i: 0..) { + // red = buffer[i]; + // ; + // } + ScanInfo.OMPFirstScanLoop = false; + auto Result = ScanLoopGen(Builder.saveIP()); + if (Result) + return Result; + } + return Error::success(); +} + +void OpenMPIRBuilder::createScanBBs() { + Function *Fun = Builder.GetInsertBlock()->getParent(); + ScanInfo.OMPScanDispatch = + BasicBlock::Create(Fun->getContext(), "omp.inscan.dispatch"); + ScanInfo.OMPAfterScanBlock = + BasicBlock::Create(Fun->getContext(), "omp.after.scan.bb"); + ScanInfo.OMPBeforeScanBlock = + BasicBlock::Create(Fun->getContext(), "omp.before.scan.bb"); + ScanInfo.OMPScanLoopExit = + BasicBlock::Create(Fun->getContext(), "omp.scan.loop.exit"); +} + CanonicalLoopInfo *OpenMPIRBuilder::createLoopSkeleton( DebugLoc DL, Value *TripCount, Function *F, BasicBlock *PreInsertBefore, BasicBlock *PostInsertBefore, const Twine &Name) { @@ -4078,10 +4407,95 @@ OpenMPIRBuilder::createCanonicalLoop(const LocationDescription &Loc, return CL; } +Expected> +OpenMPIRBuilder::createCanonicalScanLoops( + const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB, + Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop, + InsertPointTy ComputeIP, const Twine &Name) { + LocationDescription ComputeLoc = + ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc; + updateToLocation(ComputeLoc); + + Value *TripCount = calculateCanonicalLoopTripCount( + ComputeLoc, Start, Stop, Step, IsSigned, InclusiveStop, Name); + ScanInfo.Span = TripCount; + ScanInfo.OMPScanInit = splitBB(Builder, true, "scan.init"); + Builder.SetInsertPoint(ScanInfo.OMPScanInit); + + auto BodyGen = [=](InsertPointTy CodeGenIP, Value *IV) { + /// The control of the loopbody of following structure: + /// + /// InputBlock + /// | + /// ContinueBlock + /// + /// is transformed to: + /// + /// InputBlock + /// | + /// OMPScanDispatch + /// + /// OMPBeforeScanBlock + /// | + /// OMPScanLoopExit + /// | + /// ContinueBlock + /// + /// OMPBeforeScanBlock dominates the control flow of code generated until + /// scan directive is encountered and OMPAfterScanBlock dominates the + /// control flow of code generated after scan is encountered. The successor + /// of OMPScanDispatch can be OMPBeforeScanBlock or OMPAfterScanBlock based + /// on 1.whether it is in Input phase or Scan Phase , 2. whether it is an + /// exclusive or inclusive scan. + Builder.restoreIP(CodeGenIP); + ScanInfo.IV = IV; + createScanBBs(); + BasicBlock *InputBlock = Builder.GetInsertBlock(); + Instruction *Terminator = InputBlock->getTerminator(); + assert(Terminator->getNumSuccessors() == 1); + BasicBlock *ContinueBlock = Terminator->getSuccessor(0); + Terminator->setSuccessor(0, ScanInfo.OMPScanDispatch); + emitBlock(ScanInfo.OMPBeforeScanBlock, + Builder.GetInsertBlock()->getParent()); + Builder.CreateBr(ScanInfo.OMPScanLoopExit); + emitBlock(ScanInfo.OMPScanLoopExit, Builder.GetInsertBlock()->getParent()); + Builder.CreateBr(ContinueBlock); + Builder.SetInsertPoint(ScanInfo.OMPBeforeScanBlock->getFirstInsertionPt()); + return BodyGenCB(Builder.saveIP(), IV); + }; + + SmallVector Result; + const auto &&InputLoopGen = [&]() -> Error { + auto LoopInfo = + createCanonicalLoop(Builder.saveIP(), BodyGen, Start, Stop, Step, + IsSigned, InclusiveStop, ComputeIP, Name, true); + if (!LoopInfo) + return LoopInfo.takeError(); + Result.push_back(*LoopInfo); + Builder.restoreIP((*LoopInfo)->getAfterIP()); + return Error::success(); + }; + const auto &&ScanLoopGen = [&](LocationDescription Loc) -> Error { + auto LoopInfo = + createCanonicalLoop(Loc, BodyGen, Start, Stop, Step, IsSigned, + InclusiveStop, ComputeIP, Name, true); + if (!LoopInfo) + return LoopInfo.takeError(); + Result.push_back(*LoopInfo); + Builder.restoreIP((*LoopInfo)->getAfterIP()); + ScanInfo.OMPScanFinish = Builder.GetInsertBlock(); + return Error::success(); + }; + Error Err = emitScanBasedDirectiveIR(InputLoopGen, ScanLoopGen); + if (Err) { + return Err; + } + return Result; +} + Value *OpenMPIRBuilder::calculateCanonicalLoopTripCount( const LocationDescription &Loc, Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop, const Twine &Name) { - // Consider the following difficulties (assuming 8-bit signed integers): // * Adding \p Step to the loop counter which passes \p Stop may overflow: // DO I = 1, 100, 50 @@ -4141,7 +4555,7 @@ Value *OpenMPIRBuilder::calculateCanonicalLoopTripCount( Expected OpenMPIRBuilder::createCanonicalLoop( const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB, Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop, - InsertPointTy ComputeIP, const Twine &Name) { + InsertPointTy ComputeIP, const Twine &Name, bool InScan) { LocationDescription ComputeLoc = ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc; @@ -4152,6 +4566,9 @@ Expected OpenMPIRBuilder::createCanonicalLoop( Builder.restoreIP(CodeGenIP); Value *Span = Builder.CreateMul(IV, Step); Value *IndVar = Builder.CreateAdd(Span, Start); + if (InScan) { + ScanInfo.IV = IndVar; + } return BodyGenCB(Builder.saveIP(), IndVar); }; LocationDescription LoopLoc = ComputeIP.isSet() ? Loc.IP : Builder.saveIP(); diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp index 2d3d318be7ff1..4cc312c1b0f1c 100644 --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -1440,6 +1440,16 @@ TEST_F(OpenMPIRBuilderTest, CanonicalLoopSimple) { EXPECT_EQ(&Loop->getAfter()->front(), RetInst); } +void createScan(llvm::Value *scanVar, llvm::Type *scanType, + OpenMPIRBuilder &OMPBuilder, IRBuilder<> &Builder, + OpenMPIRBuilder::LocationDescription Loc, + OpenMPIRBuilder::InsertPointTy &allocaIP) { + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + ASSERT_EXPECTED_INIT( + InsertPointTy, retIp, + OMPBuilder.createScan(Loc, allocaIP, {scanVar}, {scanType}, true)); + Builder.restoreIP(retIp); +} TEST_F(OpenMPIRBuilderTest, CanonicalLoopTripCount) { OpenMPIRBuilder OMPBuilder(*M); @@ -5336,6 +5346,62 @@ TEST_F(OpenMPIRBuilderTest, CreateReductions) { EXPECT_TRUE(findGEPZeroOne(ReductionFn->getArg(1), FirstRHS, SecondRHS)); } +TEST_F(OpenMPIRBuilderTest, ScanReduction) { + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.initialize(); + IRBuilder<> Builder(BB); + OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL}); + Value *TripCount = F->getArg(0); + Type *LCTy = TripCount->getType(); + Value *StartVal = ConstantInt::get(LCTy, 1); + Value *StopVal = ConstantInt::get(LCTy, 100); + Value *Step = ConstantInt::get(LCTy, 1); + auto allocaIP = Builder.saveIP(); + + llvm::Value *scanVar = Builder.CreateAlloca(Builder.getFloatTy()); + llvm::Value *origVar = Builder.CreateAlloca(Builder.getFloatTy()); + unsigned NumBodiesGenerated = 0; + auto LoopBodyGenCB = [&](InsertPointTy CodeGenIP, llvm::Value *LC) { + NumBodiesGenerated += 1; + Builder.restoreIP(CodeGenIP); + createScan(scanVar, Builder.getFloatTy(), OMPBuilder, Builder, Loc, + allocaIP); + return Error::success(); + }; + SmallVector Loops; + ASSERT_EXPECTED_INIT(SmallVector, loopsVec, + OMPBuilder.createCanonicalScanLoops( + Loc, LoopBodyGenCB, StartVal, StopVal, Step, false, + false, Builder.saveIP(), "scan")); + Loops = loopsVec; + EXPECT_EQ(Loops.size(), 2U); + CanonicalLoopInfo *InputLoop = Loops.front(); + CanonicalLoopInfo *ScanLoop = Loops.back(); + Builder.restoreIP(ScanLoop->getAfterIP()); + InputLoop->assertOK(); + ScanLoop->assertOK(); + + EXPECT_EQ(InputLoop->getPreheader()->getSinglePredecessor(), + &F->getEntryBlock()); + EXPECT_EQ(ScanLoop->getAfter(), Builder.GetInsertBlock()); + EXPECT_EQ(NumBodiesGenerated, 2U); + SmallVector reductionInfos = { + {Builder.getFloatTy(), origVar, scanVar, + /*EvaluationKind=*/OpenMPIRBuilder::EvalKind::Scalar, sumReduction, + /*ReductionGenClang=*/nullptr, sumAtomicReduction}}; + OpenMPIRBuilder::LocationDescription RedLoc({InputLoop->getAfterIP(), DL}); + llvm::BasicBlock *Cont = splitBB(Builder, false, "omp.scan.loop.cont"); + ASSERT_EXPECTED_INIT(InsertPointTy, retIp, + OMPBuilder.emitScanReduction(RedLoc, reductionInfos)); + Builder.restoreIP(retIp); + Builder.CreateBr(Cont); + SmallVector MaskedCalls; + findCalls(F, omp::RuntimeFunction::OMPRTL___kmpc_masked, OMPBuilder, + MaskedCalls); + ASSERT_EQ(MaskedCalls.size(), 1u); +} + TEST_F(OpenMPIRBuilderTest, CreateTwoReductions) { using InsertPointTy = OpenMPIRBuilder::InsertPointTy; OpenMPIRBuilder OMPBuilder(*M); diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 8d1cc9b10a950..c84789e634101 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -47,6 +47,10 @@ using namespace mlir; +llvm::SmallDenseMap ReductionVarToType; +llvm::OpenMPIRBuilder::InsertPointTy + parallelAllocaIP; // TODO: change this alloca IP to point to originalvar + // allocaIP. ReductionDecl need to be linked to scan var. namespace { static llvm::omp::ScheduleKind convertToScheduleKind(std::optional schedKind) { @@ -86,7 +90,9 @@ class OpenMPLoopInfoStackFrame : public LLVM::ModuleTranslation::StackFrameBase { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPLoopInfoStackFrame) - llvm::CanonicalLoopInfo *loopInfo = nullptr; + // For constructs like scan, one Loop info frame can contain multiple + // Canonical Loops + SmallVector loopInfos; }; /// Custom error class to signal translation errors that don't need reporting, @@ -169,6 +175,10 @@ static LogicalResult checkImplementationStatus(Operation &op) { if (op.getDistScheduleChunkSize()) result = todo("dist_schedule with chunk_size"); }; + auto checkExclusive = [&todo](auto op, LogicalResult &result) { + if (!op.getExclusiveVars().empty()) + result = todo("exclusive"); + }; auto checkHint = [](auto op, LogicalResult &) { if (op.getHint()) op.emitWarning("hint clause discarded"); @@ -232,8 +242,8 @@ static LogicalResult checkImplementationStatus(Operation &op) { op.getReductionSyms()) result = todo("reduction"); if (op.getReductionMod() && - op.getReductionMod().value() != omp::ReductionModifier::defaultmod) - result = todo("reduction with modifier"); + op.getReductionMod().value() == omp::ReductionModifier::task) + result = todo("reduction with task modifier"); }; auto checkTaskReduction = [&todo](auto op, LogicalResult &result) { if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() || @@ -253,6 +263,7 @@ static LogicalResult checkImplementationStatus(Operation &op) { checkOrder(op, result); }) .Case([&](omp::OrderedRegionOp op) { checkParLevelSimd(op, result); }) + .Case([&](omp::ScanOp op) { checkExclusive(op, result); }) .Case([&](omp::SectionsOp op) { checkAllocate(op, result); checkPrivate(op, result); @@ -382,15 +393,15 @@ findAllocaInsertPoint(llvm::IRBuilderBase &builder, /// Find the loop information structure for the loop nest being translated. It /// will return a `null` value unless called from the translation function for /// a loop wrapper operation after successfully translating its body. -static llvm::CanonicalLoopInfo * -findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation) { - llvm::CanonicalLoopInfo *loopInfo = nullptr; +static SmallVector +findCurrentLoopInfos(LLVM::ModuleTranslation &moduleTranslation) { + SmallVector loopInfos; moduleTranslation.stackWalk( [&](OpenMPLoopInfoStackFrame &frame) { - loopInfo = frame.loopInfo; + loopInfos = frame.loopInfos; return WalkResult::interrupt(); }); - return loopInfo; + return loopInfos; } /// Converts the given region that appears within an OpenMP dialect operation to @@ -1133,6 +1144,11 @@ initReductionVars(OP op, ArrayRef reductionArgs, // variables. Although this could be done after allocas, we don't want to mess // up with the alloca insertion point. for (unsigned i = 0; i < op.getNumReductionVars(); ++i) { + + llvm::Type *reductionType = + moduleTranslation.convertType(reductionDecls[i].getType()); + ReductionVarToType[privateReductionVariables[i]] = reductionType; + SmallVector phis; // map block argument to initializer region @@ -1206,9 +1222,11 @@ static void collectReductionInfo( atomicGen = owningAtomicReductionGens[i]; llvm::Value *variable = moduleTranslation.lookupValue(loop.getReductionVars()[i]); + llvm::Type *reductionType = + moduleTranslation.convertType(reductionDecls[i].getType()); + ReductionVarToType[privateReductionVariables[i]] = reductionType; reductionInfos.push_back( - {moduleTranslation.convertType(reductionDecls[i].getType()), variable, - privateReductionVariables[i], + {reductionType, variable, privateReductionVariables[i], /*EvaluationKind=*/llvm::OpenMPIRBuilder::EvalKind::Scalar, owningReductionGens[i], /*ReductionGenClang=*/nullptr, atomicGen}); @@ -2342,27 +2360,60 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, if (failed(handleError(regionBlock, opInst))) return failure(); - builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin()); - llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation); - - llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP = - ompBuilder->applyWorkshareLoop( - ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier, - convertToScheduleKind(schedule), chunk, isSimd, - scheduleMod == omp::ScheduleModifier::monotonic, - scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered, - workshareLoopType); - - if (failed(handleError(wsloopIP, opInst))) - return failure(); - - // Process the reductions if required. - if (failed(createReductionsAndCleanup( - wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls, - privateReductionVariables, isByRef, wsloopOp.getNowait(), - /*isTeamsReduction=*/false))) - return failure(); + SmallVector loopInfos = + findCurrentLoopInfos(moduleTranslation); + auto inputLoopFinishIp = loopInfos.front()->getAfterIP(); + bool isInScanRegion = + wsloopOp.getReductionMod() && (wsloopOp.getReductionMod().value() == + mlir::omp::ReductionModifier::inscan); + if (isInScanRegion) { + builder.restoreIP(inputLoopFinishIp); + SmallVector owningReductionGens; + SmallVector owningAtomicReductionGens; + SmallVector reductionInfos; + collectReductionInfo(wsloopOp, builder, moduleTranslation, reductionDecls, + owningReductionGens, owningAtomicReductionGens, + privateReductionVariables, reductionInfos); + llvm::BasicBlock *cont = splitBB(builder, false, "omp.scan.loop.cont"); + llvm::OpenMPIRBuilder::InsertPointOrErrorTy redIP = + ompBuilder->emitScanReduction(builder.saveIP(), reductionInfos); + if (failed(handleError(redIP, opInst))) + return failure(); + builder.restoreIP(*redIP); + builder.CreateBr(cont); + } + for (llvm::CanonicalLoopInfo *loopInfo : loopInfos) { + llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP = + ompBuilder->applyWorkshareLoop( + ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier, + convertToScheduleKind(schedule), chunk, isSimd, + scheduleMod == omp::ScheduleModifier::monotonic, + scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered, + workshareLoopType); + + if (failed(handleError(wsloopIP, opInst))) + return failure(); + } + builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin()); + if (isInScanRegion) { + SmallVector reductionRegions; + llvm::transform(reductionDecls, std::back_inserter(reductionRegions), + [](omp::DeclareReductionOp reductionDecl) { + return &reductionDecl.getCleanupRegion(); + }); + if (failed(inlineOmpRegionCleanup( + reductionRegions, privateReductionVariables, moduleTranslation, + builder, "omp.reduction.cleanup"))) + return failure(); + } else { + // Process the reductions if required. + if (failed(createReductionsAndCleanup( + wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls, + privateReductionVariables, isByRef, wsloopOp.getNowait(), + /*isTeamsReduction=*/false))) + return failure(); + } return cleanupPrivateVars(builder, moduleTranslation, wsloopOp.getLoc(), privateVarsInfo.llvmVars, privateVarsInfo.privatizers); @@ -2528,6 +2579,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder::InsertPointTy allocaIP = findAllocaInsertPoint(builder, moduleTranslation); + parallelAllocaIP = allocaIP; llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = @@ -2553,6 +2605,64 @@ convertOrderKind(std::optional o) { llvm_unreachable("Unknown ClauseOrderKind kind"); } +static LogicalResult +convertOmpScan(Operation &opInst, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + if (failed(checkImplementationStatus(opInst))) + return failure(); + auto scanOp = cast(opInst); + bool isInclusive = scanOp.hasInclusiveVars(); + SmallVector llvmScanVars; + SmallVector llvmScanVarsType; + mlir::OperandRange mlirScanVars = scanOp.getInclusiveVars(); + if (!isInclusive) + mlirScanVars = scanOp.getExclusiveVars(); + for (auto val : mlirScanVars) { + llvm::Value *llvmVal = moduleTranslation.lookupValue(val); + llvmScanVars.push_back(llvmVal); + llvmScanVarsType.push_back(ReductionVarToType[llvmVal]); + val.getDefiningOp(); + } + auto parallelOp = scanOp->getParentOfType(); + if (!parallelOp) { + return failure(); + } + llvm::OpenMPIRBuilder::InsertPointTy allocaIP = parallelAllocaIP; + llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); + llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = + moduleTranslation.getOpenMPBuilder()->createScan( + ompLoc, allocaIP, llvmScanVars, llvmScanVarsType, isInclusive); + if (failed(handleError(afterIP, opInst))) + return failure(); + + builder.restoreIP(*afterIP); + + // TODO: The argument of LoopnestOp is stored into the index variable and this + // variable is used across scan operation. However that makes the mlir + // invalid.(`Intra-iteration dependences from a statement in the structured + // block sequence that precede a scan directive to a statement in the + // structured block sequence that follows a scan directive must not exist, + // except for dependences for the list items specified in an inclusive or + // exclusive clause.`). The argument of LoopNestOp need to be loaded again + // after ScanOp again so mlir generated is valid. + auto parentOp = scanOp->getParentOp(); + auto loopOp = cast(parentOp); + if (loopOp) { + auto &firstBlock = *(scanOp->getParentRegion()->getBlocks()).begin(); + auto &ins = *(firstBlock.begin()); + if (isa(ins)) { + LLVM::StoreOp storeOp = dyn_cast(ins); + auto src = moduleTranslation.lookupValue(storeOp->getOperand(0)); + if (src == moduleTranslation.lookupValue( + (loopOp.getRegion().getArguments())[0])) { + auto dest = moduleTranslation.lookupValue(storeOp->getOperand(1)); + builder.CreateStore(src, dest); + } + } + } + return success(); +} + /// Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder. static LogicalResult convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, @@ -2626,12 +2736,15 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, return failure(); builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin()); - llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation); - ompBuilder->applySimd(loopInfo, alignedVars, - simdOp.getIfExpr() - ? moduleTranslation.lookupValue(simdOp.getIfExpr()) - : nullptr, - order, simdlen, safelen); + SmallVector loopInfos = + findCurrentLoopInfos(moduleTranslation); + for (llvm::CanonicalLoopInfo *loopInfo : loopInfos) { + ompBuilder->applySimd( + loopInfo, alignedVars, + simdOp.getIfExpr() ? moduleTranslation.lookupValue(simdOp.getIfExpr()) + : nullptr, + order, simdlen, safelen); + } return cleanupPrivateVars(builder, moduleTranslation, simdOp.getLoc(), privateVarsInfo.llvmVars, @@ -2698,16 +2811,53 @@ convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder, ompLoc.DL); computeIP = loopInfos.front()->getPreheaderIP(); } + if (auto wsloopOp = loopOp->getParentOfType()) { + bool isInScanRegion = + wsloopOp.getReductionMod() && (wsloopOp.getReductionMod().value() == + mlir::omp::ReductionModifier::inscan); + if (isInScanRegion) { + //TODO: Handle nesting if Scan loop is nested in a loop + assert(loopOp.getNumLoops() == 1); + llvm::Expected> loopResults = + ompBuilder->createCanonicalScanLoops( + loc, bodyGen, lowerBound, upperBound, step, + /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP, + "loop"); + + if (failed(handleError(loopResults, *loopOp))) + return failure(); + auto inputLoop = loopResults->front(); + auto scanLoop = loopResults->back(); + moduleTranslation.stackWalk( + [&](OpenMPLoopInfoStackFrame &frame) { + frame.loopInfos.push_back(inputLoop); + frame.loopInfos.push_back(scanLoop); + return WalkResult::interrupt(); + }); + builder.restoreIP(scanLoop->getAfterIP()); + return success(); + } else { + llvm::Expected loopResult = + ompBuilder->createCanonicalLoop( + loc, bodyGen, lowerBound, upperBound, step, + /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP); - llvm::Expected loopResult = - ompBuilder->createCanonicalLoop( - loc, bodyGen, lowerBound, upperBound, step, - /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP); + if (failed(handleError(loopResult, *loopOp))) + return failure(); - if (failed(handleError(loopResult, *loopOp))) - return failure(); + loopInfos.push_back(*loopResult); + } + } else { + llvm::Expected loopResult = + ompBuilder->createCanonicalLoop( + loc, bodyGen, lowerBound, upperBound, step, + /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP); + + if (failed(handleError(loopResult, *loopOp))) + return failure(); - loopInfos.push_back(*loopResult); + loopInfos.push_back(*loopResult); + } } // Collapse loops. Store the insertion point because LoopInfos may get @@ -2719,7 +2869,8 @@ convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder, // after applying transformations. moduleTranslation.stackWalk( [&](OpenMPLoopInfoStackFrame &frame) { - frame.loopInfo = ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {}); + frame.loopInfos.push_back( + ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {})); return WalkResult::interrupt(); }); @@ -4329,18 +4480,20 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder, bool loopNeedsBarrier = false; llvm::Value *chunk = nullptr; - llvm::CanonicalLoopInfo *loopInfo = - findCurrentLoopInfo(moduleTranslation); - llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP = - ompBuilder->applyWorkshareLoop( - ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier, - convertToScheduleKind(schedule), chunk, isSimd, - scheduleMod == omp::ScheduleModifier::monotonic, - scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered, - workshareLoopType); - - if (!wsloopIP) - return wsloopIP.takeError(); + SmallVector loopInfos = + findCurrentLoopInfos(moduleTranslation); + for (llvm::CanonicalLoopInfo *loopInfo : loopInfos) { + llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP = + ompBuilder->applyWorkshareLoop( + ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier, + convertToScheduleKind(schedule), chunk, isSimd, + scheduleMod == omp::ScheduleModifier::monotonic, + scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered, + workshareLoopType); + + if (!wsloopIP) + return wsloopIP.takeError(); + } } if (failed(cleanupPrivateVars(builder, moduleTranslation, @@ -5370,6 +5523,9 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder, .Case([&](omp::WsloopOp) { return convertOmpWsloop(*op, builder, moduleTranslation); }) + .Case([&](omp::ScanOp) { + return convertOmpScan(*op, builder, moduleTranslation); + }) .Case([&](omp::SimdOp) { return convertOmpSimd(*op, builder, moduleTranslation); }) diff --git a/mlir/test/Target/LLVMIR/openmp-reduction-scan.mlir b/mlir/test/Target/LLVMIR/openmp-reduction-scan.mlir new file mode 100644 index 0000000000000..a88c1993aebe1 --- /dev/null +++ b/mlir/test/Target/LLVMIR/openmp-reduction-scan.mlir @@ -0,0 +1,120 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +omp.declare_reduction @add_reduction_i32 : i32 init { +^bb0(%arg0: i32): + %0 = llvm.mlir.constant(0 : i32) : i32 + omp.yield(%0 : i32) +} combiner { +^bb0(%arg0: i32, %arg1: i32): + %0 = llvm.add %arg0, %arg1 : i32 + omp.yield(%0 : i32) +} +// CHECK-LABEL: @scan_reduction +llvm.func @scan_reduction() { + %0 = llvm.mlir.constant(1 : i64) : i64 + %1 = llvm.alloca %0 x i32 {bindc_name = "z"} : (i64) -> !llvm.ptr + %2 = llvm.mlir.constant(1 : i64) : i64 + %3 = llvm.alloca %2 x i32 {bindc_name = "y"} : (i64) -> !llvm.ptr + %4 = llvm.mlir.constant(1 : i64) : i64 + %5 = llvm.alloca %4 x i32 {bindc_name = "x"} : (i64) -> !llvm.ptr + %6 = llvm.mlir.constant(1 : i64) : i64 + %7 = llvm.alloca %6 x i32 {bindc_name = "k"} : (i64) -> !llvm.ptr + %8 = llvm.mlir.constant(0 : index) : i64 + %9 = llvm.mlir.constant(1 : index) : i64 + %10 = llvm.mlir.constant(100 : i32) : i32 + %11 = llvm.mlir.constant(1 : i32) : i32 + %12 = llvm.mlir.constant(0 : i32) : i32 + %13 = llvm.mlir.constant(100 : index) : i64 + %14 = llvm.mlir.addressof @_QFEa : !llvm.ptr + %15 = llvm.mlir.addressof @_QFEb : !llvm.ptr + omp.parallel { + %37 = llvm.mlir.constant(1 : i64) : i64 + %38 = llvm.alloca %37 x i32 {bindc_name = "k", pinned} : (i64) -> !llvm.ptr + %39 = llvm.mlir.constant(1 : i64) : i64 + omp.wsloop reduction(mod: inscan, @add_reduction_i32 %5 -> %arg0 : !llvm.ptr) { + omp.loop_nest (%arg1) : i32 = (%11) to (%10) inclusive step (%11) { + llvm.store %arg1, %38 : i32, !llvm.ptr + %40 = llvm.load %arg0 : !llvm.ptr -> i32 + %41 = llvm.load %38 : !llvm.ptr -> i32 + %42 = llvm.sext %41 : i32 to i64 + %50 = llvm.getelementptr %14[%42] : (!llvm.ptr, i64) -> !llvm.ptr, i32 + %51 = llvm.load %50 : !llvm.ptr -> i32 + %52 = llvm.add %40, %51 : i32 + llvm.store %52, %arg0 : i32, !llvm.ptr + omp.scan inclusive(%arg0 : !llvm.ptr) + %53 = llvm.load %arg0 : !llvm.ptr -> i32 + %54 = llvm.load %38 : !llvm.ptr -> i32 + %55 = llvm.sext %54 : i32 to i64 + %63 = llvm.getelementptr %15[%55] : (!llvm.ptr, i64) -> !llvm.ptr, i32 + llvm.store %53, %63 : i32, !llvm.ptr + omp.yield + } + } + omp.terminator + } + llvm.return +} +llvm.mlir.global internal @_QFEa() {addr_space = 0 : i32} : !llvm.array<100 x i32> { + %0 = llvm.mlir.zero : !llvm.array<100 x i32> + llvm.return %0 : !llvm.array<100 x i32> +} +llvm.mlir.global internal @_QFEb() {addr_space = 0 : i32} : !llvm.array<100 x i32> { + %0 = llvm.mlir.zero : !llvm.array<100 x i32> + llvm.return %0 : !llvm.array<100 x i32> +} +llvm.mlir.global internal constant @_QFECn() {addr_space = 0 : i32} : i32 { + %0 = llvm.mlir.constant(100 : i32) : i32 + llvm.return %0 : i32 +} +//CHECK: %[[BUFF:.+]] = alloca i32, i32 100, align 4 +//CHECK: omp_loop.preheader{{.*}}: ; preds = %omp.wsloop.region +//CHECK: omp_loop.after: ; preds = %omp_loop.exit +//CHECK: %[[LOG:.+]] = call double @llvm.log2.f64(double 1.000000e+02) #0 +//CHECK: %[[CEIL:.+]] = call double @llvm.ceil.f64(double %[[LOG]]) #0 +//CHECK: %[[UB:.+]] = fptoui double %[[CEIL]] to i32 +//CHECK: br label %omp.outer.log.scan.body +//CHECK: omp.outer.log.scan.body: ; preds = %omp.inner.log.scan.exit, %omp_loop.after +//CHECK: %[[K:.+]] = phi i32 [ 0, %omp_loop.after ], [ %[[NEXTK:.+]], %omp.inner.log.scan.exit ] +//CHECK: %[[I:.+]] = phi i32 [ 1, %omp_loop.after ], [ %[[NEXTI:.+]], %omp.inner.log.scan.exit ] +//CHECK: %[[CMP1:.+]] = icmp uge i32 99, %[[I]] +//CHECK: br i1 %[[CMP1]], label %omp.inner.log.scan.body, label %omp.inner.log.scan.exit +//CHECK: omp.inner.log.scan.exit: ; preds = %omp.inner.log.scan.body, %omp.outer.log.scan.body +//CHECK: %[[NEXTK]] = add nuw i32 %[[K]], 1 +//CHECK: %[[NEXTI]] = shl nuw i32 %[[I]], 1 +//CHECK: %[[CMP2:.+]] = icmp ne i32 %[[NEXTK]], %[[UB]] +//CHECK: br i1 %[[CMP2]], label %omp.outer.log.scan.body, label %omp.outer.log.scan.exit +//CHECK: omp.outer.log.scan.exit: ; preds = %omp.inner.log.scan.exit +//CHECK: call void @__kmpc_barrier{{.*}} +//CHECK: br label %omp.scan.loop.cont +//CHECK: omp.scan.loop.cont: ; preds = %omp.outer.log.scan.exit +//CHECK: br label %omp_loop.preheader{{.*}} +//CHECK: omp_loop.after{{.*}}: ; preds = %omp_loop.exit{{.*}} +//CHECK: %[[ARRLAST:.+]] = getelementptr inbounds i32, ptr %[[BUFF]], i32 100 +//CHECK: %[[RES:.+]] = load i32, ptr %[[ARRLAST]], align 4 +//CHECK: store i32 %[[RES]], ptr %loadgep{{.*}}, align 4 +//CHECK: omp.inscan.dispatch{{.*}}: ; preds = %omp_loop.body{{.*}} +//CHECK: store i32 0, ptr %[[REDPRIV:.+]], align 4 +//CHECK: %[[arrayOffset1:.+]] = getelementptr inbounds i32, ptr %[[BUFF]], i32 %{{.*}} +//CHECK: %[[BUFFVAL1:.+]] = load i32, ptr %[[arrayOffset1]], align 4 +//CHECK: store i32 %[[BUFFVAL1]], ptr %[[REDPRIV]], align 4 +//CHECK: omp.inner.log.scan.body: ; preds = %omp.inner.log.scan.body, %omp.outer.log.scan.body +//CHECK: %[[CNT:.+]] = phi i32 [ 99, %omp.outer.log.scan.body ], [ %[[CNTNXT:.+]], %omp.inner.log.scan.body ] +//CHECK: %[[IND1:.+]] = add i32 %[[CNT]], 1 +//CHECK: %[[IND1PTR:.+]] = getelementptr inbounds i32, ptr %[[BUFF]], i32 %[[IND1]] +//CHECK: %[[IND2:.+]] = sub nuw i32 %[[IND1]], %[[I]] +//CHECK: %[[IND2PTR:.+]] = getelementptr inbounds i32, ptr %[[BUFF]], i32 %[[IND2]] +//CHECK: %[[IND1VAL:.+]] = load i32, ptr %[[IND1PTR]], align 4 +//CHECK: %[[IND2VAL:.+]] = load i32, ptr %[[IND2PTR]], align 4 +//CHECK: %[[REDVAL:.+]] = add i32 %[[IND1VAL]], %[[IND2VAL]] +//CHECK: store i32 %[[REDVAL]], ptr %[[IND1PTR]], align 4 +//CHECK: %[[CNTNXT]] = sub nuw i32 %[[CNT]], 1 +//CHECK: %[[CMP3:.+]] = icmp uge i32 %[[CNTNXT]], %[[I]] +//CHECK: br i1 %[[CMP3]], label %omp.inner.log.scan.body, label %omp.inner.log.scan.exit +//CHECK: omp.inscan.dispatch: ; preds = %omp_loop.body +//CHECK: store i32 0, ptr %[[REDPRIV]], align 4 +//CHECK: br i1 true, label %omp.before.scan.bb, label %omp.after.scan.bb +//CHECK: omp.loop_nest.region: ; preds = %omp.before.scan.bb +//CHECK: %[[ARRAYOFFSET2:.+]] = getelementptr inbounds i32, ptr %[[BUFF]], i32 %{{.*}} +//CHECK: %[[REDPRIVVAL:.+]] = load i32, ptr %[[REDPRIV]], align 4 +//CHECK: store i32 %[[REDPRIVVAL]], ptr %[[ARRAYOFFSET2]], align 4 +//CHECK: br label %omp.scan.loop.exit diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir index 7eafe396082e4..7b8e8b509d72b 100644 --- a/mlir/test/Target/LLVMIR/openmp-todo.mlir +++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir @@ -212,37 +212,6 @@ llvm.func @simd_reduction(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) { // ----- -omp.declare_reduction @add_f32 : f32 -init { -^bb0(%arg: f32): - %0 = llvm.mlir.constant(0.0 : f32) : f32 - omp.yield (%0 : f32) -} -combiner { -^bb1(%arg0: f32, %arg1: f32): - %1 = llvm.fadd %arg0, %arg1 : f32 - omp.yield (%1 : f32) -} -atomic { -^bb2(%arg2: !llvm.ptr, %arg3: !llvm.ptr): - %2 = llvm.load %arg3 : !llvm.ptr -> f32 - llvm.atomicrmw fadd %arg2, %2 monotonic : !llvm.ptr, f32 - omp.yield -} -llvm.func @scan_reduction(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) { - // expected-error@below {{not yet implemented: Unhandled clause reduction with modifier in omp.wsloop operation}} - // expected-error@below {{LLVM Translation failed for operation: omp.wsloop}} - omp.wsloop reduction(mod:inscan, @add_f32 %x -> %prv : !llvm.ptr) { - omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) { - omp.scan inclusive(%prv : !llvm.ptr) - omp.yield - } - } - llvm.return -} - -// ----- - llvm.func @single_allocate(%x : !llvm.ptr) { // expected-error@below {{not yet implemented: Unhandled clause allocate in omp.single operation}} // expected-error@below {{LLVM Translation failed for operation: omp.single}}