Skip to content

Commit 192b955

Browse files
committed
Adjust # of threads of DeformConv2D by Compute Capability
1 parent dfb94b7 commit 192b955

File tree

1 file changed

+8
-12
lines changed

1 file changed

+8
-12
lines changed

torchvision/csrc/cuda/DeformConv_cuda.cu

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,11 @@ using namespace at;
8282

8383
const int kMaxParallelImgs = 32;
8484

85-
inline unsigned int GET_THREADS(const unsigned int MAX_REGISTERS) {
86-
const unsigned int CUDA_MAX_NUM_THREADS = 1024;
87-
unsigned int kMaxRegsNumPerBlock =
88-
at::cuda::getCurrentDeviceProperties()->regsPerBlock;
89-
90-
return std::min(CUDA_MAX_NUM_THREADS, kMaxRegsNumPerBlock / MAX_REGISTERS);
85+
inline unsigned int GET_THREADS() {
86+
if (at::cuda::getCurrentDeviceProperties()->major >= 6) {
87+
return 1024;
88+
}
89+
return 512;
9190
}
9291

9392
inline unsigned int GET_BLOCKS(const unsigned int THREADS, const unsigned int N) {
@@ -231,8 +230,7 @@ static void deformable_im2col(
231230
at::Tensor data_col) {
232231
int num_kernels = n_in_channels * out_h * out_w * parallel_imgs;
233232

234-
const unsigned int max_registers = 60;
235-
const unsigned int threads = GET_THREADS(max_registers);
233+
const unsigned int threads = GET_THREADS();
236234
const unsigned int blocks = GET_BLOCKS(threads, num_kernels);
237235

238236
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
@@ -596,8 +594,7 @@ static void compute_grad_input(
596594
int num_kernels =
597595
channels * weight_h * weight_w * out_h * out_w * parallel_imgs;
598596

599-
const unsigned int max_registers = 46;
600-
const unsigned int threads = GET_THREADS(max_registers);
597+
const unsigned int threads = GET_THREADS();
601598
const unsigned int blocks = GET_BLOCKS(threads, num_kernels);
602599

603600
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
@@ -805,8 +802,7 @@ static void compute_grad_offset_and_mask(
805802
int num_kernels =
806803
out_h * out_w * 2 * weight_h * weight_w * n_offset_grps * parallel_imgs;
807804

808-
const unsigned int max_registers = 63;
809-
const unsigned int threads = GET_THREADS(max_registers);
805+
const unsigned int threads = GET_THREADS();
810806
const unsigned int blocks = GET_BLOCKS(threads, num_kernels);
811807

812808
AT_DISPATCH_FLOATING_TYPES_AND_HALF(

0 commit comments

Comments
 (0)