Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 86 additions & 5 deletions mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,74 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
return success();
}
};

//===----------------------------------------------------------------------===//
// ConvertMemRefSubview
//===----------------------------------------------------------------------===//

struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto convertedType =
cast<MemRefType>(getTypeConverter()->convertType(op.getSourceType()));
auto convertedElementType = convertedType.getElementType();
auto oldElementType = op.getSourceType().getElementType();
int srcBits = oldElementType.getIntOrFloatBitWidth();
int dstBits = convertedElementType.getIntOrFloatBitWidth();
if (dstBits % srcBits != 0) {
return rewriter.notifyMatchFailure(
op, "only dstBits % srcBits == 0 supported");
}

MemRefType newTy =
cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
if (!newTy) {
return rewriter.notifyMatchFailure(
op->getLoc(),
llvm::formatv("failed to convert memref type: {0}", op.getType()));
}

// Only support offset for 1-D subview.
if (op.getType().getRank() != 1) {
return rewriter.notifyMatchFailure(
op->getLoc(), "subview with rank > 1 is not supported");
}

// Only support stride of 1.
if (op.getStaticStride(0) != 1) {
return rewriter.notifyMatchFailure(
op->getLoc(), "subview with stride != 1 is not supported");
}

auto size = op.getStaticSize(0);
auto offset = op.getStaticOffset(0);
// Only support static sizes and offsets.
if (size == ShapedType::kDynamic || offset == ShapedType::kDynamic) {
return rewriter.notifyMatchFailure(
op->getLoc(), "subview with dynamic size or offset is not supported");
}

int elementsPerByte = dstBits / srcBits;
if (size % elementsPerByte != 0 || offset % elementsPerByte != 0) {
return rewriter.notifyMatchFailure(
op->getLoc(),
"subview with size or offset not multiple of elementsPerByte is not "
"supported");
}

size = size / elementsPerByte;
offset = offset / elementsPerByte;

rewriter.replaceOpWithNewOp<memref::SubViewOp>(
op, newTy, *adaptor.getODSOperands(0).begin(), offset, size,
op.getStaticStrides());
return success();
}
};

} // end anonymous namespace

//===----------------------------------------------------------------------===//
Expand All @@ -220,9 +288,9 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
RewritePatternSet &patterns) {

// Populate `memref.*` conversion patterns.
patterns
.add<ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefAssumeAlignment>(
typeConverter, patterns.getContext());
patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad,
ConvertMemRefAssumeAlignment, ConvertMemRefSubview>(
typeConverter, patterns.getContext());
memref::populateResolveExtractStridedMetadataPatterns(patterns);
}

Expand Down Expand Up @@ -271,9 +339,22 @@ void memref::populateMemRefNarrowTypeEmulationConversions(
return std::nullopt;

StridedLayoutAttr layoutAttr;
// If the offset is 0, we do not need a strided layout as the stride is
// 1, so we only use the strided layout if the offset is not 0.
if (offset != 0) {
layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
ArrayRef<int64_t>{1});
if (offset == ShapedType::kDynamic) {
layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
ArrayRef<int64_t>{1});
} else {
// Check if the number of bytes are a multiple of the loadStoreWidth
// and if so, divide it by the loadStoreWidth to get the offset.
if ((offset * width) % loadStoreWidth != 0)
return std::nullopt;
offset = (offset * width) / loadStoreWidth;

layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
ArrayRef<int64_t>{1});
}
}

return MemRefType::get(getLinearizedShape(ty, width, loadStoreWidth),
Expand Down
19 changes: 19 additions & 0 deletions mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,22 @@ func.func @rank_zero_memref() -> i4 {
// CHECK32: %[[LOAD:.+]] = memref.load %[[ALLOC]][] : memref<i32>
// CHECK32: %[[TRUNC:.+]] = arith.trunci %[[LOAD]] : i32 to i4
// CHECK32: return %[[TRUNC]]

// -----

func.func @memref_strided_i4(%idx : index) -> i4 {
%arr = memref.alloc() : memref<128xi4>
%subview = memref.subview %arr[32] [32] [1] : memref<128xi4> to memref<32xi4, strided<[1], offset:32>>
%1 = memref.load %subview[%idx] : memref<32xi4, strided<[1], offset:32>>
return %1 : i4
}

// CHECK-LABEL: func @memref_strided_i4
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<64xi8>
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][16] [16] [1] : memref<64xi8> to memref<16xi8, strided<[1], offset: 16>>
// CHECK: %[[LOAD:.+]] = memref.load %[[SUBVIEW]]

// CHECK32-LABEL: func @memref_strided_i4
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<16xi32>
// CHECK32: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][4] [4] [1] : memref<16xi32> to memref<4xi32, strided<[1], offset: 4>>
// CHECK32: %[[LOAD:.+]] = memref.load %[[SUBVIEW]]