|
1 | 1 | #pragma once
|
2 | 2 |
|
3 |
| -#include <cpuinfo.h> |
| 3 | +#include <ATen/Error.h> |
| 4 | +#include <ATen/ScalarType.h> |
4 | 5 | #include <type_traits>
|
5 |
| -#include <iostream> |
6 | 6 |
|
7 | 7 | // Implements instruction set specific function dispatch.
|
8 | 8 | //
|
|
23 | 23 | // REGISTER_DISPATCH(stub, &kernel);
|
24 | 24 | //
|
25 | 25 | // To call:
|
26 |
| -// stub(tensor); |
| 26 | +// stub(kCPU, tensor); |
27 | 27 | //
|
28 | 28 |
|
29 | 29 | namespace at {
|
30 | 30 | namespace native {
|
31 | 31 |
|
32 |
| -enum class CPUCapability { DEFAULT, AVX, AVX2, NUM_OPTIONS }; |
| 32 | +enum class CPUCapability { |
| 33 | + DEFAULT = 0, |
| 34 | + AVX = 1, |
| 35 | + AVX2 = 2, |
| 36 | + NUM_OPTIONS |
| 37 | +}; |
| 38 | + |
| 39 | +CPUCapability get_cpu_capability(); |
33 | 40 |
|
34 | 41 | template <typename FnPtr>
|
35 | 42 | struct DispatchStub {
|
36 | 43 | static_assert(std::is_pointer<FnPtr>::value, "FnPtr should be a pointer type");
|
37 | 44 |
|
38 | 45 | template <typename... ArgTypes>
|
39 |
| - void operator()(ArgTypes... args) { |
40 |
| - if (!dispatch_ptr) { |
41 |
| - dispatch_ptr = choose_impl(); |
| 46 | + void operator()(Backend backend, ArgTypes... args) { |
| 47 | + if (backend == Backend::CPU) { |
| 48 | + if (!dispatch_ptr) { |
| 49 | + dispatch_ptr = choose_cpu_impl(); |
| 50 | + } |
| 51 | + (*dispatch_ptr)(args...); |
| 52 | + } else if (backend == Backend::CUDA) { |
| 53 | + AT_ASSERTM(cuda_dispatch_ptr, "DispatchStub: missing CUDA kernel"); |
| 54 | + (*cuda_dispatch_ptr)(args...); |
| 55 | + } else { |
| 56 | + AT_ERROR("DispatchStub: unsupported backend", backend); |
42 | 57 | }
|
43 |
| - (*dispatch_ptr)(args...); |
44 | 58 | }
|
45 | 59 |
|
46 |
| - FnPtr choose_impl() { |
47 |
| -// Do not use cpuinfo on PowerPC as it shows confusing errors when run on ppc |
48 |
| -#ifndef __powerpc__ |
49 |
| - if (cpuinfo_initialize()) { |
50 |
| - int avx2 = static_cast<int>(CPUCapability::AVX2); |
51 |
| - if (!std::getenv("ATEN_DISABLE_AVX2") && cpuinfo_has_x86_avx2() && |
52 |
| - cpuinfo_has_x86_fma3() && table[avx2]) { |
53 |
| - return table[avx2]; |
54 |
| - } |
55 |
| - int avx = static_cast<int>(CPUCapability::AVX); |
56 |
| - if (!std::getenv("ATEN_DISABLE_AVX") && cpuinfo_has_x86_avx() && table[avx]) { |
57 |
| - return table[avx]; |
58 |
| - } |
59 |
| - } |
60 |
| -#endif |
| 60 | + FnPtr choose_cpu_impl() { |
61 | 61 | int def = static_cast<int>(CPUCapability::DEFAULT);
|
| 62 | + int avx = static_cast<int>(CPUCapability::AVX); |
| 63 | + int avx2 = static_cast<int>(CPUCapability::AVX2); |
| 64 | + |
| 65 | + auto capability = static_cast<int>(get_cpu_capability()); |
| 66 | + if (capability >= avx2 && table[avx2]) { |
| 67 | + return table[avx2]; |
| 68 | + } |
| 69 | + if (capability >= avx && table[avx]) { |
| 70 | + return table[avx]; |
| 71 | + } |
62 | 72 | AT_ASSERTM(table[def], "DispatchStub: missing default kernel");
|
63 | 73 | return table[def];
|
64 | 74 | }
|
65 | 75 |
|
66 | 76 | FnPtr dispatch_ptr = nullptr;
|
| 77 | + FnPtr cuda_dispatch_ptr = nullptr; |
67 | 78 | FnPtr table[static_cast<int>(CPUCapability::NUM_OPTIONS)];
|
68 | 79 | };
|
69 | 80 |
|
70 | 81 |
|
71 |
| -#if defined(CPU_CAPABILITY) |
| 82 | +#if defined(CPU_CAPABILITY) || defined(__CUDACC__) |
72 | 83 |
|
73 |
| -constexpr CPUCapability CURRENT_CAPABILITY = CPUCapability::CPU_CAPABILITY; |
| 84 | +namespace { |
74 | 85 |
|
75 |
| -// Registers an implementation a kernel for the current CPU capability. |
76 |
| -template<typename FnPtr> |
| 86 | +template <typename FnPtr> |
77 | 87 | struct RegisterDispatch {
|
78 | 88 | RegisterDispatch(DispatchStub<FnPtr>& stub, FnPtr value) {
|
79 |
| - stub.table[static_cast<int>(CURRENT_CAPABILITY)] = value; |
| 89 | +#if defined(__CUDACC__) |
| 90 | + stub.cuda_dispatch_ptr = value; |
| 91 | +#else |
| 92 | + int cap = static_cast<int>(CPUCapability::CPU_CAPABILITY); |
| 93 | + AT_ASSERT(!stub.table[cap]) |
| 94 | + stub.table[cap] = value; |
| 95 | +#endif |
80 | 96 | }
|
81 | 97 | };
|
82 | 98 |
|
83 |
| -// We only define the stub once in the DEFAULT capability compilation |
84 |
| -#if defined(CPU_CAPABILITY_DEFAULT) |
85 |
| -#define _DEFINE_STUB(stub, fn) DispatchStub<decltype(fn)> stub |
86 |
| -#else |
87 |
| -#define _DEFINE_STUB(stub, fn) |
88 |
| -#endif |
| 99 | +} // anonymous namespace |
89 | 100 |
|
90 | 101 | #define REGISTER_DISPATCH(stub, fn) \
|
91 |
| - _DEFINE_STUB(stub, fn); \ |
92 | 102 | static RegisterDispatch<decltype(fn)> stub ## __register(stub, fn);
|
93 | 103 |
|
94 | 104 | #endif
|
|
0 commit comments