diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp index c5643f6e2f830..dfa2e4e0376ed 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp @@ -85,11 +85,11 @@ Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot, // TODO: support more types. return TypeSwitch(slot.elemType) .Case([&](MemRefType t) { - return builder.create(getLoc(), t); + return memref::AllocaOp::create(builder, getLoc(), t); }) .Default([&](Type t) { - return builder.create(getLoc(), t, - builder.getZeroAttr(t)); + return arith::ConstantOp::create(builder, getLoc(), t, + builder.getZeroAttr(t)); }); } @@ -135,7 +135,7 @@ DenseMap memref::AllocaOp::destructure( for (Attribute usedIndex : usedIndices) { Type elemType = memrefType.getTypeAtIndex(usedIndex); MemRefType elemPtr = MemRefType::get({}, elemType); - auto subAlloca = builder.create(getLoc(), elemPtr); + auto subAlloca = memref::AllocaOp::create(builder, getLoc(), elemPtr); newAllocators.push_back(subAlloca); slotMap.try_emplace(usedIndex, {subAlloca.getResult(), elemType}); diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 51c813682ce25..74b968c27a62a 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -213,9 +213,9 @@ struct SimplifyAllocConst : public OpRewritePattern { assert(dynamicSizes.size() == newMemRefType.getNumDynamicDims()); // Create and insert the alloc op for the new memref. - auto newAlloc = rewriter.create( - alloc.getLoc(), newMemRefType, dynamicSizes, alloc.getSymbolOperands(), - alloc.getAlignmentAttr()); + auto newAlloc = AllocLikeOp::create(rewriter, alloc.getLoc(), newMemRefType, + dynamicSizes, alloc.getSymbolOperands(), + alloc.getAlignmentAttr()); // Insert a cast so we have the same type as the old alloc. rewriter.replaceOpWithNewOp(alloc, alloc.getType(), newAlloc); return success(); @@ -797,7 +797,7 @@ void DimOp::getAsmResultNames(function_ref setNameFn) { void DimOp::build(OpBuilder &builder, OperationState &result, Value source, int64_t index) { auto loc = result.location; - Value indexValue = builder.create(loc, index); + Value indexValue = arith::ConstantIndexOp::create(builder, loc, index); build(builder, result, source, indexValue); } @@ -1044,9 +1044,9 @@ struct DimOfMemRefReshape : public OpRewritePattern { rewriter.setInsertionPointAfter(reshape); Location loc = dim.getLoc(); Value load = - rewriter.create(loc, reshape.getShape(), dim.getIndex()); + LoadOp::create(rewriter, loc, reshape.getShape(), dim.getIndex()); if (load.getType() != dim.getType()) - load = rewriter.create(loc, dim.getType(), load); + load = arith::IndexCastOp::create(rewriter, loc, dim.getType(), load); rewriter.replaceOp(dim, load); return success(); } @@ -1319,8 +1319,9 @@ static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, assert(isa(maybeConstant) && "The constified value should be either unchanged (i.e., == result) " "or a constant"); - Value constantVal = rewriter.create( - loc, llvm::cast(cast(maybeConstant)).getInt()); + Value constantVal = arith::ConstantIndexOp::create( + rewriter, loc, + llvm::cast(cast(maybeConstant)).getInt()); for (Operation *op : llvm::make_early_inc_range(result.getUsers())) { // modifyOpInPlace: lambda cannot capture structured bindings in C++17 // yet. @@ -2548,8 +2549,9 @@ struct CollapseShapeOpMemRefCastFolder rewriter.modifyOpInPlace( op, [&]() { op.getSrcMutable().assign(cast.getSource()); }); } else { - Value newOp = rewriter.create( - op->getLoc(), cast.getSource(), op.getReassociationIndices()); + Value newOp = + CollapseShapeOp::create(rewriter, op->getLoc(), cast.getSource(), + op.getReassociationIndices()); rewriter.replaceOpWithNewOp(op, op.getType(), newOp); } return success(); @@ -3006,15 +3008,15 @@ SmallVector mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op, Value offset = op.isDynamicOffset(idx) ? op.getDynamicOffset(idx) - : b.create(loc, op.getStaticOffset(idx)); + : arith::ConstantIndexOp::create(b, loc, op.getStaticOffset(idx)); Value size = op.isDynamicSize(idx) ? op.getDynamicSize(idx) - : b.create(loc, op.getStaticSize(idx)); + : arith::ConstantIndexOp::create(b, loc, op.getStaticSize(idx)); Value stride = op.isDynamicStride(idx) ? op.getDynamicStride(idx) - : b.create(loc, op.getStaticStride(idx)); + : arith::ConstantIndexOp::create(b, loc, op.getStaticStride(idx)); res.emplace_back(Range{offset, size, stride}); } return res; @@ -3173,8 +3175,8 @@ class SubViewOpMemRefCastFolder final : public OpRewritePattern { if (!resultType) return failure(); - Value newSubView = rewriter.create( - subViewOp.getLoc(), resultType, castOp.getSource(), + Value newSubView = SubViewOp::create( + rewriter, subViewOp.getLoc(), resultType, castOp.getSource(), subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(), subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(), subViewOp.getStaticStrides()); @@ -3495,9 +3497,9 @@ struct ViewOpShapeFolder : public OpRewritePattern { return failure(); // Create new ViewOp. - auto newViewOp = rewriter.create( - viewOp.getLoc(), newMemRefType, viewOp.getOperand(0), - viewOp.getByteShift(), newOperands); + auto newViewOp = ViewOp::create(rewriter, viewOp.getLoc(), newMemRefType, + viewOp.getOperand(0), viewOp.getByteShift(), + newOperands); // Insert a cast so we have the same type as the old memref type. rewriter.replaceOpWithNewOp(viewOp, viewOp.getType(), newViewOp); return success(); diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp index 0c03670b4535f..95eb2a9a95bc1 100644 --- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp +++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp @@ -155,9 +155,10 @@ transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter, Type resultType = alloca.getResult().getType(); OpBuilder builder(rewriter.getContext()); // TODO: Add a better builder for this. - globalOp = builder.create( - loc, StringAttr::get(ctx, "alloca"), StringAttr::get(ctx, "private"), - TypeAttr::get(resultType), Attribute{}, UnitAttr{}, IntegerAttr{}); + globalOp = memref::GlobalOp::create( + builder, loc, StringAttr::get(ctx, "alloca"), + StringAttr::get(ctx, "private"), TypeAttr::get(resultType), + Attribute{}, UnitAttr{}, IntegerAttr{}); symbolTable.insert(globalOp); } diff --git a/mlir/lib/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.cpp index c433415944323..75cc39e61656a 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.cpp @@ -22,11 +22,11 @@ struct DefaultAllocationInterface DefaultAllocationInterface, memref::AllocOp> { static std::optional buildDealloc(OpBuilder &builder, Value alloc) { - return builder.create(alloc.getLoc(), alloc) + return memref::DeallocOp::create(builder, alloc.getLoc(), alloc) .getOperation(); } static std::optional buildClone(OpBuilder &builder, Value alloc) { - return builder.create(alloc.getLoc(), alloc) + return bufferization::CloneOp::create(builder, alloc.getLoc(), alloc) .getResult(); } static ::mlir::HoistingKind getHoistingKind() { @@ -35,8 +35,9 @@ struct DefaultAllocationInterface static ::std::optional<::mlir::Operation *> buildPromotedAlloc(OpBuilder &builder, Value alloc) { Operation *definingOp = alloc.getDefiningOp(); - return builder.create( - definingOp->getLoc(), cast(definingOp->getResultTypes()[0]), + return memref::AllocaOp::create( + builder, definingOp->getLoc(), + cast(definingOp->getResultTypes()[0]), definingOp->getOperands(), definingOp->getAttrs()); } }; @@ -52,7 +53,7 @@ struct DefaultReallocationInterface DefaultAllocationInterface, memref::ReallocOp> { static std::optional buildDealloc(OpBuilder &builder, Value realloc) { - return builder.create(realloc.getLoc(), realloc) + return memref::DeallocOp::create(builder, realloc.getLoc(), realloc) .getOperation(); } }; diff --git a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp index 7c777e807f08c..106c3b458dbac 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp @@ -124,8 +124,8 @@ struct ComposeSubViewOpPattern : public OpRewritePattern { } AffineMap map = AffineMap::get(0, affineApplyOperands.size(), expr); - Value result = rewriter.create( - op.getLoc(), map, affineApplyOperands); + Value result = affine::AffineApplyOp::create(rewriter, op.getLoc(), map, + affineApplyOperands); offsets.push_back(result); } } diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp index ec2bc95291455..556ea1a8e9c40 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp @@ -99,7 +99,7 @@ static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx, affine::makeComposedFoldedAffineApply(builder, loc, offsetExpr, {srcIdx}); Value bitOffset = getValueOrCreateConstantIndexOp(builder, loc, offsetVal); IntegerType dstType = builder.getIntegerType(targetBits); - return builder.create(loc, dstType, bitOffset); + return arith::IndexCastOp::create(builder, loc, dstType, bitOffset); } /// When writing a subbyte size, masked bitwise operations are used to only @@ -112,14 +112,14 @@ static Value getSubByteWriteMask(Location loc, OpFoldResult linearizedIndices, auto dstIntegerType = builder.getIntegerType(dstBits); auto maskRightAlignedAttr = builder.getIntegerAttr(dstIntegerType, (1 << srcBits) - 1); - Value maskRightAligned = builder.create( - loc, dstIntegerType, maskRightAlignedAttr); + Value maskRightAligned = arith::ConstantOp::create( + builder, loc, dstIntegerType, maskRightAlignedAttr); Value writeMaskInverse = - builder.create(loc, maskRightAligned, bitwidthOffset); + arith::ShLIOp::create(builder, loc, maskRightAligned, bitwidthOffset); auto flipValAttr = builder.getIntegerAttr(dstIntegerType, -1); Value flipVal = - builder.create(loc, dstIntegerType, flipValAttr); - return builder.create(loc, writeMaskInverse, flipVal); + arith::ConstantOp::create(builder, loc, dstIntegerType, flipValAttr); + return arith::XOrIOp::create(builder, loc, writeMaskInverse, flipVal); } /// Returns the scaled linearized index based on the `srcBits` and `dstBits` @@ -141,7 +141,7 @@ getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits, const SmallVector &indices, Value memref) { auto stridedMetadata = - builder.create(loc, memref); + memref::ExtractStridedMetadataOp::create(builder, loc, memref); OpFoldResult linearizedIndices; std::tie(std::ignore, linearizedIndices) = memref::getLinearizedMemRefOffsetAndSize( @@ -298,16 +298,16 @@ struct ConvertMemRefLoad final : OpConversionPattern { // Special case 0-rank memref loads. Value bitsLoad; if (convertedType.getRank() == 0) { - bitsLoad = rewriter.create(loc, adaptor.getMemref(), - ValueRange{}); + bitsLoad = memref::LoadOp::create(rewriter, loc, adaptor.getMemref(), + ValueRange{}); } else { // Linearize the indices of the original load instruction. Do not account // for the scaling yet. This will be accounted for later. OpFoldResult linearizedIndices = getLinearizedSrcIndices( rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef()); - Value newLoad = rewriter.create( - loc, adaptor.getMemref(), + Value newLoad = memref::LoadOp::create( + rewriter, loc, adaptor.getMemref(), getIndicesForLoadOrStore(rewriter, loc, linearizedIndices, srcBits, dstBits)); @@ -315,7 +315,7 @@ struct ConvertMemRefLoad final : OpConversionPattern { // Note, currently only the big-endian is supported. Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices, srcBits, dstBits, rewriter); - bitsLoad = rewriter.create(loc, newLoad, bitwidthOffset); + bitsLoad = arith::ShRSIOp::create(rewriter, loc, newLoad, bitwidthOffset); } // Get the corresponding bits. If the arith computation bitwidth equals @@ -331,17 +331,17 @@ struct ConvertMemRefLoad final : OpConversionPattern { : IntegerType::get(rewriter.getContext(), resultTy.getIntOrFloatBitWidth()); if (conversionTy == convertedElementType) { - auto mask = rewriter.create( - loc, convertedElementType, + auto mask = arith::ConstantOp::create( + rewriter, loc, convertedElementType, rewriter.getIntegerAttr(convertedElementType, (1 << srcBits) - 1)); - result = rewriter.create(loc, bitsLoad, mask); + result = arith::AndIOp::create(rewriter, loc, bitsLoad, mask); } else { - result = rewriter.create(loc, conversionTy, bitsLoad); + result = arith::TruncIOp::create(rewriter, loc, conversionTy, bitsLoad); } if (conversionTy != resultTy) { - result = rewriter.create(loc, resultTy, result); + result = arith::BitcastOp::create(rewriter, loc, resultTy, result); } rewriter.replaceOp(op, result); @@ -428,20 +428,20 @@ struct ConvertMemrefStore final : OpConversionPattern { // Pad the input value with 0s on the left. Value input = adaptor.getValue(); if (!input.getType().isInteger()) { - input = rewriter.create( - loc, + input = arith::BitcastOp::create( + rewriter, loc, IntegerType::get(rewriter.getContext(), input.getType().getIntOrFloatBitWidth()), input); } Value extendedInput = - rewriter.create(loc, dstIntegerType, input); + arith::ExtUIOp::create(rewriter, loc, dstIntegerType, input); // Special case 0-rank memref stores. No need for masking. if (convertedType.getRank() == 0) { - rewriter.create(loc, arith::AtomicRMWKind::assign, - extendedInput, adaptor.getMemref(), - ValueRange{}); + memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::assign, + extendedInput, adaptor.getMemref(), + ValueRange{}); rewriter.eraseOp(op); return success(); } @@ -456,16 +456,14 @@ struct ConvertMemrefStore final : OpConversionPattern { dstBits, bitwidthOffset, rewriter); // Align the value to write with the destination bits Value alignedVal = - rewriter.create(loc, extendedInput, bitwidthOffset); + arith::ShLIOp::create(rewriter, loc, extendedInput, bitwidthOffset); // Clear destination bits - rewriter.create(loc, arith::AtomicRMWKind::andi, - writeMask, adaptor.getMemref(), - storeIndices); + memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::andi, + writeMask, adaptor.getMemref(), storeIndices); // Write srcs bits to destination - rewriter.create(loc, arith::AtomicRMWKind::ori, - alignedVal, adaptor.getMemref(), - storeIndices); + memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::ori, + alignedVal, adaptor.getMemref(), storeIndices); rewriter.eraseOp(op); return success(); } @@ -525,8 +523,8 @@ struct ConvertMemRefSubview final : OpConversionPattern { } // Transform the offsets, sizes and strides according to the emulation. - auto stridedMetadata = rewriter.create( - loc, subViewOp.getViewSource()); + auto stridedMetadata = memref::ExtractStridedMetadataOp::create( + rewriter, loc, subViewOp.getViewSource()); OpFoldResult linearizedIndices; auto strides = stridedMetadata.getConstifiedMixedStrides(); diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp index e6e4c3b07ecb8..17a148cc31dc0 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp @@ -48,15 +48,15 @@ struct MemRefReshapeOpConverter : public OpRewritePattern { Value size; // Load dynamic sizes from the shape input, use constants for static dims. if (op.getType().isDynamicDim(i)) { - Value index = rewriter.create(loc, i); - size = rewriter.create(loc, op.getShape(), index); + Value index = arith::ConstantIndexOp::create(rewriter, loc, i); + size = memref::LoadOp::create(rewriter, loc, op.getShape(), index); if (!isa(size.getType())) - size = rewriter.create( - loc, rewriter.getIndexType(), size); + size = arith::IndexCastOp::create(rewriter, loc, + rewriter.getIndexType(), size); sizes[i] = size; } else { auto sizeAttr = rewriter.getIndexAttr(op.getType().getDimSize(i)); - size = rewriter.create(loc, sizeAttr); + size = arith::ConstantOp::create(rewriter, loc, sizeAttr); sizes[i] = sizeAttr; } if (stride) @@ -66,10 +66,11 @@ struct MemRefReshapeOpConverter : public OpRewritePattern { if (i > 0) { if (stride) { - stride = rewriter.create(loc, stride, size); + stride = arith::MulIOp::create(rewriter, loc, stride, size); } else if (op.getType().isDynamicDim(i)) { - stride = rewriter.create( - loc, rewriter.create(loc, staticStride), + stride = arith::MulIOp::create( + rewriter, loc, + arith::ConstantIndexOp::create(rewriter, loc, staticStride), size); } else { staticStride *= op.getType().getDimSize(i); diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandRealloc.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandRealloc.cpp index 7475d442b7b9a..01d32621b2055 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandRealloc.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandRealloc.cpp @@ -73,7 +73,7 @@ struct ExpandReallocOpPattern : public OpRewritePattern { if (ShapedType::isDynamic(inputSize)) { Value dimZero = getValueOrCreateConstantIndexOp(rewriter, loc, rewriter.getIndexAttr(0)); - currSize = rewriter.create(loc, op.getSource(), dimZero) + currSize = memref::DimOp::create(rewriter, loc, op.getSource(), dimZero) .getResult(); } @@ -88,10 +88,10 @@ struct ExpandReallocOpPattern : public OpRewritePattern { // the old buffer is smaller than the requested size. Value lhs = getValueOrCreateConstantIndexOp(rewriter, loc, currSize); Value rhs = getValueOrCreateConstantIndexOp(rewriter, loc, targetSize); - Value cond = rewriter.create(loc, arith::CmpIPredicate::ult, - lhs, rhs); - auto ifOp = rewriter.create( - loc, cond, + Value cond = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult, + lhs, rhs); + auto ifOp = scf::IfOp::create( + rewriter, loc, cond, [&](OpBuilder &builder, Location loc) { // Allocate the new buffer. If it is a dynamic memref we need to pass // an additional operand for the size at runtime, otherwise the static @@ -100,25 +100,26 @@ struct ExpandReallocOpPattern : public OpRewritePattern { if (op.getDynamicResultSize()) dynamicSizeOperands.push_back(op.getDynamicResultSize()); - Value newAlloc = builder.create( - loc, op.getResult().getType(), dynamicSizeOperands, + Value newAlloc = memref::AllocOp::create( + builder, loc, op.getResult().getType(), dynamicSizeOperands, op.getAlignmentAttr()); // Take a subview of the new (bigger) buffer such that we can copy the // old values over (the copy operation requires both operands to have // the same shape). - Value subview = builder.create( - loc, newAlloc, ArrayRef{rewriter.getIndexAttr(0)}, + Value subview = memref::SubViewOp::create( + builder, loc, newAlloc, + ArrayRef{rewriter.getIndexAttr(0)}, ArrayRef{currSize}, ArrayRef{rewriter.getIndexAttr(1)}); - builder.create(loc, op.getSource(), subview); + memref::CopyOp::create(builder, loc, op.getSource(), subview); // Insert the deallocation of the old buffer only if requested // (enabled by default). if (emitDeallocs) - builder.create(loc, op.getSource()); + memref::DeallocOp::create(builder, loc, op.getSource()); - builder.create(loc, newAlloc); + scf::YieldOp::create(builder, loc, newAlloc); }, [&](OpBuilder &builder, Location loc) { // We need to reinterpret-cast here because either the input or output @@ -126,11 +127,12 @@ struct ExpandReallocOpPattern : public OpRewritePattern { // dynamic or vice-versa. If both are static and the original buffer // is already bigger than the requested size, the cast represents a // subview operation. - Value casted = builder.create( - loc, cast(op.getResult().getType()), op.getSource(), - rewriter.getIndexAttr(0), ArrayRef{targetSize}, + Value casted = memref::ReinterpretCastOp::create( + builder, loc, cast(op.getResult().getType()), + op.getSource(), rewriter.getIndexAttr(0), + ArrayRef{targetSize}, ArrayRef{rewriter.getIndexAttr(1)}); - builder.create(loc, casted); + scf::YieldOp::create(builder, loc, casted); }); rewriter.replaceOp(op, ifOp.getResult(0)); diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp index 2ba798f48ac7c..9771bd2aaa143 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -66,7 +66,7 @@ resolveSubviewStridedMetadata(RewriterBase &rewriter, unsigned sourceRank = sourceType.getRank(); auto newExtractStridedMetadata = - rewriter.create(origLoc, source); + memref::ExtractStridedMetadataOp::create(rewriter, origLoc, source); auto [sourceStrides, sourceOffset] = sourceType.getStridesAndOffset(); #ifndef NDEBUG @@ -577,7 +577,7 @@ static FailureOr resolveReshapeStridedMetadata( unsigned sourceRank = sourceType.getRank(); auto newExtractStridedMetadata = - rewriter.create(origLoc, source); + memref::ExtractStridedMetadataOp::create(rewriter, origLoc, source); // Collect statically known information. auto [strides, offset] = sourceType.getStridesAndOffset(); @@ -828,14 +828,14 @@ struct ExtractStridedMetadataOpAllocFolder if (allocLikeOp.getType() == baseBufferType) results.push_back(allocLikeOp); else - results.push_back(rewriter.create( - loc, baseBufferType, allocLikeOp, offset, + results.push_back(memref::ReinterpretCastOp::create( + rewriter, loc, baseBufferType, allocLikeOp, offset, /*sizes=*/ArrayRef(), /*strides=*/ArrayRef())); } // Offset. - results.push_back(rewriter.create(loc, offset)); + results.push_back(arith::ConstantIndexOp::create(rewriter, loc, offset)); for (OpFoldResult size : sizes) results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, size)); @@ -900,19 +900,19 @@ struct ExtractStridedMetadataOpGetGlobalFolder if (getGlobalOp.getType() == baseBufferType) results.push_back(getGlobalOp); else - results.push_back(rewriter.create( - loc, baseBufferType, getGlobalOp, offset, + results.push_back(memref::ReinterpretCastOp::create( + rewriter, loc, baseBufferType, getGlobalOp, offset, /*sizes=*/ArrayRef(), /*strides=*/ArrayRef())); // Offset. - results.push_back(rewriter.create(loc, offset)); + results.push_back(arith::ConstantIndexOp::create(rewriter, loc, offset)); for (auto size : sizes) - results.push_back(rewriter.create(loc, size)); + results.push_back(arith::ConstantIndexOp::create(rewriter, loc, size)); for (auto stride : strides) - results.push_back(rewriter.create(loc, stride)); + results.push_back(arith::ConstantIndexOp::create(rewriter, loc, stride)); rewriter.replaceOp(op, results); return success(); @@ -1008,9 +1008,8 @@ class ExtractStridedMetadataOpReinterpretCastFolder SmallVector results; results.resize_for_overwrite(rank * 2 + 2); - auto newExtractStridedMetadata = - rewriter.create( - loc, reinterpretCastOp.getSource()); + auto newExtractStridedMetadata = memref::ExtractStridedMetadataOp::create( + rewriter, loc, reinterpretCastOp.getSource()); // Register the base_buffer. results[0] = newExtractStridedMetadata.getBaseBuffer(); @@ -1082,9 +1081,8 @@ class ExtractStridedMetadataOpCastFolder SmallVector results; results.resize_for_overwrite(rank * 2 + 2); - auto newExtractStridedMetadata = - rewriter.create(loc, - castOp.getSource()); + auto newExtractStridedMetadata = memref::ExtractStridedMetadataOp::create( + rewriter, loc, castOp.getSource()); // Register the base_buffer. results[0] = newExtractStridedMetadata.getBaseBuffer(); @@ -1142,9 +1140,8 @@ class ExtractStridedMetadataOpMemorySpaceCastFolder auto memSpaceCastOp = source.getDefiningOp(); if (!memSpaceCastOp) return failure(); - auto newExtractStridedMetadata = - rewriter.create( - loc, memSpaceCastOp.getSource()); + auto newExtractStridedMetadata = memref::ExtractStridedMetadataOp::create( + rewriter, loc, memSpaceCastOp.getSource()); SmallVector results(newExtractStridedMetadata.getResults()); // As with most other strided metadata rewrite patterns, don't introduce // a use of the base pointer where non existed. This needs to happen here, @@ -1158,8 +1155,8 @@ class ExtractStridedMetadataOpMemorySpaceCastFolder MemRefType::Builder newTypeBuilder(baseBufferType); newTypeBuilder.setMemorySpace( memSpaceCastOp.getResult().getType().getMemorySpace()); - results[0] = rewriter.create( - loc, Type{newTypeBuilder}, baseBuffer); + results[0] = memref::MemorySpaceCastOp::create( + rewriter, loc, Type{newTypeBuilder}, baseBuffer); } else { results[0] = nullptr; } diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp index 2f5c9436fb8c7..0946da8e4e919 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp @@ -42,8 +42,8 @@ static memref::LoadOp rebuildLoadOp(RewriterBase &rewriter, memref::LoadOp loadOp, Value srcMemRef, ArrayRef indices) { Location loc = loadOp.getLoc(); - return rewriter.create(loc, srcMemRef, indices, - loadOp.getNontemporal()); + return memref::LoadOp::create(rewriter, loc, srcMemRef, indices, + loadOp.getNontemporal()); } // Matches getViewSizeForEachDim specs for LoadOp. @@ -72,9 +72,8 @@ static memref::StoreOp rebuildStoreOp(RewriterBase &rewriter, memref::StoreOp storeOp, Value srcMemRef, ArrayRef indices) { Location loc = storeOp.getLoc(); - return rewriter.create(loc, storeOp.getValueToStore(), - srcMemRef, indices, - storeOp.getNontemporal()); + return memref::StoreOp::create(rewriter, loc, storeOp.getValueToStore(), + srcMemRef, indices, storeOp.getNontemporal()); } // Matches getViewSizeForEachDim specs for StoreOp. @@ -104,8 +103,8 @@ static nvgpu::LdMatrixOp rebuildLdMatrixOp(RewriterBase &rewriter, Value srcMemRef, ArrayRef indices) { Location loc = ldMatrixOp.getLoc(); - return rewriter.create( - loc, ldMatrixOp.getResult().getType(), srcMemRef, indices, + return nvgpu::LdMatrixOp::create( + rewriter, loc, ldMatrixOp.getResult().getType(), srcMemRef, indices, ldMatrixOp.getTranspose(), ldMatrixOp.getNumTiles()); } @@ -132,8 +131,8 @@ rebuildTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp transferReadOp, Value srcMemRef, ArrayRef indices) { Location loc = transferReadOp.getLoc(); - return rewriter.create( - loc, transferReadOp.getResult().getType(), srcMemRef, indices, + return vector::TransferReadOp::create( + rewriter, loc, transferReadOp.getResult().getType(), srcMemRef, indices, transferReadOp.getPermutationMap(), transferReadOp.getPadding(), transferReadOp.getMask(), transferReadOp.getInBoundsAttr()); } @@ -150,8 +149,8 @@ rebuildTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp transferWriteOp, Value srcMemRef, ArrayRef indices) { Location loc = transferWriteOp.getLoc(); - return rewriter.create( - loc, transferWriteOp.getValue(), srcMemRef, indices, + return vector::TransferWriteOp::create( + rewriter, loc, transferWriteOp.getValue(), srcMemRef, indices, transferWriteOp.getPermutationMapAttr(), transferWriteOp.getMask(), transferWriteOp.getInBoundsAttr()); } @@ -182,9 +181,8 @@ static SmallVector getGenericOpViewSizeForEachDim(RewriterBase &rewriter, LoadStoreLikeOp loadStoreLikeOp) { Location loc = loadStoreLikeOp.getLoc(); - auto extractStridedMetadataOp = - rewriter.create( - loc, getSrcMemRef(loadStoreLikeOp)); + auto extractStridedMetadataOp = memref::ExtractStridedMetadataOp::create( + rewriter, loc, getSrcMemRef(loadStoreLikeOp)); SmallVector srcSizes = extractStridedMetadataOp.getConstifiedMixedSizes(); SmallVector indices = @@ -267,12 +265,12 @@ struct LoadStoreLikeOpRewriter : public OpRewritePattern { // apply them properly to the input indices. // Therefore the strides multipliers are simply ones. auto subview = - rewriter.create(loc, /*source=*/srcMemRef, - /*offsets=*/indices, - /*sizes=*/sizes, /*strides=*/ones); + memref::SubViewOp::create(rewriter, loc, /*source=*/srcMemRef, + /*offsets=*/indices, + /*sizes=*/sizes, /*strides=*/ones); // Rewrite the load/store with the subview as the base pointer. SmallVector zeros(loadStoreRank, - rewriter.create(loc, 0)); + arith::ConstantIndexOp::create(rewriter, loc, 0)); LoadStoreLikeOp newLoadStore = rebuildOpFromAddressAndIndices( rewriter, loadStoreLikeOp, subview.getResult(), zeros); rewriter.replaceOp(loadStoreLikeOp, newLoadStore->getResults()); diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp index 76f7788c4dcc5..42be847811d52 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp @@ -40,8 +40,8 @@ using namespace mlir; static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc, OpFoldResult in) { if (Attribute offsetAttr = dyn_cast(in)) { - return rewriter.create( - loc, cast(offsetAttr).getInt()); + return arith::ConstantIndexOp::create( + rewriter, loc, cast(offsetAttr).getInt()); } return cast(in); } @@ -60,7 +60,7 @@ static std::pair getFlattenMemrefAndOffset(OpBuilder &rewriter, } memref::ExtractStridedMetadataOp stridedMetadata = - rewriter.create(loc, source); + memref::ExtractStridedMetadataOp::create(rewriter, loc, source); auto typeBit = sourceType.getElementType().getIntOrFloatBitWidth(); OpFoldResult linearizedIndices; @@ -74,8 +74,8 @@ static std::pair getFlattenMemrefAndOffset(OpBuilder &rewriter, getAsOpFoldResult(indices)); return std::make_pair( - rewriter.create( - loc, source, + memref::ReinterpretCastOp::create( + rewriter, loc, source, /* offset = */ linearizedInfo.linearizedOffset, /* shapes = */ ArrayRef{linearizedInfo.linearizedSize}, @@ -111,7 +111,7 @@ template static void castAllocResult(T oper, T newOper, Location loc, PatternRewriter &rewriter) { memref::ExtractStridedMetadataOp stridedMetadata = - rewriter.create(loc, oper); + memref::ExtractStridedMetadataOp::create(rewriter, loc, oper); rewriter.replaceOpWithNewOp( oper, cast(oper.getType()), newOper, /*offset=*/rewriter.getIndexAttr(0), @@ -125,63 +125,68 @@ static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref, Location loc = op->getLoc(); llvm::TypeSwitch(op.getOperation()) .template Case([&](auto oper) { - auto newAlloc = rewriter.create( - loc, cast(flatMemref.getType()), + auto newAlloc = memref::AllocOp::create( + rewriter, loc, cast(flatMemref.getType()), oper.getAlignmentAttr()); castAllocResult(oper, newAlloc, loc, rewriter); }) .template Case([&](auto oper) { - auto newAlloca = rewriter.create( - loc, cast(flatMemref.getType()), + auto newAlloca = memref::AllocaOp::create( + rewriter, loc, cast(flatMemref.getType()), oper.getAlignmentAttr()); castAllocResult(oper, newAlloca, loc, rewriter); }) .template Case([&](auto op) { - auto newLoad = rewriter.create( - loc, op->getResultTypes(), flatMemref, ValueRange{offset}); + auto newLoad = + memref::LoadOp::create(rewriter, loc, op->getResultTypes(), + flatMemref, ValueRange{offset}); newLoad->setAttrs(op->getAttrs()); rewriter.replaceOp(op, newLoad.getResult()); }) .template Case([&](auto op) { - auto newStore = rewriter.create( - loc, op->getOperands().front(), flatMemref, ValueRange{offset}); + auto newStore = + memref::StoreOp::create(rewriter, loc, op->getOperands().front(), + flatMemref, ValueRange{offset}); newStore->setAttrs(op->getAttrs()); rewriter.replaceOp(op, newStore); }) .template Case([&](auto op) { - auto newLoad = rewriter.create( - loc, op->getResultTypes(), flatMemref, ValueRange{offset}); + auto newLoad = + vector::LoadOp::create(rewriter, loc, op->getResultTypes(), + flatMemref, ValueRange{offset}); newLoad->setAttrs(op->getAttrs()); rewriter.replaceOp(op, newLoad.getResult()); }) .template Case([&](auto op) { - auto newStore = rewriter.create( - loc, op->getOperands().front(), flatMemref, ValueRange{offset}); + auto newStore = + vector::StoreOp::create(rewriter, loc, op->getOperands().front(), + flatMemref, ValueRange{offset}); newStore->setAttrs(op->getAttrs()); rewriter.replaceOp(op, newStore); }) .template Case([&](auto op) { - auto newMaskedLoad = rewriter.create( - loc, op.getType(), flatMemref, ValueRange{offset}, op.getMask(), - op.getPassThru()); + auto newMaskedLoad = vector::MaskedLoadOp::create( + rewriter, loc, op.getType(), flatMemref, ValueRange{offset}, + op.getMask(), op.getPassThru()); newMaskedLoad->setAttrs(op->getAttrs()); rewriter.replaceOp(op, newMaskedLoad.getResult()); }) .template Case([&](auto op) { - auto newMaskedStore = rewriter.create( - loc, flatMemref, ValueRange{offset}, op.getMask(), + auto newMaskedStore = vector::MaskedStoreOp::create( + rewriter, loc, flatMemref, ValueRange{offset}, op.getMask(), op.getValueToStore()); newMaskedStore->setAttrs(op->getAttrs()); rewriter.replaceOp(op, newMaskedStore); }) .template Case([&](auto op) { - auto newTransferRead = rewriter.create( - loc, op.getType(), flatMemref, ValueRange{offset}, op.getPadding()); + auto newTransferRead = vector::TransferReadOp::create( + rewriter, loc, op.getType(), flatMemref, ValueRange{offset}, + op.getPadding()); rewriter.replaceOp(op, newTransferRead.getResult()); }) .template Case([&](auto op) { - auto newTransferWrite = rewriter.create( - loc, op.getVector(), flatMemref, ValueRange{offset}); + auto newTransferWrite = vector::TransferWriteOp::create( + rewriter, loc, op.getVector(), flatMemref, ValueRange{offset}); rewriter.replaceOp(op, newTransferWrite); }) .Default([&](auto op) { diff --git a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp index 35c661ecb886d..66c1aa6bf3fe1 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp @@ -51,7 +51,7 @@ FailureOr memref::buildIndependentOp(OpBuilder &b, // Create a new memref::AllocaOp. Value newAllocaOp = - b.create(loc, newSizes, allocaOp.getType().getElementType()); + AllocaOp::create(b, loc, newSizes, allocaOp.getType().getElementType()); // Create a memref::SubViewOp. SmallVector offsets(newSizes.size(), b.getIndexAttr(0)); @@ -71,11 +71,11 @@ propagateSubViewOp(RewriterBase &rewriter, MemRefType newResultType = SubViewOp::inferRankReducedResultType( op.getType().getShape(), op.getSourceType(), op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides()); - Value newSubview = rewriter.create( - op.getLoc(), newResultType, conversionOp.getOperand(0), + Value newSubview = SubViewOp::create( + rewriter, op.getLoc(), newResultType, conversionOp.getOperand(0), op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides()); - auto newConversionOp = rewriter.create( - op.getLoc(), op.getType(), newSubview); + auto newConversionOp = UnrealizedConversionCastOp::create( + rewriter, op.getLoc(), op.getType(), newSubview); rewriter.replaceAllUsesWith(op.getResult(), newConversionOp->getResult(0)); return newConversionOp; } @@ -106,8 +106,8 @@ static void replaceAndPropagateMemRefType(RewriterBase &rewriter, SmallVector unrealizedConversions; for (const auto &it : llvm::enumerate(llvm::zip(from->getResults(), to->getResults()))) { - unrealizedConversions.push_back(rewriter.create( - to->getLoc(), std::get<0>(it.value()).getType(), + unrealizedConversions.push_back(UnrealizedConversionCastOp::create( + rewriter, to->getLoc(), std::get<0>(it.value()).getType(), std::get<1>(it.value()))); rewriter.replaceAllUsesWith(from->getResult(it.index()), unrealizedConversions.back()->getResult(0)); diff --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp index 0a84962150ead..5d3cec402cab1 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp @@ -63,9 +63,10 @@ static void replaceUsesAndPropagateType(RewriterBase &rewriter, subviewUse.getType().getShape(), cast(val.getType()), subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(), subviewUse.getStaticStrides()); - Value newSubview = rewriter.create( - subviewUse->getLoc(), newType, val, subviewUse.getMixedOffsets(), - subviewUse.getMixedSizes(), subviewUse.getMixedStrides()); + Value newSubview = memref::SubViewOp::create( + rewriter, subviewUse->getLoc(), newType, val, + subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(), + subviewUse.getMixedStrides()); // Ouch recursion ... is this really necessary? replaceUsesAndPropagateType(rewriter, subviewUse, newSubview); @@ -177,8 +178,8 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp, Location loc = allocOp->getLoc(); OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(allocOp); - auto mbAlloc = rewriter.create( - loc, mbMemRefType, ValueRange{}, allocOp->getAttrs()); + auto mbAlloc = memref::AllocOp::create(rewriter, loc, mbMemRefType, + ValueRange{}, allocOp->getAttrs()); LLVM_DEBUG(DBGS() << "--multi-buffered alloc: " << mbAlloc << "\n"); // 3. Within the loop, build the modular leading index (i.e. each loop @@ -211,8 +212,8 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp, // Strides is [1, 1 ... 1 ]. MemRefType dstMemref = memref::SubViewOp::inferRankReducedResultType( originalShape, mbMemRefType, offsets, sizes, strides); - Value subview = rewriter.create(loc, dstMemref, mbAlloc, - offsets, sizes, strides); + Value subview = memref::SubViewOp::create(rewriter, loc, dstMemref, mbAlloc, + offsets, sizes, strides); LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n"); // 5. Due to the recursive nature of replaceUsesAndPropagateType , we need to @@ -224,7 +225,7 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp, OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(deallocOp); auto newDeallocOp = - rewriter.create(deallocOp->getLoc(), mbAlloc); + memref::DeallocOp::create(rewriter, deallocOp->getLoc(), mbAlloc); (void)newDeallocOp; LLVM_DEBUG(DBGS() << "----Created dealloc: " << newDeallocOp << "\n"); rewriter.eraseOp(deallocOp); diff --git a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp index 4ec04321dd3e2..fa7991e6c6a80 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp @@ -276,8 +276,8 @@ void NormalizeMemRefs::updateFunctionSignature(func::FuncOp funcOp, if (!callOp) continue; Operation *newCallOp = - builder.create(userOp->getLoc(), callOp.getCalleeAttr(), - resultTypes, userOp->getOperands()); + func::CallOp::create(builder, userOp->getLoc(), callOp.getCalleeAttr(), + resultTypes, userOp->getOperands()); bool replacingMemRefUsesFailed = false; bool returnTypeChanged = false; for (unsigned resIndex : llvm::seq(0, userOp->getNumResults())) { diff --git a/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp b/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp index 46f9d64ebeb15..d65825bbdf391 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp @@ -115,10 +115,12 @@ static LogicalResult reifyOpResultShapes(RewriterBase &rewriter, // Update the type. newRes.setType(reifiedTy); if (isa(reifiedTy)) { - newResults.push_back(rewriter.create(loc, oldTy, newRes)); + newResults.push_back( + tensor::CastOp::create(rewriter, loc, oldTy, newRes)); } else { assert(isa(reifiedTy) && "expected a memref type"); - newResults.push_back(rewriter.create(loc, oldTy, newRes)); + newResults.push_back( + memref::CastOp::create(rewriter, loc, oldTy, newRes)); } } diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp index 89a3895d06ba5..6a81a15f30e47 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp @@ -69,7 +69,7 @@ struct DimOfShapedTypeOpInterface : public OpRewritePattern { Location loc = dimOp->getLoc(); rewriter.replaceOpWithNewOp( dimOp, resultShape, - rewriter.create(loc, *dimIndex).getResult()); + arith::ConstantIndexOp::create(rewriter, loc, *dimIndex).getResult()); return success(); } }; diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp index d231516884c7d..1f03e9ae8d6a1 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp @@ -40,19 +40,18 @@ struct AssumeAlignmentOpInterface void generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc) const { auto assumeOp = cast(op); - Value ptr = builder.create( - loc, assumeOp.getMemref()); - Value rest = builder.create( - loc, ptr, - builder.create(loc, assumeOp.getAlignment())); - Value isAligned = builder.create( - loc, arith::CmpIPredicate::eq, rest, - builder.create(loc, 0)); - builder.create( - loc, isAligned, - RuntimeVerifiableOpInterface::generateErrorMessage( - op, "memref is not aligned to " + - std::to_string(assumeOp.getAlignment()))); + Value ptr = ExtractAlignedPointerAsIndexOp::create(builder, loc, + assumeOp.getMemref()); + Value rest = arith::RemUIOp::create( + builder, loc, ptr, + arith::ConstantIndexOp::create(builder, loc, assumeOp.getAlignment())); + Value isAligned = + arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, rest, + arith::ConstantIndexOp::create(builder, loc, 0)); + cf::AssertOp::create(builder, loc, isAligned, + RuntimeVerifiableOpInterface::generateErrorMessage( + op, "memref is not aligned to " + + std::to_string(assumeOp.getAlignment()))); } }; @@ -71,15 +70,14 @@ struct CastOpInterface if (isa(srcType)) { // Check rank. - Value srcRank = builder.create(loc, castOp.getSource()); + Value srcRank = RankOp::create(builder, loc, castOp.getSource()); Value resultRank = - builder.create(loc, resultType.getRank()); - Value isSameRank = builder.create( - loc, arith::CmpIPredicate::eq, srcRank, resultRank); - builder.create( - loc, isSameRank, - RuntimeVerifiableOpInterface::generateErrorMessage(op, - "rank mismatch")); + arith::ConstantIndexOp::create(builder, loc, resultType.getRank()); + Value isSameRank = arith::CmpIOp::create( + builder, loc, arith::CmpIPredicate::eq, srcRank, resultRank); + cf::AssertOp::create(builder, loc, isSameRank, + RuntimeVerifiableOpInterface::generateErrorMessage( + op, "rank mismatch")); } // Get source offset and strides. We do not have an op to get offsets and @@ -95,8 +93,9 @@ struct CastOpInterface MemRefType::get(dynamicShape, resultType.getElementType(), stridedLayout, resultType.getMemorySpace()); Value helperCast = - builder.create(loc, dynStridesType, castOp.getSource()); - auto metadataOp = builder.create(loc, helperCast); + CastOp::create(builder, loc, dynStridesType, castOp.getSource()); + auto metadataOp = + ExtractStridedMetadataOp::create(builder, loc, helperCast); // Check dimension sizes. for (const auto &it : llvm::enumerate(resultType.getShape())) { @@ -110,13 +109,13 @@ struct CastOpInterface continue; Value srcDimSz = - builder.create(loc, castOp.getSource(), it.index()); + DimOp::create(builder, loc, castOp.getSource(), it.index()); Value resultDimSz = - builder.create(loc, it.value()); - Value isSameSz = builder.create( - loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz); - builder.create( - loc, isSameSz, + arith::ConstantIndexOp::create(builder, loc, it.value()); + Value isSameSz = arith::CmpIOp::create( + builder, loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz); + cf::AssertOp::create( + builder, loc, isSameSz, RuntimeVerifiableOpInterface::generateErrorMessage( op, "size mismatch of dim " + std::to_string(it.index()))); } @@ -132,13 +131,12 @@ struct CastOpInterface // Static/dynamic offset -> dynamic offset does not need verification. Value srcOffset = metadataOp.getResult(1); Value resultOffsetVal = - builder.create(loc, resultOffset); - Value isSameOffset = builder.create( - loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal); - builder.create( - loc, isSameOffset, - RuntimeVerifiableOpInterface::generateErrorMessage( - op, "offset mismatch")); + arith::ConstantIndexOp::create(builder, loc, resultOffset); + Value isSameOffset = arith::CmpIOp::create( + builder, loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal); + cf::AssertOp::create(builder, loc, isSameOffset, + RuntimeVerifiableOpInterface::generateErrorMessage( + op, "offset mismatch")); } // Check strides. @@ -150,11 +148,11 @@ struct CastOpInterface Value srcStride = metadataOp.getResult(2 + resultType.getRank() + it.index()); Value resultStrideVal = - builder.create(loc, it.value()); - Value isSameStride = builder.create( - loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal); - builder.create( - loc, isSameStride, + arith::ConstantIndexOp::create(builder, loc, it.value()); + Value isSameStride = arith::CmpIOp::create( + builder, loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal); + cf::AssertOp::create( + builder, loc, isSameStride, RuntimeVerifiableOpInterface::generateErrorMessage( op, "stride mismatch of dim " + std::to_string(it.index()))); } @@ -186,7 +184,7 @@ struct CopyOpInterface auto getDimSize = [&](Value memRef, MemRefType type, int64_t dim) -> Value { return type.isDynamicDim(dim) - ? builder.create(loc, memRef, dim).getResult() + ? DimOp::create(builder, loc, memRef, dim).getResult() : builder .create(loc, type.getDimSize(dim)) @@ -194,13 +192,12 @@ struct CopyOpInterface }; Value sourceDim = getDimSize(copyOp.getSource(), rankedSourceType, i); Value targetDim = getDimSize(copyOp.getTarget(), rankedTargetType, i); - Value sameDimSize = builder.create( - loc, arith::CmpIPredicate::eq, sourceDim, targetDim); - builder.create( - loc, sameDimSize, - RuntimeVerifiableOpInterface::generateErrorMessage( - op, "size of " + std::to_string(i) + - "-th source/target dim does not match")); + Value sameDimSize = arith::CmpIOp::create( + builder, loc, arith::CmpIPredicate::eq, sourceDim, targetDim); + cf::AssertOp::create(builder, loc, sameDimSize, + RuntimeVerifiableOpInterface::generateErrorMessage( + op, "size of " + std::to_string(i) + + "-th source/target dim does not match")); } } }; @@ -211,10 +208,11 @@ struct DimOpInterface void generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc) const { auto dimOp = cast(op); - Value rank = builder.create(loc, dimOp.getSource()); - Value zero = builder.create(loc, 0); - builder.create( - loc, generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank), + Value rank = RankOp::create(builder, loc, dimOp.getSource()); + Value zero = arith::ConstantIndexOp::create(builder, loc, 0); + cf::AssertOp::create( + builder, loc, + generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank), RuntimeVerifiableOpInterface::generateErrorMessage( op, "index is out of bounds")); } @@ -237,7 +235,7 @@ struct LoadStoreOpInterface } auto indices = loadStoreOp.getIndices(); - auto zero = builder.create(loc, 0); + auto zero = arith::ConstantIndexOp::create(builder, loc, 0); Value assertCond; for (auto i : llvm::seq(0, rank)) { Value dimOp = builder.createOrFold(loc, memref, i); @@ -247,10 +245,9 @@ struct LoadStoreOpInterface i > 0 ? builder.createOrFold(loc, assertCond, inBounds) : inBounds; } - builder.create( - loc, assertCond, - RuntimeVerifiableOpInterface::generateErrorMessage( - op, "out-of-bounds access")); + cf::AssertOp::create(builder, loc, assertCond, + RuntimeVerifiableOpInterface::generateErrorMessage( + op, "out-of-bounds access")); } }; @@ -265,10 +262,10 @@ struct SubViewOpInterface // For each dimension, assert that: // 0 <= offset < dim_size // 0 <= offset + (size - 1) * stride < dim_size - Value zero = builder.create(loc, 0); - Value one = builder.create(loc, 1); + Value zero = arith::ConstantIndexOp::create(builder, loc, 0); + Value one = arith::ConstantIndexOp::create(builder, loc, 1); auto metadataOp = - builder.create(loc, subView.getSource()); + ExtractStridedMetadataOp::create(builder, loc, subView.getSource()); for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) { Value offset = getValueOrCreateConstantIndexOp( builder, loc, subView.getMixedOffsets()[i]); @@ -281,21 +278,21 @@ struct SubViewOpInterface Value dimSize = metadataOp.getSizes()[i]; Value offsetInBounds = generateInBoundsCheck(builder, loc, offset, zero, dimSize); - builder.create( - loc, offsetInBounds, + cf::AssertOp::create( + builder, loc, offsetInBounds, RuntimeVerifiableOpInterface::generateErrorMessage( op, "offset " + std::to_string(i) + " is out-of-bounds")); // Verify that slice does not run out-of-bounds. - Value sizeMinusOne = builder.create(loc, size, one); + Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one); Value sizeMinusOneTimesStride = - builder.create(loc, sizeMinusOne, stride); + arith::MulIOp::create(builder, loc, sizeMinusOne, stride); Value lastPos = - builder.create(loc, offset, sizeMinusOneTimesStride); + arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride); Value lastPosInBounds = generateInBoundsCheck(builder, loc, lastPos, zero, dimSize); - builder.create( - loc, lastPosInBounds, + cf::AssertOp::create( + builder, loc, lastPosInBounds, RuntimeVerifiableOpInterface::generateErrorMessage( op, "subview runs out-of-bounds along dimension " + std::to_string(i))); @@ -315,7 +312,7 @@ struct ExpandShapeOpInterface for (const auto &it : llvm::enumerate(expandShapeOp.getReassociationIndices())) { Value srcDimSz = - builder.create(loc, expandShapeOp.getSrc(), it.index()); + DimOp::create(builder, loc, expandShapeOp.getSrc(), it.index()); int64_t groupSz = 1; bool foundDynamicDim = false; for (int64_t resultDim : it.value()) { @@ -330,18 +327,17 @@ struct ExpandShapeOpInterface groupSz *= expandShapeOp.getResultType().getDimSize(resultDim); } Value staticResultDimSz = - builder.create(loc, groupSz); + arith::ConstantIndexOp::create(builder, loc, groupSz); // staticResultDimSz must divide srcDimSz evenly. Value mod = - builder.create(loc, srcDimSz, staticResultDimSz); - Value isModZero = builder.create( - loc, arith::CmpIPredicate::eq, mod, - builder.create(loc, 0)); - builder.create( - loc, isModZero, - RuntimeVerifiableOpInterface::generateErrorMessage( - op, "static result dims in reassoc group do not " - "divide src dim evenly")); + arith::RemSIOp::create(builder, loc, srcDimSz, staticResultDimSz); + Value isModZero = arith::CmpIOp::create( + builder, loc, arith::CmpIPredicate::eq, mod, + arith::ConstantIndexOp::create(builder, loc, 0)); + cf::AssertOp::create(builder, loc, isModZero, + RuntimeVerifiableOpInterface::generateErrorMessage( + op, "static result dims in reassoc group do not " + "divide src dim evenly")); } } };