3
3
Models and pre-trained weights - New
4
4
####################################
5
5
6
- .. note ::
7
-
8
- These are the new models docs, documenting the new multi-weight API.
9
- TODO: Once all is done, remove the "- New" part in the title above, and
10
- rename this file as models.rst
11
-
12
-
13
6
The ``torchvision.models `` subpackage contains definitions of models for addressing
14
7
different tasks, including: image classification, pixelwise semantic
15
8
segmentation, object detection, instance segmentation, person
16
9
keypoint detection, video classification, and optical flow.
17
10
11
+ General information on pre-trained weights
12
+ ==========================================
13
+
14
+ TorchVision offers pre-trained weights for every provided architecture, using
15
+ the PyTorch :mod: `torch.hub `. Instancing a pre-trained model will download its
16
+ weights to a cache directory. This directory can be set using the `TORCH_HOME `
17
+ environment variable. See :func: `torch.hub.load_state_dict_from_url ` for details.
18
+
19
+ .. note ::
20
+
21
+ The pre-trained models provided in this library may have their own licenses or
22
+ terms and conditions derived from the dataset used for training. It is your
23
+ responsibility to determine whether you have permission to use the models for
24
+ your use case.
25
+
18
26
.. note ::
19
- Backward compatibility is guaranteed for loading a serialized
20
- ``state_dict`` to the model created using old PyTorch version.
21
- On the contrary, loading entire saved models or serialized
22
- ``ScriptModules`` (seralized using older versions of PyTorch)
23
- may not preserve the historic behaviour. Refer to the following
24
- `documentation
25
- <https://pytorch.org/docs/stable/notes/serialization.html#id6>`_
27
+ Backward compatibility is guaranteed for loading a serialized
28
+ ``state_dict`` to the model created using old PyTorch version.
29
+ On the contrary, loading entire saved models or serialized
30
+ ``ScriptModules`` (serialized using older versions of PyTorch)
31
+ may not preserve the historic behaviour. Refer to the following
32
+ `documentation
33
+ <https://pytorch.org/docs/stable/notes/serialization.html#id6>`_
34
+
35
+
36
+ Initializing pre-trained models
37
+ -------------------------------
26
38
27
39
As of v0.13, TorchVision offers a new `Multi-weight support API
28
- <https://pytorch.org/blog/introducing-torchvision-new-multi-weight-support-api/> `_ for loading different weights to the
29
- existing model builder methods:
40
+ <https://pytorch.org/blog/introducing-torchvision-new-multi-weight-support-api/> `_
41
+ for loading different weights to the existing model builder methods:
30
42
31
43
.. code :: python
32
44
@@ -46,7 +58,7 @@ existing model builder methods:
46
58
resnet50(weights = " IMAGENET1K_V2" )
47
59
48
60
# No weights - random initialization
49
- resnet50(weights = None ) # or resnet50()
61
+ resnet50(weights = None )
50
62
51
63
52
64
Migrating to the new API is very straightforward. The following method calls between the 2 APIs are all equivalent:
@@ -57,16 +69,57 @@ Migrating to the new API is very straightforward. The following method calls bet
57
69
58
70
# Using pretrained weights:
59
71
resnet50(weights = ResNet50_Weights.IMAGENET1K_V1 )
72
+ resnet50(weights = " IMAGENET1K_V1" )
60
73
resnet50(pretrained = True ) # deprecated
61
74
resnet50(True ) # deprecated
62
75
63
76
# Using no weights:
64
77
resnet50(weights = None )
78
+ resnet50()
65
79
resnet50(pretrained = False ) # deprecated
66
80
resnet50(False ) # deprecated
67
81
68
82
Note that the ``pretrained `` parameter is now deprecated, using it will emit warnings and will be removed on v0.15.
69
83
84
+ Using the pre-trained models
85
+ ----------------------------
86
+
87
+ Before using the pre-trained models, one must preprocess the image
88
+ (resize with right resolution/interpolation, apply inference transforms,
89
+ rescale the values etc). There is no standard way to do this as it depends on
90
+ how a given model was trained. It can vary across model families, variants or
91
+ even weight versions. Using the correct preprocessing method is critical and
92
+ failing to do so may lead to decreased accuracy or incorrect outputs.
93
+
94
+ All the necessary information for the inference transforms of each pre-trained
95
+ model is provided on its weights documentation. To simplify inference, TorchVision
96
+ bundles the necessary preprocessing transforms into each model weight. These are
97
+ accessible via the ``weight.transforms `` attribute:
98
+
99
+ .. code :: python
100
+
101
+ # Initialize the Weight Transforms
102
+ weights = ResNet50_Weights.DEFAULT
103
+ preprocess = weights.transforms()
104
+
105
+ # Apply it to the input image
106
+ img_transformed = preprocess(img)
107
+
108
+
109
+ Some models use modules which have different training and evaluation
110
+ behavior, such as batch normalization. To switch between these modes, use
111
+ ``model.train() `` or ``model.eval() `` as appropriate. See
112
+ :meth: `~torch.nn.Module.train ` or :meth: `~torch.nn.Module.eval ` for details.
113
+
114
+ .. code :: python
115
+
116
+ # Initialize model
117
+ weights = ResNet50_Weights.DEFAULT
118
+ model = resnet50(weights = weights)
119
+
120
+ # Set model to eval mode
121
+ model.eval()
122
+
70
123
71
124
Classification
72
125
==============
@@ -128,10 +181,12 @@ Here is an example of how to use the pre-trained image classification models:
128
181
category_name = weights.meta[" categories" ][class_id]
129
182
print (f " { category_name} : { 100 * score:.1f } % " )
130
183
184
+ The classes of the pre-trained model outputs can be found at ``weights.meta["categories"] ``.
185
+
131
186
Table of all available classification weights
132
187
---------------------------------------------
133
188
134
- Accuracies are reported on ImageNet
189
+ Accuracies are reported on ImageNet-1K using single crops:
135
190
136
191
.. include :: generated/classification_table.rst
137
192
@@ -140,7 +195,7 @@ Quantized models
140
195
141
196
.. currentmodule :: torchvision.models.quantization
142
197
143
- The following quantized classification models are available , with or without
198
+ The following architectures provide support for INT8 quantized models , with or without
144
199
pre-trained weights:
145
200
146
201
.. toctree ::
@@ -181,11 +236,13 @@ Here is an example of how to use the pre-trained quantized image classification
181
236
category_name = weights.meta[" categories" ][class_id]
182
237
print (f " { category_name} : { 100 * score} % " )
183
238
239
+ The classes of the pre-trained model outputs can be found at ``weights.meta["categories"] ``.
240
+
184
241
185
242
Table of all available quantized classification weights
186
243
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
187
244
188
- Accuracies are reported on ImageNet
245
+ Accuracies are reported on ImageNet-1K using single crops:
189
246
190
247
.. include :: generated/classification_quant_table.rst
191
248
@@ -234,11 +291,14 @@ Here is an example of how to use the pre-trained semantic segmentation models:
234
291
mask = normalized_masks[0 , class_to_idx[" dog" ]]
235
292
to_pil_image(mask).show()
236
293
294
+ The classes of the pre-trained model outputs can be found at ``weights.meta["categories"] ``.
295
+ The output format of the models is illustrated in :ref: `semantic_seg_output `.
296
+
237
297
238
298
Table of all available semantic segmentation weights
239
299
----------------------------------------------------
240
300
241
- All models are evaluated on COCO val2017:
301
+ All models are evaluated a subset of COCO val2017, on the 20 categories that are present in the Pascal VOC dataset :
242
302
243
303
.. include :: generated/segmentation_table.rst
244
304
@@ -247,6 +307,11 @@ All models are evaluated on COCO val2017:
247
307
Object Detection, Instance Segmentation and Person Keypoint Detection
248
308
=====================================================================
249
309
310
+ The pre-trained models for detection, instance segmentation and
311
+ keypoint detection are initialized with the classification models
312
+ in torchvision. The models expect a list of ``Tensor[C, H, W] ``.
313
+ Check the constructor of the models for more information.
314
+
250
315
Object Detection
251
316
----------------
252
317
@@ -299,10 +364,13 @@ Here is an example of how to use the pre-trained object detection models:
299
364
im = to_pil_image(box.detach())
300
365
im.show()
301
366
367
+ The classes of the pre-trained model outputs can be found at ``weights.meta["categories"] ``.
368
+ For details on how to plot the bounding boxes of the models, you may refer to :ref: `instance_seg_output `.
369
+
302
370
Table of all available Object detection weights
303
371
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
304
372
305
- Box MAPs are reported on COCO
373
+ Box MAPs are reported on COCO val2017:
306
374
307
375
.. include :: generated/detection_table.rst
308
376
@@ -319,10 +387,15 @@ weights:
319
387
320
388
models/mask_rcnn
321
389
390
+ |
391
+
392
+
393
+ For details on how to plot the masks of the models, you may refer to :ref: `instance_seg_output `.
394
+
322
395
Table of all available Instance segmentation weights
323
396
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
324
397
325
- Box and Mask MAPs are reported on COCO
398
+ Box and Mask MAPs are reported on COCO val2017:
326
399
327
400
.. include :: generated/instance_segmentation_table.rst
328
401
@@ -331,18 +404,23 @@ Keypoint Detection
331
404
332
405
.. currentmodule :: torchvision.models.detection
333
406
334
- The following keypoint detection models are available, with or without
407
+ The following person keypoint detection models are available, with or without
335
408
pre-trained weights:
336
409
337
410
.. toctree ::
338
411
:maxdepth: 1
339
412
340
413
models/keypoint_rcnn
341
414
415
+ |
416
+
417
+ The classes of the pre-trained model outputs can be found at ``weights.meta["keypoint_names"] ``.
418
+ For details on how to plot the bounding boxes of the models, you may refer to :ref: `keypoint_output `.
419
+
342
420
Table of all available Keypoint detection weights
343
421
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
344
422
345
- Box and Keypoint MAPs are reported on COCO:
423
+ Box and Keypoint MAPs are reported on COCO val2017 :
346
424
347
425
.. include :: generated/detection_keypoint_table.rst
348
426
@@ -391,10 +469,32 @@ Here is an example of how to use the pre-trained video classification models:
391
469
category_name = weights.meta[" categories" ][label]
392
470
print (f " { category_name} : { 100 * score} % " )
393
471
472
+ The classes of the pre-trained model outputs can be found at ``weights.meta["categories"] ``.
473
+
394
474
395
475
Table of all available video classification weights
396
476
---------------------------------------------------
397
477
398
- Accuracies are reported on Kinetics-400
478
+ Accuracies are reported on Kinetics-400 using single crops for clip length 16:
399
479
400
480
.. include :: generated/video_table.rst
481
+
482
+ Using models from Hub
483
+ =====================
484
+
485
+ Most pre-trained models can be accessed directly via PyTorch Hub without having TorchVision installed:
486
+
487
+ .. code :: python
488
+
489
+ import torch
490
+
491
+ # Option 1: passing weights param as string
492
+ model = torch.hub.load(" pytorch/vision" , " resnet50" , weights = " IMAGENET1K_V2" )
493
+
494
+ # Option 2: passing weights param as enum
495
+ weights = torch.hub.load(" pytorch/vision" , " get_weight" , weights = " ResNet50_Weights.IMAGENET1K_V2" )
496
+ model = torch.hub.load(" pytorch/vision" , " resnet50" , weights = weights)
497
+
498
+ The only exception to the above are the detection models included on
499
+ :mod: `torchvision.models.detection `. These models require TorchVision
500
+ to be installed because they depend on custom C++ operators.
0 commit comments