diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index d528772f28724..17ebf93edcce1 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -472,17 +472,26 @@ markDeclareTarget(mlir::Operation *op, lower::AbstractConverter &converter, /// \param [in] infoAccessor - for a private variable, this returns the /// data we want to merge: type or location. /// \param [out] allRegionArgsInfo - the merged list of region info. +/// \param [in] addBeforePrivate - `true` if the passed information goes before +/// private information. template static void mergePrivateVarsInfo(OMPOp op, llvm::ArrayRef currentList, llvm::function_ref infoAccessor, - llvm::SmallVectorImpl &allRegionArgsInfo) { + llvm::SmallVectorImpl &allRegionArgsInfo, + bool addBeforePrivate) { mlir::OperandRange privateVars = op.getPrivateVars(); - llvm::transform(currentList, std::back_inserter(allRegionArgsInfo), - [](InfoTy i) { return i; }); + if (addBeforePrivate) + llvm::transform(currentList, std::back_inserter(allRegionArgsInfo), + [](InfoTy i) { return i; }); + llvm::transform(privateVars, std::back_inserter(allRegionArgsInfo), infoAccessor); + + if (!addBeforePrivate) + llvm::transform(currentList, std::back_inserter(allRegionArgsInfo), + [](InfoTy i) { return i; }); } //===----------------------------------------------------------------------===// @@ -868,12 +877,12 @@ static void genBodyOfTargetOp( mergePrivateVarsInfo(targetOp, mapSymTypes, llvm::function_ref{ [](mlir::Value v) { return v.getType(); }}, - allRegionArgTypes); + allRegionArgTypes, /*addBeforePrivate=*/true); mergePrivateVarsInfo(targetOp, mapSymLocs, llvm::function_ref{ [](mlir::Value v) { return v.getLoc(); }}, - allRegionArgLocs); + allRegionArgLocs, /*addBeforePrivate=*/true); mlir::Block *regionBlock = firOpBuilder.createBlock( ®ion, {}, allRegionArgTypes, allRegionArgLocs); @@ -1478,21 +1487,21 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable, mergePrivateVarsInfo(parallelOp, reductionTypes, llvm::function_ref{ [](mlir::Value v) { return v.getType(); }}, - allRegionArgTypes); + allRegionArgTypes, /*addBeforePrivate=*/false); llvm::SmallVector allRegionArgLocs; mergePrivateVarsInfo(parallelOp, llvm::ArrayRef(reductionLocs), llvm::function_ref{ [](mlir::Value v) { return v.getLoc(); }}, - allRegionArgLocs); + allRegionArgLocs, /*addBeforePrivate=*/false); mlir::Region ®ion = parallelOp.getRegion(); firOpBuilder.createBlock(®ion, /*insertPt=*/{}, allRegionArgTypes, allRegionArgLocs); - llvm::SmallVector allSymbols(reductionSyms); - allSymbols.append(dsp->getDelayedPrivSymbols().begin(), - dsp->getDelayedPrivSymbols().end()); + llvm::SmallVector allSymbols( + dsp->getDelayedPrivSymbols()); + allSymbols.append(reductionSyms.begin(), reductionSyms.end()); unsigned argIdx = 0; for (const semantics::Symbol *arg : allSymbols) { diff --git a/flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90 b/flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90 index 2943957117932..6c00bb23f15b9 100644 --- a/flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90 +++ b/flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90 @@ -26,5 +26,5 @@ subroutine red_and_delayed_private ! CHECK-LABEL: _QPred_and_delayed_private ! CHECK: omp.parallel -! CHECK-SAME: reduction(byref @[[REDUCTION_SYM]] %{{.*}} -> %arg0 : !fir.ref) -! CHECK-SAME: private(@[[PRIVATIZER_SYM]] %{{.*}} -> %arg1 : !fir.ref) { +! CHECK-SAME: private(@[[PRIVATIZER_SYM]] %{{.*}} -> %arg0 : !fir.ref) +! CHECK-SAME: reduction(byref @[[REDUCTION_SYM]] %{{.*}} -> %arg1 : !fir.ref) { diff --git a/flang/test/Lower/OpenMP/delayed-privatization-reduction.f90 b/flang/test/Lower/OpenMP/delayed-privatization-reduction.f90 index d814b2b0ff0f3..38139e52ce95c 100644 --- a/flang/test/Lower/OpenMP/delayed-privatization-reduction.f90 +++ b/flang/test/Lower/OpenMP/delayed-privatization-reduction.f90 @@ -29,5 +29,5 @@ subroutine red_and_delayed_private ! CHECK-LABEL: _QPred_and_delayed_private ! CHECK: omp.parallel -! CHECK-SAME: reduction(@[[REDUCTION_SYM]] %{{.*}} -> %arg0 : !fir.ref) -! CHECK-SAME: private(@[[PRIVATIZER_SYM]] %{{.*}} -> %arg1 : !fir.ref) { +! CHECK-SAME: private(@[[PRIVATIZER_SYM]] %{{.*}} -> %arg0 : !fir.ref) +! CHECK-SAME: reduction(@[[REDUCTION_SYM]] %{{.*}} -> %arg1 : !fir.ref) { diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td index c579ba6e751d2..876d53766a0ca 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td @@ -451,7 +451,7 @@ class OpenMP_InReductionClauseSkip< > : OpenMP_Clause { let traits = [ - ReductionClauseInterface + BlockArgOpenMPOpInterface, ReductionClauseInterface ]; let arguments = (ins @@ -472,6 +472,8 @@ class OpenMP_InReductionClauseSkip< return SmallVector(getInReductionVars().begin(), getInReductionVars().end()); } + + unsigned numInReductionBlockArgs() { return getInReductionVars().size(); } }]; // Description varies depending on the operation. @@ -575,6 +577,8 @@ class OpenMP_MapClauseSkip< > : OpenMP_Clause { let traits = [ + // Not adding the BlockArgOpenMPOpInterface here because omp.target is the + // only operation defining block arguments for `map` clauses. MapClauseOwningOpInterface ]; @@ -923,6 +927,10 @@ class OpenMP_PrivateClauseSkip< bit description = false, bit extraClassDeclaration = false > : OpenMP_Clause { + let traits = [ + BlockArgOpenMPOpInterface + ]; + let arguments = (ins Variadic:$private_vars, OptionalAttr:$private_syms @@ -933,6 +941,10 @@ class OpenMP_PrivateClauseSkip< custom($private_vars, type($private_vars), $private_syms) `)` }]; + let extraClassDeclaration = [{ + unsigned numPrivateBlockArgs() { return getPrivateVars().size(); } + }]; + // TODO: Add description. } @@ -973,7 +985,7 @@ class OpenMP_ReductionClauseSkip< > : OpenMP_Clause { let traits = [ - ReductionClauseInterface + BlockArgOpenMPOpInterface, ReductionClauseInterface ]; let arguments = (ins @@ -991,6 +1003,7 @@ class OpenMP_ReductionClauseSkip< let extraClassDeclaration = [{ /// Returns the number of reduction variables. unsigned getNumReductionVars() { return getReductionVars().size(); } + unsigned numReductionBlockArgs() { return getReductionVars().size(); } }]; // Description varies depending on the operation. @@ -1104,7 +1117,7 @@ class OpenMP_TaskReductionClauseSkip< > : OpenMP_Clause { let traits = [ - ReductionClauseInterface + BlockArgOpenMPOpInterface, ReductionClauseInterface ]; let arguments = (ins @@ -1119,6 +1132,18 @@ class OpenMP_TaskReductionClauseSkip< $task_reduction_byref, $task_reduction_syms) `)` }]; + let extraClassDeclaration = [{ + /// Returns the reduction variables. + SmallVector getReductionVars() { + return SmallVector(getTaskReductionVars().begin(), + getTaskReductionVars().end()); + } + + unsigned numTaskReductionBlockArgs() { + return getTaskReductionVars().size(); + } + }]; + let description = [{ The `task_reduction` clause specifies a reduction among tasks. For each list item, the number of copies is unspecified. Any copies associated with the @@ -1130,14 +1155,6 @@ class OpenMP_TaskReductionClauseSkip< attribute, and whether the reduction variable should be passed into the reduction region by value or by reference in `task_reduction_byref`. }]; - - let extraClassDeclaration = [{ - /// Returns the reduction variables. - SmallVector getReductionVars() { - return SmallVector(getTaskReductionVars().begin(), - getTaskReductionVars().end()); - } - }]; } def OpenMP_TaskReductionClause : OpenMP_TaskReductionClauseSkip<>; diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index 9d2123a2e9bf5..326bdd3bbc946 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -1043,7 +1043,8 @@ def TargetUpdateOp: OpenMP_Op<"target_update", traits = [ //===----------------------------------------------------------------------===// def TargetOp : OpenMP_Op<"target", traits = [ - AttrSizedOperandSegments, IsolatedFromAbove, OutlineableOpenMPOpInterface + AttrSizedOperandSegments, BlockArgOpenMPOpInterface, IsolatedFromAbove, + OutlineableOpenMPOpInterface ], clauses = [ // TODO: Complete clause list (defaultmap, uses_allocators). OpenMP_AllocateClause, OpenMP_DependClause, OpenMP_DeviceClause, @@ -1065,6 +1066,10 @@ def TargetOp : OpenMP_Op<"target", traits = [ OpBuilder<(ins CArg<"const TargetOperands &">:$clauses)> ]; + let extraClassDeclaration = [{ + unsigned numMapBlockArgs() { return getMapVars().size(); } + }] # clausesExtraClassDeclaration; + let hasVerifier = 1; } diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td index ea1e3ebecef7b..2602384744f23 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td @@ -15,6 +15,114 @@ include "mlir/IR/OpBase.td" +def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> { + let description = [{ + OpenMP operations that define entry block arguments as part of the + representation of its clauses. + }]; + + let cppNamespace = "::mlir::omp"; + + let methods = [ + // Default-implemented methods to be overriden by the corresponding clauses. + InterfaceMethod<"Get number of block arguments defined by `in_reduction`.", + "unsigned", "numInReductionBlockArgs", (ins), [{}], [{ + return 0; + }]>, + InterfaceMethod<"Get number of block arguments defined by `map`.", + "unsigned", "numMapBlockArgs", (ins), [{}], [{ + return 0; + }]>, + InterfaceMethod<"Get number of block arguments defined by `private`.", + "unsigned", "numPrivateBlockArgs", (ins), [{}], [{ + return 0; + }]>, + InterfaceMethod<"Get number of block arguments defined by `reduction`.", + "unsigned", "numReductionBlockArgs", (ins), [{}], [{ + return 0; + }]>, + InterfaceMethod<"Get number of block arguments defined by `task_reduction`.", + "unsigned", "numTaskReductionBlockArgs", (ins), [{}], [{ + return 0; + }]>, + + // Unified access methods for clause-associated entry block arguments. + InterfaceMethod<"Get start index of block arguments defined by `in_reduction`.", + "unsigned", "getInReductionBlockArgsStart", (ins), [{ + return 0; + }]>, + InterfaceMethod<"Get start index of block arguments defined by `map`.", + "unsigned", "getMapBlockArgsStart", (ins), [{ + auto iface = ::llvm::cast(*$_op); + return iface.getInReductionBlockArgsStart() + + $_op.numInReductionBlockArgs(); + }]>, + InterfaceMethod<"Get start index of block arguments defined by `private`.", + "unsigned", "getPrivateBlockArgsStart", (ins), [{ + auto iface = ::llvm::cast(*$_op); + return iface.getMapBlockArgsStart() + $_op.numMapBlockArgs(); + }]>, + InterfaceMethod<"Get start index of block arguments defined by `reduction`.", + "unsigned", "getReductionBlockArgsStart", (ins), [{ + auto iface = ::llvm::cast(*$_op); + return iface.getPrivateBlockArgsStart() + $_op.numPrivateBlockArgs(); + }]>, + InterfaceMethod<"Get start index of block arguments defined by `task_reduction`.", + "unsigned", "getTaskReductionBlockArgsStart", (ins), [{ + auto iface = ::llvm::cast(*$_op); + return iface.getReductionBlockArgsStart() + $_op.numReductionBlockArgs(); + }]>, + + InterfaceMethod<"Get block arguments defined by `in_reduction`.", + "::llvm::MutableArrayRef<::mlir::BlockArgument>", + "getInReductionBlockArgs", (ins), [{ + auto iface = ::llvm::cast(*$_op); + return $_op->getRegion(0).getArguments().slice( + iface.getInReductionBlockArgsStart(), $_op.numInReductionBlockArgs()); + }]>, + InterfaceMethod<"Get block arguments defined by `map`.", + "::llvm::MutableArrayRef<::mlir::BlockArgument>", + "getMapBlockArgs", (ins), [{ + auto iface = ::llvm::cast(*$_op); + return $_op->getRegion(0).getArguments().slice( + iface.getMapBlockArgsStart(), $_op.numMapBlockArgs()); + }]>, + InterfaceMethod<"Get block arguments defined by `private`.", + "::llvm::MutableArrayRef<::mlir::BlockArgument>", + "getPrivateBlockArgs", (ins), [{ + auto iface = ::llvm::cast(*$_op); + return $_op->getRegion(0).getArguments().slice( + iface.getPrivateBlockArgsStart(), $_op.numPrivateBlockArgs()); + }]>, + InterfaceMethod<"Get block arguments defined by `reduction`.", + "::llvm::MutableArrayRef<::mlir::BlockArgument>", + "getReductionBlockArgs", (ins), [{ + auto iface = ::llvm::cast(*$_op); + return $_op->getRegion(0).getArguments().slice( + iface.getReductionBlockArgsStart(), $_op.numReductionBlockArgs()); + }]>, + InterfaceMethod<"Get block arguments defined by `task_reduction`.", + "::llvm::MutableArrayRef<::mlir::BlockArgument>", + "getTaskReductionBlockArgs", (ins), [{ + auto iface = ::llvm::cast(*$_op); + return $_op->getRegion(0).getArguments().slice( + iface.getTaskReductionBlockArgsStart(), + $_op.numTaskReductionBlockArgs()); + }]>, + ]; + + let verify = [{ + auto iface = ::llvm::cast($_op); + unsigned expectedArgs = iface.numInReductionBlockArgs() + + iface.numMapBlockArgs() + iface.numPrivateBlockArgs() + + iface.numReductionBlockArgs() + iface.numTaskReductionBlockArgs(); + if ($_op->getRegion(0).getNumArguments() < expectedArgs) + return $_op->emitOpError() << "expected at least " << expectedArgs + << " entry block argument(s)"; + return ::mlir::success(); + }]; +} + def OutlineableOpenMPOpInterface : OpInterface<"OutlineableOpenMPOpInterface"> { let description = [{ OpenMP operations whose region will be outlined will implement this diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 59e71ecc6ec5d..6b1abbc186a19 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -536,13 +536,6 @@ static ParseResult parseParallelRegion( llvm::SmallVectorImpl &privateTypes, ArrayAttr &privateSyms) { llvm::SmallVector regionPrivateArgs; - if (succeeded(parser.parseOptionalKeyword("reduction"))) { - if (failed(parseClauseWithRegionArgs(parser, region, reductionVars, - reductionTypes, reductionByref, - reductionSyms, regionPrivateArgs))) - return failure(); - } - if (succeeded(parser.parseOptionalKeyword("private"))) { auto privateByref = DenseBoolArrayAttr::get(parser.getContext(), {}); if (failed(parseClauseWithRegionArgs(parser, region, privateVars, @@ -557,6 +550,13 @@ static ParseResult parseParallelRegion( } } + if (succeeded(parser.parseOptionalKeyword("reduction"))) { + if (failed(parseClauseWithRegionArgs(parser, region, reductionVars, + reductionTypes, reductionByref, + reductionSyms, regionPrivateArgs))) + return failure(); + } + return parser.parseRegion(region, regionPrivateArgs); } @@ -566,18 +566,9 @@ static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region ®ion, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms) { - if (reductionSyms) { - auto *argsBegin = region.front().getArguments().begin(); - MutableArrayRef argsSubrange(argsBegin, argsBegin + reductionTypes.size()); - printClauseWithRegionArgs(p, op, argsSubrange, "reduction", reductionVars, - reductionTypes, reductionByref, reductionSyms); - } - if (privateSyms) { auto *argsBegin = region.front().getArguments().begin(); - MutableArrayRef argsSubrange(argsBegin + reductionVars.size(), - argsBegin + reductionVars.size() + - privateTypes.size()); + MutableArrayRef argsSubrange(argsBegin, argsBegin + privateTypes.size()); mlir::SmallVector isByRefVec; isByRefVec.resize(privateTypes.size(), false); DenseBoolArrayAttr isByRef = @@ -587,6 +578,15 @@ static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region ®ion, privateTypes, isByRef, privateSyms); } + if (reductionSyms) { + auto *argsBegin = region.front().getArguments().begin(); + MutableArrayRef argsSubrange(argsBegin + privateVars.size(), + argsBegin + privateVars.size() + + reductionTypes.size()); + printClauseWithRegionArgs(p, op, argsSubrange, "reduction", reductionVars, + reductionTypes, reductionByref, reductionSyms); + } + p.printRegion(region, /*printEntryBlockArgs=*/false); } diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index c22d9a189a7e0..7c89d3bd6ec5a 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -920,7 +920,7 @@ convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder, DenseMap reductionVariableMap; MutableArrayRef reductionArgs = - sectionsOp.getRegion().getArguments(); + cast(opInst).getReductionBlockArgs(); if (failed(allocAndInitializeReductionVars( sectionsOp, reductionArgs, builder, moduleTranslation, allocaIP, @@ -1216,7 +1216,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, DenseMap reductionVariableMap; MutableArrayRef reductionArgs = - wsloopOp.getRegion().getArguments(); + cast(opInst).getReductionBlockArgs(); if (failed(allocAndInitializeReductionVars( wsloopOp, reductionArgs, builder, moduleTranslation, allocaIP, @@ -1329,31 +1329,23 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, class OmpParallelOpConversionManager { public: OmpParallelOpConversionManager(omp::ParallelOp opInst) - : region(opInst.getRegion()), privateVars(opInst.getPrivateVars()), - privateArgBeginIdx(opInst.getNumReductionVars()), - privateArgEndIdx(privateArgBeginIdx + privateVars.size()) { - auto privateVarsIt = privateVars.begin(); - - for (size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx; - ++argIdx, ++privateVarsIt) - mlir::replaceAllUsesInRegionWith(region.getArgument(argIdx), - *privateVarsIt, region); + : region(opInst.getRegion()), + privateBlockArgs(cast(*opInst) + .getPrivateBlockArgs()), + privateVars(opInst.getPrivateVars()) { + for (auto [blockArg, var] : llvm::zip_equal(privateBlockArgs, privateVars)) + mlir::replaceAllUsesInRegionWith(blockArg, var, region); } ~OmpParallelOpConversionManager() { - auto privateVarsIt = privateVars.begin(); - - for (size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx; - ++argIdx, ++privateVarsIt) - mlir::replaceAllUsesInRegionWith(*privateVarsIt, - region.getArgument(argIdx), region); + for (auto [blockArg, var] : llvm::zip_equal(privateBlockArgs, privateVars)) + mlir::replaceAllUsesInRegionWith(var, blockArg, region); } private: Region ®ion; + llvm::MutableArrayRef privateBlockArgs; OperandRange privateVars; - unsigned privateArgBeginIdx; - unsigned privateArgEndIdx; }; // Looks up from the operation from and returns the PrivateClauseOp with @@ -1417,9 +1409,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, DenseMap reductionVariableMap; MutableArrayRef reductionArgs = - opInst.getRegion().getArguments().slice( - opInst.getNumAllocateVars() + opInst.getNumAllocatorsVars(), - opInst.getNumReductionVars()); + cast(*opInst).getReductionBlockArgs(); allocaIP = InsertPointTy(allocaIP.getBlock(), @@ -3414,6 +3404,8 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, auto &targetRegion = targetOp.getRegion(); DataLayout dl = DataLayout(opInst.getParentOfType()); SmallVector mapVars = targetOp.getMapVars(); + ArrayRef mapBlockArgs = + cast(opInst).getMapBlockArgs(); llvm::Function *llvmOutlinedFn = nullptr; // TODO: It can also be false if a compile-time constant `false` IF clause is @@ -3442,11 +3434,10 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, llvmOutlinedFn->addFnAttr(attr); builder.restoreIP(codeGenIP); - for (auto [argIndex, mapOp] : llvm::enumerate(mapVars)) { + for (auto [arg, mapOp] : llvm::zip_equal(mapBlockArgs, mapVars)) { auto mapInfoOp = cast(mapOp.getDefiningOp()); llvm::Value *mapOpValue = moduleTranslation.lookupValue(mapInfoOp.getVarPtr()); - const auto &arg = targetRegion.front().getArgument(argIndex); moduleTranslation.mapValue(arg, mapOpValue); } @@ -3457,18 +3448,13 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, OperandRange privateVars = targetOp.getPrivateVars(); std::optional privateSyms = targetOp.getPrivateSyms(); - unsigned numMapVars = targetOp.getMapVars().size(); - Block &firstTargetBlock = targetRegion.front(); - BlockArgument *blockArgsStart = firstTargetBlock.getArguments().begin(); - BlockArgument *privArgsStart = blockArgsStart + numMapVars; - BlockArgument *privArgsEnd = - privArgsStart + targetOp.getPrivateVars().size(); - MutableArrayRef privateBlockArgs(privArgsStart, privArgsEnd); + MutableArrayRef privateBlockArgs = + cast(opInst).getPrivateBlockArgs(); for (auto [privVar, privatizerNameAttr, privBlockArg] : llvm::zip_equal(privateVars, *privateSyms, privateBlockArgs)) { - SymbolRefAttr privSym = llvm::cast(privatizerNameAttr); + SymbolRefAttr privSym = cast(privatizerNameAttr); omp::PrivateClauseOp privatizer = findPrivatizer(&opInst, privSym); if (privatizer.getDataSharingType() == omp::DataSharingClauseType::FirstPrivate || diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index 35a8883e3a317..4899583ac3bff 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -1471,6 +1471,7 @@ func.func @omp_sections(%data_var : memref) -> () { func.func @omp_sections(%data_var : memref) -> () { // expected-error @below {{expected as many reduction symbol references as reduction variables}} "omp.sections" (%data_var) ({ + ^bb0(%arg0: memref): omp.terminator }) {operandSegmentSizes = array} : (memref) -> () return @@ -1662,6 +1663,7 @@ func.func @omp_task_depend(%data_var: memref) { func.func @omp_task(%ptr: !llvm.ptr) { // expected-error @below {{op expected symbol reference @add_f32 to point to a reduction declaration}} omp.task in_reduction(@add_f32 -> %ptr : !llvm.ptr) { + ^bb0(%arg0: !llvm.ptr): // CHECK: "test.foo"() : () -> () "test.foo"() : () -> () // CHECK: omp.terminator @@ -1686,6 +1688,7 @@ combiner { func.func @omp_task(%ptr: !llvm.ptr) { // expected-error @below {{op accumulator variable used more than once}} omp.task in_reduction(@add_f32 -> %ptr : !llvm.ptr, @add_f32 -> %ptr : !llvm.ptr) { + ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr): // CHECK: "test.foo"() : () -> () "test.foo"() : () -> () // CHECK: omp.terminator @@ -1716,6 +1719,7 @@ atomic { func.func @omp_task(%mem: memref<1xf32>) { // expected-error @below {{op expected accumulator ('memref<1xf32>') to be the same type as reduction declaration ('!llvm.ptr')}} omp.task in_reduction(@add_i32 -> %mem : memref<1xf32>) { + ^bb0(%arg0: memref<1xf32>): // CHECK: "test.foo"() : () -> () "test.foo"() : () -> () // CHECK: omp.terminator diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index e7d3e67ca7e05..2116071f8523a 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -1096,6 +1096,7 @@ func.func @omp_teams(%lb : i32, %ub : i32, %if_cond : i1, %num_threads : i32, %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr // CHECK: omp.teams reduction(@add_f32 -> %{{.+}} : !llvm.ptr) { omp.teams reduction(@add_f32 -> %0 : !llvm.ptr) { + ^bb0(%arg0: !llvm.ptr): %1 = arith.constant 2.0 : f32 // CHECK: omp.terminator omp.terminator @@ -1104,6 +1105,7 @@ func.func @omp_teams(%lb : i32, %ub : i32, %if_cond : i1, %num_threads : i32, // Test reduction byref // CHECK: omp.teams reduction(byref @add_f32 -> %{{.+}} : !llvm.ptr) { omp.teams reduction(byref @add_f32 -> %0 : !llvm.ptr) { + ^bb0(%arg0: !llvm.ptr): %1 = arith.constant 2.0 : f32 // CHECK: omp.terminator omp.terminator @@ -1125,6 +1127,7 @@ func.func @sections_reduction() { %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr // CHECK: omp.sections reduction(@add_f32 -> {{.+}} : !llvm.ptr) omp.sections reduction(@add_f32 -> %0 : !llvm.ptr) { + ^bb0(%arg0: !llvm.ptr): // CHECK: omp.section omp.section { %1 = arith.constant 2.0 : f32 @@ -1146,6 +1149,7 @@ func.func @sections_reduction_byref() { %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr // CHECK: omp.sections reduction(byref @add_f32 -> {{.+}} : !llvm.ptr) omp.sections reduction(byref @add_f32 -> %0 : !llvm.ptr) { + ^bb0(%arg0: !llvm.ptr): // CHECK: omp.section omp.section { %1 = arith.constant 2.0 : f32 @@ -1245,6 +1249,7 @@ func.func @sections_reduction2() { %0 = memref.alloca() : memref<1xf32> // CHECK: omp.sections reduction(@add2_f32 -> %{{.+}} : memref<1xf32>) omp.sections reduction(@add2_f32 -> %0 : memref<1xf32>) { + ^bb0(%arg0: !llvm.ptr): omp.section { %1 = arith.constant 2.0 : f32 omp.terminator @@ -1901,6 +1906,7 @@ func.func @omp_sectionsop(%data_var1 : memref, %data_var2 : memref, // CHECK: omp.sections reduction(@add_f32 -> %{{.*}} : !llvm.ptr) "omp.sections" (%redn_var) ({ + ^bb0(%arg0: !llvm.ptr): // CHECK: omp.terminator omp.terminator }) {operandSegmentSizes = array, reduction_byref = array, reduction_syms=[@add_f32]} : (!llvm.ptr) -> () @@ -1913,6 +1919,7 @@ func.func @omp_sectionsop(%data_var1 : memref, %data_var2 : memref, // CHECK: omp.sections reduction(@add_f32 -> %{{.*}} : !llvm.ptr) { omp.sections reduction(@add_f32 -> %redn_var : !llvm.ptr) { + ^bb0(%arg0: !llvm.ptr): // CHECK: omp.terminator omp.terminator } @@ -2087,6 +2094,7 @@ func.func @omp_task(%bool_var: i1, %i64_var: i64, %i32_var: i32, %data_var: memr %1 = llvm.alloca %c1 x f32 : (i32) -> !llvm.ptr // CHECK: omp.task in_reduction(@add_f32 -> %[[redn_var1]] : !llvm.ptr, @add_f32 -> %[[redn_var2]] : !llvm.ptr) { omp.task in_reduction(@add_f32 -> %0 : !llvm.ptr, @add_f32 -> %1 : !llvm.ptr) { + ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr): // CHECK: "test.foo"() : () -> () "test.foo"() : () -> () // CHECK: omp.terminator @@ -2096,6 +2104,7 @@ func.func @omp_task(%bool_var: i1, %i64_var: i64, %i32_var: i32, %data_var: memr // Checking `in_reduction` clause (mixed) byref // CHECK: omp.task in_reduction(byref @add_f32 -> %[[redn_var1]] : !llvm.ptr, @add_f32 -> %[[redn_var2]] : !llvm.ptr) { omp.task in_reduction(byref @add_f32 -> %0 : !llvm.ptr, @add_f32 -> %1 : !llvm.ptr) { + ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr): // CHECK: "test.foo"() : () -> () "test.foo"() : () -> () // CHECK: omp.terminator @@ -2129,6 +2138,7 @@ func.func @omp_task(%bool_var: i1, %i64_var: i64, %i32_var: i32, %data_var: memr in_reduction(@add_f32 -> %0 : !llvm.ptr, byref @add_f32 -> %1 : !llvm.ptr) // CHECK-SAME: priority(%[[i32_var]] : i32) untied priority(%i32_var : i32) untied { + ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr): // CHECK: "test.foo"() : () -> () "test.foo"() : () -> () // CHECK: omp.terminator @@ -2306,6 +2316,7 @@ func.func @omp_taskgroup_clauses() -> () { %testf32 = "test.f32"() : () -> (!llvm.ptr) // CHECK: omp.taskgroup allocate(%{{.+}}: memref -> %{{.+}}: memref) task_reduction(@add_f32 -> %{{.+}}: !llvm.ptr) omp.taskgroup allocate(%testmemref : memref -> %testmemref : memref) task_reduction(@add_f32 -> %testf32 : !llvm.ptr) { + ^bb0(%arg0 : !llvm.ptr): // CHECK: omp.task omp.task { "test.foo"() : () -> () @@ -2783,15 +2794,15 @@ omp.private {type = firstprivate} @y.privatizer : !llvm.ptr alloc { // CHECK-LABEL: parallel_op_reduction_and_private func.func @parallel_op_reduction_and_private(%priv_var: !llvm.ptr, %priv_var2: !llvm.ptr, %reduc_var: !llvm.ptr, %reduc_var2: !llvm.ptr) { // CHECK: omp.parallel - // CHECK-SAME: reduction( - // CHECK-SAME: @add_f32 %[[REDUC_VAR:[^[:space:]]+]] -> %[[REDUC_ARG:[^[:space:]]+]] : !llvm.ptr, - // CHECK-SAME: @add_f32 %[[REDUC_VAR2:[^[:space:]]+]] -> %[[REDUC_ARG2:[^[:space:]]+]] : !llvm.ptr) - // // CHECK-SAME: private( // CHECK-SAME: @x.privatizer %[[PRIV_VAR:[^[:space:]]+]] -> %[[PRIV_ARG:[^[:space:]]+]] : !llvm.ptr, // CHECK-SAME: @y.privatizer %[[PRIV_VAR2:[^[:space:]]+]] -> %[[PRIV_ARG2:[^[:space:]]+]] : !llvm.ptr) - omp.parallel reduction(@add_f32 %reduc_var -> %reduc_arg : !llvm.ptr, @add_f32 %reduc_var2 -> %reduc_arg2 : !llvm.ptr) - private(@x.privatizer %priv_var -> %priv_arg : !llvm.ptr, @y.privatizer %priv_var2 -> %priv_arg2 : !llvm.ptr) { + // + // CHECK-SAME: reduction( + // CHECK-SAME: @add_f32 %[[REDUC_VAR:[^[:space:]]+]] -> %[[REDUC_ARG:[^[:space:]]+]] : !llvm.ptr, + // CHECK-SAME: @add_f32 %[[REDUC_VAR2:[^[:space:]]+]] -> %[[REDUC_ARG2:[^[:space:]]+]] : !llvm.ptr) + omp.parallel private(@x.privatizer %priv_var -> %priv_arg : !llvm.ptr, @y.privatizer %priv_var2 -> %priv_arg2 : !llvm.ptr) + reduction(@add_f32 %reduc_var -> %reduc_arg : !llvm.ptr, @add_f32 %reduc_var2 -> %reduc_arg2 : !llvm.ptr) { // CHECK: llvm.load %[[PRIV_ARG]] %0 = llvm.load %priv_arg : !llvm.ptr -> f32 // CHECK: llvm.load %[[PRIV_ARG2]] diff --git a/mlir/test/Target/LLVMIR/openmp-private.mlir b/mlir/test/Target/LLVMIR/openmp-private.mlir index 21167668bbee1..a06e44fc5cfe0 100644 --- a/mlir/test/Target/LLVMIR/openmp-private.mlir +++ b/mlir/test/Target/LLVMIR/openmp-private.mlir @@ -206,7 +206,7 @@ llvm.func @private_and_reduction_() attributes {fir.internal_name = "_QPprivate_ %0 = llvm.mlir.constant(1 : i64) : i64 %1 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> : (i64) -> !llvm.ptr %2 = llvm.alloca %0 x f32 {bindc_name = "to_priv"} : (i64) -> !llvm.ptr - omp.parallel reduction(byref @reducer.part %1 -> %arg0 : !llvm.ptr) private(@privatizer.part %2 -> %arg1 : !llvm.ptr) { + omp.parallel private(@privatizer.part %2 -> %arg1 : !llvm.ptr) reduction(byref @reducer.part %1 -> %arg0 : !llvm.ptr) { %3 = llvm.load %arg0 : !llvm.ptr -> !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> %4 = llvm.mlir.constant(8.000000e+00 : f32) : f32 llvm.store %4, %arg1 : f32, !llvm.ptr