Skip to content

[mlir][OpenMP] - Transform target offloading directives for easier translation to LLVMIR #83966

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,11 @@ mlir_tablegen(OpenMPTypeInterfaces.h.inc -gen-type-interface-decls)
mlir_tablegen(OpenMPTypeInterfaces.cpp.inc -gen-type-interface-defs)
add_public_tablegen_target(MLIROpenMPTypeInterfacesIncGen)
add_dependencies(mlir-generic-headers MLIROpenMPTypeInterfacesIncGen)

set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name OpenMP)
mlir_tablegen(Passes.capi.h.inc -gen-pass-capi-header --prefix OpenMP)
mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl --prefix OpenMP)
add_public_tablegen_target(MLIROpenMPPassIncGen)

add_mlir_doc(Passes OpenMPPasses ./ -gen-pass-doc)
35 changes: 35 additions & 0 deletions mlir/include/mlir/Dialect/OpenMP/Passes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
//===- Passes.h - OpenMP passes entry points -----------------------*- C++
//-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This header file defines prototypes that expose pass constructors.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_OPENMP_PASSES_H
#define MLIR_DIALECT_OPENMP_PASSES_H

#include "mlir/Pass/Pass.h"

namespace mlir {
std::unique_ptr<Pass> createOpenMPTaskBasedTargetPass();

//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//

namespace omp {

/// Generate the code for registering passes.
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/OpenMP/Passes.h.inc"

} // namespace omp
} // namespace mlir

#endif
54 changes: 54 additions & 0 deletions mlir/include/mlir/Dialect/OpenMP/Passes.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
//===-- Passes.td - OpenMP pass definition file -------------*- tablegen -*-===//
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previous OpenMP passes were added in https://github.com/llvm/llvm-project/blob/main/flang/include/flang/Optimizer/Transforms/Passes.td#L321.

Even though I agree that OpenMP passes like this one and the ones linked above can be separated into their own file since they are not flang specific, but I think for consistency it would be better to use the already existing file to have a more centralized view of all OpenMP-related passes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a little background on why the passes are there from what I recall at least (more for information to help make informed decisions or for someone else to chime in with a preference), the OMPMarkDeclareTargetPass/OMPFunctionFiltering are there as they were initially in an experimental phase (alongside another now erased pass) to verify if we'd keep them in the offloading flow and see how they'd interact with each other overtime as they are interlinked behavior. Not sure when they'd best to be moved, it's coming up to their anniversary at this point, but it's something that can wait in any case.

The OMPDescriptorMapInfoGenPass is very much Fortran specific at the moment, it primarily handles FIR BoxType's, so it's not feasible at the moment (or perhaps ever) to move to the OpenMP dialect as it has FIR dependencies and FIR is an external dialect to the MLIR project.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the info @agozillon. @ergawy - Like you said, I kept this in mlir only because it isnt really flang specific and I anticipated that if I put it in flang someone might call me out for it. OTOH keeping it in flang would have made testcase generation easier as I'd have been able to load the fir dialect into the MLIR context. Now I had to rely on fir-opt to convert all fir ops to llvm dialect ops, but fir-opt is broken in a weird way such that I couldn't run fir-to-llvm-ir lowering to rid my testcases of fir ops. (Aside: I hope to find some time to chase down the fir-opt problem)

//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_OPENMP_PASSES
#define MLIR_DIALECT_OPENMP_PASSES

include "mlir/Pass/PassBase.td"

def OpenMPTaskBasedTarget : Pass<"openmp-task-based-target", "func::FuncOp"> {
let summary = "Nest certain instances of mlir::omp::TargetOp inside mlir::omp::TaskOp";

let constructor = "mlir::createOpenMPTaskBasedTargetPass()";

let description = [{
This pass transforms `omp.target`, `omp.target_enter_data`,
`omp.target_update_data` and `omp.target_exit_data` whenever these operations
have the `depend` clause on them.

These operations are transformed by enclosing them inside a new `omp.task`
operation. The `depend` clause related arguments are moved to the new `omp.task`
operation from the original 'target' operation.

Example:
Input:
```mlir
omp.target map_entries(%map_a -> %arg0, %map_c_from -> %arg1 : memref<?xi32>, memref<?xi32>) depend(taskdependout -> %c : memref<?xi32>) {
^bb0(%arg0 : memref<?xi32>, %arg1 : memref<?xi32>) :
"test.foobar"() : ()->()
omp.terminator
}
```
Output:
```mlir
omp.task depend(taskdependout -> %c : memref<?xi32>) {
omp.target map_entries(%map_a -> %arg0, %map_c_from -> %arg1 : memref<?xi32>, memref<?xi32>) {
^bb0(%arg0 : memref<?xi32>, %arg1 : memref<?xi32>) :
"test.foobar"() : ()->()
omp.terminator
}
omp.terminator
}
```
The intent is to make it easier to translate to LLVMIR by avoiding the
creation of such tasks in the OMPIRBuilder.
}];

let dependentDialects = ["omp::OpenMPDialect"];
}
#endif
2 changes: 2 additions & 0 deletions mlir/include/mlir/InitAllPasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "mlir/Dialect/Mesh/Transforms/Passes.h"
#include "mlir/Dialect/NVGPU/Transforms/Passes.h"
#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
#include "mlir/Dialect/OpenMP/Passes.h"
#include "mlir/Dialect/SCF/Transforms/Passes.h"
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
Expand Down Expand Up @@ -83,6 +84,7 @@ inline void registerAllPasses() {
memref::registerMemRefPasses();
mesh::registerMeshPasses();
ml_program::registerMLProgramPasses();
omp::registerOpenMPPasses();
registerSCFPasses();
registerShapePasses();
spirv::registerSPIRVPasses();
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/CAPI/Dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ add_mlir_upstream_c_api_library(MLIRCAPIOpenMP
LINK_LIBS PUBLIC
MLIRCAPIIR
MLIROpenMPDialect
MLIROpenMPTransforms
)

add_mlir_upstream_c_api_library(MLIRCAPIPDL
Expand Down
20 changes: 2 additions & 18 deletions mlir/lib/Dialect/OpenMP/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,18 +1,2 @@
add_mlir_dialect_library(MLIROpenMPDialect
IR/OpenMPDialect.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenMP

DEPENDS
omp_gen
MLIROpenMPOpsIncGen
MLIROpenMPOpsInterfacesIncGen
MLIROpenMPTypeInterfacesIncGen

LINK_LIBS PUBLIC
MLIRIR
MLIRLLVMDialect
MLIRFuncDialect
MLIROpenACCMPCommon
)
add_subdirectory(IR)
add_subdirectory(Transforms)
18 changes: 18 additions & 0 deletions mlir/lib/Dialect/OpenMP/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
add_mlir_dialect_library(MLIROpenMPDialect
OpenMPDialect.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenMP

DEPENDS
omp_gen
MLIROpenMPOpsIncGen
MLIROpenMPOpsInterfacesIncGen
MLIROpenMPTypeInterfacesIncGen

LINK_LIBS PUBLIC
MLIRIR
MLIRLLVMDialect
MLIRFuncDialect
MLIROpenACCMPCommon
)
17 changes: 17 additions & 0 deletions mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
add_mlir_dialect_library(MLIROpenMPTransforms
OpenMPTaskBasedTarget.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenMP

DEPENDS
MLIROpenMPPassIncGen

LINK_LIBS PUBLIC
MLIROpenMPDialect
MLIRArithDialect
MLIRFuncDialect
MLIRIR
MLIRPass
MLIRTransforms
)
129 changes: 129 additions & 0 deletions mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
//===- OpenMPTaskBasedTarget.cpp - Implementation of OpenMPTaskBasedTargetPass
//---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass that transforms certain omp.target.
// Specifically, an omp.target op that has the depend clause on it is
// transformed into an omp.task clause with the same depend clause on it.
// The original omp.target loses its depend clause and is contained in
// the new task region.
//
// omp.target depend(..) {
// omp.terminator
//
// }
//
// =>
//
// omp.task depend(..) {
// omp.target {
// omp.terminator
// }
// omp.terminator
// }
//
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/OpenMP/Passes.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"

namespace mlir {
#define GEN_PASS_DEF_OPENMPTASKBASEDTARGET
#include "mlir/Dialect/OpenMP/Passes.h.inc"
} // namespace mlir

using namespace mlir;
using namespace mlir::omp;

#define DEBUG_TYPE "openmp-task-based-target"

namespace {

struct OpenMPTaskBasedTargetPass
: public impl::OpenMPTaskBasedTargetBase<OpenMPTaskBasedTargetPass> {

void runOnOperation() override;
};
template <typename OpTy>
class OmpTaskBasedTargetRewritePattern : public OpRewritePattern<OpTy> {
public:
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {

// Only match a target op with a 'depend' clause on it.
if (op.getDependVars().empty()) {
return rewriter.notifyMatchFailure(op, "depend clause not found on op");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use a ConversionTarget and setup it with legality information instead? https://github.com/llvm/llvm-project/blob/main/mlir/docs/DialectConversion.md#conversion-target

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did consider it and correct me if I am wrong, but a ConversionPattern is needed only if types need to be legalized as well correct? Since I didn't need that I chose to use an OpRewritePattern. Happy to change it if there is a strong preference. BTW, thanks for pointing me to your 'do concurrent' PR, that helped me start up quickly.

}

// Step 1: Create a new task op and tack on the dependency from the 'depend'
// clause on it.
Type i1Ty = rewriter.getI1Type();
// mlir::BoolAttr T = rewriter.getBoolAttr(true);
// mlir::BoolAttr F = rewriter.getBoolAttr(false);
omp::TaskOp taskOp = rewriter.create<omp::TaskOp>(
op.getLoc(),
/*if_expr*/ op.getNowait()
? rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1))
: rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0)),
/*final_expr*/ Value(),
/*untied*/ UnitAttr(),
/*mergeable*/ UnitAttr(),
/*in_reduction_vars*/ ValueRange(),
/*in_reductions*/ nullptr,
/*priority*/ Value(), op.getDepends().value(), op.getDependVars(),
/*allocate_vars*/ ValueRange(),
/*allocate_vars*/ ValueRange());
Block *block = rewriter.createBlock(&taskOp.getRegion());
rewriter.setInsertionPointToEnd(block);
// Step 2: Clone and put the entire target op inside the newly created
// task's region.
Operation *clonedTargetOperation = rewriter.clone(*op.getOperation());
rewriter.create<mlir::omp::TerminatorOp>(op.getLoc());

// Step 3: Remove the dependency information from the clone target op.
OpTy clonedTargetOp = llvm::dyn_cast<OpTy>(clonedTargetOperation);
if (clonedTargetOp) {
clonedTargetOp.removeDependsAttr();
clonedTargetOp.getDependVarsMutable().clear();
}
// Step 4: Erase the original target op
rewriter.eraseOp(op.getOperation());
return success();
}
};
} // namespace
static void
populateOmpTaskBasedTargetRewritePatterns(RewritePatternSet &patterns) {
patterns.add<OmpTaskBasedTargetRewritePattern<omp::TargetOp>,
OmpTaskBasedTargetRewritePattern<omp::TargetEnterDataOp>,
OmpTaskBasedTargetRewritePattern<omp::TargetUpdateOp>,
OmpTaskBasedTargetRewritePattern<omp::TargetExitDataOp>>(
patterns.getContext());
}

void OpenMPTaskBasedTargetPass::runOnOperation() {
Operation *op = getOperation();

RewritePatternSet patterns(op->getContext());
populateOmpTaskBasedTargetRewritePatterns(patterns);

if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
signalPassFailure();
}
std::unique_ptr<Pass> mlir::createOpenMPTaskBasedTargetPass() {
return std::make_unique<OpenMPTaskBasedTargetPass>();
}
67 changes: 67 additions & 0 deletions mlir/test/Dialect/OpenMP/task-based-target.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// RUN: mlir-opt %s -openmp-task-based-target -split-input-file | FileCheck %s

// CHECK-LABEL: @omp_target_depend
// CHECK-SAME: (%arg0: memref<i32>, %arg1: memref<i32>) {
func.func @omp_target_depend(%arg0: memref<i32>, %arg1: memref<i32>) {
// CHECK: omp.task if(%false) depend(taskdependin -> %arg0 : memref<i32>, taskdependin -> %arg1 : memref<i32>, taskdependinout -> %arg0 : memref<i32>) {
// CHECK: omp.target {
omp.target depend(taskdependin -> %arg0 : memref<i32>, taskdependin -> %arg1 : memref<i32>, taskdependinout -> %arg0 : memref<i32>) {
// CHECK: omp.terminator
omp.terminator
} {operandSegmentSizes = array<i32: 0,0,0,3,0>}
return
}
// CHECK-LABEL: func @omp_target_enter_update_exit_data_depend
// CHECK-SAME:([[ARG0:%.*]]: memref<?xi32>, [[ARG1:%.*]]: memref<?xi32>, [[ARG2:%.*]]: memref<?xi32>) {
func.func @omp_target_enter_update_exit_data_depend(%a: memref<?xi32>, %b: memref<?xi32>, %c: memref<?xi32>) {
// CHECK: [[MAP0:%.*]] = omp.map.info
// CHECK-NEXT: [[MAP1:%.*]] = omp.map.info
// CHECK-NEXT: [[MAP2:%.*]] = omp.map.info
%map_a = omp.map.info var_ptr(%a: memref<?xi32>, tensor<?xi32>) map_clauses(to) capture(ByRef) -> memref<?xi32>
%map_b = omp.map.info var_ptr(%b: memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
%map_c = omp.map.info var_ptr(%c: memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32>

// Do some work on the host that writes to 'a'
omp.task depend(taskdependout -> %a : memref<?xi32>) {
"test.foo"(%a) : (memref<?xi32>) -> ()
omp.terminator
}

// Then map that over to the target
// CHECK: omp.task if(%true) depend(taskdependin -> [[ARG0]] : memref<?xi32>)
// CHECK: omp.target_enter_data nowait map_entries([[MAP0]], [[MAP2]] : memref<?xi32>, memref<?xi32>)
omp.target_enter_data nowait map_entries(%map_a, %map_c: memref<?xi32>, memref<?xi32>) depend(taskdependin -> %a: memref<?xi32>)

// Compute 'b' on the target and copy it back
// CHECK: omp.target map_entries([[MAP1]] -> {{%.*}} : memref<?xi32>) {
omp.target map_entries(%map_b -> %arg0 : memref<?xi32>) {
^bb0(%arg0: memref<?xi32>) :
"test.foo"(%arg0) : (memref<?xi32>) -> ()
omp.terminator
}

// Update 'a' on the host using 'b'
omp.task depend(taskdependout -> %a: memref<?xi32>){
"test.bar"(%a, %b) : (memref<?xi32>, memref<?xi32>) -> ()
}

// Copy the updated 'a' onto the target
// CHECK: omp.task if(%true) depend(taskdependin -> [[ARG0]] : memref<?xi32>)
// CHECK: omp.target_update nowait motion_entries([[MAP0]] : memref<?xi32>)
omp.target_update motion_entries(%map_a : memref<?xi32>) depend(taskdependin -> %a : memref<?xi32>) nowait

// Compute 'c' on the target and copy it back
// CHECK:[[MAP3:%.*]] = omp.map.info var_ptr([[ARG2]] : memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
%map_c_from = omp.map.info var_ptr(%c: memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
// CHECK: omp.task if(%false) depend(taskdependout -> [[ARG2]] : memref<?xi32>)
// CHECK: omp.target map_entries([[MAP0]] -> {{%.*}}, [[MAP3]] -> {{%.*}} : memref<?xi32>, memref<?xi32>) {
omp.target map_entries(%map_a -> %arg0, %map_c_from -> %arg1 : memref<?xi32>, memref<?xi32>) depend(taskdependout -> %c : memref<?xi32>) {
^bb0(%arg0 : memref<?xi32>, %arg1 : memref<?xi32>) :
"test.foobar"() : ()->()
omp.terminator
}
// CHECK: omp.task if(%false) depend(taskdependin -> [[ARG2]] : memref<?xi32>) {
// CHECK: omp.target_exit_data map_entries([[MAP2]] : memref<?xi32>)
omp.target_exit_data map_entries(%map_c : memref<?xi32>) depend(taskdependin -> %c : memref<?xi32>)
return
}
Loading