Skip to content

Commit 6e2d584

Browse files
committed
Lower omp_target_{alloc,free} to llvm
1 parent c0dee92 commit 6e2d584

File tree

1 file changed

+129
-22
lines changed

1 file changed

+129
-22
lines changed

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 129 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1248,6 +1248,112 @@ struct AllocMemOpConversion : public FIROpConversion<fir::AllocMemOp> {
12481248
};
12491249
} // namespace
12501250

1251+
/// Return the LLVMFuncOp corresponding to omp_target_alloc
1252+
///
1253+
/// void* omp_target_alloc(size_t size, int device_num);
1254+
///
1255+
/// TODO is the abi correct for all targets?
1256+
static mlir::LLVM::LLVMFuncOp getOmpTargetAlloc(mlir::Operation *op) {
1257+
auto module = op->getParentOfType<mlir::ModuleOp>();
1258+
if (mlir::LLVM::LLVMFuncOp mallocFunc =
1259+
module.lookupSymbol<mlir::LLVM::LLVMFuncOp>("omp_target_alloc"))
1260+
return mallocFunc;
1261+
mlir::OpBuilder moduleBuilder(module.getBodyRegion());
1262+
auto i64Ty = mlir::IntegerType::get(module->getContext(), 64);
1263+
auto i32Ty = mlir::IntegerType::get(module->getContext(), 32);
1264+
return moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
1265+
moduleBuilder.getUnknownLoc(), "omp_target_alloc",
1266+
mlir::LLVM::LLVMFunctionType::get(
1267+
mlir::LLVM::LLVMPointerType::get(module->getContext()),
1268+
{i64Ty, i32Ty},
1269+
/*isVarArg=*/false));
1270+
}
1271+
1272+
namespace {
1273+
struct OmpTargetAllocMemOpConversion
1274+
: public FIROpConversion<fir::OmpTargetAllocMemOp> {
1275+
using FIROpConversion::FIROpConversion;
1276+
1277+
mlir::LogicalResult
1278+
matchAndRewrite(fir::OmpTargetAllocMemOp heap, OpAdaptor adaptor,
1279+
mlir::ConversionPatternRewriter &rewriter) const override {
1280+
mlir::Type heapTy = heap.getType();
1281+
mlir::LLVM::LLVMFuncOp mallocFunc = getOmpTargetAlloc(heap);
1282+
mlir::Location loc = heap.getLoc();
1283+
auto ity = lowerTy().indexType();
1284+
mlir::Type dataTy = fir::unwrapRefType(heapTy);
1285+
mlir::Type llvmObjectTy = convertObjectType(dataTy);
1286+
if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy)))
1287+
TODO(loc, "fir.allocmem codegen of derived type with length parameters");
1288+
mlir::Value size = genTypeSizeInBytes(loc, ity, rewriter, llvmObjectTy);
1289+
if (auto scaleSize = genAllocationScaleSize(heap, ity, rewriter))
1290+
size = rewriter.create<mlir::LLVM::MulOp>(loc, ity, size, scaleSize);
1291+
for (mlir::Value opnd : adaptor.getOperands())
1292+
size = rewriter.create<mlir::LLVM::MulOp>(
1293+
loc, ity, size, integerCast(loc, rewriter, ity, opnd));
1294+
heap->setAttr("callee", mlir::SymbolRefAttr::get(mallocFunc));
1295+
// TODO need to convert the device argument to the appropriate int type
1296+
rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
1297+
heap, ::getLlvmPtrType(heap.getContext()),
1298+
mlir::SmallVector<mlir::Value, 2>({size, heap.getDevice()}),
1299+
heap->getAttrs());
1300+
return mlir::success();
1301+
}
1302+
1303+
/// Compute the allocation size in bytes of the element type of
1304+
/// \p llTy pointer type. The result is returned as a value of \p idxTy
1305+
/// integer type.
1306+
mlir::Value genTypeSizeInBytes(mlir::Location loc, mlir::Type idxTy,
1307+
mlir::ConversionPatternRewriter &rewriter,
1308+
mlir::Type llTy) const {
1309+
return computeElementDistance(loc, llTy, idxTy, rewriter);
1310+
}
1311+
};
1312+
} // namespace
1313+
1314+
/// Return the LLVMFuncOp corresponding to omp_target_free
1315+
///
1316+
/// void omp_target_free(void *device_ptr, int device_num);
1317+
///
1318+
/// TODO is the abi correct for all targets?
1319+
static mlir::LLVM::LLVMFuncOp getOmpTargetFree(mlir::Operation *op) {
1320+
auto module = op->getParentOfType<mlir::ModuleOp>();
1321+
if (mlir::LLVM::LLVMFuncOp freeFunc =
1322+
module.lookupSymbol<mlir::LLVM::LLVMFuncOp>("omp_target_free"))
1323+
return freeFunc;
1324+
mlir::OpBuilder moduleBuilder(module.getBodyRegion());
1325+
auto i32Ty = mlir::IntegerType::get(module->getContext(), 32);
1326+
return moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
1327+
moduleBuilder.getUnknownLoc(), "omp_target_free",
1328+
mlir::LLVM::LLVMFunctionType::get(
1329+
mlir::LLVM::LLVMVoidType::get(module->getContext()),
1330+
{getLlvmPtrType(module->getContext()), i32Ty},
1331+
/*isVarArg=*/false));
1332+
}
1333+
1334+
namespace {
1335+
/// Lower a `fir.freemem` instruction into `llvm.call @free`
1336+
struct OmpTargetFreeMemOpConversion
1337+
: public FIROpConversion<fir::OmpTargetFreeMemOp> {
1338+
using FIROpConversion::FIROpConversion;
1339+
1340+
mlir::LogicalResult
1341+
matchAndRewrite(fir::OmpTargetFreeMemOp freemem, OpAdaptor adaptor,
1342+
mlir::ConversionPatternRewriter &rewriter) const override {
1343+
mlir::LLVM::LLVMFuncOp freeFunc = getOmpTargetFree(freemem);
1344+
mlir::Location loc = freemem.getLoc();
1345+
freemem->setAttr("callee", mlir::SymbolRefAttr::get(freeFunc));
1346+
// TODO need to convert the device argument to the appropriate int type
1347+
rewriter.create<mlir::LLVM::CallOp>(
1348+
loc, mlir::TypeRange{},
1349+
mlir::ValueRange{adaptor.getHeapref(), freemem.getDevice()},
1350+
freemem->getAttrs());
1351+
rewriter.eraseOp(freemem);
1352+
return mlir::success();
1353+
}
1354+
};
1355+
} // namespace
1356+
12511357
/// Return the LLVMFuncOp corresponding to the standard free call.
12521358
static mlir::LLVM::LLVMFuncOp
12531359
getFree(fir::FreeMemOp op, mlir::ConversionPatternRewriter &rewriter) {
@@ -3851,28 +3957,29 @@ class FIRToLLVMLowering
38513957
mlir::RewritePatternSet pattern(context);
38523958
pattern.insert<
38533959
AbsentOpConversion, AddcOpConversion, AddrOfOpConversion,
3854-
AllocaOpConversion, AllocMemOpConversion, BoxAddrOpConversion,
3855-
BoxCharLenOpConversion, BoxDimsOpConversion, BoxEleSizeOpConversion,
3856-
BoxIsAllocOpConversion, BoxIsArrayOpConversion, BoxIsPtrOpConversion,
3857-
BoxOffsetOpConversion, BoxProcHostOpConversion, BoxRankOpConversion,
3858-
BoxTypeCodeOpConversion, BoxTypeDescOpConversion, CallOpConversion,
3859-
CmpcOpConversion, ConstcOpConversion, ConvertOpConversion,
3860-
CoordinateOpConversion, DTEntryOpConversion, DivcOpConversion,
3861-
EmboxOpConversion, EmboxCharOpConversion, EmboxProcOpConversion,
3862-
ExtractValueOpConversion, FieldIndexOpConversion, FirEndOpConversion,
3863-
FreeMemOpConversion, GlobalLenOpConversion, GlobalOpConversion,
3864-
HasValueOpConversion, InsertOnRangeOpConversion,
3865-
InsertValueOpConversion, IsPresentOpConversion,
3866-
LenParamIndexOpConversion, LoadOpConversion, MulcOpConversion,
3867-
NegcOpConversion, NoReassocOpConversion, SelectCaseOpConversion,
3868-
SelectOpConversion, SelectRankOpConversion, SelectTypeOpConversion,
3869-
ShapeOpConversion, ShapeShiftOpConversion, ShiftOpConversion,
3870-
SliceOpConversion, StoreOpConversion, StringLitOpConversion,
3871-
SubcOpConversion, TypeDescOpConversion, TypeInfoOpConversion,
3872-
UnboxCharOpConversion, UnboxProcOpConversion, UndefOpConversion,
3873-
UnreachableOpConversion, UnrealizedConversionCastOpConversion,
3874-
XArrayCoorOpConversion, XEmboxOpConversion, XReboxOpConversion,
3875-
ZeroOpConversion>(typeConverter, options);
3960+
AllocaOpConversion, AllocMemOpConversion, OmpTargetAllocMemOpConversion,
3961+
BoxAddrOpConversion, BoxCharLenOpConversion, BoxDimsOpConversion,
3962+
BoxEleSizeOpConversion, BoxIsAllocOpConversion, BoxIsArrayOpConversion,
3963+
BoxIsPtrOpConversion, BoxOffsetOpConversion, BoxProcHostOpConversion,
3964+
BoxRankOpConversion, BoxTypeCodeOpConversion, BoxTypeDescOpConversion,
3965+
CallOpConversion, CmpcOpConversion, ConstcOpConversion,
3966+
ConvertOpConversion, CoordinateOpConversion, DTEntryOpConversion,
3967+
DivcOpConversion, EmboxOpConversion, EmboxCharOpConversion,
3968+
EmboxProcOpConversion, ExtractValueOpConversion, FieldIndexOpConversion,
3969+
FirEndOpConversion, OmpTargetFreeMemOpConversion, FreeMemOpConversion,
3970+
GlobalLenOpConversion, GlobalOpConversion, HasValueOpConversion,
3971+
InsertOnRangeOpConversion, InsertValueOpConversion,
3972+
IsPresentOpConversion, LenParamIndexOpConversion, LoadOpConversion,
3973+
MulcOpConversion, NegcOpConversion, NoReassocOpConversion,
3974+
SelectCaseOpConversion, SelectOpConversion, SelectRankOpConversion,
3975+
SelectTypeOpConversion, ShapeOpConversion, ShapeShiftOpConversion,
3976+
ShiftOpConversion, SliceOpConversion, StoreOpConversion,
3977+
StringLitOpConversion, SubcOpConversion, TypeDescOpConversion,
3978+
TypeInfoOpConversion, UnboxCharOpConversion, UnboxProcOpConversion,
3979+
UndefOpConversion, UnreachableOpConversion,
3980+
UnrealizedConversionCastOpConversion, XArrayCoorOpConversion,
3981+
XEmboxOpConversion, XReboxOpConversion, ZeroOpConversion>(typeConverter,
3982+
options);
38763983
mlir::populateFuncToLLVMConversionPatterns(typeConverter, pattern);
38773984
mlir::populateOpenMPToLLVMConversionPatterns(typeConverter, pattern);
38783985
mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, pattern);

0 commit comments

Comments
 (0)