17
17
#include " mlir/Dialect/MemRef/Transforms/Transforms.h"
18
18
#include " mlir/Dialect/MemRef/Utils/MemRefUtils.h"
19
19
#include " mlir/Dialect/Vector/IR/VectorOps.h"
20
+ #include " mlir/Support/MathExtras.h"
20
21
#include " mlir/Transforms/DialectConversion.h"
21
22
#include " llvm/Support/FormatVariadic.h"
22
23
#include " llvm/Support/MathExtras.h"
@@ -209,6 +210,76 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
209
210
return success ();
210
211
}
211
212
};
213
+
214
+ // ===----------------------------------------------------------------------===//
215
+ // ConvertMemRefSubview
216
+ // ===----------------------------------------------------------------------===//
217
+
218
+ // / Emulating narrow ints on subview have limited support, supporting only
219
+ // / static offset and size and stride of 1. Ideally, the subview should be
220
+ // / folded away before running narrow type emulation, and this pattern would
221
+ // / never run. This pattern is mostly used for testing pruposes.
222
+ struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
223
+ using OpConversionPattern::OpConversionPattern;
224
+
225
+ LogicalResult
226
+ matchAndRewrite (memref::SubViewOp op, OpAdaptor adaptor,
227
+ ConversionPatternRewriter &rewriter) const override {
228
+ MemRefType newTy =
229
+ dyn_cast<MemRefType>(getTypeConverter ()->convertType (op.getType ()));
230
+ if (!newTy) {
231
+ return rewriter.notifyMatchFailure (
232
+ op->getLoc (),
233
+ llvm::formatv (" failed to convert memref type: {0}" , op.getType ()));
234
+ }
235
+
236
+ auto convertedElementType = newTy.getElementType ();
237
+ auto oldElementType = op.getType ().getElementType ();
238
+ int srcBits = oldElementType.getIntOrFloatBitWidth ();
239
+ int dstBits = convertedElementType.getIntOrFloatBitWidth ();
240
+ if (dstBits % srcBits != 0 ) {
241
+ return rewriter.notifyMatchFailure (
242
+ op, " only dstBits % srcBits == 0 supported" );
243
+ }
244
+
245
+ // Only support offset for 1-D subview.
246
+ if (op.getType ().getRank () != 1 ) {
247
+ return rewriter.notifyMatchFailure (
248
+ op->getLoc (), " subview with rank > 1 is not supported" );
249
+ }
250
+
251
+ // Only support stride of 1.
252
+ if (op.getStaticStride (0 ) != 1 ) {
253
+ return rewriter.notifyMatchFailure (
254
+ op->getLoc (), " subview with stride != 1 is not supported" );
255
+ }
256
+
257
+ int64_t size = op.getStaticSize (0 );
258
+ int64_t offset = op.getStaticOffset (0 );
259
+ // Only support static sizes and offsets.
260
+ if (size == ShapedType::kDynamic || offset == ShapedType::kDynamic ) {
261
+ return rewriter.notifyMatchFailure (
262
+ op->getLoc (), " subview with dynamic size or offset is not supported" );
263
+ }
264
+
265
+ int elementsPerByte = dstBits / srcBits;
266
+ if (offset % elementsPerByte != 0 ) {
267
+ return rewriter.notifyMatchFailure (
268
+ op->getLoc (),
269
+ " subview with offset not multiple of elementsPerByte is not "
270
+ " supported" );
271
+ }
272
+
273
+ size = ceilDiv (size, elementsPerByte);
274
+ offset = offset / elementsPerByte;
275
+
276
+ rewriter.replaceOpWithNewOp <memref::SubViewOp>(
277
+ op, newTy, *adaptor.getODSOperands (0 ).begin (), offset, size,
278
+ op.getStaticStrides ());
279
+ return success ();
280
+ }
281
+ };
282
+
212
283
} // end anonymous namespace
213
284
214
285
// ===----------------------------------------------------------------------===//
@@ -220,9 +291,9 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
220
291
RewritePatternSet &patterns) {
221
292
222
293
// Populate `memref.*` conversion patterns.
223
- patterns
224
- . add <ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefAssumeAlignment>(
225
- typeConverter, patterns.getContext ());
294
+ patterns. add <ConvertMemRefAlloc, ConvertMemRefLoad,
295
+ ConvertMemRefAssumeAlignment, ConvertMemRefSubview >(
296
+ typeConverter, patterns.getContext ());
226
297
memref::populateResolveExtractStridedMetadataPatterns (patterns);
227
298
}
228
299
@@ -271,9 +342,22 @@ void memref::populateMemRefNarrowTypeEmulationConversions(
271
342
return std::nullopt;
272
343
273
344
StridedLayoutAttr layoutAttr;
345
+ // If the offset is 0, we do not need a strided layout as the stride is
346
+ // 1, so we only use the strided layout if the offset is not 0.
274
347
if (offset != 0 ) {
275
- layoutAttr = StridedLayoutAttr::get (ty.getContext (), offset,
276
- ArrayRef<int64_t >{1 });
348
+ if (offset == ShapedType::kDynamic ) {
349
+ layoutAttr = StridedLayoutAttr::get (ty.getContext (), offset,
350
+ ArrayRef<int64_t >{1 });
351
+ } else {
352
+ // Check if the number of bytes are a multiple of the loadStoreWidth
353
+ // and if so, divide it by the loadStoreWidth to get the offset.
354
+ if ((offset * width) % loadStoreWidth != 0 )
355
+ return std::nullopt;
356
+ offset = (offset * width) / loadStoreWidth;
357
+
358
+ layoutAttr = StridedLayoutAttr::get (ty.getContext (), offset,
359
+ ArrayRef<int64_t >{1 });
360
+ }
277
361
}
278
362
279
363
return MemRefType::get (getLinearizedShape (ty, width, loadStoreWidth),
0 commit comments