@@ -141,6 +141,20 @@ struct CufDeallocateOpConversion
141
141
}
142
142
};
143
143
144
+ static bool inDeviceContext (mlir::Operation *op) {
145
+ if (op->getParentOfType <cuf::KernelOp>())
146
+ return true ;
147
+ if (auto funcOp = op->getParentOfType <mlir::func::FuncOp>()) {
148
+ if (auto cudaProcAttr =
149
+ funcOp.getOperation ()->getAttrOfType <cuf::ProcAttributeAttr>(
150
+ cuf::getProcAttrName ())) {
151
+ return cudaProcAttr.getValue () != cuf::ProcAttribute::Host &&
152
+ cudaProcAttr.getValue () != cuf::ProcAttribute::HostDevice;
153
+ }
154
+ }
155
+ return false ;
156
+ }
157
+
144
158
struct CufAllocOpConversion : public mlir ::OpRewritePattern<cuf::AllocOp> {
145
159
using OpRewritePattern::OpRewritePattern;
146
160
@@ -157,6 +171,16 @@ struct CufAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
157
171
if (!boxTy)
158
172
return failure ();
159
173
174
+ if (inDeviceContext (op.getOperation ())) {
175
+ // In device context just replace the cuf.alloc operation with a fir.alloc
176
+ // the cuf.free will be removed.
177
+ rewriter.replaceOpWithNewOp <fir::AllocaOp>(
178
+ op, op.getInType (), op.getUniqName () ? *op.getUniqName () : " " ,
179
+ op.getBindcName () ? *op.getBindcName () : " " , op.getTypeparams (),
180
+ op.getShape ());
181
+ return mlir::success ();
182
+ }
183
+
160
184
auto mod = op->getParentOfType <mlir::ModuleOp>();
161
185
fir::FirOpBuilder builder (rewriter, mod);
162
186
mlir::Location loc = op.getLoc ();
@@ -200,6 +224,11 @@ struct CufFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
200
224
if (!mlir::isa<fir::BaseBoxType>(refTy.getEleTy ()))
201
225
return failure ();
202
226
227
+ if (inDeviceContext (op.getOperation ())) {
228
+ rewriter.eraseOp (op);
229
+ return mlir::success ();
230
+ }
231
+
203
232
auto mod = op->getParentOfType <mlir::ModuleOp>();
204
233
fir::FirOpBuilder builder (rewriter, mod);
205
234
mlir::Location loc = op.getLoc ();
@@ -248,6 +277,7 @@ class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {
248
277
[](::cuf::AllocateOp op) { return isBoxGlobal (op); });
249
278
target.addDynamicallyLegalOp <cuf::DeallocateOp>(
250
279
[](::cuf::DeallocateOp op) { return isBoxGlobal (op); });
280
+ target.addLegalDialect <fir::FIROpsDialect>();
251
281
patterns.insert <CufAllocOpConversion>(ctx, &*dl, &typeConverter);
252
282
patterns.insert <CufAllocateOpConversion, CufDeallocateOpConversion,
253
283
CufFreeOpConversion>(ctx);
0 commit comments