Skip to content

Commit 841327d

Browse files
authored
[flang][cuda] Convert cuf.alloc for box to fir.alloca in device context (#102662)
In device context managed memory is not available so it makes no sense to allocate the descriptor using it. Fall back to fir.alloca as it is handled well in device code. cuf.free is just dropped.
1 parent 66d8735 commit 841327d

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed

flang/lib/Optimizer/Transforms/CufOpConversion.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,20 @@ struct CufDeallocateOpConversion
141141
}
142142
};
143143

144+
static bool inDeviceContext(mlir::Operation *op) {
145+
if (op->getParentOfType<cuf::KernelOp>())
146+
return true;
147+
if (auto funcOp = op->getParentOfType<mlir::func::FuncOp>()) {
148+
if (auto cudaProcAttr =
149+
funcOp.getOperation()->getAttrOfType<cuf::ProcAttributeAttr>(
150+
cuf::getProcAttrName())) {
151+
return cudaProcAttr.getValue() != cuf::ProcAttribute::Host &&
152+
cudaProcAttr.getValue() != cuf::ProcAttribute::HostDevice;
153+
}
154+
}
155+
return false;
156+
}
157+
144158
struct CufAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
145159
using OpRewritePattern::OpRewritePattern;
146160

@@ -157,6 +171,16 @@ struct CufAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
157171
if (!boxTy)
158172
return failure();
159173

174+
if (inDeviceContext(op.getOperation())) {
175+
// In device context just replace the cuf.alloc operation with a fir.alloc
176+
// the cuf.free will be removed.
177+
rewriter.replaceOpWithNewOp<fir::AllocaOp>(
178+
op, op.getInType(), op.getUniqName() ? *op.getUniqName() : "",
179+
op.getBindcName() ? *op.getBindcName() : "", op.getTypeparams(),
180+
op.getShape());
181+
return mlir::success();
182+
}
183+
160184
auto mod = op->getParentOfType<mlir::ModuleOp>();
161185
fir::FirOpBuilder builder(rewriter, mod);
162186
mlir::Location loc = op.getLoc();
@@ -200,6 +224,11 @@ struct CufFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
200224
if (!mlir::isa<fir::BaseBoxType>(refTy.getEleTy()))
201225
return failure();
202226

227+
if (inDeviceContext(op.getOperation())) {
228+
rewriter.eraseOp(op);
229+
return mlir::success();
230+
}
231+
203232
auto mod = op->getParentOfType<mlir::ModuleOp>();
204233
fir::FirOpBuilder builder(rewriter, mod);
205234
mlir::Location loc = op.getLoc();
@@ -248,6 +277,7 @@ class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {
248277
[](::cuf::AllocateOp op) { return isBoxGlobal(op); });
249278
target.addDynamicallyLegalOp<cuf::DeallocateOp>(
250279
[](::cuf::DeallocateOp op) { return isBoxGlobal(op); });
280+
target.addLegalDialect<fir::FIROpsDialect>();
251281
patterns.insert<CufAllocOpConversion>(ctx, &*dl, &typeConverter);
252282
patterns.insert<CufAllocateOpConversion, CufDeallocateOpConversion,
253283
CufFreeOpConversion>(ctx);

flang/test/Fir/CUDA/cuda-allocate.fir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,17 @@ func.func @_QPsub3() {
5757
// CHECK: cuf.allocate
5858
// CHECK: cuf.deallocate
5959

60+
func.func @_QPsub4() attributes {cuf.proc_attr = #cuf.cuda_proc<device>} {
61+
%0 = cuf.alloc !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub1Ea"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
62+
%4:2 = hlfir.declare %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
63+
cuf.free %4#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>> {data_attr = #cuf.cuda<device>}
64+
return
65+
}
66+
67+
// CHECK-LABEL: func.func @_QPsub4()
68+
// CHECK: fir.alloca
69+
// CHECK-NOT: cuf.free
70+
6071
}
6172

6273

0 commit comments

Comments
 (0)