Skip to content

[mlir][OpenMP] - Transform target offloading directives for easier translation to LLVMIR #83966

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from

Conversation

bhandarkar-pranav
Copy link
Contributor

This patch adds a pass that transforms omp.target, omp.target_enter_data,
omp.target_update_data and omp.target_exit_data whenever these operations
have the depend clause on them.

These operations are transformed by enclosing them inside a new omp.task
operation. The depend clause related arguments are moved to the new omp.task
operation from the original 'target' operation.

Example:
Input:

omp.target map_entries(%map_a -> %arg0, %map_c_from -> %arg1 : memref<?xi32>, memref<?xi32>) depend(taskdependout -> %c : memref<?xi32>) {
  ^bb0(%arg0 : memref<?xi32>, %arg1 : memref<?xi32>) :
      "test.foobar"() : ()->()
       omp.terminator
}

Output:

omp.task depend(taskdependout -> %c : memref<?xi32>) {
  omp.target map_entries(%map_a -> %arg0, %map_c_from -> %arg1 : memref<?xi32>, memref<?xi32>) {
    ^bb0(%arg0 : memref<?xi32>, %arg1 : memref<?xi32>) :
        "test.foobar"() : ()->()
         omp.terminator
  }
  omp.terminator
}

The intent is to make it easier to translate to LLVMIR by avoiding the creation of such tasks in the OMPIRBuilder.

@llvmbot
Copy link
Member

llvmbot commented Mar 5, 2024

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-flang-openmp

Author: Pranav Bhandarkar (bhandarkar-pranav)

Changes

This patch adds a pass that transforms omp.target, omp.target_enter_data,
omp.target_update_data and omp.target_exit_data whenever these operations
have the depend clause on them.

These operations are transformed by enclosing them inside a new omp.task
operation. The depend clause related arguments are moved to the new omp.task
operation from the original 'target' operation.

Example:
Input:

omp.target map_entries(%map_a -&gt; %arg0, %map_c_from -&gt; %arg1 : memref&lt;?xi32&gt;, memref&lt;?xi32&gt;) depend(taskdependout -&gt; %c : memref&lt;?xi32&gt;) {
  ^bb0(%arg0 : memref&lt;?xi32&gt;, %arg1 : memref&lt;?xi32&gt;) :
      "test.foobar"() : ()-&gt;()
       omp.terminator
}

Output:

omp.task depend(taskdependout -&gt; %c : memref&lt;?xi32&gt;) {
  omp.target map_entries(%map_a -&gt; %arg0, %map_c_from -&gt; %arg1 : memref&lt;?xi32&gt;, memref&lt;?xi32&gt;) {
    ^bb0(%arg0 : memref&lt;?xi32&gt;, %arg1 : memref&lt;?xi32&gt;) :
        "test.foobar"() : ()-&gt;()
         omp.terminator
  }
  omp.terminator
}

The intent is to make it easier to translate to LLVMIR by avoiding the creation of such tasks in the OMPIRBuilder.


Full diff: https://github.com/llvm/llvm-project/pull/83966.diff

11 Files Affected:

  • (modified) mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt (+8)
  • (added) mlir/include/mlir/Dialect/OpenMP/Passes.h (+35)
  • (added) mlir/include/mlir/Dialect/OpenMP/Passes.td (+54)
  • (modified) mlir/include/mlir/InitAllPasses.h (+2)
  • (modified) mlir/lib/CAPI/Dialect/CMakeLists.txt (+1)
  • (modified) mlir/lib/Dialect/OpenMP/CMakeLists.txt (+2-17)
  • (added) mlir/lib/Dialect/OpenMP/IR/CMakeLists.txt (+17)
  • (added) mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt (+16)
  • (added) mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp (+120)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+31-23)
  • (added) mlir/test/Dialect/OpenMP/task-based-target.mlir (+67)
diff --git a/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt b/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt
index 419e24a7335361..51ab0f23cd00d5 100644
--- a/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt
@@ -23,3 +23,11 @@ mlir_tablegen(OpenMPTypeInterfaces.h.inc -gen-type-interface-decls)
 mlir_tablegen(OpenMPTypeInterfaces.cpp.inc -gen-type-interface-defs)
 add_public_tablegen_target(MLIROpenMPTypeInterfacesIncGen)
 add_dependencies(mlir-generic-headers MLIROpenMPTypeInterfacesIncGen)
+
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name OpenMP)
+mlir_tablegen(Passes.capi.h.inc -gen-pass-capi-header --prefix OpenMP)
+mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl --prefix OpenMP)
+add_public_tablegen_target(MLIROpenMPPassIncGen)
+
+add_mlir_doc(Passes OpenMPPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/OpenMP/Passes.h b/mlir/include/mlir/Dialect/OpenMP/Passes.h
new file mode 100644
index 00000000000000..2167c95055d31f
--- /dev/null
+++ b/mlir/include/mlir/Dialect/OpenMP/Passes.h
@@ -0,0 +1,35 @@
+//===- Passes.h - OpenMP passes entry points -----------------------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header file defines prototypes that expose pass constructors.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_OPENMP_PASSES_H
+#define MLIR_DIALECT_OPENMP_PASSES_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+std::unique_ptr<Pass> createOpenMPTaskBasedTargetPass();
+
+//===----------------------------------------------------------------------===//
+// Registration
+//===----------------------------------------------------------------------===//
+
+namespace omp {
+
+/// Generate the code for registering passes.
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/OpenMP/Passes.h.inc"
+
+} // namespace omp
+} // namespace mlir
+
+#endif
diff --git a/mlir/include/mlir/Dialect/OpenMP/Passes.td b/mlir/include/mlir/Dialect/OpenMP/Passes.td
new file mode 100644
index 00000000000000..4863999db24c56
--- /dev/null
+++ b/mlir/include/mlir/Dialect/OpenMP/Passes.td
@@ -0,0 +1,54 @@
+//===-- Passes.td - OpenMP pass definition file -------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_OPENMP_PASSES
+#define MLIR_DIALECT_OPENMP_PASSES
+
+include "mlir/Pass/PassBase.td"
+
+def OpenMPTaskBasedTarget : Pass<"openmp-task-based-target", "func::FuncOp"> {
+  let summary = "Nest certain instances of mlir::omp::TargetOp inside mlir::omp::TaskOp";
+
+  let constructor = "mlir::createOpenMPTaskBasedTargetPass()";
+
+  let description = [{
+  This pass transforms `omp.target`, `omp.target_enter_data`,
+  `omp.target_update_data` and `omp.target_exit_data` whenever these operations
+  have the `depend` clause on them.
+
+  These operations are transformed by enclosing them inside a new `omp.task`
+  operation. The `depend` clause related arguments are moved to the new `omp.task`
+  operation from the original 'target' operation.
+
+  Example:
+  Input:
+  ```mlir
+  omp.target map_entries(%map_a -> %arg0, %map_c_from -> %arg1 : memref<?xi32>, memref<?xi32>) depend(taskdependout -> %c : memref<?xi32>) {
+    ^bb0(%arg0 : memref<?xi32>, %arg1 : memref<?xi32>) :
+        "test.foobar"() : ()->()
+         omp.terminator
+  }
+  ```
+  Output:
+  ```mlir
+  omp.task depend(taskdependout -> %c : memref<?xi32>) {
+    omp.target map_entries(%map_a -> %arg0, %map_c_from -> %arg1 : memref<?xi32>, memref<?xi32>) {
+      ^bb0(%arg0 : memref<?xi32>, %arg1 : memref<?xi32>) :
+          "test.foobar"() : ()->()
+           omp.terminator
+    }
+    omp.terminator
+  }
+  ```
+  The intent is to make it easier to translate to LLVMIR by avoiding the
+  creation of such tasks in the OMPIRBuilder.
+  }];
+
+  let dependentDialects = ["omp::OpenMPDialect"];
+}
+#endif
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index 5d90c197a6cced..902ab8f4c4fd1b 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -36,6 +36,7 @@
 #include "mlir/Dialect/Mesh/Transforms/Passes.h"
 #include "mlir/Dialect/NVGPU/Transforms/Passes.h"
 #include "mlir/Dialect/OpenACC/Transforms/Passes.h"
+#include "mlir/Dialect/OpenMP/Passes.h"
 #include "mlir/Dialect/SCF/Transforms/Passes.h"
 #include "mlir/Dialect/SPIRV/Transforms/Passes.h"
 #include "mlir/Dialect/Shape/Transforms/Passes.h"
@@ -82,6 +83,7 @@ inline void registerAllPasses() {
   memref::registerMemRefPasses();
   mesh::registerMeshPasses();
   ml_program::registerMLProgramPasses();
+  omp::registerOpenMPPasses();
   registerSCFPasses();
   registerShapePasses();
   spirv::registerSPIRVPasses();
diff --git a/mlir/lib/CAPI/Dialect/CMakeLists.txt b/mlir/lib/CAPI/Dialect/CMakeLists.txt
index 58b8739043f9df..439f0093cc3d26 100644
--- a/mlir/lib/CAPI/Dialect/CMakeLists.txt
+++ b/mlir/lib/CAPI/Dialect/CMakeLists.txt
@@ -223,6 +223,7 @@ add_mlir_upstream_c_api_library(MLIRCAPIOpenMP
   LINK_LIBS PUBLIC
   MLIRCAPIIR
   MLIROpenMPDialect
+  MLIROpenMPTransforms
 )
 
 add_mlir_upstream_c_api_library(MLIRCAPIPDL
diff --git a/mlir/lib/Dialect/OpenMP/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/CMakeLists.txt
index 40b4837484a136..9f57627c321fb0 100644
--- a/mlir/lib/Dialect/OpenMP/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenMP/CMakeLists.txt
@@ -1,17 +1,2 @@
-add_mlir_dialect_library(MLIROpenMPDialect
-  IR/OpenMPDialect.cpp
-
-  ADDITIONAL_HEADER_DIRS
-  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenMP
-
-  DEPENDS
-  MLIROpenMPOpsIncGen
-  MLIROpenMPOpsInterfacesIncGen
-  MLIROpenMPTypeInterfacesIncGen
-
-  LINK_LIBS PUBLIC
-  MLIRIR
-  MLIRLLVMDialect
-  MLIRFuncDialect
-  MLIROpenACCMPCommon
-  )
+add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/OpenMP/IR/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/IR/CMakeLists.txt
new file mode 100644
index 00000000000000..def53387255653
--- /dev/null
+++ b/mlir/lib/Dialect/OpenMP/IR/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_dialect_library(MLIROpenMPDialect
+  OpenMPDialect.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenMP
+
+  DEPENDS
+  MLIROpenMPOpsIncGen
+  MLIROpenMPOpsInterfacesIncGen
+  MLIROpenMPTypeInterfacesIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRLLVMDialect
+  MLIRFuncDialect
+  MLIROpenACCMPCommon
+  )
diff --git a/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt
new file mode 100644
index 00000000000000..1a64b5268e0839
--- /dev/null
+++ b/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt
@@ -0,0 +1,16 @@
+add_mlir_dialect_library(MLIROpenMPTransforms
+  OpenMPTaskBasedTarget.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenMP
+
+  DEPENDS
+  MLIROpenMPPassIncGen
+
+  LINK_LIBS PUBLIC
+  MLIROpenMPDialect
+  MLIRFuncDialect
+  MLIRIR
+  MLIRPass
+  MLIRTransforms
+)
diff --git a/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp
new file mode 100644
index 00000000000000..1e3977eb33e754
--- /dev/null
+++ b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp
@@ -0,0 +1,120 @@
+//===- OpenMPTaskBasedTarget.cpp - Implementation of OpenMPTaskBasedTargetPass
+//---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass that transforms certain omp.target.
+// Specifically, an omp.target op that has the depend clause on it is
+// transformed into an omp.task clause with the same depend clause on it.
+// The original omp.target loses its depend clause and is contained in
+// the new task region.
+//
+// omp.target depend(..) {
+//  omp.terminator
+//
+// }
+//
+// =>
+//
+// omp.task depend(..) {
+//   omp.target {
+//     omp.terminator
+//   }
+//   omp.terminator
+// }
+//
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/OpenMP/Passes.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Debug.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_OPENMPTASKBASEDTARGET
+#include "mlir/Dialect/OpenMP/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::omp;
+
+#define DEBUG_TYPE "openmp-task-based-target"
+
+namespace {
+
+struct OpenMPTaskBasedTargetPass
+    : public impl::OpenMPTaskBasedTargetBase<OpenMPTaskBasedTargetPass> {
+
+  void runOnOperation() override;
+};
+template <typename OpTy>
+class OmpTaskBasedTargetRewritePattern : public OpRewritePattern<OpTy> {
+public:
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+  LogicalResult matchAndRewrite(OpTy op,
+                                PatternRewriter &rewriter) const override {
+
+    // Only match a target op with a  'depend' clause on it.
+    if (op.getDependVars().empty()) {
+      return rewriter.notifyMatchFailure(op, "depend clause not found on op");
+    }
+
+    // Step 1: Create a new task op and tack on the dependency from the 'depend'
+    // clause on it.
+    omp::TaskOp taskOp = rewriter.create<omp::TaskOp>(
+        op.getLoc(), /*if_expr*/ Value(),
+        /*final_expr*/ Value(),
+        /*untied*/ UnitAttr(),
+        /*mergeable*/ UnitAttr(),
+        /*in_reduction_vars*/ ValueRange(),
+        /*in_reductions*/ nullptr,
+        /*priority*/ Value(), op.getDepends().value(), op.getDependVars(),
+        /*allocate_vars*/ ValueRange(),
+        /*allocate_vars*/ ValueRange());
+    Block *block = rewriter.createBlock(&taskOp.getRegion());
+    rewriter.setInsertionPointToEnd(block);
+    // Step 2: Clone and put the entire target op inside the newly created
+    // task's region.
+    Operation *clonedTargetOperation = rewriter.clone(*op.getOperation());
+    rewriter.create<mlir::omp::TerminatorOp>(op.getLoc());
+
+    // Step 3: Remove the dependency information from the clone target op.
+    OpTy clonedTargetOp = llvm::dyn_cast<OpTy>(clonedTargetOperation);
+    if (clonedTargetOp) {
+      clonedTargetOp.removeDependsAttr();
+      clonedTargetOp.getDependVarsMutable().clear();
+    }
+    // Step 4: Erase the original target op
+    rewriter.eraseOp(op.getOperation());
+    return success();
+  }
+};
+} // namespace
+static void
+populateOmpTaskBasedTargetRewritePatterns(RewritePatternSet &patterns) {
+  patterns.add<OmpTaskBasedTargetRewritePattern<omp::TargetOp>,
+               OmpTaskBasedTargetRewritePattern<omp::EnterDataOp>,
+               OmpTaskBasedTargetRewritePattern<omp::UpdateDataOp>,
+               OmpTaskBasedTargetRewritePattern<omp::ExitDataOp>>(
+      patterns.getContext());
+}
+
+void OpenMPTaskBasedTargetPass::runOnOperation() {
+  Operation *op = getOperation();
+
+  RewritePatternSet patterns(op->getContext());
+  populateOmpTaskBasedTargetRewritePatterns(patterns);
+
+  if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+    signalPassFailure();
+}
+std::unique_ptr<Pass> mlir::createOpenMPTaskBasedTargetPass() {
+  return std::make_unique<OpenMPTaskBasedTargetPass>();
+}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index fd1de274da60e8..1b6fa5ebd83158 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -685,7 +685,36 @@ convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder,
       ompLoc, bodyCB, finiCB, singleOp.getNowait(), llvmCPVars, llvmCPFuncs));
   return bodyGenStatus;
 }
-
+template <typename T>
+static void buildDependData(T taskOrTargetop,
+                            SmallVector<llvm::OpenMPIRBuilder::DependData> &dds,
+                            LLVM::ModuleTranslation &moduleTranslation) {
+  // std::optional<ArrayAttr> depends,
+  //                  OperandRange &dependVars,
+  if (taskOrTargetop.getDependVars().empty())
+    return;
+  std::optional<ArrayAttr> depends = taskOrTargetop.getDepends();
+  const OperandRange &dependVars = taskOrTargetop.getDependVars();
+  for (auto dep : llvm::zip(dependVars, depends->getValue())) {
+    llvm::omp::RTLDependenceKindTy type;
+    switch (
+        cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
+    case mlir::omp::ClauseTaskDepend::taskdependin:
+      type = llvm::omp::RTLDependenceKindTy::DepIn;
+      break;
+      // The OpenMP runtime requires that the codegen for 'depend' clause for
+      // 'out' dependency kind must be the same as codegen for 'depend' clause
+      // with 'inout' dependency.
+    case mlir::omp::ClauseTaskDepend::taskdependout:
+    case mlir::omp::ClauseTaskDepend::taskdependinout:
+      type = llvm::omp::RTLDependenceKindTy::DepInOut;
+      break;
+    };
+    llvm::Value *depVal = moduleTranslation.lookupValue(std::get<0>(dep));
+    llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
+    dds.emplace_back(dd);
+  }
+}
 // Convert an OpenMP Teams construct to LLVM IR using OpenMPIRBuilder
 static LogicalResult
 convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
@@ -748,28 +777,7 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
   };
 
   SmallVector<llvm::OpenMPIRBuilder::DependData> dds;
-  if (!taskOp.getDependVars().empty() && taskOp.getDepends()) {
-    for (auto dep :
-         llvm::zip(taskOp.getDependVars(), taskOp.getDepends()->getValue())) {
-      llvm::omp::RTLDependenceKindTy type;
-      switch (
-          cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
-      case mlir::omp::ClauseTaskDepend::taskdependin:
-        type = llvm::omp::RTLDependenceKindTy::DepIn;
-        break;
-      // The OpenMP runtime requires that the codegen for 'depend' clause for
-      // 'out' dependency kind must be the same as codegen for 'depend' clause
-      // with 'inout' dependency.
-      case mlir::omp::ClauseTaskDepend::taskdependout:
-      case mlir::omp::ClauseTaskDepend::taskdependinout:
-        type = llvm::omp::RTLDependenceKindTy::DepInOut;
-        break;
-      };
-      llvm::Value *depVal = moduleTranslation.lookupValue(std::get<0>(dep));
-      llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
-      dds.emplace_back(dd);
-    }
-  }
+  buildDependData(taskOp, dds, moduleTranslation);
 
   llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
       findAllocaInsertPoint(builder, moduleTranslation);
diff --git a/mlir/test/Dialect/OpenMP/task-based-target.mlir b/mlir/test/Dialect/OpenMP/task-based-target.mlir
new file mode 100644
index 00000000000000..26cc493047e19b
--- /dev/null
+++ b/mlir/test/Dialect/OpenMP/task-based-target.mlir
@@ -0,0 +1,67 @@
+// RUN: mlir-opt %s -openmp-task-based-target -split-input-file | FileCheck %s
+
+// CHECK-LABEL: @omp_target_depend
+// CHECK-SAME: (%arg0: memref<i32>, %arg1: memref<i32>) {
+func.func @omp_target_depend(%arg0: memref<i32>, %arg1: memref<i32>) {
+  // CHECK: omp.task depend(taskdependin -> %arg0 : memref<i32>, taskdependin -> %arg1 : memref<i32>, taskdependinout -> %arg0 : memref<i32>) {
+  // CHECK: omp.target {
+  omp.target depend(taskdependin -> %arg0 : memref<i32>, taskdependin -> %arg1 : memref<i32>, taskdependinout -> %arg0 : memref<i32>) {
+    // CHECK: omp.terminator
+    omp.terminator
+  } {operandSegmentSizes = array<i32: 0,0,0,3,0>}
+  return
+}
+// CHECK-LABEL: func @omp_target_enter_update_exit_data_depend
+// CHECK-SAME:([[ARG0:%.*]]: memref<?xi32>, [[ARG1:%.*]]: memref<?xi32>, [[ARG2:%.*]]: memref<?xi32>) {
+func.func @omp_target_enter_update_exit_data_depend(%a: memref<?xi32>, %b: memref<?xi32>, %c: memref<?xi32>) {
+// CHECK-NEXT: [[MAP0:%.*]] = omp.map_info
+// CHECK-NEXT: [[MAP1:%.*]] = omp.map_info
+// CHECK-NEXT: [[MAP2:%.*]] = omp.map_info
+  %map_a = omp.map_info var_ptr(%a: memref<?xi32>, tensor<?xi32>) map_clauses(to) capture(ByRef) -> memref<?xi32>
+  %map_b = omp.map_info var_ptr(%b: memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
+  %map_c = omp.map_info var_ptr(%c: memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32>
+
+  // Do some work on the host that writes to 'a'
+  omp.task depend(taskdependout -> %a : memref<?xi32>) {
+    "test.foo"(%a) : (memref<?xi32>) -> ()
+    omp.terminator
+  }
+
+  // Then map that over to the target
+  // CHECK: omp.task depend(taskdependin -> [[ARG0]] : memref<?xi32>)
+  // CHECK: omp.target_enter_data nowait map_entries([[MAP0]], [[MAP2]] : memref<?xi32>, memref<?xi32>)
+  omp.target_enter_data nowait map_entries(%map_a, %map_c: memref<?xi32>, memref<?xi32>) depend(taskdependin ->  %a: memref<?xi32>)
+
+  // Compute 'b' on the target and copy it back
+  // CHECK: omp.target map_entries([[MAP1]] -> {{%.*}} : memref<?xi32>) {
+  omp.target map_entries(%map_b -> %arg0 : memref<?xi32>) {
+    ^bb0(%arg0: memref<?xi32>) :
+      "test.foo"(%arg0) : (memref<?xi32>) -> ()
+      omp.terminator
+  }
+
+  // Update 'a' on the host using 'b'
+  omp.task depend(taskdependout -> %a: memref<?xi32>){
+    "test.bar"(%a, %b) : (memref<?xi32>, memref<?xi32>) -> ()
+  }
+
+  // Copy the updated 'a' onto the target
+  // CHECK: omp.task depend(taskdependin -> [[ARG0]] : memref<?xi32>)
+  // CHECK: omp.target_update_data nowait motion_entries([[MAP0]] : memref<?xi32>)
+  omp.target_update_data motion_entries(%map_a :  memref<?xi32>) depend(taskdependin -> %a : memref<?xi32>) nowait
+
+  // Compute 'c' on the target and copy it back
+  // CHECK:[[MAP3:%.*]] = omp.map_info var_ptr([[ARG2]] : memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
+  %map_c_from = omp.map_info var_ptr(%c: memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
+  // CHECK: omp.task depend(taskdependout -> [[ARG2]] : memref<?xi32>)
+  // CHECK: omp.target map_entries([[MAP0]] -> {{%.*}}, [[MAP3]] -> {{%.*}} : memref<?xi32>, memref<?xi32>) {
+  omp.target map_entries(%map_a -> %arg0, %map_c_from -> %arg1 : memref<?xi32>, memref<?xi32>) depend(taskdependout -> %c : memref<?xi32>) {
+  ^bb0(%arg0 : memref<?xi32>, %arg1 : memref<?xi32>) :
+    "test.foobar"() : ()->()
+    omp.terminator
+  }
+  // CHECK: omp.task depend(taskdependin -> [[ARG2]] : memref<?xi32>) {
+  // CHECK: omp.target_exit_data map_entries([[MAP2]] : memref<?xi32>)
+  omp.target_exit_data map_entries(%map_c : memref<?xi32>) depend(taskdependin -> %c : memref<?xi32>)
+  return
+}

@skatrak
Copy link
Member

skatrak commented Mar 5, 2024

Is there anything preventing this from being done as part of PFT to MLIR lowering rather than a pass? It should be possible to make the ClauseProcessor::processDepend method create the outer omp.task with the corresponding DEPEND clause, and then have a simpler ClauseProcessor::processTaskDepend for the case when it's already attached to the TASK construct (so no new omp.task is created then). omp.target* ops wouldn't need any DEPEND-related arguments at that point.

In my opinion, that would be less expensive than adding a full pass, and in principle it shouldn't create any big compromises that I can think of. Do you agree with that? I don't mind this approach either, if it's generally preferred by others.

@bhandarkar-pranav
Copy link
Contributor Author

Is there anything preventing this from being done as part of PFT to MLIR lowering rather than a pass? It should be possible to make the ClauseProcessor::processDepend method create the outer omp.task with the corresponding DEPEND clause, and then have a simpler ClauseProcessor::processTaskDepend for the case when it's already attached to the TASK construct (so no new omp.task is created then). omp.target* ops wouldn't need any DEPEND-related arguments at that point.

In my opinion, that would be less expensive than adding a full pass, and in principle it shouldn't create any big compromises that I can think of. Do you agree with that? I don't mind this approach either, if it's generally preferred by others.

Thank you for your reply @skatrak.
I don't think there is anything preventing it from being done in flang (PFT->MLIR lowering) and in fact, @ergawy had even suggested this in my discussion with him. However, my reasons for doing it this way largely boil down to preference for the following reasons (in no particular order)

  • Keeping it in MLIR allows other frontends to use this because for any frontend to realize offloading in LLVMIR, it'll have to go the route of creating an outer enclosing task
  • This way, arguably, keeps the interface of PFT->MLIR lowering cleaner - There will only be one process method handling depend -> processDepend and all ops using the depend clause will have the same interface (aka processDepend)
  • It separates the concerns of lowering (PFT->MLIR) and codegen (this transformation is needed essentially for codegen)

I felt at the time that these were reasons enough for paying the price of going over the IR in a new pass altogether. Having said that, it isn't a very very strong preference at all and I am happy to change my approach should that be required. How about I leave this open for others to comment on, while I test this with an end-to-end testcase by invoking this pass from within flang?

Copy link
Member

@ergawy ergawy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a few small comments. I like the pass approach better than handling this OpenMP.cpp tbh :).

@@ -0,0 +1,54 @@
//===-- Passes.td - OpenMP pass definition file -------------*- tablegen -*-===//
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previous OpenMP passes were added in https://github.com/llvm/llvm-project/blob/main/flang/include/flang/Optimizer/Transforms/Passes.td#L321.

Even though I agree that OpenMP passes like this one and the ones linked above can be separated into their own file since they are not flang specific, but I think for consistency it would be better to use the already existing file to have a more centralized view of all OpenMP-related passes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a little background on why the passes are there from what I recall at least (more for information to help make informed decisions or for someone else to chime in with a preference), the OMPMarkDeclareTargetPass/OMPFunctionFiltering are there as they were initially in an experimental phase (alongside another now erased pass) to verify if we'd keep them in the offloading flow and see how they'd interact with each other overtime as they are interlinked behavior. Not sure when they'd best to be moved, it's coming up to their anniversary at this point, but it's something that can wait in any case.

The OMPDescriptorMapInfoGenPass is very much Fortran specific at the moment, it primarily handles FIR BoxType's, so it's not feasible at the moment (or perhaps ever) to move to the OpenMP dialect as it has FIR dependencies and FIR is an external dialect to the MLIR project.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the info @agozillon. @ergawy - Like you said, I kept this in mlir only because it isnt really flang specific and I anticipated that if I put it in flang someone might call me out for it. OTOH keeping it in flang would have made testcase generation easier as I'd have been able to load the fir dialect into the MLIR context. Now I had to rely on fir-opt to convert all fir ops to llvm dialect ops, but fir-opt is broken in a weird way such that I couldn't run fir-to-llvm-ir lowering to rid my testcases of fir ops. (Aside: I hope to find some time to chase down the fir-opt problem)


// Only match a target op with a 'depend' clause on it.
if (op.getDependVars().empty()) {
return rewriter.notifyMatchFailure(op, "depend clause not found on op");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use a ConversionTarget and setup it with legality information instead? https://github.com/llvm/llvm-project/blob/main/mlir/docs/DialectConversion.md#conversion-target

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did consider it and correct me if I am wrong, but a ConversionPattern is needed only if types need to be legalized as well correct? Since I didn't need that I chose to use an OpRewritePattern. Happy to change it if there is a strong preference. BTW, thanks for pointing me to your 'do concurrent' PR, that helped me start up quickly.

@jsjodin
Copy link
Contributor

jsjodin commented Mar 7, 2024

Is there anything preventing this from being done as part of PFT to MLIR lowering rather than a pass? It should be possible to make the ClauseProcessor::processDepend method create the outer omp.task with the corresponding DEPEND clause, and then have a simpler ClauseProcessor::processTaskDepend for the case when it's already attached to the TASK construct (so no new omp.task is created then). omp.target* ops wouldn't need any DEPEND-related arguments at that point.
In my opinion, that would be less expensive than adding a full pass, and in principle it shouldn't create any big compromises that I can think of. Do you agree with that? I don't mind this approach either, if it's generally preferred by others.

Thank you for your reply @skatrak. I don't think there is anything preventing it from being done in flang (PFT->MLIR lowering) and in fact, @ergawy had even suggested this in my discussion with him. However, my reasons for doing it this way largely boil down to preference for the following reasons (in no particular order)

  • Keeping it in MLIR allows other frontends to use this because for any frontend to realize offloading in LLVMIR, it'll have to go the route of creating an outer enclosing task
  • This way, arguably, keeps the interface of PFT->MLIR lowering cleaner - There will only be one process method handling depend -> processDepend and all ops using the depend clause will have the same interface (aka processDepend)
  • It separates the concerns of lowering (PFT->MLIR) and codegen (this transformation is needed essentially for codegen)

I felt at the time that these were reasons enough for paying the price of going over the IR in a new pass altogether. Having said that, it isn't a very very strong preference at all and I am happy to change my approach should that be required. How about I leave this open for others to comment on, while I test this with an end-to-end testcase by invoking this pass from within flang?

It seems like this is similar to what @kparzysz is working on addressing, not quite the same perhaps, but it would be nice if we could produce MLIR directly that has the omp.task op. If we do that, would it be possible to remove the depend clause from the omp.target op?

Also add the if clause to the newly generated omp.task op that
encloses the omp.target op.
@bhandarkar-pranav bhandarkar-pranav force-pushed the mlir/depend_clause_mlir_to_llvmir branch from ae8a136 to 65c7554 Compare May 2, 2024 04:59
Copy link

github-actions bot commented May 2, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@clementval clementval removed their request for review August 6, 2024 21:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants