Skip to content

Commit 063c936

Browse files
jerryzh168pytorchmergebot
authored andcommitted
[quant] follow up fixes for prepare_fx/prepare_qat_fx calls in classyvision (#105) (pytorch#78660)
Summary: X-link: https://github.com/fairinternal/ClassyVision/pull/105 As follow up for pytorch#76496, we fixes the TODOs in quantization tests by providing correct example_inputs in the tests Test Plan: classyvision sandcastle and ossci **Static Docs Preview: classyvision** |[Full Site](https://our.intern.facebook.com/intern/staticdocs/eph/D36818665/V1/classyvision/)| |**Modified Pages**| Differential Revision: D36818665 Pull Request resolved: pytorch#78660 Approved by: https://github.com/vkuzo
1 parent ad1bff1 commit 063c936

File tree

1 file changed

+23
-29
lines changed

1 file changed

+23
-29
lines changed

torch/ao/quantization/utils.py

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -495,35 +495,29 @@ def forward(self, x, key1=3, key2=3):
495495
root = model
496496
fqn_to_example_inputs = {}
497497

498-
class InterceptionModule(type(model)): # type: ignore[misc]
499-
def __call__(self, *args, **kwargs):
500-
orig_module_call = torch.nn.Module.__call__
501-
502-
def _patched_module_call(self, *args, **kwargs):
503-
submodule_example_inputs = list(args).copy()
504-
normalized_kwargs = _normalize_kwargs(self.forward, kwargs)
505-
# minus 1 to skipping counting `self`
506-
num_args = _get_num_pos_args(self.forward) - 1
507-
num_to_pop = num_args - len(submodule_example_inputs)
508-
while num_to_pop and normalized_kwargs:
509-
normalized_kwargs.popitem(last=False)
510-
num_to_pop -= 1
511-
submodule_example_inputs.extend(normalized_kwargs.values())
512-
submodule_example_inputs_tuple = tuple(submodule_example_inputs)
513-
fqn = _get_path_of_module(root, self)
514-
if fqn is not None:
515-
fqn_to_example_inputs[fqn] = submodule_example_inputs_tuple
516-
return orig_module_call(self, *args, **kwargs)
517-
518-
torch.nn.Module.__call__ = _patched_module_call
519-
super().__call__(*args, **kwargs)
520-
torch.nn.Module.__call__ = orig_module_call
521-
522-
original_class = model.__class__
523-
model.__class__ = InterceptionModule
524-
model(*example_inputs)
525-
model.__class__ = original_class
526-
498+
def _patched_module_call(self, *args, **kwargs):
499+
submodule_example_inputs = list(args).copy()
500+
normalized_kwargs = _normalize_kwargs(self.forward, kwargs)
501+
# minus 1 to skipping counting `self`
502+
num_args = _get_num_pos_args(self.forward) - 1
503+
num_to_pop = num_args - len(submodule_example_inputs)
504+
while num_to_pop and normalized_kwargs:
505+
normalized_kwargs.popitem(last=False)
506+
num_to_pop -= 1
507+
submodule_example_inputs.extend(normalized_kwargs.values())
508+
submodule_example_inputs_tuple = tuple(submodule_example_inputs)
509+
fqn = _get_path_of_module(root, self)
510+
if fqn is not None:
511+
fqn_to_example_inputs[fqn] = submodule_example_inputs_tuple
512+
return orig_module_call(self, *args, **kwargs)
513+
514+
orig_module_call = torch.nn.Module.__call__
515+
torch.nn.Module.__call__ = _patched_module_call
516+
try:
517+
model(*example_inputs)
518+
finally:
519+
# restore the module call even if there is an exception
520+
torch.nn.Module.__call__ = orig_module_call
527521
return fqn_to_example_inputs
528522

529523

0 commit comments

Comments
 (0)