-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[mlir][OpenMP] - Transform target offloading directives for easier translation to LLVMIR #83966
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[mlir][OpenMP] - Transform target offloading directives for easier translation to LLVMIR #83966
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-flang-openmp Author: Pranav Bhandarkar (bhandarkar-pranav) ChangesThis patch adds a pass that transforms These operations are transformed by enclosing them inside a new Example: omp.target map_entries(%map_a -> %arg0, %map_c_from -> %arg1 : memref<?xi32>, memref<?xi32>) depend(taskdependout -> %c : memref<?xi32>) {
^bb0(%arg0 : memref<?xi32>, %arg1 : memref<?xi32>) :
"test.foobar"() : ()->()
omp.terminator
} Output: omp.task depend(taskdependout -> %c : memref<?xi32>) {
omp.target map_entries(%map_a -> %arg0, %map_c_from -> %arg1 : memref<?xi32>, memref<?xi32>) {
^bb0(%arg0 : memref<?xi32>, %arg1 : memref<?xi32>) :
"test.foobar"() : ()->()
omp.terminator
}
omp.terminator
} The intent is to make it easier to translate to LLVMIR by avoiding the creation of such tasks in the Full diff: https://github.com/llvm/llvm-project/pull/83966.diff 11 Files Affected:
diff --git a/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt b/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt
index 419e24a7335361..51ab0f23cd00d5 100644
--- a/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt
@@ -23,3 +23,11 @@ mlir_tablegen(OpenMPTypeInterfaces.h.inc -gen-type-interface-decls)
mlir_tablegen(OpenMPTypeInterfaces.cpp.inc -gen-type-interface-defs)
add_public_tablegen_target(MLIROpenMPTypeInterfacesIncGen)
add_dependencies(mlir-generic-headers MLIROpenMPTypeInterfacesIncGen)
+
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name OpenMP)
+mlir_tablegen(Passes.capi.h.inc -gen-pass-capi-header --prefix OpenMP)
+mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl --prefix OpenMP)
+add_public_tablegen_target(MLIROpenMPPassIncGen)
+
+add_mlir_doc(Passes OpenMPPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/OpenMP/Passes.h b/mlir/include/mlir/Dialect/OpenMP/Passes.h
new file mode 100644
index 00000000000000..2167c95055d31f
--- /dev/null
+++ b/mlir/include/mlir/Dialect/OpenMP/Passes.h
@@ -0,0 +1,35 @@
+//===- Passes.h - OpenMP passes entry points -----------------------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header file defines prototypes that expose pass constructors.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_OPENMP_PASSES_H
+#define MLIR_DIALECT_OPENMP_PASSES_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+std::unique_ptr<Pass> createOpenMPTaskBasedTargetPass();
+
+//===----------------------------------------------------------------------===//
+// Registration
+//===----------------------------------------------------------------------===//
+
+namespace omp {
+
+/// Generate the code for registering passes.
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/OpenMP/Passes.h.inc"
+
+} // namespace omp
+} // namespace mlir
+
+#endif
diff --git a/mlir/include/mlir/Dialect/OpenMP/Passes.td b/mlir/include/mlir/Dialect/OpenMP/Passes.td
new file mode 100644
index 00000000000000..4863999db24c56
--- /dev/null
+++ b/mlir/include/mlir/Dialect/OpenMP/Passes.td
@@ -0,0 +1,54 @@
+//===-- Passes.td - OpenMP pass definition file -------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_OPENMP_PASSES
+#define MLIR_DIALECT_OPENMP_PASSES
+
+include "mlir/Pass/PassBase.td"
+
+def OpenMPTaskBasedTarget : Pass<"openmp-task-based-target", "func::FuncOp"> {
+ let summary = "Nest certain instances of mlir::omp::TargetOp inside mlir::omp::TaskOp";
+
+ let constructor = "mlir::createOpenMPTaskBasedTargetPass()";
+
+ let description = [{
+ This pass transforms `omp.target`, `omp.target_enter_data`,
+ `omp.target_update_data` and `omp.target_exit_data` whenever these operations
+ have the `depend` clause on them.
+
+ These operations are transformed by enclosing them inside a new `omp.task`
+ operation. The `depend` clause related arguments are moved to the new `omp.task`
+ operation from the original 'target' operation.
+
+ Example:
+ Input:
+ ```mlir
+ omp.target map_entries(%map_a -> %arg0, %map_c_from -> %arg1 : memref<?xi32>, memref<?xi32>) depend(taskdependout -> %c : memref<?xi32>) {
+ ^bb0(%arg0 : memref<?xi32>, %arg1 : memref<?xi32>) :
+ "test.foobar"() : ()->()
+ omp.terminator
+ }
+ ```
+ Output:
+ ```mlir
+ omp.task depend(taskdependout -> %c : memref<?xi32>) {
+ omp.target map_entries(%map_a -> %arg0, %map_c_from -> %arg1 : memref<?xi32>, memref<?xi32>) {
+ ^bb0(%arg0 : memref<?xi32>, %arg1 : memref<?xi32>) :
+ "test.foobar"() : ()->()
+ omp.terminator
+ }
+ omp.terminator
+ }
+ ```
+ The intent is to make it easier to translate to LLVMIR by avoiding the
+ creation of such tasks in the OMPIRBuilder.
+ }];
+
+ let dependentDialects = ["omp::OpenMPDialect"];
+}
+#endif
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index 5d90c197a6cced..902ab8f4c4fd1b 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -36,6 +36,7 @@
#include "mlir/Dialect/Mesh/Transforms/Passes.h"
#include "mlir/Dialect/NVGPU/Transforms/Passes.h"
#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
+#include "mlir/Dialect/OpenMP/Passes.h"
#include "mlir/Dialect/SCF/Transforms/Passes.h"
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
@@ -82,6 +83,7 @@ inline void registerAllPasses() {
memref::registerMemRefPasses();
mesh::registerMeshPasses();
ml_program::registerMLProgramPasses();
+ omp::registerOpenMPPasses();
registerSCFPasses();
registerShapePasses();
spirv::registerSPIRVPasses();
diff --git a/mlir/lib/CAPI/Dialect/CMakeLists.txt b/mlir/lib/CAPI/Dialect/CMakeLists.txt
index 58b8739043f9df..439f0093cc3d26 100644
--- a/mlir/lib/CAPI/Dialect/CMakeLists.txt
+++ b/mlir/lib/CAPI/Dialect/CMakeLists.txt
@@ -223,6 +223,7 @@ add_mlir_upstream_c_api_library(MLIRCAPIOpenMP
LINK_LIBS PUBLIC
MLIRCAPIIR
MLIROpenMPDialect
+ MLIROpenMPTransforms
)
add_mlir_upstream_c_api_library(MLIRCAPIPDL
diff --git a/mlir/lib/Dialect/OpenMP/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/CMakeLists.txt
index 40b4837484a136..9f57627c321fb0 100644
--- a/mlir/lib/Dialect/OpenMP/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenMP/CMakeLists.txt
@@ -1,17 +1,2 @@
-add_mlir_dialect_library(MLIROpenMPDialect
- IR/OpenMPDialect.cpp
-
- ADDITIONAL_HEADER_DIRS
- ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenMP
-
- DEPENDS
- MLIROpenMPOpsIncGen
- MLIROpenMPOpsInterfacesIncGen
- MLIROpenMPTypeInterfacesIncGen
-
- LINK_LIBS PUBLIC
- MLIRIR
- MLIRLLVMDialect
- MLIRFuncDialect
- MLIROpenACCMPCommon
- )
+add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/OpenMP/IR/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/IR/CMakeLists.txt
new file mode 100644
index 00000000000000..def53387255653
--- /dev/null
+++ b/mlir/lib/Dialect/OpenMP/IR/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_dialect_library(MLIROpenMPDialect
+ OpenMPDialect.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenMP
+
+ DEPENDS
+ MLIROpenMPOpsIncGen
+ MLIROpenMPOpsInterfacesIncGen
+ MLIROpenMPTypeInterfacesIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRLLVMDialect
+ MLIRFuncDialect
+ MLIROpenACCMPCommon
+ )
diff --git a/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt
new file mode 100644
index 00000000000000..1a64b5268e0839
--- /dev/null
+++ b/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt
@@ -0,0 +1,16 @@
+add_mlir_dialect_library(MLIROpenMPTransforms
+ OpenMPTaskBasedTarget.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenMP
+
+ DEPENDS
+ MLIROpenMPPassIncGen
+
+ LINK_LIBS PUBLIC
+ MLIROpenMPDialect
+ MLIRFuncDialect
+ MLIRIR
+ MLIRPass
+ MLIRTransforms
+)
diff --git a/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp
new file mode 100644
index 00000000000000..1e3977eb33e754
--- /dev/null
+++ b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp
@@ -0,0 +1,120 @@
+//===- OpenMPTaskBasedTarget.cpp - Implementation of OpenMPTaskBasedTargetPass
+//---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass that transforms certain omp.target.
+// Specifically, an omp.target op that has the depend clause on it is
+// transformed into an omp.task clause with the same depend clause on it.
+// The original omp.target loses its depend clause and is contained in
+// the new task region.
+//
+// omp.target depend(..) {
+// omp.terminator
+//
+// }
+//
+// =>
+//
+// omp.task depend(..) {
+// omp.target {
+// omp.terminator
+// }
+// omp.terminator
+// }
+//
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/OpenMP/Passes.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Debug.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_OPENMPTASKBASEDTARGET
+#include "mlir/Dialect/OpenMP/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::omp;
+
+#define DEBUG_TYPE "openmp-task-based-target"
+
+namespace {
+
+struct OpenMPTaskBasedTargetPass
+ : public impl::OpenMPTaskBasedTargetBase<OpenMPTaskBasedTargetPass> {
+
+ void runOnOperation() override;
+};
+template <typename OpTy>
+class OmpTaskBasedTargetRewritePattern : public OpRewritePattern<OpTy> {
+public:
+ using OpRewritePattern<OpTy>::OpRewritePattern;
+ LogicalResult matchAndRewrite(OpTy op,
+ PatternRewriter &rewriter) const override {
+
+ // Only match a target op with a 'depend' clause on it.
+ if (op.getDependVars().empty()) {
+ return rewriter.notifyMatchFailure(op, "depend clause not found on op");
+ }
+
+ // Step 1: Create a new task op and tack on the dependency from the 'depend'
+ // clause on it.
+ omp::TaskOp taskOp = rewriter.create<omp::TaskOp>(
+ op.getLoc(), /*if_expr*/ Value(),
+ /*final_expr*/ Value(),
+ /*untied*/ UnitAttr(),
+ /*mergeable*/ UnitAttr(),
+ /*in_reduction_vars*/ ValueRange(),
+ /*in_reductions*/ nullptr,
+ /*priority*/ Value(), op.getDepends().value(), op.getDependVars(),
+ /*allocate_vars*/ ValueRange(),
+ /*allocate_vars*/ ValueRange());
+ Block *block = rewriter.createBlock(&taskOp.getRegion());
+ rewriter.setInsertionPointToEnd(block);
+ // Step 2: Clone and put the entire target op inside the newly created
+ // task's region.
+ Operation *clonedTargetOperation = rewriter.clone(*op.getOperation());
+ rewriter.create<mlir::omp::TerminatorOp>(op.getLoc());
+
+ // Step 3: Remove the dependency information from the clone target op.
+ OpTy clonedTargetOp = llvm::dyn_cast<OpTy>(clonedTargetOperation);
+ if (clonedTargetOp) {
+ clonedTargetOp.removeDependsAttr();
+ clonedTargetOp.getDependVarsMutable().clear();
+ }
+ // Step 4: Erase the original target op
+ rewriter.eraseOp(op.getOperation());
+ return success();
+ }
+};
+} // namespace
+static void
+populateOmpTaskBasedTargetRewritePatterns(RewritePatternSet &patterns) {
+ patterns.add<OmpTaskBasedTargetRewritePattern<omp::TargetOp>,
+ OmpTaskBasedTargetRewritePattern<omp::EnterDataOp>,
+ OmpTaskBasedTargetRewritePattern<omp::UpdateDataOp>,
+ OmpTaskBasedTargetRewritePattern<omp::ExitDataOp>>(
+ patterns.getContext());
+}
+
+void OpenMPTaskBasedTargetPass::runOnOperation() {
+ Operation *op = getOperation();
+
+ RewritePatternSet patterns(op->getContext());
+ populateOmpTaskBasedTargetRewritePatterns(patterns);
+
+ if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+ signalPassFailure();
+}
+std::unique_ptr<Pass> mlir::createOpenMPTaskBasedTargetPass() {
+ return std::make_unique<OpenMPTaskBasedTargetPass>();
+}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index fd1de274da60e8..1b6fa5ebd83158 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -685,7 +685,36 @@ convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder,
ompLoc, bodyCB, finiCB, singleOp.getNowait(), llvmCPVars, llvmCPFuncs));
return bodyGenStatus;
}
-
+template <typename T>
+static void buildDependData(T taskOrTargetop,
+ SmallVector<llvm::OpenMPIRBuilder::DependData> &dds,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ // std::optional<ArrayAttr> depends,
+ // OperandRange &dependVars,
+ if (taskOrTargetop.getDependVars().empty())
+ return;
+ std::optional<ArrayAttr> depends = taskOrTargetop.getDepends();
+ const OperandRange &dependVars = taskOrTargetop.getDependVars();
+ for (auto dep : llvm::zip(dependVars, depends->getValue())) {
+ llvm::omp::RTLDependenceKindTy type;
+ switch (
+ cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
+ case mlir::omp::ClauseTaskDepend::taskdependin:
+ type = llvm::omp::RTLDependenceKindTy::DepIn;
+ break;
+ // The OpenMP runtime requires that the codegen for 'depend' clause for
+ // 'out' dependency kind must be the same as codegen for 'depend' clause
+ // with 'inout' dependency.
+ case mlir::omp::ClauseTaskDepend::taskdependout:
+ case mlir::omp::ClauseTaskDepend::taskdependinout:
+ type = llvm::omp::RTLDependenceKindTy::DepInOut;
+ break;
+ };
+ llvm::Value *depVal = moduleTranslation.lookupValue(std::get<0>(dep));
+ llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
+ dds.emplace_back(dd);
+ }
+}
// Convert an OpenMP Teams construct to LLVM IR using OpenMPIRBuilder
static LogicalResult
convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
@@ -748,28 +777,7 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
};
SmallVector<llvm::OpenMPIRBuilder::DependData> dds;
- if (!taskOp.getDependVars().empty() && taskOp.getDepends()) {
- for (auto dep :
- llvm::zip(taskOp.getDependVars(), taskOp.getDepends()->getValue())) {
- llvm::omp::RTLDependenceKindTy type;
- switch (
- cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
- case mlir::omp::ClauseTaskDepend::taskdependin:
- type = llvm::omp::RTLDependenceKindTy::DepIn;
- break;
- // The OpenMP runtime requires that the codegen for 'depend' clause for
- // 'out' dependency kind must be the same as codegen for 'depend' clause
- // with 'inout' dependency.
- case mlir::omp::ClauseTaskDepend::taskdependout:
- case mlir::omp::ClauseTaskDepend::taskdependinout:
- type = llvm::omp::RTLDependenceKindTy::DepInOut;
- break;
- };
- llvm::Value *depVal = moduleTranslation.lookupValue(std::get<0>(dep));
- llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
- dds.emplace_back(dd);
- }
- }
+ buildDependData(taskOp, dds, moduleTranslation);
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
findAllocaInsertPoint(builder, moduleTranslation);
diff --git a/mlir/test/Dialect/OpenMP/task-based-target.mlir b/mlir/test/Dialect/OpenMP/task-based-target.mlir
new file mode 100644
index 00000000000000..26cc493047e19b
--- /dev/null
+++ b/mlir/test/Dialect/OpenMP/task-based-target.mlir
@@ -0,0 +1,67 @@
+// RUN: mlir-opt %s -openmp-task-based-target -split-input-file | FileCheck %s
+
+// CHECK-LABEL: @omp_target_depend
+// CHECK-SAME: (%arg0: memref<i32>, %arg1: memref<i32>) {
+func.func @omp_target_depend(%arg0: memref<i32>, %arg1: memref<i32>) {
+ // CHECK: omp.task depend(taskdependin -> %arg0 : memref<i32>, taskdependin -> %arg1 : memref<i32>, taskdependinout -> %arg0 : memref<i32>) {
+ // CHECK: omp.target {
+ omp.target depend(taskdependin -> %arg0 : memref<i32>, taskdependin -> %arg1 : memref<i32>, taskdependinout -> %arg0 : memref<i32>) {
+ // CHECK: omp.terminator
+ omp.terminator
+ } {operandSegmentSizes = array<i32: 0,0,0,3,0>}
+ return
+}
+// CHECK-LABEL: func @omp_target_enter_update_exit_data_depend
+// CHECK-SAME:([[ARG0:%.*]]: memref<?xi32>, [[ARG1:%.*]]: memref<?xi32>, [[ARG2:%.*]]: memref<?xi32>) {
+func.func @omp_target_enter_update_exit_data_depend(%a: memref<?xi32>, %b: memref<?xi32>, %c: memref<?xi32>) {
+// CHECK-NEXT: [[MAP0:%.*]] = omp.map_info
+// CHECK-NEXT: [[MAP1:%.*]] = omp.map_info
+// CHECK-NEXT: [[MAP2:%.*]] = omp.map_info
+ %map_a = omp.map_info var_ptr(%a: memref<?xi32>, tensor<?xi32>) map_clauses(to) capture(ByRef) -> memref<?xi32>
+ %map_b = omp.map_info var_ptr(%b: memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
+ %map_c = omp.map_info var_ptr(%c: memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32>
+
+ // Do some work on the host that writes to 'a'
+ omp.task depend(taskdependout -> %a : memref<?xi32>) {
+ "test.foo"(%a) : (memref<?xi32>) -> ()
+ omp.terminator
+ }
+
+ // Then map that over to the target
+ // CHECK: omp.task depend(taskdependin -> [[ARG0]] : memref<?xi32>)
+ // CHECK: omp.target_enter_data nowait map_entries([[MAP0]], [[MAP2]] : memref<?xi32>, memref<?xi32>)
+ omp.target_enter_data nowait map_entries(%map_a, %map_c: memref<?xi32>, memref<?xi32>) depend(taskdependin -> %a: memref<?xi32>)
+
+ // Compute 'b' on the target and copy it back
+ // CHECK: omp.target map_entries([[MAP1]] -> {{%.*}} : memref<?xi32>) {
+ omp.target map_entries(%map_b -> %arg0 : memref<?xi32>) {
+ ^bb0(%arg0: memref<?xi32>) :
+ "test.foo"(%arg0) : (memref<?xi32>) -> ()
+ omp.terminator
+ }
+
+ // Update 'a' on the host using 'b'
+ omp.task depend(taskdependout -> %a: memref<?xi32>){
+ "test.bar"(%a, %b) : (memref<?xi32>, memref<?xi32>) -> ()
+ }
+
+ // Copy the updated 'a' onto the target
+ // CHECK: omp.task depend(taskdependin -> [[ARG0]] : memref<?xi32>)
+ // CHECK: omp.target_update_data nowait motion_entries([[MAP0]] : memref<?xi32>)
+ omp.target_update_data motion_entries(%map_a : memref<?xi32>) depend(taskdependin -> %a : memref<?xi32>) nowait
+
+ // Compute 'c' on the target and copy it back
+ // CHECK:[[MAP3:%.*]] = omp.map_info var_ptr([[ARG2]] : memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
+ %map_c_from = omp.map_info var_ptr(%c: memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
+ // CHECK: omp.task depend(taskdependout -> [[ARG2]] : memref<?xi32>)
+ // CHECK: omp.target map_entries([[MAP0]] -> {{%.*}}, [[MAP3]] -> {{%.*}} : memref<?xi32>, memref<?xi32>) {
+ omp.target map_entries(%map_a -> %arg0, %map_c_from -> %arg1 : memref<?xi32>, memref<?xi32>) depend(taskdependout -> %c : memref<?xi32>) {
+ ^bb0(%arg0 : memref<?xi32>, %arg1 : memref<?xi32>) :
+ "test.foobar"() : ()->()
+ omp.terminator
+ }
+ // CHECK: omp.task depend(taskdependin -> [[ARG2]] : memref<?xi32>) {
+ // CHECK: omp.target_exit_data map_entries([[MAP2]] : memref<?xi32>)
+ omp.target_exit_data map_entries(%map_c : memref<?xi32>) depend(taskdependin -> %c : memref<?xi32>)
+ return
+}
|
Is there anything preventing this from being done as part of PFT to MLIR lowering rather than a pass? It should be possible to make the In my opinion, that would be less expensive than adding a full pass, and in principle it shouldn't create any big compromises that I can think of. Do you agree with that? I don't mind this approach either, if it's generally preferred by others. |
Thank you for your reply @skatrak.
I felt at the time that these were reasons enough for paying the price of going over the IR in a new pass altogether. Having said that, it isn't a very very strong preference at all and I am happy to change my approach should that be required. How about I leave this open for others to comment on, while I test this with an end-to-end testcase by invoking this pass from within flang? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a few small comments. I like the pass approach better than handling this OpenMP.cpp
tbh :).
@@ -0,0 +1,54 @@ | |||
//===-- Passes.td - OpenMP pass definition file -------------*- tablegen -*-===// |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Previous OpenMP passes were added in https://github.com/llvm/llvm-project/blob/main/flang/include/flang/Optimizer/Transforms/Passes.td#L321.
Even though I agree that OpenMP passes like this one and the ones linked above can be separated into their own file since they are not flang specific, but I think for consistency it would be better to use the already existing file to have a more centralized view of all OpenMP-related passes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a little background on why the passes are there from what I recall at least (more for information to help make informed decisions or for someone else to chime in with a preference), the OMPMarkDeclareTargetPass
/OMPFunctionFiltering
are there as they were initially in an experimental phase (alongside another now erased pass) to verify if we'd keep them in the offloading flow and see how they'd interact with each other overtime as they are interlinked behavior. Not sure when they'd best to be moved, it's coming up to their anniversary at this point, but it's something that can wait in any case.
The OMPDescriptorMapInfoGenPass
is very much Fortran specific at the moment, it primarily handles FIR BoxType
's, so it's not feasible at the moment (or perhaps ever) to move to the OpenMP dialect as it has FIR dependencies and FIR is an external dialect to the MLIR project.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the info @agozillon. @ergawy - Like you said, I kept this in mlir only because it isnt really flang specific and I anticipated that if I put it in flang someone might call me out for it. OTOH keeping it in flang would have made testcase generation easier as I'd have been able to load the fir
dialect into the MLIR context. Now I had to rely on fir-opt
to convert all fir
ops to llvm
dialect ops, but fir-opt
is broken in a weird way such that I couldn't run fir-to-llvm-ir
lowering to rid my testcases of fir
ops. (Aside: I hope to find some time to chase down the fir-opt
problem)
|
||
// Only match a target op with a 'depend' clause on it. | ||
if (op.getDependVars().empty()) { | ||
return rewriter.notifyMatchFailure(op, "depend clause not found on op"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use a ConversionTarget
and setup it with legality information instead? https://github.com/llvm/llvm-project/blob/main/mlir/docs/DialectConversion.md#conversion-target
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did consider it and correct me if I am wrong, but a ConversionPattern
is needed only if types need to be legalized as well correct? Since I didn't need that I chose to use an OpRewritePattern
. Happy to change it if there is a strong preference. BTW, thanks for pointing me to your 'do concurrent' PR, that helped me start up quickly.
It seems like this is similar to what @kparzysz is working on addressing, not quite the same perhaps, but it would be nice if we could produce MLIR directly that has the omp.task op. If we do that, would it be possible to remove the depend clause from the omp.target op? |
…exit data as well
…OpenMP/OpenMPToLLVMIRTranslation.cpp
…penMP/OpenMPToLLVMIRTranslation.cpp
Also add the if clause to the newly generated omp.task op that encloses the omp.target op.
ae8a136
to
65c7554
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
This patch adds a pass that transforms
omp.target
,omp.target_enter_data
,omp.target_update_data
andomp.target_exit_data
whenever these operationshave the
depend
clause on them.These operations are transformed by enclosing them inside a new
omp.task
operation. The
depend
clause related arguments are moved to the newomp.task
operation from the original 'target' operation.
Example:
Input:
Output:
The intent is to make it easier to translate to LLVMIR by avoiding the creation of such tasks in the
OMPIRBuilder
.