diff --git a/torchvision/models/densenet.py b/torchvision/models/densenet.py index 1bcdb0b429d..e95018a970a 100644 --- a/torchvision/models/densenet.py +++ b/torchvision/models/densenet.py @@ -1,3 +1,4 @@ +import re import torch import torch.nn as nn import torch.nn.functional as F @@ -25,7 +26,20 @@ def densenet121(pretrained=False, **kwargs): model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), **kwargs) if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['densenet121'])) + # '.'s are no longer allowed in module names, but pervious _DenseLayer + # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. + # They are also in the checkpoints in model_urls. This pattern is used + # to find such keys. + pattern = re.compile( + r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') + state_dict = model_zoo.load_url(model_urls['densenet121']) + for key in list(state_dict.keys()): + res = pattern.match(key) + if res: + new_key = res.group(1) + res.group(2) + state_dict[new_key] = state_dict[key] + del state_dict[key] + model.load_state_dict(state_dict) return model @@ -39,7 +53,20 @@ def densenet169(pretrained=False, **kwargs): model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), **kwargs) if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['densenet169'])) + # '.'s are no longer allowed in module names, but pervious _DenseLayer + # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. + # They are also in the checkpoints in model_urls. This pattern is used + # to find such keys. + pattern = re.compile( + r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') + state_dict = model_zoo.load_url(model_urls['densenet169']) + for key in list(state_dict.keys()): + res = pattern.match(key) + if res: + new_key = res.group(1) + res.group(2) + state_dict[new_key] = state_dict[key] + del state_dict[key] + model.load_state_dict(state_dict) return model @@ -53,7 +80,20 @@ def densenet201(pretrained=False, **kwargs): model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), **kwargs) if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['densenet201'])) + # '.'s are no longer allowed in module names, but pervious _DenseLayer + # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. + # They are also in the checkpoints in model_urls. This pattern is used + # to find such keys. + pattern = re.compile( + r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') + state_dict = model_zoo.load_url(model_urls['densenet201']) + for key in list(state_dict.keys()): + res = pattern.match(key) + if res: + new_key = res.group(1) + res.group(2) + state_dict[new_key] = state_dict[key] + del state_dict[key] + model.load_state_dict(state_dict) return model @@ -67,20 +107,33 @@ def densenet161(pretrained=False, **kwargs): model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), **kwargs) if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['densenet161'])) + # '.'s are no longer allowed in module names, but pervious _DenseLayer + # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. + # They are also in the checkpoints in model_urls. This pattern is used + # to find such keys. + pattern = re.compile( + r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') + state_dict = model_zoo.load_url(model_urls['densenet161']) + for key in list(state_dict.keys()): + res = pattern.match(key) + if res: + new_key = res.group(1) + res.group(2) + state_dict[new_key] = state_dict[key] + del state_dict[key] + model.load_state_dict(state_dict) return model class _DenseLayer(nn.Sequential): def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): super(_DenseLayer, self).__init__() - self.add_module('norm.1', nn.BatchNorm2d(num_input_features)), - self.add_module('relu.1', nn.ReLU(inplace=True)), - self.add_module('conv.1', nn.Conv2d(num_input_features, bn_size * + self.add_module('norm1', nn.BatchNorm2d(num_input_features)), + self.add_module('relu1', nn.ReLU(inplace=True)), + self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)), - self.add_module('norm.2', nn.BatchNorm2d(bn_size * growth_rate)), - self.add_module('relu.2', nn.ReLU(inplace=True)), - self.add_module('conv.2', nn.Conv2d(bn_size * growth_rate, growth_rate, + self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), + self.add_module('relu2', nn.ReLU(inplace=True)), + self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)), self.drop_rate = drop_rate