|
| 1 | +/* |
| 2 | +The dispatch registrations at the end of this file applies to fbgemm, qnnpack, and cudnn backends. |
| 3 | +The correct unpack backend function is determined using runtime polymorphism through the packed_weight pointer, |
| 4 | +which is of type intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> and points to either a PackedConvWeightsQnnp, |
| 5 | +PackedConvWeights (Fbgemm), or PackedConvWeightsCudnn at runtime, which all inherit from ConvPackedParamsBase. |
| 6 | +The implementations for the unpack functions can be found in /cpu/qconv_unpack_impl.cpp, for fbgemm&qnnpack |
| 7 | +and /cudnn/conv_unpack_impl.cpp, for cudnn. |
| 8 | +*/ |
| 9 | + |
1 | 10 | #include <tuple>
|
2 |
| -#include <vector> |
3 | 11 |
|
4 | 12 | #include <ATen/ATen.h>
|
5 | 13 | #include <torch/library.h>
|
|
8 | 16 | #include <ATen/native/quantized/cpu/quant_utils.h>
|
9 | 17 | #include <ATen/native/quantized/packed_params.h>
|
10 | 18 |
|
11 |
| -#ifdef USE_FBGEMM |
12 |
| -template <int kSpatialDim> |
13 |
| -std::tuple<at::Tensor, c10::optional<at::Tensor>> PackedConvWeight< |
14 |
| - kSpatialDim>::unpack() { |
15 |
| - auto* packed_weights_p = w.get(); |
16 |
| - // output channels |
17 |
| - const int output_channels = packed_weights_p->outputChannels(); |
18 |
| - const int input_channels = packed_weights_p->inputChannels(); |
19 |
| - const int groups = packed_weights_p->groups(); |
20 |
| - |
21 |
| - const int kernel_d = kSpatialDim == 2 ? 1 : kernel[0]; |
22 |
| - // R (kernel height) |
23 |
| - const int kernel_h = kernel[kSpatialDim - 2]; |
24 |
| - // S (kernel width) |
25 |
| - const int kernel_w = kernel[kSpatialDim - 1]; |
26 |
| - |
27 |
| - const int C_per_G = input_channels / groups; |
28 |
| - |
29 |
| - // Tensor for unpacked weights |
30 |
| - // Unpacked format would be physical KRS(C/G) but logical KCRS (channels |
31 |
| - // first) because that's how |
32 |
| - // ChannelsLast3d is not available now.FBGEMM stores the weights |
33 |
| - // TODO: Unify 2d and 3d when ChannelsLast3d is ready. |
34 |
| - at::Tensor unpacked_weights; |
35 |
| - if (q_scheme == c10::kPerTensorAffine) { |
36 |
| - unpacked_weights = kSpatialDim == 2 |
37 |
| - ? at::_empty_affine_quantized( |
38 |
| - {output_channels, C_per_G, kernel_h, kernel_w}, |
39 |
| - device(c10::kCPU) |
40 |
| - .dtype(c10::kQInt8) |
41 |
| - .memory_format(c10::MemoryFormat::ChannelsLast), |
42 |
| - w_scale[0], |
43 |
| - w_zp[0], |
44 |
| - c10::nullopt) |
45 |
| - : at::native::fbgemm_utils:: |
46 |
| - MakeEmptyAffineQuantizedChannelsLast3dTensor( |
47 |
| - output_channels, |
48 |
| - C_per_G, |
49 |
| - kernel_d, |
50 |
| - kernel_h, |
51 |
| - kernel_w, |
52 |
| - device(c10::kCPU).dtype(c10::kQInt8), |
53 |
| - w_scale[0], |
54 |
| - w_zp[0]); |
55 |
| - } else if (q_scheme == c10::kPerChannelAffine) { |
56 |
| - TORCH_CHECK( |
57 |
| - !transpose(), |
58 |
| - "Per Channel Quantization is currently disabled for transposed conv"); |
59 |
| - auto scales = at::from_blob( |
60 |
| - w_scale.data(), w_scale.size(), device(c10::kCPU).dtype(c10::kFloat)); |
61 |
| - auto zero_points = at::from_blob( |
62 |
| - w_zp.data(), w_zp.size(), device(c10::kCPU).dtype(c10::kInt)); |
63 |
| - unpacked_weights = kSpatialDim == 2 |
64 |
| - ? at::_empty_per_channel_affine_quantized( |
65 |
| - {output_channels, C_per_G, kernel_h, kernel_w}, |
66 |
| - scales.toType(c10::kDouble), |
67 |
| - zero_points.toType(c10::kLong), |
68 |
| - 0, /* The output channel axis is 0 */ |
69 |
| - device(c10::kCPU).dtype(c10::kQInt8), |
70 |
| - c10::MemoryFormat::ChannelsLast) |
71 |
| - : at::native::fbgemm_utils:: |
72 |
| - MakeEmptyPerChannelAffineQuantizedChannelsLast3dTensor( |
73 |
| - output_channels, |
74 |
| - C_per_G, |
75 |
| - kernel_d, |
76 |
| - kernel_h, |
77 |
| - kernel_w, |
78 |
| - device(c10::kCPU).dtype(c10::kQInt8), |
79 |
| - scales.toType(c10::kDouble), |
80 |
| - zero_points.toType(c10::kLong)); |
81 |
| - } else { |
82 |
| - TORCH_CHECK(false, "Unsupported qscheme: ", toString(q_scheme)); |
83 |
| - } |
84 |
| - int8_t* unpacked_weights_p = |
85 |
| - reinterpret_cast<int8_t*>(unpacked_weights.data_ptr<c10::qint8>()); |
86 |
| - packed_weights_p->unpack(unpacked_weights_p); |
87 |
| - if(transpose()){ |
88 |
| - unpacked_weights = |
89 |
| - at::native::fbgemm_utils::TransposeConvTensorUnpackConversion< |
90 |
| - kSpatialDim>(unpacked_weights, groups); |
91 |
| - } |
92 |
| - return std::tuple<at::Tensor, c10::optional<at::Tensor>>( |
93 |
| - unpacked_weights, bias); |
94 |
| -} |
95 |
| - |
96 |
| -template std::tuple<at::Tensor, c10::optional<at::Tensor>> PackedConvWeight< |
97 |
| - 2>::unpack(); |
98 |
| -template std::tuple<at::Tensor, c10::optional<at::Tensor>> PackedConvWeight< |
99 |
| - 3>::unpack(); |
100 |
| -#endif // USE_FBGEMM |
101 |
| - |
102 |
| -#ifdef USE_PYTORCH_QNNPACK |
103 |
| -template <int kSpatialDim> |
104 |
| -std::tuple<at::Tensor, c10::optional<at::Tensor>> PackedConvWeightsQnnp< |
105 |
| - kSpatialDim>::unpack() { |
106 |
| - TORCH_CHECK( |
107 |
| - kSpatialDim == 2, |
108 |
| - "QNNPACK only supports conv2d_unpack right " |
109 |
| - "now."); |
110 |
| - TORCH_CHECK( |
111 |
| - orig_weight.defined(), |
112 |
| - "Cannot unpack weights. " |
113 |
| - "Call at::globalContext()::setReleaseOriginalWeights(false) before packing or loading to enable unpacking."); |
114 |
| - return std::tuple<at::Tensor, c10::optional<at::Tensor>>(orig_weight, bias); |
115 |
| -} |
116 |
| - |
117 |
| -template std::tuple<at::Tensor, c10::optional<at::Tensor>> PackedConvWeightsQnnp< |
118 |
| - 2>::unpack(); |
119 |
| -template std::tuple<at::Tensor, c10::optional<at::Tensor>> PackedConvWeightsQnnp< |
120 |
| - 3>::unpack(); |
121 |
| -#endif // USE_PYTORCH_QNNPACK |
122 |
| - |
123 | 19 | namespace at {
|
124 | 20 | namespace native {
|
125 | 21 | namespace {
|
|
0 commit comments