Skip to content

Commit ac23009

Browse files
metascroyfacebook-github-bot
authored andcommitted
Enable 8-bit
Summary: Enables 8-bit kernel in operators and tests Differential Revision: D65688410
1 parent 242f181 commit ac23009

10 files changed

+74
-9
lines changed

torchao/experimental/_linear_8bit_act_xbit_weight_layout.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __init__(
6565
group_size: int,
6666
target: str,
6767
):
68-
assert nbit <= 7
68+
assert nbit <= 8
6969
self.nbit = nbit
7070
self.group_size = group_size
7171
self.target = target_from_str(target)

torchao/experimental/ops/embedding_xbit/op_embedding_xbit_aten.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ TORCH_LIBRARY_FRAGMENT(torchao, m) {
3636
DEFINE_OP(5);
3737
DEFINE_OP(6);
3838
DEFINE_OP(7);
39+
DEFINE_OP(8);
3940
}
4041

4142
TORCH_LIBRARY_IMPL(torchao, CPU, m) {
@@ -46,6 +47,7 @@ TORCH_LIBRARY_IMPL(torchao, CPU, m) {
4647
DEFINE_CPU_IMPL(5);
4748
DEFINE_CPU_IMPL(6);
4849
DEFINE_CPU_IMPL(7);
50+
DEFINE_CPU_IMPL(8);
4951
}
5052

5153
TORCH_LIBRARY_IMPL(torchao, Meta, m) {
@@ -56,4 +58,5 @@ TORCH_LIBRARY_IMPL(torchao, Meta, m) {
5658
DEFINE_META_IMPL(5);
5759
DEFINE_META_IMPL(6);
5860
DEFINE_META_IMPL(7);
61+
DEFINE_META_IMPL(8);
5962
}

torchao/experimental/ops/embedding_xbit/op_embedding_xbit_executorch.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,4 @@ DEFINE_OP(4);
3737
DEFINE_OP(5);
3838
DEFINE_OP(6);
3939
DEFINE_OP(7);
40+
DEFINE_OP(8);

torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ TORCH_LIBRARY(torchao, m) {
6868
DEFINE_OP(5);
6969
DEFINE_OP(6);
7070
DEFINE_OP(7);
71+
DEFINE_OP(8);
7172
}
7273

7374
TORCH_LIBRARY_IMPL(torchao, CPU, m) {
@@ -78,6 +79,7 @@ TORCH_LIBRARY_IMPL(torchao, CPU, m) {
7879
DEFINE_CPU_IMPL(5);
7980
DEFINE_CPU_IMPL(6);
8081
DEFINE_CPU_IMPL(7);
82+
DEFINE_CPU_IMPL(8);
8183
}
8284

8385
TORCH_LIBRARY_IMPL(torchao, Meta, m) {
@@ -88,4 +90,5 @@ TORCH_LIBRARY_IMPL(torchao, Meta, m) {
8890
DEFINE_META_IMPL(5);
8991
DEFINE_META_IMPL(6);
9092
DEFINE_META_IMPL(7);
93+
DEFINE_META_IMPL(8);
9194
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
// Unlike ATen, ExecuTorch op registration appears to only allow on
8+
// EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new
9+
// file is needed for each variant
10+
11+
#include <torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h>
12+
13+
namespace {
14+
Tensor _op_out(
15+
RuntimeContext& ctx,
16+
const Tensor& activations,
17+
const Tensor& packed_weights,
18+
const Tensor& group_size_tensor,
19+
const Tensor& n_tensor,
20+
const Tensor& k_tensor,
21+
Tensor& out) {
22+
(void)ctx;
23+
linear_out_cpu</*weight_nbit*/ 8, /*has_weight_zeros*/ false>(
24+
activations, packed_weights, group_size_tensor, n_tensor, k_tensor, out);
25+
return out;
26+
}
27+
} // namespace
28+
29+
EXECUTORCH_LIBRARY(torchao, "_linear_8bit_act_8bit0zp_weight.out", _op_out);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
// Unlike ATen, ExecuTorch op registration appears to only allow on
8+
// EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new
9+
// file is needed for each variant
10+
11+
#include <torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h>
12+
13+
namespace {
14+
Tensor _op_out(
15+
RuntimeContext& ctx,
16+
const Tensor& activations,
17+
const Tensor& packed_weights,
18+
const Tensor& group_size_tensor,
19+
const Tensor& n_tensor,
20+
const Tensor& k_tensor,
21+
Tensor& out) {
22+
(void)ctx;
23+
linear_out_cpu</*weight_nbit*/ 8, /*has_weight_zeros*/ true>(
24+
activations, packed_weights, group_size_tensor, n_tensor, k_tensor, out);
25+
return out;
26+
}
27+
} // namespace
28+
29+
EXECUTORCH_LIBRARY(torchao, "_linear_8bit_act_8bit_weight.out", _op_out);

torchao/experimental/quant_api.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def forward(self, x):
198198

199199
def _maybe_get_quantized_linear_native(nbit, has_weight_zeros):
200200
try:
201-
if nbit in [1, 2, 3, 4, 5, 6, 7]:
201+
if nbit in [1, 2, 3, 4, 5, 6, 7, 8]:
202202
wzp_suffix = "" if has_weight_zeros else "0zp"
203203
return _Int8DynActIntxWeightQuantizedLinearNative(
204204
pack_weight_op=getattr(
@@ -230,7 +230,7 @@ def _replace_linear_with_quantized_linear(module: nn.Module, kwargs={}):
230230
has_weight_zeros = kwargs["has_weight_zeros"]
231231

232232
assert not isinstance(module, nn.Linear)
233-
assert nbit >= 1 and nbit <= 7
233+
assert nbit >= 1 and nbit <= 8
234234

235235
for name, child in module.named_children():
236236
if not isinstance(child, nn.Linear):
@@ -366,9 +366,9 @@ def quantize_and_pack_weights(self, weights, group_size):
366366
weight_qvals, weight_scales, weight_zeros = _quantize(
367367
weights, self.group_size, self.nbit, has_weight_zeros=True
368368
)
369-
self.weight_qvals = weight_qvals.to(torch.int8)
369+
self.weight_qvals = weight_qvals.to(torch.int32)
370370
self.weight_scales = weight_scales
371-
self.weight_zeros = weight_zeros.to(torch.int8)
371+
self.weight_zeros = weight_zeros.to(torch.int32)
372372

373373
def forward(self, x):
374374
shape = x.shape
@@ -394,7 +394,7 @@ def _replace_embedding_with_quantized_embedding(module: nn.Module, kwargs={}):
394394
nbit = kwargs["nbit"]
395395

396396
assert not isinstance(module, nn.Embedding)
397-
assert nbit >= 1 and nbit <= 7
397+
assert nbit >= 1 and nbit <= 8
398398

399399
for name, child in module.named_children():
400400
if not isinstance(child, nn.Embedding):

torchao/experimental/tests/test_embedding_xbit_quantizer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def test_accuracy(self):
6565
model = torch.nn.Sequential(*[torch.nn.Embedding(num_embeddings, embedding_dim)])
6666
indices = torch.randint(0, num_embeddings, (7,), dtype=torch.int32)
6767

68-
for nbit in [1, 2, 3, 4, 5, 6, 7]:
68+
for nbit in [1, 2, 3, 4, 5, 6, 7, 8]:
6969
print(f"Testing nbit={nbit}")
7070
quantized_model = copy.deepcopy(model)
7171
quantizer = IntxWeightEmbeddingQuantizer(

torchao/experimental/tests/test_linear_8bit_act_xbit_weight_quantizer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def test_accuracy(self):
6767
activations = torch.randn(2, 3, m, k, dtype=torch.float32)
6868
model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)])
6969

70-
for nbit in [1, 2, 3, 4, 5, 6, 7]:
70+
for nbit in [1, 2, 3, 4, 5, 6, 7, 8]:
7171
for has_weight_zeros in [True, False]:
7272
print(f"Testing nbit={nbit}, has_weight_zeros={has_weight_zeros}")
7373
quantized_model = copy.deepcopy(model)

torchao/experimental/tests/test_linear_int8_dynamic_activation_intx_weight_subclass.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def test_accuracy(self):
7070
activations = torch.randn(m, k, dtype=torch.float32)
7171
model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)])
7272

73-
for nbit in [1, 2, 3, 4, 5, 6, 7]:
73+
for nbit in [1, 2, 3, 4, 5, 6, 7, 8]:
7474
for has_weight_zeros in [True, False]:
7575
print(f"Testing nbit={nbit}, has_weight_zeros={has_weight_zeros}")
7676
quantized_model = copy.deepcopy(model)

0 commit comments

Comments
 (0)