Skip to content

Commit 382d359

Browse files
authored
[flang][cuda] Propagate the data attribute on the converted calls (#124877)
1 parent 29441e4 commit 382d359

File tree

4 files changed

+19
-13
lines changed

4 files changed

+19
-13
lines changed

flang/lib/Optimizer/Transforms/CUFOpConversion.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -294,19 +294,22 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
294294
matchAndRewrite(cuf::AllocOp op,
295295
mlir::PatternRewriter &rewriter) const override {
296296

297+
mlir::Location loc = op.getLoc();
298+
297299
if (inDeviceContext(op.getOperation())) {
298300
// In device context just replace the cuf.alloc operation with a fir.alloc
299301
// the cuf.free will be removed.
300-
rewriter.replaceOpWithNewOp<fir::AllocaOp>(
301-
op, op.getInType(), op.getUniqName() ? *op.getUniqName() : "",
302+
auto allocaOp = rewriter.create<fir::AllocaOp>(
303+
loc, op.getInType(), op.getUniqName() ? *op.getUniqName() : "",
302304
op.getBindcName() ? *op.getBindcName() : "", op.getTypeparams(),
303305
op.getShape());
306+
allocaOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr());
307+
rewriter.replaceOp(op, allocaOp);
304308
return mlir::success();
305309
}
306310

307311
auto mod = op->getParentOfType<mlir::ModuleOp>();
308312
fir::FirOpBuilder builder(rewriter, mod);
309-
mlir::Location loc = op.getLoc();
310313
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
311314

312315
if (!mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType())) {
@@ -359,6 +362,7 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
359362
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
360363
builder, loc, fTy, bytes, memTy, sourceFile, sourceLine)};
361364
auto callOp = builder.create<fir::CallOp>(loc, func, args);
365+
callOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr());
362366
auto convOp = builder.createConvert(loc, op.getResult().getType(),
363367
callOp.getResult(0));
364368
rewriter.replaceOp(op, convOp);
@@ -381,6 +385,7 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
381385
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
382386
builder, loc, fTy, sizeInBytes, sourceFile, sourceLine)};
383387
auto callOp = builder.create<fir::CallOp>(loc, func, args);
388+
callOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr());
384389
auto convOp = builder.createConvert(loc, op.getResult().getType(),
385390
callOp.getResult(0));
386391
rewriter.replaceOp(op, convOp);
@@ -508,7 +513,8 @@ struct CUFFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
508513
fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
509514
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
510515
builder, loc, fTy, op.getDevptr(), sourceFile, sourceLine)};
511-
builder.create<fir::CallOp>(loc, func, args);
516+
auto callOp = builder.create<fir::CallOp>(loc, func, args);
517+
callOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr());
512518
rewriter.eraseOp(op);
513519
return mlir::success();
514520
}

flang/test/Fir/CUDA/cuda-alloc-free.fir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ func.func @_QPsub1() {
1111

1212
// CHECK-LABEL: func.func @_QPsub1()
1313
// CHECK: %[[BYTES:.*]] = fir.convert %c4{{.*}} : (index) -> i64
14-
// CHECK: %[[ALLOC:.*]] = fir.call @_FortranACUFMemAlloc(%[[BYTES]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (i64, i32, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
14+
// CHECK: %[[ALLOC:.*]] = fir.call @_FortranACUFMemAlloc(%[[BYTES]], %c0{{.*}}, %{{.*}}, %{{.*}}) {cuf.data_attr = #cuf.cuda<device>} : (i64, i32, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
1515
// CHECK: %[[CONV:.*]] = fir.convert %3 : (!fir.llvm_ptr<i8>) -> !fir.ref<i32>
1616
// CHECK: %[[DECL:.*]]:2 = hlfir.declare %[[CONV]] {data_attr = #cuf.cuda<device>, uniq_name = "_QFsub1Eidev"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
1717
// CHECK: %[[DEVPTR:.*]] = fir.convert %[[DECL]]#1 : (!fir.ref<i32>) -> !fir.llvm_ptr<i8>
@@ -26,7 +26,7 @@ func.func @_QPsub2() {
2626
// CHECK-LABEL: func.func @_QPsub2()
2727
// CHECK: %[[BYTES:.*]] = arith.muli %c10{{.*}}, %c4{{.*}} : index
2828
// CHECK: %[[CONV_BYTES:.*]] = fir.convert %[[BYTES]] : (index) -> i64
29-
// CHECK: %{{.*}} = fir.call @_FortranACUFMemAlloc(%[[CONV_BYTES]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (i64, i32, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
29+
// CHECK: %{{.*}} = fir.call @_FortranACUFMemAlloc(%[[CONV_BYTES]], %c0{{.*}}, %{{.*}}, %{{.*}}) {cuf.data_attr = #cuf.cuda<device>} : (i64, i32, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
3030
// CHECK: fir.call @_FortranACUFMemFree
3131

3232
func.func @_QPsub3(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}, %arg1: !fir.ref<i32> {fir.bindc_name = "m"}) {
@@ -58,7 +58,7 @@ func.func @_QPsub3(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}, %arg1: !fir.ref<
5858
// CHECK: %[[NBELEM:.*]] = arith.muli %[[N]], %[[M]] : index
5959
// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %c4{{.*}} : index
6060
// CHECK: %[[CONV_BYTES:.*]] = fir.convert %[[BYTES]] : (index) -> i64
61-
// CHECK: %{{.*}} = fir.call @_FortranACUFMemAlloc(%[[CONV_BYTES]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (i64, i32, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
61+
// CHECK: %{{.*}} = fir.call @_FortranACUFMemAlloc(%[[CONV_BYTES]], %c0{{.*}}, %{{.*}}, %{{.*}}) {cuf.data_attr = #cuf.cuda<device>} : (i64, i32, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
6262
// CHECK: fir.call @_FortranACUFMemFree
6363

6464
func.func @_QPtest_type() {
@@ -71,7 +71,7 @@ func.func @_QPtest_type() {
7171
// CHECK-LABEL: func.func @_QPtest_type()
7272
// CHECK: %[[BYTES:.*]] = arith.constant 12 : index
7373
// CHECK: %[[CONV_BYTES:.*]] = fir.convert %[[BYTES]] : (index) -> i64
74-
// CHECK: fir.call @_FortranACUFMemAlloc(%[[CONV_BYTES]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (i64, i32, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
74+
// CHECK: fir.call @_FortranACUFMemAlloc(%[[CONV_BYTES]], %c0{{.*}}, %{{.*}}, %{{.*}}) {cuf.data_attr = #cuf.cuda<device>} : (i64, i32, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
7575

7676
gpu.module @cuda_device_mod {
7777
gpu.func @_QMalloc() kernel {
@@ -81,7 +81,7 @@ gpu.module @cuda_device_mod {
8181
}
8282

8383
// CHECK-LABEL: gpu.func @_QMalloc() kernel
84-
// CHECK: fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", uniq_name = "_QMallocEa"}
84+
// CHECK: fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", cuf.data_attr = #cuf.cuda<device>, uniq_name = "_QMallocEa"}
8585

8686
func.func @_QQalloc_char() attributes {fir.bindc_name = "alloc_char"} {
8787
%c1 = arith.constant 1 : index
@@ -92,6 +92,6 @@ func.func @_QQalloc_char() attributes {fir.bindc_name = "alloc_char"} {
9292
// CHECK-LABEL: func.func @_QQalloc_char()
9393
// CHECK: %[[BYTES:.*]] = arith.muli %c10{{.*}}, %c1{{.*}} : index
9494
// CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
95-
// CHECK: fir.call @_FortranACUFMemAlloc(%[[BYTES_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (i64, i32, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
95+
// CHECK: fir.call @_FortranACUFMemAlloc(%[[BYTES_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) {cuf.data_attr = #cuf.cuda<device>} : (i64, i32, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
9696

9797
} // end module

flang/test/Fir/CUDA/cuda-allocate.fir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ func.func @_QPsub1() {
1515
}
1616

1717
// CHECK-LABEL: func.func @_QPsub1()
18-
// CHECK: %[[DESC_RT_CALL:.*]] = fir.call @_FortranACUFAllocDescriptor(%{{.*}}, %{{.*}}, %{{.*}}) : (i64, !fir.ref<i8>, i32) -> !fir.ref<!fir.box<none>>
18+
// CHECK: %[[DESC_RT_CALL:.*]] = fir.call @_FortranACUFAllocDescriptor(%{{.*}}, %{{.*}}, %{{.*}}) {cuf.data_attr = #cuf.cuda<device>} : (i64, !fir.ref<i8>, i32) -> !fir.ref<!fir.box<none>>
1919
// CHECK: %[[DESC:.*]] = fir.convert %[[DESC_RT_CALL]] : (!fir.ref<!fir.box<none>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
2020
// CHECK: %[[DECL_DESC:.*]]:2 = hlfir.declare %[[DESC]] {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
2121
// CHECK: %[[BOX_NONE:.*]] = fir.convert %[[DECL_DESC]]#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
@@ -24,7 +24,7 @@ func.func @_QPsub1() {
2424
// CHECK: %[[BOX_NONE:.*]] = fir.convert %[[DECL_DESC]]#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
2525
// CHECK: %{{.*}} = fir.call @_FortranAAllocatableDeallocate(%[[BOX_NONE]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, i1, !fir.box<none>, !fir.ref<i8>, i32) -> i32
2626
// CHECK: %[[BOX_NONE:.*]] = fir.convert %[[DECL_DESC]]#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
27-
// CHECK: fir.call @_FortranACUFFreeDescriptor(%[[BOX_NONE]], %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<i8>, i32) -> ()
27+
// CHECK: fir.call @_FortranACUFFreeDescriptor(%[[BOX_NONE]], %{{.*}}, %{{.*}}) {cuf.data_attr = #cuf.cuda<device>} : (!fir.ref<!fir.box<none>>, !fir.ref<i8>, i32) -> ()
2828

2929
fir.global @_QMmod1Ea {data_attr = #cuf.cuda<device>} : !fir.box<!fir.heap<!fir.array<?xf32>>> {
3030
%0 = fir.zero_bits !fir.heap<!fir.array<?xf32>>

flang/test/Fir/CUDA/cuda-data-transfer.fir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ func.func @_QPtest_array_type() {
329329
// CHECK-LABEL: func.func @_QPtest_array_type()
330330
// CHECK: %[[BYTES:.*]] = arith.muli %c10{{.*}}, %c12 : index
331331
// CHECK: %[[CONV_BYTES:.*]] = fir.convert %[[BYTES]] : (index) -> i64
332-
// CHECK: fir.call @_FortranACUFMemAlloc(%[[CONV_BYTES]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (i64, i32, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
332+
// CHECK: fir.call @_FortranACUFMemAlloc(%[[CONV_BYTES]], %c0{{.*}}, %{{.*}}, %{{.*}}) {cuf.data_attr = #cuf.cuda<device>} : (i64, i32, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
333333
// CHECK: %[[BYTES:.*]] = arith.muli %c10{{.*}}, %c12{{.*}} : i64
334334
// CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%{{.*}}, %{{.*}}, %[[BYTES]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> ()
335335

0 commit comments

Comments
 (0)