Skip to content

Use FX to have a more robust intermediate feature extraction #3597

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion test/test_backbone_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import unittest


import torch
import torchvision
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone

from torchvision.models._utils import IntermediateLayerGetter, get_intermediate_layers


class ResnetFPNBackboneTester(unittest.TestCase):
@classmethod
Expand All @@ -23,3 +25,50 @@ def test_resnet50_fpn_backbone(self):
resnet50_fpn = resnet_fpn_backbone(backbone_name='resnet50', pretrained=False)
y = resnet50_fpn(x)
self.assertEqual(list(y.keys()), ['0', '1', '2', '3', 'pool'])


class IntermediateLayerGetterTester(unittest.TestCase):
def test_old_new_match(self):
model = torchvision.models.resnet18(pretrained=False)

return_layers = {'layer2': '5', 'layer4': 'pool'}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, it fails when we include the final output in the return layers:

Suggested change
return_layers = {'layer2': '5', 'layer4': 'pool'}
return_layers = {'layer2': '5', 'layer4': 'pool', 'fc': 'fc1'}

with:

E       RuntimeError: mat1 and mat2 shapes cannot be multiplied (1024x1 and 512x1000)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the catch! I need to check more carefully, but the old implementation doesn't work in this case because of the torch.flatten call (which is not a nn.Module), but I believe this should work with the new implementation. To be verified


old_model = IntermediateLayerGetter(model, return_layers).eval()
new_model = get_intermediate_layers(model, return_layers).eval()

# check that we have same parameters
for (n1, p1), (n2, p2) in zip(old_model.named_parameters(), new_model.named_parameters()):
self.assertEqual(n1, n2)
self.assertTrue(p1.equal(p2))

# and state_dict matches
for (n1, p1), (n2, p2) in zip(old_model.state_dict().items(), new_model.state_dict().items()):
self.assertEqual(n1, n2)
self.assertTrue(p1.equal(p2))

# check that we actually compute the same thing
x = torch.rand(2, 3, 224, 224)
old_out = old_model(x)
new_out = new_model(x)
self.assertEqual(old_out.keys(), new_out.keys())
for k in old_out.keys():
o1 = old_out[k]
o2 = new_out[k]
self.assertTrue(o1.equal(o2))

# check torchscriptability
script_new_model = torch.jit.script(new_model)
new_out_script = script_new_model(x)
self.assertEqual(old_out.keys(), new_out_script.keys())
for k in old_out.keys():
o1 = old_out[k]
o2 = new_out_script[k]
self.assertTrue(o1.equal(o2))

# check assert that non-existing keys raise error
with self.assertRaises(ValueError):
_ = get_intermediate_layers(model, {'layer5': '0'})


if __name__ == "__main__":
unittest.main()
131 changes: 129 additions & 2 deletions torchvision/models/_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from collections import OrderedDict

import torch
import torch.fx

from torch import nn
from typing import Dict
from typing import Dict, Any, Callable, Tuple, Optional


class IntermediateLayerGetter(nn.ModuleDict):
Expand Down Expand Up @@ -53,7 +56,7 @@ def __init__(self, model: nn.Module, return_layers: Dict[str, str]) -> None:
if not return_layers:
break

super(IntermediateLayerGetter, self).__init__(layers)
super().__init__(layers)
self.return_layers = orig_return_layers

def forward(self, x):
Expand All @@ -64,3 +67,127 @@ def forward(self, x):
out_name = self.return_layers[name]
out[out_name] = x
return out


# taken from https://github.com/pytorch/examples/blob/master/fx/module_tracer.py
# with slight modifications
class ModulePathTracer(torch.fx.Tracer):
"""
ModulePathTracer is an FX tracer that--for each operation--also records
the qualified name of the Module from which the operation originated.
"""

# The current qualified name of the Module being traced. The top-level
# module is signified by empty string. This is updated when entering
# call_module and restored when exiting call_module
current_module_qualified_name : str = ''
# A map from FX Node to the qualname of the Module from which it
# originated. This is recorded by `create_proxy` when recording an
# operation
node_to_originating_module : Dict[torch.fx.Node, str] = {}

def call_module(self, m: torch.nn.Module, forward: Callable[..., Any],
args : Tuple[Any, ...], kwargs : Dict[str, Any]) -> Any:
"""
Override of Tracer.call_module (see
https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer.call_module).
This override:
1) Stores away the qualified name of the caller for restoration later
2) Installs the qualified name of the caller in `current_module_qualified_name`
for retrieval by `create_proxy`
3) Delegates into the normal Tracer.call_module method
4) Restores the caller's qualified name into current_module_qualified_name
"""
old_qualname = self.current_module_qualified_name
try:
module_qualified_name = self.path_of_module(m)
self.current_module_qualified_name = module_qualified_name
if not self.is_leaf_module(m, module_qualified_name):
out = forward(*args, **kwargs)
self.node_to_originating_module[out.node] = module_qualified_name
return out
return self.create_proxy('call_module', module_qualified_name, args, kwargs)
finally:
self.current_module_qualified_name = old_qualname

def create_proxy(self, kind: str, target: torch.fx.node.Target, args: Tuple[Any, ...],
kwargs: Dict[str, Any], name: Optional[str] = None, type_expr: Optional[Any] = None):
"""
Override of `Tracer.create_proxy`. This override intercepts the recording
of every operation and stores away the current traced module's qualified
name in `node_to_originating_module`
"""
proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr)
self.node_to_originating_module[proxy.node] = self.current_module_qualified_name
return proxy


def get_intermediate_layers(model: nn.Module, return_layers: Dict[str, str]) -> nn.Module:
"""
Creates a new FX-based module that returns intermediate layers from a given model.
This is achieved by re-writing the computation graph of the model via FX to return
the requested layers.

All unused layers are removed, together with their corresponding parameters.

Args:
model (nn.Module): model on which we will extract the features
return_layers (Dict[name, new_name]): a dict containing the names
of the modules for which the activations will be returned as
the key of the dict, and the value of the dict is the name
of the returned activation (which the user can specify).

Examples::

>>> m = torchvision.models.resnet18(pretrained=True)
>>> # extract layer1 and layer3, giving as names `feat1` and feat2`
>>> new_m = torchvision.models._utils.get_intermediate_layers(m,
>>> {'layer1': 'feat1', 'layer3': 'feat2'})
>>> out = new_m(torch.rand(1, 3, 224, 224))
>>> print([(k, v.shape) for k, v in out.items()])
>>> [('feat1', torch.Size([1, 64, 56, 56])),
>>> ('feat2', torch.Size([1, 256, 14, 14]))]
"""
# TODO come up with a better name for this
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think return_layers is fine. I understand it remaps but it still stores the mapping of the returned layers.

# TODO have duplicate nodes but full coverage for module names
return_layers = {str(k): str(v) for k, v in return_layers.items()}

# Instantiate our ModulePathTracer and use that to trace the model
tracer = ModulePathTracer()
graph = tracer.trace(model)

name = model.__class__.__name__ if isinstance(model, nn.Module) else model.__name__
m = torch.fx.GraphModule(tracer.root, graph, name)

# check that all outputs in return_layers are present in the model
if not set(return_layers).issubset(tracer.node_to_originating_module.values()):
raise ValueError("return_layers are not present in model")

# Get output node
orig_output_node: Optional[torch.fx.Node] = None
for n in reversed(m.graph.nodes):
if n.op == "output":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens in cases where we have multiple outputs (example Inception3 which got auxiliaries)? It seems that FX has another node called inception_outputs:

>>> list(reversed(m.graph.nodes))
[output, inception_outputs, fc, flatten_1, dropout, ....]

You can see this by replacing this input on your test:

        model = torchvision.models.inception_v3(pretrained=False)
        return_layers = {'Mixed_7c': '0', 'avgpool': '1'}

orig_output_node = n
break
assert orig_output_node
# and remove it
m.graph.erase_node(orig_output_node)

# find output nodes corresponding to return_layers
nodes = [n for n in m.graph.nodes]
output_node = OrderedDict()
for n in nodes:
module_qualname = tracer.node_to_originating_module.get(n)
if module_qualname in return_layers:
output_node[return_layers[module_qualname]] = n

# and add them in the end of the graph
with m.graph.inserting_after(nodes[-1]):
m.graph.output(output_node)

m.graph.eliminate_dead_code()
m.recompile()

# remove unused modules / parameters
m = torch.fx.GraphModule(m, m.graph, name)
return m