Skip to content

Commit 1aa8839

Browse files
committed
Ruff auto-format
1 parent 5091d35 commit 1aa8839

File tree

8 files changed

+675
-253
lines changed

8 files changed

+675
-253
lines changed

torchao/dtypes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .nf4tensor import NF4Tensor, to_nf4
2+
23
# from ..prototype.dtypes.uint2 import UInt2Tensor, BitnetTensor
34
from .uint4 import UInt4Tensor
45
from .affine_quantized_tensor import (

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 411 additions & 167 deletions
Large diffs are not rendered by default.

torchao/dtypes/fpx/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1-
from .fpx import FpxTensorCoreLayoutType, FpxTensorCoreAQTLayout, to_scaled_tc_fpx, from_scaled_tc_fpx, _SPLIT_K_MAP
1+
from .fpx import (
2+
FpxTensorCoreLayoutType,
3+
FpxTensorCoreAQTLayout,
4+
to_scaled_tc_fpx,
5+
from_scaled_tc_fpx,
6+
_SPLIT_K_MAP,
7+
)
28

39
__all__ = [
410
"FpxTensorCoreAQTLayout",

torchao/dtypes/fpx/fpx.py

Lines changed: 142 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
import torch
55
from torch import Tensor
66
from torch.utils._python_dispatch import return_and_correct_aliasing
7-
from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32, _n_ones
7+
from torchao.prototype.custom_fp_utils import (
8+
_f32_to_fpx_unpacked,
9+
_fpx_unpacked_to_f32,
10+
_n_ones,
11+
)
812
from torchao.dtypes.utils import (
913
LayoutType,
1014
)
@@ -17,11 +21,23 @@
1721

1822

1923
def _pack(x: Tensor, n_bits: int) -> Tensor:
20-
return reduce(torch.bitwise_or, [x[..., i::(8 // n_bits)] << (8 - (i + 1) * n_bits) for i in range(8 // n_bits)])
24+
return reduce(
25+
torch.bitwise_or,
26+
[
27+
x[..., i :: (8 // n_bits)] << (8 - (i + 1) * n_bits)
28+
for i in range(8 // n_bits)
29+
],
30+
)
2131

2232

2333
def _unpack(x: Tensor, n_bits: int) -> Tensor:
24-
return torch.stack([(x >> (8 - (i + 1) * n_bits)) & ((1 << n_bits) - 1) for i in range(8 // n_bits)], dim=-1).flatten(-2)
34+
return torch.stack(
35+
[
36+
(x >> (8 - (i + 1) * n_bits)) & ((1 << n_bits) - 1)
37+
for i in range(8 // n_bits)
38+
],
39+
dim=-1,
40+
).flatten(-2)
2541

2642

2743
# https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/utils/weight_prepacking.h#L87-L116
@@ -35,8 +51,40 @@ def _bit_interleave(x: Tensor, n_bits: int, undo: bool = False) -> Tensor:
3551

3652
if not undo:
3753
bit_order = {
38-
1: [1, 5, 9, 13, 17, 21, 25, 29, 3, 7, 11, 15, 19, 23, 27, 31,
39-
0, 4, 8, 12, 16, 20, 24, 28, 2, 6, 10, 14, 18, 22, 26, 30],
54+
1: [
55+
1,
56+
5,
57+
9,
58+
13,
59+
17,
60+
21,
61+
25,
62+
29,
63+
3,
64+
7,
65+
11,
66+
15,
67+
19,
68+
23,
69+
27,
70+
31,
71+
0,
72+
4,
73+
8,
74+
12,
75+
16,
76+
20,
77+
24,
78+
28,
79+
2,
80+
6,
81+
10,
82+
14,
83+
18,
84+
22,
85+
26,
86+
30,
87+
],
4088
2: [1, 5, 9, 13, 3, 7, 11, 15, 0, 4, 8, 12, 2, 6, 10, 14],
4189
4: [1, 5, 3, 7, 0, 4, 2, 6],
4290
}[n_bits]
@@ -45,8 +93,40 @@ def _bit_interleave(x: Tensor, n_bits: int, undo: bool = False) -> Tensor:
4593
# this is inverse of the above, obtained by running
4694
# [v.index(i) for i in range(len(v))]
4795
bit_order = {
48-
1: [16, 0, 24, 8, 17, 1, 25, 9, 18, 2, 26, 10, 19, 3, 27, 11,
49-
20, 4, 28, 12, 21, 5, 29, 13, 22, 6, 30, 14, 23, 7, 31, 15],
96+
1: [
97+
16,
98+
0,
99+
24,
100+
8,
101+
17,
102+
1,
103+
25,
104+
9,
105+
18,
106+
2,
107+
26,
108+
10,
109+
19,
110+
3,
111+
27,
112+
11,
113+
20,
114+
4,
115+
28,
116+
12,
117+
21,
118+
5,
119+
29,
120+
13,
121+
22,
122+
6,
123+
30,
124+
14,
125+
23,
126+
7,
127+
31,
128+
15,
129+
],
50130
2: [8, 0, 12, 4, 9, 1, 13, 5, 10, 2, 14, 6, 11, 3, 15, 7],
51131
4: [4, 0, 6, 2, 5, 1, 7, 3],
52132
}[n_bits]
@@ -82,8 +162,12 @@ def _pack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor:
82162
tensor_ybit = (tensor >> (nbits - used_bits - y)) & mask
83163
tensor_ybit = _pack(tensor_ybit, y)
84164

85-
tensor_ybit = tensor_ybit.view(32, -1, 4).permute(1, 0, 2).flip(2) # Pass 2 from original code
86-
tensor_ybit = _bit_interleave(tensor_ybit.flatten(), y) # Pass 3 from original code
165+
tensor_ybit = (
166+
tensor_ybit.view(32, -1, 4).permute(1, 0, 2).flip(2)
167+
) # Pass 2 from original code
168+
tensor_ybit = _bit_interleave(
169+
tensor_ybit.flatten(), y
170+
) # Pass 3 from original code
87171
fragments.append(tensor_ybit)
88172
used_bits += y
89173

@@ -125,7 +209,9 @@ def to_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int) -> Tuple[Tensor, Te
125209

126210
# workaround: global lookup table
127211
exp_bias = _ONES_TABLE[ebits - 1]
128-
max_normal = 2 ** (_ONES_TABLE[ebits] - exp_bias) * (_ONES_TABLE[mbits + 1] / (2 ** mbits))
212+
max_normal = 2 ** (_ONES_TABLE[ebits] - exp_bias) * (
213+
_ONES_TABLE[mbits + 1] / (2**mbits)
214+
)
129215

130216
tensor = tensor.float()
131217
scale = tensor.abs().amax(1).clamp(min=1e-12) / max_normal
@@ -151,8 +237,10 @@ def _unpack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor:
151237
tensor_ybit = tensor[offset : offset + size_ybit]
152238
offset += size_ybit
153239

154-
tensor_ybit = _bit_interleave(tensor_ybit, y, undo=True) # undo Pass 3
155-
tensor_ybit = tensor_ybit.view(-1, 32, 4).flip(2).permute(1, 0, 2) # undo Pass 2
240+
tensor_ybit = _bit_interleave(tensor_ybit, y, undo=True) # undo Pass 3
241+
tensor_ybit = (
242+
tensor_ybit.view(-1, 32, 4).flip(2).permute(1, 0, 2)
243+
) # undo Pass 2
156244

157245
tensor_ybit = _unpack(tensor_ybit.flatten(), y)
158246
tensor_ybit = tensor_ybit << (nbits - used_bits - y)
@@ -223,7 +311,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
223311
10240: 5,
224312
14336: 7,
225313
28672: 7,
226-
57344: 7
314+
57344: 7,
227315
},
228316
{ # tokens: [65:128]
229317
3072: 9,
@@ -234,7 +322,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
234322
10240: 5,
235323
14336: 7,
236324
28672: 7,
237-
57344: 6
325+
57344: 6,
238326
},
239327
{ # tokens: [129:192]
240328
3072: 6,
@@ -245,7 +333,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
245333
10240: 5,
246334
14336: 5,
247335
28672: 5,
248-
57344: 4
336+
57344: 4,
249337
},
250338
{ # tokens: [193:256]
251339
3072: 9,
@@ -256,7 +344,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
256344
10240: 4,
257345
14336: 8,
258346
28672: 6,
259-
57344: 4
347+
57344: 4,
260348
},
261349
{ # tokens: [257:320]
262350
3072: 7,
@@ -267,7 +355,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
267355
10240: 1,
268356
14336: 3,
269357
28672: 3,
270-
57344: 4
358+
57344: 4,
271359
},
272360
{ # tokens: [321:384]
273361
3072: 3,
@@ -278,7 +366,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
278366
10240: 8,
279367
14336: 3,
280368
28672: 4,
281-
57344: 3
369+
57344: 3,
282370
},
283371
{ # tokens: [385:448]
284372
3072: 5,
@@ -289,7 +377,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
289377
10240: 3,
290378
14336: 1,
291379
28672: 1,
292-
57344: 3
380+
57344: 3,
293381
},
294382
{ # tokens: [449:512]
295383
3072: 2,
@@ -300,7 +388,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
300388
10240: 2,
301389
14336: 6,
302390
28672: 4,
303-
57344: 1
391+
57344: 1,
304392
},
305393
{ # tokens: [513:576]
306394
3072: 2,
@@ -311,7 +399,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
311399
10240: 3,
312400
14336: 3,
313401
28672: 1,
314-
57344: 1
402+
57344: 1,
315403
},
316404
{ # tokens: [577:640]
317405
3072: 5,
@@ -322,7 +410,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
322410
10240: 1,
323411
14336: 1,
324412
28672: 1,
325-
57344: 1
413+
57344: 1,
326414
},
327415
{ # tokens: [641:704]
328416
3072: 3,
@@ -333,7 +421,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
333421
10240: 2,
334422
14336: 1,
335423
28672: 1,
336-
57344: 1
424+
57344: 1,
337425
},
338426
{ # tokens: [705:768]
339427
3072: 3,
@@ -344,20 +432,22 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
344432
10240: 1,
345433
14336: 1,
346434
28672: 1,
347-
57344: 1
348-
}
435+
57344: 1,
436+
},
349437
]
350438

351439

352440
# quantization api integrations
353441

442+
354443
@dataclass(frozen=True)
355444
class FpxTensorCoreLayoutType(LayoutType):
356-
"""Layout type for FpxTensorCoreAQTLayout
357-
"""
445+
"""Layout type for FpxTensorCoreAQTLayout"""
446+
358447
ebits: int
359448
mbits: int
360449

450+
361451
@register_layout_cls(FpxTensorCoreLayoutType)
362452
class FpxTensorCoreAQTLayout(AQTLayout):
363453
"""FpxTensorCoreAQTLayout represents a Tensor with dtype fpx(ebits=a, mbits=b),
@@ -381,6 +471,7 @@ class FpxTensorCoreAQTLayout(AQTLayout):
381471
it will then pack the weight and instantiate the FpxTensorCoreAQTLayout tensor
382472
FpxTensorCoreAQTLayout.__init__() takes a packed fpx Tensor of shape (M, N // 8 * nbit)
383473
"""
474+
384475
def __new__(
385476
cls,
386477
packed_fpx_data: torch.Tensor,
@@ -389,11 +480,16 @@ def __new__(
389480
):
390481
assert packed_fpx_data.ndim == 2
391482
assert packed_fpx_data.dtype == torch.uint8
392-
shape = (packed_fpx_data.shape[0], packed_fpx_data.shape[1] // (1 + layout_type.ebits + layout_type.mbits) * 8)
483+
shape = (
484+
packed_fpx_data.shape[0],
485+
packed_fpx_data.shape[1] // (1 + layout_type.ebits + layout_type.mbits) * 8,
486+
)
393487
kwargs = {}
394488
kwargs["device"] = packed_fpx_data.device
395489
kwargs["layout"] = (
396-
kwargs.get("layout") if kwargs.get("layout", False) else packed_fpx_data.layout
490+
kwargs.get("layout")
491+
if kwargs.get("layout", False)
492+
else packed_fpx_data.layout
397493
)
398494
kwargs["dtype"] = packed_fpx_data.dtype
399495
kwargs["requires_grad"] = False
@@ -416,12 +512,17 @@ def __tensor_flatten__(self):
416512
def __tensor_unflatten__(
417513
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
418514
):
419-
packed_fpx_data, scale = tensor_data_dict["packed_fpx_data"], tensor_data_dict["scale"]
420-
layout_type, = tensor_attributes
515+
packed_fpx_data, scale = (
516+
tensor_data_dict["packed_fpx_data"],
517+
tensor_data_dict["scale"],
518+
)
519+
(layout_type,) = tensor_attributes
421520
return cls(packed_fpx_data, scale, layout_type)
422521

423522
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor]:
424-
unpacked_fpx_data = unpack_tc_fpx(self.packed_fpx_data, 1 + self.layout_type.ebits + self.layout_type.mbits)
523+
unpacked_fpx_data = unpack_tc_fpx(
524+
self.packed_fpx_data, 1 + self.layout_type.ebits + self.layout_type.mbits
525+
)
425526
return unpacked_fpx_data, self.scale
426527

427528
@classmethod
@@ -440,7 +541,9 @@ def from_plain(
440541
bit, M is mantissa bit
441542
"""
442543
assert isinstance(layout_type, FpxTensorCoreLayoutType)
443-
packed_fpx_data = pack_tc_fpx(unpacked_fpx_data, 1 + layout_type.ebits + layout_type.mbits)
544+
packed_fpx_data = pack_tc_fpx(
545+
unpacked_fpx_data, 1 + layout_type.ebits + layout_type.mbits
546+
)
444547
return cls(packed_fpx_data, scale, layout_type)
445548

446549
def __repr__(self):
@@ -478,7 +581,12 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
478581
)
479582
elif func is aten._to_copy.default:
480583
return return_and_correct_aliasing(
481-
func, args, kwargs, args[0]._apply_fn_to_data(lambda x: x.to(device=kwargs.pop("device", None))),
584+
func,
585+
args,
586+
kwargs,
587+
args[0]._apply_fn_to_data(
588+
lambda x: x.to(device=kwargs.pop("device", None))
589+
),
482590
)
483591

484592
raise NotImplementedError(

torchao/dtypes/uint4.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@ def __new__(cls, elem, **kwargs):
105105
)
106106

107107
def __init__(self, elem, **kwargs):
108-
109108
self.elem = elem
110109

111110
@classmethod

0 commit comments

Comments
 (0)