Skip to content

[flang][OpenMP] - Transform target offloading directives with dependencies during PFT to MLIR conversion #85130

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions flang/lib/Lower/OpenMP/ClauseProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,61 @@ bool ClauseProcessor::processHasDeviceAddr(
addUseDeviceClause(converter, devAddrClause.v, result.hasDeviceAddrVars,
isDeviceTypes, isDeviceLocs, isDeviceSymbols);
});

}

bool ClauseProcessor::processTargetDepend(
mlir::Location currentLocation, mlir::omp::DependClauseOps &clauseOps) const {

processDepend(clauseOps);
if (clauseOps.dependTypeAttrs.empty())
return false;

// If 'dependTypeOperands' is not empty, this means the depend
// clause was used and we create an omp.task operation that'll
// enclose the omp.target operation corresponding to the target
// construct used. This new omp.task will be a mergeable task
// on which the depend clause will be tacked on. The depend
// clause on the original target construct is dropped.
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();

// Create the new omp.task op.
// Whether we create a mergeable task or not depends upon the presence of the
// nowait clause on the target construct.
// If the nowait clause is not present on the target construct, then as per
// the spec, the target task is an included task. We add if(0) clause to the
// task that we create. A task with an if clause that evaluates to false is
// undeferred and because this value is known at compile time, it is an
// included task. And an included task is also mergeable. So, we don't bother
// with the mergeable clause here. If the nowait clause is present on the
// target construct, then as per the spec, the execution of the target task
// may be deferred. This makes it trivially not mergeable.
mlir::omp::NowaitClauseOps nowaitClauseOp;
markClauseOccurrence<omp::clause::Nowait>(nowaitClauseOp.nowaitAttr);

mlir::omp::TaskOp taskOp = firOpBuilder.create<mlir::omp::TaskOp>(
currentLocation,
/*if_expr*/ nowaitClauseOp.nowaitAttr
? firOpBuilder.createBool(currentLocation, true)
: firOpBuilder.createBool(currentLocation, false),
/*final_expr*/ mlir::Value(), /*untied*/ mlir::UnitAttr(),
/*mergeable*/ mlir::UnitAttr(),
/*in_reduction_vars*/ mlir::ValueRange(), /*in_reductions*/ nullptr,
/*priority*/ mlir::Value(),
mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
clauseOps.dependTypeAttrs),
clauseOps.dependVars, /*allocate_vars*/ mlir::ValueRange(),
/*allocate_vars*/ mlir::ValueRange());

// Clear the dependencies so that the subsequent omp.target op doesn't have
// dependencies
clauseOps.dependTypeAttrs.clear();
clauseOps.dependVars.clear();

firOpBuilder.createBlock(&taskOp.getRegion());
firOpBuilder.create<mlir::omp::TerminatorOp>(currentLocation);
firOpBuilder.setInsertionPointToStart(&taskOp.getRegion().front());
return true;
}

bool ClauseProcessor::processIf(
Expand Down
6 changes: 6 additions & 0 deletions flang/lib/Lower/OpenMP/ClauseProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@ class ClauseProcessor {
bool processCopyprivate(mlir::Location currentLocation,
mlir::omp::CopyprivateClauseOps &result) const;
bool processDepend(mlir::omp::DependClauseOps &result) const;
// This is a special case of processDepend that processes the depend
// clause on Target ops - TargetOp, EnterDataOp, ExitDataOp, UpdateDataOp
// It sets up the generation of MLIR code for the target construct
// in question by first creating an enclosing omp.task operation and transfers
// the 'depend' clause and its arguments to this new omp.task operation.
bool processTargetDepend(mlir::Location currentLocation, mlir::omp::DependClauseOps &clauseOps) const;
bool
processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
bool processIf(omp::clause::If::DirectiveNameModifier directiveName,
Expand Down
4 changes: 2 additions & 2 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1140,7 +1140,7 @@ static void genTargetClauses(
llvm::SmallVectorImpl<mlir::Location> &devicePtrLocs,
llvm::SmallVectorImpl<mlir::Type> &devicePtrTypes) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processDepend(clauseOps);
cp.processTargetDepend(loc, clauseOps);
cp.processDevice(stmtCtx, clauseOps);
cp.processHasDeviceAddr(clauseOps, deviceAddrTypes, deviceAddrLocs,
deviceAddrSyms);
Expand Down Expand Up @@ -1199,7 +1199,7 @@ static void genTargetEnterExitUpdateDataClauses(
mlir::Location loc, llvm::omp::Directive directive,
mlir::omp::TargetEnterExitUpdateDataClauseOps &clauseOps) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processDepend(clauseOps);
cp.processTargetDepend(loc, clauseOps);
cp.processDevice(stmtCtx, clauseOps);
cp.processIf(directive, clauseOps);
cp.processNowait(clauseOps);
Expand Down
42 changes: 38 additions & 4 deletions flang/test/Lower/OpenMP/target.f90
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ subroutine omp_target_enter_depend
!$omp task depend(out: a)
call foo(a)
!$omp end task
!CHECK: omp.task if(%false) depend(taskdependin -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>) {
!CHECK: %[[BOUNDS:.*]] = omp.map.bounds lower_bound({{.*}}) upper_bound({{.*}}) extent({{.*}}) stride({{.*}}) start_idx({{.*}})
!CHECK: %[[MAP:.*]] = omp.map.info var_ptr({{.*}}) map_clauses(to) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
!CHECK: omp.target_enter_data map_entries(%[[MAP]] : !fir.ref<!fir.array<1024xi32>>) depend(taskdependin -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>)
!CHECK: omp.target_enter_data map_entries(%[[MAP]] : !fir.ref<!fir.array<1024xi32>>)
!$omp target enter data map(to: a) depend(in: a)
return
end subroutine omp_target_enter_depend
Expand Down Expand Up @@ -166,9 +167,11 @@ subroutine omp_target_exit_depend
!$omp task depend(out: a)
call foo(a)
!$omp end task

!CHECK: omp.task if(%false) depend(taskdependout -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>)
!CHECK: %[[BOUNDS:.*]] = omp.map.bounds lower_bound({{.*}}) upper_bound({{.*}}) extent({{.*}}) stride({{.*}}) start_idx({{.*}})
!CHECK: %[[MAP:.*]] = omp.map.info var_ptr({{.*}}) map_clauses(from) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
!CHECK: omp.target_exit_data map_entries(%[[MAP]] : !fir.ref<!fir.array<1024xi32>>) depend(taskdependout -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>)
!CHECK: omp.target_exit_data map_entries(%[[MAP]] : !fir.ref<!fir.array<1024xi32>>)
!$omp target exit data map(from: a) depend(out: a)
end subroutine omp_target_exit_depend

Expand All @@ -187,9 +190,10 @@ subroutine omp_target_update_depend
call foo(a)
!$omp end task

!CHECK: omp.task if(%false) depend(taskdependin -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>) {
!CHECK: %[[BOUNDS:.*]] = omp.map.bounds
!CHECK: %[[MAP:.*]] = omp.map.info var_ptr(%[[A]]#0 : !fir.ref<!fir.array<1024xi32>>, !fir.array<1024xi32>) map_clauses(to) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
!CHECK: omp.target_update motion_entries(%[[MAP]] : !fir.ref<!fir.array<1024xi32>>) depend(taskdependin -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>)
!CHECK: omp.target_update motion_entries(%[[MAP]] : !fir.ref<!fir.array<1024xi32>>)
!$omp target update to(a) depend(in:a)
end subroutine omp_target_update_depend

Expand Down Expand Up @@ -367,19 +371,49 @@ subroutine omp_target_depend
!$omp task depend(out: a)
call foo(a)
!$omp end task

!CHECK: omp.task if(%false) depend(taskdependin -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>) {
!CHECK: %[[STRIDE_A:.*]] = arith.constant 1 : index
!CHECK: %[[LBOUND_A:.*]] = arith.constant 0 : index
!CHECK: %[[UBOUND_A:.*]] = arith.subi %c1024, %c1 : index
!CHECK: %[[BOUNDS_A:.*]] = omp.map.bounds lower_bound(%[[LBOUND_A]] : index) upper_bound(%[[UBOUND_A]] : index) extent(%[[EXTENT_A]] : index) stride(%[[STRIDE_A]] : index) start_idx(%[[STRIDE_A]] : index)
!CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[A]]#0 : !fir.ref<!fir.array<1024xi32>>, !fir.array<1024xi32>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS_A]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
!CHECK: omp.target map_entries(%[[MAP_A]] -> %[[BB0_ARG:.*]] : !fir.ref<!fir.array<1024xi32>>) depend(taskdependin -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>) {
!CHECK: omp.target map_entries(%[[MAP_A]] -> %[[BB0_ARG:.*]] : !fir.ref<!fir.array<1024xi32>>)
!$omp target map(tofrom: a) depend(in: a)
a(1) = 10
!CHECK: omp.terminator
!$omp end target
!CHECK: }
end subroutine omp_target_depend

!===============================================================================
! Target with region `depend` clause and nowait
!===============================================================================

!CHECK-LABEL: func.func @_QPomp_target_depend_nowait() {
subroutine omp_target_depend_nowait
!CHECK: %[[EXTENT_A:.*]] = arith.constant 1024 : index
!CHECK: %[[A:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {uniq_name = "_QFomp_target_depend_nowaitEa"} : (!fir.ref<!fir.array<1024xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<1024xi32>>, !fir.ref<!fir.array<1024xi32>>)
integer :: a(1024)
!CHECK: omp.task depend(taskdependout -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>) {
!$omp task depend(out: a)
call foo(a)
!$omp end task

!CHECK: omp.task if(%true) depend(taskdependin -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>) {
!CHECK: %[[STRIDE_A:.*]] = arith.constant 1 : index
!CHECK: %[[LBOUND_A:.*]] = arith.constant 0 : index
!CHECK: %[[UBOUND_A:.*]] = arith.subi %c1024, %c1 : index
!CHECK: %[[BOUNDS_A:.*]] = omp.map.bounds lower_bound(%[[LBOUND_A]] : index) upper_bound(%[[UBOUND_A]] : index) extent(%[[EXTENT_A]] : index) stride(%[[STRIDE_A]] : index) start_idx(%[[STRIDE_A]] : index)
!CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[A]]#0 : !fir.ref<!fir.array<1024xi32>>, !fir.array<1024xi32>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS_A]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
!CHECK: omp.target nowait map_entries(%[[MAP_A]] -> %[[BB0_ARG:.*]] : !fir.ref<!fir.array<1024xi32>>)
!$omp target map(tofrom: a) depend(in: a) nowait
a(1) = 10
!CHECK: omp.terminator
!$omp end target
!CHECK: }
end subroutine omp_target_depend_nowait

!===============================================================================
! Target implicit capture
!===============================================================================
Expand Down
Loading