Skip to content

[OpenMP] [IR Builder] Changes to Support Scan Operation #136035

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 124 additions & 4 deletions llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
Original file line number Diff line number Diff line change
@@ -508,6 +508,30 @@ class OpenMPIRBuilder {
return allocaInst;
}
};

struct ScanInformation {
/// Dominates the body of the loop before scan directive
llvm::BasicBlock *OMPBeforeScanBlock = nullptr;
/// Dominates the body of the loop before scan directive
llvm::BasicBlock *OMPAfterScanBlock = nullptr;
/// Controls the flow to before or after scan blocks
llvm::BasicBlock *OMPScanDispatch = nullptr;
/// Exit block of loop body
llvm::BasicBlock *OMPScanLoopExit = nullptr;
/// Block before loop body where scan initializations are done
llvm::BasicBlock *OMPScanInit = nullptr;
/// Block after loop body where scan finalizations are done
llvm::BasicBlock *OMPScanFinish = nullptr;
/// If true, it indicates Input phase is lowered; else it indicates
/// ScanPhase is lowered
bool OMPFirstScanLoop = false;
// Maps the private reduction variable to the pointer of the temporary
// buffer
llvm::SmallDenseMap<llvm::Value *, llvm::Value *> ScanBuffPtrs;
llvm::Value *IV;
llvm::Value *Span;
} ScanInfo;

/// Initialize the internal state, this will put structures types and
/// potentially other helpers into the underlying module. Must be called
/// before any other method and only once! This internal state includes types
@@ -743,6 +767,35 @@ class OpenMPIRBuilder {
LoopBodyGenCallbackTy BodyGenCB, Value *TripCount,
const Twine &Name = "loop");

/// Generator for the control flow structure of an OpenMP canonical loops if
/// the parent directive has an `inscan` modifier specified.
/// If the `inscan` modifier is specified, the region of the parent is
/// expected to have a `scan` directive. Based on the clauses in
/// scan directive, the body of the loop is split into two loops: Input loop
/// and Scan Loop. Input loop contains the code generated for input phase of
/// scan and Scan loop contains the code generated for scan phase of scan.
///
/// \param Loc The insert and source location description.
/// \param BodyGenCB Callback that will generate the loop body code.
/// \param Start Value of the loop counter for the first iterations.
/// \param Stop Loop counter values past this will stop the loop.
/// \param Step Loop counter increment after each iteration; negative
/// means counting down.
/// \param IsSigned Whether Start, Stop and Step are signed integers.
/// \param InclusiveStop Whether \p Stop itself is a valid value for the loop
/// counter.
/// \param ComputeIP Insertion point for instructions computing the trip
/// count. Can be used to ensure the trip count is available
/// at the outermost loop of a loop nest. If not set,
/// defaults to the preheader of the generated loop.
/// \param Name Base name used to derive BB and instruction names.
///
/// \returns A vector containing Loop Info of Input Loop and Scan Loop.
Expected<SmallVector<llvm::CanonicalLoopInfo *>> createCanonicalScanLoops(
const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
InsertPointTy ComputeIP, const Twine &Name);

/// Calculate the trip count of a canonical loop.
///
/// This allows specifying user-defined loop counter values using increment,
@@ -811,13 +864,16 @@ class OpenMPIRBuilder {
/// at the outermost loop of a loop nest. If not set,
/// defaults to the preheader of the generated loop.
/// \param Name Base name used to derive BB and instruction names.
/// \param InScan Whether loop has a scan reduction specified.
///
/// \returns An object representing the created control flow structure which
/// can be used for loop-associated directives.
LLVM_ABI Expected<CanonicalLoopInfo *> createCanonicalLoop(
const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
InsertPointTy ComputeIP = {}, const Twine &Name = "loop");
LLVM_ABI Expected<CanonicalLoopInfo *>
createCanonicalLoop(const LocationDescription &Loc,
LoopBodyGenCallbackTy BodyGenCB, Value *Start,
Value *Stop, Value *Step, bool IsSigned,
bool InclusiveStop, InsertPointTy ComputeIP = {},
const Twine &Name = "loop", bool InScan = false);

/// Collapse a loop nest into a single loop.
///
@@ -1548,6 +1604,35 @@ class OpenMPIRBuilder {
ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
Function *ReduceFn, AttributeList FuncAttrs);

/// Helper function for CreateCanonicalScanLoops to create InputLoop
/// in the firstGen and Scan Loop in the SecondGen
/// \param InputLoopGen Callback for generating the loop for input phase
/// \param ScanLoopGen Callback for generating the loop for scan phase
///
/// \return error if any produced, else return success.
Error emitScanBasedDirectiveIR(
llvm::function_ref<Error()> InputLoopGen,
llvm::function_ref<Error(LocationDescription Loc)> ScanLoopGen);

/// Creates the basic blocks required for scan reduction.
void createScanBBs();

/// Dynamically allocates the buffer needed for scan reduction.
/// \param AllocaIP The IP where possibly-shared pointer of buffer needs to be
/// declared. \param ScanVars Scan Variables.
///
/// \return error if any produced, else return success.
Error emitScanBasedDirectiveDeclsIR(InsertPointTy AllocaIP,
ArrayRef<llvm::Value *> ScanVars,
ArrayRef<llvm::Type *> ScanVarsType);

/// Copies the result back to the reduction variable.
/// \param ReductionInfos Array type containing the ReductionOps.
///
/// \return error if any produced, else return success.
Error emitScanBasedDirectiveFinalsIR(
SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> ReductionInfos);

/// This function emits a helper that gathers Reduce lists from the first
/// lane of every active warp to lanes in the first warp.
///
@@ -2631,6 +2716,41 @@ class OpenMPIRBuilder {
FinalizeCallbackTy FiniCB,
Value *Filter);

/// This function performs the scan reduction of the values updated in
/// the input phase. The reduction logic needs to be emitted between input
/// and scan loop returned by `CreateCanonicalScanLoops`. The following
/// is the code that is generated, `buffer` and `span` are expected to be
/// populated before executing the generated code.
///
/// for (int k = 0; k != ceil(log2(span)); ++k) {
/// i=pow(2,k)
/// for (size cnt = last_iter; cnt >= i; --cnt)
/// buffer[cnt] op= buffer[cnt-i];
/// }
/// \param Loc The insert and source location description.
/// \param ReductionInfos Array type containing the ReductionOps.
///
/// \returns The insertion position *after* the masked.
InsertPointOrErrorTy emitScanReduction(
const LocationDescription &Loc,
SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> ReductionInfos);

/// This directive split and directs the control flow to input phase
/// blocks or scan phase blocks based on 1. whether input loop or scan loop
/// is executed, 2. whether exclusive or inclusive scan is used.
///
/// \param Loc The insert and source location description.
/// \param AllocaIP The IP where the temporary buffer for scan reduction
// needs to be allocated.
/// \param ScanVars Scan Variables.
/// \param IsInclusive Whether it is an inclusive or exclusive scan.
///
/// \returns The insertion position *after* the scan.
InsertPointOrErrorTy createScan(const LocationDescription &Loc,
InsertPointTy AllocaIP,
ArrayRef<llvm::Value *> ScanVars,
ArrayRef<llvm::Type *> ScanVarsType,
bool IsInclusive);
/// Generator for '#omp critical'
///
/// \param Loc The insert and source location description.
Loading