Skip to content

Commit 28990f3

Browse files
csummerseafacebook-github-bot
authored andcommitted
fix bug when falling back to acc32 when weight is prepacked (pytorch#18881)
Summary: Pull Request resolved: pytorch#18881 Pull Request resolved: pytorch#18878 When the weight is prepacked and it doesn't contain a prepacked weight for acc32, we shouldn't fallback to acc32. TODO: add unit tests with better coverage Reviewed By: feiyu1990 Differential Revision: D14778810 fbshipit-source-id: d49a8c4b7c815ab29b77feb53ee730ad63780488
1 parent c1790fa commit 28990f3

File tree

4 files changed

+210
-122
lines changed

4 files changed

+210
-122
lines changed

caffe2/quantization/server/conv_dnnlowp_acc16_op.cc

Lines changed: 92 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
C10_DECLARE_int32(caffe2_dnnlowp_nbits_in_non_outlier);
1818
C10_DECLARE_int32(caffe2_dnnlowp_copy_to_32bit_frequency);
1919
C10_DECLARE_bool(caffe2_dnnlowp_shared_int32_buffer);
20-
2120
// Thresholds to fallback to 32-bit accumulation when 16-bit accumulation
2221
// doesn't provide performance benefits.
2322
C10_DEFINE_double(
@@ -62,43 +61,26 @@ ConvDNNLowPAcc16Op<ReluFused>::ConvDNNLowPAcc16Op(
6261
template <bool ReluFused>
6362
bool ConvDNNLowPAcc16Op<ReluFused>::GetQuantizationParameters_() {
6463
if (fallback_to_32_bit_accumulation_) {
65-
return true;
66-
}
67-
68-
if (!BaseType::GetQuantizationParameters_()) {
69-
return false;
70-
}
71-
72-
if (!Wq_acc16_packed_ &&
73-
this->template InputIsType<Int8ConvDNNLowPPackedWeightBlob>(FILTER)) {
74-
CAFFE_ENFORCE_EQ(
75-
this->order_,
76-
StorageOrder::NHWC,
77-
"Pre-packed weight only works with NHWC layout");
78-
// If the input is already packed
79-
const auto& packed_filter =
80-
this->template Input<Int8ConvDNNLowPPackedWeightBlob>(FILTER);
81-
Wq_outlier_ = packed_filter.W_outlier;
82-
Wq_acc16_packed_ = packed_filter.W_acc16;
83-
84-
if (nbits_in_non_outlier_ != packed_filter.nbits_in_non_outlier) {
85-
LOG(WARNING)
86-
<< "nbits_in_non_outlier in packed weight "
87-
<< packed_filter.nbits_in_non_outlier
88-
<< " doesn't match with nbits_in_non_outlier specified in operator "
89-
<< nbits_in_non_outlier_;
90-
}
91-
92-
first_invocation_ = false;
93-
return true;
64+
// Short cut if we already know we are falling back to acc32
65+
return BaseType::GetQuantizationParameters_();
9466
}
9567

9668
int kernel_dim = this->KernelDim_();
9769
const auto& filter = InputTensorCPU_(FILTER);
9870
int num_out_channels = filter.dim32(0);
9971

10072
// Check if we should fallback to 32-bit accumulation
101-
if (this->order_ == StorageOrder::NHWC) {
73+
// We should do this before GetQuantizationParameters_ to make sure
74+
// GetQuantizationParameters_ initialize things like Wq_packed_ for acc32
75+
// properly.
76+
77+
// We can't fallback if layout is not NHWC or
78+
// if weight is prepacked and the prepacked weight doesn't have acc32.
79+
bool can_fallback_to_32_bit_accumulation =
80+
this->order_ == StorageOrder::NHWC &&
81+
(!this->template InputIsType<Int8ConvDNNLowPPackedWeightBlob>(FILTER) ||
82+
this->template Input<Int8ConvDNNLowPPackedWeightBlob>(FILTER).W);
83+
if (can_fallback_to_32_bit_accumulation) {
10284
const Tensor& X = InputTensorCPU_(INPUT);
10385
int N = X.dim32(0);
10486

@@ -121,31 +103,71 @@ bool ConvDNNLowPAcc16Op<ReluFused>::GetQuantizationParameters_() {
121103
}
122104

123105
if (N * output_image_size < FLAGS_caffe2_dnnlowp_acc16_m_threshold) {
124-
LOG(INFO) << "M " << N * output_image_size
125-
<< " of Conv layer with weight blob "
126-
<< this->debug_def().input(1) << " is smaller than threshold "
127-
<< FLAGS_caffe2_dnnlowp_acc16_m_threshold
128-
<< " . Falling back to acc32";
106+
LOG_FIRST_N(INFO, 10)
107+
<< "M " << N * output_image_size << " of Conv layer with weight blob "
108+
<< this->debug_def().input(FILTER) << " is smaller than threshold "
109+
<< FLAGS_caffe2_dnnlowp_acc16_m_threshold
110+
<< " . Falling back to acc32";
111+
fallback_to_32_bit_accumulation_ = true;
112+
}
113+
if (!fallback_to_32_bit_accumulation_ &&
114+
num_out_channels / group_ < acc16_n_threshold) {
115+
LOG_FIRST_N(INFO, 10)
116+
<< "N " << num_out_channels / group_
117+
<< " of Conv layer with weight blob "
118+
<< this->debug_def().input(FILTER) << " is smaller than threshold "
119+
<< acc16_n_threshold << " . Falling back to acc32";
129120
fallback_to_32_bit_accumulation_ = true;
130-
return true;
131121
}
132-
if (num_out_channels / group_ < acc16_n_threshold) {
133-
LOG(INFO) << "N " << num_out_channels / group_
134-
<< " of Conv layer with weight blob "
135-
<< this->debug_def().input(1) << " is smaller than threshold "
136-
<< acc16_n_threshold << " . Falling back to acc32";
122+
if (!fallback_to_32_bit_accumulation_ && kernel_dim < acc16_k_threshold) {
123+
LOG_FIRST_N(INFO, 10)
124+
<< "K " << kernel_dim << " of Conv layer with weight blob "
125+
<< this->debug_def().input(FILTER) << " is smaller than threshold "
126+
<< acc16_k_threshold << " . Falling back to acc32";
137127
fallback_to_32_bit_accumulation_ = true;
138-
return true;
139128
}
140-
if (kernel_dim < acc16_k_threshold) {
141-
LOG(INFO) << "K " << kernel_dim << " of Conv layer with weight blob "
142-
<< this->debug_def().input(1) << " is smaller than threshold "
143-
<< acc16_k_threshold << " . Falling back to acc32";
129+
if (!fallback_to_32_bit_accumulation_ &&
130+
this->template InputIsType<Int8ConvDNNLowPPackedWeightBlob>(FILTER) &&
131+
!this->template Input<Int8ConvDNNLowPPackedWeightBlob>(FILTER)
132+
.W_acc16) {
133+
LOG_FIRST_N(INFO, 10)
134+
<< "Falling back to acc32 because packed weight for acc16 is not "
135+
"available";
144136
fallback_to_32_bit_accumulation_ = true;
145-
return true;
146137
}
147138
}
148139

140+
if (!BaseType::GetQuantizationParameters_()) {
141+
return false;
142+
}
143+
144+
if (fallback_to_32_bit_accumulation_) {
145+
return true;
146+
}
147+
148+
if (!Wq_acc16_packed_ &&
149+
this->template InputIsType<Int8ConvDNNLowPPackedWeightBlob>(FILTER)) {
150+
CAFFE_ENFORCE_EQ(
151+
this->order_,
152+
StorageOrder::NHWC,
153+
"Pre-packed weight only works with NHWC layout");
154+
// If the input is already packed
155+
const auto& packed_filter =
156+
this->template Input<Int8ConvDNNLowPPackedWeightBlob>(FILTER);
157+
Wq_outlier_ = packed_filter.W_outlier;
158+
Wq_acc16_packed_ = packed_filter.W_acc16;
159+
160+
if (nbits_in_non_outlier_ != packed_filter.nbits_in_non_outlier) {
161+
LOG_FIRST_N(WARNING, 10)
162+
<< "nbits_in_non_outlier in packed weight "
163+
<< packed_filter.nbits_in_non_outlier
164+
<< " doesn't match with nbits_in_non_outlier specified in operator "
165+
<< nbits_in_non_outlier_;
166+
}
167+
first_invocation_ = false;
168+
return true;
169+
}
170+
149171
// Separate out outliers
150172
if (!Wq_outlier_ && this->order_ == StorageOrder::NHWC &&
151173
nbits_in_non_outlier_ < 8) {
@@ -159,20 +181,24 @@ bool ConvDNNLowPAcc16Op<ReluFused>::GetQuantizationParameters_() {
159181
W_quantized_));
160182
int outlier_cnt = Wq_outlier_->ColPtr()[num_out_channels];
161183

162-
LOG(INFO) << "Proportion of outlier for Conv layer with weight blob "
163-
<< this->debug_def().input(1) << " is "
164-
<< static_cast<float>(outlier_cnt) / W_quantized_.size();
165-
LOG(INFO) << "nbits_in_non_outlier " << nbits_in_non_outlier_
166-
<< " copy_to_32bit_frequency " << copy_to_32bit_frequency_;
167-
168-
if (static_cast<float>(outlier_cnt) / W_quantized_.size() >
169-
FLAGS_caffe2_dnnlowp_acc16_density_threshold) {
170-
LOG(INFO) << "Density of outliers is higher than threshold "
171-
<< FLAGS_caffe2_dnnlowp_acc16_density_threshold
172-
<< " . Falling back to acc32";
184+
LOG_FIRST_N(INFO, 10)
185+
<< "Proportion of outlier for Conv layer with weight blob "
186+
<< this->debug_def().input(FILTER) << " is "
187+
<< static_cast<float>(outlier_cnt) / W_quantized_.size();
188+
LOG_FIRST_N(INFO, 10) << "nbits_in_non_outlier " << nbits_in_non_outlier_
189+
<< " copy_to_32bit_frequency "
190+
<< copy_to_32bit_frequency_;
191+
192+
if (can_fallback_to_32_bit_accumulation &&
193+
static_cast<float>(outlier_cnt) / W_quantized_.size() >
194+
FLAGS_caffe2_dnnlowp_acc16_density_threshold) {
195+
LOG_FIRST_N(INFO, 10) << "Density of outliers is higher than threshold "
196+
<< FLAGS_caffe2_dnnlowp_acc16_density_threshold
197+
<< " . Falling back to acc32";
173198
fallback_to_32_bit_accumulation_ = true;
174199
Wq_outlier_.reset();
175-
return true;
200+
// We need to call GetQuantizationParameters_ again to pack for acc32
201+
return BaseType::GetQuantizationParameters_();
176202
}
177203
}
178204

@@ -193,17 +219,18 @@ bool ConvDNNLowPAcc16Op<ReluFused>::GetQuantizationParameters_() {
193219
static int log_occurences = 0;
194220
if (log_occurences < 32) {
195221
++log_occurences;
196-
LOG(WARNING) << "Conv with weight " << this->debug_def().input(FILTER)
197-
<< " falls back to slow path because " << reason;
222+
LOG_FIRST_N(WARNING, 10)
223+
<< "Conv with weight " << this->debug_def().input(FILTER)
224+
<< " falls back to slow path because " << reason;
198225
}
199226
}
200227
}
201228
if (nbits_in_non_outlier_ < 8 && this->order_ != StorageOrder::NHWC) {
202229
static int log_occurences = 0;
203230
if (log_occurences < 32) {
204231
++log_occurences;
205-
LOG(WARNING) << "Outlier-aware quantization only supports "
206-
"NHWC layout";
232+
LOG_FIRST_N(WARNING, 10) << "Outlier-aware quantization only supports "
233+
"NHWC layout";
207234
}
208235
}
209236
first_invocation_ = false;
@@ -359,7 +386,7 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNCHW() {
359386
static int log_occurences = 0;
360387
if (log_occurences < 32) {
361388
++log_occurences;
362-
LOG(WARNING)
389+
LOG_FIRST_N(WARNING, 10)
363390
<< "Consider using DNNLOWP instead of DNNLOWP_ACC16 engine since "
364391
"we're falling back to a slow path because of NCHW layout";
365392
}

caffe2/quantization/server/conv_dnnlowp_acc16_op_test.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,19 @@
77
import numpy as np
88
from caffe2.python import core, dyndep, utils, workspace
99
from caffe2.quantization.server import utils as dnnlowp_utils
10-
from dnnlowp_test_utils import (
11-
check_quantized_results_close,
12-
generate_conv_inputs,
13-
)
10+
from dnnlowp_test_utils import check_quantized_results_close
1411
from hypothesis import assume, given
1512

1613

1714
dyndep.InitOpsLibrary("//caffe2/caffe2/quantization/server:dnnlowp_ops")
18-
workspace.GlobalInit(["caffe2", "--caffe2_omp_num_threads=11"])
15+
workspace.GlobalInit(
16+
[
17+
"caffe2",
18+
"--caffe2_omp_num_threads=11",
19+
# Increase this threshold to test acc16 with randomly generated data
20+
"--caffe2_dnnlowp_acc16_density_threshold=0.9",
21+
]
22+
)
1923

2024

2125
class DNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
@@ -254,9 +258,7 @@ def test_dnnlowp_conv_acc16_outlier(
254258
W_min = -100
255259
W_max = W_min + 255
256260
W = (
257-
np.random.rand(
258-
output_channels, kernel, kernel, input_channels_per_group
259-
)
261+
np.random.rand(output_channels, kernel, kernel, input_channels_per_group)
260262
* 4
261263
- 2
262264
+ W_min

caffe2/quantization/server/conv_groupwise_dnnlowp_acc16_op_test.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,19 @@
77
import numpy as np
88
from caffe2.python import core, dyndep, utils, workspace
99
from caffe2.quantization.server import utils as dnnlowp_utils
10-
from dnnlowp_test_utils import (
11-
check_quantized_results_close,
12-
generate_conv_inputs,
13-
)
10+
from dnnlowp_test_utils import check_quantized_results_close
1411
from hypothesis import assume, given
1512

1613

1714
dyndep.InitOpsLibrary("//caffe2/caffe2/quantization/server:dnnlowp_ops")
18-
workspace.GlobalInit(["caffe2", "--caffe2_omp_num_threads=11"])
15+
workspace.GlobalInit(
16+
[
17+
"caffe2",
18+
"--caffe2_omp_num_threads=11",
19+
# Increase this threshold to test acc16 with randomly generated data
20+
"--caffe2_dnnlowp_acc16_density_threshold=0.9",
21+
]
22+
)
1923

2024

2125
class GroupWiseDNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
@@ -224,9 +228,7 @@ def test_groupwise_dnnlowp_conv_acc16_outlier(
224228
W_min = -100
225229
W_max = W_min + 255
226230
W = (
227-
np.random.rand(
228-
output_channels, kernel, kernel, input_channels_per_group
229-
)
231+
np.random.rand(output_channels, kernel, kernel, input_channels_per_group)
230232
* 4
231233
- 2
232234
+ W_min
@@ -237,9 +239,7 @@ def test_groupwise_dnnlowp_conv_acc16_outlier(
237239
for g in range(group):
238240
W[g * output_channels_per_group, 0, 0, 0] = W_min
239241
W[g * output_channels_per_group + 1, 0, 0, 0] = W_max
240-
W[
241-
g * output_channels_per_group : (g + 1) * output_channels_per_group,
242-
] += g
242+
W[g * output_channels_per_group : (g + 1) * output_channels_per_group,] += g
243243

244244
if order == "NCHW":
245245
X = utils.NHWC2NCHW(X)

0 commit comments

Comments
 (0)