-
Notifications
You must be signed in to change notification settings - Fork 104
Added cdist forward/backward batching rules #306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@zou3519 thanks for the review, I'll update the PR accordingly |
@vfdev-5 No rush, I'm still making my way through it (need to read the backward batch rule) |
- rewrote forward pass reusing BINARY_POINTWISE with an update - rewrote backward pass + comments
72ea65e
to
b683495
Compare
@zou3519 I moved the code to BatchRulesBinaryOps.cpp to take benefit of |
// RuntimeError: Function CdistBackward0 returned an invalid gradient at index 1 - got [5] | ||
// but expected shape compatible with [4, 5] | ||
auto bs = cdist.size(*cdist_bdim); | ||
x1_ = ensure_has_bdim(x1, false, bs); | ||
x1_ = x1_.contiguous(); | ||
x1_bdim = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, what happens if you don't create the batch dim on x1? I don't know if you've tried this already, but the error message suggests another potential solution: if the output of cdist_backward_batch_rule is a regular Tensor with no batch dim, then we should return std::make_tuple(regular_tensor, nullopt)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reading the rest of the code it looks like you may have tried that already
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, maybe we can simplify that somehow...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks pretty good to me. Two questions but the logic looks correct
0505195
to
7daf8d8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for fixing the contiguity problem in pytorch/pytorch! I had one last comment on if cdist has type promotion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm!
6246faa
to
20ea1ec
Compare
…torch#306) * WIP on adding cdist batching rules * Updated cdist forward / backward batch rules * Fixed code according to the review - rewrote forward pass reusing BINARY_POINTWISE with an update - rewrote backward pass + comments * Restore previous code as cdist issue has been fixed * Added comment about type promotion for cdist
…torch#306) * WIP on adding cdist batching rules * Updated cdist forward / backward batch rules * Fixed code according to the review - rewrote forward pass reusing BINARY_POINTWISE with an update - rewrote backward pass + comments * Restore previous code as cdist issue has been fixed * Added comment about type promotion for cdist
Description:
a.select(in_dim, idx).contiguous()
when computingloop
expected values as required by cdist_backward to have contiguous input.Related to #240