@@ -38,6 +38,10 @@ static constexpr unsigned defaultAlign = 8;
38
38
static constexpr unsigned kAttrPointer = CFI_attribute_pointer;
39
39
static constexpr unsigned kAttrAllocatable = CFI_attribute_allocatable;
40
40
41
+ static mlir::Type getVoidPtrType (mlir::MLIRContext *context) {
42
+ return mlir::LLVM::LLVMPointerType::get (mlir::IntegerType::get (context, 8 ));
43
+ }
44
+
41
45
static mlir::LLVM::ConstantOp
42
46
genConstantIndex (mlir::Location loc, mlir::Type ity,
43
47
mlir::ConversionPatternRewriter &rewriter,
@@ -1854,6 +1858,222 @@ struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> {
1854
1858
}
1855
1859
};
1856
1860
1861
+ // / Create a new box given a box reference.
1862
+ struct XReboxOpConversion : public EmboxCommonConversion <fir::cg::XReboxOp> {
1863
+ using EmboxCommonConversion::EmboxCommonConversion;
1864
+
1865
+ mlir::LogicalResult
1866
+ matchAndRewrite (fir::cg::XReboxOp rebox, OpAdaptor adaptor,
1867
+ mlir::ConversionPatternRewriter &rewriter) const override {
1868
+ mlir::Location loc = rebox.getLoc ();
1869
+ mlir::Type idxTy = lowerTy ().indexType ();
1870
+ mlir::Value loweredBox = adaptor.getOperands ()[0 ];
1871
+ mlir::ValueRange operands = adaptor.getOperands ();
1872
+
1873
+ // Create new descriptor and fill its non-shape related data.
1874
+ llvm::SmallVector<mlir::Value, 2 > lenParams;
1875
+ mlir::Type inputEleTy = getInputEleTy (rebox);
1876
+ if (auto charTy = inputEleTy.dyn_cast <fir::CharacterType>()) {
1877
+ mlir::Value len =
1878
+ loadElementSizeFromBox (loc, idxTy, loweredBox, rewriter);
1879
+ if (charTy.getFKind () != 1 ) {
1880
+ mlir::Value width =
1881
+ genConstantIndex (loc, idxTy, rewriter, charTy.getFKind ());
1882
+ len = rewriter.create <mlir::LLVM::SDivOp>(loc, idxTy, len, width);
1883
+ }
1884
+ lenParams.emplace_back (len);
1885
+ } else if (auto recTy = inputEleTy.dyn_cast <fir::RecordType>()) {
1886
+ if (recTy.getNumLenParams () != 0 )
1887
+ TODO (loc, " reboxing descriptor of derived type with length parameters" );
1888
+ }
1889
+ auto [boxTy, dest, eleSize] =
1890
+ consDescriptorPrefix (rebox, rewriter, rebox.getOutRank (), lenParams);
1891
+
1892
+ // Read input extents, strides, and base address
1893
+ llvm::SmallVector<mlir::Value> inputExtents;
1894
+ llvm::SmallVector<mlir::Value> inputStrides;
1895
+ const unsigned inputRank = rebox.getRank ();
1896
+ for (unsigned i = 0 ; i < inputRank; ++i) {
1897
+ mlir::Value dim = genConstantIndex (loc, idxTy, rewriter, i);
1898
+ SmallVector<mlir::Value, 3 > dimInfo =
1899
+ getDimsFromBox (loc, {idxTy, idxTy, idxTy}, loweredBox, dim, rewriter);
1900
+ inputExtents.emplace_back (dimInfo[1 ]);
1901
+ inputStrides.emplace_back (dimInfo[2 ]);
1902
+ }
1903
+
1904
+ mlir::Type baseTy = getBaseAddrTypeFromBox (loweredBox.getType ());
1905
+ mlir::Value baseAddr =
1906
+ loadBaseAddrFromBox (loc, baseTy, loweredBox, rewriter);
1907
+
1908
+ if (!rebox.slice ().empty () || !rebox.subcomponent ().empty ())
1909
+ return sliceBox (rebox, dest, baseAddr, inputExtents, inputStrides,
1910
+ operands, rewriter);
1911
+ return reshapeBox (rebox, dest, baseAddr, inputExtents, inputStrides,
1912
+ operands, rewriter);
1913
+ }
1914
+
1915
+ private:
1916
+ // / Write resulting shape and base address in descriptor, and replace rebox
1917
+ // / op.
1918
+ mlir::LogicalResult
1919
+ finalizeRebox (fir::cg::XReboxOp rebox, mlir::Value dest, mlir::Value base,
1920
+ mlir::ValueRange lbounds, mlir::ValueRange extents,
1921
+ mlir::ValueRange strides,
1922
+ mlir::ConversionPatternRewriter &rewriter) const {
1923
+ mlir::Location loc = rebox.getLoc ();
1924
+ mlir::Value one = genConstantIndex (loc, lowerTy ().indexType (), rewriter, 1 );
1925
+ for (auto iter : llvm::enumerate (llvm::zip (extents, strides))) {
1926
+ unsigned dim = iter.index ();
1927
+ mlir::Value lb = lbounds.empty () ? one : lbounds[dim];
1928
+ dest = insertLowerBound (rewriter, loc, dest, dim, lb);
1929
+ dest = insertExtent (rewriter, loc, dest, dim, std::get<0 >(iter.value ()));
1930
+ dest = insertStride (rewriter, loc, dest, dim, std::get<1 >(iter.value ()));
1931
+ }
1932
+ dest = insertBaseAddress (rewriter, loc, dest, base);
1933
+ mlir::Value result =
1934
+ placeInMemoryIfNotGlobalInit (rewriter, rebox.getLoc (), dest);
1935
+ rewriter.replaceOp (rebox, result);
1936
+ return success ();
1937
+ }
1938
+
1939
+ // Apply slice given the base address, extents and strides of the input box.
1940
+ mlir::LogicalResult
1941
+ sliceBox (fir::cg::XReboxOp rebox, mlir::Value dest, mlir::Value base,
1942
+ mlir::ValueRange inputExtents, mlir::ValueRange inputStrides,
1943
+ mlir::ValueRange operands,
1944
+ mlir::ConversionPatternRewriter &rewriter) const {
1945
+ mlir::Location loc = rebox.getLoc ();
1946
+ mlir::Type voidPtrTy = ::getVoidPtrType (rebox.getContext ());
1947
+ mlir::Type idxTy = lowerTy ().indexType ();
1948
+ mlir::Value zero = genConstantIndex (loc, idxTy, rewriter, 0 );
1949
+ // Apply subcomponent and substring shift on base address.
1950
+ if (!rebox.subcomponent ().empty () || !rebox.substr ().empty ()) {
1951
+ // Cast to inputEleTy* so that a GEP can be used.
1952
+ mlir::Type inputEleTy = getInputEleTy (rebox);
1953
+ auto llvmElePtrTy =
1954
+ mlir::LLVM::LLVMPointerType::get (convertType (inputEleTy));
1955
+ base = rewriter.create <mlir::LLVM::BitcastOp>(loc, llvmElePtrTy, base);
1956
+
1957
+ if (!rebox.subcomponent ().empty ()) {
1958
+ llvm::SmallVector<mlir::Value> gepOperands = {zero};
1959
+ for (unsigned i = 0 ; i < rebox.subcomponent ().size (); ++i)
1960
+ gepOperands.push_back (operands[rebox.subcomponentOffset () + i]);
1961
+ base = genGEP (loc, llvmElePtrTy, rewriter, base, gepOperands);
1962
+ }
1963
+ if (!rebox.substr ().empty ())
1964
+ base = shiftSubstringBase (rewriter, loc, base,
1965
+ operands[rebox.substrOffset ()]);
1966
+ }
1967
+
1968
+ if (rebox.slice ().empty ())
1969
+ // The array section is of the form array[%component][substring], keep
1970
+ // the input array extents and strides.
1971
+ return finalizeRebox (rebox, dest, base, /* lbounds*/ llvm::None,
1972
+ inputExtents, inputStrides, rewriter);
1973
+
1974
+ // Strides from the fir.box are in bytes.
1975
+ base = rewriter.create <mlir::LLVM::BitcastOp>(loc, voidPtrTy, base);
1976
+
1977
+ // The slice is of the form array(i:j:k)[%component]. Compute new extents
1978
+ // and strides.
1979
+ llvm::SmallVector<mlir::Value> slicedExtents;
1980
+ llvm::SmallVector<mlir::Value> slicedStrides;
1981
+ mlir::Value one = genConstantIndex (loc, idxTy, rewriter, 1 );
1982
+ const bool sliceHasOrigins = !rebox.shift ().empty ();
1983
+ unsigned sliceOps = rebox.sliceOffset ();
1984
+ unsigned shiftOps = rebox.shiftOffset ();
1985
+ auto strideOps = inputStrides.begin ();
1986
+ const unsigned inputRank = inputStrides.size ();
1987
+ for (unsigned i = 0 ; i < inputRank;
1988
+ ++i, ++strideOps, ++shiftOps, sliceOps += 3 ) {
1989
+ mlir::Value sliceLb =
1990
+ integerCast (loc, rewriter, idxTy, operands[sliceOps]);
1991
+ mlir::Value inputStride = *strideOps; // already idxTy
1992
+ // Apply origin shift: base += (lb-shift)*input_stride
1993
+ mlir::Value sliceOrigin =
1994
+ sliceHasOrigins
1995
+ ? integerCast (loc, rewriter, idxTy, operands[shiftOps])
1996
+ : one;
1997
+ mlir::Value diff =
1998
+ rewriter.create <mlir::LLVM::SubOp>(loc, idxTy, sliceLb, sliceOrigin);
1999
+ mlir::Value offset =
2000
+ rewriter.create <mlir::LLVM::MulOp>(loc, idxTy, diff, inputStride);
2001
+ base = genGEP (loc, voidPtrTy, rewriter, base, offset);
2002
+ // Apply upper bound and step if this is a triplet. Otherwise, the
2003
+ // dimension is dropped and no extents/strides are computed.
2004
+ mlir::Value upper = operands[sliceOps + 1 ];
2005
+ const bool isTripletSlice =
2006
+ !mlir::isa_and_nonnull<mlir::LLVM::UndefOp>(upper.getDefiningOp ());
2007
+ if (isTripletSlice) {
2008
+ mlir::Value step =
2009
+ integerCast (loc, rewriter, idxTy, operands[sliceOps + 2 ]);
2010
+ // extent = ub-lb+step/step
2011
+ mlir::Value sliceUb = integerCast (loc, rewriter, idxTy, upper);
2012
+ mlir::Value extent = computeTripletExtent (rewriter, loc, sliceLb,
2013
+ sliceUb, step, zero, idxTy);
2014
+ slicedExtents.emplace_back (extent);
2015
+ // stride = step*input_stride
2016
+ mlir::Value stride =
2017
+ rewriter.create <mlir::LLVM::MulOp>(loc, idxTy, step, inputStride);
2018
+ slicedStrides.emplace_back (stride);
2019
+ }
2020
+ }
2021
+ return finalizeRebox (rebox, dest, base, /* lbounds*/ llvm::None,
2022
+ slicedExtents, slicedStrides, rewriter);
2023
+ }
2024
+
2025
+ // / Apply a new shape to the data described by a box given the base address,
2026
+ // / extents and strides of the box.
2027
+ mlir::LogicalResult
2028
+ reshapeBox (fir::cg::XReboxOp rebox, mlir::Value dest, mlir::Value base,
2029
+ mlir::ValueRange inputExtents, mlir::ValueRange inputStrides,
2030
+ mlir::ValueRange operands,
2031
+ mlir::ConversionPatternRewriter &rewriter) const {
2032
+ mlir::ValueRange reboxShifts{operands.begin () + rebox.shiftOffset (),
2033
+ operands.begin () + rebox.shiftOffset () +
2034
+ rebox.shift ().size ()};
2035
+ if (rebox.shape ().empty ()) {
2036
+ // Only setting new lower bounds.
2037
+ return finalizeRebox (rebox, dest, base, reboxShifts, inputExtents,
2038
+ inputStrides, rewriter);
2039
+ }
2040
+
2041
+ mlir::Location loc = rebox.getLoc ();
2042
+ // Strides from the fir.box are in bytes.
2043
+ mlir::Type voidPtrTy = ::getVoidPtrType (rebox.getContext ());
2044
+ base = rewriter.create <mlir::LLVM::BitcastOp>(loc, voidPtrTy, base);
2045
+
2046
+ llvm::SmallVector<mlir::Value> newStrides;
2047
+ llvm::SmallVector<mlir::Value> newExtents;
2048
+ mlir::Type idxTy = lowerTy ().indexType ();
2049
+ // First stride from input box is kept. The rest is assumed contiguous
2050
+ // (it is not possible to reshape otherwise). If the input is scalar,
2051
+ // which may be OK if all new extents are ones, the stride does not
2052
+ // matter, use one.
2053
+ mlir::Value stride = inputStrides.empty ()
2054
+ ? genConstantIndex (loc, idxTy, rewriter, 1 )
2055
+ : inputStrides[0 ];
2056
+ for (unsigned i = 0 ; i < rebox.shape ().size (); ++i) {
2057
+ mlir::Value rawExtent = operands[rebox.shapeOffset () + i];
2058
+ mlir::Value extent = integerCast (loc, rewriter, idxTy, rawExtent);
2059
+ newExtents.emplace_back (extent);
2060
+ newStrides.emplace_back (stride);
2061
+ // nextStride = extent * stride;
2062
+ stride = rewriter.create <mlir::LLVM::MulOp>(loc, idxTy, extent, stride);
2063
+ }
2064
+ return finalizeRebox (rebox, dest, base, reboxShifts, newExtents, newStrides,
2065
+ rewriter);
2066
+ }
2067
+
2068
+ // / Return scalar element type of the input box.
2069
+ static mlir::Type getInputEleTy (fir::cg::XReboxOp rebox) {
2070
+ auto ty = fir::dyn_cast_ptrOrBoxEleTy (rebox.box ().getType ());
2071
+ if (auto seqTy = ty.dyn_cast <fir::SequenceType>())
2072
+ return seqTy.getEleTy ();
2073
+ return ty;
2074
+ }
2075
+ };
2076
+
1857
2077
// Code shared between insert_value and extract_value Ops.
1858
2078
struct ValueOpCommon {
1859
2079
// Translate the arguments pertaining to any multidimensional array to
@@ -2616,7 +2836,8 @@ class FIRToLLVMLowering : public fir::FIRToLLVMLoweringBase<FIRToLLVMLowering> {
2616
2836
SliceOpConversion, StoreOpConversion, StringLitOpConversion,
2617
2837
SubcOpConversion, UnboxCharOpConversion, UnboxProcOpConversion,
2618
2838
UndefOpConversion, UnreachableOpConversion, XArrayCoorOpConversion,
2619
- XEmboxOpConversion, ZeroOpConversion>(typeConverter);
2839
+ XEmboxOpConversion, XReboxOpConversion, ZeroOpConversion>(
2840
+ typeConverter);
2620
2841
mlir::populateStdToLLVMConversionPatterns (typeConverter, pattern);
2621
2842
mlir::arith::populateArithmeticToLLVMConversionPatterns (typeConverter,
2622
2843
pattern);
0 commit comments