@@ -108,6 +108,9 @@ def set_attrs_from_orig_model(cls_instance, mod, mod_extra_config, *func_names):
108
108
cls_instance .class_name_org = mod .__class__ .__name__
109
109
cls_instance ._mod_extra_config = mod_extra_config
110
110
cls_instance .quantization_mode = config .cfg ["mode" ]
111
+ # store original module in order to invoke its functions during measurements.
112
+ # this may be omitted of torch remove the related validation from dynamo. see SW-187731.
113
+ cls_instance .__dict__ ["orig_mod" ] = mod
111
114
cls_instance .forward_orig = mod .forward
112
115
if func_names is not None :
113
116
for func in func_names :
@@ -164,7 +167,7 @@ def forward(self, input, other):
164
167
165
168
def forward_measure (self , input , other ):
166
169
measure_input ((input , other ), observer = self ._mod_extra_config .inputs )
167
- output = self .forward_orig (input , other )
170
+ output = self .orig_mod (input , other )
168
171
measure_output ((output ,), self ._mod_extra_config .outputs )
169
172
return output
170
173
@@ -210,7 +213,7 @@ def forward_quant(self, input):
210
213
211
214
def forward_measure (self , input ):
212
215
measure_input ((input ,), observer = self ._mod_extra_config .inputs )
213
- output = self .forward_orig (input )
216
+ output = self .orig_mod (input )
214
217
measure_output ((output ,), self ._mod_extra_config .outputs )
215
218
return output
216
219
@@ -372,15 +375,15 @@ def forward(self, input):
372
375
)
373
376
dqoutput = self .quant_output (output )
374
377
if self .gather_output :
375
- dqoutput = self .collective_func (dqoutput )
378
+ dqoutput = self .orig_mod . collective_func (dqoutput )
376
379
return self .post_all_reduce (dqoutput )
377
380
378
381
def forward_measure (self , input ):
379
382
measure_input ((input ,), observer = self ._mod_extra_config .inputs )
380
383
output = torch .matmul (input , self .weight .transpose (- 1 , - 2 ))
381
384
measure_output ((output ,), self ._mod_extra_config .outputs )
382
385
if self .gather_output :
383
- output = self .collective_func (output )
386
+ output = self .orig_mod . collective_func (output )
384
387
return self .post_all_reduce (output )
385
388
386
389
def post_all_reduce (self , output ):
@@ -563,7 +566,7 @@ def forward(self, input):
563
566
564
567
def forward_measure (self , input ):
565
568
measure_input ((input ,), observer = self ._mod_extra_config .inputs )
566
- output = self .forward_orig (input )
569
+ output = self .orig_mod (input )
567
570
measure_output ((output ,), self ._mod_extra_config .outputs )
568
571
return output
569
572
@@ -593,7 +596,7 @@ def forward(self, x, dim=None, invAttnHead=None):
593
596
594
597
def forward_measure (self , x , dim = None , invAttnHead = None ):
595
598
measure_input ((x ,), observer = self ._mod_extra_config .inputs )
596
- output = self .forward_orig (x , dim , invAttnHead )
599
+ output = self .orig_mod (x , dim , invAttnHead )
597
600
measure_output ((output ,), self ._mod_extra_config .outputs )
598
601
return output
599
602
@@ -634,7 +637,7 @@ def forward(self, input, scale: float = 1.0):
634
637
635
638
def forward_measure (self , input , scale : float = 1.0 ):
636
639
measure_input ((input ,), observer = self ._mod_extra_config .inputs )
637
- output = self .forward_orig (input , scale )
640
+ output = self .orig_mod (input , scale )
638
641
measure_output ((output ,), self ._mod_extra_config .outputs )
639
642
return output
640
643
@@ -682,7 +685,7 @@ def forward(self, input, scale: float = 1.0):
682
685
683
686
def forward_measure (self , input , scale : float = 1.0 ):
684
687
measure_input ((input ,), observer = self ._mod_extra_config .inputs )
685
- output = self .forward_orig (input , scale )
688
+ output = self .orig_mod (input , scale )
686
689
measure_output ((output ,), self ._mod_extra_config .outputs )
687
690
return output
688
691
0 commit comments