diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 5cff95c7d125b..83fb2d94508c2 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -1037,6 +1037,31 @@ static void genBodyOfTargetOp( genNestedEvaluations(converter, eval); } +// If the symbol is specified in declare target directive, the function returns +// the corresponding declare target operation. +static mlir::omp::DeclareTargetInterface +getDeclareTargetOp(const Fortran::semantics::Symbol &sym, + Fortran::lower::AbstractConverter &converter) { + mlir::ModuleOp mod = converter.getFirOpBuilder().getModule(); + mlir::Operation *op; + op = mod.lookupSymbol(converter.mangleName(sym)); + auto declareTargetOp = + llvm::dyn_cast_if_present(op); + // If declare target op is not found Check if common block containing the + // variable is specified in declare target + if (!declareTargetOp || !declareTargetOp.isDeclareTarget()) { + if (auto cB = Fortran::semantics::FindCommonBlockContaining(sym)) { + op = mod.lookupSymbol(converter.mangleName(*cB)); + declareTargetOp = + llvm::dyn_cast_if_present(op); + } + } + if (declareTargetOp && declareTargetOp.isDeclareTarget()) { + return declareTargetOp; + } + return static_cast(nullptr); +} + static mlir::omp::TargetOp genTargetOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, @@ -1122,11 +1147,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter, // If a variable is specified in declare target link and if device // type is not specified as `nohost`, it needs to be mapped tofrom - mlir::ModuleOp mod = converter.getFirOpBuilder().getModule(); - mlir::Operation *op = mod.lookupSymbol(converter.mangleName(sym)); - auto declareTargetOp = - llvm::dyn_cast_if_present(op); - if (declareTargetOp && declareTargetOp.isDeclareTarget()) { + if (auto declareTargetOp = getDeclareTargetOp(sym, converter)) { if (declareTargetOp.getDeclareTargetCaptureClause() == mlir::omp::DeclareTargetCaptureClause::link && declareTargetOp.getDeclareTargetDeviceType() != diff --git a/flang/test/Lower/OpenMP/declare-target-link-tarop-cap.f90 b/flang/test/Lower/OpenMP/declare-target-link-tarop-cap.f90 index 7cd0597161578..cd2615faba546 100644 --- a/flang/test/Lower/OpenMP/declare-target-link-tarop-cap.f90 +++ b/flang/test/Lower/OpenMP/declare-target-link-tarop-cap.f90 @@ -20,6 +20,13 @@ program test_link integer, pointer :: test_ptr2 !$omp declare target link(test_ptr2) + integer :: test_int_cb + + integer :: test_int_array_cb(3) = (/1,2,3/) + + common /test_cb/ test_int_cb, test_int_array_cb + !$omp declare target link(/test_cb/) + !CHECK-DAG: {{%.*}} = omp.map_info var_ptr({{%.*}} : !fir.ref, i32) map_clauses(implicit, tofrom) capture(ByRef) -> !fir.ref {name = "test_int"} !$omp target test_int = test_int + 1 @@ -52,4 +59,15 @@ program test_link test_ptr2 = test_ptr2 + 1 !$omp end target + !CHECK-DAG: {{%.*}} = omp.map_info var_ptr({{%.*}} : !fir.ref, i32) map_clauses(implicit, tofrom) capture(ByRef) -> !fir.ref {name = "test_int_cb"} + !$omp target + test_int_cb = test_int_cb + 1 + !$omp end target + + !CHECK-DAG: {{%.*}} = omp.map_info var_ptr({{%.*}} : !fir.ref>, !fir.array<3xi32>) map_clauses(implicit, tofrom) capture(ByRef) bounds({{%.*}}) -> !fir.ref> {name = "test_int_array_cb"} + !$omp target + do i = 1,3 + test_int_array_cb(i) = i * 2 + end do + !$omp end target end diff --git a/openmp/libomptarget/test/offloading/fortran/declare-target-vars-in-target-region.f90 b/openmp/libomptarget/test/offloading/fortran/declare-target-vars-in-target-region.f90 index f524deac3bcce..63343d504323b 100644 --- a/openmp/libomptarget/test/offloading/fortran/declare-target-vars-in-target-region.f90 +++ b/openmp/libomptarget/test/offloading/fortran/declare-target-vars-in-target-region.f90 @@ -16,6 +16,10 @@ module test_0 !$omp declare target link(arr1) enter(arr2) INTEGER :: scalar = 1 !$omp declare target link(scalar) + INTEGER :: scalar_cb = 1 + INTEGER :: arr_cb(10) = (/0,0,0,0,0,0,0,0,0,0/) + COMMON /CB/ scalar_cb, arr_cb + !$omp declare target link(/CB/) end module test_0 subroutine test_with_array_link_and_tofrom() @@ -73,9 +77,36 @@ subroutine test_with_scalar_link_only() PRINT *, scalar end subroutine test_with_scalar_link_only +subroutine test_with_array_cb_link_only() + use test_0 + integer :: i = 1 + integer :: j = 11 + !$omp target map(i, j) + do while (i <= j) + arr_cb(i) = i + 1; + i = i + 1 + end do + !$omp end target + + ! CHECK: 2 3 4 5 6 7 8 9 10 11 + PRINT *, arr_cb(:) +end subroutine test_with_array_cb_link_only + +subroutine test_with_scalar_cb_link_only() + use test_0 + !$omp target + scalar_cb = 10 + !$omp end target + + ! CHECK: 10 + PRINT *, scalar_cb +end subroutine test_with_scalar_cb_link_only + program main call test_with_array_link_and_tofrom() call test_with_array_link_only() call test_with_array_enter_only() call test_with_scalar_link_only() + call test_with_array_cb_link_only() + call test_with_scalar_cb_link_only() end program