@@ -1248,6 +1248,112 @@ struct AllocMemOpConversion : public FIROpConversion<fir::AllocMemOp> {
1248
1248
};
1249
1249
} // namespace
1250
1250
1251
+ // / Return the LLVMFuncOp corresponding to omp_target_alloc
1252
+ // /
1253
+ // / void* omp_target_alloc(size_t size, int device_num);
1254
+ // /
1255
+ // / TODO is the abi correct for all targets?
1256
+ static mlir::LLVM::LLVMFuncOp getOmpTargetAlloc (mlir::Operation *op) {
1257
+ auto module = op->getParentOfType <mlir::ModuleOp>();
1258
+ if (mlir::LLVM::LLVMFuncOp mallocFunc =
1259
+ module .lookupSymbol <mlir::LLVM::LLVMFuncOp>(" omp_target_alloc" ))
1260
+ return mallocFunc;
1261
+ mlir::OpBuilder moduleBuilder (module .getBodyRegion ());
1262
+ auto i64Ty = mlir::IntegerType::get (module ->getContext (), 64 );
1263
+ auto i32Ty = mlir::IntegerType::get (module ->getContext (), 32 );
1264
+ return moduleBuilder.create <mlir::LLVM::LLVMFuncOp>(
1265
+ moduleBuilder.getUnknownLoc (), " omp_target_alloc" ,
1266
+ mlir::LLVM::LLVMFunctionType::get (
1267
+ mlir::LLVM::LLVMPointerType::get (module ->getContext ()),
1268
+ {i64Ty, i32Ty},
1269
+ /* isVarArg=*/ false ));
1270
+ }
1271
+
1272
+ namespace {
1273
+ struct OmpTargetAllocMemOpConversion
1274
+ : public FIROpConversion<fir::OmpTargetAllocMemOp> {
1275
+ using FIROpConversion::FIROpConversion;
1276
+
1277
+ mlir::LogicalResult
1278
+ matchAndRewrite (fir::OmpTargetAllocMemOp heap, OpAdaptor adaptor,
1279
+ mlir::ConversionPatternRewriter &rewriter) const override {
1280
+ mlir::Type heapTy = heap.getType ();
1281
+ mlir::LLVM::LLVMFuncOp mallocFunc = getOmpTargetAlloc (heap);
1282
+ mlir::Location loc = heap.getLoc ();
1283
+ auto ity = lowerTy ().indexType ();
1284
+ mlir::Type dataTy = fir::unwrapRefType (heapTy);
1285
+ mlir::Type llvmObjectTy = convertObjectType (dataTy);
1286
+ if (fir::isRecordWithTypeParameters (fir::unwrapSequenceType (dataTy)))
1287
+ TODO (loc, " fir.allocmem codegen of derived type with length parameters" );
1288
+ mlir::Value size = genTypeSizeInBytes (loc, ity, rewriter, llvmObjectTy);
1289
+ if (auto scaleSize = genAllocationScaleSize (heap, ity, rewriter))
1290
+ size = rewriter.create <mlir::LLVM::MulOp>(loc, ity, size, scaleSize);
1291
+ for (mlir::Value opnd : adaptor.getOperands ())
1292
+ size = rewriter.create <mlir::LLVM::MulOp>(
1293
+ loc, ity, size, integerCast (loc, rewriter, ity, opnd));
1294
+ heap->setAttr (" callee" , mlir::SymbolRefAttr::get (mallocFunc));
1295
+ // TODO need to convert the device argument to the appropriate int type
1296
+ rewriter.replaceOpWithNewOp <mlir::LLVM::CallOp>(
1297
+ heap, ::getLlvmPtrType (heap.getContext ()),
1298
+ mlir::SmallVector<mlir::Value, 2 >({size, heap.getDevice ()}),
1299
+ heap->getAttrs ());
1300
+ return mlir::success ();
1301
+ }
1302
+
1303
+ // / Compute the allocation size in bytes of the element type of
1304
+ // / \p llTy pointer type. The result is returned as a value of \p idxTy
1305
+ // / integer type.
1306
+ mlir::Value genTypeSizeInBytes (mlir::Location loc, mlir::Type idxTy,
1307
+ mlir::ConversionPatternRewriter &rewriter,
1308
+ mlir::Type llTy) const {
1309
+ return computeElementDistance (loc, llTy, idxTy, rewriter);
1310
+ }
1311
+ };
1312
+ } // namespace
1313
+
1314
+ // / Return the LLVMFuncOp corresponding to omp_target_free
1315
+ // /
1316
+ // / void omp_target_free(void *device_ptr, int device_num);
1317
+ // /
1318
+ // / TODO is the abi correct for all targets?
1319
+ static mlir::LLVM::LLVMFuncOp getOmpTargetFree (mlir::Operation *op) {
1320
+ auto module = op->getParentOfType <mlir::ModuleOp>();
1321
+ if (mlir::LLVM::LLVMFuncOp freeFunc =
1322
+ module .lookupSymbol <mlir::LLVM::LLVMFuncOp>(" omp_target_free" ))
1323
+ return freeFunc;
1324
+ mlir::OpBuilder moduleBuilder (module .getBodyRegion ());
1325
+ auto i32Ty = mlir::IntegerType::get (module ->getContext (), 32 );
1326
+ return moduleBuilder.create <mlir::LLVM::LLVMFuncOp>(
1327
+ moduleBuilder.getUnknownLoc (), " omp_target_free" ,
1328
+ mlir::LLVM::LLVMFunctionType::get (
1329
+ mlir::LLVM::LLVMVoidType::get (module ->getContext ()),
1330
+ {getLlvmPtrType (module ->getContext ()), i32Ty},
1331
+ /* isVarArg=*/ false ));
1332
+ }
1333
+
1334
+ namespace {
1335
+ // / Lower a `fir.freemem` instruction into `llvm.call @free`
1336
+ struct OmpTargetFreeMemOpConversion
1337
+ : public FIROpConversion<fir::OmpTargetFreeMemOp> {
1338
+ using FIROpConversion::FIROpConversion;
1339
+
1340
+ mlir::LogicalResult
1341
+ matchAndRewrite (fir::OmpTargetFreeMemOp freemem, OpAdaptor adaptor,
1342
+ mlir::ConversionPatternRewriter &rewriter) const override {
1343
+ mlir::LLVM::LLVMFuncOp freeFunc = getOmpTargetFree (freemem);
1344
+ mlir::Location loc = freemem.getLoc ();
1345
+ freemem->setAttr (" callee" , mlir::SymbolRefAttr::get (freeFunc));
1346
+ // TODO need to convert the device argument to the appropriate int type
1347
+ rewriter.create <mlir::LLVM::CallOp>(
1348
+ loc, mlir::TypeRange{},
1349
+ mlir::ValueRange{adaptor.getHeapref (), freemem.getDevice ()},
1350
+ freemem->getAttrs ());
1351
+ rewriter.eraseOp (freemem);
1352
+ return mlir::success ();
1353
+ }
1354
+ };
1355
+ } // namespace
1356
+
1251
1357
// / Return the LLVMFuncOp corresponding to the standard free call.
1252
1358
static mlir::LLVM::LLVMFuncOp
1253
1359
getFree (fir::FreeMemOp op, mlir::ConversionPatternRewriter &rewriter) {
@@ -3851,28 +3957,29 @@ class FIRToLLVMLowering
3851
3957
mlir::RewritePatternSet pattern (context);
3852
3958
pattern.insert <
3853
3959
AbsentOpConversion, AddcOpConversion, AddrOfOpConversion,
3854
- AllocaOpConversion, AllocMemOpConversion, BoxAddrOpConversion,
3855
- BoxCharLenOpConversion, BoxDimsOpConversion, BoxEleSizeOpConversion,
3856
- BoxIsAllocOpConversion, BoxIsArrayOpConversion, BoxIsPtrOpConversion,
3857
- BoxOffsetOpConversion, BoxProcHostOpConversion, BoxRankOpConversion,
3858
- BoxTypeCodeOpConversion, BoxTypeDescOpConversion, CallOpConversion,
3859
- CmpcOpConversion, ConstcOpConversion, ConvertOpConversion,
3860
- CoordinateOpConversion, DTEntryOpConversion, DivcOpConversion,
3861
- EmboxOpConversion, EmboxCharOpConversion, EmboxProcOpConversion,
3862
- ExtractValueOpConversion, FieldIndexOpConversion, FirEndOpConversion,
3863
- FreeMemOpConversion, GlobalLenOpConversion, GlobalOpConversion,
3864
- HasValueOpConversion, InsertOnRangeOpConversion,
3865
- InsertValueOpConversion, IsPresentOpConversion,
3866
- LenParamIndexOpConversion, LoadOpConversion, MulcOpConversion,
3867
- NegcOpConversion, NoReassocOpConversion, SelectCaseOpConversion,
3868
- SelectOpConversion, SelectRankOpConversion, SelectTypeOpConversion,
3869
- ShapeOpConversion, ShapeShiftOpConversion, ShiftOpConversion,
3870
- SliceOpConversion, StoreOpConversion, StringLitOpConversion,
3871
- SubcOpConversion, TypeDescOpConversion, TypeInfoOpConversion,
3872
- UnboxCharOpConversion, UnboxProcOpConversion, UndefOpConversion,
3873
- UnreachableOpConversion, UnrealizedConversionCastOpConversion,
3874
- XArrayCoorOpConversion, XEmboxOpConversion, XReboxOpConversion,
3875
- ZeroOpConversion>(typeConverter, options);
3960
+ AllocaOpConversion, AllocMemOpConversion, OmpTargetAllocMemOpConversion,
3961
+ BoxAddrOpConversion, BoxCharLenOpConversion, BoxDimsOpConversion,
3962
+ BoxEleSizeOpConversion, BoxIsAllocOpConversion, BoxIsArrayOpConversion,
3963
+ BoxIsPtrOpConversion, BoxOffsetOpConversion, BoxProcHostOpConversion,
3964
+ BoxRankOpConversion, BoxTypeCodeOpConversion, BoxTypeDescOpConversion,
3965
+ CallOpConversion, CmpcOpConversion, ConstcOpConversion,
3966
+ ConvertOpConversion, CoordinateOpConversion, DTEntryOpConversion,
3967
+ DivcOpConversion, EmboxOpConversion, EmboxCharOpConversion,
3968
+ EmboxProcOpConversion, ExtractValueOpConversion, FieldIndexOpConversion,
3969
+ FirEndOpConversion, OmpTargetFreeMemOpConversion, FreeMemOpConversion,
3970
+ GlobalLenOpConversion, GlobalOpConversion, HasValueOpConversion,
3971
+ InsertOnRangeOpConversion, InsertValueOpConversion,
3972
+ IsPresentOpConversion, LenParamIndexOpConversion, LoadOpConversion,
3973
+ MulcOpConversion, NegcOpConversion, NoReassocOpConversion,
3974
+ SelectCaseOpConversion, SelectOpConversion, SelectRankOpConversion,
3975
+ SelectTypeOpConversion, ShapeOpConversion, ShapeShiftOpConversion,
3976
+ ShiftOpConversion, SliceOpConversion, StoreOpConversion,
3977
+ StringLitOpConversion, SubcOpConversion, TypeDescOpConversion,
3978
+ TypeInfoOpConversion, UnboxCharOpConversion, UnboxProcOpConversion,
3979
+ UndefOpConversion, UnreachableOpConversion,
3980
+ UnrealizedConversionCastOpConversion, XArrayCoorOpConversion,
3981
+ XEmboxOpConversion, XReboxOpConversion, ZeroOpConversion>(typeConverter,
3982
+ options);
3876
3983
mlir::populateFuncToLLVMConversionPatterns (typeConverter, pattern);
3877
3984
mlir::populateOpenMPToLLVMConversionPatterns (typeConverter, pattern);
3878
3985
mlir::arith::populateArithToLLVMConversionPatterns (typeConverter, pattern);
0 commit comments