@@ -495,35 +495,29 @@ def forward(self, x, key1=3, key2=3):
495
495
root = model
496
496
fqn_to_example_inputs = {}
497
497
498
- class InterceptionModule (type (model )): # type: ignore[misc]
499
- def __call__ (self , * args , ** kwargs ):
500
- orig_module_call = torch .nn .Module .__call__
501
-
502
- def _patched_module_call (self , * args , ** kwargs ):
503
- submodule_example_inputs = list (args ).copy ()
504
- normalized_kwargs = _normalize_kwargs (self .forward , kwargs )
505
- # minus 1 to skipping counting `self`
506
- num_args = _get_num_pos_args (self .forward ) - 1
507
- num_to_pop = num_args - len (submodule_example_inputs )
508
- while num_to_pop and normalized_kwargs :
509
- normalized_kwargs .popitem (last = False )
510
- num_to_pop -= 1
511
- submodule_example_inputs .extend (normalized_kwargs .values ())
512
- submodule_example_inputs_tuple = tuple (submodule_example_inputs )
513
- fqn = _get_path_of_module (root , self )
514
- if fqn is not None :
515
- fqn_to_example_inputs [fqn ] = submodule_example_inputs_tuple
516
- return orig_module_call (self , * args , ** kwargs )
517
-
518
- torch .nn .Module .__call__ = _patched_module_call
519
- super ().__call__ (* args , ** kwargs )
520
- torch .nn .Module .__call__ = orig_module_call
521
-
522
- original_class = model .__class__
523
- model .__class__ = InterceptionModule
524
- model (* example_inputs )
525
- model .__class__ = original_class
526
-
498
+ def _patched_module_call (self , * args , ** kwargs ):
499
+ submodule_example_inputs = list (args ).copy ()
500
+ normalized_kwargs = _normalize_kwargs (self .forward , kwargs )
501
+ # minus 1 to skipping counting `self`
502
+ num_args = _get_num_pos_args (self .forward ) - 1
503
+ num_to_pop = num_args - len (submodule_example_inputs )
504
+ while num_to_pop and normalized_kwargs :
505
+ normalized_kwargs .popitem (last = False )
506
+ num_to_pop -= 1
507
+ submodule_example_inputs .extend (normalized_kwargs .values ())
508
+ submodule_example_inputs_tuple = tuple (submodule_example_inputs )
509
+ fqn = _get_path_of_module (root , self )
510
+ if fqn is not None :
511
+ fqn_to_example_inputs [fqn ] = submodule_example_inputs_tuple
512
+ return orig_module_call (self , * args , ** kwargs )
513
+
514
+ orig_module_call = torch .nn .Module .__call__
515
+ torch .nn .Module .__call__ = _patched_module_call
516
+ try :
517
+ model (* example_inputs )
518
+ finally :
519
+ # restore the module call even if there is an exception
520
+ torch .nn .Module .__call__ = orig_module_call
527
521
return fqn_to_example_inputs
528
522
529
523
0 commit comments