-
Notifications
You must be signed in to change notification settings - Fork 25.4k
Adding support for nonzero in LazyTensor shape functions #77572
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful links
✅ No Failures (0 Pending)As of commit cc2d25e (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
cc @miladm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please add a test case that would test this approach works. Namely,
- we can set a flag
- trace nonzero
- it gives us the correct upper bound
- when we move a lazy nonzero to
cpu()
it gives us the correct result.
} | ||
|
||
std::vector<Shape> compute_shape_nonzero(const at::Tensor& self) { | ||
return compute_shape_nonzero(self, false); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh god, as_tuple
is to emulate Pytorch API right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah
if(op == at::aten::nonzero){ | ||
// When symbolic shape mode is not enabled, the nonzero shape function | ||
// returns an incorrect result. | ||
return !symbolicShapeEnabled(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-
To better understand this condition: the intended behavior is to not fall back to CPU when symbolic shape are enabled. Correct?
-
If yes to above, are you planning to add other dynamic ops like
unique
andmasked_select
to the list? (realizing these ops may be out of scope for this PR)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is correct, and yes, we will add at least unique
to the list, I will need to look at masked_select
some more.
for (auto dim_size : t.sizes()) { | ||
max_elements *= dim_size; | ||
} | ||
return {Shape(at::kLong, {max_elements, (int64_t)t.sizes().size()})}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please help me understand why we don't do something like this:
return {Shape(at::kLong, std::vector<int64_t>(t.sizes().begin(), t.sizes().end()))};
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Those are not the same. torch.nonzero
always returns a 2d tensor. Think of it as giving a list of indicies of the nonzero values.
(int64_t)t.sizes().size()
is an integer that holds number of dimensions in the original tensor
[ghstack-poisoned]
for (auto & result_shape : result_shapes) { | ||
result_shape = result_shape.with_symbolic_dims(c10::nullopt); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed a bug where we crash when there isn't a shape function.
@Krovatkin Added a test and fixed a bug. |
[ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[ghstack-poisoned]
[ghstack-poisoned]
@pytorchmergebot merge this please |
Hey @Gamrix. |
Summary: Pull Request resolved: #77572 Approved by: https://github.com/Krovatkin Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/73480bcbe09bf06e60431b945c1e049667de68dd Reviewed By: seemethere Differential Revision: D36494241 Pulled By: Gamrix fbshipit-source-id: e761ca47ee6cd963f0bd3e7f6f7b3aa90044784d
Stack from ghstack: