@@ -34,7 +34,7 @@ def _load_state_dict(model: nn.Module, weights: Weights, progress: bool) -> None
34
34
r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
35
35
)
36
36
37
- state_dict = weights .state_dict (progress = progress )
37
+ state_dict = weights .get_state_dict (progress = progress )
38
38
for key in list (state_dict .keys ()):
39
39
res = pattern .match (key )
40
40
if res :
@@ -63,11 +63,11 @@ def _densenet(
63
63
return model
64
64
65
65
66
- _common_meta = {
66
+ _COMMON_META = {
67
67
"size" : (224 , 224 ),
68
68
"categories" : _IMAGENET_CATEGORIES ,
69
69
"interpolation" : InterpolationMode .BILINEAR ,
70
- "recipe" : None , # weights ported from LuaTorch
70
+ "recipe" : None , # TODO: add here a URL to documentation stating that the weights were ported from LuaTorch
71
71
}
72
72
73
73
@@ -76,7 +76,7 @@ class DenseNet121Weights(Weights):
76
76
url = "https://download.pytorch.org/models/densenet121-a639ec97.pth" ,
77
77
transforms = partial (ImageNetEval , crop_size = 224 ),
78
78
meta = {
79
- ** _common_meta ,
79
+ ** _COMMON_META ,
80
80
"acc@1" : 74.434 ,
81
81
"acc@5" : 91.972 ,
82
82
},
@@ -88,7 +88,7 @@ class DenseNet161Weights(Weights):
88
88
url = "https://download.pytorch.org/models/densenet161-8d451a50.pth" ,
89
89
transforms = partial (ImageNetEval , crop_size = 224 ),
90
90
meta = {
91
- ** _common_meta ,
91
+ ** _COMMON_META ,
92
92
"acc@1" : 77.138 ,
93
93
"acc@5" : 93.560 ,
94
94
},
@@ -100,7 +100,7 @@ class DenseNet169Weights(Weights):
100
100
url = "https://download.pytorch.org/models/densenet169-b2777c0a.pth" ,
101
101
transforms = partial (ImageNetEval , crop_size = 224 ),
102
102
meta = {
103
- ** _common_meta ,
103
+ ** _COMMON_META ,
104
104
"acc@1" : 75.600 ,
105
105
"acc@5" : 92.806 ,
106
106
},
@@ -112,7 +112,7 @@ class DenseNet201Weights(Weights):
112
112
url = "https://download.pytorch.org/models/densenet201-c1103571.pth" ,
113
113
transforms = partial (ImageNetEval , crop_size = 224 ),
114
114
meta = {
115
- ** _common_meta ,
115
+ ** _COMMON_META ,
116
116
"acc@1" : 76.896 ,
117
117
"acc@5" : 93.370 ,
118
118
},
@@ -121,7 +121,7 @@ class DenseNet201Weights(Weights):
121
121
122
122
def densenet121 (weights : Optional [DenseNet121Weights ] = None , progress : bool = True , ** kwargs : Any ) -> DenseNet :
123
123
if "pretrained" in kwargs :
124
- warnings .warn ("The argument pretrained is deprecated, please use weights instead." )
124
+ warnings .warn ("The parameter pretrained is deprecated, please use weights instead." )
125
125
weights = DenseNet121Weights .ImageNet1K_Community if kwargs .pop ("pretrained" ) else None
126
126
weights = DenseNet121Weights .verify (weights )
127
127
@@ -130,7 +130,7 @@ def densenet121(weights: Optional[DenseNet121Weights] = None, progress: bool = T
130
130
131
131
def densenet161 (weights : Optional [DenseNet161Weights ] = None , progress : bool = True , ** kwargs : Any ) -> DenseNet :
132
132
if "pretrained" in kwargs :
133
- warnings .warn ("The argument pretrained is deprecated, please use weights instead." )
133
+ warnings .warn ("The parameter pretrained is deprecated, please use weights instead." )
134
134
weights = DenseNet161Weights .ImageNet1K_Community if kwargs .pop ("pretrained" ) else None
135
135
weights = DenseNet161Weights .verify (weights )
136
136
@@ -139,7 +139,7 @@ def densenet161(weights: Optional[DenseNet161Weights] = None, progress: bool = T
139
139
140
140
def densenet169 (weights : Optional [DenseNet169Weights ] = None , progress : bool = True , ** kwargs : Any ) -> DenseNet :
141
141
if "pretrained" in kwargs :
142
- warnings .warn ("The argument pretrained is deprecated, please use weights instead." )
142
+ warnings .warn ("The parameter pretrained is deprecated, please use weights instead." )
143
143
weights = DenseNet169Weights .ImageNet1K_Community if kwargs .pop ("pretrained" ) else None
144
144
weights = DenseNet169Weights .verify (weights )
145
145
@@ -148,7 +148,7 @@ def densenet169(weights: Optional[DenseNet169Weights] = None, progress: bool = T
148
148
149
149
def densenet201 (weights : Optional [DenseNet201Weights ] = None , progress : bool = True , ** kwargs : Any ) -> DenseNet :
150
150
if "pretrained" in kwargs :
151
- warnings .warn ("The argument pretrained is deprecated, please use weights instead." )
151
+ warnings .warn ("The parameter pretrained is deprecated, please use weights instead." )
152
152
weights = DenseNet201Weights .ImageNet1K_Community if kwargs .pop ("pretrained" ) else None
153
153
weights = DenseNet201Weights .verify (weights )
154
154
0 commit comments