Skip to content

Commit 0f58d20

Browse files
jianyuhfacebook-github-bot
authored andcommitted
Add quantized::fbgemm_linear_unpack operator for serialization (#97)
Summary: Pull Request resolved: pytorch/FBGEMM#97 Pull Request resolved: pytorch#20721 - FBGEMM: Add unpack function for PackBMatrix class: Unpack pmat buffer to the origin_buf (Used for the serialization to recover weight matrix). - PyTorch Quantizer: Add quantized::fbgemm_linear_unpack operator for serialization. Reviewed By: zafartahirov Differential Revision: D15314568 fbshipit-source-id: 12080c8887ce31dc849d23e132ae1766ac319407
1 parent 4b576e5 commit 0f58d20

File tree

5 files changed

+104
-12
lines changed

5 files changed

+104
-12
lines changed

aten/src/ATen/native/quantized/cpu/fbgemm_utils.h

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
// of the A rows. The column offsets are needed for the asymmetric quantization
1414
// (affine quantization) of input matrix.
1515
// Note that in JIT mode we can think of a way to fuse col_offsets with bias.
16-
struct FBGEMM_API PackedFCWeight {
16+
struct FBGEMM_API PackedLinearWeight {
1717
std::unique_ptr<fbgemm::PackBMatrix<int8_t>> w;
1818
std::vector<int32_t> col_offsets;
1919
float w_scale;
@@ -28,17 +28,24 @@ struct FBGEMM_API PackedConvWeight {
2828
int32_t w_zp;
2929
};
3030

31-
// Convert the weight from uint8 to int8.
31+
// PackWeight: Convert the weight from uint8 to int8.
3232
static void convert_uint8_int8(
33-
int K,
34-
int N,
33+
int len,
3534
const uint8_t* src_uint8,
3635
int8_t* dst_int8) {
37-
for (size_t i = 0; i < N; ++i) {
38-
for (size_t j = 0; j < K; ++j) {
39-
dst_int8[i * K + j] =
40-
static_cast<int8_t>(static_cast<int32_t>(src_uint8[i * K + j]) - 128);
41-
}
36+
for (int i = 0; i < len; ++i) {
37+
dst_int8[i] = static_cast<int8_t>(static_cast<int32_t>(src_uint8[i]) - 128);
38+
}
39+
}
40+
41+
// UnpackWeight: Convert the weight from int8 to uint8.
42+
static void convert_int8_uint8(
43+
int len,
44+
const int8_t* src_int8,
45+
uint8_t* dst_uint8) {
46+
for (int i = 0; i < len; ++i) {
47+
dst_uint8[i] =
48+
static_cast<uint8_t>(static_cast<int32_t>(src_int8[i]) + 128);
4249
}
4350
}
4451

aten/src/ATen/native/quantized/cpu/qlinear.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ class QLinearInt8 final : public c10::OperatorKernel {
4343
}
4444

4545
// Pull out the PackBMatrix and col_offsets instance from the owning tensor.
46-
auto& pack_ptr = cpp_custom_type_hack::cast<PackedFCWeight>(packed_weight);
46+
auto& pack_ptr =
47+
cpp_custom_type_hack::cast<PackedLinearWeight>(packed_weight);
4748
auto packB = pack_ptr.w.get();
4849
// packB->printPackedMatrix("packedB inside fbgemm_linear (QLinearInt8): ");
4950
auto& col_offsets = pack_ptr.col_offsets;

aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
namespace caffe2 {
1212
#ifdef USE_FBGEMM
1313
// Required for cpp_custom_type_hack to work
14-
CAFFE_KNOWN_TYPE(PackedFCWeight);
14+
CAFFE_KNOWN_TYPE(PackedLinearWeight);
1515
#endif // USE_FBGEMM
1616
} // namespace caffe2
1717

@@ -42,6 +42,10 @@ class QLinearPackWeightInt8 final : public c10::OperatorKernel {
4242
}
4343

4444
at::Tensor operator()(at::Tensor weight) {
45+
TORCH_CHECK(
46+
weight.dim() == 2,
47+
"The weight tensor for quantized::fbgemm_linear_prepack should be 2-dimensional.");
48+
4549
auto N = weight.size(0);
4650
auto K = weight.size(1);
4751

@@ -61,7 +65,7 @@ class QLinearPackWeightInt8 final : public c10::OperatorKernel {
6165
/*B_zero_point=*/weight_zero_point_int32,
6266
/*col_offsets=*/col_offsets.data());
6367

64-
auto ret_ptr = guts::make_unique<PackedFCWeight>(PackedFCWeight{
68+
auto ret_ptr = guts::make_unique<PackedLinearWeight>(PackedLinearWeight{
6569
guts::make_unique<fbgemm::PackBMatrix<int8_t>>(
6670
/*trans=*/fbgemm::matrix_op_t::Transpose,
6771
/*nRow=*/K,
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/core/Type.h>
3+
#include <ATen/core/op_registration/op_registration.h>
4+
#include <ATen/cpp_custom_type_hack.h>
5+
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
6+
#include <ATen/quantized/Quantizer.h>
7+
8+
namespace at {
9+
namespace native {
10+
namespace {
11+
12+
class QLinearUnpackWeightInt8 final : public c10::OperatorKernel {
13+
public:
14+
#ifdef USE_FBGEMM
15+
at::Tensor operator()(at::Tensor packed_weight) {
16+
// Pull out the PackBMatrix instance from the owning tensor.
17+
auto& pack_ptr =
18+
cpp_custom_type_hack::cast<PackedLinearWeight>(packed_weight);
19+
auto packB = pack_ptr.w.get();
20+
21+
int64_t N = static_cast<int64_t>(packB->numCols());
22+
int64_t K = static_cast<int64_t>(packB->numRows());
23+
24+
float weight_scale_float = pack_ptr.w_scale;
25+
int32_t weight_zero_point_int32 = pack_ptr.w_zp;
26+
27+
auto weight_origin = _empty_affine_quantized(
28+
{N, K},
29+
at::device(kCPU).dtype(kQInt8),
30+
weight_scale_float,
31+
weight_zero_point_int32);
32+
int8_t* weight_ptr_int8 =
33+
reinterpret_cast<int8_t*>(weight_origin.data<c10::qint8>());
34+
35+
// packB->printPackedMatrix("packedB inside fbgemm_unpack
36+
// (QLinearUnpackWeightInt8): ");
37+
packB->unpack(weight_ptr_int8);
38+
39+
return weight_origin;
40+
}
41+
#else // USE_FBGEMM
42+
at::Tensor operator()(at::Tensor /* weight */
43+
) {
44+
// We make a strong guarantee that models using these operators will have
45+
// the same numerics across different machines. Therefore, we do not provide
46+
// a fallback path and rather fail loudly if we cannot run FBGEMM.
47+
TORCH_CHECK(
48+
false, "This PyTorch installation was not built with FBGEMM operators");
49+
}
50+
#endif // USE_FBGEMM
51+
};
52+
53+
static auto registry = c10::RegisterOperators().op(
54+
"quantized::fbgemm_linear_unpack(Tensor W_prepack) -> Tensor W_origin",
55+
c10::RegisterOperators::options().kernel<QLinearUnpackWeightInt8>(
56+
CPUTensorId()));
57+
58+
} // namespace
59+
} // namespace native
60+
} // namespace at

test/test_quantized.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,26 @@ def test_qlinear_relu(self):
451451
# Assert equal
452452
np.testing.assert_equal(Y_q_ref2.int_repr().numpy(), Y_q.int_repr().numpy())
453453

454+
"""Tests the correctness of the quantized::fbgemm_linear_unpack op."""
455+
@given(Q=qtensor(shapes=array_shapes(2, 2,), dtypes=((torch.qint8, np.int8, None),)))
456+
def test_qlinear_unpack(self, Q):
457+
W, (W_scale, W_zp), (qmin, qmax), (torch_type, np_type) = Q
458+
qlinear_prepack = torch.ops.quantized.fbgemm_linear_prepack
459+
qlinear_unpack = torch.ops.quantized.fbgemm_linear_unpack
460+
461+
W = torch.from_numpy(W)
462+
W_q = torch.quantize_linear(W, scale=W_scale, zero_point=W_zp, dtype=torch_type)
463+
464+
# Weight prepacking operator for quantized Linear
465+
W_prepack = qlinear_prepack(W_q)
466+
# Weight unpack operator for quantized Linear (Used for serialization)
467+
W_q_origin = qlinear_unpack(W_prepack)
468+
469+
# Assert equal
470+
np.testing.assert_equal(W_q.int_repr(), W_q_origin.int_repr().numpy())
471+
np.testing.assert_equal(W_q.q_scale(), W_q_origin.q_scale())
472+
np.testing.assert_equal(W_q.q_zero_point(), W_q_origin.q_zero_point())
473+
454474

455475
@unittest.skipIf(
456476
TEST_WITH_UBSAN or not torch.fbgemm_is_cpu_supported(),

0 commit comments

Comments
 (0)