Skip to content

Commit f014841

Browse files
authored
Add revamped docs for video classification models (#5894)
* Add revamped docs for video classification models * EOL
1 parent 36c4635 commit f014841

File tree

4 files changed

+96
-21
lines changed

4 files changed

+96
-21
lines changed

docs/source/conf.py

+1
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,7 @@ def generate_weights_table(module, table_name, metrics):
379379
generate_weights_table(
380380
module=M.segmentation, table_name="segmentation", metrics=[("miou", "Mean IoU"), ("pixel_acc", "pixelwise Acc")]
381381
)
382+
generate_weights_table(module=M.video, table_name="video", metrics=[("acc@1", "Acc@1"), ("acc@5", "Acc@5")])
382383

383384

384385
def setup(app):

docs/source/models/video_resnet.rst

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
Video ResNet
2+
============
3+
4+
.. currentmodule:: torchvision.models.video
5+
6+
The VideoResNet model is based on the `A Closer Look at Spatiotemporal
7+
Convolutions for Action Recognition <https://arxiv.org/abs/1711.11248>`__ paper.
8+
9+
10+
Model builders
11+
--------------
12+
13+
The following model builders can be used to instantiate a VideoResNet model, with or
14+
without pre-trained weights. All the model builders internally rely on the
15+
``torchvision.models.video.resnet.VideoResNet`` base class. Please refer to the `source
16+
code
17+
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/resnet.py>`_ for
18+
more details about this class.
19+
20+
.. autosummary::
21+
:toctree: generated/
22+
:template: function.rst
23+
24+
r3d_18
25+
mc3_18
26+
r2plus1d_18

docs/source/models_new.rst

+21
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,24 @@ Table of all available detection weights
101101
Box MAPs are reported on COCO
102102

103103
.. include:: generated/detection_table.rst
104+
105+
106+
Video Classification
107+
====================
108+
109+
.. currentmodule:: torchvision.models.video
110+
111+
The following video classification models are available, with or without
112+
pre-trained weights:
113+
114+
.. toctree::
115+
:maxdepth: 1
116+
117+
models/video_resnet
118+
119+
Table of all available video classification weights
120+
---------------------------------------------------
121+
122+
Accuracies are reported on Kinetics-400
123+
124+
.. include:: generated/video_table.rst

torchvision/models/video/resnet.py

+48-21
Original file line numberDiff line numberDiff line change
@@ -365,15 +365,24 @@ class R2Plus1D_18_Weights(WeightsEnum):
365365

366366
@handle_legacy_interface(weights=("pretrained", R3D_18_Weights.KINETICS400_V1))
367367
def r3d_18(*, weights: Optional[R3D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
368-
"""Construct 18 layer Resnet3D model as in
369-
https://arxiv.org/abs/1711.11248
368+
"""Construct 18 layer Resnet3D model.
370369
371-
Args:
372-
weights (R3D_18_Weights, optional): The pretrained weights for the model
373-
progress (bool): If True, displays a progress bar of the download to stderr
370+
Reference: `A Closer Look at Spatiotemporal Convolutions for Action Recognition <https://arxiv.org/abs/1711.11248>`__.
374371
375-
Returns:
376-
VideoResNet: R3D-18 network
372+
Args:
373+
weights (:class:`~torchvision.models.video.R3D_18_Weights`, optional): The
374+
pretrained weights to use. See
375+
:class:`~torchvision.models.video.R3D_18_Weights`
376+
below for more details, and possible values. By default, no
377+
pre-trained weights are used.
378+
progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
379+
**kwargs: parameters passed to the ``torchvision.models.video.resnet.VideoResNet`` base class.
380+
Please refer to the `source code
381+
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/resnet.py>`_
382+
for more details about this class.
383+
384+
.. autoclass:: torchvision.models.video.R3D_18_Weights
385+
:members:
377386
"""
378387
weights = R3D_18_Weights.verify(weights)
379388

@@ -390,15 +399,24 @@ def r3d_18(*, weights: Optional[R3D_18_Weights] = None, progress: bool = True, *
390399

391400
@handle_legacy_interface(weights=("pretrained", MC3_18_Weights.KINETICS400_V1))
392401
def mc3_18(*, weights: Optional[MC3_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
393-
"""Constructor for 18 layer Mixed Convolution network as in
394-
https://arxiv.org/abs/1711.11248
402+
"""Construct 18 layer Mixed Convolution network as in
395403
396-
Args:
397-
weights (MC3_18_Weights, optional): The pretrained weights for the model
398-
progress (bool): If True, displays a progress bar of the download to stderr
404+
Reference: `A Closer Look at Spatiotemporal Convolutions for Action Recognition <https://arxiv.org/abs/1711.11248>`__.
399405
400-
Returns:
401-
VideoResNet: MC3 Network definition
406+
Args:
407+
weights (:class:`~torchvision.models.video.MC3_18_Weights`, optional): The
408+
pretrained weights to use. See
409+
:class:`~torchvision.models.video.MC3_18_Weights`
410+
below for more details, and possible values. By default, no
411+
pre-trained weights are used.
412+
progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
413+
**kwargs: parameters passed to the ``torchvision.models.video.resnet.VideoResNet`` base class.
414+
Please refer to the `source code
415+
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/resnet.py>`_
416+
for more details about this class.
417+
418+
.. autoclass:: torchvision.models.video.MC3_18_Weights
419+
:members:
402420
"""
403421
weights = MC3_18_Weights.verify(weights)
404422

@@ -415,15 +433,24 @@ def mc3_18(*, weights: Optional[MC3_18_Weights] = None, progress: bool = True, *
415433

416434
@handle_legacy_interface(weights=("pretrained", R2Plus1D_18_Weights.KINETICS400_V1))
417435
def r2plus1d_18(*, weights: Optional[R2Plus1D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
418-
"""Constructor for the 18 layer deep R(2+1)D network as in
419-
https://arxiv.org/abs/1711.11248
436+
"""Construct 18 layer deep R(2+1)D network as in
420437
421-
Args:
422-
weights (R2Plus1D_18_Weights, optional): The pretrained weights for the model
423-
progress (bool): If True, displays a progress bar of the download to stderr
438+
Reference: `A Closer Look at Spatiotemporal Convolutions for Action Recognition <https://arxiv.org/abs/1711.11248>`__.
424439
425-
Returns:
426-
VideoResNet: R(2+1)D-18 network
440+
Args:
441+
weights (:class:`~torchvision.models.video.R2Plus1D_18_Weights`, optional): The
442+
pretrained weights to use. See
443+
:class:`~torchvision.models.video.R2Plus1D_18_Weights`
444+
below for more details, and possible values. By default, no
445+
pre-trained weights are used.
446+
progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
447+
**kwargs: parameters passed to the ``torchvision.models.video.resnet.VideoResNet`` base class.
448+
Please refer to the `source code
449+
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/resnet.py>`_
450+
for more details about this class.
451+
452+
.. autoclass:: torchvision.models.video.R2Plus1D_18_Weights
453+
:members:
427454
"""
428455
weights = R2Plus1D_18_Weights.verify(weights)
429456

0 commit comments

Comments
 (0)