Skip to content

remove BC-breaking changes #1560

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 15, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 3 additions & 24 deletions torchvision/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch.jit.annotations import Dict


class IntermediateLayerGetter(nn.Module):
class IntermediateLayerGetter(nn.ModuleDict):
"""
Module wrapper that returns intermediate layers from a model

Expand Down Expand Up @@ -45,8 +45,6 @@ class IntermediateLayerGetter(nn.Module):
def __init__(self, model, return_layers):
if not set(return_layers).issubset([name for name, _ in model.named_children()]):
raise ValueError("return_layers are not present in model")
super(IntermediateLayerGetter, self).__init__()

orig_return_layers = return_layers
return_layers = {k: v for k, v in return_layers.items()}
layers = OrderedDict()
Expand All @@ -57,33 +55,14 @@ def __init__(self, model, return_layers):
if not return_layers:
break

self.layers = nn.ModuleDict(layers)
super(IntermediateLayerGetter, self).__init__(layers)
self.return_layers = orig_return_layers

def forward(self, x):
out = OrderedDict()
for name, module in self.layers.items():
for name, module in self.items():
x = module(x)
if name in self.return_layers:
out_name = self.return_layers[name]
out[out_name] = x
return out

@torch.jit.ignore
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
version = local_metadata.get('version', None)
if (version is None or version < 2):
# now we have a new nesting level for torchscript support
for new_key in self.state_dict().keys():
# remove prefix "layers."
old_key = new_key[len("layers."):]
old_key = prefix + old_key
new_key = prefix + new_key
if old_key in state_dict:
value = state_dict[old_key]
del state_dict[old_key]
state_dict[new_key] = value
super(IntermediateLayerGetter, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
26 changes: 3 additions & 23 deletions torchvision/models/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,12 @@ def forward(self, input): # noqa: F811
return new_features


class _DenseBlock(nn.Module):
class _DenseBlock(nn.ModuleDict):
_version = 2
__constants__ = ['layers']

def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, memory_efficient=False):
super(_DenseBlock, self).__init__()
self.layers = nn.ModuleDict()
for i in range(num_layers):
layer = _DenseLayer(
num_input_features + i * growth_rate,
Expand All @@ -105,34 +104,15 @@ def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_ra
drop_rate=drop_rate,
memory_efficient=memory_efficient,
)
self.layers['denselayer%d' % (i + 1)] = layer
self.add_module('denselayer%d' % (i + 1), layer)

def forward(self, init_features):
features = [init_features]
for name, layer in self.layers.items():
for name, layer in self.items():
new_features = layer(features)
features.append(new_features)
return torch.cat(features, 1)

@torch.jit.ignore
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
version = local_metadata.get('version', None)
if (version is None or version < 2):
# now we have a new nesting level for torchscript support
for new_key in self.state_dict().keys():
# remove prefix "layers."
old_key = new_key[len("layers."):]
old_key = prefix + old_key
new_key = prefix + new_key
if old_key in state_dict:
value = state_dict[old_key]
del state_dict[old_key]
state_dict[new_key] = value
super(_DenseBlock, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)


class _Transition(nn.Sequential):
def __init__(self, num_input_features, num_output_features):
Expand Down