diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.td b/flang/include/flang/Optimizer/OpenMP/Passes.td index 704faf0ccd856..635604ca33550 100644 --- a/flang/include/flang/Optimizer/OpenMP/Passes.td +++ b/flang/include/flang/Optimizer/OpenMP/Passes.td @@ -46,7 +46,8 @@ def FunctionFilteringPass : Pass<"omp-function-filtering"> { "for the target device."; let dependentDialects = [ "mlir::func::FuncDialect", - "fir::FIROpsDialect" + "fir::FIROpsDialect", + "mlir::omp::OpenMPDialect" ]; } diff --git a/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp b/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp index 9554808824ac3..b600de9702fd4 100644 --- a/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp +++ b/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp @@ -13,13 +13,16 @@ #include "flang/Optimizer/Dialect/FIRDialect.h" #include "flang/Optimizer/Dialect/FIROpsSupport.h" +#include "flang/Optimizer/HLFIR/HLFIROps.h" #include "flang/Optimizer/OpenMP/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/OpenMP/OpenMPInterfaces.h" #include "mlir/IR/BuiltinOps.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" namespace flangomp { #define GEN_PASS_DEF_FUNCTIONFILTERINGPASS @@ -28,6 +31,104 @@ namespace flangomp { using namespace mlir; +/// Add an operation to one of the output sets to be later rewritten, based on +/// whether it is located within the given region. +template +static void collectRewriteImpl(OpTy op, Region ®ion, + llvm::SetVector &rewrites, + llvm::SetVector *parentRewrites) { + if (rewrites.contains(op)) + return; + + if (!parentRewrites || region.isAncestor(op->getParentRegion())) + rewrites.insert(op); + else + parentRewrites->insert(op.getOperation()); +} + +template +static void collectRewrite(OpTy op, Region ®ion, + llvm::SetVector &rewrites, + llvm::SetVector *parentRewrites) { + collectRewriteImpl(op, region, rewrites, parentRewrites); +} + +/// Add an \c omp.map.info operation and all its members recursively to one of +/// the output sets to be later rewritten, based on whether they are located +/// within the given region. +/// +/// Dependencies across \c omp.map.info are maintained by ensuring dependencies +/// are added to the output sets before operations based on them. +template <> +void collectRewrite(omp::MapInfoOp mapOp, Region ®ion, + llvm::SetVector &rewrites, + llvm::SetVector *parentRewrites) { + for (Value member : mapOp.getMembers()) + collectRewrite(cast(member.getDefiningOp()), region, + rewrites, parentRewrites); + + collectRewriteImpl(mapOp, region, rewrites, parentRewrites); +} + +/// Add the given value to a sorted set if it should be replaced by a +/// placeholder when used as an operand that must remain for the device. The +/// used output set used will depend on whether the value is defined within the +/// given region. +/// +/// Values that are block arguments of \c omp.target_data and \c func.func +/// operations are skipped, since they will still be available after all +/// rewrites are completed. +static void collectRewrite(Value value, Region ®ion, + llvm::SetVector &rewrites, + llvm::SetVector *parentRewrites) { + if ((isa(value) && + isa( + cast(value).getOwner()->getParentOp())) || + rewrites.contains(value)) + return; + + if (!parentRewrites) { + rewrites.insert(value); + return; + } + + Region *definingRegion; + if (auto blockArg = dyn_cast(value)) + definingRegion = blockArg.getOwner()->getParent(); + else + definingRegion = value.getDefiningOp()->getParentRegion(); + + assert(definingRegion && "defining op/block must exist in a region"); + + if (region.isAncestor(definingRegion)) + rewrites.insert(value); + else + parentRewrites->insert(value); +} + +/// Add operations in \c childRewrites to one of the output sets based on +/// whether they are located within the given region. +template +static void +applyChildRewrites(Region ®ion, + const llvm::SetVector &childRewrites, + llvm::SetVector &rewrites, + llvm::SetVector *parentRewrites) { + for (Operation *rewrite : childRewrites) + if (auto op = dyn_cast(*rewrite)) + collectRewrite(op, region, rewrites, parentRewrites); +} + +/// Add values in \c childRewrites to one of the output sets based on +/// whether they are defined within the given region. +static void applyChildRewrites(Region ®ion, + const llvm::SetVector &childRewrites, + llvm::SetVector &rewrites, + llvm::SetVector *parentRewrites) { + for (Value value : childRewrites) + collectRewrite(value, region, rewrites, parentRewrites); +} + namespace { class FunctionFilteringPass : public flangomp::impl::FunctionFilteringPassBase { @@ -94,6 +195,12 @@ class FunctionFilteringPass funcOp.erase(); return WalkResult::skip(); } + + if (failed(rewriteHostRegion(funcOp.getRegion()))) { + funcOp.emitOpError() << "could not be rewritten for target device"; + return WalkResult::interrupt(); + } + if (declareTargetOp) declareTargetOp.setDeclareTarget(declareType, omp::DeclareTargetCaptureClause::to); @@ -101,5 +208,346 @@ class FunctionFilteringPass return WalkResult::advance(); }); } + +private: + /// Rewrite the given host device region belonging to a function that contains + /// \c omp.target operations, to remove host-only operations that are not used + /// by device codegen. + /// + /// It is based on the expected form of the MLIR module as produced by Flang + /// lowering and it performs the following mutations: + /// - Replace all values returned by the function with \c fir.undefined. + /// - Operations taking map-like clauses (e.g. \c omp.target, + /// \c omp.target_data, etc) are moved to the end of the function. If they + /// are nested inside of any other operations, they are hoisted out of + /// them. If the region belongs to \c omp.target_data, these operations + /// are hoisted to its top level, rather than to the parent function. + /// - \c device, \c if and \c depend clauses are removed from these target + /// functions. Values initializing other clauses are either replaced by + /// placeholders as follows: + /// - Values defined by block arguments are replaced by placeholders only + /// if they are not attached to \c func.func or \c omp.target_data + /// operations. In that case, they are kept unmodified. + /// - \c arith.constant and \c fir.address_of are maintained. + /// - Other values are replaced by a combination of an \c fir.alloca for a + /// single bit and an \c fir.convert to the original type of the value. + /// This can be done because the code eventually generated for these + /// operations will be discarded, as they aren't runnable by the target + /// device. + /// - \c omp.map.info operations associated to these target regions are + /// preserved. These are moved above all \c omp.target and sorted to + /// satisfy dependencies among them. + /// - \c bounds arguments are removed from \c omp.map.info operations. + /// - \c var_ptr and \c var_ptr_ptr arguments of \c omp.map.info are + /// handled as follows: + /// - \c var_ptr_ptr is expected to be defined by a \c fir.box_offset + /// operation which is preserved. Otherwise, the pass will fail. + /// - \c var_ptr can be defined by an \c hlfir.declare which is also + /// preserved. Its \c memref argument is replaced by a placeholder or + /// maintained similarly to non-map clauses of target operations + /// described above. If it has \c shape or \c typeparams arguments, they + /// are replaced by applicable constants. \c dummy_scope arguments + /// are discarded. + /// - Every other operation not located inside of an \c omp.target is + /// removed. + /// - Whenever a value or operation that would otherwise be replaced with a + /// placeholder is defined outside of the region being rewritten, it is + /// added to the \c parentOpRewrites or \c parentValRewrites output + /// argument, to be later handled by the caller. This is only intended to + /// properly support nested \c omp.target_data and \c omp.target placed + /// inside of \c omp.target_data. When called for the main function, these + /// output arguments must not be set. + LogicalResult + rewriteHostRegion(Region ®ion, + llvm::SetVector *parentOpRewrites = nullptr, + llvm::SetVector *parentValRewrites = nullptr) { + // Extract parent op information. + auto [funcOp, targetDataOp] = [®ion]() { + Operation *parent = region.getParentOp(); + return std::make_tuple(dyn_cast(parent), + dyn_cast(parent)); + }(); + assert((bool)funcOp != (bool)targetDataOp && + "region must be defined by either func.func or omp.target_data"); + assert((bool)parentOpRewrites == (bool)targetDataOp && + (bool)parentValRewrites == (bool)targetDataOp && + "parent rewrites must be passed iff rewriting omp.target_data"); + + // Collect operations that have mapping information associated to them. + llvm::SmallVector< + std::variant> + targetOps; + + // Sets to store pending rewrites marked by child omp.target_data ops. + llvm::SetVector childOpRewrites; + llvm::SetVector childValRewrites; + WalkResult result = region.walk([&](Operation *op) { + // Skip the inside of omp.target regions, since these contain device code. + if (auto targetOp = dyn_cast(op)) { + targetOps.push_back(targetOp); + return WalkResult::skip(); + } + + if (auto targetOp = dyn_cast(op)) { + // Recursively rewrite omp.target_data regions as well. + if (failed(rewriteHostRegion(targetOp.getRegion(), &childOpRewrites, + &childValRewrites))) { + targetOp.emitOpError() << "rewrite for target device failed"; + return WalkResult::interrupt(); + } + + targetOps.push_back(targetOp); + return WalkResult::skip(); + } + + if (auto targetOp = dyn_cast(op)) + targetOps.push_back(targetOp); + else if (auto targetOp = dyn_cast(op)) + targetOps.push_back(targetOp); + else if (auto targetOp = dyn_cast(op)) + targetOps.push_back(targetOp); + + return WalkResult::advance(); + }); + + if (result.wasInterrupted()) + return failure(); + + // Make a temporary clone of the parent operation with an empty region, + // and update all references to entry block arguments to those of the new + // region. Users will later either be moved to the new region or deleted + // when the original region is replaced by the new. + OpBuilder builder(&getContext()); + builder.setInsertionPointAfter(region.getParentOp()); + Operation *newOp = builder.cloneWithoutRegions(*region.getParentOp()); + Block &block = newOp->getRegion(0).emplaceBlock(); + + llvm::SmallVector locs; + locs.reserve(region.getNumArguments()); + llvm::transform(region.getArguments(), std::back_inserter(locs), + [](const BlockArgument &arg) { return arg.getLoc(); }); + block.addArguments(region.getArgumentTypes(), locs); + + for (auto [oldArg, newArg] : + llvm::zip_equal(region.getArguments(), block.getArguments())) + oldArg.replaceAllUsesWith(newArg); + + // Collect omp.map.info ops while satisfying interdependencies. This must be + // updated whenever operands to operations contained in targetOps change. + llvm::SetVector rewriteValues; + llvm::SetVector mapInfos; + for (auto targetOp : targetOps) { + std::visit( + [&](auto op) { + // Variables unused by the device, present on all target ops. + op.getDeviceMutable().clear(); + op.getIfExprMutable().clear(); + + for (Value mapVar : op.getMapVars()) + collectRewrite(cast(mapVar.getDefiningOp()), + region, mapInfos, parentOpRewrites); + + if constexpr (!std::is_same_v) { + // Variables unused by the device, present on all target ops + // except for omp.target_data. + op.getDependVarsMutable().clear(); + op.setDependKindsAttr(nullptr); + } + + if constexpr (std::is_same_v) { + assert(op.getHostEvalVars().empty() && + "unexpected host_eval in target device module"); + // TODO: Clear some of these operands rather than rewriting them, + // depending on whether they are needed by device codegen once + // support for them is fully implemented. + for (Value allocVar : op.getAllocateVars()) + collectRewrite(allocVar, region, rewriteValues, + parentValRewrites); + for (Value allocVar : op.getAllocatorVars()) + collectRewrite(allocVar, region, rewriteValues, + parentValRewrites); + for (Value inReduction : op.getInReductionVars()) + collectRewrite(inReduction, region, rewriteValues, + parentValRewrites); + for (Value isDevPtr : op.getIsDevicePtrVars()) + collectRewrite(isDevPtr, region, rewriteValues, + parentValRewrites); + for (Value mapVar : op.getHasDeviceAddrVars()) + collectRewrite(cast(mapVar.getDefiningOp()), + region, mapInfos, parentOpRewrites); + for (Value privateVar : op.getPrivateVars()) + collectRewrite(privateVar, region, rewriteValues, + parentValRewrites); + if (Value threadLimit = op.getThreadLimit()) + collectRewrite(threadLimit, region, rewriteValues, + parentValRewrites); + } else if constexpr (std::is_same_v) { + for (Value mapVar : op.getUseDeviceAddrVars()) + collectRewrite(cast(mapVar.getDefiningOp()), + region, mapInfos, parentOpRewrites); + for (Value mapVar : op.getUseDevicePtrVars()) + collectRewrite(cast(mapVar.getDefiningOp()), + region, mapInfos, parentOpRewrites); + } + }, + targetOp); + } + + applyChildRewrites(region, childOpRewrites, mapInfos, parentOpRewrites); + + // Move omp.map.info ops to the new block and collect dependencies. + llvm::SetVector declareOps; + llvm::SetVector boxOffsets; + for (omp::MapInfoOp mapOp : mapInfos) { + if (auto declareOp = dyn_cast_if_present( + mapOp.getVarPtr().getDefiningOp())) + collectRewrite(declareOp, region, declareOps, parentOpRewrites); + else + collectRewrite(mapOp.getVarPtr(), region, rewriteValues, + parentValRewrites); + + if (Value varPtrPtr = mapOp.getVarPtrPtr()) { + if (auto boxOffset = llvm::dyn_cast_if_present( + varPtrPtr.getDefiningOp())) + collectRewrite(boxOffset, region, boxOffsets, parentOpRewrites); + else + return mapOp->emitOpError() << "var_ptr_ptr rewrite only supported " + "if defined by fir.box_offset"; + } + + // Bounds are not used during target device codegen. + mapOp.getBoundsMutable().clear(); + mapOp->moveBefore(&block, block.end()); + } + + applyChildRewrites(region, childOpRewrites, declareOps, parentOpRewrites); + applyChildRewrites(region, childOpRewrites, boxOffsets, parentOpRewrites); + + // Create a temporary marker to simplify the op moving process below. + builder.setInsertionPointToStart(&block); + auto marker = builder.create(builder.getUnknownLoc(), + builder.getNoneType()); + builder.setInsertionPoint(marker); + + // Handle dependencies of hlfir.declare ops. + for (hlfir::DeclareOp declareOp : declareOps) { + collectRewrite(declareOp.getMemref(), region, rewriteValues, + parentValRewrites); + + // Shape and typeparams aren't needed for target device codegen, but + // removing them would break verifiers. + Value zero; + if (declareOp.getShape() || !declareOp.getTypeparams().empty()) + zero = builder.create(declareOp.getLoc(), + builder.getI64IntegerAttr(0)); + + if (auto shape = declareOp.getShape()) { + // The pre-cg rewrite pass requires the shape to be defined by one of + // fir.shape, fir.shapeshift or fir.shift, so we need to make sure it's + // still defined by one of these after this pass. + Operation *shapeOp = shape.getDefiningOp(); + llvm::SmallVector extents(shapeOp->getNumOperands(), zero); + Value newShape = + llvm::TypeSwitch(shapeOp) + .Case([&](fir::ShapeOp op) { + return builder.create(op.getLoc(), extents); + }) + .Case([&](fir::ShapeShiftOp op) { + auto type = fir::ShapeShiftType::get(op.getContext(), + extents.size() / 2); + return builder.create(op.getLoc(), type, + extents); + }) + .Case([&](fir::ShiftOp op) { + auto type = + fir::ShiftType::get(op.getContext(), extents.size()); + return builder.create(op.getLoc(), type, + extents); + }) + .Default([](Operation *op) { + op->emitOpError() + << "hlfir.declare shape expected to be one of: " + "fir.shape, fir.shapeshift or fir.shift"; + return nullptr; + }); + + if (!newShape) + return failure(); + + declareOp.getShapeMutable().assign(newShape); + } + + for (OpOperand &typeParam : declareOp.getTypeparamsMutable()) + typeParam.assign(zero); + + declareOp.getDummyScopeMutable().clear(); + } + + // We don't actually need the proper initialization, but rather just + // maintain the basic form of these operands. We create 1-bit placeholder + // allocas that we "typecast" to the expected type and replace all uses. + // Using fir.undefined here instead is not possible because these variables + // cannot be constants, as that would trigger different codegen for target + // regions. + applyChildRewrites(region, childValRewrites, rewriteValues, + parentValRewrites); + for (Value value : rewriteValues) { + Location loc = value.getLoc(); + Value rewriteValue; + // If it's defined by fir.address_of, then we need to keep that op as + // well because it might be pointing to a 'declare target' global. + // Constants can also trigger different codegen paths, so we keep them as + // well. + if (isa_and_present( + value.getDefiningOp())) { + rewriteValue = builder.clone(*value.getDefiningOp())->getResult(0); + } else { + Value placeholder = + builder.create(loc, builder.getI1Type()); + rewriteValue = + builder.create(loc, value.getType(), placeholder); + } + value.replaceAllUsesWith(rewriteValue); + } + + // Move omp.map.info dependencies. + for (hlfir::DeclareOp declareOp : declareOps) + declareOp->moveBefore(marker); + + // The box_ref argument of fir.box_offset is expected to be the same value + // that was passed as var_ptr to the corresponding omp.map.info, so we don't + // need to handle its defining op here. + for (fir::BoxOffsetOp boxOffset : boxOffsets) + boxOffset->moveBefore(marker); + + marker->erase(); + + // Move target operations to the end of the new block. + for (auto targetOp : targetOps) + std::visit([&block](auto op) { op->moveBefore(&block, block.end()); }, + targetOp); + + // Add terminator to the new block. + builder.setInsertionPointToEnd(&block); + if (funcOp) { + llvm::SmallVector returnValues; + returnValues.reserve(funcOp.getNumResults()); + for (auto type : funcOp.getResultTypes()) + returnValues.push_back( + builder.create(funcOp.getLoc(), type)); + + builder.create(funcOp.getLoc(), returnValues); + } else { + builder.create(targetDataOp.getLoc()); + } + + // Replace old region (now missing ops) with the new one and remove the + // temporary operation clone. + region.takeBody(newOp->getRegion(0)); + newOp->erase(); + return success(); + } }; } // namespace diff --git a/flang/test/Lower/OpenMP/declare-target-link-tarop-cap.f90 b/flang/test/Lower/OpenMP/declare-target-link-tarop-cap.f90 index cfdcd9eda82d1..8f4d1bdd600d7 100644 --- a/flang/test/Lower/OpenMP/declare-target-link-tarop-cap.f90 +++ b/flang/test/Lower/OpenMP/declare-target-link-tarop-cap.f90 @@ -1,7 +1,7 @@ -!RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s -!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-is-device %s -o - | FileCheck %s -!RUN: bbc -emit-hlfir -fopenmp %s -o - | FileCheck %s -!RUN: bbc -emit-hlfir -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s +!RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s --check-prefixes=BOTH,HOST +!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-is-device %s -o - | FileCheck %s --check-prefixes=BOTH,DEVICE +!RUN: bbc -emit-hlfir -fopenmp %s -o - | FileCheck %s --check-prefixes=BOTH,HOST +!RUN: bbc -emit-hlfir -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s --check-prefixes=BOTH,DEVICE program test_link @@ -20,13 +20,14 @@ program test_link integer, pointer :: test_ptr2 !$omp declare target link(test_ptr2) - !CHECK-DAG: {{%.*}} = omp.map.info var_ptr({{%.*}} : !fir.ref, i32) map_clauses(implicit, tofrom) capture(ByRef) -> !fir.ref {name = "test_int"} + !BOTH-DAG: {{%.*}} = omp.map.info var_ptr({{%.*}} : !fir.ref, i32) map_clauses(implicit, tofrom) capture(ByRef) -> !fir.ref {name = "test_int"} !$omp target test_int = test_int + 1 !$omp end target - !CHECK-DAG: {{%.*}} = omp.map.info var_ptr({{%.*}} : !fir.ref>, !fir.array<3xi32>) map_clauses(implicit, tofrom) capture(ByRef) bounds({{%.*}}) -> !fir.ref> {name = "test_array_1d"} + !HOST-DAG: {{%.*}} = omp.map.info var_ptr({{%.*}} : !fir.ref>, !fir.array<3xi32>) map_clauses(implicit, tofrom) capture(ByRef) bounds({{%.*}}) -> !fir.ref> {name = "test_array_1d"} + !DEVICE-DAG: {{%.*}} = omp.map.info var_ptr({{%.*}} : !fir.ref>, !fir.array<3xi32>) map_clauses(implicit, tofrom) capture(ByRef) -> !fir.ref> {name = "test_array_1d"} !$omp target do i = 1,3 test_array_1d(i) = i * 2 @@ -35,18 +36,18 @@ program test_link allocate(test_ptr1) test_ptr1 = 1 - !CHECK-DAG: {{%.*}} = omp.map.info var_ptr({{%.*}} : !fir.ref>>, !fir.box>) map_clauses(implicit, to) capture(ByRef) members({{%.*}} : !fir.llvm_ptr>) -> !fir.ref>> {name = "test_ptr1"} + !BOTH-DAG: {{%.*}} = omp.map.info var_ptr({{%.*}} : !fir.ref>>, !fir.box>) map_clauses(implicit, to) capture(ByRef) members({{%.*}} : !fir.llvm_ptr>) -> !fir.ref>> {name = "test_ptr1"} !$omp target test_ptr1 = test_ptr1 + 1 !$omp end target - !CHECK-DAG: {{%.*}} = omp.map.info var_ptr({{%.*}} : !fir.ref, i32) map_clauses(implicit, tofrom) capture(ByRef) -> !fir.ref {name = "test_target"} + !BOTH-DAG: {{%.*}} = omp.map.info var_ptr({{%.*}} : !fir.ref, i32) map_clauses(implicit, tofrom) capture(ByRef) -> !fir.ref {name = "test_target"} !$omp target test_target = test_target + 1 !$omp end target - !CHECK-DAG: {{%.*}} = omp.map.info var_ptr({{%.*}} : !fir.ref>>, !fir.box>) map_clauses(implicit, to) capture(ByRef) members({{%.*}} : !fir.llvm_ptr>) -> !fir.ref>> {name = "test_ptr2"} + !BOTH-DAG: {{%.*}} = omp.map.info var_ptr({{%.*}} : !fir.ref>>, !fir.box>) map_clauses(implicit, to) capture(ByRef) members({{%.*}} : !fir.llvm_ptr>) -> !fir.ref>> {name = "test_ptr2"} test_ptr2 => test_target !$omp target test_ptr2 = test_ptr2 + 1 diff --git a/flang/test/Lower/OpenMP/host-eval.f90 b/flang/test/Lower/OpenMP/host-eval.f90 index fe5b9597f8620..c059f7338b26d 100644 --- a/flang/test/Lower/OpenMP/host-eval.f90 +++ b/flang/test/Lower/OpenMP/host-eval.f90 @@ -22,8 +22,10 @@ subroutine teams() !$omp end target - ! BOTH: omp.teams - ! BOTH-SAME: num_teams({{.*}}) thread_limit({{.*}}) { + ! HOST: omp.teams + ! HOST-SAME: num_teams({{.*}}) thread_limit({{.*}}) { + + ! DEVICE-NOT: omp.teams !$omp teams num_teams(1) thread_limit(2) call foo() !$omp end teams @@ -76,13 +78,18 @@ subroutine distribute_parallel_do() !$omp end distribute parallel do !$omp end target teams - ! BOTH: omp.teams + ! HOST: omp.teams + ! DEVICE-NOT: omp.teams !$omp teams - ! BOTH: omp.parallel - ! BOTH-SAME: num_threads({{.*}}) - ! BOTH: omp.distribute - ! BOTH-NEXT: omp.wsloop + ! HOST: omp.parallel + ! HOST-SAME: num_threads({{.*}}) + ! HOST: omp.distribute + ! HOST-NEXT: omp.wsloop + + ! DEVICE-NOT: omp.parallel + ! DEVICE-NOT: omp.distribute + ! DEVICE-NOT: omp.wsloop !$omp distribute parallel do num_threads(1) do i=1,10 call foo() @@ -140,14 +147,20 @@ subroutine distribute_parallel_do_simd() !$omp end distribute parallel do simd !$omp end target teams - ! BOTH: omp.teams + ! HOST: omp.teams + ! DEVICE-NOT: omp.teams !$omp teams - ! BOTH: omp.parallel - ! BOTH-SAME: num_threads({{.*}}) - ! BOTH: omp.distribute - ! BOTH-NEXT: omp.wsloop - ! BOTH-NEXT: omp.simd + ! HOST: omp.parallel + ! HOST-SAME: num_threads({{.*}}) + ! HOST: omp.distribute + ! HOST-NEXT: omp.wsloop + ! HOST-NEXT: omp.simd + + ! DEVICE-NOT: omp.parallel + ! DEVICE-NOT: omp.distribute + ! DEVICE-NOT: omp.wsloop + ! DEVICE-NOT: omp.simd !$omp distribute parallel do simd num_threads(1) do i=1,10 call foo() @@ -194,10 +207,12 @@ subroutine distribute() !$omp end distribute !$omp end target teams - ! BOTH: omp.teams + ! HOST: omp.teams + ! DEVICE-NOT: omp.teams !$omp teams - ! BOTH: omp.distribute + ! HOST: omp.distribute + ! DEVICE-NOT: omp.distribute !$omp distribute do i=1,10 call foo() @@ -246,11 +261,15 @@ subroutine distribute_simd() !$omp end distribute simd !$omp end target teams - ! BOTH: omp.teams + ! HOST: omp.teams + ! DEVICE-NOT: omp.teams !$omp teams - ! BOTH: omp.distribute - ! BOTH-NEXT: omp.simd + ! HOST: omp.distribute + ! HOST-NEXT: omp.simd + + ! DEVICE-NOT: omp.distribute + ! DEVICE-NOT: omp.simd !$omp distribute simd do i=1,10 call foo() diff --git a/flang/test/Lower/OpenMP/real10.f90 b/flang/test/Lower/OpenMP/real10.f90 index a31d2ace80044..c76c2bde0f6f6 100644 --- a/flang/test/Lower/OpenMP/real10.f90 +++ b/flang/test/Lower/OpenMP/real10.f90 @@ -5,9 +5,6 @@ !CHECK: hlfir.declare %{{.*}} {uniq_name = "_QFEx"} : (!fir.ref) -> (!fir.ref, !fir.ref) program p + !$omp declare target real(10) :: x - !$omp target - continue - !$omp end target end - diff --git a/flang/test/Transforms/OpenMP/function-filtering-host-ops.mlir b/flang/test/Transforms/OpenMP/function-filtering-host-ops.mlir new file mode 100644 index 0000000000000..4d8975d58a50c --- /dev/null +++ b/flang/test/Transforms/OpenMP/function-filtering-host-ops.mlir @@ -0,0 +1,498 @@ +// RUN: fir-opt --omp-function-filtering %s | FileCheck %s + +module attributes {omp.is_target_device = true} { + // CHECK-LABEL: func.func @basic_checks + // CHECK-SAME: (%[[ARG:.*]]: !fir.ref) -> (i32, f32) + func.func @basic_checks(%arg: !fir.ref) -> (i32, f32) { + // CHECK-NEXT: %[[PLACEHOLDER:.*]] = fir.alloca i1 + // CHECK-NEXT: %[[ALLOC:.*]] = fir.convert %[[PLACEHOLDER]] : (!fir.ref) -> !fir.ref + // CHECK-NEXT: %[[GLOBAL:.*]] = fir.address_of(@global_scalar) : !fir.ref + %r0 = arith.constant 10 : i32 + %r1 = arith.constant 2.5 : f32 + + func.call @foo() : () -> () + + // CHECK-NEXT: %[[ARG_DECL:.*]]:2 = hlfir.declare %[[ARG]] {uniq_name = "arg"} + %0:2 = hlfir.declare %arg {uniq_name = "arg"} : (!fir.ref) -> (!fir.ref, !fir.ref) + + // CHECK-NEXT: %[[GLOBAL_DECL:.*]]:2 = hlfir.declare %[[GLOBAL]] {uniq_name = "global_scalar"} + %global = fir.address_of(@global_scalar) : !fir.ref + %1:2 = hlfir.declare %global {uniq_name = "global_scalar"} : (!fir.ref) -> (!fir.ref, !fir.ref) + + // CHECK-NEXT: %[[ALLOC_DECL:.*]]:2 = hlfir.declare %[[ALLOC]] {uniq_name = "alloc"} + %alloc = fir.alloca i32 + %2:2 = hlfir.declare %alloc {uniq_name = "alloc"} : (!fir.ref) -> (!fir.ref, !fir.ref) + + // CHECK-NEXT: %[[MAP0:.*]] = omp.map.info var_ptr(%[[ARG_DECL]]#1{{.*}}) + // CHECK-NEXT: %[[MAP1:.*]] = omp.map.info var_ptr(%[[GLOBAL_DECL]]#1{{.*}}) + // CHECK-NEXT: %[[MAP3:.*]] = omp.map.info var_ptr(%[[ALLOC]]{{.*}}) + // CHECK-NEXT: %[[MAP2:.*]] = omp.map.info var_ptr(%[[ALLOC_DECL]]#1{{.*}}) + // CHECK-NEXT: %[[MAP4:.*]] = omp.map.info var_ptr(%[[ARG_DECL]]#1{{.*}}) + // CHECK-NEXT: %[[MAP5:.*]] = omp.map.info var_ptr(%[[GLOBAL_DECL]]#1{{.*}}) + // CHECK-NEXT: %[[MAP6:.*]] = omp.map.info var_ptr(%[[ALLOC_DECL]]#1{{.*}}) + // CHECK-NEXT: %[[MAP7:.*]] = omp.map.info var_ptr(%[[ALLOC]]{{.*}}) + // CHECK-NEXT: %[[MAP8:.*]] = omp.map.info var_ptr(%[[ARG_DECL]]#1{{.*}}) + // CHECK-NEXT: %[[MAP9:.*]] = omp.map.info var_ptr(%[[GLOBAL_DECL]]#1{{.*}}) + // CHECK-NEXT: %[[MAP10:.*]] = omp.map.info var_ptr(%[[ALLOC_DECL]]#1{{.*}}) + %m0 = omp.map.info var_ptr(%0#1 : !fir.ref, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref + %m1 = omp.map.info var_ptr(%1#1 : !fir.ref, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref + %m2 = omp.map.info var_ptr(%2#1 : !fir.ref, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref + %m3 = omp.map.info var_ptr(%alloc : !fir.ref, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref + + // CHECK-NEXT: omp.target has_device_addr(%[[MAP2]] -> {{.*}} : {{.*}}) map_entries(%[[MAP0]] -> {{.*}}, %[[MAP1]] -> {{.*}}, %[[MAP3]] -> {{.*}} : {{.*}}) + omp.target has_device_addr(%m2 -> %arg0 : !fir.ref) map_entries(%m0 -> %arg1, %m1 -> %arg2, %m3 -> %arg3 : !fir.ref, !fir.ref, !fir.ref) { + // CHECK-NEXT: func.call + func.call @foo() : () -> () + omp.terminator + } + + // CHECK-NOT: omp.parallel + // CHECK-NOT: func.call + // CHECK-NOT: omp.map.info + omp.parallel { + func.call @foo() : () -> () + omp.terminator + } + + %m4 = omp.map.info var_ptr(%0#1 : !fir.ref, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref + %m5 = omp.map.info var_ptr(%1#1 : !fir.ref, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref + %m6 = omp.map.info var_ptr(%2#1 : !fir.ref, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref + %m7 = omp.map.info var_ptr(%alloc : !fir.ref, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref + + // CHECK: omp.target_data map_entries(%[[MAP4]], %[[MAP5]], %[[MAP6]], %[[MAP7]] : {{.*}}) + omp.target_data map_entries(%m4, %m5, %m6, %m7 : !fir.ref, !fir.ref, !fir.ref, !fir.ref) { + // CHECK-NOT: func.call + func.call @foo() : () -> () + omp.terminator + } + + // CHECK: omp.target_enter_data map_entries(%[[MAP8]] : {{.*}}) + // CHECK-NEXT: omp.target_exit_data map_entries(%[[MAP9]] : {{.*}}) + // CHECK-NEXT: omp.target_update map_entries(%[[MAP10]] : {{.*}}) + %m8 = omp.map.info var_ptr(%0#1 : !fir.ref, i32) map_clauses(to) capture(ByRef) -> !fir.ref + omp.target_enter_data map_entries(%m8 : !fir.ref) + + %m9 = omp.map.info var_ptr(%1#1 : !fir.ref, i32) map_clauses(from) capture(ByRef) -> !fir.ref + omp.target_exit_data map_entries(%m9 : !fir.ref) + + %m10 = omp.map.info var_ptr(%2#1 : !fir.ref, !fir.ref) map_clauses(to) capture(ByRef) -> !fir.ref + omp.target_update map_entries(%m10 : !fir.ref) + + // CHECK-NOT: func.call + func.call @foo() : () -> () + + // CHECK: %[[RETURN0:.*]] = fir.undefined i32 + // CHECK-NEXT: %[[RETURN1:.*]] = fir.undefined f32 + // CHECK-NEXT: return %[[RETURN0]], %[[RETURN1]] + return %r0, %r1 : i32, f32 + } + + // CHECK-LABEL: func.func @allocatable_array + // CHECK-SAME: (%[[ALLOCATABLE:.*]]: [[ALLOCATABLE_TYPE:.*]], %[[ARRAY:.*]]: [[ARRAY_TYPE:[^)]*]]) + func.func @allocatable_array(%allocatable: !fir.ref>>>, %array: !fir.ref>) { + // CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0 : i64 + // CHECK-NEXT: %[[SHAPE:.*]] = fir.shape %[[ZERO]] : (i64) -> !fir.shape<1> + // CHECK-NEXT: %[[ALLOCATABLE_DECL:.*]]:2 = hlfir.declare %[[ALLOCATABLE]] {fortran_attrs = #fir.var_attrs, uniq_name = "allocatable"} : ([[ALLOCATABLE_TYPE]]) -> ([[ALLOCATABLE_TYPE]], [[ALLOCATABLE_TYPE]]) + // CHECK-NEXT: %[[ARRAY_DECL:.*]]:2 = hlfir.declare %[[ARRAY]](%[[SHAPE]]) {uniq_name = "array"} : ([[ARRAY_TYPE]], !fir.shape<1>) -> ([[ARRAY_TYPE]], [[ARRAY_TYPE]]) + // CHECK-NEXT: %[[VAR_PTR_PTR:.*]] = fir.box_offset %[[ALLOCATABLE_DECL]]#1 base_addr : ([[ALLOCATABLE_TYPE]]) -> [[VAR_PTR_PTR_TYPE:.*]] + // CHECK-NEXT: %[[MAP_ALLOCATABLE:.*]] = omp.map.info var_ptr(%[[ALLOCATABLE_DECL]]#1 : [[ALLOCATABLE_TYPE]], f32) map_clauses(tofrom) capture(ByRef) var_ptr_ptr(%[[VAR_PTR_PTR]] : [[VAR_PTR_PTR_TYPE]]) -> [[VAR_PTR_PTR_TYPE]] + // CHECK-NEXT: %[[MAP_ARRAY:.*]] = omp.map.info var_ptr(%[[ARRAY_DECL]]#1 : [[ARRAY_TYPE]], !fir.array<9xi32>) map_clauses(tofrom) capture(ByRef) -> [[ARRAY_TYPE]] + // CHECK-NEXT: omp.target map_entries(%[[MAP_ALLOCATABLE]] -> %{{.*}}, %[[MAP_ARRAY]] -> %{{.*}} : [[VAR_PTR_PTR_TYPE]], [[ARRAY_TYPE]]) + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c9 = arith.constant 9 : index + + %0:2 = hlfir.declare %allocatable {fortran_attrs = #fir.var_attrs, uniq_name = "allocatable"} : (!fir.ref>>>) -> (!fir.ref>>>, !fir.ref>>>) + %1 = omp.map.bounds lower_bound(%c0 : index) upper_bound(%c8 : index) extent(%c9 : index) stride(%c1 : index) start_idx(%c1 : index) + %2 = fir.box_offset %0#1 base_addr : (!fir.ref>>>) -> !fir.llvm_ptr>> + %m0 = omp.map.info var_ptr(%0#1 : !fir.ref>>>, f32) map_clauses(tofrom) capture(ByRef) var_ptr_ptr(%2 : !fir.llvm_ptr>>) bounds(%1) -> !fir.llvm_ptr>> + + %3 = fir.shape %c9 : (index) -> !fir.shape<1> + %4:2 = hlfir.declare %array(%3) {uniq_name = "array"} : (!fir.ref>, !fir.shape<1>) -> (!fir.ref>, !fir.ref>) + %5 = omp.map.bounds lower_bound(%c0 : index) upper_bound(%c8 : index) extent(%c9 : index) stride(%c1 : index) start_idx(%c1 : index) + %6 = omp.map.info var_ptr(%4#1 : !fir.ref>, !fir.array<9xi32>) map_clauses(tofrom) capture(ByRef) bounds(%5) -> !fir.ref> + + omp.target map_entries(%m0 -> %arg0, %6 -> %arg1 : !fir.llvm_ptr>>, !fir.ref>) { + omp.terminator + } + return + } + + // CHECK-LABEL: func.func @character + // CHECK-SAME: (%[[X:.*]]: [[X_TYPE:[^)]*]]) + func.func @character(%x: !fir.ref>) { + // CHECK-NEXT: %[[ZERO]] = arith.constant 0 : i64 + %0 = fir.dummy_scope : !fir.dscope + %c1 = arith.constant 1 : index + // CHECK-NEXT: %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] typeparams %[[ZERO]] {uniq_name = "x"} : ([[X_TYPE]], i64) -> ([[X_TYPE]], [[X_TYPE]]) + %3:2 = hlfir.declare %x typeparams %c1 dummy_scope %0 {uniq_name = "x"} : (!fir.ref>, index, !fir.dscope) -> (!fir.ref>, !fir.ref>) + // CHECK-NEXT: %[[MAP:.*]] = omp.map.info var_ptr(%[[X_DECL]]#1 : [[X_TYPE]], !fir.char<1>) map_clauses(tofrom) capture(ByRef) -> [[X_TYPE]] + %map = omp.map.info var_ptr(%3#1 : !fir.ref>, !fir.char<1>) map_clauses(tofrom) capture(ByRef) -> !fir.ref> + // CHECK-NEXT: omp.target map_entries(%[[MAP]] -> %{{.*}}) + omp.target map_entries(%map -> %arg0 : !fir.ref>) { + omp.terminator + } + return + } + + // CHECK-LABEL: func.func @assumed_rank + // CHECK-SAME: (%[[X:.*]]: [[X_TYPE:[^)]*]]) + func.func @assumed_rank(%x: !fir.box>) { + // CHECK-NEXT: %[[PLACEHOLDER:.*]] = fir.alloca i1 + // CHECK-NEXT: %[[ALLOCA:.*]] = fir.convert %[[PLACEHOLDER]] : (!fir.ref) -> !fir.ref<[[X_TYPE]]> + %0 = fir.alloca !fir.box> + %1 = fir.dummy_scope : !fir.dscope + %2:2 = hlfir.declare %x dummy_scope %1 {uniq_name = "x"} : (!fir.box>, !fir.dscope) -> (!fir.box>, !fir.box>) + %3 = fir.box_addr %2#1 : (!fir.box>) -> !fir.ref> + fir.store %2#1 to %0 : !fir.ref>> + // CHECK-NEXT: %[[VAR_PTR_PTR:.*]] = fir.box_offset %[[ALLOCA]] base_addr : (!fir.ref<[[X_TYPE]]>) -> [[VAR_PTR_PTR_TYPE:.*]] + %4 = fir.box_offset %0 base_addr : (!fir.ref>>) -> !fir.llvm_ptr>> + // CHECK-NEXT: %[[MAP0:.*]] = omp.map.info var_ptr(%[[ALLOCA]] : !fir.ref<[[X_TYPE]]>, !fir.array<*:f32>) {{.*}} var_ptr_ptr(%[[VAR_PTR_PTR]] : [[VAR_PTR_PTR_TYPE]]) -> [[VAR_PTR_PTR_TYPE]] + %5 = omp.map.info var_ptr(%0 : !fir.ref>>, !fir.array<*:f32>) map_clauses(tofrom) capture(ByRef) var_ptr_ptr(%4 : !fir.llvm_ptr>>) -> !fir.llvm_ptr>> + // CHECK-NEXT: %[[MAP1:.*]] = omp.map.info var_ptr(%[[ALLOCA]] : !fir.ref<[[X_TYPE]]>, !fir.box>) {{.*}} members(%[[MAP0]] : [0] : [[VAR_PTR_PTR_TYPE]]) -> !fir.ref> + %6 = omp.map.info var_ptr(%0 : !fir.ref>>, !fir.box>) map_clauses(to) capture(ByRef) members(%5 : [0] : !fir.llvm_ptr>>) -> !fir.ref> + // CHECK-NEXT: omp.target map_entries(%[[MAP1]] -> %{{.*}}, %[[MAP0]] -> {{.*}}) + omp.target map_entries(%6 -> %arg1, %5 -> %arg2 : !fir.ref>, !fir.llvm_ptr>>) { + omp.terminator + } + return + } + + // CHECK-LABEL: func.func @box_ptr + // CHECK-SAME: (%[[X:.*]]: [[X_TYPE:[^)]*]]) + func.func @box_ptr(%x: !fir.ref>>>) { + // CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0 : i64 + // CHECK-NEXT: %[[SHAPE:.*]] = fir.shape_shift %[[ZERO]], %[[ZERO]] : (i64, i64) -> !fir.shapeshift<1> + // CHECK-NEXT: %[[PLACEHOLDER_X:.*]] = fir.alloca i1 + // CHECK-NEXT: %[[ALLOCA_X:.*]] = fir.convert %[[PLACEHOLDER_X]] : (!fir.ref) -> [[X_TYPE]] + %0 = fir.alloca !fir.box>> + %1 = fir.dummy_scope : !fir.dscope + %2:2 = hlfir.declare %x dummy_scope %1 {fortran_attrs = #fir.var_attrs, uniq_name = "x"} : (!fir.ref>>>, !fir.dscope) -> (!fir.ref>>>, !fir.ref>>>) + %3 = fir.load %2#0 : !fir.ref>>> + fir.store %3 to %0 : !fir.ref>>> + + // CHECK-NEXT: %[[PLACEHOLDER_Y:.*]] = fir.alloca i1 + // CHECK-NEXT: %[[ALLOCA_Y:.*]] = fir.convert %[[PLACEHOLDER_Y]] : (!fir.ref) -> [[Y_TYPE:.*]] + %c0 = arith.constant 0 : index + %4:3 = fir.box_dims %3, %c0 : (!fir.box>>, index) -> (index, index, index) + %c1 = arith.constant 1 : index + %c0_0 = arith.constant 0 : index + %5:3 = fir.box_dims %3, %c0_0 : (!fir.box>>, index) -> (index, index, index) + %c0_1 = arith.constant 0 : index + %6 = arith.subi %5#1, %c1 : index + %7 = omp.map.bounds lower_bound(%c0_1 : index) upper_bound(%6 : index) extent(%5#1 : index) stride(%5#2 : index) start_idx(%4#0 : index) {stride_in_bytes = true} + %8 = fir.box_addr %3 : (!fir.box>>) -> !fir.ptr> + %c0_2 = arith.constant 0 : index + %9:3 = fir.box_dims %3, %c0_2 : (!fir.box>>, index) -> (index, index, index) + %10 = fir.shape_shift %9#0, %9#1 : (index, index) -> !fir.shapeshift<1> + + // CHECK-NEXT: %[[Y_DECL:.*]]:2 = hlfir.declare %[[ALLOCA_Y]](%[[SHAPE]]) {fortran_attrs = #fir.var_attrs, uniq_name = "y"} : ([[Y_TYPE]], !fir.shapeshift<1>) -> (!fir.box>, [[Y_TYPE]]) + %11:2 = hlfir.declare %8(%10) {fortran_attrs = #fir.var_attrs, uniq_name = "y"} : (!fir.ptr>, !fir.shapeshift<1>) -> (!fir.box>, !fir.ptr>) + %c1_3 = arith.constant 1 : index + %c0_4 = arith.constant 0 : index + %12:3 = fir.box_dims %11#0, %c0_4 : (!fir.box>, index) -> (index, index, index) + %c0_5 = arith.constant 0 : index + %13 = arith.subi %12#1, %c1_3 : index + %14 = omp.map.bounds lower_bound(%c0_5 : index) upper_bound(%13 : index) extent(%12#1 : index) stride(%12#2 : index) start_idx(%9#0 : index) {stride_in_bytes = true} + + // CHECK-NEXT: %[[VAR_PTR_PTR:.*]] = fir.box_offset %[[ALLOCA_X]] base_addr : ([[X_TYPE]]) -> [[VAR_PTR_PTR_TYPE:.*]] + // CHECK-NEXT: %[[MAP0:.*]] = omp.map.info var_ptr(%[[Y_DECL]]#1 : [[Y_TYPE]], i32) {{.*}} -> [[Y_TYPE]] + // CHECK-NEXT: %[[MAP1:.*]] = omp.map.info var_ptr(%[[ALLOCA_X]] : [[X_TYPE]], i32) {{.*}} var_ptr_ptr(%[[VAR_PTR_PTR]] : [[VAR_PTR_PTR_TYPE]]) -> [[VAR_PTR_PTR_TYPE]] + // CHECK-NEXT: %[[MAP2:.*]] = omp.map.info var_ptr(%[[ALLOCA_X]] : [[X_TYPE]], !fir.box>>) {{.*}} members(%[[MAP1]] : [0] : [[VAR_PTR_PTR_TYPE]]) -> [[X_TYPE]] + %15 = omp.map.info var_ptr(%11#1 : !fir.ptr>, i32) map_clauses(tofrom) capture(ByRef) bounds(%14) -> !fir.ptr> + %16 = fir.box_offset %0 base_addr : (!fir.ref>>>) -> !fir.llvm_ptr>> + %17 = omp.map.info var_ptr(%0 : !fir.ref>>>, i32) map_clauses(implicit, to) capture(ByRef) var_ptr_ptr(%16 : !fir.llvm_ptr>>) bounds(%7) -> !fir.llvm_ptr>> + %18 = omp.map.info var_ptr(%0 : !fir.ref>>>, !fir.box>>) map_clauses(implicit, to) capture(ByRef) members(%17 : [0] : !fir.llvm_ptr>>) -> !fir.ref>>> + + // CHECK-NEXT: omp.target map_entries(%[[MAP0]] -> %{{.*}}, %[[MAP2]] -> %{{.*}}, %[[MAP1]] -> {{.*}} : [[Y_TYPE]], [[X_TYPE]], [[VAR_PTR_PTR_TYPE]]) + omp.target map_entries(%15 -> %arg1, %18 -> %arg2, %17 -> %arg3 : !fir.ptr>, !fir.ref>>>, !fir.llvm_ptr>>) { + omp.terminator + } + return + } + + // CHECK-LABEL: func.func @target_data + // CHECK-SAME: (%[[MAPPED:.*]]: [[MAPPED_TYPE:[^)]*]], %[[USEDEVADDR:.*]]: [[USEDEVADDR_TYPE:[^)]*]], %[[USEDEVPTR:.*]]: [[USEDEVPTR_TYPE:[^)]*]]) + func.func @target_data(%mapped: !fir.ref, %usedevaddr: !fir.ref, %usedevptr: !fir.ref>) { + // CHECK-NEXT: %[[MAPPED_DECL:.*]]:2 = hlfir.declare %[[MAPPED]] {uniq_name = "mapped"} : ([[MAPPED_TYPE]]) -> ([[MAPPED_TYPE]], [[MAPPED_TYPE]]) + %0:2 = hlfir.declare %mapped {uniq_name = "mapped"} : (!fir.ref) -> (!fir.ref, !fir.ref) + %1:2 = hlfir.declare %usedevaddr {uniq_name = "usedevaddr"} : (!fir.ref) -> (!fir.ref, !fir.ref) + %2:2 = hlfir.declare %usedevptr {uniq_name = "usedevptr"} : (!fir.ref>) -> (!fir.ref>, !fir.ref>) + %m0 = omp.map.info var_ptr(%0#1 : !fir.ref, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref + %m1 = omp.map.info var_ptr(%1#1 : !fir.ref, i32) map_clauses(return_param) capture(ByRef) -> !fir.ref + %m2 = omp.map.info var_ptr(%2#1 : !fir.ref>, !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) map_clauses(return_param) capture(ByRef) -> !fir.ref> + // CHECK: omp.target_data map_entries(%{{.*}}) use_device_addr(%{{.*}} -> %[[USEDEVADDR_ARG:.*]] : [[USEDEVADDR_TYPE]]) use_device_ptr(%{{.*}} -> %[[USEDEVPTR_ARG:.*]] : [[USEDEVPTR_TYPE]]) + omp.target_data map_entries(%m0 : !fir.ref) use_device_addr(%m1 -> %arg0 : !fir.ref) use_device_ptr(%m2 -> %arg1 : !fir.ref>) { + // CHECK-NEXT: %[[USEDEVADDR_DECL:.*]]:2 = hlfir.declare %[[USEDEVADDR_ARG]] {uniq_name = "usedevaddr"} : ([[USEDEVADDR_TYPE]]) -> ([[USEDEVADDR_TYPE]], [[USEDEVADDR_TYPE]]) + %3:2 = hlfir.declare %arg0 {uniq_name = "usedevaddr"} : (!fir.ref) -> (!fir.ref, !fir.ref) + // CHECK-NEXT: %[[USEDEVPTR_DECL:.*]]:2 = hlfir.declare %[[USEDEVPTR_ARG]] {uniq_name = "usedevptr"} : ([[USEDEVPTR_TYPE]]) -> ([[USEDEVPTR_TYPE]], [[USEDEVPTR_TYPE]]) + %4:2 = hlfir.declare %arg1 {uniq_name = "usedevptr"} : (!fir.ref>) -> (!fir.ref>, !fir.ref>) + // CHECK-NEXT: %[[MAPPED_MAP:.*]] = omp.map.info var_ptr(%[[MAPPED_DECL]]#1 : [[MAPPED_TYPE]], i32) map_clauses(tofrom) capture(ByRef) -> [[MAPPED_TYPE]] + %m3 = omp.map.info var_ptr(%0#1 : !fir.ref, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref + // CHECK-NEXT: %[[USEDEVADDR_MAP:.*]] = omp.map.info var_ptr(%[[USEDEVADDR_DECL]]#1 : [[USEDEVADDR_TYPE]], i32) map_clauses(tofrom) capture(ByRef) -> [[USEDEVADDR_TYPE]] + %m4 = omp.map.info var_ptr(%3#1 : !fir.ref, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref + // CHECK-NEXT: %[[USEDEVPTR_MAP:.*]] = omp.map.info var_ptr(%[[USEDEVPTR_DECL]]#1 : [[USEDEVPTR_TYPE]], !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) map_clauses(tofrom) capture(ByRef) -> [[USEDEVPTR_TYPE]] + %m5 = omp.map.info var_ptr(%4#1 : !fir.ref>, !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) map_clauses(tofrom) capture(ByRef) -> !fir.ref> + + // CHECK-NOT: func.call + func.call @foo() : () -> () + + // CHECK-NEXT: omp.target map_entries(%[[MAPPED_MAP]] -> %{{.*}}, %[[USEDEVADDR_MAP]] -> %{{.*}}, %[[USEDEVPTR_MAP]] -> %{{.*}} : {{.*}}) + omp.target map_entries(%m3 -> %arg2, %m4 -> %arg3, %m5 -> %arg4 : !fir.ref, !fir.ref, !fir.ref>) { + omp.terminator + } + + // CHECK-NOT: func.call + func.call @foo() : () -> () + + omp.terminator + } + + // CHECK: return + return + } + + // CHECK-LABEL: func.func @map_info_members + // CHECK-SAME: (%[[X:.*]]: [[X_TYPE:[^)]*]]) + func.func @map_info_members(%x: !fir.ref>>>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c9 = arith.constant 9 : index + // CHECK-NEXT: %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] {fortran_attrs = #fir.var_attrs, uniq_name = "x"} : ([[X_TYPE]]) -> ([[X_TYPE]], [[X_TYPE]]) + %23:2 = hlfir.declare %x {fortran_attrs = #fir.var_attrs, uniq_name = "x"} : (!fir.ref>>>) -> (!fir.ref>>>, !fir.ref>>>) + %63 = fir.load %23#0 : !fir.ref>>> + %64:3 = fir.box_dims %63, %c0 : (!fir.box>>, index) -> (index, index, index) + %65:3 = fir.box_dims %63, %c0 : (!fir.box>>, index) -> (index, index, index) + %66 = arith.subi %c1, %64#0 : index + %67 = arith.subi %c9, %64#0 : index + %68 = fir.load %23#0 : !fir.ref>>> + %69:3 = fir.box_dims %68, %c0 : (!fir.box>>, index) -> (index, index, index) + %70 = omp.map.bounds lower_bound(%66 : index) upper_bound(%67 : index) extent(%69#1 : index) stride(%65#2 : index) start_idx(%64#0 : index) {stride_in_bytes = true} + // CHECK-NEXT: %[[VAR_PTR_PTR:.*]] = fir.box_offset %[[X_DECL]]#1 base_addr : ([[X_TYPE]]) -> [[VAR_PTR_PTR_TYPE:.*]] + %71 = fir.box_offset %23#1 base_addr : (!fir.ref>>>) -> !fir.llvm_ptr>> + // CHECK-NEXT: %[[MAP0:.*]] = omp.map.info var_ptr(%[[X_DECL]]#1 : [[X_TYPE]], f32) {{.*}} var_ptr_ptr(%[[VAR_PTR_PTR]] : [[VAR_PTR_PTR_TYPE]]) -> [[VAR_PTR_PTR_TYPE]] + %72 = omp.map.info var_ptr(%23#1 : !fir.ref>>>, f32) map_clauses(tofrom) capture(ByRef) var_ptr_ptr(%71 : !fir.llvm_ptr>>) bounds(%70) -> !fir.llvm_ptr>> + // CHECK-NEXT: %[[MAP1:.*]] = omp.map.info var_ptr(%[[X_DECL]]#1 : [[X_TYPE]], !fir.box>>) {{.*}} members(%[[MAP0]] : [0] : [[VAR_PTR_PTR_TYPE]]) -> [[X_TYPE]] + %73 = omp.map.info var_ptr(%23#1 : !fir.ref>>>, !fir.box>>) map_clauses(to) capture(ByRef) members(%72 : [0] : !fir.llvm_ptr>>) -> !fir.ref>>> + // CHECK-NEXT: omp.target map_entries(%[[MAP1]] -> {{.*}}, %[[MAP0]] -> %{{.*}} : [[X_TYPE]], [[VAR_PTR_PTR_TYPE]]) + omp.target map_entries(%73 -> %arg0, %72 -> %arg1 : !fir.ref>>>, !fir.llvm_ptr>>) { + omp.terminator + } + return + } + + // CHECK-LABEL: func.func @control_flow + // CHECK-SAME: (%[[X:.*]]: [[X_TYPE:[^,]*]], %[[COND:.*]]: [[COND_TYPE:[^)]*]]) + func.func @control_flow(%x: !fir.ref, %cond: !fir.ref>) { + // CHECK-NEXT: %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] {uniq_name = "x"} : ([[X_TYPE]]) -> ([[X_TYPE]], [[X_TYPE]]) + // CHECK-NEXT: %[[MAP0:.*]] = omp.map.info var_ptr(%[[X_DECL]]#1 : [[X_TYPE]], i32) {{.*}} -> [[X_TYPE]] + // CHECK-NEXT: %[[MAP1:.*]] = omp.map.info var_ptr(%[[X_DECL]]#1 : [[X_TYPE]], i32) {{.*}} -> [[X_TYPE]] + %x_decl:2 = hlfir.declare %x {uniq_name = "x"} : (!fir.ref) -> (!fir.ref, !fir.ref) + %cond_decl:2 = hlfir.declare %cond {uniq_name = "cond"} : (!fir.ref>) -> (!fir.ref>, !fir.ref>) + %0 = fir.load %cond_decl#0 : !fir.ref> + %1 = fir.convert %0 : (!fir.logical<4>) -> i1 + cf.cond_br %1, ^bb1, ^bb2 + ^bb1: // pred: ^bb0 + fir.call @foo() : () -> () + %m0 = omp.map.info var_ptr(%x_decl#1 : !fir.ref, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref + // CHECK-NEXT: omp.target map_entries(%[[MAP0]] -> {{.*}} : [[X_TYPE]]) + omp.target map_entries(%m0 -> %arg2 : !fir.ref) { + omp.terminator + } + fir.call @foo() : () -> () + cf.br ^bb2 + ^bb2: // 2 preds: ^bb0, ^bb1 + fir.call @foo() : () -> () + %m1 = omp.map.info var_ptr(%x_decl#1 : !fir.ref, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref + // CHECK-NOT: fir.call + // CHECK-NOT: omp.map.info + // CHECK: omp.target_data map_entries(%[[MAP1]] : [[X_TYPE]]) + omp.target_data map_entries(%m1 : !fir.ref) { + fir.call @foo() : () -> () + %8 = fir.load %cond_decl#0 : !fir.ref> + %9 = fir.convert %8 : (!fir.logical<4>) -> i1 + cf.cond_br %9, ^bb1, ^bb2 + ^bb1: // pred: ^bb0 + fir.call @foo() : () -> () + // CHECK-NEXT: %[[MAP2:.*]] = omp.map.info var_ptr(%[[X_DECL]]#1 : [[X_TYPE]], i32) {{.*}} -> [[X_TYPE]] + %m2 = omp.map.info var_ptr(%x_decl#1 : !fir.ref, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref + // CHECK-NEXT: omp.target map_entries(%[[MAP2]] -> {{.*}} : [[X_TYPE]]) + omp.target map_entries(%m2 -> %arg2 : !fir.ref) { + omp.terminator + } + // CHECK-NOT: fir.call + // CHECK-NOT: cf.br + fir.call @foo() : () -> () + cf.br ^bb2 + ^bb2: // 2 preds: ^bb0, ^bb1 + fir.call @foo() : () -> () + omp.terminator + } + fir.call @foo() : () -> () + + // CHECK: return + return + } + + // CHECK-LABEL: func.func @block_args + // CHECK-SAME: (%[[X:.*]]: [[X_TYPE:[^)]*]]) + func.func @block_args(%x: !fir.ref) { + // CHECK-NEXT: %[[PLACEHOLDER0:.*]] = fir.alloca i1 + // CHECK-NEXT: %[[ALLOCA0:.*]] = fir.convert %[[PLACEHOLDER0]] : (!fir.ref) -> !fir.ref + // CHECK-NEXT: %[[PLACEHOLDER1:.*]] = fir.alloca i1 + // CHECK-NEXT: %[[ALLOCA1:.*]] = fir.convert %[[PLACEHOLDER1]] : (!fir.ref) -> !fir.ref + // CHECK-NEXT: %[[X_DECL0:.*]]:2 = hlfir.declare %[[ALLOCA0]] {uniq_name = "x"} : ([[X_TYPE]]) -> ([[X_TYPE]], [[X_TYPE]]) + // CHECK-NEXT: %[[X_DECL1:.*]]:2 = hlfir.declare %[[ALLOCA1]] {uniq_name = "x"} : ([[X_TYPE]]) -> ([[X_TYPE]], [[X_TYPE]]) + // CHECK-NEXT: %[[MAP0:.*]] = omp.map.info var_ptr(%[[X_DECL0]]#1 : [[X_TYPE]], i32) {{.*}} -> [[X_TYPE]] + // CHECK-NEXT: %[[MAP1:.*]] = omp.map.info var_ptr(%[[X_DECL1]]#1 : [[X_TYPE]], i32) {{.*}} -> [[X_TYPE]] + %x_decl:2 = hlfir.declare %x {uniq_name = "x"} : (!fir.ref) -> (!fir.ref, !fir.ref) + omp.parallel private(@privatizer %x_decl#0 -> %arg0 : !fir.ref) { + %0:2 = hlfir.declare %arg0 {uniq_name = "x"} : (!fir.ref) -> (!fir.ref, !fir.ref) + %m0 = omp.map.info var_ptr(%0#1 : !fir.ref, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref + // CHECK-NEXT: omp.target map_entries(%[[MAP0]] -> {{.*}} : [[X_TYPE]]) + omp.target map_entries(%m0 -> %arg2 : !fir.ref) { + omp.terminator + } + omp.terminator + } + + omp.parallel private(@privatizer %x_decl#0 -> %arg0 : !fir.ref) { + %1:2 = hlfir.declare %arg0 {uniq_name = "x"} : (!fir.ref) -> (!fir.ref, !fir.ref) + %m1 = omp.map.info var_ptr(%1#1 : !fir.ref, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref + // CHECK-NOT: omp.parallel + // CHECK-NOT: hlfir.declare + // CHECK-NOT: omp.map.info + // CHECK: omp.target_data map_entries(%[[MAP1]] : [[X_TYPE]]) + omp.target_data map_entries(%m1 : !fir.ref) { + omp.parallel private(@privatizer %1#0 -> %arg1 : !fir.ref) { + // CHECK-NEXT: %[[PLACEHOLDER2:.*]] = fir.alloca i1 + // CHECK-NEXT: %[[ALLOCA2:.*]] = fir.convert %[[PLACEHOLDER2]] : (!fir.ref) -> !fir.ref + // CHECK-NEXT: %[[X_DECL2:.*]]:2 = hlfir.declare %[[ALLOCA2]] {uniq_name = "x"} : ([[X_TYPE]]) -> ([[X_TYPE]], [[X_TYPE]]) + %2:2 = hlfir.declare %arg1 {uniq_name = "x"} : (!fir.ref) -> (!fir.ref, !fir.ref) + // CHECK-NEXT: %[[MAP2:.*]] = omp.map.info var_ptr(%[[X_DECL2]]#1 : [[X_TYPE]], i32) {{.*}} -> [[X_TYPE]] + %m2 = omp.map.info var_ptr(%2#1 : !fir.ref, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref + // CHECK-NEXT: omp.target map_entries(%[[MAP2]] -> {{.*}} : [[X_TYPE]]) + omp.target map_entries(%m2 -> %arg2 : !fir.ref) { + omp.terminator + } + omp.terminator + } + omp.terminator + } + omp.terminator + } + + return + } + + // CHECK-LABEL: func.func @reuse_tests() + func.func @reuse_tests() { + // CHECK-NEXT: %[[PLACEHOLDER:.*]] = fir.alloca i1 + // CHECK-NEXT: %[[THREAD_LIMIT:.*]] = fir.convert %[[PLACEHOLDER]] : (!fir.ref) -> i32 + // CHECK-NEXT: %[[CONST:.*]] = arith.constant 1 : i32 + // CHECK-NEXT: %[[GLOBAL:.*]] = fir.address_of(@global_scalar) : !fir.ref + %global = fir.address_of(@global_scalar) : !fir.ref + // CHECK-NEXT: %[[GLOBAL_DECL0:.*]]:2 = hlfir.declare %[[GLOBAL]] {uniq_name = "global_scalar"} + // CHECK-NEXT: %[[GLOBAL_DECL1:.*]]:2 = hlfir.declare %[[GLOBAL]] {uniq_name = "global_scalar"} + %0:2 = hlfir.declare %global {uniq_name = "global_scalar"} : (!fir.ref) -> (!fir.ref, !fir.ref) + // CHECK-NEXT: %[[MAP0:.*]] = omp.map.info var_ptr(%[[GLOBAL_DECL0]]#1 : !fir.ref, i32) + // CHECK-NEXT: %[[MAP3:.*]] = omp.map.info var_ptr(%[[GLOBAL_DECL1]]#1 : !fir.ref, i32) + %m0 = omp.map.info var_ptr(%0#1 : !fir.ref, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref + // CHECK-NEXT: omp.target_data map_entries(%[[MAP0]] : !fir.ref) + omp.target_data map_entries(%m0 : !fir.ref) { + // CHECK-NEXT: %[[GLOBAL_DECL2:.*]]:2 = hlfir.declare %[[GLOBAL]] {uniq_name = "global_scalar"} + %1:2 = hlfir.declare %global {uniq_name = "global_scalar"} : (!fir.ref) -> (!fir.ref, !fir.ref) + // CHECK-NEXT: %[[MAP1:.*]] = omp.map.info var_ptr(%[[GLOBAL_DECL0]]#1 : !fir.ref, i32) + %m1 = omp.map.info var_ptr(%0#1 : !fir.ref, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref + // CHECK-NEXT: %[[MAP2:.*]] = omp.map.info var_ptr(%[[GLOBAL_DECL2]]#1 : !fir.ref, i32) + %m2 = omp.map.info var_ptr(%1#1 : !fir.ref, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref + // CHECK-NEXT: omp.target map_entries(%[[MAP1]] -> %{{.*}}, %[[MAP2]] -> {{.*}} : !fir.ref, !fir.ref) + omp.target map_entries(%m1 -> %arg0, %m2 -> %arg1 : !fir.ref, !fir.ref) { + omp.terminator + } + omp.terminator + } + // CHECK-NOT: fir.load + // CHECK-NOT: hlfir.declare + %2 = fir.load %global : !fir.ref + %3:2 = hlfir.declare %global {uniq_name = "global_scalar"} : (!fir.ref) -> (!fir.ref, !fir.ref) + %m3 = omp.map.info var_ptr(%3#1 : !fir.ref, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref + // CHECK: omp.target thread_limit(%[[THREAD_LIMIT]] : i32) map_entries(%[[MAP3]] -> %{{.*}} : !fir.ref) + omp.target thread_limit(%2 : i32) map_entries(%m3 -> %arg0 : !fir.ref) { + omp.terminator + } + // CHECK: omp.target thread_limit(%[[CONST]] : i32) + %c1 = arith.constant 1 : i32 + omp.target thread_limit(%c1 : i32) { + omp.terminator + } + // CHECK: omp.target thread_limit(%[[CONST]] : i32) + omp.target thread_limit(%c1 : i32) { + omp.terminator + } + return + } + + // CHECK-LABEL: func.func @all_non_map_clauses + // CHECK-SAME: (%[[REF:.*]]: !fir.ref, %[[INT:.*]]: i32, %[[BOOL:.*]]: i1) + func.func @all_non_map_clauses(%ref: !fir.ref, %int: i32, %bool: i1) { + %m0 = omp.map.info var_ptr(%ref : !fir.ref, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref + // CHECK: omp.target_data map_entries({{[^)]*}}) { + omp.target_data device(%int : i32) if(%bool) map_entries(%m0 : !fir.ref) { + omp.terminator + } + // CHECK: omp.target allocate({{[^)]*}}) thread_limit({{[^)]*}}) in_reduction({{[^)]*}}) private({{[^)]*}}) { + omp.target allocate(%ref : !fir.ref -> %ref : !fir.ref) + depend(taskdependin -> %ref : !fir.ref) + device(%int : i32) if(%bool) thread_limit(%int : i32) + in_reduction(@reduction %ref -> %arg0 : !fir.ref) + private(@privatizer %ref -> %arg1 : !fir.ref) { + omp.terminator + } + // CHECK: omp.target_enter_data + // CHECK-NOT: depend + // CHECK-NOT: device + // CHECK-NOT: if + omp.target_enter_data depend(taskdependin -> %ref : !fir.ref) + device(%int : i32) if(%bool) + // CHECK-NEXT: omp.target_exit_data + // CHECK-NOT: depend + // CHECK-NOT: device + // CHECK-NOT: if + omp.target_exit_data depend(taskdependin -> %ref : !fir.ref) + device(%int : i32) if(%bool) + // CHECK-NEXT: omp.target_update + // CHECK-NOT: depend + // CHECK-NOT: device + // CHECK-NOT: if + omp.target_update depend(taskdependin -> %ref : !fir.ref) + device(%int : i32) if(%bool) + + // CHECK-NEXT: return + return + } + + func.func private @foo() -> () attributes {omp.declare_target = #omp.declaretarget} + fir.global internal @global_scalar constant : i32 { + %0 = arith.constant 10 : i32 + fir.has_value %0 : i32 + } + omp.private {type = firstprivate} @privatizer : i32 copy { + ^bb0(%arg0: !fir.ref, %arg1: !fir.ref): + %0 = fir.load %arg0 : !fir.ref + hlfir.assign %0 to %arg1 : i32, !fir.ref + omp.yield(%arg1 : !fir.ref) + } + omp.declare_reduction @reduction : i32 + init { + ^bb0(%arg: i32): + %0 = arith.constant 0 : i32 + omp.yield (%0 : i32) + } + combiner { + ^bb1(%arg0: i32, %arg1: i32): + %1 = arith.addi %arg0, %arg1 : i32 + omp.yield (%1 : i32) + } +} diff --git a/flang/test/Transforms/omp-function-filtering.mlir b/flang/test/Transforms/OpenMP/function-filtering.mlir similarity index 100% rename from flang/test/Transforms/omp-function-filtering.mlir rename to flang/test/Transforms/OpenMP/function-filtering.mlir