Skip to content

Add more info on new models.srt #6025

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
May 16, 2022
10 changes: 5 additions & 5 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,10 +347,6 @@ def inject_weight_metadata(app, what, name, obj, options, lines):
metrics = meta.pop("metrics", {})
meta_with_metrics = dict(meta, **metrics)

# We don't want to document these, they can be too long
for k in ["categories", "keypoint_names"]:
meta_with_metrics.pop(k, None)

custom_docs = meta_with_metrics.pop("_docs", None) # Custom per-Weights docs
if custom_docs is not None:
lines += [custom_docs, ""]
Expand All @@ -360,14 +356,18 @@ def inject_weight_metadata(app, what, name, obj, options, lines):
v = f"`link <{v}>`__"
elif k == "min_size":
v = f"height={v[0]}, width={v[1]}"
elif k in {"categories", "keypoint_names"} and isinstance(v, list):
max_visible = 3
v_sample = ", ".join(v[:max_visible])
v = f"{v_sample}, ... ({len(v)-max_visible} omitted)" if len(v) > max_visible else v_sample
table.append((str(k), str(v)))
table = tabulate(table, tablefmt="rst")
lines += [".. rst-class:: table-weights"] # Custom CSS class, see custom_torchvision.css
lines += [".. table::", ""]
lines += textwrap.indent(table, " " * 4).split("\n")
lines.append("")
lines.append(
f"The inference transforms are available at ``{str(field)}.transforms`` and "
f"The preprocessing/inference transforms are available at ``{str(field)}.transforms`` and "
f"perform the following operations: {field.transforms().describe()}"
)
lines.append("")
Expand Down
152 changes: 126 additions & 26 deletions docs/source/models_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,42 @@
Models and pre-trained weights - New
####################################

.. note::

These are the new models docs, documenting the new multi-weight API.
TODO: Once all is done, remove the "- New" part in the title above, and
rename this file as models.rst


The ``torchvision.models`` subpackage contains definitions of models for addressing
different tasks, including: image classification, pixelwise semantic
segmentation, object detection, instance segmentation, person
keypoint detection, video classification, and optical flow.

General information on pre-trained weights
==========================================

TorchVision offers pre-trained weights for every provided architecture, using
the PyTorch :mod:`torch.hub`. Instancing a pre-trained model will download its
weights to a cache directory. This directory can be set using the `TORCH_HOME`
environment variable. See :func:`torch.hub.load_state_dict_from_url` for details.

.. note::

The pre-trained models provided in this library may have their own licenses or
terms and conditions derived from the dataset used for training. It is your
responsibility to determine whether you have permission to use the models for
your use case.

.. note ::
Backward compatibility is guaranteed for loading a serialized
``state_dict`` to the model created using old PyTorch version.
On the contrary, loading entire saved models or serialized
``ScriptModules`` (seralized using older versions of PyTorch)
may not preserve the historic behaviour. Refer to the following
`documentation
<https://pytorch.org/docs/stable/notes/serialization.html#id6>`_
Backward compatibility is guaranteed for loading a serialized
``state_dict`` to the model created using old PyTorch version.
On the contrary, loading entire saved models or serialized
``ScriptModules`` (serialized using older versions of PyTorch)
may not preserve the historic behaviour. Refer to the following
`documentation
<https://pytorch.org/docs/stable/notes/serialization.html#id6>`_


Initializing pre-trained models
-------------------------------

As of v0.13, TorchVision offers a new `Multi-weight support API
<https://pytorch.org/blog/introducing-torchvision-new-multi-weight-support-api/>`_ for loading different weights to the
existing model builder methods:
<https://pytorch.org/blog/introducing-torchvision-new-multi-weight-support-api/>`_
for loading different weights to the existing model builder methods:

.. code:: python

Expand All @@ -46,7 +58,7 @@ existing model builder methods:
resnet50(weights="IMAGENET1K_V2")

# No weights - random initialization
resnet50(weights=None) # or resnet50()
resnet50(weights=None)


Migrating to the new API is very straightforward. The following method calls between the 2 APIs are all equivalent:
Expand All @@ -57,16 +69,57 @@ Migrating to the new API is very straightforward. The following method calls bet

# Using pretrained weights:
resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
resnet50(weights="IMAGENET1K_V1")
resnet50(pretrained=True) # deprecated
resnet50(True) # deprecated

# Using no weights:
resnet50(weights=None)
resnet50()
resnet50(pretrained=False) # deprecated
resnet50(False) # deprecated

Note that the ``pretrained`` parameter is now deprecated, using it will emit warnings and will be removed on v0.15.

Using the pre-trained models
----------------------------

Before using the pre-trained models, one must preprocess the image
(resize with right resolution/interpolation, apply inference transforms,
rescale the values etc). There is no standard way to do this as it depends on
how a given model was trained. It can vary across model families, variants or
even weight versions. Using the correct preprocessing method is critical and
failing to do so may lead to decreased accuracy or incorrect outputs.

All the necessary information for the inference transforms of each pre-trained
model is provided on its weights documentation. To simplify inference, TorchVision
bundles the necessary preprocessing transforms into each model weight. These are
accessible via the ``weight.transforms`` attribute:

.. code:: python

# Initialize the Weight Transforms
weights = ResNet50_Weights.DEFAULT
preprocess = weights.transforms()

# Apply it to the input image
img_transformed = preprocess(img)


Some models use modules which have different training and evaluation
behavior, such as batch normalization. To switch between these modes, use
``model.train()`` or ``model.eval()`` as appropriate. See
:meth:`~torch.nn.Module.train` or :meth:`~torch.nn.Module.eval` for details.

.. code:: python

# Initialize model
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)

# Set model to eval mode
model.eval()


Classification
==============
Expand Down Expand Up @@ -128,10 +181,12 @@ Here is an example of how to use the pre-trained image classification models:
category_name = weights.meta["categories"][class_id]
print(f"{category_name}: {100 * score:.1f}%")

The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``.

Table of all available classification weights
---------------------------------------------

Accuracies are reported on ImageNet
Accuracies are reported on ImageNet-1K using single crops:

.. include:: generated/classification_table.rst

Expand All @@ -140,7 +195,7 @@ Quantized models

.. currentmodule:: torchvision.models.quantization

The following quantized classification models are available, with or without
The following architectures provide support for INT8 quantized models, with or without
pre-trained weights:

.. toctree::
Expand Down Expand Up @@ -181,11 +236,13 @@ Here is an example of how to use the pre-trained quantized image classification
category_name = weights.meta["categories"][class_id]
print(f"{category_name}: {100 * score}%")

The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``.


Table of all available quantized classification weights
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Accuracies are reported on ImageNet
Accuracies are reported on ImageNet-1K using single crops:

.. include:: generated/classification_quant_table.rst

Expand Down Expand Up @@ -234,11 +291,14 @@ Here is an example of how to use the pre-trained semantic segmentation models:
mask = normalized_masks[0, class_to_idx["dog"]]
to_pil_image(mask).show()

The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``.
The output format of the models is illustrated in :ref:`semantic_seg_output`.


Table of all available semantic segmentation weights
----------------------------------------------------

All models are evaluated on COCO val2017:
All models are evaluated a subset of COCO val2017, on the 20 categories that are present in the Pascal VOC dataset:

.. include:: generated/segmentation_table.rst

Expand All @@ -247,6 +307,11 @@ All models are evaluated on COCO val2017:
Object Detection, Instance Segmentation and Person Keypoint Detection
=====================================================================

The pre-trained models for detection, instance segmentation and
keypoint detection are initialized with the classification models
in torchvision. The models expect a list of ``Tensor[C, H, W]``.
Check the constructor of the models for more information.

Object Detection
----------------

Expand Down Expand Up @@ -299,10 +364,13 @@ Here is an example of how to use the pre-trained object detection models:
im = to_pil_image(box.detach())
im.show()

The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``.
For details on how to plot the bounding boxes of the models, you may refer to :ref:`instance_seg_output`.

Table of all available Object detection weights
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Box MAPs are reported on COCO
Box MAPs are reported on COCO val2017:

.. include:: generated/detection_table.rst

Expand All @@ -319,10 +387,15 @@ weights:

models/mask_rcnn

|


For details on how to plot the masks of the models, you may refer to :ref:`instance_seg_output`.

Table of all available Instance segmentation weights
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Box and Mask MAPs are reported on COCO
Box and Mask MAPs are reported on COCO val2017:

.. include:: generated/instance_segmentation_table.rst

Expand All @@ -331,18 +404,23 @@ Keypoint Detection

.. currentmodule:: torchvision.models.detection

The following keypoint detection models are available, with or without
The following person keypoint detection models are available, with or without
pre-trained weights:

.. toctree::
:maxdepth: 1

models/keypoint_rcnn

|

The classes of the pre-trained model outputs can be found at ``weights.meta["keypoint_names"]``.
For details on how to plot the bounding boxes of the models, you may refer to :ref:`keypoint_output`.

Table of all available Keypoint detection weights
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Box and Keypoint MAPs are reported on COCO:
Box and Keypoint MAPs are reported on COCO val2017:

.. include:: generated/detection_keypoint_table.rst

Expand Down Expand Up @@ -391,10 +469,32 @@ Here is an example of how to use the pre-trained video classification models:
category_name = weights.meta["categories"][label]
print(f"{category_name}: {100 * score}%")

The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``.


Table of all available video classification weights
---------------------------------------------------

Accuracies are reported on Kinetics-400
Accuracies are reported on Kinetics-400 using single crops for clip length 16:

.. include:: generated/video_table.rst

Using models from Hub
=====================

Most pre-trained models can be accessed directly via PyTorch Hub without having TorchVision installed:

.. code:: python

import torch

# Option 1: passing weights param as string
model = torch.hub.load("pytorch/vision", "resnet50", weights="IMAGENET1K_V2")

# Option 2: passing weights param as enum
weights = torch.hub.load("pytorch/vision", "get_weight", weights="ResNet50_Weights.IMAGENET1K_V2")
model = torch.hub.load("pytorch/vision", "resnet50", weights=weights)

The only exception to the above are the detection models included on
:mod:`torchvision.models.detection`. These models require TorchVision
to be installed because they depend on custom C++ operators.
2 changes: 2 additions & 0 deletions gallery/plot_visualization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,8 @@ def show(imgs):
# instance with class 15 (which corresponds to 'bench') was not selected.

#####################################
# .. _keypoint_output:
#
# Visualizing keypoints
# ------------------------------
# The :func:`~torchvision.utils.draw_keypoints` function can be used to
Expand Down
11 changes: 6 additions & 5 deletions torchvision/transforms/_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def __repr__(self) -> str:
def describe(self) -> str:
return (
f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, "
f"followed by a central crop of ``crop_size={self.crop_size}``. Then the values are rescaled to "
f"``[0.0, 1.0]`` and normalized using ``mean={self.mean}`` and ``std={self.std}``."
f"followed by a central crop of ``crop_size={self.crop_size}``. Finally the values are first rescaled to "
f"``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and ``std={self.std}``."
)


Expand Down Expand Up @@ -127,8 +127,8 @@ def __repr__(self) -> str:
def describe(self) -> str:
return (
f"The video frames are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, "
f"followed by a central crop of ``crop_size={self.crop_size}``. Then the values are rescaled to "
f"``[0.0, 1.0]`` and normalized using ``mean={self.mean}`` and ``std={self.std}``."
f"followed by a central crop of ``crop_size={self.crop_size}``. Finally the values are first rescaled to "
f"``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and ``std={self.std}``."
)


Expand Down Expand Up @@ -168,7 +168,8 @@ def __repr__(self) -> str:
def describe(self) -> str:
return (
f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``. "
f"Then the values are rescaled to ``[0.0, 1.0]`` and normalized using ``mean={self.mean}`` and ``std={self.std}``."
f"Finally the values are first rescaled to ``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and "
f"``std={self.std}``."
)


Expand Down