Skip to content

Commit 7434662

Browse files
authored
Updating code on the article.
1 parent 7e43226 commit 7434662

File tree

1 file changed

+19
-19
lines changed

1 file changed

+19
-19
lines changed

_posts/2021-12-22-introducing-torchvision-new-multi-weight-support-api.md

+19-19
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ from torchvision.prototype import models as PM
7474
img = Image.open("test/assets/encode_jpeg/grace_hopper_517x606.jpg")
7575

7676
# Step 1: Initialize model
77-
weights = PM.ResNet50_Weights.ImageNet1K_V1
77+
weights = PM.ResNet50_Weights.IMAGENET1K_V1
7878
model = PM.resnet50(weights=weights)
7979
model.eval()
8080

@@ -96,21 +96,21 @@ As we can see the new API eliminates the aforementioned limitations. Let’s exp
9696

9797
### Multi-weight support
9898

99-
At the heart of the new API, we have the ability to define multiple different weights for the same model variant. Each model building method (eg `resnet50`) has an associated Enum class (eg `ResNet50_Weights`) which has as many entries as the number of pre-trained weights available. Additionally, each Enum class has a `default` alias which points to the best available weights for the specific model. This allows the users who want to always use the best available weights to do so without modifying their code.
99+
At the heart of the new API, we have the ability to define multiple different weights for the same model variant. Each model building method (eg `resnet50`) has an associated Enum class (eg `ResNet50_Weights`) which has as many entries as the number of pre-trained weights available. Additionally, each Enum class has a `DEFAULT` alias which points to the best available weights for the specific model. This allows the users who want to always use the best available weights to do so without modifying their code.
100100

101101
Here is an example of initializing models with different weights:
102102

103103
```python
104104
from torchvision.prototype.models import resnet50, ResNet50_Weights
105105

106106
# Legacy weights with accuracy 76.130%
107-
model = resnet50(weights=ResNet50_Weights.ImageNet1K_V1)
107+
model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
108108

109-
# New weights with accuracy 80.674%
110-
model = resnet50(weights=ResNet50_Weights.ImageNet1K_V2)
109+
# New weights with accuracy 80.858%
110+
model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
111111

112-
# Best available weights (currently alias for ImageNet1K_V2)
113-
model = resnet50(weights=ResNet50_Weights.default)
112+
# Best available weights (currently alias for IMAGENET1K_V2)
113+
model = resnet50(weights=ResNet50_Weights.DEFAULT)
114114

115115
# No weights - random initialization
116116
model = resnet50(weights=None)
@@ -124,10 +124,10 @@ The weights of each model are associated with meta-data. The type of information
124124
from torchvision.prototype.models import ResNet50_Weights
125125

126126
# Accessing a single record
127-
size = ResNet50_Weights.ImageNet1K_V2.meta["size"]
127+
size = ResNet50_Weights.IMAGENET1K_V2.meta["size"]
128128

129129
# Iterating the items of the meta-data dictionary
130-
for k, v in ResNet50_Weights.ImageNet1K_V2.meta.items():
130+
for k, v in ResNet50_Weights.IMAGENET1K_V2.meta.items():
131131
print(k, v)
132132
```
133133

@@ -137,10 +137,10 @@ Additionally, each weights entry is associated with the necessary preprocessing
137137
from torchvision.prototype.models import ResNet50_Weights
138138

139139
# Initializing preprocessing at standard 224x224 resolution
140-
preprocess = ResNet50_Weights.ImageNet1K.transforms()
140+
preprocess = ResNet50_Weights.IMAGENET1K_V2.transforms()
141141

142142
# Initializing preprocessing at 400x400 resolution
143-
preprocess = ResNet50_Weights.ImageNet1K.transforms(crop_size=400, resize_size=400)
143+
preprocess = ResNet50_Weights.IMAGENET1K_V2.transforms(crop_size=400, resize_size=400)
144144

145145
# Once initialized the callable can accept the image data:
146146
# img_preprocessed = preprocess(img)
@@ -156,11 +156,11 @@ The ability to link directly the weights with their properties (meta data, prepr
156156
from torchvision.prototype.models import get_weight
157157

158158
# Weights can be retrieved by name:
159-
assert get_weight("ResNet50_Weights.ImageNet1K_V1") == ResNet50_Weights.ImageNet1K_V1
160-
assert get_weight("ResNet50_Weights.ImageNet1K_V2") == ResNet50_Weights.ImageNet1K_V2
159+
assert get_weight("ResNet50_Weights.IMAGENET1K_V1") == ResNet50_Weights.IMAGENET1K_V1
160+
assert get_weight("ResNet50_Weights.IMAGENET1K_V2") == ResNet50_Weights.IMAGENET1K_V2
161161

162-
# Including using the default alias:
163-
assert get_weight("ResNet50_Weights.default") == ResNet50_Weights.ImageNet1K_V2
162+
# Including using the DEFAULT alias:
163+
assert get_weight("ResNet50_Weights.DEFAULT") == ResNet50_Weights.IMAGENET1K_V2
164164
```
165165

166166
## Deprecations
@@ -172,8 +172,8 @@ In the new API the boolean `pretrained` and `pretrained_backbone` parameters, wh
172172
UserWarning: The parameter 'pretrained' is deprecated, please use 'weights' instead.
173173
UserWarning:
174174
Arguments other than a weight enum or `None` for 'weights' are deprecated.
175-
The current behavior is equivalent to passing `weights=ResNet50_Weights.ImageNet1K_V1`.
176-
You can also use `weights=ResNet50_Weights.default` to get the most up-to-date weights.
175+
The current behavior is equivalent to passing `weights=ResNet50_Weights.IMAGENET1K_V1`.
176+
You can also use `weights=ResNet50_Weights.DEFAULT` to get the most up-to-date weights.
177177
```
178178

179179
Additionally the builder methods require using keyword parameters. The use of positional parameter is deprecated and using them emits the following warning:
@@ -191,7 +191,7 @@ Migrating to the new API is very straightforward. The following method calls bet
191191

192192
```
193193
# Using pretrained weights:
194-
torchvision.prototype.models.resnet50(weights=ResNet50_Weights.ImageNet1K_V1)
194+
torchvision.prototype.models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
195195
torchvision.models.resnet50(pretrained=True)
196196
torchvision.models.resnet50(True)
197197

@@ -237,7 +237,7 @@ If you are still unconvinced about giving a try to the new API, here is one more
237237
|RegNet Y 8gf |80.032 |82.828 |
238238
|RegNet Y 16gf |80.424 |82.89 |
239239
|RegNet Y 32gf |80.878 |83.366 |
240-
|ResNet50 |76.13 |80.674 |
240+
|ResNet50 |76.13 |80.858 |
241241
|ResNet101 |77.374 |81.886 |
242242
|ResNet152 |78.312 |82.284 |
243243
|ResNeXt50 32x4d |77.618 |81.198 |

0 commit comments

Comments
 (0)