@@ -82,12 +82,11 @@ using namespace at;
82
82
83
83
const int kMaxParallelImgs = 32 ;
84
84
85
- inline unsigned int GET_THREADS (const unsigned int MAX_REGISTERS) {
86
- const unsigned int CUDA_MAX_NUM_THREADS = 1024 ;
87
- unsigned int kMaxRegsNumPerBlock =
88
- at::cuda::getCurrentDeviceProperties ()->regsPerBlock ;
89
-
90
- return std::min (CUDA_MAX_NUM_THREADS, kMaxRegsNumPerBlock / MAX_REGISTERS);
85
+ inline unsigned int GET_THREADS () {
86
+ if (at::cuda::getCurrentDeviceProperties ()->major >= 6 ) {
87
+ return 1024 ;
88
+ }
89
+ return 512 ;
91
90
}
92
91
93
92
inline unsigned int GET_BLOCKS (const unsigned int THREADS, const unsigned int N) {
@@ -231,8 +230,7 @@ static void deformable_im2col(
231
230
at::Tensor data_col) {
232
231
int num_kernels = n_in_channels * out_h * out_w * parallel_imgs;
233
232
234
- const unsigned int max_registers = 60 ;
235
- const unsigned int threads = GET_THREADS (max_registers);
233
+ const unsigned int threads = GET_THREADS ();
236
234
const unsigned int blocks = GET_BLOCKS (threads, num_kernels);
237
235
238
236
AT_DISPATCH_FLOATING_TYPES_AND_HALF (
@@ -596,8 +594,7 @@ static void compute_grad_input(
596
594
int num_kernels =
597
595
channels * weight_h * weight_w * out_h * out_w * parallel_imgs;
598
596
599
- const unsigned int max_registers = 46 ;
600
- const unsigned int threads = GET_THREADS (max_registers);
597
+ const unsigned int threads = GET_THREADS ();
601
598
const unsigned int blocks = GET_BLOCKS (threads, num_kernels);
602
599
603
600
AT_DISPATCH_FLOATING_TYPES_AND_HALF (
@@ -805,8 +802,7 @@ static void compute_grad_offset_and_mask(
805
802
int num_kernels =
806
803
out_h * out_w * 2 * weight_h * weight_w * n_offset_grps * parallel_imgs;
807
804
808
- const unsigned int max_registers = 63 ;
809
- const unsigned int threads = GET_THREADS (max_registers);
805
+ const unsigned int threads = GET_THREADS ();
810
806
const unsigned int blocks = GET_BLOCKS (threads, num_kernels);
811
807
812
808
AT_DISPATCH_FLOATING_TYPES_AND_HALF (
0 commit comments