Skip to content

Commit 4a3bf27

Browse files
authored
[OpenMP] Introduce omp.target_allocmem and omp.target_freemem omp dialect ops. (#145464)
This PR introduces two new ops in omp dialect, omp.target_allocmem and omp.target_freemem. omp.target_allocmem: Allocates heap memory on device. Will be lowered to omp_target_alloc call in llvm. omp.target_freemem: Deallocates heap memory on device. Will be lowered to omp+target_free call in llvm. Example: %1 = omp.target_allocmem %device : i32, i64 omp.target_freemem %device, %1 : i32, i64 The work in this PR is C-P/inspired from @ivanradanov commit from coexecute implementation: [Add fir omp target alloc and free ops](ivanradanov@be860ac) [Lower omp_target_{alloc,free} to llvm](ivanradanov@6e2d584)
1 parent e8e3e6e commit 4a3bf27

File tree

10 files changed

+804
-84
lines changed

10 files changed

+804
-84
lines changed

flang/include/flang/Optimizer/Support/Utils.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
#include "llvm/ADT/DenseMap.h"
2828
#include "llvm/ADT/StringRef.h"
2929

30+
#include "flang/Optimizer/CodeGen/TypeConverter.h"
31+
3032
namespace fir {
3133
/// Return the integer value of a arith::ConstantOp.
3234
inline std::int64_t toInt(mlir::arith::ConstantOp cop) {
@@ -198,6 +200,37 @@ std::optional<llvm::ArrayRef<int64_t>> getComponentLowerBoundsIfNonDefault(
198200
fir::RecordType recordType, llvm::StringRef component,
199201
mlir::ModuleOp module, const mlir::SymbolTable *symbolTable = nullptr);
200202

203+
/// Generate a LLVM constant value of type `ity`, using the provided offset.
204+
mlir::LLVM::ConstantOp
205+
genConstantIndex(mlir::Location loc, mlir::Type ity,
206+
mlir::ConversionPatternRewriter &rewriter,
207+
std::int64_t offset);
208+
209+
/// Helper function for generating the LLVM IR that computes the distance
210+
/// in bytes between adjacent elements pointed to by a pointer
211+
/// of type \p ptrTy. The result is returned as a value of \p idxTy integer
212+
/// type.
213+
mlir::Value computeElementDistance(mlir::Location loc,
214+
mlir::Type llvmObjectType, mlir::Type idxTy,
215+
mlir::ConversionPatternRewriter &rewriter,
216+
const mlir::DataLayout &dataLayout);
217+
218+
// Compute the alloc scale size (constant factors encoded in the array type).
219+
// We do this for arrays without a constant interior or arrays of character with
220+
// dynamic length arrays, since those are the only ones that get decayed to a
221+
// pointer to the element type.
222+
mlir::Value genAllocationScaleSize(mlir::Location loc, mlir::Type dataTy,
223+
mlir::Type ity,
224+
mlir::ConversionPatternRewriter &rewriter);
225+
226+
/// Perform an extension or truncation as needed on an integer value. Lowering
227+
/// to the specific target may involve some sign-extending or truncation of
228+
/// values, particularly to fit them from abstract box types to the
229+
/// appropriate reified structures.
230+
mlir::Value integerCast(const fir::LLVMTypeConverter &converter,
231+
mlir::Location loc,
232+
mlir::ConversionPatternRewriter &rewriter,
233+
mlir::Type ty, mlir::Value val, bool fold = false);
201234
} // namespace fir
202235

203236
#endif // FORTRAN_OPTIMIZER_SUPPORT_UTILS_H

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 31 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,6 @@ static inline mlir::Type getI8Type(mlir::MLIRContext *context) {
8787
return mlir::IntegerType::get(context, 8);
8888
}
8989

90-
static mlir::LLVM::ConstantOp
91-
genConstantIndex(mlir::Location loc, mlir::Type ity,
92-
mlir::ConversionPatternRewriter &rewriter,
93-
std::int64_t offset) {
94-
auto cattr = rewriter.getI64IntegerAttr(offset);
95-
return mlir::LLVM::ConstantOp::create(rewriter, loc, ity, cattr);
96-
}
97-
9890
static mlir::Block *createBlock(mlir::ConversionPatternRewriter &rewriter,
9991
mlir::Block *insertBefore) {
10092
assert(insertBefore && "expected valid insertion block");
@@ -208,39 +200,6 @@ getDependentTypeMemSizeFn(fir::RecordType recTy, fir::AllocaOp op,
208200
TODO(op.getLoc(), "did not find allocation function");
209201
}
210202

211-
// Compute the alloc scale size (constant factors encoded in the array type).
212-
// We do this for arrays without a constant interior or arrays of character with
213-
// dynamic length arrays, since those are the only ones that get decayed to a
214-
// pointer to the element type.
215-
template <typename OP>
216-
static mlir::Value
217-
genAllocationScaleSize(OP op, mlir::Type ity,
218-
mlir::ConversionPatternRewriter &rewriter) {
219-
mlir::Location loc = op.getLoc();
220-
mlir::Type dataTy = op.getInType();
221-
auto seqTy = mlir::dyn_cast<fir::SequenceType>(dataTy);
222-
fir::SequenceType::Extent constSize = 1;
223-
if (seqTy) {
224-
int constRows = seqTy.getConstantRows();
225-
const fir::SequenceType::ShapeRef &shape = seqTy.getShape();
226-
if (constRows != static_cast<int>(shape.size())) {
227-
for (auto extent : shape) {
228-
if (constRows-- > 0)
229-
continue;
230-
if (extent != fir::SequenceType::getUnknownExtent())
231-
constSize *= extent;
232-
}
233-
}
234-
}
235-
236-
if (constSize != 1) {
237-
mlir::Value constVal{
238-
genConstantIndex(loc, ity, rewriter, constSize).getResult()};
239-
return constVal;
240-
}
241-
return nullptr;
242-
}
243-
244203
namespace {
245204
struct DeclareOpConversion : public fir::FIROpConversion<fir::cg::XDeclareOp> {
246205
public:
@@ -275,7 +234,7 @@ struct AllocaOpConversion : public fir::FIROpConversion<fir::AllocaOp> {
275234
auto loc = alloc.getLoc();
276235
mlir::Type ity = lowerTy().indexType();
277236
unsigned i = 0;
278-
mlir::Value size = genConstantIndex(loc, ity, rewriter, 1).getResult();
237+
mlir::Value size = fir::genConstantIndex(loc, ity, rewriter, 1).getResult();
279238
mlir::Type firObjType = fir::unwrapRefType(alloc.getType());
280239
mlir::Type llvmObjectType = convertObjectType(firObjType);
281240
if (alloc.hasLenParams()) {
@@ -307,7 +266,8 @@ struct AllocaOpConversion : public fir::FIROpConversion<fir::AllocaOp> {
307266
<< scalarType << " with type parameters";
308267
}
309268
}
310-
if (auto scaleSize = genAllocationScaleSize(alloc, ity, rewriter))
269+
if (auto scaleSize = fir::genAllocationScaleSize(
270+
alloc.getLoc(), alloc.getInType(), ity, rewriter))
311271
size =
312272
rewriter.createOrFold<mlir::LLVM::MulOp>(loc, ity, size, scaleSize);
313273
if (alloc.hasShapeOperands()) {
@@ -484,7 +444,7 @@ struct BoxIsArrayOpConversion : public fir::FIROpConversion<fir::BoxIsArrayOp> {
484444
auto loc = boxisarray.getLoc();
485445
TypePair boxTyPair = getBoxTypePair(boxisarray.getVal().getType());
486446
mlir::Value rank = getRankFromBox(loc, boxTyPair, a, rewriter);
487-
mlir::Value c0 = genConstantIndex(loc, rank.getType(), rewriter, 0);
447+
mlir::Value c0 = fir::genConstantIndex(loc, rank.getType(), rewriter, 0);
488448
rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>(
489449
boxisarray, mlir::LLVM::ICmpPredicate::ne, rank, c0);
490450
return mlir::success();
@@ -820,7 +780,7 @@ struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> {
820780
// Do folding for constant inputs.
821781
if (auto constVal = fir::getIntIfConstant(op0)) {
822782
mlir::Value normVal =
823-
genConstantIndex(loc, toTy, rewriter, *constVal ? 1 : 0);
783+
fir::genConstantIndex(loc, toTy, rewriter, *constVal ? 1 : 0);
824784
rewriter.replaceOp(convert, normVal);
825785
return mlir::success();
826786
}
@@ -833,7 +793,7 @@ struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> {
833793
}
834794

835795
// Compare the input with zero.
836-
mlir::Value zero = genConstantIndex(loc, fromTy, rewriter, 0);
796+
mlir::Value zero = fir::genConstantIndex(loc, fromTy, rewriter, 0);
837797
auto isTrue = mlir::LLVM::ICmpOp::create(
838798
rewriter, loc, mlir::LLVM::ICmpPredicate::ne, op0, zero);
839799

@@ -1082,21 +1042,6 @@ static mlir::SymbolRefAttr getMalloc(fir::AllocMemOp op,
10821042
return getMallocInModule(mod, op, rewriter, indexType);
10831043
}
10841044

1085-
/// Helper function for generating the LLVM IR that computes the distance
1086-
/// in bytes between adjacent elements pointed to by a pointer
1087-
/// of type \p ptrTy. The result is returned as a value of \p idxTy integer
1088-
/// type.
1089-
static mlir::Value
1090-
computeElementDistance(mlir::Location loc, mlir::Type llvmObjectType,
1091-
mlir::Type idxTy,
1092-
mlir::ConversionPatternRewriter &rewriter,
1093-
const mlir::DataLayout &dataLayout) {
1094-
llvm::TypeSize size = dataLayout.getTypeSize(llvmObjectType);
1095-
unsigned short alignment = dataLayout.getTypeABIAlignment(llvmObjectType);
1096-
std::int64_t distance = llvm::alignTo(size, alignment);
1097-
return genConstantIndex(loc, idxTy, rewriter, distance);
1098-
}
1099-
11001045
/// Return value of the stride in bytes between adjacent elements
11011046
/// of LLVM type \p llTy. The result is returned as a value of
11021047
/// \p idxTy integer type.
@@ -1105,7 +1050,7 @@ genTypeStrideInBytes(mlir::Location loc, mlir::Type idxTy,
11051050
mlir::ConversionPatternRewriter &rewriter, mlir::Type llTy,
11061051
const mlir::DataLayout &dataLayout) {
11071052
// Create a pointer type and use computeElementDistance().
1108-
return computeElementDistance(loc, llTy, idxTy, rewriter, dataLayout);
1053+
return fir::computeElementDistance(loc, llTy, idxTy, rewriter, dataLayout);
11091054
}
11101055

11111056
namespace {
@@ -1124,17 +1069,18 @@ struct AllocMemOpConversion : public fir::FIROpConversion<fir::AllocMemOp> {
11241069
if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy)))
11251070
TODO(loc, "fir.allocmem codegen of derived type with length parameters");
11261071
mlir::Value size = genTypeSizeInBytes(loc, ity, rewriter, llvmObjectTy);
1127-
if (auto scaleSize = genAllocationScaleSize(heap, ity, rewriter))
1128-
size = mlir::LLVM::MulOp::create(rewriter, loc, ity, size, scaleSize);
1072+
if (auto scaleSize =
1073+
fir::genAllocationScaleSize(loc, heap.getInType(), ity, rewriter))
1074+
size = rewriter.create<mlir::LLVM::MulOp>(loc, ity, size, scaleSize);
11291075
for (mlir::Value opnd : adaptor.getOperands())
11301076
size = mlir::LLVM::MulOp::create(rewriter, loc, ity, size,
11311077
integerCast(loc, rewriter, ity, opnd));
11321078

11331079
// As the return value of malloc(0) is implementation defined, allocate one
11341080
// byte to ensure the allocation status being true. This behavior aligns to
11351081
// what the runtime has.
1136-
mlir::Value zero = genConstantIndex(loc, ity, rewriter, 0);
1137-
mlir::Value one = genConstantIndex(loc, ity, rewriter, 1);
1082+
mlir::Value zero = fir::genConstantIndex(loc, ity, rewriter, 0);
1083+
mlir::Value one = fir::genConstantIndex(loc, ity, rewriter, 1);
11381084
mlir::Value cmp = mlir::LLVM::ICmpOp::create(
11391085
rewriter, loc, mlir::LLVM::ICmpPredicate::sgt, size, zero);
11401086
size = mlir::LLVM::SelectOp::create(rewriter, loc, cmp, size, one);
@@ -1157,7 +1103,8 @@ struct AllocMemOpConversion : public fir::FIROpConversion<fir::AllocMemOp> {
11571103
mlir::Value genTypeSizeInBytes(mlir::Location loc, mlir::Type idxTy,
11581104
mlir::ConversionPatternRewriter &rewriter,
11591105
mlir::Type llTy) const {
1160-
return computeElementDistance(loc, llTy, idxTy, rewriter, getDataLayout());
1106+
return fir::computeElementDistance(loc, llTy, idxTy, rewriter,
1107+
getDataLayout());
11611108
}
11621109
};
11631110
} // namespace
@@ -1344,7 +1291,7 @@ genCUFAllocDescriptor(mlir::Location loc,
13441291
mlir::Type structTy = typeConverter.convertBoxTypeAsStruct(boxTy);
13451292
std::size_t boxSize = dl->getTypeSizeInBits(structTy) / 8;
13461293
mlir::Value sizeInBytes =
1347-
genConstantIndex(loc, llvmIntPtrType, rewriter, boxSize);
1294+
fir::genConstantIndex(loc, llvmIntPtrType, rewriter, boxSize);
13481295
llvm::SmallVector args = {sizeInBytes, sourceFile, sourceLine};
13491296
return mlir::LLVM::CallOp::create(rewriter, loc, fctTy,
13501297
RTNAME_STRING(CUFAllocDescriptor), args)
@@ -1599,7 +1546,7 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
15991546
// representation of derived types with pointer/allocatable components.
16001547
// This has been seen in hashing algorithms using TRANSFER.
16011548
mlir::Value zero =
1602-
genConstantIndex(loc, rewriter.getI64Type(), rewriter, 0);
1549+
fir::genConstantIndex(loc, rewriter.getI64Type(), rewriter, 0);
16031550
descriptor = insertField(rewriter, loc, descriptor,
16041551
{getLenParamFieldId(boxTy), 0}, zero);
16051552
}
@@ -1944,8 +1891,8 @@ struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> {
19441891
bool hasSlice = !xbox.getSlice().empty();
19451892
unsigned sliceOffset = xbox.getSliceOperandIndex();
19461893
mlir::Location loc = xbox.getLoc();
1947-
mlir::Value zero = genConstantIndex(loc, i64Ty, rewriter, 0);
1948-
mlir::Value one = genConstantIndex(loc, i64Ty, rewriter, 1);
1894+
mlir::Value zero = fir::genConstantIndex(loc, i64Ty, rewriter, 0);
1895+
mlir::Value one = fir::genConstantIndex(loc, i64Ty, rewriter, 1);
19491896
mlir::Value prevPtrOff = one;
19501897
mlir::Type eleTy = boxTy.getEleTy();
19511898
const unsigned rank = xbox.getRank();
@@ -1994,7 +1941,7 @@ struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> {
19941941
prevDimByteStride =
19951942
getCharacterByteSize(loc, rewriter, charTy, adaptor.getLenParams());
19961943
} else {
1997-
prevDimByteStride = genConstantIndex(
1944+
prevDimByteStride = fir::genConstantIndex(
19981945
loc, i64Ty, rewriter,
19991946
charTy.getLen() * lowerTy().characterBitsize(charTy) / 8);
20001947
}
@@ -2152,7 +2099,7 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
21522099
if (auto charTy = mlir::dyn_cast<fir::CharacterType>(inputEleTy)) {
21532100
if (charTy.hasConstantLen()) {
21542101
mlir::Value len =
2155-
genConstantIndex(loc, idxTy, rewriter, charTy.getLen());
2102+
fir::genConstantIndex(loc, idxTy, rewriter, charTy.getLen());
21562103
lenParams.emplace_back(len);
21572104
} else {
21582105
mlir::Value len = getElementSizeFromBox(loc, idxTy, inputBoxTyPair,
@@ -2161,7 +2108,7 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
21612108
assert(!isInGlobalOp(rewriter) &&
21622109
"character target in global op must have constant length");
21632110
mlir::Value width =
2164-
genConstantIndex(loc, idxTy, rewriter, charTy.getFKind());
2111+
fir::genConstantIndex(loc, idxTy, rewriter, charTy.getFKind());
21652112
len = mlir::LLVM::SDivOp::create(rewriter, loc, idxTy, len, width);
21662113
}
21672114
lenParams.emplace_back(len);
@@ -2215,8 +2162,9 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
22152162
mlir::ConversionPatternRewriter &rewriter) const {
22162163
mlir::Location loc = rebox.getLoc();
22172164
mlir::Value zero =
2218-
genConstantIndex(loc, lowerTy().indexType(), rewriter, 0);
2219-
mlir::Value one = genConstantIndex(loc, lowerTy().indexType(), rewriter, 1);
2165+
fir::genConstantIndex(loc, lowerTy().indexType(), rewriter, 0);
2166+
mlir::Value one =
2167+
fir::genConstantIndex(loc, lowerTy().indexType(), rewriter, 1);
22202168
for (auto iter : llvm::enumerate(llvm::zip(extents, strides))) {
22212169
mlir::Value extent = std::get<0>(iter.value());
22222170
unsigned dim = iter.index();
@@ -2249,7 +2197,7 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
22492197
mlir::Location loc = rebox.getLoc();
22502198
mlir::Type byteTy = ::getI8Type(rebox.getContext());
22512199
mlir::Type idxTy = lowerTy().indexType();
2252-
mlir::Value zero = genConstantIndex(loc, idxTy, rewriter, 0);
2200+
mlir::Value zero = fir::genConstantIndex(loc, idxTy, rewriter, 0);
22532201
// Apply subcomponent and substring shift on base address.
22542202
if (!rebox.getSubcomponent().empty() || !rebox.getSubstr().empty()) {
22552203
// Cast to inputEleTy* so that a GEP can be used.
@@ -2277,7 +2225,7 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
22772225
// and strides.
22782226
llvm::SmallVector<mlir::Value> slicedExtents;
22792227
llvm::SmallVector<mlir::Value> slicedStrides;
2280-
mlir::Value one = genConstantIndex(loc, idxTy, rewriter, 1);
2228+
mlir::Value one = fir::genConstantIndex(loc, idxTy, rewriter, 1);
22812229
const bool sliceHasOrigins = !rebox.getShift().empty();
22822230
unsigned sliceOps = rebox.getSliceOperandIndex();
22832231
unsigned shiftOps = rebox.getShiftOperandIndex();
@@ -2350,7 +2298,7 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
23502298
// which may be OK if all new extents are ones, the stride does not
23512299
// matter, use one.
23522300
mlir::Value stride = inputStrides.empty()
2353-
? genConstantIndex(loc, idxTy, rewriter, 1)
2301+
? fir::genConstantIndex(loc, idxTy, rewriter, 1)
23542302
: inputStrides[0];
23552303
for (unsigned i = 0; i < rebox.getShape().size(); ++i) {
23562304
mlir::Value rawExtent = operands[rebox.getShapeOperandIndex() + i];
@@ -2585,9 +2533,9 @@ struct XArrayCoorOpConversion
25852533
unsigned shiftOffset = coor.getShiftOperandIndex();
25862534
unsigned sliceOffset = coor.getSliceOperandIndex();
25872535
auto sliceOps = coor.getSlice().begin();
2588-
mlir::Value one = genConstantIndex(loc, idxTy, rewriter, 1);
2536+
mlir::Value one = fir::genConstantIndex(loc, idxTy, rewriter, 1);
25892537
mlir::Value prevExt = one;
2590-
mlir::Value offset = genConstantIndex(loc, idxTy, rewriter, 0);
2538+
mlir::Value offset = fir::genConstantIndex(loc, idxTy, rewriter, 0);
25912539
const bool isShifted = !coor.getShift().empty();
25922540
const bool isSliced = !coor.getSlice().empty();
25932541
const bool baseIsBoxed =
@@ -2918,7 +2866,7 @@ struct CoordinateOpConversion
29182866
// of lower bound aspects. This both accounts for dynamically sized
29192867
// types and non contiguous arrays.
29202868
auto idxTy = lowerTy().indexType();
2921-
mlir::Value off = genConstantIndex(loc, idxTy, rewriter, 0);
2869+
mlir::Value off = fir::genConstantIndex(loc, idxTy, rewriter, 0);
29222870
unsigned arrayDim = arrTy.getDimension();
29232871
for (unsigned dim = 0; dim < arrayDim && it != end; ++dim, ++it) {
29242872
mlir::Value stride =
@@ -3846,7 +3794,7 @@ struct IsPresentOpConversion : public fir::FIROpConversion<fir::IsPresentOp> {
38463794
ptr = mlir::LLVM::ExtractValueOp::create(rewriter, loc, ptr, 0);
38473795
}
38483796
mlir::LLVM::ConstantOp c0 =
3849-
genConstantIndex(isPresent.getLoc(), idxTy, rewriter, 0);
3797+
fir::genConstantIndex(isPresent.getLoc(), idxTy, rewriter, 0);
38503798
auto addr = mlir::LLVM::PtrToIntOp::create(rewriter, loc, idxTy, ptr);
38513799
rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>(
38523800
isPresent, mlir::LLVM::ICmpPredicate::ne, addr, c0);

0 commit comments

Comments
 (0)