4
4
import torch
5
5
from torch import Tensor
6
6
from torch .utils ._python_dispatch import return_and_correct_aliasing
7
- from torchao .prototype .custom_fp_utils import _f32_to_fpx_unpacked , _fpx_unpacked_to_f32 , _n_ones
7
+ from torchao .prototype .custom_fp_utils import (
8
+ _f32_to_fpx_unpacked ,
9
+ _fpx_unpacked_to_f32 ,
10
+ _n_ones ,
11
+ )
8
12
from torchao .dtypes .utils import (
9
13
LayoutType ,
10
14
)
17
21
18
22
19
23
def _pack (x : Tensor , n_bits : int ) -> Tensor :
20
- return reduce (torch .bitwise_or , [x [..., i ::(8 // n_bits )] << (8 - (i + 1 ) * n_bits ) for i in range (8 // n_bits )])
24
+ return reduce (
25
+ torch .bitwise_or ,
26
+ [
27
+ x [..., i :: (8 // n_bits )] << (8 - (i + 1 ) * n_bits )
28
+ for i in range (8 // n_bits )
29
+ ],
30
+ )
21
31
22
32
23
33
def _unpack (x : Tensor , n_bits : int ) -> Tensor :
24
- return torch .stack ([(x >> (8 - (i + 1 ) * n_bits )) & ((1 << n_bits ) - 1 ) for i in range (8 // n_bits )], dim = - 1 ).flatten (- 2 )
34
+ return torch .stack (
35
+ [
36
+ (x >> (8 - (i + 1 ) * n_bits )) & ((1 << n_bits ) - 1 )
37
+ for i in range (8 // n_bits )
38
+ ],
39
+ dim = - 1 ,
40
+ ).flatten (- 2 )
25
41
26
42
27
43
# https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/utils/weight_prepacking.h#L87-L116
@@ -35,8 +51,40 @@ def _bit_interleave(x: Tensor, n_bits: int, undo: bool = False) -> Tensor:
35
51
36
52
if not undo :
37
53
bit_order = {
38
- 1 : [1 , 5 , 9 , 13 , 17 , 21 , 25 , 29 , 3 , 7 , 11 , 15 , 19 , 23 , 27 , 31 ,
39
- 0 , 4 , 8 , 12 , 16 , 20 , 24 , 28 , 2 , 6 , 10 , 14 , 18 , 22 , 26 , 30 ],
54
+ 1 : [
55
+ 1 ,
56
+ 5 ,
57
+ 9 ,
58
+ 13 ,
59
+ 17 ,
60
+ 21 ,
61
+ 25 ,
62
+ 29 ,
63
+ 3 ,
64
+ 7 ,
65
+ 11 ,
66
+ 15 ,
67
+ 19 ,
68
+ 23 ,
69
+ 27 ,
70
+ 31 ,
71
+ 0 ,
72
+ 4 ,
73
+ 8 ,
74
+ 12 ,
75
+ 16 ,
76
+ 20 ,
77
+ 24 ,
78
+ 28 ,
79
+ 2 ,
80
+ 6 ,
81
+ 10 ,
82
+ 14 ,
83
+ 18 ,
84
+ 22 ,
85
+ 26 ,
86
+ 30 ,
87
+ ],
40
88
2 : [1 , 5 , 9 , 13 , 3 , 7 , 11 , 15 , 0 , 4 , 8 , 12 , 2 , 6 , 10 , 14 ],
41
89
4 : [1 , 5 , 3 , 7 , 0 , 4 , 2 , 6 ],
42
90
}[n_bits ]
@@ -45,8 +93,40 @@ def _bit_interleave(x: Tensor, n_bits: int, undo: bool = False) -> Tensor:
45
93
# this is inverse of the above, obtained by running
46
94
# [v.index(i) for i in range(len(v))]
47
95
bit_order = {
48
- 1 : [16 , 0 , 24 , 8 , 17 , 1 , 25 , 9 , 18 , 2 , 26 , 10 , 19 , 3 , 27 , 11 ,
49
- 20 , 4 , 28 , 12 , 21 , 5 , 29 , 13 , 22 , 6 , 30 , 14 , 23 , 7 , 31 , 15 ],
96
+ 1 : [
97
+ 16 ,
98
+ 0 ,
99
+ 24 ,
100
+ 8 ,
101
+ 17 ,
102
+ 1 ,
103
+ 25 ,
104
+ 9 ,
105
+ 18 ,
106
+ 2 ,
107
+ 26 ,
108
+ 10 ,
109
+ 19 ,
110
+ 3 ,
111
+ 27 ,
112
+ 11 ,
113
+ 20 ,
114
+ 4 ,
115
+ 28 ,
116
+ 12 ,
117
+ 21 ,
118
+ 5 ,
119
+ 29 ,
120
+ 13 ,
121
+ 22 ,
122
+ 6 ,
123
+ 30 ,
124
+ 14 ,
125
+ 23 ,
126
+ 7 ,
127
+ 31 ,
128
+ 15 ,
129
+ ],
50
130
2 : [8 , 0 , 12 , 4 , 9 , 1 , 13 , 5 , 10 , 2 , 14 , 6 , 11 , 3 , 15 , 7 ],
51
131
4 : [4 , 0 , 6 , 2 , 5 , 1 , 7 , 3 ],
52
132
}[n_bits ]
@@ -82,8 +162,12 @@ def _pack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor:
82
162
tensor_ybit = (tensor >> (nbits - used_bits - y )) & mask
83
163
tensor_ybit = _pack (tensor_ybit , y )
84
164
85
- tensor_ybit = tensor_ybit .view (32 , - 1 , 4 ).permute (1 , 0 , 2 ).flip (2 ) # Pass 2 from original code
86
- tensor_ybit = _bit_interleave (tensor_ybit .flatten (), y ) # Pass 3 from original code
165
+ tensor_ybit = (
166
+ tensor_ybit .view (32 , - 1 , 4 ).permute (1 , 0 , 2 ).flip (2 )
167
+ ) # Pass 2 from original code
168
+ tensor_ybit = _bit_interleave (
169
+ tensor_ybit .flatten (), y
170
+ ) # Pass 3 from original code
87
171
fragments .append (tensor_ybit )
88
172
used_bits += y
89
173
@@ -125,7 +209,9 @@ def to_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int) -> Tuple[Tensor, Te
125
209
126
210
# workaround: global lookup table
127
211
exp_bias = _ONES_TABLE [ebits - 1 ]
128
- max_normal = 2 ** (_ONES_TABLE [ebits ] - exp_bias ) * (_ONES_TABLE [mbits + 1 ] / (2 ** mbits ))
212
+ max_normal = 2 ** (_ONES_TABLE [ebits ] - exp_bias ) * (
213
+ _ONES_TABLE [mbits + 1 ] / (2 ** mbits )
214
+ )
129
215
130
216
tensor = tensor .float ()
131
217
scale = tensor .abs ().amax (1 ).clamp (min = 1e-12 ) / max_normal
@@ -151,8 +237,10 @@ def _unpack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor:
151
237
tensor_ybit = tensor [offset : offset + size_ybit ]
152
238
offset += size_ybit
153
239
154
- tensor_ybit = _bit_interleave (tensor_ybit , y , undo = True ) # undo Pass 3
155
- tensor_ybit = tensor_ybit .view (- 1 , 32 , 4 ).flip (2 ).permute (1 , 0 , 2 ) # undo Pass 2
240
+ tensor_ybit = _bit_interleave (tensor_ybit , y , undo = True ) # undo Pass 3
241
+ tensor_ybit = (
242
+ tensor_ybit .view (- 1 , 32 , 4 ).flip (2 ).permute (1 , 0 , 2 )
243
+ ) # undo Pass 2
156
244
157
245
tensor_ybit = _unpack (tensor_ybit .flatten (), y )
158
246
tensor_ybit = tensor_ybit << (nbits - used_bits - y )
@@ -223,7 +311,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
223
311
10240 : 5 ,
224
312
14336 : 7 ,
225
313
28672 : 7 ,
226
- 57344 : 7
314
+ 57344 : 7 ,
227
315
},
228
316
{ # tokens: [65:128]
229
317
3072 : 9 ,
@@ -234,7 +322,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
234
322
10240 : 5 ,
235
323
14336 : 7 ,
236
324
28672 : 7 ,
237
- 57344 : 6
325
+ 57344 : 6 ,
238
326
},
239
327
{ # tokens: [129:192]
240
328
3072 : 6 ,
@@ -245,7 +333,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
245
333
10240 : 5 ,
246
334
14336 : 5 ,
247
335
28672 : 5 ,
248
- 57344 : 4
336
+ 57344 : 4 ,
249
337
},
250
338
{ # tokens: [193:256]
251
339
3072 : 9 ,
@@ -256,7 +344,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
256
344
10240 : 4 ,
257
345
14336 : 8 ,
258
346
28672 : 6 ,
259
- 57344 : 4
347
+ 57344 : 4 ,
260
348
},
261
349
{ # tokens: [257:320]
262
350
3072 : 7 ,
@@ -267,7 +355,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
267
355
10240 : 1 ,
268
356
14336 : 3 ,
269
357
28672 : 3 ,
270
- 57344 : 4
358
+ 57344 : 4 ,
271
359
},
272
360
{ # tokens: [321:384]
273
361
3072 : 3 ,
@@ -278,7 +366,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
278
366
10240 : 8 ,
279
367
14336 : 3 ,
280
368
28672 : 4 ,
281
- 57344 : 3
369
+ 57344 : 3 ,
282
370
},
283
371
{ # tokens: [385:448]
284
372
3072 : 5 ,
@@ -289,7 +377,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
289
377
10240 : 3 ,
290
378
14336 : 1 ,
291
379
28672 : 1 ,
292
- 57344 : 3
380
+ 57344 : 3 ,
293
381
},
294
382
{ # tokens: [449:512]
295
383
3072 : 2 ,
@@ -300,7 +388,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
300
388
10240 : 2 ,
301
389
14336 : 6 ,
302
390
28672 : 4 ,
303
- 57344 : 1
391
+ 57344 : 1 ,
304
392
},
305
393
{ # tokens: [513:576]
306
394
3072 : 2 ,
@@ -311,7 +399,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
311
399
10240 : 3 ,
312
400
14336 : 3 ,
313
401
28672 : 1 ,
314
- 57344 : 1
402
+ 57344 : 1 ,
315
403
},
316
404
{ # tokens: [577:640]
317
405
3072 : 5 ,
@@ -322,7 +410,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
322
410
10240 : 1 ,
323
411
14336 : 1 ,
324
412
28672 : 1 ,
325
- 57344 : 1
413
+ 57344 : 1 ,
326
414
},
327
415
{ # tokens: [641:704]
328
416
3072 : 3 ,
@@ -333,7 +421,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
333
421
10240 : 2 ,
334
422
14336 : 1 ,
335
423
28672 : 1 ,
336
- 57344 : 1
424
+ 57344 : 1 ,
337
425
},
338
426
{ # tokens: [705:768]
339
427
3072 : 3 ,
@@ -344,20 +432,22 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
344
432
10240 : 1 ,
345
433
14336 : 1 ,
346
434
28672 : 1 ,
347
- 57344 : 1
348
- }
435
+ 57344 : 1 ,
436
+ },
349
437
]
350
438
351
439
352
440
# quantization api integrations
353
441
442
+
354
443
@dataclass (frozen = True )
355
444
class FpxTensorCoreLayoutType (LayoutType ):
356
- """Layout type for FpxTensorCoreAQTLayout
357
- """
445
+ """Layout type for FpxTensorCoreAQTLayout"""
446
+
358
447
ebits : int
359
448
mbits : int
360
449
450
+
361
451
@register_layout_cls (FpxTensorCoreLayoutType )
362
452
class FpxTensorCoreAQTLayout (AQTLayout ):
363
453
"""FpxTensorCoreAQTLayout represents a Tensor with dtype fpx(ebits=a, mbits=b),
@@ -381,6 +471,7 @@ class FpxTensorCoreAQTLayout(AQTLayout):
381
471
it will then pack the weight and instantiate the FpxTensorCoreAQTLayout tensor
382
472
FpxTensorCoreAQTLayout.__init__() takes a packed fpx Tensor of shape (M, N // 8 * nbit)
383
473
"""
474
+
384
475
def __new__ (
385
476
cls ,
386
477
packed_fpx_data : torch .Tensor ,
@@ -389,11 +480,16 @@ def __new__(
389
480
):
390
481
assert packed_fpx_data .ndim == 2
391
482
assert packed_fpx_data .dtype == torch .uint8
392
- shape = (packed_fpx_data .shape [0 ], packed_fpx_data .shape [1 ] // (1 + layout_type .ebits + layout_type .mbits ) * 8 )
483
+ shape = (
484
+ packed_fpx_data .shape [0 ],
485
+ packed_fpx_data .shape [1 ] // (1 + layout_type .ebits + layout_type .mbits ) * 8 ,
486
+ )
393
487
kwargs = {}
394
488
kwargs ["device" ] = packed_fpx_data .device
395
489
kwargs ["layout" ] = (
396
- kwargs .get ("layout" ) if kwargs .get ("layout" , False ) else packed_fpx_data .layout
490
+ kwargs .get ("layout" )
491
+ if kwargs .get ("layout" , False )
492
+ else packed_fpx_data .layout
397
493
)
398
494
kwargs ["dtype" ] = packed_fpx_data .dtype
399
495
kwargs ["requires_grad" ] = False
@@ -416,12 +512,17 @@ def __tensor_flatten__(self):
416
512
def __tensor_unflatten__ (
417
513
cls , tensor_data_dict , tensor_attributes , outer_size , outer_stride
418
514
):
419
- packed_fpx_data , scale = tensor_data_dict ["packed_fpx_data" ], tensor_data_dict ["scale" ]
420
- layout_type , = tensor_attributes
515
+ packed_fpx_data , scale = (
516
+ tensor_data_dict ["packed_fpx_data" ],
517
+ tensor_data_dict ["scale" ],
518
+ )
519
+ (layout_type ,) = tensor_attributes
421
520
return cls (packed_fpx_data , scale , layout_type )
422
521
423
522
def get_plain (self ) -> Tuple [torch .Tensor , torch .Tensor ]:
424
- unpacked_fpx_data = unpack_tc_fpx (self .packed_fpx_data , 1 + self .layout_type .ebits + self .layout_type .mbits )
523
+ unpacked_fpx_data = unpack_tc_fpx (
524
+ self .packed_fpx_data , 1 + self .layout_type .ebits + self .layout_type .mbits
525
+ )
425
526
return unpacked_fpx_data , self .scale
426
527
427
528
@classmethod
@@ -440,7 +541,9 @@ def from_plain(
440
541
bit, M is mantissa bit
441
542
"""
442
543
assert isinstance (layout_type , FpxTensorCoreLayoutType )
443
- packed_fpx_data = pack_tc_fpx (unpacked_fpx_data , 1 + layout_type .ebits + layout_type .mbits )
544
+ packed_fpx_data = pack_tc_fpx (
545
+ unpacked_fpx_data , 1 + layout_type .ebits + layout_type .mbits
546
+ )
444
547
return cls (packed_fpx_data , scale , layout_type )
445
548
446
549
def __repr__ (self ):
@@ -478,7 +581,12 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
478
581
)
479
582
elif func is aten ._to_copy .default :
480
583
return return_and_correct_aliasing (
481
- func , args , kwargs , args [0 ]._apply_fn_to_data (lambda x : x .to (device = kwargs .pop ("device" , None ))),
584
+ func ,
585
+ args ,
586
+ kwargs ,
587
+ args [0 ]._apply_fn_to_data (
588
+ lambda x : x .to (device = kwargs .pop ("device" , None ))
589
+ ),
482
590
)
483
591
484
592
raise NotImplementedError (
0 commit comments