Skip to content

Commit 1ed690c

Browse files
committed
[SW-187731] Save orig module as member of patched module
This allows direct usage of the original module methods, which solves torch compile issue Change-Id: I464d8bd1bacdfc3cd1f128a67114e1e43f092632
1 parent adfe13b commit 1ed690c

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@ def set_attrs_from_orig_model(cls_instance, mod, mod_extra_config, *func_names):
108108
cls_instance.class_name_org = mod.__class__.__name__
109109
cls_instance._mod_extra_config = mod_extra_config
110110
cls_instance.quantization_mode = config.cfg["mode"]
111+
# store original module in order to invoke its functions during measurements.
112+
# this may be omitted of torch remove the related validation from dynamo. see SW-187731.
113+
cls_instance.__dict__["orig_mod"] = mod
111114
cls_instance.forward_orig = mod.forward
112115
if func_names is not None:
113116
for func in func_names:
@@ -164,7 +167,7 @@ def forward(self, input, other):
164167

165168
def forward_measure(self, input, other):
166169
measure_input((input, other), observer=self._mod_extra_config.inputs)
167-
output = self.forward_orig(input, other)
170+
output = self.orig_mod(input, other)
168171
measure_output((output,), self._mod_extra_config.outputs)
169172
return output
170173

@@ -210,7 +213,7 @@ def forward_quant(self, input):
210213

211214
def forward_measure(self, input):
212215
measure_input((input,), observer=self._mod_extra_config.inputs)
213-
output = self.forward_orig(input)
216+
output = self.orig_mod(input)
214217
measure_output((output,), self._mod_extra_config.outputs)
215218
return output
216219

@@ -372,15 +375,15 @@ def forward(self, input):
372375
)
373376
dqoutput = self.quant_output(output)
374377
if self.gather_output:
375-
dqoutput = self.collective_func(dqoutput)
378+
dqoutput = self.orig_mod.collective_func(dqoutput)
376379
return self.post_all_reduce(dqoutput)
377380

378381
def forward_measure(self, input):
379382
measure_input((input,), observer=self._mod_extra_config.inputs)
380383
output = torch.matmul(input, self.weight.transpose(-1, -2))
381384
measure_output((output,), self._mod_extra_config.outputs)
382385
if self.gather_output:
383-
output = self.collective_func(output)
386+
output = self.orig_mod.collective_func(output)
384387
return self.post_all_reduce(output)
385388

386389
def post_all_reduce(self, output):
@@ -563,7 +566,7 @@ def forward(self, input):
563566

564567
def forward_measure(self, input):
565568
measure_input((input,), observer=self._mod_extra_config.inputs)
566-
output = self.forward_orig(input)
569+
output = self.orig_mod(input)
567570
measure_output((output,), self._mod_extra_config.outputs)
568571
return output
569572

@@ -593,7 +596,7 @@ def forward(self, x, dim=None, invAttnHead=None):
593596

594597
def forward_measure(self, x, dim=None, invAttnHead=None):
595598
measure_input((x,), observer=self._mod_extra_config.inputs)
596-
output = self.forward_orig(x, dim, invAttnHead)
599+
output = self.orig_mod(x, dim, invAttnHead)
597600
measure_output((output,), self._mod_extra_config.outputs)
598601
return output
599602

@@ -634,7 +637,7 @@ def forward(self, input, scale: float = 1.0):
634637

635638
def forward_measure(self, input, scale: float = 1.0):
636639
measure_input((input,), observer=self._mod_extra_config.inputs)
637-
output = self.forward_orig(input, scale)
640+
output = self.orig_mod(input, scale)
638641
measure_output((output,), self._mod_extra_config.outputs)
639642
return output
640643

@@ -682,7 +685,7 @@ def forward(self, input, scale: float = 1.0):
682685

683686
def forward_measure(self, input, scale: float = 1.0):
684687
measure_input((input,), observer=self._mod_extra_config.inputs)
685-
output = self.forward_orig(input, scale)
688+
output = self.orig_mod(input, scale)
686689
measure_output((output,), self._mod_extra_config.outputs)
687690
return output
688691

0 commit comments

Comments
 (0)