@@ -454,11 +454,7 @@ def extra_repr(self) -> str:
454
454
tmp_str += ", use_optimum_format=True"
455
455
return tmp_str
456
456
457
-
458
- # TODO: implement HPUWeightOnlyLinear
459
- # temporarily let HPUWeightOnlyLinear inherit INCWeightOnlyLinear
460
- # should be 'class HPUWeightOnlyLinear(WeightOnlyLinear)'
461
- class HPUWeightOnlyLinear (INCWeightOnlyLinear ):
457
+ class HPUWeightOnlyLinear (WeightOnlyLinear ):
462
458
def __init__ (
463
459
self ,
464
460
in_features ,
@@ -468,7 +464,7 @@ def __init__(
468
464
group_size = 32 ,
469
465
zp = False ,
470
466
bias = False ,
471
- scale_dtype = torch .float32 ,
467
+ scale_dtype = torch .bfloat16 ,
472
468
compression_dtype = torch .int32 ,
473
469
compression_dim = 1 ,
474
470
g_idx = False ,
@@ -482,17 +478,128 @@ def __init__(
482
478
dtype ,
483
479
bits ,
484
480
group_size ,
485
- zp ,
486
- bias ,
487
- scale_dtype ,
488
- compression_dtype ,
489
- compression_dim ,
490
- g_idx ,
491
481
device ,
492
- use_optimum_format ,
493
- ** kwargs ,
482
+ )
483
+ self .float_type = torch .bfloat16
484
+ self .compression_dim = compression_dim
485
+ self .compression_dtype = compression_dtype
486
+
487
+ if bits != 4 :
488
+ raise NotImplementedError ("Only 4 bits are supported." )
489
+ self .maxq = 2 ** self .bits - 1
490
+
491
+ if bias :
492
+ self .register_buffer ("bias" , torch .zeros (self .out_features , dtype = self .float_type ).to (self .device ))
493
+ else :
494
+ self .bias = None
495
+
496
+ self .register_buffer (
497
+ "qweight" ,
498
+ torch .zeros ((in_features , out_features // 32 * self .bits ), dtype = self .compression_dtype ).to (self .device ),
494
499
)
495
500
501
+ self .register_buffer (
502
+ "qzeros" ,
503
+ torch .zeros (
504
+ (
505
+ math .ceil (in_features / self .group_size ),
506
+ out_features // 32 * self .bits ,
507
+ ),
508
+ dtype = self .compression_dtype ,
509
+ ),
510
+ )
511
+ self .register_buffer (
512
+ "scales" ,
513
+ torch .zeros (
514
+ (math .ceil (in_features / self .group_size ), out_features ),
515
+ dtype = self .float_type ,
516
+ ),
517
+ )
518
+
519
+ if g_idx :
520
+ self .register_buffer (
521
+ "g_idx" ,
522
+ torch .tensor ([i // self .group_size for i in range (in_features )], dtype = torch .int32 ),
523
+ )
524
+ else :
525
+ self .g_idx = None
526
+
527
+ self .half_indim = self .in_features // 2
528
+
529
+ self .wf = torch .tensor (list (range (0 , 32 , self .bits )), dtype = torch .int32 ).unsqueeze (0 )
530
+
531
+ def forward (self , input ):
532
+ input_dtype = input .dtype
533
+ output_shape = input .shape [:- 1 ] + (self .out_features ,)
534
+ scales = self .scales
535
+ qweight = self .qweight
536
+ zeros = self .qzeros
537
+ weight = torch .ops .hpu .convert_from_uint4 (qweight , scales , zeros , input_dtype )
538
+ output = torch .matmul (input , weight )
539
+ output = output .to (dtype = input_dtype ).reshape (
540
+ output_shape
541
+
542
+ ) # A cast is needed here as for some reason the vecquant2matmul_faster_old still allocate a float32 output.
543
+ output = output + self .bias if self .bias is not None else output
544
+ return output
545
+
546
+
547
+ def pack (self , int_weight , scales , zp , bias = None , g_idx = None ):
548
+ logger .debug (f"Packing for HPU" )
549
+
550
+ scales = scales .T .contiguous ()
551
+ qzeros = zp .T .contiguous ()
552
+ qweight = int_weight .T .contiguous ()
553
+
554
+ self .scales = scales .to (dtype = torch .bfloat16 )
555
+
556
+ # weights and zp are on device from unpack, need to load to cpu for packing
557
+ self .qweight = qweight .cpu ()
558
+ new_qweight = self .pack_tensor (self .qweight )
559
+ self .qweight = new_qweight .to ("hpu" )
560
+
561
+ self .qzeros = qzeros .cpu ()
562
+ new_qzeros = self .pack_tensor (self .qzeros )
563
+ self .qzeros = new_qzeros .to ("hpu" )
564
+
565
+ if bias is not None :
566
+ self .bias = bias .to ("hpu" ).to (torch .bfloat16 )
567
+
568
+ def unpack (self ):
569
+ logger .debug (f"Unpacking from HPU" )
570
+ self .qweight = self .qweight .cpu ()
571
+ weight = torch .bitwise_right_shift (
572
+ torch .unsqueeze (self .qweight , 1 ).expand (- 1 , 32 // self .bits , - 1 ),
573
+ self .wf .unsqueeze (- 1 ),
574
+ ).to (torch .int16 if self .bits == 8 else torch .int8 )
575
+ weight = torch .bitwise_and (weight , (2 ** self .bits ) - 1 )
576
+ weight = weight .reshape ((weight .shape [0 ]* weight .shape [1 ], weight .shape [2 ]))
577
+ self .qweight = self .qweight .to (self .device )
578
+
579
+ zeros = torch .bitwise_right_shift (
580
+ torch .unsqueeze (self .qzeros , 2 ).expand (- 1 , - 1 , 32 // self .bits ),
581
+ self .wf .unsqueeze (0 ),
582
+ ).to (torch .int16 if self .bits == 8 else torch .int8 )
583
+
584
+ zeros = torch .bitwise_and (
585
+ zeros , (2 ** self .bits ) - 1
586
+ ).to (self .scales .dtype ) # NOTE: It appears that casting here after the `zeros = zeros + 1` is important.
587
+ zeros = zeros + 1
588
+ zeros = zeros .reshape (- 1 , 1 , zeros .shape [1 ] * zeros .shape [2 ])
589
+ return weight , zeros
590
+
591
+ def pack_tensor (self , input , bits = 4 ):
592
+ normal = input .to (torch .int32 )
593
+ q = torch .zeros ((normal .shape [0 ], normal .shape [1 ] // 32 * bits ), dtype = torch .int32 )
594
+ i = 0
595
+ col = 0
596
+ while col < q .shape [1 ]:
597
+ for j in range (i , i + (32 // bits )):
598
+ q [:, col ] |= normal [:, j ] << (bits * (j - i ))
599
+ i += 32 // bits
600
+ col += 1
601
+ q = q .to (torch .int32 )
602
+ return q
496
603
497
604
class FakeAffineTensorQuantFunction (Function ):
498
605
"""Fake version of affine quantization."""
0 commit comments