-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[mlir][OpenMP] - Transform target offloading directives for easier translation to LLVMIR #83966
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
25c5dda
99819c7
3b97140
3993997
167bb19
5789bfb
e60969a
23b1ea0
cee6f36
8cc7a89
0f71c22
65c7554
47a09c8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
//===- Passes.h - OpenMP passes entry points -----------------------*- C++ | ||
//-*-===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// This header file defines prototypes that expose pass constructors. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef MLIR_DIALECT_OPENMP_PASSES_H | ||
#define MLIR_DIALECT_OPENMP_PASSES_H | ||
|
||
#include "mlir/Pass/Pass.h" | ||
|
||
namespace mlir { | ||
std::unique_ptr<Pass> createOpenMPTaskBasedTargetPass(); | ||
|
||
//===----------------------------------------------------------------------===// | ||
// Registration | ||
//===----------------------------------------------------------------------===// | ||
|
||
namespace omp { | ||
|
||
/// Generate the code for registering passes. | ||
#define GEN_PASS_REGISTRATION | ||
#include "mlir/Dialect/OpenMP/Passes.h.inc" | ||
|
||
} // namespace omp | ||
} // namespace mlir | ||
|
||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
//===-- Passes.td - OpenMP pass definition file -------------*- tablegen -*-===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef MLIR_DIALECT_OPENMP_PASSES | ||
#define MLIR_DIALECT_OPENMP_PASSES | ||
|
||
include "mlir/Pass/PassBase.td" | ||
|
||
def OpenMPTaskBasedTarget : Pass<"openmp-task-based-target", "func::FuncOp"> { | ||
let summary = "Nest certain instances of mlir::omp::TargetOp inside mlir::omp::TaskOp"; | ||
|
||
let constructor = "mlir::createOpenMPTaskBasedTargetPass()"; | ||
|
||
let description = [{ | ||
This pass transforms `omp.target`, `omp.target_enter_data`, | ||
`omp.target_update_data` and `omp.target_exit_data` whenever these operations | ||
have the `depend` clause on them. | ||
|
||
These operations are transformed by enclosing them inside a new `omp.task` | ||
operation. The `depend` clause related arguments are moved to the new `omp.task` | ||
operation from the original 'target' operation. | ||
|
||
Example: | ||
Input: | ||
```mlir | ||
omp.target map_entries(%map_a -> %arg0, %map_c_from -> %arg1 : memref<?xi32>, memref<?xi32>) depend(taskdependout -> %c : memref<?xi32>) { | ||
^bb0(%arg0 : memref<?xi32>, %arg1 : memref<?xi32>) : | ||
"test.foobar"() : ()->() | ||
omp.terminator | ||
} | ||
``` | ||
Output: | ||
```mlir | ||
omp.task depend(taskdependout -> %c : memref<?xi32>) { | ||
omp.target map_entries(%map_a -> %arg0, %map_c_from -> %arg1 : memref<?xi32>, memref<?xi32>) { | ||
^bb0(%arg0 : memref<?xi32>, %arg1 : memref<?xi32>) : | ||
"test.foobar"() : ()->() | ||
omp.terminator | ||
} | ||
omp.terminator | ||
} | ||
``` | ||
The intent is to make it easier to translate to LLVMIR by avoiding the | ||
creation of such tasks in the OMPIRBuilder. | ||
}]; | ||
|
||
let dependentDialects = ["omp::OpenMPDialect"]; | ||
} | ||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,2 @@ | ||
add_mlir_dialect_library(MLIROpenMPDialect | ||
IR/OpenMPDialect.cpp | ||
|
||
ADDITIONAL_HEADER_DIRS | ||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenMP | ||
|
||
DEPENDS | ||
omp_gen | ||
MLIROpenMPOpsIncGen | ||
MLIROpenMPOpsInterfacesIncGen | ||
MLIROpenMPTypeInterfacesIncGen | ||
|
||
LINK_LIBS PUBLIC | ||
MLIRIR | ||
MLIRLLVMDialect | ||
MLIRFuncDialect | ||
MLIROpenACCMPCommon | ||
) | ||
add_subdirectory(IR) | ||
add_subdirectory(Transforms) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
add_mlir_dialect_library(MLIROpenMPDialect | ||
OpenMPDialect.cpp | ||
|
||
ADDITIONAL_HEADER_DIRS | ||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenMP | ||
|
||
DEPENDS | ||
omp_gen | ||
MLIROpenMPOpsIncGen | ||
MLIROpenMPOpsInterfacesIncGen | ||
MLIROpenMPTypeInterfacesIncGen | ||
|
||
LINK_LIBS PUBLIC | ||
MLIRIR | ||
MLIRLLVMDialect | ||
MLIRFuncDialect | ||
MLIROpenACCMPCommon | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
add_mlir_dialect_library(MLIROpenMPTransforms | ||
OpenMPTaskBasedTarget.cpp | ||
|
||
ADDITIONAL_HEADER_DIRS | ||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenMP | ||
|
||
DEPENDS | ||
MLIROpenMPPassIncGen | ||
|
||
LINK_LIBS PUBLIC | ||
MLIROpenMPDialect | ||
MLIRArithDialect | ||
MLIRFuncDialect | ||
MLIRIR | ||
MLIRPass | ||
MLIRTransforms | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
//===- OpenMPTaskBasedTarget.cpp - Implementation of OpenMPTaskBasedTargetPass | ||
//---===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// This file implements a pass that transforms certain omp.target. | ||
// Specifically, an omp.target op that has the depend clause on it is | ||
// transformed into an omp.task clause with the same depend clause on it. | ||
// The original omp.target loses its depend clause and is contained in | ||
// the new task region. | ||
// | ||
// omp.target depend(..) { | ||
// omp.terminator | ||
// | ||
// } | ||
// | ||
// => | ||
// | ||
// omp.task depend(..) { | ||
// omp.target { | ||
// omp.terminator | ||
// } | ||
// omp.terminator | ||
// } | ||
// | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Dialect/OpenMP/Passes.h" | ||
|
||
#include "mlir/Dialect/Arith/IR/Arith.h" | ||
#include "mlir/Dialect/Func/IR/FuncOps.h" | ||
#include "mlir/Dialect/OpenMP/OpenMPDialect.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
#include "llvm/Support/Debug.h" | ||
|
||
namespace mlir { | ||
#define GEN_PASS_DEF_OPENMPTASKBASEDTARGET | ||
#include "mlir/Dialect/OpenMP/Passes.h.inc" | ||
} // namespace mlir | ||
|
||
using namespace mlir; | ||
using namespace mlir::omp; | ||
|
||
#define DEBUG_TYPE "openmp-task-based-target" | ||
|
||
namespace { | ||
|
||
struct OpenMPTaskBasedTargetPass | ||
: public impl::OpenMPTaskBasedTargetBase<OpenMPTaskBasedTargetPass> { | ||
|
||
void runOnOperation() override; | ||
}; | ||
template <typename OpTy> | ||
class OmpTaskBasedTargetRewritePattern : public OpRewritePattern<OpTy> { | ||
public: | ||
using OpRewritePattern<OpTy>::OpRewritePattern; | ||
LogicalResult matchAndRewrite(OpTy op, | ||
PatternRewriter &rewriter) const override { | ||
|
||
// Only match a target op with a 'depend' clause on it. | ||
if (op.getDependVars().empty()) { | ||
return rewriter.notifyMatchFailure(op, "depend clause not found on op"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not use a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did consider it and correct me if I am wrong, but a |
||
} | ||
|
||
// Step 1: Create a new task op and tack on the dependency from the 'depend' | ||
// clause on it. | ||
Type i1Ty = rewriter.getI1Type(); | ||
// mlir::BoolAttr T = rewriter.getBoolAttr(true); | ||
// mlir::BoolAttr F = rewriter.getBoolAttr(false); | ||
omp::TaskOp taskOp = rewriter.create<omp::TaskOp>( | ||
op.getLoc(), | ||
/*if_expr*/ op.getNowait() | ||
? rewriter.create<mlir::arith::ConstantOp>( | ||
op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1)) | ||
: rewriter.create<mlir::arith::ConstantOp>( | ||
op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0)), | ||
/*final_expr*/ Value(), | ||
/*untied*/ UnitAttr(), | ||
/*mergeable*/ UnitAttr(), | ||
/*in_reduction_vars*/ ValueRange(), | ||
/*in_reductions*/ nullptr, | ||
/*priority*/ Value(), op.getDepends().value(), op.getDependVars(), | ||
/*allocate_vars*/ ValueRange(), | ||
/*allocate_vars*/ ValueRange()); | ||
Block *block = rewriter.createBlock(&taskOp.getRegion()); | ||
rewriter.setInsertionPointToEnd(block); | ||
// Step 2: Clone and put the entire target op inside the newly created | ||
// task's region. | ||
Operation *clonedTargetOperation = rewriter.clone(*op.getOperation()); | ||
rewriter.create<mlir::omp::TerminatorOp>(op.getLoc()); | ||
|
||
// Step 3: Remove the dependency information from the clone target op. | ||
OpTy clonedTargetOp = llvm::dyn_cast<OpTy>(clonedTargetOperation); | ||
if (clonedTargetOp) { | ||
clonedTargetOp.removeDependsAttr(); | ||
clonedTargetOp.getDependVarsMutable().clear(); | ||
} | ||
// Step 4: Erase the original target op | ||
rewriter.eraseOp(op.getOperation()); | ||
return success(); | ||
} | ||
}; | ||
} // namespace | ||
static void | ||
populateOmpTaskBasedTargetRewritePatterns(RewritePatternSet &patterns) { | ||
patterns.add<OmpTaskBasedTargetRewritePattern<omp::TargetOp>, | ||
OmpTaskBasedTargetRewritePattern<omp::TargetEnterDataOp>, | ||
OmpTaskBasedTargetRewritePattern<omp::TargetUpdateOp>, | ||
OmpTaskBasedTargetRewritePattern<omp::TargetExitDataOp>>( | ||
patterns.getContext()); | ||
} | ||
|
||
void OpenMPTaskBasedTargetPass::runOnOperation() { | ||
Operation *op = getOperation(); | ||
|
||
RewritePatternSet patterns(op->getContext()); | ||
populateOmpTaskBasedTargetRewritePatterns(patterns); | ||
|
||
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) | ||
signalPassFailure(); | ||
} | ||
std::unique_ptr<Pass> mlir::createOpenMPTaskBasedTargetPass() { | ||
return std::make_unique<OpenMPTaskBasedTargetPass>(); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
// RUN: mlir-opt %s -openmp-task-based-target -split-input-file | FileCheck %s | ||
|
||
// CHECK-LABEL: @omp_target_depend | ||
// CHECK-SAME: (%arg0: memref<i32>, %arg1: memref<i32>) { | ||
func.func @omp_target_depend(%arg0: memref<i32>, %arg1: memref<i32>) { | ||
// CHECK: omp.task if(%false) depend(taskdependin -> %arg0 : memref<i32>, taskdependin -> %arg1 : memref<i32>, taskdependinout -> %arg0 : memref<i32>) { | ||
// CHECK: omp.target { | ||
omp.target depend(taskdependin -> %arg0 : memref<i32>, taskdependin -> %arg1 : memref<i32>, taskdependinout -> %arg0 : memref<i32>) { | ||
// CHECK: omp.terminator | ||
omp.terminator | ||
} {operandSegmentSizes = array<i32: 0,0,0,3,0>} | ||
return | ||
} | ||
// CHECK-LABEL: func @omp_target_enter_update_exit_data_depend | ||
// CHECK-SAME:([[ARG0:%.*]]: memref<?xi32>, [[ARG1:%.*]]: memref<?xi32>, [[ARG2:%.*]]: memref<?xi32>) { | ||
func.func @omp_target_enter_update_exit_data_depend(%a: memref<?xi32>, %b: memref<?xi32>, %c: memref<?xi32>) { | ||
// CHECK: [[MAP0:%.*]] = omp.map.info | ||
// CHECK-NEXT: [[MAP1:%.*]] = omp.map.info | ||
// CHECK-NEXT: [[MAP2:%.*]] = omp.map.info | ||
%map_a = omp.map.info var_ptr(%a: memref<?xi32>, tensor<?xi32>) map_clauses(to) capture(ByRef) -> memref<?xi32> | ||
%map_b = omp.map.info var_ptr(%b: memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32> | ||
%map_c = omp.map.info var_ptr(%c: memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32> | ||
|
||
// Do some work on the host that writes to 'a' | ||
omp.task depend(taskdependout -> %a : memref<?xi32>) { | ||
"test.foo"(%a) : (memref<?xi32>) -> () | ||
omp.terminator | ||
} | ||
|
||
// Then map that over to the target | ||
// CHECK: omp.task if(%true) depend(taskdependin -> [[ARG0]] : memref<?xi32>) | ||
// CHECK: omp.target_enter_data nowait map_entries([[MAP0]], [[MAP2]] : memref<?xi32>, memref<?xi32>) | ||
omp.target_enter_data nowait map_entries(%map_a, %map_c: memref<?xi32>, memref<?xi32>) depend(taskdependin -> %a: memref<?xi32>) | ||
|
||
// Compute 'b' on the target and copy it back | ||
// CHECK: omp.target map_entries([[MAP1]] -> {{%.*}} : memref<?xi32>) { | ||
omp.target map_entries(%map_b -> %arg0 : memref<?xi32>) { | ||
^bb0(%arg0: memref<?xi32>) : | ||
"test.foo"(%arg0) : (memref<?xi32>) -> () | ||
omp.terminator | ||
} | ||
|
||
// Update 'a' on the host using 'b' | ||
omp.task depend(taskdependout -> %a: memref<?xi32>){ | ||
"test.bar"(%a, %b) : (memref<?xi32>, memref<?xi32>) -> () | ||
} | ||
|
||
// Copy the updated 'a' onto the target | ||
// CHECK: omp.task if(%true) depend(taskdependin -> [[ARG0]] : memref<?xi32>) | ||
// CHECK: omp.target_update nowait motion_entries([[MAP0]] : memref<?xi32>) | ||
omp.target_update motion_entries(%map_a : memref<?xi32>) depend(taskdependin -> %a : memref<?xi32>) nowait | ||
|
||
// Compute 'c' on the target and copy it back | ||
// CHECK:[[MAP3:%.*]] = omp.map.info var_ptr([[ARG2]] : memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32> | ||
%map_c_from = omp.map.info var_ptr(%c: memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32> | ||
// CHECK: omp.task if(%false) depend(taskdependout -> [[ARG2]] : memref<?xi32>) | ||
// CHECK: omp.target map_entries([[MAP0]] -> {{%.*}}, [[MAP3]] -> {{%.*}} : memref<?xi32>, memref<?xi32>) { | ||
omp.target map_entries(%map_a -> %arg0, %map_c_from -> %arg1 : memref<?xi32>, memref<?xi32>) depend(taskdependout -> %c : memref<?xi32>) { | ||
^bb0(%arg0 : memref<?xi32>, %arg1 : memref<?xi32>) : | ||
"test.foobar"() : ()->() | ||
omp.terminator | ||
} | ||
// CHECK: omp.task if(%false) depend(taskdependin -> [[ARG2]] : memref<?xi32>) { | ||
// CHECK: omp.target_exit_data map_entries([[MAP2]] : memref<?xi32>) | ||
omp.target_exit_data map_entries(%map_c : memref<?xi32>) depend(taskdependin -> %c : memref<?xi32>) | ||
return | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Previous OpenMP passes were added in https://github.com/llvm/llvm-project/blob/main/flang/include/flang/Optimizer/Transforms/Passes.td#L321.
Even though I agree that OpenMP passes like this one and the ones linked above can be separated into their own file since they are not flang specific, but I think for consistency it would be better to use the already existing file to have a more centralized view of all OpenMP-related passes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a little background on why the passes are there from what I recall at least (more for information to help make informed decisions or for someone else to chime in with a preference), the
OMPMarkDeclareTargetPass
/OMPFunctionFiltering
are there as they were initially in an experimental phase (alongside another now erased pass) to verify if we'd keep them in the offloading flow and see how they'd interact with each other overtime as they are interlinked behavior. Not sure when they'd best to be moved, it's coming up to their anniversary at this point, but it's something that can wait in any case.The
OMPDescriptorMapInfoGenPass
is very much Fortran specific at the moment, it primarily handles FIRBoxType
's, so it's not feasible at the moment (or perhaps ever) to move to the OpenMP dialect as it has FIR dependencies and FIR is an external dialect to the MLIR project.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the info @agozillon. @ergawy - Like you said, I kept this in mlir only because it isnt really flang specific and I anticipated that if I put it in flang someone might call me out for it. OTOH keeping it in flang would have made testcase generation easier as I'd have been able to load the
fir
dialect into the MLIR context. Now I had to rely onfir-opt
to convert allfir
ops tollvm
dialect ops, butfir-opt
is broken in a weird way such that I couldn't runfir-to-llvm-ir
lowering to rid my testcases offir
ops. (Aside: I hope to find some time to chase down thefir-opt
problem)