diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp index 4c51b61f6bf02..d3f45cdd14ff8 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -761,6 +761,61 @@ bool ClauseProcessor::processHasDeviceAddr( addUseDeviceClause(converter, devAddrClause.v, result.hasDeviceAddrVars, isDeviceTypes, isDeviceLocs, isDeviceSymbols); }); + +} + +bool ClauseProcessor::processTargetDepend( + mlir::Location currentLocation, mlir::omp::DependClauseOps &clauseOps) const { + + processDepend(clauseOps); + if (clauseOps.dependTypeAttrs.empty()) + return false; + + // If 'dependTypeOperands' is not empty, this means the depend + // clause was used and we create an omp.task operation that'll + // enclose the omp.target operation corresponding to the target + // construct used. This new omp.task will be a mergeable task + // on which the depend clause will be tacked on. The depend + // clause on the original target construct is dropped. + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + + // Create the new omp.task op. + // Whether we create a mergeable task or not depends upon the presence of the + // nowait clause on the target construct. + // If the nowait clause is not present on the target construct, then as per + // the spec, the target task is an included task. We add if(0) clause to the + // task that we create. A task with an if clause that evaluates to false is + // undeferred and because this value is known at compile time, it is an + // included task. And an included task is also mergeable. So, we don't bother + // with the mergeable clause here. If the nowait clause is present on the + // target construct, then as per the spec, the execution of the target task + // may be deferred. This makes it trivially not mergeable. + mlir::omp::NowaitClauseOps nowaitClauseOp; + markClauseOccurrence(nowaitClauseOp.nowaitAttr); + + mlir::omp::TaskOp taskOp = firOpBuilder.create( + currentLocation, + /*if_expr*/ nowaitClauseOp.nowaitAttr + ? firOpBuilder.createBool(currentLocation, true) + : firOpBuilder.createBool(currentLocation, false), + /*final_expr*/ mlir::Value(), /*untied*/ mlir::UnitAttr(), + /*mergeable*/ mlir::UnitAttr(), + /*in_reduction_vars*/ mlir::ValueRange(), /*in_reductions*/ nullptr, + /*priority*/ mlir::Value(), + mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(), + clauseOps.dependTypeAttrs), + clauseOps.dependVars, /*allocate_vars*/ mlir::ValueRange(), + /*allocate_vars*/ mlir::ValueRange()); + + // Clear the dependencies so that the subsequent omp.target op doesn't have + // dependencies + clauseOps.dependTypeAttrs.clear(); + clauseOps.dependVars.clear(); + + firOpBuilder.createBlock(&taskOp.getRegion()); + firOpBuilder.create(currentLocation); + firOpBuilder.setInsertionPointToStart(&taskOp.getRegion().front()); + return true; } bool ClauseProcessor::processIf( diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h index 78c148ab02163..24d8e9f1b5dea 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.h +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h @@ -94,6 +94,12 @@ class ClauseProcessor { bool processCopyprivate(mlir::Location currentLocation, mlir::omp::CopyprivateClauseOps &result) const; bool processDepend(mlir::omp::DependClauseOps &result) const; + // This is a special case of processDepend that processes the depend + // clause on Target ops - TargetOp, EnterDataOp, ExitDataOp, UpdateDataOp + // It sets up the generation of MLIR code for the target construct + // in question by first creating an enclosing omp.task operation and transfers + // the 'depend' clause and its arguments to this new omp.task operation. + bool processTargetDepend(mlir::Location currentLocation, mlir::omp::DependClauseOps &clauseOps) const; bool processEnter(llvm::SmallVectorImpl &result) const; bool processIf(omp::clause::If::DirectiveNameModifier directiveName, diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index e932f7c284bca..afe467f8e53e4 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -1140,7 +1140,7 @@ static void genTargetClauses( llvm::SmallVectorImpl &devicePtrLocs, llvm::SmallVectorImpl &devicePtrTypes) { ClauseProcessor cp(converter, semaCtx, clauses); - cp.processDepend(clauseOps); + cp.processTargetDepend(loc, clauseOps); cp.processDevice(stmtCtx, clauseOps); cp.processHasDeviceAddr(clauseOps, deviceAddrTypes, deviceAddrLocs, deviceAddrSyms); @@ -1199,7 +1199,7 @@ static void genTargetEnterExitUpdateDataClauses( mlir::Location loc, llvm::omp::Directive directive, mlir::omp::TargetEnterExitUpdateDataClauseOps &clauseOps) { ClauseProcessor cp(converter, semaCtx, clauses); - cp.processDepend(clauseOps); + cp.processTargetDepend(loc, clauseOps); cp.processDevice(stmtCtx, clauseOps); cp.processIf(directive, clauseOps); cp.processNowait(clauseOps); diff --git a/flang/test/Lower/OpenMP/target.f90 b/flang/test/Lower/OpenMP/target.f90 index 51b66327dfb24..5ee76c136189b 100644 --- a/flang/test/Lower/OpenMP/target.f90 +++ b/flang/test/Lower/OpenMP/target.f90 @@ -27,9 +27,10 @@ subroutine omp_target_enter_depend !$omp task depend(out: a) call foo(a) !$omp end task + !CHECK: omp.task if(%false) depend(taskdependin -> %[[A]]#1 : !fir.ref>) { !CHECK: %[[BOUNDS:.*]] = omp.map.bounds lower_bound({{.*}}) upper_bound({{.*}}) extent({{.*}}) stride({{.*}}) start_idx({{.*}}) !CHECK: %[[MAP:.*]] = omp.map.info var_ptr({{.*}}) map_clauses(to) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref> {name = "a"} - !CHECK: omp.target_enter_data map_entries(%[[MAP]] : !fir.ref>) depend(taskdependin -> %[[A]]#1 : !fir.ref>) + !CHECK: omp.target_enter_data map_entries(%[[MAP]] : !fir.ref>) !$omp target enter data map(to: a) depend(in: a) return end subroutine omp_target_enter_depend @@ -166,9 +167,11 @@ subroutine omp_target_exit_depend !$omp task depend(out: a) call foo(a) !$omp end task + + !CHECK: omp.task if(%false) depend(taskdependout -> %[[A]]#1 : !fir.ref>) !CHECK: %[[BOUNDS:.*]] = omp.map.bounds lower_bound({{.*}}) upper_bound({{.*}}) extent({{.*}}) stride({{.*}}) start_idx({{.*}}) !CHECK: %[[MAP:.*]] = omp.map.info var_ptr({{.*}}) map_clauses(from) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref> {name = "a"} - !CHECK: omp.target_exit_data map_entries(%[[MAP]] : !fir.ref>) depend(taskdependout -> %[[A]]#1 : !fir.ref>) + !CHECK: omp.target_exit_data map_entries(%[[MAP]] : !fir.ref>) !$omp target exit data map(from: a) depend(out: a) end subroutine omp_target_exit_depend @@ -187,9 +190,10 @@ subroutine omp_target_update_depend call foo(a) !$omp end task + !CHECK: omp.task if(%false) depend(taskdependin -> %[[A]]#1 : !fir.ref>) { !CHECK: %[[BOUNDS:.*]] = omp.map.bounds !CHECK: %[[MAP:.*]] = omp.map.info var_ptr(%[[A]]#0 : !fir.ref>, !fir.array<1024xi32>) map_clauses(to) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref> {name = "a"} - !CHECK: omp.target_update motion_entries(%[[MAP]] : !fir.ref>) depend(taskdependin -> %[[A]]#1 : !fir.ref>) + !CHECK: omp.target_update motion_entries(%[[MAP]] : !fir.ref>) !$omp target update to(a) depend(in:a) end subroutine omp_target_update_depend @@ -367,12 +371,14 @@ subroutine omp_target_depend !$omp task depend(out: a) call foo(a) !$omp end task + + !CHECK: omp.task if(%false) depend(taskdependin -> %[[A]]#1 : !fir.ref>) { !CHECK: %[[STRIDE_A:.*]] = arith.constant 1 : index !CHECK: %[[LBOUND_A:.*]] = arith.constant 0 : index !CHECK: %[[UBOUND_A:.*]] = arith.subi %c1024, %c1 : index !CHECK: %[[BOUNDS_A:.*]] = omp.map.bounds lower_bound(%[[LBOUND_A]] : index) upper_bound(%[[UBOUND_A]] : index) extent(%[[EXTENT_A]] : index) stride(%[[STRIDE_A]] : index) start_idx(%[[STRIDE_A]] : index) !CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[A]]#0 : !fir.ref>, !fir.array<1024xi32>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS_A]]) -> !fir.ref> {name = "a"} - !CHECK: omp.target map_entries(%[[MAP_A]] -> %[[BB0_ARG:.*]] : !fir.ref>) depend(taskdependin -> %[[A]]#1 : !fir.ref>) { + !CHECK: omp.target map_entries(%[[MAP_A]] -> %[[BB0_ARG:.*]] : !fir.ref>) !$omp target map(tofrom: a) depend(in: a) a(1) = 10 !CHECK: omp.terminator @@ -380,6 +386,34 @@ subroutine omp_target_depend !CHECK: } end subroutine omp_target_depend +!=============================================================================== +! Target with region `depend` clause and nowait +!=============================================================================== + +!CHECK-LABEL: func.func @_QPomp_target_depend_nowait() { +subroutine omp_target_depend_nowait + !CHECK: %[[EXTENT_A:.*]] = arith.constant 1024 : index + !CHECK: %[[A:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {uniq_name = "_QFomp_target_depend_nowaitEa"} : (!fir.ref>, !fir.shape<1>) -> (!fir.ref>, !fir.ref>) + integer :: a(1024) + !CHECK: omp.task depend(taskdependout -> %[[A]]#1 : !fir.ref>) { + !$omp task depend(out: a) + call foo(a) + !$omp end task + + !CHECK: omp.task if(%true) depend(taskdependin -> %[[A]]#1 : !fir.ref>) { + !CHECK: %[[STRIDE_A:.*]] = arith.constant 1 : index + !CHECK: %[[LBOUND_A:.*]] = arith.constant 0 : index + !CHECK: %[[UBOUND_A:.*]] = arith.subi %c1024, %c1 : index + !CHECK: %[[BOUNDS_A:.*]] = omp.map.bounds lower_bound(%[[LBOUND_A]] : index) upper_bound(%[[UBOUND_A]] : index) extent(%[[EXTENT_A]] : index) stride(%[[STRIDE_A]] : index) start_idx(%[[STRIDE_A]] : index) + !CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[A]]#0 : !fir.ref>, !fir.array<1024xi32>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS_A]]) -> !fir.ref> {name = "a"} + !CHECK: omp.target nowait map_entries(%[[MAP_A]] -> %[[BB0_ARG:.*]] : !fir.ref>) + !$omp target map(tofrom: a) depend(in: a) nowait + a(1) = 10 + !CHECK: omp.terminator + !$omp end target + !CHECK: } + end subroutine omp_target_depend_nowait + !=============================================================================== ! Target implicit capture !===============================================================================