Skip to content

Commit 44c2eb4

Browse files
authored
Encapsulate and standardize nms (#3081)
* Syncing, where possible, the names of functions across devices. * Adding all internal functions in anonymous namespaces. * Renaming C++/CUDA kernel files and moving operator code from header to cpp file. * Create foreach cpp file a separate header file with "public" functions. * Removing unnecessary repeated includes. * Update CMakeLists.txt to include all headers.
1 parent 231529f commit 44c2eb4

File tree

9 files changed

+77
-40
lines changed

9 files changed

+77
-40
lines changed

CMakeLists.txt

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,11 @@ file(GLOB HEADERS torchvision/csrc/*.h)
3232
# Image extension
3333
file(GLOB IMAGE_HEADERS torchvision/csrc/cpu/image/*.h)
3434
file(GLOB IMAGE_SOURCES torchvision/csrc/cpu/image/*.cpp)
35-
file(GLOB OPERATOR_SOURCES torchvision/csrc/cpu/*.h torchvision/csrc/cpu/*.cpp ${IMAGE_HEADERS} ${IMAGE_SOURCES} ${HEADERS} torchvision/csrc/*.cpp)
35+
file(GLOB OPERATOR_HEADERS torchvision/csrc/cpu/*.h)
36+
file(GLOB OPERATOR_SOURCES ${OPERATOR_HEADERS} torchvision/csrc/cpu/*.cpp ${IMAGE_HEADERS} ${IMAGE_SOURCES} ${HEADERS} torchvision/csrc/*.cpp)
3637
if(WITH_CUDA)
37-
file(GLOB OPERATOR_SOURCES ${OPERATOR_SOURCES} torchvision/csrc/cuda/*.h torchvision/csrc/cuda/*.cu)
38+
file(GLOB OPERATOR_HEADERS ${OPERATOR_HEADERS} torchvision/csrc/cuda/*.h)
39+
file(GLOB OPERATOR_SOURCES ${OPERATOR_SOURCES} ${OPERATOR_HEADERS} torchvision/csrc/cuda/*.cu)
3840
endif()
3941
file(GLOB MODELS_HEADERS torchvision/csrc/models/*.h)
4042
file(GLOB MODELS_SOURCES torchvision/csrc/models/*.h torchvision/csrc/models/*.cpp)
@@ -95,11 +97,11 @@ install(EXPORT TorchVisionTargets
9597

9698
install(FILES ${HEADERS} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/${PROJECT_NAME})
9799
install(FILES
98-
torchvision/csrc/cpu/vision_cpu.h
100+
${OPERATOR_HEADERS}
99101
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/${PROJECT_NAME}/cpu)
100102
if(WITH_CUDA)
101103
install(FILES
102-
torchvision/csrc/cuda/vision_cuda.h
104+
${OPERATOR_HEADERS}
103105
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/${PROJECT_NAME}/cuda)
104106
endif()
105107
install(FILES ${MODELS_HEADERS} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/${PROJECT_NAME}/models)

torchvision/csrc/cpu/nms_cpu.cpp renamed to torchvision/csrc/cpu/nms_kernel.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
#include "vision_cpu.h"
1+
#include "nms_kernel.h"
2+
3+
namespace {
24

35
template <typename scalar_t>
4-
at::Tensor nms_cpu_kernel(
6+
at::Tensor nms_kernel(
57
const at::Tensor& dets,
68
const at::Tensor& scores,
79
double iou_threshold) {
@@ -69,6 +71,8 @@ at::Tensor nms_cpu_kernel(
6971
return keep_t.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep);
7072
}
7173

74+
} // namespace
75+
7276
at::Tensor nms_cpu(
7377
const at::Tensor& dets,
7478
const at::Tensor& scores,
@@ -95,7 +99,7 @@ at::Tensor nms_cpu(
9599
auto result = at::empty({0}, dets.options());
96100

97101
AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms", [&] {
98-
result = nms_cpu_kernel<scalar_t>(dets, scores, iou_threshold);
102+
result = nms_kernel<scalar_t>(dets, scores, iou_threshold);
99103
});
100104
return result;
101105
}

torchvision/csrc/cpu/nms_kernel.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#pragma once
2+
3+
#include <ATen/ATen.h>
4+
#include "../macros.h"
5+
6+
VISION_API at::Tensor nms_cpu(
7+
const at::Tensor& dets,
8+
const at::Tensor& scores,
9+
double iou_threshold);

torchvision/csrc/cpu/vision_cpu.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,6 @@
44

55
// TODO: Delete this file once all the methods are gone
66

7-
VISION_API at::Tensor nms_cpu(
8-
const at::Tensor& dets,
9-
const at::Tensor& scores,
10-
double iou_threshold);
11-
127
VISION_API std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cpu(
138
const at::Tensor& input,
149
const at::Tensor& rois,

torchvision/csrc/cuda/nms_cuda.cu renamed to torchvision/csrc/cuda/nms_kernel.cu

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,17 @@
33
#include <c10/cuda/CUDAGuard.h>
44

55
#include "cuda_helpers.h"
6+
#include "nms_kernel.h"
67

7-
#include <iostream>
8-
#include <vector>
8+
namespace {
99

1010
int const threadsPerBlock = sizeof(unsigned long long) * 8;
1111

1212
template <typename T>
13-
__device__ inline bool devIoU(T const* const a, T const* const b, const float threshold) {
13+
__device__ inline bool devIoU(
14+
T const* const a,
15+
T const* const b,
16+
const float threshold) {
1417
T left = max(a[0], b[0]), right = min(a[2], b[2]);
1518
T top = max(a[1], b[1]), bottom = min(a[3], b[3]);
1619
T width = max(right - left, (T)0), height = max(bottom - top, (T)0);
@@ -29,7 +32,8 @@ __global__ void nms_kernel(
2932
const int row_start = blockIdx.y;
3033
const int col_start = blockIdx.x;
3134

32-
if (row_start > col_start) return;
35+
if (row_start > col_start)
36+
return;
3337

3438
const int row_size =
3539
min(n_boxes - row_start * threadsPerBlock, threadsPerBlock);
@@ -68,6 +72,8 @@ __global__ void nms_kernel(
6872
}
6973
}
7074

75+
} // namespace
76+
7177
at::Tensor nms_cuda(const at::Tensor& dets,
7278
const at::Tensor& scores,
7379
double iou_threshold) {

torchvision/csrc/cuda/nms_kernel.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#pragma once
2+
3+
#include <ATen/ATen.h>
4+
#include "../macros.h"
5+
6+
VISION_API at::Tensor nms_cuda(
7+
const at::Tensor& dets,
8+
const at::Tensor& scores,
9+
double iou_threshold);

torchvision/csrc/cuda/vision_cuda.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,6 @@
44

55
// TODO: Delete this file once all the methods are gone
66

7-
VISION_API at::Tensor nms_cuda(
8-
const at::Tensor& dets,
9-
const at::Tensor& scores,
10-
double iou_threshold);
11-
127
VISION_API std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cuda(
138
const at::Tensor& input,
149
const at::Tensor& rois,

torchvision/csrc/nms.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#include "nms.h"
2+
#include <torch/extension.h>
3+
4+
#if defined(WITH_CUDA) || defined(WITH_HIP)
5+
#include <ATen/autocast_mode.h>
6+
#endif
7+
8+
at::Tensor nms(
9+
const at::Tensor& dets,
10+
const at::Tensor& scores,
11+
double iou_threshold) {
12+
static auto op = c10::Dispatcher::singleton()
13+
.findSchemaOrThrow("torchvision::nms", "")
14+
.typed<decltype(nms)>();
15+
return op.call(dets, scores, iou_threshold);
16+
}
17+
18+
#if defined(WITH_CUDA) || defined(WITH_HIP)
19+
at::Tensor nms_autocast(
20+
const at::Tensor& dets,
21+
const at::Tensor& scores,
22+
double iou_threshold) {
23+
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
24+
return nms(
25+
at::autocast::cached_cast(at::kFloat, dets),
26+
at::autocast::cached_cast(at::kFloat, scores),
27+
iou_threshold);
28+
}
29+
#endif

torchvision/csrc/nms.h

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,24 @@
11
#pragma once
22

3-
#include "cpu/vision_cpu.h"
3+
#include "cpu/nms_kernel.h"
44

55
#ifdef WITH_CUDA
6-
#include "autocast.h"
7-
#include "cuda/vision_cuda.h"
6+
#include "cuda/nms_kernel.h"
87
#endif
98
#ifdef WITH_HIP
10-
#include "autocast.h"
11-
#include "hip/vision_cuda.h"
9+
#include "hip/nms_kernel.h"
1210
#endif
1311

14-
// nms dispatch nexus
12+
// C++ Forward
1513
at::Tensor nms(
1614
const at::Tensor& dets,
1715
const at::Tensor& scores,
18-
double iou_threshold) {
19-
static auto op = c10::Dispatcher::singleton()
20-
.findSchemaOrThrow("torchvision::nms", "")
21-
.typed<decltype(nms)>();
22-
return op.call(dets, scores, iou_threshold);
23-
}
16+
double iou_threshold);
2417

18+
// Autocast Forward
2519
#if defined(WITH_CUDA) || defined(WITH_HIP)
2620
at::Tensor nms_autocast(
2721
const at::Tensor& dets,
2822
const at::Tensor& scores,
29-
double iou_threshold) {
30-
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
31-
return nms(
32-
at::autocast::cached_cast(at::kFloat, dets),
33-
at::autocast::cached_cast(at::kFloat, scores),
34-
iou_threshold);
35-
}
23+
double iou_threshold);
3624
#endif

0 commit comments

Comments
 (0)