Skip to content

Commit fa51755

Browse files
[Flang] Upstream conversion of the XRebox Op
The XRebox Op is formed by the codegen rewrite which makes it easier to convert the operation to LLVM. The XRebox op includes the information from the rebox op and the associated slice, shift, and shape ops. During the conversion process a new descriptor is created for reboxing. Co-authored-by: Jean Perier <[email protected]> Co-authored-by: Eric Schweitz <[email protected]> Co-authored-by: Val Donaldson <[email protected]> Reviewed By: clementval Differential Revision: https://reviews.llvm.org/D114709
1 parent 36529a2 commit fa51755

File tree

3 files changed

+394
-1
lines changed

3 files changed

+394
-1
lines changed

flang/include/flang/Optimizer/CodeGen/CGOps.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,14 @@ def fircg_XReboxOp : fircg_Op<"ext_rebox", [AttrSizedOperandSegments]> {
127127
unsigned getRank();
128128
// The rank of the result box
129129
unsigned getOutRank();
130+
131+
unsigned shapeOffset() { return 1; }
132+
unsigned shiftOffset() { return shapeOffset() + shape().size(); }
133+
unsigned sliceOffset() { return shiftOffset() + shift().size(); }
134+
unsigned subcomponentOffset() { return sliceOffset() + slice().size(); }
135+
unsigned substrOffset() {
136+
return subcomponentOffset() + subcomponent().size();
137+
}
130138
}];
131139
}
132140

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 222 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ static constexpr unsigned defaultAlign = 8;
3838
static constexpr unsigned kAttrPointer = CFI_attribute_pointer;
3939
static constexpr unsigned kAttrAllocatable = CFI_attribute_allocatable;
4040

41+
static mlir::Type getVoidPtrType(mlir::MLIRContext *context) {
42+
return mlir::LLVM::LLVMPointerType::get(mlir::IntegerType::get(context, 8));
43+
}
44+
4145
static mlir::LLVM::ConstantOp
4246
genConstantIndex(mlir::Location loc, mlir::Type ity,
4347
mlir::ConversionPatternRewriter &rewriter,
@@ -1854,6 +1858,222 @@ struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> {
18541858
}
18551859
};
18561860

1861+
/// Create a new box given a box reference.
1862+
struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
1863+
using EmboxCommonConversion::EmboxCommonConversion;
1864+
1865+
mlir::LogicalResult
1866+
matchAndRewrite(fir::cg::XReboxOp rebox, OpAdaptor adaptor,
1867+
mlir::ConversionPatternRewriter &rewriter) const override {
1868+
mlir::Location loc = rebox.getLoc();
1869+
mlir::Type idxTy = lowerTy().indexType();
1870+
mlir::Value loweredBox = adaptor.getOperands()[0];
1871+
mlir::ValueRange operands = adaptor.getOperands();
1872+
1873+
// Create new descriptor and fill its non-shape related data.
1874+
llvm::SmallVector<mlir::Value, 2> lenParams;
1875+
mlir::Type inputEleTy = getInputEleTy(rebox);
1876+
if (auto charTy = inputEleTy.dyn_cast<fir::CharacterType>()) {
1877+
mlir::Value len =
1878+
loadElementSizeFromBox(loc, idxTy, loweredBox, rewriter);
1879+
if (charTy.getFKind() != 1) {
1880+
mlir::Value width =
1881+
genConstantIndex(loc, idxTy, rewriter, charTy.getFKind());
1882+
len = rewriter.create<mlir::LLVM::SDivOp>(loc, idxTy, len, width);
1883+
}
1884+
lenParams.emplace_back(len);
1885+
} else if (auto recTy = inputEleTy.dyn_cast<fir::RecordType>()) {
1886+
if (recTy.getNumLenParams() != 0)
1887+
TODO(loc, "reboxing descriptor of derived type with length parameters");
1888+
}
1889+
auto [boxTy, dest, eleSize] =
1890+
consDescriptorPrefix(rebox, rewriter, rebox.getOutRank(), lenParams);
1891+
1892+
// Read input extents, strides, and base address
1893+
llvm::SmallVector<mlir::Value> inputExtents;
1894+
llvm::SmallVector<mlir::Value> inputStrides;
1895+
const unsigned inputRank = rebox.getRank();
1896+
for (unsigned i = 0; i < inputRank; ++i) {
1897+
mlir::Value dim = genConstantIndex(loc, idxTy, rewriter, i);
1898+
SmallVector<mlir::Value, 3> dimInfo =
1899+
getDimsFromBox(loc, {idxTy, idxTy, idxTy}, loweredBox, dim, rewriter);
1900+
inputExtents.emplace_back(dimInfo[1]);
1901+
inputStrides.emplace_back(dimInfo[2]);
1902+
}
1903+
1904+
mlir::Type baseTy = getBaseAddrTypeFromBox(loweredBox.getType());
1905+
mlir::Value baseAddr =
1906+
loadBaseAddrFromBox(loc, baseTy, loweredBox, rewriter);
1907+
1908+
if (!rebox.slice().empty() || !rebox.subcomponent().empty())
1909+
return sliceBox(rebox, dest, baseAddr, inputExtents, inputStrides,
1910+
operands, rewriter);
1911+
return reshapeBox(rebox, dest, baseAddr, inputExtents, inputStrides,
1912+
operands, rewriter);
1913+
}
1914+
1915+
private:
1916+
/// Write resulting shape and base address in descriptor, and replace rebox
1917+
/// op.
1918+
mlir::LogicalResult
1919+
finalizeRebox(fir::cg::XReboxOp rebox, mlir::Value dest, mlir::Value base,
1920+
mlir::ValueRange lbounds, mlir::ValueRange extents,
1921+
mlir::ValueRange strides,
1922+
mlir::ConversionPatternRewriter &rewriter) const {
1923+
mlir::Location loc = rebox.getLoc();
1924+
mlir::Value one = genConstantIndex(loc, lowerTy().indexType(), rewriter, 1);
1925+
for (auto iter : llvm::enumerate(llvm::zip(extents, strides))) {
1926+
unsigned dim = iter.index();
1927+
mlir::Value lb = lbounds.empty() ? one : lbounds[dim];
1928+
dest = insertLowerBound(rewriter, loc, dest, dim, lb);
1929+
dest = insertExtent(rewriter, loc, dest, dim, std::get<0>(iter.value()));
1930+
dest = insertStride(rewriter, loc, dest, dim, std::get<1>(iter.value()));
1931+
}
1932+
dest = insertBaseAddress(rewriter, loc, dest, base);
1933+
mlir::Value result =
1934+
placeInMemoryIfNotGlobalInit(rewriter, rebox.getLoc(), dest);
1935+
rewriter.replaceOp(rebox, result);
1936+
return success();
1937+
}
1938+
1939+
// Apply slice given the base address, extents and strides of the input box.
1940+
mlir::LogicalResult
1941+
sliceBox(fir::cg::XReboxOp rebox, mlir::Value dest, mlir::Value base,
1942+
mlir::ValueRange inputExtents, mlir::ValueRange inputStrides,
1943+
mlir::ValueRange operands,
1944+
mlir::ConversionPatternRewriter &rewriter) const {
1945+
mlir::Location loc = rebox.getLoc();
1946+
mlir::Type voidPtrTy = ::getVoidPtrType(rebox.getContext());
1947+
mlir::Type idxTy = lowerTy().indexType();
1948+
mlir::Value zero = genConstantIndex(loc, idxTy, rewriter, 0);
1949+
// Apply subcomponent and substring shift on base address.
1950+
if (!rebox.subcomponent().empty() || !rebox.substr().empty()) {
1951+
// Cast to inputEleTy* so that a GEP can be used.
1952+
mlir::Type inputEleTy = getInputEleTy(rebox);
1953+
auto llvmElePtrTy =
1954+
mlir::LLVM::LLVMPointerType::get(convertType(inputEleTy));
1955+
base = rewriter.create<mlir::LLVM::BitcastOp>(loc, llvmElePtrTy, base);
1956+
1957+
if (!rebox.subcomponent().empty()) {
1958+
llvm::SmallVector<mlir::Value> gepOperands = {zero};
1959+
for (unsigned i = 0; i < rebox.subcomponent().size(); ++i)
1960+
gepOperands.push_back(operands[rebox.subcomponentOffset() + i]);
1961+
base = genGEP(loc, llvmElePtrTy, rewriter, base, gepOperands);
1962+
}
1963+
if (!rebox.substr().empty())
1964+
base = shiftSubstringBase(rewriter, loc, base,
1965+
operands[rebox.substrOffset()]);
1966+
}
1967+
1968+
if (rebox.slice().empty())
1969+
// The array section is of the form array[%component][substring], keep
1970+
// the input array extents and strides.
1971+
return finalizeRebox(rebox, dest, base, /*lbounds*/ llvm::None,
1972+
inputExtents, inputStrides, rewriter);
1973+
1974+
// Strides from the fir.box are in bytes.
1975+
base = rewriter.create<mlir::LLVM::BitcastOp>(loc, voidPtrTy, base);
1976+
1977+
// The slice is of the form array(i:j:k)[%component]. Compute new extents
1978+
// and strides.
1979+
llvm::SmallVector<mlir::Value> slicedExtents;
1980+
llvm::SmallVector<mlir::Value> slicedStrides;
1981+
mlir::Value one = genConstantIndex(loc, idxTy, rewriter, 1);
1982+
const bool sliceHasOrigins = !rebox.shift().empty();
1983+
unsigned sliceOps = rebox.sliceOffset();
1984+
unsigned shiftOps = rebox.shiftOffset();
1985+
auto strideOps = inputStrides.begin();
1986+
const unsigned inputRank = inputStrides.size();
1987+
for (unsigned i = 0; i < inputRank;
1988+
++i, ++strideOps, ++shiftOps, sliceOps += 3) {
1989+
mlir::Value sliceLb =
1990+
integerCast(loc, rewriter, idxTy, operands[sliceOps]);
1991+
mlir::Value inputStride = *strideOps; // already idxTy
1992+
// Apply origin shift: base += (lb-shift)*input_stride
1993+
mlir::Value sliceOrigin =
1994+
sliceHasOrigins
1995+
? integerCast(loc, rewriter, idxTy, operands[shiftOps])
1996+
: one;
1997+
mlir::Value diff =
1998+
rewriter.create<mlir::LLVM::SubOp>(loc, idxTy, sliceLb, sliceOrigin);
1999+
mlir::Value offset =
2000+
rewriter.create<mlir::LLVM::MulOp>(loc, idxTy, diff, inputStride);
2001+
base = genGEP(loc, voidPtrTy, rewriter, base, offset);
2002+
// Apply upper bound and step if this is a triplet. Otherwise, the
2003+
// dimension is dropped and no extents/strides are computed.
2004+
mlir::Value upper = operands[sliceOps + 1];
2005+
const bool isTripletSlice =
2006+
!mlir::isa_and_nonnull<mlir::LLVM::UndefOp>(upper.getDefiningOp());
2007+
if (isTripletSlice) {
2008+
mlir::Value step =
2009+
integerCast(loc, rewriter, idxTy, operands[sliceOps + 2]);
2010+
// extent = ub-lb+step/step
2011+
mlir::Value sliceUb = integerCast(loc, rewriter, idxTy, upper);
2012+
mlir::Value extent = computeTripletExtent(rewriter, loc, sliceLb,
2013+
sliceUb, step, zero, idxTy);
2014+
slicedExtents.emplace_back(extent);
2015+
// stride = step*input_stride
2016+
mlir::Value stride =
2017+
rewriter.create<mlir::LLVM::MulOp>(loc, idxTy, step, inputStride);
2018+
slicedStrides.emplace_back(stride);
2019+
}
2020+
}
2021+
return finalizeRebox(rebox, dest, base, /*lbounds*/ llvm::None,
2022+
slicedExtents, slicedStrides, rewriter);
2023+
}
2024+
2025+
/// Apply a new shape to the data described by a box given the base address,
2026+
/// extents and strides of the box.
2027+
mlir::LogicalResult
2028+
reshapeBox(fir::cg::XReboxOp rebox, mlir::Value dest, mlir::Value base,
2029+
mlir::ValueRange inputExtents, mlir::ValueRange inputStrides,
2030+
mlir::ValueRange operands,
2031+
mlir::ConversionPatternRewriter &rewriter) const {
2032+
mlir::ValueRange reboxShifts{operands.begin() + rebox.shiftOffset(),
2033+
operands.begin() + rebox.shiftOffset() +
2034+
rebox.shift().size()};
2035+
if (rebox.shape().empty()) {
2036+
// Only setting new lower bounds.
2037+
return finalizeRebox(rebox, dest, base, reboxShifts, inputExtents,
2038+
inputStrides, rewriter);
2039+
}
2040+
2041+
mlir::Location loc = rebox.getLoc();
2042+
// Strides from the fir.box are in bytes.
2043+
mlir::Type voidPtrTy = ::getVoidPtrType(rebox.getContext());
2044+
base = rewriter.create<mlir::LLVM::BitcastOp>(loc, voidPtrTy, base);
2045+
2046+
llvm::SmallVector<mlir::Value> newStrides;
2047+
llvm::SmallVector<mlir::Value> newExtents;
2048+
mlir::Type idxTy = lowerTy().indexType();
2049+
// First stride from input box is kept. The rest is assumed contiguous
2050+
// (it is not possible to reshape otherwise). If the input is scalar,
2051+
// which may be OK if all new extents are ones, the stride does not
2052+
// matter, use one.
2053+
mlir::Value stride = inputStrides.empty()
2054+
? genConstantIndex(loc, idxTy, rewriter, 1)
2055+
: inputStrides[0];
2056+
for (unsigned i = 0; i < rebox.shape().size(); ++i) {
2057+
mlir::Value rawExtent = operands[rebox.shapeOffset() + i];
2058+
mlir::Value extent = integerCast(loc, rewriter, idxTy, rawExtent);
2059+
newExtents.emplace_back(extent);
2060+
newStrides.emplace_back(stride);
2061+
// nextStride = extent * stride;
2062+
stride = rewriter.create<mlir::LLVM::MulOp>(loc, idxTy, extent, stride);
2063+
}
2064+
return finalizeRebox(rebox, dest, base, reboxShifts, newExtents, newStrides,
2065+
rewriter);
2066+
}
2067+
2068+
/// Return scalar element type of the input box.
2069+
static mlir::Type getInputEleTy(fir::cg::XReboxOp rebox) {
2070+
auto ty = fir::dyn_cast_ptrOrBoxEleTy(rebox.box().getType());
2071+
if (auto seqTy = ty.dyn_cast<fir::SequenceType>())
2072+
return seqTy.getEleTy();
2073+
return ty;
2074+
}
2075+
};
2076+
18572077
// Code shared between insert_value and extract_value Ops.
18582078
struct ValueOpCommon {
18592079
// Translate the arguments pertaining to any multidimensional array to
@@ -2616,7 +2836,8 @@ class FIRToLLVMLowering : public fir::FIRToLLVMLoweringBase<FIRToLLVMLowering> {
26162836
SliceOpConversion, StoreOpConversion, StringLitOpConversion,
26172837
SubcOpConversion, UnboxCharOpConversion, UnboxProcOpConversion,
26182838
UndefOpConversion, UnreachableOpConversion, XArrayCoorOpConversion,
2619-
XEmboxOpConversion, ZeroOpConversion>(typeConverter);
2839+
XEmboxOpConversion, XReboxOpConversion, ZeroOpConversion>(
2840+
typeConverter);
26202841
mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern);
26212842
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
26222843
pattern);

0 commit comments

Comments
 (0)