Skip to content

Commit 14868c0

Browse files
authored
Refact gptq to support runing on gaudi (#1700)
* gptq support for gaudi Signed-off-by: n1ck-guo <[email protected]>
1 parent 9d7a052 commit 14868c0

File tree

2 files changed

+33
-4
lines changed

2 files changed

+33
-4
lines changed

neural_compressor/torch/algorithms/weight_only/gptq.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from .modules import WeightOnlyLinear
3535

3636
DEBUG = False
37+
accelerator = auto_detect_accelerator()
3738

3839

3940
# ================ device related ===================
@@ -542,8 +543,10 @@ def forward(layer, *args, **kwargs):
542543
if self.run_fn:
543544
if self.run_args:
544545
self.run_fn(self.model, *self.run_args)
546+
accelerator.mark_step()
545547
else:
546548
self.run_fn(self.model)
549+
accelerator.mark_step()
547550
else:
548551
for batch in tqdm(self.dataloader):
549552
if not self.use_layer_wise:
@@ -663,6 +666,7 @@ def tmp(_, inp, out):
663666
for j in range(batch_num):
664667
cache_keyword_batch = self.gather_single_batch_from_dict(self.cache_key_arguments, j)
665668
cache_positional_batch = self.gather_single_batch_from_list(self.cache_positional_arguments, j)
669+
accelerator.mark_step()
666670
out = transformer_block(*cache_positional_batch, **cache_keyword_batch)
667671
out = self.track_hidden_states(out)
668672
self.cache_key_arguments["batch_num"] = batch_num
@@ -682,6 +686,9 @@ def tmp(_, inp, out):
682686
W = load_value(self.model, full_layer_name + ".weight", model_path)
683687
else:
684688
W = sub_layers[layer_name].weight.data.clone()
689+
accelerator.mark_step()
690+
if "hpu" in self.device:
691+
W = W.to("cpu")
685692
scale, zp, Q = gptq_for_this_block[layer_name].fasterquant(
686693
W,
687694
blocksize=weight_config_this_layer["block_size"],
@@ -854,6 +861,8 @@ def fasterquant(self, W, blocksize=128, percdamp=0.01, groupsize=-1, act_order=F
854861
self.quantizer.find_params(W, weight=True)
855862

856863
H = self.H
864+
if "hpu" in self.device:
865+
H = H.to("cpu")
857866
del self.H
858867
dead = torch.diag(H) == 0
859868
H[dead, dead] = 1
@@ -958,6 +967,10 @@ def fasterquant(self, W, blocksize=128, percdamp=0.01, groupsize=-1, act_order=F
958967
zero.append(self.quantizer.zero)
959968
scale = torch.cat(scale, dim=1)
960969
zero = torch.cat(zero, dim=1)
970+
if "hpu" in self.device:
971+
scale = scale.to(self.device)
972+
zero = zero.to(self.device)
973+
Q = Q.to(self.device)
961974
return scale, zero, Q
962975

963976
def free(self):
@@ -973,25 +986,25 @@ def free(self):
973986
class Quantizer(nn.Module):
974987
def __init__(self, shape=1):
975988
super(Quantizer, self).__init__()
976-
self.register_buffer("maxq", torch.tensor(0))
989+
self.maxq = 0
977990
self.register_buffer("scale", torch.zeros(shape))
978991
self.register_buffer("zero", torch.zeros(shape))
979992

980993
def configure(self, weight_config_this_layer, norm=2.4, grid=100, maxshrink=0.8, trits=False):
981994
for k, v in weight_config_this_layer.items():
982995
setattr(self, k, v)
983-
self.maxq = torch.tensor(2**self.bits - 1)
996+
# self.maxq = torch.tensor(2**self.bits - 1)
997+
self.maxq = 2**self.bits - 1
984998
self.scheme = "sym" if self.sym else "asym"
985999
self.double_quant_scheme = "sym" if self.double_quant_sym else "asym"
9861000
self.norm = norm
9871001
self.grid = grid
9881002
self.maxshrink = maxshrink
9891003
if trits:
990-
self.maxq = torch.tensor(-1)
1004+
self.maxq = -1
9911005

9921006
def find_params(self, x, weight=False):
9931007
dev = x.device
994-
self.maxq = self.maxq.to(dev)
9951008
# NF4 FP4
9961009
if self.dtype != "int":
9971010
from .utility import quant_tensor

neural_compressor/torch/utils/auto_accelerator.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,11 @@
2929

3030
import torch
3131

32+
from neural_compressor.common.utils import LazyImport
3233
from neural_compressor.torch.utils import logger
3334

35+
htcore = LazyImport("habana_frameworks.torch.core")
36+
3437
PRIORITY_HPU = 100
3538
PRIORITY_CUDA = 95
3639
PRIORITY_CPU = 90
@@ -133,6 +136,10 @@ def empty_cache(self):
133136
def synchronize(self):
134137
pass
135138

139+
@abstractmethod
140+
def mark_step(self):
141+
pass
142+
136143

137144
@register_accelerator(name="cpu", priority=PRIORITY_CPU)
138145
class CPU_Accelerator(Auto_Accelerator):
@@ -167,6 +174,9 @@ def empty_cache(self):
167174
def synchronize(self):
168175
pass
169176

177+
def mark_step(self):
178+
pass
179+
170180

171181
@register_accelerator(name="cuda", priority=PRIORITY_CUDA)
172182
class CUDA_Accelerator(Auto_Accelerator):
@@ -203,6 +213,9 @@ def device(self, device_index=None):
203213
def empty_cache(self):
204214
return torch.cuda.empty_cache()
205215

216+
def mark_step(self):
217+
pass
218+
206219

207220
@register_accelerator(name="hpu", priority=PRIORITY_HPU)
208221
class HPU_Accelerator(Auto_Accelerator):
@@ -244,6 +257,9 @@ def device(self, device_index=None):
244257
def empty_cache(self):
245258
return torch.hpu.empty_cache()
246259

260+
def mark_step(self):
261+
return htcore.mark_step()
262+
247263

248264
def auto_detect_accelerator(device_name="auto") -> Auto_Accelerator:
249265
# Force use the cpu on node has both cpu and gpu: `FORCE_DEVICE=cpu` python main.py ...

0 commit comments

Comments
 (0)