-
Notifications
You must be signed in to change notification settings - Fork 7
Add support for select
op
#2179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks for adding the support of select
.
@@ -1615,7 +1612,10 @@ std::vector<Val*> Index::getGlobalProducerStridedIndices( | |||
loops, | |||
root_ind); | |||
|
|||
root_ind = getProducerIndexWithHalo(producer_tv, i, root_ind, consumer_tv); | |||
if (root_dom[i] != selected_id) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should just assert this for now. getProducerIndexWithHalo
should just return root_ind
if the root domain is not extended with halo nor accessed through a shift op.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is an assert inside getProducerHaloOffset
:
auto it = p2c.find(producer_id);
// p2c should always have a mapping for producer_id. The only case
// where no mapping exists for a producer axis is when it is a
// reduction axis. Since this function is only used for indexing
// producer tensors, where reduction axes are skipped, producer_id
// should never be a reduction axis.
TORCH_INTERNAL_ASSERT(it != p2c.end());
which fails for the select id. Do we want to remove this assert and return 0
if not mapped?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I think we should remove the assert, but I'd feel more assured with some validation that the producer ID is a selected ID.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved the check inside getProducerIndexWithHalo
, something like
if (override) {
offset = 0;
} else {
assert
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Added some minor comments.
@@ -1591,18 +1575,27 @@ std::vector<Val*> Index::getGlobalProducerStridedIndices( | |||
continue; | |||
} | |||
|
|||
TORCH_INTERNAL_ASSERT( | |||
Val* root_ind = nullptr; | |||
auto override_it = override_index.find(root_dom[i]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking about passing the optional map to getTensorIndexFromIdGraph
to provide an initial ID-to-index map. That would be more consistent if we would want to allow the same initial map in consumer indexing.
That said, I think this is good enough for now given that the whole indexing code would be redesigned.
Pinging @csarofeen
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed
@@ -1615,7 +1612,10 @@ std::vector<Val*> Index::getGlobalProducerStridedIndices( | |||
loops, | |||
root_ind); | |||
|
|||
root_ind = getProducerIndexWithHalo(producer_tv, i, root_ind, consumer_tv); | |||
if (root_dom[i] != selected_id) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I think we should remove the assert, but I'd feel more assured with some validation that the producer ID is a selected ID.
Is select an alias operation? |
It is, it returns a view of the input tensor. |
In nvFuser, even a view isn't really an alias as it creates a new (transformed) copy. I don't think we have a true alias in nvFuser. Is it different in PyTorch? Is a view a true "view" to another tensor? |
Yes, in eager mode, |
Hmm, so I'm assuming we somehow handle the semantic difference, right? I wonder what would happen when writing to view tensors. |
I don't know. I remember seeing something called |
Yeah, this was why I asked. We would need to figure out how to support the alias analysis of the op to pipe it through. |
sounds like we needed to have an alias analysis to ensure that changing the |
Select op takes a slice of a tensor at a given dimension. For example, for a 3D tensor,
t.select(1, 123)
means taking a slice along dim 1 at index 123, which is identical tot[:, 123, :]
.The reason why I write this PR is:
select
is a commonly used op in PyTorch, it is always good to have better op coverage.index_select
but easier, so this PR could serve as a reference for theindex_select
, and it is also a good place to discuss indexing aboutindex_select
here.The current support of
select
is not complete, it has restrictions that require the input tensor to be a fusion input. And I don't think inregistry.cpp
, this PR is detecting all unsupported cases and rejecting them. But for now, I believe providing this level of support is sufficient.