-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Use FX to have a more robust intermediate feature extraction #3597
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
395c053
2cd8e84
6187770
fbbbf18
d72f89e
41caaa9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,10 @@ | ||
from collections import OrderedDict | ||
|
||
import torch | ||
import torch.fx | ||
|
||
from torch import nn | ||
from typing import Dict | ||
from typing import Dict, Any, Callable, Tuple, Optional | ||
|
||
|
||
class IntermediateLayerGetter(nn.ModuleDict): | ||
|
@@ -53,7 +56,7 @@ def __init__(self, model: nn.Module, return_layers: Dict[str, str]) -> None: | |
if not return_layers: | ||
break | ||
|
||
super(IntermediateLayerGetter, self).__init__(layers) | ||
super().__init__(layers) | ||
self.return_layers = orig_return_layers | ||
|
||
def forward(self, x): | ||
|
@@ -64,3 +67,127 @@ def forward(self, x): | |
out_name = self.return_layers[name] | ||
out[out_name] = x | ||
return out | ||
|
||
|
||
# taken from https://github.com/pytorch/examples/blob/master/fx/module_tracer.py | ||
# with slight modifications | ||
class ModulePathTracer(torch.fx.Tracer): | ||
""" | ||
ModulePathTracer is an FX tracer that--for each operation--also records | ||
the qualified name of the Module from which the operation originated. | ||
""" | ||
|
||
# The current qualified name of the Module being traced. The top-level | ||
# module is signified by empty string. This is updated when entering | ||
# call_module and restored when exiting call_module | ||
current_module_qualified_name : str = '' | ||
# A map from FX Node to the qualname of the Module from which it | ||
# originated. This is recorded by `create_proxy` when recording an | ||
# operation | ||
node_to_originating_module : Dict[torch.fx.Node, str] = {} | ||
|
||
def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], | ||
args : Tuple[Any, ...], kwargs : Dict[str, Any]) -> Any: | ||
""" | ||
Override of Tracer.call_module (see | ||
https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer.call_module). | ||
This override: | ||
1) Stores away the qualified name of the caller for restoration later | ||
2) Installs the qualified name of the caller in `current_module_qualified_name` | ||
for retrieval by `create_proxy` | ||
3) Delegates into the normal Tracer.call_module method | ||
4) Restores the caller's qualified name into current_module_qualified_name | ||
""" | ||
old_qualname = self.current_module_qualified_name | ||
try: | ||
module_qualified_name = self.path_of_module(m) | ||
self.current_module_qualified_name = module_qualified_name | ||
if not self.is_leaf_module(m, module_qualified_name): | ||
out = forward(*args, **kwargs) | ||
self.node_to_originating_module[out.node] = module_qualified_name | ||
return out | ||
return self.create_proxy('call_module', module_qualified_name, args, kwargs) | ||
finally: | ||
self.current_module_qualified_name = old_qualname | ||
|
||
def create_proxy(self, kind: str, target: torch.fx.node.Target, args: Tuple[Any, ...], | ||
kwargs: Dict[str, Any], name: Optional[str] = None, type_expr: Optional[Any] = None): | ||
""" | ||
Override of `Tracer.create_proxy`. This override intercepts the recording | ||
of every operation and stores away the current traced module's qualified | ||
name in `node_to_originating_module` | ||
""" | ||
proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr) | ||
self.node_to_originating_module[proxy.node] = self.current_module_qualified_name | ||
return proxy | ||
|
||
|
||
def get_intermediate_layers(model: nn.Module, return_layers: Dict[str, str]) -> nn.Module: | ||
""" | ||
Creates a new FX-based module that returns intermediate layers from a given model. | ||
This is achieved by re-writing the computation graph of the model via FX to return | ||
the requested layers. | ||
|
||
All unused layers are removed, together with their corresponding parameters. | ||
|
||
Args: | ||
model (nn.Module): model on which we will extract the features | ||
return_layers (Dict[name, new_name]): a dict containing the names | ||
of the modules for which the activations will be returned as | ||
the key of the dict, and the value of the dict is the name | ||
of the returned activation (which the user can specify). | ||
|
||
Examples:: | ||
|
||
>>> m = torchvision.models.resnet18(pretrained=True) | ||
>>> # extract layer1 and layer3, giving as names `feat1` and feat2` | ||
>>> new_m = torchvision.models._utils.get_intermediate_layers(m, | ||
>>> {'layer1': 'feat1', 'layer3': 'feat2'}) | ||
>>> out = new_m(torch.rand(1, 3, 224, 224)) | ||
>>> print([(k, v.shape) for k, v in out.items()]) | ||
>>> [('feat1', torch.Size([1, 64, 56, 56])), | ||
>>> ('feat2', torch.Size([1, 256, 14, 14]))] | ||
""" | ||
# TODO come up with a better name for this | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think |
||
# TODO have duplicate nodes but full coverage for module names | ||
return_layers = {str(k): str(v) for k, v in return_layers.items()} | ||
|
||
# Instantiate our ModulePathTracer and use that to trace the model | ||
tracer = ModulePathTracer() | ||
graph = tracer.trace(model) | ||
|
||
name = model.__class__.__name__ if isinstance(model, nn.Module) else model.__name__ | ||
m = torch.fx.GraphModule(tracer.root, graph, name) | ||
|
||
# check that all outputs in return_layers are present in the model | ||
if not set(return_layers).issubset(tracer.node_to_originating_module.values()): | ||
raise ValueError("return_layers are not present in model") | ||
|
||
# Get output node | ||
orig_output_node: Optional[torch.fx.Node] = None | ||
for n in reversed(m.graph.nodes): | ||
if n.op == "output": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What happens in cases where we have multiple outputs (example
You can see this by replacing this input on your test:
|
||
orig_output_node = n | ||
break | ||
assert orig_output_node | ||
# and remove it | ||
m.graph.erase_node(orig_output_node) | ||
|
||
# find output nodes corresponding to return_layers | ||
nodes = [n for n in m.graph.nodes] | ||
output_node = OrderedDict() | ||
for n in nodes: | ||
module_qualname = tracer.node_to_originating_module.get(n) | ||
if module_qualname in return_layers: | ||
output_node[return_layers[module_qualname]] = n | ||
|
||
# and add them in the end of the graph | ||
with m.graph.inserting_after(nodes[-1]): | ||
m.graph.output(output_node) | ||
|
||
m.graph.eliminate_dead_code() | ||
m.recompile() | ||
|
||
# remove unused modules / parameters | ||
m = torch.fx.GraphModule(m, m.graph, name) | ||
return m |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI, it fails when we include the final output in the return layers:
with:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the catch! I need to check more carefully, but the old implementation doesn't work in this case because of the
torch.flatten
call (which is not ann.Module
), but I believe this should work with the new implementation. To be verified