Skip to content

Commit 5f2fe7a

Browse files
committed
[Quant][core][gpu][improvement] Refactored implementation for conv2d_cudnn to use packed parameters (Reland PR#73510)
Summary: This a reland of #73510 -- please look at this PR directly for a summary and test plan. The only change in this PR is we add the ops to check_forward_backward_compatibility.py to get around the backwards compatibility issue introduced in the previous PR that changes the name of the cudnn conv operations. ghstack-source-id: d37922b Pull Request resolved: #74220
1 parent 3947038 commit 5f2fe7a

File tree

7 files changed

+424
-129
lines changed

7 files changed

+424
-129
lines changed

aten/src/ATen/native/quantized/cudnn/Conv.cpp

Lines changed: 112 additions & 119 deletions
Large diffs are not rendered by default.
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
#ifdef USE_CUDA
2+
#include <ATen/cuda/CUDAConfig.h> // for the definition of AT_CUDNN_ENABLED
3+
4+
#if AT_CUDNN_ENABLED()
5+
6+
#include <ATen/native/cudnn/Macros.h>
7+
8+
#if HAS_CUDNN_V8()
9+
10+
#include <ATen/ATen.h>
11+
#include <torch/library.h>
12+
#include <ATen/native/quantized/cudnn/cudnnpack_utils.h>
13+
#include <ATen/native/quantized/packed_params.h>
14+
#include <ATen/quantized/Quantizer.h>
15+
#include <c10/core/QScheme.h>
16+
#include <c10/util/irange.h>
17+
#include <torch/library.h>
18+
19+
#include <array>
20+
#include <vector>
21+
22+
template <int kSpatialDim>
23+
c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> PackedConvWeightCudnn<
24+
kSpatialDim>::
25+
prepack(
26+
at::Tensor weight,
27+
c10::optional<at::Tensor> bias,
28+
torch::List<int64_t> stride,
29+
torch::List<int64_t> padding,
30+
torch::List<int64_t> output_padding,
31+
torch::List<int64_t> dilation,
32+
int64_t groups,
33+
bool transpose) {
34+
TORCH_CHECK(weight.qscheme() == c10::kPerTensorAffine, "Unsupported qscheme: ", toString(weight.qscheme()));
35+
TORCH_CHECK(
36+
weight.ndimension() == kSpatialDim + 2,
37+
"Weights are expected to have ",
38+
kSpatialDim + 2,
39+
" dimensions");
40+
TORCH_CHECK(
41+
stride.size() == kSpatialDim,
42+
"stride should contain ",
43+
kSpatialDim,
44+
" elements for ",
45+
kSpatialDim,
46+
"D convolution.");
47+
TORCH_CHECK(
48+
padding.size() == kSpatialDim,
49+
"quantized::conv_prepack (cudnn): Specify front/top/left padding only. "
50+
"end/bottom/right padding assumed to be equal to front/top/left");
51+
TORCH_CHECK(
52+
!transpose || output_padding.size() == kSpatialDim,
53+
"quantized::conv_prepack: Specify top/left output padding "
54+
"only. bottom/right padding assumed to be equal to top/left");
55+
TORCH_CHECK(
56+
dilation.size() == kSpatialDim,
57+
"quantized::conv_prepack (cudnn): dilation should contain ",
58+
kSpatialDim,
59+
" elements for ",
60+
kSpatialDim,
61+
"D convolution.");
62+
const int output_channels = transpose ? weight.size(1) * groups
63+
: weight.size(0);
64+
const auto qtype = weight.qscheme();
65+
if (bias.has_value()) {
66+
TORCH_CHECK(bias.value().dim() == 1, "bias should be a vector (1D Tensor)");
67+
TORCH_CHECK(
68+
bias.value().size(0) == output_channels,
69+
"bias should have K elements: " + std::to_string(output_channels));
70+
// TODO: we create a broadcasted_bias tensor later so I think we don't need to make this contiguous here.
71+
// we will revisit this when nvidia adds proper support for broadcasting
72+
// bias_contig = bias->contiguous();
73+
}
74+
75+
auto ret_ptr = c10::make_intrusive<PackedConvWeightCudnn<kSpatialDim>>(
76+
weight.contiguous(c10::MemoryFormat::ChannelsLast), // TODO: this assumes 2D I think. make it more general?
77+
bias,
78+
stride,
79+
padding,
80+
output_padding,
81+
dilation,
82+
groups,
83+
transpose,
84+
qtype);
85+
return ret_ptr;
86+
}
87+
88+
template
89+
c10::intrusive_ptr<ConvPackedParamsBase<2>> PackedConvWeightCudnn<
90+
2>::
91+
prepack(
92+
at::Tensor weight,
93+
c10::optional<at::Tensor> bias_in,
94+
torch::List<int64_t> stride,
95+
torch::List<int64_t> padding,
96+
torch::List<int64_t> output_padding,
97+
torch::List<int64_t> dilation,
98+
int64_t groups,
99+
bool transpose);
100+
101+
namespace at {
102+
namespace native {
103+
namespace {
104+
105+
template <int kSpatialDim = 2>
106+
class QConvPackWeightInt8Cudnn final {
107+
public:
108+
static c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> run_conv(
109+
Tensor weight,
110+
c10::optional<Tensor> bias,
111+
torch::List<int64_t> stride,
112+
torch::List<int64_t> padding,
113+
torch::List<int64_t> dilation,
114+
int64_t groups) {
115+
torch::List<int64_t> output_padding;
116+
output_padding.reserve(kSpatialDim);
117+
for (const auto idx : c10::irange(kSpatialDim)) {
118+
(void)idx; //Suppress unused variable warning
119+
output_padding.push_back((int64_t)0);
120+
}
121+
return _run(weight, bias, stride, padding, output_padding, dilation, groups,
122+
/*transpose=*/false);
123+
}
124+
125+
private:
126+
static c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> _run(
127+
Tensor weight,
128+
c10::optional<Tensor> bias,
129+
torch::List<int64_t> stride,
130+
torch::List<int64_t> padding,
131+
torch::List<int64_t> output_padding,
132+
torch::List<int64_t> dilation,
133+
int64_t groups,
134+
bool transpose) {
135+
return PackedConvWeightCudnn<kSpatialDim>::prepack(
136+
weight, bias, stride, padding, output_padding, dilation, groups,
137+
transpose);
138+
}
139+
};
140+
141+
TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) {
142+
m.impl(TORCH_SELECTIVE_NAME("quantized::conv2d_prepack"), TORCH_FN(QConvPackWeightInt8Cudnn<2>::run_conv));
143+
}
144+
145+
} // namespace
146+
} // namespace native
147+
} // namespace at
148+
149+
#endif // HAS_CUDNN_V8
150+
#endif // AT_CUDNN_ENABLED
151+
#endif // USE_CUDA
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#ifdef USE_CUDA
2+
#include <ATen/cuda/CUDAConfig.h> // for the definition of AT_CUDNN_ENABLED
3+
4+
#if AT_CUDNN_ENABLED()
5+
6+
#include <ATen/native/cudnn/Macros.h>
7+
8+
#if HAS_CUDNN_V8()
9+
10+
#include <ATen/ATen.h>
11+
#include <ATen/native/quantized/cudnn/cudnnpack_utils.h>
12+
#include <ATen/native/quantized/packed_params.h>
13+
#include <torch/library.h>
14+
15+
#include <tuple>
16+
17+
template <int kSpatialDim>
18+
std::tuple<at::Tensor, c10::optional<at::Tensor>> PackedConvWeightCudnn<
19+
kSpatialDim>::unpack() {
20+
return std::tuple<at::Tensor, c10::optional<at::Tensor>>{orig_weight_, bias_};
21+
}
22+
23+
template std::tuple<at::Tensor, c10::optional<at::Tensor>> PackedConvWeightCudnn<
24+
2>::unpack();
25+
26+
#endif // HAS_CUDNN_V8
27+
#endif // AT_CUDNN_ENABLED
28+
#endif // USE_CUDA
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
#pragma once
2+
3+
#ifdef USE_CUDA
4+
#include <ATen/cuda/CUDAConfig.h> // for the definition of AT_CUDNN_ENABLED
5+
6+
#if AT_CUDNN_ENABLED()
7+
8+
#include <ATen/native/cudnn/Macros.h>
9+
10+
#if HAS_CUDNN_V8()
11+
12+
#include <ATen/Tensor.h>
13+
#include <ATen/native/quantized/packed_params.h>
14+
#include <c10/core/QScheme.h>
15+
16+
template <int kSpatialDim = 2>
17+
struct TORCH_API PackedConvWeightCudnn : public ConvPackedParamsBase<kSpatialDim> {
18+
PackedConvWeightCudnn(
19+
at::Tensor orig_weight,
20+
c10::optional<at::Tensor> bias,
21+
torch::List<int64_t> stride,
22+
torch::List<int64_t> padding,
23+
torch::List<int64_t> output_padding,
24+
torch::List<int64_t> dilation,
25+
int64_t groups,
26+
bool transpose,
27+
c10::QScheme q_scheme)
28+
: orig_weight_(std::move(orig_weight)),
29+
bias_(std::move(bias)),
30+
stride_(std::move(stride)),
31+
padding_(std::move(padding)),
32+
output_padding_(std::move(output_padding)),
33+
dilation_(std::move(dilation)),
34+
groups_(groups),
35+
transpose_(transpose),
36+
q_scheme_(q_scheme) {}
37+
38+
at::Tensor apply(
39+
const at::Tensor& input,
40+
double output_scale,
41+
int64_t output_zero_point) override;
42+
43+
at::Tensor apply_relu(
44+
const at::Tensor& input,
45+
double output_scale,
46+
int64_t output_zero_point) override;
47+
48+
at::Tensor apply_dynamic(
49+
const at::Tensor& input,
50+
bool reduce_range) {
51+
TORCH_CHECK(false, "apply_dynamic is currently not reported");
52+
}
53+
54+
at::Tensor apply_dynamic_relu(
55+
const at::Tensor& input,
56+
bool reduce_range) {
57+
TORCH_CHECK(false, "apply_dynamic_relu is currently not reported");
58+
}
59+
60+
std::tuple<at::Tensor, c10::optional<at::Tensor>> unpack() override;
61+
62+
static c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> prepack(
63+
at::Tensor weight,
64+
c10::optional<at::Tensor> bias,
65+
torch::List<int64_t> stride,
66+
torch::List<int64_t> padding,
67+
torch::List<int64_t> output_padding,
68+
torch::List<int64_t> dilation,
69+
int64_t groups,
70+
bool transpose);
71+
72+
const float* GetBiasData(at::Tensor* bias);
73+
74+
torch::List<int64_t> stride() const override {
75+
return stride_;
76+
}
77+
78+
torch::List<int64_t> padding() const override {
79+
return padding_;
80+
}
81+
82+
torch::List<int64_t> output_padding() const override {
83+
return output_padding_;
84+
}
85+
86+
torch::List<int64_t> dilation() const override {
87+
return dilation_;
88+
}
89+
90+
int64_t groups() const override {
91+
return groups_;
92+
}
93+
94+
bool transpose() const override {
95+
return transpose_;
96+
}
97+
98+
private:
99+
at::Tensor orig_weight_;
100+
c10::optional<at::Tensor> bias_;
101+
torch::List<int64_t> stride_;
102+
torch::List<int64_t> padding_;
103+
torch::List<int64_t> output_padding_;
104+
torch::List<int64_t> dilation_;
105+
int64_t groups_;
106+
bool transpose_;
107+
c10::QScheme q_scheme_;
108+
109+
template <bool ReluFused>
110+
at::Tensor apply_impl(
111+
const at::Tensor& input,
112+
double output_scale,
113+
int64_t output_zero_point);
114+
115+
template <bool ReluFused>
116+
void apply_impl_helper(
117+
const at::Tensor& quantized_output,
118+
const at::Tensor& input,
119+
double bias_multiplier,
120+
double requantize_multiplier);
121+
};
122+
123+
#endif // HAS_CUDNN_V8
124+
#endif // AT_CUDNN_ENABLED
125+
#endif // USE_CUDA

aten/src/ATen/native/quantized/library.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -188,11 +188,6 @@ TORCH_LIBRARY(quantized, m) {
188188
m.def(TORCH_SELECTIVE_SCHEMA("quantized::relu6(Tensor qx, bool inplace=False) -> Tensor"));
189189
m.def(TORCH_SELECTIVE_SCHEMA("quantized::leaky_relu(Tensor qx, Scalar negative_slope, bool inplace, float output_scale, int output_zero_point) -> Tensor"));
190190
m.def(TORCH_SELECTIVE_SCHEMA("quantized::sigmoid(Tensor qx, float output_scale, int output_zero_point) -> Tensor"));
191-
192-
// quantized ops implemented in cudnn, with QuantizedCUDA dispatch
193-
// TODO: use the same signature as quantized::conv2d
194-
m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv2d_cudnn(Tensor act, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, float output_scale, int output_zero_point) -> Tensor"));
195-
m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv2d_relu_cudnn(Tensor act, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, float output_scale, int output_zero_point) -> Tensor"));
196191
}
197192

198193
// According to #33294: The "_" prefix registration will be

test/forward_backward_compatibility/check_forward_backward_compatibility.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,8 @@
111111
("aten::_transform_bias_rescale_qkv", datetime.date(9999, 1, 1)),
112112
("aten::_scatter_reduce.two", datetime.date(9999, 1, 1)),
113113
("aten::_s_where", datetime.date(2022, 9, 30)),
114+
("quantized::conv2d_cudnn", datetime.date(2022, 3, 22)),
115+
("quantized::conv2d_relu_cudnn", datetime.date(2022, 3, 22)),
114116
]
115117

116118
ALLOW_LIST_COMPILED = [

test/quantization/core/test_quantized_op.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4249,9 +4249,9 @@ def test_qconv2d_cudnn(
42494249
dilations = (dilation, dilation)
42504250

42514251
if use_relu:
4252-
qconv = torch.ops.quantized.conv2d_relu_cudnn
4252+
qconv = torch.ops.quantized.conv2d_relu
42534253
else:
4254-
qconv = torch.ops.quantized.conv2d_cudnn
4254+
qconv = torch.ops.quantized.conv2d
42554255
conv_op = torch.nn.Conv2d(
42564256
input_channels,
42574257
output_channels,
@@ -4262,7 +4262,7 @@ def test_qconv2d_cudnn(
42624262
groups,
42634263
).to(torch.device("cuda"))
42644264
self._test_qconv_impl(
4265-
qconv, None, conv_op, batch_size,
4265+
qconv, torch.ops.quantized.conv2d_prepack, conv_op, batch_size,
42664266
input_channels_per_group, (height, width),
42674267
output_channels_per_group, groups, kernels, strides, pads, None,
42684268
dilations, X_scale, X_zero_point, W_scale, W_zero_point,
@@ -4338,13 +4338,14 @@ def trace_handler(p):
43384338
weight_int8 = torch.quantize_per_tensor(weight, 1, 0, torch.qint8).contiguous(memory_format=torch.channels_last)
43394339
scale = 1.0
43404340
zero_point = 0
4341-
conv_op = torch.ops.quantized.conv2d_cudnn
4341+
conv_op = torch.ops.quantized.conv2d
4342+
weight_prepacked = torch.ops.quantized.conv2d_prepack(weight_int8, None, stride, padding, dilation, groups)
43424343
with profile(
43434344
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
43444345
schedule=my_schedule,
43454346
on_trace_ready=trace_handler) as prof:
43464347
for i in range(30):
4347-
conv_op(input_int8, weight_int8, None, stride, padding, dilation, groups, scale, zero_point)
4348+
conv_op(input_int8, weight_prepacked, scale, zero_point)
43484349
prof.step()
43494350

43504351
print("int8 benchmark result:")

0 commit comments

Comments
 (0)