Skip to content

Added cdist forward/backward batching rules #306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Dec 21, 2021
Merged

Conversation

vfdev-5
Copy link
Contributor

@vfdev-5 vfdev-5 commented Dec 2, 2021

Description:

  • Added cdist forward/backward batching rules
  • updated tests
    • added a.select(in_dim, idx).contiguous() when computing loop expected values as required by cdist_backward to have contiguous input.

Related to #240

@vfdev-5 vfdev-5 marked this pull request as ready for review December 6, 2021 22:15
@vfdev-5 vfdev-5 changed the title [WIP] Added cdist forward/backward batching rules Added cdist forward/backward batching rules Dec 6, 2021
@vfdev-5 vfdev-5 requested a review from zou3519 December 6, 2021 22:19
@vfdev-5
Copy link
Contributor Author

vfdev-5 commented Dec 10, 2021

@zou3519 thanks for the review, I'll update the PR accordingly

@zou3519
Copy link
Contributor

zou3519 commented Dec 10, 2021

@vfdev-5 No rush, I'm still making my way through it (need to read the backward batch rule)

- rewrote forward pass reusing BINARY_POINTWISE with an update
- rewrote backward pass + comments
@vfdev-5
Copy link
Contributor Author

vfdev-5 commented Dec 13, 2021

@zou3519 I moved the code to BatchRulesBinaryOps.cpp to take benefit of BINARY_POINTWISE that I had to modify a little bit. I rewrote as well the backward pass and left comments about why I had to modify the inputs that way. Could you please review and let me know if we can simplify it more. Thanks

Comment on lines +208 to +213
// RuntimeError: Function CdistBackward0 returned an invalid gradient at index 1 - got [5]
// but expected shape compatible with [4, 5]
auto bs = cdist.size(*cdist_bdim);
x1_ = ensure_has_bdim(x1, false, bs);
x1_ = x1_.contiguous();
x1_bdim = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, what happens if you don't create the batch dim on x1? I don't know if you've tried this already, but the error message suggests another potential solution: if the output of cdist_backward_batch_rule is a regular Tensor with no batch dim, then we should return std::make_tuple(regular_tensor, nullopt)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reading the rest of the code it looks like you may have tried that already

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, maybe we can simplify that somehow...

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks pretty good to me. Two questions but the logic looks correct

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for fixing the contiguity problem in pytorch/pytorch! I had one last comment on if cdist has type promotion

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

@zou3519 zou3519 merged commit 3ba93d3 into pytorch:main Dec 21, 2021
@vfdev-5 vfdev-5 deleted the add-cdist-br branch December 21, 2021 22:11
zou3519 pushed a commit to zou3519/pytorch that referenced this pull request Jul 20, 2022
…torch#306)

* WIP on adding cdist batching rules

* Updated cdist forward / backward batch rules

* Fixed code according to the review
- rewrote forward pass reusing BINARY_POINTWISE with an update
- rewrote backward pass + comments

* Restore previous code as cdist issue has been fixed

* Added comment about type promotion for cdist
bigfootjon pushed a commit to pytorch/pytorch that referenced this pull request Jul 21, 2022
…torch#306)

* WIP on adding cdist batching rules

* Updated cdist forward / backward batch rules

* Fixed code according to the review
- rewrote forward pass reusing BINARY_POINTWISE with an update
- rewrote backward pass + comments

* Restore previous code as cdist issue has been fixed

* Added comment about type promotion for cdist
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants