Skip to content

Commit 8b86fbe

Browse files
committed
Create foreach cpp file a separate header file with "public" functions.
1 parent e2ff902 commit 8b86fbe

File tree

8 files changed

+48
-22
lines changed

8 files changed

+48
-22
lines changed

torchvision/csrc/cpu/nms_kernel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#include "vision_cpu.h"
1+
#include "nms_kernel.h"
22

33
namespace {
44

torchvision/csrc/cpu/nms_kernel.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#pragma once
2+
3+
#include <ATen/ATen.h>
4+
#include "../macros.h"
5+
6+
VISION_API at::Tensor nms_cpu(
7+
const at::Tensor& dets,
8+
const at::Tensor& scores,
9+
double iou_threshold);

torchvision/csrc/cpu/vision_cpu.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,6 @@
44

55
// TODO: Delete this file once all the methods are gone
66

7-
VISION_API at::Tensor nms_cpu(
8-
const at::Tensor& dets,
9-
const at::Tensor& scores,
10-
double iou_threshold);
11-
127
VISION_API std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cpu(
138
const at::Tensor& input,
149
const at::Tensor& rois,

torchvision/csrc/cuda/nms_kernel.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <c10/cuda/CUDAGuard.h>
44

55
#include "cuda_helpers.h"
6+
#include "nms_kernel.h"
67

78
#include <iostream>
89
#include <vector>

torchvision/csrc/cuda/nms_kernel.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#pragma once
2+
3+
#include <ATen/ATen.h>
4+
#include "../macros.h"
5+
6+
VISION_API at::Tensor nms_cuda(
7+
const at::Tensor& dets,
8+
const at::Tensor& scores,
9+
double iou_threshold);

torchvision/csrc/cuda/vision_cuda.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,6 @@
44

55
// TODO: Delete this file once all the methods are gone
66

7-
VISION_API at::Tensor nms_cuda(
8-
const at::Tensor& dets,
9-
const at::Tensor& scores,
10-
double iou_threshold);
11-
127
VISION_API std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cuda(
138
const at::Tensor& input,
149
const at::Tensor& rois,

torchvision/csrc/nms.cpp

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,10 @@
1-
#pragma once
1+
#include "nms.h"
2+
#include <torch/extension.h>
23

3-
#include "cpu/vision_cpu.h"
4-
5-
#ifdef WITH_CUDA
6-
#include "autocast.h"
7-
#include "cuda/vision_cuda.h"
8-
#endif
9-
#ifdef WITH_HIP
10-
#include "autocast.h"
11-
#include "hip/vision_cuda.h"
4+
#if defined(WITH_CUDA) || defined(WITH_HIP)
5+
#include <ATen/autocast_mode.h>
126
#endif
137

14-
// nms dispatch nexus
158
at::Tensor nms(
169
const at::Tensor& dets,
1710
const at::Tensor& scores,

torchvision/csrc/nms.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#pragma once
2+
3+
#include "cpu/nms_kernel.h"
4+
5+
#ifdef WITH_CUDA
6+
#include "cuda/nms_kernel.h"
7+
#endif
8+
#ifdef WITH_HIP
9+
#include "hip/nms_kernel.h"
10+
#endif
11+
12+
// C++ Forward
13+
at::Tensor nms(
14+
const at::Tensor& dets,
15+
const at::Tensor& scores,
16+
double iou_threshold);
17+
18+
// Autocast Forward
19+
#if defined(WITH_CUDA) || defined(WITH_HIP)
20+
at::Tensor nms_autocast(
21+
const at::Tensor& dets,
22+
const at::Tensor& scores,
23+
double iou_threshold);
24+
#endif

0 commit comments

Comments
 (0)