-
Notifications
You must be signed in to change notification settings - Fork 7
Index select empty tensor scalar tensor #2513
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
third_party/nvfuser/csrc/codegen.cpp
Outdated
void handle(const TensorView* tv) final { | ||
// This allows us to access scalar tensor as if they are just scalar | ||
TORCH_INTERNAL_ASSERT(tv->isZeroDim(), "TensorView can only be handled as scalar tensor"); | ||
code_ << ir_utils::varName(tv) << "[0]"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This I think should be fine for scalar tensor. Seeking advises @naoyam @zasdfgbnm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of handling scalar tensor here, should we handle it during index lowering? In
pytorch/third_party/nvfuser/csrc/ir_internal_nodes.h
Lines 74 to 76 in 9340f80
std::unordered_map<IterDomain*, Val*> getIndexOverridingMap() const { | |
return {{getSelectAxis(), input(1)}}; | |
} |
We can check if
input(1)
is a tensor, if yes, create a kir::TensorIndex
with index
zero value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be fine, but it doesn't seem quite right to me. I think all we need to do is to lower the index input in lower_index.cpp
.
I'll work on cleaning it up. Can you add a test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. Let me add a test on the python API~~ 🙇
Thanks for following up on this one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pushed a commit. The Python test seems fine. Running the other tests.
This reverts commit ef7b91d848bce06e69a872c8d5fa2318ab65d21c.
This reverts commit df4d2789948d5e061367d07c424389a417e9cff8.
9f22351
to
6d17917
Compare
@@ -743,6 +743,28 @@ def fusion_func(fd: FusionDefinition): | |||
test_fn(0) | |||
test_fn(1) | |||
|
|||
def test_index_select_scalar_indices(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@naoyam tests added and verified the failing after reverting changes in codegen.cpp
. It's all yours now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, BTW, current branch I'm on is broken. You might want to revert the fouling commit for scatter 9340f80
sorry, this one totally falls off my radar. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
CI passed. merging this one |
Fixes index_select on empty/scalar indices. Issues found in python API.
numel()==0
), removed the check on that.variable_name[0]
. Add a quick patch for that.