-
Notifications
You must be signed in to change notification settings - Fork 365
fix: aten.index converter #2487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: aten.index converter #2487
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good!
@dynamo_tensorrt_converter(torch.ops.aten.index.Tensor) | ||
def index_dtype_validator(node: Node) -> bool: | ||
index = node.args[1] | ||
return all(ind.meta["val"].dtype == torch.int32 for ind in index if ind is not None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use get
for dictionary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that the bug in #2480 is only related to the fact that our index
converter cannot handle boolean mask indices. The fact that the example has data-dependent shapes does not seem to be the root cause of the initial error.
This is a good temporary fix. Could you also file an issue which specifies that we need to expand index
support to boolean masks, since this is a valid usage in Torch and we already support something very similar in aten.where
, we just need to add the functionality here. See the sample case below:
>>> import torch
>>> y = torch.rand(2, 2)
>>> y
tensor([[0.3105, 0.9580],
[0.4824, 0.0796]])
>>> y[torch.Tensor([[True, True], [False, False]]).bool()]
tensor([0.3105, 0.9580])
Description
See details in the issue below.
Fixes #2480
Type of change
Checklist: