Skip to content

Commit 56c03d8

Browse files
Yantom1ulivne
authored andcommitted
[SW-190303] Implement HPUWeightOnlyLinear class in INC
Change-Id: Ie05c8787e708e2c3559dce24ef0758d6c498ac41
1 parent 969f467 commit 56c03d8

File tree

1 file changed

+121
-14
lines changed
  • neural_compressor/torch/algorithms/weight_only

1 file changed

+121
-14
lines changed

neural_compressor/torch/algorithms/weight_only/modules.py

Lines changed: 121 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -454,11 +454,7 @@ def extra_repr(self) -> str:
454454
tmp_str += ", use_optimum_format=True"
455455
return tmp_str
456456

457-
458-
# TODO: implement HPUWeightOnlyLinear
459-
# temporarily let HPUWeightOnlyLinear inherit INCWeightOnlyLinear
460-
# should be 'class HPUWeightOnlyLinear(WeightOnlyLinear)'
461-
class HPUWeightOnlyLinear(INCWeightOnlyLinear):
457+
class HPUWeightOnlyLinear(WeightOnlyLinear):
462458
def __init__(
463459
self,
464460
in_features,
@@ -468,7 +464,7 @@ def __init__(
468464
group_size=32,
469465
zp=False,
470466
bias=False,
471-
scale_dtype=torch.float32,
467+
scale_dtype=torch.bfloat16,
472468
compression_dtype=torch.int32,
473469
compression_dim=1,
474470
g_idx=False,
@@ -482,17 +478,128 @@ def __init__(
482478
dtype,
483479
bits,
484480
group_size,
485-
zp,
486-
bias,
487-
scale_dtype,
488-
compression_dtype,
489-
compression_dim,
490-
g_idx,
491481
device,
492-
use_optimum_format,
493-
**kwargs,
482+
)
483+
self.float_type = torch.bfloat16
484+
self.compression_dim = compression_dim
485+
self.compression_dtype = compression_dtype
486+
487+
if bits != 4:
488+
raise NotImplementedError("Only 4 bits are supported.")
489+
self.maxq = 2**self.bits - 1
490+
491+
if bias:
492+
self.register_buffer("bias", torch.zeros(self.out_features, dtype=self.float_type).to(self.device))
493+
else:
494+
self.bias = None
495+
496+
self.register_buffer(
497+
"qweight",
498+
torch.zeros((in_features, out_features // 32 * self.bits), dtype=self.compression_dtype).to(self.device),
494499
)
495500

501+
self.register_buffer(
502+
"qzeros",
503+
torch.zeros(
504+
(
505+
math.ceil(in_features / self.group_size),
506+
out_features // 32 * self.bits,
507+
),
508+
dtype=self.compression_dtype,
509+
),
510+
)
511+
self.register_buffer(
512+
"scales",
513+
torch.zeros(
514+
(math.ceil(in_features / self.group_size), out_features),
515+
dtype=self.float_type,
516+
),
517+
)
518+
519+
if g_idx:
520+
self.register_buffer(
521+
"g_idx",
522+
torch.tensor([i // self.group_size for i in range(in_features)], dtype=torch.int32),
523+
)
524+
else:
525+
self.g_idx = None
526+
527+
self.half_indim = self.in_features // 2
528+
529+
self.wf = torch.tensor(list(range(0, 32, self.bits)), dtype=torch.int32).unsqueeze(0)
530+
531+
def forward(self, input):
532+
input_dtype = input.dtype
533+
output_shape = input.shape[:-1] + (self.out_features,)
534+
scales = self.scales
535+
qweight = self.qweight
536+
zeros = self.qzeros
537+
weight = torch.ops.hpu.convert_from_uint4(qweight, scales, zeros, input_dtype)
538+
output = torch.matmul(input, weight)
539+
output = output.to(dtype=input_dtype).reshape(
540+
output_shape
541+
542+
) # A cast is needed here as for some reason the vecquant2matmul_faster_old still allocate a float32 output.
543+
output = output + self.bias if self.bias is not None else output
544+
return output
545+
546+
547+
def pack(self, int_weight, scales, zp, bias=None, g_idx=None):
548+
logger.debug(f"Packing for HPU")
549+
550+
scales = scales.T.contiguous()
551+
qzeros = zp.T.contiguous()
552+
qweight = int_weight.T.contiguous()
553+
554+
self.scales = scales.to(dtype=torch.bfloat16)
555+
556+
# weights and zp are on device from unpack, need to load to cpu for packing
557+
self.qweight = qweight.cpu()
558+
new_qweight = self.pack_tensor(self.qweight)
559+
self.qweight = new_qweight.to("hpu")
560+
561+
self.qzeros = qzeros.cpu()
562+
new_qzeros = self.pack_tensor(self.qzeros)
563+
self.qzeros = new_qzeros.to("hpu")
564+
565+
if bias is not None:
566+
self.bias = bias.to("hpu").to(torch.bfloat16)
567+
568+
def unpack(self):
569+
logger.debug(f"Unpacking from HPU")
570+
self.qweight = self.qweight.cpu()
571+
weight = torch.bitwise_right_shift(
572+
torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1),
573+
self.wf.unsqueeze(-1),
574+
).to(torch.int16 if self.bits == 8 else torch.int8)
575+
weight = torch.bitwise_and(weight, (2**self.bits) - 1)
576+
weight = weight.reshape((weight.shape[0]*weight.shape[1], weight.shape[2]))
577+
self.qweight = self.qweight.to(self.device)
578+
579+
zeros = torch.bitwise_right_shift(
580+
torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits),
581+
self.wf.unsqueeze(0),
582+
).to(torch.int16 if self.bits == 8 else torch.int8)
583+
584+
zeros = torch.bitwise_and(
585+
zeros, (2**self.bits) - 1
586+
).to(self.scales.dtype) # NOTE: It appears that casting here after the `zeros = zeros + 1` is important.
587+
zeros = zeros + 1
588+
zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2])
589+
return weight, zeros
590+
591+
def pack_tensor(self, input, bits = 4):
592+
normal = input.to(torch.int32)
593+
q = torch.zeros((normal.shape[0], normal.shape[1] // 32 * bits), dtype=torch.int32)
594+
i = 0
595+
col = 0
596+
while col < q.shape[1]:
597+
for j in range(i, i + (32 // bits)):
598+
q[:, col] |= normal[:, j] << (bits * (j - i))
599+
i += 32 // bits
600+
col += 1
601+
q = q.to(torch.int32)
602+
return q
496603

497604
class FakeAffineTensorQuantFunction(Function):
498605
"""Fake version of affine quantization."""

0 commit comments

Comments
 (0)