Skip to content

Make custom ops differentiable #1314

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 10, 2019
Merged

Make custom ops differentiable #1314

merged 5 commits into from
Sep 10, 2019

Conversation

t-vi
Copy link
Contributor

@t-vi t-vi commented Sep 9, 2019

Make custom ops differentiable and replace autograd.Function. Use ops unconditionally.

We may consider removing the extension module in a follow-up.
The code-path is tested by the existing tests for differentiability.

and replace autograd.Function. Use ops unconditionally.

We may consider removing the extension functions in a follow-up.

The code-path is tested by the exisitng tests for differentiability.
Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, this is pretty clean, thanks a lot!

Could you also add a test showing that the gradients of a scripted function works as expected?

Removing the extension module in a follow-up sounds good to me

I'm also ccing @ezyang , this looks like the first example of how to use C++ Function for autograd, and this looks great.

@fmassa
Copy link
Member

fmassa commented Sep 9, 2019

@t-vi tests failures seems to be related

 /tmp/pip-req-build-4yu7qgl5/torchvision/csrc/custom_ops/custom_ops.cpp: In function ‘std::tuple<at::Tensor, at::Tensor> roi_pool(const at::Tensor&, const at::Tensor&, double, int64_t, int64_t)’:
  /tmp/pip-req-build-4yu7qgl5/torchvision/csrc/custom_ops/custom_ops.cpp:136:31: error: converting to ‘std::tuple<at::Tensor, at::Tensor>’ from initializer list would use explicit constructor ‘constexpr std::tuple<_T1, _T2>::tuple(_U1&&, _U2&&) [with _U1 = torch::autograd::Variable&; _U2 = torch::autograd::Variable&; <template-parameter-2-3> = void; _T1 = at::Tensor; _T2 = at::Tensor]’
     return {result[0], result[1]};

@t-vi
Copy link
Contributor Author

t-vi commented Sep 9, 2019

Which CI failure is that? I'm seeing a lot of data_ptr-things, but I cannot seem to find that error message... 😕

@fmassa
Copy link
Member

fmassa commented Sep 9, 2019

Here is one example https://travis-ci.org/pytorch/vision/jobs/582603380 , line 954

I've re-run some of the CircleCI failures, which I think were still picking the patch before my fix

@fmassa
Copy link
Member

fmassa commented Sep 10, 2019

@t-vi can you rebase your PR with current master? This looks pretty good!

@codecov-io
Copy link

codecov-io commented Sep 10, 2019

Codecov Report

Merging #1314 into master will increase coverage by 0.01%.
The diff coverage is 100%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #1314      +/-   ##
==========================================
+ Coverage   65.83%   65.84%   +0.01%     
==========================================
  Files          75       75              
  Lines        5824     5782      -42     
  Branches      886      884       -2     
==========================================
- Hits         3834     3807      -27     
+ Misses       1725     1710      -15     
  Partials      265      265
Impacted Files Coverage Δ
torchvision/ops/roi_align.py 71.42% <100%> (+3.42%) ⬆️
torchvision/ops/roi_pool.py 74.07% <100%> (+3.86%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update cabca39...6d800de. Read the comment docs.

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, thanks a lot!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants