Skip to content

Speed-up adaptive average pooling for the common case of size=1 output (#17011) #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 28 additions & 55 deletions aten/src/ATen/native/AdaptiveAveragePooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,19 +75,6 @@ namespace {
at::Tensor const& input,
IntArrayRef output_size)
{
int dimD = 0;
int dimH = 1;
int dimW = 2;
int64_t sizeB = 1;
int64_t sizeD = 0;
int64_t isizeH = 0;
int64_t isizeW = 0;

int64_t istrideB = 0;
int64_t istrideD = 0;
int64_t istrideH = 0;
int64_t istrideW = 0;

for (int64_t i = 0; i < input.ndimension(); i++) {
AT_CHECK(input.size(i) > 0,
"adaptive_avg_pooling2d(): expected input to have non-empty spatial dimensions, "
Expand All @@ -98,23 +85,14 @@ namespace {
AT_CHECK((input.ndimension() == 3 || input.ndimension() == 4),
"non-empty 3D or 4D (batch mode) tensor expected for input");

if (input.ndimension() == 4)
{
istrideB = input.stride(0);
sizeB = input.size(0);
dimD++;
dimH++;
dimW++;
}

/* sizes */
sizeD = input.size(dimD);
isizeH = input.size(dimH);
isizeW = input.size(dimW);
int64_t sizeD = input.size(-3);
int64_t isizeH = input.size(-2);
int64_t isizeW = input.size(-1);
/* strides */
istrideD = input.stride(dimD);
istrideH = input.stride(dimH);
istrideW = input.stride(dimW);
int64_t istrideD = input.stride(-3);
int64_t istrideH = input.stride(-2);
int64_t istrideW = input.stride(-1);

auto osizeH = output_size[0];
auto osizeW = output_size[1];
Expand All @@ -138,16 +116,15 @@ namespace {
}
else
{
output.resize_({sizeB, sizeD, osizeH, osizeW});

output.resize_({input.size(-4), sizeD, osizeH, osizeW});
int64_t b;
#pragma omp parallel for private(b)
for (b = 0; b < sizeB; b++)
for (b = 0; b < input.size(0); b++)
{
AT_DISPATCH_FLOATING_TYPES(input.type(), "adaptive_avg_pool2d", [&] {
auto input_data = input.data<scalar_t>();
auto output_data = output.data<scalar_t>();
adaptive_avg_pool2d_out_frame<scalar_t>(input_data+b*istrideB, output_data+b*sizeD*osizeH*osizeW,
adaptive_avg_pool2d_out_frame<scalar_t>(input_data+b*input.stride(0), output_data+b*sizeD*osizeH*osizeW,
sizeD,
isizeH, isizeW,
osizeH, osizeW,
Expand Down Expand Up @@ -212,29 +189,12 @@ namespace {
const Tensor& gradOutput_,
const Tensor& input)
{
int dimD = 0;
int dimH = 1;
int dimW = 2;
int64_t sizeB = 1;
int sizeD;
int isizeH;
int isizeW;
int osizeH;
int osizeW;

if (input.ndimension() == 4) {
sizeB = input.size(0);
dimD++;
dimH++;
dimW++;
}

/* sizes */
sizeD = input.size(dimD);
isizeH = input.size(dimH);
isizeW = input.size(dimW);
osizeH = gradOutput_.size(dimH);
osizeW = gradOutput_.size(dimW);
int sizeD = input.size(-3);
int isizeH = input.size(-2);
int isizeW = input.size(-1);
int osizeH = gradOutput_.size(-2);
int osizeW = gradOutput_.size(-1);

/* get contiguous gradOutput */
auto gradOutput = gradOutput_.contiguous();
Expand All @@ -260,7 +220,7 @@ namespace {
{
int64_t b;
#pragma omp parallel for private(b)
for (b = 0; b < sizeB; b++)
for (b = 0; b < input.size(0); b++)
{
AT_DISPATCH_FLOATING_TYPES(
input.type(), "adaptive_avg_pool2d_backward", [&] {
Expand Down Expand Up @@ -302,6 +262,19 @@ namespace {
return output;
}

Tensor adaptive_avg_pool2d(
at::Tensor const& input,
IntArrayRef output_size){
if (output_size[0] == 1 && output_size[1] == 1) {
//in this case, adaptive pooling is just computing mean over hw dimensions, which can be done more efficiently
int64_t mean_size = input.size(-1) * input.size(-2);
Tensor out = input.contiguous().view({-1, mean_size}).mean(-1);
return input.ndimension() == 3 ? out.view({input.size(0), 1, 1}) : out.view({input.size(0), input.size(1), 1, 1});
} else {
return _adaptive_avg_pool2d(input, output_size);
}
}

Tensor& adaptive_avg_pool2d_backward_out_cpu(
Tensor& gradInput,
const Tensor& gradOutput,
Expand Down
131 changes: 34 additions & 97 deletions aten/src/ATen/native/cuda/AdaptiveAveragePooling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -222,72 +222,44 @@ namespace {

AT_CHECK((input.ndimension() == 3 || input.ndimension() == 4),
"non-empty 3D or 4D (batch mode) tensor expected for input");

if (input.ndimension() == 3) {
int64_t sizeD = input.size(0);
int64_t isizeH = input.size(1);
int64_t isizeW = input.size(2);

int64_t istrideD = input.stride(0);
int64_t istrideH = input.stride(1);
int64_t istrideW = input.stride(2);

int64_t osizeH = output_size[0];
int64_t osizeW = output_size[1];
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.type(), "adaptive_avg_pool2d", [&] {
scalar_t *input_data = input.data<scalar_t>();

output.resize_({sizeD, osizeH, osizeW});

scalar_t *output_data = output.data<scalar_t>();

// cuda blocks & threads:
int blocksH = std::max<int64_t>((int)(16L / sizeD), 1);
dim3 blocks(sizeD, blocksH);
dim3 threads(32, 8);

// run averagepool kernel
adaptiveaveragepool <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
input_data, output_data,
isizeH, isizeW, osizeH, osizeW,
istrideD, istrideH, istrideW);
}
);
Tensor input_ = input;
int64_t grid_x = input.size(-3);
if (input.ndimension() == 4) {
input_ = input.contiguous();
grid_x *= input_.size(-4);
}
int64_t sizeD = input_.size(-3);
int64_t isizeH = input_.size(-2);
int64_t isizeW = input_.size(-1);

int64_t istrideD = input_.stride(-3);
int64_t istrideH = input_.stride(-2);
int64_t istrideW = input_.stride(-1);

int64_t osizeH = output_size[0];
int64_t osizeW = output_size[1];
if (input.ndimension() == 4) {
output.resize_({input_.size(-4), sizeD, osizeH, osizeW});
} else {
Tensor input_ = input.contiguous();
int64_t sizeB = input_.size(0);
int64_t sizeD = input_.size(1);
int64_t isizeH = input_.size(2);
int64_t isizeW = input.size(3);

int64_t istrideD = input_.stride(1);
int64_t istrideH = input_.stride(2);
int64_t istrideW = input_.stride(3);

int64_t osizeH = output_size[0];
int64_t osizeW = output_size[1];
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.type(), "adaptive_avg_pool2d", [&] {
output.resize_({sizeD, osizeH, osizeW});
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input_.type(), "adaptive_avg_pool2d", [&] {
scalar_t *input_data = input_.data<scalar_t>();

output.resize_({sizeB, sizeD, osizeH, osizeW});

scalar_t *output_data = output.data<scalar_t>();

// cuda blocks & threads:
int blocksH = std::max<int64_t>((int)(16L / sizeD), 1);
dim3 blocks(sizeB * sizeD, blocksH);
dim3 blocks(grid_x, blocksH);
dim3 threads(32, 8);

// run averagepool kernel
adaptiveaveragepool <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
input_data, output_data,
isizeH, isizeW, osizeH, osizeW,
istrideD, istrideH, istrideW);
}
}
);
}
THCudaCheck(cudaGetLastError());
}

Expand All @@ -306,23 +278,25 @@ namespace {

Tensor gradOutput = gradOutput_.contiguous();

if (input.ndimension() == 3) {
int64_t sizeD = input.size(0);
int64_t isizeH = input.size(1);
int64_t isizeW = input.size(2);
int64_t sizeD = input.size(-3);
int64_t isizeH = input.size(-2);
int64_t isizeW = input.size(-1);

int64_t osizeH = gradOutput.size(1);
int64_t osizeW = gradOutput.size(2);
int64_t osizeH = gradOutput.size(-2);
int64_t osizeW = gradOutput.size(-1);

int64_t grid_x = sizeD;
if (input.ndimension() == 4) grid_x *= input.size(-4);

//bool atomic = (isizeW%osizeW != 0) || (isizeH%osizeH != 0);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.type(), "adaptive_avg_pool2d_backward", [&] {
scalar_t *gradOutput_data = gradOutput.data<scalar_t>();
scalar_t *gradInput_data = gradInput.data<scalar_t>();

// cuda blocks & threads:
int blocksH = std::max((int)(16L / sizeD), 1);
dim3 blocks(sizeD, blocksH);
dim3 blocks(grid_x, blocksH);
dim3 threads(32, 8);

if(atomic)
Expand All @@ -341,43 +315,6 @@ namespace {
}
}
);
} else {
int64_t sizeB = input.size(0);
int64_t sizeD = input.size(1);
int64_t isizeH = input.size(2);
int64_t isizeW = input.size(3);

int64_t osizeH = gradOutput.size(2);
int64_t osizeW = gradOutput.size(3);

//bool atomic = //(isizeW%osizeW != 0) || (isizeH%osizeH != 0);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.type(), "adaptive_avg_pool2d_backward", [&] {
scalar_t *gradOutput_data = gradOutput.data<scalar_t>();
scalar_t *gradInput_data = gradInput.data<scalar_t>();

// cuda blocks & threads:
int blocksH = std::max((int)(16L / sizeD), 1);
dim3 blocks(sizeB * sizeD, blocksH);
dim3 threads(32, 8);

if(atomic)
{
// run updateGradInput kernel, accumulate gradients atomically
atomicadaptiveaveragegradinput <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
gradInput_data, gradOutput_data,
isizeH, isizeW, osizeH, osizeW);
}
else
{
// run updateGradInput kernel, accumulate gradients atomically
adaptiveaveragegradinput <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
gradInput_data, gradOutput_data,
isizeH, isizeW, osizeH, osizeW);
}
}
);
}
THCudaCheck(cudaGetLastError());
}

Expand Down
10 changes: 3 additions & 7 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4166,17 +4166,13 @@
- func: adaptive_avg_pool2d(Tensor self, int[2] output_size) -> Tensor
matches_jit_signature: True
python_module: nn

- func: _adaptive_avg_pool2d(Tensor self, int[2] output_size) -> Tensor
dispatch:
CPU: adaptive_avg_pool2d_cpu
CUDA: adaptive_avg_pool2d_cuda

- func: adaptive_avg_pool2d_backward(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!)
python_module: nn
dispatch:
CPU: adaptive_avg_pool2d_backward_out_cpu
CUDA: adaptive_avg_pool2d_backward_out_cuda

- func: adaptive_avg_pool2d_backward(Tensor grad_output, Tensor self) -> Tensor
- func: _adaptive_avg_pool2d_backward(Tensor grad_output, Tensor self) -> Tensor
matches_jit_signature: True
python_module: nn
dispatch:
Expand Down
12 changes: 12 additions & 0 deletions test/common_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1927,12 +1927,24 @@ def fractional_max_pool3d_test(test_case):
constructor_args=(3,),
input_fn=lambda: torch.rand(1, 3, 5),
),
dict(
module_name='AdaptiveAvgPool1d',
constructor_args=(1,),
input_fn=lambda: torch.rand(1, 3, 5),
desc='one_output',
),
dict(
module_name='AdaptiveAvgPool2d',
constructor_args=(3,),
input_fn=lambda: torch.rand(1, 3, 5, 6),
desc='single',
),
dict(
module_name='AdaptiveAvgPool2d',
constructor_args=(1,),
input_fn=lambda: torch.rand(1, 3, 5, 6),
desc='single_1x1output',
),
dict(
module_name='AdaptiveAvgPool2d',
constructor_args=((3, 4),),
Expand Down
8 changes: 4 additions & 4 deletions tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1061,8 +1061,8 @@
- name: upsample_nearest3d(Tensor self, IntArrayRef output_size)
self: upsample_nearest3d_backward(grad, output_size, self.sizes())

- name: adaptive_avg_pool2d(Tensor self, IntArrayRef output_size)
self: adaptive_avg_pool2d_backward(grad, self)
- name: _adaptive_avg_pool2d(Tensor self, IntArrayRef output_size)
self: _adaptive_avg_pool2d_backward(grad, self)

- name: adaptive_avg_pool3d(Tensor self, IntArrayRef output_size)
self: adaptive_avg_pool3d_backward(grad, self)
Expand Down Expand Up @@ -1148,8 +1148,8 @@

# NN double backwards support

- name: adaptive_avg_pool2d_backward(Tensor grad_output, Tensor self)
grad_output: adaptive_avg_pool2d(grad, { grad_output.size(-2), grad_output.size(-1) })
- name: _adaptive_avg_pool2d_backward(Tensor grad_output, Tensor self)
grad_output: _adaptive_avg_pool2d(grad, { grad_output.size(-2), grad_output.size(-1) })
self: zeros_like(self)

- name: adaptive_avg_pool3d_backward(Tensor grad_output, Tensor self)
Expand Down
6 changes: 3 additions & 3 deletions torch/csrc/jit/symbolic_script.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,13 +303,13 @@ const std::vector<std::string> functions = {

return torch.view(self, size), backward

def adaptive_avg_pool2d(self,
def _adaptive_avg_pool2d(self,
output_size: List[int]):
def backward(grad_output):
grad_self = torch.adaptive_avg_pool2d_backward(grad_output, self)
grad_self = torch._adaptive_avg_pool2d_backward(grad_output, self)
return grad_self, None

return torch.adaptive_avg_pool2d(self, output_size), backward
return torch._adaptive_avg_pool2d(self, output_size), backward

def embedding(weight,
indices,
Expand Down