Skip to content

Commit bcf0bf4

Browse files
colesburyfacebook-github-bot
authored andcommitted
Extend DispatchStub to support CUDA dispatch (pytorch#9579)
Summary: This is a few files taken from pytorch#8919. They're unchanged from the latest versions of that PR. ``` This is part of pytorch#8919. It's separated to make it easier to merge the PR in pieces. There are a few major changes to DispatchStub - The environment variable ATEN_CPU_CAPABILITY overrides the CPU capability detection code (Previous ATEN_DISABLE_AVX/AVX2) - DispatchStub is defined in the generic native code instead of the CPU_CAPABILITY_DEFAULT kernel. ``` Pull Request resolved: pytorch#9579 Differential Revision: D8909000 Pulled By: colesbury fbshipit-source-id: fdeb606270b06acdab3c01dba97ec9d81584ecc0
1 parent a08119a commit bcf0bf4

File tree

10 files changed

+139
-55
lines changed

10 files changed

+139
-55
lines changed

.jenkins/pytorch/test.sh

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,10 @@ if [[ "$BUILD_ENVIRONMENT" == *asan* ]]; then
4444
(cd test && ! get_exit_code python -c "import torch; torch._C._crash_if_aten_asan(3)")
4545
fi
4646

47-
export ATEN_DISABLE_AVX=
48-
export ATEN_DISABLE_AVX2=
4947
if [[ "${JOB_BASE_NAME}" == *-NO_AVX-* ]]; then
50-
export ATEN_DISABLE_AVX=1
51-
fi
52-
if [[ "${JOB_BASE_NAME}" == *-NO_AVX2-* ]]; then
53-
export ATEN_DISABLE_AVX2=1
48+
export ATEN_CPU_CAPABILITY=default
49+
elif [[ "${JOB_BASE_NAME}" == *-NO_AVX2-* ]]; then
50+
export ATEN_CPU_CAPABILITY=avx
5451
fi
5552

5653
test_python_nn() {

aten/src/ATen/native/DispatchStub.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#include "DispatchStub.h"
2+
3+
#include <ATen/Error.h>
4+
5+
#include <cpuinfo.h>
6+
#include <cstdlib>
7+
#include <cstring>
8+
9+
namespace at { namespace native {
10+
11+
static CPUCapability compute_cpu_capability() {
12+
auto envar = std::getenv("ATEN_CPU_CAPABILITY");
13+
if (envar) {
14+
if (strcmp(envar, "avx2") == 0) {
15+
return CPUCapability::AVX2;
16+
}
17+
if (strcmp(envar, "avx") == 0) {
18+
return CPUCapability::AVX;
19+
}
20+
if (strcmp(envar, "default") == 0) {
21+
return CPUCapability::DEFAULT;
22+
}
23+
AT_WARN("ignoring invalid value for ATEN_CPU_CAPABILITY: ", envar);
24+
}
25+
26+
#ifndef __powerpc__
27+
if (cpuinfo_initialize()) {
28+
if (cpuinfo_has_x86_avx2() && cpuinfo_has_x86_fma3()) {
29+
return CPUCapability::AVX2;
30+
}
31+
if (cpuinfo_has_x86_avx()) {
32+
return CPUCapability::AVX;
33+
}
34+
}
35+
#endif
36+
return CPUCapability::DEFAULT;
37+
}
38+
39+
CPUCapability get_cpu_capability() {
40+
static CPUCapability capability = compute_cpu_capability();
41+
return capability;
42+
}
43+
44+
}} // namespace at::native

aten/src/ATen/native/cpu/CapabilityDispatch.h renamed to aten/src/ATen/native/DispatchStub.h

Lines changed: 45 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#pragma once
22

3-
#include <cpuinfo.h>
3+
#include <ATen/Error.h>
4+
#include <ATen/ScalarType.h>
45
#include <type_traits>
5-
#include <iostream>
66

77
// Implements instruction set specific function dispatch.
88
//
@@ -23,72 +23,82 @@
2323
// REGISTER_DISPATCH(stub, &kernel);
2424
//
2525
// To call:
26-
// stub(tensor);
26+
// stub(kCPU, tensor);
2727
//
2828

2929
namespace at {
3030
namespace native {
3131

32-
enum class CPUCapability { DEFAULT, AVX, AVX2, NUM_OPTIONS };
32+
enum class CPUCapability {
33+
DEFAULT = 0,
34+
AVX = 1,
35+
AVX2 = 2,
36+
NUM_OPTIONS
37+
};
38+
39+
CPUCapability get_cpu_capability();
3340

3441
template <typename FnPtr>
3542
struct DispatchStub {
3643
static_assert(std::is_pointer<FnPtr>::value, "FnPtr should be a pointer type");
3744

3845
template <typename... ArgTypes>
39-
void operator()(ArgTypes... args) {
40-
if (!dispatch_ptr) {
41-
dispatch_ptr = choose_impl();
46+
void operator()(Backend backend, ArgTypes... args) {
47+
if (backend == Backend::CPU) {
48+
if (!dispatch_ptr) {
49+
dispatch_ptr = choose_cpu_impl();
50+
}
51+
(*dispatch_ptr)(args...);
52+
} else if (backend == Backend::CUDA) {
53+
AT_ASSERTM(cuda_dispatch_ptr, "DispatchStub: missing CUDA kernel");
54+
(*cuda_dispatch_ptr)(args...);
55+
} else {
56+
AT_ERROR("DispatchStub: unsupported backend", backend);
4257
}
43-
(*dispatch_ptr)(args...);
4458
}
4559

46-
FnPtr choose_impl() {
47-
// Do not use cpuinfo on PowerPC as it shows confusing errors when run on ppc
48-
#ifndef __powerpc__
49-
if (cpuinfo_initialize()) {
50-
int avx2 = static_cast<int>(CPUCapability::AVX2);
51-
if (!std::getenv("ATEN_DISABLE_AVX2") && cpuinfo_has_x86_avx2() &&
52-
cpuinfo_has_x86_fma3() && table[avx2]) {
53-
return table[avx2];
54-
}
55-
int avx = static_cast<int>(CPUCapability::AVX);
56-
if (!std::getenv("ATEN_DISABLE_AVX") && cpuinfo_has_x86_avx() && table[avx]) {
57-
return table[avx];
58-
}
59-
}
60-
#endif
60+
FnPtr choose_cpu_impl() {
6161
int def = static_cast<int>(CPUCapability::DEFAULT);
62+
int avx = static_cast<int>(CPUCapability::AVX);
63+
int avx2 = static_cast<int>(CPUCapability::AVX2);
64+
65+
auto capability = static_cast<int>(get_cpu_capability());
66+
if (capability >= avx2 && table[avx2]) {
67+
return table[avx2];
68+
}
69+
if (capability >= avx && table[avx]) {
70+
return table[avx];
71+
}
6272
AT_ASSERTM(table[def], "DispatchStub: missing default kernel");
6373
return table[def];
6474
}
6575

6676
FnPtr dispatch_ptr = nullptr;
77+
FnPtr cuda_dispatch_ptr = nullptr;
6778
FnPtr table[static_cast<int>(CPUCapability::NUM_OPTIONS)];
6879
};
6980

7081

71-
#if defined(CPU_CAPABILITY)
82+
#if defined(CPU_CAPABILITY) || defined(__CUDACC__)
7283

73-
constexpr CPUCapability CURRENT_CAPABILITY = CPUCapability::CPU_CAPABILITY;
84+
namespace {
7485

75-
// Registers an implementation a kernel for the current CPU capability.
76-
template<typename FnPtr>
86+
template <typename FnPtr>
7787
struct RegisterDispatch {
7888
RegisterDispatch(DispatchStub<FnPtr>& stub, FnPtr value) {
79-
stub.table[static_cast<int>(CURRENT_CAPABILITY)] = value;
89+
#if defined(__CUDACC__)
90+
stub.cuda_dispatch_ptr = value;
91+
#else
92+
int cap = static_cast<int>(CPUCapability::CPU_CAPABILITY);
93+
AT_ASSERT(!stub.table[cap])
94+
stub.table[cap] = value;
95+
#endif
8096
}
8197
};
8298

83-
// We only define the stub once in the DEFAULT capability compilation
84-
#if defined(CPU_CAPABILITY_DEFAULT)
85-
#define _DEFINE_STUB(stub, fn) DispatchStub<decltype(fn)> stub
86-
#else
87-
#define _DEFINE_STUB(stub, fn)
88-
#endif
99+
} // anonymous namespace
89100

90101
#define REGISTER_DISPATCH(stub, fn) \
91-
_DEFINE_STUB(stub, fn); \
92102
static RegisterDispatch<decltype(fn)> stub ## __register(stub, fn);
93103

94104
#endif

aten/src/ATen/native/ReduceOps.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
namespace at {
1818
namespace native {
1919

20+
DispatchStub<reduce_fn> sum_kernel;
21+
DispatchStub<reduce_fn> prod_kernel;
22+
2023
static inline Tensor integer_upcast(const Tensor& self, optional<ScalarType> dtype) {
2124
ScalarType scalarType = self.type().scalarType();
2225
ScalarType upcast_scalarType = dtype.value_or(at::isIntegralType(scalarType) ? ScalarType::Long : scalarType);
@@ -127,7 +130,7 @@ Tensor sum(const Tensor &self) {
127130
Tensor _sum_cpu(const Tensor& self) {
128131
if (self.is_contiguous()) {
129132
Tensor result = at::empty({}, self.type());
130-
sum_kernel(result, self, at::nullopt);
133+
sum_kernel(kCPU, result, self, at::nullopt);
131134
return result;
132135
}
133136
return self._sumall();
@@ -148,7 +151,7 @@ Tensor prod(const Tensor &self) {
148151
Tensor _prod_cpu(const Tensor &self) {
149152
if (self.is_contiguous()) {
150153
Tensor result = at::empty({}, self.type());
151-
prod_kernel(result, self, at::nullopt);
154+
prod_kernel(kCPU, result, self, at::nullopt);
152155
return result;
153156
}
154157
return self._prodall();
@@ -222,7 +225,7 @@ Tensor &_sum_out_cpu(Tensor &result, const Tensor &self, int64_t dim_,
222225
return result;
223226
if (self.is_contiguous() && result.is_contiguous()) {
224227
_dimreduce_setup(result, self, dim);
225-
sum_kernel(result, self, dim);
228+
sum_kernel(kCPU, result, self, dim);
226229
if (!keepdim) result.squeeze_(dim);
227230
return result;
228231
}
@@ -260,7 +263,7 @@ Tensor &_prod_out_cpu(Tensor &result, const Tensor &self, int64_t dim_,
260263
return result;
261264
if (self.is_contiguous() && result.is_contiguous()) {
262265
_dimreduce_setup(result, self, dim);
263-
prod_kernel(result, self, dim);
266+
prod_kernel(kCPU, result, self, dim);
264267
if (!keepdim) result.squeeze_(dim);
265268
return result;
266269
}

aten/src/ATen/native/SoftMax.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ Tensor softmax_cpu(const Tensor& input_, const int64_t dim_) {
128128
dim >= 0 && dim < input.dim(),
129129
"dim must be non-negative and less than input dimensions");
130130
if (input.ndimension() > 0 && dim == input.ndimension() - 1) {
131-
softmax_lastdim_kernel(output, input);
131+
softmax_lastdim_kernel(kCPU, output, input);
132132
} else {
133133
AT_DISPATCH_FLOATING_TYPES(input.type(), "softmax", [&] {
134134
host_softmax<scalar_t, false>(output, input, dim);
@@ -147,7 +147,7 @@ Tensor log_softmax_cpu(const Tensor& input_, const int64_t dim_) {
147147
dim >= 0 && dim < input.dim(),
148148
"dim must be non-negative and less than input dimensions");
149149
if (input.ndimension() > 0 && dim == input.ndimension() - 1) {
150-
log_softmax_lastdim_kernel(output, input);
150+
log_softmax_lastdim_kernel(kCPU, output, input);
151151
} else {
152152
AT_DISPATCH_FLOATING_TYPES(input.type(), "log_softmax", [&] {
153153
host_softmax<scalar_t, true>(output, input, dim);
@@ -176,7 +176,7 @@ Tensor softmax_backward_cpu(
176176
dim >= 0 && dim < grad.dim(),
177177
"dim must be non-negative and less than input dimensions");
178178
if (grad.ndimension() > 0 && dim == grad.ndimension() - 1) {
179-
softmax_backward_lastdim_kernel(grad_input, grad, output);
179+
softmax_backward_lastdim_kernel(kCPU, grad_input, grad, output);
180180
} else {
181181
AT_DISPATCH_FLOATING_TYPES(grad.type(), "softmax_backward", [&] {
182182
host_softmax_backward<scalar_t, false>(grad_input, grad, output, dim);
@@ -205,13 +205,19 @@ Tensor log_softmax_backward_cpu(
205205
dim >= 0 && dim < grad.dim(),
206206
"dim must be non-negative and less than input dimensions");
207207
if (grad.ndimension() > 0 && dim == grad.ndimension() - 1) {
208-
log_softmax_backward_lastdim_kernel(grad_input, grad, output);
208+
log_softmax_backward_lastdim_kernel(kCPU, grad_input, grad, output);
209209
} else {
210210
AT_DISPATCH_FLOATING_TYPES(grad.type(), "log_softmax_backward", [&] {
211211
host_softmax_backward<scalar_t, true>(grad_input, grad, output, dim);
212212
});
213213
}
214214
return grad_input;
215215
}
216+
217+
DispatchStub<forward_fn> softmax_lastdim_kernel;
218+
DispatchStub<forward_fn> log_softmax_lastdim_kernel;
219+
DispatchStub<backward_fn> softmax_backward_lastdim_kernel;
220+
DispatchStub<backward_fn> log_softmax_backward_lastdim_kernel;
221+
216222
}
217223
}

aten/src/ATen/native/UnaryOps.cpp

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,14 +92,14 @@ Tensor& fill_(Tensor& self, const Tensor& value) {
9292
Tensor& _##op##__cpu(Tensor& self_) { \
9393
if (self_.numel() > 0) { \
9494
Tensor self = sort_strides(self_); \
95-
op##Impl(self, self); \
95+
op##Impl(kCPU, self, self); \
9696
} \
9797
return self_; \
9898
} \
9999
Tensor& _##op##_out_cpu(Tensor& result, const Tensor& self) { \
100100
result.resize_(self.sizes()); \
101101
if (result.numel() > 0) { \
102-
op##Impl(result, self); \
102+
op##Impl(kCPU, result, self); \
103103
} \
104104
return result; \
105105
}
@@ -145,5 +145,29 @@ IMPLEMENT_UNARY_OP_VEC(tan)
145145
IMPLEMENT_UNARY_OP_VEC(tanh)
146146
IMPLEMENT_UNARY_OP_VEC(trunc)
147147

148+
DispatchStub<unary_fn> absImpl;
149+
DispatchStub<unary_fn> acosImpl;
150+
DispatchStub<unary_fn> asinImpl;
151+
DispatchStub<unary_fn> atanImpl;
152+
DispatchStub<unary_fn> ceilImpl;
153+
DispatchStub<unary_fn> cosImpl;
154+
DispatchStub<unary_fn> erfImpl;
155+
DispatchStub<unary_fn> erfcImpl;
156+
DispatchStub<unary_fn> expImpl;
157+
DispatchStub<unary_fn> expm1Impl;
158+
DispatchStub<unary_fn> floorImpl;
159+
DispatchStub<unary_fn> logImpl;
160+
DispatchStub<unary_fn> log10Impl;
161+
DispatchStub<unary_fn> log1pImpl;
162+
DispatchStub<unary_fn> log2Impl;
163+
DispatchStub<unary_fn> roundImpl;
164+
DispatchStub<unary_fn> rsqrtImpl;
165+
DispatchStub<unary_fn> sigmoidImpl;
166+
DispatchStub<unary_fn> sinImpl;
167+
DispatchStub<unary_fn> sqrtImpl;
168+
DispatchStub<unary_fn> tanImpl;
169+
DispatchStub<unary_fn> tanhImpl;
170+
DispatchStub<unary_fn> truncImpl;
171+
148172
}
149173
} // namespace at

aten/src/ATen/native/cpu/ReduceOpsKernel.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#pragma once
22

33
#include <ATen/ATen.h>
4+
#include <ATen/native/DispatchStub.h>
45
#include <ATen/optional.h>
5-
#include "CapabilityDispatch.h"
66

77
namespace at {
88
namespace native {

aten/src/ATen/native/cpu/SoftmaxKernel.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#pragma once
22

33
#include <ATen/ATen.h>
4-
#include "CapabilityDispatch.h"
4+
#include <ATen/native/DispatchStub.h>
55

66
namespace at {
77
namespace native {

aten/src/ATen/native/cpu/UnaryOpsKernel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#include "ATen/Dispatch.h"
55
#include "ATen/cpu/vml.h"
66
#include "ATen/CPUApplyUtils.h"
7-
#include "ATen/native/cpu/CapabilityDispatch.h"
7+
#include "ATen/native/DispatchStub.h"
88
#ifdef __AVX2__
99
#include "ATen/native/cpu/avx_mathfun.h"
1010
#endif

aten/src/ATen/native/cpu/UnaryOpsKernel.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#pragma once
22

33
#include <ATen/ATen.h>
4+
#include <ATen/native/DispatchStub.h>
45
#include <stdexcept>
5-
#include "CapabilityDispatch.h"
66

77
namespace at { namespace native {
88

0 commit comments

Comments
 (0)