Skip to content

Commit f677ea3

Browse files
authored
Remove cpp extensions in favor of torch ops (#1348)
* Remove C++ extensions in favor of custom ops * Remove unused custom_ops.cpp file * Rename _custom_ops.py * Reorganize functions * Minor improvements and fixes * Fix lint * Fully scriptable ops * Import types used by annotations
1 parent 0dd5588 commit f677ea3

File tree

14 files changed

+230
-239
lines changed

14 files changed

+230
-239
lines changed

setup.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ def write_version_file():
5252
with open(version_path, 'w') as f:
5353
f.write("__version__ = '{}'\n".format(version))
5454
f.write("git_version = {}\n".format(repr(sha)))
55-
f.write("from torchvision import _C\n")
56-
f.write("if hasattr(_C, 'CUDA_VERSION'):\n")
57-
f.write(" cuda = _C.CUDA_VERSION\n")
55+
f.write("from torchvision.extension import _check_cuda_version\n")
56+
f.write("if _check_cuda_version() > 0:\n")
57+
f.write(" cuda = _check_cuda_version()\n")
5858

5959

6060
write_version_file()
@@ -96,21 +96,12 @@ def get_extensions():
9696
source_models = [os.path.join(models_dir, s) for s in source_models]
9797
tests = test_file + source_models
9898

99-
custom_ops_sources = [os.path.join(extensions_dir, "custom_ops", "custom_ops.cpp"),
100-
os.path.join(extensions_dir, "cpu", "nms_cpu.cpp"),
101-
os.path.join(extensions_dir, "cpu", "ROIAlign_cpu.cpp"),
102-
os.path.join(extensions_dir, "cpu", "ROIPool_cpu.cpp")]
103-
custom_ops_sources_cuda = [os.path.join(extensions_dir, "cuda", "nms_cuda.cu"),
104-
os.path.join(extensions_dir, "cuda", "ROIAlign_cuda.cu"),
105-
os.path.join(extensions_dir, "cuda", "ROIPool_cuda.cu")]
106-
10799
define_macros = []
108100

109101
extra_compile_args = {}
110102
if (torch.cuda.is_available() and CUDA_HOME is not None) or os.getenv('FORCE_CUDA', '0') == '1':
111103
extension = CUDAExtension
112104
sources += source_cuda
113-
custom_ops_sources += custom_ops_sources_cuda
114105
define_macros += [('WITH_CUDA', None)]
115106
nvcc_flags = os.getenv('NVCC_FLAGS', '')
116107
if nvcc_flags == '':
@@ -148,13 +139,6 @@ def get_extensions():
148139
define_macros=define_macros,
149140
extra_compile_args=extra_compile_args,
150141
),
151-
extension(
152-
"torchvision._custom_ops",
153-
sources=custom_ops_sources,
154-
include_dirs=include_dirs,
155-
define_macros=define_macros,
156-
extra_compile_args=extra_compile_args,
157-
),
158142
]
159143

160144
return ext_modules

test/test_ops.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def func(input):
190190

191191
@torch.jit.script
192192
def script_func(input, rois):
193-
return torch.ops.torchvision.roi_pool(input, rois, 1.0, 5, 5)[0]
193+
return ops.roi_pool(input, rois, 5, 1.0)[0]
194194

195195
assert gradcheck(lambda x: script_func(x, rois), (x,)), 'gradcheck failed for scripted roi_pool'
196196

@@ -282,7 +282,7 @@ def func(input):
282282

283283
@torch.jit.script
284284
def script_func(input, rois):
285-
return torch.ops.torchvision.roi_pool(input, rois, 1.0, 5, 5)[0]
285+
return ops.roi_pool(input, rois, 5, 1.0)[0]
286286

287287
assert gradcheck(lambda x: script_func(x, rois), (x,)), 'gradcheck failed for scripted roi_pool on CUDA'
288288

@@ -442,7 +442,7 @@ def func(input):
442442

443443
@torch.jit.script
444444
def script_func(input, rois):
445-
return torch.ops.torchvision.roi_align(input, rois, 0.5, 5, 5, 1)[0]
445+
return ops.roi_align(input, rois, 5, 0.5, 1)[0]
446446

447447
assert gradcheck(lambda x: script_func(x, rois), (x,)), 'gradcheck failed for scripted roi_align'
448448

@@ -482,7 +482,7 @@ def func(input):
482482

483483
@torch.jit.script
484484
def script_func(input, rois):
485-
return torch.ops.torchvision.roi_align(input, rois, 0.5, 5, 5, 1)[0]
485+
return ops.roi_align(input, rois, 5, 0.5, 1)[0]
486486

487487
assert gradcheck(lambda x: script_func(x, rois), (x,)), 'gradcheck failed for scripted roi_align on CUDA'
488488

torchvision/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from torchvision import utils
66
from torchvision import io
77

8+
from .extension import _HAS_OPS
9+
810
try:
911
from .version import __version__ # noqa: F401
1012
except ImportError:

torchvision/csrc/ROIAlign.h

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,74 @@ at::Tensor ROIAlign_backward(
7474
width,
7575
sampling_ratio);
7676
}
77+
78+
using namespace at;
79+
using torch::Tensor;
80+
using torch::autograd::AutogradContext;
81+
using torch::autograd::Variable;
82+
using torch::autograd::variable_list;
83+
84+
class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
85+
public:
86+
static variable_list forward(
87+
AutogradContext* ctx,
88+
Variable input,
89+
Variable rois,
90+
const double spatial_scale,
91+
const int64_t pooled_height,
92+
const int64_t pooled_width,
93+
const int64_t sampling_ratio) {
94+
ctx->saved_data["spatial_scale"] = spatial_scale;
95+
ctx->saved_data["pooled_height"] = pooled_height;
96+
ctx->saved_data["pooled_width"] = pooled_width;
97+
ctx->saved_data["sampling_ratio"] = sampling_ratio;
98+
ctx->saved_data["input_shape"] = input.sizes();
99+
ctx->save_for_backward({rois});
100+
auto result = ROIAlign_forward(
101+
input,
102+
rois,
103+
spatial_scale,
104+
pooled_height,
105+
pooled_width,
106+
sampling_ratio);
107+
return {result};
108+
}
109+
110+
static variable_list backward(
111+
AutogradContext* ctx,
112+
variable_list grad_output) {
113+
// Use data saved in forward
114+
auto saved = ctx->get_saved_variables();
115+
auto rois = saved[0];
116+
auto input_shape = ctx->saved_data["input_shape"].toIntList();
117+
auto grad_in = ROIAlign_backward(
118+
grad_output[0],
119+
rois,
120+
ctx->saved_data["spatial_scale"].toDouble(),
121+
ctx->saved_data["pooled_height"].toInt(),
122+
ctx->saved_data["pooled_width"].toInt(),
123+
input_shape[0],
124+
input_shape[1],
125+
input_shape[2],
126+
input_shape[3],
127+
ctx->saved_data["sampling_ratio"].toInt());
128+
return {
129+
grad_in, Variable(), Variable(), Variable(), Variable(), Variable()};
130+
}
131+
};
132+
133+
Tensor roi_align(
134+
const Tensor& input,
135+
const Tensor& rois,
136+
const double spatial_scale,
137+
const int64_t pooled_height,
138+
const int64_t pooled_width,
139+
const int64_t sampling_ratio) {
140+
return ROIAlignFunction::apply(
141+
input,
142+
rois,
143+
spatial_scale,
144+
pooled_height,
145+
pooled_width,
146+
sampling_ratio)[0];
147+
}

torchvision/csrc/ROIPool.h

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,66 @@ at::Tensor ROIPool_backward(
6363
channels,
6464
height,
6565
width);
66-
}
66+
}
67+
68+
using namespace at;
69+
using torch::Tensor;
70+
using torch::autograd::AutogradContext;
71+
using torch::autograd::Variable;
72+
using torch::autograd::variable_list;
73+
74+
class ROIPoolFunction : public torch::autograd::Function<ROIPoolFunction> {
75+
public:
76+
static variable_list forward(
77+
AutogradContext* ctx,
78+
Variable input,
79+
Variable rois,
80+
const double spatial_scale,
81+
const int64_t pooled_height,
82+
const int64_t pooled_width) {
83+
ctx->saved_data["spatial_scale"] = spatial_scale;
84+
ctx->saved_data["pooled_height"] = pooled_height;
85+
ctx->saved_data["pooled_width"] = pooled_width;
86+
ctx->saved_data["input_shape"] = input.sizes();
87+
auto result = ROIPool_forward(
88+
input, rois, spatial_scale, pooled_height, pooled_width);
89+
auto output = std::get<0>(result);
90+
auto argmax = std::get<1>(result);
91+
ctx->save_for_backward({rois, argmax});
92+
ctx->mark_non_differentiable({argmax});
93+
return {output, argmax};
94+
}
95+
96+
static variable_list backward(
97+
AutogradContext* ctx,
98+
variable_list grad_output) {
99+
// Use data saved in forward
100+
auto saved = ctx->get_saved_variables();
101+
auto rois = saved[0];
102+
auto argmax = saved[1];
103+
auto input_shape = ctx->saved_data["input_shape"].toIntList();
104+
auto grad_in = ROIPool_backward(
105+
grad_output[0],
106+
rois,
107+
argmax,
108+
ctx->saved_data["spatial_scale"].toDouble(),
109+
ctx->saved_data["pooled_height"].toInt(),
110+
ctx->saved_data["pooled_width"].toInt(),
111+
input_shape[0],
112+
input_shape[1],
113+
input_shape[2],
114+
input_shape[3]);
115+
return {grad_in, Variable(), Variable(), Variable(), Variable()};
116+
}
117+
};
118+
119+
std::tuple<Tensor, Tensor> roi_pool(
120+
const Tensor& input,
121+
const Tensor& rois,
122+
const double spatial_scale,
123+
const int64_t pooled_height,
124+
const int64_t pooled_width) {
125+
auto result = ROIPoolFunction::apply(
126+
input, rois, spatial_scale, pooled_height, pooled_width);
127+
return std::tuple<Tensor, Tensor>(result[0], result[1]);
128+
}

torchvision/csrc/custom_ops/custom_ops.cpp

Lines changed: 0 additions & 159 deletions
This file was deleted.

0 commit comments

Comments
 (0)