-
Notifications
You must be signed in to change notification settings - Fork 564
[MLIR][TORCH] Add op verifier for aten.index_put op #4184
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6086,6 +6086,91 @@ LogicalResult AtenCountNonzeroDimIntListOp::verify() { | |
return success(); | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
// AtenIndexPutOp | ||
//===----------------------------------------------------------------------===// | ||
|
||
// Determine the common broadcast shape of all the index tensors. | ||
SmallVector<int64_t> | ||
getIndexBroadcastShape(SmallVector<Torch::ValueTensorType> indicesTypes) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The indices may not be broadcast compatible on their own. E.g.
raises
This function should be failable. The resulting shape after indexing is more complicated. I think this verification would reject valid IR.
Here I am using |
||
int64_t indicesBroadcastRank = 0; | ||
SmallVector<int64_t> indicesRank; | ||
SmallVector<ArrayRef<int64_t>> indicesShape; | ||
for (auto indexTy : indicesTypes) { | ||
indicesShape.push_back(indexTy.getSizes()); | ||
int64_t rank = indexTy.getSizes().size(); | ||
indicesRank.push_back(rank); | ||
indicesBroadcastRank = std::max(rank, indicesBroadcastRank); | ||
} | ||
|
||
auto maxDim = [](int64_t dim0, int64_t dim1) { | ||
if (dim0 == Torch::kUnknownSize || dim1 == Torch::kUnknownSize) | ||
return Torch::kUnknownSize; | ||
return std::max(dim0, dim1); | ||
}; | ||
|
||
SmallVector<int64_t> broadcastShape(indicesBroadcastRank, 0); | ||
for (unsigned i = 0; i < indicesTypes.size(); i++) { | ||
for (int32_t j = 0; j < indicesRank[i]; ++j) { | ||
auto size = indicesShape[i][j]; | ||
int32_t idx = broadcastShape.size() - indicesRank[i] + j; | ||
broadcastShape[idx] = maxDim(size, broadcastShape[idx]); | ||
} | ||
} | ||
return broadcastShape; | ||
} | ||
|
||
LogicalResult AtenIndexPutOp::verify() { | ||
if (isa<Torch::NoneType>(getIndices().getType())) | ||
return success(); | ||
|
||
SmallVector<Value> indices; | ||
if (!getListConstructElements(getIndices(), indices)) | ||
return success(); | ||
|
||
SmallVector<Torch::ValueTensorType> indicesTypes; | ||
for (auto index : indices) { | ||
// Skipping the none value in the indices list. | ||
if (auto indexTy = dyn_cast<Torch::ValueTensorType>(index.getType())) { | ||
if (!indexTy.hasSizes()) | ||
return success(); | ||
Comment on lines
+6135
to
+6136
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't the verification fail if does not have sizes. I assume this is the shape. What does it mean even if there is no size? |
||
indicesTypes.push_back(indexTy); | ||
} | ||
} | ||
|
||
auto inputType = cast<BaseTensorType>(getSelf().getType()); | ||
if (!inputType.hasSizes()) | ||
return success(); | ||
Comment on lines
+6142
to
+6143
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't we fail here as the other case? |
||
SmallVector<int64_t> inputShape(inputType.getSizes()); | ||
|
||
auto valuesType = cast<BaseTensorType>(getValues().getType()); | ||
if (!valuesType.hasSizes()) | ||
return success(); | ||
Comment on lines
+6147
to
+6148
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't we fail here as the other case? |
||
SmallVector<int64_t> valuesShape(valuesType.getSizes()); | ||
|
||
SmallVector<int64_t> indicesBroadcastShape( | ||
getIndexBroadcastShape(indicesTypes)); | ||
// In the case where the input rank is greater than the number of index | ||
// tensors, the remaining dimensions of the input are indexed in their | ||
// entirety. Thus, we need to append the remaining dimensions to get the shape | ||
// of the indexed slice. | ||
for (size_t i = indices.size(); i < inputShape.size(); i++) { | ||
indicesBroadcastShape.push_back(inputShape[i]); | ||
} | ||
|
||
// Check if the values tensor is broadcast compatible with indexing result | ||
// shape or not. Here, we only check the static dimensions the dynamic ones | ||
// will be caught by the downstream lowering through runtime checks. | ||
if (failed( | ||
areStaticallyBroadcastCompatible(valuesShape, indicesBroadcastShape))) | ||
return emitOpError("values tensor shape [") | ||
<< valuesShape | ||
<< "] cannot be broadcasted to indexing result shape [" | ||
<< indicesBroadcastShape << "]\n"; | ||
|
||
return success(); | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
// OnnxVariantRotaryEmbeddingOp | ||
//===----------------------------------------------------------------------===// | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am actually surprised that there are no functions already available in Torch MLIR to get the shape after broadcasting or to check if broadcastable.