Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ getConstantIntValues(ArrayRef<OpFoldResult> ofrs);

/// Return true if `ofr` is constant integer equal to `value`.
bool isConstantIntValue(OpFoldResult ofr, int64_t value);
/// Return true if all of `ofrs` are constant integers equal to `value`.
bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value);
/// Return true if all of `ofrs` are constant integers equal to the
/// corresponding value in `values`.
bool areConstantIntValues(ArrayRef<OpFoldResult> ofrs,
ArrayRef<int64_t> values);

/// Return true if ofr1 and ofr2 are the same integer constant attribute
/// values or the same SSA value. Ignore integer bitwitdh and type mismatch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"

Expand Down Expand Up @@ -636,6 +637,28 @@ struct InsertOpInterface
}
};

template <typename InsertOpTy>
static bool insertSliceOpRequiresRead(InsertOpTy insertSliceOp,
OpOperand &opOperand) {
// The source is always read.
if (opOperand == insertSliceOp.getSourceMutable())
return true;

// For the destination, it depends...
assert(opOperand == insertSliceOp.getDestMutable() && "expected dest");

// Dest is not read if it is entirely overwritten. E.g.:
// tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
bool allOffsetsZero =
llvm::all_of(insertSliceOp.getMixedOffsets(), isZeroIndex);
RankedTensorType destType = insertSliceOp.getDestType();
bool sizesMatchDestSizes =
areConstantIntValues(insertSliceOp.getMixedSizes(), destType.getShape());
bool allStridesOne =
areAllConstantIntValue(insertSliceOp.getMixedStrides(), 1);
return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne);
}

/// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
/// certain circumstances, this op can also be a no-op.
///
Expand All @@ -646,32 +669,8 @@ struct InsertSliceOpInterface
tensor::InsertSliceOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
RankedTensorType destType = insertSliceOp.getDestType();

// The source is always read.
if (opOperand == insertSliceOp.getSourceMutable())
return true;

// For the destination, it depends...
assert(opOperand == insertSliceOp.getDestMutable() && "expected dest");

// Dest is not read if it is entirely overwritten. E.g.:
// tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
bool allOffsetsZero =
llvm::all_of(insertSliceOp.getMixedOffsets(), [](OpFoldResult ofr) {
return isConstantIntValue(ofr, 0);
});
bool sizesMatchDestSizes = llvm::all_of(
llvm::enumerate(insertSliceOp.getMixedSizes()), [&](const auto &it) {
return getConstantIntValue(it.value()) ==
destType.getDimSize(it.index());
});
bool allStridesOne =
llvm::all_of(insertSliceOp.getMixedStrides(), [](OpFoldResult ofr) {
return isConstantIntValue(ofr, 1);
});
return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne);
return insertSliceOpRequiresRead(cast<tensor::InsertSliceOp>(op),
opOperand);
}

LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
Expand Down Expand Up @@ -931,7 +930,8 @@ struct ParallelInsertSliceOpInterface

bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return true;
return insertSliceOpRequiresRead(cast<tensor::ParallelInsertSliceOp>(op),
opOperand);
}

bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,6 @@ namespace mlir {
namespace tensor {
namespace {

static bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value) {
return llvm::all_of(
ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); });
}

/// Returns the number of shape sizes that is either dynamic or greater than 1.
static int64_t getNumGtOneDims(ArrayRef<int64_t> shape) {
return llvm::count_if(
Expand Down
15 changes: 14 additions & 1 deletion mlir/lib/Dialect/Utils/StaticValueUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/APSInt.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/MathExtras.h"

namespace mlir {
Expand Down Expand Up @@ -131,12 +132,24 @@ getConstantIntValues(ArrayRef<OpFoldResult> ofrs) {
return res;
}

/// Return true if `ofr` is constant integer equal to `value`.
bool isConstantIntValue(OpFoldResult ofr, int64_t value) {
auto val = getConstantIntValue(ofr);
return val && *val == value;
}

bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value) {
return llvm::all_of(
ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); });
}

bool areConstantIntValues(ArrayRef<OpFoldResult> ofrs,
ArrayRef<int64_t> values) {
if (ofrs.size() != values.size())
return false;
std::optional<SmallVector<int64_t>> constOfrs = getConstantIntValues(ofrs);
return constOfrs && llvm::equal(constOfrs.value(), values);
}

/// Return true if ofr1 and ofr2 are the same integer constant attribute values
/// or the same SSA value.
/// Ignore integer bitwidth and type mismatch that come from the fact there is
Expand Down
15 changes: 15 additions & 0 deletions mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,21 @@ func.func @rank_reducing_parallel_insert_slice(%in: tensor<100xf32>, %out: tenso

// -----

// CHECK-LABEL: func.func @parallel_insert_full_slice_in_place
// CHECK-NOT: memref.alloc()
func.func @parallel_insert_full_slice_in_place(%2: tensor<2xf32>) -> tensor<2xf32> {
%cst = arith.constant 0.000000e+00 : f32
%3 = scf.forall (%arg0) in (1) shared_outs(%arg2 = %2) -> (tensor<2xf32>) {
%fill = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<2xf32>) -> tensor<2xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %fill into %arg2[0] [2] [1] : tensor<2xf32> into tensor<2xf32>
}
} {mapping = [#gpu.thread<linear_dim_0>]}
return %3 : tensor<2xf32>
}

// -----

// This test case could bufferize in-place with a better analysis. However, it
// is simpler to let the canonicalizer fold away the tensor.insert_slice.

Expand Down
Loading