Skip to content

Per file C++ Operator registration #3135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Dec 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions test/tracing/frcnn/test_frcnn_tracing.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
#include <ATen/ATen.h>
#include <torch/script.h>
#include <torch/torch.h>
#include <torchvision/roi_align.h>
#include <torchvision/nms.h>
#include <torchvision/roi_align.h>
Comment on lines 4 to +5
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are those individual imports needed or can we just include torchvision/vision.h?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Including the nms is necessary because vision.h no longer includes all individual header files. I'm not sure that roi_align is necessary though. I left it because we had it before. Do you want me to try and remove it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah. We should fix this linker issue soon so that we can get rid of those hacks.


#ifdef _WIN32
// Windows only
// This is necessary until operators are automatically registered on include
static auto _nms = &vision::ops::nms_cpu;
static auto _nms = &vision::ops::nms;
#endif

int main() {
Expand Down
16 changes: 11 additions & 5 deletions torchvision/csrc/cpu/deform_conv2d_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@
// modified from
// https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp

#include "deform_conv2d_kernel.h"
#include <ATen/ATen.h>
#include <torch/library.h>

namespace vision {
namespace ops {
Expand Down Expand Up @@ -852,9 +853,7 @@ at::Tensor backward_gradient_parameters(
return grad_weight;
}

} // namespace

at::Tensor deform_conv2d_forward_cpu(
at::Tensor deform_conv2d_forward_kernel(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
Expand Down Expand Up @@ -1070,7 +1069,7 @@ at::Tensor deform_conv2d_forward_cpu(
}

std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
deform_conv2d_backward_cpu(
deform_conv2d_backward_kernel(
const at::Tensor& grad_out,
const at::Tensor& input,
const at::Tensor& weight,
Expand Down Expand Up @@ -1141,5 +1140,12 @@ deform_conv2d_backward_cpu(
grad_input, grad_weight, grad_offset, grad_mask, grad_bias);
}

} // namespace

TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl("deform_conv2d", deform_conv2d_forward_kernel);
m.impl("_deform_conv2d_backward", deform_conv2d_backward_kernel);
}

} // namespace ops
} // namespace vision
45 changes: 0 additions & 45 deletions torchvision/csrc/cpu/deform_conv2d_kernel.h

This file was deleted.

15 changes: 10 additions & 5 deletions torchvision/csrc/cpu/nms_kernel.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "nms_kernel.h"
#include <ATen/ATen.h>
#include <torch/library.h>

namespace vision {
namespace ops {
Expand Down Expand Up @@ -74,9 +75,7 @@ at::Tensor nms_kernel_impl(
return keep_t.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep);
}

} // namespace

at::Tensor nms_cpu(
at::Tensor nms_kernel(
const at::Tensor& dets,
const at::Tensor& scores,
double iou_threshold) {
Expand All @@ -101,11 +100,17 @@ at::Tensor nms_cpu(

auto result = at::empty({0}, dets.options());

AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms_cpu", [&] {
AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms_kernel", [&] {
result = nms_kernel_impl<scalar_t>(dets, scores, iou_threshold);
});
return result;
}

} // namespace

TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl("nms", nms_kernel);
}

} // namespace ops
} // namespace vision
15 changes: 0 additions & 15 deletions torchvision/csrc/cpu/nms_kernel.h

This file was deleted.

24 changes: 15 additions & 9 deletions torchvision/csrc/cpu/ps_roi_align_kernel.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "ps_roi_align_kernel.h"
#include <ATen/ATen.h>
#include <torch/library.h>

namespace vision {
namespace ops {
Expand Down Expand Up @@ -301,9 +302,7 @@ void ps_roi_align_backward_kernel_impl(
}
}

} // namespace

std::tuple<at::Tensor, at::Tensor> ps_roi_align_forward_cpu(
std::tuple<at::Tensor, at::Tensor> ps_roi_align_forward_kernel(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
Expand All @@ -318,7 +317,7 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_align_forward_cpu(

at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};

at::CheckedFrom c = "ps_roi_align_forward_cpu";
at::CheckedFrom c = "ps_roi_align_forward_kernel";
at::checkAllSameType(c, {input_t, rois_t});

int num_rois = rois.size(0);
Expand All @@ -343,7 +342,7 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_align_forward_cpu(

auto input_ = input.contiguous(), rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "ps_roi_align_forward_cpu", [&] {
input.scalar_type(), "ps_roi_align_forward_kernel", [&] {
ps_roi_align_forward_kernel_impl<scalar_t>(
output_size,
input_.data_ptr<scalar_t>(),
Expand All @@ -362,7 +361,7 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_align_forward_cpu(
return std::make_tuple(output, channel_mapping);
}

at::Tensor ps_roi_align_backward_cpu(
at::Tensor ps_roi_align_backward_kernel(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
Expand All @@ -384,7 +383,7 @@ at::Tensor ps_roi_align_backward_cpu(
at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2},
channel_mapping_t{channel_mapping, "channel_mapping", 3};

at::CheckedFrom c = "ps_roi_align_backward_cpu";
at::CheckedFrom c = "ps_roi_align_backward_kernel";
at::checkAllSameType(c, {grad_t, rois_t});

auto num_rois = rois.size(0);
Expand All @@ -400,7 +399,7 @@ at::Tensor ps_roi_align_backward_cpu(

auto grad_ = grad.contiguous(), rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad.scalar_type(), "ps_roi_align_backward_cpu", [&] {
grad.scalar_type(), "ps_roi_align_backward_kernel", [&] {
ps_roi_align_backward_kernel_impl<scalar_t>(
grad.numel(),
grad_.data_ptr<scalar_t>(),
Expand All @@ -420,5 +419,12 @@ at::Tensor ps_roi_align_backward_cpu(
return grad_input;
}

} // namespace

TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl("ps_roi_align", ps_roi_align_forward_kernel);
m.impl("_ps_roi_align_backward", ps_roi_align_backward_kernel);
}

} // namespace ops
} // namespace vision
31 changes: 0 additions & 31 deletions torchvision/csrc/cpu/ps_roi_align_kernel.h

This file was deleted.

24 changes: 15 additions & 9 deletions torchvision/csrc/cpu/ps_roi_pool_kernel.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "ps_roi_pool_kernel.h"
#include <ATen/ATen.h>
#include <torch/library.h>

namespace vision {
namespace ops {
Expand Down Expand Up @@ -145,9 +146,7 @@ void ps_roi_pool_backward_kernel_impl(
}
}

} // namespace

std::tuple<at::Tensor, at::Tensor> ps_roi_pool_forward_cpu(
std::tuple<at::Tensor, at::Tensor> ps_roi_pool_forward_kernel(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
Expand All @@ -161,7 +160,7 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_pool_forward_cpu(

at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};

at::CheckedFrom c = "ps_roi_pool_forward_cpu";
at::CheckedFrom c = "ps_roi_pool_forward_kernel";
at::checkAllSameType(c, {input_t, rois_t});

int num_rois = rois.size(0);
Expand All @@ -186,7 +185,7 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_pool_forward_cpu(

auto input_ = input.contiguous(), rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "ps_roi_pool_forward_cpu", [&] {
input.scalar_type(), "ps_roi_pool_forward_kernel", [&] {
ps_roi_pool_forward_kernel_impl<scalar_t>(
input_.data_ptr<scalar_t>(),
spatial_scale,
Expand All @@ -204,7 +203,7 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_pool_forward_cpu(
return std::make_tuple(output, channel_mapping);
}

at::Tensor ps_roi_pool_backward_cpu(
at::Tensor ps_roi_pool_backward_kernel(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
Expand All @@ -225,7 +224,7 @@ at::Tensor ps_roi_pool_backward_cpu(
at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2},
channel_mapping_t{channel_mapping, "channel_mapping", 3};

at::CheckedFrom c = "ps_roi_pool_backward_cpu";
at::CheckedFrom c = "ps_roi_pool_backward_kernel";
at::checkAllSameType(c, {grad_t, rois_t});

auto num_rois = rois.size(0);
Expand All @@ -241,7 +240,7 @@ at::Tensor ps_roi_pool_backward_cpu(

auto grad_ = grad.contiguous(), rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad.scalar_type(), "ps_roi_pool_backward_cpu", [&] {
grad.scalar_type(), "ps_roi_pool_backward_kernel", [&] {
ps_roi_pool_backward_kernel_impl<scalar_t>(
grad_.data_ptr<scalar_t>(),
channel_mapping.data_ptr<int>(),
Expand All @@ -259,5 +258,12 @@ at::Tensor ps_roi_pool_backward_cpu(
return grad_input;
}

} // namespace

TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl("ps_roi_pool", ps_roi_pool_forward_kernel);
m.impl("_ps_roi_pool_backward", ps_roi_pool_backward_kernel);
}

} // namespace ops
} // namespace vision
29 changes: 0 additions & 29 deletions torchvision/csrc/cpu/ps_roi_pool_kernel.h

This file was deleted.

Loading