From 395c0531b43aeb783a662835eb630f8007a74351 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 23 Mar 2021 17:03:49 +0100 Subject: [PATCH 1/4] Use FX to have a more robust intermediate feature extraction --- test/test_backbone_utils.py | 47 +++++++++++++++- torchvision/models/_utils.py | 103 ++++++++++++++++++++++++++++++++++- 2 files changed, 146 insertions(+), 4 deletions(-) diff --git a/test/test_backbone_utils.py b/test/test_backbone_utils.py index 7ee1aed1459..a60802c5de2 100644 --- a/test/test_backbone_utils.py +++ b/test/test_backbone_utils.py @@ -1,9 +1,11 @@ import unittest - import torch +import torchvision from torchvision.models.detection.backbone_utils import resnet_fpn_backbone +from torchvision.models._utils import IntermediateLayerGetter, IntermediateLayerGetter2 + class ResnetFPNBackboneTester(unittest.TestCase): @classmethod @@ -23,3 +25,46 @@ def test_resnet50_fpn_backbone(self): resnet50_fpn = resnet_fpn_backbone(backbone_name='resnet50', pretrained=False) y = resnet50_fpn(x) self.assertEqual(list(y.keys()), ['0', '1', '2', '3', 'pool']) + + +class IntermediateLayerGetterTester(unittest.TestCase): + def test_old_new_match(self): + model = torchvision.models.resnet18(pretrained=False) + + return_layers = {'layer2': '5', 'layer4': 'pool'} + + old_model = IntermediateLayerGetter2(model, return_layers).eval() + new_model = IntermediateLayerGetter(model, return_layers).eval() + + # check that we have same parameters + for (n1, p1), (n2, p2) in zip(old_model.named_parameters(), new_model.named_parameters()): + self.assertEqual(n1, n2) + self.assertTrue(p1.equal(p2)) + + # and state_dict matches + for (n1, p1), (n2, p2) in zip(old_model.state_dict().items(), new_model.state_dict().items()): + self.assertEqual(n1, n2) + self.assertTrue(p1.equal(p2)) + + # check that we actually compute the same thing + x = torch.rand(2, 3, 224, 224) + old_out = old_model(x) + new_out = new_model(x) + self.assertEqual(old_out.keys(), new_out.keys()) + for k in old_out.keys(): + o1 = old_out[k] + o2 = new_out[k] + self.assertTrue(o1.equal(o2)) + + # check torchscriptability + script_new_model = torch.jit.script(new_model) + new_out_script = script_new_model(x) + self.assertEqual(old_out.keys(), new_out_script.keys()) + for k in old_out.keys(): + o1 = old_out[k] + o2 = new_out_script[k] + self.assertTrue(o1.equal(o2)) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchvision/models/_utils.py b/torchvision/models/_utils.py index c8faf12786c..e67e9dbd0c1 100644 --- a/torchvision/models/_utils.py +++ b/torchvision/models/_utils.py @@ -1,11 +1,12 @@ from collections import OrderedDict import torch +import torch.fx from torch import nn -from typing import Dict +from typing import Dict, Any, Callable, Tuple, Optional -class IntermediateLayerGetter(nn.ModuleDict): +class IntermediateLayerGetter2(nn.ModuleDict): """ Module wrapper that returns intermediate layers from a model @@ -54,7 +55,7 @@ def __init__(self, model: nn.Module, return_layers: Dict[str, str]) -> None: if not return_layers: break - super(IntermediateLayerGetter, self).__init__(layers) + super().__init__(layers) self.return_layers = orig_return_layers def forward(self, x): @@ -65,3 +66,99 @@ def forward(self, x): out_name = self.return_layers[name] out[out_name] = x return out + + +# taken from https://github.com/pytorch/examples/blob/master/fx/module_tracer.py +# with slight modifications +class ModulePathTracer(torch.fx.Tracer): + """ + ModulePathTracer is an FX tracer that--for each operation--also records + the qualified name of the Module from which the operation originated. + """ + + # The current qualified name of the Module being traced. The top-level + # module is signified by empty string. This is updated when entering + # call_module and restored when exiting call_module + current_module_qualified_name : str = '' + # A map from FX Node to the qualname of the Module from which it + # originated. This is recorded by `create_proxy` when recording an + # operation + node_to_originating_module : Dict[torch.fx.Node, str] = {} + + def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], + args : Tuple[Any, ...], kwargs : Dict[str, Any]) -> Any: + """ + Override of Tracer.call_module (see + https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer.call_module). + This override: + 1) Stores away the qualified name of the caller for restoration later + 2) Installs the qualified name of the caller in `current_module_qualified_name` + for retrieval by `create_proxy` + 3) Delegates into the normal Tracer.call_module method + 4) Restores the caller's qualified name into current_module_qualified_name + """ + old_qualname = self.current_module_qualified_name + try: + module_qualified_name = self.path_of_module(m) + self.current_module_qualified_name = module_qualified_name + if not self.is_leaf_module(m, module_qualified_name): + out = forward(*args, **kwargs) + self.node_to_originating_module[out.node] = module_qualified_name + return out + return self.create_proxy('call_module', module_qualified_name, args, kwargs) + finally: + self.current_module_qualified_name = old_qualname + + def create_proxy(self, kind: str, target: torch.fx.node.Target, args: Tuple[Any, ...], + kwargs: Dict[str, Any], name: Optional[str] = None, type_expr: Optional[Any] = None): + """ + Override of `Tracer.create_proxy`. This override intercepts the recording + of every operation and stores away the current traced module's qualified + name in `node_to_originating_module` + """ + proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr) + self.node_to_originating_module[proxy.node] = self.current_module_qualified_name + return proxy + + +def IntermediateLayerGetter(model: nn.Module, return_layers: Dict[str, str]) -> nn.Module: + return_layers = {str(k): str(v) for k, v in return_layers.items()} + + # Instantiate our ModulePathTracer and use that to trace the model + tracer = ModulePathTracer() + graph = tracer.trace(model) + + name = model.__class__.__name__ if isinstance(model, nn.Module) else model.__name__ + m = torch.fx.GraphModule(tracer.root, graph, name) + + # Get output node + orig_output_node: Optional[torch.fx.Node] = None + for n in reversed(m.graph.nodes): + if n.op == "output": + orig_output_node = n + break + assert orig_output_node + # and remove it + m.graph.erase_node(orig_output_node) + + # find output nodes corresponding to return_layers + nodes = [n for n in m.graph.nodes] + output_node = OrderedDict() + for n in nodes: + module_qualname = tracer.node_to_originating_module.get(n) + if module_qualname in return_layers: + output_node[return_layers[module_qualname]] = n + + # TODO raise error if some of return layers don't exist + # TODO have duplicate nodes but full coverage for module names + + # and add them in the end of the graph + with m.graph.inserting_after(nodes[-1]): + m.graph.output(output_node) + + m.graph.eliminate_dead_code() + m.recompile() + + # remove unused modules / parameters + m = torch.fx.GraphModule(m, m.graph) + return m From 2cd8e848655ff44783c2bd950439e727752f1782 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 14 Apr 2021 13:52:23 +0200 Subject: [PATCH 2/4] Raise error if requested output is not present --- test/test_backbone_utils.py | 4 ++++ torchvision/models/_utils.py | 13 +++++++++---- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/test/test_backbone_utils.py b/test/test_backbone_utils.py index a60802c5de2..aede206f031 100644 --- a/test/test_backbone_utils.py +++ b/test/test_backbone_utils.py @@ -65,6 +65,10 @@ def test_old_new_match(self): o2 = new_out_script[k] self.assertTrue(o1.equal(o2)) + # check assert that non-existing keys raise error + with self.assertRaises(ValueError): + _ = IntermediateLayerGetter(model, {'layer5': '0'}) + if __name__ == "__main__": unittest.main() diff --git a/torchvision/models/_utils.py b/torchvision/models/_utils.py index e67e9dbd0c1..07bad38c346 100644 --- a/torchvision/models/_utils.py +++ b/torchvision/models/_utils.py @@ -122,6 +122,8 @@ def create_proxy(self, kind: str, target: torch.fx.node.Target, args: Tuple[Any, def IntermediateLayerGetter(model: nn.Module, return_layers: Dict[str, str]) -> nn.Module: + # TODO come up with a better name for this + # TODO have duplicate nodes but full coverage for module names return_layers = {str(k): str(v) for k, v in return_layers.items()} # Instantiate our ModulePathTracer and use that to trace the model @@ -131,6 +133,12 @@ def IntermediateLayerGetter(model: nn.Module, return_layers: Dict[str, str]) -> name = model.__class__.__name__ if isinstance(model, nn.Module) else model.__name__ m = torch.fx.GraphModule(tracer.root, graph, name) + + # check that all outputs in return_layers are present in the model + if not set(return_layers).issubset(tracer.node_to_originating_module.values()): + raise ValueError("return_layers are not present in model") + + # Get output node orig_output_node: Optional[torch.fx.Node] = None for n in reversed(m.graph.nodes): @@ -149,9 +157,6 @@ def IntermediateLayerGetter(model: nn.Module, return_layers: Dict[str, str]) -> if module_qualname in return_layers: output_node[return_layers[module_qualname]] = n - # TODO raise error if some of return layers don't exist - # TODO have duplicate nodes but full coverage for module names - # and add them in the end of the graph with m.graph.inserting_after(nodes[-1]): m.graph.output(output_node) @@ -160,5 +165,5 @@ def IntermediateLayerGetter(model: nn.Module, return_layers: Dict[str, str]) -> m.recompile() # remove unused modules / parameters - m = torch.fx.GraphModule(m, m.graph) + m = torch.fx.GraphModule(m, m.graph, name) return m From 6187770dfdd9884f47aa44edb52bcf1c754823d7 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 14 Apr 2021 13:55:12 +0200 Subject: [PATCH 3/4] Rename to not replace old impl for now --- test/test_backbone_utils.py | 6 +++--- torchvision/models/_utils.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/test/test_backbone_utils.py b/test/test_backbone_utils.py index aede206f031..7c21a50f70f 100644 --- a/test/test_backbone_utils.py +++ b/test/test_backbone_utils.py @@ -33,8 +33,8 @@ def test_old_new_match(self): return_layers = {'layer2': '5', 'layer4': 'pool'} - old_model = IntermediateLayerGetter2(model, return_layers).eval() - new_model = IntermediateLayerGetter(model, return_layers).eval() + old_model = IntermediateLayerGetter(model, return_layers).eval() + new_model = IntermediateLayerGetter2(model, return_layers).eval() # check that we have same parameters for (n1, p1), (n2, p2) in zip(old_model.named_parameters(), new_model.named_parameters()): @@ -67,7 +67,7 @@ def test_old_new_match(self): # check assert that non-existing keys raise error with self.assertRaises(ValueError): - _ = IntermediateLayerGetter(model, {'layer5': '0'}) + _ = IntermediateLayerGetter2(model, {'layer5': '0'}) if __name__ == "__main__": diff --git a/torchvision/models/_utils.py b/torchvision/models/_utils.py index 07bad38c346..feb41c4c7ce 100644 --- a/torchvision/models/_utils.py +++ b/torchvision/models/_utils.py @@ -6,7 +6,7 @@ from typing import Dict, Any, Callable, Tuple, Optional -class IntermediateLayerGetter2(nn.ModuleDict): +class IntermediateLayerGetter(nn.ModuleDict): """ Module wrapper that returns intermediate layers from a model @@ -121,7 +121,7 @@ def create_proxy(self, kind: str, target: torch.fx.node.Target, args: Tuple[Any, return proxy -def IntermediateLayerGetter(model: nn.Module, return_layers: Dict[str, str]) -> nn.Module: +def IntermediateLayerGetter2(model: nn.Module, return_layers: Dict[str, str]) -> nn.Module: # TODO come up with a better name for this # TODO have duplicate nodes but full coverage for module names return_layers = {str(k): str(v) for k, v in return_layers.items()} From d72f89ec082bf2981687d39a7def86de67e0fa04 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Sat, 15 May 2021 11:03:43 +0200 Subject: [PATCH 4/4] Dump commit to work on something else --- test/test_backbone_utils.py | 6 +++--- torchvision/models/_utils.py | 29 ++++++++++++++++++++++++++--- 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/test/test_backbone_utils.py b/test/test_backbone_utils.py index 7c21a50f70f..1acf3c25ad2 100644 --- a/test/test_backbone_utils.py +++ b/test/test_backbone_utils.py @@ -4,7 +4,7 @@ import torchvision from torchvision.models.detection.backbone_utils import resnet_fpn_backbone -from torchvision.models._utils import IntermediateLayerGetter, IntermediateLayerGetter2 +from torchvision.models._utils import IntermediateLayerGetter, get_intermediate_layers class ResnetFPNBackboneTester(unittest.TestCase): @@ -34,7 +34,7 @@ def test_old_new_match(self): return_layers = {'layer2': '5', 'layer4': 'pool'} old_model = IntermediateLayerGetter(model, return_layers).eval() - new_model = IntermediateLayerGetter2(model, return_layers).eval() + new_model = get_intermediate_layers(model, return_layers).eval() # check that we have same parameters for (n1, p1), (n2, p2) in zip(old_model.named_parameters(), new_model.named_parameters()): @@ -67,7 +67,7 @@ def test_old_new_match(self): # check assert that non-existing keys raise error with self.assertRaises(ValueError): - _ = IntermediateLayerGetter2(model, {'layer5': '0'}) + _ = get_intermediate_layers(model, {'layer5': '0'}) if __name__ == "__main__": diff --git a/torchvision/models/_utils.py b/torchvision/models/_utils.py index e303607b9cb..ee5519870a4 100644 --- a/torchvision/models/_utils.py +++ b/torchvision/models/_utils.py @@ -122,7 +122,32 @@ def create_proxy(self, kind: str, target: torch.fx.node.Target, args: Tuple[Any, return proxy -def IntermediateLayerGetter2(model: nn.Module, return_layers: Dict[str, str]) -> nn.Module: +def get_intermediate_layers(model: nn.Module, return_layers: Dict[str, str]) -> nn.Module: + """ + Creates a new FX-based module that returns intermediate layers from a given model. + This is achieved by re-writing the computation graph of the model via FX to return + the requested layers. + + All unused layers are removed, together with their corresponding parameters. + + Args: + model (nn.Module): model on which we will extract the features + return_layers (Dict[name, new_name]): a dict containing the names + of the modules for which the activations will be returned as + the key of the dict, and the value of the dict is the name + of the returned activation (which the user can specify). + + Examples:: + + >>> m = torchvision.models.resnet18(pretrained=True) + >>> # extract layer1 and layer3, giving as names `feat1` and feat2` + >>> new_m = torchvision.models._utils.get_intermediate_layers(m, + >>> {'layer1': 'feat1', 'layer3': 'feat2'}) + >>> out = new_m(torch.rand(1, 3, 224, 224)) + >>> print([(k, v.shape) for k, v in out.items()]) + >>> [('feat1', torch.Size([1, 64, 56, 56])), + >>> ('feat2', torch.Size([1, 256, 14, 14]))] + """ # TODO come up with a better name for this # TODO have duplicate nodes but full coverage for module names return_layers = {str(k): str(v) for k, v in return_layers.items()} @@ -134,12 +159,10 @@ def IntermediateLayerGetter2(model: nn.Module, return_layers: Dict[str, str]) -> name = model.__class__.__name__ if isinstance(model, nn.Module) else model.__name__ m = torch.fx.GraphModule(tracer.root, graph, name) - # check that all outputs in return_layers are present in the model if not set(return_layers).issubset(tracer.node_to_originating_module.values()): raise ValueError("return_layers are not present in model") - # Get output node orig_output_node: Optional[torch.fx.Node] = None for n in reversed(m.graph.nodes):