Skip to content

[MLIR][TORCH] Add op verifier for aten.index_put op #4184

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -6351,6 +6351,7 @@ def Torch_AtenIndexPutOp : Torch_Op<"aten.index_put", [
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
let hasVerifier = 1;
}

def Torch_AtenIndexPut_Op : Torch_Op<"aten.index_put_", [
Expand Down
4 changes: 4 additions & 0 deletions include/torch-mlir/Dialect/Torch/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,10 @@ LogicalResult getPermutedType(BaseTensorType inType,
SmallVector<int64_t> permuteDims,
Type &permutedType);

// Check whether the given shapes of 2 tensors are broadcastable or not.
LogicalResult areStaticallyBroadcastCompatible(ArrayRef<int64_t> shapeA,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am actually surprised that there are no functions already available in Torch MLIR to get the shape after broadcasting or to check if broadcastable.

ArrayRef<int64_t> shapeB);

} // namespace Torch
} // namespace torch
} // namespace mlir
Expand Down
85 changes: 85 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6086,6 +6086,91 @@ LogicalResult AtenCountNonzeroDimIntListOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// AtenIndexPutOp
//===----------------------------------------------------------------------===//

// Determine the common broadcast shape of all the index tensors.
SmallVector<int64_t>
getIndexBroadcastShape(SmallVector<Torch::ValueTensorType> indicesTypes) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The indices may not be broadcast compatible on their own. E.g.

t = torch.empty([3, 4, 5], dtype=torch.int32)
indices = [
    torch.zeros(size=[5], dtype=torch.long),
    torch.zeros(size=[3], dtype=torch.long),
]
t[indices]

raises

    t[indices]
    ~^^^^^^^^^
IndexError: shape mismatch: indexing tensors could not be broadcast together with shapes [5], [3]

This function should be failable.

The resulting shape after indexing is more complicated. I think this verification would reject valid IR.
Example:

indices = [
    torch.zeros(size=[7], dtype=torch.long),
]
print(t[indices].shape)
torch.Size([7, 4, 5])

indices = [
    torch.zeros(size=[7], dtype=torch.long),
    torch.zeros(size=[7], dtype=torch.long),
]
print(t[indices].shape)
torch.Size([7, 5])

indices = [
    torch.zeros(size=[7], dtype=torch.long),
    torch.zeros(size=[7], dtype=torch.long),
    torch.zeros(size=[7], dtype=torch.long),
]
print(t[indices].shape)
torch.Size([7])

indices = [
    torch.zeros(size=[7], dtype=torch.long),
    torch.zeros(size=[7], dtype=torch.long),
    torch.zeros(size=[7], dtype=torch.long),
    torch.zeros(size=[7], dtype=torch.long),
]
print(t[indices].shape)
    print(t[indices].shape)
          ~^^^^^^^^^
IndexError: too many indices for tensor of dimension 3

Here I am using t[indices] indexing because according to the doc torch.Tensor.index_put_ is equivalent to tensor[indices] = values.

int64_t indicesBroadcastRank = 0;
SmallVector<int64_t> indicesRank;
SmallVector<ArrayRef<int64_t>> indicesShape;
for (auto indexTy : indicesTypes) {
indicesShape.push_back(indexTy.getSizes());
int64_t rank = indexTy.getSizes().size();
indicesRank.push_back(rank);
indicesBroadcastRank = std::max(rank, indicesBroadcastRank);
}

auto maxDim = [](int64_t dim0, int64_t dim1) {
if (dim0 == Torch::kUnknownSize || dim1 == Torch::kUnknownSize)
return Torch::kUnknownSize;
return std::max(dim0, dim1);
};

SmallVector<int64_t> broadcastShape(indicesBroadcastRank, 0);
for (unsigned i = 0; i < indicesTypes.size(); i++) {
for (int32_t j = 0; j < indicesRank[i]; ++j) {
auto size = indicesShape[i][j];
int32_t idx = broadcastShape.size() - indicesRank[i] + j;
broadcastShape[idx] = maxDim(size, broadcastShape[idx]);
}
}
return broadcastShape;
}

LogicalResult AtenIndexPutOp::verify() {
if (isa<Torch::NoneType>(getIndices().getType()))
return success();

SmallVector<Value> indices;
if (!getListConstructElements(getIndices(), indices))
return success();

SmallVector<Torch::ValueTensorType> indicesTypes;
for (auto index : indices) {
// Skipping the none value in the indices list.
if (auto indexTy = dyn_cast<Torch::ValueTensorType>(index.getType())) {
if (!indexTy.hasSizes())
return success();
Comment on lines +6135 to +6136
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the verification fail if does not have sizes. I assume this is the shape. What does it mean even if there is no size?

indicesTypes.push_back(indexTy);
}
}

auto inputType = cast<BaseTensorType>(getSelf().getType());
if (!inputType.hasSizes())
return success();
Comment on lines +6142 to +6143
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we fail here as the other case?

SmallVector<int64_t> inputShape(inputType.getSizes());

auto valuesType = cast<BaseTensorType>(getValues().getType());
if (!valuesType.hasSizes())
return success();
Comment on lines +6147 to +6148
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we fail here as the other case?

SmallVector<int64_t> valuesShape(valuesType.getSizes());

SmallVector<int64_t> indicesBroadcastShape(
getIndexBroadcastShape(indicesTypes));
// In the case where the input rank is greater than the number of index
// tensors, the remaining dimensions of the input are indexed in their
// entirety. Thus, we need to append the remaining dimensions to get the shape
// of the indexed slice.
for (size_t i = indices.size(); i < inputShape.size(); i++) {
indicesBroadcastShape.push_back(inputShape[i]);
}

// Check if the values tensor is broadcast compatible with indexing result
// shape or not. Here, we only check the static dimensions the dynamic ones
// will be caught by the downstream lowering through runtime checks.
if (failed(
areStaticallyBroadcastCompatible(valuesShape, indicesBroadcastShape)))
return emitOpError("values tensor shape [")
<< valuesShape
<< "] cannot be broadcasted to indexing result shape ["
<< indicesBroadcastShape << "]\n";

return success();
}

//===----------------------------------------------------------------------===//
// OnnxVariantRotaryEmbeddingOp
//===----------------------------------------------------------------------===//
Expand Down
26 changes: 26 additions & 0 deletions lib/Dialect/Torch/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -709,3 +709,29 @@ Type Torch::getDefaultAccType(PatternRewriter &rewriter, Type inputType) {
return rewriter.getI64Type();
return inputType;
}

// Check whether the shapes of the tensors are broadcastable or not.
// Two tensors are “broadcastable” if the following rules hold:
// 1.) Each tensor has at least one dimension.
// 2.) When iterating over the dimension sizes, starting at the trailing
// dimension, the dimension sizes must either be equal, one of them is 1, or
// one of them does not exist.
LogicalResult
Torch::areStaticallyBroadcastCompatible(ArrayRef<int64_t> shapeA,
ArrayRef<int64_t> shapeB) {
unsigned rankA = shapeA.size();
unsigned rankB = shapeB.size();
unsigned minRank = std::min(rankA, rankB);

for (unsigned i = 0; i < minRank; i++) {
int64_t dimA = shapeA[rankA - i - 1];
int64_t dimB = shapeB[rankB - i - 1];
// Here, we only check the static dimensions for compatibility.
if (dimA == Torch::kUnknownSize || dimB == Torch::kUnknownSize)
continue;
if (!(dimA == dimB || dimA == 1 || dimB == 1))
return failure();
}

return success();
}
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,8 @@ def emit_with_mutating_variants(key, **kwargs):
emit_with_mutating_variants("aten::triu : (Tensor, int) -> (Tensor)")
emit_with_mutating_variants("aten::tril : (Tensor, int) -> (Tensor)")
emit_with_mutating_variants(
"aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)"
"aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)",
has_verifier=True,
)
emit_with_mutating_variants(
"aten::index_put.hacked_twin : (Tensor, Tensor[], Tensor, bool) -> (Tensor)"
Expand Down
10 changes: 10 additions & 0 deletions test/Dialect/Torch/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -403,3 +403,13 @@ func.func @torch.symbolic_int$no_shape_symbols(%arg0: !torch.vtensor<[?],f32>) -
torch.bind_symbolic_shape %arg0, [%int0], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
return %arg0 : !torch.vtensor<[?],f32>
}

// -----

func.func @index_put_values_shape_broadcast_incompatible(%arg0: !torch.vtensor<[?,32,16,192],f16>, %arg1: !torch.vtensor<[?],si64>, %arg2: !torch.vtensor<[?,32,128,192],f16>) -> !torch.vtensor<[?,32,16,192],f16> attributes {torch.onnx_meta.opset_version = 10 : si64} {
%0 = torch.prim.ListConstruct %arg1 : (!torch.vtensor<[?],si64>) -> !torch.list<optional<vtensor>>
%false = torch.constant.bool false
// expected-error @+1 {{'torch.aten.index_put' op values tensor shape [-1, 32, 128, 192] cannot be broadcasted to indexing result shape [-1, 32, 16, 192]}}
%1 = torch.aten.index_put %arg0, %0, %arg2, %false : !torch.vtensor<[?,32,16,192],f16>, !torch.list<optional<vtensor>>, !torch.vtensor<[?,32,128,192],f16>, !torch.bool -> !torch.vtensor<[?,32,16,192],f16>
return %1 : !torch.vtensor<[?,32,16,192],f16>
}
Loading