Skip to content

Commit 769ae13

Browse files
datumboxNicolasHug
andauthored
Add more info on new models.srt (#6025)
* Minor updates on model examples. * Improving wording of auto-generated docs. * Add general info for pre-trained weights. * Updating torch hub * Minor updates * Make lengthy meta-data partially visible * Adding meta-data and reference info. * Minor corrections * Update docs/source/models_new.rst Co-authored-by: Nicolas Hug <[email protected]> * Moving Torch hub section at the end Co-authored-by: Nicolas Hug <[email protected]>
1 parent 44252c8 commit 769ae13

File tree

4 files changed

+139
-36
lines changed

4 files changed

+139
-36
lines changed

docs/source/conf.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -347,10 +347,6 @@ def inject_weight_metadata(app, what, name, obj, options, lines):
347347
metrics = meta.pop("metrics", {})
348348
meta_with_metrics = dict(meta, **metrics)
349349

350-
# We don't want to document these, they can be too long
351-
for k in ["categories", "keypoint_names"]:
352-
meta_with_metrics.pop(k, None)
353-
354350
custom_docs = meta_with_metrics.pop("_docs", None) # Custom per-Weights docs
355351
if custom_docs is not None:
356352
lines += [custom_docs, ""]
@@ -360,14 +356,18 @@ def inject_weight_metadata(app, what, name, obj, options, lines):
360356
v = f"`link <{v}>`__"
361357
elif k == "min_size":
362358
v = f"height={v[0]}, width={v[1]}"
359+
elif k in {"categories", "keypoint_names"} and isinstance(v, list):
360+
max_visible = 3
361+
v_sample = ", ".join(v[:max_visible])
362+
v = f"{v_sample}, ... ({len(v)-max_visible} omitted)" if len(v) > max_visible else v_sample
363363
table.append((str(k), str(v)))
364364
table = tabulate(table, tablefmt="rst")
365365
lines += [".. rst-class:: table-weights"] # Custom CSS class, see custom_torchvision.css
366366
lines += [".. table::", ""]
367367
lines += textwrap.indent(table, " " * 4).split("\n")
368368
lines.append("")
369369
lines.append(
370-
f"The inference transforms are available at ``{str(field)}.transforms`` and "
370+
f"The preprocessing/inference transforms are available at ``{str(field)}.transforms`` and "
371371
f"perform the following operations: {field.transforms().describe()}"
372372
)
373373
lines.append("")

docs/source/models_new.rst

Lines changed: 126 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,30 +3,42 @@
33
Models and pre-trained weights - New
44
####################################
55

6-
.. note::
7-
8-
These are the new models docs, documenting the new multi-weight API.
9-
TODO: Once all is done, remove the "- New" part in the title above, and
10-
rename this file as models.rst
11-
12-
136
The ``torchvision.models`` subpackage contains definitions of models for addressing
147
different tasks, including: image classification, pixelwise semantic
158
segmentation, object detection, instance segmentation, person
169
keypoint detection, video classification, and optical flow.
1710

11+
General information on pre-trained weights
12+
==========================================
13+
14+
TorchVision offers pre-trained weights for every provided architecture, using
15+
the PyTorch :mod:`torch.hub`. Instancing a pre-trained model will download its
16+
weights to a cache directory. This directory can be set using the `TORCH_HOME`
17+
environment variable. See :func:`torch.hub.load_state_dict_from_url` for details.
18+
19+
.. note::
20+
21+
The pre-trained models provided in this library may have their own licenses or
22+
terms and conditions derived from the dataset used for training. It is your
23+
responsibility to determine whether you have permission to use the models for
24+
your use case.
25+
1826
.. note ::
19-
Backward compatibility is guaranteed for loading a serialized
20-
``state_dict`` to the model created using old PyTorch version.
21-
On the contrary, loading entire saved models or serialized
22-
``ScriptModules`` (seralized using older versions of PyTorch)
23-
may not preserve the historic behaviour. Refer to the following
24-
`documentation
25-
<https://pytorch.org/docs/stable/notes/serialization.html#id6>`_
27+
Backward compatibility is guaranteed for loading a serialized
28+
``state_dict`` to the model created using old PyTorch version.
29+
On the contrary, loading entire saved models or serialized
30+
``ScriptModules`` (serialized using older versions of PyTorch)
31+
may not preserve the historic behaviour. Refer to the following
32+
`documentation
33+
<https://pytorch.org/docs/stable/notes/serialization.html#id6>`_
34+
35+
36+
Initializing pre-trained models
37+
-------------------------------
2638

2739
As of v0.13, TorchVision offers a new `Multi-weight support API
28-
<https://pytorch.org/blog/introducing-torchvision-new-multi-weight-support-api/>`_ for loading different weights to the
29-
existing model builder methods:
40+
<https://pytorch.org/blog/introducing-torchvision-new-multi-weight-support-api/>`_
41+
for loading different weights to the existing model builder methods:
3042

3143
.. code:: python
3244
@@ -46,7 +58,7 @@ existing model builder methods:
4658
resnet50(weights="IMAGENET1K_V2")
4759
4860
# No weights - random initialization
49-
resnet50(weights=None) # or resnet50()
61+
resnet50(weights=None)
5062
5163
5264
Migrating to the new API is very straightforward. The following method calls between the 2 APIs are all equivalent:
@@ -57,16 +69,57 @@ Migrating to the new API is very straightforward. The following method calls bet
5769
5870
# Using pretrained weights:
5971
resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
72+
resnet50(weights="IMAGENET1K_V1")
6073
resnet50(pretrained=True) # deprecated
6174
resnet50(True) # deprecated
6275
6376
# Using no weights:
6477
resnet50(weights=None)
78+
resnet50()
6579
resnet50(pretrained=False) # deprecated
6680
resnet50(False) # deprecated
6781
6882
Note that the ``pretrained`` parameter is now deprecated, using it will emit warnings and will be removed on v0.15.
6983

84+
Using the pre-trained models
85+
----------------------------
86+
87+
Before using the pre-trained models, one must preprocess the image
88+
(resize with right resolution/interpolation, apply inference transforms,
89+
rescale the values etc). There is no standard way to do this as it depends on
90+
how a given model was trained. It can vary across model families, variants or
91+
even weight versions. Using the correct preprocessing method is critical and
92+
failing to do so may lead to decreased accuracy or incorrect outputs.
93+
94+
All the necessary information for the inference transforms of each pre-trained
95+
model is provided on its weights documentation. To simplify inference, TorchVision
96+
bundles the necessary preprocessing transforms into each model weight. These are
97+
accessible via the ``weight.transforms`` attribute:
98+
99+
.. code:: python
100+
101+
# Initialize the Weight Transforms
102+
weights = ResNet50_Weights.DEFAULT
103+
preprocess = weights.transforms()
104+
105+
# Apply it to the input image
106+
img_transformed = preprocess(img)
107+
108+
109+
Some models use modules which have different training and evaluation
110+
behavior, such as batch normalization. To switch between these modes, use
111+
``model.train()`` or ``model.eval()`` as appropriate. See
112+
:meth:`~torch.nn.Module.train` or :meth:`~torch.nn.Module.eval` for details.
113+
114+
.. code:: python
115+
116+
# Initialize model
117+
weights = ResNet50_Weights.DEFAULT
118+
model = resnet50(weights=weights)
119+
120+
# Set model to eval mode
121+
model.eval()
122+
70123
71124
Classification
72125
==============
@@ -128,10 +181,12 @@ Here is an example of how to use the pre-trained image classification models:
128181
category_name = weights.meta["categories"][class_id]
129182
print(f"{category_name}: {100 * score:.1f}%")
130183
184+
The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``.
185+
131186
Table of all available classification weights
132187
---------------------------------------------
133188

134-
Accuracies are reported on ImageNet
189+
Accuracies are reported on ImageNet-1K using single crops:
135190

136191
.. include:: generated/classification_table.rst
137192

@@ -140,7 +195,7 @@ Quantized models
140195

141196
.. currentmodule:: torchvision.models.quantization
142197

143-
The following quantized classification models are available, with or without
198+
The following architectures provide support for INT8 quantized models, with or without
144199
pre-trained weights:
145200

146201
.. toctree::
@@ -181,11 +236,13 @@ Here is an example of how to use the pre-trained quantized image classification
181236
category_name = weights.meta["categories"][class_id]
182237
print(f"{category_name}: {100 * score}%")
183238
239+
The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``.
240+
184241

185242
Table of all available quantized classification weights
186243
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
187244

188-
Accuracies are reported on ImageNet
245+
Accuracies are reported on ImageNet-1K using single crops:
189246

190247
.. include:: generated/classification_quant_table.rst
191248

@@ -234,11 +291,14 @@ Here is an example of how to use the pre-trained semantic segmentation models:
234291
mask = normalized_masks[0, class_to_idx["dog"]]
235292
to_pil_image(mask).show()
236293
294+
The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``.
295+
The output format of the models is illustrated in :ref:`semantic_seg_output`.
296+
237297

238298
Table of all available semantic segmentation weights
239299
----------------------------------------------------
240300

241-
All models are evaluated on COCO val2017:
301+
All models are evaluated a subset of COCO val2017, on the 20 categories that are present in the Pascal VOC dataset:
242302

243303
.. include:: generated/segmentation_table.rst
244304

@@ -247,6 +307,11 @@ All models are evaluated on COCO val2017:
247307
Object Detection, Instance Segmentation and Person Keypoint Detection
248308
=====================================================================
249309

310+
The pre-trained models for detection, instance segmentation and
311+
keypoint detection are initialized with the classification models
312+
in torchvision. The models expect a list of ``Tensor[C, H, W]``.
313+
Check the constructor of the models for more information.
314+
250315
Object Detection
251316
----------------
252317

@@ -299,10 +364,13 @@ Here is an example of how to use the pre-trained object detection models:
299364
im = to_pil_image(box.detach())
300365
im.show()
301366
367+
The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``.
368+
For details on how to plot the bounding boxes of the models, you may refer to :ref:`instance_seg_output`.
369+
302370
Table of all available Object detection weights
303371
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
304372

305-
Box MAPs are reported on COCO
373+
Box MAPs are reported on COCO val2017:
306374

307375
.. include:: generated/detection_table.rst
308376

@@ -319,10 +387,15 @@ weights:
319387

320388
models/mask_rcnn
321389

390+
|
391+
392+
393+
For details on how to plot the masks of the models, you may refer to :ref:`instance_seg_output`.
394+
322395
Table of all available Instance segmentation weights
323396
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
324397

325-
Box and Mask MAPs are reported on COCO
398+
Box and Mask MAPs are reported on COCO val2017:
326399

327400
.. include:: generated/instance_segmentation_table.rst
328401

@@ -331,18 +404,23 @@ Keypoint Detection
331404

332405
.. currentmodule:: torchvision.models.detection
333406

334-
The following keypoint detection models are available, with or without
407+
The following person keypoint detection models are available, with or without
335408
pre-trained weights:
336409

337410
.. toctree::
338411
:maxdepth: 1
339412

340413
models/keypoint_rcnn
341414

415+
|
416+
417+
The classes of the pre-trained model outputs can be found at ``weights.meta["keypoint_names"]``.
418+
For details on how to plot the bounding boxes of the models, you may refer to :ref:`keypoint_output`.
419+
342420
Table of all available Keypoint detection weights
343421
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
344422

345-
Box and Keypoint MAPs are reported on COCO:
423+
Box and Keypoint MAPs are reported on COCO val2017:
346424

347425
.. include:: generated/detection_keypoint_table.rst
348426

@@ -391,10 +469,32 @@ Here is an example of how to use the pre-trained video classification models:
391469
category_name = weights.meta["categories"][label]
392470
print(f"{category_name}: {100 * score}%")
393471
472+
The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``.
473+
394474

395475
Table of all available video classification weights
396476
---------------------------------------------------
397477

398-
Accuracies are reported on Kinetics-400
478+
Accuracies are reported on Kinetics-400 using single crops for clip length 16:
399479

400480
.. include:: generated/video_table.rst
481+
482+
Using models from Hub
483+
=====================
484+
485+
Most pre-trained models can be accessed directly via PyTorch Hub without having TorchVision installed:
486+
487+
.. code:: python
488+
489+
import torch
490+
491+
# Option 1: passing weights param as string
492+
model = torch.hub.load("pytorch/vision", "resnet50", weights="IMAGENET1K_V2")
493+
494+
# Option 2: passing weights param as enum
495+
weights = torch.hub.load("pytorch/vision", "get_weight", weights="ResNet50_Weights.IMAGENET1K_V2")
496+
model = torch.hub.load("pytorch/vision", "resnet50", weights=weights)
497+
498+
The only exception to the above are the detection models included on
499+
:mod:`torchvision.models.detection`. These models require TorchVision
500+
to be installed because they depend on custom C++ operators.

gallery/plot_visualization_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,8 @@ def show(imgs):
379379
# instance with class 15 (which corresponds to 'bench') was not selected.
380380

381381
#####################################
382+
# .. _keypoint_output:
383+
#
382384
# Visualizing keypoints
383385
# ------------------------------
384386
# The :func:`~torchvision.utils.draw_keypoints` function can be used to

torchvision/transforms/_presets.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ def __repr__(self) -> str:
7171
def describe(self) -> str:
7272
return (
7373
f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, "
74-
f"followed by a central crop of ``crop_size={self.crop_size}``. Then the values are rescaled to "
75-
f"``[0.0, 1.0]`` and normalized using ``mean={self.mean}`` and ``std={self.std}``."
74+
f"followed by a central crop of ``crop_size={self.crop_size}``. Finally the values are first rescaled to "
75+
f"``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and ``std={self.std}``."
7676
)
7777

7878

@@ -127,8 +127,8 @@ def __repr__(self) -> str:
127127
def describe(self) -> str:
128128
return (
129129
f"The video frames are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, "
130-
f"followed by a central crop of ``crop_size={self.crop_size}``. Then the values are rescaled to "
131-
f"``[0.0, 1.0]`` and normalized using ``mean={self.mean}`` and ``std={self.std}``."
130+
f"followed by a central crop of ``crop_size={self.crop_size}``. Finally the values are first rescaled to "
131+
f"``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and ``std={self.std}``."
132132
)
133133

134134

@@ -168,7 +168,8 @@ def __repr__(self) -> str:
168168
def describe(self) -> str:
169169
return (
170170
f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``. "
171-
f"Then the values are rescaled to ``[0.0, 1.0]`` and normalized using ``mean={self.mean}`` and ``std={self.std}``."
171+
f"Finally the values are first rescaled to ``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and "
172+
f"``std={self.std}``."
172173
)
173174

174175

0 commit comments

Comments
 (0)