diff --git a/torchvision/models/_utils.py b/torchvision/models/_utils.py index 3f70f8dd25a..617778116b4 100644 --- a/torchvision/models/_utils.py +++ b/torchvision/models/_utils.py @@ -5,7 +5,7 @@ from torch.jit.annotations import Dict -class IntermediateLayerGetter(nn.Module): +class IntermediateLayerGetter(nn.ModuleDict): """ Module wrapper that returns intermediate layers from a model @@ -45,8 +45,6 @@ class IntermediateLayerGetter(nn.Module): def __init__(self, model, return_layers): if not set(return_layers).issubset([name for name, _ in model.named_children()]): raise ValueError("return_layers are not present in model") - super(IntermediateLayerGetter, self).__init__() - orig_return_layers = return_layers return_layers = {k: v for k, v in return_layers.items()} layers = OrderedDict() @@ -57,33 +55,14 @@ def __init__(self, model, return_layers): if not return_layers: break - self.layers = nn.ModuleDict(layers) + super(IntermediateLayerGetter, self).__init__(layers) self.return_layers = orig_return_layers def forward(self, x): out = OrderedDict() - for name, module in self.layers.items(): + for name, module in self.items(): x = module(x) if name in self.return_layers: out_name = self.return_layers[name] out[out_name] = x return out - - @torch.jit.ignore - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs): - version = local_metadata.get('version', None) - if (version is None or version < 2): - # now we have a new nesting level for torchscript support - for new_key in self.state_dict().keys(): - # remove prefix "layers." - old_key = new_key[len("layers."):] - old_key = prefix + old_key - new_key = prefix + new_key - if old_key in state_dict: - value = state_dict[old_key] - del state_dict[old_key] - state_dict[new_key] = value - super(IntermediateLayerGetter, self)._load_from_state_dict( - state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs) diff --git a/torchvision/models/densenet.py b/torchvision/models/densenet.py index 06bbc3653b7..d4dc283a5f5 100644 --- a/torchvision/models/densenet.py +++ b/torchvision/models/densenet.py @@ -90,13 +90,12 @@ def forward(self, input): # noqa: F811 return new_features -class _DenseBlock(nn.Module): +class _DenseBlock(nn.ModuleDict): _version = 2 __constants__ = ['layers'] def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, memory_efficient=False): super(_DenseBlock, self).__init__() - self.layers = nn.ModuleDict() for i in range(num_layers): layer = _DenseLayer( num_input_features + i * growth_rate, @@ -105,34 +104,15 @@ def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_ra drop_rate=drop_rate, memory_efficient=memory_efficient, ) - self.layers['denselayer%d' % (i + 1)] = layer + self.add_module('denselayer%d' % (i + 1), layer) def forward(self, init_features): features = [init_features] - for name, layer in self.layers.items(): + for name, layer in self.items(): new_features = layer(features) features.append(new_features) return torch.cat(features, 1) - @torch.jit.ignore - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs): - version = local_metadata.get('version', None) - if (version is None or version < 2): - # now we have a new nesting level for torchscript support - for new_key in self.state_dict().keys(): - # remove prefix "layers." - old_key = new_key[len("layers."):] - old_key = prefix + old_key - new_key = prefix + new_key - if old_key in state_dict: - value = state_dict[old_key] - del state_dict[old_key] - state_dict[new_key] = value - super(_DenseBlock, self)._load_from_state_dict( - state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs) - class _Transition(nn.Sequential): def __init__(self, num_input_features, num_output_features):